| | |
| | """copy from |
| | https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py.""" |
| |
|
| | import torch |
| | from mmengine.structures import InstanceData |
| |
|
| | from mmdet.registry import TASK_UTILS |
| | from ..assigners import AssignResult |
| | from .base_sampler import BaseSampler |
| | from .mask_sampling_result import MaskSamplingResult |
| |
|
| |
|
| | @TASK_UTILS.register_module() |
| | class MaskPseudoSampler(BaseSampler): |
| | """A pseudo sampler that does not do sampling actually.""" |
| |
|
| | def __init__(self, **kwargs): |
| | pass |
| |
|
| | def _sample_pos(self, **kwargs): |
| | """Sample positive samples.""" |
| | raise NotImplementedError |
| |
|
| | def _sample_neg(self, **kwargs): |
| | """Sample negative samples.""" |
| | raise NotImplementedError |
| |
|
| | def sample(self, assign_result: AssignResult, pred_instances: InstanceData, |
| | gt_instances: InstanceData, *args, **kwargs): |
| | """Directly returns the positive and negative indices of samples. |
| | |
| | Args: |
| | assign_result (:obj:`AssignResult`): Mask assigning results. |
| | pred_instances (:obj:`InstanceData`): Instances of model |
| | predictions. It includes ``scores`` and ``masks`` predicted |
| | by the model. |
| | gt_instances (:obj:`InstanceData`): Ground truth of instance |
| | annotations. It usually includes ``labels`` and ``masks`` |
| | attributes. |
| | |
| | Returns: |
| | :obj:`SamplingResult`: sampler results |
| | """ |
| | pred_masks = pred_instances.masks |
| | gt_masks = gt_instances.masks |
| | pos_inds = torch.nonzero( |
| | assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() |
| | neg_inds = torch.nonzero( |
| | assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() |
| | gt_flags = pred_masks.new_zeros(pred_masks.shape[0], dtype=torch.uint8) |
| | sampling_result = MaskSamplingResult( |
| | pos_inds=pos_inds, |
| | neg_inds=neg_inds, |
| | masks=pred_masks, |
| | gt_masks=gt_masks, |
| | assign_result=assign_result, |
| | gt_flags=gt_flags, |
| | avg_factor_with_neg=False) |
| | return sampling_result |
| |
|