|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from mmcv.cnn import xavier_init
|
|
|
from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence
|
|
|
from mmcv.runner.base_module import BaseModule
|
|
|
|
|
|
from mmdet.models.utils.builder import TRANSFORMER
|
|
|
from torch.nn.init import normal_
|
|
|
from mmcv.runner.base_module import BaseModule
|
|
|
from torchvision.transforms.functional import rotate
|
|
|
from .temporal_self_attention import TemporalSelfAttention
|
|
|
from .spatial_cross_attention import MSDeformableAttention3D
|
|
|
from .decoder import CustomMSDeformableAttention
|
|
|
from mmcv.runner import force_fp32, auto_fp16
|
|
|
|
|
|
from mmdet3d_plugin.uniad.custom_modules.peft import (LoRALinear, ZeroAdapter, LoRACLAdapter, LoRAMoECLAdapter, MOELoRALinear,
|
|
|
finetuning_detach, frozen_grad, peft_wrapper_forward, lora_wrapper)
|
|
|
from mmdet3d_plugin.utils import get_logger
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@TRANSFORMER.register_module()
|
|
|
class PerceptionTransformer(BaseModule):
|
|
|
"""Implements the Detr3D transformer.
|
|
|
Args:
|
|
|
as_two_stage (bool): Generate query from encoder features.
|
|
|
Default: False.
|
|
|
num_feature_levels (int): Number of feature maps from FPN:
|
|
|
Default: 4.
|
|
|
two_stage_num_proposals (int): Number of proposals when set
|
|
|
`as_two_stage` as True. Default: 300.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
num_feature_levels=4,
|
|
|
num_cams=6,
|
|
|
two_stage_num_proposals=300,
|
|
|
encoder=None,
|
|
|
decoder=None,
|
|
|
embed_dims=256,
|
|
|
rotate_prev_bev=True,
|
|
|
use_shift=True,
|
|
|
use_can_bus=True,
|
|
|
can_bus_norm=True,
|
|
|
use_cams_embeds=True,
|
|
|
drop_decoder=False,
|
|
|
rotate_center=[100, 100],
|
|
|
use_lora=False,
|
|
|
lora_rank=16,
|
|
|
lora_drop=0.1,
|
|
|
moe_lora=False,
|
|
|
num_task=6,
|
|
|
fix_temporal_shift=False,
|
|
|
**kwargs
|
|
|
):
|
|
|
super(PerceptionTransformer, self).__init__(**kwargs)
|
|
|
self.encoder = build_transformer_layer_sequence(encoder)
|
|
|
if not drop_decoder:
|
|
|
self.decoder = build_transformer_layer_sequence(decoder)
|
|
|
else:
|
|
|
logger.info('DET decoder are dropped')
|
|
|
self.embed_dims = embed_dims
|
|
|
self.num_feature_levels = num_feature_levels
|
|
|
self.num_cams = num_cams
|
|
|
self.fp16_enabled = False
|
|
|
|
|
|
self.rotate_prev_bev = rotate_prev_bev
|
|
|
self.use_shift = use_shift
|
|
|
self.use_can_bus = use_can_bus
|
|
|
self.can_bus_norm = can_bus_norm
|
|
|
self.use_cams_embeds = use_cams_embeds
|
|
|
|
|
|
self.use_lora = use_lora
|
|
|
self.lora_rank = lora_rank
|
|
|
self.lora_drop = lora_drop
|
|
|
self.moe_lora = moe_lora
|
|
|
self.num_task = num_task
|
|
|
|
|
|
self.two_stage_num_proposals = two_stage_num_proposals
|
|
|
self.init_layers()
|
|
|
self.rotate_center = rotate_center
|
|
|
|
|
|
self.fix_temporal_shift = fix_temporal_shift
|
|
|
|
|
|
def init_layers(self):
|
|
|
"""Initialize layers of the Detr3DTransformer."""
|
|
|
self.level_embeds = nn.Parameter(
|
|
|
torch.Tensor(self.num_feature_levels, self.embed_dims)
|
|
|
)
|
|
|
self.cams_embeds = nn.Parameter(torch.Tensor(self.num_cams, self.embed_dims))
|
|
|
self.reference_points = nn.Linear(self.embed_dims, 3)
|
|
|
self.can_bus_mlp = nn.Sequential(
|
|
|
nn.Linear(18, self.embed_dims // 2),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Linear(self.embed_dims // 2, self.embed_dims),
|
|
|
nn.ReLU(inplace=True),
|
|
|
)
|
|
|
if self.can_bus_norm:
|
|
|
self.can_bus_mlp.add_module("norm", nn.LayerNorm(self.embed_dims))
|
|
|
|
|
|
if self.use_lora:
|
|
|
lora_layer = MOELoRALinear if self.moe_lora else LoRALinear
|
|
|
self.can_bus_mlp_lora = lora_wrapper(self.can_bus_mlp, lora_layer, self.lora_rank, dropout=self.lora_drop, num_task=self.num_task)
|
|
|
finetuning_detach(self)
|
|
|
|
|
|
def init_weights(self):
|
|
|
"""Initialize the transformer weights."""
|
|
|
for p in self.parameters():
|
|
|
if p.dim() > 1:
|
|
|
nn.init.xavier_uniform_(p)
|
|
|
for m in self.modules():
|
|
|
if (
|
|
|
isinstance(m, MSDeformableAttention3D)
|
|
|
or isinstance(m, TemporalSelfAttention)
|
|
|
or isinstance(m, CustomMSDeformableAttention)
|
|
|
):
|
|
|
try:
|
|
|
m.init_weight()
|
|
|
except AttributeError:
|
|
|
m.init_weights()
|
|
|
normal_(self.level_embeds)
|
|
|
normal_(self.cams_embeds)
|
|
|
xavier_init(self.reference_points, distribution="uniform", bias=0.0)
|
|
|
xavier_init(self.can_bus_mlp, distribution="uniform", bias=0.0)
|
|
|
|
|
|
@auto_fp16(apply_to=("mlvl_feats", "bev_queries", "prev_bev", "bev_pos"))
|
|
|
def get_bev_features(
|
|
|
self,
|
|
|
mlvl_feats,
|
|
|
bev_queries,
|
|
|
bev_h,
|
|
|
bev_w,
|
|
|
grid_length=[0.512, 0.512],
|
|
|
bev_pos=None,
|
|
|
prev_bev=None,
|
|
|
img_metas=None,
|
|
|
task_idx=None,
|
|
|
forward_origin=False,
|
|
|
):
|
|
|
"""
|
|
|
obtain bev features.
|
|
|
"""
|
|
|
|
|
|
bs = mlvl_feats[0].size(0)
|
|
|
bev_queries = bev_queries.unsqueeze(1).repeat(1, bs, 1)
|
|
|
bev_pos = bev_pos.flatten(2).permute(2, 0, 1)
|
|
|
|
|
|
grid_length_y = grid_length[0]
|
|
|
grid_length_x = grid_length[1]
|
|
|
if not self.fix_temporal_shift:
|
|
|
delta_x = np.array([each["can_bus"][0] for each in img_metas])
|
|
|
delta_y = np.array([each["can_bus"][1] for each in img_metas])
|
|
|
ego_angle = np.array([each["can_bus"][-2] / np.pi * 180 for each in img_metas])
|
|
|
translation_length = np.sqrt(delta_x ** 2 + delta_y ** 2)
|
|
|
translation_angle = np.arctan2(delta_y, delta_x) / np.pi * 180
|
|
|
bev_angle = ego_angle - translation_angle
|
|
|
shift_y = (
|
|
|
translation_length * np.cos(bev_angle / 180 * np.pi) / grid_length_y / bev_h
|
|
|
)
|
|
|
shift_x = (
|
|
|
translation_length * np.sin(bev_angle / 180 * np.pi) / grid_length_x / bev_w
|
|
|
)
|
|
|
shift_y = shift_y * self.use_shift
|
|
|
shift_x = shift_x * self.use_shift
|
|
|
shift = bev_queries.new_tensor([shift_x, shift_y]).permute(
|
|
|
1, 0
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
delta_global = np.array([each['can_bus'][:3] for each in img_metas])
|
|
|
lidar2global_rotation = np.array([each['lidar2global_rotation'] for each in img_metas])
|
|
|
delta_lidar = []
|
|
|
for i in range(bs):
|
|
|
delta_lidar.append(np.linalg.inv(lidar2global_rotation[i]) @ delta_global[i])
|
|
|
delta_lidar = np.array(delta_lidar)
|
|
|
shift_y = delta_lidar[:, 1] / grid_length_y / bev_h
|
|
|
shift_x = delta_lidar[:, 0] / grid_length_x / bev_w
|
|
|
shift_y = shift_y * self.use_shift
|
|
|
shift_x = shift_x * self.use_shift
|
|
|
shift = bev_queries.new_tensor([shift_x, shift_y]).permute(1, 0)
|
|
|
|
|
|
|
|
|
if prev_bev is not None:
|
|
|
if prev_bev.shape[1] == bev_h * bev_w:
|
|
|
prev_bev = prev_bev.permute(1, 0, 2)
|
|
|
|
|
|
if self.rotate_prev_bev:
|
|
|
for i in range(bs):
|
|
|
rotation_angle = img_metas[i]["can_bus"][-1].astype('float64')
|
|
|
tmp_prev_bev = (
|
|
|
prev_bev[:, i].reshape(bev_h, bev_w, -1).permute(2, 0, 1)
|
|
|
)
|
|
|
tmp_prev_bev = rotate(
|
|
|
tmp_prev_bev, rotation_angle, center=self.rotate_center
|
|
|
)
|
|
|
tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape(
|
|
|
bev_h * bev_w, 1, -1
|
|
|
)
|
|
|
prev_bev[:, i] = tmp_prev_bev[:, 0]
|
|
|
|
|
|
|
|
|
|
|
|
can_bus = [each["can_bus"] for each in img_metas]
|
|
|
|
|
|
|
|
|
|
|
|
can_bus = bev_queries.new_tensor(can_bus)
|
|
|
if self.use_lora and forward_origin == False:
|
|
|
can_bus = peft_wrapper_forward(can_bus,self.can_bus_mlp, self.can_bus_mlp_lora)[None, :, :]
|
|
|
else:
|
|
|
can_bus = self.can_bus_mlp(can_bus)[None, :, :]
|
|
|
|
|
|
|
|
|
|
|
|
bev_queries = bev_queries + can_bus * self.use_can_bus
|
|
|
|
|
|
feat_flatten = []
|
|
|
spatial_shapes = []
|
|
|
for lvl, feat in enumerate(mlvl_feats):
|
|
|
bs, num_cam, c, h, w = feat.shape
|
|
|
spatial_shape = (h, w)
|
|
|
feat = feat.flatten(3).permute(1, 0, 3, 2)
|
|
|
if self.use_cams_embeds:
|
|
|
feat = feat + self.cams_embeds[:, None, None, :].to(feat.dtype)
|
|
|
feat = feat + self.level_embeds[None, None, lvl : lvl + 1, :].to(feat.dtype)
|
|
|
spatial_shapes.append(spatial_shape)
|
|
|
feat_flatten.append(feat)
|
|
|
|
|
|
feat_flatten = torch.cat(feat_flatten, 2)
|
|
|
spatial_shapes = torch.as_tensor(
|
|
|
spatial_shapes, dtype=torch.long, device=bev_pos.device
|
|
|
)
|
|
|
level_start_index = torch.cat(
|
|
|
(spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
|
|
|
)
|
|
|
|
|
|
feat_flatten = feat_flatten.permute(
|
|
|
0, 2, 1, 3
|
|
|
)
|
|
|
|
|
|
bev_embed = self.encoder(
|
|
|
bev_queries,
|
|
|
feat_flatten,
|
|
|
feat_flatten,
|
|
|
bev_h=bev_h,
|
|
|
bev_w=bev_w,
|
|
|
bev_pos=bev_pos,
|
|
|
spatial_shapes=spatial_shapes,
|
|
|
level_start_index=level_start_index,
|
|
|
prev_bev=prev_bev,
|
|
|
shift=shift,
|
|
|
img_metas=img_metas,
|
|
|
task_idx=task_idx,
|
|
|
forward_origin=forward_origin
|
|
|
)
|
|
|
|
|
|
return bev_embed
|
|
|
|
|
|
def get_states_and_refs(
|
|
|
self,
|
|
|
bev_embed,
|
|
|
object_query_embed,
|
|
|
bev_h,
|
|
|
bev_w,
|
|
|
reference_points=None,
|
|
|
reg_branches=None,
|
|
|
cls_branches=None,
|
|
|
img_metas=None,
|
|
|
):
|
|
|
bs = bev_embed.shape[1]
|
|
|
query_pos, query = torch.split(object_query_embed, self.embed_dims, dim=1)
|
|
|
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
|
|
|
query = query.unsqueeze(0).expand(bs, -1, -1)
|
|
|
if reference_points is not None:
|
|
|
reference_points = reference_points.unsqueeze(0).expand(bs, -1, -1)
|
|
|
else:
|
|
|
reference_points = self.reference_points(query_pos)
|
|
|
reference_points = reference_points.sigmoid()
|
|
|
init_reference_out = reference_points
|
|
|
query = query.permute(1, 0, 2)
|
|
|
query_pos = query_pos.permute(1, 0, 2)
|
|
|
inter_states, inter_references = self.decoder(
|
|
|
query=query,
|
|
|
key=None,
|
|
|
value=bev_embed,
|
|
|
query_pos=query_pos,
|
|
|
reference_points=reference_points,
|
|
|
reg_branches=reg_branches,
|
|
|
cls_branches=cls_branches,
|
|
|
spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device),
|
|
|
level_start_index=torch.tensor([0], device=query.device),
|
|
|
img_metas=img_metas,
|
|
|
)
|
|
|
inter_references_out = inter_references
|
|
|
|
|
|
return inter_states, init_reference_out, inter_references_out
|
|
|
|