LiDAR-Perfect-Depth / code /ppd /entrys /validation.py
chenming-wu's picture
code
436b829 verified
import hydra
import pytorch_lightning as pl
import os
from omegaconf import DictConfig
from ppd.entrys.utils import get_data, get_model, get_callbacks, print_cfg, find_last_ckpt_path
from typing import Tuple
from tqdm.auto import tqdm
from ppd.utils.logger import Log
def setup_trainer(cfg: DictConfig) -> Tuple[pl.Trainer, pl.LightningModule, pl.LightningDataModule]:
"""
Set up the PyTorch Lightning trainer, model, and data module.
"""
if cfg.print_cfg: print_cfg(cfg, use_rich=True)
pl.seed_everything(cfg.seed)
# preparation
datamodule = get_data(cfg, wo_train=True)
model = get_model(cfg)
if cfg.pretrained_model:
model.load_pretrained_model_eval(cfg.pretrained_model)
else:
raise FileNotFoundError("Pretrained model not found. Please specify 'pretrained_model' path in config.")
# PL callbacks and logger
callbacks = get_callbacks(cfg)
cfg_logger = DictConfig.copy(cfg.logger)
cfg_logger.update({'version': 'val_metrics'})
logger = hydra.utils.instantiate(cfg_logger, _recursive_=False)
# PL-Trainer
trainer = pl.Trainer(
accelerator="gpu",
logger=logger if logger is not None else False,
callbacks=callbacks,
**cfg.pl_trainer,
)
return trainer, model, datamodule
def val(cfg: DictConfig) -> None:
"""
Validate the model.
"""
trainer, model, datamodule = setup_trainer(cfg)
trainer.validate(model, datamodule.val_dataloader())
def predict(cfg: DictConfig) -> None:
"""
Predict using the model.
"""
trainer, model, datamodule = setup_trainer(cfg)
trainer.predict(model, datamodule.val_dataloader())