File size: 2,823 Bytes
f71ac1d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | """Base module for callbacks."""
from __future__ import annotations
import lightning.pytorch as pl
from torch import Tensor
from vis4d.common.typing import DictStrArrNested
from vis4d.data.typing import DictData
from vis4d.engine.connectors import CallbackConnector
class Callback(pl.Callback):
"""Base class for Callbacks."""
def __init__(
self,
epoch_based: bool = True,
train_connector: None | CallbackConnector = None,
test_connector: None | CallbackConnector = None,
) -> None:
"""Init callback.
Args:
epoch_based (bool, optional): Whether the callback is epoch based.
Defaults to False.
train_connector (None | CallbackConnector, optional): Defines which
kwargs to use during training for different callbacks. Defaults
to None.
test_connector (None | CallbackConnector, optional): Defines which
kwargs to use during testing for different callbacks. Defaults
to None.
"""
self.epoch_based = epoch_based
self.train_connector = train_connector
self.test_connector = test_connector
def setup(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
) -> None:
"""Setup callback."""
def get_train_callback_inputs(
self, outputs: DictData, batch: DictData
) -> dict[str, Tensor | DictStrArrNested]:
"""Returns the data connector results for training.
It extracts the required data from prediction and datas and passes it
to the next component with the provided new key.
Args:
outputs (DictData): Outputs of the model.
batch (DictData): Batch data.
Returns:
dict[str, Tensor | DictStrArrNested]: Data connector results.
Raises:
AssertionError: If train connector is None.
"""
assert self.train_connector is not None, "Train connector is None."
return self.train_connector(outputs, batch)
def get_test_callback_inputs(
self, outputs: DictData, batch: DictData
) -> dict[str, Tensor | DictStrArrNested]:
"""Returns the data connector results for inference.
It extracts the required data from prediction and datas and passes it
to the next component with the provided new key.
Args:
outputs (DictData): Outputs of the model.
batch (DictData): Batch data.
Returns:
dict[str, Tensor | DictStrArrNested]: Data connector results.
Raises:
AssertionError: If test connector is None.
"""
assert self.test_connector is not None, "Test connector is None."
return self.test_connector(outputs, batch)
|