GenD-Sentinel / run.py
yermandy's picture
init
c29babb
import os
import traceback
import torch
from lightning import Trainer
from lightning.pytorch import callbacks as pl_callbacks
from lightning.pytorch import loggers as pl_loggers
from rich import traceback as rich_traceback
from src import dataset as datasets
from src.config import Config
from src.model.base import BaseDeepakeDetectionModel
from src.utils import logger
from src.utils.checks import checks
from src.utils.model_checkpoint import ModelCheckpointParallel
rich_traceback.install()
def load_third_party_model(config: Config) -> BaseDeepakeDetectionModel:
if "weights/Effort" in config.checkpoint:
# Download: https://drive.google.com/drive/folders/19kQwGDjF18uk78EnnypxxOLaG4Aa4v1h
from src.model.Effort import Effort
return Effort(config)
if "weights/ForAda" in config.checkpoint:
# Download: https://drive.usercontent.google.com/download?id=1UlaAUTtsX87ofIibf38TtfAKIsnA7WVm&export=download&authuser=0
from src.model.ForAda import ForAda
return ForAda(config)
if "weights/FS-VFM/" in config.checkpoint:
from src.model.FSFM import FSFM
return FSFM(config)
if "yermandy/" in config.checkpoint:
# https://huggingface.co/yermandy/models
from src.model.GenDHF import GenDHF
return GenDHF(config)
raise ValueError(f"Unknown third party model in checkpoint path: {config.checkpoint}")
def load_model(config: Config) -> BaseDeepakeDetectionModel:
# If no checkpoint is provided, use GenD as default
if config.checkpoint is None or config.checkpoint == "":
from src.model.GenD import GenD
return GenD(config, verbose=True)
# Try to load third party model
try:
return load_third_party_model(config)
except ValueError:
# If not a third party model, use GenD as default
from src.model.GenD import GenD
return GenD(config, verbose=True)
def init_loggers(config: Config) -> list:
save_dir = f"{config.run_dir}/{config.run_name}"
loggers: list = [pl_loggers.CSVLogger(config.run_dir, name=config.run_name, version="")]
if config.wandb:
wandb_logger = pl_loggers.WandbLogger(
project="deepfake",
name=config.run_name,
save_dir=save_dir,
tags=set(config.wandb_tags),
group=config.wandb_group,
)
loggers.append(wandb_logger)
return loggers
def init_callbacks(config: Config) -> list:
callbacks = [
pl_callbacks.RichProgressBar(leave=True),
ModelCheckpointParallel(
filename=config.checkpoint_name, monitor=config.monitor_metric, mode=config.monitor_metric_mode
),
]
# pl_callbacks.LearningRateFinder(1e-5, 1e-2),
if config.early_stopping_patience > 0:
callbacks.append(
pl_callbacks.EarlyStopping(
monitor=config.monitor_metric,
patience=config.early_stopping_patience,
mode=config.monitor_metric_mode,
verbose=True,
)
)
return callbacks
def finish_wandb_run(trainer, config: Config):
if config.wandb:
if any(isinstance(l, pl_loggers.WandbLogger) for l in trainer.loggers):
wandb_logger = [l for l in trainer.loggers if isinstance(l, pl_loggers.WandbLogger)][0]
wandb_logger.finalize("success")
wandb_logger.experiment.finish()
def main(config: Config, train: bool):
# Performs initial checks
checks(config)
# Set the precision for matmul operations
torch.set_float32_matmul_precision("high")
# Instantiates the model
model = load_model(config)
# Loads the checkpoint if provided
model.load_checkpoint(config.checkpoint)
data_module = datasets.DeepfakeDataModule(config, model.get_preprocessing())
save_dir = f"{config.run_dir}/{config.run_name}"
trainer = Trainer(
devices=config.devices,
max_epochs=config.max_epochs,
precision=config.precision,
accumulate_grad_batches=config.batch_size // config.mini_batch_size,
fast_dev_run=config.fast_dev_run,
log_every_n_steps=100,
overfit_batches=config.overfit_batches,
limit_train_batches=config.limit_train_batches,
limit_val_batches=config.limit_val_batches,
limit_test_batches=config.limit_test_batches,
deterministic=config.deterministic,
detect_anomaly=config.detect_anomaly,
logger=init_loggers(config),
callbacks=init_callbacks(config),
default_root_dir=config.run_dir,
)
if train:
try:
trainer.fit(model, data_module)
except KeyboardInterrupt:
logger.print_warning("Training interrupted")
except Exception as e:
traceback.print_exc() # Print complete exception traceback
logger.print_error(f"Training failed: {e}")
# Save the exception traceback to a file
with open(f"{save_dir}/failed.log", "a") as f:
f.write(f"Training failed: {e}\n")
f.write(traceback.format_exc())
finally:
logger.print_info("Training finished. Starting testing")
ckpt_path = f"{save_dir}/checkpoints/{config.checkpoint_name}.ckpt"
if not os.path.exists(ckpt_path):
logger.print_error(f"Checkpoint {ckpt_path} does not exist. Cannot proceed with testing.")
else:
model.load_checkpoint(ckpt_path)
trainer.test(model, data_module)
else:
assert config.checkpoint is not None, "Checkpoint is required for testing"
trainer.test(model, data_module)
# Finish wandb run
finish_wandb_run(trainer, config)