File size: 1,676 Bytes
436b829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import hydra
import pytorch_lightning as pl
import os

from omegaconf import DictConfig
from ppd.entrys.utils import get_data, get_model, get_callbacks, print_cfg, find_last_ckpt_path
from typing import Tuple
from tqdm.auto import tqdm
from ppd.utils.logger import Log

def setup_trainer(cfg: DictConfig) -> Tuple[pl.Trainer, pl.LightningModule, pl.LightningDataModule]:
    """
    Set up the PyTorch Lightning trainer, model, and data module.
    """
    if cfg.print_cfg: print_cfg(cfg, use_rich=True)
    pl.seed_everything(cfg.seed)
    # preparation
    datamodule = get_data(cfg, wo_train=True)
    model = get_model(cfg)
    if cfg.pretrained_model:
        model.load_pretrained_model_eval(cfg.pretrained_model)
    else:
        raise FileNotFoundError("Pretrained model not found. Please specify 'pretrained_model' path in config.")

    # PL callbacks and logger
    callbacks = get_callbacks(cfg)
    cfg_logger = DictConfig.copy(cfg.logger)
    cfg_logger.update({'version': 'val_metrics'})
    logger = hydra.utils.instantiate(cfg_logger, _recursive_=False)

    # PL-Trainer
    trainer = pl.Trainer(
        accelerator="gpu",
        logger=logger if logger is not None else False,
        callbacks=callbacks,
        **cfg.pl_trainer,
    )

    return trainer, model, datamodule

def val(cfg: DictConfig) -> None:
    """
    Validate the model.
    """
    trainer, model, datamodule = setup_trainer(cfg)
    trainer.validate(model, datamodule.val_dataloader())

def predict(cfg: DictConfig) -> None:
    """
    Predict using the model.
    """
    trainer, model, datamodule = setup_trainer(cfg)
    trainer.predict(model, datamodule.val_dataloader())