File size: 2,823 Bytes
f71ac1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""Base module for callbacks."""

from __future__ import annotations

import lightning.pytorch as pl
from torch import Tensor

from vis4d.common.typing import DictStrArrNested
from vis4d.data.typing import DictData
from vis4d.engine.connectors import CallbackConnector


class Callback(pl.Callback):
    """Base class for Callbacks."""

    def __init__(
        self,
        epoch_based: bool = True,
        train_connector: None | CallbackConnector = None,
        test_connector: None | CallbackConnector = None,
    ) -> None:
        """Init callback.

        Args:
            epoch_based (bool, optional): Whether the callback is epoch based.
                Defaults to False.
            train_connector (None | CallbackConnector, optional): Defines which
                kwargs to use during training for different callbacks. Defaults
                to None.
            test_connector (None | CallbackConnector, optional): Defines which
                kwargs to use during testing for different callbacks. Defaults
                to None.
        """
        self.epoch_based = epoch_based
        self.train_connector = train_connector
        self.test_connector = test_connector

    def setup(
        self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
    ) -> None:
        """Setup callback."""

    def get_train_callback_inputs(
        self, outputs: DictData, batch: DictData
    ) -> dict[str, Tensor | DictStrArrNested]:
        """Returns the data connector results for training.

        It extracts the required data from prediction and datas and passes it
        to the next component with the provided new key.

        Args:
            outputs (DictData): Outputs of the model.
            batch (DictData): Batch data.

        Returns:
            dict[str, Tensor | DictStrArrNested]: Data connector results.

        Raises:
            AssertionError: If train connector is None.
        """
        assert self.train_connector is not None, "Train connector is None."

        return self.train_connector(outputs, batch)

    def get_test_callback_inputs(
        self, outputs: DictData, batch: DictData
    ) -> dict[str, Tensor | DictStrArrNested]:
        """Returns the data connector results for inference.

        It extracts the required data from prediction and datas and passes it
        to the next component with the provided new key.

        Args:
            outputs (DictData): Outputs of the model.
            batch (DictData): Batch data.

        Returns:
            dict[str, Tensor | DictStrArrNested]: Data connector results.

        Raises:
            AssertionError: If test connector is None.
        """
        assert self.test_connector is not None, "Test connector is None."

        return self.test_connector(outputs, batch)