File size: 5,776 Bytes
c29babb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os
import traceback

import torch
from lightning import Trainer
from lightning.pytorch import callbacks as pl_callbacks
from lightning.pytorch import loggers as pl_loggers
from rich import traceback as rich_traceback

from src import dataset as datasets
from src.config import Config
from src.model.base import BaseDeepakeDetectionModel
from src.utils import logger
from src.utils.checks import checks
from src.utils.model_checkpoint import ModelCheckpointParallel

rich_traceback.install()


def load_third_party_model(config: Config) -> BaseDeepakeDetectionModel:
    if "weights/Effort" in config.checkpoint:
        # Download: https://drive.google.com/drive/folders/19kQwGDjF18uk78EnnypxxOLaG4Aa4v1h
        from src.model.Effort import Effort

        return Effort(config)

    if "weights/ForAda" in config.checkpoint:
        # Download: https://drive.usercontent.google.com/download?id=1UlaAUTtsX87ofIibf38TtfAKIsnA7WVm&export=download&authuser=0
        from src.model.ForAda import ForAda

        return ForAda(config)

    if "weights/FS-VFM/" in config.checkpoint:
        from src.model.FSFM import FSFM

        return FSFM(config)

    if "yermandy/" in config.checkpoint:
        # https://huggingface.co/yermandy/models
        from src.model.GenDHF import GenDHF

        return GenDHF(config)


    raise ValueError(f"Unknown third party model in checkpoint path: {config.checkpoint}")


def load_model(config: Config) -> BaseDeepakeDetectionModel:
    # If no checkpoint is provided, use GenD as default
    if config.checkpoint is None or config.checkpoint == "":
        from src.model.GenD import GenD

        return GenD(config, verbose=True)

    # Try to load third party model
    try:
        return load_third_party_model(config)
    except ValueError:
        # If not a third party model, use GenD as default
        from src.model.GenD import GenD

        return GenD(config, verbose=True)


def init_loggers(config: Config) -> list:
    save_dir = f"{config.run_dir}/{config.run_name}"

    loggers: list = [pl_loggers.CSVLogger(config.run_dir, name=config.run_name, version="")]

    if config.wandb:
        wandb_logger = pl_loggers.WandbLogger(
            project="deepfake",
            name=config.run_name,
            save_dir=save_dir,
            tags=set(config.wandb_tags),
            group=config.wandb_group,
        )
        loggers.append(wandb_logger)

    return loggers


def init_callbacks(config: Config) -> list:
    callbacks = [
        pl_callbacks.RichProgressBar(leave=True),
        ModelCheckpointParallel(
            filename=config.checkpoint_name, monitor=config.monitor_metric, mode=config.monitor_metric_mode
        ),
    ]
    # pl_callbacks.LearningRateFinder(1e-5, 1e-2),

    if config.early_stopping_patience > 0:
        callbacks.append(
            pl_callbacks.EarlyStopping(
                monitor=config.monitor_metric,
                patience=config.early_stopping_patience,
                mode=config.monitor_metric_mode,
                verbose=True,
            )
        )

    return callbacks


def finish_wandb_run(trainer, config: Config):
    if config.wandb:
        if any(isinstance(l, pl_loggers.WandbLogger) for l in trainer.loggers):
            wandb_logger = [l for l in trainer.loggers if isinstance(l, pl_loggers.WandbLogger)][0]
            wandb_logger.finalize("success")
            wandb_logger.experiment.finish()


def main(config: Config, train: bool):
    # Performs initial checks
    checks(config)

    # Set the precision for matmul operations
    torch.set_float32_matmul_precision("high")

    # Instantiates the model
    model = load_model(config)

    # Loads the checkpoint if provided
    model.load_checkpoint(config.checkpoint)

    data_module = datasets.DeepfakeDataModule(config, model.get_preprocessing())

    save_dir = f"{config.run_dir}/{config.run_name}"

    trainer = Trainer(
        devices=config.devices,
        max_epochs=config.max_epochs,
        precision=config.precision,
        accumulate_grad_batches=config.batch_size // config.mini_batch_size,
        fast_dev_run=config.fast_dev_run,
        log_every_n_steps=100,
        overfit_batches=config.overfit_batches,
        limit_train_batches=config.limit_train_batches,
        limit_val_batches=config.limit_val_batches,
        limit_test_batches=config.limit_test_batches,
        deterministic=config.deterministic,
        detect_anomaly=config.detect_anomaly,
        logger=init_loggers(config),
        callbacks=init_callbacks(config),
        default_root_dir=config.run_dir,
    )

    if train:
        try:
            trainer.fit(model, data_module)
        except KeyboardInterrupt:
            logger.print_warning("Training interrupted")
        except Exception as e:
            traceback.print_exc()  # Print complete exception traceback
            logger.print_error(f"Training failed: {e}")
            # Save the exception traceback to a file
            with open(f"{save_dir}/failed.log", "a") as f:
                f.write(f"Training failed: {e}\n")
                f.write(traceback.format_exc())
        finally:
            logger.print_info("Training finished. Starting testing")
            ckpt_path = f"{save_dir}/checkpoints/{config.checkpoint_name}.ckpt"
            if not os.path.exists(ckpt_path):
                logger.print_error(f"Checkpoint {ckpt_path} does not exist. Cannot proceed with testing.")
            else:
                model.load_checkpoint(ckpt_path)
                trainer.test(model, data_module)

    else:
        assert config.checkpoint is not None, "Checkpoint is required for testing"
        trainer.test(model, data_module)

    # Finish wandb run
    finish_wandb_run(trainer, config)