File size: 9,031 Bytes
b4d7ac8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py
# which has the following license...
# https://github.com/pytorch/vision/blob/main/LICENSE
# BSD 3-Clause License
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
"""
Part of this script is adapted from
https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py
"""
from __future__ import annotations
from collections.abc import Callable
import torch
from torch import Tensor
from monai.data.box_utils import batched_nms, box_iou, clip_boxes_to_image
from monai.transforms.utils_pytorch_numpy_unification import floor_divide
class BoxSelector:
"""
Box selector which selects the predicted boxes.
The box selection is performed with the following steps:
#. For each level, discard boxes with scores less than self.score_thresh.
#. For each level, keep boxes with top self.topk_candidates_per_level scores.
#. For the whole image, perform non-maximum suppression (NMS) on boxes, with overlapping threshold nms_thresh.
#. For the whole image, keep boxes with top self.detections_per_img scores.
Args:
apply_sigmoid: whether to apply sigmoid to get scores from classification logits
score_thresh: no box with scores less than score_thresh will be kept
topk_candidates_per_level: max number of boxes to keep for each level
nms_thresh: box overlapping threshold for NMS
detections_per_img: max number of boxes to keep for each image
Example:
.. code-block:: python
input_param = {
"apply_sigmoid": True,
"score_thresh": 0.1,
"topk_candidates_per_level": 2,
"nms_thresh": 0.1,
"detections_per_img": 5,
}
box_selector = BoxSelector(**input_param)
boxes = [torch.randn([3,6]), torch.randn([7,6])]
logits = [torch.randn([3,3]), torch.randn([7,3])]
spatial_size = (8,8,8)
selected_boxes, selected_scores, selected_labels = box_selector.select_boxes_per_image(
boxes, logits, spatial_size
)
"""
def __init__(
self,
box_overlap_metric: Callable = box_iou,
apply_sigmoid: bool = True,
score_thresh: float = 0.05,
topk_candidates_per_level: int = 1000,
nms_thresh: float = 0.5,
detections_per_img: int = 300,
):
self.box_overlap_metric = box_overlap_metric
self.apply_sigmoid = apply_sigmoid
self.score_thresh = score_thresh
self.topk_candidates_per_level = topk_candidates_per_level
self.nms_thresh = nms_thresh
self.detections_per_img = detections_per_img
def select_top_score_idx_per_level(self, logits: Tensor) -> tuple[Tensor, Tensor, Tensor]:
"""
Select indices with highest scores.
The indices selection is performed with the following steps:
#. If self.apply_sigmoid, get scores by applying sigmoid to logits. Otherwise, use logits as scores.
#. Discard indices with scores less than self.score_thresh
#. Keep indices with top self.topk_candidates_per_level scores
Args:
logits: predicted classification logits, Tensor sized (N, num_classes)
Return:
- topk_idxs: selected M indices, Tensor sized (M, )
- selected_scores: selected M scores, Tensor sized (M, )
- selected_labels: selected M labels, Tensor sized (M, )
"""
num_classes = logits.shape[-1]
# apply sigmoid to classification logits if asked
if self.apply_sigmoid:
scores = torch.sigmoid(logits.to(torch.float32)).flatten()
else:
scores = logits.flatten()
# remove low scoring boxes
keep_idxs = scores > self.score_thresh
scores = scores[keep_idxs]
flatten_topk_idxs = torch.where(keep_idxs)[0]
# keep only topk scoring predictions
num_topk = min(self.topk_candidates_per_level, flatten_topk_idxs.size(0))
selected_scores, idxs = scores.to(torch.float32).topk(
num_topk
) # half precision not implemented for cpu float16
flatten_topk_idxs = flatten_topk_idxs[idxs]
selected_labels = flatten_topk_idxs % num_classes
topk_idxs = floor_divide(flatten_topk_idxs, num_classes)
return topk_idxs, selected_scores, selected_labels # type: ignore
def select_boxes_per_image(
self, boxes_list: list[Tensor], logits_list: list[Tensor], spatial_size: list[int] | tuple[int]
) -> tuple[Tensor, Tensor, Tensor]:
"""
Postprocessing to generate detection result from classification logits and boxes.
The box selection is performed with the following steps:
#. For each level, discard boxes with scores less than self.score_thresh.
#. For each level, keep boxes with top self.topk_candidates_per_level scores.
#. For the whole image, perform non-maximum suppression (NMS) on boxes, with overlapping threshold nms_thresh.
#. For the whole image, keep boxes with top self.detections_per_img scores.
Args:
boxes_list: list of predicted boxes from a single image,
each element i is a Tensor sized (N_i, 2*spatial_dims)
logits_list: list of predicted classification logits from a single image,
each element i is a Tensor sized (N_i, num_classes)
spatial_size: spatial size of the image
Return:
- selected boxes, Tensor sized (P, 2*spatial_dims)
- selected_scores, Tensor sized (P, )
- selected_labels, Tensor sized (P, )
"""
if len(boxes_list) != len(logits_list):
raise ValueError(
"len(boxes_list) should equal to len(logits_list). "
f"Got len(boxes_list)={len(boxes_list)}, len(logits_list)={len(logits_list)}"
)
image_boxes = []
image_scores = []
image_labels = []
boxes_dtype = boxes_list[0].dtype
logits_dtype = logits_list[0].dtype
for boxes_per_level, logits_per_level in zip(boxes_list, logits_list):
# select topk boxes for each level
topk_idxs: Tensor
topk_idxs, scores_per_level, labels_per_level = self.select_top_score_idx_per_level(logits_per_level)
boxes_per_level = boxes_per_level[topk_idxs]
keep: Tensor
boxes_per_level, keep = clip_boxes_to_image( # type: ignore
boxes_per_level, spatial_size, remove_empty=True
)
image_boxes.append(boxes_per_level)
image_scores.append(scores_per_level[keep])
image_labels.append(labels_per_level[keep])
image_boxes_t: Tensor = torch.cat(image_boxes, dim=0)
image_scores_t: Tensor = torch.cat(image_scores, dim=0)
image_labels_t: Tensor = torch.cat(image_labels, dim=0)
# non-maximum suppression on detected boxes from all levels
keep_t: Tensor = batched_nms( # type: ignore
image_boxes_t,
image_scores_t,
image_labels_t,
self.nms_thresh,
box_overlap_metric=self.box_overlap_metric,
max_proposals=self.detections_per_img,
)
selected_boxes = image_boxes_t[keep_t].to(boxes_dtype)
selected_scores = image_scores_t[keep_t].to(logits_dtype)
selected_labels = image_labels_t[keep_t]
return selected_boxes, selected_scores, selected_labels
|