weikaih's picture
WildDet3D Gradio demo
f71ac1d verified
"""PyTorch Lightning callbacks utilities."""
from __future__ import annotations
import lightning.pytorch as pl
from torch import nn
from vis4d.engine.loss_module import LossModule
from vis4d.engine.training_module import TrainingModule
def get_model(model: pl.LightningModule) -> nn.Module:
"""Get model from pl module."""
if isinstance(model, TrainingModule):
return model.model
return model
def get_loss_module(loss_module: pl.LightningModule) -> LossModule:
"""Get loss_module from pl module."""
assert hasattr(loss_module, "loss_module") and isinstance(
loss_module.loss_module, LossModule
), "Loss module is not set in the training module."
return loss_module.loss_module