| """Base module for callbacks.""" |
|
|
| from __future__ import annotations |
|
|
| import lightning.pytorch as pl |
| from torch import Tensor |
|
|
| from vis4d.common.typing import DictStrArrNested |
| from vis4d.data.typing import DictData |
| from vis4d.engine.connectors import CallbackConnector |
|
|
|
|
| class Callback(pl.Callback): |
| """Base class for Callbacks.""" |
|
|
| def __init__( |
| self, |
| epoch_based: bool = True, |
| train_connector: None | CallbackConnector = None, |
| test_connector: None | CallbackConnector = None, |
| ) -> None: |
| """Init callback. |
| |
| Args: |
| epoch_based (bool, optional): Whether the callback is epoch based. |
| Defaults to False. |
| train_connector (None | CallbackConnector, optional): Defines which |
| kwargs to use during training for different callbacks. Defaults |
| to None. |
| test_connector (None | CallbackConnector, optional): Defines which |
| kwargs to use during testing for different callbacks. Defaults |
| to None. |
| """ |
| self.epoch_based = epoch_based |
| self.train_connector = train_connector |
| self.test_connector = test_connector |
|
|
| def setup( |
| self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str |
| ) -> None: |
| """Setup callback.""" |
|
|
| def get_train_callback_inputs( |
| self, outputs: DictData, batch: DictData |
| ) -> dict[str, Tensor | DictStrArrNested]: |
| """Returns the data connector results for training. |
| |
| It extracts the required data from prediction and datas and passes it |
| to the next component with the provided new key. |
| |
| Args: |
| outputs (DictData): Outputs of the model. |
| batch (DictData): Batch data. |
| |
| Returns: |
| dict[str, Tensor | DictStrArrNested]: Data connector results. |
| |
| Raises: |
| AssertionError: If train connector is None. |
| """ |
| assert self.train_connector is not None, "Train connector is None." |
|
|
| return self.train_connector(outputs, batch) |
|
|
| def get_test_callback_inputs( |
| self, outputs: DictData, batch: DictData |
| ) -> dict[str, Tensor | DictStrArrNested]: |
| """Returns the data connector results for inference. |
| |
| It extracts the required data from prediction and datas and passes it |
| to the next component with the provided new key. |
| |
| Args: |
| outputs (DictData): Outputs of the model. |
| batch (DictData): Batch data. |
| |
| Returns: |
| dict[str, Tensor | DictStrArrNested]: Data connector results. |
| |
| Raises: |
| AssertionError: If test connector is None. |
| """ |
| assert self.test_connector is not None, "Test connector is None." |
|
|
| return self.test_connector(outputs, batch) |
|
|