File size: 4,047 Bytes
d19bd3e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | import torch
from mmdet.core.bbox import BaseBBoxCoder
from mmdet.core.bbox.builder import BBOX_CODERS
from ..utils import denormalize_bbox
@BBOX_CODERS.register_module()
class NMSFreeCoder(BaseBBoxCoder):
"""Bbox coder for NMS-free detector.
Args:
pc_range (list[float]): Range of point cloud.
post_center_range (list[float]): Limit of the center.
Default: None.
max_num (int): Max number to be kept. Default: 100.
score_threshold (float): Threshold to filter boxes based on score.
Default: None.
code_size (int): Code size of bboxes. Default: 9
"""
def __init__(self,
pc_range,
voxel_size=None,
post_center_range=None,
max_num=100,
score_threshold=None,
num_classes=10):
self.pc_range = pc_range
self.voxel_size = voxel_size
self.post_center_range = post_center_range
self.max_num = max_num
self.score_threshold = score_threshold
self.num_classes = num_classes
def encode(self):
pass
def decode_single(self, cls_scores, bbox_preds):
"""Decode bboxes.
Args:
cls_scores (Tensor): Outputs from the classification head, \
shape [num_query, cls_out_channels]. Note \
cls_out_channels should includes background.
bbox_preds (Tensor): Outputs from the regression \
head with normalized coordinate format (cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy). \
Shape [num_query, 9].
Returns:
list[dict]: Decoded boxes.
"""
max_num = self.max_num
cls_scores = cls_scores.sigmoid()
scores, indexs = cls_scores.view(-1).topk(max_num)
labels = indexs % self.num_classes
bbox_index = torch.div(indexs, self.num_classes, rounding_mode='trunc')
bbox_preds = bbox_preds[bbox_index]
final_box_preds = denormalize_bbox(bbox_preds)
final_scores = scores
final_preds = labels
# use score threshold
if self.score_threshold is not None:
thresh_mask = final_scores > self.score_threshold
if self.post_center_range is not None:
limit = torch.tensor(self.post_center_range, device=scores.device)
mask = (final_box_preds[..., :3] >= limit[:3]).all(1)
mask &= (final_box_preds[..., :3] <= limit[3:]).all(1)
if self.score_threshold:
mask &= thresh_mask
boxes3d = final_box_preds[mask]
scores = final_scores[mask]
labels = final_preds[mask]
predictions_dict = {
'bboxes': boxes3d,
'scores': scores,
'labels': labels
}
else:
raise NotImplementedError(
'Need to reorganize output as a batch, only '
'support post_center_range is not None for now!'
)
return predictions_dict
def decode(self, preds_dicts):
"""Decode bboxes.
Args:
all_cls_scores (Tensor): Outputs from the classification head, \
shape [nb_dec, bs, num_query, cls_out_channels]. Note \
cls_out_channels should includes background.
all_bbox_preds (Tensor): Sigmoid outputs from the regression \
head with normalized coordinate format (cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy). \
Shape [nb_dec, bs, num_query, 9].
Returns:
list[dict]: Decoded boxes.
"""
all_cls_scores = preds_dicts['all_cls_scores'][-1]
all_bbox_preds = preds_dicts['all_bbox_preds'][-1]
batch_size = all_cls_scores.size()[0]
predictions_list = []
for i in range(batch_size):
predictions_list.append(self.decode_single(all_cls_scores[i], all_bbox_preds[i]))
return predictions_list
|