weikaih's picture
WildDet3D Gradio demo
f71ac1d verified
"""Base module for callbacks."""
from __future__ import annotations
import lightning.pytorch as pl
from torch import Tensor
from vis4d.common.typing import DictStrArrNested
from vis4d.data.typing import DictData
from vis4d.engine.connectors import CallbackConnector
class Callback(pl.Callback):
"""Base class for Callbacks."""
def __init__(
self,
epoch_based: bool = True,
train_connector: None | CallbackConnector = None,
test_connector: None | CallbackConnector = None,
) -> None:
"""Init callback.
Args:
epoch_based (bool, optional): Whether the callback is epoch based.
Defaults to False.
train_connector (None | CallbackConnector, optional): Defines which
kwargs to use during training for different callbacks. Defaults
to None.
test_connector (None | CallbackConnector, optional): Defines which
kwargs to use during testing for different callbacks. Defaults
to None.
"""
self.epoch_based = epoch_based
self.train_connector = train_connector
self.test_connector = test_connector
def setup(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
) -> None:
"""Setup callback."""
def get_train_callback_inputs(
self, outputs: DictData, batch: DictData
) -> dict[str, Tensor | DictStrArrNested]:
"""Returns the data connector results for training.
It extracts the required data from prediction and datas and passes it
to the next component with the provided new key.
Args:
outputs (DictData): Outputs of the model.
batch (DictData): Batch data.
Returns:
dict[str, Tensor | DictStrArrNested]: Data connector results.
Raises:
AssertionError: If train connector is None.
"""
assert self.train_connector is not None, "Train connector is None."
return self.train_connector(outputs, batch)
def get_test_callback_inputs(
self, outputs: DictData, batch: DictData
) -> dict[str, Tensor | DictStrArrNested]:
"""Returns the data connector results for inference.
It extracts the required data from prediction and datas and passes it
to the next component with the provided new key.
Args:
outputs (DictData): Outputs of the model.
batch (DictData): Batch data.
Returns:
dict[str, Tensor | DictStrArrNested]: Data connector results.
Raises:
AssertionError: If test connector is None.
"""
assert self.test_connector is not None, "Test connector is None."
return self.test_connector(outputs, batch)