| import torch |
| from torch.nn import functional as F |
|
|
| @torch.jit.script |
| def sigmoid_focal_loss( |
| inputs: torch.Tensor, |
| targets: torch.Tensor, |
| alpha: float = 0.25, |
| gamma: float = 2.0, |
| reduction: str = "none", |
| ) -> torch.Tensor: |
| """ |
| Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. |
| Taken from |
| https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py |
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. |
| |
| Args: |
| inputs: A float tensor of arbitrary shape. |
| The predictions for each example. |
| targets: A float tensor with the same shape as inputs. Stores the binary |
| classification label for each element in inputs |
| (0 for the negative class and 1 for the positive class). |
| alpha: (optional) Weighting factor in range (0,1) to balance |
| positive vs negative examples. Default = 0.25. |
| gamma: Exponent of the modulating factor (1 - p_t) to |
| balance easy vs hard examples. |
| reduction: 'none' | 'mean' | 'sum' |
| 'none': No reduction will be applied to the output. |
| 'mean': The output will be averaged. |
| 'sum': The output will be summed. |
| Returns: |
| Loss tensor with the reduction option applied. |
| """ |
| inputs = inputs.float() |
| targets = targets.float() |
| p = torch.sigmoid(inputs) |
| ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") |
| p_t = p * targets + (1 - p) * (1 - targets) |
| loss = ce_loss * ((1 - p_t) ** gamma) |
|
|
| if alpha >= 0: |
| alpha_t = alpha * targets + (1 - alpha) * (1 - targets) |
| loss = alpha_t * loss |
|
|
| if reduction == "mean": |
| loss = loss.mean() |
| elif reduction == "sum": |
| loss = loss.sum() |
|
|
| return loss |
|
|
|
|
| @torch.jit.script |
| def ctr_giou_loss_1d( |
| input_offsets: torch.Tensor, |
| target_offsets: torch.Tensor, |
| reduction: str = 'none', |
| eps: float = 1e-8, |
| ) -> torch.Tensor: |
| """ |
| Generalized Intersection over Union Loss (Hamid Rezatofighi et. al) |
| https://arxiv.org/abs/1902.09630 |
| |
| This is an implementation that assumes a 1D event is represented using |
| the same center point with different offsets, e.g., |
| (t1, t2) = (c - o_1, c + o_2) with o_i >= 0 |
| |
| Reference code from |
| https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/giou_loss.py |
| |
| Args: |
| input/target_offsets (Tensor): 1D offsets of size (N, 2) |
| reduction: 'none' | 'mean' | 'sum' |
| 'none': No reduction will be applied to the output. |
| 'mean': The output will be averaged. |
| 'sum': The output will be summed. |
| eps (float): small number to prevent division by zero |
| """ |
| input_offsets = input_offsets.float() |
| target_offsets = target_offsets.float() |
| |
| assert (input_offsets >= 0.0).all(), "predicted offsets must be non-negative" |
| assert (target_offsets >= 0.0).all(), "GT offsets must be non-negative" |
|
|
| lp, rp = input_offsets[:, 0], input_offsets[:, 1] |
| lg, rg = target_offsets[:, 0], target_offsets[:, 1] |
|
|
| |
| lkis = torch.min(lp, lg) |
| rkis = torch.min(rp, rg) |
|
|
| |
| intsctk = rkis + lkis |
| unionk = (lp + rp) + (lg + rg) - intsctk |
| iouk = intsctk / unionk.clamp(min=eps) |
|
|
| |
| loss = 1.0 - iouk |
|
|
| if reduction == "mean": |
| loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() |
| elif reduction == "sum": |
| loss = loss.sum() |
|
|
| return loss |
|
|
| @torch.jit.script |
| def ctr_diou_loss_1d( |
| input_offsets: torch.Tensor, |
| target_offsets: torch.Tensor, |
| reduction: str = 'none', |
| eps: float = 1e-8, |
| ) -> torch.Tensor: |
| """ |
| Distance-IoU Loss (Zheng et. al) |
| https://arxiv.org/abs/1911.08287 |
| |
| This is an implementation that assumes a 1D event is represented using |
| the same center point with different offsets, e.g., |
| (t1, t2) = (c - o_1, c + o_2) with o_i >= 0 |
| |
| Reference code from |
| https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/giou_loss.py |
| |
| Args: |
| input/target_offsets (Tensor): 1D offsets of size (N, 2) |
| reduction: 'none' | 'mean' | 'sum' |
| 'none': No reduction will be applied to the output. |
| 'mean': The output will be averaged. |
| 'sum': The output will be summed. |
| eps (float): small number to prevent division by zero |
| """ |
| input_offsets = input_offsets.float() |
| target_offsets = target_offsets.float() |
| |
| assert (input_offsets >= 0.0).all(), "predicted offsets must be non-negative" |
| assert (target_offsets >= 0.0).all(), "GT offsets must be non-negative" |
|
|
| lp, rp = input_offsets[:, 0], input_offsets[:, 1] |
| lg, rg = target_offsets[:, 0], target_offsets[:, 1] |
|
|
| |
| lkis = torch.min(lp, lg) |
| rkis = torch.min(rp, rg) |
|
|
| |
| intsctk = rkis + lkis |
| unionk = (lp + rp) + (lg + rg) - intsctk |
| iouk = intsctk / unionk.clamp(min=eps) |
|
|
| |
| lc = torch.max(lp, lg) |
| rc = torch.max(rp, rg) |
| len_c = lc + rc |
|
|
| |
| rho = 0.5 * (rp - lp - rg + lg) |
|
|
| |
| loss = 1.0 - iouk + torch.square(rho / len_c.clamp(min=eps)) |
|
|
| if reduction == "mean": |
| loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() |
| elif reduction == "sum": |
| loss = loss.sum() |
|
|
| return loss |
|
|