File size: 2,458 Bytes
891e05c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import sys
import hydra
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import (
    LearningRateMonitor,
    RichProgressBar,
    ModelCheckpoint,
)
from pytorch_lightning.loggers import TensorBoardLogger
from omegaconf import DictConfig
from pathlib import Path

from models.loupe import LoupeModel, LoupeConfig
from data_module import DataModule
from lit_model import LitModel
from utils import convert_deepspeed_checkpoint

sys.path.insert(0, ".")
project_root = Path(__file__).resolve().parent.parent
save_path = os.environ.get("CATTINO_TASK_HOME", f"{project_root}/results")


@hydra.main(
    config_path=str(project_root / "configs"),
    config_name="train",
    version_base=None,
)
def main(cfg: DictConfig):
    pl.seed_everything(cfg.seed)

    if cfg.stage.name == "test":
        raise ValueError(
            "This script is for training only. Please use one of cls, seg, or cls_seg stages."
        )

    checkpoint_callback = hydra.utils.instantiate(
        cfg.ckpt.saver,
        dirpath=os.path.join(save_path, "checkpoints"),
    )
    logger = TensorBoardLogger(
        save_dir=save_path,
        name=cfg.stage.name,
        default_hp_metric=False,
    )
    trainer_overrides = dict(
        devices=1 if cfg.trainer.fast_dev_run else "auto",
        max_epochs=cfg.hparams.epoch,
        val_check_interval=0.05,
        log_every_n_steps=2,
        accumulate_grad_batches=cfg.hparams.accumulate_grad_batches,
        gradient_clip_val=cfg.hparams.grad_clip_val,
    )
    trainer_overrides.update(cfg.trainer)
    trainer = pl.Trainer(
        logger=logger,
        callbacks=[LearningRateMonitor(), RichProgressBar(), checkpoint_callback],
        **trainer_overrides,
    )
    torch.set_float32_matmul_precision("medium")
    loupe_config = LoupeConfig(stage=cfg.stage.name, **cfg.model)
    loupe = LoupeModel(loupe_config)

    model = LitModel(cfg, loupe)
    data_module = DataModule(cfg, loupe_config)

    trainer.fit(model, data_module)

    if "deepspeed" in cfg.trainer.get("strategy", "") and trainer_overrides.get(
        "enable_checkpointing", False
    ):
        convert_deepspeed_checkpoint(
            cfg,
            checkpoint_callback,
            os.path.join(
                project_root,
                "checkpoints",
                os.path.basename(checkpoint_callback.best_model_path),
            ),
        )


if __name__ == "__main__":
    main()