| |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
| from mmengine.structures import InstanceData |
| from torch import Tensor |
|
|
| from mmdet.registry import TASK_UTILS |
| from mmdet.utils import ConfigType |
| from .assign_result import AssignResult |
| from .base_assigner import BaseAssigner |
|
|
| INF = 100000.0 |
| EPS = 1.0e-7 |
|
|
|
|
| @TASK_UTILS.register_module() |
| class SimOTAAssigner(BaseAssigner): |
| """Computes matching between predictions and ground truth. |
| |
| Args: |
| center_radius (float): Ground truth center size |
| to judge whether a prior is in center. Defaults to 2.5. |
| candidate_topk (int): The candidate top-k which used to |
| get top-k ious to calculate dynamic-k. Defaults to 10. |
| iou_weight (float): The scale factor for regression |
| iou cost. Defaults to 3.0. |
| cls_weight (float): The scale factor for classification |
| cost. Defaults to 1.0. |
| iou_calculator (ConfigType): Config of overlaps Calculator. |
| Defaults to dict(type='BboxOverlaps2D'). |
| """ |
|
|
| def __init__(self, |
| center_radius: float = 2.5, |
| candidate_topk: int = 10, |
| iou_weight: float = 3.0, |
| cls_weight: float = 1.0, |
| iou_calculator: ConfigType = dict(type='BboxOverlaps2D')): |
| self.center_radius = center_radius |
| self.candidate_topk = candidate_topk |
| self.iou_weight = iou_weight |
| self.cls_weight = cls_weight |
| self.iou_calculator = TASK_UTILS.build(iou_calculator) |
|
|
| def assign(self, |
| pred_instances: InstanceData, |
| gt_instances: InstanceData, |
| gt_instances_ignore: Optional[InstanceData] = None, |
| **kwargs) -> AssignResult: |
| """Assign gt to priors using SimOTA. |
| |
| Args: |
| pred_instances (:obj:`InstanceData`): Instances of model |
| predictions. It includes ``priors``, and the priors can |
| be anchors or points, or the bboxes predicted by the |
| previous stage, has shape (n, 4). The bboxes predicted by |
| the current model or stage will be named ``bboxes``, |
| ``labels``, and ``scores``, the same as the ``InstanceData`` |
| in other places. |
| gt_instances (:obj:`InstanceData`): Ground truth of instance |
| annotations. It usually includes ``bboxes``, with shape (k, 4), |
| and ``labels``, with shape (k, ). |
| gt_instances_ignore (:obj:`InstanceData`, optional): Instances |
| to be ignored during training. It includes ``bboxes`` |
| attribute data that is ignored during training and testing. |
| Defaults to None. |
| Returns: |
| obj:`AssignResult`: The assigned result. |
| """ |
| gt_bboxes = gt_instances.bboxes |
| gt_labels = gt_instances.labels |
| num_gt = gt_bboxes.size(0) |
|
|
| decoded_bboxes = pred_instances.bboxes |
| pred_scores = pred_instances.scores |
| priors = pred_instances.priors |
| num_bboxes = decoded_bboxes.size(0) |
|
|
| |
| assigned_gt_inds = decoded_bboxes.new_full((num_bboxes, ), |
| 0, |
| dtype=torch.long) |
| if num_gt == 0 or num_bboxes == 0: |
| |
| max_overlaps = decoded_bboxes.new_zeros((num_bboxes, )) |
| assigned_labels = decoded_bboxes.new_full((num_bboxes, ), |
| -1, |
| dtype=torch.long) |
| return AssignResult( |
| num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) |
|
|
| valid_mask, is_in_boxes_and_center = self.get_in_gt_and_in_center_info( |
| priors, gt_bboxes) |
| valid_decoded_bbox = decoded_bboxes[valid_mask] |
| valid_pred_scores = pred_scores[valid_mask] |
| num_valid = valid_decoded_bbox.size(0) |
| if num_valid == 0: |
| |
| max_overlaps = decoded_bboxes.new_zeros((num_bboxes, )) |
| assigned_labels = decoded_bboxes.new_full((num_bboxes, ), |
| -1, |
| dtype=torch.long) |
| return AssignResult( |
| num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) |
|
|
| pairwise_ious = self.iou_calculator(valid_decoded_bbox, gt_bboxes) |
| iou_cost = -torch.log(pairwise_ious + EPS) |
|
|
| gt_onehot_label = ( |
| F.one_hot(gt_labels.to(torch.int64), |
| pred_scores.shape[-1]).float().unsqueeze(0).repeat( |
| num_valid, 1, 1)) |
|
|
| valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1) |
| |
| with torch.cuda.amp.autocast(enabled=False): |
| cls_cost = ( |
| F.binary_cross_entropy( |
| valid_pred_scores.to(dtype=torch.float32), |
| gt_onehot_label, |
| reduction='none', |
| ).sum(-1).to(dtype=valid_pred_scores.dtype)) |
|
|
| cost_matrix = ( |
| cls_cost * self.cls_weight + iou_cost * self.iou_weight + |
| (~is_in_boxes_and_center) * INF) |
|
|
| matched_pred_ious, matched_gt_inds = \ |
| self.dynamic_k_matching( |
| cost_matrix, pairwise_ious, num_gt, valid_mask) |
|
|
| |
| assigned_gt_inds[valid_mask] = matched_gt_inds + 1 |
| assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) |
| assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long() |
| max_overlaps = assigned_gt_inds.new_full((num_bboxes, ), |
| -INF, |
| dtype=torch.float32) |
| max_overlaps[valid_mask] = matched_pred_ious |
| return AssignResult( |
| num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) |
|
|
| def get_in_gt_and_in_center_info( |
| self, priors: Tensor, gt_bboxes: Tensor) -> Tuple[Tensor, Tensor]: |
| """Get the information of which prior is in gt bboxes and gt center |
| priors.""" |
| num_gt = gt_bboxes.size(0) |
|
|
| repeated_x = priors[:, 0].unsqueeze(1).repeat(1, num_gt) |
| repeated_y = priors[:, 1].unsqueeze(1).repeat(1, num_gt) |
| repeated_stride_x = priors[:, 2].unsqueeze(1).repeat(1, num_gt) |
| repeated_stride_y = priors[:, 3].unsqueeze(1).repeat(1, num_gt) |
|
|
| |
| l_ = repeated_x - gt_bboxes[:, 0] |
| t_ = repeated_y - gt_bboxes[:, 1] |
| r_ = gt_bboxes[:, 2] - repeated_x |
| b_ = gt_bboxes[:, 3] - repeated_y |
|
|
| deltas = torch.stack([l_, t_, r_, b_], dim=1) |
| is_in_gts = deltas.min(dim=1).values > 0 |
| is_in_gts_all = is_in_gts.sum(dim=1) > 0 |
|
|
| |
| gt_cxs = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0 |
| gt_cys = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0 |
| ct_box_l = gt_cxs - self.center_radius * repeated_stride_x |
| ct_box_t = gt_cys - self.center_radius * repeated_stride_y |
| ct_box_r = gt_cxs + self.center_radius * repeated_stride_x |
| ct_box_b = gt_cys + self.center_radius * repeated_stride_y |
|
|
| cl_ = repeated_x - ct_box_l |
| ct_ = repeated_y - ct_box_t |
| cr_ = ct_box_r - repeated_x |
| cb_ = ct_box_b - repeated_y |
|
|
| ct_deltas = torch.stack([cl_, ct_, cr_, cb_], dim=1) |
| is_in_cts = ct_deltas.min(dim=1).values > 0 |
| is_in_cts_all = is_in_cts.sum(dim=1) > 0 |
|
|
| |
| is_in_gts_or_centers = is_in_gts_all | is_in_cts_all |
|
|
| |
| is_in_boxes_and_centers = ( |
| is_in_gts[is_in_gts_or_centers, :] |
| & is_in_cts[is_in_gts_or_centers, :]) |
| return is_in_gts_or_centers, is_in_boxes_and_centers |
|
|
| def dynamic_k_matching(self, cost: Tensor, pairwise_ious: Tensor, |
| num_gt: int, |
| valid_mask: Tensor) -> Tuple[Tensor, Tensor]: |
| """Use IoU and matching cost to calculate the dynamic top-k positive |
| targets.""" |
| matching_matrix = torch.zeros_like(cost, dtype=torch.uint8) |
| |
| candidate_topk = min(self.candidate_topk, pairwise_ious.size(0)) |
| topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0) |
| |
| dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1) |
| for gt_idx in range(num_gt): |
| _, pos_idx = torch.topk( |
| cost[:, gt_idx], k=dynamic_ks[gt_idx], largest=False) |
| matching_matrix[:, gt_idx][pos_idx] = 1 |
|
|
| del topk_ious, dynamic_ks, pos_idx |
|
|
| prior_match_gt_mask = matching_matrix.sum(1) > 1 |
| if prior_match_gt_mask.sum() > 0: |
| cost_min, cost_argmin = torch.min( |
| cost[prior_match_gt_mask, :], dim=1) |
| matching_matrix[prior_match_gt_mask, :] *= 0 |
| matching_matrix[prior_match_gt_mask, cost_argmin] = 1 |
| |
| fg_mask_inboxes = matching_matrix.sum(1) > 0 |
| valid_mask[valid_mask.clone()] = fg_mask_inboxes |
|
|
| matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1) |
| matched_pred_ious = (matching_matrix * |
| pairwise_ious).sum(1)[fg_mask_inboxes] |
| return matched_pred_ious, matched_gt_inds |
|
|