Spaces:
Configuration error
Configuration error
| import os | |
| import traceback | |
| import torch | |
| from lightning import Trainer | |
| from lightning.pytorch import callbacks as pl_callbacks | |
| from lightning.pytorch import loggers as pl_loggers | |
| from rich import traceback as rich_traceback | |
| from src import dataset as datasets | |
| from src.config import Config | |
| from src.model.base import BaseDeepakeDetectionModel | |
| from src.utils import logger | |
| from src.utils.checks import checks | |
| from src.utils.model_checkpoint import ModelCheckpointParallel | |
| rich_traceback.install() | |
| def load_third_party_model(config: Config) -> BaseDeepakeDetectionModel: | |
| if "weights/Effort" in config.checkpoint: | |
| # Download: https://drive.google.com/drive/folders/19kQwGDjF18uk78EnnypxxOLaG4Aa4v1h | |
| from src.model.Effort import Effort | |
| return Effort(config) | |
| if "weights/ForAda" in config.checkpoint: | |
| # Download: https://drive.usercontent.google.com/download?id=1UlaAUTtsX87ofIibf38TtfAKIsnA7WVm&export=download&authuser=0 | |
| from src.model.ForAda import ForAda | |
| return ForAda(config) | |
| if "weights/FS-VFM/" in config.checkpoint: | |
| from src.model.FSFM import FSFM | |
| return FSFM(config) | |
| if "yermandy/" in config.checkpoint: | |
| # https://huggingface.co/yermandy/models | |
| from src.model.GenDHF import GenDHF | |
| return GenDHF(config) | |
| raise ValueError(f"Unknown third party model in checkpoint path: {config.checkpoint}") | |
| def load_model(config: Config) -> BaseDeepakeDetectionModel: | |
| # If no checkpoint is provided, use GenD as default | |
| if config.checkpoint is None or config.checkpoint == "": | |
| from src.model.GenD import GenD | |
| return GenD(config, verbose=True) | |
| # Try to load third party model | |
| try: | |
| return load_third_party_model(config) | |
| except ValueError: | |
| # If not a third party model, use GenD as default | |
| from src.model.GenD import GenD | |
| return GenD(config, verbose=True) | |
| def init_loggers(config: Config) -> list: | |
| save_dir = f"{config.run_dir}/{config.run_name}" | |
| loggers: list = [pl_loggers.CSVLogger(config.run_dir, name=config.run_name, version="")] | |
| if config.wandb: | |
| wandb_logger = pl_loggers.WandbLogger( | |
| project="deepfake", | |
| name=config.run_name, | |
| save_dir=save_dir, | |
| tags=set(config.wandb_tags), | |
| group=config.wandb_group, | |
| ) | |
| loggers.append(wandb_logger) | |
| return loggers | |
| def init_callbacks(config: Config) -> list: | |
| callbacks = [ | |
| pl_callbacks.RichProgressBar(leave=True), | |
| ModelCheckpointParallel( | |
| filename=config.checkpoint_name, monitor=config.monitor_metric, mode=config.monitor_metric_mode | |
| ), | |
| ] | |
| # pl_callbacks.LearningRateFinder(1e-5, 1e-2), | |
| if config.early_stopping_patience > 0: | |
| callbacks.append( | |
| pl_callbacks.EarlyStopping( | |
| monitor=config.monitor_metric, | |
| patience=config.early_stopping_patience, | |
| mode=config.monitor_metric_mode, | |
| verbose=True, | |
| ) | |
| ) | |
| return callbacks | |
| def finish_wandb_run(trainer, config: Config): | |
| if config.wandb: | |
| if any(isinstance(l, pl_loggers.WandbLogger) for l in trainer.loggers): | |
| wandb_logger = [l for l in trainer.loggers if isinstance(l, pl_loggers.WandbLogger)][0] | |
| wandb_logger.finalize("success") | |
| wandb_logger.experiment.finish() | |
| def main(config: Config, train: bool): | |
| # Performs initial checks | |
| checks(config) | |
| # Set the precision for matmul operations | |
| torch.set_float32_matmul_precision("high") | |
| # Instantiates the model | |
| model = load_model(config) | |
| # Loads the checkpoint if provided | |
| model.load_checkpoint(config.checkpoint) | |
| data_module = datasets.DeepfakeDataModule(config, model.get_preprocessing()) | |
| save_dir = f"{config.run_dir}/{config.run_name}" | |
| trainer = Trainer( | |
| devices=config.devices, | |
| max_epochs=config.max_epochs, | |
| precision=config.precision, | |
| accumulate_grad_batches=config.batch_size // config.mini_batch_size, | |
| fast_dev_run=config.fast_dev_run, | |
| log_every_n_steps=100, | |
| overfit_batches=config.overfit_batches, | |
| limit_train_batches=config.limit_train_batches, | |
| limit_val_batches=config.limit_val_batches, | |
| limit_test_batches=config.limit_test_batches, | |
| deterministic=config.deterministic, | |
| detect_anomaly=config.detect_anomaly, | |
| logger=init_loggers(config), | |
| callbacks=init_callbacks(config), | |
| default_root_dir=config.run_dir, | |
| ) | |
| if train: | |
| try: | |
| trainer.fit(model, data_module) | |
| except KeyboardInterrupt: | |
| logger.print_warning("Training interrupted") | |
| except Exception as e: | |
| traceback.print_exc() # Print complete exception traceback | |
| logger.print_error(f"Training failed: {e}") | |
| # Save the exception traceback to a file | |
| with open(f"{save_dir}/failed.log", "a") as f: | |
| f.write(f"Training failed: {e}\n") | |
| f.write(traceback.format_exc()) | |
| finally: | |
| logger.print_info("Training finished. Starting testing") | |
| ckpt_path = f"{save_dir}/checkpoints/{config.checkpoint_name}.ckpt" | |
| if not os.path.exists(ckpt_path): | |
| logger.print_error(f"Checkpoint {ckpt_path} does not exist. Cannot proceed with testing.") | |
| else: | |
| model.load_checkpoint(ckpt_path) | |
| trainer.test(model, data_module) | |
| else: | |
| assert config.checkpoint is not None, "Checkpoint is required for testing" | |
| trainer.test(model, data_module) | |
| # Finish wandb run | |
| finish_wandb_run(trainer, config) | |