Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from mmpose.registry import MODELS | |
| class KeypointMSELoss(nn.Module): | |
| """MSE loss for heatmaps. | |
| Args: | |
| use_target_weight (bool): Option to use weighted MSE loss. | |
| Different joint types may have different target weights. | |
| Defaults to ``False`` | |
| skip_empty_channel (bool): If ``True``, heatmap channels with no | |
| non-zero value (which means no visible ground-truth keypoint | |
| in the image) will not be used to calculate the loss. Defaults to | |
| ``False`` | |
| loss_weight (float): Weight of the loss. Defaults to 1.0 | |
| """ | |
| def __init__(self, | |
| use_target_weight: bool = False, | |
| skip_empty_channel: bool = False, | |
| loss_weight: float = 1.): | |
| super().__init__() | |
| self.use_target_weight = use_target_weight | |
| self.skip_empty_channel = skip_empty_channel | |
| self.loss_weight = loss_weight | |
| def forward(self, | |
| output: Tensor, | |
| target: Tensor, | |
| target_weights: Optional[Tensor] = None, | |
| mask: Optional[Tensor] = None, | |
| per_keypoint: bool = False, | |
| per_pixel: bool = False) -> Tensor: | |
| """Forward function of loss. | |
| Note: | |
| - batch_size: B | |
| - num_keypoints: K | |
| - heatmaps height: H | |
| - heatmaps weight: W | |
| Args: | |
| output (Tensor): The output heatmaps with shape [B, K, H, W] | |
| target (Tensor): The target heatmaps with shape [B, K, H, W] | |
| target_weights (Tensor, optional): The target weights of differet | |
| keypoints, with shape [B, K] (keypoint-wise) or | |
| [B, K, H, W] (pixel-wise). | |
| mask (Tensor, optional): The masks of valid heatmap pixels in | |
| shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will | |
| be applied. Defaults to ``None`` | |
| Returns: | |
| Tensor: The calculated loss. | |
| """ | |
| _mask = self._get_mask(target, target_weights, mask) | |
| _loss = F.mse_loss(output, target, reduction='none') | |
| if _mask is not None: | |
| loss = _loss * _mask | |
| if per_pixel: | |
| pass | |
| elif per_keypoint: | |
| loss = loss.mean(dim=(2, 3)) | |
| else: | |
| loss = loss.mean() | |
| return loss * self.loss_weight | |
| def _get_mask(self, target: Tensor, target_weights: Optional[Tensor], | |
| mask: Optional[Tensor]) -> Optional[Tensor]: | |
| """Generate the heatmap mask w.r.t. the given mask, target weight and | |
| `skip_empty_channel` setting. | |
| Returns: | |
| Tensor: The mask in shape (B, K, *) or ``None`` if no mask is | |
| needed. | |
| """ | |
| # Given spatial mask | |
| if mask is not None: | |
| # check mask has matching type with target | |
| assert (mask.ndim == target.ndim and all( | |
| d_m == d_t or d_m == 1 | |
| for d_m, d_t in zip(mask.shape, target.shape))), ( | |
| f'mask and target have mismatched shapes {mask.shape} v.s.' | |
| f'{target.shape}') | |
| # Mask by target weights (keypoint-wise mask) | |
| if target_weights is not None: | |
| # check target weight has matching shape with target | |
| assert (target_weights.ndim in (2, 4) and target_weights.shape | |
| == target.shape[:target_weights.ndim]), ( | |
| 'target_weights and target have mismatched shapes ' | |
| f'{target_weights.shape} v.s. {target.shape}') | |
| ndim_pad = target.ndim - target_weights.ndim | |
| _mask = target_weights.view(target_weights.shape + | |
| (1, ) * ndim_pad) | |
| if mask is None: | |
| mask = _mask | |
| else: | |
| mask = mask * _mask | |
| # Mask by ``skip_empty_channel`` | |
| if self.skip_empty_channel: | |
| _mask = (target != 0).flatten(2).any(dim=2) | |
| ndim_pad = target.ndim - _mask.ndim | |
| _mask = _mask.view(_mask.shape + (1, ) * ndim_pad) | |
| if mask is None: | |
| mask = _mask | |
| else: | |
| mask = mask * _mask | |
| return mask | |
| class CombinedTargetMSELoss(nn.Module): | |
| """MSE loss for combined target. | |
| CombinedTarget: The combination of classification target | |
| (response map) and regression target (offset map). | |
| Paper ref: Huang et al. The Devil is in the Details: Delving into | |
| Unbiased Data Processing for Human Pose Estimation (CVPR 2020). | |
| Args: | |
| use_target_weight (bool): Option to use weighted MSE loss. | |
| Different joint types may have different target weights. | |
| Defaults to ``False`` | |
| loss_weight (float): Weight of the loss. Defaults to 1.0 | |
| """ | |
| def __init__(self, | |
| use_target_weight: bool = False, | |
| loss_weight: float = 1.): | |
| super().__init__() | |
| self.criterion = nn.MSELoss(reduction='mean') | |
| self.use_target_weight = use_target_weight | |
| self.loss_weight = loss_weight | |
| def forward(self, output: Tensor, target: Tensor, | |
| target_weights: Tensor) -> Tensor: | |
| """Forward function of loss. | |
| Note: | |
| - batch_size: B | |
| - num_channels: C | |
| - heatmaps height: H | |
| - heatmaps weight: W | |
| - num_keypoints: K | |
| Here, C = 3 * K | |
| Args: | |
| output (Tensor): The output feature maps with shape [B, C, H, W]. | |
| target (Tensor): The target feature maps with shape [B, C, H, W]. | |
| target_weights (Tensor): The target weights of differet keypoints, | |
| with shape [B, K]. | |
| Returns: | |
| Tensor: The calculated loss. | |
| """ | |
| batch_size = output.size(0) | |
| num_channels = output.size(1) | |
| heatmaps_pred = output.reshape( | |
| (batch_size, num_channels, -1)).split(1, 1) | |
| heatmaps_gt = target.reshape( | |
| (batch_size, num_channels, -1)).split(1, 1) | |
| loss = 0. | |
| num_joints = num_channels // 3 | |
| for idx in range(num_joints): | |
| heatmap_pred = heatmaps_pred[idx * 3].squeeze() | |
| heatmap_gt = heatmaps_gt[idx * 3].squeeze() | |
| offset_x_pred = heatmaps_pred[idx * 3 + 1].squeeze() | |
| offset_x_gt = heatmaps_gt[idx * 3 + 1].squeeze() | |
| offset_y_pred = heatmaps_pred[idx * 3 + 2].squeeze() | |
| offset_y_gt = heatmaps_gt[idx * 3 + 2].squeeze() | |
| if self.use_target_weight: | |
| target_weight = target_weights[:, idx, None] | |
| heatmap_pred = heatmap_pred * target_weight | |
| heatmap_gt = heatmap_gt * target_weight | |
| # classification loss | |
| loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt) | |
| # regression loss | |
| loss += 0.5 * self.criterion(heatmap_gt * offset_x_pred, | |
| heatmap_gt * offset_x_gt) | |
| loss += 0.5 * self.criterion(heatmap_gt * offset_y_pred, | |
| heatmap_gt * offset_y_gt) | |
| return loss / num_joints * self.loss_weight | |
| class KeypointOHKMMSELoss(nn.Module): | |
| """MSE loss with online hard keypoint mining. | |
| Args: | |
| use_target_weight (bool): Option to use weighted MSE loss. | |
| Different joint types may have different target weights. | |
| Defaults to ``False`` | |
| topk (int): Only top k joint losses are kept. Defaults to 8 | |
| loss_weight (float): Weight of the loss. Defaults to 1.0 | |
| """ | |
| def __init__(self, | |
| use_target_weight: bool = False, | |
| topk: int = 8, | |
| loss_weight: float = 1.): | |
| super().__init__() | |
| assert topk > 0 | |
| self.criterion = nn.MSELoss(reduction='none') | |
| self.use_target_weight = use_target_weight | |
| self.topk = topk | |
| self.loss_weight = loss_weight | |
| def _ohkm(self, losses: Tensor) -> Tensor: | |
| """Online hard keypoint mining. | |
| Note: | |
| - batch_size: B | |
| - num_keypoints: K | |
| Args: | |
| loss (Tensor): The losses with shape [B, K] | |
| Returns: | |
| Tensor: The calculated loss. | |
| """ | |
| ohkm_loss = 0. | |
| B = losses.shape[0] | |
| for i in range(B): | |
| sub_loss = losses[i] | |
| _, topk_idx = torch.topk( | |
| sub_loss, k=self.topk, dim=0, sorted=False) | |
| tmp_loss = torch.gather(sub_loss, 0, topk_idx) | |
| ohkm_loss += torch.sum(tmp_loss) / self.topk | |
| ohkm_loss /= B | |
| return ohkm_loss | |
| def forward(self, output: Tensor, target: Tensor, | |
| target_weights: Tensor) -> Tensor: | |
| """Forward function of loss. | |
| Note: | |
| - batch_size: B | |
| - num_keypoints: K | |
| - heatmaps height: H | |
| - heatmaps weight: W | |
| Args: | |
| output (Tensor): The output heatmaps with shape [B, K, H, W]. | |
| target (Tensor): The target heatmaps with shape [B, K, H, W]. | |
| target_weights (Tensor): The target weights of differet keypoints, | |
| with shape [B, K]. | |
| Returns: | |
| Tensor: The calculated loss. | |
| """ | |
| num_keypoints = output.size(1) | |
| if num_keypoints < self.topk: | |
| raise ValueError(f'topk ({self.topk}) should not be ' | |
| f'larger than num_keypoints ({num_keypoints}).') | |
| losses = [] | |
| for idx in range(num_keypoints): | |
| if self.use_target_weight: | |
| target_weight = target_weights[:, idx, None, None] | |
| losses.append( | |
| self.criterion(output[:, idx] * target_weight, | |
| target[:, idx] * target_weight)) | |
| else: | |
| losses.append(self.criterion(output[:, idx], target[:, idx])) | |
| losses = [loss.mean(dim=(1, 2)).unsqueeze(dim=1) for loss in losses] | |
| losses = torch.cat(losses, dim=1) | |
| return self._ohkm(losses) * self.loss_weight | |
| class AdaptiveWingLoss(nn.Module): | |
| """Adaptive wing loss. paper ref: 'Adaptive Wing Loss for Robust Face | |
| Alignment via Heatmap Regression' Wang et al. ICCV'2019. | |
| Args: | |
| alpha (float), omega (float), epsilon (float), theta (float) | |
| are hyper-parameters. | |
| use_target_weight (bool): Option to use weighted MSE loss. | |
| Different joint types may have different target weights. | |
| loss_weight (float): Weight of the loss. Default: 1.0. | |
| """ | |
| def __init__(self, | |
| alpha=2.1, | |
| omega=14, | |
| epsilon=1, | |
| theta=0.5, | |
| use_target_weight=False, | |
| loss_weight=1.): | |
| super().__init__() | |
| self.alpha = float(alpha) | |
| self.omega = float(omega) | |
| self.epsilon = float(epsilon) | |
| self.theta = float(theta) | |
| self.use_target_weight = use_target_weight | |
| self.loss_weight = loss_weight | |
| def criterion(self, pred, target): | |
| """Criterion of wingloss. | |
| Note: | |
| batch_size: N | |
| num_keypoints: K | |
| Args: | |
| pred (torch.Tensor[NxKxHxW]): Predicted heatmaps. | |
| target (torch.Tensor[NxKxHxW]): Target heatmaps. | |
| """ | |
| H, W = pred.shape[2:4] | |
| delta = (target - pred).abs() | |
| A = self.omega * ( | |
| 1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - target)) | |
| ) * (self.alpha - target) * (torch.pow( | |
| self.theta / self.epsilon, | |
| self.alpha - target - 1)) * (1 / self.epsilon) | |
| C = self.theta * A - self.omega * torch.log( | |
| 1 + torch.pow(self.theta / self.epsilon, self.alpha - target)) | |
| losses = torch.where( | |
| delta < self.theta, | |
| self.omega * | |
| torch.log(1 + | |
| torch.pow(delta / self.epsilon, self.alpha - target)), | |
| A * delta - C) | |
| return torch.mean(losses) | |
| def forward(self, | |
| output: Tensor, | |
| target: Tensor, | |
| target_weights: Optional[Tensor] = None): | |
| """Forward function. | |
| Note: | |
| batch_size: N | |
| num_keypoints: K | |
| Args: | |
| output (torch.Tensor[N, K, H, W]): Output heatmaps. | |
| target (torch.Tensor[N, K, H, W]): Target heatmaps. | |
| target_weight (torch.Tensor[N, K]): | |
| Weights across different joint types. | |
| """ | |
| if self.use_target_weight: | |
| assert (target_weights.ndim in (2, 4) and target_weights.shape | |
| == target.shape[:target_weights.ndim]), ( | |
| 'target_weights and target have mismatched shapes ' | |
| f'{target_weights.shape} v.s. {target.shape}') | |
| ndim_pad = target.ndim - target_weights.ndim | |
| target_weights = target_weights.view(target_weights.shape + | |
| (1, ) * ndim_pad) | |
| loss = self.criterion(output * target_weights, | |
| target * target_weights) | |
| else: | |
| loss = self.criterion(output, target) | |
| return loss * self.loss_weight | |
| class FocalHeatmapLoss(KeypointMSELoss): | |
| """A class for calculating the modified focal loss for heatmap prediction. | |
| This loss function is exactly the same as the one used in CornerNet. It | |
| runs faster and costs a little bit more memory. | |
| `CornerNet: Detecting Objects as Paired Keypoints | |
| arXiv: <https://arxiv.org/abs/1808.01244>`_. | |
| Arguments: | |
| alpha (int): The alpha parameter in the focal loss equation. | |
| beta (int): The beta parameter in the focal loss equation. | |
| use_target_weight (bool): Option to use weighted MSE loss. | |
| Different joint types may have different target weights. | |
| Defaults to ``False`` | |
| skip_empty_channel (bool): If ``True``, heatmap channels with no | |
| non-zero value (which means no visible ground-truth keypoint | |
| in the image) will not be used to calculate the loss. Defaults to | |
| ``False`` | |
| loss_weight (float): Weight of the loss. Defaults to 1.0 | |
| """ | |
| def __init__(self, | |
| alpha: int = 2, | |
| beta: int = 4, | |
| use_target_weight: bool = False, | |
| skip_empty_channel: bool = False, | |
| loss_weight: float = 1.0): | |
| super(FocalHeatmapLoss, self).__init__(use_target_weight, | |
| skip_empty_channel, loss_weight) | |
| self.alpha = alpha | |
| self.beta = beta | |
| def forward(self, | |
| output: Tensor, | |
| target: Tensor, | |
| target_weights: Optional[Tensor] = None, | |
| mask: Optional[Tensor] = None) -> Tensor: | |
| """Calculate the modified focal loss for heatmap prediction. | |
| Note: | |
| - batch_size: B | |
| - num_keypoints: K | |
| - heatmaps height: H | |
| - heatmaps weight: W | |
| Args: | |
| output (Tensor): The output heatmaps with shape [B, K, H, W] | |
| target (Tensor): The target heatmaps with shape [B, K, H, W] | |
| target_weights (Tensor, optional): The target weights of differet | |
| keypoints, with shape [B, K] (keypoint-wise) or | |
| [B, K, H, W] (pixel-wise). | |
| mask (Tensor, optional): The masks of valid heatmap pixels in | |
| shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will | |
| be applied. Defaults to ``None`` | |
| Returns: | |
| Tensor: The calculated loss. | |
| """ | |
| _mask = self._get_mask(target, target_weights, mask) | |
| pos_inds = target.eq(1).float() | |
| neg_inds = target.lt(1).float() | |
| if _mask is not None: | |
| pos_inds = pos_inds * _mask | |
| neg_inds = neg_inds * _mask | |
| neg_weights = torch.pow(1 - target, self.beta) | |
| pos_loss = torch.log(output) * torch.pow(1 - output, | |
| self.alpha) * pos_inds | |
| neg_loss = torch.log(1 - output) * torch.pow( | |
| output, self.alpha) * neg_weights * neg_inds | |
| num_pos = pos_inds.float().sum() | |
| if num_pos == 0: | |
| loss = -neg_loss.sum() | |
| else: | |
| loss = -(pos_loss.sum() + neg_loss.sum()) / num_pos | |
| return loss * self.loss_weight | |
| class MLECCLoss(nn.Module): | |
| """Maximum Likelihood Estimation loss for Coordinate Classification. | |
| This loss function is designed to work with coordinate classification | |
| problems where the likelihood of each target coordinate is maximized. | |
| Args: | |
| reduction (str): Specifies the reduction to apply to the output: | |
| 'none' | 'mean' | 'sum'. Default: 'mean'. | |
| mode (str): Specifies the mode of calculating loss: | |
| 'linear' | 'square' | 'log'. Default: 'log'. | |
| use_target_weight (bool): If True, uses weighted loss. Different | |
| joint types may have different target weights. Defaults to False. | |
| loss_weight (float): Weight of the loss. Defaults to 1.0. | |
| Raises: | |
| AssertionError: If the `reduction` or `mode` arguments are not in the | |
| expected choices. | |
| NotImplementedError: If the selected mode is not implemented. | |
| """ | |
| def __init__(self, | |
| reduction: str = 'mean', | |
| mode: str = 'log', | |
| use_target_weight: bool = False, | |
| loss_weight: float = 1.0): | |
| super().__init__() | |
| assert reduction in ('mean', 'sum', 'none'), \ | |
| f"`reduction` should be either 'mean', 'sum', or 'none', " \ | |
| f'but got {reduction}' | |
| assert mode in ('linear', 'square', 'log'), \ | |
| f"`mode` should be either 'linear', 'square', or 'log', " \ | |
| f'but got {mode}' | |
| self.reduction = reduction | |
| self.mode = mode | |
| self.use_target_weight = use_target_weight | |
| self.loss_weight = loss_weight | |
| def forward(self, outputs, targets, target_weight=None): | |
| """Forward pass for the MLECCLoss. | |
| Args: | |
| outputs (torch.Tensor): The predicted outputs. | |
| targets (torch.Tensor): The ground truth targets. | |
| target_weight (torch.Tensor, optional): Optional tensor of weights | |
| for each target. | |
| Returns: | |
| torch.Tensor: Calculated loss based on the specified mode and | |
| reduction. | |
| """ | |
| assert len(outputs) == len(targets), \ | |
| 'Outputs and targets must have the same length' | |
| prob = 1.0 | |
| for o, t in zip(outputs, targets): | |
| prob *= (o * t).sum(dim=-1) | |
| if self.mode == 'linear': | |
| loss = 1.0 - prob | |
| elif self.mode == 'square': | |
| loss = 1.0 - prob.pow(2) | |
| elif self.mode == 'log': | |
| loss = -torch.log(prob + 1e-4) | |
| loss[torch.isnan(loss)] = 0.0 | |
| if self.use_target_weight: | |
| assert target_weight is not None | |
| for i in range(loss.ndim - target_weight.ndim): | |
| target_weight = target_weight.unsqueeze(-1) | |
| loss = loss * target_weight | |
| if self.reduction == 'sum': | |
| loss = loss.flatten(1).sum(dim=1) | |
| elif self.reduction == 'mean': | |
| loss = loss.flatten(1).mean(dim=1) | |
| return loss * self.loss_weight | |
| class OKSHeatmapLoss(nn.Module): | |
| """OKS-based loss for heatmaps. | |
| Args: | |
| use_target_weight (bool): Option to use weighted MSE loss. | |
| Different joint types may have different target weights. | |
| Defaults to ``False`` | |
| skip_empty_channel (bool): If ``True``, heatmap channels with no | |
| non-zero value (which means no visible ground-truth keypoint | |
| in the image) will not be used to calculate the loss. Defaults to | |
| ``False`` | |
| loss_weight (float): Weight of the loss. Defaults to 1.0 | |
| """ | |
| def __init__(self, | |
| use_target_weight: bool = False, | |
| skip_empty_channel: bool = False, | |
| smoothing_weight: float = 0.2, | |
| gaussian_weight: float = 0.0, | |
| loss_weight: float = 1., | |
| oks_type: str = "minus"): | |
| super().__init__() | |
| self.use_target_weight = use_target_weight | |
| self.skip_empty_channel = skip_empty_channel | |
| self.loss_weight = loss_weight | |
| self.smoothing_weight = smoothing_weight | |
| self.gaussian_weight = gaussian_weight | |
| self.oks_type = oks_type.lower() | |
| assert self.oks_type in ["minus", "plus", "both"] | |
| def forward(self, | |
| output: Tensor, | |
| target: Tensor, | |
| target_weights: Optional[Tensor] = None, | |
| mask: Optional[Tensor] = None, | |
| per_pixel: bool = False, | |
| per_keypoint: bool = False) -> Tensor: | |
| """Forward function of loss. | |
| Note: | |
| - batch_size: B | |
| - num_keypoints: K | |
| - heatmaps height: H | |
| - heatmaps weight: W | |
| Args: | |
| output (Tensor): The output heatmaps with shape [B, K, H, W] | |
| target (Tensor): The target heatmaps with shape [B, K, H, W] | |
| target_weights (Tensor, optional): The target weights of differet | |
| keypoints, with shape [B, K] (keypoint-wise) or | |
| [B, K, H, W] (pixel-wise). | |
| mask (Tensor, optional): The masks of valid heatmap pixels in | |
| shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will | |
| be applied. Defaults to ``None`` | |
| Returns: | |
| Tensor: The calculated loss. | |
| """ | |
| assert target.max() <= 1, 'target should be normalized' | |
| assert target.min() >= 0, 'target should be normalized' | |
| B, K, H, W = output.shape | |
| _mask = self._get_mask(target, target_weights, mask) | |
| oks_minus = output * (1-target) | |
| oks_plus = (1-output) * (target) | |
| if self.oks_type == "both": | |
| oks = (oks_minus + oks_plus) / 2 | |
| elif self.oks_type == "minus": | |
| oks = oks_minus | |
| elif self.oks_type == "plus": | |
| oks = oks_plus | |
| else: | |
| raise ValueError(f"oks_type {self.oks_type} not recognized") | |
| mse = F.mse_loss(output, target, reduction='none') | |
| # Smoothness loss | |
| sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=torch.float32).view(1, 1, 3, 3).to(output.device) | |
| sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32).view(1, 1, 3, 3).to(output.device) | |
| gradient_x = F.conv2d(output.reshape(B*K, 1, H, W), sobel_x, padding='same') | |
| gradient_y = F.conv2d(output.reshape(B*K, 1, H, W), sobel_y, padding='same') | |
| gradient = (gradient_x**2 + gradient_y**2).reshape(B, K, H, W) | |
| if _mask is not None: | |
| oks = oks * _mask | |
| mse = mse * _mask | |
| gradient = gradient * _mask | |
| oks_minus_weight = ( | |
| 1 - self.smoothing_weight - self.gaussian_weight | |
| ) | |
| if per_pixel: | |
| loss = ( | |
| self.smoothing_weight * gradient + | |
| oks_minus_weight * oks + | |
| self.gaussian_weight * mse | |
| ) | |
| elif per_keypoint: | |
| max_gradient, _ = gradient.reshape((B, K, H*W)).max(dim=-1) | |
| loss = ( | |
| oks_minus_weight * oks.sum(dim=(2, 3)) + | |
| self.smoothing_weight * max_gradient + | |
| self.gaussian_weight * mse.mean(dim=(2, 3)) | |
| ) | |
| else: | |
| max_gradient, _ = gradient.reshape((B, K, H*W)).max(dim=-1) | |
| loss = ( | |
| oks_minus_weight * oks.sum(dim=(2, 3)) + | |
| self.smoothing_weight * max_gradient + | |
| self.gaussian_weight * mse.mean(dim=(2, 3)) | |
| ).mean() | |
| return loss * self.loss_weight | |
| def _get_mask(self, target: Tensor, target_weights: Optional[Tensor], | |
| mask: Optional[Tensor]) -> Optional[Tensor]: | |
| """Generate the heatmap mask w.r.t. the given mask, target weight and | |
| `skip_empty_channel` setting. | |
| Returns: | |
| Tensor: The mask in shape (B, K, *) or ``None`` if no mask is | |
| needed. | |
| """ | |
| # Given spatial mask | |
| if mask is not None: | |
| # check mask has matching type with target | |
| assert (mask.ndim == target.ndim and all( | |
| d_m == d_t or d_m == 1 | |
| for d_m, d_t in zip(mask.shape, target.shape))), ( | |
| f'mask and target have mismatched shapes {mask.shape} v.s.' | |
| f'{target.shape}') | |
| # Mask by target weights (keypoint-wise mask) | |
| if target_weights is not None: | |
| # check target weight has matching shape with target | |
| assert (target_weights.ndim in (2, 4) and target_weights.shape | |
| == target.shape[:target_weights.ndim]), ( | |
| 'target_weights and target have mismatched shapes ' | |
| f'{target_weights.shape} v.s. {target.shape}') | |
| ndim_pad = target.ndim - target_weights.ndim | |
| _mask = target_weights.view(target_weights.shape + | |
| (1, ) * ndim_pad) | |
| if mask is None: | |
| mask = _mask | |
| else: | |
| mask = mask * _mask | |
| # Mask by ``skip_empty_channel`` | |
| if self.skip_empty_channel: | |
| _mask = (target != 0).flatten(2).any(dim=2) | |
| ndim_pad = target.ndim - _mask.ndim | |
| _mask = _mask.view(_mask.shape + (1, ) * ndim_pad) | |
| if mask is None: | |
| mask = _mask | |
| else: | |
| mask = mask * _mask | |
| return mask | |
| class CalibrationLoss(nn.Module): | |
| """OKS-based loss for heatmaps. | |
| Args: | |
| use_target_weight (bool): Option to use weighted MSE loss. | |
| Different joint types may have different target weights. | |
| Defaults to ``False`` | |
| skip_empty_channel (bool): If ``True``, heatmap channels with no | |
| non-zero value (which means no visible ground-truth keypoint | |
| in the image) will not be used to calculate the loss. Defaults to | |
| ``False`` | |
| loss_weight (float): Weight of the loss. Defaults to 1.0 | |
| """ | |
| def __init__(self, | |
| use_target_weight: bool = False, | |
| skip_empty_channel: bool = False, | |
| loss_weight: float = 1., | |
| ignore_bottom_percentile: float = 0.7): | |
| super().__init__() | |
| self.use_target_weight = use_target_weight | |
| self.skip_empty_channel = skip_empty_channel | |
| self.loss_weight = loss_weight | |
| self.ignore_bottom_percentile = ignore_bottom_percentile | |
| def forward(self, | |
| output: Tensor, | |
| target: Tensor, | |
| target_weights: Optional[Tensor] = None, | |
| mask: Optional[Tensor] = None, | |
| per_pixel: bool = False, | |
| per_keypoint: bool = False) -> Tensor: | |
| """Forward function of loss. | |
| Note: | |
| - batch_size: B | |
| - num_keypoints: K | |
| - heatmaps height: H | |
| - heatmaps weight: W | |
| Args: | |
| output (Tensor): The output heatmaps with shape [B, K, H, W] | |
| target (Tensor): The target heatmaps with shape [B, K, H, W] | |
| target_weights (Tensor, optional): The target weights of differet | |
| keypoints, with shape [B, K] (keypoint-wise) or | |
| [B, K, H, W] (pixel-wise). | |
| mask (Tensor, optional): The masks of valid heatmap pixels in | |
| shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will | |
| be applied. Defaults to ``None`` | |
| Returns: | |
| Tensor: The calculated loss. | |
| """ | |
| assert target.max() <= 1, 'target should be normalized' | |
| assert target.min() >= 0, 'target should be normalized' | |
| B, K, H, W = output.shape | |
| _mask = self._get_mask(target, target_weights, mask) | |
| pred_probs = output * target | |
| pred_probs_sum = pred_probs.sum(dim=(2,3)) | |
| # threshold = torch.quantile(pred_probs_sum.detach(), self.ignore_bottom_percentile) | |
| # _mask = _mask * (pred_probs_sum > self.ignore_bottom_percentile).view(B, K, 1, 1) | |
| # print() | |
| # tmp = -torch.log(pred_probs_sum.flatten() + 1e-10)[:, None] | |
| # tmp = torch.cat([pred_probs_sum.flatten()[:, None], tmp, _mask.reshape(tmp.shape)], dim=1) | |
| # print(tmp[:5, :]) | |
| if per_pixel: | |
| cross_entropy = -torch.log(pred_probs + 1e-10) | |
| loss = cross_entropy * _mask | |
| elif per_keypoint: | |
| cross_entropy = -torch.log(pred_probs_sum + 1e-10) | |
| loss = cross_entropy * _mask | |
| else: | |
| cross_entropy = -torch.log(pred_probs_sum + 1e-10) | |
| loss = cross_entropy * _mask | |
| loss = loss.mean() | |
| return loss * self.loss_weight | |
| def _get_mask(self, target: Tensor, target_weights: Optional[Tensor], | |
| mask: Optional[Tensor]) -> Optional[Tensor]: | |
| """Generate the heatmap mask w.r.t. the given mask, target weight and | |
| `skip_empty_channel` setting. | |
| Returns: | |
| Tensor: The mask in shape (B, K, *) or ``None`` if no mask is | |
| needed. | |
| """ | |
| # Given spatial mask | |
| if mask is not None: | |
| # check mask has matching type with target | |
| assert (mask.ndim == target.ndim and all( | |
| d_m == d_t or d_m == 1 | |
| for d_m, d_t in zip(mask.shape, target.shape))), ( | |
| f'mask and target have mismatched shapes {mask.shape} v.s.' | |
| f'{target.shape}') | |
| # Mask by target weights (keypoint-wise mask) | |
| if target_weights is not None: | |
| # check target weight has matching shape with target | |
| assert (target_weights.ndim in (2, 4) and target_weights.shape | |
| == target.shape[:target_weights.ndim]), ( | |
| 'target_weights and target have mismatched shapes ' | |
| f'{target_weights.shape} v.s. {target.shape}') | |
| ndim_pad = target.ndim - target_weights.ndim | |
| _mask = target_weights.view(target_weights.shape + | |
| (1, ) * ndim_pad) | |
| if mask is None: | |
| mask = _mask | |
| else: | |
| mask = mask * _mask | |
| # Mask by ``skip_empty_channel`` | |
| if self.skip_empty_channel: | |
| _mask = (target != 0).flatten(2).any(dim=2) | |
| ndim_pad = target.ndim - _mask.ndim | |
| _mask = _mask.view(_mask.shape + (1, ) * ndim_pad) | |
| if mask is None: | |
| mask = _mask | |
| else: | |
| mask = mask * _mask | |
| return mask | |