fuzirui commited on
Commit
196e6e0
·
verified ·
1 Parent(s): 09c5dcd

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. MEMORY.md +63 -63
  2. README.md +63 -63
  3. configs/default.yaml +192 -192
  4. pyproject.toml +40 -40
  5. scripts/download_data.py +76 -76
  6. scripts/estimate_memory.py +203 -203
  7. scripts/ingest_hub_to_bucket.py +234 -207
  8. scripts/push_cpu_ingest_job.py +148 -141
  9. scripts/push_to_jobs.py +196 -196
  10. scripts/push_to_sandbox.py +185 -185
  11. scripts/sandbox_real_data.py +236 -236
  12. scripts/smoke_test.py +78 -78
  13. scripts/smoke_train.py +152 -152
  14. scripts/update_deps.py +123 -123
  15. src/wjad/__init__.py +5 -5
  16. src/wjad/backbone/__init__.py +6 -6
  17. src/wjad/backbone/backbone.py +110 -110
  18. src/wjad/backbone/blocks.py +79 -79
  19. src/wjad/calibration/__init__.py +5 -5
  20. src/wjad/calibration/online_calib.py +196 -196
  21. src/wjad/data/__init__.py +39 -39
  22. src/wjad/data/cosmos_dataset.py +439 -439
  23. src/wjad/data/ftheta_proj.py +62 -62
  24. src/wjad/data/hdmap.py +247 -247
  25. src/wjad/data/label_paths.py +218 -218
  26. src/wjad/data/se3.py +111 -111
  27. src/wjad/data/targets.py +214 -214
  28. src/wjad/data/transforms.py +86 -86
  29. src/wjad/encoders/__init__.py +5 -5
  30. src/wjad/encoders/dinov3_wrapper.py +104 -104
  31. src/wjad/heads/__init__.py +11 -11
  32. src/wjad/heads/control.py +100 -100
  33. src/wjad/heads/detection_traj.py +106 -106
  34. src/wjad/losses/__init__.py +24 -24
  35. src/wjad/losses/calib_reg.py +21 -21
  36. src/wjad/losses/control.py +25 -25
  37. src/wjad/losses/detection.py +213 -213
  38. src/wjad/losses/moe_aux.py +33 -33
  39. src/wjad/losses/nll.py +47 -47
  40. src/wjad/losses/trajectory.py +43 -43
  41. src/wjad/model.py +289 -289
  42. src/wjad/modules/__init__.py +28 -28
  43. src/wjad/modules/ffn.py +30 -30
  44. src/wjad/modules/gate_attention.py +181 -181
  45. src/wjad/modules/learned_pe.py +24 -24
  46. src/wjad/modules/moe.py +129 -129
  47. src/wjad/modules/normalization.py +22 -22
  48. src/wjad/modules/pos_encoding.py +224 -224
  49. src/wjad/modules/rays.py +182 -182
  50. src/wjad/modules/temporal_compress.py +34 -34
MEMORY.md CHANGED
@@ -1,63 +1,63 @@
1
- # 显存与内存估算(bf16 AMP + GradNorm + PCGrad)
2
-
3
- 由 `scripts/estimate_memory.py` 生成。模型规模:18 层主干(9 Dense + 9 MoE,每 MoE 层 7 路由 + 1 共享专家)+ DINOv3 ViT-B/16 + 6 层校准 + 1024 检测 token + 24 控制 token。
4
-
5
- | 项目 | 数值 |
6
- |---|---|
7
- | 总参数 | **725.62 M** |
8
- | 可训练 (Stage1, DINOv3 冻结) | 639.96 M |
9
- | 可训练 (Stage2, DINOv3 解冻) | 725.62 M |
10
- | 序列长度(拼接后) | 2848 |
11
-
12
- ## 显存(含 15% 余量)
13
-
14
- | Batch Size | Stage2 峰值 | 推荐单卡 GPU | HF Sandbox 选项 |
15
- |---:|---:|---|---|
16
- | 1 | ~16 GB | T4 16GB(紧)/ L4 24GB | `t4-small` |
17
- | 2 | ~18 GB | L4 24GB | `l4x1` |
18
- | 4 | ~22 GB | L4 24GB / A10G 24GB | `a10g-small` |
19
- | **8 (目标)** | **~30 GB** | **A10G Large 48GB / A100 40GB** | **`a10g-large`** |
20
- | 16 | ~46 GB | A100 80GB / H100 80GB | `a100-large` |
21
-
22
- 显存细分(BS=8 Stage2):
23
- - 权重 (bf16): 1.35 GB
24
- - 优化器 (AdamW fp32 m+v + 主副本): 8.11 GB
25
- - 主激活 (bf16, 18 + 6 层): ~12.6 GB
26
- - PCGrad retain_graph 开销: ~6.3 GB
27
- - 缓冲 / cuDNN workspace / 碎片: ~2 GB
28
-
29
- 如显存不足:
30
- - 开 `gradient_checkpointing`(激活降至 ~1/3,可把 BS=8 塞进 A10G 24GB 大约 28GB)
31
- - BS=4 + `grad_accum_steps=2` 等价 BS=8 训练
32
- - 关 PCGrad(节省 ~6 GB),但牺牲多任务收敛质量
33
-
34
- ## 主机内存 / 磁盘
35
-
36
- | 项目 | 数值(BS=8) |
37
- |---|---|
38
- | 主机 RAM 推荐 | ≥ 32 GB(DataLoader 4 workers × prefetch 2 + 模型 CPU 副本) |
39
- | 磁盘(一个 weather 子集,sandbox 验证) | ~50 GB |
40
- | 磁盘(synthetic 全量 121 帧 × 7 weather × 5843 clip) | ~700 GB |
41
- | 磁盘(synthetic + lidar + hdmap 全量) | ~3 TB |
42
-
43
- ## 设备选择建议
44
-
45
- - **本地烟囱(CPU/小卡)**:`scripts/smoke_test.py` 用极小张量验证 forward+backward,不需要 GPU。
46
- - **HF Sandbox**:`a10g-large` (48 GB),BS=8 + bf16 + PCGrad 一次成功;约 $1.05/小时(HF 价格随时调整请以官方为准)。
47
- - **HF Jobs 全量训练**:`a100x1` (80 GB) 或 `h100x1`,BS=8~16。
48
-
49
- ## 复现命令
50
-
51
- ```bash
52
- # 升级依赖到最新(写入 requirements.lock.txt)
53
- python scripts/update_deps.py --torch-index https://download.pytorch.org/whl/cu124
54
-
55
- # 估算
56
- python scripts/estimate_memory.py
57
-
58
- # Sandbox 推送
59
- python scripts/push_to_sandbox.py --repo your-username/wjad-sandbox --gpu a10g-large
60
-
61
- # Jobs 全量
62
- python scripts/push_to_jobs.py --repo your-username/wjad --flavor a100x1
63
- ```
 
1
+ # 显存与内存估算(bf16 AMP + GradNorm + PCGrad)
2
+
3
+ 由 `scripts/estimate_memory.py` 生成。模型规模:18 层主干(9 Dense + 9 MoE,每 MoE 层 7 路由 + 1 共享专家)+ DINOv3 ViT-B/16 + 6 层校准 + 1024 检测 token + 24 控制 token。
4
+
5
+ | 项目 | 数值 |
6
+ |---|---|
7
+ | 总参数 | **725.62 M** |
8
+ | 可训练 (Stage1, DINOv3 冻结) | 639.96 M |
9
+ | 可训练 (Stage2, DINOv3 解冻) | 725.62 M |
10
+ | 序列长度(拼接后) | 2848 |
11
+
12
+ ## 显存(含 15% 余量)
13
+
14
+ | Batch Size | Stage2 峰值 | 推荐单卡 GPU | HF Sandbox 选项 |
15
+ |---:|---:|---|---|
16
+ | 1 | ~16 GB | T4 16GB(紧)/ L4 24GB | `t4-small` |
17
+ | 2 | ~18 GB | L4 24GB | `l4x1` |
18
+ | 4 | ~22 GB | L4 24GB / A10G 24GB | `a10g-small` |
19
+ | **8 (目标)** | **~30 GB** | **A10G Large 48GB / A100 40GB** | **`a10g-large`** |
20
+ | 16 | ~46 GB | A100 80GB / H100 80GB | `a100-large` |
21
+
22
+ 显存细分(BS=8 Stage2):
23
+ - 权重 (bf16): 1.35 GB
24
+ - 优化器 (AdamW fp32 m+v + 主副本): 8.11 GB
25
+ - 主激活 (bf16, 18 + 6 层): ~12.6 GB
26
+ - PCGrad retain_graph 开销: ~6.3 GB
27
+ - 缓冲 / cuDNN workspace / 碎片: ~2 GB
28
+
29
+ 如显存不足:
30
+ - 开 `gradient_checkpointing`(激活降至 ~1/3,可把 BS=8 塞进 A10G 24GB 大约 28GB)
31
+ - BS=4 + `grad_accum_steps=2` 等价 BS=8 训练
32
+ - 关 PCGrad(节省 ~6 GB),但牺牲多任务收敛质量
33
+
34
+ ## 主机内存 / 磁盘
35
+
36
+ | 项目 | 数值(BS=8) |
37
+ |---|---|
38
+ | 主机 RAM 推荐 | ≥ 32 GB(DataLoader 4 workers × prefetch 2 + 模型 CPU 副本) |
39
+ | 磁盘(一个 weather 子集,sandbox 验证) | ~50 GB |
40
+ | 磁盘(synthetic 全量 121 帧 × 7 weather × 5843 clip) | ~700 GB |
41
+ | 磁盘(synthetic + lidar + hdmap 全量) | ~3 TB |
42
+
43
+ ## 设备选择建议
44
+
45
+ - **本地烟囱(CPU/小卡)**:`scripts/smoke_test.py` 用极小张量验证 forward+backward,不需要 GPU。
46
+ - **HF Sandbox**:`a10g-large` (48 GB),BS=8 + bf16 + PCGrad 一次成功;约 $1.05/小时(HF 价格随时调整请以官方为准)。
47
+ - **HF Jobs 全量训练**:`a100x1` (80 GB) 或 `h100x1`,BS=8~16。
48
+
49
+ ## 复现命令
50
+
51
+ ```bash
52
+ # 升级依赖到最新(写入 requirements.lock.txt)
53
+ python scripts/update_deps.py --torch-index https://download.pytorch.org/whl/cu124
54
+
55
+ # 估算
56
+ python scripts/estimate_memory.py
57
+
58
+ # Sandbox 推送
59
+ python scripts/push_to_sandbox.py --repo your-username/wjad-sandbox --gpu a10g-large
60
+
61
+ # Jobs 全量
62
+ python scripts/push_to_jobs.py --repo your-username/wjad --flavor a100x1
63
+ ```
README.md CHANGED
@@ -1,63 +1,63 @@
1
- ---
2
- title: WJAD Sandbox
3
- emoji: 🚗
4
- colorFrom: blue
5
- colorTo: indigo
6
- sdk: docker
7
- app_port: 7860
8
- pinned: false
9
- ---
10
-
11
- # WJAD - 端到端自动驾驶模型
12
-
13
- 基于 [Design.md](Design.md) 实现的端到端自动驾驶模型,用于 NVIDIA Cosmos-Drive-Dreams 数据集。
14
-
15
- ## 架构概览
16
-
17
- - **视觉编码器**:本地 DINOv3 ViT-B/16(`dinov3-vitb16-pretrain-lvd1689m`),SDPA 注意力。
18
- - **时空压缩**:2×2×2 Conv3D,将 8 帧 × 24 × 64 patch tokens 压缩为 1536 个视觉 token。
19
- - **在线校准**:dim=256,6 层 (1 GateCrossAttn + 2 GateSelfAttn) × 2,跨注意力 K/V 来自 DINOv3 patch;输入与残差均在 symlog 空间,输出 SE3 + 内外参修正量。
20
- - **主干**:18 层 GateSelfAttention(前 9 Dense + 后 9 MoE,每层独立 7 路由 + 1 共享专家,GAP 序列级 Sigmoid Top-3),dim=768,12 头 SDPA + PreNorm + SwiGLU。
21
- - **位置编码**:3D RoPE 仅作用于视觉 token 的 Q/K——头 0-3 编码自车系单位射线,头 4-7 编码 H/W/T,头 8-11 零频段(identity,统一代码路径)。其余 token(ego/det/ctrl/extra)使用一一对应的可学习 PE。
22
- - **统一检测+预测头**:1024 token 同时输出 `cls + is_dynamic + box3d(μ,logσ) + 未来 24 帧轨迹(μ,logσ)`。
23
- - **控制头**:24 token 输出自车未来轨迹与全局控制(均 NLL μ/logσ)。
24
- - **多任务训练**:GradNorm 自适应任务权重 + Stage2 启用 PCGrad 正交化梯度冲突。
25
- - **训练阶段**:Stage1 Dense + 路由锐化 + 中期运动学/内外参扰动;Stage2 切 Top-3 + DINOv3 低 LR 微调。
26
-
27
- ## 三步训练路径
28
-
29
- ```bash
30
- # 1. 本地跑通(纯随机张量)
31
- python -m scripts.smoke_test
32
-
33
- # 2. HF Sandbox 微小训练
34
- python -m scripts.push_to_sandbox
35
-
36
- # 3. HF Jobs 全量训练
37
- python -m scripts.push_to_jobs
38
- ```
39
-
40
- ## 数据准备
41
-
42
- ```bash
43
- python -m scripts.download_data --odir ./data/cosmos --file_types synthetic,lidar,hdmap
44
- ```
45
-
46
- ## 项目结构
47
-
48
- ```
49
- src/wjad/
50
- ├── modules/ # 公用算子:FFN/门控注意力/MoE/RoPE/可学习PE/symlog/...
51
- ├── encoders/ # DINOv3 包装 + 2x2x2 时空压缩
52
- ├── calibration/ # 在线校准网络
53
- ├── backbone/ # 18 层主干
54
- ├── heads/ # 检测+预测头、控制头
55
- ├── data/ # Cosmos-Drive-Dreams 加载器、f-theta、增广
56
- ├── losses/ # NLL/检测/轨迹/控制/MoE/校准正则
57
- ├── train/ # 多任务(GradNorm+PCGrad)、Trainer、调度
58
- └── model.py # 顶层 E2EAVModel
59
- ```
60
-
61
- ## License
62
-
63
- 代码遵循仓库根目录指定的开源协议。DINOv3 权重遵循 Meta DINOv3 License;Cosmos-Drive-Dreams 数据集遵循 CC BY 4.0。
 
1
+ ---
2
+ title: WJAD Sandbox
3
+ emoji: 🚗
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: false
9
+ ---
10
+
11
+ # WJAD - 端到端自动驾驶模型
12
+
13
+ 基于 [Design.md](Design.md) 实现的端到端自动驾驶模型,用于 NVIDIA Cosmos-Drive-Dreams 数据集。
14
+
15
+ ## 架构概览
16
+
17
+ - **视觉编码器**:本地 DINOv3 ViT-B/16(`dinov3-vitb16-pretrain-lvd1689m`),SDPA 注意力。
18
+ - **时空压缩**:2×2×2 Conv3D,将 8 帧 × 24 × 64 patch tokens 压缩为 1536 个视觉 token。
19
+ - **在线校准**:dim=256,6 层 (1 GateCrossAttn + 2 GateSelfAttn) × 2,跨注意力 K/V 来自 DINOv3 patch;输入与残差均在 symlog 空间,输出 SE3 + 内外参修正量。
20
+ - **主干**:18 层 GateSelfAttention(前 9 Dense + 后 9 MoE,每层独立 7 路由 + 1 共享专家,GAP 序列级 Sigmoid Top-3),dim=768,12 头 SDPA + PreNorm + SwiGLU。
21
+ - **位置编码**:3D RoPE 仅作用于视觉 token 的 Q/K——头 0-3 编码自车系单位射线,头 4-7 编码 H/W/T,头 8-11 零频段(identity,统一代码路径)。其余 token(ego/det/ctrl/extra)使用一一对应的可学习 PE。
22
+ - **统一检测+预测头**:1024 token 同时输出 `cls + is_dynamic + box3d(μ,logσ) + 未来 24 帧轨迹(μ,logσ)`。
23
+ - **控制头**:24 token 输出自车未来轨迹与全局控制(均 NLL μ/logσ)。
24
+ - **多任务训练**:GradNorm 自适应任务权重 + Stage2 启用 PCGrad 正交化梯度冲突。
25
+ - **训练阶段**:Stage1 Dense + 路由锐化 + 中期运动学/内外参扰动;Stage2 切 Top-3 + DINOv3 低 LR 微调。
26
+
27
+ ## 三步训练路径
28
+
29
+ ```bash
30
+ # 1. 本地跑通(纯随机张量)
31
+ python -m scripts.smoke_test
32
+
33
+ # 2. HF Sandbox 微小训练
34
+ python -m scripts.push_to_sandbox
35
+
36
+ # 3. HF Jobs 全量训练
37
+ python -m scripts.push_to_jobs
38
+ ```
39
+
40
+ ## 数据准备
41
+
42
+ ```bash
43
+ python -m scripts.download_data --odir ./data/cosmos --file_types synthetic,lidar,hdmap
44
+ ```
45
+
46
+ ## 项目结构
47
+
48
+ ```
49
+ src/wjad/
50
+ ├── modules/ # 公用算子:FFN/门控注意力/MoE/RoPE/可学习PE/symlog/...
51
+ ├── encoders/ # DINOv3 包装 + 2x2x2 时空压缩
52
+ ├── calibration/ # 在线校准网络
53
+ ├── backbone/ # 18 层主干
54
+ ├── heads/ # 检测+预测头、控制头
55
+ ├── data/ # Cosmos-Drive-Dreams 加载器、f-theta、增广
56
+ ├── losses/ # NLL/检测/轨迹/控制/MoE/校准正则
57
+ ├── train/ # 多任务(GradNorm+PCGrad)、Trainer、调度
58
+ └── model.py # 顶层 E2EAVModel
59
+ ```
60
+
61
+ ## License
62
+
63
+ 代码遵循仓库根目录指定的开源协议。DINOv3 权重遵循 Meta DINOv3 License;Cosmos-Drive-Dreams 数据集遵循 CC BY 4.0。
configs/default.yaml CHANGED
@@ -1,192 +1,192 @@
1
- # 端到端自动驾驶模型默认配置(与 Design.md 对齐)
2
-
3
- # === 全局 ===
4
- seed: 42
5
- device: cuda
6
- mixed_precision: bf16 # H100 推荐 bf16
7
- gradient_checkpointing: true # A10G/L4/A100-40G 上需要打开;H100-80G 可关闭以加速
8
-
9
- # === 输入 ===
10
- input:
11
- image_height: 384 # 已裁去上半天空的高度
12
- image_width: 1024
13
- num_history_frames: 8 # t-7..t(含当前)
14
- num_future_frames: 24 # 预测窗口
15
- camera_name: camera_front_wide_120fov # 当前唯一开放的视角
16
-
17
- # === 视觉编码器(DINOv3)===
18
- dinov3:
19
- pretrained_path: "./dinov3-vitb16-pretrain-lvd1689m"
20
- hidden_size: 768
21
- patch_size: 16
22
- num_register_tokens: 4
23
- attn_implementation: sdpa
24
- freeze_in_stage1: true # Stage1 冻结 DINOv3
25
- finetune_lr_ratio: 0.01 # Stage2 解冻后相对主干 LR 的倍率
26
-
27
- # === 时空压缩 ===
28
- temporal_compress:
29
- kernel: [2, 2, 2]
30
- stride: [2, 2, 2]
31
-
32
- # === 主干 ===
33
- backbone:
34
- hidden_size: 768
35
- num_heads: 12 # head_dim = 64
36
- ffn_mult: 4 # SwiGLU 扩展倍数(D->4D->2D->D)
37
- num_dense_layers: 9
38
- num_moe_layers: 9 # 共 18 层
39
- dropout: 0.0
40
- prenorm: true
41
-
42
- # === MoE ===
43
- moe:
44
- num_routed_experts: 7
45
- num_shared_experts: 1
46
- topk: 3 # Stage2 激活专家数
47
- router_temperature_init: 0.5
48
- router_temperature_final: 1.0
49
- load_balance_weight: 0.01
50
- boundary_weight: 0.001 # mean(logits^2) 防越界
51
-
52
- # === Token 数量 ===
53
- tokens:
54
- num_detection: 1024
55
- num_control: 24
56
- num_ego: 8
57
- num_extra: 256
58
-
59
- # === 在线校准 ===
60
- calibration:
61
- intr_vec_dim: 11 # Cosmos npy 常见 11 维;完整 14 维数据则改为 14
62
- hidden_size: 256
63
- num_query_tokens: 256
64
- num_self_attn_per_block: 2 # 每个 block: 1 cross + 2 self
65
- num_blocks: 2 # 总计 6 层
66
- num_heads: 8 # head_dim = 32
67
- residual_range: 0.1 # Tanh * range
68
- init_zero_output: true
69
-
70
- # === 检测+未来预测头 ===
71
- det_traj_head:
72
- num_classes: 22 # 1 bg + 12 dynamic + 9 structured(HDMap)
73
- box_dim: 7 # x,y,z,l,w,h,yaw
74
- traj_horizon: 24 # 未来 24 帧
75
- traj_dim: 3 # dx,dy,dyaw
76
- hidden_size: 384
77
- log_sigma_clamp: [-7.0, 7.0]
78
-
79
- # === 控制头 ===
80
- control_head:
81
- num_traj_tokens: 12 # 解码 24 帧 ego 轨迹
82
- num_action_tokens: 12 # 1 个解码 (steer,throttle,brake) μ/logσ
83
- ego_traj_dim: 3 # x,y,yaw
84
- action_dim: 3
85
- hidden_size: 384
86
- log_sigma_clamp: [-7.0, 7.0]
87
-
88
- # === 检测目标筛选 ===
89
- detection:
90
- max_distance_m: 48.0
91
- occlusion_depth_tolerance: 0.5 # LIDAR 深度容差(米)
92
- min_box_pixels: 8
93
- dynamic_classes:
94
- - Automobile
95
- - Heavy_truck
96
- - Bus
97
- - Train_or_tram_car
98
- - Trolley_bus
99
- - Other_vehicle
100
- - Trailer
101
- - Person
102
- - Stroller
103
- - Rider
104
- - Animal
105
- - Protruding_object
106
-
107
- # === 数据 ===
108
- data:
109
- root: "./data/cosmos"
110
- hdmap_subdir: "rds_hq"
111
- synthetic_subdir: "cosmos_synthetic/single_view"
112
- use_synthetic: true
113
- use_real: false
114
- weather:
115
- - Sunny
116
- - Morning
117
- - Golden_hour
118
- - Night
119
- - Rainy
120
- - Snowy
121
- - Foggy
122
- num_workers: 4
123
- pin_memory: true
124
- prefetch_factor: 2
125
- augmentation:
126
- gaussian_noise_std: 0.01
127
- color_jitter: 0.1
128
- perturb_translation_std_m: 0.1 # Stage1 中期开启
129
- perturb_rotation_std_deg: 0.5
130
- perturb_intrinsic_std: 0.005
131
- perturb_extrinsic_std: 0.005
132
-
133
- # === 损失权重(GradNorm 自适应的 1-6 任务初值)===
134
- loss:
135
- cls_weight: 1.0
136
- box_weight: 1.0
137
- isdyn_weight: 1.0
138
- traj_obj_weight: 1.0
139
- traj_ego_weight: 1.0
140
- ctrl_weight: 1.0
141
- moe_weight: 1.0 # 固定权重正则
142
- calib_weight: 0.1 # 固定权重正则
143
- giou_weight: 0.5 # 在 box 内组合
144
- focal_alpha: 0.25
145
- focal_gamma: 2.0
146
- matcher_cls_cost: 2.0
147
- matcher_l1_cost: 5.0
148
- matcher_giou_cost: 2.0
149
-
150
- # === 多任务训练 ===
151
- # PCGrad 与 GradNorm 在 Stage1 / Stage2 全程启用:
152
- # 两阶段的 6 项主任务(cls / box / isdyn / traj_obj / traj_ego / ctrl)
153
- # 都存在尺度差异与梯度方向冲突,PCGrad 不应延迟到 Stage2。
154
- multitask:
155
- enable_gradnorm: true
156
- enable_pcgrad: true
157
- gradnorm_alpha: 1.5
158
- gradnorm_lr: 0.025
159
- pcgrad_shuffle: true
160
-
161
- # === 训练 ===
162
- train:
163
- batch_size: 12 # A10G-Large 起步;OOM 时改为 8/6 并视情况增大 grad_accum_steps
164
- grad_accum_steps: 1 # 有效 batch = batch_size * grad_accum_steps
165
- ckpt_dir: outputs/checkpoints
166
- total_steps: 100000
167
- warmup_steps: 1000
168
- base_lr: 2.0e-4
169
- min_lr: 1.0e-6
170
- weight_decay: 0.05
171
- optimizer: adamw
172
- betas: [0.9, 0.95]
173
- grad_clip: 1.0
174
- log_interval: 20
175
- ckpt_interval: 1000
176
- eval_interval: 5000
177
- stage1_steps: 60000 # Stage1 步数
178
- stage1_perturb_start: 20000 # 中期开始扰动
179
- grad_monitor_threshold: 1.0e-7
180
- param_groups:
181
- dinov3_lr_mult: 0.0 # Stage1=0, Stage2 由 finetune_lr_ratio 提供
182
- backbone_lr_mult: 1.0
183
- calibration_lr_mult: 0.1
184
- head_lr_mult: 1.0
185
- gate_lr_mult: 0.1 # 门控参数低 LR
186
-
187
- # === 部署 ===
188
- deploy:
189
- hf_repo: "fuzirui/WJAD" # 训练产生的 checkpoint 上传到此 model 仓库
190
- push_checkpoints: true # 每 ckpt_interval 步上传 step_*.pt + latest.pt
191
- hub_ckpt_prefix: checkpoints # Hub 上子目录
192
- hf_sandbox_space: "fuzirui/wjad-sandbox"
 
1
+ # 端到端自动驾驶模型默认配置(与 Design.md 对齐)
2
+
3
+ # === 全局 ===
4
+ seed: 42
5
+ device: cuda
6
+ mixed_precision: bf16 # H100 推荐 bf16
7
+ gradient_checkpointing: true # A10G/L4/A100-40G 上需要打开;H100-80G 可关闭以加速
8
+
9
+ # === 输入 ===
10
+ input:
11
+ image_height: 384 # 已裁去上半天空的高度
12
+ image_width: 1024
13
+ num_history_frames: 8 # t-7..t(含当前)
14
+ num_future_frames: 24 # 预测窗口
15
+ camera_name: camera_front_wide_120fov # 当前唯一开放的视角
16
+
17
+ # === 视觉编码器(DINOv3)===
18
+ dinov3:
19
+ pretrained_path: "./dinov3-vitb16-pretrain-lvd1689m"
20
+ hidden_size: 768
21
+ patch_size: 16
22
+ num_register_tokens: 4
23
+ attn_implementation: sdpa
24
+ freeze_in_stage1: true # Stage1 冻结 DINOv3
25
+ finetune_lr_ratio: 0.01 # Stage2 解冻后相对主干 LR 的倍率
26
+
27
+ # === 时空压缩 ===
28
+ temporal_compress:
29
+ kernel: [2, 2, 2]
30
+ stride: [2, 2, 2]
31
+
32
+ # === 主干 ===
33
+ backbone:
34
+ hidden_size: 768
35
+ num_heads: 12 # head_dim = 64
36
+ ffn_mult: 4 # SwiGLU 扩展倍数(D->4D->2D->D)
37
+ num_dense_layers: 9
38
+ num_moe_layers: 9 # 共 18 层
39
+ dropout: 0.0
40
+ prenorm: true
41
+
42
+ # === MoE ===
43
+ moe:
44
+ num_routed_experts: 7
45
+ num_shared_experts: 1
46
+ topk: 3 # Stage2 激活专家数
47
+ router_temperature_init: 0.5
48
+ router_temperature_final: 1.0
49
+ load_balance_weight: 0.01
50
+ boundary_weight: 0.001 # mean(logits^2) 防越界
51
+
52
+ # === Token 数量 ===
53
+ tokens:
54
+ num_detection: 1024
55
+ num_control: 24
56
+ num_ego: 8
57
+ num_extra: 256
58
+
59
+ # === 在线校准 ===
60
+ calibration:
61
+ intr_vec_dim: 11 # Cosmos npy 常见 11 维;完整 14 维数据则改为 14
62
+ hidden_size: 256
63
+ num_query_tokens: 256
64
+ num_self_attn_per_block: 2 # 每个 block: 1 cross + 2 self
65
+ num_blocks: 2 # 总计 6 层
66
+ num_heads: 8 # head_dim = 32
67
+ residual_range: 0.1 # Tanh * range
68
+ init_zero_output: true
69
+
70
+ # === 检测+未来预测头 ===
71
+ det_traj_head:
72
+ num_classes: 22 # 1 bg + 12 dynamic + 9 structured(HDMap)
73
+ box_dim: 7 # x,y,z,l,w,h,yaw
74
+ traj_horizon: 24 # 未来 24 帧
75
+ traj_dim: 3 # dx,dy,dyaw
76
+ hidden_size: 384
77
+ log_sigma_clamp: [-7.0, 7.0]
78
+
79
+ # === 控制头 ===
80
+ control_head:
81
+ num_traj_tokens: 12 # 解码 24 帧 ego 轨迹
82
+ num_action_tokens: 12 # 1 个解码 (steer,throttle,brake) μ/logσ
83
+ ego_traj_dim: 3 # x,y,yaw
84
+ action_dim: 3
85
+ hidden_size: 384
86
+ log_sigma_clamp: [-7.0, 7.0]
87
+
88
+ # === 检测目标筛选 ===
89
+ detection:
90
+ max_distance_m: 48.0
91
+ occlusion_depth_tolerance: 0.5 # LIDAR 深度容差(米)
92
+ min_box_pixels: 8
93
+ dynamic_classes:
94
+ - Automobile
95
+ - Heavy_truck
96
+ - Bus
97
+ - Train_or_tram_car
98
+ - Trolley_bus
99
+ - Other_vehicle
100
+ - Trailer
101
+ - Person
102
+ - Stroller
103
+ - Rider
104
+ - Animal
105
+ - Protruding_object
106
+
107
+ # === 数据 ===
108
+ data:
109
+ root: "./data/cosmos"
110
+ hdmap_subdir: "rds_hq"
111
+ synthetic_subdir: "cosmos_synthetic/single_view"
112
+ use_synthetic: true
113
+ use_real: false
114
+ weather:
115
+ - Sunny
116
+ - Morning
117
+ - Golden_hour
118
+ - Night
119
+ - Rainy
120
+ - Snowy
121
+ - Foggy
122
+ num_workers: 4
123
+ pin_memory: true
124
+ prefetch_factor: 2
125
+ augmentation:
126
+ gaussian_noise_std: 0.01
127
+ color_jitter: 0.1
128
+ perturb_translation_std_m: 0.1 # Stage1 中期开启
129
+ perturb_rotation_std_deg: 0.5
130
+ perturb_intrinsic_std: 0.005
131
+ perturb_extrinsic_std: 0.005
132
+
133
+ # === 损失权重(GradNorm 自适应的 1-6 任务初值)===
134
+ loss:
135
+ cls_weight: 1.0
136
+ box_weight: 1.0
137
+ isdyn_weight: 1.0
138
+ traj_obj_weight: 1.0
139
+ traj_ego_weight: 1.0
140
+ ctrl_weight: 1.0
141
+ moe_weight: 1.0 # 固定权重正则
142
+ calib_weight: 0.1 # 固定权重正则
143
+ giou_weight: 0.5 # 在 box 内组合
144
+ focal_alpha: 0.25
145
+ focal_gamma: 2.0
146
+ matcher_cls_cost: 2.0
147
+ matcher_l1_cost: 5.0
148
+ matcher_giou_cost: 2.0
149
+
150
+ # === 多任务训练 ===
151
+ # PCGrad 与 GradNorm 在 Stage1 / Stage2 全程启用:
152
+ # 两阶段的 6 项主任务(cls / box / isdyn / traj_obj / traj_ego / ctrl)
153
+ # 都存在尺度差异与梯度方向冲突,PCGrad 不应延迟到 Stage2。
154
+ multitask:
155
+ enable_gradnorm: true
156
+ enable_pcgrad: true
157
+ gradnorm_alpha: 1.5
158
+ gradnorm_lr: 0.025
159
+ pcgrad_shuffle: true
160
+
161
+ # === 训练 ===
162
+ train:
163
+ batch_size: 12 # A10G-Large 起步;OOM 时改为 8/6 并视情况增大 grad_accum_steps
164
+ grad_accum_steps: 1 # 有效 batch = batch_size * grad_accum_steps
165
+ ckpt_dir: outputs/checkpoints
166
+ total_steps: 100000
167
+ warmup_steps: 1000
168
+ base_lr: 2.0e-4
169
+ min_lr: 1.0e-6
170
+ weight_decay: 0.05
171
+ optimizer: adamw
172
+ betas: [0.9, 0.95]
173
+ grad_clip: 1.0
174
+ log_interval: 20
175
+ ckpt_interval: 1000
176
+ eval_interval: 5000
177
+ stage1_steps: 60000 # Stage1 步数
178
+ stage1_perturb_start: 20000 # 中期开始扰动
179
+ grad_monitor_threshold: 1.0e-7
180
+ param_groups:
181
+ dinov3_lr_mult: 0.0 # Stage1=0, Stage2 由 finetune_lr_ratio 提供
182
+ backbone_lr_mult: 1.0
183
+ calibration_lr_mult: 0.1
184
+ head_lr_mult: 1.0
185
+ gate_lr_mult: 0.1 # 门控参数低 LR
186
+
187
+ # === 部署 ===
188
+ deploy:
189
+ hf_repo: "fuzirui/WJAD" # 训练产生的 checkpoint 上传到此 model 仓库
190
+ push_checkpoints: true # 每 ckpt_interval 步上传 step_*.pt + latest.pt
191
+ hub_ckpt_prefix: checkpoints # Hub 上子目录
192
+ hf_sandbox_space: "fuzirui/wjad-sandbox"
pyproject.toml CHANGED
@@ -1,40 +1,40 @@
1
- [build-system]
2
- requires = ["setuptools>=68", "wheel"]
3
- build-backend = "setuptools.build_meta"
4
-
5
- [project]
6
- name = "wjad"
7
- version = "0.1.0"
8
- description = "End-to-end autonomous driving model with DINOv3, GateSelfAttention backbone, MoE, online calibration."
9
- requires-python = ">=3.10"
10
- dependencies = [
11
- "torch>=2.4",
12
- "transformers>=4.56",
13
- "safetensors>=0.4",
14
- "numpy>=1.24",
15
- "opencv-python-headless>=4.8",
16
- "einops>=0.7",
17
- "scipy>=1.11",
18
- "pyyaml>=6.0",
19
- "tqdm>=4.66",
20
- "huggingface_hub>=0.24",
21
- "pillow>=10.0",
22
- "av>=12.0",
23
- ]
24
-
25
- [project.optional-dependencies]
26
- dev = [
27
- "pytest>=7",
28
- "pytest-cov>=4",
29
- "ruff>=0.5",
30
- ]
31
-
32
- [tool.setuptools]
33
- package-dir = {"" = "src"}
34
-
35
- [tool.setuptools.packages.find]
36
- where = ["src"]
37
-
38
- [tool.ruff]
39
- line-length = 120
40
- target-version = "py310"
 
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "wjad"
7
+ version = "0.1.0"
8
+ description = "End-to-end autonomous driving model with DINOv3, GateSelfAttention backbone, MoE, online calibration."
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ "torch>=2.4",
12
+ "transformers>=4.56",
13
+ "safetensors>=0.4",
14
+ "numpy>=1.24",
15
+ "opencv-python-headless>=4.8",
16
+ "einops>=0.7",
17
+ "scipy>=1.11",
18
+ "pyyaml>=6.0",
19
+ "tqdm>=4.66",
20
+ "huggingface_hub>=0.24",
21
+ "pillow>=10.0",
22
+ "av>=12.0",
23
+ ]
24
+
25
+ [project.optional-dependencies]
26
+ dev = [
27
+ "pytest>=7",
28
+ "pytest-cov>=4",
29
+ "ruff>=0.5",
30
+ ]
31
+
32
+ [tool.setuptools]
33
+ package-dir = {"" = "src"}
34
+
35
+ [tool.setuptools.packages.find]
36
+ where = ["src"]
37
+
38
+ [tool.ruff]
39
+ line-length = 120
40
+ target-version = "py310"
scripts/download_data.py CHANGED
@@ -1,76 +1,76 @@
1
- """下载 NVIDIA Cosmos-Drive-Dreams 数据集。
2
-
3
- 直接调用 NVIDIA 官方 download.py,并支持仅下载几个 clip 用于 sandbox 验证。
4
-
5
- 用法:
6
- # 完整下载(synthetic + lidar + hdmap),约 3TB
7
- python scripts/download_data.py --odir ./data/cosmos --workers 8
8
-
9
- # 仅烟囱:限制 clip 数量(约几 GB,取决于 N)
10
- python scripts/download_data.py --odir ./data/cosmos --file_types synthetic,lidar,hdmap --limit 2
11
- """
12
-
13
- from __future__ import annotations
14
-
15
- import argparse
16
- import subprocess
17
- import sys
18
- import urllib.request
19
- from pathlib import Path
20
-
21
-
22
- NV_DOWNLOAD_URL = (
23
- "https://raw.githubusercontent.com/nv-tlabs/Cosmos-Drive-Dreams/main/scripts/download.py"
24
- )
25
-
26
-
27
- def _ensure_official_script(local_path: Path) -> None:
28
- if local_path.exists():
29
- return
30
- print(f"[download_data] 下载 NVIDIA download.py -> {local_path}")
31
- local_path.parent.mkdir(parents=True, exist_ok=True)
32
- with urllib.request.urlopen(NV_DOWNLOAD_URL) as resp, open(local_path, "wb") as f:
33
- f.write(resp.read())
34
-
35
-
36
- def main() -> None:
37
- parser = argparse.ArgumentParser()
38
- parser.add_argument("--odir", required=True, help="数据输出目录")
39
- parser.add_argument(
40
- "--file_types",
41
- default="synthetic,lidar,hdmap",
42
- help="数据类型逗号分隔列表",
43
- )
44
- parser.add_argument("--workers", type=int, default=4)
45
- parser.add_argument(
46
- "--limit",
47
- type=int,
48
- default=None,
49
- metavar="N",
50
- help="只拉取前 N 个 clip(传给 NVIDIA download.py,省磁盘)",
51
- )
52
- parser.add_argument("--clean_cache", action="store_true")
53
- args = parser.parse_args()
54
-
55
- odir = Path(args.odir)
56
- nv_script = odir / ".nvidia_download.py"
57
- _ensure_official_script(nv_script)
58
-
59
- cmd = [
60
- sys.executable,
61
- str(nv_script),
62
- "--odir", str(odir),
63
- "--file_types", args.file_types,
64
- "--workers", str(args.workers),
65
- ]
66
- if args.limit is not None:
67
- cmd.extend(["--limit", str(args.limit)])
68
- if args.clean_cache:
69
- cmd.append("--clean_cache")
70
- print(f"[download_data] $ {' '.join(cmd)}")
71
- rc = subprocess.call(cmd)
72
- sys.exit(rc)
73
-
74
-
75
- if __name__ == "__main__":
76
- main()
 
1
+ """下载 NVIDIA Cosmos-Drive-Dreams 数据集。
2
+
3
+ 直接调用 NVIDIA 官方 download.py,并支持仅下载几个 clip 用于 sandbox 验证。
4
+
5
+ 用法:
6
+ # 完整下载(synthetic + lidar + hdmap),约 3TB
7
+ python scripts/download_data.py --odir ./data/cosmos --workers 8
8
+
9
+ # 仅烟囱:限制 clip 数量(约几 GB,取决于 N)
10
+ python scripts/download_data.py --odir ./data/cosmos --file_types synthetic,lidar,hdmap --limit 2
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import subprocess
17
+ import sys
18
+ import urllib.request
19
+ from pathlib import Path
20
+
21
+
22
+ NV_DOWNLOAD_URL = (
23
+ "https://raw.githubusercontent.com/nv-tlabs/Cosmos-Drive-Dreams/main/scripts/download.py"
24
+ )
25
+
26
+
27
+ def _ensure_official_script(local_path: Path) -> None:
28
+ if local_path.exists():
29
+ return
30
+ print(f"[download_data] 下载 NVIDIA download.py -> {local_path}")
31
+ local_path.parent.mkdir(parents=True, exist_ok=True)
32
+ with urllib.request.urlopen(NV_DOWNLOAD_URL) as resp, open(local_path, "wb") as f:
33
+ f.write(resp.read())
34
+
35
+
36
+ def main() -> None:
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument("--odir", required=True, help="数据输出目录")
39
+ parser.add_argument(
40
+ "--file_types",
41
+ default="synthetic,lidar,hdmap",
42
+ help="数据类型逗号分隔列表",
43
+ )
44
+ parser.add_argument("--workers", type=int, default=4)
45
+ parser.add_argument(
46
+ "--limit",
47
+ type=int,
48
+ default=None,
49
+ metavar="N",
50
+ help="只拉取前 N 个 clip(传给 NVIDIA download.py,省磁盘)",
51
+ )
52
+ parser.add_argument("--clean_cache", action="store_true")
53
+ args = parser.parse_args()
54
+
55
+ odir = Path(args.odir)
56
+ nv_script = odir / ".nvidia_download.py"
57
+ _ensure_official_script(nv_script)
58
+
59
+ cmd = [
60
+ sys.executable,
61
+ str(nv_script),
62
+ "--odir", str(odir),
63
+ "--file_types", args.file_types,
64
+ "--workers", str(args.workers),
65
+ ]
66
+ if args.limit is not None:
67
+ cmd.extend(["--limit", str(args.limit)])
68
+ if args.clean_cache:
69
+ cmd.append("--clean_cache")
70
+ print(f"[download_data] $ {' '.join(cmd)}")
71
+ rc = subprocess.call(cmd)
72
+ sys.exit(rc)
73
+
74
+
75
+ if __name__ == "__main__":
76
+ main()
scripts/estimate_memory.py CHANGED
@@ -1,203 +1,203 @@
1
- """估算 E2EAVModel 在 BS≥8 训练时的显存/内存需求。
2
-
3
- 输出
4
- - 各模块参数数量
5
- - 训练显存细分:参数 / 优化器 / 梯度 / 主激活 / 多任务梯度副本 / 缓冲
6
- - 推荐设备(HF Sandbox / Jobs)
7
- - 主机内存与磁盘开销
8
-
9
- 公式说明(粗略上界)
10
- - 参数 (bf16): 2 B/p;fp32 主副本: 4 B/p
11
- - AdamW 一阶/二阶矩 (fp32): 8 B/p
12
- - 梯度 (fp32): 4 B/p
13
- - bf16 训练总计:参数 2 + 主 4 + AdamW 8 + grad 4 = 18 B/可训练 p
14
- - DINOv3 冻结 Stage1:仅 2 B/p(前向激活按 no_grad 释放,可忽略)
15
- - 主激活:每层约 ``B * N * D * 2 B``(bf16),18 层;MoE 层另加 8 个专家
16
- SwiGLU 中间 ``B * N * 2 * 4D * 2 B`` 的临时项,但 Dense 加权求和后只
17
- 需 1 份输出。实际显存按"激活 = 单层峰值 × 层数"近似。
18
- - PCGrad 在共享参数上 N 次 ``autograd.grad``:需要 retain_graph,
19
- 每个任务额外保留中间激活的引用,最坏放大 N 倍。这里按 1.5x 估算
20
- (GPU autograd 内部 reuse + checkpointing 后通常远低于 N 倍)。
21
- """
22
-
23
- from __future__ import annotations
24
-
25
- import sys
26
- from pathlib import Path
27
-
28
- ROOT = Path(__file__).resolve().parent.parent
29
- sys.path.insert(0, str(ROOT / "src"))
30
-
31
- from dataclasses import dataclass
32
-
33
- from wjad.model import E2EAVModel
34
-
35
-
36
- @dataclass
37
- class MemoryReport:
38
- bs: int
39
- seq_len: int
40
- dim: int
41
- layers: int
42
- params_total: int
43
- params_trainable_stage1: int
44
- params_trainable_stage2: int
45
- weights_gb_stage1: float
46
- weights_gb_stage2: float
47
- optim_gb_stage1: float
48
- optim_gb_stage2: float
49
- activations_gb: float
50
- pcgrad_overhead_gb: float
51
- total_stage1_gb: float
52
- total_stage2_gb: float
53
- host_ram_gb: float
54
- disk_gb: float
55
-
56
-
57
- def count_params(model) -> tuple[int, dict[str, int]]:
58
- total = 0
59
- by_module: dict[str, int] = {}
60
- for name, child in model.named_children():
61
- n = sum(p.numel() for p in child.parameters())
62
- by_module[name] = n
63
- total += n
64
- return total, by_module
65
-
66
-
67
- def estimate(bs: int = 8) -> MemoryReport:
68
- model = E2EAVModel(
69
- dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
70
- # 完整规模
71
- backbone_dim=768,
72
- num_heads=12,
73
- num_dense_layers=9,
74
- num_moe_layers=9,
75
- num_routed_experts=7,
76
- num_shared_experts=1,
77
- topk_experts=3,
78
- ffn_mult=4,
79
- num_history_frames=8,
80
- num_detection_tokens=1024,
81
- num_control_tokens=24,
82
- num_ego_tokens=8,
83
- num_extra_tokens=256,
84
- image_h=384,
85
- image_w=1024,
86
- patch_size=16,
87
- num_classes=22,
88
- traj_horizon=24,
89
- freeze_dinov3=True,
90
- )
91
-
92
- total, by_module = count_params(model)
93
- dinov3_n = by_module.get("dinov3", 0)
94
- trainable_stage1 = total - dinov3_n
95
- trainable_stage2 = total
96
-
97
- # 序列长度(拼接后总 token 数 + 上下文)
98
- n_visual = (8 // 2) * (24 // 2) * (64 // 2)
99
- seq_len = n_visual + 8 + 1024 + 24 + 256
100
-
101
- # === 显存 ===
102
- # 单位:GB(除以 1024**3)
103
- GB = 1024 ** 3
104
- weights_stage1 = (dinov3_n * 2 + trainable_stage1 * 2) / GB # 全部 bf16
105
- weights_stage2 = (total * 2) / GB
106
- optim_stage1 = (trainable_stage1 * (4 + 4 + 4)) / GB # master + m + v
107
- optim_stage2 = (trainable_stage2 * (4 + 4 + 4)) / GB
108
-
109
- # 激活:粗略 = bs * seq_len * dim * 2 * (num_layers + 1) * 1.5 (含 attn/FFN 重叠)
110
- base_act = bs * seq_len * 768 * 2 * (18 + 6) * 1.5 # 主干 18 + 校准 6
111
- # MoE FFN 中间 (4D = 3072) 的临时项:每 MoE 层 ≈ bs * seq_len * 3072 * 2 * 8(8 专家)
112
- moe_act = bs * seq_len * 3072 * 2 * 8 * 9
113
- # DINOv3 冻结:no_grad,前向激活在 forward 后立即释放,估 2 GB 峰值
114
- dino_act = 2.0 * GB
115
- activations_gb = (base_act + moe_act + dino_act) / GB
116
-
117
- # PCGrad 开销(共享参数上 N 次 autograd.grad):retain_graph 阶段会
118
- # 阻止激活释放,最坏接近 1.5x;这里按 +0.5x 估算
119
- pcgrad_overhead_gb = 0.5 * activations_gb
120
-
121
- total_stage1 = weights_stage1 + optim_stage1 + activations_gb + pcgrad_overhead_gb + 2.0
122
- total_stage2 = weights_stage2 + optim_stage2 + activations_gb + pcgrad_overhead_gb + 2.0
123
-
124
- # === 主机 RAM ===
125
- # DataLoader prefetch + workers + 模型 CPU 副本 + JSON / LIDAR 解析
126
- host_ram = 8.0 + bs * 0.3 * 4 * 2 # 4 workers, prefetch 2
127
-
128
- # === 磁盘 ===
129
- # 全量数据集 ~3TB;只跑 sandbox 时 ~5GB(几个 clip);典型 ~50GB(一个 weather 全部)
130
- disk = 50.0
131
-
132
- return MemoryReport(
133
- bs=bs,
134
- seq_len=seq_len,
135
- dim=768,
136
- layers=18,
137
- params_total=total,
138
- params_trainable_stage1=trainable_stage1,
139
- params_trainable_stage2=trainable_stage2,
140
- weights_gb_stage1=weights_stage1,
141
- weights_gb_stage2=weights_stage2,
142
- optim_gb_stage1=optim_stage1,
143
- optim_gb_stage2=optim_stage2,
144
- activations_gb=activations_gb,
145
- pcgrad_overhead_gb=pcgrad_overhead_gb,
146
- total_stage1_gb=total_stage1,
147
- total_stage2_gb=total_stage2,
148
- host_ram_gb=host_ram,
149
- disk_gb=disk,
150
- )
151
-
152
-
153
- def recommend_device(stage_max_gb: float) -> tuple[str, str]:
154
- """根据 Stage2 峰值显存推荐 GPU。"""
155
- margin = 1.15 # 留 15% 余量(碎片化、CUDA caching、cuBLAS workspace)
156
- need = stage_max_gb * margin
157
- candidates = [
158
- ("T4 16GB", 16),
159
- ("L4 24GB", 24),
160
- ("A10G 24GB", 24),
161
- ("A10G Large 48GB", 48),
162
- ("A100 40GB", 40),
163
- ("L40S 48GB", 48),
164
- ("A100 80GB", 80),
165
- ("H100 80GB", 80),
166
- ]
167
- fit = [c for c in candidates if c[1] >= need]
168
- if not fit:
169
- return "H200 / 多卡 80GB+", f"需要 ≥{need:.1f} GB(单卡极限)"
170
- return fit[0][0], f"需要 ≥{need:.1f} GB"
171
-
172
-
173
- def main() -> None:
174
- print("=" * 72)
175
- print(" WJAD 训练显存/内存估算 (bf16 AMP)")
176
- print("=" * 72)
177
- for bs in (1, 2, 4, 8, 16):
178
- r = estimate(bs)
179
- print(f"\n--- BS = {bs} ---")
180
- print(f" 总参数 : {r.params_total / 1e6:8.2f} M")
181
- print(f" 可训练 (S1) : {r.params_trainable_stage1 / 1e6:8.2f} M")
182
- print(f" 可训练 (S2) : {r.params_trainable_stage2 / 1e6:8.2f} M")
183
- print(f" 序列长度 : {r.seq_len}")
184
- print(f" 权重 (S1/S2) : {r.weights_gb_stage1:6.2f} / {r.weights_gb_stage2:6.2f} GB")
185
- print(f" 优化器 (S1/S2): {r.optim_gb_stage1:6.2f} / {r.optim_gb_stage2:6.2f} GB")
186
- print(f" 激活 : {r.activations_gb:6.2f} GB")
187
- print(f" PCGrad 余量 : {r.pcgrad_overhead_gb:6.2f} GB")
188
- print(f" 显存合计 S1 : {r.total_stage1_gb:6.2f} GB")
189
- print(f" 显存合计 S2 : {r.total_stage2_gb:6.2f} GB <- 峰值")
190
- gpu, note = recommend_device(r.total_stage2_gb)
191
- print(f" 推荐 GPU : {gpu} ({note})")
192
- print(f" 主机 RAM : ≥ {r.host_ram_gb:6.2f} GB")
193
- print(f" 磁盘 (典型) : ≈ {r.disk_gb:6.0f} GB")
194
-
195
- print()
196
- print("说明:")
197
- print(" - 估算包含 bf16 AMP + AdamW(m,v fp32) + 梯度 fp32 主副本 + PCGrad 开销。")
198
- print(" - 开 ``gradient_checkpointing`` 可把激活降约 1/3,BS 可成倍提升。")
199
- print(" - 实测请用 ``nvidia-smi`` 或 ``torch.cuda.max_memory_allocated()`` 校准。")
200
-
201
-
202
- if __name__ == "__main__":
203
- main()
 
1
+ """估算 E2EAVModel 在 BS≥8 训练时的显存/内存需求。
2
+
3
+ 输出
4
+ - 各模块参数数量
5
+ - 训练显存细分:参数 / 优化器 / 梯度 / 主激活 / 多任务梯度副本 / 缓冲
6
+ - 推荐设备(HF Sandbox / Jobs)
7
+ - 主机内存与磁盘开销
8
+
9
+ 公式说明(粗略上界)
10
+ - 参数 (bf16): 2 B/p;fp32 主副本: 4 B/p
11
+ - AdamW 一阶/二阶矩 (fp32): 8 B/p
12
+ - 梯度 (fp32): 4 B/p
13
+ - bf16 训练总计:参数 2 + 主 4 + AdamW 8 + grad 4 = 18 B/可训练 p
14
+ - DINOv3 冻结 Stage1:仅 2 B/p(前向激活按 no_grad 释放,可忽略)
15
+ - 主激活:每层约 ``B * N * D * 2 B``(bf16),18 层;MoE 层另加 8 个专家
16
+ SwiGLU 中间 ``B * N * 2 * 4D * 2 B`` 的临时项,但 Dense 加权求和后只
17
+ 需 1 份输出。实际显存按"激活 = 单层峰值 × 层数"近似。
18
+ - PCGrad 在共享参数上 N 次 ``autograd.grad``:需要 retain_graph,
19
+ 每个任务额外保留中间激活的引用,最坏放大 N 倍。这里按 1.5x 估算
20
+ (GPU autograd 内部 reuse + checkpointing 后通常远低于 N 倍)。
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import sys
26
+ from pathlib import Path
27
+
28
+ ROOT = Path(__file__).resolve().parent.parent
29
+ sys.path.insert(0, str(ROOT / "src"))
30
+
31
+ from dataclasses import dataclass
32
+
33
+ from wjad.model import E2EAVModel
34
+
35
+
36
+ @dataclass
37
+ class MemoryReport:
38
+ bs: int
39
+ seq_len: int
40
+ dim: int
41
+ layers: int
42
+ params_total: int
43
+ params_trainable_stage1: int
44
+ params_trainable_stage2: int
45
+ weights_gb_stage1: float
46
+ weights_gb_stage2: float
47
+ optim_gb_stage1: float
48
+ optim_gb_stage2: float
49
+ activations_gb: float
50
+ pcgrad_overhead_gb: float
51
+ total_stage1_gb: float
52
+ total_stage2_gb: float
53
+ host_ram_gb: float
54
+ disk_gb: float
55
+
56
+
57
+ def count_params(model) -> tuple[int, dict[str, int]]:
58
+ total = 0
59
+ by_module: dict[str, int] = {}
60
+ for name, child in model.named_children():
61
+ n = sum(p.numel() for p in child.parameters())
62
+ by_module[name] = n
63
+ total += n
64
+ return total, by_module
65
+
66
+
67
+ def estimate(bs: int = 8) -> MemoryReport:
68
+ model = E2EAVModel(
69
+ dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
70
+ # 完整规模
71
+ backbone_dim=768,
72
+ num_heads=12,
73
+ num_dense_layers=9,
74
+ num_moe_layers=9,
75
+ num_routed_experts=7,
76
+ num_shared_experts=1,
77
+ topk_experts=3,
78
+ ffn_mult=4,
79
+ num_history_frames=8,
80
+ num_detection_tokens=1024,
81
+ num_control_tokens=24,
82
+ num_ego_tokens=8,
83
+ num_extra_tokens=256,
84
+ image_h=384,
85
+ image_w=1024,
86
+ patch_size=16,
87
+ num_classes=22,
88
+ traj_horizon=24,
89
+ freeze_dinov3=True,
90
+ )
91
+
92
+ total, by_module = count_params(model)
93
+ dinov3_n = by_module.get("dinov3", 0)
94
+ trainable_stage1 = total - dinov3_n
95
+ trainable_stage2 = total
96
+
97
+ # 序列长度(拼接后总 token 数 + 上下文)
98
+ n_visual = (8 // 2) * (24 // 2) * (64 // 2)
99
+ seq_len = n_visual + 8 + 1024 + 24 + 256
100
+
101
+ # === 显存 ===
102
+ # 单位:GB(除以 1024**3)
103
+ GB = 1024 ** 3
104
+ weights_stage1 = (dinov3_n * 2 + trainable_stage1 * 2) / GB # 全部 bf16
105
+ weights_stage2 = (total * 2) / GB
106
+ optim_stage1 = (trainable_stage1 * (4 + 4 + 4)) / GB # master + m + v
107
+ optim_stage2 = (trainable_stage2 * (4 + 4 + 4)) / GB
108
+
109
+ # 激活:粗略 = bs * seq_len * dim * 2 * (num_layers + 1) * 1.5 (含 attn/FFN 重叠)
110
+ base_act = bs * seq_len * 768 * 2 * (18 + 6) * 1.5 # 主干 18 + 校准 6
111
+ # MoE FFN 中间 (4D = 3072) 的临时项:每 MoE 层 ≈ bs * seq_len * 3072 * 2 * 8(8 专家)
112
+ moe_act = bs * seq_len * 3072 * 2 * 8 * 9
113
+ # DINOv3 冻结:no_grad,前向激活在 forward 后立即释放,估 2 GB 峰值
114
+ dino_act = 2.0 * GB
115
+ activations_gb = (base_act + moe_act + dino_act) / GB
116
+
117
+ # PCGrad 开销(共享参数上 N 次 autograd.grad):retain_graph 阶段会
118
+ # 阻止激活释放,最坏接近 1.5x;这里按 +0.5x 估算
119
+ pcgrad_overhead_gb = 0.5 * activations_gb
120
+
121
+ total_stage1 = weights_stage1 + optim_stage1 + activations_gb + pcgrad_overhead_gb + 2.0
122
+ total_stage2 = weights_stage2 + optim_stage2 + activations_gb + pcgrad_overhead_gb + 2.0
123
+
124
+ # === 主机 RAM ===
125
+ # DataLoader prefetch + workers + 模型 CPU 副本 + JSON / LIDAR 解析
126
+ host_ram = 8.0 + bs * 0.3 * 4 * 2 # 4 workers, prefetch 2
127
+
128
+ # === 磁盘 ===
129
+ # 全量数据集 ~3TB;只跑 sandbox 时 ~5GB(几个 clip);典型 ~50GB(一个 weather 全部)
130
+ disk = 50.0
131
+
132
+ return MemoryReport(
133
+ bs=bs,
134
+ seq_len=seq_len,
135
+ dim=768,
136
+ layers=18,
137
+ params_total=total,
138
+ params_trainable_stage1=trainable_stage1,
139
+ params_trainable_stage2=trainable_stage2,
140
+ weights_gb_stage1=weights_stage1,
141
+ weights_gb_stage2=weights_stage2,
142
+ optim_gb_stage1=optim_stage1,
143
+ optim_gb_stage2=optim_stage2,
144
+ activations_gb=activations_gb,
145
+ pcgrad_overhead_gb=pcgrad_overhead_gb,
146
+ total_stage1_gb=total_stage1,
147
+ total_stage2_gb=total_stage2,
148
+ host_ram_gb=host_ram,
149
+ disk_gb=disk,
150
+ )
151
+
152
+
153
+ def recommend_device(stage_max_gb: float) -> tuple[str, str]:
154
+ """根据 Stage2 峰值显存推荐 GPU。"""
155
+ margin = 1.15 # 留 15% 余量(碎片化、CUDA caching、cuBLAS workspace)
156
+ need = stage_max_gb * margin
157
+ candidates = [
158
+ ("T4 16GB", 16),
159
+ ("L4 24GB", 24),
160
+ ("A10G 24GB", 24),
161
+ ("A10G Large 48GB", 48),
162
+ ("A100 40GB", 40),
163
+ ("L40S 48GB", 48),
164
+ ("A100 80GB", 80),
165
+ ("H100 80GB", 80),
166
+ ]
167
+ fit = [c for c in candidates if c[1] >= need]
168
+ if not fit:
169
+ return "H200 / 多卡 80GB+", f"需要 ≥{need:.1f} GB(单卡极限)"
170
+ return fit[0][0], f"需要 ≥{need:.1f} GB"
171
+
172
+
173
+ def main() -> None:
174
+ print("=" * 72)
175
+ print(" WJAD 训练显存/内存估算 (bf16 AMP)")
176
+ print("=" * 72)
177
+ for bs in (1, 2, 4, 8, 16):
178
+ r = estimate(bs)
179
+ print(f"\n--- BS = {bs} ---")
180
+ print(f" 总参数 : {r.params_total / 1e6:8.2f} M")
181
+ print(f" 可训练 (S1) : {r.params_trainable_stage1 / 1e6:8.2f} M")
182
+ print(f" 可训练 (S2) : {r.params_trainable_stage2 / 1e6:8.2f} M")
183
+ print(f" 序列长度 : {r.seq_len}")
184
+ print(f" 权重 (S1/S2) : {r.weights_gb_stage1:6.2f} / {r.weights_gb_stage2:6.2f} GB")
185
+ print(f" 优化器 (S1/S2): {r.optim_gb_stage1:6.2f} / {r.optim_gb_stage2:6.2f} GB")
186
+ print(f" 激活 : {r.activations_gb:6.2f} GB")
187
+ print(f" PCGrad 余量 : {r.pcgrad_overhead_gb:6.2f} GB")
188
+ print(f" 显存合计 S1 : {r.total_stage1_gb:6.2f} GB")
189
+ print(f" 显存合计 S2 : {r.total_stage2_gb:6.2f} GB <- 峰值")
190
+ gpu, note = recommend_device(r.total_stage2_gb)
191
+ print(f" 推荐 GPU : {gpu} ({note})")
192
+ print(f" 主机 RAM : ≥ {r.host_ram_gb:6.2f} GB")
193
+ print(f" 磁盘 (典型) : ≈ {r.disk_gb:6.0f} GB")
194
+
195
+ print()
196
+ print("说明:")
197
+ print(" - 估算包含 bf16 AMP + AdamW(m,v fp32) + 梯度 fp32 主副本 + PCGrad 开销。")
198
+ print(" - 开 ``gradient_checkpointing`` 可把激活降��约 1/3,BS 可成倍提升。")
199
+ print(" - 实测请用 ``nvidia-smi`` 或 ``torch.cuda.max_memory_allocated()`` 校准。")
200
+
201
+
202
+ if __name__ == "__main__":
203
+ main()
scripts/ingest_hub_to_bucket.py CHANGED
@@ -1,207 +1,234 @@
1
- """将 Hub 上数据集/仓库路径 **服务端拷贝** 到 Storage Bucket,并在挂载点上可选解压 ``.tar`` / ``.tar.*``。
2
-
3
- 解压输出默认写入 **另一棵目录树**( ``--dest-prefix`` 平级的 ``{dest-prefix}_unpacked/``)
4
- 相对路径与镜像里的 ``.tar`` 一致,避免在源树旁叠 ``*_extracted/`` 导致 ``rglob`` 反复扫到嵌套 tar
5
-
6
- 示例(本地或 Job 内,且已挂载 bucket 到 ``/mnt/cosmos``)::
7
-
8
- python scripts/ingest_hub_to_bucket.py \\
9
- --bucket fuzirui/my-cosmos-bucket \\
10
- --dest-prefix cosmos_hub_mirror \\
11
- --bucket-mount /mnt/cosmos \\
12
- --extract-tars
13
-
14
- 仅拷贝、不解压::
15
-
16
- python scripts/ingest_hub_to_bucket.py \\
17
- --bucket fuzirui/my-cosmos-bucket \\
18
- --source 'hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/' \\
19
- --dest-prefix raw \\
20
- --copy-only
21
- """
22
-
23
- from __future__ import annotations
24
-
25
- import argparse
26
- import sys
27
- import tarfile
28
- from pathlib import Path
29
-
30
- from huggingface_hub import HfApi, create_bucket
31
-
32
-
33
- def _ensure_trailing_slash_hf_url(url: str) -> str:
34
- s = url.strip()
35
- if s.endswith("/"):
36
- return s
37
- return s + "/"
38
-
39
-
40
- def _archive_stem(path: Path) -> str:
41
- """``foo.tar.gz`` -> ``foo``;``bar.tar`` -> ``bar``。"""
42
- n = path.name
43
- for ext in (".tar.gz", ".tar.xz", ".tgz", ".tar"):
44
- if n.endswith(ext):
45
- return n[: -len(ext)]
46
- return path.stem
47
-
48
-
49
- def _is_under_path(path: Path, parent: Path) -> bool:
50
- try:
51
- path.resolve().relative_to(parent.resolve())
52
- return True
53
- except ValueError:
54
- return False
55
-
56
-
57
- def _collect_archives(
58
- root: Path,
59
- patterns: tuple[str, ...],
60
- *,
61
- exclude_under: Path | None = None,
62
- ) -> list[Path]:
63
- """收集待解压归档,排除历史 ``*_extracted`` 目录及解压输出树,避免嵌套/重复扫描。"""
64
- out: list[Path] = []
65
- seen: set[Path] = set()
66
- for pat in patterns:
67
- for p in root.rglob(pat):
68
- if not p.is_file():
69
- continue
70
- rp = p.resolve()
71
- if rp in seen:
72
- continue
73
- if any(part.endswith("_extracted") or part == "_extracted" for part in p.parts):
74
- continue
75
- if exclude_under is not None and _is_under_path(p, exclude_under):
76
- continue
77
- seen.add(rp)
78
- out.append(p)
79
- return sorted(out)
80
-
81
-
82
- def main() -> None:
83
- parser = argparse.ArgumentParser(description="Hub copy_files → Bucket,可选按镜像目录解压 tar")
84
- parser.add_argument(
85
- "--source",
86
- default="hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/",
87
- help="hf:// 源(仓库或 bucket 前缀),目录建议以 / 结尾",
88
- )
89
- parser.add_argument(
90
- "--bucket",
91
- required=True,
92
- help='目标 bucket id,如 "user/my-bucket"(不要写 hf://)',
93
- )
94
- parser.add_argument(
95
- "--dest-prefix",
96
- default="cosmos_hub_mirror",
97
- help="copy_files 写入 bucket 内的子路径(不要用前导 /)",
98
- )
99
- parser.add_argument(
100
- "--ensure-bucket",
101
- action="store_true",
102
- help="若不存在则 create_bucket(..., exist_ok=True)",
103
- )
104
- parser.add_argument(
105
- "--copy-only",
106
- action="store_true",
107
- help="只做 copy_files,解压",
108
- )
109
- parser.add_argument(
110
- "--bucket-mount",
111
- default=None,
112
- help="Job 内 bucket 挂载点(如 /mnt/cosmos);若设 --extract-tars 则必填",
113
- )
114
- parser.add_argument(
115
- "--extract-tars",
116
- action="store_true",
117
- help="解压 mirror 树下的 tar;输出见 --extract-out-prefix",
118
- )
119
- parser.add_argument(
120
- "--extract-out-prefix",
121
- default=None,
122
- metavar="NAME",
123
- help="解压根目录(bucket 内相对路径,与 dest-prefix 平级)。默认 {dest-prefix}_unpacked",
124
- )
125
- parser.add_argument(
126
- "--extract-beside-tar",
127
- action="store_true",
128
- help="旧行为:在每条 tar 旁建 ``{name}_extracted``(易与 rglob 嵌套 tar 纠缠,一般不推荐)",
129
- )
130
- parser.add_argument(
131
- "--max-tars",
132
- type=int,
133
- default=None,
134
- help="最多处理多少个 tar(烟囱/限流)",
135
- )
136
- args = parser.parse_args()
137
-
138
- src = _ensure_trailing_slash_hf_url(args.source)
139
- dest_prefix = args.dest_prefix.strip().strip("/")
140
- dest = f"hf://buckets/{args.bucket}/{dest_prefix}/"
141
-
142
- api = HfApi()
143
- if args.ensure_bucket:
144
- create_bucket(args.bucket, exist_ok=True)
145
- print(f"[ingest] bucket ready: {args.bucket}", flush=True)
146
-
147
- print(f"[ingest] copy_files\n {src}\n -> {dest}", flush=True)
148
- api.copy_files(src, dest)
149
- print("[ingest] copy_files 完成", flush=True)
150
-
151
- if args.copy_only or not args.extract_tars:
152
- return
153
-
154
- if not args.bucket_mount:
155
- print("[ingest] 错误: --extract-tars 需要 --bucket-mount", file=sys.stderr)
156
- sys.exit(2)
157
-
158
- root = Path(args.bucket_mount) / dest_prefix
159
- out_rel = args.extract_out_prefix
160
- if out_rel is None:
161
- out_rel = f"{dest_prefix}_unpacked"
162
- out_rel = out_rel.strip().strip("/")
163
- extract_base = Path(args.bucket_mount) / out_rel
164
-
165
- if not root.is_dir():
166
- print(f"[ingest] 警告: 镜像路径不存在或尚不可见: {root}", flush=True)
167
-
168
- patterns = ("*.tar", "*.tar.gz", "*.tar.xz", "*.tgz")
169
- archives = _collect_archives(root, patterns, exclude_under=extract_base)
170
-
171
- if args.max_tars is not None:
172
- archives = archives[: args.max_tars]
173
-
174
- mode = "beside-tar" if args.extract_beside_tar else f"mirror -> {extract_base}"
175
- print(f"[ingest] 将解压 {len(archives)} 个归档 under {root}(模式: {mode})", flush=True)
176
-
177
- for i, tar_path in enumerate(archives):
178
- if args.extract_beside_tar:
179
- out_dir = tar_path.parent / f"{tar_path.name}_extracted"
180
- else:
181
- rel = tar_path.relative_to(root)
182
- out_dir = extract_base / rel.parent / _archive_stem(tar_path)
183
-
184
- if out_dir.exists() and any(out_dir.iterdir()):
185
- print(f"[ingest] ({i + 1}/{len(archives)}) 跳过(已存在非空) {out_dir}", flush=True)
186
- continue
187
- out_dir.mkdir(parents=True, exist_ok=True)
188
- print(f"[ingest] ({i + 1}/{len(archives)}) {tar_path} -> {out_dir}", flush=True)
189
- try:
190
- with tarfile.open(tar_path, mode="r:*") as tf:
191
- _extract(tf, out_dir)
192
- except Exception as e:
193
- print(f"[ingest] 解压失败 {tar_path}: {e}", flush=True)
194
- raise
195
-
196
- print("[ingest] 全部完成", flush=True)
197
-
198
-
199
- def _extract(tf: tarfile.TarFile, out_dir: Path) -> None:
200
- if sys.version_info >= (3, 12):
201
- tf.extractall(out_dir, filter="data")
202
- else:
203
- tf.extractall(out_dir)
204
-
205
-
206
- if __name__ == "__main__":
207
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """将 Hub 上数据集/仓库路径 **服务端拷贝** 到 Storage Bucket,并在挂载点上可选解压 ``.tar`` / ``.tar.*``。
2
+
3
+ HF Jobs 容器 **根分区**(ephemeral)常有 **~50GiB** 上限。``copy_files`` Hub 侧完成几乎不占根盘;
4
+ 解压、``pip``、Hub 客户端默认缓存会写 ``/tmp````~/.cache``,易触发 eviction
5
+ 若设 ``--bucket-mount``,会把 ``TMPDIR``、``HF_HOME`` 等重定向到挂载点下 ``.wjad_ephemeral/``,
6
+ 大块临时数据落在 **Bucket**。训练时用 Volume 挂载 Bucket,``WJAD_DATA_ROOT`` 指到 mirror 或解压树即可,
7
+ 无需把整库先下载到根盘。
8
+
9
+ 解压输出默认写入 **另一棵目录树**(与 ``--dest-prefix`` 平级的 ``{dest-prefix}_unpacked/``),
10
+ 相对路径与镜像里的 ``.tar`` 一致,避免在源树旁叠 ``*_extracted/`` 导致 ``rglob`` 反复扫到嵌套 tar。
11
+
12
+ 示例(本地或 Job 内,且已挂载 bucket 到 ``/mnt/cosmos``)::
13
+
14
+ python scripts/ingest_hub_to_bucket.py \\
15
+ --bucket fuzirui/my-cosmos-bucket \\
16
+ --dest-prefix cosmos_hub_mirror \\
17
+ --bucket-mount /mnt/cosmos \\
18
+ --extract-tars
19
+
20
+ 仅拷贝、不解压::
21
+
22
+ python scripts/ingest_hub_to_bucket.py \\
23
+ --bucket fuzirui/my-cosmos-bucket \\
24
+ --source 'hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/' \\
25
+ --dest-prefix raw \\
26
+ --copy-only
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import argparse
32
+ import os
33
+ import sys
34
+ import tarfile
35
+ from pathlib import Path
36
+
37
+ from huggingface_hub import HfApi, create_bucket
38
+
39
+
40
+ def _ensure_trailing_slash_hf_url(url: str) -> str:
41
+ s = url.strip()
42
+ if s.endswith("/"):
43
+ return s
44
+ return s + "/"
45
+
46
+
47
+ def _archive_stem(path: Path) -> str:
48
+ """``foo.tar.gz`` -> ``foo``;``bar.tar`` -> ``bar``。"""
49
+ n = path.name
50
+ for ext in (".tar.gz", ".tar.xz", ".tgz", ".tar"):
51
+ if n.endswith(ext):
52
+ return n[: -len(ext)]
53
+ return path.stem
54
+
55
+
56
+ def _is_under_path(path: Path, parent: Path) -> bool:
57
+ try:
58
+ path.resolve().relative_to(parent.resolve())
59
+ return True
60
+ except ValueError:
61
+ return False
62
+
63
+
64
+ def _collect_archives(
65
+ root: Path,
66
+ patterns: tuple[str, ...],
67
+ *,
68
+ exclude_under: Path | None = None,
69
+ ) -> list[Path]:
70
+ """收集待解压归档,排除历史 ``*_extracted`` 目录及解压输出树,避免嵌套/重复扫描。"""
71
+ out: list[Path] = []
72
+ seen: set[Path] = set()
73
+ for pat in patterns:
74
+ for p in root.rglob(pat):
75
+ if not p.is_file():
76
+ continue
77
+ rp = p.resolve()
78
+ if rp in seen:
79
+ continue
80
+ if any(part.endswith("_extracted") or part == "_extracted" for part in p.parts):
81
+ continue
82
+ if exclude_under is not None and _is_under_path(p, exclude_under):
83
+ continue
84
+ seen.add(rp)
85
+ out.append(p)
86
+ return sorted(out)
87
+
88
+
89
+ def _redirect_ephemeral_to_bucket(bucket_mount: Path) -> None:
90
+ """把临时文件与 HF 缓存写到 Bucket 挂载点,避免撑爆 Job 50G 根分区。"""
91
+ base = bucket_mount / ".wjad_ephemeral"
92
+ tmp = base / "tmp"
93
+ hf_home = base / "hf_home"
94
+ xdg = base / "xdg_cache"
95
+ for d in (tmp, hf_home, hf_home / "hub", xdg):
96
+ d.mkdir(parents=True, exist_ok=True)
97
+ os.environ["TMPDIR"] = str(tmp)
98
+ os.environ["TMP"] = str(tmp)
99
+ os.environ["TEMP"] = str(tmp)
100
+ os.environ["HF_HOME"] = str(hf_home)
101
+ os.environ["HF_HUB_CACHE"] = str(hf_home / "hub")
102
+ os.environ["XDG_CACHE_HOME"] = str(xdg)
103
+ print(f"[ingest] 临时/缓存 -> {base}", flush=True)
104
+
105
+
106
+ def main() -> None:
107
+ parser = argparse.ArgumentParser(description="Hub copy_files → Bucket可选按镜像目录解压 tar")
108
+ parser.add_argument(
109
+ "--source",
110
+ default="hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/",
111
+ help="hf:// 源(仓库或 bucket 前缀),目录建议以 / 结尾",
112
+ )
113
+ parser.add_argument(
114
+ "--bucket",
115
+ required=True,
116
+ help='目标 bucket id,如 "user/my-bucket"(不要写 hf://)',
117
+ )
118
+ parser.add_argument(
119
+ "--dest-prefix",
120
+ default="cosmos_hub_mirror",
121
+ help="copy_files 写入 bucket 内的子路径(不要用前导 /)",
122
+ )
123
+ parser.add_argument(
124
+ "--ensure-bucket",
125
+ action="store_true",
126
+ help="若不存在则 create_bucket(..., exist_ok=True)",
127
+ )
128
+ parser.add_argument(
129
+ "--copy-only",
130
+ action="store_true",
131
+ help="只做 copy_files,不解压",
132
+ )
133
+ parser.add_argument(
134
+ "--bucket-mount",
135
+ default=None,
136
+ help="Job 内 bucket 挂载点(如 /mnt/cosmos);若设 --extract-tars 则必填",
137
+ )
138
+ parser.add_argument(
139
+ "--extract-tars",
140
+ action="store_true",
141
+ help="解压 mirror 树下的 tar;输出见 --extract-out-prefix",
142
+ )
143
+ parser.add_argument(
144
+ "--extract-out-prefix",
145
+ default=None,
146
+ metavar="NAME",
147
+ help="解压根目录(bucket 内相对路径,与 dest-prefix 平级)。默认 {dest-prefix}_unpacked",
148
+ )
149
+ parser.add_argument(
150
+ "--extract-beside-tar",
151
+ action="store_true",
152
+ help="旧行为:在每条 tar 旁建 ``{name}_extracted``(易与 rglob 嵌套 tar 纠缠,一般不推荐)",
153
+ )
154
+ parser.add_argument(
155
+ "--max-tars",
156
+ type=int,
157
+ default=None,
158
+ help="最多处理多少个 tar(烟囱/限流)",
159
+ )
160
+ args = parser.parse_args()
161
+
162
+ if args.bucket_mount:
163
+ _redirect_ephemeral_to_bucket(Path(args.bucket_mount))
164
+
165
+ src = _ensure_trailing_slash_hf_url(args.source)
166
+ dest_prefix = args.dest_prefix.strip().strip("/")
167
+ dest = f"hf://buckets/{args.bucket}/{dest_prefix}/"
168
+
169
+ api = HfApi()
170
+ if args.ensure_bucket:
171
+ create_bucket(args.bucket, exist_ok=True)
172
+ print(f"[ingest] bucket ready: {args.bucket}", flush=True)
173
+
174
+ print(f"[ingest] copy_files\n {src}\n -> {dest}", flush=True)
175
+ api.copy_files(src, dest)
176
+ print("[ingest] copy_files 完成", flush=True)
177
+
178
+ if args.copy_only or not args.extract_tars:
179
+ return
180
+
181
+ if not args.bucket_mount:
182
+ print("[ingest] 错误: --extract-tars 需要 --bucket-mount", file=sys.stderr)
183
+ sys.exit(2)
184
+
185
+ root = Path(args.bucket_mount) / dest_prefix
186
+ out_rel = args.extract_out_prefix
187
+ if out_rel is None:
188
+ out_rel = f"{dest_prefix}_unpacked"
189
+ out_rel = out_rel.strip().strip("/")
190
+ extract_base = Path(args.bucket_mount) / out_rel
191
+
192
+ if not root.is_dir():
193
+ print(f"[ingest] 警告: 镜像路径不存在或尚不可见: {root}", flush=True)
194
+
195
+ patterns = ("*.tar", "*.tar.gz", "*.tar.xz", "*.tgz")
196
+ archives = _collect_archives(root, patterns, exclude_under=extract_base)
197
+
198
+ if args.max_tars is not None:
199
+ archives = archives[: args.max_tars]
200
+
201
+ mode = "beside-tar" if args.extract_beside_tar else f"mirror -> {extract_base}"
202
+ print(f"[ingest] 将解压 {len(archives)} 个归档 under {root}(模式: {mode})", flush=True)
203
+
204
+ for i, tar_path in enumerate(archives):
205
+ if args.extract_beside_tar:
206
+ out_dir = tar_path.parent / f"{tar_path.name}_extracted"
207
+ else:
208
+ rel = tar_path.relative_to(root)
209
+ out_dir = extract_base / rel.parent / _archive_stem(tar_path)
210
+
211
+ if out_dir.exists() and any(out_dir.iterdir()):
212
+ print(f"[ingest] ({i + 1}/{len(archives)}) 跳过(已存在非空) {out_dir}", flush=True)
213
+ continue
214
+ out_dir.mkdir(parents=True, exist_ok=True)
215
+ print(f"[ingest] ({i + 1}/{len(archives)}) {tar_path} -> {out_dir}", flush=True)
216
+ try:
217
+ with tarfile.open(tar_path, mode="r:*") as tf:
218
+ _extract(tf, out_dir)
219
+ except Exception as e:
220
+ print(f"[ingest] 解压失败 {tar_path}: {e}", flush=True)
221
+ raise
222
+
223
+ print("[ingest] 全部完成", flush=True)
224
+
225
+
226
+ def _extract(tf: tarfile.TarFile, out_dir: Path) -> None:
227
+ if sys.version_info >= (3, 12):
228
+ tf.extractall(out_dir, filter="data")
229
+ else:
230
+ tf.extractall(out_dir)
231
+
232
+
233
+ if __name__ == "__main__":
234
+ main()
scripts/push_cpu_ingest_job.py CHANGED
@@ -1,141 +1,148 @@
1
- """提交 **CPU Basic** Job:把 Hub 上 Cosmos(或其它源)服务端复制到你的 Bucket,并尝试解压 tar。
2
-
3
- - 计费:见 https://huggingface.co/docs/hub/jobs-pricing(CPU Basic 约 \\$0.01/ 小时量级,以官网为准)。
4
- - 默认挂载:代码 ``fuzirui/WJAD``、可写 Bucket;超时默认 48h(大仓库复制可能很久)。
5
- - 须 ``hf auth login``;NVIDIA 数据集须在网页接受条款。
6
-
7
- 用法::
8
-
9
- python scripts/push_cpu_ingest_job.py --bucket fuzirui/wjad-cosmos-data
10
- python scripts/push_cpu_ingest_job.py --bucket fuzirui/wjad-cosmos-data --follow
11
- python scripts/push_cpu_ingest_job.py --bucket fuzirui/x --source 'hf://datasets/foo/bar/'
12
- """
13
-
14
- from __future__ import annotations
15
-
16
- import argparse
17
- import os
18
- import sys
19
-
20
- from huggingface_hub import HfApi, Volume, create_bucket
21
-
22
- try:
23
- from huggingface_hub.cli._cli_utils import parse_env_map
24
- except Exception: # pragma: no cover
25
- parse_env_map = None
26
-
27
- DEFAULT_CODE_REPO = "fuzirui/WJAD"
28
- DEFAULT_BUCKET = "fuzirui/WJAD"
29
- DEFAULT_SOURCE = "hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/"
30
- DEFAULT_DEST_PREFIX = "cosmos_hub_mirror"
31
- DEFAULT_IMAGE = "python:3.12"
32
- DEFAULT_TIMEOUT = "7d"
33
-
34
-
35
- def _secrets_for_job() -> dict | None:
36
- if parse_env_map is not None:
37
- try:
38
- m = parse_env_map(["HF_TOKEN"])
39
- if m.get("HF_TOKEN"):
40
- return m
41
- except Exception:
42
- pass
43
- t = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
44
- return {"HF_TOKEN": t} if t else None
45
-
46
-
47
- def main() -> None:
48
- parser = argparse.ArgumentParser(description="HF Jobs:CPU 拉取 Hub → Bucket + 解压")
49
- parser.add_argument("--bucket", default=DEFAULT_BUCKET, help="目标 Storage Bucket id(须已创建或加 --ensure-bucket)")
50
- parser.add_argument("--code-repo", default=DEFAULT_CODE_REPO, help="含 ingest 脚本的 Hub model/space id")
51
- parser.add_argument("--code-type", default="model", choices=("model", "space", "dataset"))
52
- parser.add_argument("--source", default=DEFAULT_SOURCE, help="hf:// 源目录")
53
- parser.add_argument("--dest-prefix", default=DEFAULT_DEST_PREFIX, help="bucket 内子路径")
54
- parser.add_argument(
55
- "--skip-create-bucket",
56
- action="store_true",
57
- help="不在本机预先 create_bucket(bucket 必须已存在,否则挂载失败)",
58
- )
59
- parser.add_argument(
60
- "--no-extract",
61
- action="store_true",
62
- help="只做 copy_files,不解压 tar",
63
- )
64
- parser.add_argument(
65
- "--max-tars",
66
- type=int,
67
- default=None,
68
- help="传给 ingest_hub_to_bucket.py --max-tars",
69
- )
70
- parser.add_argument(
71
- "--extract-out-prefix",
72
- default=None,
73
- metavar="NAME",
74
- help="解压输出子路径(默认 {dest-prefix}_unpacked)",
75
- )
76
- parser.add_argument(
77
- "--extract-beside-tar",
78
- action="store_true",
79
- help="旧行为:在每条 tar 旁解压为 _extracted",
80
- )
81
- parser.add_argument("--image", default=DEFAULT_IMAGE)
82
- parser.add_argument("--timeout", default=DEFAULT_TIMEOUT)
83
- parser.add_argument("--follow", action="store_true")
84
- parser.add_argument("--no-secrets", action="store_true")
85
- args = parser.parse_args()
86
-
87
- bucket_mount = "/mnt/cosmos"
88
- code_mount = "/workspace"
89
-
90
- max_tars = ""
91
- if args.max_tars is not None:
92
- max_tars = f" --max-tars {args.max_tars}"
93
-
94
- extract_flag = "" if args.no_extract else " --extract-tars"
95
-
96
- extract_beside = " --extract-beside-tar" if args.extract_beside_tar else ""
97
- out_prefix = ""
98
- if args.extract_out_prefix:
99
- out_prefix = f" --extract-out-prefix '{args.extract_out_prefix}'"
100
-
101
- script = f"""set -euo pipefail
102
- pip install --root-user-action=ignore --no-cache-dir 'huggingface_hub>=0.30'
103
- python {code_mount}/scripts/ingest_hub_to_bucket.py \\
104
- --bucket '{args.bucket}' \\
105
- --source '{args.source}' \\
106
- --dest-prefix '{args.dest_prefix}' \\
107
- --bucket-mount '{bucket_mount}'{extract_flag}{max_tars}{out_prefix}{extract_beside}
108
- """
109
-
110
- secrets = None if args.no_secrets else _secrets_for_job()
111
- if secrets is None and not args.no_secrets:
112
- print("[push_cpu_ingest] 警告: 无 HF_TOKEN,gated 数据会失败。", file=sys.stderr)
113
-
114
- if not args.skip_create_bucket:
115
- create_bucket(args.bucket, exist_ok=True)
116
- print(f"[push_cpu_ingest] bucket 已确保存在(或已存在): {args.bucket}")
117
-
118
- volumes = [
119
- Volume(type=args.code_type, source=args.code_repo, mount_path=code_mount),
120
- Volume(type="bucket", source=args.bucket, mount_path=bucket_mount),
121
- ]
122
-
123
- api = HfApi()
124
- job = api.run_job(
125
- image=args.image,
126
- command=["bash", "-lc", script],
127
- flavor="cpu-basic",
128
- volumes=volumes,
129
- secrets=secrets,
130
- timeout=args.timeout,
131
- )
132
- print(f"[push_cpu_ingest] Job ID: {job.id}")
133
- print(f"[push_cpu_ingest] URL: {job.url}")
134
-
135
- if args.follow:
136
- for line in api.fetch_job_logs(job_id=job.id, namespace=job.owner.name, follow=True):
137
- print(line, end="" if str(line).endswith("\n") else "\n")
138
-
139
-
140
- if __name__ == "__main__":
141
- main()
 
 
 
 
 
 
 
 
1
+ """提交 **CPU Basic** Job:把 Hub 上 Cosmos(或其它源)服务端复制到你的 Bucket,并尝试解压 tar。
2
+
3
+ - 计费:见 https://huggingface.co/docs/hub/jobs-pricing(CPU Basic 约 \\$0.01/ 小时量级,以官网为准)。
4
+ - 默认挂载:代码 ``fuzirui/WJAD``、可写 Bucket;超时默认 48h(大仓库复制可能很久)。
5
+ - 须 ``hf auth login``;NVIDIA 数据集须在网页接受条款。
6
+
7
+ 用法::
8
+
9
+ python scripts/push_cpu_ingest_job.py --bucket fuzirui/wjad-cosmos-data
10
+ python scripts/push_cpu_ingest_job.py --bucket fuzirui/wjad-cosmos-data --follow
11
+ python scripts/push_cpu_ingest_job.py --bucket fuzirui/x --source 'hf://datasets/foo/bar/'
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import argparse
17
+ import os
18
+ import sys
19
+
20
+ from huggingface_hub import HfApi, Volume, create_bucket
21
+
22
+ try:
23
+ from huggingface_hub.cli._cli_utils import parse_env_map
24
+ except Exception: # pragma: no cover
25
+ parse_env_map = None
26
+
27
+ DEFAULT_CODE_REPO = "fuzirui/WJAD"
28
+ DEFAULT_BUCKET = "fuzirui/WJAD"
29
+ DEFAULT_SOURCE = "hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/"
30
+ DEFAULT_DEST_PREFIX = "cosmos_hub_mirror"
31
+ DEFAULT_IMAGE = "python:3.12"
32
+ DEFAULT_TIMEOUT = "7d"
33
+
34
+
35
+ def _secrets_for_job() -> dict | None:
36
+ if parse_env_map is not None:
37
+ try:
38
+ m = parse_env_map(["HF_TOKEN"])
39
+ if m.get("HF_TOKEN"):
40
+ return m
41
+ except Exception:
42
+ pass
43
+ t = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
44
+ return {"HF_TOKEN": t} if t else None
45
+
46
+
47
+ def main() -> None:
48
+ parser = argparse.ArgumentParser(description="HF Jobs:CPU 拉取 Hub → Bucket + 解压")
49
+ parser.add_argument("--bucket", default=DEFAULT_BUCKET, help="目标 Storage Bucket id(须已创建或加 --ensure-bucket)")
50
+ parser.add_argument("--code-repo", default=DEFAULT_CODE_REPO, help="含 ingest 脚本的 Hub model/space id")
51
+ parser.add_argument("--code-type", default="model", choices=("model", "space", "dataset"))
52
+ parser.add_argument("--source", default=DEFAULT_SOURCE, help="hf:// 源目录")
53
+ parser.add_argument("--dest-prefix", default=DEFAULT_DEST_PREFIX, help="bucket 内子路径")
54
+ parser.add_argument(
55
+ "--skip-create-bucket",
56
+ action="store_true",
57
+ help="不在本机预先 create_bucket(bucket 必须已存在,否则挂载失败)",
58
+ )
59
+ parser.add_argument(
60
+ "--no-extract",
61
+ action="store_true",
62
+ help="只做 copy_files,不解压 tar",
63
+ )
64
+ parser.add_argument(
65
+ "--max-tars",
66
+ type=int,
67
+ default=None,
68
+ help="传给 ingest_hub_to_bucket.py --max-tars",
69
+ )
70
+ parser.add_argument(
71
+ "--extract-out-prefix",
72
+ default=None,
73
+ metavar="NAME",
74
+ help="解压输出子路径(默认 {dest-prefix}_unpacked)",
75
+ )
76
+ parser.add_argument(
77
+ "--extract-beside-tar",
78
+ action="store_true",
79
+ help="旧行为:在每条 tar 旁解压为 _extracted",
80
+ )
81
+ parser.add_argument("--image", default=DEFAULT_IMAGE)
82
+ parser.add_argument("--timeout", default=DEFAULT_TIMEOUT)
83
+ parser.add_argument("--follow", action="store_true")
84
+ parser.add_argument("--no-secrets", action="store_true")
85
+ args = parser.parse_args()
86
+
87
+ bucket_mount = "/mnt/cosmos"
88
+ code_mount = "/workspace"
89
+
90
+ max_tars = ""
91
+ if args.max_tars is not None:
92
+ max_tars = f" --max-tars {args.max_tars}"
93
+
94
+ extract_flag = "" if args.no_extract else " --extract-tars"
95
+
96
+ extract_beside = " --extract-beside-tar" if args.extract_beside_tar else ""
97
+ out_prefix = ""
98
+ if args.extract_out_prefix:
99
+ out_prefix = f" --extract-out-prefix '{args.extract_out_prefix}'"
100
+
101
+ script = f"""set -euo pipefail
102
+ Eph="{bucket_mount}/.wjad_ephemeral"
103
+ mkdir -p "$Eph/tmp" "$Eph/hf_home/hub" "$Eph/xdg_cache"
104
+ export TMPDIR="$Eph/tmp"
105
+ export TMP="$TMPDIR" TEMP="$TMPDIR"
106
+ export HF_HOME="$Eph/hf_home"
107
+ export HF_HUB_CACHE="$HF_HOME/hub"
108
+ export XDG_CACHE_HOME="$Eph/xdg_cache"
109
+ pip install --root-user-action=ignore --no-cache-dir 'huggingface_hub>=0.30'
110
+ python {code_mount}/scripts/ingest_hub_to_bucket.py \\
111
+ --bucket '{args.bucket}' \\
112
+ --source '{args.source}' \\
113
+ --dest-prefix '{args.dest_prefix}' \\
114
+ --bucket-mount '{bucket_mount}'{extract_flag}{max_tars}{out_prefix}{extract_beside}
115
+ """
116
+
117
+ secrets = None if args.no_secrets else _secrets_for_job()
118
+ if secrets is None and not args.no_secrets:
119
+ print("[push_cpu_ingest] 警告: 无 HF_TOKEN,gated 数据会失败。", file=sys.stderr)
120
+
121
+ if not args.skip_create_bucket:
122
+ create_bucket(args.bucket, exist_ok=True)
123
+ print(f"[push_cpu_ingest] bucket 已确保存在(或已存在): {args.bucket}")
124
+
125
+ volumes = [
126
+ Volume(type=args.code_type, source=args.code_repo, mount_path=code_mount),
127
+ Volume(type="bucket", source=args.bucket, mount_path=bucket_mount),
128
+ ]
129
+
130
+ api = HfApi()
131
+ job = api.run_job(
132
+ image=args.image,
133
+ command=["bash", "-lc", script],
134
+ flavor="cpu-basic",
135
+ volumes=volumes,
136
+ secrets=secrets,
137
+ timeout=args.timeout,
138
+ )
139
+ print(f"[push_cpu_ingest] Job ID: {job.id}")
140
+ print(f"[push_cpu_ingest] URL: {job.url}")
141
+
142
+ if args.follow:
143
+ for line in api.fetch_job_logs(job_id=job.id, namespace=job.owner.name, follow=True):
144
+ print(line, end="" if str(line).endswith("\n") else "\n")
145
+
146
+
147
+ if __name__ == "__main__":
148
+ main()
scripts/push_to_jobs.py CHANGED
@@ -1,196 +1,196 @@
1
- """提交 Hugging Face Jobs 正式训练。
2
-
3
- - 代码仓库挂载为只读后复制到 ``/tmp/wjad-run`` 再 ``pip install -e .``。
4
- - **数据**:``CosmosDriveDreamsDataset`` 需要 NVIDIA ``download.py`` 拉下来的目录树
5
- (``synthetic/single_view/generation/*.mp4`` + ``labels/``)。Hub 上 **datasets 视图挂载**
6
- 不是这棵树,``build_clip_index`` 会得到 0 条样本。
7
- - **默认**:在 Job 里先执行 ``scripts/download_data.py``,把数据落到可写目录
8
- ``WJAD_DATA_ROOT``(默认 ``/tmp/wjad-cosmos``)——即 **一次性下载**(按 clip 限流可用
9
- ``--download-limit``)。全量约 TB 级,请用大磁盘 Job 或挂 **HF Bucket** 并把
10
- ``WJAD_DATA_ROOT`` 指到挂载路径。
11
- - **流式**:当前 DataLoader 按视频/帧文件随机访问,未接 ``datasets`` 流式 API;要低改动
12
- 流式需另用 ``IterableDataset`` + shard,属后续工作。
13
-
14
- 用法:
15
-
16
- python scripts/push_to_jobs.py
17
- python scripts/push_to_jobs.py --follow
18
- python scripts/push_to_jobs.py --download-limit 0
19
- python scripts/push_to_jobs.py --skip-download
20
- python scripts/push_to_jobs.py --mount-hub-dataset
21
- """
22
-
23
- from __future__ import annotations
24
-
25
- import argparse
26
- import sys
27
-
28
- from huggingface_hub import HfApi, Volume
29
-
30
- try:
31
- from huggingface_hub.cli._cli_utils import parse_env_map
32
- except Exception: # pragma: no cover
33
- parse_env_map = None
34
-
35
-
36
- DEFAULT_FLAVOR = "a10g-large"
37
- DEFAULT_IMAGE = "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-runtime"
38
- DEFAULT_REPO = "fuzirui/WJAD"
39
- DEFAULT_MOUNT = "/workspace"
40
- DEFAULT_HUB_DATASET = "nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams"
41
- DEFAULT_HUB_DATASET_MOUNT = "/data/cosmos"
42
- DEFAULT_DATA_PREP_DIR = "/tmp/wjad-cosmos"
43
-
44
-
45
- def _secrets_for_job() -> dict | None:
46
- if parse_env_map is not None:
47
- try:
48
- m = parse_env_map(["HF_TOKEN"])
49
- if m.get("HF_TOKEN"):
50
- return m
51
- except Exception:
52
- pass
53
- import os
54
-
55
- t = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
56
- return {"HF_TOKEN": t} if t else None
57
-
58
-
59
- def main() -> None:
60
- parser = argparse.ArgumentParser(description="在 HF Jobs 上启动 WJAD 训练")
61
- parser.add_argument("--repo", default=DEFAULT_REPO, help="含本仓库代码的 Hub repo id")
62
- parser.add_argument(
63
- "--repo-type",
64
- default="model",
65
- choices=("model", "space", "dataset"),
66
- help="仓库类型",
67
- )
68
- parser.add_argument("--mount", default=DEFAULT_MOUNT, help="代码在容器内的挂载路径")
69
- parser.add_argument("--flavor", default=DEFAULT_FLAVOR)
70
- parser.add_argument("--image", default=DEFAULT_IMAGE)
71
- parser.add_argument(
72
- "--follow",
73
- action="store_true",
74
- help="跟随 Job 日志直到结束",
75
- )
76
- parser.add_argument("--no-secrets", action="store_true")
77
- parser.add_argument("--timeout", default=None, help="如 168h、7d")
78
- # —— 数据:默认 NVIDIA 下载到可写目录 ——
79
- parser.add_argument(
80
- "--data-prep-dir",
81
- default=DEFAULT_DATA_PREP_DIR,
82
- help="下载目标目录(可写)。可被环境变量 WJAD_DATA_ROOT 覆盖",
83
- )
84
- parser.add_argument(
85
- "--download-workers",
86
- type=int,
87
- default=8,
88
- help="download_data.py --workers",
89
- )
90
- parser.add_argument(
91
- "--download-limit",
92
- type=int,
93
- default=8,
94
- metavar="N",
95
- help="传给 NVIDIA --limit;默认 8 个 clip 控制磁盘。0=不限制(全量,需足够盘或 Bucket)",
96
- )
97
- parser.add_argument(
98
- "--skip-download",
99
- action="store_true",
100
- help="不运行 download_data.py(数据已存在于 WJAD_DATA_ROOT 或 Bucket 挂载点)",
101
- )
102
- # —— 可选:挂载 Hub dataset(当前 loader 一般不兼容,仅特殊预处理树可用)——
103
- parser.add_argument(
104
- "--mount-hub-dataset",
105
- action="store_true",
106
- help="额外只读挂载 nvidia/Cosmos dataset 到 --hub-dataset-mount(与自动下载互斥)",
107
- )
108
- parser.add_argument("--hub-dataset", default=DEFAULT_HUB_DATASET, metavar="REPO_ID")
109
- parser.add_argument("--hub-dataset-mount", default=DEFAULT_HUB_DATASET_MOUNT)
110
-
111
- args = parser.parse_args()
112
-
113
- code_vol = Volume(type=args.repo_type, source=args.repo, mount_path=args.mount)
114
- ro_mount = args.mount
115
- work = "/tmp/wjad-run"
116
- volumes: list[Volume] = [code_vol]
117
-
118
- if args.mount_hub_dataset:
119
- volumes.append(
120
- Volume(type="dataset", source=args.hub_dataset, mount_path=args.hub_dataset_mount)
121
- )
122
-
123
- use_auto_download = not args.skip_download and not args.mount_hub_dataset
124
- data_root_default = (
125
- args.hub_dataset_mount if args.mount_hub_dataset else args.data_prep_dir
126
- )
127
-
128
- limit_tail = ""
129
- if use_auto_download and args.download_limit > 0:
130
- limit_tail = f" --limit {args.download_limit}"
131
-
132
- download_block = ""
133
- if use_auto_download:
134
- download_block = f"""
135
- mkdir -p "$DATA_ROOT"
136
- python scripts/download_data.py --odir "$DATA_ROOT" \\
137
- --file_types synthetic,lidar,hdmap --workers {args.download_workers}{limit_tail}
138
- """
139
-
140
- script = f"""set -euo pipefail
141
- rm -rf {work}
142
- cp -a {ro_mount} {work}
143
- cd {work}
144
- export PIP_ROOT_USER_ACTION=ignore
145
- pip install --root-user-action=ignore --no-cache-dir -e .
146
- export PYTHONPATH="{work}/src:${{PYTHONPATH:-}}"
147
- DATA_ROOT="${{WJAD_DATA_ROOT:-{data_root_default}}}"
148
- {download_block}
149
- python -m wjad.train.runner_local \\
150
- --device cuda \\
151
- --config configs/default.yaml \\
152
- --data_root "$DATA_ROOT" \\
153
- --dinov3_path "${{DINOV3_PATH:-{work}/dinov3-vitb16-pretrain-lvd1689m}}"
154
- """
155
-
156
- secrets = None if args.no_secrets else _secrets_for_job()
157
- if secrets is None and not args.no_secrets:
158
- print(
159
- "[push_to_jobs] 警告: 未解析到 HF_TOKEN,下载/checkpoint 可能失败。请先 hf auth login。",
160
- file=sys.stderr,
161
- )
162
-
163
- if args.mount_hub_dataset:
164
- print(
165
- f"[push_to_jobs] 已挂载 Hub dataset(只读){args.hub_dataset} -> {args.hub_dataset_mount};"
166
- "若仍 0 样本,说明布局不是 synthetic/*/generation + labels/,请不要用 --mount-hub-dataset,"
167
- "改用默认自动 download。"
168
- )
169
- elif use_auto_download:
170
- lim_msg = f"limit={args.download_limit}" if args.download_limit > 0 else "无 limit(全量)"
171
- print(
172
- f"[push_to_jobs] 将下载到 DATA_ROOT={data_root_default}({lim_msg})。"
173
- "全量请 --download-limit 0 并保证磁盘或 Bucket。"
174
- )
175
- else:
176
- print("[push_to_jobs] 已 --skip-download,请保证 $WJAD_DATA_ROOT 下已有 NVIDIA 布局数据。")
177
-
178
- api = HfApi()
179
- job = api.run_job(
180
- image=args.image,
181
- command=["bash", "-lc", script],
182
- flavor=args.flavor,
183
- volumes=volumes,
184
- secrets=secrets,
185
- timeout=args.timeout,
186
- )
187
- print(f"[push_to_jobs] Job ID: {job.id}")
188
- print(f"[push_to_jobs] URL: {job.url}")
189
-
190
- if args.follow:
191
- for line in api.fetch_job_logs(job_id=job.id, namespace=job.owner.name, follow=True):
192
- print(line, end="" if str(line).endswith("\n") else "\n")
193
-
194
-
195
- if __name__ == "__main__":
196
- main()
 
1
+ """提交 Hugging Face Jobs 正式训练。
2
+
3
+ - 代码仓库挂载为只读后复制到 ``/tmp/wjad-run`` 再 ``pip install -e .``。
4
+ - **数据**:``CosmosDriveDreamsDataset`` 需要 NVIDIA ``download.py`` 拉下来的目录树
5
+ (``synthetic/single_view/generation/*.mp4`` + ``labels/``)。Hub 上 **datasets 视图挂载**
6
+ 不是这棵树,``build_clip_index`` 会得到 0 条样本。
7
+ - **默认**:在 Job 里先执行 ``scripts/download_data.py``,把数据落到可写目录
8
+ ``WJAD_DATA_ROOT``(默认 ``/tmp/wjad-cosmos``)——即 **一次性下载**(按 clip 限流可用
9
+ ``--download-limit``)。全量约 TB 级,请用大磁盘 Job 或挂 **HF Bucket** 并把
10
+ ``WJAD_DATA_ROOT`` 指到挂载路径。
11
+ - **流式**:当前 DataLoader 按视频/帧文件随机访问,未接 ``datasets`` 流式 API;要低改动
12
+ 流式需另用 ``IterableDataset`` + shard,属后续工作。
13
+
14
+ 用法:
15
+
16
+ python scripts/push_to_jobs.py
17
+ python scripts/push_to_jobs.py --follow
18
+ python scripts/push_to_jobs.py --download-limit 0
19
+ python scripts/push_to_jobs.py --skip-download
20
+ python scripts/push_to_jobs.py --mount-hub-dataset
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import argparse
26
+ import sys
27
+
28
+ from huggingface_hub import HfApi, Volume
29
+
30
+ try:
31
+ from huggingface_hub.cli._cli_utils import parse_env_map
32
+ except Exception: # pragma: no cover
33
+ parse_env_map = None
34
+
35
+
36
+ DEFAULT_FLAVOR = "a10g-large"
37
+ DEFAULT_IMAGE = "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-runtime"
38
+ DEFAULT_REPO = "fuzirui/WJAD"
39
+ DEFAULT_MOUNT = "/workspace"
40
+ DEFAULT_HUB_DATASET = "nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams"
41
+ DEFAULT_HUB_DATASET_MOUNT = "/data/cosmos"
42
+ DEFAULT_DATA_PREP_DIR = "/tmp/wjad-cosmos"
43
+
44
+
45
+ def _secrets_for_job() -> dict | None:
46
+ if parse_env_map is not None:
47
+ try:
48
+ m = parse_env_map(["HF_TOKEN"])
49
+ if m.get("HF_TOKEN"):
50
+ return m
51
+ except Exception:
52
+ pass
53
+ import os
54
+
55
+ t = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
56
+ return {"HF_TOKEN": t} if t else None
57
+
58
+
59
+ def main() -> None:
60
+ parser = argparse.ArgumentParser(description="在 HF Jobs 上启动 WJAD 训练")
61
+ parser.add_argument("--repo", default=DEFAULT_REPO, help="含本仓库代码的 Hub repo id")
62
+ parser.add_argument(
63
+ "--repo-type",
64
+ default="model",
65
+ choices=("model", "space", "dataset"),
66
+ help="仓库类型",
67
+ )
68
+ parser.add_argument("--mount", default=DEFAULT_MOUNT, help="代码在容器内的挂载路径")
69
+ parser.add_argument("--flavor", default=DEFAULT_FLAVOR)
70
+ parser.add_argument("--image", default=DEFAULT_IMAGE)
71
+ parser.add_argument(
72
+ "--follow",
73
+ action="store_true",
74
+ help="跟随 Job 日志直到结束",
75
+ )
76
+ parser.add_argument("--no-secrets", action="store_true")
77
+ parser.add_argument("--timeout", default=None, help="如 168h、7d")
78
+ # —— 数据:默认 NVIDIA 下载到可写目录 ——
79
+ parser.add_argument(
80
+ "--data-prep-dir",
81
+ default=DEFAULT_DATA_PREP_DIR,
82
+ help="下载目标目录(可写)。可被环境变量 WJAD_DATA_ROOT 覆盖",
83
+ )
84
+ parser.add_argument(
85
+ "--download-workers",
86
+ type=int,
87
+ default=8,
88
+ help="download_data.py --workers",
89
+ )
90
+ parser.add_argument(
91
+ "--download-limit",
92
+ type=int,
93
+ default=8,
94
+ metavar="N",
95
+ help="传给 NVIDIA --limit;默认 8 个 clip 控制磁盘。0=不限制(全量,需足够盘或 Bucket)",
96
+ )
97
+ parser.add_argument(
98
+ "--skip-download",
99
+ action="store_true",
100
+ help="不运行 download_data.py(数据已存在于 WJAD_DATA_ROOT 或 Bucket 挂载点)",
101
+ )
102
+ # —— 可选:挂载 Hub dataset(当前 loader 一般不兼容,仅特殊预处理树可用)——
103
+ parser.add_argument(
104
+ "--mount-hub-dataset",
105
+ action="store_true",
106
+ help="额外只读挂载 nvidia/Cosmos dataset 到 --hub-dataset-mount(与自动下载互斥)",
107
+ )
108
+ parser.add_argument("--hub-dataset", default=DEFAULT_HUB_DATASET, metavar="REPO_ID")
109
+ parser.add_argument("--hub-dataset-mount", default=DEFAULT_HUB_DATASET_MOUNT)
110
+
111
+ args = parser.parse_args()
112
+
113
+ code_vol = Volume(type=args.repo_type, source=args.repo, mount_path=args.mount)
114
+ ro_mount = args.mount
115
+ work = "/tmp/wjad-run"
116
+ volumes: list[Volume] = [code_vol]
117
+
118
+ if args.mount_hub_dataset:
119
+ volumes.append(
120
+ Volume(type="dataset", source=args.hub_dataset, mount_path=args.hub_dataset_mount)
121
+ )
122
+
123
+ use_auto_download = not args.skip_download and not args.mount_hub_dataset
124
+ data_root_default = (
125
+ args.hub_dataset_mount if args.mount_hub_dataset else args.data_prep_dir
126
+ )
127
+
128
+ limit_tail = ""
129
+ if use_auto_download and args.download_limit > 0:
130
+ limit_tail = f" --limit {args.download_limit}"
131
+
132
+ download_block = ""
133
+ if use_auto_download:
134
+ download_block = f"""
135
+ mkdir -p "$DATA_ROOT"
136
+ python scripts/download_data.py --odir "$DATA_ROOT" \\
137
+ --file_types synthetic,lidar,hdmap --workers {args.download_workers}{limit_tail}
138
+ """
139
+
140
+ script = f"""set -euo pipefail
141
+ rm -rf {work}
142
+ cp -a {ro_mount} {work}
143
+ cd {work}
144
+ export PIP_ROOT_USER_ACTION=ignore
145
+ pip install --root-user-action=ignore --no-cache-dir -e .
146
+ export PYTHONPATH="{work}/src:${{PYTHONPATH:-}}"
147
+ DATA_ROOT="${{WJAD_DATA_ROOT:-{data_root_default}}}"
148
+ {download_block}
149
+ python -m wjad.train.runner_local \\
150
+ --device cuda \\
151
+ --config configs/default.yaml \\
152
+ --data_root "$DATA_ROOT" \\
153
+ --dinov3_path "${{DINOV3_PATH:-{work}/dinov3-vitb16-pretrain-lvd1689m}}"
154
+ """
155
+
156
+ secrets = None if args.no_secrets else _secrets_for_job()
157
+ if secrets is None and not args.no_secrets:
158
+ print(
159
+ "[push_to_jobs] 警告: 未解析到 HF_TOKEN,下载/checkpoint 可能失败。请先 hf auth login。",
160
+ file=sys.stderr,
161
+ )
162
+
163
+ if args.mount_hub_dataset:
164
+ print(
165
+ f"[push_to_jobs] 已挂载 Hub dataset(只读){args.hub_dataset} -> {args.hub_dataset_mount};"
166
+ "若仍 0 样本,说明布局不是 synthetic/*/generation + labels/,请不要用 --mount-hub-dataset,"
167
+ "改用默认自动 download。"
168
+ )
169
+ elif use_auto_download:
170
+ lim_msg = f"limit={args.download_limit}" if args.download_limit > 0 else "无 limit(全量)"
171
+ print(
172
+ f"[push_to_jobs] 将下载到 DATA_ROOT={data_root_default}({lim_msg})。"
173
+ "全量请 --download-limit 0 并保证磁盘或 Bucket。"
174
+ )
175
+ else:
176
+ print("[push_to_jobs] 已 --skip-download,请保证 $WJAD_DATA_ROOT 下已有 NVIDIA 布局数据。")
177
+
178
+ api = HfApi()
179
+ job = api.run_job(
180
+ image=args.image,
181
+ command=["bash", "-lc", script],
182
+ flavor=args.flavor,
183
+ volumes=volumes,
184
+ secrets=secrets,
185
+ timeout=args.timeout,
186
+ )
187
+ print(f"[push_to_jobs] Job ID: {job.id}")
188
+ print(f"[push_to_jobs] URL: {job.url}")
189
+
190
+ if args.follow:
191
+ for line in api.fetch_job_logs(job_id=job.id, namespace=job.owner.name, follow=True):
192
+ print(line, end="" if str(line).endswith("\n") else "\n")
193
+
194
+
195
+ if __name__ == "__main__":
196
+ main()
scripts/push_to_sandbox.py CHANGED
@@ -1,185 +1,185 @@
1
- """推送代码到 HF Space,做 sandbox 微训练。
2
-
3
- 依据 ``estimate_memory.py`` 的估算:
4
- - BS=8 + bf16 + PCGrad + GradNorm 需要 ≥34 GB 显存;
5
- - 默认硬件 **a10g-small**(~24 GB):与 ``smoke_train`` / ``sandbox_real_data`` 的 tiny 设置一致;
6
- - 要拉满 BS=8 可改用 ``--gpu a10g-large`` 或 A100。
7
-
8
- 本脚本:
9
- 1. ``huggingface_hub.create_repo`` 在 HF 上创建(或复用)一个 Space,
10
- Space SDK = ``docker``;
11
- 2. 用 ``upload_folder`` 上传当前仓库(排除 ``.venv``、数据集等);
12
- 3. 写入 ``Dockerfile`` + ``app.py``(在 Space 启动时跑微训练)。
13
-
14
- 要求:先在本地 ``hf auth login``。
15
- """
16
-
17
- from __future__ import annotations
18
-
19
- import argparse
20
- from pathlib import Path
21
-
22
- from huggingface_hub import HfApi, create_repo
23
-
24
- ROOT = Path(__file__).resolve().parent.parent
25
- DOCKERFILE = """FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
26
-
27
- ENV DEBIAN_FRONTEND=noninteractive
28
- ENV PYTHONUNBUFFERED=1
29
- RUN apt-get update && apt-get install -y --no-install-recommends \\
30
- python3 python3-pip python3-venv ffmpeg libgl1 libglib2.0-0 git \\
31
- && rm -rf /var/lib/apt/lists/*
32
-
33
- # HF Space 默认用户(避免权限问题)
34
- RUN useradd -m -u 1000 user
35
- USER user
36
- ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
37
-
38
- WORKDIR /app
39
- COPY --chown=user pyproject.toml /app/
40
- COPY --chown=user src /app/src
41
- COPY --chown=user scripts /app/scripts
42
- COPY --chown=user configs /app/configs
43
- COPY --chown=user dinov3-vitb16-pretrain-lvd1689m /app/dinov3-vitb16-pretrain-lvd1689m
44
- COPY --chown=user app.py /app/app.py
45
-
46
- RUN python3 -m pip install --user --no-cache-dir --upgrade pip \\
47
- && python3 -m pip install --user --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cu124 \\
48
- && python3 -m pip install --user --no-cache-dir -e .
49
-
50
- EXPOSE 7860
51
- CMD ["python3", "app.py"]
52
- """
53
-
54
- APP_PY = '''"""HF Sandbox 入口(docker SDK,监听 7860)。
55
-
56
- 启动后:
57
- 1. 后台进程跑 scripts/smoke_train.py(追加写入 /tmp/wjad.log)
58
- 2. 主进程开 HTTP server on :7860,返回最新日志
59
-
60
- 阶段 A(无需数据):smoke_train 用随机张量验证 GPU 上的 forward/反传/AMP/PCGrad。
61
- 阶段 B(需要数据):把 LAUNCH_CMD 改为 runner_local 的真实训练命令。
62
- """
63
- import os
64
- import subprocess
65
- import sys
66
- import threading
67
- from http.server import BaseHTTPRequestHandler, HTTPServer
68
-
69
- LOG_PATH = "/tmp/wjad.log"
70
- PORT = 7860
71
- # 当 SANDBOX_MODE=real_data 时跑真实标签 + 占位视频;否则跑随机张量 smoke。
72
- _MODE = os.environ.get("SANDBOX_MODE", "smoke")
73
- if _MODE == "real_data":
74
- LAUNCH_CMD = [sys.executable, "scripts/sandbox_real_data.py"]
75
- else:
76
- LAUNCH_CMD = [sys.executable, "scripts/smoke_train.py"]
77
-
78
-
79
- def _print_env(f):
80
- f.write("=" * 72 + "\\n")
81
- f.write(" WJAD HF Sandbox\\n")
82
- f.write("=" * 72 + "\\n")
83
- f.write(f"Python: {sys.version}\\n")
84
- try:
85
- import torch
86
- f.write(f"torch: {torch.__version__} cuda_avail={torch.cuda.is_available()}\\n")
87
- if torch.cuda.is_available():
88
- p = torch.cuda.get_device_properties(0)
89
- f.write(f"device: {p.name} vram={p.total_memory / 1024**3:.2f} GB\\n")
90
- except Exception as e:
91
- f.write(f"torch import failed: {e}\\n")
92
- f.flush()
93
-
94
-
95
- def run_training():
96
- with open(LOG_PATH, "w", buffering=1) as f:
97
- _print_env(f)
98
- f.write(f"$ {' '.join(LAUNCH_CMD)}\\n")
99
- f.flush()
100
- p = subprocess.Popen(
101
- LAUNCH_CMD, stdout=f, stderr=subprocess.STDOUT, cwd="/app"
102
- )
103
- rc = p.wait()
104
- f.write(f"\\n[exit code = {rc}]\\n")
105
-
106
-
107
- class Handler(BaseHTTPRequestHandler):
108
- def do_GET(self):
109
- try:
110
- with open(LOG_PATH, "r") as f:
111
- body = f.read()
112
- except FileNotFoundError:
113
- body = "starting..."
114
- self.send_response(200)
115
- self.send_header("Content-Type", "text/plain; charset=utf-8")
116
- self.end_headers()
117
- self.wfile.write(body.encode("utf-8"))
118
-
119
- def log_message(self, fmt, *args):
120
- return
121
-
122
-
123
- if __name__ == "__main__":
124
- threading.Thread(target=run_training, daemon=True).start()
125
- HTTPServer(("0.0.0.0", PORT), Handler).serve_forever()
126
- '''
127
-
128
-
129
- def main() -> None:
130
- parser = argparse.ArgumentParser()
131
- parser.add_argument("--repo", required=True, help="HF Space repo, e.g. user/wjad-sandbox")
132
- parser.add_argument("--gpu", default="a10g-small", help="HF Spaces 硬件,默认 a10g-small(省 GPU 小时)")
133
- parser.add_argument("--private", action="store_true")
134
- parser.add_argument(
135
- "--mode",
136
- choices=["smoke", "real_data"],
137
- default="smoke",
138
- help="smoke=随机张量;real_data=拉真实标签+占位视频跑 trainer",
139
- )
140
- args = parser.parse_args()
141
-
142
- api = HfApi()
143
- print(f"[push_to_sandbox] 创建 / 复用 Space: {args.repo} (GPU={args.gpu}, mode={args.mode})")
144
- create_repo(
145
- args.repo,
146
- repo_type="space",
147
- space_sdk="docker",
148
- space_hardware=args.gpu,
149
- private=args.private,
150
- exist_ok=True,
151
- )
152
-
153
- # 把 SANDBOX_MODE 写到 Space 变量;HF_TOKEN 需要用户自己在 Space Settings
154
- # -> Secrets 里加一份能访问 NVIDIA 数据集的 token(real_data 模式必须)。
155
- api.add_space_variable(repo_id=args.repo, key="SANDBOX_MODE", value=args.mode)
156
- if args.mode == "real_data":
157
- print(
158
- "[push_to_sandbox] 提醒:real_data 模式需要在 Space Settings -> Secrets "
159
- "里手动添加 HF_TOKEN(必须是能访问 nvidia/PhysicalAI-Autonomous-Vehicle-"
160
- "Cosmos-Drive-Dreams 的账号 token,否则 download.py 会拒绝访问)。"
161
- )
162
-
163
- # 落盘 Dockerfile / app.py
164
- (ROOT / "Dockerfile").write_text(DOCKERFILE, encoding="utf-8")
165
- (ROOT / "app.py").write_text(APP_PY, encoding="utf-8")
166
-
167
- print("[push_to_sandbox] 上传仓库(排除 .venv / data / 缓存)...")
168
- api.upload_folder(
169
- folder_path=str(ROOT),
170
- repo_id=args.repo,
171
- repo_type="space",
172
- ignore_patterns=[
173
- ".venv/*",
174
- "data/*",
175
- "**/__pycache__/*",
176
- "*.pyc",
177
- "agent-tools/*",
178
- ".git/*",
179
- ],
180
- )
181
- print(f"[push_to_sandbox] OK -> https://huggingface.co/spaces/{args.repo}")
182
-
183
-
184
- if __name__ == "__main__":
185
- main()
 
1
+ """推送代码到 HF Space,做 sandbox 微训练。
2
+
3
+ 依据 ``estimate_memory.py`` 的估算:
4
+ - BS=8 + bf16 + PCGrad + GradNorm 需要 ≥34 GB 显存;
5
+ - 默认硬件 **a10g-small**(~24 GB):与 ``smoke_train`` / ``sandbox_real_data`` 的 tiny 设置一致;
6
+ - 要拉满 BS=8 可改用 ``--gpu a10g-large`` 或 A100。
7
+
8
+ 本脚本:
9
+ 1. ``huggingface_hub.create_repo`` 在 HF 上创建(或复用)一个 Space,
10
+ Space SDK = ``docker``;
11
+ 2. 用 ``upload_folder`` 上传当前仓库(排除 ``.venv``、数据集等);
12
+ 3. 写入 ``Dockerfile`` + ``app.py``(在 Space 启动时跑微训练)。
13
+
14
+ 要求:先在本地 ``hf auth login``。
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ from pathlib import Path
21
+
22
+ from huggingface_hub import HfApi, create_repo
23
+
24
+ ROOT = Path(__file__).resolve().parent.parent
25
+ DOCKERFILE = """FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
26
+
27
+ ENV DEBIAN_FRONTEND=noninteractive
28
+ ENV PYTHONUNBUFFERED=1
29
+ RUN apt-get update && apt-get install -y --no-install-recommends \\
30
+ python3 python3-pip python3-venv ffmpeg libgl1 libglib2.0-0 git \\
31
+ && rm -rf /var/lib/apt/lists/*
32
+
33
+ # HF Space 默认用户(避免权限问题)
34
+ RUN useradd -m -u 1000 user
35
+ USER user
36
+ ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
37
+
38
+ WORKDIR /app
39
+ COPY --chown=user pyproject.toml /app/
40
+ COPY --chown=user src /app/src
41
+ COPY --chown=user scripts /app/scripts
42
+ COPY --chown=user configs /app/configs
43
+ COPY --chown=user dinov3-vitb16-pretrain-lvd1689m /app/dinov3-vitb16-pretrain-lvd1689m
44
+ COPY --chown=user app.py /app/app.py
45
+
46
+ RUN python3 -m pip install --user --no-cache-dir --upgrade pip \\
47
+ && python3 -m pip install --user --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cu124 \\
48
+ && python3 -m pip install --user --no-cache-dir -e .
49
+
50
+ EXPOSE 7860
51
+ CMD ["python3", "app.py"]
52
+ """
53
+
54
+ APP_PY = '''"""HF Sandbox 入口(docker SDK,监听 7860)。
55
+
56
+ 启动后:
57
+ 1. 后台进程跑 scripts/smoke_train.py(追加写入 /tmp/wjad.log)
58
+ 2. 主进程开 HTTP server on :7860,返回最新日志
59
+
60
+ 阶段 A(无需数据):smoke_train 用随机张量验证 GPU 上的 forward/反传/AMP/PCGrad。
61
+ 阶段 B(需要数据):把 LAUNCH_CMD 改为 runner_local 的真实训练命令。
62
+ """
63
+ import os
64
+ import subprocess
65
+ import sys
66
+ import threading
67
+ from http.server import BaseHTTPRequestHandler, HTTPServer
68
+
69
+ LOG_PATH = "/tmp/wjad.log"
70
+ PORT = 7860
71
+ # 当 SANDBOX_MODE=real_data 时跑真实标签 + 占位视频;否则跑随机张量 smoke。
72
+ _MODE = os.environ.get("SANDBOX_MODE", "smoke")
73
+ if _MODE == "real_data":
74
+ LAUNCH_CMD = [sys.executable, "scripts/sandbox_real_data.py"]
75
+ else:
76
+ LAUNCH_CMD = [sys.executable, "scripts/smoke_train.py"]
77
+
78
+
79
+ def _print_env(f):
80
+ f.write("=" * 72 + "\\n")
81
+ f.write(" WJAD HF Sandbox\\n")
82
+ f.write("=" * 72 + "\\n")
83
+ f.write(f"Python: {sys.version}\\n")
84
+ try:
85
+ import torch
86
+ f.write(f"torch: {torch.__version__} cuda_avail={torch.cuda.is_available()}\\n")
87
+ if torch.cuda.is_available():
88
+ p = torch.cuda.get_device_properties(0)
89
+ f.write(f"device: {p.name} vram={p.total_memory / 1024**3:.2f} GB\\n")
90
+ except Exception as e:
91
+ f.write(f"torch import failed: {e}\\n")
92
+ f.flush()
93
+
94
+
95
+ def run_training():
96
+ with open(LOG_PATH, "w", buffering=1) as f:
97
+ _print_env(f)
98
+ f.write(f"$ {' '.join(LAUNCH_CMD)}\\n")
99
+ f.flush()
100
+ p = subprocess.Popen(
101
+ LAUNCH_CMD, stdout=f, stderr=subprocess.STDOUT, cwd="/app"
102
+ )
103
+ rc = p.wait()
104
+ f.write(f"\\n[exit code = {rc}]\\n")
105
+
106
+
107
+ class Handler(BaseHTTPRequestHandler):
108
+ def do_GET(self):
109
+ try:
110
+ with open(LOG_PATH, "r") as f:
111
+ body = f.read()
112
+ except FileNotFoundError:
113
+ body = "starting..."
114
+ self.send_response(200)
115
+ self.send_header("Content-Type", "text/plain; charset=utf-8")
116
+ self.end_headers()
117
+ self.wfile.write(body.encode("utf-8"))
118
+
119
+ def log_message(self, fmt, *args):
120
+ return
121
+
122
+
123
+ if __name__ == "__main__":
124
+ threading.Thread(target=run_training, daemon=True).start()
125
+ HTTPServer(("0.0.0.0", PORT), Handler).serve_forever()
126
+ '''
127
+
128
+
129
+ def main() -> None:
130
+ parser = argparse.ArgumentParser()
131
+ parser.add_argument("--repo", required=True, help="HF Space repo, e.g. user/wjad-sandbox")
132
+ parser.add_argument("--gpu", default="a10g-small", help="HF Spaces 硬件,默认 a10g-small(省 GPU 小时)")
133
+ parser.add_argument("--private", action="store_true")
134
+ parser.add_argument(
135
+ "--mode",
136
+ choices=["smoke", "real_data"],
137
+ default="smoke",
138
+ help="smoke=随机张量;real_data=拉真实标签+占位视频跑 trainer",
139
+ )
140
+ args = parser.parse_args()
141
+
142
+ api = HfApi()
143
+ print(f"[push_to_sandbox] 创建 / 复用 Space: {args.repo} (GPU={args.gpu}, mode={args.mode})")
144
+ create_repo(
145
+ args.repo,
146
+ repo_type="space",
147
+ space_sdk="docker",
148
+ space_hardware=args.gpu,
149
+ private=args.private,
150
+ exist_ok=True,
151
+ )
152
+
153
+ # 把 SANDBOX_MODE 写到 Space 变量;HF_TOKEN 需要用户自己在 Space Settings
154
+ # -> Secrets 里加一份能访问 NVIDIA 数据集的 token(real_data 模式必须)。
155
+ api.add_space_variable(repo_id=args.repo, key="SANDBOX_MODE", value=args.mode)
156
+ if args.mode == "real_data":
157
+ print(
158
+ "[push_to_sandbox] 提醒:real_data 模式需要在 Space Settings -> Secrets "
159
+ "里手动添加 HF_TOKEN(必须是能访问 nvidia/PhysicalAI-Autonomous-Vehicle-"
160
+ "Cosmos-Drive-Dreams 的账号 token,否则 download.py 会拒绝访问)。"
161
+ )
162
+
163
+ # 落盘 Dockerfile / app.py
164
+ (ROOT / "Dockerfile").write_text(DOCKERFILE, encoding="utf-8")
165
+ (ROOT / "app.py").write_text(APP_PY, encoding="utf-8")
166
+
167
+ print("[push_to_sandbox] 上传仓库(排除 .venv / data / 缓存)...")
168
+ api.upload_folder(
169
+ folder_path=str(ROOT),
170
+ repo_id=args.repo,
171
+ repo_type="space",
172
+ ignore_patterns=[
173
+ ".venv/*",
174
+ "data/*",
175
+ "**/__pycache__/*",
176
+ "*.pyc",
177
+ "agent-tools/*",
178
+ ".git/*",
179
+ ],
180
+ )
181
+ print(f"[push_to_sandbox] OK -> https://huggingface.co/spaces/{args.repo}")
182
+
183
+
184
+ if __name__ == "__main__":
185
+ main()
scripts/sandbox_real_data.py CHANGED
@@ -1,236 +1,236 @@
1
- """Sandbox 真实数据微验证脚本。
2
-
3
- 由于 NVIDIA Cosmos-Drive-Dreams 数据集的 ``cosmos_synthetic`` 是一份切成 17
4
- 个分卷(共 ~700 GB)的 ``split`` 二进制,单独下载某一分卷无法解压出 mp4。
5
- 因此本脚本采用混合方案:
6
-
7
- 1. 用官方 ``download.py --file_types lidar --limit 1`` 拉下 1 个 clip 的
8
- 全部真实标签(所有 common 文件夹 + lidar_raw),约 50-200 MB;
9
- 2. 把每个 ``.tar`` 解压到 ``labels/{clip_id}/{folder}/`` 结构,匹配
10
- ``wjad.data.cosmos_dataset`` 期待的布局;
11
- 3. 用 ``imageio`` 合成一个随机噪声 mp4 占位真实合成视频
12
- (文件名 ``{clip_id}_{chunk_id}_Sunny.mp4``,121 帧,分辨率 1024×768);
13
- 4. 调用 ``wjad.train.runner_local --tiny --max_steps 4`` 跑 4 步真实标签 +
14
- 伪造视觉的训练。
15
-
16
- 这样能验证:
17
- - 数据集索引(``build_clip_index``)
18
- - 标签解析(``all_object_info`` JSON、SE(3) pose、f-theta 内参)
19
- - LIDAR 加载与遮挡过滤
20
- - Hungarian 匹配 + DETR loss
21
- - 端到端 forward / GradNorm / PCGrad / 反传
22
-
23
- 但不会验证 DINOv3 在真实图像上的语义提取(视觉是噪声,不会收敛)。
24
- """
25
-
26
- from __future__ import annotations
27
-
28
- import os
29
- import shutil
30
- import subprocess
31
- import sys
32
- import tarfile
33
- import urllib.request
34
- from pathlib import Path
35
-
36
- ROOT = Path(__file__).resolve().parent.parent
37
- sys.path.insert(0, str(ROOT / "src"))
38
-
39
- DATA_ROOT = Path(os.environ.get("WJAD_DATA_ROOT", ROOT / "data" / "cosmos"))
40
- NV_DOWNLOAD_URL = (
41
- "https://raw.githubusercontent.com/nv-tlabs/Cosmos-Drive-Dreams/main/scripts/download.py"
42
- )
43
-
44
-
45
- def _print_section(title: str) -> None:
46
- bar = "=" * 60
47
- print(f"\n{bar}\n{title}\n{bar}", flush=True)
48
-
49
-
50
- def step1_download_labels() -> None:
51
- """用 NVIDIA 官方脚本下载 1 个 clip 的标签 + lidar。"""
52
- _print_section("STEP 1 下载真实标签(1 个 clip)")
53
- DATA_ROOT.mkdir(parents=True, exist_ok=True)
54
- nv_script = DATA_ROOT / ".nvidia_download.py"
55
- if not nv_script.exists():
56
- print(f"[download] 取 NVIDIA download.py -> {nv_script}", flush=True)
57
- with urllib.request.urlopen(NV_DOWNLOAD_URL) as r, open(nv_script, "wb") as f:
58
- f.write(r.read())
59
- # 同时拉 lidar + hdmap:``hdmap`` 类别会触发 9 个 3d_* 文件夹下载,
60
- # 配合 common 文件夹一起拿,覆盖动态 + 结构化两类标签。
61
- cmd = [
62
- sys.executable,
63
- str(nv_script),
64
- "--odir", str(DATA_ROOT),
65
- "--file_types", "lidar,hdmap",
66
- "--workers", "4",
67
- "--limit", "1",
68
- ]
69
- print(f"$ {' '.join(cmd)}", flush=True)
70
- rc = subprocess.call(cmd)
71
- if rc != 0:
72
- sys.exit(f"download.py 失败 rc={rc}")
73
-
74
-
75
- def _hoist_single_subdir(out_dir: Path) -> None:
76
- """若解压结果仅为「单个子目录、顶层无文件」,把子目录内容抬到 out_dir(常见 tar 布局)。"""
77
- if not out_dir.is_dir():
78
- return
79
- subs = [p for p in out_dir.iterdir() if p.is_dir()]
80
- files = [p for p in out_dir.iterdir() if p.is_file()]
81
- if len(subs) == 1 and not files:
82
- child = subs[0]
83
- for item in child.iterdir():
84
- dest = out_dir / item.name
85
- if dest.exists():
86
- continue
87
- shutil.move(str(item), str(dest))
88
- try:
89
- child.rmdir()
90
- except OSError:
91
- pass
92
-
93
-
94
- def step2_reorganize_labels() -> str:
95
- """把每个 common 文件夹的 .tar 解压到 ``labels/{clip_id}/{folder}/``。
96
-
97
- 返回挑选出的 ``clip_id``(去掉 ``_{start}_{end}`` 后缀)。
98
- """
99
- _print_section("STEP 2 解压标签到 labels/<clip_id>/<folder> 布局")
100
-
101
- common_folders = [
102
- "all_object_info",
103
- "captions",
104
- "car_mask_coarse",
105
- "ftheta_intrinsic",
106
- "pinhole_intrinsic",
107
- "pose",
108
- "vehicle_pose",
109
- "lidar_raw",
110
- # HDMap 9 类
111
- "3d_lanes",
112
- "3d_lanelines",
113
- "3d_road_boundaries",
114
- "3d_wait_lines",
115
- "3d_crosswalks",
116
- "3d_road_markings",
117
- "3d_poles",
118
- "3d_traffic_lights",
119
- "3d_traffic_signs",
120
- ]
121
-
122
- clip_id_full: str | None = None # {clip_id}_{start}_{end}
123
- clip_id: str | None = None
124
-
125
- for folder in common_folders:
126
- src = DATA_ROOT / folder
127
- if not src.exists():
128
- print(f" - skip {folder} (not downloaded)", flush=True)
129
- continue
130
- tars = sorted(src.glob("*.tar"))
131
- if not tars:
132
- print(f" - skip {folder} (no .tar)", flush=True)
133
- continue
134
- if clip_id_full is None:
135
- clip_id_full = tars[0].stem
136
- clip_id = clip_id_full.rsplit("_", 2)[0]
137
- print(f" -> chosen clip_id_full = {clip_id_full}", flush=True)
138
- print(f" -> video / symlink clip_id = {clip_id}", flush=True)
139
- use_tars = [t for t in tars if t.stem == clip_id_full]
140
- if not use_tars:
141
- print(f" - skip {folder}: 无与 {clip_id_full} 同名的 tar(避免解压错 clip)", flush=True)
142
- continue
143
- tar_path = use_tars[0]
144
- # 目标目录
145
- out_dir = DATA_ROOT / "labels" / clip_id_full / folder
146
- out_dir.mkdir(parents=True, exist_ok=True)
147
- with tarfile.open(tar_path, "r") as tf:
148
- tf.extractall(out_dir)
149
- _hoist_single_subdir(out_dir)
150
- # 若仍嵌套一层 modality 名(ftheta_intrinsic/ftheta_intrinsic/...)
151
- _hoist_single_subdir(out_dir)
152
- # 列几个样例
153
- members = sorted(out_dir.rglob("*"))[:3]
154
- for m in members:
155
- print(f" {m.relative_to(DATA_ROOT)}", flush=True)
156
- print(f" - {folder}: {len(list(out_dir.rglob('*')))} files", flush=True)
157
-
158
- if clip_id_full is None:
159
- sys.exit("没有下到任何标签 tar,确认 HF_TOKEN 是否能访问 NVIDIA 数据集")
160
-
161
- # 兼容 cosmos_dataset.py:它从 labels/{clip_id}/ 读,但实际下载用的是
162
- # {clip_id}_{start}_{end} 作为目录名。这里软链一份名为纯 clip_id 的目录。
163
- short_dir = DATA_ROOT / "labels" / clip_id # type: ignore[arg-type]
164
- if not short_dir.exists():
165
- try:
166
- short_dir.symlink_to(DATA_ROOT / "labels" / clip_id_full, target_is_directory=True)
167
- except OSError:
168
- shutil.copytree(DATA_ROOT / "labels" / clip_id_full, short_dir)
169
- return clip_id # type: ignore[return-value]
170
-
171
-
172
- def step3_make_fake_video(clip_id: str) -> None:
173
- """合成 121 帧随机 mp4 模拟 ``cosmos_synthetic`` 视频。"""
174
- _print_section("STEP 3 合成占位视频(随机噪声 mp4)")
175
- import numpy as np
176
- import cv2
177
-
178
- syn_dir = DATA_ROOT / "synthetic" / "single_view" / "generation"
179
- syn_dir.mkdir(parents=True, exist_ok=True)
180
- out_path = syn_dir / f"{clip_id}_0_Sunny.mp4"
181
-
182
- H, W, T = 768, 1024, 121 # 顶部裁剪后 384,原始 768
183
- rng = np.random.default_rng(0)
184
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
185
- writer = cv2.VideoWriter(str(out_path), fourcc, 30.0, (W, H))
186
- if not writer.isOpened():
187
- sys.exit(f"无法打开 mp4 写入器(缺 codec?): {out_path}")
188
- for _ in range(T):
189
- frame = rng.integers(0, 256, size=(H, W, 3), dtype=np.uint8)
190
- writer.write(frame)
191
- writer.release()
192
- print(f" 写入 {out_path} ({out_path.stat().st_size / 1024**2:.1f} MB)", flush=True)
193
-
194
-
195
- def step4_run_trainer(clip_id: str) -> None:
196
- """跑 runner_local --tiny --max_steps 4。"""
197
- _print_section("STEP 4 跑 trainer(真实标签 + 伪造视觉)")
198
- cmd = [
199
- sys.executable,
200
- "-m",
201
- "wjad.train.runner_local",
202
- "--config", str(ROOT / "configs" / "default.yaml"),
203
- "--data_root", str(DATA_ROOT),
204
- "--dinov3_path", str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
205
- "--device", "cuda" if _has_cuda() else "cpu",
206
- "--tiny",
207
- "--max_steps", "4",
208
- ]
209
- env = os.environ.copy()
210
- env["PYTHONPATH"] = str(ROOT / "src") + os.pathsep + env.get("PYTHONPATH", "")
211
- print(f"$ {' '.join(cmd)}", flush=True)
212
- rc = subprocess.call(cmd, env=env)
213
- if rc != 0:
214
- sys.exit(f"trainer 失败 rc={rc}")
215
-
216
-
217
- def _has_cuda() -> bool:
218
- try:
219
- import torch
220
- return torch.cuda.is_available()
221
- except Exception:
222
- return False
223
-
224
-
225
- def main() -> None:
226
- _print_section("WJAD Sandbox Real-Data Tiny Test")
227
- print(f"DATA_ROOT = {DATA_ROOT}", flush=True)
228
- step1_download_labels()
229
- clip_id = step2_reorganize_labels()
230
- step3_make_fake_video(clip_id)
231
- step4_run_trainer(clip_id)
232
- _print_section("DONE")
233
-
234
-
235
- if __name__ == "__main__":
236
- main()
 
1
+ """Sandbox 真实数据微验证脚本。
2
+
3
+ 由于 NVIDIA Cosmos-Drive-Dreams 数据集的 ``cosmos_synthetic`` 是一份切成 17
4
+ 个分卷(共 ~700 GB)的 ``split`` 二进制,单独下载某一分卷无法解压出 mp4。
5
+ 因此本脚本采用混合方案:
6
+
7
+ 1. 用官方 ``download.py --file_types lidar --limit 1`` 拉下 1 个 clip 的
8
+ 全部真实标签(所有 common 文件夹 + lidar_raw),约 50-200 MB;
9
+ 2. 把每个 ``.tar`` 解压到 ``labels/{clip_id}/{folder}/`` 结构,匹配
10
+ ``wjad.data.cosmos_dataset`` 期待的布局;
11
+ 3. 用 ``imageio`` 合成一个随机噪声 mp4 占位真实合成视频
12
+ (文件名 ``{clip_id}_{chunk_id}_Sunny.mp4``,121 帧,分辨率 1024×768);
13
+ 4. 调用 ``wjad.train.runner_local --tiny --max_steps 4`` 跑 4 步真实标签 +
14
+ 伪造视觉的训练。
15
+
16
+ 这样能验证:
17
+ - 数据集索引(``build_clip_index``)
18
+ - 标签解析(``all_object_info`` JSON、SE(3) pose、f-theta 内参)
19
+ - LIDAR 加载与遮挡过滤
20
+ - Hungarian 匹配 + DETR loss
21
+ - 端到端 forward / GradNorm / PCGrad / 反传
22
+
23
+ 但不会验证 DINOv3 在真实图像上的语义提取(视觉是噪声,不会收敛)。
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import os
29
+ import shutil
30
+ import subprocess
31
+ import sys
32
+ import tarfile
33
+ import urllib.request
34
+ from pathlib import Path
35
+
36
+ ROOT = Path(__file__).resolve().parent.parent
37
+ sys.path.insert(0, str(ROOT / "src"))
38
+
39
+ DATA_ROOT = Path(os.environ.get("WJAD_DATA_ROOT", ROOT / "data" / "cosmos"))
40
+ NV_DOWNLOAD_URL = (
41
+ "https://raw.githubusercontent.com/nv-tlabs/Cosmos-Drive-Dreams/main/scripts/download.py"
42
+ )
43
+
44
+
45
+ def _print_section(title: str) -> None:
46
+ bar = "=" * 60
47
+ print(f"\n{bar}\n{title}\n{bar}", flush=True)
48
+
49
+
50
+ def step1_download_labels() -> None:
51
+ """用 NVIDIA 官方脚本下载 1 个 clip 的标签 + lidar。"""
52
+ _print_section("STEP 1 下载真实标签(1 个 clip)")
53
+ DATA_ROOT.mkdir(parents=True, exist_ok=True)
54
+ nv_script = DATA_ROOT / ".nvidia_download.py"
55
+ if not nv_script.exists():
56
+ print(f"[download] 取 NVIDIA download.py -> {nv_script}", flush=True)
57
+ with urllib.request.urlopen(NV_DOWNLOAD_URL) as r, open(nv_script, "wb") as f:
58
+ f.write(r.read())
59
+ # 同时拉 lidar + hdmap:``hdmap`` 类别会触发 9 个 3d_* 文件夹下载,
60
+ # 配合 common 文件夹一起拿,覆盖动态 + 结构化两类标签。
61
+ cmd = [
62
+ sys.executable,
63
+ str(nv_script),
64
+ "--odir", str(DATA_ROOT),
65
+ "--file_types", "lidar,hdmap",
66
+ "--workers", "4",
67
+ "--limit", "1",
68
+ ]
69
+ print(f"$ {' '.join(cmd)}", flush=True)
70
+ rc = subprocess.call(cmd)
71
+ if rc != 0:
72
+ sys.exit(f"download.py 失败 rc={rc}")
73
+
74
+
75
+ def _hoist_single_subdir(out_dir: Path) -> None:
76
+ """若解压结果仅为「单个子目录、顶层无文件」,把子目录内容抬到 out_dir(常见 tar 布局)。"""
77
+ if not out_dir.is_dir():
78
+ return
79
+ subs = [p for p in out_dir.iterdir() if p.is_dir()]
80
+ files = [p for p in out_dir.iterdir() if p.is_file()]
81
+ if len(subs) == 1 and not files:
82
+ child = subs[0]
83
+ for item in child.iterdir():
84
+ dest = out_dir / item.name
85
+ if dest.exists():
86
+ continue
87
+ shutil.move(str(item), str(dest))
88
+ try:
89
+ child.rmdir()
90
+ except OSError:
91
+ pass
92
+
93
+
94
+ def step2_reorganize_labels() -> str:
95
+ """把每个 common 文件夹的 .tar 解压到 ``labels/{clip_id}/{folder}/``。
96
+
97
+ 返回挑选出的 ``clip_id``(去掉 ``_{start}_{end}`` 后缀)。
98
+ """
99
+ _print_section("STEP 2 解压标签到 labels/<clip_id>/<folder> 布局")
100
+
101
+ common_folders = [
102
+ "all_object_info",
103
+ "captions",
104
+ "car_mask_coarse",
105
+ "ftheta_intrinsic",
106
+ "pinhole_intrinsic",
107
+ "pose",
108
+ "vehicle_pose",
109
+ "lidar_raw",
110
+ # HDMap 9 类
111
+ "3d_lanes",
112
+ "3d_lanelines",
113
+ "3d_road_boundaries",
114
+ "3d_wait_lines",
115
+ "3d_crosswalks",
116
+ "3d_road_markings",
117
+ "3d_poles",
118
+ "3d_traffic_lights",
119
+ "3d_traffic_signs",
120
+ ]
121
+
122
+ clip_id_full: str | None = None # {clip_id}_{start}_{end}
123
+ clip_id: str | None = None
124
+
125
+ for folder in common_folders:
126
+ src = DATA_ROOT / folder
127
+ if not src.exists():
128
+ print(f" - skip {folder} (not downloaded)", flush=True)
129
+ continue
130
+ tars = sorted(src.glob("*.tar"))
131
+ if not tars:
132
+ print(f" - skip {folder} (no .tar)", flush=True)
133
+ continue
134
+ if clip_id_full is None:
135
+ clip_id_full = tars[0].stem
136
+ clip_id = clip_id_full.rsplit("_", 2)[0]
137
+ print(f" -> chosen clip_id_full = {clip_id_full}", flush=True)
138
+ print(f" -> video / symlink clip_id = {clip_id}", flush=True)
139
+ use_tars = [t for t in tars if t.stem == clip_id_full]
140
+ if not use_tars:
141
+ print(f" - skip {folder}: 无与 {clip_id_full} 同名的 tar(避免解压错 clip)", flush=True)
142
+ continue
143
+ tar_path = use_tars[0]
144
+ # 目标目录
145
+ out_dir = DATA_ROOT / "labels" / clip_id_full / folder
146
+ out_dir.mkdir(parents=True, exist_ok=True)
147
+ with tarfile.open(tar_path, "r") as tf:
148
+ tf.extractall(out_dir)
149
+ _hoist_single_subdir(out_dir)
150
+ # 若仍嵌套一层 modality 名(ftheta_intrinsic/ftheta_intrinsic/...)
151
+ _hoist_single_subdir(out_dir)
152
+ # 列几个样例
153
+ members = sorted(out_dir.rglob("*"))[:3]
154
+ for m in members:
155
+ print(f" {m.relative_to(DATA_ROOT)}", flush=True)
156
+ print(f" - {folder}: {len(list(out_dir.rglob('*')))} files", flush=True)
157
+
158
+ if clip_id_full is None:
159
+ sys.exit("没有下到任何标签 tar,确认 HF_TOKEN 是否能访问 NVIDIA 数据集")
160
+
161
+ # 兼容 cosmos_dataset.py:它从 labels/{clip_id}/ 读,但实际下载用的是
162
+ # {clip_id}_{start}_{end} 作为目录名。这里软链一份名为纯 clip_id 的目录。
163
+ short_dir = DATA_ROOT / "labels" / clip_id # type: ignore[arg-type]
164
+ if not short_dir.exists():
165
+ try:
166
+ short_dir.symlink_to(DATA_ROOT / "labels" / clip_id_full, target_is_directory=True)
167
+ except OSError:
168
+ shutil.copytree(DATA_ROOT / "labels" / clip_id_full, short_dir)
169
+ return clip_id # type: ignore[return-value]
170
+
171
+
172
+ def step3_make_fake_video(clip_id: str) -> None:
173
+ """合成 121 帧随机 mp4 模拟 ``cosmos_synthetic`` 视频。"""
174
+ _print_section("STEP 3 合成占位视频(随机噪声 mp4)")
175
+ import numpy as np
176
+ import cv2
177
+
178
+ syn_dir = DATA_ROOT / "synthetic" / "single_view" / "generation"
179
+ syn_dir.mkdir(parents=True, exist_ok=True)
180
+ out_path = syn_dir / f"{clip_id}_0_Sunny.mp4"
181
+
182
+ H, W, T = 768, 1024, 121 # 顶部裁剪后 384,原始 768
183
+ rng = np.random.default_rng(0)
184
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
185
+ writer = cv2.VideoWriter(str(out_path), fourcc, 30.0, (W, H))
186
+ if not writer.isOpened():
187
+ sys.exit(f"无法打开 mp4 写入器(缺 codec?): {out_path}")
188
+ for _ in range(T):
189
+ frame = rng.integers(0, 256, size=(H, W, 3), dtype=np.uint8)
190
+ writer.write(frame)
191
+ writer.release()
192
+ print(f" 写入 {out_path} ({out_path.stat().st_size / 1024**2:.1f} MB)", flush=True)
193
+
194
+
195
+ def step4_run_trainer(clip_id: str) -> None:
196
+ """跑 runner_local --tiny --max_steps 4。"""
197
+ _print_section("STEP 4 跑 trainer(真实标签 + 伪造视觉)")
198
+ cmd = [
199
+ sys.executable,
200
+ "-m",
201
+ "wjad.train.runner_local",
202
+ "--config", str(ROOT / "configs" / "default.yaml"),
203
+ "--data_root", str(DATA_ROOT),
204
+ "--dinov3_path", str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
205
+ "--device", "cuda" if _has_cuda() else "cpu",
206
+ "--tiny",
207
+ "--max_steps", "4",
208
+ ]
209
+ env = os.environ.copy()
210
+ env["PYTHONPATH"] = str(ROOT / "src") + os.pathsep + env.get("PYTHONPATH", "")
211
+ print(f"$ {' '.join(cmd)}", flush=True)
212
+ rc = subprocess.call(cmd, env=env)
213
+ if rc != 0:
214
+ sys.exit(f"trainer 失败 rc={rc}")
215
+
216
+
217
+ def _has_cuda() -> bool:
218
+ try:
219
+ import torch
220
+ return torch.cuda.is_available()
221
+ except Exception:
222
+ return False
223
+
224
+
225
+ def main() -> None:
226
+ _print_section("WJAD Sandbox Real-Data Tiny Test")
227
+ print(f"DATA_ROOT = {DATA_ROOT}", flush=True)
228
+ step1_download_labels()
229
+ clip_id = step2_reorganize_labels()
230
+ step3_make_fake_video(clip_id)
231
+ step4_run_trainer(clip_id)
232
+ _print_section("DONE")
233
+
234
+
235
+ if __name__ == "__main__":
236
+ main()
scripts/smoke_test.py CHANGED
@@ -1,78 +1,78 @@
1
- """本地烟囱测试:用随机张量验证 forward + backward。
2
-
3
- 运行:
4
- python -m scripts.smoke_test
5
- 或:
6
- python scripts/smoke_test.py
7
- """
8
-
9
- from __future__ import annotations
10
-
11
- import sys
12
- from pathlib import Path
13
-
14
- # 允许直接运行
15
- ROOT = Path(__file__).resolve().parent.parent
16
- sys.path.insert(0, str(ROOT / "src"))
17
-
18
- import torch
19
-
20
- from wjad.model import E2EAVModel
21
-
22
-
23
- def main() -> None:
24
- torch.manual_seed(0)
25
- device = "cpu"
26
-
27
- print("[smoke_test] 构建模型...")
28
- model = E2EAVModel(
29
- dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
30
- # 减小测试规模以适配 CPU
31
- num_dense_layers=2,
32
- num_moe_layers=2,
33
- num_detection_tokens=64,
34
- num_extra_tokens=32,
35
- num_classes=22,
36
- ).to(device)
37
-
38
- # 切到 sparse 验证 Top-3 路径也通
39
- model.backbone.set_moe_mode("sparse")
40
-
41
- B, T = 1, 8
42
- images = torch.randn(B, T, 3, 384, 1024, device=device)
43
- ego_6d = torch.zeros(B, T, 6, device=device)
44
- ego_6d[..., 0] = torch.linspace(0, 7, T) # 模拟前进
45
- intr_vec = torch.tensor([[
46
- 512.0, 192.0, 1024, 384, # cx, cy, w, h
47
- 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, # poly
48
- 1.0, # is_bw_poly(与 Cosmos 11 维一致)
49
- ]], device=device)
50
- extr_6d = torch.zeros(B, 6, device=device)
51
-
52
- print("[smoke_test] 前向...")
53
- out = model(images, ego_6d, intr_vec, extr_6d)
54
- print(f" detection cls: {tuple(out.detection.cls_logits.shape)}")
55
- print(f" detection box_mu: {tuple(out.detection.box3d_mu.shape)}")
56
- print(f" detection traj_mu: {tuple(out.detection.traj_mu.shape)}")
57
- print(f" control ego_traj_mu: {tuple(out.control.ego_traj_mu.shape)}")
58
- print(f" control action_mu: {tuple(out.control.action_mu.shape)}")
59
- print(f" moe_stats per layer: {len(out.backbone_out.moe_stats)}")
60
-
61
- # 简单 backward:用 cls + box_mu 和 + ego_traj_mu 的简单 loss
62
- loss = (
63
- out.detection.cls_logits.float().abs().mean()
64
- + out.detection.box3d_mu.float().abs().mean()
65
- + out.detection.traj_mu.float().abs().mean()
66
- + out.control.ego_traj_mu.float().abs().mean()
67
- )
68
- print(f"[smoke_test] loss = {loss.item():.6f}")
69
- loss.backward()
70
- grad_norm = sum(
71
- p.grad.detach().norm().item() for p in model.parameters() if p.grad is not None
72
- )
73
- print(f"[smoke_test] grad sum-of-norms = {grad_norm:.4f}")
74
- print("[smoke_test] OK")
75
-
76
-
77
- if __name__ == "__main__":
78
- main()
 
1
+ """本地烟囱测试:用随机张量验证 forward + backward。
2
+
3
+ 运行:
4
+ python -m scripts.smoke_test
5
+ 或:
6
+ python scripts/smoke_test.py
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import sys
12
+ from pathlib import Path
13
+
14
+ # 允许直接运行
15
+ ROOT = Path(__file__).resolve().parent.parent
16
+ sys.path.insert(0, str(ROOT / "src"))
17
+
18
+ import torch
19
+
20
+ from wjad.model import E2EAVModel
21
+
22
+
23
+ def main() -> None:
24
+ torch.manual_seed(0)
25
+ device = "cpu"
26
+
27
+ print("[smoke_test] 构建模型...")
28
+ model = E2EAVModel(
29
+ dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
30
+ # 减小测试规模以适配 CPU
31
+ num_dense_layers=2,
32
+ num_moe_layers=2,
33
+ num_detection_tokens=64,
34
+ num_extra_tokens=32,
35
+ num_classes=22,
36
+ ).to(device)
37
+
38
+ # 切到 sparse 验证 Top-3 路径也通
39
+ model.backbone.set_moe_mode("sparse")
40
+
41
+ B, T = 1, 8
42
+ images = torch.randn(B, T, 3, 384, 1024, device=device)
43
+ ego_6d = torch.zeros(B, T, 6, device=device)
44
+ ego_6d[..., 0] = torch.linspace(0, 7, T) # 模拟前进
45
+ intr_vec = torch.tensor([[
46
+ 512.0, 192.0, 1024, 384, # cx, cy, w, h
47
+ 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, # poly
48
+ 1.0, # is_bw_poly(与 Cosmos 11 维一致)
49
+ ]], device=device)
50
+ extr_6d = torch.zeros(B, 6, device=device)
51
+
52
+ print("[smoke_test] 前向...")
53
+ out = model(images, ego_6d, intr_vec, extr_6d)
54
+ print(f" detection cls: {tuple(out.detection.cls_logits.shape)}")
55
+ print(f" detection box_mu: {tuple(out.detection.box3d_mu.shape)}")
56
+ print(f" detection traj_mu: {tuple(out.detection.traj_mu.shape)}")
57
+ print(f" control ego_traj_mu: {tuple(out.control.ego_traj_mu.shape)}")
58
+ print(f" control action_mu: {tuple(out.control.action_mu.shape)}")
59
+ print(f" moe_stats per layer: {len(out.backbone_out.moe_stats)}")
60
+
61
+ # 简单 backward:用 cls + box_mu 和 + ego_traj_mu 的简单 loss
62
+ loss = (
63
+ out.detection.cls_logits.float().abs().mean()
64
+ + out.detection.box3d_mu.float().abs().mean()
65
+ + out.detection.traj_mu.float().abs().mean()
66
+ + out.control.ego_traj_mu.float().abs().mean()
67
+ )
68
+ print(f"[smoke_test] loss = {loss.item():.6f}")
69
+ loss.backward()
70
+ grad_norm = sum(
71
+ p.grad.detach().norm().item() for p in model.parameters() if p.grad is not None
72
+ )
73
+ print(f"[smoke_test] grad sum-of-norms = {grad_norm:.4f}")
74
+ print("[smoke_test] OK")
75
+
76
+
77
+ if __name__ == "__main__":
78
+ main()
scripts/smoke_train.py CHANGED
@@ -1,152 +1,152 @@
1
- """端到端训练循环烟囱测试:构造随机 batch,跑 1-2 步 trainer。
2
-
3
- 不依赖磁盘上的数据集,仅验证 forward/backward/loss/PCGrad/GradNorm 链路。
4
- """
5
-
6
- from __future__ import annotations
7
-
8
- import os
9
- import sys
10
- from pathlib import Path
11
-
12
- os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
13
-
14
- ROOT = Path(__file__).resolve().parent.parent
15
- sys.path.insert(0, str(ROOT / "src"))
16
-
17
- import logging
18
-
19
- import numpy as np
20
- import torch
21
-
22
- from wjad.model import E2EAVModel
23
- from wjad.train.trainer import Trainer, TrainerConfig
24
-
25
-
26
- def _make_dummy_batch(
27
- B: int = 1,
28
- T: int = 8,
29
- H: int = 64,
30
- W: int = 128,
31
- num_classes: int = 22,
32
- num_objects: int = 3,
33
- ) -> dict:
34
- """构造极小分辨率的随机 batch(CPU 烟囱测试用)。"""
35
- images = torch.randn(B, T, 3, H, W)
36
- ego_6d = torch.zeros(B, T, 6)
37
- intr_vec = torch.tensor([[
38
- W / 2, H / 2, W, H,
39
- 0.0, 0.5, 0.0, 0.0, 0.0, 0.0,
40
- 1.0,
41
- ]] * B)
42
- extr_6d = torch.zeros(B, 6)
43
- ego_future = torch.zeros(B, 24, 3)
44
- ego_future_valid = torch.ones(B, 24, dtype=torch.bool)
45
-
46
- targets = []
47
- for _ in range(B):
48
- boxes = torch.zeros(num_objects, 7)
49
- boxes[:, 3:6] = 2.0
50
- targets.append({
51
- "labels": torch.randint(1, num_classes, (num_objects,)),
52
- "boxes": boxes,
53
- "is_dynamic": torch.ones(num_objects, dtype=torch.long),
54
- "future_traj": torch.zeros(num_objects, 24, 3),
55
- "future_valid": torch.ones(num_objects, 24, dtype=torch.bool),
56
- })
57
-
58
- return {
59
- "images": images,
60
- "ego_6d": ego_6d,
61
- "intr_vec": intr_vec,
62
- "extr_6d": extr_6d,
63
- "ego_future": ego_future,
64
- "ego_future_valid": ego_future_valid,
65
- "targets": targets,
66
- "meta": [{}] * B,
67
- }
68
-
69
-
70
- def main() -> None:
71
- logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
72
- torch.manual_seed(0)
73
-
74
- has_cuda = torch.cuda.is_available()
75
- device = "cuda" if has_cuda else "cpu"
76
- if has_cuda:
77
- # GPU 上跑接近真实规模:full 384x1024 + 完整 18 层
78
- # a10g-small (~22 GB) 上 BS=4 OOM,启用 gradient_checkpointing 后 BS=2 稳定
79
- H, W = 384, 1024
80
- B = 2
81
- num_dense, num_moe = 9, 9
82
- num_det = 1024
83
- num_extra = 256
84
- amp = "bf16"
85
- n_steps = 4
86
- use_grad_ckpt = True
87
- else:
88
- # CPU 上跑极小规模仅做 sanity
89
- H, W = 64, 128
90
- B = 1
91
- num_dense, num_moe = 2, 2
92
- num_det = 32
93
- num_extra = 16
94
- amp = "fp32"
95
- n_steps = 4
96
- use_grad_ckpt = False
97
-
98
- print(f"[smoke_train] device={device}, H={H} W={W} B={B} amp={amp} grad_ckpt={use_grad_ckpt}")
99
- model = E2EAVModel(
100
- dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
101
- num_dense_layers=num_dense,
102
- num_moe_layers=num_moe,
103
- num_detection_tokens=num_det,
104
- num_control_tokens=24,
105
- num_ego_tokens=8,
106
- num_extra_tokens=num_extra,
107
- num_classes=22,
108
- image_h=H,
109
- image_w=W,
110
- patch_size=16,
111
- )
112
- if use_grad_ckpt:
113
- model.backbone.set_gradient_checkpointing(True)
114
- # sandbox a10g-small 不做 DINOv3 finetune(显存预算 22GB 不够),冻结即可
115
- # 验证两阶段路径切换。完整训练交给 H100 Jobs。
116
- model.dinov3.freeze()
117
-
118
- cfg = TrainerConfig(
119
- total_steps=n_steps,
120
- warmup_steps=1,
121
- base_lr=1e-4,
122
- log_interval=1,
123
- stage1_steps=2, # 跑到 stage2 验证切换路径
124
- stage1_perturb_start=1,
125
- enable_gradnorm=True,
126
- enable_pcgrad=True, # 全程启用 PCGrad
127
- mixed_precision=amp,
128
- unfreeze_dinov3_at_stage2=False, # sandbox 显存有限,验证路径即可
129
- )
130
- trainer = Trainer(model, cfg, num_classes=22, device=device)
131
- rng = np.random.default_rng(0)
132
-
133
- if has_cuda:
134
- torch.cuda.reset_peak_memory_stats()
135
-
136
- for step in range(n_steps):
137
- batch = _make_dummy_batch(B=B, H=H, W=W)
138
- info = trainer.train_step(batch, rng)
139
- print(
140
- f"step={info['step']} stage={info['stage']} total={info['total_loss']:.4f} "
141
- f"cls={info['L_cls']:.4f} box={info['L_box']:.4f} traj_obj={info['L_traj_obj']:.4f} "
142
- f"weights={[f'{w:.2f}' for w in info['weights']]}"
143
- )
144
-
145
- if has_cuda:
146
- peak_gb = torch.cuda.max_memory_allocated() / 1024**3
147
- print(f"[smoke_train] CUDA peak memory = {peak_gb:.2f} GB")
148
- print("[smoke_train] OK")
149
-
150
-
151
- if __name__ == "__main__":
152
- main()
 
1
+ """端到端训练循环烟囱测试:构造随机 batch,跑 1-2 步 trainer。
2
+
3
+ 不依赖磁盘上的数据集,仅验证 forward/backward/loss/PCGrad/GradNorm 链路。
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import os
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
13
+
14
+ ROOT = Path(__file__).resolve().parent.parent
15
+ sys.path.insert(0, str(ROOT / "src"))
16
+
17
+ import logging
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from wjad.model import E2EAVModel
23
+ from wjad.train.trainer import Trainer, TrainerConfig
24
+
25
+
26
+ def _make_dummy_batch(
27
+ B: int = 1,
28
+ T: int = 8,
29
+ H: int = 64,
30
+ W: int = 128,
31
+ num_classes: int = 22,
32
+ num_objects: int = 3,
33
+ ) -> dict:
34
+ """构造极小分辨率的随机 batch(CPU 烟囱测试用)。"""
35
+ images = torch.randn(B, T, 3, H, W)
36
+ ego_6d = torch.zeros(B, T, 6)
37
+ intr_vec = torch.tensor([[
38
+ W / 2, H / 2, W, H,
39
+ 0.0, 0.5, 0.0, 0.0, 0.0, 0.0,
40
+ 1.0,
41
+ ]] * B)
42
+ extr_6d = torch.zeros(B, 6)
43
+ ego_future = torch.zeros(B, 24, 3)
44
+ ego_future_valid = torch.ones(B, 24, dtype=torch.bool)
45
+
46
+ targets = []
47
+ for _ in range(B):
48
+ boxes = torch.zeros(num_objects, 7)
49
+ boxes[:, 3:6] = 2.0
50
+ targets.append({
51
+ "labels": torch.randint(1, num_classes, (num_objects,)),
52
+ "boxes": boxes,
53
+ "is_dynamic": torch.ones(num_objects, dtype=torch.long),
54
+ "future_traj": torch.zeros(num_objects, 24, 3),
55
+ "future_valid": torch.ones(num_objects, 24, dtype=torch.bool),
56
+ })
57
+
58
+ return {
59
+ "images": images,
60
+ "ego_6d": ego_6d,
61
+ "intr_vec": intr_vec,
62
+ "extr_6d": extr_6d,
63
+ "ego_future": ego_future,
64
+ "ego_future_valid": ego_future_valid,
65
+ "targets": targets,
66
+ "meta": [{}] * B,
67
+ }
68
+
69
+
70
+ def main() -> None:
71
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
72
+ torch.manual_seed(0)
73
+
74
+ has_cuda = torch.cuda.is_available()
75
+ device = "cuda" if has_cuda else "cpu"
76
+ if has_cuda:
77
+ # GPU 上跑接近真实规模:full 384x1024 + 完整 18 层
78
+ # a10g-small (~22 GB) 上 BS=4 OOM,启用 gradient_checkpointing 后 BS=2 稳定
79
+ H, W = 384, 1024
80
+ B = 2
81
+ num_dense, num_moe = 9, 9
82
+ num_det = 1024
83
+ num_extra = 256
84
+ amp = "bf16"
85
+ n_steps = 4
86
+ use_grad_ckpt = True
87
+ else:
88
+ # CPU 上跑极小规模仅做 sanity
89
+ H, W = 64, 128
90
+ B = 1
91
+ num_dense, num_moe = 2, 2
92
+ num_det = 32
93
+ num_extra = 16
94
+ amp = "fp32"
95
+ n_steps = 4
96
+ use_grad_ckpt = False
97
+
98
+ print(f"[smoke_train] device={device}, H={H} W={W} B={B} amp={amp} grad_ckpt={use_grad_ckpt}")
99
+ model = E2EAVModel(
100
+ dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
101
+ num_dense_layers=num_dense,
102
+ num_moe_layers=num_moe,
103
+ num_detection_tokens=num_det,
104
+ num_control_tokens=24,
105
+ num_ego_tokens=8,
106
+ num_extra_tokens=num_extra,
107
+ num_classes=22,
108
+ image_h=H,
109
+ image_w=W,
110
+ patch_size=16,
111
+ )
112
+ if use_grad_ckpt:
113
+ model.backbone.set_gradient_checkpointing(True)
114
+ # sandbox a10g-small 不做 DINOv3 finetune(显存预算 22GB 不够),冻结即可
115
+ # 验证两阶段路径切换。完整训练交给 H100 Jobs。
116
+ model.dinov3.freeze()
117
+
118
+ cfg = TrainerConfig(
119
+ total_steps=n_steps,
120
+ warmup_steps=1,
121
+ base_lr=1e-4,
122
+ log_interval=1,
123
+ stage1_steps=2, # 跑到 stage2 验证切换路径
124
+ stage1_perturb_start=1,
125
+ enable_gradnorm=True,
126
+ enable_pcgrad=True, # 全程启用 PCGrad
127
+ mixed_precision=amp,
128
+ unfreeze_dinov3_at_stage2=False, # sandbox 显存有限,验证路径即可
129
+ )
130
+ trainer = Trainer(model, cfg, num_classes=22, device=device)
131
+ rng = np.random.default_rng(0)
132
+
133
+ if has_cuda:
134
+ torch.cuda.reset_peak_memory_stats()
135
+
136
+ for step in range(n_steps):
137
+ batch = _make_dummy_batch(B=B, H=H, W=W)
138
+ info = trainer.train_step(batch, rng)
139
+ print(
140
+ f"step={info['step']} stage={info['stage']} total={info['total_loss']:.4f} "
141
+ f"cls={info['L_cls']:.4f} box={info['L_box']:.4f} traj_obj={info['L_traj_obj']:.4f} "
142
+ f"weights={[f'{w:.2f}' for w in info['weights']]}"
143
+ )
144
+
145
+ if has_cuda:
146
+ peak_gb = torch.cuda.max_memory_allocated() / 1024**3
147
+ print(f"[smoke_train] CUDA peak memory = {peak_gb:.2f} GB")
148
+ print("[smoke_train] OK")
149
+
150
+
151
+ if __name__ == "__main__":
152
+ main()
scripts/update_deps.py CHANGED
@@ -1,123 +1,123 @@
1
- """自动把项目依赖升级到 PyPI 最新版。
2
-
3
- 特点:
4
- - 从 ``pyproject.toml`` 读取 ``project.dependencies`` 与
5
- ``project.optional-dependencies``;
6
- - 直接调用 ``pip install --upgrade <pkg>`` 把所有第三方依赖升级;
7
- - 为 ``torch`` / ``torchvision`` / ``torchaudio`` 提供单独的 CUDA index URL
8
- 选项(``--torch-index https://download.pytorch.org/whl/cu124``);
9
- - 升级后调用 ``pip freeze`` 把锁定版本写入 ``requirements.lock.txt``,便于
10
- 在 HF Sandbox / Jobs 环境中复现。
11
-
12
- 注意:本脚本会修改本地 venv!若需要安全演练,加 ``--dry-run``。
13
- """
14
-
15
- from __future__ import annotations
16
-
17
- import argparse
18
- import shutil
19
- import subprocess
20
- import sys
21
- from pathlib import Path
22
-
23
- try:
24
- import tomllib # py3.11+
25
- except ImportError: # pragma: no cover
26
- import tomli as tomllib # type: ignore
27
-
28
-
29
- ROOT = Path(__file__).resolve().parent.parent
30
- PYPROJECT = ROOT / "pyproject.toml"
31
- LOCK_FILE = ROOT / "requirements.lock.txt"
32
- TORCH_PKGS = {"torch", "torchvision", "torchaudio"}
33
-
34
-
35
- def parse_pyproject() -> tuple[list[str], list[str]]:
36
- """返回 (主依赖, dev 依赖) 的纯包名列表。"""
37
- data = tomllib.loads(PYPROJECT.read_text(encoding="utf-8"))
38
- main = [
39
- _strip_spec(d) for d in data.get("project", {}).get("dependencies", [])
40
- ]
41
- dev = [
42
- _strip_spec(d)
43
- for d in data.get("project", {})
44
- .get("optional-dependencies", {})
45
- .get("dev", [])
46
- ]
47
- return main, dev
48
-
49
-
50
- def _strip_spec(req: str) -> str:
51
- """去掉版本约束,只留包名。"""
52
- name = req.split(";")[0] # 去掉 environment marker
53
- for sym in ("[", ">=", "<=", "==", "~=", ">", "<", "!=", "@"):
54
- if sym in name:
55
- name = name.split(sym)[0]
56
- return name.strip()
57
-
58
-
59
- def run(cmd: list[str], dry_run: bool = False) -> int:
60
- print("$", " ".join(cmd))
61
- if dry_run:
62
- return 0
63
- return subprocess.call(cmd)
64
-
65
-
66
- def upgrade(pkgs: list[str], extra_index: str | None, dry_run: bool, with_pre: bool = False) -> None:
67
- base = [sys.executable, "-m", "pip", "install", "--upgrade"]
68
- if with_pre:
69
- base.append("--pre")
70
- if extra_index:
71
- base += ["--extra-index-url", extra_index]
72
- rc = run(base + pkgs, dry_run=dry_run)
73
- if rc != 0:
74
- sys.exit(rc)
75
-
76
-
77
- def main() -> None:
78
- parser = argparse.ArgumentParser()
79
- parser.add_argument(
80
- "--torch-index",
81
- default=None,
82
- help="PyTorch CUDA wheel 索引(如 https://download.pytorch.org/whl/cu124)",
83
- )
84
- parser.add_argument("--no-dev", action="store_true", help="不升级 dev 依赖")
85
- parser.add_argument("--dry-run", action="store_true", help="只打印命令不执行")
86
- parser.add_argument("--with-pre", action="store_true", help="允许升级到 pre-release")
87
- args = parser.parse_args()
88
-
89
- if not PYPROJECT.exists():
90
- print(f"[update_deps] 找不到 {PYPROJECT}", file=sys.stderr)
91
- sys.exit(1)
92
-
93
- main_deps, dev_deps = parse_pyproject()
94
-
95
- # 把 torch 系列单独处理(用专用索引)
96
- torch_deps = [p for p in main_deps if p in TORCH_PKGS]
97
- other_deps = [p for p in main_deps if p not in TORCH_PKGS]
98
-
99
- print(f"[update_deps] 升级 pip / setuptools / wheel ...")
100
- upgrade(["pip", "setuptools", "wheel"], extra_index=None, dry_run=args.dry_run)
101
-
102
- if torch_deps:
103
- print(f"[update_deps] 升级 torch 系列 ({torch_deps}) ...")
104
- upgrade(torch_deps, extra_index=args.torch_index, dry_run=args.dry_run, with_pre=args.with_pre)
105
-
106
- if other_deps:
107
- print(f"[update_deps] 升级主依赖 ({len(other_deps)} 个) ...")
108
- upgrade(other_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre)
109
-
110
- if dev_deps and not args.no_dev:
111
- print(f"[update_deps] 升级 dev 依赖 ({len(dev_deps)} 个) ...")
112
- upgrade(dev_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre)
113
-
114
- print("[update_deps] 写入锁定文件 ...")
115
- if not args.dry_run:
116
- with open(LOCK_FILE, "w", encoding="utf-8") as f:
117
- subprocess.run([sys.executable, "-m", "pip", "freeze"], stdout=f, check=True)
118
- print(f"[update_deps] 锁定版本已写入 {LOCK_FILE}")
119
- print("[update_deps] OK")
120
-
121
-
122
- if __name__ == "__main__":
123
- main()
 
1
+ """自动把项目依赖升级到 PyPI 最新版。
2
+
3
+ 特点:
4
+ - 从 ``pyproject.toml`` 读取 ``project.dependencies`` 与
5
+ ``project.optional-dependencies``;
6
+ - 直接调用 ``pip install --upgrade <pkg>`` 把所有第三方依赖升级;
7
+ - 为 ``torch`` / ``torchvision`` / ``torchaudio`` 提供单独的 CUDA index URL
8
+ 选项(``--torch-index https://download.pytorch.org/whl/cu124``);
9
+ - 升级后调用 ``pip freeze`` 把锁定版本写入 ``requirements.lock.txt``,便于
10
+ 在 HF Sandbox / Jobs 环境中复现。
11
+
12
+ 注意:本脚本会修改本地 venv!若需要安全演练,加 ``--dry-run``。
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import shutil
19
+ import subprocess
20
+ import sys
21
+ from pathlib import Path
22
+
23
+ try:
24
+ import tomllib # py3.11+
25
+ except ImportError: # pragma: no cover
26
+ import tomli as tomllib # type: ignore
27
+
28
+
29
+ ROOT = Path(__file__).resolve().parent.parent
30
+ PYPROJECT = ROOT / "pyproject.toml"
31
+ LOCK_FILE = ROOT / "requirements.lock.txt"
32
+ TORCH_PKGS = {"torch", "torchvision", "torchaudio"}
33
+
34
+
35
+ def parse_pyproject() -> tuple[list[str], list[str]]:
36
+ """返回 (主依赖, dev 依赖) 的纯包名列表。"""
37
+ data = tomllib.loads(PYPROJECT.read_text(encoding="utf-8"))
38
+ main = [
39
+ _strip_spec(d) for d in data.get("project", {}).get("dependencies", [])
40
+ ]
41
+ dev = [
42
+ _strip_spec(d)
43
+ for d in data.get("project", {})
44
+ .get("optional-dependencies", {})
45
+ .get("dev", [])
46
+ ]
47
+ return main, dev
48
+
49
+
50
+ def _strip_spec(req: str) -> str:
51
+ """去掉版本约束,只留包名。"""
52
+ name = req.split(";")[0] # 去掉 environment marker
53
+ for sym in ("[", ">=", "<=", "==", "~=", ">", "<", "!=", "@"):
54
+ if sym in name:
55
+ name = name.split(sym)[0]
56
+ return name.strip()
57
+
58
+
59
+ def run(cmd: list[str], dry_run: bool = False) -> int:
60
+ print("$", " ".join(cmd))
61
+ if dry_run:
62
+ return 0
63
+ return subprocess.call(cmd)
64
+
65
+
66
+ def upgrade(pkgs: list[str], extra_index: str | None, dry_run: bool, with_pre: bool = False) -> None:
67
+ base = [sys.executable, "-m", "pip", "install", "--upgrade"]
68
+ if with_pre:
69
+ base.append("--pre")
70
+ if extra_index:
71
+ base += ["--extra-index-url", extra_index]
72
+ rc = run(base + pkgs, dry_run=dry_run)
73
+ if rc != 0:
74
+ sys.exit(rc)
75
+
76
+
77
+ def main() -> None:
78
+ parser = argparse.ArgumentParser()
79
+ parser.add_argument(
80
+ "--torch-index",
81
+ default=None,
82
+ help="PyTorch CUDA wheel 索引(如 https://download.pytorch.org/whl/cu124)",
83
+ )
84
+ parser.add_argument("--no-dev", action="store_true", help="不升级 dev 依赖")
85
+ parser.add_argument("--dry-run", action="store_true", help="只打印命令不执行")
86
+ parser.add_argument("--with-pre", action="store_true", help="允许升级到 pre-release")
87
+ args = parser.parse_args()
88
+
89
+ if not PYPROJECT.exists():
90
+ print(f"[update_deps] 找不到 {PYPROJECT}", file=sys.stderr)
91
+ sys.exit(1)
92
+
93
+ main_deps, dev_deps = parse_pyproject()
94
+
95
+ # 把 torch 系列单独处理(用专用索引)
96
+ torch_deps = [p for p in main_deps if p in TORCH_PKGS]
97
+ other_deps = [p for p in main_deps if p not in TORCH_PKGS]
98
+
99
+ print(f"[update_deps] 升级 pip / setuptools / wheel ...")
100
+ upgrade(["pip", "setuptools", "wheel"], extra_index=None, dry_run=args.dry_run)
101
+
102
+ if torch_deps:
103
+ print(f"[update_deps] 升级 torch 系列 ({torch_deps}) ...")
104
+ upgrade(torch_deps, extra_index=args.torch_index, dry_run=args.dry_run, with_pre=args.with_pre)
105
+
106
+ if other_deps:
107
+ print(f"[update_deps] 升级主依赖 ({len(other_deps)} 个) ...")
108
+ upgrade(other_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre)
109
+
110
+ if dev_deps and not args.no_dev:
111
+ print(f"[update_deps] 升级 dev 依赖 ({len(dev_deps)} 个) ...")
112
+ upgrade(dev_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre)
113
+
114
+ print("[update_deps] 写入锁定文件 ...")
115
+ if not args.dry_run:
116
+ with open(LOCK_FILE, "w", encoding="utf-8") as f:
117
+ subprocess.run([sys.executable, "-m", "pip", "freeze"], stdout=f, check=True)
118
+ print(f"[update_deps] 锁定版本已写入 {LOCK_FILE}")
119
+ print("[update_deps] OK")
120
+
121
+
122
+ if __name__ == "__main__":
123
+ main()
src/wjad/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- """WJAD: 端到端自动驾驶模型主包。"""
2
-
3
- from __future__ import annotations
4
-
5
- __version__ = "0.1.0"
 
1
+ """WJAD: 端到端自动驾驶模型主包。"""
2
+
3
+ from __future__ import annotations
4
+
5
+ __version__ = "0.1.0"
src/wjad/backbone/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
- """18 层主干。"""
2
-
3
- from .backbone import Backbone, BackboneOutput
4
- from .blocks import DenseBlock, MoEBlockWithAttn
5
-
6
- __all__ = ["Backbone", "BackboneOutput", "DenseBlock", "MoEBlockWithAttn"]
 
1
+ """18 层主干。"""
2
+
3
+ from .backbone import Backbone, BackboneOutput
4
+ from .blocks import DenseBlock, MoEBlockWithAttn
5
+
6
+ __all__ = ["Backbone", "BackboneOutput", "DenseBlock", "MoEBlockWithAttn"]
src/wjad/backbone/backbone.py CHANGED
@@ -1,110 +1,110 @@
1
- """18 层主干:前 9 Dense + 后 9 MoE。"""
2
-
3
- from __future__ import annotations
4
-
5
- from dataclasses import dataclass, field
6
- from typing import Optional
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.utils.checkpoint as cp
11
-
12
- from ..modules.moe import MoEStats
13
- from .blocks import DenseBlock, MoEBlockWithAttn
14
-
15
-
16
- @dataclass
17
- class BackboneOutput:
18
- """主干输出。"""
19
-
20
- hidden_states: torch.Tensor # [B, N, D]
21
- moe_stats: list[MoEStats] = field(default_factory=list)
22
-
23
-
24
- class Backbone(nn.Module):
25
- """端到端主干。
26
-
27
- 输入序列已包含位置编码(视觉部分 RoPE 在每层内部应用,非视觉部分使用
28
- 可学习 PE 在外部加完)。本模块只负责 18 层堆叠 + 路由统计聚合。
29
- """
30
-
31
- def __init__(
32
- self,
33
- dim: int = 768,
34
- num_heads: int = 12,
35
- ffn_mult: int = 4,
36
- num_dense_layers: int = 9,
37
- num_moe_layers: int = 9,
38
- num_routed: int = 7,
39
- num_shared: int = 1,
40
- topk: int = 3,
41
- dropout: float = 0.0,
42
- ) -> None:
43
- super().__init__()
44
- self.dim = dim
45
- self.num_heads = num_heads
46
- self.num_dense_layers = num_dense_layers
47
- self.num_moe_layers = num_moe_layers
48
-
49
- self.dense_layers = nn.ModuleList([
50
- DenseBlock(dim, num_heads, ffn_mult=ffn_mult, dropout=dropout)
51
- for _ in range(num_dense_layers)
52
- ])
53
- self.moe_layers = nn.ModuleList([
54
- MoEBlockWithAttn(
55
- dim,
56
- num_heads,
57
- num_routed=num_routed,
58
- num_shared=num_shared,
59
- topk=topk,
60
- ffn_mult=ffn_mult,
61
- dropout=dropout,
62
- )
63
- for _ in range(num_moe_layers)
64
- ])
65
- self.final_norm = nn.LayerNorm(dim)
66
- # 默认关闭;外部通过 ``set_gradient_checkpointing(True)`` 打开以省显存
67
- self.gradient_checkpointing = False
68
-
69
- def set_gradient_checkpointing(self, enabled: bool) -> None:
70
- """开启/关闭主干各层 gradient checkpointing(约省 2/3 激活显存)。"""
71
- self.gradient_checkpointing = enabled
72
-
73
- def set_moe_mode(self, mode: str) -> None:
74
- """切换所有 MoE 层模式('dense' / 'sparse')。"""
75
- for blk in self.moe_layers:
76
- blk.set_mode(mode)
77
-
78
- def set_router_temperature(self, t: float) -> None:
79
- for blk in self.moe_layers:
80
- blk.set_temperature(t)
81
-
82
- def forward(
83
- self,
84
- x: torch.Tensor,
85
- rope_cos: Optional[torch.Tensor] = None,
86
- rope_sin: Optional[torch.Tensor] = None,
87
- visual_slice: Optional[tuple[int, int]] = None,
88
- ) -> BackboneOutput:
89
- moe_stats: list[MoEStats] = []
90
- use_ckpt = self.gradient_checkpointing and self.training
91
-
92
- for blk in self.dense_layers:
93
- if use_ckpt:
94
- x = cp.checkpoint(
95
- blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False
96
- )
97
- else:
98
- x = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
99
-
100
- for blk in self.moe_layers:
101
- if use_ckpt:
102
- x, stats = cp.checkpoint(
103
- blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False
104
- )
105
- else:
106
- x, stats = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
107
- moe_stats.append(stats)
108
-
109
- x = self.final_norm(x)
110
- return BackboneOutput(hidden_states=x, moe_stats=moe_stats)
 
1
+ """18 层主干:前 9 Dense + 后 9 MoE。"""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.utils.checkpoint as cp
11
+
12
+ from ..modules.moe import MoEStats
13
+ from .blocks import DenseBlock, MoEBlockWithAttn
14
+
15
+
16
+ @dataclass
17
+ class BackboneOutput:
18
+ """主干输出。"""
19
+
20
+ hidden_states: torch.Tensor # [B, N, D]
21
+ moe_stats: list[MoEStats] = field(default_factory=list)
22
+
23
+
24
+ class Backbone(nn.Module):
25
+ """端到端主干。
26
+
27
+ 输入序列已包含位置编码(视觉部分 RoPE 在每层内部应用,非视觉部分使用
28
+ 可学习 PE 在外部加完)。本模块只负责 18 层堆叠 + 路由统计聚合。
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ dim: int = 768,
34
+ num_heads: int = 12,
35
+ ffn_mult: int = 4,
36
+ num_dense_layers: int = 9,
37
+ num_moe_layers: int = 9,
38
+ num_routed: int = 7,
39
+ num_shared: int = 1,
40
+ topk: int = 3,
41
+ dropout: float = 0.0,
42
+ ) -> None:
43
+ super().__init__()
44
+ self.dim = dim
45
+ self.num_heads = num_heads
46
+ self.num_dense_layers = num_dense_layers
47
+ self.num_moe_layers = num_moe_layers
48
+
49
+ self.dense_layers = nn.ModuleList([
50
+ DenseBlock(dim, num_heads, ffn_mult=ffn_mult, dropout=dropout)
51
+ for _ in range(num_dense_layers)
52
+ ])
53
+ self.moe_layers = nn.ModuleList([
54
+ MoEBlockWithAttn(
55
+ dim,
56
+ num_heads,
57
+ num_routed=num_routed,
58
+ num_shared=num_shared,
59
+ topk=topk,
60
+ ffn_mult=ffn_mult,
61
+ dropout=dropout,
62
+ )
63
+ for _ in range(num_moe_layers)
64
+ ])
65
+ self.final_norm = nn.LayerNorm(dim)
66
+ # 默认关闭;外部通过 ``set_gradient_checkpointing(True)`` 打开以省显存
67
+ self.gradient_checkpointing = False
68
+
69
+ def set_gradient_checkpointing(self, enabled: bool) -> None:
70
+ """开启/关闭主干各层 gradient checkpointing(约省 2/3 激活显存)。"""
71
+ self.gradient_checkpointing = enabled
72
+
73
+ def set_moe_mode(self, mode: str) -> None:
74
+ """切换所有 MoE 层模式('dense' / 'sparse')。"""
75
+ for blk in self.moe_layers:
76
+ blk.set_mode(mode)
77
+
78
+ def set_router_temperature(self, t: float) -> None:
79
+ for blk in self.moe_layers:
80
+ blk.set_temperature(t)
81
+
82
+ def forward(
83
+ self,
84
+ x: torch.Tensor,
85
+ rope_cos: Optional[torch.Tensor] = None,
86
+ rope_sin: Optional[torch.Tensor] = None,
87
+ visual_slice: Optional[tuple[int, int]] = None,
88
+ ) -> BackboneOutput:
89
+ moe_stats: list[MoEStats] = []
90
+ use_ckpt = self.gradient_checkpointing and self.training
91
+
92
+ for blk in self.dense_layers:
93
+ if use_ckpt:
94
+ x = cp.checkpoint(
95
+ blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False
96
+ )
97
+ else:
98
+ x = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
99
+
100
+ for blk in self.moe_layers:
101
+ if use_ckpt:
102
+ x, stats = cp.checkpoint(
103
+ blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False
104
+ )
105
+ else:
106
+ x, stats = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
107
+ moe_stats.append(stats)
108
+
109
+ x = self.final_norm(x)
110
+ return BackboneOutput(hidden_states=x, moe_stats=moe_stats)
src/wjad/backbone/blocks.py CHANGED
@@ -1,79 +1,79 @@
1
- """主干层 block:Dense(GateSelfAttn + SwiGLU FFN)/ MoE(GateSelfAttn + MoE FFN)。"""
2
-
3
- from __future__ import annotations
4
-
5
- from typing import Optional
6
-
7
- import torch
8
- import torch.nn as nn
9
-
10
- from ..modules.ffn import SwiGLUFFN
11
- from ..modules.gate_attention import GateSelfAttention
12
- from ..modules.moe import MoEBlock, MoEStats
13
-
14
-
15
- class DenseBlock(nn.Module):
16
- """PreNorm GateSelfAttention + PreNorm SwiGLU FFN。"""
17
-
18
- def __init__(self, dim: int, num_heads: int, ffn_mult: int = 4, dropout: float = 0.0) -> None:
19
- super().__init__()
20
- self.norm1 = nn.LayerNorm(dim)
21
- self.attn = GateSelfAttention(dim, num_heads=num_heads, dropout=dropout)
22
- self.norm2 = nn.LayerNorm(dim)
23
- self.ffn = SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout)
24
-
25
- def forward(
26
- self,
27
- x: torch.Tensor,
28
- rope_cos: Optional[torch.Tensor] = None,
29
- rope_sin: Optional[torch.Tensor] = None,
30
- visual_slice: Optional[tuple[int, int]] = None,
31
- ) -> torch.Tensor:
32
- x = x + self.attn(self.norm1(x), rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
33
- x = x + self.ffn(self.norm2(x))
34
- return x
35
-
36
-
37
- class MoEBlockWithAttn(nn.Module):
38
- """PreNorm GateSelfAttention + PreNorm MoE FFN。"""
39
-
40
- def __init__(
41
- self,
42
- dim: int,
43
- num_heads: int,
44
- num_routed: int = 7,
45
- num_shared: int = 1,
46
- topk: int = 3,
47
- ffn_mult: int = 4,
48
- dropout: float = 0.0,
49
- ) -> None:
50
- super().__init__()
51
- self.norm1 = nn.LayerNorm(dim)
52
- self.attn = GateSelfAttention(dim, num_heads=num_heads, dropout=dropout)
53
- self.norm2 = nn.LayerNorm(dim)
54
- self.moe = MoEBlock(
55
- dim,
56
- num_routed=num_routed,
57
- num_shared=num_shared,
58
- topk=topk,
59
- ffn_mult=ffn_mult,
60
- dropout=dropout,
61
- )
62
-
63
- def set_mode(self, mode: str) -> None:
64
- self.moe.set_mode(mode)
65
-
66
- def set_temperature(self, t: float) -> None:
67
- self.moe.set_temperature(t)
68
-
69
- def forward(
70
- self,
71
- x: torch.Tensor,
72
- rope_cos: Optional[torch.Tensor] = None,
73
- rope_sin: Optional[torch.Tensor] = None,
74
- visual_slice: Optional[tuple[int, int]] = None,
75
- ) -> tuple[torch.Tensor, MoEStats]:
76
- x = x + self.attn(self.norm1(x), rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
77
- moe_out, stats = self.moe(self.norm2(x))
78
- x = x + moe_out
79
- return x, stats
 
1
+ """主干层 block:Dense(GateSelfAttn + SwiGLU FFN)/ MoE(GateSelfAttn + MoE FFN)。"""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from ..modules.ffn import SwiGLUFFN
11
+ from ..modules.gate_attention import GateSelfAttention
12
+ from ..modules.moe import MoEBlock, MoEStats
13
+
14
+
15
+ class DenseBlock(nn.Module):
16
+ """PreNorm GateSelfAttention + PreNorm SwiGLU FFN。"""
17
+
18
+ def __init__(self, dim: int, num_heads: int, ffn_mult: int = 4, dropout: float = 0.0) -> None:
19
+ super().__init__()
20
+ self.norm1 = nn.LayerNorm(dim)
21
+ self.attn = GateSelfAttention(dim, num_heads=num_heads, dropout=dropout)
22
+ self.norm2 = nn.LayerNorm(dim)
23
+ self.ffn = SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout)
24
+
25
+ def forward(
26
+ self,
27
+ x: torch.Tensor,
28
+ rope_cos: Optional[torch.Tensor] = None,
29
+ rope_sin: Optional[torch.Tensor] = None,
30
+ visual_slice: Optional[tuple[int, int]] = None,
31
+ ) -> torch.Tensor:
32
+ x = x + self.attn(self.norm1(x), rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
33
+ x = x + self.ffn(self.norm2(x))
34
+ return x
35
+
36
+
37
+ class MoEBlockWithAttn(nn.Module):
38
+ """PreNorm GateSelfAttention + PreNorm MoE FFN。"""
39
+
40
+ def __init__(
41
+ self,
42
+ dim: int,
43
+ num_heads: int,
44
+ num_routed: int = 7,
45
+ num_shared: int = 1,
46
+ topk: int = 3,
47
+ ffn_mult: int = 4,
48
+ dropout: float = 0.0,
49
+ ) -> None:
50
+ super().__init__()
51
+ self.norm1 = nn.LayerNorm(dim)
52
+ self.attn = GateSelfAttention(dim, num_heads=num_heads, dropout=dropout)
53
+ self.norm2 = nn.LayerNorm(dim)
54
+ self.moe = MoEBlock(
55
+ dim,
56
+ num_routed=num_routed,
57
+ num_shared=num_shared,
58
+ topk=topk,
59
+ ffn_mult=ffn_mult,
60
+ dropout=dropout,
61
+ )
62
+
63
+ def set_mode(self, mode: str) -> None:
64
+ self.moe.set_mode(mode)
65
+
66
+ def set_temperature(self, t: float) -> None:
67
+ self.moe.set_temperature(t)
68
+
69
+ def forward(
70
+ self,
71
+ x: torch.Tensor,
72
+ rope_cos: Optional[torch.Tensor] = None,
73
+ rope_sin: Optional[torch.Tensor] = None,
74
+ visual_slice: Optional[tuple[int, int]] = None,
75
+ ) -> tuple[torch.Tensor, MoEStats]:
76
+ x = x + self.attn(self.norm1(x), rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
77
+ moe_out, stats = self.moe(self.norm2(x))
78
+ x = x + moe_out
79
+ return x, stats
src/wjad/calibration/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- """在线校准模块。"""
2
-
3
- from .online_calib import OnlineCalibration, CalibrationOutput
4
-
5
- __all__ = ["OnlineCalibration", "CalibrationOutput"]
 
1
+ """在线校准模块。"""
2
+
3
+ from .online_calib import OnlineCalibration, CalibrationOutput
4
+
5
+ __all__ = ["OnlineCalibration", "CalibrationOutput"]
src/wjad/calibration/online_calib.py CHANGED
@@ -1,196 +1,196 @@
1
- """在线校准网络。
2
-
3
- 输入
4
- - DINOv3 patch 特征 ``[B, T, gh, gw, D_dino]``(用作 K/V 上下文)。
5
- - 8 帧自车位姿(每帧 6D = 3 平移 + 3 轴角)``[B, 8, 6]``。
6
- - f-theta 内参 ``[B, intr_dim]``。Cosmos-Drive-Dreams 常见 **11** 维(无 ``linear_cde``);
7
- README 完整式为 14 维时把 ``intr_dim`` 配成 14 即可,**不做零填充**。
8
- - 相机外参 6D ``[B, 6]``。
9
-
10
- 流程
11
- - 上述运动学 / 内外参先 ``symlog`` 归一,再 Linear -> 256,作为额外
12
- 条件 token 与 256 个可学习 query token 拼接。
13
- - 6 层 = 2 × [1 GateCrossAttn(K,V <- DINOv3 patch) + 2 GateSelfAttn]。
14
- - 取最后一个 token,MLP -> Tanh -> ``residual_range`` → 输出
15
- symlog 空间的残差。``corrected = symexp(symlog(raw) + Tanh_residual)``。
16
-
17
- 输出
18
- - ``ego_residual`` ``[B, 8, 6]``、``intr_residual`` ``[B, intr_dim]``、
19
- ``extr_residual`` ``[B, 6]``,已 symexp 还原到真实空间的 ``corrected_*``。
20
- """
21
-
22
- from __future__ import annotations
23
-
24
- from dataclasses import dataclass
25
-
26
- import torch
27
- import torch.nn as nn
28
-
29
- from ..modules.gate_attention import GateCrossAttention, GateSelfAttention
30
- from ..modules.learned_pe import LearnedTokenPE
31
- from ..modules.normalization import symexp, symlog
32
-
33
-
34
- @dataclass
35
- class CalibrationOutput:
36
- """校准网络输出。残差均在 symlog 空间,``corrected_*`` 已 symexp 还原。"""
37
-
38
- ego_residual: torch.Tensor # [B, 8, 6]
39
- intr_residual: torch.Tensor # [B, intr_dim]
40
- extr_residual: torch.Tensor # [B, 6]
41
- corrected_ego: torch.Tensor # [B, 8, 6],真实空间
42
- corrected_intr: torch.Tensor # [B, intr_dim]
43
- corrected_extr: torch.Tensor # [B, 6]
44
-
45
-
46
- class _CalibBlock(nn.Module):
47
- """单个校准 block:1 GateCrossAttn + 2 GateSelfAttn,PreNorm。"""
48
-
49
- def __init__(self, dim: int, dim_kv: int, num_heads: int, num_self: int = 2) -> None:
50
- super().__init__()
51
- self.cross_norm = nn.LayerNorm(dim)
52
- self.cross = GateCrossAttention(dim, dim_kv, num_heads=num_heads)
53
- self.cross_drop = nn.Dropout(0.0)
54
- self.self_blocks = nn.ModuleList()
55
- for _ in range(num_self):
56
- self.self_blocks.append(
57
- nn.ModuleDict({
58
- "norm": nn.LayerNorm(dim),
59
- "attn": GateSelfAttention(dim, num_heads=num_heads),
60
- })
61
- )
62
-
63
- def forward(self, q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
64
- # 1) Cross
65
- q = q + self.cross(self.cross_norm(q), kv)
66
- # 2) Self ×2
67
- for blk in self.self_blocks:
68
- q = q + blk["attn"](blk["norm"](q))
69
- return q
70
-
71
-
72
- class OnlineCalibration(nn.Module):
73
- def __init__(
74
- self,
75
- dino_dim: int = 768,
76
- hidden_dim: int = 256,
77
- num_query_tokens: int = 256,
78
- num_blocks: int = 2,
79
- num_self_attn_per_block: int = 2,
80
- num_heads: int = 8,
81
- residual_range: float = 0.1,
82
- ego_dim: int = 6, # 3 平移 + 3 轴角
83
- intr_dim: int = 11,
84
- extr_dim: int = 6,
85
- num_history_frames: int = 8,
86
- init_zero_output: bool = True,
87
- ) -> None:
88
- super().__init__()
89
- self.hidden_dim = hidden_dim
90
- self.residual_range = residual_range
91
- self.ego_dim = ego_dim
92
- self.intr_dim = intr_dim
93
- self.extr_dim = extr_dim
94
- self.num_history_frames = num_history_frames
95
-
96
- # 256 可学习 query token
97
- self.query_tokens = nn.Parameter(torch.empty(num_query_tokens, hidden_dim))
98
- nn.init.trunc_normal_(self.query_tokens, std=0.02)
99
- self.query_pe = LearnedTokenPE(num_query_tokens, hidden_dim)
100
-
101
- # 条件 token 编码(symlog 空间)
102
- self.ego_proj = nn.Linear(ego_dim, hidden_dim)
103
- self.intr_proj = nn.Linear(intr_dim, hidden_dim)
104
- self.extr_proj = nn.Linear(extr_dim, hidden_dim)
105
- # 条件 token 也加可学习 PE(与 query 区分)
106
- num_cond = num_history_frames + 2 # 8 ego + 1 intr + 1 extr
107
- self.cond_pe = LearnedTokenPE(num_cond, hidden_dim)
108
- self.num_cond = num_cond
109
- self.num_query = num_query_tokens
110
-
111
- # KV 上下文:DINOv3 patch 特征投影到 hidden_dim
112
- self.kv_proj = nn.Linear(dino_dim, hidden_dim)
113
- self.kv_norm = nn.LayerNorm(hidden_dim)
114
-
115
- # 校准 block × num_blocks
116
- self.blocks = nn.ModuleList([
117
- _CalibBlock(hidden_dim, hidden_dim, num_heads=num_heads, num_self=num_self_attn_per_block)
118
- for _ in range(num_blocks)
119
- ])
120
-
121
- # 输出 MLP:取 last token -> 残差向量
122
- residual_total = num_history_frames * ego_dim + intr_dim + extr_dim
123
- self.out_norm = nn.LayerNorm(hidden_dim)
124
- self.out_mlp = nn.Sequential(
125
- nn.Linear(hidden_dim, hidden_dim),
126
- nn.GELU(),
127
- nn.Linear(hidden_dim, residual_total),
128
- )
129
- # 0 初始化最后一层 → Tanh(0) = 0 → 初始残差 = 0
130
- if init_zero_output:
131
- nn.init.zeros_(self.out_mlp[-1].weight)
132
- nn.init.zeros_(self.out_mlp[-1].bias)
133
-
134
- def forward(
135
- self,
136
- dino_feats: torch.Tensor, # [B, T, gh, gw, D_dino]
137
- ego_raw: torch.Tensor, # [B, 8, 6] 真实空间
138
- intr_raw: torch.Tensor, # [B, intr_dim],须与构造 ``OnlineCalibration`` 时一致
139
- extr_raw: torch.Tensor, # [B, 6]
140
- ) -> CalibrationOutput:
141
- b = dino_feats.shape[0]
142
- if intr_raw.shape[-1] != self.intr_dim:
143
- raise ValueError(
144
- f"intr_raw.shape[-1]={intr_raw.shape[-1]} 与 OnlineCalibration.intr_dim={self.intr_dim} 不一致。"
145
- f"数据是多少维就设多少维(见 configs calibration.intr_vec_dim),不要填充假参数。"
146
- )
147
- # === 上下文 K/V ===
148
- # 把 [B, T, gh, gw, D] flatten 为 [B, T*gh*gw, D]
149
- kv = dino_feats.reshape(b, -1, dino_feats.shape[-1])
150
- kv = self.kv_norm(self.kv_proj(kv))
151
-
152
- # === 条件 token(symlog 空间)===
153
- ego_sym = symlog(ego_raw)
154
- intr_sym = symlog(intr_raw)
155
- extr_sym = symlog(extr_raw)
156
-
157
- ego_tok = self.ego_proj(ego_sym) # [B, 8, D]
158
- intr_tok = self.intr_proj(intr_sym).unsqueeze(1) # [B, 1, D]
159
- extr_tok = self.extr_proj(extr_sym).unsqueeze(1) # [B, 1, D]
160
- cond = torch.cat([ego_tok, intr_tok, extr_tok], dim=1) # [B, num_cond, D]
161
- cond = self.cond_pe(cond)
162
-
163
- # === 拼接 query token ===
164
- q = self.query_tokens.unsqueeze(0).expand(b, -1, -1)
165
- q = self.query_pe(q)
166
- seq = torch.cat([cond, q], dim=1) # [B, num_cond + num_query, D]
167
-
168
- # === 6 层 block ===
169
- for blk in self.blocks:
170
- seq = blk(seq, kv)
171
-
172
- # === 取 last token ===
173
- last = self.out_norm(seq[:, -1, :]) # [B, D]
174
- residual_flat = self.out_mlp(last)
175
- # Tanh + 缩放
176
- residual_flat = torch.tanh(residual_flat) * self.residual_range
177
-
178
- # 拆分
179
- n_ego = self.num_history_frames * self.ego_dim
180
- ego_res = residual_flat[:, :n_ego].view(b, self.num_history_frames, self.ego_dim)
181
- intr_res = residual_flat[:, n_ego : n_ego + self.intr_dim]
182
- extr_res = residual_flat[:, n_ego + self.intr_dim :]
183
-
184
- # symlog 空间叠加 + symexp 还原
185
- corrected_ego = symexp(symlog(ego_raw) + ego_res)
186
- corrected_intr = symexp(symlog(intr_raw) + intr_res)
187
- corrected_extr = symexp(symlog(extr_raw) + extr_res)
188
-
189
- return CalibrationOutput(
190
- ego_residual=ego_res,
191
- intr_residual=intr_res,
192
- extr_residual=extr_res,
193
- corrected_ego=corrected_ego,
194
- corrected_intr=corrected_intr,
195
- corrected_extr=corrected_extr,
196
- )
 
1
+ """在线校准网络。
2
+
3
+ 输入
4
+ - DINOv3 patch 特征 ``[B, T, gh, gw, D_dino]``(用作 K/V 上下文)。
5
+ - 8 帧自车位姿(每帧 6D = 3 平移 + 3 轴角)``[B, 8, 6]``。
6
+ - f-theta 内参 ``[B, intr_dim]``。Cosmos-Drive-Dreams 常见 **11** 维(无 ``linear_cde``);
7
+ README 完整式为 14 维时把 ``intr_dim`` 配成 14 即可,**不做零填充**。
8
+ - 相机外参 6D ``[B, 6]``。
9
+
10
+ 流程
11
+ - 上述运动学 / 内外参先 ``symlog`` 归一,再 Linear -> 256,作为额外
12
+ 条件 token 与 256 个可学习 query token 拼接。
13
+ - 6 层 = 2 × [1 GateCrossAttn(K,V <- DINOv3 patch) + 2 GateSelfAttn]。
14
+ - 取最后一个 token,MLP -> Tanh -> ``residual_range`` → 输出
15
+ symlog 空间的残差。``corrected = symexp(symlog(raw) + Tanh_residual)``。
16
+
17
+ 输出
18
+ - ``ego_residual`` ``[B, 8, 6]``、``intr_residual`` ``[B, intr_dim]``、
19
+ ``extr_residual`` ``[B, 6]``,已 symexp 还原到真实空间的 ``corrected_*``。
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ from dataclasses import dataclass
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+
29
+ from ..modules.gate_attention import GateCrossAttention, GateSelfAttention
30
+ from ..modules.learned_pe import LearnedTokenPE
31
+ from ..modules.normalization import symexp, symlog
32
+
33
+
34
+ @dataclass
35
+ class CalibrationOutput:
36
+ """校准网络输出。残差均在 symlog 空间,``corrected_*`` 已 symexp 还原。"""
37
+
38
+ ego_residual: torch.Tensor # [B, 8, 6]
39
+ intr_residual: torch.Tensor # [B, intr_dim]
40
+ extr_residual: torch.Tensor # [B, 6]
41
+ corrected_ego: torch.Tensor # [B, 8, 6],真实空间
42
+ corrected_intr: torch.Tensor # [B, intr_dim]
43
+ corrected_extr: torch.Tensor # [B, 6]
44
+
45
+
46
+ class _CalibBlock(nn.Module):
47
+ """单个校准 block:1 GateCrossAttn + 2 GateSelfAttn,PreNorm。"""
48
+
49
+ def __init__(self, dim: int, dim_kv: int, num_heads: int, num_self: int = 2) -> None:
50
+ super().__init__()
51
+ self.cross_norm = nn.LayerNorm(dim)
52
+ self.cross = GateCrossAttention(dim, dim_kv, num_heads=num_heads)
53
+ self.cross_drop = nn.Dropout(0.0)
54
+ self.self_blocks = nn.ModuleList()
55
+ for _ in range(num_self):
56
+ self.self_blocks.append(
57
+ nn.ModuleDict({
58
+ "norm": nn.LayerNorm(dim),
59
+ "attn": GateSelfAttention(dim, num_heads=num_heads),
60
+ })
61
+ )
62
+
63
+ def forward(self, q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
64
+ # 1) Cross
65
+ q = q + self.cross(self.cross_norm(q), kv)
66
+ # 2) Self ×2
67
+ for blk in self.self_blocks:
68
+ q = q + blk["attn"](blk["norm"](q))
69
+ return q
70
+
71
+
72
+ class OnlineCalibration(nn.Module):
73
+ def __init__(
74
+ self,
75
+ dino_dim: int = 768,
76
+ hidden_dim: int = 256,
77
+ num_query_tokens: int = 256,
78
+ num_blocks: int = 2,
79
+ num_self_attn_per_block: int = 2,
80
+ num_heads: int = 8,
81
+ residual_range: float = 0.1,
82
+ ego_dim: int = 6, # 3 平移 + 3 轴角
83
+ intr_dim: int = 11,
84
+ extr_dim: int = 6,
85
+ num_history_frames: int = 8,
86
+ init_zero_output: bool = True,
87
+ ) -> None:
88
+ super().__init__()
89
+ self.hidden_dim = hidden_dim
90
+ self.residual_range = residual_range
91
+ self.ego_dim = ego_dim
92
+ self.intr_dim = intr_dim
93
+ self.extr_dim = extr_dim
94
+ self.num_history_frames = num_history_frames
95
+
96
+ # 256 可学习 query token
97
+ self.query_tokens = nn.Parameter(torch.empty(num_query_tokens, hidden_dim))
98
+ nn.init.trunc_normal_(self.query_tokens, std=0.02)
99
+ self.query_pe = LearnedTokenPE(num_query_tokens, hidden_dim)
100
+
101
+ # 条件 token 编码(symlog 空间)
102
+ self.ego_proj = nn.Linear(ego_dim, hidden_dim)
103
+ self.intr_proj = nn.Linear(intr_dim, hidden_dim)
104
+ self.extr_proj = nn.Linear(extr_dim, hidden_dim)
105
+ # 条件 token 也加可学习 PE(与 query 区分)
106
+ num_cond = num_history_frames + 2 # 8 ego + 1 intr + 1 extr
107
+ self.cond_pe = LearnedTokenPE(num_cond, hidden_dim)
108
+ self.num_cond = num_cond
109
+ self.num_query = num_query_tokens
110
+
111
+ # KV 上下文:DINOv3 patch 特征投影到 hidden_dim
112
+ self.kv_proj = nn.Linear(dino_dim, hidden_dim)
113
+ self.kv_norm = nn.LayerNorm(hidden_dim)
114
+
115
+ # 校准 block × num_blocks
116
+ self.blocks = nn.ModuleList([
117
+ _CalibBlock(hidden_dim, hidden_dim, num_heads=num_heads, num_self=num_self_attn_per_block)
118
+ for _ in range(num_blocks)
119
+ ])
120
+
121
+ # 输出 MLP:取 last token -> 残差向量
122
+ residual_total = num_history_frames * ego_dim + intr_dim + extr_dim
123
+ self.out_norm = nn.LayerNorm(hidden_dim)
124
+ self.out_mlp = nn.Sequential(
125
+ nn.Linear(hidden_dim, hidden_dim),
126
+ nn.GELU(),
127
+ nn.Linear(hidden_dim, residual_total),
128
+ )
129
+ # 0 初始化最后一层 → Tanh(0) = 0 → 初始残差 = 0
130
+ if init_zero_output:
131
+ nn.init.zeros_(self.out_mlp[-1].weight)
132
+ nn.init.zeros_(self.out_mlp[-1].bias)
133
+
134
+ def forward(
135
+ self,
136
+ dino_feats: torch.Tensor, # [B, T, gh, gw, D_dino]
137
+ ego_raw: torch.Tensor, # [B, 8, 6] 真实空间
138
+ intr_raw: torch.Tensor, # [B, intr_dim],须与构造 ``OnlineCalibration`` 时一致
139
+ extr_raw: torch.Tensor, # [B, 6]
140
+ ) -> CalibrationOutput:
141
+ b = dino_feats.shape[0]
142
+ if intr_raw.shape[-1] != self.intr_dim:
143
+ raise ValueError(
144
+ f"intr_raw.shape[-1]={intr_raw.shape[-1]} 与 OnlineCalibration.intr_dim={self.intr_dim} 不一致。"
145
+ f"数据是多少维就设多少维(见 configs calibration.intr_vec_dim),不要填充假参数。"
146
+ )
147
+ # === 上下文 K/V ===
148
+ # 把 [B, T, gh, gw, D] flatten 为 [B, T*gh*gw, D]
149
+ kv = dino_feats.reshape(b, -1, dino_feats.shape[-1])
150
+ kv = self.kv_norm(self.kv_proj(kv))
151
+
152
+ # === 条件 token(symlog 空间)===
153
+ ego_sym = symlog(ego_raw)
154
+ intr_sym = symlog(intr_raw)
155
+ extr_sym = symlog(extr_raw)
156
+
157
+ ego_tok = self.ego_proj(ego_sym) # [B, 8, D]
158
+ intr_tok = self.intr_proj(intr_sym).unsqueeze(1) # [B, 1, D]
159
+ extr_tok = self.extr_proj(extr_sym).unsqueeze(1) # [B, 1, D]
160
+ cond = torch.cat([ego_tok, intr_tok, extr_tok], dim=1) # [B, num_cond, D]
161
+ cond = self.cond_pe(cond)
162
+
163
+ # === 拼接 query token ===
164
+ q = self.query_tokens.unsqueeze(0).expand(b, -1, -1)
165
+ q = self.query_pe(q)
166
+ seq = torch.cat([cond, q], dim=1) # [B, num_cond + num_query, D]
167
+
168
+ # === 6 层 block ===
169
+ for blk in self.blocks:
170
+ seq = blk(seq, kv)
171
+
172
+ # === 取 last token ===
173
+ last = self.out_norm(seq[:, -1, :]) # [B, D]
174
+ residual_flat = self.out_mlp(last)
175
+ # Tanh + 缩放
176
+ residual_flat = torch.tanh(residual_flat) * self.residual_range
177
+
178
+ # 拆分
179
+ n_ego = self.num_history_frames * self.ego_dim
180
+ ego_res = residual_flat[:, :n_ego].view(b, self.num_history_frames, self.ego_dim)
181
+ intr_res = residual_flat[:, n_ego : n_ego + self.intr_dim]
182
+ extr_res = residual_flat[:, n_ego + self.intr_dim :]
183
+
184
+ # symlog 空间叠加 + symexp 还原
185
+ corrected_ego = symexp(symlog(ego_raw) + ego_res)
186
+ corrected_intr = symexp(symlog(intr_raw) + intr_res)
187
+ corrected_extr = symexp(symlog(extr_raw) + extr_res)
188
+
189
+ return CalibrationOutput(
190
+ ego_residual=ego_res,
191
+ intr_residual=intr_res,
192
+ extr_residual=extr_res,
193
+ corrected_ego=corrected_ego,
194
+ corrected_intr=corrected_intr,
195
+ corrected_extr=corrected_extr,
196
+ )
src/wjad/data/__init__.py CHANGED
@@ -1,39 +1,39 @@
1
- """Cosmos-Drive-Dreams 数据加载与目标构建。"""
2
-
3
- from .se3 import (
4
- matrix_to_6d,
5
- six_d_to_matrix,
6
- invert_se3,
7
- rotation_matrix_to_axis_angle,
8
- axis_angle_to_rotation_matrix,
9
- )
10
- from .ftheta_proj import project_points_ftheta
11
- from .transforms import (
12
- crop_top_half,
13
- normalize_image,
14
- add_gaussian_noise,
15
- perturb_kinematics,
16
- )
17
- from .targets import build_detection_targets, build_ego_future_target, ObjectTrackInfo
18
- from .hdmap import parse_hdmap_clip, HDMAP_SOURCES
19
- from .cosmos_dataset import CosmosDriveDreamsDataset, build_clip_index
20
-
21
- __all__ = [
22
- "matrix_to_6d",
23
- "six_d_to_matrix",
24
- "invert_se3",
25
- "rotation_matrix_to_axis_angle",
26
- "axis_angle_to_rotation_matrix",
27
- "project_points_ftheta",
28
- "crop_top_half",
29
- "normalize_image",
30
- "add_gaussian_noise",
31
- "perturb_kinematics",
32
- "build_detection_targets",
33
- "build_ego_future_target",
34
- "ObjectTrackInfo",
35
- "CosmosDriveDreamsDataset",
36
- "build_clip_index",
37
- "parse_hdmap_clip",
38
- "HDMAP_SOURCES",
39
- ]
 
1
+ """Cosmos-Drive-Dreams 数据加载与目标构建。"""
2
+
3
+ from .se3 import (
4
+ matrix_to_6d,
5
+ six_d_to_matrix,
6
+ invert_se3,
7
+ rotation_matrix_to_axis_angle,
8
+ axis_angle_to_rotation_matrix,
9
+ )
10
+ from .ftheta_proj import project_points_ftheta
11
+ from .transforms import (
12
+ crop_top_half,
13
+ normalize_image,
14
+ add_gaussian_noise,
15
+ perturb_kinematics,
16
+ )
17
+ from .targets import build_detection_targets, build_ego_future_target, ObjectTrackInfo
18
+ from .hdmap import parse_hdmap_clip, HDMAP_SOURCES
19
+ from .cosmos_dataset import CosmosDriveDreamsDataset, build_clip_index
20
+
21
+ __all__ = [
22
+ "matrix_to_6d",
23
+ "six_d_to_matrix",
24
+ "invert_se3",
25
+ "rotation_matrix_to_axis_angle",
26
+ "axis_angle_to_rotation_matrix",
27
+ "project_points_ftheta",
28
+ "crop_top_half",
29
+ "normalize_image",
30
+ "add_gaussian_noise",
31
+ "perturb_kinematics",
32
+ "build_detection_targets",
33
+ "build_ego_future_target",
34
+ "ObjectTrackInfo",
35
+ "CosmosDriveDreamsDataset",
36
+ "build_clip_index",
37
+ "parse_hdmap_clip",
38
+ "HDMAP_SOURCES",
39
+ ]
src/wjad/data/cosmos_dataset.py CHANGED
@@ -1,439 +1,439 @@
1
- """Cosmos-Drive-Dreams 数据集加载器(真实实现)。
2
-
3
- 期待目录结构(从 NVIDIA 提供的 .tar 解压):
4
-
5
- data_root/
6
- synthetic/single_view/
7
- generation/{clip_id}_{chunk_id}_{weather}.mp4 # 121 帧合成视频
8
- labels/{clip_id}/
9
- vehicle_pose/000000.vehicle_pose.npy ... # 30 FPS, FLU
10
- pose/000000.pose.{camera}.npy # 30 FPS, OpenCV
11
- ftheta_intrinsic/ftheta_intrinsic.{camera}.npy
12
- all_object_info/000000.all_object_info.json
13
- lidar_raw/000000.lidar_raw.npz # 10 FPS
14
-
15
- 每段 clip 提供:
16
- - 视频按 `_chunk_id` 分块。chunk_id=0 对应 label idx 0..120;chunk_id=1 对应 label idx 121..241。
17
- - 每个样本:8 帧不重叠窗口 t∈[7, 96],输入 8 帧(t-7..t)+ 未来 24 帧标签。
18
- """
19
-
20
- from __future__ import annotations
21
-
22
- import json
23
- from dataclasses import dataclass
24
- from pathlib import Path
25
- from typing import Sequence
26
-
27
- import cv2
28
- import numpy as np
29
- import torch
30
- from torch.utils.data import Dataset
31
-
32
- from ..modules.normalization import symlog
33
- from ..modules.rays import FThetaCamera
34
- from .label_paths import resolve_clip_file
35
- from .hdmap import parse_hdmap_clip
36
- from .se3 import matrix_to_6d
37
- from .targets import (
38
- ObjectTrackInfo,
39
- build_detection_targets,
40
- build_ego_future_target,
41
- )
42
- from .transforms import DINOV3_MEAN, DINOV3_STD
43
-
44
-
45
- # 数据集 README 列出的对象类型;动态类用于 is_dynamic + 未来轨迹监督。
46
- DEFAULT_DYNAMIC_CLASSES = [
47
- "Automobile",
48
- "Heavy_truck",
49
- "Bus",
50
- "Train_or_tram_car",
51
- "Trolley_bus",
52
- "Other_vehicle",
53
- "Trailer",
54
- "Person",
55
- "Stroller",
56
- "Rider",
57
- "Animal",
58
- "Protruding_object",
59
- ]
60
-
61
- # 结构化场景类(与 ``hdmap.py`` 的 9 个 HDMAP_SOURCES key 一一对应)。
62
- DEFAULT_STRUCTURED_CLASSES = [
63
- "lane",
64
- "laneline",
65
- "road_boundary",
66
- "wait_line",
67
- "crosswalk",
68
- "road_marking",
69
- "pole",
70
- "traffic_light",
71
- "traffic_sign",
72
- ]
73
-
74
-
75
- @dataclass
76
- class ClipSample:
77
- """clip 索引项。"""
78
-
79
- clip_id: str
80
- chunk_id: int
81
- weather: str
82
- video_path: Path
83
- labels_dir: Path
84
- anchor_t: int # 当前帧(含),范围 [7, 96]
85
- chunk_offset: int # 当前 chunk 在标签里的起始 idx(0 或 121)
86
-
87
-
88
- def build_clip_index(
89
- data_root: str | Path,
90
- weathers: Sequence[str] = ("Sunny",),
91
- chunk_ids: Sequence[int] = (0, 1),
92
- camera_name: str = "camera_front_wide_120fov",
93
- stride: int = 8,
94
- anchor_min: int = 7,
95
- anchor_max: int = 96,
96
- max_clips: int | None = None,
97
- ) -> list[ClipSample]:
98
- """枚举所有可用 (clip, chunk, weather, anchor_t) 样本。
99
-
100
- 锚点 ``t`` 在 chunk 内为局部索引,对应视频帧 ``t``,对应标签帧
101
- ``chunk_offset + t``(chunk_offset = chunk_id * 121)。
102
- """
103
- root = Path(data_root)
104
- syn_dir = root / "synthetic" / "single_view" / "generation"
105
- labels_dir = root / "labels"
106
-
107
- samples: list[ClipSample] = []
108
- if not syn_dir.exists():
109
- return samples
110
-
111
- for video_path in sorted(syn_dir.glob("*.mp4")):
112
- # 文件名形如 {clip_id}_{chunk_id}_{weather}.mp4
113
- # clip_id 可能含下划线(UUID 或 timestamp 形式),所以从右侧解析
114
- stem = video_path.stem
115
- parts = stem.rsplit("_", 2)
116
- if len(parts) != 3:
117
- continue
118
- clip_id, chunk_str, weather = parts
119
- try:
120
- chunk_id = int(chunk_str)
121
- except ValueError:
122
- continue
123
- if chunk_id not in chunk_ids or weather not in weathers:
124
- continue
125
-
126
- clip_label_dir = labels_dir / clip_id
127
- if not clip_label_dir.exists():
128
- continue
129
-
130
- chunk_offset = chunk_id * 121
131
- for t in range(anchor_min, anchor_max + 1, stride):
132
- samples.append(
133
- ClipSample(
134
- clip_id=clip_id,
135
- chunk_id=chunk_id,
136
- weather=weather,
137
- video_path=video_path,
138
- labels_dir=clip_label_dir,
139
- anchor_t=t,
140
- chunk_offset=chunk_offset,
141
- )
142
- )
143
- if max_clips is not None and len({s.clip_id for s in samples}) >= max_clips:
144
- break
145
-
146
- return samples
147
-
148
-
149
- def _load_video_frames(
150
- video_path: Path,
151
- frame_indices: Sequence[int],
152
- target_h: int,
153
- target_w: int,
154
- ) -> torch.Tensor:
155
- """从 .mp4 中读取指定帧序列,调整大小并按 ``[T, 3, H, W]`` 返回 ``float32 in [0, 1]``。"""
156
- cap = cv2.VideoCapture(str(video_path))
157
- if not cap.isOpened():
158
- raise FileNotFoundError(f"无法打开视频: {video_path}")
159
- frames = []
160
- for idx in frame_indices:
161
- cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
162
- ok, bgr = cap.read()
163
- if not ok:
164
- cap.release()
165
- raise RuntimeError(f"读取帧 {idx} 失败: {video_path}")
166
- rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
167
- rgb = cv2.resize(rgb, (target_w, target_h * 2), interpolation=cv2.INTER_AREA)
168
- # 裁去上半部分(天空)后高度变为 target_h
169
- rgb = rgb[target_h:, :, :]
170
- rgb = rgb.astype(np.float32) / 255.0
171
- frames.append(torch.from_numpy(rgb).permute(2, 0, 1)) # [3, H, W]
172
- cap.release()
173
- return torch.stack(frames, dim=0)
174
-
175
-
176
- def _load_npy(path: Path) -> np.ndarray:
177
- return np.load(path, allow_pickle=False)
178
-
179
-
180
- def _load_object_info(path: Path) -> list[ObjectTrackInfo]:
181
- """解析单帧 all_object_info JSON。"""
182
- if not path.exists():
183
- return []
184
- data = json.loads(path.read_text())
185
- out = []
186
- for tid, info in data.items():
187
- T = torch.tensor(info["object_to_world"], dtype=torch.float32)
188
- lwh = torch.tensor(info["object_lwh"], dtype=torch.float32)
189
- out.append(
190
- ObjectTrackInfo(
191
- tracking_id=tid,
192
- object_to_world=T,
193
- lwh=lwh,
194
- is_moving=bool(info.get("object_is_moving", False)),
195
- object_type=str(info.get("object_type", "")),
196
- )
197
- )
198
- return out
199
-
200
-
201
- def _load_lidar_self_frame(
202
- labels_dir: Path,
203
- label_idx: int,
204
- vehicle_pose: torch.Tensor,
205
- max_history: int = 3,
206
- ) -> torch.Tensor | None:
207
- """读取与 ``label_idx`` 时间最近的 LIDAR 帧并把 xyz 转到当前 ego self 系。
208
-
209
- LIDAR 是 10 FPS(每 3 个相机帧 1 个 LIDAR 帧),数据集存储 ``000000``、
210
- ``000003``、``000006`` 等步长 3 的索引。我们向下取整最近的一帧。
211
- """
212
- lidar_idx = (label_idx // 3) * 3
213
- search_order = [lidar_idx - back * 3 for back in range(max_history + 1) if lidar_idx - back * 3 >= 0]
214
- p: Path | None = None
215
- for idx_try in search_order:
216
- try:
217
- p = resolve_clip_file(labels_dir, "lidar_raw", f"{idx_try:06d}.lidar_raw.npz")
218
- break
219
- except FileNotFoundError:
220
- continue
221
- if p is None:
222
- return None
223
- arr = np.load(p, allow_pickle=False)
224
- xyz_lidar = arr["xyz"] # [N, 3] in lidar frame
225
- lidar_to_world = arr["lidar_to_world"] # [4, 4]
226
- # 转到 world 后再转 self
227
- pts_w = (lidar_to_world[:3, :3] @ xyz_lidar.T).T + lidar_to_world[:3, 3]
228
- inv_pose = torch.linalg.inv(vehicle_pose)
229
- pts_w_t = torch.from_numpy(pts_w).float()
230
- pts_self = (inv_pose[:3, :3] @ pts_w_t.T).T + inv_pose[:3, 3]
231
- return pts_self
232
-
233
-
234
- class CosmosDriveDreamsDataset(Dataset):
235
- """端到端样本:8 帧图像 + ego/intr/extr + 检测 + 自车未来 + 对象未来。"""
236
-
237
- def __init__(
238
- self,
239
- data_root: str | Path,
240
- samples: list[ClipSample] | None = None,
241
- weathers: Sequence[str] = ("Sunny",),
242
- camera_name: str = "camera_front_wide_120fov",
243
- image_h: int = 384,
244
- image_w: int = 1024,
245
- num_history: int = 8,
246
- future_horizon: int = 24,
247
- max_distance_m: float = 48.0,
248
- occlusion_tol: float = 0.5,
249
- dynamic_classes: Sequence[str] = DEFAULT_DYNAMIC_CLASSES,
250
- structured_classes: Sequence[str] = DEFAULT_STRUCTURED_CLASSES,
251
- do_normalize: bool = True,
252
- use_lidar_occlusion: bool = True,
253
- use_hdmap: bool = True,
254
- ) -> None:
255
- super().__init__()
256
- self.data_root = Path(data_root)
257
- self.samples = samples if samples is not None else build_clip_index(
258
- data_root, weathers=weathers, camera_name=camera_name
259
- )
260
- self.camera_name = camera_name
261
- self.image_h = image_h
262
- self.image_w = image_w
263
- self.num_history = num_history
264
- self.future_horizon = future_horizon
265
- self.max_distance_m = max_distance_m
266
- self.occlusion_tol = occlusion_tol
267
- self.dynamic_classes = list(dynamic_classes)
268
- self.structured_classes = list(structured_classes)
269
- self.do_normalize = do_normalize
270
- self.use_lidar_occlusion = use_lidar_occlusion
271
- self.use_hdmap = use_hdmap
272
- # HDMap 是 per-clip 静态对象,缓存避免每个 anchor_t 都重新解析
273
- self._hdmap_cache: dict[str, list[ObjectTrackInfo]] = {}
274
- self._hdmap_cache_max = 32
275
-
276
- def __len__(self) -> int:
277
- return len(self.samples)
278
-
279
- def _load_intrinsic(self, sample: ClipSample) -> torch.Tensor:
280
- p = resolve_clip_file(
281
- sample.labels_dir,
282
- "ftheta_intrinsic",
283
- f"ftheta_intrinsic.{self.camera_name}.npy",
284
- )
285
- return torch.from_numpy(_load_npy(p)).float()
286
-
287
- def _load_pose_camera(self, sample: ClipSample, label_idx: int) -> torch.Tensor:
288
- p = resolve_clip_file(
289
- sample.labels_dir,
290
- "pose",
291
- f"{label_idx:06d}.pose.{self.camera_name}.npy",
292
- )
293
- return torch.from_numpy(_load_npy(p)).float()
294
-
295
- def _load_pose_vehicle(self, sample: ClipSample, label_idx: int) -> torch.Tensor:
296
- p = resolve_clip_file(
297
- sample.labels_dir,
298
- "vehicle_pose",
299
- f"{label_idx:06d}.vehicle_pose.npy",
300
- )
301
- return torch.from_numpy(_load_npy(p)).float()
302
-
303
- def _load_hdmap_static(self, clip_dir: Path) -> list[ObjectTrackInfo]:
304
- if not self.use_hdmap:
305
- return []
306
- key = str(clip_dir)
307
- cached = self._hdmap_cache.get(key)
308
- if cached is not None:
309
- return cached
310
- objs = parse_hdmap_clip(clip_dir)
311
- if len(self._hdmap_cache) >= self._hdmap_cache_max:
312
- self._hdmap_cache.pop(next(iter(self._hdmap_cache)))
313
- self._hdmap_cache[key] = objs
314
- return objs
315
-
316
- def _load_objects(self, sample: ClipSample, label_idx: int) -> list[ObjectTrackInfo]:
317
- p = resolve_clip_file(
318
- sample.labels_dir,
319
- "all_object_info",
320
- f"{label_idx:06d}.all_object_info.json",
321
- )
322
- dynamic = _load_object_info(p)
323
- # HDMap 是 clip 级静态标签:t 与 t+k 帧都拿同一份(tracking_id 相同),
324
- # 这样 ``build_detection_targets`` 的未来轨迹分支会自动得到 ~0 残差,
325
- # 同时由 ``is_dynamic=0`` 在损失里被 mask 掉,不进 trajectory NLL。
326
- return dynamic + self._load_hdmap_static(sample.labels_dir)
327
-
328
- def __getitem__(self, idx: int) -> dict:
329
- s = self.samples[idx]
330
- # 视频帧索引(chunk 内 0-based)
331
- t = s.anchor_t
332
- history_frames = list(range(t - self.num_history + 1, t + 1))
333
- # 标签索引:chunk_offset + chunk-local idx
334
- history_label_idx = [s.chunk_offset + f for f in history_frames]
335
- future_label_idx = [s.chunk_offset + t + 1 + k for k in range(self.future_horizon)]
336
-
337
- # === 1) 加载图像 ===
338
- # 注意:videl 已经裁过上半(数据生成时仍 1920x1080 等原始分辨率);
339
- # 这里在 _load_video_frames 内同时做 resize 与 top-half 裁剪。
340
- images = _load_video_frames(s.video_path, history_frames, self.image_h, self.image_w)
341
- # [T, 3, H, W],[0, 1]
342
- if self.do_normalize:
343
- images = (images - DINOV3_MEAN) / DINOV3_STD
344
-
345
- # === 2) 加载内参 / 外参 ===
346
- intr_vec = self._load_intrinsic(s) # [14]
347
-
348
- # 当前帧的 cam_to_world 与 vehicle_to_world,得到 cam_to_vehicle
349
- pose_cam_world = self._load_pose_camera(s, s.chunk_offset + t)
350
- pose_veh_world = self._load_pose_vehicle(s, s.chunk_offset + t)
351
- # cam_to_vehicle = inv(vehicle_to_world) @ cam_to_world
352
- inv_veh = torch.linalg.inv(pose_veh_world)
353
- cam2veh = inv_veh @ pose_cam_world
354
- extr_6d = matrix_to_6d(cam2veh) # [6]
355
-
356
- # === 3) 历史 8 帧 ego pose(vehicle 6D)===
357
- ego_6d = []
358
- for li in history_label_idx:
359
- T_vw = self._load_pose_vehicle(s, li)
360
- ego_6d.append(matrix_to_6d(T_vw))
361
- ego_6d = torch.stack(ego_6d, dim=0) # [8, 6]
362
-
363
- # === 4) 检测 / 未来轨迹标签 ===
364
- # objs_t / objs_future = 动态 all_object_info ∪ HDMap 静态对象。
365
- objs_t = self._load_objects(s, s.chunk_offset + t)
366
- objs_future = [self._load_objects(s, li) for li in future_label_idx]
367
- veh_pose_future = []
368
- for li in future_label_idx:
369
- try:
370
- veh_pose_future.append(self._load_pose_vehicle(s, li))
371
- except FileNotFoundError:
372
- break
373
-
374
- cam = FThetaCamera.from_vector(intr_vec)
375
- lidar_self = None
376
- if self.use_lidar_occlusion:
377
- try:
378
- lidar_self = _load_lidar_self_frame(
379
- s.labels_dir,
380
- s.chunk_offset + t,
381
- pose_veh_world,
382
- )
383
- except Exception:
384
- lidar_self = None
385
-
386
- det_targets = build_detection_targets(
387
- objects_t=objs_t,
388
- objects_future=objs_future,
389
- vehicle_pose_t=pose_veh_world,
390
- vehicle_pose_future=veh_pose_future,
391
- cam_intrinsic=cam,
392
- cam2vehicle=cam2veh,
393
- image_h=self.image_h,
394
- image_w=self.image_w,
395
- max_distance_m=self.max_distance_m,
396
- occlusion_depth_tolerance=self.occlusion_tol,
397
- lidar_points_self=lidar_self,
398
- dynamic_classes=self.dynamic_classes,
399
- structured_classes=self.structured_classes,
400
- future_horizon=self.future_horizon,
401
- )
402
-
403
- ego_future, ego_future_valid = build_ego_future_target(
404
- pose_veh_world, veh_pose_future, horizon=self.future_horizon
405
- )
406
-
407
- sample_out = {
408
- "images": images,
409
- "ego_6d": ego_6d,
410
- "intr_vec": intr_vec,
411
- "extr_6d": extr_6d,
412
- "ego_future": ego_future,
413
- "ego_future_valid": ego_future_valid,
414
- "targets": det_targets,
415
- "meta": {
416
- "clip_id": s.clip_id,
417
- "chunk_id": s.chunk_id,
418
- "weather": s.weather,
419
- "anchor_t": s.anchor_t,
420
- },
421
- }
422
- return sample_out
423
-
424
-
425
- def collate_samples(batch: list[dict]) -> dict:
426
- """自定义 collate:对图像 / ego / intr / extr / ego_future 直接 stack;
427
- targets 列表保留为 list(便于匈牙利匹配处理变长 N);
428
- meta 也保留为 list。"""
429
- out = {
430
- "images": torch.stack([b["images"] for b in batch], dim=0),
431
- "ego_6d": torch.stack([b["ego_6d"] for b in batch], dim=0),
432
- "intr_vec": torch.stack([b["intr_vec"] for b in batch], dim=0),
433
- "extr_6d": torch.stack([b["extr_6d"] for b in batch], dim=0),
434
- "ego_future": torch.stack([b["ego_future"] for b in batch], dim=0),
435
- "ego_future_valid": torch.stack([b["ego_future_valid"] for b in batch], dim=0),
436
- "targets": [b["targets"] for b in batch],
437
- "meta": [b["meta"] for b in batch],
438
- }
439
- return out
 
1
+ """Cosmos-Drive-Dreams 数据集加载器(真实实现)。
2
+
3
+ 期待目录结构(从 NVIDIA 提供的 .tar 解压):
4
+
5
+ data_root/
6
+ synthetic/single_view/
7
+ generation/{clip_id}_{chunk_id}_{weather}.mp4 # 121 帧合成视频
8
+ labels/{clip_id}/
9
+ vehicle_pose/000000.vehicle_pose.npy ... # 30 FPS, FLU
10
+ pose/000000.pose.{camera}.npy # 30 FPS, OpenCV
11
+ ftheta_intrinsic/ftheta_intrinsic.{camera}.npy
12
+ all_object_info/000000.all_object_info.json
13
+ lidar_raw/000000.lidar_raw.npz # 10 FPS
14
+
15
+ 每段 clip 提供:
16
+ - 视频按 `_chunk_id` 分块。chunk_id=0 对应 label idx 0..120;chunk_id=1 对应 label idx 121..241。
17
+ - 每个样本:8 帧不重叠窗口 t∈[7, 96],输入 8 帧(t-7..t)+ 未来 24 帧标签。
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ from dataclasses import dataclass
24
+ from pathlib import Path
25
+ from typing import Sequence
26
+
27
+ import cv2
28
+ import numpy as np
29
+ import torch
30
+ from torch.utils.data import Dataset
31
+
32
+ from ..modules.normalization import symlog
33
+ from ..modules.rays import FThetaCamera
34
+ from .label_paths import resolve_clip_file
35
+ from .hdmap import parse_hdmap_clip
36
+ from .se3 import matrix_to_6d
37
+ from .targets import (
38
+ ObjectTrackInfo,
39
+ build_detection_targets,
40
+ build_ego_future_target,
41
+ )
42
+ from .transforms import DINOV3_MEAN, DINOV3_STD
43
+
44
+
45
+ # 数据集 README 列出的对象类型;动态类用于 is_dynamic + 未来轨迹监督。
46
+ DEFAULT_DYNAMIC_CLASSES = [
47
+ "Automobile",
48
+ "Heavy_truck",
49
+ "Bus",
50
+ "Train_or_tram_car",
51
+ "Trolley_bus",
52
+ "Other_vehicle",
53
+ "Trailer",
54
+ "Person",
55
+ "Stroller",
56
+ "Rider",
57
+ "Animal",
58
+ "Protruding_object",
59
+ ]
60
+
61
+ # 结构化场景类(与 ``hdmap.py`` 的 9 个 HDMAP_SOURCES key 一一对应)。
62
+ DEFAULT_STRUCTURED_CLASSES = [
63
+ "lane",
64
+ "laneline",
65
+ "road_boundary",
66
+ "wait_line",
67
+ "crosswalk",
68
+ "road_marking",
69
+ "pole",
70
+ "traffic_light",
71
+ "traffic_sign",
72
+ ]
73
+
74
+
75
+ @dataclass
76
+ class ClipSample:
77
+ """clip 索引项。"""
78
+
79
+ clip_id: str
80
+ chunk_id: int
81
+ weather: str
82
+ video_path: Path
83
+ labels_dir: Path
84
+ anchor_t: int # 当前帧(含),范围 [7, 96]
85
+ chunk_offset: int # 当前 chunk 在标签里的起始 idx(0 或 121)
86
+
87
+
88
+ def build_clip_index(
89
+ data_root: str | Path,
90
+ weathers: Sequence[str] = ("Sunny",),
91
+ chunk_ids: Sequence[int] = (0, 1),
92
+ camera_name: str = "camera_front_wide_120fov",
93
+ stride: int = 8,
94
+ anchor_min: int = 7,
95
+ anchor_max: int = 96,
96
+ max_clips: int | None = None,
97
+ ) -> list[ClipSample]:
98
+ """枚举所有可用 (clip, chunk, weather, anchor_t) 样本。
99
+
100
+ 锚点 ``t`` 在 chunk 内为局部索引,对应视频帧 ``t``,对应标签帧
101
+ ``chunk_offset + t``(chunk_offset = chunk_id * 121)。
102
+ """
103
+ root = Path(data_root)
104
+ syn_dir = root / "synthetic" / "single_view" / "generation"
105
+ labels_dir = root / "labels"
106
+
107
+ samples: list[ClipSample] = []
108
+ if not syn_dir.exists():
109
+ return samples
110
+
111
+ for video_path in sorted(syn_dir.glob("*.mp4")):
112
+ # 文件名形如 {clip_id}_{chunk_id}_{weather}.mp4
113
+ # clip_id 可能含下划线(UUID 或 timestamp 形式),所以从右侧解析
114
+ stem = video_path.stem
115
+ parts = stem.rsplit("_", 2)
116
+ if len(parts) != 3:
117
+ continue
118
+ clip_id, chunk_str, weather = parts
119
+ try:
120
+ chunk_id = int(chunk_str)
121
+ except ValueError:
122
+ continue
123
+ if chunk_id not in chunk_ids or weather not in weathers:
124
+ continue
125
+
126
+ clip_label_dir = labels_dir / clip_id
127
+ if not clip_label_dir.exists():
128
+ continue
129
+
130
+ chunk_offset = chunk_id * 121
131
+ for t in range(anchor_min, anchor_max + 1, stride):
132
+ samples.append(
133
+ ClipSample(
134
+ clip_id=clip_id,
135
+ chunk_id=chunk_id,
136
+ weather=weather,
137
+ video_path=video_path,
138
+ labels_dir=clip_label_dir,
139
+ anchor_t=t,
140
+ chunk_offset=chunk_offset,
141
+ )
142
+ )
143
+ if max_clips is not None and len({s.clip_id for s in samples}) >= max_clips:
144
+ break
145
+
146
+ return samples
147
+
148
+
149
+ def _load_video_frames(
150
+ video_path: Path,
151
+ frame_indices: Sequence[int],
152
+ target_h: int,
153
+ target_w: int,
154
+ ) -> torch.Tensor:
155
+ """从 .mp4 中读取指定帧序列,调整大小并按 ``[T, 3, H, W]`` 返回 ``float32 in [0, 1]``。"""
156
+ cap = cv2.VideoCapture(str(video_path))
157
+ if not cap.isOpened():
158
+ raise FileNotFoundError(f"无法打开视频: {video_path}")
159
+ frames = []
160
+ for idx in frame_indices:
161
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
162
+ ok, bgr = cap.read()
163
+ if not ok:
164
+ cap.release()
165
+ raise RuntimeError(f"读取帧 {idx} 失败: {video_path}")
166
+ rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
167
+ rgb = cv2.resize(rgb, (target_w, target_h * 2), interpolation=cv2.INTER_AREA)
168
+ # 裁去上半部分(天空)后高度变为 target_h
169
+ rgb = rgb[target_h:, :, :]
170
+ rgb = rgb.astype(np.float32) / 255.0
171
+ frames.append(torch.from_numpy(rgb).permute(2, 0, 1)) # [3, H, W]
172
+ cap.release()
173
+ return torch.stack(frames, dim=0)
174
+
175
+
176
+ def _load_npy(path: Path) -> np.ndarray:
177
+ return np.load(path, allow_pickle=False)
178
+
179
+
180
+ def _load_object_info(path: Path) -> list[ObjectTrackInfo]:
181
+ """解析单帧 all_object_info JSON。"""
182
+ if not path.exists():
183
+ return []
184
+ data = json.loads(path.read_text())
185
+ out = []
186
+ for tid, info in data.items():
187
+ T = torch.tensor(info["object_to_world"], dtype=torch.float32)
188
+ lwh = torch.tensor(info["object_lwh"], dtype=torch.float32)
189
+ out.append(
190
+ ObjectTrackInfo(
191
+ tracking_id=tid,
192
+ object_to_world=T,
193
+ lwh=lwh,
194
+ is_moving=bool(info.get("object_is_moving", False)),
195
+ object_type=str(info.get("object_type", "")),
196
+ )
197
+ )
198
+ return out
199
+
200
+
201
+ def _load_lidar_self_frame(
202
+ labels_dir: Path,
203
+ label_idx: int,
204
+ vehicle_pose: torch.Tensor,
205
+ max_history: int = 3,
206
+ ) -> torch.Tensor | None:
207
+ """读取与 ``label_idx`` 时间最近的 LIDAR 帧并把 xyz 转到当前 ego self 系。
208
+
209
+ LIDAR 是 10 FPS(每 3 个相机帧 1 个 LIDAR 帧),数据集存储 ``000000``、
210
+ ``000003``、``000006`` 等步长 3 的索引。我们向下取整最近的一帧。
211
+ """
212
+ lidar_idx = (label_idx // 3) * 3
213
+ search_order = [lidar_idx - back * 3 for back in range(max_history + 1) if lidar_idx - back * 3 >= 0]
214
+ p: Path | None = None
215
+ for idx_try in search_order:
216
+ try:
217
+ p = resolve_clip_file(labels_dir, "lidar_raw", f"{idx_try:06d}.lidar_raw.npz")
218
+ break
219
+ except FileNotFoundError:
220
+ continue
221
+ if p is None:
222
+ return None
223
+ arr = np.load(p, allow_pickle=False)
224
+ xyz_lidar = arr["xyz"] # [N, 3] in lidar frame
225
+ lidar_to_world = arr["lidar_to_world"] # [4, 4]
226
+ # 转到 world 后再转 self
227
+ pts_w = (lidar_to_world[:3, :3] @ xyz_lidar.T).T + lidar_to_world[:3, 3]
228
+ inv_pose = torch.linalg.inv(vehicle_pose)
229
+ pts_w_t = torch.from_numpy(pts_w).float()
230
+ pts_self = (inv_pose[:3, :3] @ pts_w_t.T).T + inv_pose[:3, 3]
231
+ return pts_self
232
+
233
+
234
+ class CosmosDriveDreamsDataset(Dataset):
235
+ """端到端样本:8 帧图像 + ego/intr/extr + 检测 + 自车未来 + 对象未来。"""
236
+
237
+ def __init__(
238
+ self,
239
+ data_root: str | Path,
240
+ samples: list[ClipSample] | None = None,
241
+ weathers: Sequence[str] = ("Sunny",),
242
+ camera_name: str = "camera_front_wide_120fov",
243
+ image_h: int = 384,
244
+ image_w: int = 1024,
245
+ num_history: int = 8,
246
+ future_horizon: int = 24,
247
+ max_distance_m: float = 48.0,
248
+ occlusion_tol: float = 0.5,
249
+ dynamic_classes: Sequence[str] = DEFAULT_DYNAMIC_CLASSES,
250
+ structured_classes: Sequence[str] = DEFAULT_STRUCTURED_CLASSES,
251
+ do_normalize: bool = True,
252
+ use_lidar_occlusion: bool = True,
253
+ use_hdmap: bool = True,
254
+ ) -> None:
255
+ super().__init__()
256
+ self.data_root = Path(data_root)
257
+ self.samples = samples if samples is not None else build_clip_index(
258
+ data_root, weathers=weathers, camera_name=camera_name
259
+ )
260
+ self.camera_name = camera_name
261
+ self.image_h = image_h
262
+ self.image_w = image_w
263
+ self.num_history = num_history
264
+ self.future_horizon = future_horizon
265
+ self.max_distance_m = max_distance_m
266
+ self.occlusion_tol = occlusion_tol
267
+ self.dynamic_classes = list(dynamic_classes)
268
+ self.structured_classes = list(structured_classes)
269
+ self.do_normalize = do_normalize
270
+ self.use_lidar_occlusion = use_lidar_occlusion
271
+ self.use_hdmap = use_hdmap
272
+ # HDMap 是 per-clip 静态对象,缓存避免每个 anchor_t 都重新解析
273
+ self._hdmap_cache: dict[str, list[ObjectTrackInfo]] = {}
274
+ self._hdmap_cache_max = 32
275
+
276
+ def __len__(self) -> int:
277
+ return len(self.samples)
278
+
279
+ def _load_intrinsic(self, sample: ClipSample) -> torch.Tensor:
280
+ p = resolve_clip_file(
281
+ sample.labels_dir,
282
+ "ftheta_intrinsic",
283
+ f"ftheta_intrinsic.{self.camera_name}.npy",
284
+ )
285
+ return torch.from_numpy(_load_npy(p)).float()
286
+
287
+ def _load_pose_camera(self, sample: ClipSample, label_idx: int) -> torch.Tensor:
288
+ p = resolve_clip_file(
289
+ sample.labels_dir,
290
+ "pose",
291
+ f"{label_idx:06d}.pose.{self.camera_name}.npy",
292
+ )
293
+ return torch.from_numpy(_load_npy(p)).float()
294
+
295
+ def _load_pose_vehicle(self, sample: ClipSample, label_idx: int) -> torch.Tensor:
296
+ p = resolve_clip_file(
297
+ sample.labels_dir,
298
+ "vehicle_pose",
299
+ f"{label_idx:06d}.vehicle_pose.npy",
300
+ )
301
+ return torch.from_numpy(_load_npy(p)).float()
302
+
303
+ def _load_hdmap_static(self, clip_dir: Path) -> list[ObjectTrackInfo]:
304
+ if not self.use_hdmap:
305
+ return []
306
+ key = str(clip_dir)
307
+ cached = self._hdmap_cache.get(key)
308
+ if cached is not None:
309
+ return cached
310
+ objs = parse_hdmap_clip(clip_dir)
311
+ if len(self._hdmap_cache) >= self._hdmap_cache_max:
312
+ self._hdmap_cache.pop(next(iter(self._hdmap_cache)))
313
+ self._hdmap_cache[key] = objs
314
+ return objs
315
+
316
+ def _load_objects(self, sample: ClipSample, label_idx: int) -> list[ObjectTrackInfo]:
317
+ p = resolve_clip_file(
318
+ sample.labels_dir,
319
+ "all_object_info",
320
+ f"{label_idx:06d}.all_object_info.json",
321
+ )
322
+ dynamic = _load_object_info(p)
323
+ # HDMap 是 clip 级静态标签:t 与 t+k 帧都拿同一份(tracking_id 相同),
324
+ # 这样 ``build_detection_targets`` 的未来轨迹分支会自动得到 ~0 残差,
325
+ # 同时由 ``is_dynamic=0`` 在损失里被 mask 掉,不进 trajectory NLL。
326
+ return dynamic + self._load_hdmap_static(sample.labels_dir)
327
+
328
+ def __getitem__(self, idx: int) -> dict:
329
+ s = self.samples[idx]
330
+ # 视频帧索引(chunk 内 0-based)
331
+ t = s.anchor_t
332
+ history_frames = list(range(t - self.num_history + 1, t + 1))
333
+ # 标签索引:chunk_offset + chunk-local idx
334
+ history_label_idx = [s.chunk_offset + f for f in history_frames]
335
+ future_label_idx = [s.chunk_offset + t + 1 + k for k in range(self.future_horizon)]
336
+
337
+ # === 1) 加载图像 ===
338
+ # 注意:videl 已经裁过上半(数据生成时仍 1920x1080 等原始分辨率);
339
+ # 这里在 _load_video_frames 内同时做 resize 与 top-half 裁剪。
340
+ images = _load_video_frames(s.video_path, history_frames, self.image_h, self.image_w)
341
+ # [T, 3, H, W],[0, 1]
342
+ if self.do_normalize:
343
+ images = (images - DINOV3_MEAN) / DINOV3_STD
344
+
345
+ # === 2) 加载内参 / 外参 ===
346
+ intr_vec = self._load_intrinsic(s) # [14]
347
+
348
+ # 当前帧的 cam_to_world 与 vehicle_to_world,得到 cam_to_vehicle
349
+ pose_cam_world = self._load_pose_camera(s, s.chunk_offset + t)
350
+ pose_veh_world = self._load_pose_vehicle(s, s.chunk_offset + t)
351
+ # cam_to_vehicle = inv(vehicle_to_world) @ cam_to_world
352
+ inv_veh = torch.linalg.inv(pose_veh_world)
353
+ cam2veh = inv_veh @ pose_cam_world
354
+ extr_6d = matrix_to_6d(cam2veh) # [6]
355
+
356
+ # === 3) 历史 8 帧 ego pose(vehicle 6D)===
357
+ ego_6d = []
358
+ for li in history_label_idx:
359
+ T_vw = self._load_pose_vehicle(s, li)
360
+ ego_6d.append(matrix_to_6d(T_vw))
361
+ ego_6d = torch.stack(ego_6d, dim=0) # [8, 6]
362
+
363
+ # === 4) 检测 / 未来轨迹标签 ===
364
+ # objs_t / objs_future = 动态 all_object_info ∪ HDMap 静态对象。
365
+ objs_t = self._load_objects(s, s.chunk_offset + t)
366
+ objs_future = [self._load_objects(s, li) for li in future_label_idx]
367
+ veh_pose_future = []
368
+ for li in future_label_idx:
369
+ try:
370
+ veh_pose_future.append(self._load_pose_vehicle(s, li))
371
+ except FileNotFoundError:
372
+ break
373
+
374
+ cam = FThetaCamera.from_vector(intr_vec)
375
+ lidar_self = None
376
+ if self.use_lidar_occlusion:
377
+ try:
378
+ lidar_self = _load_lidar_self_frame(
379
+ s.labels_dir,
380
+ s.chunk_offset + t,
381
+ pose_veh_world,
382
+ )
383
+ except Exception:
384
+ lidar_self = None
385
+
386
+ det_targets = build_detection_targets(
387
+ objects_t=objs_t,
388
+ objects_future=objs_future,
389
+ vehicle_pose_t=pose_veh_world,
390
+ vehicle_pose_future=veh_pose_future,
391
+ cam_intrinsic=cam,
392
+ cam2vehicle=cam2veh,
393
+ image_h=self.image_h,
394
+ image_w=self.image_w,
395
+ max_distance_m=self.max_distance_m,
396
+ occlusion_depth_tolerance=self.occlusion_tol,
397
+ lidar_points_self=lidar_self,
398
+ dynamic_classes=self.dynamic_classes,
399
+ structured_classes=self.structured_classes,
400
+ future_horizon=self.future_horizon,
401
+ )
402
+
403
+ ego_future, ego_future_valid = build_ego_future_target(
404
+ pose_veh_world, veh_pose_future, horizon=self.future_horizon
405
+ )
406
+
407
+ sample_out = {
408
+ "images": images,
409
+ "ego_6d": ego_6d,
410
+ "intr_vec": intr_vec,
411
+ "extr_6d": extr_6d,
412
+ "ego_future": ego_future,
413
+ "ego_future_valid": ego_future_valid,
414
+ "targets": det_targets,
415
+ "meta": {
416
+ "clip_id": s.clip_id,
417
+ "chunk_id": s.chunk_id,
418
+ "weather": s.weather,
419
+ "anchor_t": s.anchor_t,
420
+ },
421
+ }
422
+ return sample_out
423
+
424
+
425
+ def collate_samples(batch: list[dict]) -> dict:
426
+ """自定义 collate:对图像 / ego / intr / extr / ego_future 直接 stack;
427
+ targets 列表保留为 list(便于匈牙利匹配处理变长 N);
428
+ meta 也保留为 list。"""
429
+ out = {
430
+ "images": torch.stack([b["images"] for b in batch], dim=0),
431
+ "ego_6d": torch.stack([b["ego_6d"] for b in batch], dim=0),
432
+ "intr_vec": torch.stack([b["intr_vec"] for b in batch], dim=0),
433
+ "extr_6d": torch.stack([b["extr_6d"] for b in batch], dim=0),
434
+ "ego_future": torch.stack([b["ego_future"] for b in batch], dim=0),
435
+ "ego_future_valid": torch.stack([b["ego_future_valid"] for b in batch], dim=0),
436
+ "targets": [b["targets"] for b in batch],
437
+ "meta": [b["meta"] for b in batch],
438
+ }
439
+ return out
src/wjad/data/ftheta_proj.py CHANGED
@@ -1,62 +1,62 @@
1
- """f-theta 正向投影:3D 点(相机系) -> 像素。
2
-
3
- 仅支持 backward polynomial 形式(与 NVIDIA 工具一致),
4
- forward 形式可在内部用牛顿迭代反推。
5
- """
6
-
7
- from __future__ import annotations
8
-
9
- import torch
10
-
11
- from ..modules.rays import FThetaCamera
12
-
13
-
14
- def project_points_ftheta(
15
- points_cam: torch.Tensor, # [..., 3],相机系下 3D 点
16
- cam: FThetaCamera,
17
- ) -> tuple[torch.Tensor, torch.Tensor]:
18
- """正向投影:相机系点 -> 像素 ``(u, v)``,并返回深度。
19
-
20
- 返回
21
- ----
22
- uv : [..., 2]
23
- depth : [..., 1],沿主光轴(z)方向的深度(如果 z<0 则视为后方,仍计算
24
- 但调用方需用 ``depth > 0`` 做有效性筛选)。
25
- """
26
- x = points_cam[..., 0]
27
- y = points_cam[..., 1]
28
- z = points_cam[..., 2]
29
- norm = torch.sqrt(x * x + y * y + z * z).clamp_min(1e-6)
30
- cos_theta = z / norm
31
- cos_theta = cos_theta.clamp(-1.0 + 1e-7, 1.0 - 1e-7)
32
- theta = torch.acos(cos_theta)
33
- phi = torch.atan2(y, x)
34
-
35
- if cam.intr.is_bw_poly:
36
- # backward poly 是 r_pix -> theta;正向需要反求 theta -> r_pix。
37
- # 用牛顿迭代:希望 _eval_poly(r) = theta
38
- r = theta.clone() # 初始猜测
39
- for _ in range(8):
40
- f = cam._eval_poly(r) - theta
41
- df = cam._eval_poly_grad(r).clamp_min(1e-6)
42
- r = r - f / df
43
- r_pix = r
44
- else:
45
- r_pix = cam._eval_poly(theta)
46
-
47
- cos_p = torch.cos(phi)
48
- sin_p = torch.sin(phi)
49
- du = r_pix * cos_p
50
- dv = r_pix * sin_p
51
- # 反线性修正:linear_cde 是仿射 (du,dv) = M (du0,dv0),正投影需要逆
52
- c = cam.intr.linear_cde[0]
53
- d = cam.intr.linear_cde[1]
54
- e = cam.intr.linear_cde[2]
55
- # 简化:忽略 linear_cde 的修正(与 unproject 中近似一致)
56
- du0 = du
57
- dv0 = dv
58
- u = du0 + cam.intr.cx
59
- v = dv0 + cam.intr.cy
60
- uv = torch.stack([u, v], dim=-1)
61
- depth = z.unsqueeze(-1)
62
- return uv, depth
 
1
+ """f-theta 正向投影:3D 点(相机系) -> 像素。
2
+
3
+ 仅支持 backward polynomial 形式(与 NVIDIA 工具一致),
4
+ forward 形式可在内部用牛顿迭代反推。
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch
10
+
11
+ from ..modules.rays import FThetaCamera
12
+
13
+
14
+ def project_points_ftheta(
15
+ points_cam: torch.Tensor, # [..., 3],相机系下 3D 点
16
+ cam: FThetaCamera,
17
+ ) -> tuple[torch.Tensor, torch.Tensor]:
18
+ """正向投影:相机系点 -> 像素 ``(u, v)``,并返回深度。
19
+
20
+ 返回
21
+ ----
22
+ uv : [..., 2]
23
+ depth : [..., 1],沿主光轴(z)方向的深度(如果 z<0 则视为后方,仍计算
24
+ 但调用方需用 ``depth > 0`` 做有效性筛选)。
25
+ """
26
+ x = points_cam[..., 0]
27
+ y = points_cam[..., 1]
28
+ z = points_cam[..., 2]
29
+ norm = torch.sqrt(x * x + y * y + z * z).clamp_min(1e-6)
30
+ cos_theta = z / norm
31
+ cos_theta = cos_theta.clamp(-1.0 + 1e-7, 1.0 - 1e-7)
32
+ theta = torch.acos(cos_theta)
33
+ phi = torch.atan2(y, x)
34
+
35
+ if cam.intr.is_bw_poly:
36
+ # backward poly 是 r_pix -> theta;正向需要反求 theta -> r_pix。
37
+ # 用牛顿迭代:希望 _eval_poly(r) = theta
38
+ r = theta.clone() # 初始猜测
39
+ for _ in range(8):
40
+ f = cam._eval_poly(r) - theta
41
+ df = cam._eval_poly_grad(r).clamp_min(1e-6)
42
+ r = r - f / df
43
+ r_pix = r
44
+ else:
45
+ r_pix = cam._eval_poly(theta)
46
+
47
+ cos_p = torch.cos(phi)
48
+ sin_p = torch.sin(phi)
49
+ du = r_pix * cos_p
50
+ dv = r_pix * sin_p
51
+ # 反线性修正:linear_cde 是仿射 (du,dv) = M (du0,dv0),正投影需要逆
52
+ c = cam.intr.linear_cde[0]
53
+ d = cam.intr.linear_cde[1]
54
+ e = cam.intr.linear_cde[2]
55
+ # 简化:忽略 linear_cde 的修正(与 unproject 中近似一致)
56
+ du0 = du
57
+ dv0 = dv
58
+ u = du0 + cam.intr.cx
59
+ v = dv0 + cam.intr.cy
60
+ uv = torch.stack([u, v], dim=-1)
61
+ depth = z.unsqueeze(-1)
62
+ return uv, depth
src/wjad/data/hdmap.py CHANGED
@@ -1,247 +1,247 @@
1
- """HDMap 3D 标签解析(Cosmos-Drive-Dreams 9 类结构化对象)。
2
-
3
- 输入:clip 标签目录(``labels/{clip_id_full}/``)。
4
- 输出:``list[ObjectTrackInfo]``,每个对象给出 ``object_to_world`` 4x4 + ``lwh``,
5
- ``object_type`` 取自 ``HDMAP_SOURCES`` 的 9 类,``is_moving=False``。
6
-
7
- 形状约定(按 README):
8
- - 3d_lanes / lanes.json
9
- labels[i]['labelData']['shape3d']['polylines3d']['polylines'][0/1]['vertices']
10
- - 3d_lanelines / lanelines.json
11
- labels[i]['labelData']['shape3d']['polyline3d']['vertices']
12
- - 3d_road_boundaries / road_boundaries.json 同 polyline3d
13
- - 3d_wait_lines / wait_lines.json 同 polyline3d
14
- - 3d_crosswalks / crosswalks.json
15
- labels[i]['labelData']['shape3d']['surface']['vertices']
16
- - 3d_road_markings / road_markings.json 同 surface
17
- - 3d_poles / poles.json 同 polyline3d
18
- - 3d_traffic_lights / 3d_traffic_lights.json
19
- labels[i]['labelData']['shape3d']['cuboid3d']['vertices'] # 8 角点
20
- - 3d_traffic_signs / 3d_traffic_signs.json 同 cuboid3d
21
-
22
- 折线 → 7-DoF box:
23
- PCA 主方向作 yaw,主/副/竖三向 min-max 作 ``l/w/h``;过长 polyline 按累计
24
- 弧长切成若干 ``segment_len`` 米的小段,每段一个独立 box(车道线一段太长会
25
- 超出 max_distance_m,DETR query 也很难一次拟合一整条 100 m 车道线)。
26
- """
27
-
28
- from __future__ import annotations
29
-
30
- import json
31
- from pathlib import Path
32
-
33
- import numpy as np
34
- import torch
35
-
36
- from .targets import ObjectTrackInfo
37
- from .label_paths import resolve_clip_file
38
-
39
-
40
- # 折线类长度切分阈值(米)
41
- POLYLINE_SEGMENT_LEN = 10.0
42
- LANE_SEGMENT_LEN = 15.0 # lanes 是一对左右 polyline,整体粗一点
43
- MIN_LWH = (0.2, 0.2, 0.05)
44
-
45
-
46
- # cls_name -> (folder, json_name, kind)
47
- HDMAP_SOURCES = {
48
- "lane": ("3d_lanes", "lanes.json", "lane_pair"),
49
- "laneline": ("3d_lanelines", "lanelines.json", "polyline"),
50
- "road_boundary": ("3d_road_boundaries", "road_boundaries.json", "polyline"),
51
- "wait_line": ("3d_wait_lines", "wait_lines.json", "polyline"),
52
- "crosswalk": ("3d_crosswalks", "crosswalks.json", "surface"),
53
- "road_marking": ("3d_road_markings", "road_markings.json", "surface"),
54
- "pole": ("3d_poles", "poles.json", "polyline_short"),
55
- # 磁盘文件名为 ``{clip_stem}.traffic_lights.json``(非 README 里的 3d_*.json)
56
- "traffic_light": ("3d_traffic_lights", "traffic_lights.json", "cuboid"),
57
- "traffic_sign": ("3d_traffic_signs", "traffic_signs.json", "cuboid"),
58
- }
59
-
60
-
61
- def _load_json_labels(path: Path) -> list:
62
- """容错读取:JSON 顶层可能是 ``{labels: ...}`` 或 ``{<filename>: {labels: ...}}``。"""
63
- if not path.exists():
64
- return []
65
- try:
66
- data = json.loads(path.read_text(encoding="utf-8"))
67
- except Exception:
68
- return []
69
- if isinstance(data, dict):
70
- if isinstance(data.get("labels"), list):
71
- return data["labels"]
72
- for v in data.values():
73
- if isinstance(v, dict) and isinstance(v.get("labels"), list):
74
- return v["labels"]
75
- return []
76
-
77
-
78
- def _verts_to_array(verts) -> np.ndarray:
79
- """vertices 兼容 ``list[[x,y,z]]`` 与 ``list[{x,y,z}]`` 两种格式。"""
80
- if not verts:
81
- return np.zeros((0, 3), dtype=np.float32)
82
- out: list[list[float]] = []
83
- for v in verts:
84
- if isinstance(v, dict):
85
- out.append([float(v.get("x", 0.0)), float(v.get("y", 0.0)), float(v.get("z", 0.0))])
86
- elif isinstance(v, (list, tuple)) and len(v) >= 3:
87
- out.append([float(v[0]), float(v[1]), float(v[2])])
88
- return np.array(out, dtype=np.float32) if out else np.zeros((0, 3), dtype=np.float32)
89
-
90
-
91
- def _split_polyline(verts: np.ndarray, seg_len: float) -> list[np.ndarray]:
92
- """按累计弧长把折线切成若干段。每段顶点数 >=2。"""
93
- if verts.shape[0] < 2:
94
- return []
95
- edges = np.linalg.norm(np.diff(verts, axis=0), axis=1)
96
- cum = np.concatenate([[0.0], np.cumsum(edges)])
97
- total = float(cum[-1])
98
- if total <= seg_len:
99
- return [verts]
100
- n = max(1, int(np.ceil(total / seg_len)))
101
- bounds = np.linspace(0.0, total, n + 1)
102
- chunks: list[np.ndarray] = []
103
- for i in range(n):
104
- lo, hi = bounds[i], bounds[i + 1]
105
- mask = (cum >= lo - 1e-6) & (cum <= hi + 1e-6)
106
- chunk = verts[mask]
107
- if chunk.shape[0] >= 2:
108
- chunks.append(chunk)
109
- return chunks
110
-
111
-
112
- def _vertices_to_box(verts: np.ndarray) -> tuple[np.ndarray, np.ndarray, float] | None:
113
- """[N, 3] -> (center, lwh, yaw)。"""
114
- if verts.shape[0] < 2:
115
- return None
116
- center = verts.mean(0)
117
- centered_xy = verts[:, :2] - center[:2]
118
- if np.allclose(centered_xy, 0.0):
119
- yaw = 0.0
120
- else:
121
- cov = centered_xy.T @ centered_xy / max(verts.shape[0] - 1, 1)
122
- _, eigvecs = np.linalg.eigh(cov)
123
- principal = eigvecs[:, -1]
124
- yaw = float(np.arctan2(principal[1], principal[0]))
125
- c, s = float(np.cos(-yaw)), float(np.sin(-yaw))
126
- rot_xy = centered_xy @ np.array([[c, -s], [s, c]], dtype=np.float32).T
127
- l = float(rot_xy[:, 0].max() - rot_xy[:, 0].min())
128
- w = float(rot_xy[:, 1].max() - rot_xy[:, 1].min())
129
- h = float(verts[:, 2].max() - verts[:, 2].min())
130
- l = max(l, MIN_LWH[0])
131
- w = max(w, MIN_LWH[1])
132
- h = max(h, MIN_LWH[2])
133
- return center.astype(np.float32), np.array([l, w, h], dtype=np.float32), yaw
134
-
135
-
136
- def _cuboid_to_box(corners: np.ndarray) -> tuple[np.ndarray, np.ndarray, float]:
137
- """8 角点 -> (center, lwh, yaw)。用 corner[0]→corner[1] 估计 yaw。"""
138
- center = corners.mean(0)
139
- edge = corners[1] - corners[0]
140
- yaw = float(np.arctan2(edge[1], edge[0]))
141
- c, s = float(np.cos(-yaw)), float(np.sin(-yaw))
142
- R = np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32)
143
- rot = (corners - center) @ R.T
144
- lwh = (rot.max(0) - rot.min(0)).astype(np.float32)
145
- lwh = np.maximum(lwh, np.array(MIN_LWH, dtype=np.float32))
146
- return center.astype(np.float32), lwh, yaw
147
-
148
-
149
- def _build_object(
150
- center: np.ndarray,
151
- lwh: np.ndarray,
152
- yaw: float,
153
- cls_name: str,
154
- idx: int,
155
- sub_idx: int = 0,
156
- ) -> ObjectTrackInfo:
157
- T = np.eye(4, dtype=np.float32)
158
- c, s = float(np.cos(yaw)), float(np.sin(yaw))
159
- T[:3, :3] = np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32)
160
- T[:3, 3] = center
161
- return ObjectTrackInfo(
162
- tracking_id=f"hdmap_{cls_name}_{idx}_{sub_idx}",
163
- object_to_world=torch.from_numpy(T),
164
- lwh=torch.from_numpy(lwh),
165
- is_moving=False,
166
- object_type=cls_name,
167
- )
168
-
169
-
170
- def parse_hdmap_clip(
171
- clip_label_dir: Path,
172
- segment_len: float = POLYLINE_SEGMENT_LEN,
173
- lane_segment_len: float = LANE_SEGMENT_LEN,
174
- ) -> list[ObjectTrackInfo]:
175
- """解析一个 clip 的 9 类 HDMap,展开为 world-frame ``ObjectTrackInfo`` 列表。"""
176
- out: list[ObjectTrackInfo] = []
177
- for cls_name, (subdir, json_name, kind) in HDMAP_SOURCES.items():
178
- try:
179
- path = resolve_clip_file(clip_label_dir, subdir, json_name)
180
- except FileNotFoundError:
181
- continue
182
- labels = _load_json_labels(path)
183
- for i, lbl in enumerate(labels):
184
- if not isinstance(lbl, dict):
185
- continue
186
- shape = lbl.get("labelData", {}).get("shape3d", {})
187
- if not isinstance(shape, dict):
188
- continue
189
-
190
- if kind == "cuboid":
191
- verts = shape.get("cuboid3d", {}).get("vertices", [])
192
- arr = _verts_to_array(verts)
193
- if arr.shape[0] != 8:
194
- continue
195
- c, lwh, yaw = _cuboid_to_box(arr)
196
- out.append(_build_object(c, lwh, yaw, cls_name, i))
197
-
198
- elif kind == "surface":
199
- verts = shape.get("surface", {}).get("vertices", [])
200
- arr = _verts_to_array(verts)
201
- if arr.shape[0] < 3:
202
- continue
203
- box = _vertices_to_box(arr)
204
- if box is not None:
205
- out.append(_build_object(*box, cls_name, i))
206
-
207
- elif kind == "polyline":
208
- verts = shape.get("polyline3d", {}).get("vertices", [])
209
- arr = _verts_to_array(verts)
210
- if arr.shape[0] < 2:
211
- continue
212
- for j, chunk in enumerate(_split_polyline(arr, segment_len)):
213
- box = _vertices_to_box(chunk)
214
- if box is not None:
215
- out.append(_build_object(*box, cls_name, i, j))
216
-
217
- elif kind == "polyline_short":
218
- # 杆状物体不切分
219
- verts = shape.get("polyline3d", {}).get("vertices", [])
220
- arr = _verts_to_array(verts)
221
- if arr.shape[0] < 2:
222
- continue
223
- box = _vertices_to_box(arr)
224
- if box is not None:
225
- out.append(_build_object(*box, cls_name, i))
226
-
227
- elif kind == "lane_pair":
228
- pl_root = shape.get("polylines3d", {}).get("polylines", [])
229
- if not isinstance(pl_root, list) or len(pl_root) < 2:
230
- continue
231
- left = _verts_to_array(
232
- pl_root[0].get("vertices", []) if isinstance(pl_root[0], dict) else []
233
- )
234
- right = _verts_to_array(
235
- pl_root[1].get("vertices", []) if isinstance(pl_root[1], dict) else []
236
- )
237
- if left.shape[0] == 0 and right.shape[0] == 0:
238
- continue
239
- merged = np.concatenate([a for a in (left, right) if a.shape[0]], axis=0)
240
- if merged.shape[0] < 2:
241
- continue
242
- for j, chunk in enumerate(_split_polyline(merged, lane_segment_len)):
243
- box = _vertices_to_box(chunk)
244
- if box is not None:
245
- out.append(_build_object(*box, cls_name, i, j))
246
-
247
- return out
 
1
+ """HDMap 3D 标签解析(Cosmos-Drive-Dreams 9 类结构化对象)。
2
+
3
+ 输入:clip 标签目录(``labels/{clip_id_full}/``)。
4
+ 输出:``list[ObjectTrackInfo]``,每个对象给出 ``object_to_world`` 4x4 + ``lwh``,
5
+ ``object_type`` 取自 ``HDMAP_SOURCES`` 的 9 类,``is_moving=False``。
6
+
7
+ 形状约定(按 README):
8
+ - 3d_lanes / lanes.json
9
+ labels[i]['labelData']['shape3d']['polylines3d']['polylines'][0/1]['vertices']
10
+ - 3d_lanelines / lanelines.json
11
+ labels[i]['labelData']['shape3d']['polyline3d']['vertices']
12
+ - 3d_road_boundaries / road_boundaries.json 同 polyline3d
13
+ - 3d_wait_lines / wait_lines.json 同 polyline3d
14
+ - 3d_crosswalks / crosswalks.json
15
+ labels[i]['labelData']['shape3d']['surface']['vertices']
16
+ - 3d_road_markings / road_markings.json 同 surface
17
+ - 3d_poles / poles.json 同 polyline3d
18
+ - 3d_traffic_lights / 3d_traffic_lights.json
19
+ labels[i]['labelData']['shape3d']['cuboid3d']['vertices'] # 8 角点
20
+ - 3d_traffic_signs / 3d_traffic_signs.json 同 cuboid3d
21
+
22
+ 折线 → 7-DoF box:
23
+ PCA 主方向作 yaw,主/副/竖三向 min-max 作 ``l/w/h``;过长 polyline 按累计
24
+ 弧长切成若干 ``segment_len`` 米的小段,每段一个独立 box(车道线一段太长会
25
+ 超出 max_distance_m,DETR query 也很难一次拟合一整条 100 m 车道线)。
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ import json
31
+ from pathlib import Path
32
+
33
+ import numpy as np
34
+ import torch
35
+
36
+ from .targets import ObjectTrackInfo
37
+ from .label_paths import resolve_clip_file
38
+
39
+
40
+ # 折线类长度切分阈值(米)
41
+ POLYLINE_SEGMENT_LEN = 10.0
42
+ LANE_SEGMENT_LEN = 15.0 # lanes 是一对左右 polyline,整体粗一点
43
+ MIN_LWH = (0.2, 0.2, 0.05)
44
+
45
+
46
+ # cls_name -> (folder, json_name, kind)
47
+ HDMAP_SOURCES = {
48
+ "lane": ("3d_lanes", "lanes.json", "lane_pair"),
49
+ "laneline": ("3d_lanelines", "lanelines.json", "polyline"),
50
+ "road_boundary": ("3d_road_boundaries", "road_boundaries.json", "polyline"),
51
+ "wait_line": ("3d_wait_lines", "wait_lines.json", "polyline"),
52
+ "crosswalk": ("3d_crosswalks", "crosswalks.json", "surface"),
53
+ "road_marking": ("3d_road_markings", "road_markings.json", "surface"),
54
+ "pole": ("3d_poles", "poles.json", "polyline_short"),
55
+ # 磁盘文件名为 ``{clip_stem}.traffic_lights.json``(非 README 里的 3d_*.json)
56
+ "traffic_light": ("3d_traffic_lights", "traffic_lights.json", "cuboid"),
57
+ "traffic_sign": ("3d_traffic_signs", "traffic_signs.json", "cuboid"),
58
+ }
59
+
60
+
61
+ def _load_json_labels(path: Path) -> list:
62
+ """容错读取:JSON 顶层可能是 ``{labels: ...}`` 或 ``{<filename>: {labels: ...}}``。"""
63
+ if not path.exists():
64
+ return []
65
+ try:
66
+ data = json.loads(path.read_text(encoding="utf-8"))
67
+ except Exception:
68
+ return []
69
+ if isinstance(data, dict):
70
+ if isinstance(data.get("labels"), list):
71
+ return data["labels"]
72
+ for v in data.values():
73
+ if isinstance(v, dict) and isinstance(v.get("labels"), list):
74
+ return v["labels"]
75
+ return []
76
+
77
+
78
+ def _verts_to_array(verts) -> np.ndarray:
79
+ """vertices 兼容 ``list[[x,y,z]]`` 与 ``list[{x,y,z}]`` 两种格式。"""
80
+ if not verts:
81
+ return np.zeros((0, 3), dtype=np.float32)
82
+ out: list[list[float]] = []
83
+ for v in verts:
84
+ if isinstance(v, dict):
85
+ out.append([float(v.get("x", 0.0)), float(v.get("y", 0.0)), float(v.get("z", 0.0))])
86
+ elif isinstance(v, (list, tuple)) and len(v) >= 3:
87
+ out.append([float(v[0]), float(v[1]), float(v[2])])
88
+ return np.array(out, dtype=np.float32) if out else np.zeros((0, 3), dtype=np.float32)
89
+
90
+
91
+ def _split_polyline(verts: np.ndarray, seg_len: float) -> list[np.ndarray]:
92
+ """按累计弧长把折线切成若干段。每段顶点数 >=2。"""
93
+ if verts.shape[0] < 2:
94
+ return []
95
+ edges = np.linalg.norm(np.diff(verts, axis=0), axis=1)
96
+ cum = np.concatenate([[0.0], np.cumsum(edges)])
97
+ total = float(cum[-1])
98
+ if total <= seg_len:
99
+ return [verts]
100
+ n = max(1, int(np.ceil(total / seg_len)))
101
+ bounds = np.linspace(0.0, total, n + 1)
102
+ chunks: list[np.ndarray] = []
103
+ for i in range(n):
104
+ lo, hi = bounds[i], bounds[i + 1]
105
+ mask = (cum >= lo - 1e-6) & (cum <= hi + 1e-6)
106
+ chunk = verts[mask]
107
+ if chunk.shape[0] >= 2:
108
+ chunks.append(chunk)
109
+ return chunks
110
+
111
+
112
+ def _vertices_to_box(verts: np.ndarray) -> tuple[np.ndarray, np.ndarray, float] | None:
113
+ """[N, 3] -> (center, lwh, yaw)。"""
114
+ if verts.shape[0] < 2:
115
+ return None
116
+ center = verts.mean(0)
117
+ centered_xy = verts[:, :2] - center[:2]
118
+ if np.allclose(centered_xy, 0.0):
119
+ yaw = 0.0
120
+ else:
121
+ cov = centered_xy.T @ centered_xy / max(verts.shape[0] - 1, 1)
122
+ _, eigvecs = np.linalg.eigh(cov)
123
+ principal = eigvecs[:, -1]
124
+ yaw = float(np.arctan2(principal[1], principal[0]))
125
+ c, s = float(np.cos(-yaw)), float(np.sin(-yaw))
126
+ rot_xy = centered_xy @ np.array([[c, -s], [s, c]], dtype=np.float32).T
127
+ l = float(rot_xy[:, 0].max() - rot_xy[:, 0].min())
128
+ w = float(rot_xy[:, 1].max() - rot_xy[:, 1].min())
129
+ h = float(verts[:, 2].max() - verts[:, 2].min())
130
+ l = max(l, MIN_LWH[0])
131
+ w = max(w, MIN_LWH[1])
132
+ h = max(h, MIN_LWH[2])
133
+ return center.astype(np.float32), np.array([l, w, h], dtype=np.float32), yaw
134
+
135
+
136
+ def _cuboid_to_box(corners: np.ndarray) -> tuple[np.ndarray, np.ndarray, float]:
137
+ """8 角点 -> (center, lwh, yaw)。用 corner[0]→corner[1] 估计 yaw。"""
138
+ center = corners.mean(0)
139
+ edge = corners[1] - corners[0]
140
+ yaw = float(np.arctan2(edge[1], edge[0]))
141
+ c, s = float(np.cos(-yaw)), float(np.sin(-yaw))
142
+ R = np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32)
143
+ rot = (corners - center) @ R.T
144
+ lwh = (rot.max(0) - rot.min(0)).astype(np.float32)
145
+ lwh = np.maximum(lwh, np.array(MIN_LWH, dtype=np.float32))
146
+ return center.astype(np.float32), lwh, yaw
147
+
148
+
149
+ def _build_object(
150
+ center: np.ndarray,
151
+ lwh: np.ndarray,
152
+ yaw: float,
153
+ cls_name: str,
154
+ idx: int,
155
+ sub_idx: int = 0,
156
+ ) -> ObjectTrackInfo:
157
+ T = np.eye(4, dtype=np.float32)
158
+ c, s = float(np.cos(yaw)), float(np.sin(yaw))
159
+ T[:3, :3] = np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32)
160
+ T[:3, 3] = center
161
+ return ObjectTrackInfo(
162
+ tracking_id=f"hdmap_{cls_name}_{idx}_{sub_idx}",
163
+ object_to_world=torch.from_numpy(T),
164
+ lwh=torch.from_numpy(lwh),
165
+ is_moving=False,
166
+ object_type=cls_name,
167
+ )
168
+
169
+
170
+ def parse_hdmap_clip(
171
+ clip_label_dir: Path,
172
+ segment_len: float = POLYLINE_SEGMENT_LEN,
173
+ lane_segment_len: float = LANE_SEGMENT_LEN,
174
+ ) -> list[ObjectTrackInfo]:
175
+ """解析一个 clip 的 9 类 HDMap,展开为 world-frame ``ObjectTrackInfo`` 列表。"""
176
+ out: list[ObjectTrackInfo] = []
177
+ for cls_name, (subdir, json_name, kind) in HDMAP_SOURCES.items():
178
+ try:
179
+ path = resolve_clip_file(clip_label_dir, subdir, json_name)
180
+ except FileNotFoundError:
181
+ continue
182
+ labels = _load_json_labels(path)
183
+ for i, lbl in enumerate(labels):
184
+ if not isinstance(lbl, dict):
185
+ continue
186
+ shape = lbl.get("labelData", {}).get("shape3d", {})
187
+ if not isinstance(shape, dict):
188
+ continue
189
+
190
+ if kind == "cuboid":
191
+ verts = shape.get("cuboid3d", {}).get("vertices", [])
192
+ arr = _verts_to_array(verts)
193
+ if arr.shape[0] != 8:
194
+ continue
195
+ c, lwh, yaw = _cuboid_to_box(arr)
196
+ out.append(_build_object(c, lwh, yaw, cls_name, i))
197
+
198
+ elif kind == "surface":
199
+ verts = shape.get("surface", {}).get("vertices", [])
200
+ arr = _verts_to_array(verts)
201
+ if arr.shape[0] < 3:
202
+ continue
203
+ box = _vertices_to_box(arr)
204
+ if box is not None:
205
+ out.append(_build_object(*box, cls_name, i))
206
+
207
+ elif kind == "polyline":
208
+ verts = shape.get("polyline3d", {}).get("vertices", [])
209
+ arr = _verts_to_array(verts)
210
+ if arr.shape[0] < 2:
211
+ continue
212
+ for j, chunk in enumerate(_split_polyline(arr, segment_len)):
213
+ box = _vertices_to_box(chunk)
214
+ if box is not None:
215
+ out.append(_build_object(*box, cls_name, i, j))
216
+
217
+ elif kind == "polyline_short":
218
+ # 杆状物体不切分
219
+ verts = shape.get("polyline3d", {}).get("vertices", [])
220
+ arr = _verts_to_array(verts)
221
+ if arr.shape[0] < 2:
222
+ continue
223
+ box = _vertices_to_box(arr)
224
+ if box is not None:
225
+ out.append(_build_object(*box, cls_name, i))
226
+
227
+ elif kind == "lane_pair":
228
+ pl_root = shape.get("polylines3d", {}).get("polylines", [])
229
+ if not isinstance(pl_root, list) or len(pl_root) < 2:
230
+ continue
231
+ left = _verts_to_array(
232
+ pl_root[0].get("vertices", []) if isinstance(pl_root[0], dict) else []
233
+ )
234
+ right = _verts_to_array(
235
+ pl_root[1].get("vertices", []) if isinstance(pl_root[1], dict) else []
236
+ )
237
+ if left.shape[0] == 0 and right.shape[0] == 0:
238
+ continue
239
+ merged = np.concatenate([a for a in (left, right) if a.shape[0]], axis=0)
240
+ if merged.shape[0] < 2:
241
+ continue
242
+ for j, chunk in enumerate(_split_polyline(merged, lane_segment_len)):
243
+ box = _vertices_to_box(chunk)
244
+ if box is not None:
245
+ out.append(_build_object(*box, cls_name, i, j))
246
+
247
+ return out
src/wjad/data/label_paths.py CHANGED
@@ -1,218 +1,218 @@
1
- """数据集标签目录布局解析。
2
-
3
- README 中的 keys 是相对每个 modality 的 ``.tar`` 根目录的扁平路径;
4
- 实际解压后常多一层子目录或 clip stem 前缀。解析失败时在 ``FileNotFoundError``
5
- 里附带目录列表,便于与 Hugging Face 数据集页面中的说明对照。
6
- """
7
-
8
- from __future__ import annotations
9
-
10
- from pathlib import Path
11
-
12
-
13
- def _norm_name(s: str) -> str:
14
- return "".join(c for c in s.lower() if c.isalnum())
15
-
16
-
17
- def _diagnose_labels(labels_dir: Path, folder: str, max_list: int = 50) -> str:
18
- """列出 ``labels/<clip>/<folder>/`` 下文件采样 + clip 根下一级子目录。"""
19
- lines: list[str] = []
20
- sub = labels_dir / folder
21
- if sub.is_dir():
22
- files = sorted(p for p in sub.rglob("*") if p.is_file())
23
- lines.append(f"[{folder}/] 下共 {len(files)} 个文件(最多列出 {max_list} 条相对路径):")
24
- for p in files[:max_list]:
25
- try:
26
- rel = p.relative_to(labels_dir).as_posix()
27
- except ValueError:
28
- rel = str(p)
29
- lines.append(f" {rel}")
30
- if len(files) > max_list:
31
- lines.append(f" ... 另有 {len(files) - max_list} 个文件未列出")
32
- else:
33
- lines.append(f"[{folder}/] 不存在:{sub}")
34
- try:
35
- top = sorted(d.name for d in labels_dir.iterdir() if d.is_dir())
36
- lines.append(f"[labels/<clip>/] 一级子目录:{top}")
37
- except OSError as e:
38
- lines.append(f"[labels/<clip>/] 无法列举:{e}")
39
- return "\n".join(lines)
40
-
41
-
42
- def _scan_npy_json_npz(
43
- labels_dir: Path,
44
- folder: str,
45
- fname: str,
46
- *,
47
- exts: tuple[str, ...] = (".npy",),
48
- tokens_norm: list[str],
49
- name_must_contain: str | None = None,
50
- ) -> list[Path]:
51
- """在整棵 labels/<clip>/ 下找候选文件:扩展名 + 归一化名须含各 token。"""
52
- root_hint = labels_dir / folder
53
- search_roots = [root_hint] if root_hint.is_dir() else []
54
- if not search_roots:
55
- search_roots = [labels_dir]
56
- hits: list[Path] = []
57
- for root in search_roots:
58
- for p in root.rglob("*"):
59
- if not p.is_file():
60
- continue
61
- if not p.suffix.lower() in [e.lower() for e in exts]:
62
- continue
63
- if name_must_contain and name_must_contain.lower() not in p.name.lower():
64
- continue
65
- pn = _norm_name(p.name)
66
- if all(tok in pn for tok in tokens_norm if tok):
67
- hits.append(p)
68
- if not hits and root_hint.is_dir():
69
- for p in labels_dir.rglob("*"):
70
- if not p.is_file() or p.suffix.lower() not in [e.lower() for e in exts]:
71
- continue
72
- if name_must_contain and name_must_contain.lower() not in p.name.lower():
73
- continue
74
- pn = _norm_name(p.name)
75
- if all(tok in pn for tok in tokens_norm if tok):
76
- hits.append(p)
77
- return hits
78
-
79
-
80
- def resolve_clip_file(labels_dir: Path, *parts: str) -> Path:
81
- """在 ``labels/<clip_id>/`` 下解析 ``parts`` 组成的相对路径(首个元素为一级子文件夹)。"""
82
- if not parts:
83
- raise ValueError("parts 不能为空")
84
- if not labels_dir.is_dir():
85
- raise FileNotFoundError(f"clip 标签根目录不存在: {labels_dir}")
86
-
87
- direct = labels_dir.joinpath(*parts)
88
- if direct.is_file():
89
- return direct
90
- # NVIDIA 磁盘命名:``{clip_stem}.{README_key}``,clip_stem = ``labels/<clip>/``
91
- # 解析后的目录名(含 ``uuid_t0_t1``);README 里的 key 本身不含此前缀。
92
- clip_stem = labels_dir.resolve().name
93
- if len(parts) >= 2:
94
- folder, fname = parts[0], parts[-1]
95
- if not fname.startswith(f"{clip_stem}."):
96
- stemmed = labels_dir / folder / f"{clip_stem}.{fname}"
97
- if stemmed.is_file():
98
- return stemmed
99
- if len(parts) >= 2:
100
- folder = parts[0]
101
- rest = parts[1:]
102
- doubled = (labels_dir / folder / folder).joinpath(*rest)
103
- if doubled.is_file():
104
- return doubled
105
- fname = parts[-1]
106
- folder = parts[0]
107
- sub = labels_dir / folder
108
- if sub.is_dir():
109
- for p in sub.rglob(fname):
110
- if p.is_file():
111
- return p
112
-
113
- for p in labels_dir.rglob(fname):
114
- if p.is_file():
115
- return p
116
- fl = fname.lower()
117
- for p in labels_dir.rglob("*"):
118
- if p.is_file() and p.name.lower() == fl:
119
- return p
120
-
121
- # ftheta / pinhole
122
- if folder in ("ftheta_intrinsic", "pinhole_intrinsic") and fname.endswith(".npy"):
123
- prefix = folder + "."
124
- if fname.lower().startswith(prefix.lower()):
125
- cam = fname[len(prefix) : -len(".npy")]
126
- cam_n = _norm_name(cam)
127
- hits = []
128
- for p in labels_dir.rglob("*.npy"):
129
- if not p.is_file():
130
- continue
131
- pn = _norm_name(p.name)
132
- if folder == "ftheta_intrinsic":
133
- if "ftheta" not in pn:
134
- continue
135
- else:
136
- if "pinhole" not in pn:
137
- continue
138
- if cam_n and cam_n in pn:
139
- hits.append(p)
140
- if len(hits) == 1:
141
- return hits[0]
142
- if len(hits) > 1:
143
- hits.sort(key=lambda x: (len(x.parts), str(x)))
144
- return hits[0]
145
-
146
- # pose: ``{idx:06d}.pose.{camera}.npy``
147
- if folder == "pose" and fname.endswith(".npy"):
148
- base = fname[: -len(".npy")]
149
- if ".pose." in base:
150
- idx_part, _, cam_part = base.partition(".pose.")
151
- hits = _scan_npy_json_npz(
152
- labels_dir,
153
- folder,
154
- fname,
155
- exts=(".npy",),
156
- tokens_norm=[_norm_name(idx_part), _norm_name(cam_part)],
157
- name_must_contain="pose",
158
- )
159
- if len(hits) == 1:
160
- return hits[0]
161
- if len(hits) > 1:
162
- hits.sort(key=lambda x: (len(x.parts), -len(x.name), str(x)))
163
- return hits[0]
164
-
165
- # vehicle_pose: ``{idx:06d}.vehicle_pose.npy``
166
- if folder == "vehicle_pose" and fname.endswith(".npy"):
167
- idx_part = fname.split(".")[0]
168
- hits = _scan_npy_json_npz(
169
- labels_dir,
170
- folder,
171
- fname,
172
- exts=(".npy",),
173
- tokens_norm=[_norm_name(idx_part), "vehiclepose"],
174
- name_must_contain="vehicle",
175
- )
176
- if len(hits) == 1:
177
- return hits[0]
178
- if len(hits) > 1:
179
- hits.sort(key=lambda x: (len(x.parts), str(x)))
180
- return hits[0]
181
-
182
- # all_object_info
183
- if folder == "all_object_info" and fname.endswith(".json"):
184
- idx_part = fname.split(".")[0]
185
- hits = _scan_npy_json_npz(
186
- labels_dir,
187
- folder,
188
- fname,
189
- exts=(".json",),
190
- tokens_norm=[_norm_name(idx_part), "allobjectinfo"],
191
- )
192
- if len(hits) == 1:
193
- return hits[0]
194
- if len(hits) > 1:
195
- hits.sort(key=lambda x: (len(x.parts), str(x)))
196
- return hits[0]
197
-
198
- # lidar_raw
199
- if folder == "lidar_raw" and fname.endswith(".npz"):
200
- stem = fname[: -len(".npz")]
201
- hits = _scan_npy_json_npz(
202
- labels_dir,
203
- folder,
204
- fname,
205
- exts=(".npz",),
206
- tokens_norm=[_norm_name(stem), "lidar", "raw"],
207
- )
208
- if len(hits) == 1:
209
- return hits[0]
210
- if len(hits) > 1:
211
- hits.sort(key=lambda x: (len(x.parts), str(x)))
212
- return hits[0]
213
-
214
- detail = _diagnose_labels(labels_dir, folder)
215
- raise FileNotFoundError(
216
- f"在 {labels_dir} 下未找到 {'/'.join(parts)}(已尝试 README 扁平路径、双嵌套、"
217
- f"rglob、按帧索引+相机的扫描匹配)。\n{detail}"
218
- )
 
1
+ """数据集标签目录布局解析。
2
+
3
+ README 中的 keys 是相对每个 modality 的 ``.tar`` 根目录的扁平路径;
4
+ 实际解压后常多一层子目录或 clip stem 前缀。解析失败时在 ``FileNotFoundError``
5
+ 里附带目录列表,便于与 Hugging Face 数据集页面中的说明对照。
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from pathlib import Path
11
+
12
+
13
+ def _norm_name(s: str) -> str:
14
+ return "".join(c for c in s.lower() if c.isalnum())
15
+
16
+
17
+ def _diagnose_labels(labels_dir: Path, folder: str, max_list: int = 50) -> str:
18
+ """列出 ``labels/<clip>/<folder>/`` 下文件采样 + clip 根下一级子目录。"""
19
+ lines: list[str] = []
20
+ sub = labels_dir / folder
21
+ if sub.is_dir():
22
+ files = sorted(p for p in sub.rglob("*") if p.is_file())
23
+ lines.append(f"[{folder}/] 下共 {len(files)} 个文件(最多列出 {max_list} 条相对路径):")
24
+ for p in files[:max_list]:
25
+ try:
26
+ rel = p.relative_to(labels_dir).as_posix()
27
+ except ValueError:
28
+ rel = str(p)
29
+ lines.append(f" {rel}")
30
+ if len(files) > max_list:
31
+ lines.append(f" ... 另有 {len(files) - max_list} 个文件未列出")
32
+ else:
33
+ lines.append(f"[{folder}/] 不存在:{sub}")
34
+ try:
35
+ top = sorted(d.name for d in labels_dir.iterdir() if d.is_dir())
36
+ lines.append(f"[labels/<clip>/] 一级子目录:{top}")
37
+ except OSError as e:
38
+ lines.append(f"[labels/<clip>/] 无法列举:{e}")
39
+ return "\n".join(lines)
40
+
41
+
42
+ def _scan_npy_json_npz(
43
+ labels_dir: Path,
44
+ folder: str,
45
+ fname: str,
46
+ *,
47
+ exts: tuple[str, ...] = (".npy",),
48
+ tokens_norm: list[str],
49
+ name_must_contain: str | None = None,
50
+ ) -> list[Path]:
51
+ """在整棵 labels/<clip>/ 下找候选文件:扩展名 + 归一化名须含各 token。"""
52
+ root_hint = labels_dir / folder
53
+ search_roots = [root_hint] if root_hint.is_dir() else []
54
+ if not search_roots:
55
+ search_roots = [labels_dir]
56
+ hits: list[Path] = []
57
+ for root in search_roots:
58
+ for p in root.rglob("*"):
59
+ if not p.is_file():
60
+ continue
61
+ if not p.suffix.lower() in [e.lower() for e in exts]:
62
+ continue
63
+ if name_must_contain and name_must_contain.lower() not in p.name.lower():
64
+ continue
65
+ pn = _norm_name(p.name)
66
+ if all(tok in pn for tok in tokens_norm if tok):
67
+ hits.append(p)
68
+ if not hits and root_hint.is_dir():
69
+ for p in labels_dir.rglob("*"):
70
+ if not p.is_file() or p.suffix.lower() not in [e.lower() for e in exts]:
71
+ continue
72
+ if name_must_contain and name_must_contain.lower() not in p.name.lower():
73
+ continue
74
+ pn = _norm_name(p.name)
75
+ if all(tok in pn for tok in tokens_norm if tok):
76
+ hits.append(p)
77
+ return hits
78
+
79
+
80
+ def resolve_clip_file(labels_dir: Path, *parts: str) -> Path:
81
+ """在 ``labels/<clip_id>/`` 下解析 ``parts`` 组成的相对路径(首个元素为一级子文件夹)。"""
82
+ if not parts:
83
+ raise ValueError("parts 不能为空")
84
+ if not labels_dir.is_dir():
85
+ raise FileNotFoundError(f"clip 标签根目录不存在: {labels_dir}")
86
+
87
+ direct = labels_dir.joinpath(*parts)
88
+ if direct.is_file():
89
+ return direct
90
+ # NVIDIA 磁盘命名:``{clip_stem}.{README_key}``,clip_stem = ``labels/<clip>/``
91
+ # 解析后的目录名(含 ``uuid_t0_t1``);README 里的 key 本身不含此前缀。
92
+ clip_stem = labels_dir.resolve().name
93
+ if len(parts) >= 2:
94
+ folder, fname = parts[0], parts[-1]
95
+ if not fname.startswith(f"{clip_stem}."):
96
+ stemmed = labels_dir / folder / f"{clip_stem}.{fname}"
97
+ if stemmed.is_file():
98
+ return stemmed
99
+ if len(parts) >= 2:
100
+ folder = parts[0]
101
+ rest = parts[1:]
102
+ doubled = (labels_dir / folder / folder).joinpath(*rest)
103
+ if doubled.is_file():
104
+ return doubled
105
+ fname = parts[-1]
106
+ folder = parts[0]
107
+ sub = labels_dir / folder
108
+ if sub.is_dir():
109
+ for p in sub.rglob(fname):
110
+ if p.is_file():
111
+ return p
112
+
113
+ for p in labels_dir.rglob(fname):
114
+ if p.is_file():
115
+ return p
116
+ fl = fname.lower()
117
+ for p in labels_dir.rglob("*"):
118
+ if p.is_file() and p.name.lower() == fl:
119
+ return p
120
+
121
+ # ftheta / pinhole
122
+ if folder in ("ftheta_intrinsic", "pinhole_intrinsic") and fname.endswith(".npy"):
123
+ prefix = folder + "."
124
+ if fname.lower().startswith(prefix.lower()):
125
+ cam = fname[len(prefix) : -len(".npy")]
126
+ cam_n = _norm_name(cam)
127
+ hits = []
128
+ for p in labels_dir.rglob("*.npy"):
129
+ if not p.is_file():
130
+ continue
131
+ pn = _norm_name(p.name)
132
+ if folder == "ftheta_intrinsic":
133
+ if "ftheta" not in pn:
134
+ continue
135
+ else:
136
+ if "pinhole" not in pn:
137
+ continue
138
+ if cam_n and cam_n in pn:
139
+ hits.append(p)
140
+ if len(hits) == 1:
141
+ return hits[0]
142
+ if len(hits) > 1:
143
+ hits.sort(key=lambda x: (len(x.parts), str(x)))
144
+ return hits[0]
145
+
146
+ # pose: ``{idx:06d}.pose.{camera}.npy``
147
+ if folder == "pose" and fname.endswith(".npy"):
148
+ base = fname[: -len(".npy")]
149
+ if ".pose." in base:
150
+ idx_part, _, cam_part = base.partition(".pose.")
151
+ hits = _scan_npy_json_npz(
152
+ labels_dir,
153
+ folder,
154
+ fname,
155
+ exts=(".npy",),
156
+ tokens_norm=[_norm_name(idx_part), _norm_name(cam_part)],
157
+ name_must_contain="pose",
158
+ )
159
+ if len(hits) == 1:
160
+ return hits[0]
161
+ if len(hits) > 1:
162
+ hits.sort(key=lambda x: (len(x.parts), -len(x.name), str(x)))
163
+ return hits[0]
164
+
165
+ # vehicle_pose: ``{idx:06d}.vehicle_pose.npy``
166
+ if folder == "vehicle_pose" and fname.endswith(".npy"):
167
+ idx_part = fname.split(".")[0]
168
+ hits = _scan_npy_json_npz(
169
+ labels_dir,
170
+ folder,
171
+ fname,
172
+ exts=(".npy",),
173
+ tokens_norm=[_norm_name(idx_part), "vehiclepose"],
174
+ name_must_contain="vehicle",
175
+ )
176
+ if len(hits) == 1:
177
+ return hits[0]
178
+ if len(hits) > 1:
179
+ hits.sort(key=lambda x: (len(x.parts), str(x)))
180
+ return hits[0]
181
+
182
+ # all_object_info
183
+ if folder == "all_object_info" and fname.endswith(".json"):
184
+ idx_part = fname.split(".")[0]
185
+ hits = _scan_npy_json_npz(
186
+ labels_dir,
187
+ folder,
188
+ fname,
189
+ exts=(".json",),
190
+ tokens_norm=[_norm_name(idx_part), "allobjectinfo"],
191
+ )
192
+ if len(hits) == 1:
193
+ return hits[0]
194
+ if len(hits) > 1:
195
+ hits.sort(key=lambda x: (len(x.parts), str(x)))
196
+ return hits[0]
197
+
198
+ # lidar_raw
199
+ if folder == "lidar_raw" and fname.endswith(".npz"):
200
+ stem = fname[: -len(".npz")]
201
+ hits = _scan_npy_json_npz(
202
+ labels_dir,
203
+ folder,
204
+ fname,
205
+ exts=(".npz",),
206
+ tokens_norm=[_norm_name(stem), "lidar", "raw"],
207
+ )
208
+ if len(hits) == 1:
209
+ return hits[0]
210
+ if len(hits) > 1:
211
+ hits.sort(key=lambda x: (len(x.parts), str(x)))
212
+ return hits[0]
213
+
214
+ detail = _diagnose_labels(labels_dir, folder)
215
+ raise FileNotFoundError(
216
+ f"在 {labels_dir} 下未找到 {'/'.join(parts)}(已尝试 README 扁平路径、双嵌套、"
217
+ f"rglob、按帧索引+相机的扫描匹配)。\n{detail}"
218
+ )
src/wjad/data/se3.py CHANGED
@@ -1,111 +1,111 @@
1
- """SE(3) 与 6D 表示之间的转换。
2
-
3
- 约定:6D = ``[tx, ty, tz, rx, ry, rz]``,rotation 为轴角向量(``angle * axis``)。
4
- 平移单位为米;旋转角弧度。
5
- """
6
-
7
- from __future__ import annotations
8
-
9
- import numpy as np
10
- import torch
11
-
12
-
13
- def rotation_matrix_to_axis_angle(R: torch.Tensor | np.ndarray) -> torch.Tensor:
14
- """3x3 旋转矩阵 -> 轴角向量 ``[3]`` (=angle * axis),支持 batch。
15
-
16
- 使用 Rodrigues 公式数值反求。
17
- """
18
- if isinstance(R, np.ndarray):
19
- R = torch.from_numpy(R).float()
20
- if R.dim() == 2:
21
- R = R.unsqueeze(0)
22
- single = True
23
- else:
24
- single = False
25
-
26
- trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
27
- cos_theta = ((trace - 1.0) * 0.5).clamp(-1.0 + 1e-7, 1.0 - 1e-7)
28
- theta = torch.acos(cos_theta) # [B]
29
-
30
- # 提取轴向量
31
- rx = R[..., 2, 1] - R[..., 1, 2]
32
- ry = R[..., 0, 2] - R[..., 2, 0]
33
- rz = R[..., 1, 0] - R[..., 0, 1]
34
- axis = torch.stack([rx, ry, rz], dim=-1)
35
- sin_theta = torch.sin(theta).clamp_min(1e-7)
36
- axis = axis / (2.0 * sin_theta).unsqueeze(-1)
37
-
38
- aa = axis * theta.unsqueeze(-1)
39
- if single:
40
- aa = aa.squeeze(0)
41
- return aa
42
-
43
-
44
- def axis_angle_to_rotation_matrix(aa: torch.Tensor) -> torch.Tensor:
45
- """轴角向量 ``[..., 3]`` -> 旋转矩阵 ``[..., 3, 3]``(Rodrigues)。"""
46
- theta = aa.norm(dim=-1, keepdim=True).clamp_min(1e-9) # [..., 1]
47
- axis = aa / theta
48
- x, y, z = axis[..., 0], axis[..., 1], axis[..., 2]
49
- sin_t = torch.sin(theta.squeeze(-1))
50
- cos_t = torch.cos(theta.squeeze(-1))
51
- one_c = 1.0 - cos_t
52
-
53
- R = torch.stack(
54
- [
55
- cos_t + x * x * one_c, x * y * one_c - z * sin_t, x * z * one_c + y * sin_t,
56
- y * x * one_c + z * sin_t, cos_t + y * y * one_c, y * z * one_c - x * sin_t,
57
- z * x * one_c - y * sin_t, z * y * one_c + x * sin_t, cos_t + z * z * one_c,
58
- ],
59
- dim=-1,
60
- ).reshape(*aa.shape[:-1], 3, 3)
61
- return R
62
-
63
-
64
- def matrix_to_6d(T: torch.Tensor | np.ndarray) -> torch.Tensor:
65
- """4x4 SE(3) -> 6D ``[tx, ty, tz, rx, ry, rz]``。"""
66
- if isinstance(T, np.ndarray):
67
- T = torch.from_numpy(T).float()
68
- if T.dim() == 2:
69
- T = T.unsqueeze(0)
70
- single = True
71
- else:
72
- single = False
73
-
74
- R = T[..., :3, :3]
75
- t = T[..., :3, 3]
76
- aa = rotation_matrix_to_axis_angle(R)
77
- six = torch.cat([t, aa], dim=-1)
78
- if single:
79
- six = six.squeeze(0)
80
- return six
81
-
82
-
83
- def six_d_to_matrix(six: torch.Tensor) -> torch.Tensor:
84
- """6D -> 4x4 SE(3)。"""
85
- if six.dim() == 1:
86
- six = six.unsqueeze(0)
87
- single = True
88
- else:
89
- single = False
90
- t = six[..., :3]
91
- aa = six[..., 3:]
92
- R = axis_angle_to_rotation_matrix(aa)
93
- T = torch.zeros(*six.shape[:-1], 4, 4, dtype=six.dtype, device=six.device)
94
- T[..., :3, :3] = R
95
- T[..., :3, 3] = t
96
- T[..., 3, 3] = 1.0
97
- if single:
98
- T = T.squeeze(0)
99
- return T
100
-
101
-
102
- def invert_se3(T: torch.Tensor) -> torch.Tensor:
103
- """4x4 SE(3) 逆,``[..., 4, 4]``。"""
104
- R = T[..., :3, :3]
105
- t = T[..., :3, 3:4]
106
- Rt = R.transpose(-2, -1)
107
- inv = torch.zeros_like(T)
108
- inv[..., :3, :3] = Rt
109
- inv[..., :3, 3:4] = -Rt @ t
110
- inv[..., 3, 3] = 1.0
111
- return inv
 
1
+ """SE(3) 与 6D 表示之间的转换。
2
+
3
+ 约定:6D = ``[tx, ty, tz, rx, ry, rz]``,rotation 为轴角向量(``angle * axis``)。
4
+ 平移单位为米;旋转角弧度。
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+
13
+ def rotation_matrix_to_axis_angle(R: torch.Tensor | np.ndarray) -> torch.Tensor:
14
+ """3x3 旋转矩阵 -> 轴角向量 ``[3]`` (=angle * axis),支持 batch。
15
+
16
+ 使用 Rodrigues 公式数值反求。
17
+ """
18
+ if isinstance(R, np.ndarray):
19
+ R = torch.from_numpy(R).float()
20
+ if R.dim() == 2:
21
+ R = R.unsqueeze(0)
22
+ single = True
23
+ else:
24
+ single = False
25
+
26
+ trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
27
+ cos_theta = ((trace - 1.0) * 0.5).clamp(-1.0 + 1e-7, 1.0 - 1e-7)
28
+ theta = torch.acos(cos_theta) # [B]
29
+
30
+ # 提取轴向量
31
+ rx = R[..., 2, 1] - R[..., 1, 2]
32
+ ry = R[..., 0, 2] - R[..., 2, 0]
33
+ rz = R[..., 1, 0] - R[..., 0, 1]
34
+ axis = torch.stack([rx, ry, rz], dim=-1)
35
+ sin_theta = torch.sin(theta).clamp_min(1e-7)
36
+ axis = axis / (2.0 * sin_theta).unsqueeze(-1)
37
+
38
+ aa = axis * theta.unsqueeze(-1)
39
+ if single:
40
+ aa = aa.squeeze(0)
41
+ return aa
42
+
43
+
44
+ def axis_angle_to_rotation_matrix(aa: torch.Tensor) -> torch.Tensor:
45
+ """轴角向量 ``[..., 3]`` -> 旋转矩阵 ``[..., 3, 3]``(Rodrigues)。"""
46
+ theta = aa.norm(dim=-1, keepdim=True).clamp_min(1e-9) # [..., 1]
47
+ axis = aa / theta
48
+ x, y, z = axis[..., 0], axis[..., 1], axis[..., 2]
49
+ sin_t = torch.sin(theta.squeeze(-1))
50
+ cos_t = torch.cos(theta.squeeze(-1))
51
+ one_c = 1.0 - cos_t
52
+
53
+ R = torch.stack(
54
+ [
55
+ cos_t + x * x * one_c, x * y * one_c - z * sin_t, x * z * one_c + y * sin_t,
56
+ y * x * one_c + z * sin_t, cos_t + y * y * one_c, y * z * one_c - x * sin_t,
57
+ z * x * one_c - y * sin_t, z * y * one_c + x * sin_t, cos_t + z * z * one_c,
58
+ ],
59
+ dim=-1,
60
+ ).reshape(*aa.shape[:-1], 3, 3)
61
+ return R
62
+
63
+
64
+ def matrix_to_6d(T: torch.Tensor | np.ndarray) -> torch.Tensor:
65
+ """4x4 SE(3) -> 6D ``[tx, ty, tz, rx, ry, rz]``。"""
66
+ if isinstance(T, np.ndarray):
67
+ T = torch.from_numpy(T).float()
68
+ if T.dim() == 2:
69
+ T = T.unsqueeze(0)
70
+ single = True
71
+ else:
72
+ single = False
73
+
74
+ R = T[..., :3, :3]
75
+ t = T[..., :3, 3]
76
+ aa = rotation_matrix_to_axis_angle(R)
77
+ six = torch.cat([t, aa], dim=-1)
78
+ if single:
79
+ six = six.squeeze(0)
80
+ return six
81
+
82
+
83
+ def six_d_to_matrix(six: torch.Tensor) -> torch.Tensor:
84
+ """6D -> 4x4 SE(3)。"""
85
+ if six.dim() == 1:
86
+ six = six.unsqueeze(0)
87
+ single = True
88
+ else:
89
+ single = False
90
+ t = six[..., :3]
91
+ aa = six[..., 3:]
92
+ R = axis_angle_to_rotation_matrix(aa)
93
+ T = torch.zeros(*six.shape[:-1], 4, 4, dtype=six.dtype, device=six.device)
94
+ T[..., :3, :3] = R
95
+ T[..., :3, 3] = t
96
+ T[..., 3, 3] = 1.0
97
+ if single:
98
+ T = T.squeeze(0)
99
+ return T
100
+
101
+
102
+ def invert_se3(T: torch.Tensor) -> torch.Tensor:
103
+ """4x4 SE(3) 逆,``[..., 4, 4]``。"""
104
+ R = T[..., :3, :3]
105
+ t = T[..., :3, 3:4]
106
+ Rt = R.transpose(-2, -1)
107
+ inv = torch.zeros_like(T)
108
+ inv[..., :3, :3] = Rt
109
+ inv[..., :3, 3:4] = -Rt @ t
110
+ inv[..., 3, 3] = 1.0
111
+ return inv
src/wjad/data/targets.py CHANGED
@@ -1,214 +1,214 @@
1
- """检测 / 自车未来轨迹的目标构建。
2
-
3
- 依据 Cosmos-Drive-Dreams 数据集 README:
4
- all_object_info JSON 中以 ``tracking_id`` 为 key,存储
5
- ``{object_to_world: 4x4, object_lwh: [l,w,h], object_is_moving: bool, object_type: str}``。
6
-
7
- 构建步骤:
8
- 1. 把每个对象的 ``object_to_world`` 转到 t 时刻自车系:
9
- object_to_self = inv(vehicle_pose_t) @ object_to_world
10
- 2. 距离 ``≤ max_distance_m`` 过滤;
11
- 3. 投影中心点到当前帧像素,要求落在视锥内;
12
- 4. 用 LIDAR 深度对比做遮挡剔除(粗粒度);
13
- 5. 对动态目标,从 t+1..t+24 帧逐帧获取其 ``object_to_world``,转到 t 自车系,
14
- 提取 (dx, dy, dyaw) 并做 symlog 归一作为未来轨迹 GT;缺帧时 ``valid=0``。
15
-
16
- 为方便与 head 输出对齐,最终输出格式:
17
- {"labels": [N], "boxes": [N, 7], "is_dynamic": [N],
18
- "future_traj": [N, 24, 3], "future_valid": [N, 24]}
19
- """
20
-
21
- from __future__ import annotations
22
-
23
- from dataclasses import dataclass
24
-
25
- import numpy as np
26
- import torch
27
-
28
- from ..modules.normalization import symlog
29
- from ..modules.rays import FThetaCamera
30
- from .ftheta_proj import project_points_ftheta
31
- from .se3 import invert_se3
32
-
33
-
34
- @dataclass
35
- class ObjectTrackInfo:
36
- """单个对在某帧的简化记录。"""
37
-
38
- tracking_id: str
39
- object_to_world: torch.Tensor # [4, 4]
40
- lwh: torch.Tensor # [3]
41
- is_moving: bool
42
- object_type: str
43
-
44
-
45
- def _yaw_from_rotation_matrix(R: torch.Tensor) -> torch.Tensor:
46
- """从 3x3 旋转矩阵提取自车系下绕 z 轴的 yaw 角。
47
-
48
- 使用 ``atan2(R[1,0], R[0,0])``。
49
- """
50
- return torch.atan2(R[..., 1, 0], R[..., 0, 0])
51
-
52
-
53
- def _make_class_index(object_type: str, dynamic_classes: list[str], structured_classes: list[str], background_idx: int = 0) -> tuple[int, int]:
54
- """根据 object_type 字符串映射到 (class_index, is_dynamic)。"""
55
- if object_type in dynamic_classes:
56
- return dynamic_classes.index(object_type) + 1, 1 # +1 为 background 留 idx 0
57
- if object_type in structured_classes:
58
- return len(dynamic_classes) + structured_classes.index(object_type) + 1, 0
59
- return background_idx, 0 # 未知类型当 background
60
-
61
-
62
- def build_detection_targets(
63
- objects_t: list[ObjectTrackInfo],
64
- objects_future: list[list[ObjectTrackInfo]], # len = future_horizon,每帧一个对象列表
65
- vehicle_pose_t: torch.Tensor, # [4, 4],vehicle to world
66
- vehicle_pose_future: list[torch.Tensor], # 每帧一个 4x4
67
- cam_intrinsic: FThetaCamera,
68
- cam2vehicle: torch.Tensor, # [4, 4]
69
- image_h: int,
70
- image_w: int,
71
- max_distance_m: float = 48.0,
72
- occlusion_depth_tolerance: float = 0.5,
73
- lidar_points_self: torch.Tensor | None = None, # [P, 3] in self frame,做粗遮挡
74
- dynamic_classes: list[str] | None = None,
75
- structured_classes: list[str] | None = None,
76
- future_horizon: int = 24,
77
- ) -> dict:
78
- """构建一个样本的检测+未来轨迹标签。"""
79
- if dynamic_classes is None:
80
- dynamic_classes = []
81
- if structured_classes is None:
82
- structured_classes = []
83
-
84
- inv_pose_t = invert_se3(vehicle_pose_t)
85
- vehicle2cam = invert_se3(cam2vehicle)
86
-
87
- labels: list[int] = []
88
- boxes: list[list[float]] = []
89
- is_dynamic: list[int] = []
90
- future_traj: list[list[list[float]]] = []
91
- future_valid: list[list[int]] = []
92
-
93
- for obj in objects_t:
94
- T_obj_self = inv_pose_t @ obj.object_to_world # [4,4]
95
- center_self = T_obj_self[:3, 3]
96
-
97
- dist = float(center_self.norm().item())
98
- if dist > max_distance_m:
99
- continue
100
-
101
- # 视锥裁剪:把中心投影到相机系再投影到像素
102
- center_cam = (vehicle2cam @ torch.cat([center_self, torch.ones(1)])[None].T).squeeze(-1)[:3]
103
- if center_cam[2].item() <= 0:
104
- continue
105
- uv, depth = project_points_ftheta(center_cam.unsqueeze(0), cam_intrinsic)
106
- u, v = uv[0, 0].item(), uv[0, 1].item()
107
- if not (0 <= u < image_w and 0 <= v < image_h):
108
- continue
109
-
110
- # LIDAR 遮挡:找到 LIDAR 中靠近当前射线方向的最近点深度,与对象深度对比
111
- if lidar_points_self is not None and lidar_points_self.numel() > 0:
112
- ray = center_self / (center_self.norm() + 1e-6)
113
- proj = lidar_points_self @ ray # [P]
114
- # 选取沿射线方向投影距离接近 dist 的点(容差 1m,水平角 5°)
115
- cosang = (lidar_points_self / (lidar_points_self.norm(dim=-1, keepdim=True) + 1e-6)) @ ray
116
- mask = (cosang > 0.996) & (proj > 0)
117
- if mask.any():
118
- lidar_depth = proj[mask].min().item()
119
- if lidar_depth + occlusion_depth_tolerance < dist:
120
- # LIDAR 击中前方更近物体 -> 当前对象被遮挡
121
- continue
122
-
123
- # 类别映射
124
- cls_idx, is_dyn = _make_class_index(obj.object_type, dynamic_classes, structured_classes)
125
- if cls_idx == 0:
126
- continue
127
- labels.append(cls_idx)
128
- is_dynamic.append(is_dyn)
129
-
130
- yaw = _yaw_from_rotation_matrix(T_obj_self[:3, :3]).item()
131
- l, w, h = obj.lwh.tolist()
132
- # box 坐标 symlog 归一
133
- x_n, y_n, z_n = (
134
- float(symlog(center_self[0]).item()),
135
- float(symlog(center_self[1]).item()),
136
- float(symlog(center_self[2]).item()),
137
- )
138
- l_n = float(symlog(torch.tensor(l)).item())
139
- w_n = float(symlog(torch.tensor(w)).item())
140
- h_n = float(symlog(torch.tensor(h)).item())
141
- boxes.append([x_n, y_n, z_n, l_n, w_n, h_n, yaw])
142
-
143
- # 未来轨迹:在当前 self 系下用 (dx, dy, dyaw),相对 t 时刻对象自身
144
- # 先取 t 时刻对象在 self 系下的 (x_t, y_t, yaw_t)
145
- x0, y0, yaw0 = center_self[0].item(), center_self[1].item(), yaw
146
- future_3 = []
147
- future_v = []
148
- for k in range(future_horizon):
149
- if k >= len(objects_future) or k >= len(vehicle_pose_future):
150
- future_3.append([0.0, 0.0, 0.0])
151
- future_v.append(0)
152
- continue
153
- # 找对象在 t+k+1 帧
154
- future_objs = objects_future[k]
155
- match = next((o for o in future_objs if o.tracking_id == obj.tracking_id), None)
156
- if match is None:
157
- future_3.append([0.0, 0.0, 0.0])
158
- future_v.append(0)
159
- continue
160
- T_obj_self_future = invert_se3(vehicle_pose_t) @ match.object_to_world
161
- xf = T_obj_self_future[0, 3].item()
162
- yf = T_obj_self_future[1, 3].item()
163
- yawf = _yaw_from_rotation_matrix(T_obj_self_future[:3, :3]).item()
164
- dx = xf - x0
165
- dy = yf - y0
166
- dyaw = yawf - yaw0
167
- # 角度归到 (-pi, pi]
168
- dyaw = (dyaw + np.pi) % (2 * np.pi) - np.pi
169
- future_3.append([
170
- float(symlog(torch.tensor(dx)).item()),
171
- float(symlog(torch.tensor(dy)).item()),
172
- float(dyaw),
173
- ])
174
- future_v.append(1)
175
- future_traj.append(future_3)
176
- future_valid.append(future_v)
177
-
178
- if not labels:
179
- return {
180
- "labels": torch.zeros(0, dtype=torch.long),
181
- "boxes": torch.zeros(0, 7),
182
- "is_dynamic": torch.zeros(0, dtype=torch.long),
183
- "future_traj": torch.zeros(0, future_horizon, 3),
184
- "future_valid": torch.zeros(0, future_horizon, dtype=torch.bool),
185
- }
186
- return {
187
- "labels": torch.tensor(labels, dtype=torch.long),
188
- "boxes": torch.tensor(boxes, dtype=torch.float32),
189
- "is_dynamic": torch.tensor(is_dynamic, dtype=torch.long),
190
- "future_traj": torch.tensor(future_traj, dtype=torch.float32),
191
- "future_valid": torch.tensor(future_valid, dtype=torch.bool),
192
- }
193
-
194
-
195
- def build_ego_future_target(
196
- vehicle_pose_t: torch.Tensor,
197
- vehicle_pose_future: list[torch.Tensor],
198
- horizon: int = 24,
199
- ) -> tuple[torch.Tensor, torch.Tensor]:
200
- """自车未来 24 帧轨迹(在 t 自车系下,``(x, y, yaw)`` 已 symlog 归一)。"""
201
- inv_t = invert_se3(vehicle_pose_t)
202
- out = torch.zeros(horizon, 3)
203
- valid = torch.zeros(horizon, dtype=torch.bool)
204
- for k in range(horizon):
205
- if k >= len(vehicle_pose_future):
206
- break
207
- rel = inv_t @ vehicle_pose_future[k]
208
- x, y = rel[0, 3].item(), rel[1, 3].item()
209
- yaw = _yaw_from_rotation_matrix(rel[:3, :3]).item()
210
- out[k, 0] = symlog(torch.tensor(x))
211
- out[k, 1] = symlog(torch.tensor(y))
212
- out[k, 2] = yaw
213
- valid[k] = True
214
- return out, valid
 
1
+ """检测 / 自车未来轨迹的目标构建。
2
+
3
+ 依据 Cosmos-Drive-Dreams 数据集 README:
4
+ all_object_info JSON 中以 ``tracking_id`` 为 key,存储
5
+ ``{object_to_world: 4x4, object_lwh: [l,w,h], object_is_moving: bool, object_type: str}``。
6
+
7
+ 构建步骤:
8
+ 1. 把每个对象的 ``object_to_world`` 转到 t 时刻自车系:
9
+ object_to_self = inv(vehicle_pose_t) @ object_to_world
10
+ 2. 距离 ``≤ max_distance_m`` 过滤;
11
+ 3. 投影中心点到当前帧像素,要求落在视锥内;
12
+ 4. 用 LIDAR 深度对比做遮挡剔除(粗粒度);
13
+ 5. 对动态目标,从 t+1..t+24 帧逐帧获取其 ``object_to_world``,转到 t 自车系,
14
+ 提取 (dx, dy, dyaw) 并做 symlog 归一作为未来轨迹 GT;缺帧时 ``valid=0``。
15
+
16
+ 为方便与 head 输出对齐,最终输出格式:
17
+ {"labels": [N], "boxes": [N, 7], "is_dynamic": [N],
18
+ "future_traj": [N, 24, 3], "future_valid": [N, 24]}
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from dataclasses import dataclass
24
+
25
+ import numpy as np
26
+ import torch
27
+
28
+ from ..modules.normalization import symlog
29
+ from ..modules.rays import FThetaCamera
30
+ from .ftheta_proj import project_points_ftheta
31
+ from .se3 import invert_se3
32
+
33
+
34
+ @dataclass
35
+ class ObjectTrackInfo:
36
+ """单个对���在某帧的简化记录。"""
37
+
38
+ tracking_id: str
39
+ object_to_world: torch.Tensor # [4, 4]
40
+ lwh: torch.Tensor # [3]
41
+ is_moving: bool
42
+ object_type: str
43
+
44
+
45
+ def _yaw_from_rotation_matrix(R: torch.Tensor) -> torch.Tensor:
46
+ """从 3x3 旋转矩阵提取自车系下绕 z 轴的 yaw 角。
47
+
48
+ 使用 ``atan2(R[1,0], R[0,0])``。
49
+ """
50
+ return torch.atan2(R[..., 1, 0], R[..., 0, 0])
51
+
52
+
53
+ def _make_class_index(object_type: str, dynamic_classes: list[str], structured_classes: list[str], background_idx: int = 0) -> tuple[int, int]:
54
+ """根据 object_type 字符串映射到 (class_index, is_dynamic)。"""
55
+ if object_type in dynamic_classes:
56
+ return dynamic_classes.index(object_type) + 1, 1 # +1 为 background 留 idx 0
57
+ if object_type in structured_classes:
58
+ return len(dynamic_classes) + structured_classes.index(object_type) + 1, 0
59
+ return background_idx, 0 # 未知类型当 background
60
+
61
+
62
+ def build_detection_targets(
63
+ objects_t: list[ObjectTrackInfo],
64
+ objects_future: list[list[ObjectTrackInfo]], # len = future_horizon,每帧一个对象列表
65
+ vehicle_pose_t: torch.Tensor, # [4, 4],vehicle to world
66
+ vehicle_pose_future: list[torch.Tensor], # 每帧一个 4x4
67
+ cam_intrinsic: FThetaCamera,
68
+ cam2vehicle: torch.Tensor, # [4, 4]
69
+ image_h: int,
70
+ image_w: int,
71
+ max_distance_m: float = 48.0,
72
+ occlusion_depth_tolerance: float = 0.5,
73
+ lidar_points_self: torch.Tensor | None = None, # [P, 3] in self frame,做粗遮挡
74
+ dynamic_classes: list[str] | None = None,
75
+ structured_classes: list[str] | None = None,
76
+ future_horizon: int = 24,
77
+ ) -> dict:
78
+ """构建一个样本的检测+未来轨迹标签。"""
79
+ if dynamic_classes is None:
80
+ dynamic_classes = []
81
+ if structured_classes is None:
82
+ structured_classes = []
83
+
84
+ inv_pose_t = invert_se3(vehicle_pose_t)
85
+ vehicle2cam = invert_se3(cam2vehicle)
86
+
87
+ labels: list[int] = []
88
+ boxes: list[list[float]] = []
89
+ is_dynamic: list[int] = []
90
+ future_traj: list[list[list[float]]] = []
91
+ future_valid: list[list[int]] = []
92
+
93
+ for obj in objects_t:
94
+ T_obj_self = inv_pose_t @ obj.object_to_world # [4,4]
95
+ center_self = T_obj_self[:3, 3]
96
+
97
+ dist = float(center_self.norm().item())
98
+ if dist > max_distance_m:
99
+ continue
100
+
101
+ # 视锥裁剪:把中心投影到相机系再投影到像素
102
+ center_cam = (vehicle2cam @ torch.cat([center_self, torch.ones(1)])[None].T).squeeze(-1)[:3]
103
+ if center_cam[2].item() <= 0:
104
+ continue
105
+ uv, depth = project_points_ftheta(center_cam.unsqueeze(0), cam_intrinsic)
106
+ u, v = uv[0, 0].item(), uv[0, 1].item()
107
+ if not (0 <= u < image_w and 0 <= v < image_h):
108
+ continue
109
+
110
+ # LIDAR 遮挡:找到 LIDAR 中靠近当前射线方向的最近点深度,与对象深度对比
111
+ if lidar_points_self is not None and lidar_points_self.numel() > 0:
112
+ ray = center_self / (center_self.norm() + 1e-6)
113
+ proj = lidar_points_self @ ray # [P]
114
+ # 选取沿射线方向投影距离接近 dist 的点(容差 1m,水平角 5°)
115
+ cosang = (lidar_points_self / (lidar_points_self.norm(dim=-1, keepdim=True) + 1e-6)) @ ray
116
+ mask = (cosang > 0.996) & (proj > 0)
117
+ if mask.any():
118
+ lidar_depth = proj[mask].min().item()
119
+ if lidar_depth + occlusion_depth_tolerance < dist:
120
+ # LIDAR 击中前方更近物体 -> 当前对象被遮挡
121
+ continue
122
+
123
+ # 类别映射
124
+ cls_idx, is_dyn = _make_class_index(obj.object_type, dynamic_classes, structured_classes)
125
+ if cls_idx == 0:
126
+ continue
127
+ labels.append(cls_idx)
128
+ is_dynamic.append(is_dyn)
129
+
130
+ yaw = _yaw_from_rotation_matrix(T_obj_self[:3, :3]).item()
131
+ l, w, h = obj.lwh.tolist()
132
+ # box 坐标 symlog 归一
133
+ x_n, y_n, z_n = (
134
+ float(symlog(center_self[0]).item()),
135
+ float(symlog(center_self[1]).item()),
136
+ float(symlog(center_self[2]).item()),
137
+ )
138
+ l_n = float(symlog(torch.tensor(l)).item())
139
+ w_n = float(symlog(torch.tensor(w)).item())
140
+ h_n = float(symlog(torch.tensor(h)).item())
141
+ boxes.append([x_n, y_n, z_n, l_n, w_n, h_n, yaw])
142
+
143
+ # 未来轨迹:在当前 self 系下用 (dx, dy, dyaw),相对 t 时刻对象自身
144
+ # 先取 t 时刻对象在 self 系下的 (x_t, y_t, yaw_t)
145
+ x0, y0, yaw0 = center_self[0].item(), center_self[1].item(), yaw
146
+ future_3 = []
147
+ future_v = []
148
+ for k in range(future_horizon):
149
+ if k >= len(objects_future) or k >= len(vehicle_pose_future):
150
+ future_3.append([0.0, 0.0, 0.0])
151
+ future_v.append(0)
152
+ continue
153
+ # 找对象在 t+k+1 帧
154
+ future_objs = objects_future[k]
155
+ match = next((o for o in future_objs if o.tracking_id == obj.tracking_id), None)
156
+ if match is None:
157
+ future_3.append([0.0, 0.0, 0.0])
158
+ future_v.append(0)
159
+ continue
160
+ T_obj_self_future = invert_se3(vehicle_pose_t) @ match.object_to_world
161
+ xf = T_obj_self_future[0, 3].item()
162
+ yf = T_obj_self_future[1, 3].item()
163
+ yawf = _yaw_from_rotation_matrix(T_obj_self_future[:3, :3]).item()
164
+ dx = xf - x0
165
+ dy = yf - y0
166
+ dyaw = yawf - yaw0
167
+ # 角度归到 (-pi, pi]
168
+ dyaw = (dyaw + np.pi) % (2 * np.pi) - np.pi
169
+ future_3.append([
170
+ float(symlog(torch.tensor(dx)).item()),
171
+ float(symlog(torch.tensor(dy)).item()),
172
+ float(dyaw),
173
+ ])
174
+ future_v.append(1)
175
+ future_traj.append(future_3)
176
+ future_valid.append(future_v)
177
+
178
+ if not labels:
179
+ return {
180
+ "labels": torch.zeros(0, dtype=torch.long),
181
+ "boxes": torch.zeros(0, 7),
182
+ "is_dynamic": torch.zeros(0, dtype=torch.long),
183
+ "future_traj": torch.zeros(0, future_horizon, 3),
184
+ "future_valid": torch.zeros(0, future_horizon, dtype=torch.bool),
185
+ }
186
+ return {
187
+ "labels": torch.tensor(labels, dtype=torch.long),
188
+ "boxes": torch.tensor(boxes, dtype=torch.float32),
189
+ "is_dynamic": torch.tensor(is_dynamic, dtype=torch.long),
190
+ "future_traj": torch.tensor(future_traj, dtype=torch.float32),
191
+ "future_valid": torch.tensor(future_valid, dtype=torch.bool),
192
+ }
193
+
194
+
195
+ def build_ego_future_target(
196
+ vehicle_pose_t: torch.Tensor,
197
+ vehicle_pose_future: list[torch.Tensor],
198
+ horizon: int = 24,
199
+ ) -> tuple[torch.Tensor, torch.Tensor]:
200
+ """自车未来 24 帧轨迹(在 t 自车系下,``(x, y, yaw)`` 已 symlog 归一)。"""
201
+ inv_t = invert_se3(vehicle_pose_t)
202
+ out = torch.zeros(horizon, 3)
203
+ valid = torch.zeros(horizon, dtype=torch.bool)
204
+ for k in range(horizon):
205
+ if k >= len(vehicle_pose_future):
206
+ break
207
+ rel = inv_t @ vehicle_pose_future[k]
208
+ x, y = rel[0, 3].item(), rel[1, 3].item()
209
+ yaw = _yaw_from_rotation_matrix(rel[:3, :3]).item()
210
+ out[k, 0] = symlog(torch.tensor(x))
211
+ out[k, 1] = symlog(torch.tensor(y))
212
+ out[k, 2] = yaw
213
+ valid[k] = True
214
+ return out, valid
src/wjad/data/transforms.py CHANGED
@@ -1,86 +1,86 @@
1
- """图像与运动学的数据增广。"""
2
-
3
- from __future__ import annotations
4
-
5
- import numpy as np
6
- import torch
7
-
8
-
9
- # DINOv3 的 ImageNet 标准化参数
10
- DINOV3_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
11
- DINOV3_STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
12
-
13
-
14
- def crop_top_half(image: torch.Tensor) -> torch.Tensor:
15
- """裁去图像上半部分(主要是天空)。
16
-
17
- 输入 ``[3, H, W]`` 或 ``[T, 3, H, W]``;返回相同维度但 H 减半。
18
- """
19
- if image.dim() == 4:
20
- h = image.shape[2]
21
- return image[:, :, h // 2 :, :]
22
- elif image.dim() == 3:
23
- h = image.shape[1]
24
- return image[:, h // 2 :, :]
25
- raise ValueError(f"unsupported image dim: {image.dim()}")
26
-
27
-
28
- def normalize_image(image: torch.Tensor, mean: torch.Tensor = DINOV3_MEAN, std: torch.Tensor = DINOV3_STD) -> torch.Tensor:
29
- """对 [0, 1] 范围的图像做标准化。支持 ``[3,H,W]``/``[T,3,H,W]``/``[B,T,3,H,W]``。"""
30
- while mean.dim() < image.dim():
31
- mean = mean.unsqueeze(0)
32
- std = std.unsqueeze(0)
33
- return (image - mean.to(image.device, image.dtype)) / std.to(image.device, image.dtype)
34
-
35
-
36
- def add_gaussian_noise(image: torch.Tensor, std: float = 0.01) -> torch.Tensor:
37
- """高斯噪声增广。``image`` 应已归一化(mean=0,std=1 之后)。"""
38
- if std <= 0:
39
- return image
40
- return image + torch.randn_like(image) * std
41
-
42
-
43
- def perturb_kinematics(
44
- ego_6d: torch.Tensor, # [T, 6]
45
- intr_vec: torch.Tensor, # [14]
46
- extr_6d: torch.Tensor, # [6]
47
- translation_std_m: float,
48
- rotation_std_deg: float,
49
- intrinsic_std: float,
50
- extrinsic_std: float,
51
- rng: np.random.Generator,
52
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
53
- """在 Stage1 中期对运动学和内外参添加微小扰动,作为校准训练增广。
54
-
55
- 返回扰动后值与扰动量(GT 残差 = -扰动量,因为校准网络要把扰动反推回去)。
56
-
57
- 返回
58
- ----
59
- perturbed_ego, perturbed_intr, perturbed_extr,
60
- gt_residual_concat (在 symlog 空间作为 calibration 监督,可选;
61
- 本文件仅返回扰动后的真实空间值,校准 GT 由 trainer 构造)
62
- """
63
- rot_std_rad = np.deg2rad(rotation_std_deg)
64
-
65
- # ego 8x6
66
- delta_ego = np.zeros_like(ego_6d.numpy())
67
- delta_ego[:, :3] = rng.normal(0.0, translation_std_m, size=(ego_6d.shape[0], 3))
68
- delta_ego[:, 3:] = rng.normal(0.0, rot_std_rad, size=(ego_6d.shape[0], 3))
69
- perturbed_ego = ego_6d + torch.from_numpy(delta_ego).to(ego_6d)
70
-
71
- # intrinsic 14
72
- delta_intr = rng.normal(0.0, intrinsic_std, size=(intr_vec.shape[0],))
73
- perturbed_intr = intr_vec + torch.from_numpy(delta_intr).to(intr_vec)
74
-
75
- # extrinsic 6
76
- delta_extr = np.zeros_like(extr_6d.numpy())
77
- delta_extr[:3] = rng.normal(0.0, extrinsic_std, size=(3,))
78
- delta_extr[3:] = rng.normal(0.0, rot_std_rad, size=(3,))
79
- perturbed_extr = extr_6d + torch.from_numpy(delta_extr).to(extr_6d)
80
-
81
- return (
82
- perturbed_ego,
83
- perturbed_intr,
84
- perturbed_extr,
85
- torch.from_numpy(np.concatenate([delta_ego.flatten(), delta_intr, delta_extr])).to(ego_6d.dtype),
86
- )
 
1
+ """图像与运动学的数据增广。"""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ # DINOv3 的 ImageNet 标准化参数
10
+ DINOV3_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
11
+ DINOV3_STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
12
+
13
+
14
+ def crop_top_half(image: torch.Tensor) -> torch.Tensor:
15
+ """裁去图像上半部分(主要是天空)。
16
+
17
+ 输入 ``[3, H, W]`` 或 ``[T, 3, H, W]``;返回相同维度但 H 减半。
18
+ """
19
+ if image.dim() == 4:
20
+ h = image.shape[2]
21
+ return image[:, :, h // 2 :, :]
22
+ elif image.dim() == 3:
23
+ h = image.shape[1]
24
+ return image[:, h // 2 :, :]
25
+ raise ValueError(f"unsupported image dim: {image.dim()}")
26
+
27
+
28
+ def normalize_image(image: torch.Tensor, mean: torch.Tensor = DINOV3_MEAN, std: torch.Tensor = DINOV3_STD) -> torch.Tensor:
29
+ """对 [0, 1] 范围的图像做标准化。支持 ``[3,H,W]``/``[T,3,H,W]``/``[B,T,3,H,W]``。"""
30
+ while mean.dim() < image.dim():
31
+ mean = mean.unsqueeze(0)
32
+ std = std.unsqueeze(0)
33
+ return (image - mean.to(image.device, image.dtype)) / std.to(image.device, image.dtype)
34
+
35
+
36
+ def add_gaussian_noise(image: torch.Tensor, std: float = 0.01) -> torch.Tensor:
37
+ """高斯噪声增广。``image`` 应已归一化(mean=0,std=1 之后)。"""
38
+ if std <= 0:
39
+ return image
40
+ return image + torch.randn_like(image) * std
41
+
42
+
43
+ def perturb_kinematics(
44
+ ego_6d: torch.Tensor, # [T, 6]
45
+ intr_vec: torch.Tensor, # [14]
46
+ extr_6d: torch.Tensor, # [6]
47
+ translation_std_m: float,
48
+ rotation_std_deg: float,
49
+ intrinsic_std: float,
50
+ extrinsic_std: float,
51
+ rng: np.random.Generator,
52
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
53
+ """在 Stage1 中期对运动学和内外参添加微小扰动,作为校准训练增广。
54
+
55
+ 返回扰动后值与扰动量(GT 残差 = -扰动量,因为校准网络要把扰动反推回去)。
56
+
57
+ 返回
58
+ ----
59
+ perturbed_ego, perturbed_intr, perturbed_extr,
60
+ gt_residual_concat (在 symlog 空间作为 calibration 监督,可选;
61
+ 本文件仅返回扰动后的真实空间值,校准 GT 由 trainer 构造)
62
+ """
63
+ rot_std_rad = np.deg2rad(rotation_std_deg)
64
+
65
+ # ego 8x6
66
+ delta_ego = np.zeros_like(ego_6d.numpy())
67
+ delta_ego[:, :3] = rng.normal(0.0, translation_std_m, size=(ego_6d.shape[0], 3))
68
+ delta_ego[:, 3:] = rng.normal(0.0, rot_std_rad, size=(ego_6d.shape[0], 3))
69
+ perturbed_ego = ego_6d + torch.from_numpy(delta_ego).to(ego_6d)
70
+
71
+ # intrinsic 14
72
+ delta_intr = rng.normal(0.0, intrinsic_std, size=(intr_vec.shape[0],))
73
+ perturbed_intr = intr_vec + torch.from_numpy(delta_intr).to(intr_vec)
74
+
75
+ # extrinsic 6
76
+ delta_extr = np.zeros_like(extr_6d.numpy())
77
+ delta_extr[:3] = rng.normal(0.0, extrinsic_std, size=(3,))
78
+ delta_extr[3:] = rng.normal(0.0, rot_std_rad, size=(3,))
79
+ perturbed_extr = extr_6d + torch.from_numpy(delta_extr).to(extr_6d)
80
+
81
+ return (
82
+ perturbed_ego,
83
+ perturbed_intr,
84
+ perturbed_extr,
85
+ torch.from_numpy(np.concatenate([delta_ego.flatten(), delta_intr, delta_extr])).to(ego_6d.dtype),
86
+ )
src/wjad/encoders/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- """视觉编码相关:DINOv3 包装、时空压缩。"""
2
-
3
- from .dinov3_wrapper import DINOv3Wrapper
4
-
5
- __all__ = ["DINOv3Wrapper"]
 
1
+ """视觉编码相关:DINOv3 包装、时空压缩。"""
2
+
3
+ from .dinov3_wrapper import DINOv3Wrapper
4
+
5
+ __all__ = ["DINOv3Wrapper"]
src/wjad/encoders/dinov3_wrapper.py CHANGED
@@ -1,104 +1,104 @@
1
- """DINOv3 ViT-B/16 包装器。
2
-
3
- - 从本地路径加载(``./dinov3-vitb16-pretrain-lvd1689m``)。
4
- - 强制使用 ``attn_implementation="sdpa"``。
5
- - 提供 ``freeze()`` / ``unfreeze()`` 开关。
6
- - 输入:``[B, T, 3, H, W]``;输出:``[B, T, gh, gw, D]``,其中
7
- ``(gh, gw) = (H/patch, W/patch)``。
8
- """
9
-
10
- from __future__ import annotations
11
-
12
- from pathlib import Path
13
-
14
- import torch
15
- import torch.nn as nn
16
- from transformers import AutoModel
17
-
18
-
19
- class DINOv3Wrapper(nn.Module):
20
- """加载并包装 DINOv3 ViT-B/16,输出 patch 网格特征。"""
21
-
22
- def __init__(
23
- self,
24
- pretrained_path: str | Path = "./dinov3-vitb16-pretrain-lvd1689m",
25
- attn_implementation: str = "sdpa",
26
- freeze: bool = True,
27
- ) -> None:
28
- super().__init__()
29
- self.pretrained_path = str(pretrained_path)
30
- # 加载 HuggingFace transformers 中的 DINOv3 ViT 模型
31
- self.model = AutoModel.from_pretrained(
32
- self.pretrained_path,
33
- attn_implementation=attn_implementation,
34
- )
35
- cfg = self.model.config
36
- self.hidden_size = cfg.hidden_size
37
- self.patch_size = cfg.patch_size
38
- self.num_register_tokens = getattr(cfg, "num_register_tokens", 4)
39
-
40
- if freeze:
41
- self.freeze()
42
- self._frozen = freeze
43
-
44
- def freeze(self) -> None:
45
- """冻结所有参数。"""
46
- for p in self.model.parameters():
47
- p.requires_grad_(False)
48
- self.model.eval()
49
- self._frozen = True
50
-
51
- def unfreeze(self) -> None:
52
- """解冻全部参(Stage2 微调)。"""
53
- for p in self.model.parameters():
54
- p.requires_grad_(True)
55
- self.model.train()
56
- self._frozen = False
57
-
58
- @property
59
- def is_frozen(self) -> bool:
60
- return self._frozen
61
-
62
- def train(self, mode: bool = True) -> "DINOv3Wrapper":
63
- """覆盖 train():冻结时永远保持 eval 模式(避免 BN/Dropout 漂移)。"""
64
- super().train(mode)
65
- if self._frozen:
66
- self.model.eval()
67
- return self
68
-
69
- def forward(self, images: torch.Tensor) -> torch.Tensor:
70
- """
71
- 参数
72
- ----
73
- images : ``[B, T, 3, H, W]``,已按 DINOv3 mean/std 归一化。
74
-
75
- 返回
76
- ----
77
- feats : ``[B, T, gh, gw, D]``。
78
- """
79
- b, t, c, h, w = images.shape
80
- # DINOv3 forward 接受 ``pixel_values: [B*T, 3, H, W]``
81
- flat = images.view(b * t, c, h, w)
82
-
83
- # 冻结分支无需梯度,节省显存与时间
84
- if self._frozen:
85
- with torch.no_grad():
86
- outputs = self.model(pixel_values=flat)
87
- else:
88
- outputs = self.model(pixel_values=flat)
89
-
90
- last = outputs.last_hidden_state # [B*T, 1 + R + N_patch, D]
91
- num_prefix = 1 + self.num_register_tokens
92
- patches = last[:, num_prefix:, :] # [B*T, N_patch, D]
93
-
94
- gh = h // self.patch_size
95
- gw = w // self.patch_size
96
- d = patches.shape[-1]
97
- # reshape 回网格
98
- feats = patches.view(b, t, gh, gw, d)
99
- return feats
100
-
101
- @torch.no_grad()
102
- def expected_grid(self, image_h: int, image_w: int) -> tuple[int, int]:
103
- """给定输入分辨率,返回 patch 网格大小。"""
104
- return image_h // self.patch_size, image_w // self.patch_size
 
1
+ """DINOv3 ViT-B/16 包装器。
2
+
3
+ - 从本地路径加载(``./dinov3-vitb16-pretrain-lvd1689m``)。
4
+ - 强制使用 ``attn_implementation="sdpa"``。
5
+ - 提供 ``freeze()`` / ``unfreeze()`` 开关。
6
+ - 输入:``[B, T, 3, H, W]``;输出:``[B, T, gh, gw, D]``,其中
7
+ ``(gh, gw) = (H/patch, W/patch)``。
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from pathlib import Path
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from transformers import AutoModel
17
+
18
+
19
+ class DINOv3Wrapper(nn.Module):
20
+ """加载并包装 DINOv3 ViT-B/16,输出 patch 网格特征。"""
21
+
22
+ def __init__(
23
+ self,
24
+ pretrained_path: str | Path = "./dinov3-vitb16-pretrain-lvd1689m",
25
+ attn_implementation: str = "sdpa",
26
+ freeze: bool = True,
27
+ ) -> None:
28
+ super().__init__()
29
+ self.pretrained_path = str(pretrained_path)
30
+ # 加载 HuggingFace transformers 中的 DINOv3 ViT 模型
31
+ self.model = AutoModel.from_pretrained(
32
+ self.pretrained_path,
33
+ attn_implementation=attn_implementation,
34
+ )
35
+ cfg = self.model.config
36
+ self.hidden_size = cfg.hidden_size
37
+ self.patch_size = cfg.patch_size
38
+ self.num_register_tokens = getattr(cfg, "num_register_tokens", 4)
39
+
40
+ if freeze:
41
+ self.freeze()
42
+ self._frozen = freeze
43
+
44
+ def freeze(self) -> None:
45
+ """冻结所有参数。"""
46
+ for p in self.model.parameters():
47
+ p.requires_grad_(False)
48
+ self.model.eval()
49
+ self._frozen = True
50
+
51
+ def unfreeze(self) -> None:
52
+ """解冻全部参��(Stage2 微调)。"""
53
+ for p in self.model.parameters():
54
+ p.requires_grad_(True)
55
+ self.model.train()
56
+ self._frozen = False
57
+
58
+ @property
59
+ def is_frozen(self) -> bool:
60
+ return self._frozen
61
+
62
+ def train(self, mode: bool = True) -> "DINOv3Wrapper":
63
+ """覆盖 train():冻结时永远保持 eval 模式(避免 BN/Dropout 漂移)。"""
64
+ super().train(mode)
65
+ if self._frozen:
66
+ self.model.eval()
67
+ return self
68
+
69
+ def forward(self, images: torch.Tensor) -> torch.Tensor:
70
+ """
71
+ 参数
72
+ ----
73
+ images : ``[B, T, 3, H, W]``,已按 DINOv3 mean/std 归一化。
74
+
75
+ 返回
76
+ ----
77
+ feats : ``[B, T, gh, gw, D]``。
78
+ """
79
+ b, t, c, h, w = images.shape
80
+ # DINOv3 forward 接受 ``pixel_values: [B*T, 3, H, W]``
81
+ flat = images.view(b * t, c, h, w)
82
+
83
+ # 冻结分支无需梯度,节省显存与时间
84
+ if self._frozen:
85
+ with torch.no_grad():
86
+ outputs = self.model(pixel_values=flat)
87
+ else:
88
+ outputs = self.model(pixel_values=flat)
89
+
90
+ last = outputs.last_hidden_state # [B*T, 1 + R + N_patch, D]
91
+ num_prefix = 1 + self.num_register_tokens
92
+ patches = last[:, num_prefix:, :] # [B*T, N_patch, D]
93
+
94
+ gh = h // self.patch_size
95
+ gw = w // self.patch_size
96
+ d = patches.shape[-1]
97
+ # reshape 回网格
98
+ feats = patches.view(b, t, gh, gw, d)
99
+ return feats
100
+
101
+ @torch.no_grad()
102
+ def expected_grid(self, image_h: int, image_w: int) -> tuple[int, int]:
103
+ """给定输入分辨率,返回 patch 网格大小。"""
104
+ return image_h // self.patch_size, image_w // self.patch_size
src/wjad/heads/__init__.py CHANGED
@@ -1,11 +1,11 @@
1
- """检测+未来轨迹头 + 控制头。"""
2
-
3
- from .detection_traj import DetectionTrajHead, DetectionTrajOutput
4
- from .control import ControlHead, ControlOutput
5
-
6
- __all__ = [
7
- "DetectionTrajHead",
8
- "DetectionTrajOutput",
9
- "ControlHead",
10
- "ControlOutput",
11
- ]
 
1
+ """检测+未来轨迹头 + 控制头。"""
2
+
3
+ from .detection_traj import DetectionTrajHead, DetectionTrajOutput
4
+ from .control import ControlHead, ControlOutput
5
+
6
+ __all__ = [
7
+ "DetectionTrajHead",
8
+ "DetectionTrajOutput",
9
+ "ControlHead",
10
+ "ControlOutput",
11
+ ]
src/wjad/heads/control.py CHANGED
@@ -1,100 +1,100 @@
1
- """自车控制头:24 个控制 token 输出未来轨迹与全局动作。
2
-
3
- token 切分:
4
- - 12 个轨迹 token → 经 MLP 上采样到 24 帧自车 ``(x, y, yaw)`` 的 ``μ`` / ``log_sigma``
5
- - 12 个动作 token → 第 0 个解码 ``(steer, throttle, brake)`` 的 ``μ`` / ``log_sigma``
6
- (其余作为冗余 / 未来扩展,暂不监督)
7
- """
8
-
9
- from __future__ import annotations
10
-
11
- from dataclasses import dataclass
12
-
13
- import torch
14
- import torch.nn as nn
15
-
16
-
17
- @dataclass
18
- class ControlOutput:
19
- ego_traj_mu: torch.Tensor # [B, T_future, 3]
20
- ego_traj_log_sigma: torch.Tensor # [B, T_future, 3]
21
- action_mu: torch.Tensor # [B, action_dim]
22
- action_log_sigma: torch.Tensor # [B, action_dim]
23
-
24
-
25
- class ControlHead(nn.Module):
26
- def __init__(
27
- self,
28
- in_dim: int = 768,
29
- hidden_size: int = 384,
30
- num_traj_tokens: int = 12,
31
- num_action_tokens: int = 12,
32
- ego_traj_horizon: int = 24,
33
- ego_traj_dim: int = 3,
34
- action_dim: int = 3,
35
- log_sigma_clamp: tuple[float, float] = (-7.0, 7.0),
36
- ) -> None:
37
- super().__init__()
38
- assert num_traj_tokens + num_action_tokens == 24, "控制 token 总数应为 24"
39
- self.in_dim = in_dim
40
- self.hidden = hidden_size
41
- self.num_traj_tokens = num_traj_tokens
42
- self.num_action_tokens = num_action_tokens
43
- self.ego_traj_horizon = ego_traj_horizon
44
- self.ego_traj_dim = ego_traj_dim
45
- self.action_dim = action_dim
46
- self.log_sigma_clamp = log_sigma_clamp
47
-
48
- self.norm = nn.LayerNorm(in_dim)
49
- # 轨迹分支:把 12 个轨迹 token 拼平 -> MLP -> 24*3 (mu) + 24*3 (logsig)
50
- self.traj_proj = nn.Sequential(
51
- nn.Linear(num_traj_tokens * in_dim, hidden_size),
52
- nn.GELU(),
53
- nn.Linear(hidden_size, hidden_size),
54
- nn.GELU(),
55
- )
56
- self.traj_mu_head = nn.Linear(hidden_size, ego_traj_horizon * ego_traj_dim)
57
- self.traj_logsig_head = nn.Linear(hidden_size, ego_traj_horizon * ego_traj_dim)
58
-
59
- # 动作分支:取第 0 个动作 token
60
- self.action_proj = nn.Sequential(
61
- nn.Linear(in_dim, hidden_size),
62
- nn.GELU(),
63
- nn.Linear(hidden_size, hidden_size),
64
- nn.GELU(),
65
- )
66
- self.action_mu_head = nn.Linear(hidden_size, action_dim)
67
- self.action_logsig_head = nn.Linear(hidden_size, action_dim)
68
-
69
- self._init_heads()
70
-
71
- def _init_heads(self) -> None:
72
- for m in [self.traj_mu_head, self.traj_logsig_head, self.action_mu_head, self.action_logsig_head]:
73
- nn.init.zeros_(m.weight)
74
- nn.init.zeros_(m.bias)
75
-
76
- def forward(self, ctrl_tokens: torch.Tensor) -> ControlOutput:
77
- """
78
- ctrl_tokens : ``[B, 24, in_dim]``
79
- """
80
- b, n, d = ctrl_tokens.shape
81
- assert n == self.num_traj_tokens + self.num_action_tokens
82
- x = self.norm(ctrl_tokens)
83
- traj_feats = x[:, : self.num_traj_tokens, :].reshape(b, -1)
84
- action_feats = x[:, self.num_traj_tokens, :] # 取第一个动作 token
85
-
86
- traj_h = self.traj_proj(traj_feats)
87
- traj_mu = self.traj_mu_head(traj_h).view(b, self.ego_traj_horizon, self.ego_traj_dim)
88
- traj_logsig = self.traj_logsig_head(traj_h).view(b, self.ego_traj_horizon, self.ego_traj_dim)
89
- traj_logsig = traj_logsig.clamp(*self.log_sigma_clamp)
90
-
91
- action_h = self.action_proj(action_feats)
92
- action_mu = self.action_mu_head(action_h)
93
- action_logsig = self.action_logsig_head(action_h).clamp(*self.log_sigma_clamp)
94
-
95
- return ControlOutput(
96
- ego_traj_mu=traj_mu,
97
- ego_traj_log_sigma=traj_logsig,
98
- action_mu=action_mu,
99
- action_log_sigma=action_logsig,
100
- )
 
1
+ """自车控制头:24 个控制 token 输出未来轨迹与全局动作。
2
+
3
+ token 切分:
4
+ - 12 个轨迹 token → 经 MLP 上采样到 24 帧自车 ``(x, y, yaw)`` 的 ``μ`` / ``log_sigma``
5
+ - 12 个动作 token → 第 0 个解码 ``(steer, throttle, brake)`` 的 ``μ`` / ``log_sigma``
6
+ (其余作为冗余 / 未来扩展,暂不监督)
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import dataclass
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+
17
+ @dataclass
18
+ class ControlOutput:
19
+ ego_traj_mu: torch.Tensor # [B, T_future, 3]
20
+ ego_traj_log_sigma: torch.Tensor # [B, T_future, 3]
21
+ action_mu: torch.Tensor # [B, action_dim]
22
+ action_log_sigma: torch.Tensor # [B, action_dim]
23
+
24
+
25
+ class ControlHead(nn.Module):
26
+ def __init__(
27
+ self,
28
+ in_dim: int = 768,
29
+ hidden_size: int = 384,
30
+ num_traj_tokens: int = 12,
31
+ num_action_tokens: int = 12,
32
+ ego_traj_horizon: int = 24,
33
+ ego_traj_dim: int = 3,
34
+ action_dim: int = 3,
35
+ log_sigma_clamp: tuple[float, float] = (-7.0, 7.0),
36
+ ) -> None:
37
+ super().__init__()
38
+ assert num_traj_tokens + num_action_tokens == 24, "控制 token 总数应为 24"
39
+ self.in_dim = in_dim
40
+ self.hidden = hidden_size
41
+ self.num_traj_tokens = num_traj_tokens
42
+ self.num_action_tokens = num_action_tokens
43
+ self.ego_traj_horizon = ego_traj_horizon
44
+ self.ego_traj_dim = ego_traj_dim
45
+ self.action_dim = action_dim
46
+ self.log_sigma_clamp = log_sigma_clamp
47
+
48
+ self.norm = nn.LayerNorm(in_dim)
49
+ # 轨迹分支:把 12 个轨迹 token 拼平 -> MLP -> 24*3 (mu) + 24*3 (logsig)
50
+ self.traj_proj = nn.Sequential(
51
+ nn.Linear(num_traj_tokens * in_dim, hidden_size),
52
+ nn.GELU(),
53
+ nn.Linear(hidden_size, hidden_size),
54
+ nn.GELU(),
55
+ )
56
+ self.traj_mu_head = nn.Linear(hidden_size, ego_traj_horizon * ego_traj_dim)
57
+ self.traj_logsig_head = nn.Linear(hidden_size, ego_traj_horizon * ego_traj_dim)
58
+
59
+ # 动作分支:取第 0 个动作 token
60
+ self.action_proj = nn.Sequential(
61
+ nn.Linear(in_dim, hidden_size),
62
+ nn.GELU(),
63
+ nn.Linear(hidden_size, hidden_size),
64
+ nn.GELU(),
65
+ )
66
+ self.action_mu_head = nn.Linear(hidden_size, action_dim)
67
+ self.action_logsig_head = nn.Linear(hidden_size, action_dim)
68
+
69
+ self._init_heads()
70
+
71
+ def _init_heads(self) -> None:
72
+ for m in [self.traj_mu_head, self.traj_logsig_head, self.action_mu_head, self.action_logsig_head]:
73
+ nn.init.zeros_(m.weight)
74
+ nn.init.zeros_(m.bias)
75
+
76
+ def forward(self, ctrl_tokens: torch.Tensor) -> ControlOutput:
77
+ """
78
+ ctrl_tokens : ``[B, 24, in_dim]``
79
+ """
80
+ b, n, d = ctrl_tokens.shape
81
+ assert n == self.num_traj_tokens + self.num_action_tokens
82
+ x = self.norm(ctrl_tokens)
83
+ traj_feats = x[:, : self.num_traj_tokens, :].reshape(b, -1)
84
+ action_feats = x[:, self.num_traj_tokens, :] # 取第一个动作 token
85
+
86
+ traj_h = self.traj_proj(traj_feats)
87
+ traj_mu = self.traj_mu_head(traj_h).view(b, self.ego_traj_horizon, self.ego_traj_dim)
88
+ traj_logsig = self.traj_logsig_head(traj_h).view(b, self.ego_traj_horizon, self.ego_traj_dim)
89
+ traj_logsig = traj_logsig.clamp(*self.log_sigma_clamp)
90
+
91
+ action_h = self.action_proj(action_feats)
92
+ action_mu = self.action_mu_head(action_h)
93
+ action_logsig = self.action_logsig_head(action_h).clamp(*self.log_sigma_clamp)
94
+
95
+ return ControlOutput(
96
+ ego_traj_mu=traj_mu,
97
+ ego_traj_log_sigma=traj_logsig,
98
+ action_mu=action_mu,
99
+ action_log_sigma=action_logsig,
100
+ )
src/wjad/heads/detection_traj.py CHANGED
@@ -1,106 +1,106 @@
1
- """统一的检测 + 未来轨迹头。
2
-
3
- 每个检测 query token 输出:
4
- - ``cls`` : ``[num_classes]`` logits(含 background)
5
- - ``is_dynamic`` : 二分类 logit(是否为运动类,用于 mask 轨迹分支损失)
6
- - ``box3d_mu`` / ``box3d_log_sigma`` : ``[7]``(x, y, z, l, w, h, yaw)
7
- - ``traj_mu`` / ``traj_log_sigma`` : ``[traj_horizon, 3]``(dx, dy, dyaw)
8
-
9
- 匈牙利匹配代价由外部损失模块构造(用 cls focal 代价 + L1(box μ) + GIoU3D 近似)。
10
- """
11
-
12
- from __future__ import annotations
13
-
14
- from dataclasses import dataclass
15
-
16
- import torch
17
- import torch.nn as nn
18
-
19
- from ..modules.ffn import SwiGLUFFN
20
-
21
-
22
- @dataclass
23
- class DetectionTrajOutput:
24
- """检测+未来轨迹头输出。"""
25
-
26
- cls_logits: torch.Tensor # [B, Q, num_classes]
27
- is_dynamic_logit: torch.Tensor # [B, Q]
28
- box3d_mu: torch.Tensor # [B, Q, 7]
29
- box3d_log_sigma: torch.Tensor # [B, Q, 7]
30
- traj_mu: torch.Tensor # [B, Q, T_future, 3]
31
- traj_log_sigma: torch.Tensor # [B, Q, T_future, 3]
32
-
33
-
34
- class DetectionTrajHead(nn.Module):
35
- def __init__(
36
- self,
37
- in_dim: int = 768,
38
- hidden_size: int = 384,
39
- num_classes: int = 22,
40
- box_dim: int = 7,
41
- traj_horizon: int = 24,
42
- traj_dim: int = 3,
43
- log_sigma_clamp: tuple[float, float] = (-7.0, 7.0),
44
- ) -> None:
45
- super().__init__()
46
- self.num_classes = num_classes
47
- self.box_dim = box_dim
48
- self.traj_horizon = traj_horizon
49
- self.traj_dim = traj_dim
50
- self.log_sigma_clamp = log_sigma_clamp
51
-
52
- # 共享主干 MLP(PreNorm + SwiGLU)
53
- self.norm = nn.LayerNorm(in_dim)
54
- self.shared = nn.Sequential(
55
- nn.Linear(in_dim, hidden_size),
56
- nn.GELU(),
57
- nn.Linear(hidden_size, hidden_size),
58
- nn.GELU(),
59
- )
60
-
61
- # 各分支
62
- self.cls_head = nn.Linear(hidden_size, num_classes)
63
- self.isdyn_head = nn.Linear(hidden_size, 1)
64
- self.box_mu_head = nn.Linear(hidden_size, box_dim)
65
- self.box_logsig_head = nn.Linear(hidden_size, box_dim)
66
- self.traj_mu_head = nn.Linear(hidden_size, traj_horizon * traj_dim)
67
- self.traj_logsig_head = nn.Linear(hidden_size, traj_horizon * traj_dim)
68
-
69
- self._init_heads()
70
-
71
- def _init_heads(self) -> None:
72
- # 让 box / traj 输出初始 ≈ 0;log_sigma 初始 ≈ 0 → sigma ≈ 1
73
- for m in [self.box_mu_head, self.box_logsig_head, self.traj_mu_head, self.traj_logsig_head]:
74
- nn.init.zeros_(m.weight)
75
- nn.init.zeros_(m.bias)
76
- # cls / isdyn 用小初始化即可(避免 background 一开始全选)
77
- nn.init.normal_(self.cls_head.weight, std=0.01)
78
- nn.init.zeros_(self.cls_head.bias)
79
- nn.init.zeros_(self.isdyn_head.weight)
80
- nn.init.zeros_(self.isdyn_head.bias)
81
-
82
- def forward(self, det_tokens: torch.Tensor) -> DetectionTrajOutput:
83
- """
84
- det_tokens : ``[B, Q, in_dim]``,主干输出中切出来的检测 token。
85
- """
86
- b, q, _ = det_tokens.shape
87
- feats = self.shared(self.norm(det_tokens))
88
-
89
- cls_logits = self.cls_head(feats)
90
- isdyn_logit = self.isdyn_head(feats).squeeze(-1)
91
-
92
- box_mu = self.box_mu_head(feats)
93
- box_logsig = self.box_logsig_head(feats).clamp(*self.log_sigma_clamp)
94
-
95
- traj_mu = self.traj_mu_head(feats).view(b, q, self.traj_horizon, self.traj_dim)
96
- traj_logsig = self.traj_logsig_head(feats).view(b, q, self.traj_horizon, self.traj_dim)
97
- traj_logsig = traj_logsig.clamp(*self.log_sigma_clamp)
98
-
99
- return DetectionTrajOutput(
100
- cls_logits=cls_logits,
101
- is_dynamic_logit=isdyn_logit,
102
- box3d_mu=box_mu,
103
- box3d_log_sigma=box_logsig,
104
- traj_mu=traj_mu,
105
- traj_log_sigma=traj_logsig,
106
- )
 
1
+ """统一的检测 + 未来轨迹头。
2
+
3
+ 每个检测 query token 输出:
4
+ - ``cls`` : ``[num_classes]`` logits(含 background)
5
+ - ``is_dynamic`` : 二分类 logit(是否为运动类,用于 mask 轨迹分支损失)
6
+ - ``box3d_mu`` / ``box3d_log_sigma`` : ``[7]``(x, y, z, l, w, h, yaw)
7
+ - ``traj_mu`` / ``traj_log_sigma`` : ``[traj_horizon, 3]``(dx, dy, dyaw)
8
+
9
+ 匈牙利匹配代价由外部损失模块构造(用 cls focal 代价 + L1(box μ) + GIoU3D 近似)。
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from dataclasses import dataclass
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from ..modules.ffn import SwiGLUFFN
20
+
21
+
22
+ @dataclass
23
+ class DetectionTrajOutput:
24
+ """检测+未来轨迹头输出。"""
25
+
26
+ cls_logits: torch.Tensor # [B, Q, num_classes]
27
+ is_dynamic_logit: torch.Tensor # [B, Q]
28
+ box3d_mu: torch.Tensor # [B, Q, 7]
29
+ box3d_log_sigma: torch.Tensor # [B, Q, 7]
30
+ traj_mu: torch.Tensor # [B, Q, T_future, 3]
31
+ traj_log_sigma: torch.Tensor # [B, Q, T_future, 3]
32
+
33
+
34
+ class DetectionTrajHead(nn.Module):
35
+ def __init__(
36
+ self,
37
+ in_dim: int = 768,
38
+ hidden_size: int = 384,
39
+ num_classes: int = 22,
40
+ box_dim: int = 7,
41
+ traj_horizon: int = 24,
42
+ traj_dim: int = 3,
43
+ log_sigma_clamp: tuple[float, float] = (-7.0, 7.0),
44
+ ) -> None:
45
+ super().__init__()
46
+ self.num_classes = num_classes
47
+ self.box_dim = box_dim
48
+ self.traj_horizon = traj_horizon
49
+ self.traj_dim = traj_dim
50
+ self.log_sigma_clamp = log_sigma_clamp
51
+
52
+ # 共享主干 MLP(PreNorm + SwiGLU)
53
+ self.norm = nn.LayerNorm(in_dim)
54
+ self.shared = nn.Sequential(
55
+ nn.Linear(in_dim, hidden_size),
56
+ nn.GELU(),
57
+ nn.Linear(hidden_size, hidden_size),
58
+ nn.GELU(),
59
+ )
60
+
61
+ # 各分支
62
+ self.cls_head = nn.Linear(hidden_size, num_classes)
63
+ self.isdyn_head = nn.Linear(hidden_size, 1)
64
+ self.box_mu_head = nn.Linear(hidden_size, box_dim)
65
+ self.box_logsig_head = nn.Linear(hidden_size, box_dim)
66
+ self.traj_mu_head = nn.Linear(hidden_size, traj_horizon * traj_dim)
67
+ self.traj_logsig_head = nn.Linear(hidden_size, traj_horizon * traj_dim)
68
+
69
+ self._init_heads()
70
+
71
+ def _init_heads(self) -> None:
72
+ # 让 box / traj 输出初始 ≈ 0;log_sigma 初始 ≈ 0 → sigma ≈ 1
73
+ for m in [self.box_mu_head, self.box_logsig_head, self.traj_mu_head, self.traj_logsig_head]:
74
+ nn.init.zeros_(m.weight)
75
+ nn.init.zeros_(m.bias)
76
+ # cls / isdyn 用小初始化即可(避免 background 一开始全选)
77
+ nn.init.normal_(self.cls_head.weight, std=0.01)
78
+ nn.init.zeros_(self.cls_head.bias)
79
+ nn.init.zeros_(self.isdyn_head.weight)
80
+ nn.init.zeros_(self.isdyn_head.bias)
81
+
82
+ def forward(self, det_tokens: torch.Tensor) -> DetectionTrajOutput:
83
+ """
84
+ det_tokens : ``[B, Q, in_dim]``,主干输出中切出来的检测 token。
85
+ """
86
+ b, q, _ = det_tokens.shape
87
+ feats = self.shared(self.norm(det_tokens))
88
+
89
+ cls_logits = self.cls_head(feats)
90
+ isdyn_logit = self.isdyn_head(feats).squeeze(-1)
91
+
92
+ box_mu = self.box_mu_head(feats)
93
+ box_logsig = self.box_logsig_head(feats).clamp(*self.log_sigma_clamp)
94
+
95
+ traj_mu = self.traj_mu_head(feats).view(b, q, self.traj_horizon, self.traj_dim)
96
+ traj_logsig = self.traj_logsig_head(feats).view(b, q, self.traj_horizon, self.traj_dim)
97
+ traj_logsig = traj_logsig.clamp(*self.log_sigma_clamp)
98
+
99
+ return DetectionTrajOutput(
100
+ cls_logits=cls_logits,
101
+ is_dynamic_logit=isdyn_logit,
102
+ box3d_mu=box_mu,
103
+ box3d_log_sigma=box_logsig,
104
+ traj_mu=traj_mu,
105
+ traj_log_sigma=traj_logsig,
106
+ )
src/wjad/losses/__init__.py CHANGED
@@ -1,24 +1,24 @@
1
- """损失函数集合。"""
2
-
3
- from .nll import gaussian_nll
4
- from .detection import (
5
- HungarianMatcher,
6
- detection_losses,
7
- DetectionLossOutputs,
8
- )
9
- from .trajectory import object_traj_nll
10
- from .control import ego_traj_nll, action_nll
11
- from .moe_aux import moe_load_balance_and_boundary
12
- from .calib_reg import calibration_regularization
13
-
14
- __all__ = [
15
- "gaussian_nll",
16
- "HungarianMatcher",
17
- "detection_losses",
18
- "DetectionLossOutputs",
19
- "object_traj_nll",
20
- "ego_traj_nll",
21
- "action_nll",
22
- "moe_load_balance_and_boundary",
23
- "calibration_regularization",
24
- ]
 
1
+ """损失函数集合。"""
2
+
3
+ from .nll import gaussian_nll
4
+ from .detection import (
5
+ HungarianMatcher,
6
+ detection_losses,
7
+ DetectionLossOutputs,
8
+ )
9
+ from .trajectory import object_traj_nll
10
+ from .control import ego_traj_nll, action_nll
11
+ from .moe_aux import moe_load_balance_and_boundary
12
+ from .calib_reg import calibration_regularization
13
+
14
+ __all__ = [
15
+ "gaussian_nll",
16
+ "HungarianMatcher",
17
+ "detection_losses",
18
+ "DetectionLossOutputs",
19
+ "object_traj_nll",
20
+ "ego_traj_nll",
21
+ "action_nll",
22
+ "moe_load_balance_and_boundary",
23
+ "calibration_regularization",
24
+ ]
src/wjad/losses/calib_reg.py CHANGED
@@ -1,21 +1,21 @@
1
- """在线校准残差正则:
2
-
3
- - L2 残差先验:让早期 / 一般情况下残差接近 0;
4
- - Tanh 边界正则:``residual^2`` 在 Tanh 上抑制饱和。
5
- """
6
-
7
- from __future__ import annotations
8
-
9
- import torch
10
-
11
-
12
- def calibration_regularization(
13
- ego_residual: torch.Tensor,
14
- intr_residual: torch.Tensor,
15
- extr_residual: torch.Tensor,
16
- l2_weight: float = 1.0,
17
- ) -> torch.Tensor:
18
- e = ego_residual.pow(2).mean()
19
- i = intr_residual.pow(2).mean()
20
- x = extr_residual.pow(2).mean()
21
- return l2_weight * (e + i + x) / 3.0
 
1
+ """在线校准残差正则:
2
+
3
+ - L2 残差先验:让早期 / 一般情况下残差接近 0;
4
+ - Tanh 边界正则:``residual^2`` 在 Tanh 上抑制饱和。
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch
10
+
11
+
12
+ def calibration_regularization(
13
+ ego_residual: torch.Tensor,
14
+ intr_residual: torch.Tensor,
15
+ extr_residual: torch.Tensor,
16
+ l2_weight: float = 1.0,
17
+ ) -> torch.Tensor:
18
+ e = ego_residual.pow(2).mean()
19
+ i = intr_residual.pow(2).mean()
20
+ x = extr_residual.pow(2).mean()
21
+ return l2_weight * (e + i + x) / 3.0
src/wjad/losses/control.py CHANGED
@@ -1,25 +1,25 @@
1
- """自车未来轨迹 + 全局动作的 NLL。"""
2
-
3
- from __future__ import annotations
4
-
5
- import torch
6
-
7
- from .nll import gaussian_nll
8
-
9
-
10
- def ego_traj_nll(
11
- pred_mu: torch.Tensor, # [B, T, 3]
12
- pred_log_sigma: torch.Tensor, # [B, T, 3]
13
- target: torch.Tensor, # [B, T, 3] (symlog 空间)
14
- valid: torch.Tensor | None = None,
15
- ) -> torch.Tensor:
16
- return gaussian_nll(pred_mu, pred_log_sigma, target, valid_mask=valid)
17
-
18
-
19
- def action_nll(
20
- pred_mu: torch.Tensor, # [B, A]
21
- pred_log_sigma: torch.Tensor, # [B, A]
22
- target: torch.Tensor, # [B, A]
23
- valid: torch.Tensor | None = None,
24
- ) -> torch.Tensor:
25
- return gaussian_nll(pred_mu, pred_log_sigma, target, valid_mask=valid)
 
1
+ """自车未来轨迹 + 全局动作的 NLL。"""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+
7
+ from .nll import gaussian_nll
8
+
9
+
10
+ def ego_traj_nll(
11
+ pred_mu: torch.Tensor, # [B, T, 3]
12
+ pred_log_sigma: torch.Tensor, # [B, T, 3]
13
+ target: torch.Tensor, # [B, T, 3] (symlog 空间)
14
+ valid: torch.Tensor | None = None,
15
+ ) -> torch.Tensor:
16
+ return gaussian_nll(pred_mu, pred_log_sigma, target, valid_mask=valid)
17
+
18
+
19
+ def action_nll(
20
+ pred_mu: torch.Tensor, # [B, A]
21
+ pred_log_sigma: torch.Tensor, # [B, A]
22
+ target: torch.Tensor, # [B, A]
23
+ valid: torch.Tensor | None = None,
24
+ ) -> torch.Tensor:
25
+ return gaussian_nll(pred_mu, pred_log_sigma, target, valid_mask=valid)
src/wjad/losses/detection.py CHANGED
@@ -1,213 +1,213 @@
1
- """检测损失:匈牙利匹配 + 分类 focal + 3D box NLL + 近似 GIoU3D。
2
-
3
- GIoU3D 用 BEV 平面 + 高度近似:用 ``(x, y, l, w, yaw)`` 计算 BEV IoU/GIoU,
4
- ``z, h`` 在长度上做线性 IoU;最终 GIoU3D ≈ GIoU_BEV * (h_overlap / h_union)。
5
- 此近似在大多数 AV 公开 benchmark 已被广泛使用(速度快、可微)。
6
- """
7
-
8
- from __future__ import annotations
9
-
10
- from dataclasses import dataclass
11
-
12
- import torch
13
- import torch.nn.functional as F
14
- from scipy.optimize import linear_sum_assignment
15
-
16
- from .nll import gaussian_nll
17
-
18
-
19
- @dataclass
20
- class DetectionLossOutputs:
21
- """检测损失分项与匹配结果。"""
22
-
23
- cls_loss: torch.Tensor
24
- box_nll: torch.Tensor
25
- giou_loss: torch.Tensor
26
- isdyn_loss: torch.Tensor
27
- matched_indices: list[tuple[torch.Tensor, torch.Tensor]]
28
-
29
-
30
- def focal_loss(
31
- logits: torch.Tensor,
32
- target: torch.Tensor,
33
- alpha: float = 0.25,
34
- gamma: float = 2.0,
35
- reduction: str = "mean",
36
- ) -> torch.Tensor:
37
- """多类 focal loss。``logits``: ``[N, C]``, ``target``: ``[N]`` (long)。"""
38
- log_softmax = F.log_softmax(logits, dim=-1)
39
- pt = log_softmax.exp().gather(-1, target.unsqueeze(-1)).squeeze(-1)
40
- ce = -log_softmax.gather(-1, target.unsqueeze(-1)).squeeze(-1)
41
- focal = alpha * (1 - pt).pow(gamma) * ce
42
- if reduction == "mean":
43
- return focal.mean()
44
- if reduction == "sum":
45
- return focal.sum()
46
- return focal
47
-
48
-
49
- def _bev_giou(
50
- box_a: torch.Tensor, # [..., 5] (x,y,l,w,yaw)
51
- box_b: torch.Tensor, # [..., 5]
52
- ) -> torch.Tensor:
53
- """BEV 简化 GIoU:取轴对齐的 (x,y,l,w) 包络(忽略 yaw 旋转),
54
- 可微且足够稳定。如需 SOTA 旋转 IoU 可在后续替换为 ``oriented_iou``。
55
- """
56
- cx_a, cy_a, l_a, w_a = box_a[..., 0], box_a[..., 1], box_a[..., 2], box_a[..., 3]
57
- cx_b, cy_b, l_b, w_b = box_b[..., 0], box_b[..., 1], box_b[..., 2], box_b[..., 3]
58
- a_x1, a_y1 = cx_a - l_a / 2, cy_a - w_a / 2
59
- a_x2, a_y2 = cx_a + l_a / 2, cy_a + w_a / 2
60
- b_x1, b_y1 = cx_b - l_b / 2, cy_b - w_b / 2
61
- b_x2, b_y2 = cx_b + l_b / 2, cy_b + w_b / 2
62
- inter_x1 = torch.max(a_x1, b_x1)
63
- inter_y1 = torch.max(a_y1, b_y1)
64
- inter_x2 = torch.min(a_x2, b_x2)
65
- inter_y2 = torch.min(a_y2, b_y2)
66
- inter_w = (inter_x2 - inter_x1).clamp_min(0)
67
- inter_h = (inter_y2 - inter_y1).clamp_min(0)
68
- inter = inter_w * inter_h
69
- area_a = (l_a * w_a).clamp_min(0)
70
- area_b = (l_b * w_b).clamp_min(0)
71
- union = area_a + area_b - inter + 1e-6
72
- iou = inter / union
73
- # GIoU enclosure
74
- enc_x1 = torch.min(a_x1, b_x1)
75
- enc_y1 = torch.min(a_y1, b_y1)
76
- enc_x2 = torch.max(a_x2, b_x2)
77
- enc_y2 = torch.max(a_y2, b_y2)
78
- enc_area = ((enc_x2 - enc_x1) * (enc_y2 - enc_y1)).clamp_min(1e-6)
79
- giou = iou - (enc_area - union) / enc_area
80
- return giou
81
-
82
-
83
- def giou3d_approx(box_a: torch.Tensor, box_b: torch.Tensor) -> torch.Tensor:
84
- """3D 近似 GIoU。``box``: ``[..., 7]`` (x,y,z,l,w,h,yaw)。"""
85
- bev = _bev_giou(box_a[..., [0, 1, 3, 4, 6]], box_b[..., [0, 1, 3, 4, 6]])
86
- z_a, h_a = box_a[..., 2], box_a[..., 5]
87
- z_b, h_b = box_b[..., 2], box_b[..., 5]
88
- a_z1, a_z2 = z_a - h_a / 2, z_a + h_a / 2
89
- b_z1, b_z2 = z_b - h_b / 2, z_b + h_b / 2
90
- inter_z = (torch.min(a_z2, b_z2) - torch.max(a_z1, b_z1)).clamp_min(0)
91
- union_z = (h_a + h_b - inter_z).clamp_min(1e-6)
92
- z_iou = inter_z / union_z
93
- return bev * z_iou
94
-
95
-
96
- class HungarianMatcher:
97
- """DETR 风格匈牙利匹配(CPU 上 ``scipy.linear_sum_assignment``)。"""
98
-
99
- def __init__(
100
- self,
101
- cls_cost: float = 2.0,
102
- l1_cost: float = 5.0,
103
- giou_cost: float = 2.0,
104
- ) -> None:
105
- self.cls_cost = cls_cost
106
- self.l1_cost = l1_cost
107
- self.giou_cost = giou_cost
108
-
109
- @torch.no_grad()
110
- def match(
111
- self,
112
- cls_logits: torch.Tensor, # [B, Q, C]
113
- box_mu: torch.Tensor, # [B, Q, 7]
114
- targets: list[dict], # 每个样本: {"labels": [N_i], "boxes": [N_i, 7]}
115
- ) -> list[tuple[torch.Tensor, torch.Tensor]]:
116
- b, q, c = cls_logits.shape
117
- out = []
118
- cls_probs = cls_logits.softmax(-1) # [B, Q, C]
119
-
120
- for i in range(b):
121
- tgt_labels = targets[i]["labels"]
122
- tgt_boxes = targets[i]["boxes"]
123
- n_tgt = tgt_labels.numel()
124
- if n_tgt == 0:
125
- out.append((
126
- torch.empty(0, dtype=torch.long),
127
- torch.empty(0, dtype=torch.long),
128
- ))
129
- continue
130
-
131
- # cost_cls: 越大越好 → 取负
132
- cost_cls = -cls_probs[i, :, tgt_labels] # [Q, n_tgt]
133
- cost_l1 = torch.cdist(box_mu[i], tgt_boxes, p=1) # [Q, n_tgt]
134
- # giou3d: [Q, n_tgt]
135
- qa = box_mu[i].unsqueeze(1).expand(-1, n_tgt, -1)
136
- tb = tgt_boxes.unsqueeze(0).expand(q, -1, -1)
137
- cost_giou = -giou3d_approx(qa, tb)
138
-
139
- cost = (
140
- self.cls_cost * cost_cls
141
- + self.l1_cost * cost_l1
142
- + self.giou_cost * cost_giou
143
- )
144
- cost_np = cost.cpu().numpy()
145
- row, col = linear_sum_assignment(cost_np)
146
- out.append((torch.as_tensor(row, dtype=torch.long), torch.as_tensor(col, dtype=torch.long)))
147
- return out
148
-
149
-
150
- def detection_losses(
151
- cls_logits: torch.Tensor, # [B, Q, C]
152
- box_mu: torch.Tensor, # [B, Q, 7]
153
- box_log_sigma: torch.Tensor, # [B, Q, 7]
154
- isdyn_logit: torch.Tensor, # [B, Q]
155
- targets: list[dict], # 每样本: {"labels":..., "boxes":..., "is_dynamic":...}
156
- matcher: HungarianMatcher,
157
- num_classes: int,
158
- background_class: int = 0,
159
- focal_alpha: float = 0.25,
160
- focal_gamma: float = 2.0,
161
- ) -> DetectionLossOutputs:
162
- """返回 cls/box_nll/giou/isdyn 四个标量 loss + 匹配下标。"""
163
- indices = matcher.match(cls_logits, box_mu, targets)
164
- b, q, _ = box_mu.shape
165
- device = box_mu.device
166
-
167
- # 构造分类目标:所有 query 默认 background;匹配的填对应 label
168
- target_classes = torch.full((b, q), background_class, dtype=torch.long, device=device)
169
- target_isdyn = torch.zeros(b, q, dtype=torch.float32, device=device)
170
- matched_box_pairs = []
171
- matched_logsig_pairs = []
172
- matched_target_boxes = []
173
-
174
- for i, (rows, cols) in enumerate(indices):
175
- if rows.numel() == 0:
176
- continue
177
- rows = rows.to(device)
178
- cols = cols.to(device)
179
- target_classes[i, rows] = targets[i]["labels"][cols].to(device)
180
- target_isdyn[i, rows] = targets[i]["is_dynamic"][cols].to(device).float()
181
- matched_box_pairs.append(box_mu[i, rows])
182
- matched_logsig_pairs.append(box_log_sigma[i, rows])
183
- matched_target_boxes.append(targets[i]["boxes"][cols].to(device))
184
-
185
- cls_loss = focal_loss(
186
- cls_logits.view(b * q, -1),
187
- target_classes.view(-1),
188
- alpha=focal_alpha,
189
- gamma=focal_gamma,
190
- )
191
-
192
- if matched_box_pairs:
193
- pred_box = torch.cat(matched_box_pairs, dim=0)
194
- pred_logsig = torch.cat(matched_logsig_pairs, dim=0)
195
- gt_box = torch.cat(matched_target_boxes, dim=0)
196
- box_nll = gaussian_nll(pred_box, pred_logsig, gt_box)
197
- giou_v = giou3d_approx(pred_box, gt_box)
198
- giou_loss = (1.0 - giou_v).mean()
199
- else:
200
- box_nll = torch.zeros((), device=device)
201
- giou_loss = torch.zeros((), device=device)
202
-
203
- isdyn_loss = F.binary_cross_entropy_with_logits(
204
- isdyn_logit, target_isdyn, reduction="mean"
205
- )
206
-
207
- return DetectionLossOutputs(
208
- cls_loss=cls_loss,
209
- box_nll=box_nll,
210
- giou_loss=giou_loss,
211
- isdyn_loss=isdyn_loss,
212
- matched_indices=indices,
213
- )
 
1
+ """检测损失:匈牙利匹配 + 分类 focal + 3D box NLL + 近似 GIoU3D。
2
+
3
+ GIoU3D 用 BEV 平面 + 高度近似:用 ``(x, y, l, w, yaw)`` 计算 BEV IoU/GIoU,
4
+ ``z, h`` 在长度上做线性 IoU;最终 GIoU3D ≈ GIoU_BEV * (h_overlap / h_union)。
5
+ 此近似在大多数 AV 公开 benchmark 已被广泛使用(速度快、可微)。
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from scipy.optimize import linear_sum_assignment
15
+
16
+ from .nll import gaussian_nll
17
+
18
+
19
+ @dataclass
20
+ class DetectionLossOutputs:
21
+ """检测损失分项与匹配结果。"""
22
+
23
+ cls_loss: torch.Tensor
24
+ box_nll: torch.Tensor
25
+ giou_loss: torch.Tensor
26
+ isdyn_loss: torch.Tensor
27
+ matched_indices: list[tuple[torch.Tensor, torch.Tensor]]
28
+
29
+
30
+ def focal_loss(
31
+ logits: torch.Tensor,
32
+ target: torch.Tensor,
33
+ alpha: float = 0.25,
34
+ gamma: float = 2.0,
35
+ reduction: str = "mean",
36
+ ) -> torch.Tensor:
37
+ """多类 focal loss。``logits``: ``[N, C]``, ``target``: ``[N]`` (long)。"""
38
+ log_softmax = F.log_softmax(logits, dim=-1)
39
+ pt = log_softmax.exp().gather(-1, target.unsqueeze(-1)).squeeze(-1)
40
+ ce = -log_softmax.gather(-1, target.unsqueeze(-1)).squeeze(-1)
41
+ focal = alpha * (1 - pt).pow(gamma) * ce
42
+ if reduction == "mean":
43
+ return focal.mean()
44
+ if reduction == "sum":
45
+ return focal.sum()
46
+ return focal
47
+
48
+
49
+ def _bev_giou(
50
+ box_a: torch.Tensor, # [..., 5] (x,y,l,w,yaw)
51
+ box_b: torch.Tensor, # [..., 5]
52
+ ) -> torch.Tensor:
53
+ """BEV 简化 GIoU:取轴对齐的 (x,y,l,w) 包络(忽略 yaw 旋转),
54
+ 可微且足够稳定。如需 SOTA 旋转 IoU 可在后续替换为 ``oriented_iou``。
55
+ """
56
+ cx_a, cy_a, l_a, w_a = box_a[..., 0], box_a[..., 1], box_a[..., 2], box_a[..., 3]
57
+ cx_b, cy_b, l_b, w_b = box_b[..., 0], box_b[..., 1], box_b[..., 2], box_b[..., 3]
58
+ a_x1, a_y1 = cx_a - l_a / 2, cy_a - w_a / 2
59
+ a_x2, a_y2 = cx_a + l_a / 2, cy_a + w_a / 2
60
+ b_x1, b_y1 = cx_b - l_b / 2, cy_b - w_b / 2
61
+ b_x2, b_y2 = cx_b + l_b / 2, cy_b + w_b / 2
62
+ inter_x1 = torch.max(a_x1, b_x1)
63
+ inter_y1 = torch.max(a_y1, b_y1)
64
+ inter_x2 = torch.min(a_x2, b_x2)
65
+ inter_y2 = torch.min(a_y2, b_y2)
66
+ inter_w = (inter_x2 - inter_x1).clamp_min(0)
67
+ inter_h = (inter_y2 - inter_y1).clamp_min(0)
68
+ inter = inter_w * inter_h
69
+ area_a = (l_a * w_a).clamp_min(0)
70
+ area_b = (l_b * w_b).clamp_min(0)
71
+ union = area_a + area_b - inter + 1e-6
72
+ iou = inter / union
73
+ # GIoU enclosure
74
+ enc_x1 = torch.min(a_x1, b_x1)
75
+ enc_y1 = torch.min(a_y1, b_y1)
76
+ enc_x2 = torch.max(a_x2, b_x2)
77
+ enc_y2 = torch.max(a_y2, b_y2)
78
+ enc_area = ((enc_x2 - enc_x1) * (enc_y2 - enc_y1)).clamp_min(1e-6)
79
+ giou = iou - (enc_area - union) / enc_area
80
+ return giou
81
+
82
+
83
+ def giou3d_approx(box_a: torch.Tensor, box_b: torch.Tensor) -> torch.Tensor:
84
+ """3D 近似 GIoU。``box``: ``[..., 7]`` (x,y,z,l,w,h,yaw)。"""
85
+ bev = _bev_giou(box_a[..., [0, 1, 3, 4, 6]], box_b[..., [0, 1, 3, 4, 6]])
86
+ z_a, h_a = box_a[..., 2], box_a[..., 5]
87
+ z_b, h_b = box_b[..., 2], box_b[..., 5]
88
+ a_z1, a_z2 = z_a - h_a / 2, z_a + h_a / 2
89
+ b_z1, b_z2 = z_b - h_b / 2, z_b + h_b / 2
90
+ inter_z = (torch.min(a_z2, b_z2) - torch.max(a_z1, b_z1)).clamp_min(0)
91
+ union_z = (h_a + h_b - inter_z).clamp_min(1e-6)
92
+ z_iou = inter_z / union_z
93
+ return bev * z_iou
94
+
95
+
96
+ class HungarianMatcher:
97
+ """DETR 风格匈牙利匹配(CPU 上 ``scipy.linear_sum_assignment``)。"""
98
+
99
+ def __init__(
100
+ self,
101
+ cls_cost: float = 2.0,
102
+ l1_cost: float = 5.0,
103
+ giou_cost: float = 2.0,
104
+ ) -> None:
105
+ self.cls_cost = cls_cost
106
+ self.l1_cost = l1_cost
107
+ self.giou_cost = giou_cost
108
+
109
+ @torch.no_grad()
110
+ def match(
111
+ self,
112
+ cls_logits: torch.Tensor, # [B, Q, C]
113
+ box_mu: torch.Tensor, # [B, Q, 7]
114
+ targets: list[dict], # 每个样本: {"labels": [N_i], "boxes": [N_i, 7]}
115
+ ) -> list[tuple[torch.Tensor, torch.Tensor]]:
116
+ b, q, c = cls_logits.shape
117
+ out = []
118
+ cls_probs = cls_logits.softmax(-1) # [B, Q, C]
119
+
120
+ for i in range(b):
121
+ tgt_labels = targets[i]["labels"]
122
+ tgt_boxes = targets[i]["boxes"]
123
+ n_tgt = tgt_labels.numel()
124
+ if n_tgt == 0:
125
+ out.append((
126
+ torch.empty(0, dtype=torch.long),
127
+ torch.empty(0, dtype=torch.long),
128
+ ))
129
+ continue
130
+
131
+ # cost_cls: 越大越好 → 取负
132
+ cost_cls = -cls_probs[i, :, tgt_labels] # [Q, n_tgt]
133
+ cost_l1 = torch.cdist(box_mu[i], tgt_boxes, p=1) # [Q, n_tgt]
134
+ # giou3d: [Q, n_tgt]
135
+ qa = box_mu[i].unsqueeze(1).expand(-1, n_tgt, -1)
136
+ tb = tgt_boxes.unsqueeze(0).expand(q, -1, -1)
137
+ cost_giou = -giou3d_approx(qa, tb)
138
+
139
+ cost = (
140
+ self.cls_cost * cost_cls
141
+ + self.l1_cost * cost_l1
142
+ + self.giou_cost * cost_giou
143
+ )
144
+ cost_np = cost.cpu().numpy()
145
+ row, col = linear_sum_assignment(cost_np)
146
+ out.append((torch.as_tensor(row, dtype=torch.long), torch.as_tensor(col, dtype=torch.long)))
147
+ return out
148
+
149
+
150
+ def detection_losses(
151
+ cls_logits: torch.Tensor, # [B, Q, C]
152
+ box_mu: torch.Tensor, # [B, Q, 7]
153
+ box_log_sigma: torch.Tensor, # [B, Q, 7]
154
+ isdyn_logit: torch.Tensor, # [B, Q]
155
+ targets: list[dict], # 每样本: {"labels":..., "boxes":..., "is_dynamic":...}
156
+ matcher: HungarianMatcher,
157
+ num_classes: int,
158
+ background_class: int = 0,
159
+ focal_alpha: float = 0.25,
160
+ focal_gamma: float = 2.0,
161
+ ) -> DetectionLossOutputs:
162
+ """返回 cls/box_nll/giou/isdyn 四个标量 loss + 匹配下标。"""
163
+ indices = matcher.match(cls_logits, box_mu, targets)
164
+ b, q, _ = box_mu.shape
165
+ device = box_mu.device
166
+
167
+ # 构造分类目标:所有 query 默认 background;匹配的填对应 label
168
+ target_classes = torch.full((b, q), background_class, dtype=torch.long, device=device)
169
+ target_isdyn = torch.zeros(b, q, dtype=torch.float32, device=device)
170
+ matched_box_pairs = []
171
+ matched_logsig_pairs = []
172
+ matched_target_boxes = []
173
+
174
+ for i, (rows, cols) in enumerate(indices):
175
+ if rows.numel() == 0:
176
+ continue
177
+ rows = rows.to(device)
178
+ cols = cols.to(device)
179
+ target_classes[i, rows] = targets[i]["labels"][cols].to(device)
180
+ target_isdyn[i, rows] = targets[i]["is_dynamic"][cols].to(device).float()
181
+ matched_box_pairs.append(box_mu[i, rows])
182
+ matched_logsig_pairs.append(box_log_sigma[i, rows])
183
+ matched_target_boxes.append(targets[i]["boxes"][cols].to(device))
184
+
185
+ cls_loss = focal_loss(
186
+ cls_logits.view(b * q, -1),
187
+ target_classes.view(-1),
188
+ alpha=focal_alpha,
189
+ gamma=focal_gamma,
190
+ )
191
+
192
+ if matched_box_pairs:
193
+ pred_box = torch.cat(matched_box_pairs, dim=0)
194
+ pred_logsig = torch.cat(matched_logsig_pairs, dim=0)
195
+ gt_box = torch.cat(matched_target_boxes, dim=0)
196
+ box_nll = gaussian_nll(pred_box, pred_logsig, gt_box)
197
+ giou_v = giou3d_approx(pred_box, gt_box)
198
+ giou_loss = (1.0 - giou_v).mean()
199
+ else:
200
+ box_nll = torch.zeros((), device=device)
201
+ giou_loss = torch.zeros((), device=device)
202
+
203
+ isdyn_loss = F.binary_cross_entropy_with_logits(
204
+ isdyn_logit, target_isdyn, reduction="mean"
205
+ )
206
+
207
+ return DetectionLossOutputs(
208
+ cls_loss=cls_loss,
209
+ box_nll=box_nll,
210
+ giou_loss=giou_loss,
211
+ isdyn_loss=isdyn_loss,
212
+ matched_indices=indices,
213
+ )
src/wjad/losses/moe_aux.py CHANGED
@@ -1,33 +1,33 @@
1
- """MoE 路由的负载均衡 + 边界正则。
2
-
3
- - 负载均衡:``var_b(probs)`` 跨样本(batch)应较小,避免某些样本恒选少数专家。
4
- 这里用 ``var(probs.mean(dim=0))`` 作为简单负载方差度量;
5
- 也加入 ``mean(probs).std()`` 跨专家的均匀性度量。
6
- - 边界正则:``mean(logits ** 2)`` 防止路由 logits 越界,使 sigmoid 不饱和。
7
- """
8
-
9
- from __future__ import annotations
10
-
11
- import torch
12
-
13
- from ..modules.moe import MoEStats
14
-
15
-
16
- def moe_load_balance_and_boundary(
17
- stats_list: list[MoEStats],
18
- load_balance_weight: float = 1.0,
19
- boundary_weight: float = 1.0,
20
- ) -> torch.Tensor:
21
- if not stats_list:
22
- return torch.zeros((), device="cpu")
23
-
24
- device = stats_list[0].logits.device
25
- total = torch.zeros((), device=device)
26
- for stats in stats_list:
27
- # boundary:logits^2 的均值
28
- boundary = stats.logits.pow(2).mean()
29
- # load balance:跨样本的专家选择频率方差
30
- avg_per_expert = stats.probs.mean(dim=0) # [num_routed]
31
- load = avg_per_expert.std()
32
- total = total + boundary_weight * boundary + load_balance_weight * load
33
- return total / max(len(stats_list), 1)
 
1
+ """MoE 路由的负载均衡 + 边界正则。
2
+
3
+ - 负载均衡:``var_b(probs)`` 跨样本(batch)应较小,避免某些样本恒选少数专家。
4
+ 这里用 ``var(probs.mean(dim=0))`` 作为简单负载方差度量;
5
+ 也加入 ``mean(probs).std()`` 跨专家的均匀性度量。
6
+ - 边界正则:``mean(logits ** 2)`` 防止路由 logits 越界,使 sigmoid 不饱和。
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import torch
12
+
13
+ from ..modules.moe import MoEStats
14
+
15
+
16
+ def moe_load_balance_and_boundary(
17
+ stats_list: list[MoEStats],
18
+ load_balance_weight: float = 1.0,
19
+ boundary_weight: float = 1.0,
20
+ ) -> torch.Tensor:
21
+ if not stats_list:
22
+ return torch.zeros((), device="cpu")
23
+
24
+ device = stats_list[0].logits.device
25
+ total = torch.zeros((), device=device)
26
+ for stats in stats_list:
27
+ # boundary:logits^2 的均值
28
+ boundary = stats.logits.pow(2).mean()
29
+ # load balance:跨样本的专家选择频率方差
30
+ avg_per_expert = stats.probs.mean(dim=0) # [num_routed]
31
+ load = avg_per_expert.std()
32
+ total = total + boundary_weight * boundary + load_balance_weight * load
33
+ return total / max(len(stats_list), 1)
src/wjad/losses/nll.py CHANGED
@@ -1,47 +1,47 @@
1
- """高斯 NLL 置信度损失。
2
-
3
- 公式: ``L = 0.5 * ((y - μ) * exp(-log_sigma)) ** 2 + log_sigma + 0.5 * log(2π)``
4
- 为节省常数项,实际实现忽略 ``0.5 * log(2π)``(不影响优化)。
5
-
6
- 支持可选的 ``valid_mask``:在 mask=False 处忽略对应元素。
7
- """
8
-
9
- from __future__ import annotations
10
-
11
- import torch
12
-
13
-
14
- def gaussian_nll(
15
- mu: torch.Tensor,
16
- log_sigma: torch.Tensor,
17
- target: torch.Tensor,
18
- valid_mask: torch.Tensor | None = None,
19
- reduction: str = "mean",
20
- ) -> torch.Tensor:
21
- """高斯负对数似然。
22
-
23
- 所有张量同形状(mask 是其 broadcast 子集,可在最后维省略 features)。
24
- """
25
- diff = target - mu
26
- inv_sigma = torch.exp(-log_sigma)
27
- nll = 0.5 * (diff * inv_sigma).pow(2) + log_sigma
28
-
29
- if valid_mask is not None:
30
- # broadcast 到 nll 的 shape
31
- while valid_mask.dim() < nll.dim():
32
- valid_mask = valid_mask.unsqueeze(-1)
33
- valid_mask = valid_mask.to(nll.dtype)
34
- nll = nll * valid_mask
35
- if reduction == "mean":
36
- denom = valid_mask.sum().clamp_min(1.0)
37
- return nll.sum() / denom
38
- elif reduction == "sum":
39
- return nll.sum()
40
- else:
41
- return nll
42
-
43
- if reduction == "mean":
44
- return nll.mean()
45
- if reduction == "sum":
46
- return nll.sum()
47
- return nll
 
1
+ """高斯 NLL 置信度损失。
2
+
3
+ 公式: ``L = 0.5 * ((y - μ) * exp(-log_sigma)) ** 2 + log_sigma + 0.5 * log(2π)``
4
+ 为节省常数项,实际实现忽略 ``0.5 * log(2π)``(不影响优化)。
5
+
6
+ 支持可选的 ``valid_mask``:在 mask=False 处忽略对应元素。
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import torch
12
+
13
+
14
+ def gaussian_nll(
15
+ mu: torch.Tensor,
16
+ log_sigma: torch.Tensor,
17
+ target: torch.Tensor,
18
+ valid_mask: torch.Tensor | None = None,
19
+ reduction: str = "mean",
20
+ ) -> torch.Tensor:
21
+ """高斯负对数似然。
22
+
23
+ 所有张量同形状(mask 是其 broadcast 子集,可在最后维省略 features)。
24
+ """
25
+ diff = target - mu
26
+ inv_sigma = torch.exp(-log_sigma)
27
+ nll = 0.5 * (diff * inv_sigma).pow(2) + log_sigma
28
+
29
+ if valid_mask is not None:
30
+ # broadcast 到 nll 的 shape
31
+ while valid_mask.dim() < nll.dim():
32
+ valid_mask = valid_mask.unsqueeze(-1)
33
+ valid_mask = valid_mask.to(nll.dtype)
34
+ nll = nll * valid_mask
35
+ if reduction == "mean":
36
+ denom = valid_mask.sum().clamp_min(1.0)
37
+ return nll.sum() / denom
38
+ elif reduction == "sum":
39
+ return nll.sum()
40
+ else:
41
+ return nll
42
+
43
+ if reduction == "mean":
44
+ return nll.mean()
45
+ if reduction == "sum":
46
+ return nll.sum()
47
+ return nll
src/wjad/losses/trajectory.py CHANGED
@@ -1,43 +1,43 @@
1
- """动态目标未来 24 帧轨迹 NLL(仅在匹配到运动类的 query 上启用)。"""
2
-
3
- from __future__ import annotations
4
-
5
- import torch
6
-
7
- from .nll import gaussian_nll
8
-
9
-
10
- def object_traj_nll(
11
- traj_mu: torch.Tensor, # [B, Q, T, 3]
12
- traj_log_sigma: torch.Tensor, # [B, Q, T, 3]
13
- matched_indices: list[tuple[torch.Tensor, torch.Tensor]],
14
- targets: list[dict], # 每样本 {"future_traj":[N,T,3], "future_valid":[N,T], "is_dynamic":[N]}
15
- ) -> torch.Tensor:
16
- """对 ``is_dynamic == True`` 的匹配项求 traj NLL;其余忽略。
17
-
18
- 返回标量 loss。
19
- """
20
- device = traj_mu.device
21
- b = traj_mu.shape[0]
22
- losses = []
23
- for i in range(b):
24
- rows, cols = matched_indices[i]
25
- if rows.numel() == 0:
26
- continue
27
- rows = rows.to(device)
28
- cols = cols.to(device)
29
- is_dyn = targets[i]["is_dynamic"][cols].to(device).bool()
30
- if not is_dyn.any():
31
- continue
32
- sel_rows = rows[is_dyn]
33
- sel_cols = cols[is_dyn]
34
- pred_mu = traj_mu[i, sel_rows] # [n, T, 3]
35
- pred_logsig = traj_log_sigma[i, sel_rows] # [n, T, 3]
36
- gt_traj = targets[i]["future_traj"][sel_cols].to(device)
37
- valid = targets[i]["future_valid"][sel_cols].to(device).bool() # [n, T]
38
- # 在 (T, 3) 维上算 NLL,valid mask 只到 T 维
39
- nll = gaussian_nll(pred_mu, pred_logsig, gt_traj, valid_mask=valid)
40
- losses.append(nll)
41
- if not losses:
42
- return torch.zeros((), device=device)
43
- return torch.stack(losses).mean()
 
1
+ """动态目标未来 24 帧轨迹 NLL(仅在匹配到运动类的 query 上启用)。"""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+
7
+ from .nll import gaussian_nll
8
+
9
+
10
+ def object_traj_nll(
11
+ traj_mu: torch.Tensor, # [B, Q, T, 3]
12
+ traj_log_sigma: torch.Tensor, # [B, Q, T, 3]
13
+ matched_indices: list[tuple[torch.Tensor, torch.Tensor]],
14
+ targets: list[dict], # 每样本 {"future_traj":[N,T,3], "future_valid":[N,T], "is_dynamic":[N]}
15
+ ) -> torch.Tensor:
16
+ """对 ``is_dynamic == True`` 的匹配项求 traj NLL;其余忽略。
17
+
18
+ 返回标量 loss。
19
+ """
20
+ device = traj_mu.device
21
+ b = traj_mu.shape[0]
22
+ losses = []
23
+ for i in range(b):
24
+ rows, cols = matched_indices[i]
25
+ if rows.numel() == 0:
26
+ continue
27
+ rows = rows.to(device)
28
+ cols = cols.to(device)
29
+ is_dyn = targets[i]["is_dynamic"][cols].to(device).bool()
30
+ if not is_dyn.any():
31
+ continue
32
+ sel_rows = rows[is_dyn]
33
+ sel_cols = cols[is_dyn]
34
+ pred_mu = traj_mu[i, sel_rows] # [n, T, 3]
35
+ pred_logsig = traj_log_sigma[i, sel_rows] # [n, T, 3]
36
+ gt_traj = targets[i]["future_traj"][sel_cols].to(device)
37
+ valid = targets[i]["future_valid"][sel_cols].to(device).bool() # [n, T]
38
+ # 在 (T, 3) 维上算 NLL,valid mask 只到 T 维
39
+ nll = gaussian_nll(pred_mu, pred_logsig, gt_traj, valid_mask=valid)
40
+ losses.append(nll)
41
+ if not losses:
42
+ return torch.zeros((), device=device)
43
+ return torch.stack(losses).mean()
src/wjad/model.py CHANGED
@@ -1,289 +1,289 @@
1
- """端到端自动驾驶模型 E2EAVModel。
2
-
3
- forward 流程
4
- 1. ``DINOv3`` 提取 8 帧 patch 特征。
5
- 2. ``OnlineCalibration`` 用原始 ego/intr/extr (symlog) + DINOv3 patch 作 KV,
6
- 输出 symlog 空间残差,叠加并 symexp 还原得到 corrected_*。
7
- 3. 用 corrected_intr / corrected_extr / corrected_ego 计算
8
- - 每 token 的自车系单位射线(仅用于视觉 token 的 RoPE 第一组头)。
9
- - 8 个 ego token(symlog 后线性投影)。
10
- 4. 2×2×2 时空压缩 -> 1536 视觉 token。
11
- 5. 拼接 [vision(1536) | ego(8) | det(1024) | ctrl(24) | extra(256)] = 2848 token。
12
- 非视觉切片各自加可学习 PE。
13
- 6. 18 层主干(仅视觉切片应用 3D RoPE)。
14
- 7. 切片送入 ``DetectionTrajHead`` 与 ``ControlHead``。
15
- """
16
-
17
- from __future__ import annotations
18
-
19
- from dataclasses import dataclass
20
-
21
- import torch
22
- import torch.nn as nn
23
- import torch.nn.functional as F
24
-
25
- from .backbone import Backbone, BackboneOutput
26
- from .calibration import OnlineCalibration, CalibrationOutput
27
- from .encoders import DINOv3Wrapper
28
- from .heads import (
29
- ControlHead,
30
- ControlOutput,
31
- DetectionTrajHead,
32
- DetectionTrajOutput,
33
- )
34
- from .modules.learned_pe import LearnedTokenPE
35
- from .modules.normalization import symlog
36
- from .modules.pos_encoding import RoPE3D
37
- from .modules.rays import compute_ego_rays
38
- from .modules.temporal_compress import TemporalCompress2x2x2
39
-
40
-
41
- @dataclass
42
- class E2EOutput:
43
- """模型完整输出。"""
44
-
45
- detection: DetectionTrajOutput
46
- control: ControlOutput
47
- backbone_out: BackboneOutput
48
- calibration: CalibrationOutput
49
-
50
-
51
- class E2EAVModel(nn.Module):
52
- def __init__(
53
- self,
54
- dinov3_path: str = "./dinov3-vitb16-pretrain-lvd1689m",
55
- backbone_dim: int = 768,
56
- num_heads: int = 12,
57
- num_dense_layers: int = 9,
58
- num_moe_layers: int = 9,
59
- num_routed_experts: int = 7,
60
- num_shared_experts: int = 1,
61
- topk_experts: int = 3,
62
- ffn_mult: int = 4,
63
- # token 数量
64
- num_history_frames: int = 8,
65
- num_detection_tokens: int = 1024,
66
- num_control_tokens: int = 24,
67
- num_ego_tokens: int = 8,
68
- num_extra_tokens: int = 256,
69
- # 输入分辨率
70
- image_h: int = 384,
71
- image_w: int = 1024,
72
- patch_size: int = 16,
73
- # 头超参
74
- num_classes: int = 22,
75
- traj_horizon: int = 24,
76
- det_head_hidden: int = 384,
77
- ctrl_head_hidden: int = 384,
78
- # 校准
79
- calib_dim: int = 256,
80
- calib_num_query: int = 256,
81
- calib_num_blocks: int = 2,
82
- calib_num_self_per_block: int = 2,
83
- calib_num_heads: int = 8,
84
- calib_residual_range: float = 0.1,
85
- calib_intr_dim: int = 11,
86
- # DINOv3
87
- freeze_dinov3: bool = True,
88
- attn_implementation: str = "sdpa",
89
- ) -> None:
90
- super().__init__()
91
- self.image_h = image_h
92
- self.image_w = image_w
93
- self.patch_size = patch_size
94
- self.num_history = num_history_frames
95
- self.num_det = num_detection_tokens
96
- self.num_ctrl = num_control_tokens
97
- self.num_ego = num_ego_tokens
98
- self.num_extra = num_extra_tokens
99
-
100
- # === 1) DINOv3 ===
101
- self.dinov3 = DINOv3Wrapper(
102
- pretrained_path=dinov3_path,
103
- attn_implementation=attn_implementation,
104
- freeze=freeze_dinov3,
105
- )
106
- dino_dim = self.dinov3.hidden_size
107
-
108
- # === 2) 在线校准 ===
109
- self.calib = OnlineCalibration(
110
- dino_dim=dino_dim,
111
- hidden_dim=calib_dim,
112
- num_query_tokens=calib_num_query,
113
- num_blocks=calib_num_blocks,
114
- num_self_attn_per_block=calib_num_self_per_block,
115
- num_heads=calib_num_heads,
116
- residual_range=calib_residual_range,
117
- num_history_frames=num_history_frames,
118
- intr_dim=calib_intr_dim,
119
- )
120
-
121
- # === 3) 时空压缩 ===
122
- self.compress = TemporalCompress2x2x2(dim=dino_dim)
123
- # patch 网格大小(必须能被 2 整除)
124
- self.gh = image_h // patch_size
125
- self.gw = image_w // patch_size
126
-
127
- # === 4) 各类 token + 可学习 PE ===
128
- self.ego_proj = nn.Linear(6, backbone_dim) # 6D pose -> backbone dim
129
- self.det_tokens = nn.Parameter(torch.empty(num_detection_tokens, backbone_dim))
130
- nn.init.trunc_normal_(self.det_tokens, std=0.02)
131
- self.ctrl_tokens = nn.Parameter(torch.empty(num_control_tokens, backbone_dim))
132
- nn.init.trunc_normal_(self.ctrl_tokens, std=0.02)
133
- self.extra_tokens = nn.Parameter(torch.empty(num_extra_tokens, backbone_dim))
134
- nn.init.trunc_normal_(self.extra_tokens, std=0.02)
135
-
136
- self.ego_pe = LearnedTokenPE(num_ego_tokens, backbone_dim)
137
- self.det_pe = LearnedTokenPE(num_detection_tokens, backbone_dim)
138
- self.ctrl_pe = LearnedTokenPE(num_control_tokens, backbone_dim)
139
- self.extra_pe = LearnedTokenPE(num_extra_tokens, backbone_dim)
140
-
141
- # === 5) RoPE 3D(仅视觉,4 时间帧 × 12 × 32 网格)===
142
- self.rope = RoPE3D(
143
- num_heads=num_heads,
144
- head_dim=backbone_dim // num_heads,
145
- time_size=num_history_frames // 2,
146
- height_size=self.gh // 2,
147
- width_size=self.gw // 2,
148
- )
149
-
150
- # === 6) 主干 18 层 ===
151
- self.backbone = Backbone(
152
- dim=backbone_dim,
153
- num_heads=num_heads,
154
- ffn_mult=ffn_mult,
155
- num_dense_layers=num_dense_layers,
156
- num_moe_layers=num_moe_layers,
157
- num_routed=num_routed_experts,
158
- num_shared=num_shared_experts,
159
- topk=topk_experts,
160
- )
161
-
162
- # === 7) 头 ===
163
- self.det_traj_head = DetectionTrajHead(
164
- in_dim=backbone_dim,
165
- hidden_size=det_head_hidden,
166
- num_classes=num_classes,
167
- traj_horizon=traj_horizon,
168
- )
169
- self.ctrl_head = ControlHead(
170
- in_dim=backbone_dim,
171
- hidden_size=ctrl_head_hidden,
172
- num_traj_tokens=12,
173
- num_action_tokens=num_control_tokens - 12,
174
- ego_traj_horizon=traj_horizon,
175
- )
176
-
177
- # ---------- 工具 ----------
178
-
179
- @property
180
- def num_visual_tokens(self) -> int:
181
- # 2×2×2 压缩后
182
- return (self.num_history // 2) * (self.gh // 2) * (self.gw // 2)
183
-
184
- def _build_ego_tokens(self, ego_6d_corrected: torch.Tensor) -> torch.Tensor:
185
- """``[B, 8, 6]`` -> symlog -> Linear -> ``[B, 8, D]``。"""
186
- return self.ego_proj(symlog(ego_6d_corrected))
187
-
188
- def _build_visual_rays(
189
- self,
190
- intr_corrected: torch.Tensor, # [B, calib_intr_dim]
191
- extr_corrected_se3: torch.Tensor, # [B, 4, 4] cam2vehicle
192
- compressed_thw: tuple[int, int, int],
193
- ) -> torch.Tensor:
194
- """计算压缩后视觉 token 网格的射线方向。
195
-
196
- 在 2×2×2 压缩后,每个视觉 token 对应原 patch 网格的一个 2x2 区域 +
197
- 2 个时间帧。这里取所代表区域的中心像素与"中间时间"的射线作近似,
198
- 所有时间帧取同一个 (h, w) 上的射线(因为相机 pose 在 8 帧间是
199
- rigid 的相机系;自车运动差异会通过 ego token 传递)。
200
- """
201
- b = intr_corrected.shape[0]
202
- t_, h_, w_ = compressed_thw
203
- rays_grid = compute_ego_rays(
204
- intr_vec=intr_corrected,
205
- cam2vehicle=extr_corrected_se3,
206
- height=self.image_h,
207
- width=self.image_w,
208
- grid_h=h_,
209
- grid_w=w_,
210
- device=intr_corrected.device,
211
- dtype=intr_corrected.dtype,
212
- ) # [B, h_, w_, 3]
213
- # 复制到时间维:[B, T_, h_, w_, 3] -> flatten 为 [B, N_v, 3]
214
- rays = rays_grid.unsqueeze(1).expand(-1, t_, -1, -1, -1).contiguous()
215
- rays = rays.reshape(b, t_ * h_ * w_, 3)
216
- return rays
217
-
218
- # ---------- 前向 ----------
219
-
220
- def forward(
221
- self,
222
- images: torch.Tensor, # [B, T=8, 3, H, W]
223
- ego_6d_raw: torch.Tensor, # [B, 8, 6]
224
- intr_raw: torch.Tensor, # [B, calib_intr_dim],须与构造时一致
225
- extr_6d_raw: torch.Tensor, # [B, 6]
226
- ) -> E2EOutput:
227
- b, t, _, h, w = images.shape
228
- assert t == self.num_history, f"history frames mismatch: {t} vs {self.num_history}"
229
-
230
- # 1) DINOv3 patch tokens [B, T, gh, gw, D_dino]
231
- dino_feats = self.dinov3(images)
232
-
233
- # 2) 校准(symlog 空间残差 + symexp 还原)
234
- calib_out: CalibrationOutput = self.calib(
235
- dino_feats=dino_feats,
236
- ego_raw=ego_6d_raw,
237
- intr_raw=intr_raw,
238
- extr_raw=extr_6d_raw,
239
- )
240
- corrected_ego = calib_out.corrected_ego
241
- corrected_intr = calib_out.corrected_intr
242
- corrected_extr_6d = calib_out.corrected_extr
243
-
244
- # 3) 把 corrected_extr 6D 转成 4x4
245
- from .data.se3 import six_d_to_matrix
246
- cam2veh_corrected = six_d_to_matrix(corrected_extr_6d) # [B, 4, 4]
247
-
248
- # 4) 2x2x2 时空压缩
249
- compressed, thw = self.compress(dino_feats) # [B, N_v, D]
250
- n_v = compressed.shape[1]
251
-
252
- # 5) 视觉射线(用 corrected_intr / corrected_extr)
253
- rays = self._build_visual_rays(corrected_intr, cam2veh_corrected, thw)
254
- rope_cos, rope_sin = self.rope.compute_freqs(rays)
255
-
256
- # 6) 构造非视觉 token
257
- ego_tok = self._build_ego_tokens(corrected_ego) # [B, 8, D]
258
- det_tok = self.det_tokens.unsqueeze(0).expand(b, -1, -1)
259
- ctrl_tok = self.ctrl_tokens.unsqueeze(0).expand(b, -1, -1)
260
- extra_tok = self.extra_tokens.unsqueeze(0).expand(b, -1, -1)
261
-
262
- ego_tok = self.ego_pe(ego_tok)
263
- det_tok = self.det_pe(det_tok)
264
- ctrl_tok = self.ctrl_pe(ctrl_tok)
265
- extra_tok = self.extra_pe(extra_tok)
266
-
267
- # 7) 拼接序列:[vision | ego | det | ctrl | extra]
268
- seq = torch.cat([compressed, ego_tok, det_tok, ctrl_tok, extra_tok], dim=1)
269
- visual_slice = (0, n_v)
270
-
271
- # 8) 主干
272
- bb_out = self.backbone(seq, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
273
-
274
- # 9) 切片送入头
275
- offset_det = n_v + self.num_ego
276
- offset_ctrl = offset_det + self.num_det
277
-
278
- det_feats = bb_out.hidden_states[:, offset_det : offset_det + self.num_det]
279
- ctrl_feats = bb_out.hidden_states[:, offset_ctrl : offset_ctrl + self.num_ctrl]
280
-
281
- det_out = self.det_traj_head(det_feats)
282
- ctrl_out = self.ctrl_head(ctrl_feats)
283
-
284
- return E2EOutput(
285
- detection=det_out,
286
- control=ctrl_out,
287
- backbone_out=bb_out,
288
- calibration=calib_out,
289
- )
 
1
+ """端到端自动驾驶模型 E2EAVModel。
2
+
3
+ forward 流程
4
+ 1. ``DINOv3`` 提取 8 帧 patch 特征。
5
+ 2. ``OnlineCalibration`` 用原始 ego/intr/extr (symlog) + DINOv3 patch 作 KV,
6
+ 输出 symlog 空间残差,叠加并 symexp 还原得到 corrected_*。
7
+ 3. 用 corrected_intr / corrected_extr / corrected_ego 计算
8
+ - 每 token 的自车系单位射线(仅用于视觉 token 的 RoPE 第一组头)。
9
+ - 8 个 ego token(symlog 后线性投影)。
10
+ 4. 2×2×2 时空压缩 -> 1536 视觉 token。
11
+ 5. 拼接 [vision(1536) | ego(8) | det(1024) | ctrl(24) | extra(256)] = 2848 token。
12
+ 非视觉切片各自加可学习 PE。
13
+ 6. 18 层主干(仅视觉切片应用 3D RoPE)。
14
+ 7. 切片送入 ``DetectionTrajHead`` 与 ``ControlHead``。
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ from dataclasses import dataclass
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+ from .backbone import Backbone, BackboneOutput
26
+ from .calibration import OnlineCalibration, CalibrationOutput
27
+ from .encoders import DINOv3Wrapper
28
+ from .heads import (
29
+ ControlHead,
30
+ ControlOutput,
31
+ DetectionTrajHead,
32
+ DetectionTrajOutput,
33
+ )
34
+ from .modules.learned_pe import LearnedTokenPE
35
+ from .modules.normalization import symlog
36
+ from .modules.pos_encoding import RoPE3D
37
+ from .modules.rays import compute_ego_rays
38
+ from .modules.temporal_compress import TemporalCompress2x2x2
39
+
40
+
41
+ @dataclass
42
+ class E2EOutput:
43
+ """模型完整输出。"""
44
+
45
+ detection: DetectionTrajOutput
46
+ control: ControlOutput
47
+ backbone_out: BackboneOutput
48
+ calibration: CalibrationOutput
49
+
50
+
51
+ class E2EAVModel(nn.Module):
52
+ def __init__(
53
+ self,
54
+ dinov3_path: str = "./dinov3-vitb16-pretrain-lvd1689m",
55
+ backbone_dim: int = 768,
56
+ num_heads: int = 12,
57
+ num_dense_layers: int = 9,
58
+ num_moe_layers: int = 9,
59
+ num_routed_experts: int = 7,
60
+ num_shared_experts: int = 1,
61
+ topk_experts: int = 3,
62
+ ffn_mult: int = 4,
63
+ # token 数量
64
+ num_history_frames: int = 8,
65
+ num_detection_tokens: int = 1024,
66
+ num_control_tokens: int = 24,
67
+ num_ego_tokens: int = 8,
68
+ num_extra_tokens: int = 256,
69
+ # 输入分辨率
70
+ image_h: int = 384,
71
+ image_w: int = 1024,
72
+ patch_size: int = 16,
73
+ # 头超参
74
+ num_classes: int = 22,
75
+ traj_horizon: int = 24,
76
+ det_head_hidden: int = 384,
77
+ ctrl_head_hidden: int = 384,
78
+ # 校准
79
+ calib_dim: int = 256,
80
+ calib_num_query: int = 256,
81
+ calib_num_blocks: int = 2,
82
+ calib_num_self_per_block: int = 2,
83
+ calib_num_heads: int = 8,
84
+ calib_residual_range: float = 0.1,
85
+ calib_intr_dim: int = 11,
86
+ # DINOv3
87
+ freeze_dinov3: bool = True,
88
+ attn_implementation: str = "sdpa",
89
+ ) -> None:
90
+ super().__init__()
91
+ self.image_h = image_h
92
+ self.image_w = image_w
93
+ self.patch_size = patch_size
94
+ self.num_history = num_history_frames
95
+ self.num_det = num_detection_tokens
96
+ self.num_ctrl = num_control_tokens
97
+ self.num_ego = num_ego_tokens
98
+ self.num_extra = num_extra_tokens
99
+
100
+ # === 1) DINOv3 ===
101
+ self.dinov3 = DINOv3Wrapper(
102
+ pretrained_path=dinov3_path,
103
+ attn_implementation=attn_implementation,
104
+ freeze=freeze_dinov3,
105
+ )
106
+ dino_dim = self.dinov3.hidden_size
107
+
108
+ # === 2) 在线校准 ===
109
+ self.calib = OnlineCalibration(
110
+ dino_dim=dino_dim,
111
+ hidden_dim=calib_dim,
112
+ num_query_tokens=calib_num_query,
113
+ num_blocks=calib_num_blocks,
114
+ num_self_attn_per_block=calib_num_self_per_block,
115
+ num_heads=calib_num_heads,
116
+ residual_range=calib_residual_range,
117
+ num_history_frames=num_history_frames,
118
+ intr_dim=calib_intr_dim,
119
+ )
120
+
121
+ # === 3) 时空压缩 ===
122
+ self.compress = TemporalCompress2x2x2(dim=dino_dim)
123
+ # patch 网格大小(必须能被 2 整除)
124
+ self.gh = image_h // patch_size
125
+ self.gw = image_w // patch_size
126
+
127
+ # === 4) 各类 token + 可学习 PE ===
128
+ self.ego_proj = nn.Linear(6, backbone_dim) # 6D pose -> backbone dim
129
+ self.det_tokens = nn.Parameter(torch.empty(num_detection_tokens, backbone_dim))
130
+ nn.init.trunc_normal_(self.det_tokens, std=0.02)
131
+ self.ctrl_tokens = nn.Parameter(torch.empty(num_control_tokens, backbone_dim))
132
+ nn.init.trunc_normal_(self.ctrl_tokens, std=0.02)
133
+ self.extra_tokens = nn.Parameter(torch.empty(num_extra_tokens, backbone_dim))
134
+ nn.init.trunc_normal_(self.extra_tokens, std=0.02)
135
+
136
+ self.ego_pe = LearnedTokenPE(num_ego_tokens, backbone_dim)
137
+ self.det_pe = LearnedTokenPE(num_detection_tokens, backbone_dim)
138
+ self.ctrl_pe = LearnedTokenPE(num_control_tokens, backbone_dim)
139
+ self.extra_pe = LearnedTokenPE(num_extra_tokens, backbone_dim)
140
+
141
+ # === 5) RoPE 3D(仅视觉,4 时间帧 × 12 × 32 网格)===
142
+ self.rope = RoPE3D(
143
+ num_heads=num_heads,
144
+ head_dim=backbone_dim // num_heads,
145
+ time_size=num_history_frames // 2,
146
+ height_size=self.gh // 2,
147
+ width_size=self.gw // 2,
148
+ )
149
+
150
+ # === 6) 主干 18 层 ===
151
+ self.backbone = Backbone(
152
+ dim=backbone_dim,
153
+ num_heads=num_heads,
154
+ ffn_mult=ffn_mult,
155
+ num_dense_layers=num_dense_layers,
156
+ num_moe_layers=num_moe_layers,
157
+ num_routed=num_routed_experts,
158
+ num_shared=num_shared_experts,
159
+ topk=topk_experts,
160
+ )
161
+
162
+ # === 7) 头 ===
163
+ self.det_traj_head = DetectionTrajHead(
164
+ in_dim=backbone_dim,
165
+ hidden_size=det_head_hidden,
166
+ num_classes=num_classes,
167
+ traj_horizon=traj_horizon,
168
+ )
169
+ self.ctrl_head = ControlHead(
170
+ in_dim=backbone_dim,
171
+ hidden_size=ctrl_head_hidden,
172
+ num_traj_tokens=12,
173
+ num_action_tokens=num_control_tokens - 12,
174
+ ego_traj_horizon=traj_horizon,
175
+ )
176
+
177
+ # ---------- 工具 ----------
178
+
179
+ @property
180
+ def num_visual_tokens(self) -> int:
181
+ # 2×2×2 压缩后
182
+ return (self.num_history // 2) * (self.gh // 2) * (self.gw // 2)
183
+
184
+ def _build_ego_tokens(self, ego_6d_corrected: torch.Tensor) -> torch.Tensor:
185
+ """``[B, 8, 6]`` -> symlog -> Linear -> ``[B, 8, D]``。"""
186
+ return self.ego_proj(symlog(ego_6d_corrected))
187
+
188
+ def _build_visual_rays(
189
+ self,
190
+ intr_corrected: torch.Tensor, # [B, calib_intr_dim]
191
+ extr_corrected_se3: torch.Tensor, # [B, 4, 4] cam2vehicle
192
+ compressed_thw: tuple[int, int, int],
193
+ ) -> torch.Tensor:
194
+ """计算压缩后视觉 token 网格的射线方向。
195
+
196
+ 在 2×2×2 压缩后,每个视觉 token 对应原 patch 网格的一个 2x2 区域 +
197
+ 2 个时间帧。这里取所代表区域的中心像素与"中间时间"的射线作近似,
198
+ 所有时间帧取同一个 (h, w) 上的射线(因为相机 pose 在 8 帧间是
199
+ rigid 的相机系;自车运动差异会通过 ego token 传递)。
200
+ """
201
+ b = intr_corrected.shape[0]
202
+ t_, h_, w_ = compressed_thw
203
+ rays_grid = compute_ego_rays(
204
+ intr_vec=intr_corrected,
205
+ cam2vehicle=extr_corrected_se3,
206
+ height=self.image_h,
207
+ width=self.image_w,
208
+ grid_h=h_,
209
+ grid_w=w_,
210
+ device=intr_corrected.device,
211
+ dtype=intr_corrected.dtype,
212
+ ) # [B, h_, w_, 3]
213
+ # 复制到时间维:[B, T_, h_, w_, 3] -> flatten 为 [B, N_v, 3]
214
+ rays = rays_grid.unsqueeze(1).expand(-1, t_, -1, -1, -1).contiguous()
215
+ rays = rays.reshape(b, t_ * h_ * w_, 3)
216
+ return rays
217
+
218
+ # ---------- 前向 ----------
219
+
220
+ def forward(
221
+ self,
222
+ images: torch.Tensor, # [B, T=8, 3, H, W]
223
+ ego_6d_raw: torch.Tensor, # [B, 8, 6]
224
+ intr_raw: torch.Tensor, # [B, calib_intr_dim],须与构造时一致
225
+ extr_6d_raw: torch.Tensor, # [B, 6]
226
+ ) -> E2EOutput:
227
+ b, t, _, h, w = images.shape
228
+ assert t == self.num_history, f"history frames mismatch: {t} vs {self.num_history}"
229
+
230
+ # 1) DINOv3 patch tokens [B, T, gh, gw, D_dino]
231
+ dino_feats = self.dinov3(images)
232
+
233
+ # 2) 校准(symlog 空间残差 + symexp 还原)
234
+ calib_out: CalibrationOutput = self.calib(
235
+ dino_feats=dino_feats,
236
+ ego_raw=ego_6d_raw,
237
+ intr_raw=intr_raw,
238
+ extr_raw=extr_6d_raw,
239
+ )
240
+ corrected_ego = calib_out.corrected_ego
241
+ corrected_intr = calib_out.corrected_intr
242
+ corrected_extr_6d = calib_out.corrected_extr
243
+
244
+ # 3) 把 corrected_extr 6D 转成 4x4
245
+ from .data.se3 import six_d_to_matrix
246
+ cam2veh_corrected = six_d_to_matrix(corrected_extr_6d) # [B, 4, 4]
247
+
248
+ # 4) 2x2x2 时空压缩
249
+ compressed, thw = self.compress(dino_feats) # [B, N_v, D]
250
+ n_v = compressed.shape[1]
251
+
252
+ # 5) 视觉射线(用 corrected_intr / corrected_extr)
253
+ rays = self._build_visual_rays(corrected_intr, cam2veh_corrected, thw)
254
+ rope_cos, rope_sin = self.rope.compute_freqs(rays)
255
+
256
+ # 6) 构造非视觉 token
257
+ ego_tok = self._build_ego_tokens(corrected_ego) # [B, 8, D]
258
+ det_tok = self.det_tokens.unsqueeze(0).expand(b, -1, -1)
259
+ ctrl_tok = self.ctrl_tokens.unsqueeze(0).expand(b, -1, -1)
260
+ extra_tok = self.extra_tokens.unsqueeze(0).expand(b, -1, -1)
261
+
262
+ ego_tok = self.ego_pe(ego_tok)
263
+ det_tok = self.det_pe(det_tok)
264
+ ctrl_tok = self.ctrl_pe(ctrl_tok)
265
+ extra_tok = self.extra_pe(extra_tok)
266
+
267
+ # 7) 拼接序列:[vision | ego | det | ctrl | extra]
268
+ seq = torch.cat([compressed, ego_tok, det_tok, ctrl_tok, extra_tok], dim=1)
269
+ visual_slice = (0, n_v)
270
+
271
+ # 8) 主干
272
+ bb_out = self.backbone(seq, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
273
+
274
+ # 9) 切片送入头
275
+ offset_det = n_v + self.num_ego
276
+ offset_ctrl = offset_det + self.num_det
277
+
278
+ det_feats = bb_out.hidden_states[:, offset_det : offset_det + self.num_det]
279
+ ctrl_feats = bb_out.hidden_states[:, offset_ctrl : offset_ctrl + self.num_ctrl]
280
+
281
+ det_out = self.det_traj_head(det_feats)
282
+ ctrl_out = self.ctrl_head(ctrl_feats)
283
+
284
+ return E2EOutput(
285
+ detection=det_out,
286
+ control=ctrl_out,
287
+ backbone_out=bb_out,
288
+ calibration=calib_out,
289
+ )
src/wjad/modules/__init__.py CHANGED
@@ -1,28 +1,28 @@
1
- """公用算子模块集合。"""
2
-
3
- from __future__ import annotations
4
-
5
- from .ffn import SwiGLUFFN
6
- from .gate_attention import GateSelfAttention, GateCrossAttention
7
- from .moe import MoEBlock, PerLayerExperts
8
- from .normalization import symlog, symexp
9
- from .pos_encoding import RoPE3D, build_rope_freqs
10
- from .learned_pe import LearnedTokenPE
11
- from .rays import FThetaCamera, compute_ego_rays
12
- from .temporal_compress import TemporalCompress2x2x2
13
-
14
- __all__ = [
15
- "SwiGLUFFN",
16
- "GateSelfAttention",
17
- "GateCrossAttention",
18
- "MoEBlock",
19
- "PerLayerExperts",
20
- "symlog",
21
- "symexp",
22
- "RoPE3D",
23
- "build_rope_freqs",
24
- "LearnedTokenPE",
25
- "FThetaCamera",
26
- "compute_ego_rays",
27
- "TemporalCompress2x2x2",
28
- ]
 
1
+ """公用算子模块集合。"""
2
+
3
+ from __future__ import annotations
4
+
5
+ from .ffn import SwiGLUFFN
6
+ from .gate_attention import GateSelfAttention, GateCrossAttention
7
+ from .moe import MoEBlock, PerLayerExperts
8
+ from .normalization import symlog, symexp
9
+ from .pos_encoding import RoPE3D, build_rope_freqs
10
+ from .learned_pe import LearnedTokenPE
11
+ from .rays import FThetaCamera, compute_ego_rays
12
+ from .temporal_compress import TemporalCompress2x2x2
13
+
14
+ __all__ = [
15
+ "SwiGLUFFN",
16
+ "GateSelfAttention",
17
+ "GateCrossAttention",
18
+ "MoEBlock",
19
+ "PerLayerExperts",
20
+ "symlog",
21
+ "symexp",
22
+ "RoPE3D",
23
+ "build_rope_freqs",
24
+ "LearnedTokenPE",
25
+ "FThetaCamera",
26
+ "compute_ego_rays",
27
+ "TemporalCompress2x2x2",
28
+ ]
src/wjad/modules/ffn.py CHANGED
@@ -1,30 +1,30 @@
1
- """SwiGLU 前馈网络。
2
-
3
- 实现:D -> Linear(2 * 4D) -> chunk2 -> SiLU(a) * b -> Linear(D)
4
- 即 D -> 4D -> SwiGLU -> 2D -> D,与 Design.md 规定一致。
5
- """
6
-
7
- from __future__ import annotations
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
-
13
-
14
- class SwiGLUFFN(nn.Module):
15
- """SwiGLU FFN: D->4D->SwiGLU->2D->D。
16
-
17
- 使用 ``F.silu(a) * b`` 与现有 ``swiglu.py`` 中的实现一致。
18
- """
19
-
20
- def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0, bias: bool = True) -> None:
21
- super().__init__()
22
- hidden = mult * dim
23
- self.fc1 = nn.Linear(dim, hidden * 2, bias=bias) # 一次性投影出 a,b
24
- self.fc2 = nn.Linear(hidden, dim, bias=bias)
25
- self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
26
-
27
- def forward(self, x: torch.Tensor) -> torch.Tensor:
28
- ab = self.fc1(x)
29
- a, b = ab.chunk(2, dim=-1)
30
- return self.drop(self.fc2(F.silu(a) * b))
 
1
+ """SwiGLU 前馈网络。
2
+
3
+ 实现:D -> Linear(2 * 4D) -> chunk2 -> SiLU(a) * b -> Linear(D)
4
+ 即 D -> 4D -> SwiGLU -> 2D -> D,与 Design.md 规定一致。
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ """SwiGLU FFN: D->4D->SwiGLU->2D->D。
16
+
17
+ 使用 ``F.silu(a) * b`` 与现有 ``swiglu.py`` 中的实现一致。
18
+ """
19
+
20
+ def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0, bias: bool = True) -> None:
21
+ super().__init__()
22
+ hidden = mult * dim
23
+ self.fc1 = nn.Linear(dim, hidden * 2, bias=bias) # 一次性投影出 a,b
24
+ self.fc2 = nn.Linear(hidden, dim, bias=bias)
25
+ self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ ab = self.fc1(x)
29
+ a, b = ab.chunk(2, dim=-1)
30
+ return self.drop(self.fc2(F.silu(a) * b))
src/wjad/modules/gate_attention.py CHANGED
@@ -1,181 +1,181 @@
1
- """GateSelfAttention / GateCrossAttention(基于 PyTorch SDPA)。
2
-
3
- 与 Design.md 一致:
4
- - Q 经 Linear + Sigmoid 生成 D 维门控参数;
5
- - 注意力得到的多头 V 合并后与门控逐元素相乘,再做 out_proj;
6
- - 门控网络初始化输出 ≈ 1(bias 设大正值,weight ≈ 0),低 LR 缓慢步进。
7
- """
8
-
9
- from __future__ import annotations
10
-
11
- import math
12
- from typing import Optional
13
-
14
- import torch
15
- import torch.nn as nn
16
- import torch.nn.functional as F
17
-
18
- from .pos_encoding import apply_rope
19
-
20
-
21
- class _MultiHeadProj(nn.Module):
22
- """通用的多头 Q/K/V 投影 + reshape。"""
23
-
24
- def __init__(
25
- self,
26
- dim_q: int,
27
- dim_kv: int,
28
- num_heads: int,
29
- head_dim: int,
30
- q_bias: bool = True,
31
- kv_bias: bool = True,
32
- ) -> None:
33
- super().__init__()
34
- self.num_heads = num_heads
35
- self.head_dim = head_dim
36
- inner = num_heads * head_dim
37
- self.q_proj = nn.Linear(dim_q, inner, bias=q_bias)
38
- self.k_proj = nn.Linear(dim_kv, inner, bias=kv_bias)
39
- self.v_proj = nn.Linear(dim_kv, inner, bias=kv_bias)
40
-
41
- def project_q(self, x: torch.Tensor) -> torch.Tensor:
42
- b, n, _ = x.shape
43
- return self.q_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
44
-
45
- def project_kv(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
46
- b, n, _ = x.shape
47
- k = self.k_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
48
- v = self.v_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
49
- return k, v
50
-
51
-
52
- class _GateModule(nn.Module):
53
- """门控生成器:输入 Q 来源张量,输出 [B,N,D] 门控值,初始 ≈ 1。
54
-
55
- bias 初始化为 ``init_bias``(默认 5.0 → sigmoid≈0.993),weight 初始化为 0。
56
- 这样初始状态等价于普通注意力,门控随训练缓慢偏离 1。
57
- """
58
-
59
- def __init__(self, dim: int, init_bias: float = 5.0) -> None:
60
- super().__init__()
61
- self.proj = nn.Linear(dim, dim)
62
- nn.init.zeros_(self.proj.weight)
63
- nn.init.constant_(self.proj.bias, init_bias)
64
-
65
- def forward(self, x: torch.Tensor) -> torch.Tensor:
66
- return torch.sigmoid(self.proj(x))
67
-
68
-
69
- class GateSelfAttention(nn.Module):
70
- """门控自注意力,使用 PyTorch SDPA。
71
-
72
- 支持仅对视觉 token 应用 3D RoPE:通过 ``visual_slice`` 指定切片。
73
- """
74
-
75
- def __init__(
76
- self,
77
- dim: int,
78
- num_heads: int,
79
- dropout: float = 0.0,
80
- gate_init_bias: float = 5.0,
81
- q_bias: bool = True,
82
- kv_bias: bool = True,
83
- ) -> None:
84
- super().__init__()
85
- assert dim % num_heads == 0, "dim 必须能被 num_heads 整除"
86
- self.dim = dim
87
- self.num_heads = num_heads
88
- self.head_dim = dim // num_heads
89
- self.scale = 1.0 / math.sqrt(self.head_dim)
90
- self.dropout_p = dropout
91
-
92
- self.proj = _MultiHeadProj(dim, dim, num_heads, self.head_dim, q_bias, kv_bias)
93
- self.gate = _GateModule(dim, init_bias=gate_init_bias)
94
- self.out_proj = nn.Linear(dim, dim, bias=True)
95
-
96
- def forward(
97
- self,
98
- x: torch.Tensor,
99
- rope_cos: Optional[torch.Tensor] = None,
100
- rope_sin: Optional[torch.Tensor] = None,
101
- visual_slice: Optional[tuple[int, int]] = None,
102
- ) -> torch.Tensor:
103
- """
104
- 参数
105
- ----
106
- x : [B, N, D]
107
- rope_cos, rope_sin : [B, N_v, H, head_dim/2] 或 None
108
- visual_slice : (start, end),指定视觉 token 在序列中的范围。
109
- 非视觉 token 切片 Q/K 不做 RoPE。
110
- """
111
- b, n, _ = x.shape
112
- q = self.proj.project_q(x) # [B, H, N, Dh]
113
- k, v = self.proj.project_kv(x)
114
-
115
- # 仅对视觉切片应用 RoPE
116
- if rope_cos is not None and visual_slice is not None:
117
- s, e = visual_slice
118
- q_v = q[:, :, s:e, :]
119
- k_v = k[:, :, s:e, :]
120
- q_v, k_v = apply_rope(q_v, k_v, rope_cos, rope_sin)
121
- q = torch.cat([q[:, :, :s, :], q_v, q[:, :, e:, :]], dim=2)
122
- k = torch.cat([k[:, :, :s, :], k_v, k[:, :, e:, :]], dim=2)
123
-
124
- # SDPA:[B, H, N, Dh]
125
- attn = F.scaled_dot_product_attention(
126
- q,
127
- k,
128
- v,
129
- attn_mask=None,
130
- dropout_p=self.dropout_p if self.training else 0.0,
131
- is_causal=False,
132
- )
133
- # 多头合并
134
- attn = attn.transpose(1, 2).contiguous().view(b, n, self.dim)
135
-
136
- # 门控 ⊗ 多头合并后的 V,再 out_proj
137
- gate = self.gate(x) # 用 Q 的源(即 x)生成门控
138
- out = self.out_proj(attn * gate)
139
- return out
140
-
141
-
142
- class GateCrossAttention(nn.Module):
143
- """门控交叉注意力,Q 来自 query token,K/V 来自 context(如 DINOv3 patch 特征)。"""
144
-
145
- def __init__(
146
- self,
147
- dim_q: int,
148
- dim_kv: int,
149
- num_heads: int,
150
- dropout: float = 0.0,
151
- gate_init_bias: float = 5.0,
152
- q_bias: bool = True,
153
- kv_bias: bool = True,
154
- ) -> None:
155
- super().__init__()
156
- assert dim_q % num_heads == 0, "dim_q 必须能被 num_heads 整除"
157
- self.dim_q = dim_q
158
- self.num_heads = num_heads
159
- self.head_dim = dim_q // num_heads
160
- self.dropout_p = dropout
161
-
162
- self.proj = _MultiHeadProj(dim_q, dim_kv, num_heads, self.head_dim, q_bias, kv_bias)
163
- self.gate = _GateModule(dim_q, init_bias=gate_init_bias)
164
- self.out_proj = nn.Linear(dim_q, dim_q, bias=True)
165
-
166
- def forward(self, q_in: torch.Tensor, kv_in: torch.Tensor) -> torch.Tensor:
167
- b, n, _ = q_in.shape
168
- q = self.proj.project_q(q_in)
169
- k, v = self.proj.project_kv(kv_in)
170
-
171
- attn = F.scaled_dot_product_attention(
172
- q,
173
- k,
174
- v,
175
- attn_mask=None,
176
- dropout_p=self.dropout_p if self.training else 0.0,
177
- is_causal=False,
178
- )
179
- attn = attn.transpose(1, 2).contiguous().view(b, n, self.dim_q)
180
- gate = self.gate(q_in)
181
- return self.out_proj(attn * gate)
 
1
+ """GateSelfAttention / GateCrossAttention(基于 PyTorch SDPA)。
2
+
3
+ 与 Design.md 一致:
4
+ - Q 经 Linear + Sigmoid 生成 D 维门控参数;
5
+ - 注意力得到的多头 V 合并后与门控逐元素相乘,再做 out_proj;
6
+ - 门控网络初始化输出 ≈ 1(bias 设大正值,weight ≈ 0),低 LR 缓慢步进。
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import math
12
+ from typing import Optional
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ from .pos_encoding import apply_rope
19
+
20
+
21
+ class _MultiHeadProj(nn.Module):
22
+ """通用的多头 Q/K/V 投影 + reshape。"""
23
+
24
+ def __init__(
25
+ self,
26
+ dim_q: int,
27
+ dim_kv: int,
28
+ num_heads: int,
29
+ head_dim: int,
30
+ q_bias: bool = True,
31
+ kv_bias: bool = True,
32
+ ) -> None:
33
+ super().__init__()
34
+ self.num_heads = num_heads
35
+ self.head_dim = head_dim
36
+ inner = num_heads * head_dim
37
+ self.q_proj = nn.Linear(dim_q, inner, bias=q_bias)
38
+ self.k_proj = nn.Linear(dim_kv, inner, bias=kv_bias)
39
+ self.v_proj = nn.Linear(dim_kv, inner, bias=kv_bias)
40
+
41
+ def project_q(self, x: torch.Tensor) -> torch.Tensor:
42
+ b, n, _ = x.shape
43
+ return self.q_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
44
+
45
+ def project_kv(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
46
+ b, n, _ = x.shape
47
+ k = self.k_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
48
+ v = self.v_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
49
+ return k, v
50
+
51
+
52
+ class _GateModule(nn.Module):
53
+ """门控生成器:输入 Q 来源张量,输出 [B,N,D] 门控值,初始 ≈ 1。
54
+
55
+ bias 初始化为 ``init_bias``(默认 5.0 → sigmoid≈0.993),weight 初始化为 0。
56
+ 这样初始状态等价于普通注意力,门控随训练缓慢偏离 1。
57
+ """
58
+
59
+ def __init__(self, dim: int, init_bias: float = 5.0) -> None:
60
+ super().__init__()
61
+ self.proj = nn.Linear(dim, dim)
62
+ nn.init.zeros_(self.proj.weight)
63
+ nn.init.constant_(self.proj.bias, init_bias)
64
+
65
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
66
+ return torch.sigmoid(self.proj(x))
67
+
68
+
69
+ class GateSelfAttention(nn.Module):
70
+ """门控自注意力,使用 PyTorch SDPA。
71
+
72
+ 支持仅对视觉 token 应用 3D RoPE:通过 ``visual_slice`` 指定切片。
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ dim: int,
78
+ num_heads: int,
79
+ dropout: float = 0.0,
80
+ gate_init_bias: float = 5.0,
81
+ q_bias: bool = True,
82
+ kv_bias: bool = True,
83
+ ) -> None:
84
+ super().__init__()
85
+ assert dim % num_heads == 0, "dim 必须能被 num_heads 整除"
86
+ self.dim = dim
87
+ self.num_heads = num_heads
88
+ self.head_dim = dim // num_heads
89
+ self.scale = 1.0 / math.sqrt(self.head_dim)
90
+ self.dropout_p = dropout
91
+
92
+ self.proj = _MultiHeadProj(dim, dim, num_heads, self.head_dim, q_bias, kv_bias)
93
+ self.gate = _GateModule(dim, init_bias=gate_init_bias)
94
+ self.out_proj = nn.Linear(dim, dim, bias=True)
95
+
96
+ def forward(
97
+ self,
98
+ x: torch.Tensor,
99
+ rope_cos: Optional[torch.Tensor] = None,
100
+ rope_sin: Optional[torch.Tensor] = None,
101
+ visual_slice: Optional[tuple[int, int]] = None,
102
+ ) -> torch.Tensor:
103
+ """
104
+ 参数
105
+ ----
106
+ x : [B, N, D]
107
+ rope_cos, rope_sin : [B, N_v, H, head_dim/2] 或 None
108
+ visual_slice : (start, end),指定视觉 token 在序列中的范围。
109
+ 非视觉 token 切片 Q/K 不做 RoPE。
110
+ """
111
+ b, n, _ = x.shape
112
+ q = self.proj.project_q(x) # [B, H, N, Dh]
113
+ k, v = self.proj.project_kv(x)
114
+
115
+ # 仅对视觉切片应用 RoPE
116
+ if rope_cos is not None and visual_slice is not None:
117
+ s, e = visual_slice
118
+ q_v = q[:, :, s:e, :]
119
+ k_v = k[:, :, s:e, :]
120
+ q_v, k_v = apply_rope(q_v, k_v, rope_cos, rope_sin)
121
+ q = torch.cat([q[:, :, :s, :], q_v, q[:, :, e:, :]], dim=2)
122
+ k = torch.cat([k[:, :, :s, :], k_v, k[:, :, e:, :]], dim=2)
123
+
124
+ # SDPA:[B, H, N, Dh]
125
+ attn = F.scaled_dot_product_attention(
126
+ q,
127
+ k,
128
+ v,
129
+ attn_mask=None,
130
+ dropout_p=self.dropout_p if self.training else 0.0,
131
+ is_causal=False,
132
+ )
133
+ # 多头合并
134
+ attn = attn.transpose(1, 2).contiguous().view(b, n, self.dim)
135
+
136
+ # 门控 ⊗ 多头合并后的 V,再 out_proj
137
+ gate = self.gate(x) # 用 Q 的源(即 x)生成门控
138
+ out = self.out_proj(attn * gate)
139
+ return out
140
+
141
+
142
+ class GateCrossAttention(nn.Module):
143
+ """门控交叉注意力,Q 来自 query token,K/V 来自 context(如 DINOv3 patch 特征)。"""
144
+
145
+ def __init__(
146
+ self,
147
+ dim_q: int,
148
+ dim_kv: int,
149
+ num_heads: int,
150
+ dropout: float = 0.0,
151
+ gate_init_bias: float = 5.0,
152
+ q_bias: bool = True,
153
+ kv_bias: bool = True,
154
+ ) -> None:
155
+ super().__init__()
156
+ assert dim_q % num_heads == 0, "dim_q 必须能被 num_heads 整除"
157
+ self.dim_q = dim_q
158
+ self.num_heads = num_heads
159
+ self.head_dim = dim_q // num_heads
160
+ self.dropout_p = dropout
161
+
162
+ self.proj = _MultiHeadProj(dim_q, dim_kv, num_heads, self.head_dim, q_bias, kv_bias)
163
+ self.gate = _GateModule(dim_q, init_bias=gate_init_bias)
164
+ self.out_proj = nn.Linear(dim_q, dim_q, bias=True)
165
+
166
+ def forward(self, q_in: torch.Tensor, kv_in: torch.Tensor) -> torch.Tensor:
167
+ b, n, _ = q_in.shape
168
+ q = self.proj.project_q(q_in)
169
+ k, v = self.proj.project_kv(kv_in)
170
+
171
+ attn = F.scaled_dot_product_attention(
172
+ q,
173
+ k,
174
+ v,
175
+ attn_mask=None,
176
+ dropout_p=self.dropout_p if self.training else 0.0,
177
+ is_causal=False,
178
+ )
179
+ attn = attn.transpose(1, 2).contiguous().view(b, n, self.dim_q)
180
+ gate = self.gate(q_in)
181
+ return self.out_proj(attn * gate)
src/wjad/modules/learned_pe.py CHANGED
@@ -1,24 +1,24 @@
1
- """非视觉 token 的可学习位置编码。
2
-
3
- ego(8) / det(1024) / ctrl(24) / extra(256) 各自维护一份独立的
4
- ``[N, D]`` 可学习参数,初始化 ``trunc_normal(std=0.02)``。
5
- 直接加到对应 token 上,不参与 RoPE。
6
- """
7
-
8
- from __future__ import annotations
9
-
10
- import torch
11
- import torch.nn as nn
12
-
13
-
14
- class LearnedTokenPE(nn.Module):
15
- """形状为 ``[N, D]`` 的可学习位置编码,前向时按 batch 广播加。"""
16
-
17
- def __init__(self, num_tokens: int, dim: int, init_std: float = 0.02) -> None:
18
- super().__init__()
19
- self.pe = nn.Parameter(torch.empty(num_tokens, dim))
20
- nn.init.trunc_normal_(self.pe, std=init_std)
21
-
22
- def forward(self, x: torch.Tensor) -> torch.Tensor:
23
- # x: [B, N, D]
24
- return x + self.pe.unsqueeze(0)
 
1
+ """非视觉 token 的可学习位置编码。
2
+
3
+ ego(8) / det(1024) / ctrl(24) / extra(256) 各自维护一份独立的
4
+ ``[N, D]`` 可学习参数,初始化 ``trunc_normal(std=0.02)``。
5
+ 直接加到对应 token 上,不参与 RoPE。
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+
14
+ class LearnedTokenPE(nn.Module):
15
+ """形状为 ``[N, D]`` 的可学习位置编码,前向时按 batch 广播加。"""
16
+
17
+ def __init__(self, num_tokens: int, dim: int, init_std: float = 0.02) -> None:
18
+ super().__init__()
19
+ self.pe = nn.Parameter(torch.empty(num_tokens, dim))
20
+ nn.init.trunc_normal_(self.pe, std=init_std)
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ # x: [B, N, D]
24
+ return x + self.pe.unsqueeze(0)
src/wjad/modules/moe.py CHANGED
@@ -1,129 +1,129 @@
1
- """每层独立 MoE 块(7 路由 + 1 共享专家,GAP 序列级 Sigmoid Top-3)。
2
-
3
- 设计要点(与 Design.md 对齐):
4
- - 每层独立 8 个专家库(专家[0] 为共享),不同层之间不共享。
5
- - 路由:对当前层输入做 ``GAP(序列) -> Linear -> Sigmoid -> Top3 mask``。
6
- - 共享专家始终激活;路由专家依据 sigmoid 概率加权(Stage1 全激活、Stage2 严格 Top-3)。
7
- - 输出 = 共享专家(x) + sum_i (probs_i * mask_i) * expert_i(x)。
8
- - 提供路由 logits / probs / 负载均衡 / 边界正则的辅助统计,外部由
9
- ``losses/moe_aux.py`` 聚合成正则损失。
10
- """
11
-
12
- from __future__ import annotations
13
-
14
- from dataclasses import dataclass
15
-
16
- import torch
17
- import torch.nn as nn
18
-
19
- from .ffn import SwiGLUFFN
20
-
21
-
22
- @dataclass
23
- class MoEStats:
24
- """单层 MoE 输出的辅助统计,用于损失与监控。"""
25
-
26
- logits: torch.Tensor # [B, num_routed]
27
- probs: torch.Tensor # [B, num_routed],sigmoid 后的概率
28
- topk_mask: torch.Tensor # [B, num_routed],0/1
29
-
30
-
31
- class PerLayerExperts(nn.Module):
32
- """单层的专家库:1 个共享 + N 个路由,全部为 SwiGLUFFN。"""
33
-
34
- def __init__(
35
- self,
36
- dim: int,
37
- num_routed: int = 7,
38
- num_shared: int = 1,
39
- ffn_mult: int = 4,
40
- dropout: float = 0.0,
41
- ) -> None:
42
- super().__init__()
43
- self.num_routed = num_routed
44
- self.num_shared = num_shared
45
- self.shared = nn.ModuleList(
46
- [SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_shared)]
47
- )
48
- self.routed = nn.ModuleList(
49
- [SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_routed)]
50
- )
51
-
52
-
53
- class MoEBlock(nn.Module):
54
- """带路由的 MoE FFN 块(每层独立专家库)。"""
55
-
56
- def __init__(
57
- self,
58
- dim: int,
59
- num_routed: int = 7,
60
- num_shared: int = 1,
61
- ffn_mult: int = 4,
62
- topk: int = 3,
63
- dropout: float = 0.0,
64
- ) -> None:
65
- super().__init__()
66
- self.dim = dim
67
- self.num_routed = num_routed
68
- self.num_shared = num_shared
69
- self.topk = topk
70
- self.experts = PerLayerExperts(dim, num_routed, num_shared, ffn_mult, dropout)
71
- self.router = nn.Linear(dim, num_routed, bias=True)
72
- # 路由初始化:bias=0、weight 较小,以使初始概率接近 0.5
73
- nn.init.normal_(self.router.weight, std=0.02)
74
- nn.init.zeros_(self.router.bias)
75
-
76
- # 训练阶段:'dense' 等同于 topk=num_routed;'sparse' 用真实 topk
77
- self._mode: str = "dense"
78
- # 路由温度(温度 < 1 => 锐化)
79
- self.register_buffer("router_temperature", torch.tensor(1.0))
80
-
81
- def set_mode(self, mode: str) -> None:
82
- assert mode in ("dense", "sparse"), f"未知模式: {mode}"
83
- self._mode = mode
84
-
85
- @property
86
- def mode(self) -> str:
87
- return self._mode
88
-
89
- def set_temperature(self, t: float) -> None:
90
- self.router_temperature.fill_(float(t))
91
-
92
- def _route(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
93
- """计算 logits / probs / topk_mask。x: [B, N, D]。"""
94
- pooled = x.mean(dim=1) # [B, D]
95
- logits = self.router(pooled) # [B, num_routed]
96
- # 温度锐化(温度小 => 概率更尖)
97
- scaled = logits / self.router_temperature.clamp_min(1e-3)
98
- probs = torch.sigmoid(scaled)
99
-
100
- if self._mode == "dense" or self.topk >= self.num_routed:
101
- mask = torch.ones_like(probs)
102
- else:
103
- topk_vals, topk_idx = torch.topk(probs, self.topk, dim=-1)
104
- mask = torch.zeros_like(probs)
105
- mask.scatter_(-1, topk_idx, 1.0)
106
-
107
- return logits, probs, mask
108
-
109
- def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, MoEStats]:
110
- b, n, d = x.shape
111
- logits, probs, mask = self._route(x)
112
-
113
- # 共享专家(恒激活,无门控)
114
- out = torch.zeros_like(x)
115
- for sh in self.experts.shared:
116
- out = out + sh(x)
117
-
118
- # 路由专家:按 (probs * mask) 在 batch 维加权
119
- weights = probs * mask # [B, num_routed]
120
- # 注意:每个样本各自的权重独立。逐专家计算后 batch 级加权,避免 token 级
121
- # 路由的索引开销;与 Design.md "序列级分配" 一致。
122
- for i, expert in enumerate(self.experts.routed):
123
- w_i = weights[:, i].view(b, 1, 1) # [B,1,1]
124
- # 仅当批内任一样本权重 > 0 时才前向以减少计算
125
- if torch.any(w_i > 0):
126
- out = out + w_i * expert(x)
127
-
128
- stats = MoEStats(logits=logits, probs=probs, topk_mask=mask)
129
- return out, stats
 
1
+ """每层独立 MoE 块(7 路由 + 1 共享专家,GAP 序列级 Sigmoid Top-3)。
2
+
3
+ 设计要点(与 Design.md 对齐):
4
+ - 每层独立 8 个专家库(专家[0] 为共享),不同层之间不共享。
5
+ - 路由:对当前层输入做 ``GAP(序列) -> Linear -> Sigmoid -> Top3 mask``。
6
+ - 共享专家始终激活;路由专家依据 sigmoid 概率加权(Stage1 全激活、Stage2 严格 Top-3)。
7
+ - 输出 = 共享专家(x) + sum_i (probs_i * mask_i) * expert_i(x)。
8
+ - 提供路由 logits / probs / 负载均衡 / 边界正则的辅助统计,外部由
9
+ ``losses/moe_aux.py`` 聚合成正则损失。
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from dataclasses import dataclass
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from .ffn import SwiGLUFFN
20
+
21
+
22
+ @dataclass
23
+ class MoEStats:
24
+ """单层 MoE 输出的辅助统计,用于损失与监控。"""
25
+
26
+ logits: torch.Tensor # [B, num_routed]
27
+ probs: torch.Tensor # [B, num_routed],sigmoid 后的概率
28
+ topk_mask: torch.Tensor # [B, num_routed],0/1
29
+
30
+
31
+ class PerLayerExperts(nn.Module):
32
+ """单层的专家库:1 个共享 + N 个路由,全部为 SwiGLUFFN。"""
33
+
34
+ def __init__(
35
+ self,
36
+ dim: int,
37
+ num_routed: int = 7,
38
+ num_shared: int = 1,
39
+ ffn_mult: int = 4,
40
+ dropout: float = 0.0,
41
+ ) -> None:
42
+ super().__init__()
43
+ self.num_routed = num_routed
44
+ self.num_shared = num_shared
45
+ self.shared = nn.ModuleList(
46
+ [SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_shared)]
47
+ )
48
+ self.routed = nn.ModuleList(
49
+ [SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_routed)]
50
+ )
51
+
52
+
53
+ class MoEBlock(nn.Module):
54
+ """带路由的 MoE FFN 块(每层独立专家库)。"""
55
+
56
+ def __init__(
57
+ self,
58
+ dim: int,
59
+ num_routed: int = 7,
60
+ num_shared: int = 1,
61
+ ffn_mult: int = 4,
62
+ topk: int = 3,
63
+ dropout: float = 0.0,
64
+ ) -> None:
65
+ super().__init__()
66
+ self.dim = dim
67
+ self.num_routed = num_routed
68
+ self.num_shared = num_shared
69
+ self.topk = topk
70
+ self.experts = PerLayerExperts(dim, num_routed, num_shared, ffn_mult, dropout)
71
+ self.router = nn.Linear(dim, num_routed, bias=True)
72
+ # 路由初始化:bias=0、weight 较小,以使初始概率接近 0.5
73
+ nn.init.normal_(self.router.weight, std=0.02)
74
+ nn.init.zeros_(self.router.bias)
75
+
76
+ # 训练阶段:'dense' 等同于 topk=num_routed;'sparse' 用真实 topk
77
+ self._mode: str = "dense"
78
+ # 路由温度(温度 < 1 => 锐化)
79
+ self.register_buffer("router_temperature", torch.tensor(1.0))
80
+
81
+ def set_mode(self, mode: str) -> None:
82
+ assert mode in ("dense", "sparse"), f"未知模式: {mode}"
83
+ self._mode = mode
84
+
85
+ @property
86
+ def mode(self) -> str:
87
+ return self._mode
88
+
89
+ def set_temperature(self, t: float) -> None:
90
+ self.router_temperature.fill_(float(t))
91
+
92
+ def _route(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
93
+ """计算 logits / probs / topk_mask。x: [B, N, D]。"""
94
+ pooled = x.mean(dim=1) # [B, D]
95
+ logits = self.router(pooled) # [B, num_routed]
96
+ # 温度锐化(温度小 => 概率更尖)
97
+ scaled = logits / self.router_temperature.clamp_min(1e-3)
98
+ probs = torch.sigmoid(scaled)
99
+
100
+ if self._mode == "dense" or self.topk >= self.num_routed:
101
+ mask = torch.ones_like(probs)
102
+ else:
103
+ topk_vals, topk_idx = torch.topk(probs, self.topk, dim=-1)
104
+ mask = torch.zeros_like(probs)
105
+ mask.scatter_(-1, topk_idx, 1.0)
106
+
107
+ return logits, probs, mask
108
+
109
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, MoEStats]:
110
+ b, n, d = x.shape
111
+ logits, probs, mask = self._route(x)
112
+
113
+ # 共享专家(恒激活,无门控)
114
+ out = torch.zeros_like(x)
115
+ for sh in self.experts.shared:
116
+ out = out + sh(x)
117
+
118
+ # 路由专家:按 (probs * mask) 在 batch 维加权
119
+ weights = probs * mask # [B, num_routed]
120
+ # 注意:每个样本各自的权重独立。逐专家计算后 batch 级加权,避免 token 级
121
+ # 路由的索引开销;与 Design.md "序列级分配" 一致。
122
+ for i, expert in enumerate(self.experts.routed):
123
+ w_i = weights[:, i].view(b, 1, 1) # [B,1,1]
124
+ # 仅当批内任一样本权重 > 0 时才前向以减少计算
125
+ if torch.any(w_i > 0):
126
+ out = out + w_i * expert(x)
127
+
128
+ stats = MoEStats(logits=logits, probs=probs, topk_mask=mask)
129
+ return out, stats
src/wjad/modules/normalization.py CHANGED
@@ -1,22 +1,22 @@
1
- """对称对数归一化算子 symlog / symexp。
2
-
3
- 公式:
4
- symlog(x) = sign(x) * log(|x| + 1)
5
- symexp(y) = sign(y) * (exp(|y|) - 1)
6
-
7
- 用于运动学 / 坐标 / 内外参的归一化,使大幅值被压缩、保持可逆。
8
- """
9
-
10
- from __future__ import annotations
11
-
12
- import torch
13
-
14
-
15
- def symlog(x: torch.Tensor) -> torch.Tensor:
16
- """对称对数压缩:sign(x) * log(|x| + 1)。"""
17
- return torch.sign(x) * torch.log1p(torch.abs(x))
18
-
19
-
20
- def symexp(y: torch.Tensor) -> torch.Tensor:
21
- """symlog 的逆:sign(y) * (exp(|y|) - 1)。"""
22
- return torch.sign(y) * torch.expm1(torch.abs(y))
 
1
+ """对称对数归一化算子 symlog / symexp。
2
+
3
+ 公式:
4
+ symlog(x) = sign(x) * log(|x| + 1)
5
+ symexp(y) = sign(y) * (exp(|y|) - 1)
6
+
7
+ 用于运动学 / 坐标 / 内外参的归一化,使大幅值被压缩、保持可逆。
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+
14
+
15
+ def symlog(x: torch.Tensor) -> torch.Tensor:
16
+ """对称对数压缩:sign(x) * log(|x| + 1)。"""
17
+ return torch.sign(x) * torch.log1p(torch.abs(x))
18
+
19
+
20
+ def symexp(y: torch.Tensor) -> torch.Tensor:
21
+ """symlog 的逆:sign(y) * (exp(|y|) - 1)。"""
22
+ return torch.sign(y) * torch.expm1(torch.abs(y))
src/wjad/modules/pos_encoding.py CHANGED
@@ -1,224 +1,224 @@
1
- """3D RoPE(仅作用于视觉 token)。
2
-
3
- 12 头按 4+4+4 拆为三组:
4
- - 头 0-3:射线 RoPE,编码自车系下的单位射线方向 ``(dx, dy, dz)``。
5
- - 头 4-7:H/W/T RoPE,编码归一化的空间-时间索引 ``(h_norm, w_norm, t_norm)``。
6
- - 头 8-11:零频段 RoPE,cos=1 / sin=0 → 旋转矩阵恒为 I(identity)。
7
-
8
- 为减少分支与显存通信,全部 12 头统一走同一份 RoPE 算子(不写 if/else),
9
- 零频段头自然变为恒等映射。
10
-
11
- 将 ``head_dim=64`` 切成 32 个 (cos, sin) 对(两两一组旋转)。每组头内部再按
12
- 3 个分量(dx,dy,dz 或 h,w,t)平均分配 32/3 ≈ 10 对(最后 2 对补 0 频)。
13
- """
14
-
15
- from __future__ import annotations
16
-
17
- import torch
18
- import torch.nn as nn
19
-
20
-
21
- def _split_head_dim_for_components(half: int, num_components: int) -> list[int]:
22
- """把 head_dim/2 个旋转对均匀分给若干个分量;剩余补 0 频。
23
-
24
- 返回每个分量分到的旋转对数,最后一项是 ``half - sum(其它)``。
25
- 若 ``num_components == 0``(零频段头),则返回 ``[0, 0, ..., half]``,最后
26
- 一项视为"零频段"——它的频率会被显式置为 0。
27
- """
28
- if num_components == 0:
29
- return [0, half]
30
- base = half // num_components
31
- splits = [base] * num_components
32
- splits[-1] += half - base * num_components # 余数全归到最后一个分量
33
- return splits
34
-
35
-
36
- def build_rope_freqs(
37
- rays: torch.Tensor,
38
- hwt_grid: torch.Tensor,
39
- num_heads: int = 12,
40
- head_dim: int = 64,
41
- rope_theta: float = 10000.0,
42
- device: torch.device | None = None,
43
- dtype: torch.dtype = torch.float32,
44
- ) -> tuple[torch.Tensor, torch.Tensor]:
45
- """构造 3D RoPE 的 cos / sin 表。
46
-
47
- 参数
48
- ----
49
- rays : Tensor, shape ``[B, N_v, 3]``
50
- 每个视觉 token 在自车系下的单位射线方向 ``(dx, dy, dz)``。
51
- hwt_grid : Tensor, shape ``[B, N_v, 3]``
52
- 归一化的空间-时间坐标 ``(h_norm, w_norm, t_norm)`` ∈ [-1, 1]。
53
- num_heads : int
54
- 总头数(默认 12)。
55
- head_dim : int
56
- 每头维度(默认 64,必须为偶数)。
57
-
58
- 返回
59
- ----
60
- cos, sin : Tensor, shape ``[B, N_v, num_heads, head_dim // 2]``
61
- 每个旋转对的 cos / sin 值,已就绪可送入 ``apply_rope``。
62
- """
63
- assert head_dim % 2 == 0, "head_dim 必须为偶数"
64
- assert num_heads % 3 == 0, "num_heads 需被 3 整除以便 4+4+4 分组"
65
-
66
- half = head_dim // 2
67
- heads_per_group = num_heads // 3
68
- bsz, n_v, _ = rays.shape
69
- if device is None:
70
- device = rays.device
71
-
72
- # === 三组分量值 ===
73
- # group 0: rays (3 components)
74
- # group 1: hwt (3 components)
75
- # group 2: zero (0 components -> 全部 half 视为零频段)
76
- splits_g0 = _split_head_dim_for_components(half, 3) # 用于 rays
77
- splits_g1 = _split_head_dim_for_components(half, 3) # 用于 hwt
78
- splits_g2 = _split_head_dim_for_components(half, 0) # [0, half]
79
-
80
- # === 频率向量(沿 head_dim 半轴)===
81
- # 经典 RoPE: theta_i = base ^ (-2i / d)
82
- # 这里我们对每个分量独立排布频率
83
- def _freqs(num_pairs: int) -> torch.Tensor:
84
- # 前 num_pairs 个用 RoPE 频率,剩余补 0
85
- idx = torch.arange(num_pairs, device=device, dtype=dtype)
86
- freqs = rope_theta ** (-2.0 * idx / head_dim)
87
- return freqs # [num_pairs]
88
-
89
- # 把分量值与频率张量逐头展开为 [B, N_v, num_heads, half]
90
- angles = torch.zeros(bsz, n_v, num_heads, half, device=device, dtype=dtype)
91
-
92
- # ---- 第 0 组(4 头):射线 ----
93
- base_offset = 0
94
- h0_start = 0
95
- h0_end = h0_start + heads_per_group
96
- cursor = 0
97
- for c in range(3): # dx, dy, dz
98
- n_pairs = splits_g0[c]
99
- if n_pairs > 0:
100
- f = _freqs(n_pairs) # [n_pairs]
101
- comp_val = rays[..., c : c + 1] # [B, N_v, 1]
102
- ang = comp_val * f # 广播 -> [B, N_v, n_pairs]
103
- angles[:, :, h0_start:h0_end, cursor : cursor + n_pairs] = ang.unsqueeze(2)
104
- cursor += n_pairs
105
- # 余数(splits_g0 最后一项的"补足"部分由 _split 已并入最后分量),无需置 0
106
-
107
- # ---- 第 1 组(4 头):HWT ----
108
- h1_start = heads_per_group
109
- h1_end = h1_start + heads_per_group
110
- cursor = 0
111
- for c in range(3): # h, w, t
112
- n_pairs = splits_g1[c]
113
- if n_pairs > 0:
114
- f = _freqs(n_pairs)
115
- comp_val = hwt_grid[..., c : c + 1]
116
- ang = comp_val * f
117
- angles[:, :, h1_start:h1_end, cursor : cursor + n_pairs] = ang.unsqueeze(2)
118
- cursor += n_pairs
119
-
120
- # ---- 第 2 组(4 头):零频段 ----
121
- # 角度恒为 0 → cos=1, sin=0 → 等价 identity;不需要再赋值(已是零)
122
-
123
- cos = torch.cos(angles)
124
- sin = torch.sin(angles)
125
- return cos, sin
126
-
127
-
128
- def apply_rope(
129
- q: torch.Tensor,
130
- k: torch.Tensor,
131
- cos: torch.Tensor,
132
- sin: torch.Tensor,
133
- ) -> tuple[torch.Tensor, torch.Tensor]:
134
- """对 ``q`` ``k`` 的视觉 token 部分应用 3D RoPE。
135
-
136
- 所有 12 头一视同仁地走同一段代码(零频段头 cos=1/sin=0 → identity)。
137
-
138
- 参数
139
- ----
140
- q, k : Tensor, shape ``[B, H, N_v, head_dim]``
141
- cos, sin : Tensor, shape ``[B, N_v, H, head_dim // 2]``
142
-
143
- 返回
144
- ----
145
- 旋转后的 q, k,形状不变。
146
- """
147
- # 把 cos/sin 转成 [B, H, N_v, half]
148
- cos_e = cos.permute(0, 2, 1, 3)
149
- sin_e = sin.permute(0, 2, 1, 3)
150
-
151
- # 把 head_dim 维度按 (even, odd) 拆开成 [..., half]
152
- q_even = q[..., 0::2]
153
- q_odd = q[..., 1::2]
154
- k_even = k[..., 0::2]
155
- k_odd = k[..., 1::2]
156
-
157
- q_rot_even = q_even * cos_e - q_odd * sin_e
158
- q_rot_odd = q_even * sin_e + q_odd * cos_e
159
- k_rot_even = k_even * cos_e - k_odd * sin_e
160
- k_rot_odd = k_even * sin_e + k_odd * cos_e
161
-
162
- q_out = torch.empty_like(q)
163
- k_out = torch.empty_like(k)
164
- q_out[..., 0::2] = q_rot_even
165
- q_out[..., 1::2] = q_rot_odd
166
- k_out[..., 0::2] = k_rot_even
167
- k_out[..., 1::2] = k_rot_odd
168
- return q_out, k_out
169
-
170
-
171
- class RoPE3D(nn.Module):
172
- """3D RoPE 工具模块:缓存 hwt_grid(视觉 token 网格上不变),动态计算 rays。
173
-
174
- 使用方式:
175
- rope = RoPE3D(num_heads=12, head_dim=64, T=4, H=12, W=32)
176
- cos, sin = rope.compute_freqs(rays) # rays: [B, N_v, 3]
177
- q, k = apply_rope(q_visual_only, k_visual_only, cos, sin)
178
- """
179
-
180
- def __init__(
181
- self,
182
- num_heads: int = 12,
183
- head_dim: int = 64,
184
- time_size: int = 4,
185
- height_size: int = 12,
186
- width_size: int = 32,
187
- rope_theta: float = 10000.0,
188
- ) -> None:
189
- super().__init__()
190
- self.num_heads = num_heads
191
- self.head_dim = head_dim
192
- self.rope_theta = rope_theta
193
- self.T = time_size
194
- self.H = height_size
195
- self.W = width_size
196
-
197
- # 预计算并缓存归一化 H/W/T 网格 [N_v, 3],N_v = T*H*W
198
- t = torch.linspace(-1.0, 1.0, steps=time_size) if time_size > 1 else torch.zeros(1)
199
- h = torch.linspace(-1.0, 1.0, steps=height_size) if height_size > 1 else torch.zeros(1)
200
- w = torch.linspace(-1.0, 1.0, steps=width_size) if width_size > 1 else torch.zeros(1)
201
- # 顺序:t -> h -> w(与 Conv3D 输出展平顺序一致)
202
- T_, H_, W_ = torch.meshgrid(t, h, w, indexing="ij")
203
- hwt = torch.stack([H_, W_, T_], dim=-1).reshape(-1, 3) # [N_v, 3]
204
- self.register_buffer("hwt_grid", hwt, persistent=False)
205
-
206
- @property
207
- def num_visual_tokens(self) -> int:
208
- return self.T * self.H * self.W
209
-
210
- def compute_freqs(self, rays: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
211
- """根据每 token 的射线方向计算 cos/sin。
212
-
213
- ``rays`` shape: ``[B, N_v, 3]``。
214
- """
215
- bsz = rays.shape[0]
216
- hwt = self.hwt_grid.unsqueeze(0).expand(bsz, -1, -1) # [B, N_v, 3]
217
- return build_rope_freqs(
218
- rays=rays,
219
- hwt_grid=hwt,
220
- num_heads=self.num_heads,
221
- head_dim=self.head_dim,
222
- rope_theta=self.rope_theta,
223
- dtype=rays.dtype,
224
- )
 
1
+ """3D RoPE(仅作用于视觉 token)。
2
+
3
+ 12 头按 4+4+4 拆为三组:
4
+ - 头 0-3:射线 RoPE,编码自车系下的单位射线方向 ``(dx, dy, dz)``。
5
+ - 头 4-7:H/W/T RoPE,编码归一化的空间-时间索引 ``(h_norm, w_norm, t_norm)``。
6
+ - 头 8-11:零频段 RoPE,cos=1 / sin=0 → 旋转矩阵恒为 I(identity)。
7
+
8
+ 为减少分支与显存通信,全部 12 头统一走同一份 RoPE 算子(不写 if/else),
9
+ 零频段头自然变为恒等映射。
10
+
11
+ 将 ``head_dim=64`` 切成 32 个 (cos, sin) 对(两两一组旋转)。每组头内部再按
12
+ 3 个分量(dx,dy,dz 或 h,w,t)平均分配 32/3 ≈ 10 对(最后 2 对补 0 频)。
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+
21
+ def _split_head_dim_for_components(half: int, num_components: int) -> list[int]:
22
+ """把 head_dim/2 个旋转对均匀分给若干个分量;剩余补 0 频。
23
+
24
+ 返回每个分量分到的旋转对数,最后一项是 ``half - sum(其它)``。
25
+ 若 ``num_components == 0``(零频段头),则返回 ``[0, 0, ..., half]``,最后
26
+ 一项视为"零频段"——它的频率会被显式置为 0。
27
+ """
28
+ if num_components == 0:
29
+ return [0, half]
30
+ base = half // num_components
31
+ splits = [base] * num_components
32
+ splits[-1] += half - base * num_components # 余数全归到最后一个分量
33
+ return splits
34
+
35
+
36
+ def build_rope_freqs(
37
+ rays: torch.Tensor,
38
+ hwt_grid: torch.Tensor,
39
+ num_heads: int = 12,
40
+ head_dim: int = 64,
41
+ rope_theta: float = 10000.0,
42
+ device: torch.device | None = None,
43
+ dtype: torch.dtype = torch.float32,
44
+ ) -> tuple[torch.Tensor, torch.Tensor]:
45
+ """构造 3D RoPE 的 cos / sin 表。
46
+
47
+ 参数
48
+ ----
49
+ rays : Tensor, shape ``[B, N_v, 3]``
50
+ 每个视觉 token 在自车系下的单位射线方向 ``(dx, dy, dz)``。
51
+ hwt_grid : Tensor, shape ``[B, N_v, 3]``
52
+ 归一化的空间-时间坐标 ``(h_norm, w_norm, t_norm)`` ∈ [-1, 1]。
53
+ num_heads : int
54
+ 总头数(默认 12)。
55
+ head_dim : int
56
+ 每头维度(默认 64,必须为偶数)。
57
+
58
+ 返回
59
+ ----
60
+ cos, sin : Tensor, shape ``[B, N_v, num_heads, head_dim // 2]``
61
+ 每个旋转对的 cos / sin 值,已就绪可送入 ``apply_rope``。
62
+ """
63
+ assert head_dim % 2 == 0, "head_dim 必须为偶数"
64
+ assert num_heads % 3 == 0, "num_heads 需被 3 整除以便 4+4+4 分组"
65
+
66
+ half = head_dim // 2
67
+ heads_per_group = num_heads // 3
68
+ bsz, n_v, _ = rays.shape
69
+ if device is None:
70
+ device = rays.device
71
+
72
+ # === 三组分量值 ===
73
+ # group 0: rays (3 components)
74
+ # group 1: hwt (3 components)
75
+ # group 2: zero (0 components -> 全部 half 视为零频段)
76
+ splits_g0 = _split_head_dim_for_components(half, 3) # 用于 rays
77
+ splits_g1 = _split_head_dim_for_components(half, 3) # 用于 hwt
78
+ splits_g2 = _split_head_dim_for_components(half, 0) # [0, half]
79
+
80
+ # === 频率向量(沿 head_dim 半轴)===
81
+ # 经典 RoPE: theta_i = base ^ (-2i / d)
82
+ # 这里我们对每个分量独立排布频率
83
+ def _freqs(num_pairs: int) -> torch.Tensor:
84
+ # 前 num_pairs 个用 RoPE 频率,剩余补 0
85
+ idx = torch.arange(num_pairs, device=device, dtype=dtype)
86
+ freqs = rope_theta ** (-2.0 * idx / head_dim)
87
+ return freqs # [num_pairs]
88
+
89
+ # 把分量值与频率张量逐头展开为 [B, N_v, num_heads, half]
90
+ angles = torch.zeros(bsz, n_v, num_heads, half, device=device, dtype=dtype)
91
+
92
+ # ---- 第 0 组(4 头):射线 ----
93
+ base_offset = 0
94
+ h0_start = 0
95
+ h0_end = h0_start + heads_per_group
96
+ cursor = 0
97
+ for c in range(3): # dx, dy, dz
98
+ n_pairs = splits_g0[c]
99
+ if n_pairs > 0:
100
+ f = _freqs(n_pairs) # [n_pairs]
101
+ comp_val = rays[..., c : c + 1] # [B, N_v, 1]
102
+ ang = comp_val * f # 广播 -> [B, N_v, n_pairs]
103
+ angles[:, :, h0_start:h0_end, cursor : cursor + n_pairs] = ang.unsqueeze(2)
104
+ cursor += n_pairs
105
+ # 余数(splits_g0 最后一项的"补足"部分由 _split 已并入最后分量),无需置 0
106
+
107
+ # ---- 第 1 组(4 头):HWT ----
108
+ h1_start = heads_per_group
109
+ h1_end = h1_start + heads_per_group
110
+ cursor = 0
111
+ for c in range(3): # h, w, t
112
+ n_pairs = splits_g1[c]
113
+ if n_pairs > 0:
114
+ f = _freqs(n_pairs)
115
+ comp_val = hwt_grid[..., c : c + 1]
116
+ ang = comp_val * f
117
+ angles[:, :, h1_start:h1_end, cursor : cursor + n_pairs] = ang.unsqueeze(2)
118
+ cursor += n_pairs
119
+
120
+ # ---- 第 2 组(4 头):零频段 ----
121
+ # 角度恒为 0 → cos=1, sin=0 → 等价 identity;不需要再赋值(已是零)
122
+
123
+ cos = torch.cos(angles)
124
+ sin = torch.sin(angles)
125
+ return cos, sin
126
+
127
+
128
+ def apply_rope(
129
+ q: torch.Tensor,
130
+ k: torch.Tensor,
131
+ cos: torch.Tensor,
132
+ sin: torch.Tensor,
133
+ ) -> tuple[torch.Tensor, torch.Tensor]:
134
+ """对 ``q`` ``k`` 的视觉 token 部分应用 3D RoPE。
135
+
136
+ 所有 12 头一视同仁地走同一段代码(零频段头 cos=1/sin=0 → identity)。
137
+
138
+ 参数
139
+ ----
140
+ q, k : Tensor, shape ``[B, H, N_v, head_dim]``
141
+ cos, sin : Tensor, shape ``[B, N_v, H, head_dim // 2]``
142
+
143
+ 返回
144
+ ----
145
+ 旋转后的 q, k,形状不变。
146
+ """
147
+ # 把 cos/sin 转成 [B, H, N_v, half]
148
+ cos_e = cos.permute(0, 2, 1, 3)
149
+ sin_e = sin.permute(0, 2, 1, 3)
150
+
151
+ # 把 head_dim 维度按 (even, odd) 拆开成 [..., half]
152
+ q_even = q[..., 0::2]
153
+ q_odd = q[..., 1::2]
154
+ k_even = k[..., 0::2]
155
+ k_odd = k[..., 1::2]
156
+
157
+ q_rot_even = q_even * cos_e - q_odd * sin_e
158
+ q_rot_odd = q_even * sin_e + q_odd * cos_e
159
+ k_rot_even = k_even * cos_e - k_odd * sin_e
160
+ k_rot_odd = k_even * sin_e + k_odd * cos_e
161
+
162
+ q_out = torch.empty_like(q)
163
+ k_out = torch.empty_like(k)
164
+ q_out[..., 0::2] = q_rot_even
165
+ q_out[..., 1::2] = q_rot_odd
166
+ k_out[..., 0::2] = k_rot_even
167
+ k_out[..., 1::2] = k_rot_odd
168
+ return q_out, k_out
169
+
170
+
171
+ class RoPE3D(nn.Module):
172
+ """3D RoPE 工具模块:缓存 hwt_grid(视觉 token 网格上不变),动态计算 rays。
173
+
174
+ 使用方式:
175
+ rope = RoPE3D(num_heads=12, head_dim=64, T=4, H=12, W=32)
176
+ cos, sin = rope.compute_freqs(rays) # rays: [B, N_v, 3]
177
+ q, k = apply_rope(q_visual_only, k_visual_only, cos, sin)
178
+ """
179
+
180
+ def __init__(
181
+ self,
182
+ num_heads: int = 12,
183
+ head_dim: int = 64,
184
+ time_size: int = 4,
185
+ height_size: int = 12,
186
+ width_size: int = 32,
187
+ rope_theta: float = 10000.0,
188
+ ) -> None:
189
+ super().__init__()
190
+ self.num_heads = num_heads
191
+ self.head_dim = head_dim
192
+ self.rope_theta = rope_theta
193
+ self.T = time_size
194
+ self.H = height_size
195
+ self.W = width_size
196
+
197
+ # 预计算并缓存归一化 H/W/T 网格 [N_v, 3],N_v = T*H*W
198
+ t = torch.linspace(-1.0, 1.0, steps=time_size) if time_size > 1 else torch.zeros(1)
199
+ h = torch.linspace(-1.0, 1.0, steps=height_size) if height_size > 1 else torch.zeros(1)
200
+ w = torch.linspace(-1.0, 1.0, steps=width_size) if width_size > 1 else torch.zeros(1)
201
+ # 顺序:t -> h -> w(与 Conv3D 输出展平顺序一致)
202
+ T_, H_, W_ = torch.meshgrid(t, h, w, indexing="ij")
203
+ hwt = torch.stack([H_, W_, T_], dim=-1).reshape(-1, 3) # [N_v, 3]
204
+ self.register_buffer("hwt_grid", hwt, persistent=False)
205
+
206
+ @property
207
+ def num_visual_tokens(self) -> int:
208
+ return self.T * self.H * self.W
209
+
210
+ def compute_freqs(self, rays: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
211
+ """根据每 token 的射线方向计算 cos/sin。
212
+
213
+ ``rays`` shape: ``[B, N_v, 3]``。
214
+ """
215
+ bsz = rays.shape[0]
216
+ hwt = self.hwt_grid.unsqueeze(0).expand(bsz, -1, -1) # [B, N_v, 3]
217
+ return build_rope_freqs(
218
+ rays=rays,
219
+ hwt_grid=hwt,
220
+ num_heads=self.num_heads,
221
+ head_dim=self.head_dim,
222
+ rope_theta=self.rope_theta,
223
+ dtype=rays.dtype,
224
+ )
src/wjad/modules/rays.py CHANGED
@@ -1,182 +1,182 @@
1
- """f-theta 相机模型 + 射线计算。
2
-
3
- 依据 Cosmos-Drive-Dreams 数据集 README:
4
- ftheta_intrinsic 存储为 ``[cx, cy, w, h, *poly(6), is_bw_poly, *linear_cde(3)]``。
5
-
6
- f-theta 相机模型用 6 阶多项式将像素半径 ``r_pix = ||(u-cx, v-cy)||`` 映射到
7
- 入射角 ``theta``(或反向)。``is_bw_poly == True`` 表示多项式是从 ``r_pix`` 反
8
- 求 ``theta`` 的 backward polynomial(pixel -> theta);否则是 forward polynomial
9
- (theta -> r_pix)。``linear_cde`` 是仿射修正系数 ``[c, d, e]``,用于补偿轻微
10
- 的非旋转对称形变。
11
-
12
- 为了简单与可微,本模块默认假设 backward polynomial(``is_bw_poly=True``,
13
- 即 ``theta = poly(r_pix)``);实际数据通常是这种格式。如需 forward 多项式,
14
- 这里使用牛顿迭代反求。
15
- """
16
-
17
- from __future__ import annotations
18
-
19
- from dataclasses import dataclass
20
-
21
- import torch
22
- import torch.nn.functional as F
23
-
24
-
25
- @dataclass
26
- class FThetaIntrinsic:
27
- """f-theta 内参(PyTorch 张量形式)。
28
-
29
- 所有字段均为标量或一维向量;外部使用时通常 broadcast 到 batch。
30
- """
31
-
32
- cx: torch.Tensor # ()
33
- cy: torch.Tensor # ()
34
- w: int
35
- h: int
36
- poly: torch.Tensor # (6,)
37
- is_bw_poly: bool
38
- linear_cde: torch.Tensor # (3,)
39
-
40
-
41
- class FThetaCamera:
42
- """f-theta 相机:像素 -> 单位射线方向(相机坐标系)。"""
43
-
44
- def __init__(self, intr: FThetaIntrinsic) -> None:
45
- self.intr = intr
46
-
47
- @staticmethod
48
- def from_vector(vec: torch.Tensor) -> "FThetaCamera":
49
- """从 NVIDIA ftheta 向量构造:``[cx, cy, w, h, poly×6, is_bw_poly?, linear_cde×3?]``。
50
-
51
- 官方常见 14 维;部分 clip 仅 11 维(无 ``linear_cde``),此时用 ``(1,0,1)``,
52
- 与 ``unproject`` 里近似一致。
53
- """
54
- v = vec.flatten().float()
55
- n = int(v.numel())
56
- if n < 10:
57
- raise ValueError(f"ftheta intrinsic 维度 {n} < 10(至少需要 cx,cy,w,h + 6 poly)")
58
- cx = v[0]
59
- cy = v[1]
60
- w = int(v[2].item())
61
- h = int(v[3].item())
62
- poly = v[4:10].clone()
63
- if n >= 11:
64
- is_bw = bool(v[10].item() > 0.5)
65
- else:
66
- is_bw = True
67
- if n >= 14:
68
- linear_cde = v[11:14].clone()
69
- else:
70
- linear_cde = torch.tensor([1.0, 0.0, 1.0], dtype=v.dtype, device=v.device)
71
- return FThetaCamera(
72
- FThetaIntrinsic(cx=cx, cy=cy, w=w, h=h, poly=poly, is_bw_poly=is_bw, linear_cde=linear_cde)
73
- )
74
-
75
- def _eval_poly(self, r: torch.Tensor) -> torch.Tensor:
76
- """用 Horner 法计算 poly(r) = sum_{i=0..5} c_i * r^i。"""
77
- c = self.intr.poly
78
- out = torch.zeros_like(r)
79
- for i in range(c.numel() - 1, -1, -1):
80
- out = out * r + c[i]
81
- return out
82
-
83
- def _eval_poly_grad(self, r: torch.Tensor) -> torch.Tensor:
84
- """poly 的导数。"""
85
- c = self.intr.poly
86
- n = c.numel()
87
- out = torch.zeros_like(r)
88
- for i in range(n - 1, 0, -1):
89
- out = out * r + c[i] * float(i)
90
- return out
91
-
92
- def pixel_to_theta(self, r_pix: torch.Tensor) -> torch.Tensor:
93
- """像素半径 -> 入射角 theta(弧度)。"""
94
- if self.intr.is_bw_poly:
95
- return self._eval_poly(r_pix)
96
- # forward: r_pix = poly(theta) -> 牛顿迭代
97
- theta = r_pix.clone()
98
- for _ in range(8):
99
- f = self._eval_poly(theta) - r_pix
100
- df = self._eval_poly_grad(theta).clamp_min(1e-6)
101
- theta = theta - f / df
102
- return theta
103
-
104
- def unproject(self, uv: torch.Tensor) -> torch.Tensor:
105
- """像素坐标 ``[..., 2]`` -> 相机坐标系下的单位方向 ``[..., 3]``。
106
-
107
- f-theta 反投影:
108
- (du, dv) = (u - cx, v - cy) (并应用 linear_cde 的微小仿射)
109
- r_pix = ||(du, dv)||
110
- theta = poly(r_pix) 或 inv_poly(r_pix)
111
- phi = atan2(dv, du)
112
- dir_cam = (sin(theta)*cos(phi), sin(theta)*sin(phi), cos(theta))
113
- """
114
- cx = self.intr.cx
115
- cy = self.intr.cy
116
- c, d, e = self.intr.linear_cde[0], self.intr.linear_cde[1], self.intr.linear_cde[2]
117
-
118
- u = uv[..., 0]
119
- v = uv[..., 1]
120
- # 应用线性修正:du' = c*du + d*dv + e*1(NVIDIA 工具中通常是 2x2 仿射,这里做近似)
121
- du0 = u - cx
122
- dv0 = v - cy
123
- du = c * du0 + d * dv0
124
- dv = e * du0 + dv0 # 简化:保持 dv 不变量、加入 e*du 微调
125
- r_pix = torch.sqrt(du * du + dv * dv).clamp_min(1e-6)
126
- theta = self.pixel_to_theta(r_pix)
127
-
128
- sin_t = torch.sin(theta)
129
- cos_t = torch.cos(theta)
130
- cos_p = du / r_pix
131
- sin_p = dv / r_pix
132
- x = sin_t * cos_p
133
- y = sin_t * sin_p
134
- z = cos_t
135
- dir_cam = torch.stack([x, y, z], dim=-1)
136
- return F.normalize(dir_cam, dim=-1)
137
-
138
-
139
- def compute_ego_rays(
140
- intr_vec: torch.Tensor,
141
- cam2vehicle: torch.Tensor,
142
- height: int,
143
- width: int,
144
- grid_h: int,
145
- grid_w: int,
146
- device: torch.device,
147
- dtype: torch.dtype = torch.float32,
148
- ) -> torch.Tensor:
149
- """对一个 ``grid_h x grid_w`` 的均匀像素网格计算自车系下单位射线方向。
150
-
151
- 参数
152
- ----
153
- intr_vec : ``[B, 14]`` 或 ``[14]``,f-theta 内参向量。
154
- cam2vehicle : ``[B, 4, 4]`` 或 ``[4, 4]`` 相机系到自车系的变换。
155
- height, width : 图像分辨率(像素),用于在 ``[0, w] x [0, h]`` 网格采样。
156
- grid_h, grid_w : 输出射线网格分辨率(与 patch 网格一致,例如 24x64)。
157
-
158
- 返回
159
- ----
160
- rays : ``[B, grid_h, grid_w, 3]``,自车系下单位方向。
161
- """
162
- if intr_vec.dim() == 1:
163
- intr_vec = intr_vec.unsqueeze(0)
164
- if cam2vehicle.dim() == 2:
165
- cam2vehicle = cam2vehicle.unsqueeze(0)
166
- B = intr_vec.shape[0]
167
-
168
- # 在像素中心采样
169
- u = (torch.arange(grid_w, device=device, dtype=dtype) + 0.5) * (width / grid_w)
170
- v = (torch.arange(grid_h, device=device, dtype=dtype) + 0.5) * (height / grid_h)
171
- vv, uu = torch.meshgrid(v, u, indexing="ij") # [gh, gw]
172
- uv = torch.stack([uu, vv], dim=-1) # [gh, gw, 2]
173
-
174
- out = []
175
- for b in range(B):
176
- cam = FThetaCamera.from_vector(intr_vec[b].to(dtype))
177
- dir_cam = cam.unproject(uv) # [gh, gw, 3]
178
- # 旋到自车系:取 cam2vehicle 的 3x3 旋转部分
179
- R = cam2vehicle[b, :3, :3].to(dtype)
180
- dir_veh = dir_cam @ R.T # [gh, gw, 3]
181
- out.append(F.normalize(dir_veh, dim=-1))
182
- return torch.stack(out, dim=0)
 
1
+ """f-theta 相机模型 + 射线计算。
2
+
3
+ 依据 Cosmos-Drive-Dreams 数据集 README:
4
+ ftheta_intrinsic 存储为 ``[cx, cy, w, h, *poly(6), is_bw_poly, *linear_cde(3)]``。
5
+
6
+ f-theta 相机模型用 6 阶多项式将像素半径 ``r_pix = ||(u-cx, v-cy)||`` 映射到
7
+ 入射角 ``theta``(或反向)。``is_bw_poly == True`` 表示多项式是从 ``r_pix`` 反
8
+ 求 ``theta`` 的 backward polynomial(pixel -> theta);否则是 forward polynomial
9
+ (theta -> r_pix)。``linear_cde`` 是仿射修正系数 ``[c, d, e]``,用于补偿轻微
10
+ 的非旋转对称形变。
11
+
12
+ 为了简单与可微,本模块默认假设 backward polynomial(``is_bw_poly=True``,
13
+ 即 ``theta = poly(r_pix)``);实际数据通常是这种格式。如需 forward 多项式,
14
+ 这里使用牛顿迭代反求。
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ from dataclasses import dataclass
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+
24
+
25
+ @dataclass
26
+ class FThetaIntrinsic:
27
+ """f-theta 内参(PyTorch 张量形式)。
28
+
29
+ 所有字段均为标量或一维向量;外部使用时通常 broadcast 到 batch。
30
+ """
31
+
32
+ cx: torch.Tensor # ()
33
+ cy: torch.Tensor # ()
34
+ w: int
35
+ h: int
36
+ poly: torch.Tensor # (6,)
37
+ is_bw_poly: bool
38
+ linear_cde: torch.Tensor # (3,)
39
+
40
+
41
+ class FThetaCamera:
42
+ """f-theta 相机:像素 -> 单位射线方向(相机坐标系)。"""
43
+
44
+ def __init__(self, intr: FThetaIntrinsic) -> None:
45
+ self.intr = intr
46
+
47
+ @staticmethod
48
+ def from_vector(vec: torch.Tensor) -> "FThetaCamera":
49
+ """从 NVIDIA ftheta 向量构造:``[cx, cy, w, h, poly×6, is_bw_poly?, linear_cde×3?]``。
50
+
51
+ 官方常见 14 维;部分 clip 仅 11 维(无 ``linear_cde``),此时用 ``(1,0,1)``,
52
+ 与 ``unproject`` 里近似一致。
53
+ """
54
+ v = vec.flatten().float()
55
+ n = int(v.numel())
56
+ if n < 10:
57
+ raise ValueError(f"ftheta intrinsic 维度 {n} < 10(至少需要 cx,cy,w,h + 6 poly)")
58
+ cx = v[0]
59
+ cy = v[1]
60
+ w = int(v[2].item())
61
+ h = int(v[3].item())
62
+ poly = v[4:10].clone()
63
+ if n >= 11:
64
+ is_bw = bool(v[10].item() > 0.5)
65
+ else:
66
+ is_bw = True
67
+ if n >= 14:
68
+ linear_cde = v[11:14].clone()
69
+ else:
70
+ linear_cde = torch.tensor([1.0, 0.0, 1.0], dtype=v.dtype, device=v.device)
71
+ return FThetaCamera(
72
+ FThetaIntrinsic(cx=cx, cy=cy, w=w, h=h, poly=poly, is_bw_poly=is_bw, linear_cde=linear_cde)
73
+ )
74
+
75
+ def _eval_poly(self, r: torch.Tensor) -> torch.Tensor:
76
+ """用 Horner 法计算 poly(r) = sum_{i=0..5} c_i * r^i。"""
77
+ c = self.intr.poly
78
+ out = torch.zeros_like(r)
79
+ for i in range(c.numel() - 1, -1, -1):
80
+ out = out * r + c[i]
81
+ return out
82
+
83
+ def _eval_poly_grad(self, r: torch.Tensor) -> torch.Tensor:
84
+ """poly 的导数。"""
85
+ c = self.intr.poly
86
+ n = c.numel()
87
+ out = torch.zeros_like(r)
88
+ for i in range(n - 1, 0, -1):
89
+ out = out * r + c[i] * float(i)
90
+ return out
91
+
92
+ def pixel_to_theta(self, r_pix: torch.Tensor) -> torch.Tensor:
93
+ """像素半径 -> 入射角 theta(弧度)。"""
94
+ if self.intr.is_bw_poly:
95
+ return self._eval_poly(r_pix)
96
+ # forward: r_pix = poly(theta) -> 牛顿迭代
97
+ theta = r_pix.clone()
98
+ for _ in range(8):
99
+ f = self._eval_poly(theta) - r_pix
100
+ df = self._eval_poly_grad(theta).clamp_min(1e-6)
101
+ theta = theta - f / df
102
+ return theta
103
+
104
+ def unproject(self, uv: torch.Tensor) -> torch.Tensor:
105
+ """像素坐标 ``[..., 2]`` -> 相机坐标系下的单位方向 ``[..., 3]``。
106
+
107
+ f-theta 反投影:
108
+ (du, dv) = (u - cx, v - cy) (并应用 linear_cde 的微小仿射)
109
+ r_pix = ||(du, dv)||
110
+ theta = poly(r_pix) 或 inv_poly(r_pix)
111
+ phi = atan2(dv, du)
112
+ dir_cam = (sin(theta)*cos(phi), sin(theta)*sin(phi), cos(theta))
113
+ """
114
+ cx = self.intr.cx
115
+ cy = self.intr.cy
116
+ c, d, e = self.intr.linear_cde[0], self.intr.linear_cde[1], self.intr.linear_cde[2]
117
+
118
+ u = uv[..., 0]
119
+ v = uv[..., 1]
120
+ # 应用线性修正:du' = c*du + d*dv + e*1(NVIDIA 工具中通常是 2x2 仿射,这里做近似)
121
+ du0 = u - cx
122
+ dv0 = v - cy
123
+ du = c * du0 + d * dv0
124
+ dv = e * du0 + dv0 # 简化:保持 dv 不变量、加入 e*du 微调
125
+ r_pix = torch.sqrt(du * du + dv * dv).clamp_min(1e-6)
126
+ theta = self.pixel_to_theta(r_pix)
127
+
128
+ sin_t = torch.sin(theta)
129
+ cos_t = torch.cos(theta)
130
+ cos_p = du / r_pix
131
+ sin_p = dv / r_pix
132
+ x = sin_t * cos_p
133
+ y = sin_t * sin_p
134
+ z = cos_t
135
+ dir_cam = torch.stack([x, y, z], dim=-1)
136
+ return F.normalize(dir_cam, dim=-1)
137
+
138
+
139
+ def compute_ego_rays(
140
+ intr_vec: torch.Tensor,
141
+ cam2vehicle: torch.Tensor,
142
+ height: int,
143
+ width: int,
144
+ grid_h: int,
145
+ grid_w: int,
146
+ device: torch.device,
147
+ dtype: torch.dtype = torch.float32,
148
+ ) -> torch.Tensor:
149
+ """对一个 ``grid_h x grid_w`` 的均匀像素网格计算自车系下单位射线方向。
150
+
151
+ 参数
152
+ ----
153
+ intr_vec : ``[B, 14]`` 或 ``[14]``,f-theta 内参向量。
154
+ cam2vehicle : ``[B, 4, 4]`` 或 ``[4, 4]`` 相机系到自车系的变换。
155
+ height, width : 图像分辨率(像素),用于在 ``[0, w] x [0, h]`` 网格采样。
156
+ grid_h, grid_w : 输出射线网格分辨率(与 patch 网格一致,例如 24x64)。
157
+
158
+ 返回
159
+ ----
160
+ rays : ``[B, grid_h, grid_w, 3]``,自车系下单位方向。
161
+ """
162
+ if intr_vec.dim() == 1:
163
+ intr_vec = intr_vec.unsqueeze(0)
164
+ if cam2vehicle.dim() == 2:
165
+ cam2vehicle = cam2vehicle.unsqueeze(0)
166
+ B = intr_vec.shape[0]
167
+
168
+ # 在像素中心采样
169
+ u = (torch.arange(grid_w, device=device, dtype=dtype) + 0.5) * (width / grid_w)
170
+ v = (torch.arange(grid_h, device=device, dtype=dtype) + 0.5) * (height / grid_h)
171
+ vv, uu = torch.meshgrid(v, u, indexing="ij") # [gh, gw]
172
+ uv = torch.stack([uu, vv], dim=-1) # [gh, gw, 2]
173
+
174
+ out = []
175
+ for b in range(B):
176
+ cam = FThetaCamera.from_vector(intr_vec[b].to(dtype))
177
+ dir_cam = cam.unproject(uv) # [gh, gw, 3]
178
+ # 旋到自车系:取 cam2vehicle 的 3x3 旋转部分
179
+ R = cam2vehicle[b, :3, :3].to(dtype)
180
+ dir_veh = dir_cam @ R.T # [gh, gw, 3]
181
+ out.append(F.normalize(dir_veh, dim=-1))
182
+ return torch.stack(out, dim=0)
src/wjad/modules/temporal_compress.py CHANGED
@@ -1,34 +1,34 @@
1
- """2×2×2 时空压缩卷积。
2
-
3
- 将 8 帧 × 24 × 64 = 8×1536 = 12288 个 patch tokens 压缩为
4
- 4 × 12 × 32 = 1536 个视觉 tokens。维度保持 768。
5
- """
6
-
7
- from __future__ import annotations
8
-
9
- import torch
10
- import torch.nn as nn
11
-
12
-
13
- class TemporalCompress2x2x2(nn.Module):
14
- """``Conv3d(D, D, kernel=2, stride=2)`` 配合标准 LayerNorm。"""
15
-
16
- def __init__(self, dim: int = 768) -> None:
17
- super().__init__()
18
- self.dim = dim
19
- self.conv = nn.Conv3d(dim, dim, kernel_size=2, stride=2, padding=0)
20
- self.norm = nn.LayerNorm(dim)
21
-
22
- def forward(self, x: torch.Tensor) -> torch.Tensor:
23
- """输入 ``[B, T, H, W, D]``;输出 ``[B, T*H*W//8, D]``。
24
-
25
- 中间排布:
26
- [B, T, H, W, D] -> [B, D, T, H, W] -> Conv3d -> [B, D, T', H', W']
27
- -> [B, T'*H'*W', D] -> LayerNorm
28
- """
29
- b, t, h, w, d = x.shape
30
- x_in = x.permute(0, 4, 1, 2, 3).contiguous() # [B, D, T, H, W]
31
- y = self.conv(x_in)
32
- bb, dd, t2, h2, w2 = y.shape
33
- y = y.permute(0, 2, 3, 4, 1).reshape(bb, t2 * h2 * w2, dd)
34
- return self.norm(y), (t2, h2, w2)
 
1
+ """2×2×2 时空压缩卷积。
2
+
3
+ 将 8 帧 × 24 × 64 = 8×1536 = 12288 个 patch tokens 压缩为
4
+ 4 × 12 × 32 = 1536 个视觉 tokens。维度保持 768。
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+
13
+ class TemporalCompress2x2x2(nn.Module):
14
+ """``Conv3d(D, D, kernel=2, stride=2)`` 配合标准 LayerNorm。"""
15
+
16
+ def __init__(self, dim: int = 768) -> None:
17
+ super().__init__()
18
+ self.dim = dim
19
+ self.conv = nn.Conv3d(dim, dim, kernel_size=2, stride=2, padding=0)
20
+ self.norm = nn.LayerNorm(dim)
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ """输入 ``[B, T, H, W, D]``;输出 ``[B, T*H*W//8, D]``。
24
+
25
+ 中间排布:
26
+ [B, T, H, W, D] -> [B, D, T, H, W] -> Conv3d -> [B, D, T', H', W']
27
+ -> [B, T'*H'*W', D] -> LayerNorm
28
+ """
29
+ b, t, h, w, d = x.shape
30
+ x_in = x.permute(0, 4, 1, 2, 3).contiguous() # [B, D, T, H, W]
31
+ y = self.conv(x_in)
32
+ bb, dd, t2, h2, w2 = y.shape
33
+ y = y.permute(0, 2, 3, 4, 1).reshape(bb, t2 * h2 * w2, dd)
34
+ return self.norm(y), (t2, h2, w2)