| import math |
| import torch |
| import torch.nn as nn |
| from mmcv.runner import force_fp32 |
| from mmdet.core import multi_apply, reduce_mean |
| from mmdet.models import HEADS |
| from mmdet.models.dense_heads import DETRHead |
| from mmdet3d.core.bbox.coders import build_bbox_coder |
| from mmdet3d.core.bbox.structures.lidar_box3d import LiDARInstance3DBoxes |
| from .bbox.utils import normalize_bbox, encode_bbox |
| from .utils import VERSION |
|
|
|
|
| @HEADS.register_module() |
| class SparseBEVHead(DETRHead): |
| def __init__(self, |
| *args, |
| num_classes, |
| in_channels, |
| query_denoising=True, |
| query_denoising_groups=10, |
| bbox_coder=None, |
| code_size=10, |
| code_weights=[1.0] * 10, |
| train_cfg=dict(), |
| test_cfg=dict(max_per_img=100), |
| **kwargs): |
| self.code_size = code_size |
| self.code_weights = code_weights |
| self.num_classes = num_classes |
| self.in_channels = in_channels |
| self.train_cfg = train_cfg |
| self.test_cfg = test_cfg |
| self.fp16_enabled = False |
| self.embed_dims = in_channels |
|
|
| super(SparseBEVHead, self).__init__(num_classes, in_channels, train_cfg=train_cfg, test_cfg=test_cfg, **kwargs) |
|
|
| self.code_weights = nn.Parameter(torch.tensor(self.code_weights), requires_grad=False) |
| self.bbox_coder = build_bbox_coder(bbox_coder) |
| self.pc_range = self.bbox_coder.pc_range |
|
|
| self.dn_enabled = query_denoising |
| self.dn_group_num = query_denoising_groups |
| self.dn_weight = 1.0 |
| self.dn_bbox_noise_scale = 0.5 |
| self.dn_label_noise_scale = 0.5 |
|
|
| def _init_layers(self): |
| self.init_query_bbox = nn.Embedding(self.num_query, 10) |
| self.label_enc = nn.Embedding(self.num_classes + 1, self.embed_dims - 1) |
|
|
| nn.init.zeros_(self.init_query_bbox.weight[:, 2:3]) |
| nn.init.zeros_(self.init_query_bbox.weight[:, 8:10]) |
| nn.init.constant_(self.init_query_bbox.weight[:, 5:6], 1.5) |
|
|
| grid_size = int(math.sqrt(self.num_query)) |
| assert grid_size * grid_size == self.num_query |
| x = y = torch.arange(grid_size) |
| xx, yy = torch.meshgrid(x, y, indexing='ij') |
| xy = torch.cat([xx[..., None], yy[..., None]], dim=-1) |
| xy = (xy + 0.5) / grid_size |
| with torch.no_grad(): |
| self.init_query_bbox.weight[:, :2] = xy.reshape(-1, 2) |
|
|
| def init_weights(self): |
| self.transformer.init_weights() |
|
|
| def forward(self, mlvl_feats, img_metas): |
| query_bbox = self.init_query_bbox.weight.clone() |
| |
|
|
| |
| B = mlvl_feats[0].shape[0] |
| query_bbox, query_feat, attn_mask, mask_dict = self.prepare_for_dn_input(B, query_bbox, self.label_enc, img_metas) |
|
|
| cls_scores, bbox_preds = self.transformer( |
| query_bbox, |
| query_feat, |
| mlvl_feats, |
| attn_mask=attn_mask, |
| img_metas=img_metas, |
| ) |
|
|
| bbox_preds[..., 0] = bbox_preds[..., 0] * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0] |
| bbox_preds[..., 1] = bbox_preds[..., 1] * (self.pc_range[4] - self.pc_range[1]) + self.pc_range[1] |
| bbox_preds[..., 2] = bbox_preds[..., 2] * (self.pc_range[5] - self.pc_range[2]) + self.pc_range[2] |
|
|
| bbox_preds = torch.cat([ |
| bbox_preds[..., 0:2], |
| bbox_preds[..., 3:5], |
| bbox_preds[..., 2:3], |
| bbox_preds[..., 5:10], |
| ], dim=-1) |
|
|
| if mask_dict is not None and mask_dict['pad_size'] > 0: |
| output_known_cls_scores = cls_scores[:, :, :mask_dict['pad_size'], :] |
| output_known_bbox_preds = bbox_preds[:, :, :mask_dict['pad_size'], :] |
| output_cls_scores = cls_scores[:, :, mask_dict['pad_size']:, :] |
| output_bbox_preds = bbox_preds[:, :, mask_dict['pad_size']:, :] |
| mask_dict['output_known_lbs_bboxes'] = (output_known_cls_scores, output_known_bbox_preds) |
| outs = { |
| 'all_cls_scores': output_cls_scores, |
| 'all_bbox_preds': output_bbox_preds, |
| 'enc_cls_scores': None, |
| 'enc_bbox_preds': None, |
| 'dn_mask_dict': mask_dict, |
| } |
| else: |
| outs = { |
| 'all_cls_scores': cls_scores, |
| 'all_bbox_preds': bbox_preds, |
| 'enc_cls_scores': None, |
| 'enc_bbox_preds': None, |
| } |
|
|
| return outs |
|
|
| def prepare_for_dn_input(self, batch_size, init_query_bbox, label_enc, img_metas): |
| |
| |
| |
|
|
| device = init_query_bbox.device |
| indicator0 = torch.zeros([self.num_query, 1], device=device) |
| init_query_feat = label_enc.weight[self.num_classes].repeat(self.num_query, 1) |
| init_query_feat = torch.cat([init_query_feat, indicator0], dim=1) |
|
|
| if self.training and self.dn_enabled: |
| targets = [{ |
| 'bboxes': torch.cat([m['gt_bboxes_3d'].gravity_center, |
| m['gt_bboxes_3d'].tensor[:, 3:]], dim=1).cuda(), |
| 'labels': m['gt_labels_3d'].cuda().long() |
| } for m in img_metas] |
|
|
| known = [torch.ones_like(t['labels'], device=device) for t in targets] |
| known_num = [sum(k) for k in known] |
|
|
| |
| unmask_bbox = unmask_label = torch.cat(known) |
| labels = torch.cat([t['labels'] for t in targets]).clone() |
| bboxes = torch.cat([t['bboxes'] for t in targets]).clone() |
| batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)]) |
|
|
| known_indice = torch.nonzero(unmask_label + unmask_bbox) |
| known_indice = known_indice.view(-1) |
|
|
| |
| known_indice = known_indice.repeat(self.dn_group_num, 1).view(-1) |
| known_labels = labels.repeat(self.dn_group_num, 1).view(-1) |
| known_bid = batch_idx.repeat(self.dn_group_num, 1).view(-1) |
| known_bboxs = bboxes.repeat(self.dn_group_num, 1) |
| known_labels_expand = known_labels.clone() |
| known_bbox_expand = known_bboxs.clone() |
|
|
| |
| if self.dn_bbox_noise_scale > 0: |
| wlh = known_bbox_expand[..., 3:6].clone() |
| rand_prob = torch.rand_like(known_bbox_expand) * 2 - 1.0 |
| known_bbox_expand[..., 0:3] += torch.mul(rand_prob[..., 0:3], wlh / 2) * self.dn_bbox_noise_scale |
| |
| |
|
|
| known_bbox_expand = encode_bbox(known_bbox_expand, self.pc_range) |
| known_bbox_expand[..., 0:3].clamp_(min=0.0, max=1.0) |
| |
|
|
| |
| if self.dn_label_noise_scale > 0: |
| p = torch.rand_like(known_labels_expand.float()) |
| chosen_indice = torch.nonzero(p < self.dn_label_noise_scale).view(-1) |
| new_label = torch.randint_like(chosen_indice, 0, self.num_classes) |
| known_labels_expand.scatter_(0, chosen_indice, new_label) |
|
|
| known_feat_expand = label_enc(known_labels_expand) |
| indicator1 = torch.ones([known_feat_expand.shape[0], 1], device=device) |
| known_feat_expand = torch.cat([known_feat_expand, indicator1], dim=1) |
|
|
| |
| dn_single_pad = int(max(known_num)) |
| dn_pad_size = int(dn_single_pad * self.dn_group_num) |
| dn_query_bbox = torch.zeros([dn_pad_size, init_query_bbox.shape[-1]], device=device) |
| dn_query_feat = torch.zeros([dn_pad_size, self.embed_dims], device=device) |
| input_query_bbox = torch.cat([dn_query_bbox, init_query_bbox], dim=0).repeat(batch_size, 1, 1) |
| input_query_feat = torch.cat([dn_query_feat, init_query_feat], dim=0).repeat(batch_size, 1, 1) |
|
|
| if len(known_num): |
| map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num]) |
| map_known_indice = torch.cat([map_known_indice + dn_single_pad * i for i in range(self.dn_group_num)]).long() |
|
|
| if len(known_bid): |
| input_query_bbox[known_bid.long(), map_known_indice] = known_bbox_expand |
| input_query_feat[(known_bid.long(), map_known_indice)] = known_feat_expand |
|
|
| total_size = dn_pad_size + self.num_query |
| attn_mask = torch.ones([total_size, total_size], device=device) < 0 |
|
|
| |
| attn_mask[dn_pad_size:, :dn_pad_size] = True |
| for i in range(self.dn_group_num): |
| if i == 0: |
| attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), dn_single_pad * (i + 1):dn_pad_size] = True |
| if i == self.dn_group_num - 1: |
| attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), :dn_single_pad * i] = True |
| else: |
| attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), dn_single_pad * (i + 1):dn_pad_size] = True |
| attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), :dn_single_pad * i] = True |
|
|
| mask_dict = { |
| 'known_indice': torch.as_tensor(known_indice).long(), |
| 'batch_idx': torch.as_tensor(batch_idx).long(), |
| 'map_known_indice': torch.as_tensor(map_known_indice).long(), |
| 'known_lbs_bboxes': (known_labels, known_bboxs), |
| 'pad_size': dn_pad_size |
| } |
| else: |
| input_query_bbox = init_query_bbox.repeat(batch_size, 1, 1) |
| input_query_feat = init_query_feat.repeat(batch_size, 1, 1) |
| attn_mask = None |
| mask_dict = None |
|
|
| return input_query_bbox, input_query_feat, attn_mask, mask_dict |
|
|
| def prepare_for_dn_loss(self, mask_dict): |
| cls_scores, bbox_preds = mask_dict['output_known_lbs_bboxes'] |
| known_labels, known_bboxs = mask_dict['known_lbs_bboxes'] |
| map_known_indice = mask_dict['map_known_indice'].long() |
| known_indice = mask_dict['known_indice'].long() |
| batch_idx = mask_dict['batch_idx'].long() |
| bid = batch_idx[known_indice] |
| num_tgt = known_indice.numel() |
|
|
| if len(cls_scores) > 0: |
| cls_scores = cls_scores.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2) |
| bbox_preds = bbox_preds.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2) |
|
|
| return known_labels, known_bboxs, cls_scores, bbox_preds, num_tgt |
|
|
| def dn_loss_single(self, |
| cls_scores, |
| bbox_preds, |
| known_bboxs, |
| known_labels, |
| num_total_pos=None): |
| |
| num_total_pos = cls_scores.new_tensor([num_total_pos]) |
| num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1.0).item() |
|
|
| |
| cls_scores = cls_scores.reshape(-1, self.cls_out_channels) |
| bbox_weights = torch.ones_like(bbox_preds) |
| label_weights = torch.ones_like(known_labels) |
| loss_cls = self.loss_cls( |
| cls_scores, |
| known_labels.long(), |
| label_weights, |
| avg_factor=num_total_pos |
| ) |
|
|
| |
| bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1)) |
| normalized_bbox_targets = normalize_bbox(known_bboxs) |
| isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1) |
| bbox_weights = bbox_weights * self.code_weights |
| loss_bbox = self.loss_bbox( |
| bbox_preds[isnotnan, :10], |
| normalized_bbox_targets[isnotnan, :10], |
| bbox_weights[isnotnan, :10], |
| avg_factor=num_total_pos |
| ) |
|
|
| loss_cls = self.dn_weight * torch.nan_to_num(loss_cls) |
| loss_bbox = self.dn_weight * torch.nan_to_num(loss_bbox) |
|
|
| return loss_cls, loss_bbox |
|
|
| @force_fp32(apply_to=('preds_dicts')) |
| def calc_dn_loss(self, loss_dict, preds_dicts, num_dec_layers): |
| known_labels, known_bboxs, cls_scores, bbox_preds, num_tgt = \ |
| self.prepare_for_dn_loss(preds_dicts['dn_mask_dict']) |
|
|
| all_known_bboxs_list = [known_bboxs for _ in range(num_dec_layers)] |
| all_known_labels_list = [known_labels for _ in range(num_dec_layers)] |
| all_num_tgts_list = [num_tgt for _ in range(num_dec_layers)] |
|
|
| dn_losses_cls, dn_losses_bbox = multi_apply( |
| self.dn_loss_single, cls_scores, bbox_preds, |
| all_known_bboxs_list, all_known_labels_list, all_num_tgts_list) |
|
|
| loss_dict['loss_cls_dn'] = dn_losses_cls[-1] |
| loss_dict['loss_bbox_dn'] = dn_losses_bbox[-1] |
|
|
| num_dec_layer = 0 |
| for loss_cls_i, loss_bbox_i in zip(dn_losses_cls[:-1], dn_losses_bbox[:-1]): |
| loss_dict[f'd{num_dec_layer}.loss_cls_dn'] = loss_cls_i |
| loss_dict[f'd{num_dec_layer}.loss_bbox_dn'] = loss_bbox_i |
| num_dec_layer += 1 |
|
|
| return loss_dict |
|
|
| def _get_target_single(self, |
| cls_score, |
| bbox_pred, |
| gt_labels, |
| gt_bboxes, |
| gt_bboxes_ignore=None): |
| num_bboxes = bbox_pred.size(0) |
|
|
| |
| assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes, gt_labels, gt_bboxes_ignore, self.code_weights, True) |
| sampling_result = self.sampler.sample(assign_result, bbox_pred, gt_bboxes) |
| pos_inds = sampling_result.pos_inds |
| neg_inds = sampling_result.neg_inds |
|
|
| |
| labels = gt_bboxes.new_full((num_bboxes, ), self.num_classes, dtype=torch.long) |
| labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] |
| label_weights = gt_bboxes.new_ones(num_bboxes) |
|
|
| |
| bbox_targets = torch.zeros_like(bbox_pred)[..., :9] |
| bbox_weights = torch.zeros_like(bbox_pred) |
| bbox_weights[pos_inds] = 1.0 |
| |
| |
| bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes |
| return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds) |
|
|
| def get_targets(self, |
| cls_scores_list, |
| bbox_preds_list, |
| gt_bboxes_list, |
| gt_labels_list, |
| gt_bboxes_ignore_list=None): |
| assert gt_bboxes_ignore_list is None, \ |
| 'Only supports for gt_bboxes_ignore setting to None.' |
| num_imgs = len(cls_scores_list) |
| gt_bboxes_ignore_list = [gt_bboxes_ignore_list for _ in range(num_imgs)] |
|
|
| (labels_list, label_weights_list, bbox_targets_list, |
| bbox_weights_list, pos_inds_list, neg_inds_list) = multi_apply( |
| self._get_target_single, cls_scores_list, bbox_preds_list, |
| gt_labels_list, gt_bboxes_list, gt_bboxes_ignore_list) |
| num_total_pos = sum((inds.numel() for inds in pos_inds_list)) |
| num_total_neg = sum((inds.numel() for inds in neg_inds_list)) |
| return (labels_list, label_weights_list, bbox_targets_list, |
| bbox_weights_list, num_total_pos, num_total_neg) |
|
|
| def loss_single(self, |
| cls_scores, |
| bbox_preds, |
| gt_bboxes_list, |
| gt_labels_list, |
| gt_bboxes_ignore_list=None): |
| num_imgs = cls_scores.size(0) |
| cls_scores_list = [cls_scores[i] for i in range(num_imgs)] |
| bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)] |
| cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list, |
| gt_bboxes_list, gt_labels_list, gt_bboxes_ignore_list) |
| (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, |
| num_total_pos, num_total_neg) = cls_reg_targets |
|
|
| labels = torch.cat(labels_list, 0) |
| label_weights = torch.cat(label_weights_list, 0) |
| bbox_targets = torch.cat(bbox_targets_list, 0) |
| bbox_weights = torch.cat(bbox_weights_list, 0) |
|
|
| |
| cls_scores = cls_scores.reshape(-1, self.cls_out_channels) |
| |
| cls_avg_factor = num_total_pos * 1.0 + \ |
| num_total_neg * self.bg_cls_weight |
| if self.sync_cls_avg_factor: |
| cls_avg_factor = reduce_mean( |
| cls_scores.new_tensor([cls_avg_factor])) |
|
|
| cls_avg_factor = max(cls_avg_factor, 1) |
| loss_cls = self.loss_cls( |
| cls_scores, labels, label_weights, avg_factor=cls_avg_factor) |
|
|
| |
| |
| num_total_pos = loss_cls.new_tensor([num_total_pos]) |
| num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() |
|
|
| |
| bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1)) |
| normalized_bbox_targets = normalize_bbox(bbox_targets) |
| isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1) |
| bbox_weights = bbox_weights * self.code_weights |
|
|
| loss_bbox = self.loss_bbox( |
| bbox_preds[isnotnan, :10], |
| normalized_bbox_targets[isnotnan, :10], |
| bbox_weights[isnotnan, :10], |
| avg_factor=num_total_pos |
| ) |
|
|
| loss_cls = torch.nan_to_num(loss_cls) |
| loss_bbox = torch.nan_to_num(loss_bbox) |
| |
| return loss_cls, loss_bbox |
|
|
| @force_fp32(apply_to=('preds_dicts')) |
| def loss(self, |
| gt_bboxes_list, |
| gt_labels_list, |
| preds_dicts, |
| gt_bboxes_ignore=None): |
| assert gt_bboxes_ignore is None, \ |
| f'{self.__class__.__name__} only supports ' \ |
| f'for gt_bboxes_ignore setting to None.' |
|
|
| all_cls_scores = preds_dicts['all_cls_scores'] |
| all_bbox_preds = preds_dicts['all_bbox_preds'] |
| enc_cls_scores = preds_dicts['enc_cls_scores'] |
| enc_bbox_preds = preds_dicts['enc_bbox_preds'] |
|
|
| num_dec_layers = len(all_cls_scores) |
| device = gt_labels_list[0].device |
| gt_bboxes_list = [torch.cat( |
| (gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]), |
| dim=1).to(device) for gt_bboxes in gt_bboxes_list] |
|
|
| all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)] |
| all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] |
| all_gt_bboxes_ignore_list = [gt_bboxes_ignore for _ in range(num_dec_layers)] |
|
|
| losses_cls, losses_bbox = multi_apply( |
| self.loss_single, all_cls_scores, all_bbox_preds, |
| all_gt_bboxes_list, all_gt_labels_list, |
| all_gt_bboxes_ignore_list) |
|
|
| loss_dict = dict() |
| |
| if enc_cls_scores is not None: |
| binary_labels_list = [ |
| torch.zeros_like(gt_labels_list[i]) |
| for i in range(len(all_gt_labels_list)) |
| ] |
| enc_loss_cls, enc_losses_bbox = \ |
| self.loss_single(enc_cls_scores, enc_bbox_preds, |
| gt_bboxes_list, binary_labels_list, gt_bboxes_ignore) |
| loss_dict['enc_loss_cls'] = enc_loss_cls |
| loss_dict['enc_loss_bbox'] = enc_losses_bbox |
|
|
| if 'dn_mask_dict' in preds_dicts and preds_dicts['dn_mask_dict'] is not None: |
| loss_dict = self.calc_dn_loss(loss_dict, preds_dicts, num_dec_layers) |
|
|
| |
| loss_dict['loss_cls'] = losses_cls[-1] |
| loss_dict['loss_bbox'] = losses_bbox[-1] |
|
|
| |
| num_dec_layer = 0 |
| for loss_cls_i, loss_bbox_i in zip(losses_cls[:-1], losses_bbox[:-1]): |
| loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i |
| loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i |
| num_dec_layer += 1 |
| return loss_dict |
|
|
| @force_fp32(apply_to=('preds_dicts')) |
| def get_bboxes(self, preds_dicts, img_metas, rescale=False): |
| preds_dicts = self.bbox_coder.decode(preds_dicts) |
| num_samples = len(preds_dicts) |
| ret_list = [] |
| for i in range(num_samples): |
| preds = preds_dicts[i] |
| bboxes = preds['bboxes'] |
| bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5 |
|
|
| if VERSION.name == 'v0.17.1': |
| import copy |
| w, l = copy.deepcopy(bboxes[:, 3]), copy.deepcopy(bboxes[:, 4]) |
| bboxes[:, 3], bboxes[:, 4] = l, w |
| bboxes[:, 6] = -bboxes[:, 6] - math.pi / 2 |
|
|
| bboxes = LiDARInstance3DBoxes(bboxes, 9) |
| scores = preds['scores'] |
| labels = preds['labels'] |
| ret_list.append([bboxes, scores, labels]) |
| return ret_list |
|
|