"""端到端自动驾驶模型 E2EAVModel。 forward 流程 1. ``DINOv3`` 提取 8 帧 patch 特征。 2. ``OnlineCalibration`` 用原始 ego/intr/extr (symlog) + DINOv3 patch 作 KV, 输出 symlog 空间残差,叠加并 symexp 还原得到 corrected_*。 3. 用 corrected_intr / corrected_extr / corrected_ego 计算 - 每 token 的自车系单位射线(仅用于视觉 token 的 RoPE 第一组头)。 - 8 个 ego token(symlog 后线性投影)。 4. 2×2×2 时空压缩 -> 1536 视觉 token。 5. 拼接 [vision(1536) | ego(8) | det(1024) | ctrl(24) | extra(256)] = 2848 token。 非视觉切片各自加可学习 PE。 6. 18 层主干(仅视觉切片应用 3D RoPE)。 7. 切片送入 ``DetectionTrajHead`` 与 ``ControlHead``。 """ from __future__ import annotations from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from .backbone import Backbone, BackboneOutput from .calibration import OnlineCalibration, CalibrationOutput from .encoders import DINOv3Wrapper from .heads import ( ControlHead, ControlOutput, DetectionTrajHead, DetectionTrajOutput, ) from .modules.learned_pe import LearnedTokenPE from .modules.normalization import symlog from .modules.pos_encoding import RoPE3D from .modules.rays import compute_ego_rays from .modules.temporal_compress import TemporalCompress2x2x2 @dataclass class E2EOutput: """模型完整输出。""" detection: DetectionTrajOutput control: ControlOutput backbone_out: BackboneOutput calibration: CalibrationOutput class E2EAVModel(nn.Module): def __init__( self, dinov3_path: str = "./dinov3-vitb16-pretrain-lvd1689m", backbone_dim: int = 768, num_heads: int = 12, num_dense_layers: int = 9, num_moe_layers: int = 9, num_routed_experts: int = 7, num_shared_experts: int = 1, topk_experts: int = 3, ffn_mult: int = 4, # token 数量 num_history_frames: int = 8, num_detection_tokens: int = 1024, num_control_tokens: int = 24, num_ego_tokens: int = 8, num_extra_tokens: int = 256, # 输入分辨率 image_h: int = 384, image_w: int = 1024, patch_size: int = 16, # 头超参 num_classes: int = 22, traj_horizon: int = 24, det_head_hidden: int = 384, ctrl_head_hidden: int = 384, # 校准 calib_dim: int = 256, calib_num_query: int = 256, calib_num_blocks: int = 2, calib_num_self_per_block: int = 2, calib_num_heads: int = 8, calib_residual_range: float = 0.1, calib_intr_dim: int = 11, # DINOv3 freeze_dinov3: bool = True, attn_implementation: str = "sdpa", ) -> None: super().__init__() self.image_h = image_h self.image_w = image_w self.patch_size = patch_size self.num_history = num_history_frames self.num_det = num_detection_tokens self.num_ctrl = num_control_tokens self.num_ego = num_ego_tokens self.num_extra = num_extra_tokens # === 1) DINOv3 === self.dinov3 = DINOv3Wrapper( pretrained_path=dinov3_path, attn_implementation=attn_implementation, freeze=freeze_dinov3, ) dino_dim = self.dinov3.hidden_size # === 2) 在线校准 === self.calib = OnlineCalibration( dino_dim=dino_dim, hidden_dim=calib_dim, num_query_tokens=calib_num_query, num_blocks=calib_num_blocks, num_self_attn_per_block=calib_num_self_per_block, num_heads=calib_num_heads, residual_range=calib_residual_range, num_history_frames=num_history_frames, intr_dim=calib_intr_dim, ) # === 3) 时空压缩 === self.compress = TemporalCompress2x2x2(dim=dino_dim) # patch 网格大小(必须能被 2 整除) self.gh = image_h // patch_size self.gw = image_w // patch_size # === 4) 各类 token + 可学习 PE === self.ego_proj = nn.Linear(6, backbone_dim) # 6D pose -> backbone dim self.det_tokens = nn.Parameter(torch.empty(num_detection_tokens, backbone_dim)) nn.init.trunc_normal_(self.det_tokens, std=0.02) self.ctrl_tokens = nn.Parameter(torch.empty(num_control_tokens, backbone_dim)) nn.init.trunc_normal_(self.ctrl_tokens, std=0.02) self.extra_tokens = nn.Parameter(torch.empty(num_extra_tokens, backbone_dim)) nn.init.trunc_normal_(self.extra_tokens, std=0.02) self.ego_pe = LearnedTokenPE(num_ego_tokens, backbone_dim) self.det_pe = LearnedTokenPE(num_detection_tokens, backbone_dim) self.ctrl_pe = LearnedTokenPE(num_control_tokens, backbone_dim) self.extra_pe = LearnedTokenPE(num_extra_tokens, backbone_dim) # === 5) RoPE 3D(仅视觉,4 时间帧 × 12 × 32 网格)=== self.rope = RoPE3D( num_heads=num_heads, head_dim=backbone_dim // num_heads, time_size=num_history_frames // 2, height_size=self.gh // 2, width_size=self.gw // 2, ) # === 6) 主干 18 层 === self.backbone = Backbone( dim=backbone_dim, num_heads=num_heads, ffn_mult=ffn_mult, num_dense_layers=num_dense_layers, num_moe_layers=num_moe_layers, num_routed=num_routed_experts, num_shared=num_shared_experts, topk=topk_experts, ) # === 7) 头 === self.det_traj_head = DetectionTrajHead( in_dim=backbone_dim, hidden_size=det_head_hidden, num_classes=num_classes, traj_horizon=traj_horizon, ) self.ctrl_head = ControlHead( in_dim=backbone_dim, hidden_size=ctrl_head_hidden, num_traj_tokens=12, num_action_tokens=num_control_tokens - 12, ego_traj_horizon=traj_horizon, ) # ---------- 工具 ---------- @property def num_visual_tokens(self) -> int: # 2×2×2 压缩后 return (self.num_history // 2) * (self.gh // 2) * (self.gw // 2) def _build_ego_tokens(self, ego_6d_corrected: torch.Tensor) -> torch.Tensor: """``[B, 8, 6]`` -> symlog -> Linear -> ``[B, 8, D]``。""" return self.ego_proj(symlog(ego_6d_corrected)) def _build_visual_rays( self, intr_corrected: torch.Tensor, # [B, calib_intr_dim] extr_corrected_se3: torch.Tensor, # [B, 4, 4] cam2vehicle compressed_thw: tuple[int, int, int], ) -> torch.Tensor: """计算压缩后视觉 token 网格的射线方向。 在 2×2×2 压缩后,每个视觉 token 对应原 patch 网格的一个 2x2 区域 + 2 个时间帧。这里取所代表区域的中心像素与"中间时间"的射线作近似, 所有时间帧取同一个 (h, w) 上的射线(因为相机 pose 在 8 帧间是 rigid 的相机系;自车运动差异会通过 ego token 传递)。 """ b = intr_corrected.shape[0] t_, h_, w_ = compressed_thw rays_grid = compute_ego_rays( intr_vec=intr_corrected, cam2vehicle=extr_corrected_se3, height=self.image_h, width=self.image_w, grid_h=h_, grid_w=w_, device=intr_corrected.device, dtype=intr_corrected.dtype, ) # [B, h_, w_, 3] # 复制到时间维:[B, T_, h_, w_, 3] -> flatten 为 [B, N_v, 3] rays = rays_grid.unsqueeze(1).expand(-1, t_, -1, -1, -1).contiguous() rays = rays.reshape(b, t_ * h_ * w_, 3) return rays # ---------- 前向 ---------- def forward( self, images: torch.Tensor, # [B, T=8, 3, H, W] ego_6d_raw: torch.Tensor, # [B, 8, 6] intr_raw: torch.Tensor, # [B, calib_intr_dim],须与构造时一致 extr_6d_raw: torch.Tensor, # [B, 6] ) -> E2EOutput: b, t, _, h, w = images.shape assert t == self.num_history, f"history frames mismatch: {t} vs {self.num_history}" # 1) DINOv3 patch tokens [B, T, gh, gw, D_dino] dino_feats = self.dinov3(images) # 2) 校准(symlog 空间残差 + symexp 还原) calib_out: CalibrationOutput = self.calib( dino_feats=dino_feats, ego_raw=ego_6d_raw, intr_raw=intr_raw, extr_raw=extr_6d_raw, ) corrected_ego = calib_out.corrected_ego corrected_intr = calib_out.corrected_intr corrected_extr_6d = calib_out.corrected_extr # 3) 把 corrected_extr 6D 转成 4x4 from .data.se3 import six_d_to_matrix cam2veh_corrected = six_d_to_matrix(corrected_extr_6d) # [B, 4, 4] # 4) 2x2x2 时空压缩 compressed, thw = self.compress(dino_feats) # [B, N_v, D] n_v = compressed.shape[1] # 5) 视觉射线(用 corrected_intr / corrected_extr) rays = self._build_visual_rays(corrected_intr, cam2veh_corrected, thw) rope_cos, rope_sin = self.rope.compute_freqs(rays) # 6) 构造非视觉 token ego_tok = self._build_ego_tokens(corrected_ego) # [B, 8, D] det_tok = self.det_tokens.unsqueeze(0).expand(b, -1, -1) ctrl_tok = self.ctrl_tokens.unsqueeze(0).expand(b, -1, -1) extra_tok = self.extra_tokens.unsqueeze(0).expand(b, -1, -1) ego_tok = self.ego_pe(ego_tok) det_tok = self.det_pe(det_tok) ctrl_tok = self.ctrl_pe(ctrl_tok) extra_tok = self.extra_pe(extra_tok) # 7) 拼接序列:[vision | ego | det | ctrl | extra] seq = torch.cat([compressed, ego_tok, det_tok, ctrl_tok, extra_tok], dim=1) visual_slice = (0, n_v) # 8) 主干 bb_out = self.backbone(seq, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) # 9) 切片送入头 offset_det = n_v + self.num_ego offset_ctrl = offset_det + self.num_det det_feats = bb_out.hidden_states[:, offset_det : offset_det + self.num_det] ctrl_feats = bb_out.hidden_states[:, offset_ctrl : offset_ctrl + self.num_ctrl] det_out = self.det_traj_head(det_feats) ctrl_out = self.ctrl_head(ctrl_feats) return E2EOutput( detection=det_out, control=ctrl_out, backbone_out=bb_out, calibration=calib_out, )