| | import copy |
| | import warnings |
| |
|
| | import numpy as np |
| | import torch |
| | from mmcv import ConfigDict |
| | from mmcv.ops import nms |
| |
|
| | from ..bbox import bbox_mapping_back |
| |
|
| |
|
| | def merge_aug_proposals(aug_proposals, img_metas, cfg): |
| | """Merge augmented proposals (multiscale, flip, etc.) |
| | |
| | Args: |
| | aug_proposals (list[Tensor]): proposals from different testing |
| | schemes, shape (n, 5). Note that they are not rescaled to the |
| | original image size. |
| | |
| | img_metas (list[dict]): list of image info dict where each dict has: |
| | 'img_shape', 'scale_factor', 'flip', and may also contain |
| | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. |
| | For details on the values of these keys see |
| | `mmdet/datasets/pipelines/formatting.py:Collect`. |
| | |
| | cfg (dict): rpn test config. |
| | |
| | Returns: |
| | Tensor: shape (n, 4), proposals corresponding to original image scale. |
| | """ |
| |
|
| | cfg = copy.deepcopy(cfg) |
| |
|
| | |
| | if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg: |
| | warnings.warn( |
| | 'In rpn_proposal or test_cfg, ' |
| | 'nms_thr has been moved to a dict named nms as ' |
| | 'iou_threshold, max_num has been renamed as max_per_img, ' |
| | 'name of original arguments and the way to specify ' |
| | 'iou_threshold of NMS will be deprecated.') |
| | if 'nms' not in cfg: |
| | cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr)) |
| | if 'max_num' in cfg: |
| | if 'max_per_img' in cfg: |
| | assert cfg.max_num == cfg.max_per_img, f'You set max_num and ' \ |
| | f'max_per_img at the same time, but get {cfg.max_num} ' \ |
| | f'and {cfg.max_per_img} respectively' \ |
| | f'Please delete max_num which will be deprecated.' |
| | else: |
| | cfg.max_per_img = cfg.max_num |
| | if 'nms_thr' in cfg: |
| | assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set ' \ |
| | f'iou_threshold in nms and ' \ |
| | f'nms_thr at the same time, but get ' \ |
| | f'{cfg.nms.iou_threshold} and {cfg.nms_thr}' \ |
| | f' respectively. Please delete the nms_thr ' \ |
| | f'which will be deprecated.' |
| |
|
| | recovered_proposals = [] |
| | for proposals, img_info in zip(aug_proposals, img_metas): |
| | img_shape = img_info['img_shape'] |
| | scale_factor = img_info['scale_factor'] |
| | flip = img_info['flip'] |
| | flip_direction = img_info['flip_direction'] |
| | _proposals = proposals.clone() |
| | _proposals[:, :4] = bbox_mapping_back(_proposals[:, :4], img_shape, |
| | scale_factor, flip, |
| | flip_direction) |
| | recovered_proposals.append(_proposals) |
| | aug_proposals = torch.cat(recovered_proposals, dim=0) |
| | merged_proposals, _ = nms(aug_proposals[:, :4].contiguous(), |
| | aug_proposals[:, -1].contiguous(), |
| | cfg.nms.iou_threshold) |
| | scores = merged_proposals[:, 4] |
| | _, order = scores.sort(0, descending=True) |
| | num = min(cfg.max_per_img, merged_proposals.shape[0]) |
| | order = order[:num] |
| | merged_proposals = merged_proposals[order, :] |
| | return merged_proposals |
| |
|
| |
|
| | def merge_aug_bboxes(aug_bboxes, aug_scores, img_metas, rcnn_test_cfg): |
| | """Merge augmented detection bboxes and scores. |
| | |
| | Args: |
| | aug_bboxes (list[Tensor]): shape (n, 4*#class) |
| | aug_scores (list[Tensor] or None): shape (n, #class) |
| | img_shapes (list[Tensor]): shape (3, ). |
| | rcnn_test_cfg (dict): rcnn test config. |
| | |
| | Returns: |
| | tuple: (bboxes, scores) |
| | """ |
| | recovered_bboxes = [] |
| | for bboxes, img_info in zip(aug_bboxes, img_metas): |
| | img_shape = img_info[0]['img_shape'] |
| | scale_factor = img_info[0]['scale_factor'] |
| | flip = img_info[0]['flip'] |
| | flip_direction = img_info[0]['flip_direction'] |
| | bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip, |
| | flip_direction) |
| | recovered_bboxes.append(bboxes) |
| | bboxes = torch.stack(recovered_bboxes).mean(dim=0) |
| | if aug_scores is None: |
| | return bboxes |
| | else: |
| | scores = torch.stack(aug_scores).mean(dim=0) |
| | return bboxes, scores |
| |
|
| |
|
| | def merge_aug_scores(aug_scores): |
| | """Merge augmented bbox scores.""" |
| | if isinstance(aug_scores[0], torch.Tensor): |
| | return torch.mean(torch.stack(aug_scores), dim=0) |
| | else: |
| | return np.mean(aug_scores, axis=0) |
| |
|
| |
|
| | def merge_aug_masks(aug_masks, img_metas, rcnn_test_cfg, weights=None): |
| | """Merge augmented mask prediction. |
| | |
| | Args: |
| | aug_masks (list[ndarray]): shape (n, #class, h, w) |
| | img_shapes (list[ndarray]): shape (3, ). |
| | rcnn_test_cfg (dict): rcnn test config. |
| | |
| | Returns: |
| | tuple: (bboxes, scores) |
| | """ |
| | recovered_masks = [] |
| | for mask, img_info in zip(aug_masks, img_metas): |
| | flip = img_info[0]['flip'] |
| | flip_direction = img_info[0]['flip_direction'] |
| | if flip: |
| | if flip_direction == 'horizontal': |
| | mask = mask[:, :, :, ::-1] |
| | elif flip_direction == 'vertical': |
| | mask = mask[:, :, ::-1, :] |
| | else: |
| | raise ValueError( |
| | f"Invalid flipping direction '{flip_direction}'") |
| | recovered_masks.append(mask) |
| |
|
| | if weights is None: |
| | merged_masks = np.mean(recovered_masks, axis=0) |
| | else: |
| | merged_masks = np.average( |
| | np.array(recovered_masks), axis=0, weights=np.array(weights)) |
| | return merged_masks |
| |
|