File size: 10,708 Bytes
0cfefd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
"""端到端自动驾驶模型 E2EAVModel。

forward 流程
  1. ``DINOv3`` 提取 8 帧 patch 特征。
  2. ``OnlineCalibration`` 用原始 ego/intr/extr (symlog) + DINOv3 patch 作 KV,
     输出 symlog 空间残差,叠加并 symexp 还原得到 corrected_*。
  3. 用 corrected_intr / corrected_extr / corrected_ego 计算
     - 每 token 的自车系单位射线(仅用于视觉 token 的 RoPE 第一组头)。
     - 8 个 ego token(symlog 后线性投影)。
  4. 2×2×2 时空压缩 -> 1536 视觉 token。
  5. 拼接 [vision(1536) | ego(8) | det(1024) | ctrl(24) | extra(256)] = 2848 token。
     非视觉切片各自加可学习 PE。
  6. 18 层主干(仅视觉切片应用 3D RoPE)。
  7. 切片送入 ``DetectionTrajHead`` 与 ``ControlHead``。
"""

from __future__ import annotations

from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F

from .backbone import Backbone, BackboneOutput
from .calibration import OnlineCalibration, CalibrationOutput
from .encoders import DINOv3Wrapper
from .heads import (
    ControlHead,
    ControlOutput,
    DetectionTrajHead,
    DetectionTrajOutput,
)
from .modules.learned_pe import LearnedTokenPE
from .modules.normalization import symlog
from .modules.pos_encoding import RoPE3D
from .modules.rays import compute_ego_rays
from .modules.temporal_compress import TemporalCompress2x2x2


@dataclass
class E2EOutput:
    """模型完整输出。"""

    detection: DetectionTrajOutput
    control: ControlOutput
    backbone_out: BackboneOutput
    calibration: CalibrationOutput


class E2EAVModel(nn.Module):
    def __init__(
        self,
        dinov3_path: str = "./dinov3-vitb16-pretrain-lvd1689m",
        backbone_dim: int = 768,
        num_heads: int = 12,
        num_dense_layers: int = 9,
        num_moe_layers: int = 9,
        num_routed_experts: int = 7,
        num_shared_experts: int = 1,
        topk_experts: int = 3,
        ffn_mult: int = 4,
        # token 数量
        num_history_frames: int = 8,
        num_detection_tokens: int = 1024,
        num_control_tokens: int = 24,
        num_ego_tokens: int = 8,
        num_extra_tokens: int = 256,
        # 输入分辨率
        image_h: int = 384,
        image_w: int = 1024,
        patch_size: int = 16,
        # 头超参
        num_classes: int = 22,
        traj_horizon: int = 24,
        det_head_hidden: int = 384,
        ctrl_head_hidden: int = 384,
        # 校准
        calib_dim: int = 256,
        calib_num_query: int = 256,
        calib_num_blocks: int = 2,
        calib_num_self_per_block: int = 2,
        calib_num_heads: int = 8,
        calib_residual_range: float = 0.1,
        calib_intr_dim: int = 11,
        # DINOv3
        freeze_dinov3: bool = True,
        attn_implementation: str = "sdpa",
    ) -> None:
        super().__init__()
        self.image_h = image_h
        self.image_w = image_w
        self.patch_size = patch_size
        self.num_history = num_history_frames
        self.num_det = num_detection_tokens
        self.num_ctrl = num_control_tokens
        self.num_ego = num_ego_tokens
        self.num_extra = num_extra_tokens

        # === 1) DINOv3 ===
        self.dinov3 = DINOv3Wrapper(
            pretrained_path=dinov3_path,
            attn_implementation=attn_implementation,
            freeze=freeze_dinov3,
        )
        dino_dim = self.dinov3.hidden_size

        # === 2) 在线校准 ===
        self.calib = OnlineCalibration(
            dino_dim=dino_dim,
            hidden_dim=calib_dim,
            num_query_tokens=calib_num_query,
            num_blocks=calib_num_blocks,
            num_self_attn_per_block=calib_num_self_per_block,
            num_heads=calib_num_heads,
            residual_range=calib_residual_range,
            num_history_frames=num_history_frames,
            intr_dim=calib_intr_dim,
        )

        # === 3) 时空压缩 ===
        self.compress = TemporalCompress2x2x2(dim=dino_dim)
        # patch 网格大小(必须能被 2 整除)
        self.gh = image_h // patch_size
        self.gw = image_w // patch_size

        # === 4) 各类 token + 可学习 PE ===
        self.ego_proj = nn.Linear(6, backbone_dim)  # 6D pose -> backbone dim
        self.det_tokens = nn.Parameter(torch.empty(num_detection_tokens, backbone_dim))
        nn.init.trunc_normal_(self.det_tokens, std=0.02)
        self.ctrl_tokens = nn.Parameter(torch.empty(num_control_tokens, backbone_dim))
        nn.init.trunc_normal_(self.ctrl_tokens, std=0.02)
        self.extra_tokens = nn.Parameter(torch.empty(num_extra_tokens, backbone_dim))
        nn.init.trunc_normal_(self.extra_tokens, std=0.02)

        self.ego_pe = LearnedTokenPE(num_ego_tokens, backbone_dim)
        self.det_pe = LearnedTokenPE(num_detection_tokens, backbone_dim)
        self.ctrl_pe = LearnedTokenPE(num_control_tokens, backbone_dim)
        self.extra_pe = LearnedTokenPE(num_extra_tokens, backbone_dim)

        # === 5) RoPE 3D(仅视觉,4 时间帧 × 12 × 32 网格)===
        self.rope = RoPE3D(
            num_heads=num_heads,
            head_dim=backbone_dim // num_heads,
            time_size=num_history_frames // 2,
            height_size=self.gh // 2,
            width_size=self.gw // 2,
        )

        # === 6) 主干 18 层 ===
        self.backbone = Backbone(
            dim=backbone_dim,
            num_heads=num_heads,
            ffn_mult=ffn_mult,
            num_dense_layers=num_dense_layers,
            num_moe_layers=num_moe_layers,
            num_routed=num_routed_experts,
            num_shared=num_shared_experts,
            topk=topk_experts,
        )

        # === 7) 头 ===
        self.det_traj_head = DetectionTrajHead(
            in_dim=backbone_dim,
            hidden_size=det_head_hidden,
            num_classes=num_classes,
            traj_horizon=traj_horizon,
        )
        self.ctrl_head = ControlHead(
            in_dim=backbone_dim,
            hidden_size=ctrl_head_hidden,
            num_traj_tokens=12,
            num_action_tokens=num_control_tokens - 12,
            ego_traj_horizon=traj_horizon,
        )

    # ---------- 工具 ----------

    @property
    def num_visual_tokens(self) -> int:
        # 2×2×2 压缩后
        return (self.num_history // 2) * (self.gh // 2) * (self.gw // 2)

    def _build_ego_tokens(self, ego_6d_corrected: torch.Tensor) -> torch.Tensor:
        """``[B, 8, 6]`` -> symlog -> Linear -> ``[B, 8, D]``。"""
        return self.ego_proj(symlog(ego_6d_corrected))

    def _build_visual_rays(
        self,
        intr_corrected: torch.Tensor,    # [B, calib_intr_dim]
        extr_corrected_se3: torch.Tensor, # [B, 4, 4] cam2vehicle
        compressed_thw: tuple[int, int, int],
    ) -> torch.Tensor:
        """计算压缩后视觉 token 网格的射线方向。

        在 2×2×2 压缩后,每个视觉 token 对应原 patch 网格的一个 2x2 区域 +
        2 个时间帧。这里取所代表区域的中心像素与"中间时间"的射线作近似,
        所有时间帧取同一个 (h, w) 上的射线(因为相机 pose 在 8 帧间是
        rigid 的相机系;自车运动差异会通过 ego token 传递)。
        """
        b = intr_corrected.shape[0]
        t_, h_, w_ = compressed_thw
        rays_grid = compute_ego_rays(
            intr_vec=intr_corrected,
            cam2vehicle=extr_corrected_se3,
            height=self.image_h,
            width=self.image_w,
            grid_h=h_,
            grid_w=w_,
            device=intr_corrected.device,
            dtype=intr_corrected.dtype,
        )  # [B, h_, w_, 3]
        # 复制到时间维:[B, T_, h_, w_, 3] -> flatten 为 [B, N_v, 3]
        rays = rays_grid.unsqueeze(1).expand(-1, t_, -1, -1, -1).contiguous()
        rays = rays.reshape(b, t_ * h_ * w_, 3)
        return rays

    # ---------- 前向 ----------

    def forward(
        self,
        images: torch.Tensor,        # [B, T=8, 3, H, W]
        ego_6d_raw: torch.Tensor,     # [B, 8, 6]
        intr_raw: torch.Tensor,        # [B, calib_intr_dim],须与构造时一致
        extr_6d_raw: torch.Tensor,     # [B, 6]
    ) -> E2EOutput:
        b, t, _, h, w = images.shape
        assert t == self.num_history, f"history frames mismatch: {t} vs {self.num_history}"

        # 1) DINOv3 patch tokens [B, T, gh, gw, D_dino]
        dino_feats = self.dinov3(images)

        # 2) 校准(symlog 空间残差 + symexp 还原)
        calib_out: CalibrationOutput = self.calib(
            dino_feats=dino_feats,
            ego_raw=ego_6d_raw,
            intr_raw=intr_raw,
            extr_raw=extr_6d_raw,
        )
        corrected_ego = calib_out.corrected_ego
        corrected_intr = calib_out.corrected_intr
        corrected_extr_6d = calib_out.corrected_extr

        # 3) 把 corrected_extr 6D 转成 4x4
        from .data.se3 import six_d_to_matrix
        cam2veh_corrected = six_d_to_matrix(corrected_extr_6d)  # [B, 4, 4]

        # 4) 2x2x2 时空压缩
        compressed, thw = self.compress(dino_feats)  # [B, N_v, D]
        n_v = compressed.shape[1]

        # 5) 视觉射线(用 corrected_intr / corrected_extr)
        rays = self._build_visual_rays(corrected_intr, cam2veh_corrected, thw)
        rope_cos, rope_sin = self.rope.compute_freqs(rays)

        # 6) 构造非视觉 token
        ego_tok = self._build_ego_tokens(corrected_ego)         # [B, 8, D]
        det_tok = self.det_tokens.unsqueeze(0).expand(b, -1, -1)
        ctrl_tok = self.ctrl_tokens.unsqueeze(0).expand(b, -1, -1)
        extra_tok = self.extra_tokens.unsqueeze(0).expand(b, -1, -1)

        ego_tok = self.ego_pe(ego_tok)
        det_tok = self.det_pe(det_tok)
        ctrl_tok = self.ctrl_pe(ctrl_tok)
        extra_tok = self.extra_pe(extra_tok)

        # 7) 拼接序列:[vision | ego | det | ctrl | extra]
        seq = torch.cat([compressed, ego_tok, det_tok, ctrl_tok, extra_tok], dim=1)
        visual_slice = (0, n_v)

        # 8) 主干
        bb_out = self.backbone(seq, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)

        # 9) 切片送入头
        offset_det = n_v + self.num_ego
        offset_ctrl = offset_det + self.num_det

        det_feats = bb_out.hidden_states[:, offset_det : offset_det + self.num_det]
        ctrl_feats = bb_out.hidden_states[:, offset_ctrl : offset_ctrl + self.num_ctrl]

        det_out = self.det_traj_head(det_feats)
        ctrl_out = self.ctrl_head(ctrl_feats)

        return E2EOutput(
            detection=det_out,
            control=ctrl_out,
            backbone_out=bb_out,
            calibration=calib_out,
        )