| |
| from detectron2.layers import batched_nms |
| from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads |
| from detectron2.modeling.roi_heads.roi_heads import Res5ROIHeads |
| from detectron2.structures import Instances |
|
|
|
|
| def merge_branch_instances(instances, num_branch, nms_thresh, topk_per_image): |
| """ |
| Merge detection results from different branches of TridentNet. |
| Return detection results by applying non-maximum suppression (NMS) on bounding boxes |
| and keep the unsuppressed boxes and other instances (e.g mask) if any. |
| |
| Args: |
| instances (list[Instances]): A list of N * num_branch instances that store detection |
| results. Contain N images and each image has num_branch instances. |
| num_branch (int): Number of branches used for merging detection results for each image. |
| nms_thresh (float): The threshold to use for box non-maximum suppression. Value in [0, 1]. |
| topk_per_image (int): The number of top scoring detections to return. Set < 0 to return |
| all detections. |
| |
| Returns: |
| results: (list[Instances]): A list of N instances, one for each image in the batch, |
| that stores the topk most confidence detections after merging results from multiple |
| branches. |
| """ |
| if num_branch == 1: |
| return instances |
|
|
| batch_size = len(instances) // num_branch |
| results = [] |
| for i in range(batch_size): |
| instance = Instances.cat([instances[i + batch_size * j] for j in range(num_branch)]) |
|
|
| |
| keep = batched_nms( |
| instance.pred_boxes.tensor, instance.scores, instance.pred_classes, nms_thresh |
| ) |
| keep = keep[:topk_per_image] |
| result = instance[keep] |
|
|
| results.append(result) |
|
|
| return results |
|
|
|
|
| @ROI_HEADS_REGISTRY.register() |
| class TridentRes5ROIHeads(Res5ROIHeads): |
| """ |
| The TridentNet ROIHeads in a typical "C4" R-CNN model. |
| See :class:`Res5ROIHeads`. |
| """ |
|
|
| def __init__(self, cfg, input_shape): |
| super().__init__(cfg, input_shape) |
|
|
| self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH |
| self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1 |
|
|
| def forward(self, images, features, proposals, targets=None): |
| """ |
| See :class:`Res5ROIHeads.forward`. |
| """ |
| num_branch = self.num_branch if self.training or not self.trident_fast else 1 |
| all_targets = targets * num_branch if targets is not None else None |
| pred_instances, losses = super().forward(images, features, proposals, all_targets) |
| del images, all_targets, targets |
|
|
| if self.training: |
| return pred_instances, losses |
| else: |
| pred_instances = merge_branch_instances( |
| pred_instances, |
| num_branch, |
| self.box_predictor.test_nms_thresh, |
| self.box_predictor.test_topk_per_image, |
| ) |
|
|
| return pred_instances, {} |
|
|
|
|
| @ROI_HEADS_REGISTRY.register() |
| class TridentStandardROIHeads(StandardROIHeads): |
| """ |
| The `StandardROIHeads` for TridentNet. |
| See :class:`StandardROIHeads`. |
| """ |
|
|
| def __init__(self, cfg, input_shape): |
| super(TridentStandardROIHeads, self).__init__(cfg, input_shape) |
|
|
| self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH |
| self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1 |
|
|
| def forward(self, images, features, proposals, targets=None): |
| """ |
| See :class:`Res5ROIHeads.forward`. |
| """ |
| |
| num_branch = self.num_branch if self.training or not self.trident_fast else 1 |
| |
| all_targets = targets * num_branch if targets is not None else None |
| pred_instances, losses = super().forward(images, features, proposals, all_targets) |
| del images, all_targets, targets |
|
|
| if self.training: |
| return pred_instances, losses |
| else: |
| pred_instances = merge_branch_instances( |
| pred_instances, |
| num_branch, |
| self.box_predictor.test_nms_thresh, |
| self.box_predictor.test_topk_per_image, |
| ) |
|
|
| return pred_instances, {} |
|
|