Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from mmpose.registry import MODELS | |
| class AssociativeEmbeddingLoss(nn.Module): | |
| """Associative Embedding loss. | |
| Details can be found in | |
| `Associative Embedding <https://arxiv.org/abs/1611.05424>`_ | |
| Note: | |
| - batch size: B | |
| - instance number: N | |
| - keypoint number: K | |
| - keypoint dimension: D | |
| - embedding tag dimension: L | |
| - heatmap size: [W, H] | |
| Args: | |
| loss_weight (float): Weight of the loss. Defaults to 1.0 | |
| push_loss_factor (float): A factor that controls the weight between | |
| the push loss and the pull loss. Defaults to 0.5 | |
| """ | |
| def __init__(self, | |
| loss_weight: float = 1.0, | |
| push_loss_factor: float = 0.5) -> None: | |
| super().__init__() | |
| self.loss_weight = loss_weight | |
| self.push_loss_factor = push_loss_factor | |
| def _ae_loss_per_image(self, tags: Tensor, keypoint_indices: Tensor): | |
| """Compute associative embedding loss for one image. | |
| Args: | |
| tags (Tensor): Tagging heatmaps in shape (K*L, H, W) | |
| keypoint_indices (Tensor): Ground-truth keypint position indices | |
| in shape (N, K, 2) | |
| """ | |
| K = keypoint_indices.shape[1] | |
| C, H, W = tags.shape | |
| L = C // K | |
| tags = tags.view(L, K, H * W) | |
| instance_tags = [] | |
| instance_kpt_tags = [] | |
| for keypoint_indices_n in keypoint_indices: | |
| _kpt_tags = [] | |
| for k in range(K): | |
| if keypoint_indices_n[k, 1]: | |
| _kpt_tags.append(tags[:, k, keypoint_indices_n[k, 0]]) | |
| if _kpt_tags: | |
| kpt_tags = torch.stack(_kpt_tags) | |
| instance_kpt_tags.append(kpt_tags) | |
| instance_tags.append(kpt_tags.mean(dim=0)) | |
| N = len(instance_kpt_tags) # number of instances with valid keypoints | |
| if N == 0: | |
| pull_loss = tags.new_zeros(size=(), requires_grad=True) | |
| push_loss = tags.new_zeros(size=(), requires_grad=True) | |
| else: | |
| pull_loss = sum( | |
| F.mse_loss(_kpt_tags, _tag.expand_as(_kpt_tags)) | |
| for (_kpt_tags, _tag) in zip(instance_kpt_tags, instance_tags)) | |
| if N == 1: | |
| push_loss = tags.new_zeros(size=(), requires_grad=True) | |
| else: | |
| tag_mat = torch.stack(instance_tags) # (N, L) | |
| diff = tag_mat[None] - tag_mat[:, None] # (N, N, L) | |
| push_loss = torch.sum(torch.exp(-diff.pow(2))) | |
| # normalization | |
| eps = 1e-6 | |
| pull_loss = pull_loss / (N + eps) | |
| push_loss = push_loss / ((N - 1) * N + eps) | |
| return pull_loss, push_loss | |
| def forward(self, tags: Tensor, keypoint_indices: Union[List[Tensor], | |
| Tensor]): | |
| """Compute associative embedding loss on a batch of data. | |
| Args: | |
| tags (Tensor): Tagging heatmaps in shape (B, L*K, H, W) | |
| keypoint_indices (Tensor|List[Tensor]): Ground-truth keypint | |
| position indices represented by a Tensor in shape | |
| (B, N, K, 2), or a list of B Tensors in shape (N_i, K, 2) | |
| Each keypoint's index is represented as [i, v], where i is the | |
| position index in the heatmap (:math:`i=y*w+x`) and v is the | |
| visibility | |
| Returns: | |
| tuple: | |
| - pull_loss (Tensor) | |
| - push_loss (Tensor) | |
| """ | |
| assert tags.shape[0] == len(keypoint_indices) | |
| pull_loss = 0. | |
| push_loss = 0. | |
| for i in range(tags.shape[0]): | |
| _pull, _push = self._ae_loss_per_image(tags[i], | |
| keypoint_indices[i]) | |
| pull_loss += _pull * self.loss_weight | |
| push_loss += _push * self.loss_weight * self.push_loss_factor | |
| return pull_loss, push_loss | |