import math import warnings import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import constant_init, xavier_init from mmcv.cnn.bricks.registry import ATTENTION from mmcv.runner.base_module import BaseModule from torch.nn.init import normal_ from .detr3d_transformer import inverse_sigmoid, feature_sampling @ATTENTION.register_module() class Detr3DCamRadarCrossAtten(BaseModule): """An attention module used in Detr3d. Args: embed_dims (int): The embedding dimension of Attention. Default: 256. num_heads (int): Parallel attention heads. Default: 64. num_levels (int): The number of feature map used in Attention. Default: 4. num_points (int): The number of sampling points for each query in each head. Default: 4. im2col_step (int): The step used in image_to_column. Default: 64. dropout (float): A Dropout layer on `inp_residual`. Default: 0.. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. """ def __init__( self, embed_dims=256, num_heads=8, num_levels=4, num_points=5, num_cams=6, radar_dims=3, radar_topk=8, im2col_step=64, pc_range=None, dropout=0.1, norm_cfg=None, init_cfg=None, batch_first=False, ): super(Detr3DCamRadarCrossAtten, self).__init__(init_cfg) if embed_dims % num_heads != 0: raise ValueError( f"embed_dims must be divisible by num_heads, " f"but got {embed_dims} and {num_heads}" ) dim_per_head = embed_dims // num_heads self.norm_cfg = norm_cfg self.init_cfg = init_cfg self.dropout = nn.Dropout(dropout) self.pc_range = pc_range # you'd better set dim_per_head to a power of 2 # which is more efficient in the CUDA implementation def _is_power_of_2(n): if (not isinstance(n, int)) or (n < 0): raise ValueError( "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)) ) return (n & (n - 1) == 0) and n != 0 if not _is_power_of_2(dim_per_head): warnings.warn( "You'd better set embed_dims in " "MultiScaleDeformAttention to make " "the dimension of each attention head a power of 2 " "which is more efficient in our CUDA implementation." ) self.im2col_step = im2col_step self.embed_dims = embed_dims self.num_levels = num_levels self.num_heads = num_heads self.num_points = num_points self.num_cams = num_cams self.attention_weights = nn.Linear( embed_dims, num_cams * num_levels * num_points ) self.radar_dims = radar_dims self.attention_weights_radar = nn.Linear(embed_dims, radar_topk) self.radar_topk = radar_topk self.img_output_proj = nn.Linear(embed_dims, embed_dims) self.radar_output_proj = nn.Linear(self.radar_dims, self.radar_dims) self.img_radar_fusion = nn.Sequential( nn.Linear(embed_dims + radar_dims, embed_dims), nn.LayerNorm(self.embed_dims), nn.ReLU(inplace=True), nn.Linear(self.embed_dims, self.embed_dims), nn.LayerNorm(self.embed_dims), ) self.position_encoder = nn.Sequential( nn.Linear(3, self.embed_dims), nn.LayerNorm(self.embed_dims), nn.ReLU(inplace=True), nn.Linear(self.embed_dims, self.embed_dims), nn.LayerNorm(self.embed_dims), nn.ReLU(inplace=True), ) self.batch_first = batch_first self.init_weight() def init_weight(self): """Default initialization for Parameters of Module.""" constant_init(self.attention_weights, val=0.0, bias=0.0) constant_init(self.attention_weights_radar, val=0.0, bias=0.0) xavier_init(self.img_output_proj, distribution="uniform", bias=0.0) xavier_init(self.radar_output_proj, distribution="uniform", bias=0.0) xavier_init(self.img_radar_fusion, distribution="uniform", bias=0.0) def forward( self, query, key, value, residual=None, query_pos=None, key_padding_mask=None, reference_points=None, ref_size=None, spatial_shapes=None, level_start_index=None, radar_feats=None, **kwargs, ): """Forward Function of Detr3DCrossAtten. Args: query (Tensor): Query of Transformer with shape (num_query, bs, embed_dims). key (Tensor): The key tensor with shape `(num_key, bs, embed_dims)`. value (Tensor): The value tensor with shape `(num_key, bs, embed_dims)`. (B, N, C, H, W) residual (Tensor): The tensor used for addition, with the same shape as `x`. Default None. If None, `x` will be used. query_pos (Tensor): The positional encoding for `query`. Default: None. key_pos (Tensor): The positional encoding for `key`. Default None. reference_points (Tensor): The normalized reference points with shape (bs, num_query, 3), all elements is range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area. ref_size (Tensor): the wlh(bbox size) associated with each query shape (bs, num_query, 3) value in log space. key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_key]. spatial_shapes (Tensor): Spatial shape of features in different level. With shape (num_levels, 2), last dimension represent (h, w). level_start_index (Tensor): The start index of each level. A tensor has shape (num_levels) and can be represented as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. Returns: Tensor: forwarded results with shape [num_query, bs, embed_dims]. """ if key is None: key = query if value is None: value = key if residual is None: inp_residual = query if query_pos is not None: query = query + query_pos # change to (bs, num_query, embed_dims) query = query.permute(1, 0, 2) bs, num_query, _ = query.size() attention_weights = self.attention_weights(query).view( bs, 1, num_query, self.num_cams, self.num_points, self.num_levels ) reference_points_3d, output, mask = feature_sampling( value, reference_points, self.pc_range, kwargs["img_metas"] ) output = torch.nan_to_num(output) mask = torch.nan_to_num(mask) attention_weights = attention_weights.sigmoid() * mask output = output * attention_weights # [bs, embed_dim, num_query] output = output.sum(-1).sum(-1).sum(-1) # chaneg to [num_query, bs, embed_dims] output = output.permute(2, 0, 1) output = self.img_output_proj(output) radar_feats, radar_mask = radar_feats[:, :, :-1], radar_feats[:, :, -1] radar_xy = radar_feats[:, :, :2] ref_xy = reference_points[:, :, :2] radar_feats = radar_feats[:, :, 2:] pad_xy = torch.ones_like(radar_xy) * 1000.0 radar_xy = radar_xy + (1.0 - radar_mask.unsqueeze(dim=-1).type(torch.float)) * ( pad_xy ) # [B, num_query, M] ref_radar_dist = -1.0 * torch.cdist(ref_xy, radar_xy) # [B, num_query, topk] _value, indices = torch.topk(ref_radar_dist, self.radar_topk) # [B, num_query, M] radar_mask = radar_mask.unsqueeze(dim=1).repeat(1, num_query, 1) # [B, num_query, topk] top_mask = torch.gather(radar_mask, 2, indices) # [B, num_query, M, radar_dim] radar_feats = radar_feats.unsqueeze(dim=1).repeat(1, num_query, 1, 1) radar_dim = radar_feats.size(-1) # [B, num_query, topk, radar_dim] indices_pad = indices.unsqueeze(dim=-1).repeat(1, 1, 1, radar_dim) # [B, num_query, topk, radar_dim] radar_feats_topk = torch.gather( radar_feats, dim=2, index=indices_pad, sparse_grad=False ) attention_weights_radar = self.attention_weights_radar(query).view( bs, num_query, self.radar_topk ) # [B, num_query, topk] attention_weights_radar = attention_weights_radar.sigmoid() * top_mask # [B, num_query, topk, radar_dim] radar_out = radar_feats_topk * attention_weights_radar.unsqueeze(dim=-1) # [bs, num_query, radar_dim] radar_out = radar_out.sum(dim=2) # change to (num_query, bs, embed_dims) radar_out = radar_out.permute(1, 0, 2) radar_out = self.radar_output_proj(radar_out) output = torch.cat((output, radar_out), dim=-1) output = self.img_radar_fusion(output) # (num_query, bs, embed_dims) pos_feat = self.position_encoder(inverse_sigmoid(reference_points_3d)).permute( 1, 0, 2 ) return self.dropout(output) + inp_residual + pos_feat # @ATTENTION.register_module() # class Detr3DCrossAtten(BaseModule): # """An attention module used in Detr3d. # Args: # embed_dims (int): The embedding dimension of Attention. # Default: 256. # num_heads (int): Parallel attention heads. Default: 64. # num_levels (int): The number of feature map used in # Attention. Default: 4. # num_points (int): The number of sampling points for # each query in each head. Default: 4. # im2col_step (int): The step used in image_to_column. # Default: 64. # dropout (float): A Dropout layer on `inp_residual`. # Default: 0.. # init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. # Default: None. # """ # def __init__( # self, # embed_dims=256, # num_heads=8, # num_levels=4, # num_points=5, # num_cams=6, # im2col_step=64, # pc_range=None, # dropout=0.1, # norm_cfg=None, # init_cfg=None, # batch_first=False, # ): # super(Detr3DCrossAtten, self).__init__(init_cfg) # if embed_dims % num_heads != 0: # raise ValueError( # f"embed_dims must be divisible by num_heads, " # f"but got {embed_dims} and {num_heads}" # ) # dim_per_head = embed_dims // num_heads # self.norm_cfg = norm_cfg # self.init_cfg = init_cfg # self.dropout = nn.Dropout(dropout) # self.pc_range = pc_range # # you'd better set dim_per_head to a power of 2 # # which is more efficient in the CUDA implementation # def _is_power_of_2(n): # if (not isinstance(n, int)) or (n < 0): # raise ValueError( # "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)) # ) # return (n & (n - 1) == 0) and n != 0 # if not _is_power_of_2(dim_per_head): # warnings.warn( # "You'd better set embed_dims in " # "MultiScaleDeformAttention to make " # "the dimension of each attention head a power of 2 " # "which is more efficient in our CUDA implementation." # ) # self.im2col_step = im2col_step # self.embed_dims = embed_dims # self.num_levels = num_levels # self.num_heads = num_heads # self.num_points = num_points # self.num_cams = num_cams # self.attention_weights = nn.Linear( # embed_dims, num_cams * num_levels * num_points # ) # self.output_proj = nn.Linear(embed_dims, embed_dims) # self.position_encoder = nn.Sequential( # nn.Linear(3, self.embed_dims), # nn.LayerNorm(self.embed_dims), # nn.ReLU(inplace=True), # nn.Linear(self.embed_dims, self.embed_dims), # nn.LayerNorm(self.embed_dims), # nn.ReLU(inplace=True), # ) # self.batch_first = batch_first # self.init_weight() # def init_weight(self): # """Default initialization for Parameters of Module.""" # constant_init(self.attention_weights, val=0.0, bias=0.0) # xavier_init(self.output_proj, distribution="uniform", bias=0.0) # def forward( # self, # query, # key, # value, # residual=None, # query_pos=None, # key_padding_mask=None, # reference_points=None, # spatial_shapes=None, # level_start_index=None, # **kwargs, # ): # """Forward Function of Detr3DCrossAtten. # Args: # query (Tensor): Query of Transformer with shape # (num_query, bs, embed_dims). # key (Tensor): The key tensor with shape # `(num_key, bs, embed_dims)`. # value (Tensor): The value tensor with shape # `(num_key, bs, embed_dims)`. (B, N, C, H, W) # residual (Tensor): The tensor used for addition, with the # same shape as `x`. Default None. If None, `x` will be used. # query_pos (Tensor): The positional encoding for `query`. # Default: None. # key_pos (Tensor): The positional encoding for `key`. Default # None. # reference_points (Tensor): The normalized reference # points with shape (bs, num_query, 4), # all elements is range in [0, 1], top-left (0,0), # bottom-right (1, 1), including padding area. # or (N, Length_{query}, num_levels, 4), add # additional two dimensions is (w, h) to # form reference boxes. # key_padding_mask (Tensor): ByteTensor for `query`, with # shape [bs, num_key]. # spatial_shapes (Tensor): Spatial shape of features in # different level. With shape (num_levels, 2), # last dimension represent (h, w). # level_start_index (Tensor): The start index of each level. # A tensor has shape (num_levels) and can be represented # as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. # Returns: # Tensor: forwarded results with shape [num_query, bs, embed_dims]. # """ # if key is None: # key = query # if value is None: # value = key # if residual is None: # inp_residual = query # if query_pos is not None: # query = query + query_pos # # change to (bs, num_query, embed_dims) # query = query.permute(1, 0, 2) # bs, num_query, _ = query.size() # attention_weights = self.attention_weights(query).view( # bs, 1, num_query, self.num_cams, self.num_points, self.num_levels # ) # reference_points_3d, output, mask = feature_sampling( # value, reference_points, self.pc_range, kwargs["img_metas"] # ) # output = torch.nan_to_num(output) # mask = torch.nan_to_num(mask) # attention_weights = attention_weights.sigmoid() * mask # output = output * attention_weights # output = output.sum(-1).sum(-1).sum(-1) # output = output.permute(2, 0, 1) # output = self.output_proj(output) # # (num_query, bs, embed_dims) # pos_feat = self.position_encoder(inverse_sigmoid(reference_points_3d)).permute( # 1, 0, 2 # ) # return self.dropout(output) + inp_residual + pos_feat # def feature_sampling(mlvl_feats, reference_points, pc_range, img_metas): # lidar2img = [] # for img_meta in img_metas: # lidar2img.append(img_meta["lidar2img"]) # lidar2img = np.asarray(lidar2img) # lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4) # reference_points = reference_points.clone() # reference_points_3d = reference_points.clone() # reference_points[..., 0:1] = ( # reference_points[..., 0:1] * (pc_range[3] - pc_range[0]) + pc_range[0] # ) # reference_points[..., 1:2] = ( # reference_points[..., 1:2] * (pc_range[4] - pc_range[1]) + pc_range[1] # ) # reference_points[..., 2:3] = ( # reference_points[..., 2:3] * (pc_range[5] - pc_range[2]) + pc_range[2] # ) # # reference_points (B, num_queries, 4) # reference_points = torch.cat( # (reference_points, torch.ones_like(reference_points[..., :1])), -1 # ) # B, num_query = reference_points.size()[:2] # num_cam = lidar2img.size(1) # reference_points = ( # reference_points.view(B, 1, num_query, 4).repeat(1, num_cam, 1, 1).unsqueeze(-1) # ) # lidar2img = lidar2img.view(B, num_cam, 1, 4, 4).repeat(1, 1, num_query, 1, 1) # reference_points_cam = torch.matmul(lidar2img, reference_points).squeeze(-1) # eps = 1e-5 # mask = reference_points_cam[..., 2:3] > eps # reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum( # reference_points_cam[..., 2:3], # torch.ones_like(reference_points_cam[..., 2:3]) * eps, # ) # reference_points_cam[..., 0] /= img_metas[0]["img_shape"][0][0][1] # reference_points_cam[..., 1] /= img_metas[0]["img_shape"][0][0][0] # reference_points_cam = (reference_points_cam - 0.5) * 2 # mask = ( # mask # & (reference_points_cam[..., 0:1] > -1.0) # & (reference_points_cam[..., 0:1] < 1.0) # & (reference_points_cam[..., 1:2] > -1.0) # & (reference_points_cam[..., 1:2] < 1.0) # ) # mask = mask.view(B, num_cam, 1, num_query, 1, 1).permute(0, 2, 3, 1, 4, 5) # mask = torch.nan_to_num(mask) # sampled_feats = [] # for lvl, feat in enumerate(mlvl_feats): # B, N, C, H, W = feat.size() # feat = feat.view(B * N, C, H, W) # reference_points_cam_lvl = reference_points_cam.view(B * N, num_query, 1, 2) # sampled_feat = F.grid_sample(feat, reference_points_cam_lvl) # sampled_feat = sampled_feat.view(B, N, C, num_query, 1).permute(0, 2, 3, 1, 4) # sampled_feats.append(sampled_feat) # sampled_feats = torch.stack(sampled_feats, -1) # sampled_feats = sampled_feats.view(B, C, num_query, num_cam, 1, len(mlvl_feats)) # return reference_points_3d, sampled_feats, mask