| |
| from typing import List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| from torch import Tensor |
|
|
| from mmdet.structures.bbox import bbox_overlaps |
| from ..task_modules.coders import BaseBBoxCoder |
| from ..task_modules.samplers import SamplingResult |
|
|
|
|
| def isr_p(cls_score: Tensor, |
| bbox_pred: Tensor, |
| bbox_targets: Tuple[Tensor], |
| rois: Tensor, |
| sampling_results: List[SamplingResult], |
| loss_cls: nn.Module, |
| bbox_coder: BaseBBoxCoder, |
| k: float = 2, |
| bias: float = 0, |
| num_class: int = 80) -> tuple: |
| """Importance-based Sample Reweighting (ISR_P), positive part. |
| |
| Args: |
| cls_score (Tensor): Predicted classification scores. |
| bbox_pred (Tensor): Predicted bbox deltas. |
| bbox_targets (tuple[Tensor]): A tuple of bbox targets, the are |
| labels, label_weights, bbox_targets, bbox_weights, respectively. |
| rois (Tensor): Anchors (single_stage) in shape (n, 4) or RoIs |
| (two_stage) in shape (n, 5). |
| sampling_results (:obj:`SamplingResult`): Sampling results. |
| loss_cls (:obj:`nn.Module`): Classification loss func of the head. |
| bbox_coder (:obj:`BaseBBoxCoder`): BBox coder of the head. |
| k (float): Power of the non-linear mapping. Defaults to 2. |
| bias (float): Shift of the non-linear mapping. Defaults to 0. |
| num_class (int): Number of classes, defaults to 80. |
| |
| Return: |
| tuple([Tensor]): labels, imp_based_label_weights, bbox_targets, |
| bbox_target_weights |
| """ |
|
|
| labels, label_weights, bbox_targets, bbox_weights = bbox_targets |
| pos_label_inds = ((labels >= 0) & |
| (labels < num_class)).nonzero().reshape(-1) |
| pos_labels = labels[pos_label_inds] |
|
|
| |
| num_pos = float(pos_label_inds.size(0)) |
| if num_pos == 0: |
| return labels, label_weights, bbox_targets, bbox_weights |
|
|
| |
| gts = list() |
| last_max_gt = 0 |
| for i in range(len(sampling_results)): |
| gt_i = sampling_results[i].pos_assigned_gt_inds |
| gts.append(gt_i + last_max_gt) |
| if len(gt_i) != 0: |
| last_max_gt = gt_i.max() + 1 |
| gts = torch.cat(gts) |
| assert len(gts) == num_pos |
|
|
| cls_score = cls_score.detach() |
| bbox_pred = bbox_pred.detach() |
|
|
| |
| |
| if rois.size(-1) == 5: |
| pos_rois = rois[pos_label_inds][:, 1:] |
| else: |
| pos_rois = rois[pos_label_inds] |
|
|
| if bbox_pred.size(-1) > 4: |
| bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4) |
| pos_delta_pred = bbox_pred[pos_label_inds, pos_labels].view(-1, 4) |
| else: |
| pos_delta_pred = bbox_pred[pos_label_inds].view(-1, 4) |
|
|
| |
| pos_delta_target = bbox_targets[pos_label_inds].view(-1, 4) |
| pos_bbox_pred = bbox_coder.decode(pos_rois, pos_delta_pred) |
| target_bbox_pred = bbox_coder.decode(pos_rois, pos_delta_target) |
| ious = bbox_overlaps(pos_bbox_pred, target_bbox_pred, is_aligned=True) |
|
|
| pos_imp_weights = label_weights[pos_label_inds] |
| |
| |
| max_l_num = pos_labels.bincount().max() |
| for label in pos_labels.unique(): |
| l_inds = (pos_labels == label).nonzero().view(-1) |
| l_gts = gts[l_inds] |
| for t in l_gts.unique(): |
| t_inds = l_inds[l_gts == t] |
| t_ious = ious[t_inds] |
| _, t_iou_rank_idx = t_ious.sort(descending=True) |
| _, t_iou_rank = t_iou_rank_idx.sort() |
| ious[t_inds] += max_l_num - t_iou_rank.float() |
| l_ious = ious[l_inds] |
| _, l_iou_rank_idx = l_ious.sort(descending=True) |
| _, l_iou_rank = l_iou_rank_idx.sort() |
| |
| pos_imp_weights[l_inds] *= (max_l_num - l_iou_rank.float()) / max_l_num |
|
|
| pos_imp_weights = (bias + pos_imp_weights * (1 - bias)).pow(k) |
|
|
| |
| pos_loss_cls = loss_cls( |
| cls_score[pos_label_inds], pos_labels, reduction_override='none') |
| if pos_loss_cls.dim() > 1: |
| ori_pos_loss_cls = pos_loss_cls * label_weights[pos_label_inds][:, |
| None] |
| new_pos_loss_cls = pos_loss_cls * pos_imp_weights[:, None] |
| else: |
| ori_pos_loss_cls = pos_loss_cls * label_weights[pos_label_inds] |
| new_pos_loss_cls = pos_loss_cls * pos_imp_weights |
| pos_loss_cls_ratio = ori_pos_loss_cls.sum() / new_pos_loss_cls.sum() |
| pos_imp_weights = pos_imp_weights * pos_loss_cls_ratio |
| label_weights[pos_label_inds] = pos_imp_weights |
|
|
| bbox_targets = labels, label_weights, bbox_targets, bbox_weights |
| return bbox_targets |
|
|
|
|
| def carl_loss(cls_score: Tensor, |
| labels: Tensor, |
| bbox_pred: Tensor, |
| bbox_targets: Tensor, |
| loss_bbox: nn.Module, |
| k: float = 1, |
| bias: float = 0.2, |
| avg_factor: Optional[int] = None, |
| sigmoid: bool = False, |
| num_class: int = 80) -> dict: |
| """Classification-Aware Regression Loss (CARL). |
| |
| Args: |
| cls_score (Tensor): Predicted classification scores. |
| labels (Tensor): Targets of classification. |
| bbox_pred (Tensor): Predicted bbox deltas. |
| bbox_targets (Tensor): Target of bbox regression. |
| loss_bbox (func): Regression loss func of the head. |
| bbox_coder (obj): BBox coder of the head. |
| k (float): Power of the non-linear mapping. Defaults to 1. |
| bias (float): Shift of the non-linear mapping. Defaults to 0.2. |
| avg_factor (int, optional): Average factor used in regression loss. |
| sigmoid (bool): Activation of the classification score. |
| num_class (int): Number of classes, defaults to 80. |
| |
| Return: |
| dict: CARL loss dict. |
| """ |
| pos_label_inds = ((labels >= 0) & |
| (labels < num_class)).nonzero().reshape(-1) |
| if pos_label_inds.numel() == 0: |
| return dict(loss_carl=cls_score.sum()[None] * 0.) |
| pos_labels = labels[pos_label_inds] |
|
|
| |
| |
| if sigmoid: |
| pos_cls_score = cls_score.sigmoid()[pos_label_inds, pos_labels] |
| else: |
| pos_cls_score = cls_score.softmax(-1)[pos_label_inds, pos_labels] |
| carl_loss_weights = (bias + (1 - bias) * pos_cls_score).pow(k) |
|
|
| |
| num_pos = float(pos_cls_score.size(0)) |
| weight_ratio = num_pos / carl_loss_weights.sum() |
| carl_loss_weights *= weight_ratio |
|
|
| if avg_factor is None: |
| avg_factor = bbox_targets.size(0) |
| |
| |
| if bbox_pred.size(-1) > 4: |
| bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4) |
| pos_bbox_preds = bbox_pred[pos_label_inds, pos_labels] |
| else: |
| pos_bbox_preds = bbox_pred[pos_label_inds] |
| ori_loss_reg = loss_bbox( |
| pos_bbox_preds, |
| bbox_targets[pos_label_inds], |
| reduction_override='none') / avg_factor |
| loss_carl = (ori_loss_reg * carl_loss_weights[:, None]).sum() |
| return dict(loss_carl=loss_carl[None]) |
|
|