Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| def sigmoid_focal_loss( | |
| inputs: torch.Tensor, | |
| targets: torch.Tensor, | |
| alpha: float = 0.25, | |
| gamma: float = 2, | |
| reduction: str = "none", | |
| ) -> torch.Tensor: | |
| """ | |
| Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. | |
| Args: | |
| inputs (Tensor): A float tensor of arbitrary shape. | |
| The predictions for each example. | |
| targets (Tensor): A float tensor with the same shape as inputs. Stores the binary | |
| classification label for each element in inputs | |
| (0 for the negative class and 1 for the positive class). | |
| alpha (float): Weighting factor in range (0,1) to balance | |
| positive vs negative examples or -1 for ignore. Default: ``0.25``. | |
| gamma (float): Exponent of the modulating factor (1 - p_t) to | |
| balance easy vs hard examples. Default: ``2``. | |
| reduction (string): ``'none'`` | ``'mean'`` | ``'sum'`` | |
| ``'none'``: No reduction will be applied to the output. | |
| ``'mean'``: The output will be averaged. | |
| ``'sum'``: The output will be summed. Default: ``'none'``. | |
| Returns: | |
| Loss tensor with the reduction option applied. | |
| """ | |
| # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py | |
| p = torch.sigmoid(inputs) | |
| ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") | |
| p_t = p * targets + (1 - p) * (1 - targets) | |
| loss = ce_loss * ((1 - p_t) ** gamma) | |
| if alpha >= 0: # decrease the importance of negative samples | |
| alpha_t = alpha * targets + (1 - alpha) * (1 - targets) | |
| loss = alpha_t * loss | |
| # Check reduction option and return loss accordingly | |
| if reduction == "none": | |
| pass | |
| elif reduction == "mean": | |
| loss = loss.mean() | |
| elif reduction == "sum": | |
| loss = loss.sum() | |
| else: | |
| raise ValueError( | |
| f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'" | |
| ) | |
| return loss | |