Spaces:
Configuration error
Configuration error
File size: 5,776 Bytes
c29babb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | import os
import traceback
import torch
from lightning import Trainer
from lightning.pytorch import callbacks as pl_callbacks
from lightning.pytorch import loggers as pl_loggers
from rich import traceback as rich_traceback
from src import dataset as datasets
from src.config import Config
from src.model.base import BaseDeepakeDetectionModel
from src.utils import logger
from src.utils.checks import checks
from src.utils.model_checkpoint import ModelCheckpointParallel
rich_traceback.install()
def load_third_party_model(config: Config) -> BaseDeepakeDetectionModel:
if "weights/Effort" in config.checkpoint:
# Download: https://drive.google.com/drive/folders/19kQwGDjF18uk78EnnypxxOLaG4Aa4v1h
from src.model.Effort import Effort
return Effort(config)
if "weights/ForAda" in config.checkpoint:
# Download: https://drive.usercontent.google.com/download?id=1UlaAUTtsX87ofIibf38TtfAKIsnA7WVm&export=download&authuser=0
from src.model.ForAda import ForAda
return ForAda(config)
if "weights/FS-VFM/" in config.checkpoint:
from src.model.FSFM import FSFM
return FSFM(config)
if "yermandy/" in config.checkpoint:
# https://huggingface.co/yermandy/models
from src.model.GenDHF import GenDHF
return GenDHF(config)
raise ValueError(f"Unknown third party model in checkpoint path: {config.checkpoint}")
def load_model(config: Config) -> BaseDeepakeDetectionModel:
# If no checkpoint is provided, use GenD as default
if config.checkpoint is None or config.checkpoint == "":
from src.model.GenD import GenD
return GenD(config, verbose=True)
# Try to load third party model
try:
return load_third_party_model(config)
except ValueError:
# If not a third party model, use GenD as default
from src.model.GenD import GenD
return GenD(config, verbose=True)
def init_loggers(config: Config) -> list:
save_dir = f"{config.run_dir}/{config.run_name}"
loggers: list = [pl_loggers.CSVLogger(config.run_dir, name=config.run_name, version="")]
if config.wandb:
wandb_logger = pl_loggers.WandbLogger(
project="deepfake",
name=config.run_name,
save_dir=save_dir,
tags=set(config.wandb_tags),
group=config.wandb_group,
)
loggers.append(wandb_logger)
return loggers
def init_callbacks(config: Config) -> list:
callbacks = [
pl_callbacks.RichProgressBar(leave=True),
ModelCheckpointParallel(
filename=config.checkpoint_name, monitor=config.monitor_metric, mode=config.monitor_metric_mode
),
]
# pl_callbacks.LearningRateFinder(1e-5, 1e-2),
if config.early_stopping_patience > 0:
callbacks.append(
pl_callbacks.EarlyStopping(
monitor=config.monitor_metric,
patience=config.early_stopping_patience,
mode=config.monitor_metric_mode,
verbose=True,
)
)
return callbacks
def finish_wandb_run(trainer, config: Config):
if config.wandb:
if any(isinstance(l, pl_loggers.WandbLogger) for l in trainer.loggers):
wandb_logger = [l for l in trainer.loggers if isinstance(l, pl_loggers.WandbLogger)][0]
wandb_logger.finalize("success")
wandb_logger.experiment.finish()
def main(config: Config, train: bool):
# Performs initial checks
checks(config)
# Set the precision for matmul operations
torch.set_float32_matmul_precision("high")
# Instantiates the model
model = load_model(config)
# Loads the checkpoint if provided
model.load_checkpoint(config.checkpoint)
data_module = datasets.DeepfakeDataModule(config, model.get_preprocessing())
save_dir = f"{config.run_dir}/{config.run_name}"
trainer = Trainer(
devices=config.devices,
max_epochs=config.max_epochs,
precision=config.precision,
accumulate_grad_batches=config.batch_size // config.mini_batch_size,
fast_dev_run=config.fast_dev_run,
log_every_n_steps=100,
overfit_batches=config.overfit_batches,
limit_train_batches=config.limit_train_batches,
limit_val_batches=config.limit_val_batches,
limit_test_batches=config.limit_test_batches,
deterministic=config.deterministic,
detect_anomaly=config.detect_anomaly,
logger=init_loggers(config),
callbacks=init_callbacks(config),
default_root_dir=config.run_dir,
)
if train:
try:
trainer.fit(model, data_module)
except KeyboardInterrupt:
logger.print_warning("Training interrupted")
except Exception as e:
traceback.print_exc() # Print complete exception traceback
logger.print_error(f"Training failed: {e}")
# Save the exception traceback to a file
with open(f"{save_dir}/failed.log", "a") as f:
f.write(f"Training failed: {e}\n")
f.write(traceback.format_exc())
finally:
logger.print_info("Training finished. Starting testing")
ckpt_path = f"{save_dir}/checkpoints/{config.checkpoint_name}.ckpt"
if not os.path.exists(ckpt_path):
logger.print_error(f"Checkpoint {ckpt_path} does not exist. Cannot proceed with testing.")
else:
model.load_checkpoint(ckpt_path)
trainer.test(model, data_module)
else:
assert config.checkpoint is not None, "Checkpoint is required for testing"
trainer.test(model, data_module)
# Finish wandb run
finish_wandb_run(trainer, config)
|