File size: 2,924 Bytes
7bed60d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from mmpose.models import HeatmapHead
from mmpose.registry import MODELS


# Register your head to the `MODELS`.
@MODELS.register_module()
class ExampleHead(HeatmapHead):
    """Implements an example head.

    Implement the model head just like a normal pytorch module.
    """

    def __init__(self, **kwargs) -> None:
        print('Initializing ExampleHead...')
        super().__init__(**kwargs)

    def forward(self, feats):
        """Forward the network. The input is multi scale feature maps and the
        output is the coordinates.

        Args:
            feats (Tuple[Tensor]): Multi scale feature maps.

        Returns:
            Tensor: output coordinates or heatmaps.
        """
        return super().forward(feats)

    def predict(self, feats, batch_data_samples, test_cfg={}):
        """Predict results from outputs. The behaviour of head during testing
        should be defined in this function.

        Args:
            feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage
                features (or multiple multi-stage features in TTA)
            batch_data_samples (List[:obj:`PoseDataSample`]): A list of
                data samples for instances in a batch
            test_cfg (dict): The runtime config for testing process. Defaults
                to {}

        Returns:
            Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If
            ``test_cfg['output_heatmap']==True``, return both pose and heatmap
            prediction; otherwise only return the pose prediction.

            The pose prediction is a list of ``InstanceData``, each contains
            the following fields:

                - keypoints (np.ndarray): predicted keypoint coordinates in
                    shape (num_instances, K, D) where K is the keypoint number
                    and D is the keypoint dimension
                - keypoint_scores (np.ndarray): predicted keypoint scores in
                    shape (num_instances, K)

            The heatmap prediction is a list of ``PixelData``, each contains
            the following fields:

                - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w)
        """
        return super().predict(feats, batch_data_samples, test_cfg)

    def loss(self, feats, batch_data_samples, train_cfg={}) -> dict:
        """Calculate losses from a batch of inputs and data samples. The
        behaviour of head during training should be defined in this function.

        Args:
            feats (Tuple[Tensor]): The multi-stage features
            batch_data_samples (List[:obj:`PoseDataSample`]): A list of
                data samples for instances in a batch
            train_cfg (dict): The runtime config for training process.
                Defaults to {}

        Returns:
            dict: A dictionary of losses.
        """

        return super().loss(feats, batch_data_samples, train_cfg)