LiDAR-Perfect-Depth / code /ppd /entrys /train_net.py
chenming-wu's picture
code
436b829 verified
import os
import hydra
import pytorch_lightning as pl
from omegaconf import DictConfig
from ppd.utils.logger import Log
from ppd.entrys.utils import get_data, get_model, get_callbacks, print_cfg, find_last_ckpt_path
def train_net(cfg: DictConfig) -> None:
"""
Instantiate the trainer, and then train the model.
"""
if cfg.print_cfg: print_cfg(cfg, use_rich=True)
callbacks = get_callbacks(cfg)
logger = hydra.utils.instantiate(cfg.logger, _recursive_=False)
trainer = pl.Trainer(
accelerator="gpu",
logger=logger if logger is not None else False,
callbacks=callbacks,
**cfg.pl_trainer,
)
# seed everything before loading data
pl.seed_everything(cfg.seed)
datamodule: pl.LightningDataModule = get_data(cfg)
model: pl.LightningModule = get_model(cfg)
# load pretrained model
ckpt_path = find_last_ckpt_path(cfg.callbacks.model_checkpoint.dirpath)
if ckpt_path:
model.load_pretrained_model(ckpt_path)
# training loop
trainer.fit(model, datamodule, ckpt_path=None)