| """端到端自动驾驶模型 E2EAVModel。 |
| |
| forward 流程 |
| 1. ``DINOv3`` 提取 8 帧 patch 特征。 |
| 2. ``OnlineCalibration`` 用原始 ego/intr/extr (symlog) + DINOv3 patch 作 KV, |
| 输出 symlog 空间残差,叠加并 symexp 还原得到 corrected_*。 |
| 3. 用 corrected_intr / corrected_extr / corrected_ego 计算 |
| - 每 token 的自车系单位射线(仅用于视觉 token 的 RoPE 第一组头)。 |
| - 8 个 ego token(symlog 后线性投影)。 |
| 4. 2×2×2 时空压缩 -> 1536 视觉 token。 |
| 5. 拼接 [vision(1536) | ego(8) | det(1024) | ctrl(24) | extra(256)] = 2848 token。 |
| 非视觉切片各自加可学习 PE。 |
| 6. 18 层主干(仅视觉切片应用 3D RoPE)。 |
| 7. 切片送入 ``DetectionTrajHead`` 与 ``ControlHead``。 |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .backbone import Backbone, BackboneOutput |
| from .calibration import OnlineCalibration, CalibrationOutput |
| from .encoders import DINOv3Wrapper |
| from .heads import ( |
| ControlHead, |
| ControlOutput, |
| DetectionTrajHead, |
| DetectionTrajOutput, |
| ) |
| from .modules.learned_pe import LearnedTokenPE |
| from .modules.normalization import symlog |
| from .modules.pos_encoding import RoPE3D |
| from .modules.rays import compute_ego_rays |
| from .modules.temporal_compress import TemporalCompress2x2x2 |
|
|
|
|
| @dataclass |
| class E2EOutput: |
| """模型完整输出。""" |
|
|
| detection: DetectionTrajOutput |
| control: ControlOutput |
| backbone_out: BackboneOutput |
| calibration: CalibrationOutput |
|
|
|
|
| class E2EAVModel(nn.Module): |
| def __init__( |
| self, |
| dinov3_path: str = "./dinov3-vitb16-pretrain-lvd1689m", |
| backbone_dim: int = 768, |
| num_heads: int = 12, |
| num_dense_layers: int = 9, |
| num_moe_layers: int = 9, |
| num_routed_experts: int = 7, |
| num_shared_experts: int = 1, |
| topk_experts: int = 3, |
| ffn_mult: int = 4, |
| |
| num_history_frames: int = 8, |
| num_detection_tokens: int = 1024, |
| num_control_tokens: int = 24, |
| num_ego_tokens: int = 8, |
| num_extra_tokens: int = 256, |
| |
| image_h: int = 384, |
| image_w: int = 1024, |
| patch_size: int = 16, |
| |
| num_classes: int = 22, |
| traj_horizon: int = 24, |
| det_head_hidden: int = 384, |
| ctrl_head_hidden: int = 384, |
| |
| calib_dim: int = 256, |
| calib_num_query: int = 256, |
| calib_num_blocks: int = 2, |
| calib_num_self_per_block: int = 2, |
| calib_num_heads: int = 8, |
| calib_residual_range: float = 0.1, |
| calib_intr_dim: int = 11, |
| |
| freeze_dinov3: bool = True, |
| attn_implementation: str = "sdpa", |
| ) -> None: |
| super().__init__() |
| self.image_h = image_h |
| self.image_w = image_w |
| self.patch_size = patch_size |
| self.num_history = num_history_frames |
| self.num_det = num_detection_tokens |
| self.num_ctrl = num_control_tokens |
| self.num_ego = num_ego_tokens |
| self.num_extra = num_extra_tokens |
|
|
| |
| self.dinov3 = DINOv3Wrapper( |
| pretrained_path=dinov3_path, |
| attn_implementation=attn_implementation, |
| freeze=freeze_dinov3, |
| ) |
| dino_dim = self.dinov3.hidden_size |
|
|
| |
| self.calib = OnlineCalibration( |
| dino_dim=dino_dim, |
| hidden_dim=calib_dim, |
| num_query_tokens=calib_num_query, |
| num_blocks=calib_num_blocks, |
| num_self_attn_per_block=calib_num_self_per_block, |
| num_heads=calib_num_heads, |
| residual_range=calib_residual_range, |
| num_history_frames=num_history_frames, |
| intr_dim=calib_intr_dim, |
| ) |
|
|
| |
| self.compress = TemporalCompress2x2x2(dim=dino_dim) |
| |
| self.gh = image_h // patch_size |
| self.gw = image_w // patch_size |
|
|
| |
| self.ego_proj = nn.Linear(6, backbone_dim) |
| self.det_tokens = nn.Parameter(torch.empty(num_detection_tokens, backbone_dim)) |
| nn.init.trunc_normal_(self.det_tokens, std=0.02) |
| self.ctrl_tokens = nn.Parameter(torch.empty(num_control_tokens, backbone_dim)) |
| nn.init.trunc_normal_(self.ctrl_tokens, std=0.02) |
| self.extra_tokens = nn.Parameter(torch.empty(num_extra_tokens, backbone_dim)) |
| nn.init.trunc_normal_(self.extra_tokens, std=0.02) |
|
|
| self.ego_pe = LearnedTokenPE(num_ego_tokens, backbone_dim) |
| self.det_pe = LearnedTokenPE(num_detection_tokens, backbone_dim) |
| self.ctrl_pe = LearnedTokenPE(num_control_tokens, backbone_dim) |
| self.extra_pe = LearnedTokenPE(num_extra_tokens, backbone_dim) |
|
|
| |
| self.rope = RoPE3D( |
| num_heads=num_heads, |
| head_dim=backbone_dim // num_heads, |
| time_size=num_history_frames // 2, |
| height_size=self.gh // 2, |
| width_size=self.gw // 2, |
| ) |
|
|
| |
| self.backbone = Backbone( |
| dim=backbone_dim, |
| num_heads=num_heads, |
| ffn_mult=ffn_mult, |
| num_dense_layers=num_dense_layers, |
| num_moe_layers=num_moe_layers, |
| num_routed=num_routed_experts, |
| num_shared=num_shared_experts, |
| topk=topk_experts, |
| ) |
|
|
| |
| self.det_traj_head = DetectionTrajHead( |
| in_dim=backbone_dim, |
| hidden_size=det_head_hidden, |
| num_classes=num_classes, |
| traj_horizon=traj_horizon, |
| ) |
| self.ctrl_head = ControlHead( |
| in_dim=backbone_dim, |
| hidden_size=ctrl_head_hidden, |
| num_traj_tokens=12, |
| num_action_tokens=num_control_tokens - 12, |
| ego_traj_horizon=traj_horizon, |
| ) |
|
|
| |
|
|
| @property |
| def num_visual_tokens(self) -> int: |
| |
| return (self.num_history // 2) * (self.gh // 2) * (self.gw // 2) |
|
|
| def _build_ego_tokens(self, ego_6d_corrected: torch.Tensor) -> torch.Tensor: |
| """``[B, 8, 6]`` -> symlog -> Linear -> ``[B, 8, D]``。""" |
| return self.ego_proj(symlog(ego_6d_corrected)) |
|
|
| def _build_visual_rays( |
| self, |
| intr_corrected: torch.Tensor, |
| extr_corrected_se3: torch.Tensor, |
| compressed_thw: tuple[int, int, int], |
| ) -> torch.Tensor: |
| """计算压缩后视觉 token 网格的射线方向。 |
| |
| 在 2×2×2 压缩后,每个视觉 token 对应原 patch 网格的一个 2x2 区域 + |
| 2 个时间帧。这里取所代表区域的中心像素与"中间时间"的射线作近似, |
| 所有时间帧取同一个 (h, w) 上的射线(因为相机 pose 在 8 帧间是 |
| rigid 的相机系;自车运动差异会通过 ego token 传递)。 |
| """ |
| b = intr_corrected.shape[0] |
| t_, h_, w_ = compressed_thw |
| rays_grid = compute_ego_rays( |
| intr_vec=intr_corrected, |
| cam2vehicle=extr_corrected_se3, |
| height=self.image_h, |
| width=self.image_w, |
| grid_h=h_, |
| grid_w=w_, |
| device=intr_corrected.device, |
| dtype=intr_corrected.dtype, |
| ) |
| |
| rays = rays_grid.unsqueeze(1).expand(-1, t_, -1, -1, -1).contiguous() |
| rays = rays.reshape(b, t_ * h_ * w_, 3) |
| return rays |
|
|
| |
|
|
| def forward( |
| self, |
| images: torch.Tensor, |
| ego_6d_raw: torch.Tensor, |
| intr_raw: torch.Tensor, |
| extr_6d_raw: torch.Tensor, |
| ) -> E2EOutput: |
| b, t, _, h, w = images.shape |
| assert t == self.num_history, f"history frames mismatch: {t} vs {self.num_history}" |
|
|
| |
| dino_feats = self.dinov3(images) |
|
|
| |
| calib_out: CalibrationOutput = self.calib( |
| dino_feats=dino_feats, |
| ego_raw=ego_6d_raw, |
| intr_raw=intr_raw, |
| extr_raw=extr_6d_raw, |
| ) |
| corrected_ego = calib_out.corrected_ego |
| corrected_intr = calib_out.corrected_intr |
| corrected_extr_6d = calib_out.corrected_extr |
|
|
| |
| from .data.se3 import six_d_to_matrix |
| cam2veh_corrected = six_d_to_matrix(corrected_extr_6d) |
|
|
| |
| compressed, thw = self.compress(dino_feats) |
| n_v = compressed.shape[1] |
|
|
| |
| rays = self._build_visual_rays(corrected_intr, cam2veh_corrected, thw) |
| rope_cos, rope_sin = self.rope.compute_freqs(rays) |
|
|
| |
| ego_tok = self._build_ego_tokens(corrected_ego) |
| det_tok = self.det_tokens.unsqueeze(0).expand(b, -1, -1) |
| ctrl_tok = self.ctrl_tokens.unsqueeze(0).expand(b, -1, -1) |
| extra_tok = self.extra_tokens.unsqueeze(0).expand(b, -1, -1) |
|
|
| ego_tok = self.ego_pe(ego_tok) |
| det_tok = self.det_pe(det_tok) |
| ctrl_tok = self.ctrl_pe(ctrl_tok) |
| extra_tok = self.extra_pe(extra_tok) |
|
|
| |
| seq = torch.cat([compressed, ego_tok, det_tok, ctrl_tok, extra_tok], dim=1) |
| visual_slice = (0, n_v) |
|
|
| |
| bb_out = self.backbone(seq, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) |
|
|
| |
| offset_det = n_v + self.num_ego |
| offset_ctrl = offset_det + self.num_det |
|
|
| det_feats = bb_out.hidden_states[:, offset_det : offset_det + self.num_det] |
| ctrl_feats = bb_out.hidden_states[:, offset_ctrl : offset_ctrl + self.num_ctrl] |
|
|
| det_out = self.det_traj_head(det_feats) |
| ctrl_out = self.ctrl_head(ctrl_feats) |
|
|
| return E2EOutput( |
| detection=det_out, |
| control=ctrl_out, |
| backbone_out=bb_out, |
| calibration=calib_out, |
| ) |
|
|