Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional | |
| import torch | |
| from mmengine.structures import InstanceData | |
| from mmdet.registry import TASK_UTILS | |
| from mmdet.utils import ConfigType | |
| from .assign_result import AssignResult | |
| from .base_assigner import BaseAssigner | |
| INF = 100000000 | |
| class TaskAlignedAssigner(BaseAssigner): | |
| """Task aligned assigner used in the paper: | |
| `TOOD: Task-aligned One-stage Object Detection. | |
| <https://arxiv.org/abs/2108.07755>`_. | |
| Assign a corresponding gt bbox or background to each predicted bbox. | |
| Each bbox will be assigned with `0` or a positive integer | |
| indicating the ground truth index. | |
| - 0: negative sample, no assigned gt | |
| - positive integer: positive sample, index (1-based) of assigned gt | |
| Args: | |
| topk (int): number of bbox selected in each level | |
| iou_calculator (:obj:`ConfigDict` or dict): Config dict for iou | |
| calculator. Defaults to ``dict(type='BboxOverlaps2D')`` | |
| """ | |
| def __init__(self, | |
| topk: int, | |
| iou_calculator: ConfigType = dict(type='BboxOverlaps2D')): | |
| assert topk >= 1 | |
| self.topk = topk | |
| self.iou_calculator = TASK_UTILS.build(iou_calculator) | |
| def assign(self, | |
| pred_instances: InstanceData, | |
| gt_instances: InstanceData, | |
| gt_instances_ignore: Optional[InstanceData] = None, | |
| alpha: int = 1, | |
| beta: int = 6) -> AssignResult: | |
| """Assign gt to bboxes. | |
| The assignment is done in following steps | |
| 1. compute alignment metric between all bbox (bbox of all pyramid | |
| levels) and gt | |
| 2. select top-k bbox as candidates for each gt | |
| 3. limit the positive sample's center in gt (because the anchor-free | |
| detector only can predict positive distance) | |
| Args: | |
| pred_instances (:obj:`InstaceData`): Instances of model | |
| predictions. It includes ``priors``, and the priors can | |
| be anchors, points, or bboxes predicted by the model, | |
| shape(n, 4). | |
| gt_instances (:obj:`InstaceData`): Ground truth of instance | |
| annotations. It usually includes ``bboxes`` and ``labels`` | |
| attributes. | |
| gt_instances_ignore (:obj:`InstaceData`, optional): Instances | |
| to be ignored during training. It includes ``bboxes`` | |
| attribute data that is ignored during training and testing. | |
| Defaults to None. | |
| alpha (int): Hyper-parameters related to alignment_metrics. | |
| Defaults to 1. | |
| beta (int): Hyper-parameters related to alignment_metrics. | |
| Defaults to 6. | |
| Returns: | |
| :obj:`TaskAlignedAssignResult`: The assign result. | |
| """ | |
| priors = pred_instances.priors | |
| decode_bboxes = pred_instances.bboxes | |
| pred_scores = pred_instances.scores | |
| gt_bboxes = gt_instances.bboxes | |
| gt_labels = gt_instances.labels | |
| priors = priors[:, :4] | |
| num_gt, num_bboxes = gt_bboxes.size(0), priors.size(0) | |
| # compute alignment metric between all bbox and gt | |
| overlaps = self.iou_calculator(decode_bboxes, gt_bboxes).detach() | |
| bbox_scores = pred_scores[:, gt_labels].detach() | |
| # assign 0 by default | |
| assigned_gt_inds = priors.new_full((num_bboxes, ), 0, dtype=torch.long) | |
| assign_metrics = priors.new_zeros((num_bboxes, )) | |
| if num_gt == 0 or num_bboxes == 0: | |
| # No ground truth or boxes, return empty assignment | |
| max_overlaps = priors.new_zeros((num_bboxes, )) | |
| if num_gt == 0: | |
| # No gt boxes, assign everything to background | |
| assigned_gt_inds[:] = 0 | |
| assigned_labels = priors.new_full((num_bboxes, ), | |
| -1, | |
| dtype=torch.long) | |
| assign_result = AssignResult( | |
| num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) | |
| assign_result.assign_metrics = assign_metrics | |
| return assign_result | |
| # select top-k bboxes as candidates for each gt | |
| alignment_metrics = bbox_scores**alpha * overlaps**beta | |
| topk = min(self.topk, alignment_metrics.size(0)) | |
| _, candidate_idxs = alignment_metrics.topk(topk, dim=0, largest=True) | |
| candidate_metrics = alignment_metrics[candidate_idxs, | |
| torch.arange(num_gt)] | |
| is_pos = candidate_metrics > 0 | |
| # limit the positive sample's center in gt | |
| priors_cx = (priors[:, 0] + priors[:, 2]) / 2.0 | |
| priors_cy = (priors[:, 1] + priors[:, 3]) / 2.0 | |
| for gt_idx in range(num_gt): | |
| candidate_idxs[:, gt_idx] += gt_idx * num_bboxes | |
| ep_priors_cx = priors_cx.view(1, -1).expand( | |
| num_gt, num_bboxes).contiguous().view(-1) | |
| ep_priors_cy = priors_cy.view(1, -1).expand( | |
| num_gt, num_bboxes).contiguous().view(-1) | |
| candidate_idxs = candidate_idxs.view(-1) | |
| # calculate the left, top, right, bottom distance between positive | |
| # bbox center and gt side | |
| l_ = ep_priors_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0] | |
| t_ = ep_priors_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1] | |
| r_ = gt_bboxes[:, 2] - ep_priors_cx[candidate_idxs].view(-1, num_gt) | |
| b_ = gt_bboxes[:, 3] - ep_priors_cy[candidate_idxs].view(-1, num_gt) | |
| is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01 | |
| is_pos = is_pos & is_in_gts | |
| # if an anchor box is assigned to multiple gts, | |
| # the one with the highest iou will be selected. | |
| overlaps_inf = torch.full_like(overlaps, | |
| -INF).t().contiguous().view(-1) | |
| index = candidate_idxs.view(-1)[is_pos.view(-1)] | |
| overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index] | |
| overlaps_inf = overlaps_inf.view(num_gt, -1).t() | |
| max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1) | |
| assigned_gt_inds[ | |
| max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1 | |
| assign_metrics[max_overlaps != -INF] = alignment_metrics[ | |
| max_overlaps != -INF, argmax_overlaps[max_overlaps != -INF]] | |
| assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) | |
| pos_inds = torch.nonzero( | |
| assigned_gt_inds > 0, as_tuple=False).squeeze() | |
| if pos_inds.numel() > 0: | |
| assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - | |
| 1] | |
| assign_result = AssignResult( | |
| num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) | |
| assign_result.assign_metrics = assign_metrics | |
| return assign_result | |