File size: 3,331 Bytes
891e05c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import sys
import hydra
import torch
import pytorch_lightning as pl

from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from omegaconf import DictConfig
from pathlib import Path

from models.loupe import LoupeModel, LoupeConfig
from data_module import DataModule
from lit_model import LitModel
from utils import CustomWriter, convert_deepspeed_checkpoint

sys.path.insert(0, ".")
project_root = Path(__file__).resolve().parent.parent
save_path = os.environ.get("CATTINO_TASK_HOME", f"{project_root}/results")


@hydra.main(
    config_path=str(project_root / "configs"),
    config_name="infer",
    version_base=None,
)
def main(cfg: DictConfig):
    pl.seed_everything(cfg.seed)

    if cfg.stage.name != "test":
        raise ValueError("This script is for testing only. Please use the test stage.")

    callbacks = [RichProgressBar()]
    trainer_overrides = dict(
        devices=1 if cfg.trainer.fast_dev_run else "auto",
    )
    if cfg.stage.enable_tta:
        # if tta is enbaled, turn trainer to trainable
        trainer_overrides.update(
            logger=TensorBoardLogger(
                save_dir=save_path,
                name=cfg.stage.name,
                default_hp_metric=False,
            ),
            max_epochs=cfg.hparams.epoch,
            gradient_clip_val=cfg.hparams.grad_clip_val,
            val_check_interval=0.2,
            log_every_n_steps=2,
            accumulate_grad_batches=cfg.hparams.accumulate_grad_batches,
        )
        cfg.model.freeze_seg = False
        cfg.model.enable_conditional_queries = True
    else:
        trainer_overrides.update(
            logger=False,
            enable_checkpointing=False,
        )
        if cfg.stage.pred_output_dir:
            callbacks.append(CustomWriter(cfg=cfg, write_interval="batch"))
    trainer_overrides.update(cfg.trainer)

    if trainer_overrides.get("enable_checkpointing", False):
        checkpoint_callback = hydra.utils.instantiate(
            cfg.ckpt.saver,
            dirpath=os.path.join(save_path, "checkpoints"),
        )
        callbacks.append(checkpoint_callback)

    trainer = pl.Trainer(
        callbacks=callbacks,
        **trainer_overrides,
    )
    torch.set_float32_matmul_precision("medium")
    loupe_config = LoupeConfig(stage=cfg.stage.name, **cfg.model)
    loupe = LoupeModel(loupe_config)
    model = LitModel(cfg=cfg, loupe=loupe)
    data_module = DataModule(cfg, loupe_config)

    if cfg.stage.enable_tta:
        trainer.fit(
            model,
            data_module,
        )
        if "deepspeed" in cfg.trainer.get("strategy", "") and trainer_overrides.get(
            "enable_checkpointing", False
        ):
            convert_deepspeed_checkpoint(
                cfg,
                checkpoint_callback,
                os.path.join(
                    project_root,
                    "checkpoints",
                    os.path.basename(checkpoint_callback.best_model_path),
                ),
            )
    elif cfg.stage.pred_output_dir:
        trainer.predict(
            model,
            data_module,
            return_predictions=False,
        )
    else:
        trainer.test(
            model,
            data_module,
        )


if __name__ == "__main__":
    main()