Spaces:
Configuration error
Configuration error
| import traceback | |
| from copy import deepcopy | |
| import fire | |
| from run import main | |
| from src import config as C | |
| from src.config import Config | |
| from src.exp import experiments | |
| from src.utils import files, logger | |
| def get_val_files(): | |
| return [ | |
| *files.DeepSpeak_v2.my_val, | |
| *files.DeepSpeak_v1_1.my_val, | |
| *files.CDFv2.val, | |
| *files.FFIW.val, | |
| ] | |
| def get_test_files(): | |
| return { | |
| "FF": files.FF.test, | |
| "FF-DF": files.FF.DF.test, | |
| "FF-F2F": files.FF.F2F.test, | |
| "FF-FS": files.FF.FS.test, | |
| "FF-NT": files.FF.NT.test, | |
| "CDF": files.CDFv2.test, | |
| "FaceFusion": files.FaceFusion.CDF.test, | |
| "DFD": files.DFD.test, | |
| "DFDC": files.DFDC.test, | |
| "FSh": files.FSh.test, | |
| "UADFD": files.UADFV.test, | |
| "DFDM": files.DFDM.test, | |
| "FFIW": files.FFIW.test, | |
| "DeepSpeak-1.1": files.DeepSpeak_v1_1.test, | |
| "DeepSpeak-2.0": files.DeepSpeak_v2.test, | |
| "KoDF": files.KoDF.test, | |
| "KoDF-adv": files.KoDF.adversarial, | |
| "FakeAVCeleb": files.FakeAVCeleb.test, | |
| "FAVC-FV-RA-WL": files.FakeAVCeleb.FV_RA_WL.test, | |
| "FAVC-FV-FA-FS": files.FakeAVCeleb.FV_FA_FS.test, | |
| "FAVC-FV-FA-GAN": files.FakeAVCeleb.FV_FA_GAN.test, | |
| "FAVC-FV-FA-WL": files.FakeAVCeleb.FV_FA_WL.test, | |
| "PolyGlotFake": files.PolyGlotFake.test, | |
| "IDForge-v1": files.IDForge_v1.test, | |
| } | { | |
| k: v.map(lambda x: x.replace("/CDFv3/", "/CDFv3-x1.3-th0.5-all/subset/uniform-32-frames/")) | |
| for k, v in files.CDFv3.get_test_dict().items() | |
| } | |
| def get_default_train_config() -> Config: | |
| config = Config() | |
| config.run_dir = "runs/rebuttal" | |
| config.wandb = True | |
| config.wandb_tags.append("rebuttal") | |
| config.throw_exception_if_run_exists = True | |
| config.num_workers = 12 | |
| config.devices = "auto" | |
| config.backbone = C.Backbone.CLIP_L_14 | |
| config.freeze_feature_extractor = True | |
| config.num_classes = 2 | |
| config.batch_size = config.mini_batch_size = 128 | |
| config.lr_scheduler = "cosine" | |
| config.lr = 3e-4 | |
| config.min_lr = 1e-5 | |
| config.weight_decay = 0 | |
| config.max_epochs = 1 + 50 | |
| config.warmup_epochs = 1 | |
| config.trn_files = files.FF.train | |
| config.val_files = get_val_files() | |
| config.tst_files = get_test_files() | |
| return config | |
| def get_default_test_config(orig_run_name, new_run_name) -> Config: | |
| orig_run_dir = files.find_run_dir(orig_run_name) | |
| orig_config_path = f"{orig_run_dir}/hparams.yaml" | |
| checkpoint = "best_mAP.ckpt" # Default checkpoint name | |
| # Load run specific config | |
| config = C.load_config(orig_config_path) | |
| config.run_name = new_run_name # Rename the run | |
| config.run_dir = "runs/test" # Set default test dir | |
| config.checkpoint = f"{orig_run_dir}/checkpoints/{checkpoint}" | |
| config.wandb = True | |
| config.wandb_tags.extend(["test"]) | |
| config.num_workers = 12 | |
| config.batch_size = config.mini_batch_size = 1024 | |
| config.devices = "auto" | |
| config.tst_files = get_test_files() | |
| return config | |
| def get_debug_config(config: Config) -> Config: | |
| #! Debug | |
| config.run_dir = "runs/tmp" | |
| config.run_name = "tmp" | |
| # config.num_workers = 0 | |
| config.max_epochs = 1 | |
| config.limit_train_batches = 12 | |
| config.limit_val_batches = 12 | |
| config.limit_test_batches = 12 | |
| # config.batch_size = config.mini_batch_size = 2 | |
| # config.deterministic = True | |
| # config.detect_anomaly = True | |
| config.trn_files = files.FF.train | |
| config.val_files = files.FF.val | |
| config.tst_files = files.FF.val | |
| return config | |
| experiments = { | |
| **experiments, # Include all experiments defined in src.exp | |
| } | |
| def entry( | |
| exp_names: str | list[str], | |
| debug: bool = False, | |
| test: bool = False, | |
| from_exp: str | None = None, | |
| **kwargs, | |
| ): | |
| if test: | |
| if from_exp is not None: | |
| if isinstance(exp_names, list): | |
| if len(exp_names) != 1: | |
| raise Exception("When running in test mode, you can provide only one experiment name.") | |
| config = get_default_test_config(from_exp, exp_names[0]) | |
| else: | |
| logger.print_warning("Running in test mode, but 'from_exp' is not provided. Using default test config.") | |
| config = C.Config() | |
| else: | |
| config = get_default_train_config() | |
| # parse name to list | |
| if isinstance(exp_names, str): | |
| exp_names = [exp_names] | |
| for exp_name in exp_names: | |
| exp_name = exp_name.strip() | |
| if exp_name not in experiments: | |
| logger.print_error(f"Experiment '{exp_name}' is not defined in 'src/exp/__init__.py:1'") | |
| logger.print(f"Available experiments: {list(experiments.keys())}") | |
| continue | |
| modifiers = experiments[exp_name] | |
| config_exp = deepcopy(config) | |
| config_exp.run_name = exp_name | |
| for modify in modifiers: | |
| if isinstance(modify, Config): | |
| # If the modifier is a Config object, change only different values | |
| difference = modify.model_dump(exclude_unset=True) | |
| # TODO: maybe set_values_from_dict(difference)? | |
| config_exp = Config(**config_exp.model_copy(update=difference).model_dump()) | |
| # config_exp = config_exp.model_copy(update=difference) | |
| else: | |
| config_exp = modify(config_exp) | |
| config_exp = Config(**config_exp.model_dump()) # Parse and validate config | |
| if debug: | |
| config_exp = config_exp.model_copy(update=get_debug_config(config_exp).model_dump()) | |
| # Update config with kwargs | |
| config_exp.set_values_from_dict(kwargs) | |
| # Revalidate the config - checks if user provided valid values | |
| config_exp = Config(**config_exp.model_dump()) | |
| # logger.print(config_exp) | |
| # exit() | |
| try: | |
| main(config_exp, not test) | |
| except Exception as e: | |
| traceback.print_exc() # Print complete exception traceback | |
| logger.print_error(f"Error occurred while running experiment '{exp_name}':") | |
| logger.print(e) | |
| save_dir = f"{config_exp.run_dir}/{config_exp.run_name}" | |
| # Save the exception traceback to a file | |
| with open(f"{save_dir}/failed.log", "a") as f: | |
| f.write(f"\nTraining failed: {e}\n") | |
| f.write(traceback.format_exc()) | |
| if __name__ == "__main__": | |
| fire.Fire(entry) | |