| |
| from typing import Optional |
|
|
| import torch |
| from mmengine.structures import InstanceData |
|
|
| from mmdet.registry import TASK_UTILS |
| from mmdet.structures.bbox import bbox_xyxy_to_cxcywh |
| from mmdet.utils import ConfigType |
| from .assign_result import AssignResult |
| from .base_assigner import BaseAssigner |
|
|
|
|
| @TASK_UTILS.register_module() |
| class UniformAssigner(BaseAssigner): |
| """Uniform Matching between the priors and gt boxes, which can achieve |
| balance in positive priors, and gt_bboxes_ignore was not considered for |
| now. |
| |
| Args: |
| pos_ignore_thr (float): the threshold to ignore positive priors |
| neg_ignore_thr (float): the threshold to ignore negative priors |
| match_times(int): Number of positive priors for each gt box. |
| Defaults to 4. |
| iou_calculator (:obj:`ConfigDict` or dict): Config dict for iou |
| calculator. Defaults to ``dict(type='BboxOverlaps2D')`` |
| """ |
|
|
| def __init__(self, |
| pos_ignore_thr: float, |
| neg_ignore_thr: float, |
| match_times: int = 4, |
| iou_calculator: ConfigType = dict(type='BboxOverlaps2D')): |
| self.match_times = match_times |
| self.pos_ignore_thr = pos_ignore_thr |
| self.neg_ignore_thr = neg_ignore_thr |
| self.iou_calculator = TASK_UTILS.build(iou_calculator) |
|
|
| def assign( |
| self, |
| pred_instances: InstanceData, |
| gt_instances: InstanceData, |
| gt_instances_ignore: Optional[InstanceData] = None |
| ) -> AssignResult: |
| """Assign gt to priors. |
| |
| The assignment is done in following steps |
| |
| 1. assign -1 by default |
| 2. compute the L1 cost between boxes. Note that we use priors and |
| predict boxes both |
| 3. compute the ignore indexes use gt_bboxes and predict boxes |
| 4. compute the ignore indexes of positive sample use priors and |
| predict boxes |
| |
| |
| Args: |
| pred_instances (:obj:`InstaceData`): Instances of model |
| predictions. It includes ``priors``, and the priors can |
| be priors, points, or bboxes predicted by the model, |
| shape(n, 4). |
| gt_instances (:obj:`InstaceData`): Ground truth of instance |
| annotations. It usually includes ``bboxes`` and ``labels`` |
| attributes. |
| gt_instances_ignore (:obj:`InstaceData`, optional): Instances |
| to be ignored during training. It includes ``bboxes`` |
| attribute data that is ignored during training and testing. |
| Defaults to None. |
| |
| Returns: |
| :obj:`AssignResult`: The assign result. |
| """ |
|
|
| gt_bboxes = gt_instances.bboxes |
| gt_labels = gt_instances.labels |
| priors = pred_instances.priors |
| bbox_pred = pred_instances.decoder_priors |
|
|
| num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0) |
|
|
| |
| assigned_gt_inds = bbox_pred.new_full((num_bboxes, ), |
| 0, |
| dtype=torch.long) |
| assigned_labels = bbox_pred.new_full((num_bboxes, ), |
| -1, |
| dtype=torch.long) |
| if num_gts == 0 or num_bboxes == 0: |
| |
| if num_gts == 0: |
| |
| assigned_gt_inds[:] = 0 |
| assign_result = AssignResult( |
| num_gts, assigned_gt_inds, None, labels=assigned_labels) |
| assign_result.set_extra_property( |
| 'pos_idx', bbox_pred.new_empty(0, dtype=torch.bool)) |
| assign_result.set_extra_property('pos_predicted_boxes', |
| bbox_pred.new_empty((0, 4))) |
| assign_result.set_extra_property('target_boxes', |
| bbox_pred.new_empty((0, 4))) |
| return assign_result |
|
|
| |
| |
| cost_bbox = torch.cdist( |
| bbox_xyxy_to_cxcywh(bbox_pred), |
| bbox_xyxy_to_cxcywh(gt_bboxes), |
| p=1) |
| cost_bbox_priors = torch.cdist( |
| bbox_xyxy_to_cxcywh(priors), bbox_xyxy_to_cxcywh(gt_bboxes), p=1) |
|
|
| |
| |
| |
| |
| C = cost_bbox.cpu() |
| C1 = cost_bbox_priors.cpu() |
|
|
| |
| index = torch.topk( |
| C, |
| k=self.match_times, |
| dim=0, |
| largest=False)[1] |
|
|
| |
| index1 = torch.topk(C1, k=self.match_times, dim=0, largest=False)[1] |
| |
| indexes = torch.cat((index, index1), |
| dim=1).reshape(-1).to(bbox_pred.device) |
|
|
| pred_overlaps = self.iou_calculator(bbox_pred, gt_bboxes) |
| anchor_overlaps = self.iou_calculator(priors, gt_bboxes) |
| pred_max_overlaps, _ = pred_overlaps.max(dim=1) |
| anchor_max_overlaps, _ = anchor_overlaps.max(dim=0) |
|
|
| |
| ignore_idx = pred_max_overlaps > self.neg_ignore_thr |
| assigned_gt_inds[ignore_idx] = -1 |
|
|
| |
| |
| pos_gt_index = torch.arange( |
| 0, C1.size(1), |
| device=bbox_pred.device).repeat(self.match_times * 2) |
| pos_ious = anchor_overlaps[indexes, pos_gt_index] |
| pos_ignore_idx = pos_ious < self.pos_ignore_thr |
|
|
| pos_gt_index_with_ignore = pos_gt_index + 1 |
| pos_gt_index_with_ignore[pos_ignore_idx] = -1 |
| assigned_gt_inds[indexes] = pos_gt_index_with_ignore |
|
|
| if gt_labels is not None: |
| assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) |
| pos_inds = torch.nonzero( |
| assigned_gt_inds > 0, as_tuple=False).squeeze() |
| if pos_inds.numel() > 0: |
| assigned_labels[pos_inds] = gt_labels[ |
| assigned_gt_inds[pos_inds] - 1] |
| else: |
| assigned_labels = None |
|
|
| assign_result = AssignResult( |
| num_gts, |
| assigned_gt_inds, |
| anchor_max_overlaps, |
| labels=assigned_labels) |
| assign_result.set_extra_property('pos_idx', ~pos_ignore_idx) |
| assign_result.set_extra_property('pos_predicted_boxes', |
| bbox_pred[indexes]) |
| assign_result.set_extra_property('target_boxes', |
| gt_bboxes[pos_gt_index]) |
| return assign_result |
|
|