File size: 1,676 Bytes
436b829 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | import hydra
import pytorch_lightning as pl
import os
from omegaconf import DictConfig
from ppd.entrys.utils import get_data, get_model, get_callbacks, print_cfg, find_last_ckpt_path
from typing import Tuple
from tqdm.auto import tqdm
from ppd.utils.logger import Log
def setup_trainer(cfg: DictConfig) -> Tuple[pl.Trainer, pl.LightningModule, pl.LightningDataModule]:
"""
Set up the PyTorch Lightning trainer, model, and data module.
"""
if cfg.print_cfg: print_cfg(cfg, use_rich=True)
pl.seed_everything(cfg.seed)
# preparation
datamodule = get_data(cfg, wo_train=True)
model = get_model(cfg)
if cfg.pretrained_model:
model.load_pretrained_model_eval(cfg.pretrained_model)
else:
raise FileNotFoundError("Pretrained model not found. Please specify 'pretrained_model' path in config.")
# PL callbacks and logger
callbacks = get_callbacks(cfg)
cfg_logger = DictConfig.copy(cfg.logger)
cfg_logger.update({'version': 'val_metrics'})
logger = hydra.utils.instantiate(cfg_logger, _recursive_=False)
# PL-Trainer
trainer = pl.Trainer(
accelerator="gpu",
logger=logger if logger is not None else False,
callbacks=callbacks,
**cfg.pl_trainer,
)
return trainer, model, datamodule
def val(cfg: DictConfig) -> None:
"""
Validate the model.
"""
trainer, model, datamodule = setup_trainer(cfg)
trainer.validate(model, datamodule.val_dataloader())
def predict(cfg: DictConfig) -> None:
"""
Predict using the model.
"""
trainer, model, datamodule = setup_trainer(cfg)
trainer.predict(model, datamodule.val_dataloader())
|