| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss |
|
|
| from mmdet.registry import MODELS |
| from .accuracy import accuracy |
| from .utils import weight_reduce_loss |
|
|
|
|
| |
| def py_sigmoid_focal_loss(pred, |
| target, |
| weight=None, |
| gamma=2.0, |
| alpha=0.25, |
| reduction='mean', |
| avg_factor=None): |
| """PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_. |
| |
| Args: |
| pred (torch.Tensor): The prediction with shape (N, C), C is the |
| number of classes |
| target (torch.Tensor): The learning label of the prediction. |
| weight (torch.Tensor, optional): Sample-wise loss weight. |
| gamma (float, optional): The gamma for calculating the modulating |
| factor. Defaults to 2.0. |
| alpha (float, optional): A balanced form for Focal Loss. |
| Defaults to 0.25. |
| reduction (str, optional): The method used to reduce the loss into |
| a scalar. Defaults to 'mean'. |
| avg_factor (int, optional): Average factor that is used to average |
| the loss. Defaults to None. |
| """ |
| pred_sigmoid = pred.sigmoid() |
| target = target.type_as(pred) |
| |
| pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) |
| |
| focal_weight = (alpha * target + (1 - alpha) * |
| (1 - target)) * pt.pow(gamma) |
| loss = F.binary_cross_entropy_with_logits( |
| pred, target, reduction='none') * focal_weight |
| if weight is not None: |
| if weight.shape != loss.shape: |
| if weight.size(0) == loss.size(0): |
| |
| |
| weight = weight.view(-1, 1) |
| else: |
| |
| |
| |
| |
| assert weight.numel() == loss.numel() |
| weight = weight.view(loss.size(0), -1) |
| assert weight.ndim == loss.ndim |
| loss = weight_reduce_loss(loss, weight, reduction, avg_factor) |
| return loss |
|
|
|
|
| def py_focal_loss_with_prob(pred, |
| target, |
| weight=None, |
| gamma=2.0, |
| alpha=0.25, |
| reduction='mean', |
| avg_factor=None): |
| """PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_. |
| Different from `py_sigmoid_focal_loss`, this function accepts probability |
| as input. |
| |
| Args: |
| pred (torch.Tensor): The prediction probability with shape (N, C), |
| C is the number of classes. |
| target (torch.Tensor): The learning label of the prediction. |
| The target shape support (N,C) or (N,), (N,C) means one-hot form. |
| weight (torch.Tensor, optional): Sample-wise loss weight. |
| gamma (float, optional): The gamma for calculating the modulating |
| factor. Defaults to 2.0. |
| alpha (float, optional): A balanced form for Focal Loss. |
| Defaults to 0.25. |
| reduction (str, optional): The method used to reduce the loss into |
| a scalar. Defaults to 'mean'. |
| avg_factor (int, optional): Average factor that is used to average |
| the loss. Defaults to None. |
| """ |
| if pred.dim() != target.dim(): |
| num_classes = pred.size(1) |
| target = F.one_hot(target, num_classes=num_classes + 1) |
| target = target[:, :num_classes] |
|
|
| target = target.type_as(pred) |
| pt = (1 - pred) * target + pred * (1 - target) |
| focal_weight = (alpha * target + (1 - alpha) * |
| (1 - target)) * pt.pow(gamma) |
| loss = F.binary_cross_entropy( |
| pred, target, reduction='none') * focal_weight |
| if weight is not None: |
| if weight.shape != loss.shape: |
| if weight.size(0) == loss.size(0): |
| |
| |
| weight = weight.view(-1, 1) |
| else: |
| |
| |
| |
| |
| assert weight.numel() == loss.numel() |
| weight = weight.view(loss.size(0), -1) |
| assert weight.ndim == loss.ndim |
| loss = weight_reduce_loss(loss, weight, reduction, avg_factor) |
| return loss |
|
|
|
|
| def sigmoid_focal_loss(pred, |
| target, |
| weight=None, |
| gamma=2.0, |
| alpha=0.25, |
| reduction='mean', |
| avg_factor=None): |
| r"""A wrapper of cuda version `Focal Loss |
| <https://arxiv.org/abs/1708.02002>`_. |
| |
| Args: |
| pred (torch.Tensor): The prediction with shape (N, C), C is the number |
| of classes. |
| target (torch.Tensor): The learning label of the prediction. |
| weight (torch.Tensor, optional): Sample-wise loss weight. |
| gamma (float, optional): The gamma for calculating the modulating |
| factor. Defaults to 2.0. |
| alpha (float, optional): A balanced form for Focal Loss. |
| Defaults to 0.25. |
| reduction (str, optional): The method used to reduce the loss into |
| a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". |
| avg_factor (int, optional): Average factor that is used to average |
| the loss. Defaults to None. |
| """ |
| |
| |
| loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma, |
| alpha, None, 'none') |
| if weight is not None: |
| if weight.shape != loss.shape: |
| if weight.size(0) == loss.size(0): |
| |
| |
| weight = weight.view(-1, 1) |
| else: |
| |
| |
| |
| |
| assert weight.numel() == loss.numel() |
| weight = weight.view(loss.size(0), -1) |
| assert weight.ndim == loss.ndim |
| loss = weight_reduce_loss(loss, weight, reduction, avg_factor) |
| return loss |
|
|
|
|
| @MODELS.register_module() |
| class FocalLoss(nn.Module): |
|
|
| def __init__(self, |
| use_sigmoid=True, |
| gamma=2.0, |
| alpha=0.25, |
| reduction='mean', |
| loss_weight=1.0, |
| activated=False): |
| """`Focal Loss <https://arxiv.org/abs/1708.02002>`_ |
| |
| Args: |
| use_sigmoid (bool, optional): Whether to the prediction is |
| used for sigmoid or softmax. Defaults to True. |
| gamma (float, optional): The gamma for calculating the modulating |
| factor. Defaults to 2.0. |
| alpha (float, optional): A balanced form for Focal Loss. |
| Defaults to 0.25. |
| reduction (str, optional): The method used to reduce the loss into |
| a scalar. Defaults to 'mean'. Options are "none", "mean" and |
| "sum". |
| loss_weight (float, optional): Weight of loss. Defaults to 1.0. |
| activated (bool, optional): Whether the input is activated. |
| If True, it means the input has been activated and can be |
| treated as probabilities. Else, it should be treated as logits. |
| Defaults to False. |
| """ |
| super(FocalLoss, self).__init__() |
| assert use_sigmoid is True, 'Only sigmoid focal loss supported now.' |
| self.use_sigmoid = use_sigmoid |
| self.gamma = gamma |
| self.alpha = alpha |
| self.reduction = reduction |
| self.loss_weight = loss_weight |
| self.activated = activated |
|
|
| def forward(self, |
| pred, |
| target, |
| weight=None, |
| avg_factor=None, |
| reduction_override=None): |
| """Forward function. |
| |
| Args: |
| pred (torch.Tensor): The prediction. |
| target (torch.Tensor): The learning label of the prediction. |
| The target shape support (N,C) or (N,), (N,C) means |
| one-hot form. |
| weight (torch.Tensor, optional): The weight of loss for each |
| prediction. Defaults to None. |
| avg_factor (int, optional): Average factor that is used to average |
| the loss. Defaults to None. |
| reduction_override (str, optional): The reduction method used to |
| override the original reduction method of the loss. |
| Options are "none", "mean" and "sum". |
| |
| Returns: |
| torch.Tensor: The calculated loss |
| """ |
| assert reduction_override in (None, 'none', 'mean', 'sum') |
| reduction = ( |
| reduction_override if reduction_override else self.reduction) |
| if self.use_sigmoid: |
| if self.activated: |
| calculate_loss_func = py_focal_loss_with_prob |
| else: |
| if pred.dim() == target.dim(): |
| |
| calculate_loss_func = py_sigmoid_focal_loss |
| elif torch.cuda.is_available() and pred.is_cuda: |
| calculate_loss_func = sigmoid_focal_loss |
| else: |
| num_classes = pred.size(1) |
| target = F.one_hot(target, num_classes=num_classes + 1) |
| target = target[:, :num_classes] |
| calculate_loss_func = py_sigmoid_focal_loss |
|
|
| loss_cls = self.loss_weight * calculate_loss_func( |
| pred, |
| target, |
| weight, |
| gamma=self.gamma, |
| alpha=self.alpha, |
| reduction=reduction, |
| avg_factor=avg_factor) |
|
|
| else: |
| raise NotImplementedError |
| return loss_cls |
|
|
|
|
| @MODELS.register_module() |
| class FocalCustomLoss(nn.Module): |
|
|
| def __init__(self, |
| use_sigmoid=True, |
| num_classes=-1, |
| gamma=2.0, |
| alpha=0.25, |
| reduction='mean', |
| loss_weight=1.0, |
| activated=False): |
| """`Focal Loss for V3Det <https://arxiv.org/abs/1708.02002>`_ |
| |
| Args: |
| use_sigmoid (bool, optional): Whether to the prediction is |
| used for sigmoid or softmax. Defaults to True. |
| num_classes (int): Number of classes to classify. |
| gamma (float, optional): The gamma for calculating the modulating |
| factor. Defaults to 2.0. |
| alpha (float, optional): A balanced form for Focal Loss. |
| Defaults to 0.25. |
| reduction (str, optional): The method used to reduce the loss into |
| a scalar. Defaults to 'mean'. Options are "none", "mean" and |
| "sum". |
| loss_weight (float, optional): Weight of loss. Defaults to 1.0. |
| activated (bool, optional): Whether the input is activated. |
| If True, it means the input has been activated and can be |
| treated as probabilities. Else, it should be treated as logits. |
| Defaults to False. |
| """ |
| super(FocalCustomLoss, self).__init__() |
| assert use_sigmoid is True, 'Only sigmoid focal loss supported now.' |
| self.use_sigmoid = use_sigmoid |
| self.num_classes = num_classes |
| self.gamma = gamma |
| self.alpha = alpha |
| self.reduction = reduction |
| self.loss_weight = loss_weight |
| self.activated = activated |
|
|
| assert self.num_classes != -1 |
|
|
| |
| self.custom_cls_channels = True |
| |
| self.custom_activation = True |
| |
| self.custom_accuracy = True |
|
|
| def get_cls_channels(self, num_classes): |
| assert num_classes == self.num_classes |
| return num_classes |
|
|
| def get_activation(self, cls_score): |
|
|
| fine_cls_score = cls_score[:, :self.num_classes] |
|
|
| score_classes = fine_cls_score.sigmoid() |
|
|
| return score_classes |
|
|
| def get_accuracy(self, cls_score, labels): |
|
|
| fine_cls_score = cls_score[:, :self.num_classes] |
|
|
| pos_inds = labels < self.num_classes |
| acc_classes = accuracy(fine_cls_score[pos_inds], labels[pos_inds]) |
| acc = dict() |
| acc['acc_classes'] = acc_classes |
| return acc |
|
|
| def forward(self, |
| pred, |
| target, |
| weight=None, |
| avg_factor=None, |
| reduction_override=None): |
| """Forward function. |
| |
| Args: |
| pred (torch.Tensor): The prediction. |
| target (torch.Tensor): The learning label of the prediction. |
| weight (torch.Tensor, optional): The weight of loss for each |
| prediction. Defaults to None. |
| avg_factor (int, optional): Average factor that is used to average |
| the loss. Defaults to None. |
| reduction_override (str, optional): The reduction method used to |
| override the original reduction method of the loss. |
| Options are "none", "mean" and "sum". |
| |
| Returns: |
| torch.Tensor: The calculated loss |
| """ |
| assert reduction_override in (None, 'none', 'mean', 'sum') |
| reduction = ( |
| reduction_override if reduction_override else self.reduction) |
| if self.use_sigmoid: |
|
|
| num_classes = pred.size(1) |
| target = F.one_hot(target, num_classes=num_classes + 1) |
| target = target[:, :num_classes] |
| calculate_loss_func = py_sigmoid_focal_loss |
|
|
| loss_cls = self.loss_weight * calculate_loss_func( |
| pred, |
| target, |
| weight, |
| gamma=self.gamma, |
| alpha=self.alpha, |
| reduction=reduction, |
| avg_factor=avg_factor) |
|
|
| else: |
| raise NotImplementedError |
| return loss_cls |
|
|