| import yaml |
| import os |
| import subprocess |
| import argparse |
|
|
| def check_dataset_files(yaml_file, rerun=False): |
| """ |
| Check if all required .bin files exist for each dataset in the YAML file. |
| """ |
| try: |
| |
| with open(yaml_file, 'r') as file: |
| config = yaml.safe_load(file) |
|
|
| |
| if 'Datasets' not in config: |
| print(f"No 'Datasets' section found in {yaml_file}.") |
| return |
|
|
| datasets = config['Datasets'] |
| all_files_exist = True |
|
|
| for dataset_name, dataset_config in datasets.items(): |
| |
| save_dir = dataset_config['args']['save_dir'] |
| chunks = dataset_config['args']['chunks'] |
| folding = dataset_config.get('folding', {}) |
| n_folds = folding.get('n_folds', 0) |
| test_folds = folding.get('test', []) |
| train_folds = folding.get('train', []) |
|
|
| print(f"\n== Checking dataset: {dataset_name} ==") |
| print(f" save_dir: {save_dir}") |
| print(f" chunks: {chunks}") |
| print(f" n_folds: {n_folds}") |
| print(f" test_folds: {test_folds}") |
| print(f" train_folds: {train_folds}") |
|
|
| missing_files = [] |
|
|
| |
| for chunk in range(chunks): |
| chunk_file = os.path.join(save_dir, f"{dataset_name}_{chunk}.bin") |
| if not os.path.exists(chunk_file): |
| missing_files.append(chunk_file) |
|
|
| |
| |
| fold_types = [('test', test_folds), ('train', train_folds)] |
| for fold_type, folds in fold_types: |
| if not folds: |
| continue |
| foldlist_str = '_'.join(map(str, folds)) |
| for i in range(chunks): |
| prebatched_file = os.path.join( |
| save_dir, |
| f"{dataset_name}_prebatched_padded_{i}_n_{n_folds}_f_{foldlist_str}.bin" |
| ) |
| if not os.path.exists(prebatched_file): |
| missing_files.append(prebatched_file) |
|
|
| |
| if missing_files: |
| all_files_exist = False |
| print(f" Missing files for dataset '{dataset_name}':") |
| for missing_file in missing_files: |
| print(f" - {missing_file}") |
|
|
| |
| if rerun: |
| print(f" Reprocessing dataset '{dataset_name}' ...") |
| prep_command = f"bash/prep_data.sh {yaml_file} {dataset_name} {chunks}" |
| try: |
| subprocess.run(prep_command, shell=True, check=True) |
| except subprocess.CalledProcessError as e: |
| print(f" Could NOT reprocess '{dataset_name}': {e}") |
| else: |
| print(f" All files exist for dataset '{dataset_name}'.") |
|
|
| |
| if all_files_exist: |
| print("\nAll required files exist for all datasets.") |
| else: |
| print("\nSome files are missing.") |
|
|
| except Exception as e: |
| print(f"Error processing {yaml_file}: {e}") |
|
|
| def main(pargs): |
| |
| base_directory = os.getcwd() + "/configs/" |
|
|
| if pargs.configs: |
| configs = [p.strip() for p in pargs.configs.split(',')] |
| else: |
| configs = [ |
| "attention/ttH_CP_even_vs_odd.yaml", |
|
|
| "stats_100K/finetuning_ttH_CP_even_vs_odd.yaml", |
| "stats_100K/pretraining_multiclass.yaml", |
| "stats_100K/ttH_CP_even_vs_odd.yaml", |
|
|
| "stats_all/finetuning_ttH_CP_even_vs_odd.yaml", |
| "stats_all/pretraining_multiclass.yaml", |
| "stats_all/ttH_CP_even_vs_odd.yaml", |
| ] |
|
|
| for config in configs: |
| yaml_file = os.path.join(base_directory, config) |
| if os.path.exists(yaml_file): |
| print(f"\nProcessing file: {config}") |
| check_dataset_files(yaml_file, pargs.rerun) |
| else: |
| print(f"File not found: {yaml_file}") |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Check YAML config files") |
| parser.add_argument( |
| "--configs", "-c", |
| type=str, |
| required=False, |
| help="Comma-separated list of YAML config paths relative to base directory" |
| ) |
| parser.add_argument( |
| "--rerun", "-r", |
| action='store_true', |
| help="Automatically re-run data processing to fix missing files" |
| ) |
| args = parser.parse_args() |
| main(args) |