| """PyTorch Lightning callbacks utilities.""" |
|
|
| from __future__ import annotations |
|
|
| import lightning.pytorch as pl |
| from torch import nn |
|
|
| from vis4d.engine.loss_module import LossModule |
| from vis4d.engine.training_module import TrainingModule |
|
|
|
|
| def get_model(model: pl.LightningModule) -> nn.Module: |
| """Get model from pl module.""" |
| if isinstance(model, TrainingModule): |
| return model.model |
| return model |
|
|
|
|
| def get_loss_module(loss_module: pl.LightningModule) -> LossModule: |
| """Get loss_module from pl module.""" |
| assert hasattr(loss_module, "loss_module") and isinstance( |
| loss_module.loss_module, LossModule |
| ), "Loss module is not set in the training module." |
| return loss_module.loss_module |
|
|