| import torch |
|
|
| from mmdet.core import bbox2result, bbox_mapping_back |
| from ..builder import DETECTORS |
| from .single_stage import SingleStageDetector |
|
|
|
|
| @DETECTORS.register_module() |
| class CornerNet(SingleStageDetector): |
| """CornerNet. |
| |
| This detector is the implementation of the paper `CornerNet: Detecting |
| Objects as Paired Keypoints <https://arxiv.org/abs/1808.01244>`_ . |
| """ |
|
|
| def __init__(self, |
| backbone, |
| neck, |
| bbox_head, |
| train_cfg=None, |
| test_cfg=None, |
| pretrained=None): |
| super(CornerNet, self).__init__(backbone, neck, bbox_head, train_cfg, |
| test_cfg, pretrained) |
|
|
| def merge_aug_results(self, aug_results, img_metas): |
| """Merge augmented detection bboxes and score. |
| |
| Args: |
| aug_results (list[list[Tensor]]): Det_bboxes and det_labels of each |
| image. |
| img_metas (list[list[dict]]): Meta information of each image, e.g., |
| image size, scaling factor, etc. |
| |
| Returns: |
| tuple: (bboxes, labels) |
| """ |
| recovered_bboxes, aug_labels = [], [] |
| for bboxes_labels, img_info in zip(aug_results, img_metas): |
| img_shape = img_info[0]['img_shape'] |
| scale_factor = img_info[0]['scale_factor'] |
| flip = img_info[0]['flip'] |
| bboxes, labels = bboxes_labels |
| bboxes, scores = bboxes[:, :4], bboxes[:, -1:] |
| bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip) |
| recovered_bboxes.append(torch.cat([bboxes, scores], dim=-1)) |
| aug_labels.append(labels) |
|
|
| bboxes = torch.cat(recovered_bboxes, dim=0) |
| labels = torch.cat(aug_labels) |
|
|
| if bboxes.shape[0] > 0: |
| out_bboxes, out_labels = self.bbox_head._bboxes_nms( |
| bboxes, labels, self.bbox_head.test_cfg) |
| else: |
| out_bboxes, out_labels = bboxes, labels |
|
|
| return out_bboxes, out_labels |
|
|
| def aug_test(self, imgs, img_metas, rescale=False): |
| """Augment testing of CornerNet. |
| |
| Args: |
| imgs (list[Tensor]): Augmented images. |
| img_metas (list[list[dict]]): Meta information of each image, e.g., |
| image size, scaling factor, etc. |
| rescale (bool): If True, return boxes in original image space. |
| Default: False. |
| |
| Note: |
| ``imgs`` must including flipped image pairs. |
| |
| Returns: |
| list[list[np.ndarray]]: BBox results of each image and classes. |
| The outer list corresponds to each image. The inner list |
| corresponds to each class. |
| """ |
| img_inds = list(range(len(imgs))) |
|
|
| assert img_metas[0][0]['flip'] + img_metas[1][0]['flip'], ( |
| 'aug test must have flipped image pair') |
| aug_results = [] |
| for ind, flip_ind in zip(img_inds[0::2], img_inds[1::2]): |
| img_pair = torch.cat([imgs[ind], imgs[flip_ind]]) |
| x = self.extract_feat(img_pair) |
| outs = self.bbox_head(x) |
| bbox_list = self.bbox_head.get_bboxes( |
| *outs, [img_metas[ind], img_metas[flip_ind]], False, False) |
| aug_results.append(bbox_list[0]) |
| aug_results.append(bbox_list[1]) |
|
|
| bboxes, labels = self.merge_aug_results(aug_results, img_metas) |
| bbox_results = bbox2result(bboxes, labels, self.bbox_head.num_classes) |
|
|
| return [bbox_results] |
|
|