File size: 10,708 Bytes
0cfefd2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 | """端到端自动驾驶模型 E2EAVModel。
forward 流程
1. ``DINOv3`` 提取 8 帧 patch 特征。
2. ``OnlineCalibration`` 用原始 ego/intr/extr (symlog) + DINOv3 patch 作 KV,
输出 symlog 空间残差,叠加并 symexp 还原得到 corrected_*。
3. 用 corrected_intr / corrected_extr / corrected_ego 计算
- 每 token 的自车系单位射线(仅用于视觉 token 的 RoPE 第一组头)。
- 8 个 ego token(symlog 后线性投影)。
4. 2×2×2 时空压缩 -> 1536 视觉 token。
5. 拼接 [vision(1536) | ego(8) | det(1024) | ctrl(24) | extra(256)] = 2848 token。
非视觉切片各自加可学习 PE。
6. 18 层主干(仅视觉切片应用 3D RoPE)。
7. 切片送入 ``DetectionTrajHead`` 与 ``ControlHead``。
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from .backbone import Backbone, BackboneOutput
from .calibration import OnlineCalibration, CalibrationOutput
from .encoders import DINOv3Wrapper
from .heads import (
ControlHead,
ControlOutput,
DetectionTrajHead,
DetectionTrajOutput,
)
from .modules.learned_pe import LearnedTokenPE
from .modules.normalization import symlog
from .modules.pos_encoding import RoPE3D
from .modules.rays import compute_ego_rays
from .modules.temporal_compress import TemporalCompress2x2x2
@dataclass
class E2EOutput:
"""模型完整输出。"""
detection: DetectionTrajOutput
control: ControlOutput
backbone_out: BackboneOutput
calibration: CalibrationOutput
class E2EAVModel(nn.Module):
def __init__(
self,
dinov3_path: str = "./dinov3-vitb16-pretrain-lvd1689m",
backbone_dim: int = 768,
num_heads: int = 12,
num_dense_layers: int = 9,
num_moe_layers: int = 9,
num_routed_experts: int = 7,
num_shared_experts: int = 1,
topk_experts: int = 3,
ffn_mult: int = 4,
# token 数量
num_history_frames: int = 8,
num_detection_tokens: int = 1024,
num_control_tokens: int = 24,
num_ego_tokens: int = 8,
num_extra_tokens: int = 256,
# 输入分辨率
image_h: int = 384,
image_w: int = 1024,
patch_size: int = 16,
# 头超参
num_classes: int = 22,
traj_horizon: int = 24,
det_head_hidden: int = 384,
ctrl_head_hidden: int = 384,
# 校准
calib_dim: int = 256,
calib_num_query: int = 256,
calib_num_blocks: int = 2,
calib_num_self_per_block: int = 2,
calib_num_heads: int = 8,
calib_residual_range: float = 0.1,
calib_intr_dim: int = 11,
# DINOv3
freeze_dinov3: bool = True,
attn_implementation: str = "sdpa",
) -> None:
super().__init__()
self.image_h = image_h
self.image_w = image_w
self.patch_size = patch_size
self.num_history = num_history_frames
self.num_det = num_detection_tokens
self.num_ctrl = num_control_tokens
self.num_ego = num_ego_tokens
self.num_extra = num_extra_tokens
# === 1) DINOv3 ===
self.dinov3 = DINOv3Wrapper(
pretrained_path=dinov3_path,
attn_implementation=attn_implementation,
freeze=freeze_dinov3,
)
dino_dim = self.dinov3.hidden_size
# === 2) 在线校准 ===
self.calib = OnlineCalibration(
dino_dim=dino_dim,
hidden_dim=calib_dim,
num_query_tokens=calib_num_query,
num_blocks=calib_num_blocks,
num_self_attn_per_block=calib_num_self_per_block,
num_heads=calib_num_heads,
residual_range=calib_residual_range,
num_history_frames=num_history_frames,
intr_dim=calib_intr_dim,
)
# === 3) 时空压缩 ===
self.compress = TemporalCompress2x2x2(dim=dino_dim)
# patch 网格大小(必须能被 2 整除)
self.gh = image_h // patch_size
self.gw = image_w // patch_size
# === 4) 各类 token + 可学习 PE ===
self.ego_proj = nn.Linear(6, backbone_dim) # 6D pose -> backbone dim
self.det_tokens = nn.Parameter(torch.empty(num_detection_tokens, backbone_dim))
nn.init.trunc_normal_(self.det_tokens, std=0.02)
self.ctrl_tokens = nn.Parameter(torch.empty(num_control_tokens, backbone_dim))
nn.init.trunc_normal_(self.ctrl_tokens, std=0.02)
self.extra_tokens = nn.Parameter(torch.empty(num_extra_tokens, backbone_dim))
nn.init.trunc_normal_(self.extra_tokens, std=0.02)
self.ego_pe = LearnedTokenPE(num_ego_tokens, backbone_dim)
self.det_pe = LearnedTokenPE(num_detection_tokens, backbone_dim)
self.ctrl_pe = LearnedTokenPE(num_control_tokens, backbone_dim)
self.extra_pe = LearnedTokenPE(num_extra_tokens, backbone_dim)
# === 5) RoPE 3D(仅视觉,4 时间帧 × 12 × 32 网格)===
self.rope = RoPE3D(
num_heads=num_heads,
head_dim=backbone_dim // num_heads,
time_size=num_history_frames // 2,
height_size=self.gh // 2,
width_size=self.gw // 2,
)
# === 6) 主干 18 层 ===
self.backbone = Backbone(
dim=backbone_dim,
num_heads=num_heads,
ffn_mult=ffn_mult,
num_dense_layers=num_dense_layers,
num_moe_layers=num_moe_layers,
num_routed=num_routed_experts,
num_shared=num_shared_experts,
topk=topk_experts,
)
# === 7) 头 ===
self.det_traj_head = DetectionTrajHead(
in_dim=backbone_dim,
hidden_size=det_head_hidden,
num_classes=num_classes,
traj_horizon=traj_horizon,
)
self.ctrl_head = ControlHead(
in_dim=backbone_dim,
hidden_size=ctrl_head_hidden,
num_traj_tokens=12,
num_action_tokens=num_control_tokens - 12,
ego_traj_horizon=traj_horizon,
)
# ---------- 工具 ----------
@property
def num_visual_tokens(self) -> int:
# 2×2×2 压缩后
return (self.num_history // 2) * (self.gh // 2) * (self.gw // 2)
def _build_ego_tokens(self, ego_6d_corrected: torch.Tensor) -> torch.Tensor:
"""``[B, 8, 6]`` -> symlog -> Linear -> ``[B, 8, D]``。"""
return self.ego_proj(symlog(ego_6d_corrected))
def _build_visual_rays(
self,
intr_corrected: torch.Tensor, # [B, calib_intr_dim]
extr_corrected_se3: torch.Tensor, # [B, 4, 4] cam2vehicle
compressed_thw: tuple[int, int, int],
) -> torch.Tensor:
"""计算压缩后视觉 token 网格的射线方向。
在 2×2×2 压缩后,每个视觉 token 对应原 patch 网格的一个 2x2 区域 +
2 个时间帧。这里取所代表区域的中心像素与"中间时间"的射线作近似,
所有时间帧取同一个 (h, w) 上的射线(因为相机 pose 在 8 帧间是
rigid 的相机系;自车运动差异会通过 ego token 传递)。
"""
b = intr_corrected.shape[0]
t_, h_, w_ = compressed_thw
rays_grid = compute_ego_rays(
intr_vec=intr_corrected,
cam2vehicle=extr_corrected_se3,
height=self.image_h,
width=self.image_w,
grid_h=h_,
grid_w=w_,
device=intr_corrected.device,
dtype=intr_corrected.dtype,
) # [B, h_, w_, 3]
# 复制到时间维:[B, T_, h_, w_, 3] -> flatten 为 [B, N_v, 3]
rays = rays_grid.unsqueeze(1).expand(-1, t_, -1, -1, -1).contiguous()
rays = rays.reshape(b, t_ * h_ * w_, 3)
return rays
# ---------- 前向 ----------
def forward(
self,
images: torch.Tensor, # [B, T=8, 3, H, W]
ego_6d_raw: torch.Tensor, # [B, 8, 6]
intr_raw: torch.Tensor, # [B, calib_intr_dim],须与构造时一致
extr_6d_raw: torch.Tensor, # [B, 6]
) -> E2EOutput:
b, t, _, h, w = images.shape
assert t == self.num_history, f"history frames mismatch: {t} vs {self.num_history}"
# 1) DINOv3 patch tokens [B, T, gh, gw, D_dino]
dino_feats = self.dinov3(images)
# 2) 校准(symlog 空间残差 + symexp 还原)
calib_out: CalibrationOutput = self.calib(
dino_feats=dino_feats,
ego_raw=ego_6d_raw,
intr_raw=intr_raw,
extr_raw=extr_6d_raw,
)
corrected_ego = calib_out.corrected_ego
corrected_intr = calib_out.corrected_intr
corrected_extr_6d = calib_out.corrected_extr
# 3) 把 corrected_extr 6D 转成 4x4
from .data.se3 import six_d_to_matrix
cam2veh_corrected = six_d_to_matrix(corrected_extr_6d) # [B, 4, 4]
# 4) 2x2x2 时空压缩
compressed, thw = self.compress(dino_feats) # [B, N_v, D]
n_v = compressed.shape[1]
# 5) 视觉射线(用 corrected_intr / corrected_extr)
rays = self._build_visual_rays(corrected_intr, cam2veh_corrected, thw)
rope_cos, rope_sin = self.rope.compute_freqs(rays)
# 6) 构造非视觉 token
ego_tok = self._build_ego_tokens(corrected_ego) # [B, 8, D]
det_tok = self.det_tokens.unsqueeze(0).expand(b, -1, -1)
ctrl_tok = self.ctrl_tokens.unsqueeze(0).expand(b, -1, -1)
extra_tok = self.extra_tokens.unsqueeze(0).expand(b, -1, -1)
ego_tok = self.ego_pe(ego_tok)
det_tok = self.det_pe(det_tok)
ctrl_tok = self.ctrl_pe(ctrl_tok)
extra_tok = self.extra_pe(extra_tok)
# 7) 拼接序列:[vision | ego | det | ctrl | extra]
seq = torch.cat([compressed, ego_tok, det_tok, ctrl_tok, extra_tok], dim=1)
visual_slice = (0, n_v)
# 8) 主干
bb_out = self.backbone(seq, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
# 9) 切片送入头
offset_det = n_v + self.num_ego
offset_ctrl = offset_det + self.num_det
det_feats = bb_out.hidden_states[:, offset_det : offset_det + self.num_det]
ctrl_feats = bb_out.hidden_states[:, offset_ctrl : offset_ctrl + self.num_ctrl]
det_out = self.det_traj_head(det_feats)
ctrl_out = self.ctrl_head(ctrl_feats)
return E2EOutput(
detection=det_out,
control=ctrl_out,
backbone_out=bb_out,
calibration=calib_out,
)
|