Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Union | |
| import torch | |
| from numpy import ndarray | |
| from torch import Tensor | |
| from mmdet.registry import TASK_UTILS | |
| from ..assigners import AssignResult | |
| from .base_sampler import BaseSampler | |
| class RandomSampler(BaseSampler): | |
| """Random sampler. | |
| Args: | |
| num (int): Number of samples | |
| pos_fraction (float): Fraction of positive samples | |
| neg_pos_up (int): Upper bound number of negative and | |
| positive samples. Defaults to -1. | |
| add_gt_as_proposals (bool): Whether to add ground truth | |
| boxes as proposals. Defaults to True. | |
| """ | |
| def __init__(self, | |
| num: int, | |
| pos_fraction: float, | |
| neg_pos_ub: int = -1, | |
| add_gt_as_proposals: bool = True, | |
| **kwargs): | |
| from .sampling_result import ensure_rng | |
| super().__init__( | |
| num=num, | |
| pos_fraction=pos_fraction, | |
| neg_pos_ub=neg_pos_ub, | |
| add_gt_as_proposals=add_gt_as_proposals) | |
| self.rng = ensure_rng(kwargs.get('rng', None)) | |
| def random_choice(self, gallery: Union[Tensor, ndarray, list], | |
| num: int) -> Union[Tensor, ndarray]: | |
| """Random select some elements from the gallery. | |
| If `gallery` is a Tensor, the returned indices will be a Tensor; | |
| If `gallery` is a ndarray or list, the returned indices will be a | |
| ndarray. | |
| Args: | |
| gallery (Tensor | ndarray | list): indices pool. | |
| num (int): expected sample num. | |
| Returns: | |
| Tensor or ndarray: sampled indices. | |
| """ | |
| assert len(gallery) >= num | |
| is_tensor = isinstance(gallery, torch.Tensor) | |
| if not is_tensor: | |
| if torch.cuda.is_available(): | |
| device = torch.cuda.current_device() | |
| else: | |
| device = 'cpu' | |
| gallery = torch.tensor(gallery, dtype=torch.long, device=device) | |
| # This is a temporary fix. We can revert the following code | |
| # when PyTorch fixes the abnormal return of torch.randperm. | |
| # See: https://github.com/open-mmlab/mmdetection/pull/5014 | |
| perm = torch.randperm(gallery.numel())[:num].to(device=gallery.device) | |
| rand_inds = gallery[perm] | |
| if not is_tensor: | |
| rand_inds = rand_inds.cpu().numpy() | |
| return rand_inds | |
| def _sample_pos(self, assign_result: AssignResult, num_expected: int, | |
| **kwargs) -> Union[Tensor, ndarray]: | |
| """Randomly sample some positive samples. | |
| Args: | |
| assign_result (:obj:`AssignResult`): Bbox assigning results. | |
| num_expected (int): The number of expected positive samples | |
| Returns: | |
| Tensor or ndarray: sampled indices. | |
| """ | |
| pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) | |
| if pos_inds.numel() != 0: | |
| pos_inds = pos_inds.squeeze(1) | |
| if pos_inds.numel() <= num_expected: | |
| return pos_inds | |
| else: | |
| return self.random_choice(pos_inds, num_expected) | |
| def _sample_neg(self, assign_result: AssignResult, num_expected: int, | |
| **kwargs) -> Union[Tensor, ndarray]: | |
| """Randomly sample some negative samples. | |
| Args: | |
| assign_result (:obj:`AssignResult`): Bbox assigning results. | |
| num_expected (int): The number of expected positive samples | |
| Returns: | |
| Tensor or ndarray: sampled indices. | |
| """ | |
| neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False) | |
| if neg_inds.numel() != 0: | |
| neg_inds = neg_inds.squeeze(1) | |
| if len(neg_inds) <= num_expected: | |
| return neg_inds | |
| else: | |
| return self.random_choice(neg_inds, num_expected) | |