| | import yaml |
| | import os |
| | import subprocess |
| | import argparse |
| |
|
| | def check_dataset_files(yaml_file, rerun=False): |
| | """ |
| | Check if all required .bin files exist for each dataset in the YAML file. |
| | """ |
| | try: |
| | |
| | with open(yaml_file, 'r') as file: |
| | config = yaml.safe_load(file) |
| |
|
| | |
| | if 'Datasets' not in config: |
| | print(f"No 'Datasets' section found in {yaml_file}.") |
| | return |
| |
|
| | datasets = config['Datasets'] |
| | all_files_exist = True |
| |
|
| | for dataset_name, dataset_config in datasets.items(): |
| | |
| | save_dir = dataset_config['args']['save_dir'] |
| | chunks = dataset_config['args']['chunks'] |
| | folding = dataset_config.get('folding', {}) |
| | n_folds = folding.get('n_folds', 0) |
| | test_folds = folding.get('test', []) |
| | train_folds = folding.get('train', []) |
| |
|
| | print(f"\n== Checking dataset: {dataset_name} ==") |
| | print(f" save_dir: {save_dir}") |
| | print(f" chunks: {chunks}") |
| | print(f" n_folds: {n_folds}") |
| | print(f" test_folds: {test_folds}") |
| | print(f" train_folds: {train_folds}") |
| |
|
| | missing_files = [] |
| |
|
| | |
| | for chunk in range(chunks): |
| | chunk_file = os.path.join(save_dir, f"{dataset_name}_{chunk}.bin") |
| | if not os.path.exists(chunk_file): |
| | missing_files.append(chunk_file) |
| |
|
| | |
| | |
| | fold_types = [('test', test_folds), ('train', train_folds)] |
| | for fold_type, folds in fold_types: |
| | if not folds: |
| | continue |
| | foldlist_str = '_'.join(map(str, folds)) |
| | for i in range(chunks): |
| | prebatched_file = os.path.join( |
| | save_dir, |
| | f"{dataset_name}_prebatched_padded_{i}_n_{n_folds}_f_{foldlist_str}.bin" |
| | ) |
| | if not os.path.exists(prebatched_file): |
| | missing_files.append(prebatched_file) |
| |
|
| | |
| | if missing_files: |
| | all_files_exist = False |
| | print(f" Missing files for dataset '{dataset_name}':") |
| | for missing_file in missing_files: |
| | print(f" - {missing_file}") |
| |
|
| | |
| | if rerun: |
| | print(f" Reprocessing dataset '{dataset_name}' ...") |
| | prep_command = f"bash/prep_data.sh {yaml_file} {dataset_name} {chunks}" |
| | try: |
| | subprocess.run(prep_command, shell=True, check=True) |
| | except subprocess.CalledProcessError as e: |
| | print(f" Could NOT reprocess '{dataset_name}': {e}") |
| | else: |
| | print(f" All files exist for dataset '{dataset_name}'.") |
| |
|
| | |
| | if all_files_exist: |
| | print("\nAll required files exist for all datasets.") |
| | else: |
| | print("\nSome files are missing.") |
| |
|
| | except Exception as e: |
| | print(f"Error processing {yaml_file}: {e}") |
| |
|
| | def main(pargs): |
| | |
| | base_directory = os.getcwd() + "/configs/" |
| |
|
| | if pargs.configs: |
| | configs = [p.strip() for p in pargs.configs.split(',')] |
| | else: |
| | configs = [ |
| | "attention/ttH_CP_even_vs_odd.yaml", |
| |
|
| | "stats_100K/finetuning_ttH_CP_even_vs_odd.yaml", |
| | "stats_100K/pretraining_multiclass.yaml", |
| | "stats_100K/ttH_CP_even_vs_odd.yaml", |
| |
|
| | "stats_all/finetuning_ttH_CP_even_vs_odd.yaml", |
| | "stats_all/pretraining_multiclass.yaml", |
| | "stats_all/ttH_CP_even_vs_odd.yaml", |
| | ] |
| |
|
| | for config in configs: |
| | yaml_file = os.path.join(base_directory, config) |
| | if os.path.exists(yaml_file): |
| | print(f"\nProcessing file: {config}") |
| | check_dataset_files(yaml_file, pargs.rerun) |
| | else: |
| | print(f"File not found: {yaml_file}") |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Check YAML config files") |
| | parser.add_argument( |
| | "--configs", "-c", |
| | type=str, |
| | required=False, |
| | help="Comma-separated list of YAML config paths relative to base directory" |
| | ) |
| | parser.add_argument( |
| | "--rerun", "-r", |
| | action='store_true', |
| | help="Automatically re-run data processing to fix missing files" |
| | ) |
| | args = parser.parse_args() |
| | main(args) |