|
|
from mmpose.models import KeypointMSELoss |
|
|
from mmpose.registry import MODELS |
|
|
|
|
|
|
|
|
|
|
|
@MODELS.register_module() |
|
|
class ExampleLoss(KeypointMSELoss): |
|
|
"""Implements an example loss. |
|
|
|
|
|
Implement the loss just like a normal pytorch module. |
|
|
""" |
|
|
|
|
|
def __init__(self, **kwargs) -> None: |
|
|
print('Initializing ExampleLoss...') |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
def forward(self, output, target, target_weights=None, mask=None): |
|
|
"""Forward function of loss. The input arguments should match those |
|
|
given in `head.loss` function. |
|
|
|
|
|
Note: |
|
|
- batch_size: B |
|
|
- num_keypoints: K |
|
|
- heatmaps height: H |
|
|
- heatmaps weight: W |
|
|
|
|
|
Args: |
|
|
output (Tensor): The output heatmaps with shape [B, K, H, W] |
|
|
target (Tensor): The target heatmaps with shape [B, K, H, W] |
|
|
target_weights (Tensor, optional): The target weights of differet |
|
|
keypoints, with shape [B, K] (keypoint-wise) or |
|
|
[B, K, H, W] (pixel-wise). |
|
|
mask (Tensor, optional): The masks of valid heatmap pixels in |
|
|
shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will |
|
|
be applied. Defaults to ``None`` |
|
|
|
|
|
Returns: |
|
|
Tensor: The calculated loss. |
|
|
""" |
|
|
return super().forward(output, target, target_weights, mask) |
|
|
|