| |
| from typing import Optional |
|
|
| import torch |
| from mmengine.structures import InstanceData |
|
|
| from mmdet.models.task_modules.assigners.assign_result import AssignResult |
| from mmdet.models.task_modules.assigners.max_iou_assigner import MaxIoUAssigner |
| from mmdet.registry import TASK_UTILS |
|
|
|
|
| @TASK_UTILS.register_module() |
| class TransMaxIoUAssigner(MaxIoUAssigner): |
|
|
| def assign(self, |
| pred_instances: InstanceData, |
| gt_instances: InstanceData, |
| gt_instances_ignore: Optional[InstanceData] = None, |
| **kwargs) -> AssignResult: |
| """Assign gt to bboxes. |
| |
| This method assign a gt bbox to every bbox (proposal/anchor), each bbox |
| will be assigned with -1, or a semi-positive number. -1 means negative |
| sample, semi-positive number is the index (0-based) of assigned gt. |
| The assignment is done in following steps, the order matters. |
| |
| 1. assign every bbox to the background |
| 2. assign proposals whose iou with all gts < neg_iou_thr to 0 |
| 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr, |
| assign it to that bbox |
| 4. for each gt bbox, assign its nearest proposals (may be more than |
| one) to itself |
| |
| Args: |
| pred_instances (:obj:`InstanceData`): Instances of model |
| predictions. It includes ``priors``, and the priors can |
| be anchors or points, or the bboxes predicted by the |
| previous stage, has shape (n, 4). The bboxes predicted by |
| the current model or stage will be named ``bboxes``, |
| ``labels``, and ``scores``, the same as the ``InstanceData`` |
| in other places. |
| gt_instances (:obj:`InstanceData`): Ground truth of instance |
| annotations. It usually includes ``bboxes``, with shape (k, 4), |
| and ``labels``, with shape (k, ). |
| gt_instances_ignore (:obj:`InstanceData`, optional): Instances |
| to be ignored during training. It includes ``bboxes`` |
| attribute data that is ignored during training and testing. |
| Defaults to None. |
| |
| Returns: |
| :obj:`AssignResult`: The assign result. |
| |
| Example: |
| >>> from mmengine.structures import InstanceData |
| >>> self = MaxIoUAssigner(0.5, 0.5) |
| >>> pred_instances = InstanceData() |
| >>> pred_instances.priors = torch.Tensor([[0, 0, 10, 10], |
| ... [10, 10, 20, 20]]) |
| >>> gt_instances = InstanceData() |
| >>> gt_instances.bboxes = torch.Tensor([[0, 0, 10, 9]]) |
| >>> gt_instances.labels = torch.Tensor([0]) |
| >>> assign_result = self.assign(pred_instances, gt_instances) |
| >>> expected_gt_inds = torch.LongTensor([1, 0]) |
| >>> assert torch.all(assign_result.gt_inds == expected_gt_inds) |
| """ |
| gt_bboxes = gt_instances.bboxes |
| priors = pred_instances.priors |
| gt_labels = gt_instances.labels |
| if gt_instances_ignore is not None: |
| gt_bboxes_ignore = gt_instances_ignore.bboxes |
| else: |
| gt_bboxes_ignore = None |
|
|
| assign_on_cpu = True if (self.gpu_assign_thr > 0) and ( |
| gt_bboxes.shape[0] > self.gpu_assign_thr) else False |
| |
| if assign_on_cpu: |
| device = priors.device |
| priors = priors.cpu() |
| gt_bboxes = gt_bboxes.cpu() |
| gt_labels = gt_labels.cpu() |
| if gt_bboxes_ignore is not None: |
| gt_bboxes_ignore = gt_bboxes_ignore.cpu() |
|
|
| trans_priors = torch.cat([ |
| priors[..., 1].view(-1, 1), priors[..., 0].view(-1, 1), |
| priors[..., 3].view(-1, 1), priors[..., 2].view(-1, 1) |
| ], |
| dim=-1) |
| overlaps = self.iou_calculator(gt_bboxes, trans_priors) |
|
|
| if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None |
| and gt_bboxes_ignore.numel() > 0 and trans_priors.numel() > 0): |
| if self.ignore_wrt_candidates: |
| ignore_overlaps = self.iou_calculator( |
| trans_priors, gt_bboxes_ignore, mode='iof') |
| ignore_max_overlaps, _ = ignore_overlaps.max(dim=1) |
| else: |
| ignore_overlaps = self.iou_calculator( |
| gt_bboxes_ignore, trans_priors, mode='iof') |
| ignore_max_overlaps, _ = ignore_overlaps.max(dim=0) |
| overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1 |
|
|
| assign_result = self.assign_wrt_overlaps(overlaps, gt_labels) |
| if assign_on_cpu: |
| assign_result.gt_inds = assign_result.gt_inds.to(device) |
| assign_result.max_overlaps = assign_result.max_overlaps.to(device) |
| if assign_result.labels is not None: |
| assign_result.labels = assign_result.labels.to(device) |
| return assign_result |
|
|