Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from functools import partial | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmpose.registry import MODELS | |
| from mmpose.structures.bbox import bbox_overlaps | |
| class IoULoss(nn.Module): | |
| """Binary Cross Entropy loss. | |
| Args: | |
| reduction (str): Options are "none", "mean" and "sum". | |
| eps (float): Epsilon to avoid log(0). | |
| loss_weight (float): Weight of the loss. Default: 1.0. | |
| mode (str): Loss scaling mode, including "linear", "square", and "log". | |
| Default: 'log' | |
| """ | |
| def __init__(self, | |
| reduction='mean', | |
| mode='log', | |
| eps: float = 1e-16, | |
| loss_weight=1.): | |
| super().__init__() | |
| assert reduction in ('mean', 'sum', 'none'), f'the argument ' \ | |
| f'`reduction` should be either \'mean\', \'sum\' or \'none\', ' \ | |
| f'but got {reduction}' | |
| assert mode in ('linear', 'square', 'log'), f'the argument ' \ | |
| f'`reduction` should be either \'linear\', \'square\' or ' \ | |
| f'\'log\', but got {mode}' | |
| self.reduction = reduction | |
| self.criterion = partial(F.cross_entropy, reduction='none') | |
| self.loss_weight = loss_weight | |
| self.mode = mode | |
| self.eps = eps | |
| def forward(self, output, target, target_weight=None): | |
| """Forward function. | |
| Note: | |
| - batch_size: N | |
| - num_labels: K | |
| Args: | |
| output (torch.Tensor[N, K]): Output classification. | |
| target (torch.Tensor[N, K]): Target classification. | |
| """ | |
| ious = bbox_overlaps( | |
| output, target, is_aligned=True).clamp(min=self.eps) | |
| if self.mode == 'linear': | |
| loss = 1 - ious | |
| elif self.mode == 'square': | |
| loss = 1 - ious.pow(2) | |
| elif self.mode == 'log': | |
| loss = -ious.log() | |
| else: | |
| raise NotImplementedError | |
| if target_weight is not None: | |
| for i in range(loss.ndim - target_weight.ndim): | |
| target_weight = target_weight.unsqueeze(-1) | |
| loss = loss * target_weight | |
| if self.reduction == 'sum': | |
| loss = loss.sum() | |
| elif self.reduction == 'mean': | |
| loss = loss.mean() | |
| return loss * self.loss_weight | |