import torch import torch.nn as nn import numpy as np import torch.nn.functional as F from mmcv.runner import BaseModule from mmcv.cnn import bias_init_with_prob from mmcv.cnn.bricks.transformer import MultiheadAttention, FFN from mmdet.models.utils.builder import TRANSFORMER from .bbox.utils import decode_bbox from .utils import inverse_sigmoid, DUMP from .sparsebev_sampling import sampling_4d, make_sample_points from .checkpoint import checkpoint as cp from .csrc.wrapper import MSMV_CUDA @TRANSFORMER.register_module() class SparseBEVTransformer(BaseModule): def __init__(self, embed_dims, num_frames=8, num_points=4, num_layers=6, num_levels=4, num_classes=10, code_size=10, pc_range=[], init_cfg=None): assert init_cfg is None, 'To prevent abnormal initialization ' \ 'behavior, init_cfg is not allowed to be set' super(SparseBEVTransformer, self).__init__(init_cfg=init_cfg) self.embed_dims = embed_dims self.pc_range = pc_range self.decoder = SparseBEVTransformerDecoder(embed_dims, num_frames, num_points, num_layers, num_levels, num_classes, code_size, pc_range=pc_range) @torch.no_grad() def init_weights(self): self.decoder.init_weights() def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas): cls_scores, bbox_preds = self.decoder(query_bbox, query_feat, mlvl_feats, attn_mask, img_metas) cls_scores = torch.nan_to_num(cls_scores) bbox_preds = torch.nan_to_num(bbox_preds) return cls_scores, bbox_preds class SparseBEVTransformerDecoder(BaseModule): def __init__(self, embed_dims, num_frames=8, num_points=4, num_layers=6, num_levels=4, num_classes=10, code_size=10, pc_range=[], init_cfg=None): super(SparseBEVTransformerDecoder, self).__init__(init_cfg) self.num_layers = num_layers self.pc_range = pc_range # params are shared across all decoder layers self.decoder_layer = SparseBEVTransformerDecoderLayer( embed_dims, num_frames, num_points, num_levels, num_classes, code_size, pc_range=pc_range ) @torch.no_grad() def init_weights(self): self.decoder_layer.init_weights() def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas): cls_scores, bbox_preds = [], [] if isinstance(img_metas[0].get('time_diff'), torch.Tensor): # ONNX export path: tensors pre-computed and injected by the wrapper pass # time_diff and lidar2img already set in img_metas[0] else: # Standard path: extract from img_metas using numpy # calculate time difference according to timestamps timestamps = np.array([m['img_timestamp'] for m in img_metas], dtype=np.float64) timestamps = np.reshape(timestamps, [query_bbox.shape[0], -1, 6]) time_diff = timestamps[:, :1, :] - timestamps time_diff = np.mean(time_diff, axis=-1).astype(np.float32) # [B, F] time_diff = torch.from_numpy(time_diff).to(query_bbox.device) # [B, F] img_metas[0]['time_diff'] = time_diff # organize projections matrix and copy to CUDA lidar2img = np.asarray([m['lidar2img'] for m in img_metas]).astype(np.float32) lidar2img = torch.from_numpy(lidar2img).to(query_bbox.device) # [B, N, 4, 4] img_metas[0]['lidar2img'] = lidar2img # group image features in advance for sampling, see `sampling_4d` for more details for lvl, feat in enumerate(mlvl_feats): B, TN, GC, H, W = feat.shape # [B, TN, GC, H, W] N, T, G, C = 6, TN // 6, 4, GC // 4 feat = feat.reshape(B, T, N, G, C, H, W) if MSMV_CUDA: # Our CUDA operator requires channel_last feat = feat.permute(0, 1, 3, 2, 5, 6, 4) # [B, T, G, N, H, W, C] feat = feat.reshape(B*T*G, N, H, W, C) else: # Torch's grid_sample requires channel_first feat = feat.permute(0, 1, 3, 4, 2, 5, 6) # [B, T, G, C, N, H, W] feat = feat.reshape(B*T*G, C, N, H, W) mlvl_feats[lvl] = feat.contiguous() for i in range(self.num_layers): DUMP.stage_count = i query_feat, cls_score, bbox_pred = self.decoder_layer( query_bbox, query_feat, mlvl_feats, attn_mask, img_metas ) query_bbox = bbox_pred.clone().detach() cls_scores.append(cls_score) bbox_preds.append(bbox_pred) cls_scores = torch.stack(cls_scores) bbox_preds = torch.stack(bbox_preds) return cls_scores, bbox_preds class SparseBEVTransformerDecoderLayer(BaseModule): def __init__(self, embed_dims, num_frames=8, num_points=4, num_levels=4, num_classes=10, code_size=10, num_cls_fcs=2, num_reg_fcs=2, pc_range=[], init_cfg=None): super(SparseBEVTransformerDecoderLayer, self).__init__(init_cfg) self.embed_dims = embed_dims self.num_classes = num_classes self.code_size = code_size self.pc_range = pc_range self.position_encoder = nn.Sequential( nn.Linear(3, self.embed_dims), nn.LayerNorm(self.embed_dims), nn.ReLU(inplace=True), nn.Linear(self.embed_dims, self.embed_dims), nn.LayerNorm(self.embed_dims), nn.ReLU(inplace=True), ) self.self_attn = SparseBEVSelfAttention(embed_dims, num_heads=8, dropout=0.1, pc_range=pc_range) self.sampling = SparseBEVSampling(embed_dims, num_frames=num_frames, num_groups=4, num_points=num_points, num_levels=num_levels, pc_range=pc_range) self.mixing = AdaptiveMixing(in_dim=embed_dims, in_points=num_points * num_frames, n_groups=4, out_points=128) self.ffn = FFN(embed_dims, feedforward_channels=512, ffn_drop=0.1) self.norm1 = nn.LayerNorm(embed_dims) self.norm2 = nn.LayerNorm(embed_dims) self.norm3 = nn.LayerNorm(embed_dims) cls_branch = [] for _ in range(num_cls_fcs): cls_branch.append(nn.Linear(self.embed_dims, self.embed_dims)) cls_branch.append(nn.LayerNorm(self.embed_dims)) cls_branch.append(nn.ReLU(inplace=True)) cls_branch.append(nn.Linear(self.embed_dims, self.num_classes)) self.cls_branch = nn.Sequential(*cls_branch) reg_branch = [] for _ in range(num_reg_fcs): reg_branch.append(nn.Linear(self.embed_dims, self.embed_dims)) reg_branch.append(nn.ReLU(inplace=True)) reg_branch.append(nn.Linear(self.embed_dims, self.code_size)) self.reg_branch = nn.Sequential(*reg_branch) @torch.no_grad() def init_weights(self): self.self_attn.init_weights() self.sampling.init_weights() self.mixing.init_weights() bias_init = bias_init_with_prob(0.01) nn.init.constant_(self.cls_branch[-1].bias, bias_init) def refine_bbox(self, bbox_proposal, bbox_delta): xyz = inverse_sigmoid(bbox_proposal[..., 0:3]) xyz_delta = bbox_delta[..., 0:3] xyz_new = torch.sigmoid(xyz_delta + xyz) return torch.cat([xyz_new, bbox_delta[..., 3:]], dim=-1) def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas): """ query_bbox: [B, Q, 10] [cx, cy, cz, w, h, d, rot.sin, rot.cos, vx, vy] """ query_pos = self.position_encoder(query_bbox[..., :3]) query_feat = query_feat + query_pos query_feat = self.norm1(self.self_attn(query_bbox, query_feat, attn_mask)) sampled_feat = self.sampling(query_bbox, query_feat, mlvl_feats, img_metas) query_feat = self.norm2(self.mixing(sampled_feat, query_feat)) query_feat = self.norm3(self.ffn(query_feat)) cls_score = self.cls_branch(query_feat) # [B, Q, num_classes] bbox_pred = self.reg_branch(query_feat) # [B, Q, code_size] bbox_pred = self.refine_bbox(query_bbox, bbox_pred) # calculate absolute velocity according to time difference time_diff = img_metas[0]['time_diff'] # [B, F] if time_diff.shape[1] > 1: time_diff = torch.where(time_diff < 1e-5, torch.ones_like(time_diff), time_diff) bbox_pred = torch.cat([ bbox_pred[..., :8], bbox_pred[..., 8:] / time_diff[:, 1:2, None], ], dim=-1) if DUMP.enabled: query_bbox_dec = decode_bbox(query_bbox, self.pc_range) bbox_pred_dec = decode_bbox(bbox_pred, self.pc_range) cls_score_sig = torch.sigmoid(cls_score) torch.save(query_bbox_dec.cpu(), '{}/query_bbox_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count)) torch.save(bbox_pred_dec.cpu(), '{}/bbox_pred_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count)) torch.save(cls_score_sig.cpu(), '{}/cls_score_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count)) return query_feat, cls_score, bbox_pred class SparseBEVSelfAttention(BaseModule): """Scale-adaptive Self Attention""" def __init__(self, embed_dims=256, num_heads=8, dropout=0.1, pc_range=[], init_cfg=None): super().__init__(init_cfg) self.pc_range = pc_range self.attention = MultiheadAttention(embed_dims, num_heads, dropout, batch_first=True) self.gen_tau = nn.Linear(embed_dims, num_heads) @torch.no_grad() def init_weights(self): nn.init.zeros_(self.gen_tau.weight) nn.init.uniform_(self.gen_tau.bias, 0.0, 2.0) def inner_forward(self, query_bbox, query_feat, pre_attn_mask): """ query_bbox: [B, Q, 10] query_feat: [B, Q, C] """ dist = self.calc_bbox_dists(query_bbox) tau = self.gen_tau(query_feat) # [B, Q, 8] if DUMP.enabled: torch.save(tau.cpu(), '{}/sasa_tau_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count)) tau = tau.permute(0, 2, 1) # [B, 8, Q] attn_mask = dist[:, None, :, :] * tau[..., None] # [B, 8, Q, Q] if pre_attn_mask is not None: # for query denoising attn_mask[:, :, pre_attn_mask] = float('-inf') attn_mask = attn_mask.flatten(0, 1) # [Bx8, Q, Q] return self.attention(query_feat, attn_mask=attn_mask) def forward(self, query_bbox, query_feat, pre_attn_mask): if self.training and query_feat.requires_grad: return cp(self.inner_forward, query_bbox, query_feat, pre_attn_mask, use_reentrant=False) else: return self.inner_forward(query_bbox, query_feat, pre_attn_mask) @torch.no_grad() def calc_bbox_dists(self, bboxes): centers = decode_bbox(bboxes, self.pc_range)[..., :2] # [B, Q, 2] dist = torch.norm(centers.unsqueeze(2) - centers.unsqueeze(1), dim=-1) # [B, Q, Q] return -dist class SparseBEVSampling(BaseModule): """Adaptive Spatio-temporal Sampling""" def __init__(self, embed_dims=256, num_frames=4, num_groups=4, num_points=8, num_levels=4, pc_range=[], init_cfg=None): super().__init__(init_cfg) self.num_frames = num_frames self.num_points = num_points self.num_groups = num_groups self.num_levels = num_levels self.pc_range = pc_range self.sampling_offset = nn.Linear(embed_dims, num_groups * num_points * 3) self.scale_weights = nn.Linear(embed_dims, num_groups * num_points * num_levels) def init_weights(self): bias = self.sampling_offset.bias.data.view(self.num_groups * self.num_points, 3) nn.init.zeros_(self.sampling_offset.weight) nn.init.uniform_(bias[:, 0:3], -0.5, 0.5) def inner_forward(self, query_bbox, query_feat, mlvl_feats, img_metas): ''' query_bbox: [B, Q, 10] query_feat: [B, Q, C] ''' B, Q = query_bbox.shape[:2] image_h, image_w, _ = img_metas[0]['img_shape'][0] # sampling offset of all frames sampling_offset = self.sampling_offset(query_feat) sampling_offset = sampling_offset.view(B, Q, self.num_groups * self.num_points, 3) sampling_points = make_sample_points(query_bbox, sampling_offset, self.pc_range) # [B, Q, GP, 3] sampling_points = sampling_points.reshape(B, Q, 1, self.num_groups, self.num_points, 3) sampling_points = sampling_points.expand(B, Q, self.num_frames, self.num_groups, self.num_points, 3) # warp sample points based on velocity time_diff = img_metas[0]['time_diff'] # [B, F] time_diff = time_diff[:, None, :, None] # [B, 1, F, 1] vel = query_bbox[..., 8:].detach() # [B, Q, 2] vel = vel[:, :, None, :] # [B, Q, 1, 2] dist = vel * time_diff # [B, Q, F, 2] dist = dist[:, :, :, None, None, :] # [B, Q, F, 1, 1, 2] sampling_points = torch.cat([ sampling_points[..., 0:2] - dist, sampling_points[..., 2:3] ], dim=-1) # scale weights scale_weights = self.scale_weights(query_feat).view(B, Q, self.num_groups, 1, self.num_points, self.num_levels) scale_weights = torch.softmax(scale_weights, dim=-1) scale_weights = scale_weights.expand(B, Q, self.num_groups, self.num_frames, self.num_points, self.num_levels) # sampling sampled_feats = sampling_4d( sampling_points, mlvl_feats, scale_weights, img_metas[0]['lidar2img'], image_h, image_w ) # [B, Q, G, FP, C] return sampled_feats def forward(self, query_bbox, query_feat, mlvl_feats, img_metas): if self.training and query_feat.requires_grad: return cp(self.inner_forward, query_bbox, query_feat, mlvl_feats, img_metas, use_reentrant=False) else: return self.inner_forward(query_bbox, query_feat, mlvl_feats, img_metas) class AdaptiveMixing(nn.Module): """Adaptive Mixing""" def __init__(self, in_dim, in_points, n_groups=1, query_dim=None, out_dim=None, out_points=None): super(AdaptiveMixing, self).__init__() out_dim = out_dim if out_dim is not None else in_dim out_points = out_points if out_points is not None else in_points query_dim = query_dim if query_dim is not None else in_dim self.query_dim = query_dim self.in_dim = in_dim self.in_points = in_points self.n_groups = n_groups self.out_dim = out_dim self.out_points = out_points self.eff_in_dim = in_dim // n_groups self.eff_out_dim = out_dim // n_groups self.m_parameters = self.eff_in_dim * self.eff_out_dim self.s_parameters = self.in_points * self.out_points self.total_parameters = self.m_parameters + self.s_parameters self.parameter_generator = nn.Linear(self.query_dim, self.n_groups * self.total_parameters) self.out_proj = nn.Linear(self.eff_out_dim * self.out_points * self.n_groups, self.query_dim) self.act = nn.ReLU(inplace=True) @torch.no_grad() def init_weights(self): nn.init.zeros_(self.parameter_generator.weight) def inner_forward(self, x, query): B, Q, G, P, C = x.shape assert G == self.n_groups assert P == self.in_points assert C == self.eff_in_dim '''generate mixing parameters''' params = self.parameter_generator(query) params = params.reshape(B*Q, G, -1) out = x.reshape(B*Q, G, P, C) M, S = params.split([self.m_parameters, self.s_parameters], 2) M = M.reshape(B*Q, G, self.eff_in_dim, self.eff_out_dim) S = S.reshape(B*Q, G, self.out_points, self.in_points) '''adaptive channel mixing''' out = torch.matmul(out, M) out = F.layer_norm(out, [out.size(-2), out.size(-1)]) out = self.act(out) '''adaptive point mixing''' out = torch.matmul(S, out) # implicitly transpose and matmul out = F.layer_norm(out, [out.size(-2), out.size(-1)]) out = self.act(out) '''linear transfomation to query dim''' out = out.reshape(B, Q, -1) out = self.out_proj(out) out = query + out return out def forward(self, x, query): if self.training and x.requires_grad: return cp(self.inner_forward, x, query, use_reentrant=False) else: return self.inner_forward(x, query)