Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| import logging | |
| import torch | |
| from maskrcnn_benchmark.modeling.box_coder import BoxCoder | |
| from maskrcnn_benchmark.structures.bounding_box import BoxList, _onnx_clip_boxes_to_image | |
| from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist | |
| from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms | |
| from maskrcnn_benchmark.structures.boxlist_ops import boxlist_ml_nms | |
| from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes | |
| from ..utils import permute_and_flatten | |
| import pdb | |
| class RPNPostProcessor(torch.nn.Module): | |
| """ | |
| Performs post-processing on the outputs of the RPN boxes, before feeding the | |
| proposals to the heads | |
| """ | |
| def __init__( | |
| self, | |
| pre_nms_top_n, | |
| post_nms_top_n, | |
| nms_thresh, | |
| min_size, | |
| box_coder=None, | |
| fpn_post_nms_top_n=None, | |
| onnx=False | |
| ): | |
| """ | |
| Arguments: | |
| pre_nms_top_n (int) | |
| post_nms_top_n (int) | |
| nms_thresh (float) | |
| min_size (int) | |
| box_coder (BoxCoder) | |
| fpn_post_nms_top_n (int) | |
| """ | |
| super(RPNPostProcessor, self).__init__() | |
| self.pre_nms_top_n = pre_nms_top_n | |
| self.post_nms_top_n = post_nms_top_n | |
| self.nms_thresh = nms_thresh | |
| self.min_size = min_size | |
| self.onnx = onnx | |
| if box_coder is None: | |
| box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) | |
| self.box_coder = box_coder | |
| if fpn_post_nms_top_n is None: | |
| fpn_post_nms_top_n = post_nms_top_n | |
| self.fpn_post_nms_top_n = fpn_post_nms_top_n | |
| def add_gt_proposals(self, proposals, targets): | |
| """ | |
| Arguments: | |
| proposals: list[BoxList] | |
| targets: list[BoxList] | |
| """ | |
| # Get the device we're operating on | |
| device = proposals[0].bbox.device | |
| gt_boxes = [target.copy_with_fields([]) for target in targets] | |
| # later cat of bbox requires all fields to be present for all bbox | |
| # so we need to add a dummy for objectness that's missing | |
| for gt_box in gt_boxes: | |
| gt_box.add_field("objectness", torch.ones(len(gt_box), device=device)) | |
| proposals = [ | |
| cat_boxlist((proposal, gt_box)) | |
| for proposal, gt_box in zip(proposals, gt_boxes) | |
| ] | |
| return proposals | |
| def forward_for_single_feature_map(self, anchors, objectness, box_regression): | |
| """ | |
| Arguments: | |
| anchors: list[BoxList] | |
| objectness: tensor of size N, A, H, W | |
| box_regression: tensor of size N, A * 4, H, W | |
| """ | |
| device = objectness.device | |
| N, A, H, W = objectness.shape | |
| # put in the same format as anchors | |
| objectness = objectness.permute(0, 2, 3, 1).reshape(N, -1) | |
| objectness = objectness.sigmoid() | |
| box_regression = box_regression.view(N, -1, 4, H, W).permute(0, 3, 4, 1, 2) | |
| box_regression = box_regression.reshape(N, -1, 4) | |
| num_anchors = A * H * W | |
| pre_nms_top_n = min(self.pre_nms_top_n, num_anchors) | |
| objectness, topk_idx = objectness.topk(pre_nms_top_n, dim=1, sorted=True) | |
| batch_idx = torch.arange(N, device=device)[:, None] | |
| box_regression = box_regression[batch_idx, topk_idx] | |
| image_shapes = [box.size for box in anchors] | |
| concat_anchors = torch.cat([a.bbox for a in anchors], dim=0) | |
| concat_anchors = concat_anchors.reshape(N, -1, 4)[batch_idx, topk_idx] | |
| proposals = self.box_coder.decode( | |
| box_regression.view(-1, 4), concat_anchors.view(-1, 4) | |
| ) | |
| proposals = proposals.view(N, -1, 4) | |
| result = [] | |
| for proposal, score, im_shape in zip(proposals, objectness, image_shapes): | |
| if self.onnx: | |
| proposal = _onnx_clip_boxes_to_image(proposal, im_shape) | |
| boxlist = BoxList(proposal, im_shape, mode="xyxy") | |
| else: | |
| boxlist = BoxList(proposal, im_shape, mode="xyxy") | |
| boxlist = boxlist.clip_to_image(remove_empty=False) | |
| boxlist.add_field("objectness", score) | |
| boxlist = remove_small_boxes(boxlist, self.min_size) | |
| boxlist = boxlist_nms( | |
| boxlist, | |
| self.nms_thresh, | |
| max_proposals=self.post_nms_top_n, | |
| score_field="objectness", | |
| ) | |
| result.append(boxlist) | |
| return result | |
| def forward(self, anchors, objectness, box_regression, targets=None): | |
| """ | |
| Arguments: | |
| anchors: list[list[BoxList]] | |
| objectness: list[tensor] | |
| box_regression: list[tensor] | |
| Returns: | |
| boxlists (list[BoxList]): the post-processed anchors, after | |
| applying box decoding and NMS | |
| """ | |
| sampled_boxes = [] | |
| num_levels = len(objectness) | |
| anchors = list(zip(*anchors)) | |
| for a, o, b in zip(anchors, objectness, box_regression): | |
| sampled_boxes.append(self.forward_for_single_feature_map(a, o, b)) | |
| boxlists = list(zip(*sampled_boxes)) | |
| boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] | |
| if num_levels > 1: | |
| boxlists = self.select_over_all_levels(boxlists) | |
| # append ground-truth bboxes to proposals | |
| if self.training and targets is not None: | |
| boxlists = self.add_gt_proposals(boxlists, targets) | |
| return boxlists | |
| def select_over_all_levels(self, boxlists): | |
| num_images = len(boxlists) | |
| # different behavior during training and during testing: | |
| # during training, post_nms_top_n is over *all* the proposals combined, while | |
| # during testing, it is over the proposals for each image | |
| # TODO resolve this difference and make it consistent. It should be per image, | |
| # and not per batch | |
| if self.training: | |
| objectness = torch.cat( | |
| [boxlist.get_field("objectness") for boxlist in boxlists], dim=0 | |
| ) | |
| box_sizes = [len(boxlist) for boxlist in boxlists] | |
| post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness)) | |
| _, inds_sorted = torch.topk(objectness, post_nms_top_n, dim=0, sorted=True) | |
| inds_mask = torch.zeros_like(objectness, dtype=torch.bool) | |
| inds_mask[inds_sorted] = 1 | |
| inds_mask = inds_mask.split(box_sizes) | |
| for i in range(num_images): | |
| boxlists[i] = boxlists[i][inds_mask[i]] | |
| else: | |
| for i in range(num_images): | |
| objectness = boxlists[i].get_field("objectness") | |
| post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness)) | |
| _, inds_sorted = torch.topk( | |
| objectness, post_nms_top_n, dim=0, sorted=True | |
| ) | |
| boxlists[i] = boxlists[i][inds_sorted] | |
| return boxlists | |
| def make_rpn_postprocessor(config, rpn_box_coder, is_train): | |
| fpn_post_nms_top_n = config.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN | |
| if not is_train: | |
| fpn_post_nms_top_n = config.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST | |
| pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TRAIN | |
| post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TRAIN | |
| if not is_train: | |
| pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TEST | |
| post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TEST | |
| nms_thresh = config.MODEL.RPN.NMS_THRESH | |
| min_size = config.MODEL.RPN.MIN_SIZE | |
| onnx = config.MODEL.ONNX | |
| box_selector = RPNPostProcessor( | |
| pre_nms_top_n=pre_nms_top_n, | |
| post_nms_top_n=post_nms_top_n, | |
| nms_thresh=nms_thresh, | |
| min_size=min_size, | |
| box_coder=rpn_box_coder, | |
| fpn_post_nms_top_n=fpn_post_nms_top_n, | |
| onnx=onnx | |
| ) | |
| return box_selector | |
| class RetinaPostProcessor(torch.nn.Module): | |
| """ | |
| Performs post-processing on the outputs of the RetinaNet boxes. | |
| This is only used in the testing. | |
| """ | |
| def __init__( | |
| self, | |
| pre_nms_thresh, | |
| pre_nms_top_n, | |
| nms_thresh, | |
| fpn_post_nms_top_n, | |
| min_size, | |
| num_classes, | |
| box_coder=None, | |
| ): | |
| """ | |
| Arguments: | |
| pre_nms_thresh (float) | |
| pre_nms_top_n (int) | |
| nms_thresh (float) | |
| fpn_post_nms_top_n (int) | |
| min_size (int) | |
| num_classes (int) | |
| box_coder (BoxCoder) | |
| """ | |
| super(RetinaPostProcessor, self).__init__() | |
| self.pre_nms_thresh = pre_nms_thresh | |
| self.pre_nms_top_n = pre_nms_top_n | |
| self.nms_thresh = nms_thresh | |
| self.fpn_post_nms_top_n = fpn_post_nms_top_n | |
| self.min_size = min_size | |
| self.num_classes = num_classes | |
| if box_coder is None: | |
| box_coder = BoxCoder(weights=(10., 10., 5., 5.)) | |
| self.box_coder = box_coder | |
| def forward_for_single_feature_map(self, anchors, box_cls, box_regression): | |
| """ | |
| Arguments: | |
| anchors: list[BoxList] | |
| box_cls: tensor of size N, A * C, H, W | |
| box_regression: tensor of size N, A * 4, H, W | |
| """ | |
| device = box_cls.device | |
| N, _, H, W = box_cls.shape | |
| A = box_regression.size(1) // 4 | |
| C = box_cls.size(1) // A | |
| # put in the same format as anchors | |
| box_cls = permute_and_flatten(box_cls, N, A, C, H, W) | |
| box_cls = box_cls.sigmoid() | |
| box_regression = permute_and_flatten(box_regression, N, A, 4, H, W) | |
| box_regression = box_regression.reshape(N, -1, 4) | |
| num_anchors = A * H * W | |
| candidate_inds = box_cls > self.pre_nms_thresh | |
| pre_nms_top_n = candidate_inds.view(N, -1).sum(1) | |
| pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n) | |
| results = [] | |
| for per_box_cls, per_box_regression, per_pre_nms_top_n, \ | |
| per_candidate_inds, per_anchors in zip( | |
| box_cls, | |
| box_regression, | |
| pre_nms_top_n, | |
| candidate_inds, | |
| anchors): | |
| # Sort and select TopN | |
| # TODO most of this can be made out of the loop for | |
| # all images. | |
| # TODO:Yang: Not easy to do. Because the numbers of detections are | |
| # different in each image. Therefore, this part needs to be done | |
| # per image. | |
| per_box_cls = per_box_cls[per_candidate_inds] | |
| per_box_cls, top_k_indices = \ | |
| per_box_cls.topk(per_pre_nms_top_n, sorted=False) | |
| per_candidate_nonzeros = \ | |
| per_candidate_inds.nonzero()[top_k_indices, :] | |
| per_box_loc = per_candidate_nonzeros[:, 0] | |
| per_class = per_candidate_nonzeros[:, 1] | |
| per_class += 1 | |
| detections = self.box_coder.decode( | |
| per_box_regression[per_box_loc, :].view(-1, 4), | |
| per_anchors.bbox[per_box_loc, :].view(-1, 4) | |
| ) | |
| boxlist = BoxList(detections, per_anchors.size, mode="xyxy") | |
| boxlist.add_field("labels", per_class) | |
| boxlist.add_field("scores", per_box_cls) | |
| boxlist = boxlist.clip_to_image(remove_empty=False) | |
| boxlist = remove_small_boxes(boxlist, self.min_size) | |
| results.append(boxlist) | |
| return results | |
| # TODO very similar to filter_results from PostProcessor | |
| # but filter_results is per image | |
| # TODO Yang: solve this issue in the future. No good solution | |
| # right now. | |
| def select_over_all_levels(self, boxlists): | |
| num_images = len(boxlists) | |
| results = [] | |
| for i in range(num_images): | |
| scores = boxlists[i].get_field("scores") | |
| labels = boxlists[i].get_field("labels") | |
| boxes = boxlists[i].bbox | |
| boxlist = boxlists[i] | |
| result = [] | |
| # skip the background | |
| for j in range(1, self.num_classes): | |
| inds = (labels == j).nonzero().view(-1) | |
| scores_j = scores[inds] | |
| boxes_j = boxes[inds, :].view(-1, 4) | |
| boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") | |
| boxlist_for_class.add_field("scores", scores_j) | |
| boxlist_for_class = boxlist_nms( | |
| boxlist_for_class, self.nms_thresh, | |
| score_field="scores" | |
| ) | |
| num_labels = len(boxlist_for_class) | |
| boxlist_for_class.add_field( | |
| "labels", torch.full((num_labels,), j, | |
| dtype=torch.int64, | |
| device=scores.device) | |
| ) | |
| result.append(boxlist_for_class) | |
| result = cat_boxlist(result) | |
| number_of_detections = len(result) | |
| # Limit to max_per_image detections **over all classes** | |
| if number_of_detections > self.fpn_post_nms_top_n > 0: | |
| cls_scores = result.get_field("scores") | |
| image_thresh, _ = torch.kthvalue( | |
| cls_scores.cpu(), | |
| number_of_detections - self.fpn_post_nms_top_n + 1 | |
| ) | |
| keep = cls_scores >= image_thresh.item() | |
| keep = torch.nonzero(keep).squeeze(1) | |
| result = result[keep] | |
| results.append(result) | |
| return results | |
| def forward(self, anchors, objectness, box_regression, targets=None): | |
| """ | |
| Arguments: | |
| anchors: list[list[BoxList]] | |
| objectness: list[tensor] | |
| box_regression: list[tensor] | |
| Returns: | |
| boxlists (list[BoxList]): the post-processed anchors, after | |
| applying box decoding and NMS | |
| """ | |
| sampled_boxes = [] | |
| anchors = list(zip(*anchors)) | |
| for a, o, b in zip(anchors, objectness, box_regression): | |
| sampled_boxes.append(self.forward_for_single_feature_map(a, o, b)) | |
| boxlists = list(zip(*sampled_boxes)) | |
| boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] | |
| boxlists = self.select_over_all_levels(boxlists) | |
| return boxlists | |
| def make_retina_postprocessor(config, rpn_box_coder, is_train): | |
| pre_nms_thresh = config.MODEL.RETINANET.INFERENCE_TH | |
| pre_nms_top_n = config.MODEL.RETINANET.PRE_NMS_TOP_N | |
| nms_thresh = config.MODEL.RETINANET.NMS_TH | |
| fpn_post_nms_top_n = config.MODEL.RETINANET.DETECTIONS_PER_IMG | |
| min_size = 0 | |
| box_selector = RetinaPostProcessor( | |
| pre_nms_thresh=pre_nms_thresh, | |
| pre_nms_top_n=pre_nms_top_n, | |
| nms_thresh=nms_thresh, | |
| fpn_post_nms_top_n=fpn_post_nms_top_n, | |
| min_size=min_size, | |
| num_classes=config.MODEL.RETINANET.NUM_CLASSES, | |
| box_coder=rpn_box_coder, | |
| ) | |
| return box_selector | |
| class FCOSPostProcessor(torch.nn.Module): | |
| """ | |
| Performs post-processing on the outputs of the RetinaNet boxes. | |
| This is only used in the testing. | |
| """ | |
| def __init__( | |
| self, | |
| pre_nms_thresh, | |
| pre_nms_top_n, | |
| nms_thresh, | |
| fpn_post_nms_top_n, | |
| min_size, | |
| num_classes, | |
| bbox_aug_enabled=False | |
| ): | |
| """ | |
| Arguments: | |
| pre_nms_thresh (float) | |
| pre_nms_top_n (int) | |
| nms_thresh (float) | |
| fpn_post_nms_top_n (int) | |
| min_size (int) | |
| num_classes (int) | |
| box_coder (BoxCoder) | |
| """ | |
| super(FCOSPostProcessor, self).__init__() | |
| self.pre_nms_thresh = pre_nms_thresh | |
| self.pre_nms_top_n = pre_nms_top_n | |
| self.nms_thresh = nms_thresh | |
| self.fpn_post_nms_top_n = fpn_post_nms_top_n | |
| self.min_size = min_size | |
| self.num_classes = num_classes | |
| self.bbox_aug_enabled = bbox_aug_enabled | |
| def forward_for_single_feature_map( | |
| self, locations, box_cls, | |
| box_regression, centerness, | |
| image_sizes): | |
| """ | |
| Arguments: | |
| anchors: list[BoxList] | |
| box_cls: tensor of size N, A * C, H, W | |
| box_regression: tensor of size N, A * 4, H, W | |
| """ | |
| N, C, H, W = box_cls.shape | |
| # put in the same format as locations | |
| box_cls = box_cls.view(N, C, H, W).permute(0, 2, 3, 1) | |
| box_cls = box_cls.reshape(N, -1, C).sigmoid() | |
| box_regression = box_regression.view(N, 4, H, W).permute(0, 2, 3, 1) | |
| box_regression = box_regression.reshape(N, -1, 4) | |
| centerness = centerness.view(N, 1, H, W).permute(0, 2, 3, 1) | |
| centerness = centerness.reshape(N, -1).sigmoid() | |
| candidate_inds = box_cls > self.pre_nms_thresh | |
| pre_nms_top_n = candidate_inds.reshape(N, -1).sum(1) | |
| pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n) | |
| # multiply the classification scores with centerness scores | |
| box_cls = box_cls * centerness[:, :, None] | |
| results = [] | |
| for i in range(N): | |
| per_box_cls = box_cls[i] | |
| per_candidate_inds = candidate_inds[i] | |
| per_box_cls = per_box_cls[per_candidate_inds] | |
| per_candidate_nonzeros = per_candidate_inds.nonzero() | |
| per_box_loc = per_candidate_nonzeros[:, 0] | |
| per_class = per_candidate_nonzeros[:, 1] + 1 | |
| per_box_regression = box_regression[i] | |
| per_box_regression = per_box_regression[per_box_loc] | |
| per_locations = locations[per_box_loc] | |
| per_pre_nms_top_n = pre_nms_top_n[i] | |
| if per_candidate_inds.sum().item() > per_pre_nms_top_n.item(): | |
| per_box_cls, top_k_indices = \ | |
| per_box_cls.topk(per_pre_nms_top_n, sorted=False) | |
| per_class = per_class[top_k_indices] | |
| per_box_regression = per_box_regression[top_k_indices] | |
| per_locations = per_locations[top_k_indices] | |
| detections = torch.stack([ | |
| per_locations[:, 0] - per_box_regression[:, 0], | |
| per_locations[:, 1] - per_box_regression[:, 1], | |
| per_locations[:, 0] + per_box_regression[:, 2], | |
| per_locations[:, 1] + per_box_regression[:, 3], | |
| ], dim=1) | |
| h, w = image_sizes[i] | |
| boxlist = BoxList(detections, (int(w), int(h)), mode="xyxy") | |
| boxlist.add_field('centers', per_locations) | |
| boxlist.add_field("labels", per_class) | |
| boxlist.add_field("scores", torch.sqrt(per_box_cls)) | |
| boxlist = boxlist.clip_to_image(remove_empty=False) | |
| boxlist = remove_small_boxes(boxlist, self.min_size) | |
| results.append(boxlist) | |
| return results | |
| def forward(self, locations, box_cls, box_regression, centerness, image_sizes): | |
| """ | |
| Arguments: | |
| anchors: list[list[BoxList]] | |
| box_cls: list[tensor] | |
| box_regression: list[tensor] | |
| image_sizes: list[(h, w)] | |
| Returns: | |
| boxlists (list[BoxList]): the post-processed anchors, after | |
| applying box decoding and NMS | |
| """ | |
| sampled_boxes = [] | |
| for _, (l, o, b, c) in enumerate(zip(locations, box_cls, box_regression, centerness)): | |
| sampled_boxes.append( | |
| self.forward_for_single_feature_map( | |
| l, o, b, c, image_sizes | |
| ) | |
| ) | |
| boxlists = list(zip(*sampled_boxes)) | |
| boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] | |
| if not self.bbox_aug_enabled: | |
| boxlists = self.select_over_all_levels(boxlists) | |
| return boxlists | |
| # TODO very similar to filter_results from PostProcessor | |
| # but filter_results is per image | |
| # TODO Yang: solve this issue in the future. No good solution | |
| # right now. | |
| def select_over_all_levels(self, boxlists): | |
| num_images = len(boxlists) | |
| results = [] | |
| for i in range(num_images): | |
| # multiclass nms | |
| result = boxlist_ml_nms(boxlists[i], self.nms_thresh) | |
| number_of_detections = len(result) | |
| # Limit to max_per_image detections **over all classes** | |
| if number_of_detections > self.fpn_post_nms_top_n > 0: | |
| cls_scores = result.get_field("scores") | |
| image_thresh, _ = torch.kthvalue( | |
| cls_scores.cpu(), | |
| number_of_detections - self.fpn_post_nms_top_n + 1 | |
| ) | |
| keep = cls_scores >= image_thresh.item() | |
| keep = torch.nonzero(keep).squeeze(1) | |
| result = result[keep] | |
| results.append(result) | |
| return results | |
| def make_fcos_postprocessor(config, is_train=False): | |
| pre_nms_thresh = config.MODEL.FCOS.INFERENCE_TH | |
| if is_train: | |
| pre_nms_thresh = config.MODEL.FCOS.INFERENCE_TH_TRAIN | |
| pre_nms_top_n = config.MODEL.FCOS.PRE_NMS_TOP_N | |
| fpn_post_nms_top_n = config.MODEL.FCOS.DETECTIONS_PER_IMG | |
| if is_train: | |
| pre_nms_top_n = config.MODEL.FCOS.PRE_NMS_TOP_N_TRAIN | |
| fpn_post_nms_top_n = config.MODEL.FCOS.POST_NMS_TOP_N_TRAIN | |
| nms_thresh = config.MODEL.FCOS.NMS_TH | |
| box_selector = FCOSPostProcessor( | |
| pre_nms_thresh=pre_nms_thresh, | |
| pre_nms_top_n=pre_nms_top_n, | |
| nms_thresh=nms_thresh, | |
| fpn_post_nms_top_n=fpn_post_nms_top_n, | |
| min_size=0, | |
| num_classes=config.MODEL.FCOS.NUM_CLASSES, | |
| ) | |
| return box_selector | |
| class ATSSPostProcessor(torch.nn.Module): | |
| def __init__( | |
| self, | |
| pre_nms_thresh, | |
| pre_nms_top_n, | |
| nms_thresh, | |
| fpn_post_nms_top_n, | |
| min_size, | |
| num_classes, | |
| box_coder, | |
| bbox_aug_enabled=False, | |
| bbox_aug_vote=False, | |
| score_agg='MEAN', | |
| mdetr_style_aggregate_class_num=-1 | |
| ): | |
| super(ATSSPostProcessor, self).__init__() | |
| self.pre_nms_thresh = pre_nms_thresh | |
| self.pre_nms_top_n = pre_nms_top_n | |
| self.nms_thresh = nms_thresh | |
| self.fpn_post_nms_top_n = fpn_post_nms_top_n | |
| self.min_size = min_size | |
| self.num_classes = num_classes | |
| self.bbox_aug_enabled = bbox_aug_enabled | |
| self.box_coder = box_coder | |
| self.bbox_aug_vote = bbox_aug_vote | |
| self.score_agg = score_agg | |
| self.mdetr_style_aggregate_class_num = mdetr_style_aggregate_class_num | |
| def forward_for_single_feature_map(self, box_regression, centerness, anchors, | |
| box_cls=None, | |
| token_logits=None, | |
| dot_product_logits=None, | |
| positive_map=None, | |
| ): | |
| N, _, H, W = box_regression.shape | |
| A = box_regression.size(1) // 4 | |
| if box_cls is not None: | |
| C = box_cls.size(1) // A | |
| if token_logits is not None: | |
| T = token_logits.size(1) // A | |
| # put in the same format as anchors | |
| if box_cls is not None: | |
| #print('Classification.') | |
| box_cls = permute_and_flatten(box_cls, N, A, C, H, W) | |
| box_cls = box_cls.sigmoid() | |
| # binary focal loss version | |
| if token_logits is not None: | |
| #print('Token.') | |
| token_logits = permute_and_flatten(token_logits, N, A, T, H, W) | |
| token_logits = token_logits.sigmoid() | |
| # turn back to original classes | |
| scores = convert_grounding_to_od_logits(logits=token_logits, box_cls=box_cls, positive_map=positive_map, | |
| score_agg=self.score_agg) | |
| box_cls = scores | |
| # binary dot product focal version | |
| if dot_product_logits is not None: | |
| #print('Dot Product.') | |
| dot_product_logits = dot_product_logits.sigmoid() | |
| if self.mdetr_style_aggregate_class_num != -1: | |
| scores = convert_grounding_to_od_logits_v2( | |
| logits=dot_product_logits, | |
| num_class=self.mdetr_style_aggregate_class_num, | |
| positive_map=positive_map, | |
| score_agg=self.score_agg, | |
| disable_minus_one=False) | |
| else: | |
| scores = convert_grounding_to_od_logits(logits=dot_product_logits, box_cls=box_cls, | |
| positive_map=positive_map, | |
| score_agg=self.score_agg) | |
| box_cls = scores | |
| box_regression = permute_and_flatten(box_regression, N, A, 4, H, W) | |
| box_regression = box_regression.reshape(N, -1, 4) | |
| candidate_inds = box_cls > self.pre_nms_thresh | |
| pre_nms_top_n = candidate_inds.reshape(N, -1).sum(1) | |
| pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n) | |
| centerness = permute_and_flatten(centerness, N, A, 1, H, W) | |
| centerness = centerness.reshape(N, -1).sigmoid() | |
| # multiply the classification scores with centerness scores | |
| box_cls = box_cls * centerness[:, :, None] | |
| results = [] | |
| for per_box_cls, per_box_regression, per_pre_nms_top_n, per_candidate_inds, per_anchors \ | |
| in zip(box_cls, box_regression, pre_nms_top_n, candidate_inds, anchors): | |
| per_box_cls = per_box_cls[per_candidate_inds] | |
| per_box_cls, top_k_indices = per_box_cls.topk(per_pre_nms_top_n, sorted=False) | |
| per_candidate_nonzeros = per_candidate_inds.nonzero()[top_k_indices, :] | |
| per_box_loc = per_candidate_nonzeros[:, 0] | |
| per_class = per_candidate_nonzeros[:, 1] + 1 | |
| # print(per_class) | |
| detections = self.box_coder.decode( | |
| per_box_regression[per_box_loc, :].view(-1, 4), | |
| per_anchors.bbox[per_box_loc, :].view(-1, 4) | |
| ) | |
| boxlist = BoxList(detections, per_anchors.size, mode="xyxy") | |
| boxlist.add_field("labels", per_class) | |
| boxlist.add_field("scores", torch.sqrt(per_box_cls)) | |
| boxlist = boxlist.clip_to_image(remove_empty=False) | |
| boxlist = remove_small_boxes(boxlist, self.min_size) | |
| results.append(boxlist) | |
| return results | |
| def forward(self, box_regression, centerness, anchors, | |
| box_cls=None, | |
| token_logits=None, | |
| dot_product_logits=None, | |
| positive_map=None, | |
| ): | |
| sampled_boxes = [] | |
| anchors = list(zip(*anchors)) | |
| for idx, (b, c, a) in enumerate(zip(box_regression, centerness, anchors)): | |
| o = None | |
| t = None | |
| d = None | |
| if box_cls is not None: | |
| o = box_cls[idx] | |
| if token_logits is not None: | |
| t = token_logits[idx] | |
| if dot_product_logits is not None: | |
| d = dot_product_logits[idx] | |
| sampled_boxes.append( | |
| self.forward_for_single_feature_map(b, c, a, o, t, d, positive_map) | |
| ) | |
| boxlists = list(zip(*sampled_boxes)) | |
| boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] | |
| if not (self.bbox_aug_enabled and not self.bbox_aug_vote): | |
| boxlists = self.select_over_all_levels(boxlists) | |
| return boxlists | |
| # TODO very similar to filter_results from PostProcessor | |
| # but filter_results is per image | |
| # TODO Yang: solve this issue in the future. No good solution | |
| # right now. | |
| def select_over_all_levels(self, boxlists): | |
| num_images = len(boxlists) | |
| results = [] | |
| for i in range(num_images): | |
| # multiclass nms | |
| result = boxlist_ml_nms(boxlists[i], self.nms_thresh) | |
| number_of_detections = len(result) | |
| # Limit to max_per_image detections **over all classes** | |
| if number_of_detections > self.fpn_post_nms_top_n > 0: | |
| cls_scores = result.get_field("scores") | |
| image_thresh, _ = torch.kthvalue( | |
| # TODO: confirm with Pengchuan and Xiyang, torch.kthvalue is not implemented for 'Half' | |
| # cls_scores.cpu(), | |
| cls_scores.cpu().float(), | |
| number_of_detections - self.fpn_post_nms_top_n + 1 | |
| ) | |
| keep = cls_scores >= image_thresh.item() | |
| keep = torch.nonzero(keep).squeeze(1) | |
| result = result[keep] | |
| results.append(result) | |
| return results | |
| def convert_grounding_to_od_logits(logits, box_cls, positive_map, score_agg=None): | |
| scores = torch.zeros(logits.shape[0], logits.shape[1], box_cls.shape[2]).to(logits.device) | |
| # 256 -> 80, average for each class | |
| if positive_map is not None: | |
| # score aggregation method | |
| if score_agg == "MEAN": | |
| for label_j in positive_map: | |
| scores[:, :, label_j - 1] = logits[:, :, torch.LongTensor(positive_map[label_j])].mean(-1) | |
| elif score_agg == "MAX": | |
| # torch.max() returns (values, indices) | |
| for label_j in positive_map: | |
| scores[:, :, label_j - 1] = logits[:, :, torch.LongTensor(positive_map[label_j])].max(-1)[ | |
| 0] | |
| elif score_agg == "ONEHOT": | |
| # one hot | |
| scores = logits[:, :, :len(positive_map)] | |
| else: | |
| raise NotImplementedError | |
| return scores | |
| def convert_grounding_to_od_logits_v2(logits, num_class, positive_map, score_agg=None, disable_minus_one = True): | |
| scores = torch.zeros(logits.shape[0], logits.shape[1], num_class).to(logits.device) | |
| # 256 -> 80, average for each class | |
| if positive_map is not None: | |
| # score aggregation method | |
| if score_agg == "MEAN": | |
| for label_j in positive_map: | |
| locations_label_j = positive_map[label_j] | |
| if isinstance(locations_label_j, int): | |
| locations_label_j = [locations_label_j] | |
| scores[:, :, label_j if disable_minus_one else label_j - 1] = logits[:, :, torch.LongTensor(locations_label_j)].mean(-1) | |
| elif score_agg == "POWER": | |
| for label_j in positive_map: | |
| locations_label_j = positive_map[label_j] | |
| if isinstance(locations_label_j, int): | |
| locations_label_j = [locations_label_j] | |
| probability = torch.prod(logits[:, :, torch.LongTensor(locations_label_j)], dim=-1).squeeze(-1) | |
| probability = torch.pow(probability, 1/len(locations_label_j)) | |
| scores[:, :, label_j if disable_minus_one else label_j - 1] = probability | |
| elif score_agg == "MAX": | |
| # torch.max() returns (values, indices) | |
| for label_j in positive_map: | |
| scores[:, :, label_j if disable_minus_one else label_j - 1] = logits[:, :, torch.LongTensor(positive_map[label_j])].max(-1)[ | |
| 0] | |
| elif score_agg == "ONEHOT": | |
| # one hot | |
| scores = logits[:, :, :len(positive_map)] | |
| else: | |
| raise NotImplementedError | |
| return scores | |
| def make_atss_postprocessor(config, box_coder, is_train=False): | |
| pre_nms_thresh = config.MODEL.ATSS.INFERENCE_TH | |
| if is_train: | |
| pre_nms_thresh = config.MODEL.ATSS.INFERENCE_TH_TRAIN | |
| pre_nms_top_n = config.MODEL.ATSS.PRE_NMS_TOP_N | |
| fpn_post_nms_top_n = config.MODEL.ATSS.DETECTIONS_PER_IMG | |
| if is_train: | |
| pre_nms_top_n = config.MODEL.ATSS.PRE_NMS_TOP_N_TRAIN | |
| fpn_post_nms_top_n = config.MODEL.ATSS.POST_NMS_TOP_N_TRAIN | |
| nms_thresh = config.MODEL.ATSS.NMS_TH | |
| score_agg = config.MODEL.DYHEAD.SCORE_AGG | |
| box_selector = ATSSPostProcessor( | |
| pre_nms_thresh=pre_nms_thresh, | |
| pre_nms_top_n=pre_nms_top_n, | |
| nms_thresh=nms_thresh, | |
| fpn_post_nms_top_n=fpn_post_nms_top_n, | |
| min_size=0, | |
| num_classes=config.MODEL.ATSS.NUM_CLASSES, | |
| box_coder=box_coder, | |
| bbox_aug_enabled=config.TEST.USE_MULTISCALE, | |
| score_agg=score_agg, | |
| mdetr_style_aggregate_class_num=config.TEST.MDETR_STYLE_AGGREGATE_CLASS_NUM | |
| ) | |
| return box_selector | |