import math import torch import torch.nn as nn from mmcv.runner import force_fp32 from mmdet.core import multi_apply, reduce_mean from mmdet.models import HEADS from mmdet.models.dense_heads import DETRHead from mmdet3d.core.bbox.coders import build_bbox_coder from mmdet3d.core.bbox.structures.lidar_box3d import LiDARInstance3DBoxes from .bbox.utils import normalize_bbox, encode_bbox from .utils import VERSION @HEADS.register_module() class SparseBEVHead(DETRHead): def __init__(self, *args, num_classes, in_channels, query_denoising=True, query_denoising_groups=10, bbox_coder=None, code_size=10, code_weights=[1.0] * 10, train_cfg=dict(), test_cfg=dict(max_per_img=100), **kwargs): self.code_size = code_size self.code_weights = code_weights self.num_classes = num_classes self.in_channels = in_channels self.train_cfg = train_cfg self.test_cfg = test_cfg self.fp16_enabled = False self.embed_dims = in_channels super(SparseBEVHead, self).__init__(num_classes, in_channels, train_cfg=train_cfg, test_cfg=test_cfg, **kwargs) self.code_weights = nn.Parameter(torch.tensor(self.code_weights), requires_grad=False) self.bbox_coder = build_bbox_coder(bbox_coder) self.pc_range = self.bbox_coder.pc_range self.dn_enabled = query_denoising self.dn_group_num = query_denoising_groups self.dn_weight = 1.0 self.dn_bbox_noise_scale = 0.5 self.dn_label_noise_scale = 0.5 def _init_layers(self): self.init_query_bbox = nn.Embedding(self.num_query, 10) # (x, y, z, w, l, h, sin, cos, vx, vy) self.label_enc = nn.Embedding(self.num_classes + 1, self.embed_dims - 1) # DAB-DETR nn.init.zeros_(self.init_query_bbox.weight[:, 2:3]) nn.init.zeros_(self.init_query_bbox.weight[:, 8:10]) nn.init.constant_(self.init_query_bbox.weight[:, 5:6], 1.5) grid_size = int(math.sqrt(self.num_query)) assert grid_size * grid_size == self.num_query x = y = torch.arange(grid_size) xx, yy = torch.meshgrid(x, y, indexing='ij') # [0, grid_size - 1] xy = torch.cat([xx[..., None], yy[..., None]], dim=-1) xy = (xy + 0.5) / grid_size # [0.5, grid_size - 0.5] / grid_size ~= (0, 1) with torch.no_grad(): self.init_query_bbox.weight[:, :2] = xy.reshape(-1, 2) # [Q, 2] def init_weights(self): self.transformer.init_weights() def forward(self, mlvl_feats, img_metas): query_bbox = self.init_query_bbox.weight.clone() # [Q, 10] #query_bbox[..., :3] = query_bbox[..., :3].sigmoid() # query denoising B = mlvl_feats[0].shape[0] query_bbox, query_feat, attn_mask, mask_dict = self.prepare_for_dn_input(B, query_bbox, self.label_enc, img_metas) cls_scores, bbox_preds = self.transformer( query_bbox, query_feat, mlvl_feats, attn_mask=attn_mask, img_metas=img_metas, ) bbox_preds[..., 0] = bbox_preds[..., 0] * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0] bbox_preds[..., 1] = bbox_preds[..., 1] * (self.pc_range[4] - self.pc_range[1]) + self.pc_range[1] bbox_preds[..., 2] = bbox_preds[..., 2] * (self.pc_range[5] - self.pc_range[2]) + self.pc_range[2] bbox_preds = torch.cat([ bbox_preds[..., 0:2], bbox_preds[..., 3:5], bbox_preds[..., 2:3], bbox_preds[..., 5:10], ], dim=-1) # [cx, cy, w, l, cz, h, sin, cos, vx, vy] if mask_dict is not None and mask_dict['pad_size'] > 0: # if using query denoising output_known_cls_scores = cls_scores[:, :, :mask_dict['pad_size'], :] output_known_bbox_preds = bbox_preds[:, :, :mask_dict['pad_size'], :] output_cls_scores = cls_scores[:, :, mask_dict['pad_size']:, :] output_bbox_preds = bbox_preds[:, :, mask_dict['pad_size']:, :] mask_dict['output_known_lbs_bboxes'] = (output_known_cls_scores, output_known_bbox_preds) outs = { 'all_cls_scores': output_cls_scores, 'all_bbox_preds': output_bbox_preds, 'enc_cls_scores': None, 'enc_bbox_preds': None, 'dn_mask_dict': mask_dict, } else: outs = { 'all_cls_scores': cls_scores, 'all_bbox_preds': bbox_preds, 'enc_cls_scores': None, 'enc_bbox_preds': None, } return outs def prepare_for_dn_input(self, batch_size, init_query_bbox, label_enc, img_metas): # mostly borrowed from: # - https://github.com/IDEA-Research/DN-DETR/blob/main/models/DN_DAB_DETR/dn_components.py # - https://github.com/megvii-research/PETR/blob/main/projects/mmdet3d_plugin/models/dense_heads/petrv2_dnhead.py device = init_query_bbox.device indicator0 = torch.zeros([self.num_query, 1], device=device) init_query_feat = label_enc.weight[self.num_classes].repeat(self.num_query, 1) init_query_feat = torch.cat([init_query_feat, indicator0], dim=1) if self.training and self.dn_enabled: targets = [{ 'bboxes': torch.cat([m['gt_bboxes_3d'].gravity_center, m['gt_bboxes_3d'].tensor[:, 3:]], dim=1).cuda(), 'labels': m['gt_labels_3d'].cuda().long() } for m in img_metas] known = [torch.ones_like(t['labels'], device=device) for t in targets] known_num = [sum(k) for k in known] # can be modified to selectively denosie some label or boxes; also known label prediction unmask_bbox = unmask_label = torch.cat(known) labels = torch.cat([t['labels'] for t in targets]).clone() bboxes = torch.cat([t['bboxes'] for t in targets]).clone() batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)]) known_indice = torch.nonzero(unmask_label + unmask_bbox) known_indice = known_indice.view(-1) # add noise known_indice = known_indice.repeat(self.dn_group_num, 1).view(-1) known_labels = labels.repeat(self.dn_group_num, 1).view(-1) known_bid = batch_idx.repeat(self.dn_group_num, 1).view(-1) known_bboxs = bboxes.repeat(self.dn_group_num, 1) # 9 known_labels_expand = known_labels.clone() known_bbox_expand = known_bboxs.clone() # noise on the box if self.dn_bbox_noise_scale > 0: wlh = known_bbox_expand[..., 3:6].clone() rand_prob = torch.rand_like(known_bbox_expand) * 2 - 1.0 known_bbox_expand[..., 0:3] += torch.mul(rand_prob[..., 0:3], wlh / 2) * self.dn_bbox_noise_scale # known_bbox_expand[..., 3:6] += torch.mul(rand_prob[..., 3:6], wlh) * self.dn_bbox_noise_scale # known_bbox_expand[..., 6:7] += torch.mul(rand_prob[..., 6:7], 3.14159) * self.dn_bbox_noise_scale known_bbox_expand = encode_bbox(known_bbox_expand, self.pc_range) known_bbox_expand[..., 0:3].clamp_(min=0.0, max=1.0) # nn.init.constant(known_bbox_expand[..., 8:10], 0.0) # noise on the label if self.dn_label_noise_scale > 0: p = torch.rand_like(known_labels_expand.float()) chosen_indice = torch.nonzero(p < self.dn_label_noise_scale).view(-1) # usually half of bbox noise new_label = torch.randint_like(chosen_indice, 0, self.num_classes) # randomly put a new one here known_labels_expand.scatter_(0, chosen_indice, new_label) known_feat_expand = label_enc(known_labels_expand) indicator1 = torch.ones([known_feat_expand.shape[0], 1], device=device) # add dn part indicator known_feat_expand = torch.cat([known_feat_expand, indicator1], dim=1) # construct final query dn_single_pad = int(max(known_num)) dn_pad_size = int(dn_single_pad * self.dn_group_num) dn_query_bbox = torch.zeros([dn_pad_size, init_query_bbox.shape[-1]], device=device) dn_query_feat = torch.zeros([dn_pad_size, self.embed_dims], device=device) input_query_bbox = torch.cat([dn_query_bbox, init_query_bbox], dim=0).repeat(batch_size, 1, 1) input_query_feat = torch.cat([dn_query_feat, init_query_feat], dim=0).repeat(batch_size, 1, 1) if len(known_num): map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num]) # [1,2, 1,2,3] map_known_indice = torch.cat([map_known_indice + dn_single_pad * i for i in range(self.dn_group_num)]).long() if len(known_bid): input_query_bbox[known_bid.long(), map_known_indice] = known_bbox_expand input_query_feat[(known_bid.long(), map_known_indice)] = known_feat_expand total_size = dn_pad_size + self.num_query attn_mask = torch.ones([total_size, total_size], device=device) < 0 # match query cannot see the reconstruct attn_mask[dn_pad_size:, :dn_pad_size] = True for i in range(self.dn_group_num): if i == 0: attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), dn_single_pad * (i + 1):dn_pad_size] = True if i == self.dn_group_num - 1: attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), :dn_single_pad * i] = True else: attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), dn_single_pad * (i + 1):dn_pad_size] = True attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), :dn_single_pad * i] = True mask_dict = { 'known_indice': torch.as_tensor(known_indice).long(), 'batch_idx': torch.as_tensor(batch_idx).long(), 'map_known_indice': torch.as_tensor(map_known_indice).long(), 'known_lbs_bboxes': (known_labels, known_bboxs), 'pad_size': dn_pad_size } else: input_query_bbox = init_query_bbox.repeat(batch_size, 1, 1) input_query_feat = init_query_feat.repeat(batch_size, 1, 1) attn_mask = None mask_dict = None return input_query_bbox, input_query_feat, attn_mask, mask_dict def prepare_for_dn_loss(self, mask_dict): cls_scores, bbox_preds = mask_dict['output_known_lbs_bboxes'] known_labels, known_bboxs = mask_dict['known_lbs_bboxes'] map_known_indice = mask_dict['map_known_indice'].long() known_indice = mask_dict['known_indice'].long() batch_idx = mask_dict['batch_idx'].long() bid = batch_idx[known_indice] num_tgt = known_indice.numel() if len(cls_scores) > 0: cls_scores = cls_scores.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2) bbox_preds = bbox_preds.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2) return known_labels, known_bboxs, cls_scores, bbox_preds, num_tgt def dn_loss_single(self, cls_scores, bbox_preds, known_bboxs, known_labels, num_total_pos=None): # Compute the average number of gt boxes accross all gpus num_total_pos = cls_scores.new_tensor([num_total_pos]) num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1.0).item() # cls loss cls_scores = cls_scores.reshape(-1, self.cls_out_channels) bbox_weights = torch.ones_like(bbox_preds) label_weights = torch.ones_like(known_labels) loss_cls = self.loss_cls( cls_scores, known_labels.long(), label_weights, avg_factor=num_total_pos ) # regression L1 loss bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1)) normalized_bbox_targets = normalize_bbox(known_bboxs) isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1) bbox_weights = bbox_weights * self.code_weights loss_bbox = self.loss_bbox( bbox_preds[isnotnan, :10], normalized_bbox_targets[isnotnan, :10], bbox_weights[isnotnan, :10], avg_factor=num_total_pos ) loss_cls = self.dn_weight * torch.nan_to_num(loss_cls) loss_bbox = self.dn_weight * torch.nan_to_num(loss_bbox) return loss_cls, loss_bbox @force_fp32(apply_to=('preds_dicts')) def calc_dn_loss(self, loss_dict, preds_dicts, num_dec_layers): known_labels, known_bboxs, cls_scores, bbox_preds, num_tgt = \ self.prepare_for_dn_loss(preds_dicts['dn_mask_dict']) all_known_bboxs_list = [known_bboxs for _ in range(num_dec_layers)] all_known_labels_list = [known_labels for _ in range(num_dec_layers)] all_num_tgts_list = [num_tgt for _ in range(num_dec_layers)] dn_losses_cls, dn_losses_bbox = multi_apply( self.dn_loss_single, cls_scores, bbox_preds, all_known_bboxs_list, all_known_labels_list, all_num_tgts_list) loss_dict['loss_cls_dn'] = dn_losses_cls[-1] loss_dict['loss_bbox_dn'] = dn_losses_bbox[-1] num_dec_layer = 0 for loss_cls_i, loss_bbox_i in zip(dn_losses_cls[:-1], dn_losses_bbox[:-1]): loss_dict[f'd{num_dec_layer}.loss_cls_dn'] = loss_cls_i loss_dict[f'd{num_dec_layer}.loss_bbox_dn'] = loss_bbox_i num_dec_layer += 1 return loss_dict def _get_target_single(self, cls_score, bbox_pred, gt_labels, gt_bboxes, gt_bboxes_ignore=None): num_bboxes = bbox_pred.size(0) # assigner and sampler assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes, gt_labels, gt_bboxes_ignore, self.code_weights, True) sampling_result = self.sampler.sample(assign_result, bbox_pred, gt_bboxes) pos_inds = sampling_result.pos_inds neg_inds = sampling_result.neg_inds # label targets labels = gt_bboxes.new_full((num_bboxes, ), self.num_classes, dtype=torch.long) labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] label_weights = gt_bboxes.new_ones(num_bboxes) # bbox targets bbox_targets = torch.zeros_like(bbox_pred)[..., :9] bbox_weights = torch.zeros_like(bbox_pred) bbox_weights[pos_inds] = 1.0 # DETR bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds) def get_targets(self, cls_scores_list, bbox_preds_list, gt_bboxes_list, gt_labels_list, gt_bboxes_ignore_list=None): assert gt_bboxes_ignore_list is None, \ 'Only supports for gt_bboxes_ignore setting to None.' num_imgs = len(cls_scores_list) gt_bboxes_ignore_list = [gt_bboxes_ignore_list for _ in range(num_imgs)] (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, pos_inds_list, neg_inds_list) = multi_apply( self._get_target_single, cls_scores_list, bbox_preds_list, gt_labels_list, gt_bboxes_list, gt_bboxes_ignore_list) num_total_pos = sum((inds.numel() for inds in pos_inds_list)) num_total_neg = sum((inds.numel() for inds in neg_inds_list)) return (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) def loss_single(self, cls_scores, bbox_preds, gt_bboxes_list, gt_labels_list, gt_bboxes_ignore_list=None): num_imgs = cls_scores.size(0) cls_scores_list = [cls_scores[i] for i in range(num_imgs)] bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)] cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list, gt_bboxes_list, gt_labels_list, gt_bboxes_ignore_list) (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets labels = torch.cat(labels_list, 0) label_weights = torch.cat(label_weights_list, 0) bbox_targets = torch.cat(bbox_targets_list, 0) bbox_weights = torch.cat(bbox_weights_list, 0) # classification loss cls_scores = cls_scores.reshape(-1, self.cls_out_channels) # construct weighted avg_factor to match with the official DETR repo cls_avg_factor = num_total_pos * 1.0 + \ num_total_neg * self.bg_cls_weight if self.sync_cls_avg_factor: cls_avg_factor = reduce_mean( cls_scores.new_tensor([cls_avg_factor])) cls_avg_factor = max(cls_avg_factor, 1) loss_cls = self.loss_cls( cls_scores, labels, label_weights, avg_factor=cls_avg_factor) # Compute the average number of gt boxes accross all gpus, for # normalization purposes num_total_pos = loss_cls.new_tensor([num_total_pos]) num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() # regression L1 loss bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1)) normalized_bbox_targets = normalize_bbox(bbox_targets) isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1) bbox_weights = bbox_weights * self.code_weights loss_bbox = self.loss_bbox( bbox_preds[isnotnan, :10], normalized_bbox_targets[isnotnan, :10], bbox_weights[isnotnan, :10], avg_factor=num_total_pos ) loss_cls = torch.nan_to_num(loss_cls) loss_bbox = torch.nan_to_num(loss_bbox) return loss_cls, loss_bbox @force_fp32(apply_to=('preds_dicts')) def loss(self, gt_bboxes_list, gt_labels_list, preds_dicts, gt_bboxes_ignore=None): assert gt_bboxes_ignore is None, \ f'{self.__class__.__name__} only supports ' \ f'for gt_bboxes_ignore setting to None.' all_cls_scores = preds_dicts['all_cls_scores'] all_bbox_preds = preds_dicts['all_bbox_preds'] enc_cls_scores = preds_dicts['enc_cls_scores'] enc_bbox_preds = preds_dicts['enc_bbox_preds'] num_dec_layers = len(all_cls_scores) device = gt_labels_list[0].device gt_bboxes_list = [torch.cat( (gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]), dim=1).to(device) for gt_bboxes in gt_bboxes_list] all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)] all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] all_gt_bboxes_ignore_list = [gt_bboxes_ignore for _ in range(num_dec_layers)] losses_cls, losses_bbox = multi_apply( self.loss_single, all_cls_scores, all_bbox_preds, all_gt_bboxes_list, all_gt_labels_list, all_gt_bboxes_ignore_list) loss_dict = dict() # loss of proposal generated from encode feature map if enc_cls_scores is not None: binary_labels_list = [ torch.zeros_like(gt_labels_list[i]) for i in range(len(all_gt_labels_list)) ] enc_loss_cls, enc_losses_bbox = \ self.loss_single(enc_cls_scores, enc_bbox_preds, gt_bboxes_list, binary_labels_list, gt_bboxes_ignore) loss_dict['enc_loss_cls'] = enc_loss_cls loss_dict['enc_loss_bbox'] = enc_losses_bbox if 'dn_mask_dict' in preds_dicts and preds_dicts['dn_mask_dict'] is not None: loss_dict = self.calc_dn_loss(loss_dict, preds_dicts, num_dec_layers) # loss from the last decoder layer loss_dict['loss_cls'] = losses_cls[-1] loss_dict['loss_bbox'] = losses_bbox[-1] # loss from other decoder layers num_dec_layer = 0 for loss_cls_i, loss_bbox_i in zip(losses_cls[:-1], losses_bbox[:-1]): loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i num_dec_layer += 1 return loss_dict @force_fp32(apply_to=('preds_dicts')) def get_bboxes(self, preds_dicts, img_metas, rescale=False): preds_dicts = self.bbox_coder.decode(preds_dicts) num_samples = len(preds_dicts) ret_list = [] for i in range(num_samples): preds = preds_dicts[i] bboxes = preds['bboxes'] bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5 if VERSION.name == 'v0.17.1': import copy w, l = copy.deepcopy(bboxes[:, 3]), copy.deepcopy(bboxes[:, 4]) bboxes[:, 3], bboxes[:, 4] = l, w bboxes[:, 6] = -bboxes[:, 6] - math.pi / 2 bboxes = LiDARInstance3DBoxes(bboxes, 9) scores = preds['scores'] labels = preds['labels'] ret_list.append([bboxes, scores, labels]) return ret_list