unknownuser6666's picture
Upload folder using huggingface_hub
663494c verified
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import xavier_init
from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence
from mmcv.runner.base_module import BaseModule
from mmdet.models.utils.builder import TRANSFORMER
from torch.nn.init import normal_
from mmcv.runner.base_module import BaseModule
from torchvision.transforms.functional import rotate
from .temporal_self_attention import TemporalSelfAttention
from .spatial_cross_attention import MSDeformableAttention3D
from .decoder import CustomMSDeformableAttention
from mmcv.runner import force_fp32, auto_fp16
from mmdet3d_plugin.uniad.custom_modules.peft import (LoRALinear, ZeroAdapter, LoRACLAdapter, LoRAMoECLAdapter, MOELoRALinear,
finetuning_detach, frozen_grad, peft_wrapper_forward, lora_wrapper)
from mmdet3d_plugin.utils import get_logger
logger = get_logger(__name__)
@TRANSFORMER.register_module()
class PerceptionTransformer(BaseModule):
"""Implements the Detr3D transformer.
Args:
as_two_stage (bool): Generate query from encoder features.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
"""
def __init__(
self,
num_feature_levels=4,
num_cams=6,
two_stage_num_proposals=300,
encoder=None,
decoder=None,
embed_dims=256,
rotate_prev_bev=True,
use_shift=True,
use_can_bus=True,
can_bus_norm=True,
use_cams_embeds=True,
drop_decoder=False,
rotate_center=[100, 100],
use_lora=False,
lora_rank=16,
lora_drop=0.1,
moe_lora=False,
num_task=6,
fix_temporal_shift=False,
**kwargs
):
super(PerceptionTransformer, self).__init__(**kwargs)
self.encoder = build_transformer_layer_sequence(encoder)
if not drop_decoder:
self.decoder = build_transformer_layer_sequence(decoder)
else:
logger.info('DET decoder are dropped')
self.embed_dims = embed_dims
self.num_feature_levels = num_feature_levels
self.num_cams = num_cams
self.fp16_enabled = False
self.rotate_prev_bev = rotate_prev_bev
self.use_shift = use_shift
self.use_can_bus = use_can_bus
self.can_bus_norm = can_bus_norm
self.use_cams_embeds = use_cams_embeds
self.use_lora = use_lora
self.lora_rank = lora_rank
self.lora_drop = lora_drop
self.moe_lora = moe_lora
self.num_task = num_task
self.two_stage_num_proposals = two_stage_num_proposals
self.init_layers()
self.rotate_center = rotate_center
self.fix_temporal_shift = fix_temporal_shift
def init_layers(self):
"""Initialize layers of the Detr3DTransformer."""
self.level_embeds = nn.Parameter(
torch.Tensor(self.num_feature_levels, self.embed_dims)
)
self.cams_embeds = nn.Parameter(torch.Tensor(self.num_cams, self.embed_dims))
self.reference_points = nn.Linear(self.embed_dims, 3)
self.can_bus_mlp = nn.Sequential(
nn.Linear(18, self.embed_dims // 2),
nn.ReLU(inplace=True),
nn.Linear(self.embed_dims // 2, self.embed_dims),
nn.ReLU(inplace=True),
)
if self.can_bus_norm:
self.can_bus_mlp.add_module("norm", nn.LayerNorm(self.embed_dims))
if self.use_lora:
lora_layer = MOELoRALinear if self.moe_lora else LoRALinear
self.can_bus_mlp_lora = lora_wrapper(self.can_bus_mlp, lora_layer, self.lora_rank, dropout=self.lora_drop, num_task=self.num_task)
finetuning_detach(self)
def init_weights(self):
"""Initialize the transformer weights."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if (
isinstance(m, MSDeformableAttention3D)
or isinstance(m, TemporalSelfAttention)
or isinstance(m, CustomMSDeformableAttention)
):
try:
m.init_weight()
except AttributeError:
m.init_weights()
normal_(self.level_embeds)
normal_(self.cams_embeds)
xavier_init(self.reference_points, distribution="uniform", bias=0.0)
xavier_init(self.can_bus_mlp, distribution="uniform", bias=0.0)
@auto_fp16(apply_to=("mlvl_feats", "bev_queries", "prev_bev", "bev_pos"))
def get_bev_features(
self,
mlvl_feats,
bev_queries,
bev_h,
bev_w,
grid_length=[0.512, 0.512],
bev_pos=None,
prev_bev=None,
img_metas=None,
task_idx=None,
forward_origin=False,
):
"""
obtain bev features.
"""
bs = mlvl_feats[0].size(0)
bev_queries = bev_queries.unsqueeze(1).repeat(1, bs, 1)
bev_pos = bev_pos.flatten(2).permute(2, 0, 1)
grid_length_y = grid_length[0]
grid_length_x = grid_length[1]
if not self.fix_temporal_shift:
delta_x = np.array([each["can_bus"][0] for each in img_metas])
delta_y = np.array([each["can_bus"][1] for each in img_metas])
ego_angle = np.array([each["can_bus"][-2] / np.pi * 180 for each in img_metas])
translation_length = np.sqrt(delta_x ** 2 + delta_y ** 2)
translation_angle = np.arctan2(delta_y, delta_x) / np.pi * 180
bev_angle = ego_angle - translation_angle
shift_y = (
translation_length * np.cos(bev_angle / 180 * np.pi) / grid_length_y / bev_h
)
shift_x = (
translation_length * np.sin(bev_angle / 180 * np.pi) / grid_length_x / bev_w
)
shift_y = shift_y * self.use_shift
shift_x = shift_x * self.use_shift
shift = bev_queries.new_tensor([shift_x, shift_y]).permute(
1, 0
) # xy, bs -> bs, xy
else:
# BEVFormer assumes the coords are x-right and y-forward for the nuScenes lidar
# but nuplan's coords are x-forward and y-left
# here is a fix for any lidar coords, the shift is calculated by the rotation matrix
delta_global = np.array([each['can_bus'][:3] for each in img_metas])
lidar2global_rotation = np.array([each['lidar2global_rotation'] for each in img_metas])
delta_lidar = []
for i in range(bs):
delta_lidar.append(np.linalg.inv(lidar2global_rotation[i]) @ delta_global[i])
delta_lidar = np.array(delta_lidar)
shift_y = delta_lidar[:, 1] / grid_length_y / bev_h
shift_x = delta_lidar[:, 0] / grid_length_x / bev_w
shift_y = shift_y * self.use_shift
shift_x = shift_x * self.use_shift
shift = bev_queries.new_tensor([shift_x, shift_y]).permute(1, 0) # xy, bs -> bs, xy
if prev_bev is not None:
if prev_bev.shape[1] == bev_h * bev_w:
prev_bev = prev_bev.permute(1, 0, 2)
if self.rotate_prev_bev:
for i in range(bs):
rotation_angle = img_metas[i]["can_bus"][-1].astype('float64')
tmp_prev_bev = (
prev_bev[:, i].reshape(bev_h, bev_w, -1).permute(2, 0, 1)
)
tmp_prev_bev = rotate(
tmp_prev_bev, rotation_angle, center=self.rotate_center
)
tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape(
bev_h * bev_w, 1, -1
)
prev_bev[:, i] = tmp_prev_bev[:, 0]
# add can bus signals
# try:
can_bus = [each["can_bus"] for each in img_metas]
# except TypeError:
# print('Invalid input')
# can_bus = [[0 for i in range(18)]]
can_bus = bev_queries.new_tensor(can_bus) # 1 x 18
if self.use_lora and forward_origin == False:
can_bus = peft_wrapper_forward(can_bus,self.can_bus_mlp, self.can_bus_mlp_lora)[None, :, :]
else:
can_bus = self.can_bus_mlp(can_bus)[None, :, :] # 1 x 1 x 256
# bev_queries: HW x 1 x 256
# print(self.use_can_bus)
bev_queries = bev_queries + can_bus * self.use_can_bus
feat_flatten = []
spatial_shapes = []
for lvl, feat in enumerate(mlvl_feats):
bs, num_cam, c, h, w = feat.shape
spatial_shape = (h, w)
feat = feat.flatten(3).permute(1, 0, 3, 2)
if self.use_cams_embeds:
feat = feat + self.cams_embeds[:, None, None, :].to(feat.dtype)
feat = feat + self.level_embeds[None, None, lvl : lvl + 1, :].to(feat.dtype)
spatial_shapes.append(spatial_shape)
feat_flatten.append(feat)
feat_flatten = torch.cat(feat_flatten, 2)
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=bev_pos.device
)
level_start_index = torch.cat(
(spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
)
feat_flatten = feat_flatten.permute(
0, 2, 1, 3
) # (num_cam, H*W, bs, embed_dims)
bev_embed = self.encoder(
bev_queries,
feat_flatten,
feat_flatten,
bev_h=bev_h,
bev_w=bev_w,
bev_pos=bev_pos,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
prev_bev=prev_bev,
shift=shift,
img_metas=img_metas,
task_idx=task_idx,
forward_origin=forward_origin
)
return bev_embed
def get_states_and_refs(
self,
bev_embed,
object_query_embed,
bev_h,
bev_w,
reference_points=None,
reg_branches=None,
cls_branches=None,
img_metas=None,
):
bs = bev_embed.shape[1]
query_pos, query = torch.split(object_query_embed, self.embed_dims, dim=1)
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
query = query.unsqueeze(0).expand(bs, -1, -1)
if reference_points is not None:
reference_points = reference_points.unsqueeze(0).expand(bs, -1, -1)
else:
reference_points = self.reference_points(query_pos)
reference_points = reference_points.sigmoid()
init_reference_out = reference_points
query = query.permute(1, 0, 2)
query_pos = query_pos.permute(1, 0, 2)
inter_states, inter_references = self.decoder(
query=query,
key=None,
value=bev_embed,
query_pos=query_pos,
reference_points=reference_points,
reg_branches=reg_branches,
cls_branches=cls_branches,
spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device),
level_start_index=torch.tensor([0], device=query.device),
img_metas=img_metas,
)
inter_references_out = inter_references
return inter_states, init_reference_out, inter_references_out