Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import numpy as np | |
| import torch | |
| from mmdet.registry import TASK_UTILS | |
| from .random_sampler import RandomSampler | |
| class InstanceBalancedPosSampler(RandomSampler): | |
| """Instance balanced sampler that samples equal number of positive samples | |
| for each instance.""" | |
| def _sample_pos(self, assign_result, num_expected, **kwargs): | |
| """Sample positive boxes. | |
| Args: | |
| assign_result (:obj:`AssignResult`): The assigned results of boxes. | |
| num_expected (int): The number of expected positive samples | |
| Returns: | |
| Tensor or ndarray: sampled indices. | |
| """ | |
| pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) | |
| if pos_inds.numel() != 0: | |
| pos_inds = pos_inds.squeeze(1) | |
| if pos_inds.numel() <= num_expected: | |
| return pos_inds | |
| else: | |
| unique_gt_inds = assign_result.gt_inds[pos_inds].unique() | |
| num_gts = len(unique_gt_inds) | |
| num_per_gt = int(round(num_expected / float(num_gts)) + 1) | |
| sampled_inds = [] | |
| for i in unique_gt_inds: | |
| inds = torch.nonzero( | |
| assign_result.gt_inds == i.item(), as_tuple=False) | |
| if inds.numel() != 0: | |
| inds = inds.squeeze(1) | |
| else: | |
| continue | |
| if len(inds) > num_per_gt: | |
| inds = self.random_choice(inds, num_per_gt) | |
| sampled_inds.append(inds) | |
| sampled_inds = torch.cat(sampled_inds) | |
| if len(sampled_inds) < num_expected: | |
| num_extra = num_expected - len(sampled_inds) | |
| extra_inds = np.array( | |
| list(set(pos_inds.cpu()) - set(sampled_inds.cpu()))) | |
| if len(extra_inds) > num_extra: | |
| extra_inds = self.random_choice(extra_inds, num_extra) | |
| extra_inds = torch.from_numpy(extra_inds).to( | |
| assign_result.gt_inds.device).long() | |
| sampled_inds = torch.cat([sampled_inds, extra_inds]) | |
| elif len(sampled_inds) > num_expected: | |
| sampled_inds = self.random_choice(sampled_inds, num_expected) | |
| return sampled_inds | |