Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Union | |
| import torch | |
| from mmengine.structures import InstanceData | |
| from numpy import ndarray | |
| from torch import Tensor | |
| from mmdet.registry import TASK_UTILS | |
| from ..assigners import AssignResult | |
| from .multi_instance_sampling_result import MultiInstanceSamplingResult | |
| from .random_sampler import RandomSampler | |
| class MultiInsRandomSampler(RandomSampler): | |
| """Random sampler for multi instance. | |
| Note: | |
| Multi-instance means to predict multiple detection boxes with | |
| one proposal box. `AssignResult` may assign multiple gt boxes | |
| to each proposal box, in this case `RandomSampler` should be | |
| replaced by `MultiInsRandomSampler` | |
| """ | |
| def _sample_pos(self, assign_result: AssignResult, num_expected: int, | |
| **kwargs) -> Union[Tensor, ndarray]: | |
| """Randomly sample some positive samples. | |
| Args: | |
| assign_result (:obj:`AssignResult`): Bbox assigning results. | |
| num_expected (int): The number of expected positive samples | |
| Returns: | |
| Tensor or ndarray: sampled indices. | |
| """ | |
| pos_inds = torch.nonzero( | |
| assign_result.labels[:, 0] > 0, as_tuple=False) | |
| if pos_inds.numel() != 0: | |
| pos_inds = pos_inds.squeeze(1) | |
| if pos_inds.numel() <= num_expected: | |
| return pos_inds | |
| else: | |
| return self.random_choice(pos_inds, num_expected) | |
| def _sample_neg(self, assign_result: AssignResult, num_expected: int, | |
| **kwargs) -> Union[Tensor, ndarray]: | |
| """Randomly sample some negative samples. | |
| Args: | |
| assign_result (:obj:`AssignResult`): Bbox assigning results. | |
| num_expected (int): The number of expected positive samples | |
| Returns: | |
| Tensor or ndarray: sampled indices. | |
| """ | |
| neg_inds = torch.nonzero( | |
| assign_result.labels[:, 0] == 0, as_tuple=False) | |
| if neg_inds.numel() != 0: | |
| neg_inds = neg_inds.squeeze(1) | |
| if len(neg_inds) <= num_expected: | |
| return neg_inds | |
| else: | |
| return self.random_choice(neg_inds, num_expected) | |
| def sample(self, assign_result: AssignResult, pred_instances: InstanceData, | |
| gt_instances: InstanceData, | |
| **kwargs) -> MultiInstanceSamplingResult: | |
| """Sample positive and negative bboxes. | |
| Args: | |
| assign_result (:obj:`AssignResult`): Assigning results from | |
| MultiInstanceAssigner. | |
| pred_instances (:obj:`InstanceData`): Instances of model | |
| predictions. It includes ``priors``, and the priors can | |
| be anchors or points, or the bboxes predicted by the | |
| previous stage, has shape (n, 4). The bboxes predicted by | |
| the current model or stage will be named ``bboxes``, | |
| ``labels``, and ``scores``, the same as the ``InstanceData`` | |
| in other places. | |
| gt_instances (:obj:`InstanceData`): Ground truth of instance | |
| annotations. It usually includes ``bboxes``, with shape (k, 4), | |
| and ``labels``, with shape (k, ). | |
| Returns: | |
| :obj:`MultiInstanceSamplingResult`: Sampling result. | |
| """ | |
| assert 'batch_gt_instances_ignore' in kwargs, \ | |
| 'batch_gt_instances_ignore is necessary for MultiInsRandomSampler' | |
| gt_bboxes = gt_instances.bboxes | |
| ignore_bboxes = kwargs['batch_gt_instances_ignore'].bboxes | |
| gt_and_ignore_bboxes = torch.cat([gt_bboxes, ignore_bboxes], dim=0) | |
| priors = pred_instances.priors | |
| if len(priors.shape) < 2: | |
| priors = priors[None, :] | |
| priors = priors[:, :4] | |
| gt_flags = priors.new_zeros((priors.shape[0], ), dtype=torch.uint8) | |
| priors = torch.cat([priors, gt_and_ignore_bboxes], dim=0) | |
| gt_ones = priors.new_ones( | |
| gt_and_ignore_bboxes.shape[0], dtype=torch.uint8) | |
| gt_flags = torch.cat([gt_flags, gt_ones]) | |
| num_expected_pos = int(self.num * self.pos_fraction) | |
| pos_inds = self.pos_sampler._sample_pos(assign_result, | |
| num_expected_pos) | |
| # We found that sampled indices have duplicated items occasionally. | |
| # (may be a bug of PyTorch) | |
| pos_inds = pos_inds.unique() | |
| num_sampled_pos = pos_inds.numel() | |
| num_expected_neg = self.num - num_sampled_pos | |
| if self.neg_pos_ub >= 0: | |
| _pos = max(1, num_sampled_pos) | |
| neg_upper_bound = int(self.neg_pos_ub * _pos) | |
| if num_expected_neg > neg_upper_bound: | |
| num_expected_neg = neg_upper_bound | |
| neg_inds = self.neg_sampler._sample_neg(assign_result, | |
| num_expected_neg) | |
| neg_inds = neg_inds.unique() | |
| sampling_result = MultiInstanceSamplingResult( | |
| pos_inds=pos_inds, | |
| neg_inds=neg_inds, | |
| priors=priors, | |
| gt_and_ignore_bboxes=gt_and_ignore_bboxes, | |
| assign_result=assign_result, | |
| gt_flags=gt_flags) | |
| return sampling_result | |