GenD-Sentinel / run_exp.py
yermandy's picture
init
c29babb
import traceback
from copy import deepcopy
import fire
from run import main
from src import config as C
from src.config import Config
from src.exp import experiments
from src.utils import files, logger
def get_val_files():
return [
*files.DeepSpeak_v2.my_val,
*files.DeepSpeak_v1_1.my_val,
*files.CDFv2.val,
*files.FFIW.val,
]
def get_test_files():
return {
"FF": files.FF.test,
"FF-DF": files.FF.DF.test,
"FF-F2F": files.FF.F2F.test,
"FF-FS": files.FF.FS.test,
"FF-NT": files.FF.NT.test,
"CDF": files.CDFv2.test,
"FaceFusion": files.FaceFusion.CDF.test,
"DFD": files.DFD.test,
"DFDC": files.DFDC.test,
"FSh": files.FSh.test,
"UADFD": files.UADFV.test,
"DFDM": files.DFDM.test,
"FFIW": files.FFIW.test,
"DeepSpeak-1.1": files.DeepSpeak_v1_1.test,
"DeepSpeak-2.0": files.DeepSpeak_v2.test,
"KoDF": files.KoDF.test,
"KoDF-adv": files.KoDF.adversarial,
"FakeAVCeleb": files.FakeAVCeleb.test,
"FAVC-FV-RA-WL": files.FakeAVCeleb.FV_RA_WL.test,
"FAVC-FV-FA-FS": files.FakeAVCeleb.FV_FA_FS.test,
"FAVC-FV-FA-GAN": files.FakeAVCeleb.FV_FA_GAN.test,
"FAVC-FV-FA-WL": files.FakeAVCeleb.FV_FA_WL.test,
"PolyGlotFake": files.PolyGlotFake.test,
"IDForge-v1": files.IDForge_v1.test,
} | {
k: v.map(lambda x: x.replace("/CDFv3/", "/CDFv3-x1.3-th0.5-all/subset/uniform-32-frames/"))
for k, v in files.CDFv3.get_test_dict().items()
}
def get_default_train_config() -> Config:
config = Config()
config.run_dir = "runs/rebuttal"
config.wandb = True
config.wandb_tags.append("rebuttal")
config.throw_exception_if_run_exists = True
config.num_workers = 12
config.devices = "auto"
config.backbone = C.Backbone.CLIP_L_14
config.freeze_feature_extractor = True
config.num_classes = 2
config.batch_size = config.mini_batch_size = 128
config.lr_scheduler = "cosine"
config.lr = 3e-4
config.min_lr = 1e-5
config.weight_decay = 0
config.max_epochs = 1 + 50
config.warmup_epochs = 1
config.trn_files = files.FF.train
config.val_files = get_val_files()
config.tst_files = get_test_files()
return config
def get_default_test_config(orig_run_name, new_run_name) -> Config:
orig_run_dir = files.find_run_dir(orig_run_name)
orig_config_path = f"{orig_run_dir}/hparams.yaml"
checkpoint = "best_mAP.ckpt" # Default checkpoint name
# Load run specific config
config = C.load_config(orig_config_path)
config.run_name = new_run_name # Rename the run
config.run_dir = "runs/test" # Set default test dir
config.checkpoint = f"{orig_run_dir}/checkpoints/{checkpoint}"
config.wandb = True
config.wandb_tags.extend(["test"])
config.num_workers = 12
config.batch_size = config.mini_batch_size = 1024
config.devices = "auto"
config.tst_files = get_test_files()
return config
def get_debug_config(config: Config) -> Config:
#! Debug
config.run_dir = "runs/tmp"
config.run_name = "tmp"
# config.num_workers = 0
config.max_epochs = 1
config.limit_train_batches = 12
config.limit_val_batches = 12
config.limit_test_batches = 12
# config.batch_size = config.mini_batch_size = 2
# config.deterministic = True
# config.detect_anomaly = True
config.trn_files = files.FF.train
config.val_files = files.FF.val
config.tst_files = files.FF.val
return config
experiments = {
**experiments, # Include all experiments defined in src.exp
}
def entry(
exp_names: str | list[str],
debug: bool = False,
test: bool = False,
from_exp: str | None = None,
**kwargs,
):
if test:
if from_exp is not None:
if isinstance(exp_names, list):
if len(exp_names) != 1:
raise Exception("When running in test mode, you can provide only one experiment name.")
config = get_default_test_config(from_exp, exp_names[0])
else:
logger.print_warning("Running in test mode, but 'from_exp' is not provided. Using default test config.")
config = C.Config()
else:
config = get_default_train_config()
# parse name to list
if isinstance(exp_names, str):
exp_names = [exp_names]
for exp_name in exp_names:
exp_name = exp_name.strip()
if exp_name not in experiments:
logger.print_error(f"Experiment '{exp_name}' is not defined in 'src/exp/__init__.py:1'")
logger.print(f"Available experiments: {list(experiments.keys())}")
continue
modifiers = experiments[exp_name]
config_exp = deepcopy(config)
config_exp.run_name = exp_name
for modify in modifiers:
if isinstance(modify, Config):
# If the modifier is a Config object, change only different values
difference = modify.model_dump(exclude_unset=True)
# TODO: maybe set_values_from_dict(difference)?
config_exp = Config(**config_exp.model_copy(update=difference).model_dump())
# config_exp = config_exp.model_copy(update=difference)
else:
config_exp = modify(config_exp)
config_exp = Config(**config_exp.model_dump()) # Parse and validate config
if debug:
config_exp = config_exp.model_copy(update=get_debug_config(config_exp).model_dump())
# Update config with kwargs
config_exp.set_values_from_dict(kwargs)
# Revalidate the config - checks if user provided valid values
config_exp = Config(**config_exp.model_dump())
# logger.print(config_exp)
# exit()
try:
main(config_exp, not test)
except Exception as e:
traceback.print_exc() # Print complete exception traceback
logger.print_error(f"Error occurred while running experiment '{exp_name}':")
logger.print(e)
save_dir = f"{config_exp.run_dir}/{config_exp.run_name}"
# Save the exception traceback to a file
with open(f"{save_dir}/failed.log", "a") as f:
f.write(f"\nTraining failed: {e}\n")
f.write(traceback.format_exc())
if __name__ == "__main__":
fire.Fire(entry)