Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional | |
| import torch | |
| from mmengine.structures import InstanceData | |
| from mmdet.registry import TASK_UTILS | |
| from .assign_result import AssignResult | |
| from .base_assigner import BaseAssigner | |
| class PointAssigner(BaseAssigner): | |
| """Assign a corresponding gt bbox or background to each point. | |
| Each proposals will be assigned with `0`, or a positive integer | |
| indicating the ground truth index. | |
| - 0: negative sample, no assigned gt | |
| - positive integer: positive sample, index (1-based) of assigned gt | |
| """ | |
| def __init__(self, scale: int = 4, pos_num: int = 3) -> None: | |
| self.scale = scale | |
| self.pos_num = pos_num | |
| def assign(self, | |
| pred_instances: InstanceData, | |
| gt_instances: InstanceData, | |
| gt_instances_ignore: Optional[InstanceData] = None, | |
| **kwargs) -> AssignResult: | |
| """Assign gt to points. | |
| This method assign a gt bbox to every points set, each points set | |
| will be assigned with the background_label (-1), or a label number. | |
| -1 is background, and semi-positive number is the index (0-based) of | |
| assigned gt. | |
| The assignment is done in following steps, the order matters. | |
| 1. assign every points to the background_label (-1) | |
| 2. A point is assigned to some gt bbox if | |
| (i) the point is within the k closest points to the gt bbox | |
| (ii) the distance between this point and the gt is smaller than | |
| other gt bboxes | |
| Args: | |
| pred_instances (:obj:`InstanceData`): Instances of model | |
| predictions. It includes ``priors``, and the priors can | |
| be anchors or points, or the bboxes predicted by the | |
| previous stage, has shape (n, 4). The bboxes predicted by | |
| the current model or stage will be named ``bboxes``, | |
| ``labels``, and ``scores``, the same as the ``InstanceData`` | |
| in other places. | |
| gt_instances (:obj:`InstanceData`): Ground truth of instance | |
| annotations. It usually includes ``bboxes``, with shape (k, 4), | |
| and ``labels``, with shape (k, ). | |
| gt_instances_ignore (:obj:`InstanceData`, optional): Instances | |
| to be ignored during training. It includes ``bboxes`` | |
| attribute data that is ignored during training and testing. | |
| Defaults to None. | |
| Returns: | |
| :obj:`AssignResult`: The assign result. | |
| """ | |
| gt_bboxes = gt_instances.bboxes | |
| gt_labels = gt_instances.labels | |
| # points to be assigned, shape(n, 3) while last | |
| # dimension stands for (x, y, stride). | |
| points = pred_instances.priors | |
| num_points = points.shape[0] | |
| num_gts = gt_bboxes.shape[0] | |
| if num_gts == 0 or num_points == 0: | |
| # If no truth assign everything to the background | |
| assigned_gt_inds = points.new_full((num_points, ), | |
| 0, | |
| dtype=torch.long) | |
| assigned_labels = points.new_full((num_points, ), | |
| -1, | |
| dtype=torch.long) | |
| return AssignResult( | |
| num_gts=num_gts, | |
| gt_inds=assigned_gt_inds, | |
| max_overlaps=None, | |
| labels=assigned_labels) | |
| points_xy = points[:, :2] | |
| points_stride = points[:, 2] | |
| points_lvl = torch.log2( | |
| points_stride).int() # [3...,4...,5...,6...,7...] | |
| lvl_min, lvl_max = points_lvl.min(), points_lvl.max() | |
| # assign gt box | |
| gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2 | |
| gt_bboxes_wh = (gt_bboxes[:, 2:] - gt_bboxes[:, :2]).clamp(min=1e-6) | |
| scale = self.scale | |
| gt_bboxes_lvl = ((torch.log2(gt_bboxes_wh[:, 0] / scale) + | |
| torch.log2(gt_bboxes_wh[:, 1] / scale)) / 2).int() | |
| gt_bboxes_lvl = torch.clamp(gt_bboxes_lvl, min=lvl_min, max=lvl_max) | |
| # stores the assigned gt index of each point | |
| assigned_gt_inds = points.new_zeros((num_points, ), dtype=torch.long) | |
| # stores the assigned gt dist (to this point) of each point | |
| assigned_gt_dist = points.new_full((num_points, ), float('inf')) | |
| points_range = torch.arange(points.shape[0]) | |
| for idx in range(num_gts): | |
| gt_lvl = gt_bboxes_lvl[idx] | |
| # get the index of points in this level | |
| lvl_idx = gt_lvl == points_lvl | |
| points_index = points_range[lvl_idx] | |
| # get the points in this level | |
| lvl_points = points_xy[lvl_idx, :] | |
| # get the center point of gt | |
| gt_point = gt_bboxes_xy[[idx], :] | |
| # get width and height of gt | |
| gt_wh = gt_bboxes_wh[[idx], :] | |
| # compute the distance between gt center and | |
| # all points in this level | |
| points_gt_dist = ((lvl_points - gt_point) / gt_wh).norm(dim=1) | |
| # find the nearest k points to gt center in this level | |
| min_dist, min_dist_index = torch.topk( | |
| points_gt_dist, self.pos_num, largest=False) | |
| # the index of nearest k points to gt center in this level | |
| min_dist_points_index = points_index[min_dist_index] | |
| # The less_than_recorded_index stores the index | |
| # of min_dist that is less then the assigned_gt_dist. Where | |
| # assigned_gt_dist stores the dist from previous assigned gt | |
| # (if exist) to each point. | |
| less_than_recorded_index = min_dist < assigned_gt_dist[ | |
| min_dist_points_index] | |
| # The min_dist_points_index stores the index of points satisfy: | |
| # (1) it is k nearest to current gt center in this level. | |
| # (2) it is closer to current gt center than other gt center. | |
| min_dist_points_index = min_dist_points_index[ | |
| less_than_recorded_index] | |
| # assign the result | |
| assigned_gt_inds[min_dist_points_index] = idx + 1 | |
| assigned_gt_dist[min_dist_points_index] = min_dist[ | |
| less_than_recorded_index] | |
| assigned_labels = assigned_gt_inds.new_full((num_points, ), -1) | |
| pos_inds = torch.nonzero( | |
| assigned_gt_inds > 0, as_tuple=False).squeeze() | |
| if pos_inds.numel() > 0: | |
| assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - | |
| 1] | |
| return AssignResult( | |
| num_gts=num_gts, | |
| gt_inds=assigned_gt_inds, | |
| max_overlaps=None, | |
| labels=assigned_labels) | |