| | |
| | from typing import Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| | from mmengine.model import BaseModule |
| | from torch import Tensor |
| |
|
| | from mmdet.registry import MODELS |
| | from .mse_loss import mse_loss |
| |
|
| |
|
| | @MODELS.register_module() |
| | class MarginL2Loss(BaseModule): |
| | """L2 loss with margin. |
| | |
| | Args: |
| | neg_pos_ub (int, optional): The upper bound of negative to positive |
| | samples in hard mining. Defaults to -1. |
| | pos_margin (float, optional): The similarity margin for positive |
| | samples in hard mining. Defaults to -1. |
| | neg_margin (float, optional): The similarity margin for negative |
| | samples in hard mining. Defaults to -1. |
| | hard_mining (bool, optional): Whether to use hard mining. Defaults to |
| | False. |
| | reduction (str, optional): The method to reduce the loss. |
| | Options are "none", "mean" and "sum". Defaults to "mean". |
| | loss_weight (float, optional): The weight of loss. Defaults to 1.0. |
| | """ |
| |
|
| | def __init__(self, |
| | neg_pos_ub: int = -1, |
| | pos_margin: float = -1, |
| | neg_margin: float = -1, |
| | hard_mining: bool = False, |
| | reduction: str = 'mean', |
| | loss_weight: float = 1.0): |
| | super(MarginL2Loss, self).__init__() |
| | self.neg_pos_ub = neg_pos_ub |
| | self.pos_margin = pos_margin |
| | self.neg_margin = neg_margin |
| | self.hard_mining = hard_mining |
| | self.reduction = reduction |
| | self.loss_weight = loss_weight |
| |
|
| | def forward(self, |
| | pred: Tensor, |
| | target: Tensor, |
| | weight: Optional[Tensor] = None, |
| | avg_factor: Optional[float] = None, |
| | reduction_override: Optional[str] = None) -> Tensor: |
| | """Forward function. |
| | |
| | Args: |
| | pred (torch.Tensor): The prediction. |
| | target (torch.Tensor): The learning target of the prediction. |
| | weight (torch.Tensor, optional): The weight of loss for each |
| | prediction. Defaults to None. |
| | avg_factor (float, optional): Average factor that is used to |
| | average the loss. Defaults to None. |
| | reduction_override (str, optional): The reduction method used to |
| | override the original reduction method of the loss. |
| | Defaults to None. |
| | """ |
| | assert reduction_override in (None, 'none', 'mean', 'sum') |
| | reduction = ( |
| | reduction_override if reduction_override else self.reduction) |
| | pred, weight, avg_factor = self.update_weight(pred, target, weight, |
| | avg_factor) |
| | loss_bbox = self.loss_weight * mse_loss( |
| | pred, |
| | target.float(), |
| | weight.float(), |
| | reduction=reduction, |
| | avg_factor=avg_factor) |
| | return loss_bbox |
| |
|
| | def update_weight(self, pred: Tensor, target: Tensor, weight: Tensor, |
| | avg_factor: float) -> Tuple[Tensor, Tensor, float]: |
| | """Update the weight according to targets. |
| | |
| | Args: |
| | pred (torch.Tensor): The prediction. |
| | target (torch.Tensor): The learning target of the prediction. |
| | weight (torch.Tensor): The weight of loss for each prediction. |
| | avg_factor (float): Average factor that is used to average the |
| | loss. |
| | |
| | Returns: |
| | tuple[torch.Tensor]: The updated prediction, weight and average |
| | factor. |
| | """ |
| | if weight is None: |
| | weight = target.new_ones(target.size()) |
| |
|
| | invalid_inds = weight <= 0 |
| | target[invalid_inds] = -1 |
| | pos_inds = target == 1 |
| | neg_inds = target == 0 |
| |
|
| | if self.pos_margin > 0: |
| | pred[pos_inds] -= self.pos_margin |
| | if self.neg_margin > 0: |
| | pred[neg_inds] -= self.neg_margin |
| | pred = torch.clamp(pred, min=0, max=1) |
| |
|
| | num_pos = int((target == 1).sum()) |
| | num_neg = int((target == 0).sum()) |
| | if self.neg_pos_ub > 0 and num_neg / (num_pos + |
| | 1e-6) > self.neg_pos_ub: |
| | num_neg = num_pos * self.neg_pos_ub |
| | neg_idx = torch.nonzero(target == 0, as_tuple=False) |
| |
|
| | if self.hard_mining: |
| | costs = mse_loss( |
| | pred, target.float(), |
| | reduction='none')[neg_idx[:, 0], neg_idx[:, 1]].detach() |
| | neg_idx = neg_idx[costs.topk(num_neg)[1], :] |
| | else: |
| | neg_idx = self.random_choice(neg_idx, num_neg) |
| |
|
| | new_neg_inds = neg_inds.new_zeros(neg_inds.size()).bool() |
| | new_neg_inds[neg_idx[:, 0], neg_idx[:, 1]] = True |
| |
|
| | invalid_neg_inds = torch.logical_xor(neg_inds, new_neg_inds) |
| | weight[invalid_neg_inds] = 0 |
| |
|
| | avg_factor = (weight > 0).sum() |
| | return pred, weight, avg_factor |
| |
|
| | @staticmethod |
| | def random_choice(gallery: Union[list, np.ndarray, Tensor], |
| | num: int) -> np.ndarray: |
| | """Random select some elements from the gallery. |
| | |
| | It seems that Pytorch's implementation is slower than numpy so we use |
| | numpy to randperm the indices. |
| | |
| | Args: |
| | gallery (list | np.ndarray | torch.Tensor): The gallery from |
| | which to sample. |
| | num (int): The number of elements to sample. |
| | """ |
| | assert len(gallery) >= num |
| | if isinstance(gallery, list): |
| | gallery = np.array(gallery) |
| | cands = np.arange(len(gallery)) |
| | np.random.shuffle(cands) |
| | rand_inds = cands[:num] |
| | if not isinstance(gallery, np.ndarray): |
| | rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device) |
| | return gallery[rand_inds] |
| |
|