| import hydra |
| import pytorch_lightning as pl |
| import os |
|
|
| from omegaconf import DictConfig |
| from ppd.entrys.utils import get_data, get_model, get_callbacks, print_cfg, find_last_ckpt_path |
| from typing import Tuple |
| from tqdm.auto import tqdm |
| from ppd.utils.logger import Log |
|
|
| def setup_trainer(cfg: DictConfig) -> Tuple[pl.Trainer, pl.LightningModule, pl.LightningDataModule]: |
| """ |
| Set up the PyTorch Lightning trainer, model, and data module. |
| """ |
| if cfg.print_cfg: print_cfg(cfg, use_rich=True) |
| pl.seed_everything(cfg.seed) |
| |
| datamodule = get_data(cfg, wo_train=True) |
| model = get_model(cfg) |
| if cfg.pretrained_model: |
| model.load_pretrained_model_eval(cfg.pretrained_model) |
| else: |
| raise FileNotFoundError("Pretrained model not found. Please specify 'pretrained_model' path in config.") |
|
|
| |
| callbacks = get_callbacks(cfg) |
| cfg_logger = DictConfig.copy(cfg.logger) |
| cfg_logger.update({'version': 'val_metrics'}) |
| logger = hydra.utils.instantiate(cfg_logger, _recursive_=False) |
|
|
| |
| trainer = pl.Trainer( |
| accelerator="gpu", |
| logger=logger if logger is not None else False, |
| callbacks=callbacks, |
| **cfg.pl_trainer, |
| ) |
|
|
| return trainer, model, datamodule |
|
|
| def val(cfg: DictConfig) -> None: |
| """ |
| Validate the model. |
| """ |
| trainer, model, datamodule = setup_trainer(cfg) |
| trainer.validate(model, datamodule.val_dataloader()) |
|
|
| def predict(cfg: DictConfig) -> None: |
| """ |
| Predict using the model. |
| """ |
| trainer, model, datamodule = setup_trainer(cfg) |
| trainer.predict(model, datamodule.val_dataloader()) |
|
|