| |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| from torch import Tensor |
|
|
| from mmdet.models.losses.utils import weighted_loss |
| from mmdet.registry import MODELS |
|
|
|
|
| @weighted_loss |
| def huber_loss(pred: Tensor, target: Tensor, beta: float = 1.0) -> Tensor: |
| """Huber loss. |
| |
| Args: |
| pred (Tensor): The prediction. |
| target (Tensor): The learning target of the prediction. |
| beta (float, optional): The threshold in the piecewise function. |
| Defaults to 1.0. |
| |
| Returns: |
| Tensor: Calculated loss |
| """ |
| assert beta > 0 |
| if target.numel() == 0: |
| return pred.sum() * 0 |
|
|
| assert pred.size() == target.size() |
| diff = torch.abs(pred - target) |
| loss = torch.where(diff < beta, 0.5 * diff * diff, |
| beta * diff - 0.5 * beta * beta) |
| return loss |
|
|
|
|
| @MODELS.register_module() |
| class HuberLoss(nn.Module): |
| """Huber loss. |
| |
| Args: |
| beta (float, optional): The threshold in the piecewise function. |
| Defaults to 1.0. |
| reduction (str, optional): The method to reduce the loss. |
| Options are "none", "mean" and "sum". Defaults to "mean". |
| loss_weight (float, optional): The weight of loss. |
| """ |
|
|
| def __init__(self, |
| beta: float = 1.0, |
| reduction: str = 'mean', |
| loss_weight: float = 1.0) -> None: |
| super().__init__() |
| self.beta = beta |
| self.reduction = reduction |
| self.loss_weight = loss_weight |
|
|
| def forward(self, |
| pred: Tensor, |
| target: Tensor, |
| weight: Optional[Tensor] = None, |
| avg_factor: Optional[int] = None, |
| reduction_override: Optional[str] = None, |
| **kwargs) -> Tensor: |
| """Forward function. |
| |
| Args: |
| pred (Tensor): The prediction. |
| target (Tensor): The learning target of the prediction. |
| weight (Tensor, optional): The weight of loss for each |
| prediction. Defaults to None. |
| avg_factor (int, optional): Average factor that is used to average |
| the loss. Defaults to None. |
| reduction_override (str, optional): The reduction method used to |
| override the original reduction method of the loss. |
| Defaults to None. |
| |
| Returns: |
| Tensor: Calculated loss |
| """ |
| assert reduction_override in (None, 'none', 'mean', 'sum') |
| reduction = ( |
| reduction_override if reduction_override else self.reduction) |
| loss_bbox = self.loss_weight * huber_loss( |
| pred, |
| target, |
| weight, |
| beta=self.beta, |
| reduction=reduction, |
| avg_factor=avg_factor, |
| **kwargs) |
| return loss_bbox |
|
|