File size: 1,418 Bytes
7bed60d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from mmpose.models import KeypointMSELoss
from mmpose.registry import MODELS


# Register your loss to the `MODELS`.
@MODELS.register_module()
class ExampleLoss(KeypointMSELoss):
    """Implements an example loss.

    Implement the loss just like a normal pytorch module.
    """

    def __init__(self, **kwargs) -> None:
        print('Initializing ExampleLoss...')
        super().__init__(**kwargs)

    def forward(self, output, target, target_weights=None, mask=None):
        """Forward function of loss. The input arguments should match those
        given in `head.loss` function.

        Note:
            - batch_size: B
            - num_keypoints: K
            - heatmaps height: H
            - heatmaps weight: W

        Args:
            output (Tensor): The output heatmaps with shape [B, K, H, W]
            target (Tensor): The target heatmaps with shape [B, K, H, W]
            target_weights (Tensor, optional): The target weights of differet
                keypoints, with shape [B, K] (keypoint-wise) or
                [B, K, H, W] (pixel-wise).
            mask (Tensor, optional): The masks of valid heatmap pixels in
                shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will
                be applied. Defaults to ``None``

        Returns:
            Tensor: The calculated loss.
        """
        return super().forward(output, target, target_weights, mask)