| import os |
| import hydra |
| import pytorch_lightning as pl |
| from omegaconf import DictConfig |
| from ppd.utils.logger import Log |
| from ppd.entrys.utils import get_data, get_model, get_callbacks, print_cfg, find_last_ckpt_path |
|
|
|
|
| def train_net(cfg: DictConfig) -> None: |
| """ |
| Instantiate the trainer, and then train the model. |
| """ |
| if cfg.print_cfg: print_cfg(cfg, use_rich=True) |
| callbacks = get_callbacks(cfg) |
| logger = hydra.utils.instantiate(cfg.logger, _recursive_=False) |
| trainer = pl.Trainer( |
| accelerator="gpu", |
| logger=logger if logger is not None else False, |
| callbacks=callbacks, |
| **cfg.pl_trainer, |
| ) |
| |
| pl.seed_everything(cfg.seed) |
| datamodule: pl.LightningDataModule = get_data(cfg) |
| model: pl.LightningModule = get_model(cfg) |
| |
| |
| ckpt_path = find_last_ckpt_path(cfg.callbacks.model_checkpoint.dirpath) |
| if ckpt_path: |
| model.load_pretrained_model(ckpt_path) |
| |
| |
| trainer.fit(model, datamodule, ckpt_path=None) |