File size: 728 Bytes
f71ac1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
"""PyTorch Lightning callbacks utilities."""

from __future__ import annotations

import lightning.pytorch as pl
from torch import nn

from vis4d.engine.loss_module import LossModule
from vis4d.engine.training_module import TrainingModule


def get_model(model: pl.LightningModule) -> nn.Module:
    """Get model from pl module."""
    if isinstance(model, TrainingModule):
        return model.model
    return model


def get_loss_module(loss_module: pl.LightningModule) -> LossModule:
    """Get loss_module from pl module."""
    assert hasattr(loss_module, "loss_module") and isinstance(
        loss_module.loss_module, LossModule
    ), "Loss module is not set in the training module."
    return loss_module.loss_module