# --------------------------------------------- # Copyright (c) OpenMMLab. All rights reserved. # --------------------------------------------- # Modified by Zhiqi Li # --------------------------------------------- import numpy as np import torch import torch.nn as nn from mmcv.cnn import xavier_init from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence from mmcv.runner.base_module import BaseModule from mmdet.models.utils.builder import TRANSFORMER from torch.nn.init import normal_ from mmcv.runner.base_module import BaseModule from torchvision.transforms.functional import rotate from .temporal_self_attention import TemporalSelfAttention from .spatial_cross_attention import MSDeformableAttention3D from .decoder import CustomMSDeformableAttention from mmcv.runner import force_fp32, auto_fp16 from mmdet3d_plugin.uniad.custom_modules.peft import (LoRALinear, ZeroAdapter, LoRACLAdapter, LoRAMoECLAdapter, MOELoRALinear, finetuning_detach, frozen_grad, peft_wrapper_forward, lora_wrapper) from mmdet3d_plugin.utils import get_logger logger = get_logger(__name__) @TRANSFORMER.register_module() class PerceptionTransformer(BaseModule): """Implements the Detr3D transformer. Args: as_two_stage (bool): Generate query from encoder features. Default: False. num_feature_levels (int): Number of feature maps from FPN: Default: 4. two_stage_num_proposals (int): Number of proposals when set `as_two_stage` as True. Default: 300. """ def __init__( self, num_feature_levels=4, num_cams=6, two_stage_num_proposals=300, encoder=None, decoder=None, embed_dims=256, rotate_prev_bev=True, use_shift=True, use_can_bus=True, can_bus_norm=True, use_cams_embeds=True, drop_decoder=False, rotate_center=[100, 100], use_lora=False, lora_rank=16, lora_drop=0.1, moe_lora=False, num_task=6, fix_temporal_shift=False, **kwargs ): super(PerceptionTransformer, self).__init__(**kwargs) self.encoder = build_transformer_layer_sequence(encoder) if not drop_decoder: self.decoder = build_transformer_layer_sequence(decoder) else: logger.info('DET decoder are dropped') self.embed_dims = embed_dims self.num_feature_levels = num_feature_levels self.num_cams = num_cams self.fp16_enabled = False self.rotate_prev_bev = rotate_prev_bev self.use_shift = use_shift self.use_can_bus = use_can_bus self.can_bus_norm = can_bus_norm self.use_cams_embeds = use_cams_embeds self.use_lora = use_lora self.lora_rank = lora_rank self.lora_drop = lora_drop self.moe_lora = moe_lora self.num_task = num_task self.two_stage_num_proposals = two_stage_num_proposals self.init_layers() self.rotate_center = rotate_center self.fix_temporal_shift = fix_temporal_shift def init_layers(self): """Initialize layers of the Detr3DTransformer.""" self.level_embeds = nn.Parameter( torch.Tensor(self.num_feature_levels, self.embed_dims) ) self.cams_embeds = nn.Parameter(torch.Tensor(self.num_cams, self.embed_dims)) self.reference_points = nn.Linear(self.embed_dims, 3) self.can_bus_mlp = nn.Sequential( nn.Linear(18, self.embed_dims // 2), nn.ReLU(inplace=True), nn.Linear(self.embed_dims // 2, self.embed_dims), nn.ReLU(inplace=True), ) if self.can_bus_norm: self.can_bus_mlp.add_module("norm", nn.LayerNorm(self.embed_dims)) if self.use_lora: lora_layer = MOELoRALinear if self.moe_lora else LoRALinear self.can_bus_mlp_lora = lora_wrapper(self.can_bus_mlp, lora_layer, self.lora_rank, dropout=self.lora_drop, num_task=self.num_task) finetuning_detach(self) def init_weights(self): """Initialize the transformer weights.""" for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if ( isinstance(m, MSDeformableAttention3D) or isinstance(m, TemporalSelfAttention) or isinstance(m, CustomMSDeformableAttention) ): try: m.init_weight() except AttributeError: m.init_weights() normal_(self.level_embeds) normal_(self.cams_embeds) xavier_init(self.reference_points, distribution="uniform", bias=0.0) xavier_init(self.can_bus_mlp, distribution="uniform", bias=0.0) @auto_fp16(apply_to=("mlvl_feats", "bev_queries", "prev_bev", "bev_pos")) def get_bev_features( self, mlvl_feats, bev_queries, bev_h, bev_w, grid_length=[0.512, 0.512], bev_pos=None, prev_bev=None, img_metas=None, task_idx=None, forward_origin=False, ): """ obtain bev features. """ bs = mlvl_feats[0].size(0) bev_queries = bev_queries.unsqueeze(1).repeat(1, bs, 1) bev_pos = bev_pos.flatten(2).permute(2, 0, 1) grid_length_y = grid_length[0] grid_length_x = grid_length[1] if not self.fix_temporal_shift: delta_x = np.array([each["can_bus"][0] for each in img_metas]) delta_y = np.array([each["can_bus"][1] for each in img_metas]) ego_angle = np.array([each["can_bus"][-2] / np.pi * 180 for each in img_metas]) translation_length = np.sqrt(delta_x ** 2 + delta_y ** 2) translation_angle = np.arctan2(delta_y, delta_x) / np.pi * 180 bev_angle = ego_angle - translation_angle shift_y = ( translation_length * np.cos(bev_angle / 180 * np.pi) / grid_length_y / bev_h ) shift_x = ( translation_length * np.sin(bev_angle / 180 * np.pi) / grid_length_x / bev_w ) shift_y = shift_y * self.use_shift shift_x = shift_x * self.use_shift shift = bev_queries.new_tensor([shift_x, shift_y]).permute( 1, 0 ) # xy, bs -> bs, xy else: # BEVFormer assumes the coords are x-right and y-forward for the nuScenes lidar # but nuplan's coords are x-forward and y-left # here is a fix for any lidar coords, the shift is calculated by the rotation matrix delta_global = np.array([each['can_bus'][:3] for each in img_metas]) lidar2global_rotation = np.array([each['lidar2global_rotation'] for each in img_metas]) delta_lidar = [] for i in range(bs): delta_lidar.append(np.linalg.inv(lidar2global_rotation[i]) @ delta_global[i]) delta_lidar = np.array(delta_lidar) shift_y = delta_lidar[:, 1] / grid_length_y / bev_h shift_x = delta_lidar[:, 0] / grid_length_x / bev_w shift_y = shift_y * self.use_shift shift_x = shift_x * self.use_shift shift = bev_queries.new_tensor([shift_x, shift_y]).permute(1, 0) # xy, bs -> bs, xy if prev_bev is not None: if prev_bev.shape[1] == bev_h * bev_w: prev_bev = prev_bev.permute(1, 0, 2) if self.rotate_prev_bev: for i in range(bs): rotation_angle = img_metas[i]["can_bus"][-1].astype('float64') tmp_prev_bev = ( prev_bev[:, i].reshape(bev_h, bev_w, -1).permute(2, 0, 1) ) tmp_prev_bev = rotate( tmp_prev_bev, rotation_angle, center=self.rotate_center ) tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape( bev_h * bev_w, 1, -1 ) prev_bev[:, i] = tmp_prev_bev[:, 0] # add can bus signals # try: can_bus = [each["can_bus"] for each in img_metas] # except TypeError: # print('Invalid input') # can_bus = [[0 for i in range(18)]] can_bus = bev_queries.new_tensor(can_bus) # 1 x 18 if self.use_lora and forward_origin == False: can_bus = peft_wrapper_forward(can_bus,self.can_bus_mlp, self.can_bus_mlp_lora)[None, :, :] else: can_bus = self.can_bus_mlp(can_bus)[None, :, :] # 1 x 1 x 256 # bev_queries: HW x 1 x 256 # print(self.use_can_bus) bev_queries = bev_queries + can_bus * self.use_can_bus feat_flatten = [] spatial_shapes = [] for lvl, feat in enumerate(mlvl_feats): bs, num_cam, c, h, w = feat.shape spatial_shape = (h, w) feat = feat.flatten(3).permute(1, 0, 3, 2) if self.use_cams_embeds: feat = feat + self.cams_embeds[:, None, None, :].to(feat.dtype) feat = feat + self.level_embeds[None, None, lvl : lvl + 1, :].to(feat.dtype) spatial_shapes.append(spatial_shape) feat_flatten.append(feat) feat_flatten = torch.cat(feat_flatten, 2) spatial_shapes = torch.as_tensor( spatial_shapes, dtype=torch.long, device=bev_pos.device ) level_start_index = torch.cat( (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]) ) feat_flatten = feat_flatten.permute( 0, 2, 1, 3 ) # (num_cam, H*W, bs, embed_dims) bev_embed = self.encoder( bev_queries, feat_flatten, feat_flatten, bev_h=bev_h, bev_w=bev_w, bev_pos=bev_pos, spatial_shapes=spatial_shapes, level_start_index=level_start_index, prev_bev=prev_bev, shift=shift, img_metas=img_metas, task_idx=task_idx, forward_origin=forward_origin ) return bev_embed def get_states_and_refs( self, bev_embed, object_query_embed, bev_h, bev_w, reference_points=None, reg_branches=None, cls_branches=None, img_metas=None, ): bs = bev_embed.shape[1] query_pos, query = torch.split(object_query_embed, self.embed_dims, dim=1) query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) query = query.unsqueeze(0).expand(bs, -1, -1) if reference_points is not None: reference_points = reference_points.unsqueeze(0).expand(bs, -1, -1) else: reference_points = self.reference_points(query_pos) reference_points = reference_points.sigmoid() init_reference_out = reference_points query = query.permute(1, 0, 2) query_pos = query_pos.permute(1, 0, 2) inter_states, inter_references = self.decoder( query=query, key=None, value=bev_embed, query_pos=query_pos, reference_points=reference_points, reg_branches=reg_branches, cls_branches=cls_branches, spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device), level_start_index=torch.tensor([0], device=query.device), img_metas=img_metas, ) inter_references_out = inter_references return inter_states, init_reference_out, inter_references_out