Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- MEMORY.md +63 -63
- README.md +63 -63
- configs/default.yaml +192 -192
- pyproject.toml +40 -40
- scripts/download_data.py +76 -76
- scripts/estimate_memory.py +203 -203
- scripts/ingest_hub_to_bucket.py +234 -207
- scripts/push_cpu_ingest_job.py +148 -141
- scripts/push_to_jobs.py +196 -196
- scripts/push_to_sandbox.py +185 -185
- scripts/sandbox_real_data.py +236 -236
- scripts/smoke_test.py +78 -78
- scripts/smoke_train.py +152 -152
- scripts/update_deps.py +123 -123
- src/wjad/__init__.py +5 -5
- src/wjad/backbone/__init__.py +6 -6
- src/wjad/backbone/backbone.py +110 -110
- src/wjad/backbone/blocks.py +79 -79
- src/wjad/calibration/__init__.py +5 -5
- src/wjad/calibration/online_calib.py +196 -196
- src/wjad/data/__init__.py +39 -39
- src/wjad/data/cosmos_dataset.py +439 -439
- src/wjad/data/ftheta_proj.py +62 -62
- src/wjad/data/hdmap.py +247 -247
- src/wjad/data/label_paths.py +218 -218
- src/wjad/data/se3.py +111 -111
- src/wjad/data/targets.py +214 -214
- src/wjad/data/transforms.py +86 -86
- src/wjad/encoders/__init__.py +5 -5
- src/wjad/encoders/dinov3_wrapper.py +104 -104
- src/wjad/heads/__init__.py +11 -11
- src/wjad/heads/control.py +100 -100
- src/wjad/heads/detection_traj.py +106 -106
- src/wjad/losses/__init__.py +24 -24
- src/wjad/losses/calib_reg.py +21 -21
- src/wjad/losses/control.py +25 -25
- src/wjad/losses/detection.py +213 -213
- src/wjad/losses/moe_aux.py +33 -33
- src/wjad/losses/nll.py +47 -47
- src/wjad/losses/trajectory.py +43 -43
- src/wjad/model.py +289 -289
- src/wjad/modules/__init__.py +28 -28
- src/wjad/modules/ffn.py +30 -30
- src/wjad/modules/gate_attention.py +181 -181
- src/wjad/modules/learned_pe.py +24 -24
- src/wjad/modules/moe.py +129 -129
- src/wjad/modules/normalization.py +22 -22
- src/wjad/modules/pos_encoding.py +224 -224
- src/wjad/modules/rays.py +182 -182
- src/wjad/modules/temporal_compress.py +34 -34
MEMORY.md
CHANGED
|
@@ -1,63 +1,63 @@
|
|
| 1 |
-
# 显存与内存估算(bf16 AMP + GradNorm + PCGrad)
|
| 2 |
-
|
| 3 |
-
由 `scripts/estimate_memory.py` 生成。模型规模:18 层主干(9 Dense + 9 MoE,每 MoE 层 7 路由 + 1 共享专家)+ DINOv3 ViT-B/16 + 6 层校准 + 1024 检测 token + 24 控制 token。
|
| 4 |
-
|
| 5 |
-
| 项目 | 数值 |
|
| 6 |
-
|---|---|
|
| 7 |
-
| 总参数 | **725.62 M** |
|
| 8 |
-
| 可训练 (Stage1, DINOv3 冻结) | 639.96 M |
|
| 9 |
-
| 可训练 (Stage2, DINOv3 解冻) | 725.62 M |
|
| 10 |
-
| 序列长度(拼接后) | 2848 |
|
| 11 |
-
|
| 12 |
-
## 显存(含 15% 余量)
|
| 13 |
-
|
| 14 |
-
| Batch Size | Stage2 峰值 | 推荐单卡 GPU | HF Sandbox 选项 |
|
| 15 |
-
|---:|---:|---|---|
|
| 16 |
-
| 1 | ~16 GB | T4 16GB(紧)/ L4 24GB | `t4-small` |
|
| 17 |
-
| 2 | ~18 GB | L4 24GB | `l4x1` |
|
| 18 |
-
| 4 | ~22 GB | L4 24GB / A10G 24GB | `a10g-small` |
|
| 19 |
-
| **8 (目标)** | **~30 GB** | **A10G Large 48GB / A100 40GB** | **`a10g-large`** |
|
| 20 |
-
| 16 | ~46 GB | A100 80GB / H100 80GB | `a100-large` |
|
| 21 |
-
|
| 22 |
-
显存细分(BS=8 Stage2):
|
| 23 |
-
- 权重 (bf16): 1.35 GB
|
| 24 |
-
- 优化器 (AdamW fp32 m+v + 主副本): 8.11 GB
|
| 25 |
-
- 主激活 (bf16, 18 + 6 层): ~12.6 GB
|
| 26 |
-
- PCGrad retain_graph 开销: ~6.3 GB
|
| 27 |
-
- 缓冲 / cuDNN workspace / 碎片: ~2 GB
|
| 28 |
-
|
| 29 |
-
如显存不足:
|
| 30 |
-
- 开 `gradient_checkpointing`(激活降至 ~1/3,可把 BS=8 塞进 A10G 24GB 大约 28GB)
|
| 31 |
-
- BS=4 + `grad_accum_steps=2` 等价 BS=8 训练
|
| 32 |
-
- 关 PCGrad(节省 ~6 GB),但牺牲多任务收敛质量
|
| 33 |
-
|
| 34 |
-
## 主机内存 / 磁盘
|
| 35 |
-
|
| 36 |
-
| 项目 | 数值(BS=8) |
|
| 37 |
-
|---|---|
|
| 38 |
-
| 主机 RAM 推荐 | ≥ 32 GB(DataLoader 4 workers × prefetch 2 + 模型 CPU 副本) |
|
| 39 |
-
| 磁盘(一个 weather 子集,sandbox 验证) | ~50 GB |
|
| 40 |
-
| 磁盘(synthetic 全量 121 帧 × 7 weather × 5843 clip) | ~700 GB |
|
| 41 |
-
| 磁盘(synthetic + lidar + hdmap 全量) | ~3 TB |
|
| 42 |
-
|
| 43 |
-
## 设备选择建议
|
| 44 |
-
|
| 45 |
-
- **本地烟囱(CPU/小卡)**:`scripts/smoke_test.py` 用极小张量验证 forward+backward,不需要 GPU。
|
| 46 |
-
- **HF Sandbox**:`a10g-large` (48 GB),BS=8 + bf16 + PCGrad 一次成功;约 $1.05/小时(HF 价格随时调整请以官方为准)。
|
| 47 |
-
- **HF Jobs 全量训练**:`a100x1` (80 GB) 或 `h100x1`,BS=8~16。
|
| 48 |
-
|
| 49 |
-
## 复现命令
|
| 50 |
-
|
| 51 |
-
```bash
|
| 52 |
-
# 升级依赖到最新(写入 requirements.lock.txt)
|
| 53 |
-
python scripts/update_deps.py --torch-index https://download.pytorch.org/whl/cu124
|
| 54 |
-
|
| 55 |
-
# 估算
|
| 56 |
-
python scripts/estimate_memory.py
|
| 57 |
-
|
| 58 |
-
# Sandbox 推送
|
| 59 |
-
python scripts/push_to_sandbox.py --repo your-username/wjad-sandbox --gpu a10g-large
|
| 60 |
-
|
| 61 |
-
# Jobs 全量
|
| 62 |
-
python scripts/push_to_jobs.py --repo your-username/wjad --flavor a100x1
|
| 63 |
-
```
|
|
|
|
| 1 |
+
# 显存与内存估算(bf16 AMP + GradNorm + PCGrad)
|
| 2 |
+
|
| 3 |
+
由 `scripts/estimate_memory.py` 生成。模型规模:18 层主干(9 Dense + 9 MoE,每 MoE 层 7 路由 + 1 共享专家)+ DINOv3 ViT-B/16 + 6 层校准 + 1024 检测 token + 24 控制 token。
|
| 4 |
+
|
| 5 |
+
| 项目 | 数值 |
|
| 6 |
+
|---|---|
|
| 7 |
+
| 总参数 | **725.62 M** |
|
| 8 |
+
| 可训练 (Stage1, DINOv3 冻结) | 639.96 M |
|
| 9 |
+
| 可训练 (Stage2, DINOv3 解冻) | 725.62 M |
|
| 10 |
+
| 序列长度(拼接后) | 2848 |
|
| 11 |
+
|
| 12 |
+
## 显存(含 15% 余量)
|
| 13 |
+
|
| 14 |
+
| Batch Size | Stage2 峰值 | 推荐单卡 GPU | HF Sandbox 选项 |
|
| 15 |
+
|---:|---:|---|---|
|
| 16 |
+
| 1 | ~16 GB | T4 16GB(紧)/ L4 24GB | `t4-small` |
|
| 17 |
+
| 2 | ~18 GB | L4 24GB | `l4x1` |
|
| 18 |
+
| 4 | ~22 GB | L4 24GB / A10G 24GB | `a10g-small` |
|
| 19 |
+
| **8 (目标)** | **~30 GB** | **A10G Large 48GB / A100 40GB** | **`a10g-large`** |
|
| 20 |
+
| 16 | ~46 GB | A100 80GB / H100 80GB | `a100-large` |
|
| 21 |
+
|
| 22 |
+
显存细分(BS=8 Stage2):
|
| 23 |
+
- 权重 (bf16): 1.35 GB
|
| 24 |
+
- 优化器 (AdamW fp32 m+v + 主副本): 8.11 GB
|
| 25 |
+
- 主激活 (bf16, 18 + 6 层): ~12.6 GB
|
| 26 |
+
- PCGrad retain_graph 开销: ~6.3 GB
|
| 27 |
+
- 缓冲 / cuDNN workspace / 碎片: ~2 GB
|
| 28 |
+
|
| 29 |
+
如显存不足:
|
| 30 |
+
- 开 `gradient_checkpointing`(激活降至 ~1/3,可把 BS=8 塞进 A10G 24GB 大约 28GB)
|
| 31 |
+
- BS=4 + `grad_accum_steps=2` 等价 BS=8 训练
|
| 32 |
+
- 关 PCGrad(节省 ~6 GB),但牺牲多任务收敛质量
|
| 33 |
+
|
| 34 |
+
## 主机内存 / 磁盘
|
| 35 |
+
|
| 36 |
+
| 项目 | 数值(BS=8) |
|
| 37 |
+
|---|---|
|
| 38 |
+
| 主机 RAM 推荐 | ≥ 32 GB(DataLoader 4 workers × prefetch 2 + 模型 CPU 副本) |
|
| 39 |
+
| 磁盘(一个 weather 子集,sandbox 验证) | ~50 GB |
|
| 40 |
+
| 磁盘(synthetic 全量 121 帧 × 7 weather × 5843 clip) | ~700 GB |
|
| 41 |
+
| 磁盘(synthetic + lidar + hdmap 全量) | ~3 TB |
|
| 42 |
+
|
| 43 |
+
## 设备选择建议
|
| 44 |
+
|
| 45 |
+
- **本地烟囱(CPU/小卡)**:`scripts/smoke_test.py` 用极小张量验证 forward+backward,不需要 GPU。
|
| 46 |
+
- **HF Sandbox**:`a10g-large` (48 GB),BS=8 + bf16 + PCGrad 一次成功;约 $1.05/小时(HF 价格随时调整请以官方为准)。
|
| 47 |
+
- **HF Jobs 全量训练**:`a100x1` (80 GB) 或 `h100x1`,BS=8~16。
|
| 48 |
+
|
| 49 |
+
## 复现命令
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
# 升级依赖到最新(写入 requirements.lock.txt)
|
| 53 |
+
python scripts/update_deps.py --torch-index https://download.pytorch.org/whl/cu124
|
| 54 |
+
|
| 55 |
+
# 估算
|
| 56 |
+
python scripts/estimate_memory.py
|
| 57 |
+
|
| 58 |
+
# Sandbox 推送
|
| 59 |
+
python scripts/push_to_sandbox.py --repo your-username/wjad-sandbox --gpu a10g-large
|
| 60 |
+
|
| 61 |
+
# Jobs 全量
|
| 62 |
+
python scripts/push_to_jobs.py --repo your-username/wjad --flavor a100x1
|
| 63 |
+
```
|
README.md
CHANGED
|
@@ -1,63 +1,63 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: WJAD Sandbox
|
| 3 |
-
emoji: 🚗
|
| 4 |
-
colorFrom: blue
|
| 5 |
-
colorTo: indigo
|
| 6 |
-
sdk: docker
|
| 7 |
-
app_port: 7860
|
| 8 |
-
pinned: false
|
| 9 |
-
---
|
| 10 |
-
|
| 11 |
-
# WJAD - 端到端自动驾驶模型
|
| 12 |
-
|
| 13 |
-
基于 [Design.md](Design.md) 实现的端到端自动驾驶模型,用于 NVIDIA Cosmos-Drive-Dreams 数据集。
|
| 14 |
-
|
| 15 |
-
## 架构概览
|
| 16 |
-
|
| 17 |
-
- **视觉编码器**:本地 DINOv3 ViT-B/16(`dinov3-vitb16-pretrain-lvd1689m`),SDPA 注意力。
|
| 18 |
-
- **时空压缩**:2×2×2 Conv3D,将 8 帧 × 24 × 64 patch tokens 压缩为 1536 个视觉 token。
|
| 19 |
-
- **在线校准**:dim=256,6 层 (1 GateCrossAttn + 2 GateSelfAttn) × 2,跨注意力 K/V 来自 DINOv3 patch;输入与残差均在 symlog 空间,输出 SE3 + 内外参修正量。
|
| 20 |
-
- **主干**:18 层 GateSelfAttention(前 9 Dense + 后 9 MoE,每层独立 7 路由 + 1 共享专家,GAP 序列级 Sigmoid Top-3),dim=768,12 头 SDPA + PreNorm + SwiGLU。
|
| 21 |
-
- **位置编码**:3D RoPE 仅作用于视觉 token 的 Q/K——头 0-3 编码自车系单位射线,头 4-7 编码 H/W/T,头 8-11 零频段(identity,统一代码路径)。其余 token(ego/det/ctrl/extra)使用一一对应的可学习 PE。
|
| 22 |
-
- **统一检测+预测头**:1024 token 同时输出 `cls + is_dynamic + box3d(μ,logσ) + 未来 24 帧轨迹(μ,logσ)`。
|
| 23 |
-
- **控制头**:24 token 输出自车未来轨迹与全局控制(均 NLL μ/logσ)。
|
| 24 |
-
- **多任务训练**:GradNorm 自适应任务权重 + Stage2 启用 PCGrad 正交化梯度冲突。
|
| 25 |
-
- **训练阶段**:Stage1 Dense + 路由锐化 + 中期运动学/内外参扰动;Stage2 切 Top-3 + DINOv3 低 LR 微调。
|
| 26 |
-
|
| 27 |
-
## 三步训练路径
|
| 28 |
-
|
| 29 |
-
```bash
|
| 30 |
-
# 1. 本地跑通(纯随机张量)
|
| 31 |
-
python -m scripts.smoke_test
|
| 32 |
-
|
| 33 |
-
# 2. HF Sandbox 微小训练
|
| 34 |
-
python -m scripts.push_to_sandbox
|
| 35 |
-
|
| 36 |
-
# 3. HF Jobs 全量训练
|
| 37 |
-
python -m scripts.push_to_jobs
|
| 38 |
-
```
|
| 39 |
-
|
| 40 |
-
## 数据准备
|
| 41 |
-
|
| 42 |
-
```bash
|
| 43 |
-
python -m scripts.download_data --odir ./data/cosmos --file_types synthetic,lidar,hdmap
|
| 44 |
-
```
|
| 45 |
-
|
| 46 |
-
## 项目结构
|
| 47 |
-
|
| 48 |
-
```
|
| 49 |
-
src/wjad/
|
| 50 |
-
├── modules/ # 公用算子:FFN/门控注意力/MoE/RoPE/可学习PE/symlog/...
|
| 51 |
-
├── encoders/ # DINOv3 包装 + 2x2x2 时空压缩
|
| 52 |
-
├── calibration/ # 在线校准网络
|
| 53 |
-
├── backbone/ # 18 层主干
|
| 54 |
-
├── heads/ # 检测+预测头、控制头
|
| 55 |
-
├── data/ # Cosmos-Drive-Dreams 加载器、f-theta、增广
|
| 56 |
-
├── losses/ # NLL/检测/轨迹/控制/MoE/校准正则
|
| 57 |
-
├── train/ # 多任务(GradNorm+PCGrad)、Trainer、调度
|
| 58 |
-
└── model.py # 顶层 E2EAVModel
|
| 59 |
-
```
|
| 60 |
-
|
| 61 |
-
## License
|
| 62 |
-
|
| 63 |
-
代码遵循仓库根目录指定的开源协议。DINOv3 权重遵循 Meta DINOv3 License;Cosmos-Drive-Dreams 数据集遵循 CC BY 4.0。
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: WJAD Sandbox
|
| 3 |
+
emoji: 🚗
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: false
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# WJAD - 端到端自动驾驶模型
|
| 12 |
+
|
| 13 |
+
基于 [Design.md](Design.md) 实现的端到端自动驾驶模型,用于 NVIDIA Cosmos-Drive-Dreams 数据集。
|
| 14 |
+
|
| 15 |
+
## 架构概览
|
| 16 |
+
|
| 17 |
+
- **视觉编码器**:本地 DINOv3 ViT-B/16(`dinov3-vitb16-pretrain-lvd1689m`),SDPA 注意力。
|
| 18 |
+
- **时空压缩**:2×2×2 Conv3D,将 8 帧 × 24 × 64 patch tokens 压缩为 1536 个视觉 token。
|
| 19 |
+
- **在线校准**:dim=256,6 层 (1 GateCrossAttn + 2 GateSelfAttn) × 2,跨注意力 K/V 来自 DINOv3 patch;输入与残差均在 symlog 空间,输出 SE3 + 内外参修正量。
|
| 20 |
+
- **主干**:18 层 GateSelfAttention(前 9 Dense + 后 9 MoE,每层独立 7 路由 + 1 共享专家,GAP 序列级 Sigmoid Top-3),dim=768,12 头 SDPA + PreNorm + SwiGLU。
|
| 21 |
+
- **位置编码**:3D RoPE 仅作用于视觉 token 的 Q/K——头 0-3 编码自车系单位射线,头 4-7 编码 H/W/T,头 8-11 零频段(identity,统一代码路径)。其余 token(ego/det/ctrl/extra)使用一一对应的可学习 PE。
|
| 22 |
+
- **统一检测+预测头**:1024 token 同时输出 `cls + is_dynamic + box3d(μ,logσ) + 未来 24 帧轨迹(μ,logσ)`。
|
| 23 |
+
- **控制头**:24 token 输出自车未来轨迹与全局控制(均 NLL μ/logσ)。
|
| 24 |
+
- **多任务训练**:GradNorm 自适应任务权重 + Stage2 启用 PCGrad 正交化梯度冲突。
|
| 25 |
+
- **训练阶段**:Stage1 Dense + 路由锐化 + 中期运动学/内外参扰动;Stage2 切 Top-3 + DINOv3 低 LR 微调。
|
| 26 |
+
|
| 27 |
+
## 三步训练路径
|
| 28 |
+
|
| 29 |
+
```bash
|
| 30 |
+
# 1. 本地跑通(纯随机张量)
|
| 31 |
+
python -m scripts.smoke_test
|
| 32 |
+
|
| 33 |
+
# 2. HF Sandbox 微小训练
|
| 34 |
+
python -m scripts.push_to_sandbox
|
| 35 |
+
|
| 36 |
+
# 3. HF Jobs 全量训练
|
| 37 |
+
python -m scripts.push_to_jobs
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## 数据准备
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
python -m scripts.download_data --odir ./data/cosmos --file_types synthetic,lidar,hdmap
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## 项目结构
|
| 47 |
+
|
| 48 |
+
```
|
| 49 |
+
src/wjad/
|
| 50 |
+
├── modules/ # 公用算子:FFN/门控注意力/MoE/RoPE/可学习PE/symlog/...
|
| 51 |
+
├── encoders/ # DINOv3 包装 + 2x2x2 时空压缩
|
| 52 |
+
├── calibration/ # 在线校准网络
|
| 53 |
+
├── backbone/ # 18 层主干
|
| 54 |
+
├── heads/ # 检测+预测头、控制头
|
| 55 |
+
├── data/ # Cosmos-Drive-Dreams 加载器、f-theta、增广
|
| 56 |
+
├── losses/ # NLL/检测/轨迹/控制/MoE/校准正则
|
| 57 |
+
├── train/ # 多任务(GradNorm+PCGrad)、Trainer、调度
|
| 58 |
+
└── model.py # 顶层 E2EAVModel
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
## License
|
| 62 |
+
|
| 63 |
+
代码遵循仓库根目录指定的开源协议。DINOv3 权重遵循 Meta DINOv3 License;Cosmos-Drive-Dreams 数据集遵循 CC BY 4.0。
|
configs/default.yaml
CHANGED
|
@@ -1,192 +1,192 @@
|
|
| 1 |
-
# 端到端自动驾驶模型默认配置(与 Design.md 对齐)
|
| 2 |
-
|
| 3 |
-
# === 全局 ===
|
| 4 |
-
seed: 42
|
| 5 |
-
device: cuda
|
| 6 |
-
mixed_precision: bf16 # H100 推荐 bf16
|
| 7 |
-
gradient_checkpointing: true # A10G/L4/A100-40G 上需要打开;H100-80G 可关闭以加速
|
| 8 |
-
|
| 9 |
-
# === 输入 ===
|
| 10 |
-
input:
|
| 11 |
-
image_height: 384 # 已裁去上半天空的高度
|
| 12 |
-
image_width: 1024
|
| 13 |
-
num_history_frames: 8 # t-7..t(含当前)
|
| 14 |
-
num_future_frames: 24 # 预测窗口
|
| 15 |
-
camera_name: camera_front_wide_120fov # 当前唯一开放的视角
|
| 16 |
-
|
| 17 |
-
# === 视觉编码器(DINOv3)===
|
| 18 |
-
dinov3:
|
| 19 |
-
pretrained_path: "./dinov3-vitb16-pretrain-lvd1689m"
|
| 20 |
-
hidden_size: 768
|
| 21 |
-
patch_size: 16
|
| 22 |
-
num_register_tokens: 4
|
| 23 |
-
attn_implementation: sdpa
|
| 24 |
-
freeze_in_stage1: true # Stage1 冻结 DINOv3
|
| 25 |
-
finetune_lr_ratio: 0.01 # Stage2 解冻后相对主干 LR 的倍率
|
| 26 |
-
|
| 27 |
-
# === 时空压缩 ===
|
| 28 |
-
temporal_compress:
|
| 29 |
-
kernel: [2, 2, 2]
|
| 30 |
-
stride: [2, 2, 2]
|
| 31 |
-
|
| 32 |
-
# === 主干 ===
|
| 33 |
-
backbone:
|
| 34 |
-
hidden_size: 768
|
| 35 |
-
num_heads: 12 # head_dim = 64
|
| 36 |
-
ffn_mult: 4 # SwiGLU 扩展倍数(D->4D->2D->D)
|
| 37 |
-
num_dense_layers: 9
|
| 38 |
-
num_moe_layers: 9 # 共 18 层
|
| 39 |
-
dropout: 0.0
|
| 40 |
-
prenorm: true
|
| 41 |
-
|
| 42 |
-
# === MoE ===
|
| 43 |
-
moe:
|
| 44 |
-
num_routed_experts: 7
|
| 45 |
-
num_shared_experts: 1
|
| 46 |
-
topk: 3 # Stage2 激活专家数
|
| 47 |
-
router_temperature_init: 0.5
|
| 48 |
-
router_temperature_final: 1.0
|
| 49 |
-
load_balance_weight: 0.01
|
| 50 |
-
boundary_weight: 0.001 # mean(logits^2) 防越界
|
| 51 |
-
|
| 52 |
-
# === Token 数量 ===
|
| 53 |
-
tokens:
|
| 54 |
-
num_detection: 1024
|
| 55 |
-
num_control: 24
|
| 56 |
-
num_ego: 8
|
| 57 |
-
num_extra: 256
|
| 58 |
-
|
| 59 |
-
# === 在线校准 ===
|
| 60 |
-
calibration:
|
| 61 |
-
intr_vec_dim: 11 # Cosmos npy 常见 11 维;完整 14 维数据则改为 14
|
| 62 |
-
hidden_size: 256
|
| 63 |
-
num_query_tokens: 256
|
| 64 |
-
num_self_attn_per_block: 2 # 每个 block: 1 cross + 2 self
|
| 65 |
-
num_blocks: 2 # 总计 6 层
|
| 66 |
-
num_heads: 8 # head_dim = 32
|
| 67 |
-
residual_range: 0.1 # Tanh * range
|
| 68 |
-
init_zero_output: true
|
| 69 |
-
|
| 70 |
-
# === 检测+未来预测头 ===
|
| 71 |
-
det_traj_head:
|
| 72 |
-
num_classes: 22 # 1 bg + 12 dynamic + 9 structured(HDMap)
|
| 73 |
-
box_dim: 7 # x,y,z,l,w,h,yaw
|
| 74 |
-
traj_horizon: 24 # 未来 24 帧
|
| 75 |
-
traj_dim: 3 # dx,dy,dyaw
|
| 76 |
-
hidden_size: 384
|
| 77 |
-
log_sigma_clamp: [-7.0, 7.0]
|
| 78 |
-
|
| 79 |
-
# === 控制头 ===
|
| 80 |
-
control_head:
|
| 81 |
-
num_traj_tokens: 12 # 解码 24 帧 ego 轨迹
|
| 82 |
-
num_action_tokens: 12 # 1 个解码 (steer,throttle,brake) μ/logσ
|
| 83 |
-
ego_traj_dim: 3 # x,y,yaw
|
| 84 |
-
action_dim: 3
|
| 85 |
-
hidden_size: 384
|
| 86 |
-
log_sigma_clamp: [-7.0, 7.0]
|
| 87 |
-
|
| 88 |
-
# === 检测目标筛选 ===
|
| 89 |
-
detection:
|
| 90 |
-
max_distance_m: 48.0
|
| 91 |
-
occlusion_depth_tolerance: 0.5 # LIDAR 深度容差(米)
|
| 92 |
-
min_box_pixels: 8
|
| 93 |
-
dynamic_classes:
|
| 94 |
-
- Automobile
|
| 95 |
-
- Heavy_truck
|
| 96 |
-
- Bus
|
| 97 |
-
- Train_or_tram_car
|
| 98 |
-
- Trolley_bus
|
| 99 |
-
- Other_vehicle
|
| 100 |
-
- Trailer
|
| 101 |
-
- Person
|
| 102 |
-
- Stroller
|
| 103 |
-
- Rider
|
| 104 |
-
- Animal
|
| 105 |
-
- Protruding_object
|
| 106 |
-
|
| 107 |
-
# === 数据 ===
|
| 108 |
-
data:
|
| 109 |
-
root: "./data/cosmos"
|
| 110 |
-
hdmap_subdir: "rds_hq"
|
| 111 |
-
synthetic_subdir: "cosmos_synthetic/single_view"
|
| 112 |
-
use_synthetic: true
|
| 113 |
-
use_real: false
|
| 114 |
-
weather:
|
| 115 |
-
- Sunny
|
| 116 |
-
- Morning
|
| 117 |
-
- Golden_hour
|
| 118 |
-
- Night
|
| 119 |
-
- Rainy
|
| 120 |
-
- Snowy
|
| 121 |
-
- Foggy
|
| 122 |
-
num_workers: 4
|
| 123 |
-
pin_memory: true
|
| 124 |
-
prefetch_factor: 2
|
| 125 |
-
augmentation:
|
| 126 |
-
gaussian_noise_std: 0.01
|
| 127 |
-
color_jitter: 0.1
|
| 128 |
-
perturb_translation_std_m: 0.1 # Stage1 中期开启
|
| 129 |
-
perturb_rotation_std_deg: 0.5
|
| 130 |
-
perturb_intrinsic_std: 0.005
|
| 131 |
-
perturb_extrinsic_std: 0.005
|
| 132 |
-
|
| 133 |
-
# === 损失权重(GradNorm 自适应的 1-6 任务初值)===
|
| 134 |
-
loss:
|
| 135 |
-
cls_weight: 1.0
|
| 136 |
-
box_weight: 1.0
|
| 137 |
-
isdyn_weight: 1.0
|
| 138 |
-
traj_obj_weight: 1.0
|
| 139 |
-
traj_ego_weight: 1.0
|
| 140 |
-
ctrl_weight: 1.0
|
| 141 |
-
moe_weight: 1.0 # 固定权重正则
|
| 142 |
-
calib_weight: 0.1 # 固定权重正则
|
| 143 |
-
giou_weight: 0.5 # 在 box 内组合
|
| 144 |
-
focal_alpha: 0.25
|
| 145 |
-
focal_gamma: 2.0
|
| 146 |
-
matcher_cls_cost: 2.0
|
| 147 |
-
matcher_l1_cost: 5.0
|
| 148 |
-
matcher_giou_cost: 2.0
|
| 149 |
-
|
| 150 |
-
# === 多任务训练 ===
|
| 151 |
-
# PCGrad 与 GradNorm 在 Stage1 / Stage2 全程启用:
|
| 152 |
-
# 两阶段的 6 项主任务(cls / box / isdyn / traj_obj / traj_ego / ctrl)
|
| 153 |
-
# 都存在尺度差异与梯度方向冲突,PCGrad 不应延迟到 Stage2。
|
| 154 |
-
multitask:
|
| 155 |
-
enable_gradnorm: true
|
| 156 |
-
enable_pcgrad: true
|
| 157 |
-
gradnorm_alpha: 1.5
|
| 158 |
-
gradnorm_lr: 0.025
|
| 159 |
-
pcgrad_shuffle: true
|
| 160 |
-
|
| 161 |
-
# === 训练 ===
|
| 162 |
-
train:
|
| 163 |
-
batch_size: 12 # A10G-Large 起步;OOM 时改为 8/6 并视情况增大 grad_accum_steps
|
| 164 |
-
grad_accum_steps: 1 # 有效 batch = batch_size * grad_accum_steps
|
| 165 |
-
ckpt_dir: outputs/checkpoints
|
| 166 |
-
total_steps: 100000
|
| 167 |
-
warmup_steps: 1000
|
| 168 |
-
base_lr: 2.0e-4
|
| 169 |
-
min_lr: 1.0e-6
|
| 170 |
-
weight_decay: 0.05
|
| 171 |
-
optimizer: adamw
|
| 172 |
-
betas: [0.9, 0.95]
|
| 173 |
-
grad_clip: 1.0
|
| 174 |
-
log_interval: 20
|
| 175 |
-
ckpt_interval: 1000
|
| 176 |
-
eval_interval: 5000
|
| 177 |
-
stage1_steps: 60000 # Stage1 步数
|
| 178 |
-
stage1_perturb_start: 20000 # 中期开始扰动
|
| 179 |
-
grad_monitor_threshold: 1.0e-7
|
| 180 |
-
param_groups:
|
| 181 |
-
dinov3_lr_mult: 0.0 # Stage1=0, Stage2 由 finetune_lr_ratio 提供
|
| 182 |
-
backbone_lr_mult: 1.0
|
| 183 |
-
calibration_lr_mult: 0.1
|
| 184 |
-
head_lr_mult: 1.0
|
| 185 |
-
gate_lr_mult: 0.1 # 门控参数低 LR
|
| 186 |
-
|
| 187 |
-
# === 部署 ===
|
| 188 |
-
deploy:
|
| 189 |
-
hf_repo: "fuzirui/WJAD" # 训练产生的 checkpoint 上传到此 model 仓库
|
| 190 |
-
push_checkpoints: true # 每 ckpt_interval 步上传 step_*.pt + latest.pt
|
| 191 |
-
hub_ckpt_prefix: checkpoints # Hub 上子目录
|
| 192 |
-
hf_sandbox_space: "fuzirui/wjad-sandbox"
|
|
|
|
| 1 |
+
# 端到端自动驾驶模型默认配置(与 Design.md 对齐)
|
| 2 |
+
|
| 3 |
+
# === 全局 ===
|
| 4 |
+
seed: 42
|
| 5 |
+
device: cuda
|
| 6 |
+
mixed_precision: bf16 # H100 推荐 bf16
|
| 7 |
+
gradient_checkpointing: true # A10G/L4/A100-40G 上需要打开;H100-80G 可关闭以加速
|
| 8 |
+
|
| 9 |
+
# === 输入 ===
|
| 10 |
+
input:
|
| 11 |
+
image_height: 384 # 已裁去上半天空的高度
|
| 12 |
+
image_width: 1024
|
| 13 |
+
num_history_frames: 8 # t-7..t(含当前)
|
| 14 |
+
num_future_frames: 24 # 预测窗口
|
| 15 |
+
camera_name: camera_front_wide_120fov # 当前唯一开放的视角
|
| 16 |
+
|
| 17 |
+
# === 视觉编码器(DINOv3)===
|
| 18 |
+
dinov3:
|
| 19 |
+
pretrained_path: "./dinov3-vitb16-pretrain-lvd1689m"
|
| 20 |
+
hidden_size: 768
|
| 21 |
+
patch_size: 16
|
| 22 |
+
num_register_tokens: 4
|
| 23 |
+
attn_implementation: sdpa
|
| 24 |
+
freeze_in_stage1: true # Stage1 冻结 DINOv3
|
| 25 |
+
finetune_lr_ratio: 0.01 # Stage2 解冻后相对主干 LR 的倍率
|
| 26 |
+
|
| 27 |
+
# === 时空压缩 ===
|
| 28 |
+
temporal_compress:
|
| 29 |
+
kernel: [2, 2, 2]
|
| 30 |
+
stride: [2, 2, 2]
|
| 31 |
+
|
| 32 |
+
# === 主干 ===
|
| 33 |
+
backbone:
|
| 34 |
+
hidden_size: 768
|
| 35 |
+
num_heads: 12 # head_dim = 64
|
| 36 |
+
ffn_mult: 4 # SwiGLU 扩展倍数(D->4D->2D->D)
|
| 37 |
+
num_dense_layers: 9
|
| 38 |
+
num_moe_layers: 9 # 共 18 层
|
| 39 |
+
dropout: 0.0
|
| 40 |
+
prenorm: true
|
| 41 |
+
|
| 42 |
+
# === MoE ===
|
| 43 |
+
moe:
|
| 44 |
+
num_routed_experts: 7
|
| 45 |
+
num_shared_experts: 1
|
| 46 |
+
topk: 3 # Stage2 激活专家数
|
| 47 |
+
router_temperature_init: 0.5
|
| 48 |
+
router_temperature_final: 1.0
|
| 49 |
+
load_balance_weight: 0.01
|
| 50 |
+
boundary_weight: 0.001 # mean(logits^2) 防越界
|
| 51 |
+
|
| 52 |
+
# === Token 数量 ===
|
| 53 |
+
tokens:
|
| 54 |
+
num_detection: 1024
|
| 55 |
+
num_control: 24
|
| 56 |
+
num_ego: 8
|
| 57 |
+
num_extra: 256
|
| 58 |
+
|
| 59 |
+
# === 在线校准 ===
|
| 60 |
+
calibration:
|
| 61 |
+
intr_vec_dim: 11 # Cosmos npy 常见 11 维;完整 14 维数据则改为 14
|
| 62 |
+
hidden_size: 256
|
| 63 |
+
num_query_tokens: 256
|
| 64 |
+
num_self_attn_per_block: 2 # 每个 block: 1 cross + 2 self
|
| 65 |
+
num_blocks: 2 # 总计 6 层
|
| 66 |
+
num_heads: 8 # head_dim = 32
|
| 67 |
+
residual_range: 0.1 # Tanh * range
|
| 68 |
+
init_zero_output: true
|
| 69 |
+
|
| 70 |
+
# === 检测+未来预测头 ===
|
| 71 |
+
det_traj_head:
|
| 72 |
+
num_classes: 22 # 1 bg + 12 dynamic + 9 structured(HDMap)
|
| 73 |
+
box_dim: 7 # x,y,z,l,w,h,yaw
|
| 74 |
+
traj_horizon: 24 # 未来 24 帧
|
| 75 |
+
traj_dim: 3 # dx,dy,dyaw
|
| 76 |
+
hidden_size: 384
|
| 77 |
+
log_sigma_clamp: [-7.0, 7.0]
|
| 78 |
+
|
| 79 |
+
# === 控制头 ===
|
| 80 |
+
control_head:
|
| 81 |
+
num_traj_tokens: 12 # 解码 24 帧 ego 轨迹
|
| 82 |
+
num_action_tokens: 12 # 1 个解码 (steer,throttle,brake) μ/logσ
|
| 83 |
+
ego_traj_dim: 3 # x,y,yaw
|
| 84 |
+
action_dim: 3
|
| 85 |
+
hidden_size: 384
|
| 86 |
+
log_sigma_clamp: [-7.0, 7.0]
|
| 87 |
+
|
| 88 |
+
# === 检测目标筛选 ===
|
| 89 |
+
detection:
|
| 90 |
+
max_distance_m: 48.0
|
| 91 |
+
occlusion_depth_tolerance: 0.5 # LIDAR 深度容差(米)
|
| 92 |
+
min_box_pixels: 8
|
| 93 |
+
dynamic_classes:
|
| 94 |
+
- Automobile
|
| 95 |
+
- Heavy_truck
|
| 96 |
+
- Bus
|
| 97 |
+
- Train_or_tram_car
|
| 98 |
+
- Trolley_bus
|
| 99 |
+
- Other_vehicle
|
| 100 |
+
- Trailer
|
| 101 |
+
- Person
|
| 102 |
+
- Stroller
|
| 103 |
+
- Rider
|
| 104 |
+
- Animal
|
| 105 |
+
- Protruding_object
|
| 106 |
+
|
| 107 |
+
# === 数据 ===
|
| 108 |
+
data:
|
| 109 |
+
root: "./data/cosmos"
|
| 110 |
+
hdmap_subdir: "rds_hq"
|
| 111 |
+
synthetic_subdir: "cosmos_synthetic/single_view"
|
| 112 |
+
use_synthetic: true
|
| 113 |
+
use_real: false
|
| 114 |
+
weather:
|
| 115 |
+
- Sunny
|
| 116 |
+
- Morning
|
| 117 |
+
- Golden_hour
|
| 118 |
+
- Night
|
| 119 |
+
- Rainy
|
| 120 |
+
- Snowy
|
| 121 |
+
- Foggy
|
| 122 |
+
num_workers: 4
|
| 123 |
+
pin_memory: true
|
| 124 |
+
prefetch_factor: 2
|
| 125 |
+
augmentation:
|
| 126 |
+
gaussian_noise_std: 0.01
|
| 127 |
+
color_jitter: 0.1
|
| 128 |
+
perturb_translation_std_m: 0.1 # Stage1 中期开启
|
| 129 |
+
perturb_rotation_std_deg: 0.5
|
| 130 |
+
perturb_intrinsic_std: 0.005
|
| 131 |
+
perturb_extrinsic_std: 0.005
|
| 132 |
+
|
| 133 |
+
# === 损失权重(GradNorm 自适应的 1-6 任务初值)===
|
| 134 |
+
loss:
|
| 135 |
+
cls_weight: 1.0
|
| 136 |
+
box_weight: 1.0
|
| 137 |
+
isdyn_weight: 1.0
|
| 138 |
+
traj_obj_weight: 1.0
|
| 139 |
+
traj_ego_weight: 1.0
|
| 140 |
+
ctrl_weight: 1.0
|
| 141 |
+
moe_weight: 1.0 # 固定权重正则
|
| 142 |
+
calib_weight: 0.1 # 固定权重正则
|
| 143 |
+
giou_weight: 0.5 # 在 box 内组合
|
| 144 |
+
focal_alpha: 0.25
|
| 145 |
+
focal_gamma: 2.0
|
| 146 |
+
matcher_cls_cost: 2.0
|
| 147 |
+
matcher_l1_cost: 5.0
|
| 148 |
+
matcher_giou_cost: 2.0
|
| 149 |
+
|
| 150 |
+
# === 多任务训练 ===
|
| 151 |
+
# PCGrad 与 GradNorm 在 Stage1 / Stage2 全程启用:
|
| 152 |
+
# 两阶段的 6 项主任务(cls / box / isdyn / traj_obj / traj_ego / ctrl)
|
| 153 |
+
# 都存在尺度差异与梯度方向冲突,PCGrad 不应延迟到 Stage2。
|
| 154 |
+
multitask:
|
| 155 |
+
enable_gradnorm: true
|
| 156 |
+
enable_pcgrad: true
|
| 157 |
+
gradnorm_alpha: 1.5
|
| 158 |
+
gradnorm_lr: 0.025
|
| 159 |
+
pcgrad_shuffle: true
|
| 160 |
+
|
| 161 |
+
# === 训练 ===
|
| 162 |
+
train:
|
| 163 |
+
batch_size: 12 # A10G-Large 起步;OOM 时改为 8/6 并视情况增大 grad_accum_steps
|
| 164 |
+
grad_accum_steps: 1 # 有效 batch = batch_size * grad_accum_steps
|
| 165 |
+
ckpt_dir: outputs/checkpoints
|
| 166 |
+
total_steps: 100000
|
| 167 |
+
warmup_steps: 1000
|
| 168 |
+
base_lr: 2.0e-4
|
| 169 |
+
min_lr: 1.0e-6
|
| 170 |
+
weight_decay: 0.05
|
| 171 |
+
optimizer: adamw
|
| 172 |
+
betas: [0.9, 0.95]
|
| 173 |
+
grad_clip: 1.0
|
| 174 |
+
log_interval: 20
|
| 175 |
+
ckpt_interval: 1000
|
| 176 |
+
eval_interval: 5000
|
| 177 |
+
stage1_steps: 60000 # Stage1 步数
|
| 178 |
+
stage1_perturb_start: 20000 # 中期开始扰动
|
| 179 |
+
grad_monitor_threshold: 1.0e-7
|
| 180 |
+
param_groups:
|
| 181 |
+
dinov3_lr_mult: 0.0 # Stage1=0, Stage2 由 finetune_lr_ratio 提供
|
| 182 |
+
backbone_lr_mult: 1.0
|
| 183 |
+
calibration_lr_mult: 0.1
|
| 184 |
+
head_lr_mult: 1.0
|
| 185 |
+
gate_lr_mult: 0.1 # 门控参数低 LR
|
| 186 |
+
|
| 187 |
+
# === 部署 ===
|
| 188 |
+
deploy:
|
| 189 |
+
hf_repo: "fuzirui/WJAD" # 训练产生的 checkpoint 上传到此 model 仓库
|
| 190 |
+
push_checkpoints: true # 每 ckpt_interval 步上传 step_*.pt + latest.pt
|
| 191 |
+
hub_ckpt_prefix: checkpoints # Hub 上子目录
|
| 192 |
+
hf_sandbox_space: "fuzirui/wjad-sandbox"
|
pyproject.toml
CHANGED
|
@@ -1,40 +1,40 @@
|
|
| 1 |
-
[build-system]
|
| 2 |
-
requires = ["setuptools>=68", "wheel"]
|
| 3 |
-
build-backend = "setuptools.build_meta"
|
| 4 |
-
|
| 5 |
-
[project]
|
| 6 |
-
name = "wjad"
|
| 7 |
-
version = "0.1.0"
|
| 8 |
-
description = "End-to-end autonomous driving model with DINOv3, GateSelfAttention backbone, MoE, online calibration."
|
| 9 |
-
requires-python = ">=3.10"
|
| 10 |
-
dependencies = [
|
| 11 |
-
"torch>=2.4",
|
| 12 |
-
"transformers>=4.56",
|
| 13 |
-
"safetensors>=0.4",
|
| 14 |
-
"numpy>=1.24",
|
| 15 |
-
"opencv-python-headless>=4.8",
|
| 16 |
-
"einops>=0.7",
|
| 17 |
-
"scipy>=1.11",
|
| 18 |
-
"pyyaml>=6.0",
|
| 19 |
-
"tqdm>=4.66",
|
| 20 |
-
"huggingface_hub>=0.24",
|
| 21 |
-
"pillow>=10.0",
|
| 22 |
-
"av>=12.0",
|
| 23 |
-
]
|
| 24 |
-
|
| 25 |
-
[project.optional-dependencies]
|
| 26 |
-
dev = [
|
| 27 |
-
"pytest>=7",
|
| 28 |
-
"pytest-cov>=4",
|
| 29 |
-
"ruff>=0.5",
|
| 30 |
-
]
|
| 31 |
-
|
| 32 |
-
[tool.setuptools]
|
| 33 |
-
package-dir = {"" = "src"}
|
| 34 |
-
|
| 35 |
-
[tool.setuptools.packages.find]
|
| 36 |
-
where = ["src"]
|
| 37 |
-
|
| 38 |
-
[tool.ruff]
|
| 39 |
-
line-length = 120
|
| 40 |
-
target-version = "py310"
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "wjad"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "End-to-end autonomous driving model with DINOv3, GateSelfAttention backbone, MoE, online calibration."
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"torch>=2.4",
|
| 12 |
+
"transformers>=4.56",
|
| 13 |
+
"safetensors>=0.4",
|
| 14 |
+
"numpy>=1.24",
|
| 15 |
+
"opencv-python-headless>=4.8",
|
| 16 |
+
"einops>=0.7",
|
| 17 |
+
"scipy>=1.11",
|
| 18 |
+
"pyyaml>=6.0",
|
| 19 |
+
"tqdm>=4.66",
|
| 20 |
+
"huggingface_hub>=0.24",
|
| 21 |
+
"pillow>=10.0",
|
| 22 |
+
"av>=12.0",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
[project.optional-dependencies]
|
| 26 |
+
dev = [
|
| 27 |
+
"pytest>=7",
|
| 28 |
+
"pytest-cov>=4",
|
| 29 |
+
"ruff>=0.5",
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
[tool.setuptools]
|
| 33 |
+
package-dir = {"" = "src"}
|
| 34 |
+
|
| 35 |
+
[tool.setuptools.packages.find]
|
| 36 |
+
where = ["src"]
|
| 37 |
+
|
| 38 |
+
[tool.ruff]
|
| 39 |
+
line-length = 120
|
| 40 |
+
target-version = "py310"
|
scripts/download_data.py
CHANGED
|
@@ -1,76 +1,76 @@
|
|
| 1 |
-
"""下载 NVIDIA Cosmos-Drive-Dreams 数据集。
|
| 2 |
-
|
| 3 |
-
直接调用 NVIDIA 官方 download.py,并支持仅下载几个 clip 用于 sandbox 验证。
|
| 4 |
-
|
| 5 |
-
用法:
|
| 6 |
-
# 完整下载(synthetic + lidar + hdmap),约 3TB
|
| 7 |
-
python scripts/download_data.py --odir ./data/cosmos --workers 8
|
| 8 |
-
|
| 9 |
-
# 仅烟囱:限制 clip 数量(约几 GB,取决于 N)
|
| 10 |
-
python scripts/download_data.py --odir ./data/cosmos --file_types synthetic,lidar,hdmap --limit 2
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
from __future__ import annotations
|
| 14 |
-
|
| 15 |
-
import argparse
|
| 16 |
-
import subprocess
|
| 17 |
-
import sys
|
| 18 |
-
import urllib.request
|
| 19 |
-
from pathlib import Path
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
NV_DOWNLOAD_URL = (
|
| 23 |
-
"https://raw.githubusercontent.com/nv-tlabs/Cosmos-Drive-Dreams/main/scripts/download.py"
|
| 24 |
-
)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def _ensure_official_script(local_path: Path) -> None:
|
| 28 |
-
if local_path.exists():
|
| 29 |
-
return
|
| 30 |
-
print(f"[download_data] 下载 NVIDIA download.py -> {local_path}")
|
| 31 |
-
local_path.parent.mkdir(parents=True, exist_ok=True)
|
| 32 |
-
with urllib.request.urlopen(NV_DOWNLOAD_URL) as resp, open(local_path, "wb") as f:
|
| 33 |
-
f.write(resp.read())
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def main() -> None:
|
| 37 |
-
parser = argparse.ArgumentParser()
|
| 38 |
-
parser.add_argument("--odir", required=True, help="数据输出目录")
|
| 39 |
-
parser.add_argument(
|
| 40 |
-
"--file_types",
|
| 41 |
-
default="synthetic,lidar,hdmap",
|
| 42 |
-
help="数据类型逗号分隔列表",
|
| 43 |
-
)
|
| 44 |
-
parser.add_argument("--workers", type=int, default=4)
|
| 45 |
-
parser.add_argument(
|
| 46 |
-
"--limit",
|
| 47 |
-
type=int,
|
| 48 |
-
default=None,
|
| 49 |
-
metavar="N",
|
| 50 |
-
help="只拉取前 N 个 clip(传给 NVIDIA download.py,省磁盘)",
|
| 51 |
-
)
|
| 52 |
-
parser.add_argument("--clean_cache", action="store_true")
|
| 53 |
-
args = parser.parse_args()
|
| 54 |
-
|
| 55 |
-
odir = Path(args.odir)
|
| 56 |
-
nv_script = odir / ".nvidia_download.py"
|
| 57 |
-
_ensure_official_script(nv_script)
|
| 58 |
-
|
| 59 |
-
cmd = [
|
| 60 |
-
sys.executable,
|
| 61 |
-
str(nv_script),
|
| 62 |
-
"--odir", str(odir),
|
| 63 |
-
"--file_types", args.file_types,
|
| 64 |
-
"--workers", str(args.workers),
|
| 65 |
-
]
|
| 66 |
-
if args.limit is not None:
|
| 67 |
-
cmd.extend(["--limit", str(args.limit)])
|
| 68 |
-
if args.clean_cache:
|
| 69 |
-
cmd.append("--clean_cache")
|
| 70 |
-
print(f"[download_data] $ {' '.join(cmd)}")
|
| 71 |
-
rc = subprocess.call(cmd)
|
| 72 |
-
sys.exit(rc)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
if __name__ == "__main__":
|
| 76 |
-
main()
|
|
|
|
| 1 |
+
"""下载 NVIDIA Cosmos-Drive-Dreams 数据集。
|
| 2 |
+
|
| 3 |
+
直接调用 NVIDIA 官方 download.py,并支持仅下载几个 clip 用于 sandbox 验证。
|
| 4 |
+
|
| 5 |
+
用法:
|
| 6 |
+
# 完整下载(synthetic + lidar + hdmap),约 3TB
|
| 7 |
+
python scripts/download_data.py --odir ./data/cosmos --workers 8
|
| 8 |
+
|
| 9 |
+
# 仅烟囱:限制 clip 数量(约几 GB,取决于 N)
|
| 10 |
+
python scripts/download_data.py --odir ./data/cosmos --file_types synthetic,lidar,hdmap --limit 2
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import subprocess
|
| 17 |
+
import sys
|
| 18 |
+
import urllib.request
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
NV_DOWNLOAD_URL = (
|
| 23 |
+
"https://raw.githubusercontent.com/nv-tlabs/Cosmos-Drive-Dreams/main/scripts/download.py"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _ensure_official_script(local_path: Path) -> None:
|
| 28 |
+
if local_path.exists():
|
| 29 |
+
return
|
| 30 |
+
print(f"[download_data] 下载 NVIDIA download.py -> {local_path}")
|
| 31 |
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
| 32 |
+
with urllib.request.urlopen(NV_DOWNLOAD_URL) as resp, open(local_path, "wb") as f:
|
| 33 |
+
f.write(resp.read())
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def main() -> None:
|
| 37 |
+
parser = argparse.ArgumentParser()
|
| 38 |
+
parser.add_argument("--odir", required=True, help="数据输出目录")
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--file_types",
|
| 41 |
+
default="synthetic,lidar,hdmap",
|
| 42 |
+
help="数据类型逗号分隔列表",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument("--workers", type=int, default=4)
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--limit",
|
| 47 |
+
type=int,
|
| 48 |
+
default=None,
|
| 49 |
+
metavar="N",
|
| 50 |
+
help="只拉取前 N 个 clip(传给 NVIDIA download.py,省磁盘)",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument("--clean_cache", action="store_true")
|
| 53 |
+
args = parser.parse_args()
|
| 54 |
+
|
| 55 |
+
odir = Path(args.odir)
|
| 56 |
+
nv_script = odir / ".nvidia_download.py"
|
| 57 |
+
_ensure_official_script(nv_script)
|
| 58 |
+
|
| 59 |
+
cmd = [
|
| 60 |
+
sys.executable,
|
| 61 |
+
str(nv_script),
|
| 62 |
+
"--odir", str(odir),
|
| 63 |
+
"--file_types", args.file_types,
|
| 64 |
+
"--workers", str(args.workers),
|
| 65 |
+
]
|
| 66 |
+
if args.limit is not None:
|
| 67 |
+
cmd.extend(["--limit", str(args.limit)])
|
| 68 |
+
if args.clean_cache:
|
| 69 |
+
cmd.append("--clean_cache")
|
| 70 |
+
print(f"[download_data] $ {' '.join(cmd)}")
|
| 71 |
+
rc = subprocess.call(cmd)
|
| 72 |
+
sys.exit(rc)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
main()
|
scripts/estimate_memory.py
CHANGED
|
@@ -1,203 +1,203 @@
|
|
| 1 |
-
"""估算 E2EAVModel 在 BS≥8 训练时的显存/内存需求。
|
| 2 |
-
|
| 3 |
-
输出
|
| 4 |
-
- 各模块参数数量
|
| 5 |
-
- 训练显存细分:参数 / 优化器 / 梯度 / 主激活 / 多任务梯度副本 / 缓冲
|
| 6 |
-
- 推荐设备(HF Sandbox / Jobs)
|
| 7 |
-
- 主机内存与磁盘开销
|
| 8 |
-
|
| 9 |
-
公式说明(粗略上界)
|
| 10 |
-
- 参数 (bf16): 2 B/p;fp32 主副本: 4 B/p
|
| 11 |
-
- AdamW 一阶/二阶矩 (fp32): 8 B/p
|
| 12 |
-
- 梯度 (fp32): 4 B/p
|
| 13 |
-
- bf16 训练总计:参数 2 + 主 4 + AdamW 8 + grad 4 = 18 B/可训练 p
|
| 14 |
-
- DINOv3 冻结 Stage1:仅 2 B/p(前向激活按 no_grad 释放,可忽略)
|
| 15 |
-
- 主激活:每层约 ``B * N * D * 2 B``(bf16),18 层;MoE 层另加 8 个专家
|
| 16 |
-
SwiGLU 中间 ``B * N * 2 * 4D * 2 B`` 的临时项,但 Dense 加权求和后只
|
| 17 |
-
需 1 份输出。实际显存按"激活 = 单层峰值 × 层数"近似。
|
| 18 |
-
- PCGrad 在共享参数上 N 次 ``autograd.grad``:需要 retain_graph,
|
| 19 |
-
每个任务额外保留中间激活的引用,最坏放大 N 倍。这里按 1.5x 估算
|
| 20 |
-
(GPU autograd 内部 reuse + checkpointing 后通常远低于 N 倍)。
|
| 21 |
-
"""
|
| 22 |
-
|
| 23 |
-
from __future__ import annotations
|
| 24 |
-
|
| 25 |
-
import sys
|
| 26 |
-
from pathlib import Path
|
| 27 |
-
|
| 28 |
-
ROOT = Path(__file__).resolve().parent.parent
|
| 29 |
-
sys.path.insert(0, str(ROOT / "src"))
|
| 30 |
-
|
| 31 |
-
from dataclasses import dataclass
|
| 32 |
-
|
| 33 |
-
from wjad.model import E2EAVModel
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
@dataclass
|
| 37 |
-
class MemoryReport:
|
| 38 |
-
bs: int
|
| 39 |
-
seq_len: int
|
| 40 |
-
dim: int
|
| 41 |
-
layers: int
|
| 42 |
-
params_total: int
|
| 43 |
-
params_trainable_stage1: int
|
| 44 |
-
params_trainable_stage2: int
|
| 45 |
-
weights_gb_stage1: float
|
| 46 |
-
weights_gb_stage2: float
|
| 47 |
-
optim_gb_stage1: float
|
| 48 |
-
optim_gb_stage2: float
|
| 49 |
-
activations_gb: float
|
| 50 |
-
pcgrad_overhead_gb: float
|
| 51 |
-
total_stage1_gb: float
|
| 52 |
-
total_stage2_gb: float
|
| 53 |
-
host_ram_gb: float
|
| 54 |
-
disk_gb: float
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def count_params(model) -> tuple[int, dict[str, int]]:
|
| 58 |
-
total = 0
|
| 59 |
-
by_module: dict[str, int] = {}
|
| 60 |
-
for name, child in model.named_children():
|
| 61 |
-
n = sum(p.numel() for p in child.parameters())
|
| 62 |
-
by_module[name] = n
|
| 63 |
-
total += n
|
| 64 |
-
return total, by_module
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
def estimate(bs: int = 8) -> MemoryReport:
|
| 68 |
-
model = E2EAVModel(
|
| 69 |
-
dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
|
| 70 |
-
# 完整规模
|
| 71 |
-
backbone_dim=768,
|
| 72 |
-
num_heads=12,
|
| 73 |
-
num_dense_layers=9,
|
| 74 |
-
num_moe_layers=9,
|
| 75 |
-
num_routed_experts=7,
|
| 76 |
-
num_shared_experts=1,
|
| 77 |
-
topk_experts=3,
|
| 78 |
-
ffn_mult=4,
|
| 79 |
-
num_history_frames=8,
|
| 80 |
-
num_detection_tokens=1024,
|
| 81 |
-
num_control_tokens=24,
|
| 82 |
-
num_ego_tokens=8,
|
| 83 |
-
num_extra_tokens=256,
|
| 84 |
-
image_h=384,
|
| 85 |
-
image_w=1024,
|
| 86 |
-
patch_size=16,
|
| 87 |
-
num_classes=22,
|
| 88 |
-
traj_horizon=24,
|
| 89 |
-
freeze_dinov3=True,
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
total, by_module = count_params(model)
|
| 93 |
-
dinov3_n = by_module.get("dinov3", 0)
|
| 94 |
-
trainable_stage1 = total - dinov3_n
|
| 95 |
-
trainable_stage2 = total
|
| 96 |
-
|
| 97 |
-
# 序列长度(拼接后总 token 数 + 上下文)
|
| 98 |
-
n_visual = (8 // 2) * (24 // 2) * (64 // 2)
|
| 99 |
-
seq_len = n_visual + 8 + 1024 + 24 + 256
|
| 100 |
-
|
| 101 |
-
# === 显存 ===
|
| 102 |
-
# 单位:GB(除以 1024**3)
|
| 103 |
-
GB = 1024 ** 3
|
| 104 |
-
weights_stage1 = (dinov3_n * 2 + trainable_stage1 * 2) / GB # 全部 bf16
|
| 105 |
-
weights_stage2 = (total * 2) / GB
|
| 106 |
-
optim_stage1 = (trainable_stage1 * (4 + 4 + 4)) / GB # master + m + v
|
| 107 |
-
optim_stage2 = (trainable_stage2 * (4 + 4 + 4)) / GB
|
| 108 |
-
|
| 109 |
-
# 激活:粗略 = bs * seq_len * dim * 2 * (num_layers + 1) * 1.5 (含 attn/FFN 重叠)
|
| 110 |
-
base_act = bs * seq_len * 768 * 2 * (18 + 6) * 1.5 # 主干 18 + 校准 6
|
| 111 |
-
# MoE FFN 中间 (4D = 3072) 的临时项:每 MoE 层 ≈ bs * seq_len * 3072 * 2 * 8(8 专家)
|
| 112 |
-
moe_act = bs * seq_len * 3072 * 2 * 8 * 9
|
| 113 |
-
# DINOv3 冻结:no_grad,前向激活在 forward 后立即释放,估 2 GB 峰值
|
| 114 |
-
dino_act = 2.0 * GB
|
| 115 |
-
activations_gb = (base_act + moe_act + dino_act) / GB
|
| 116 |
-
|
| 117 |
-
# PCGrad 开销(共享参数上 N 次 autograd.grad):retain_graph 阶段会
|
| 118 |
-
# 阻止激活释放,最坏接近 1.5x;这里按 +0.5x 估算
|
| 119 |
-
pcgrad_overhead_gb = 0.5 * activations_gb
|
| 120 |
-
|
| 121 |
-
total_stage1 = weights_stage1 + optim_stage1 + activations_gb + pcgrad_overhead_gb + 2.0
|
| 122 |
-
total_stage2 = weights_stage2 + optim_stage2 + activations_gb + pcgrad_overhead_gb + 2.0
|
| 123 |
-
|
| 124 |
-
# === 主机 RAM ===
|
| 125 |
-
# DataLoader prefetch + workers + 模型 CPU 副本 + JSON / LIDAR 解析
|
| 126 |
-
host_ram = 8.0 + bs * 0.3 * 4 * 2 # 4 workers, prefetch 2
|
| 127 |
-
|
| 128 |
-
# === 磁盘 ===
|
| 129 |
-
# 全量数据集 ~3TB;只跑 sandbox 时 ~5GB(几个 clip);典型 ~50GB(一个 weather 全部)
|
| 130 |
-
disk = 50.0
|
| 131 |
-
|
| 132 |
-
return MemoryReport(
|
| 133 |
-
bs=bs,
|
| 134 |
-
seq_len=seq_len,
|
| 135 |
-
dim=768,
|
| 136 |
-
layers=18,
|
| 137 |
-
params_total=total,
|
| 138 |
-
params_trainable_stage1=trainable_stage1,
|
| 139 |
-
params_trainable_stage2=trainable_stage2,
|
| 140 |
-
weights_gb_stage1=weights_stage1,
|
| 141 |
-
weights_gb_stage2=weights_stage2,
|
| 142 |
-
optim_gb_stage1=optim_stage1,
|
| 143 |
-
optim_gb_stage2=optim_stage2,
|
| 144 |
-
activations_gb=activations_gb,
|
| 145 |
-
pcgrad_overhead_gb=pcgrad_overhead_gb,
|
| 146 |
-
total_stage1_gb=total_stage1,
|
| 147 |
-
total_stage2_gb=total_stage2,
|
| 148 |
-
host_ram_gb=host_ram,
|
| 149 |
-
disk_gb=disk,
|
| 150 |
-
)
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
def recommend_device(stage_max_gb: float) -> tuple[str, str]:
|
| 154 |
-
"""根据 Stage2 峰值显存推荐 GPU。"""
|
| 155 |
-
margin = 1.15 # 留 15% 余量(碎片化、CUDA caching、cuBLAS workspace)
|
| 156 |
-
need = stage_max_gb * margin
|
| 157 |
-
candidates = [
|
| 158 |
-
("T4 16GB", 16),
|
| 159 |
-
("L4 24GB", 24),
|
| 160 |
-
("A10G 24GB", 24),
|
| 161 |
-
("A10G Large 48GB", 48),
|
| 162 |
-
("A100 40GB", 40),
|
| 163 |
-
("L40S 48GB", 48),
|
| 164 |
-
("A100 80GB", 80),
|
| 165 |
-
("H100 80GB", 80),
|
| 166 |
-
]
|
| 167 |
-
fit = [c for c in candidates if c[1] >= need]
|
| 168 |
-
if not fit:
|
| 169 |
-
return "H200 / 多卡 80GB+", f"需要 ≥{need:.1f} GB(单卡极限)"
|
| 170 |
-
return fit[0][0], f"需要 ≥{need:.1f} GB"
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
def main() -> None:
|
| 174 |
-
print("=" * 72)
|
| 175 |
-
print(" WJAD 训练显存/内存估算 (bf16 AMP)")
|
| 176 |
-
print("=" * 72)
|
| 177 |
-
for bs in (1, 2, 4, 8, 16):
|
| 178 |
-
r = estimate(bs)
|
| 179 |
-
print(f"\n--- BS = {bs} ---")
|
| 180 |
-
print(f" 总参数 : {r.params_total / 1e6:8.2f} M")
|
| 181 |
-
print(f" 可训练 (S1) : {r.params_trainable_stage1 / 1e6:8.2f} M")
|
| 182 |
-
print(f" 可训练 (S2) : {r.params_trainable_stage2 / 1e6:8.2f} M")
|
| 183 |
-
print(f" 序列长度 : {r.seq_len}")
|
| 184 |
-
print(f" 权重 (S1/S2) : {r.weights_gb_stage1:6.2f} / {r.weights_gb_stage2:6.2f} GB")
|
| 185 |
-
print(f" 优化器 (S1/S2): {r.optim_gb_stage1:6.2f} / {r.optim_gb_stage2:6.2f} GB")
|
| 186 |
-
print(f" 激活 : {r.activations_gb:6.2f} GB")
|
| 187 |
-
print(f" PCGrad 余量 : {r.pcgrad_overhead_gb:6.2f} GB")
|
| 188 |
-
print(f" 显存合计 S1 : {r.total_stage1_gb:6.2f} GB")
|
| 189 |
-
print(f" 显存合计 S2 : {r.total_stage2_gb:6.2f} GB <- 峰值")
|
| 190 |
-
gpu, note = recommend_device(r.total_stage2_gb)
|
| 191 |
-
print(f" 推荐 GPU : {gpu} ({note})")
|
| 192 |
-
print(f" 主机 RAM : ≥ {r.host_ram_gb:6.2f} GB")
|
| 193 |
-
print(f" 磁盘 (典型) : ≈ {r.disk_gb:6.0f} GB")
|
| 194 |
-
|
| 195 |
-
print()
|
| 196 |
-
print("说明:")
|
| 197 |
-
print(" - 估算包含 bf16 AMP + AdamW(m,v fp32) + 梯度 fp32 主副本 + PCGrad 开销。")
|
| 198 |
-
print(" - 开 ``gradient_checkpointing`` 可把激活降
|
| 199 |
-
print(" - 实测请用 ``nvidia-smi`` 或 ``torch.cuda.max_memory_allocated()`` 校准。")
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
if __name__ == "__main__":
|
| 203 |
-
main()
|
|
|
|
| 1 |
+
"""估算 E2EAVModel 在 BS≥8 训练时的显存/内存需求。
|
| 2 |
+
|
| 3 |
+
输出
|
| 4 |
+
- 各模块参数数量
|
| 5 |
+
- 训练显存细分:参数 / 优化器 / 梯度 / 主激活 / 多任务梯度副本 / 缓冲
|
| 6 |
+
- 推荐设备(HF Sandbox / Jobs)
|
| 7 |
+
- 主机内存与磁盘开销
|
| 8 |
+
|
| 9 |
+
公式说明(粗略上界)
|
| 10 |
+
- 参数 (bf16): 2 B/p;fp32 主副本: 4 B/p
|
| 11 |
+
- AdamW 一阶/二阶矩 (fp32): 8 B/p
|
| 12 |
+
- 梯度 (fp32): 4 B/p
|
| 13 |
+
- bf16 训练总计:参数 2 + 主 4 + AdamW 8 + grad 4 = 18 B/可训练 p
|
| 14 |
+
- DINOv3 冻结 Stage1:仅 2 B/p(前向激活按 no_grad 释放,可忽略)
|
| 15 |
+
- 主激活:每层约 ``B * N * D * 2 B``(bf16),18 层;MoE 层另加 8 个专家
|
| 16 |
+
SwiGLU 中间 ``B * N * 2 * 4D * 2 B`` 的临时项,但 Dense 加权求和后只
|
| 17 |
+
需 1 份输出。实际显存按"激活 = 单层峰值 × 层数"近似。
|
| 18 |
+
- PCGrad 在共享参数上 N 次 ``autograd.grad``:需要 retain_graph,
|
| 19 |
+
每个任务额外保留中间激活的引用,最坏放大 N 倍。这里按 1.5x 估算
|
| 20 |
+
(GPU autograd 内部 reuse + checkpointing 后通常远低于 N 倍)。
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import sys
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 29 |
+
sys.path.insert(0, str(ROOT / "src"))
|
| 30 |
+
|
| 31 |
+
from dataclasses import dataclass
|
| 32 |
+
|
| 33 |
+
from wjad.model import E2EAVModel
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class MemoryReport:
|
| 38 |
+
bs: int
|
| 39 |
+
seq_len: int
|
| 40 |
+
dim: int
|
| 41 |
+
layers: int
|
| 42 |
+
params_total: int
|
| 43 |
+
params_trainable_stage1: int
|
| 44 |
+
params_trainable_stage2: int
|
| 45 |
+
weights_gb_stage1: float
|
| 46 |
+
weights_gb_stage2: float
|
| 47 |
+
optim_gb_stage1: float
|
| 48 |
+
optim_gb_stage2: float
|
| 49 |
+
activations_gb: float
|
| 50 |
+
pcgrad_overhead_gb: float
|
| 51 |
+
total_stage1_gb: float
|
| 52 |
+
total_stage2_gb: float
|
| 53 |
+
host_ram_gb: float
|
| 54 |
+
disk_gb: float
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def count_params(model) -> tuple[int, dict[str, int]]:
|
| 58 |
+
total = 0
|
| 59 |
+
by_module: dict[str, int] = {}
|
| 60 |
+
for name, child in model.named_children():
|
| 61 |
+
n = sum(p.numel() for p in child.parameters())
|
| 62 |
+
by_module[name] = n
|
| 63 |
+
total += n
|
| 64 |
+
return total, by_module
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def estimate(bs: int = 8) -> MemoryReport:
|
| 68 |
+
model = E2EAVModel(
|
| 69 |
+
dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
|
| 70 |
+
# 完整规模
|
| 71 |
+
backbone_dim=768,
|
| 72 |
+
num_heads=12,
|
| 73 |
+
num_dense_layers=9,
|
| 74 |
+
num_moe_layers=9,
|
| 75 |
+
num_routed_experts=7,
|
| 76 |
+
num_shared_experts=1,
|
| 77 |
+
topk_experts=3,
|
| 78 |
+
ffn_mult=4,
|
| 79 |
+
num_history_frames=8,
|
| 80 |
+
num_detection_tokens=1024,
|
| 81 |
+
num_control_tokens=24,
|
| 82 |
+
num_ego_tokens=8,
|
| 83 |
+
num_extra_tokens=256,
|
| 84 |
+
image_h=384,
|
| 85 |
+
image_w=1024,
|
| 86 |
+
patch_size=16,
|
| 87 |
+
num_classes=22,
|
| 88 |
+
traj_horizon=24,
|
| 89 |
+
freeze_dinov3=True,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
total, by_module = count_params(model)
|
| 93 |
+
dinov3_n = by_module.get("dinov3", 0)
|
| 94 |
+
trainable_stage1 = total - dinov3_n
|
| 95 |
+
trainable_stage2 = total
|
| 96 |
+
|
| 97 |
+
# 序列长度(拼接后总 token 数 + 上下文)
|
| 98 |
+
n_visual = (8 // 2) * (24 // 2) * (64 // 2)
|
| 99 |
+
seq_len = n_visual + 8 + 1024 + 24 + 256
|
| 100 |
+
|
| 101 |
+
# === 显存 ===
|
| 102 |
+
# 单位:GB(除以 1024**3)
|
| 103 |
+
GB = 1024 ** 3
|
| 104 |
+
weights_stage1 = (dinov3_n * 2 + trainable_stage1 * 2) / GB # 全部 bf16
|
| 105 |
+
weights_stage2 = (total * 2) / GB
|
| 106 |
+
optim_stage1 = (trainable_stage1 * (4 + 4 + 4)) / GB # master + m + v
|
| 107 |
+
optim_stage2 = (trainable_stage2 * (4 + 4 + 4)) / GB
|
| 108 |
+
|
| 109 |
+
# 激活:粗略 = bs * seq_len * dim * 2 * (num_layers + 1) * 1.5 (含 attn/FFN 重叠)
|
| 110 |
+
base_act = bs * seq_len * 768 * 2 * (18 + 6) * 1.5 # 主干 18 + 校准 6
|
| 111 |
+
# MoE FFN 中间 (4D = 3072) 的临时项:每 MoE 层 ≈ bs * seq_len * 3072 * 2 * 8(8 专家)
|
| 112 |
+
moe_act = bs * seq_len * 3072 * 2 * 8 * 9
|
| 113 |
+
# DINOv3 冻结:no_grad,前向激活在 forward 后立即释放,估 2 GB 峰值
|
| 114 |
+
dino_act = 2.0 * GB
|
| 115 |
+
activations_gb = (base_act + moe_act + dino_act) / GB
|
| 116 |
+
|
| 117 |
+
# PCGrad 开销(共享参数上 N 次 autograd.grad):retain_graph 阶段会
|
| 118 |
+
# 阻止激活释放,最坏接近 1.5x;这里按 +0.5x 估算
|
| 119 |
+
pcgrad_overhead_gb = 0.5 * activations_gb
|
| 120 |
+
|
| 121 |
+
total_stage1 = weights_stage1 + optim_stage1 + activations_gb + pcgrad_overhead_gb + 2.0
|
| 122 |
+
total_stage2 = weights_stage2 + optim_stage2 + activations_gb + pcgrad_overhead_gb + 2.0
|
| 123 |
+
|
| 124 |
+
# === 主机 RAM ===
|
| 125 |
+
# DataLoader prefetch + workers + 模型 CPU 副本 + JSON / LIDAR 解析
|
| 126 |
+
host_ram = 8.0 + bs * 0.3 * 4 * 2 # 4 workers, prefetch 2
|
| 127 |
+
|
| 128 |
+
# === 磁盘 ===
|
| 129 |
+
# 全量数据集 ~3TB;只跑 sandbox 时 ~5GB(几个 clip);典型 ~50GB(一个 weather 全部)
|
| 130 |
+
disk = 50.0
|
| 131 |
+
|
| 132 |
+
return MemoryReport(
|
| 133 |
+
bs=bs,
|
| 134 |
+
seq_len=seq_len,
|
| 135 |
+
dim=768,
|
| 136 |
+
layers=18,
|
| 137 |
+
params_total=total,
|
| 138 |
+
params_trainable_stage1=trainable_stage1,
|
| 139 |
+
params_trainable_stage2=trainable_stage2,
|
| 140 |
+
weights_gb_stage1=weights_stage1,
|
| 141 |
+
weights_gb_stage2=weights_stage2,
|
| 142 |
+
optim_gb_stage1=optim_stage1,
|
| 143 |
+
optim_gb_stage2=optim_stage2,
|
| 144 |
+
activations_gb=activations_gb,
|
| 145 |
+
pcgrad_overhead_gb=pcgrad_overhead_gb,
|
| 146 |
+
total_stage1_gb=total_stage1,
|
| 147 |
+
total_stage2_gb=total_stage2,
|
| 148 |
+
host_ram_gb=host_ram,
|
| 149 |
+
disk_gb=disk,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def recommend_device(stage_max_gb: float) -> tuple[str, str]:
|
| 154 |
+
"""根据 Stage2 峰值显存推荐 GPU。"""
|
| 155 |
+
margin = 1.15 # 留 15% 余量(碎片化、CUDA caching、cuBLAS workspace)
|
| 156 |
+
need = stage_max_gb * margin
|
| 157 |
+
candidates = [
|
| 158 |
+
("T4 16GB", 16),
|
| 159 |
+
("L4 24GB", 24),
|
| 160 |
+
("A10G 24GB", 24),
|
| 161 |
+
("A10G Large 48GB", 48),
|
| 162 |
+
("A100 40GB", 40),
|
| 163 |
+
("L40S 48GB", 48),
|
| 164 |
+
("A100 80GB", 80),
|
| 165 |
+
("H100 80GB", 80),
|
| 166 |
+
]
|
| 167 |
+
fit = [c for c in candidates if c[1] >= need]
|
| 168 |
+
if not fit:
|
| 169 |
+
return "H200 / 多卡 80GB+", f"需要 ≥{need:.1f} GB(单卡极限)"
|
| 170 |
+
return fit[0][0], f"需要 ≥{need:.1f} GB"
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def main() -> None:
|
| 174 |
+
print("=" * 72)
|
| 175 |
+
print(" WJAD 训练显存/内存估算 (bf16 AMP)")
|
| 176 |
+
print("=" * 72)
|
| 177 |
+
for bs in (1, 2, 4, 8, 16):
|
| 178 |
+
r = estimate(bs)
|
| 179 |
+
print(f"\n--- BS = {bs} ---")
|
| 180 |
+
print(f" 总参数 : {r.params_total / 1e6:8.2f} M")
|
| 181 |
+
print(f" 可训练 (S1) : {r.params_trainable_stage1 / 1e6:8.2f} M")
|
| 182 |
+
print(f" 可训练 (S2) : {r.params_trainable_stage2 / 1e6:8.2f} M")
|
| 183 |
+
print(f" 序列长度 : {r.seq_len}")
|
| 184 |
+
print(f" 权重 (S1/S2) : {r.weights_gb_stage1:6.2f} / {r.weights_gb_stage2:6.2f} GB")
|
| 185 |
+
print(f" 优化器 (S1/S2): {r.optim_gb_stage1:6.2f} / {r.optim_gb_stage2:6.2f} GB")
|
| 186 |
+
print(f" 激活 : {r.activations_gb:6.2f} GB")
|
| 187 |
+
print(f" PCGrad 余量 : {r.pcgrad_overhead_gb:6.2f} GB")
|
| 188 |
+
print(f" 显存合计 S1 : {r.total_stage1_gb:6.2f} GB")
|
| 189 |
+
print(f" 显存合计 S2 : {r.total_stage2_gb:6.2f} GB <- 峰值")
|
| 190 |
+
gpu, note = recommend_device(r.total_stage2_gb)
|
| 191 |
+
print(f" 推荐 GPU : {gpu} ({note})")
|
| 192 |
+
print(f" 主机 RAM : ≥ {r.host_ram_gb:6.2f} GB")
|
| 193 |
+
print(f" 磁盘 (典型) : ≈ {r.disk_gb:6.0f} GB")
|
| 194 |
+
|
| 195 |
+
print()
|
| 196 |
+
print("说明:")
|
| 197 |
+
print(" - 估算包含 bf16 AMP + AdamW(m,v fp32) + 梯度 fp32 主副本 + PCGrad 开销。")
|
| 198 |
+
print(" - 开 ``gradient_checkpointing`` 可把激活降��约 1/3,BS 可成倍提升。")
|
| 199 |
+
print(" - 实测请用 ``nvidia-smi`` 或 ``torch.cuda.max_memory_allocated()`` 校准。")
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
if __name__ == "__main__":
|
| 203 |
+
main()
|
scripts/ingest_hub_to_bucket.py
CHANGED
|
@@ -1,207 +1,234 @@
|
|
| 1 |
-
"""将 Hub 上数据集/仓库路径 **服务端拷贝** 到 Storage Bucket,并在挂载点上可选解压 ``.tar`` / ``.tar.*``。
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
--bucket
|
| 18 |
-
--
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
if
|
| 76 |
-
continue
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
)
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
)
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
"--
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
"
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
"
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
"
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""将 Hub 上数据集/仓库路径 **服务端拷贝** 到 Storage Bucket,并在挂载点上可选解压 ``.tar`` / ``.tar.*``。
|
| 2 |
+
|
| 3 |
+
HF Jobs 容器 **根分区**(ephemeral)常有 **~50GiB** 上限。``copy_files`` 在 Hub 侧完成,几乎不占根盘;
|
| 4 |
+
解压、``pip``、Hub 客户端默认缓存会写 ``/tmp``、``~/.cache``,易触发 eviction。
|
| 5 |
+
若设 ``--bucket-mount``,会把 ``TMPDIR``、``HF_HOME`` 等重定向到挂载点下 ``.wjad_ephemeral/``,
|
| 6 |
+
大块临时数据落在 **Bucket**。训练时用 Volume 挂载 Bucket,``WJAD_DATA_ROOT`` 指到 mirror 或解压树即可,
|
| 7 |
+
无需把整库先下载到根盘。
|
| 8 |
+
|
| 9 |
+
解压输出默认写入 **另一棵目录树**(与 ``--dest-prefix`` 平级的 ``{dest-prefix}_unpacked/``),
|
| 10 |
+
相对路径与镜像里的 ``.tar`` 一致,避免在源树旁叠 ``*_extracted/`` 导致 ``rglob`` 反复扫到嵌套 tar。
|
| 11 |
+
|
| 12 |
+
示例(本地或 Job 内,且已挂载 bucket 到 ``/mnt/cosmos``)::
|
| 13 |
+
|
| 14 |
+
python scripts/ingest_hub_to_bucket.py \\
|
| 15 |
+
--bucket fuzirui/my-cosmos-bucket \\
|
| 16 |
+
--dest-prefix cosmos_hub_mirror \\
|
| 17 |
+
--bucket-mount /mnt/cosmos \\
|
| 18 |
+
--extract-tars
|
| 19 |
+
|
| 20 |
+
仅拷贝、不解压::
|
| 21 |
+
|
| 22 |
+
python scripts/ingest_hub_to_bucket.py \\
|
| 23 |
+
--bucket fuzirui/my-cosmos-bucket \\
|
| 24 |
+
--source 'hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/' \\
|
| 25 |
+
--dest-prefix raw \\
|
| 26 |
+
--copy-only
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from __future__ import annotations
|
| 30 |
+
|
| 31 |
+
import argparse
|
| 32 |
+
import os
|
| 33 |
+
import sys
|
| 34 |
+
import tarfile
|
| 35 |
+
from pathlib import Path
|
| 36 |
+
|
| 37 |
+
from huggingface_hub import HfApi, create_bucket
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _ensure_trailing_slash_hf_url(url: str) -> str:
|
| 41 |
+
s = url.strip()
|
| 42 |
+
if s.endswith("/"):
|
| 43 |
+
return s
|
| 44 |
+
return s + "/"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _archive_stem(path: Path) -> str:
|
| 48 |
+
"""``foo.tar.gz`` -> ``foo``;``bar.tar`` -> ``bar``。"""
|
| 49 |
+
n = path.name
|
| 50 |
+
for ext in (".tar.gz", ".tar.xz", ".tgz", ".tar"):
|
| 51 |
+
if n.endswith(ext):
|
| 52 |
+
return n[: -len(ext)]
|
| 53 |
+
return path.stem
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _is_under_path(path: Path, parent: Path) -> bool:
|
| 57 |
+
try:
|
| 58 |
+
path.resolve().relative_to(parent.resolve())
|
| 59 |
+
return True
|
| 60 |
+
except ValueError:
|
| 61 |
+
return False
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _collect_archives(
|
| 65 |
+
root: Path,
|
| 66 |
+
patterns: tuple[str, ...],
|
| 67 |
+
*,
|
| 68 |
+
exclude_under: Path | None = None,
|
| 69 |
+
) -> list[Path]:
|
| 70 |
+
"""收集待解压归档,排除历史 ``*_extracted`` 目录及解压输出树,避免嵌套/重复扫描。"""
|
| 71 |
+
out: list[Path] = []
|
| 72 |
+
seen: set[Path] = set()
|
| 73 |
+
for pat in patterns:
|
| 74 |
+
for p in root.rglob(pat):
|
| 75 |
+
if not p.is_file():
|
| 76 |
+
continue
|
| 77 |
+
rp = p.resolve()
|
| 78 |
+
if rp in seen:
|
| 79 |
+
continue
|
| 80 |
+
if any(part.endswith("_extracted") or part == "_extracted" for part in p.parts):
|
| 81 |
+
continue
|
| 82 |
+
if exclude_under is not None and _is_under_path(p, exclude_under):
|
| 83 |
+
continue
|
| 84 |
+
seen.add(rp)
|
| 85 |
+
out.append(p)
|
| 86 |
+
return sorted(out)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _redirect_ephemeral_to_bucket(bucket_mount: Path) -> None:
|
| 90 |
+
"""把临时文件与 HF 缓存写到 Bucket 挂载点,避免撑爆 Job 50G 根分区。"""
|
| 91 |
+
base = bucket_mount / ".wjad_ephemeral"
|
| 92 |
+
tmp = base / "tmp"
|
| 93 |
+
hf_home = base / "hf_home"
|
| 94 |
+
xdg = base / "xdg_cache"
|
| 95 |
+
for d in (tmp, hf_home, hf_home / "hub", xdg):
|
| 96 |
+
d.mkdir(parents=True, exist_ok=True)
|
| 97 |
+
os.environ["TMPDIR"] = str(tmp)
|
| 98 |
+
os.environ["TMP"] = str(tmp)
|
| 99 |
+
os.environ["TEMP"] = str(tmp)
|
| 100 |
+
os.environ["HF_HOME"] = str(hf_home)
|
| 101 |
+
os.environ["HF_HUB_CACHE"] = str(hf_home / "hub")
|
| 102 |
+
os.environ["XDG_CACHE_HOME"] = str(xdg)
|
| 103 |
+
print(f"[ingest] 临时/缓存 -> {base}", flush=True)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def main() -> None:
|
| 107 |
+
parser = argparse.ArgumentParser(description="Hub copy_files → Bucket,可选按镜像目录解压 tar")
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--source",
|
| 110 |
+
default="hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/",
|
| 111 |
+
help="hf:// 源(仓库或 bucket 前缀),目录建议以 / 结尾",
|
| 112 |
+
)
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"--bucket",
|
| 115 |
+
required=True,
|
| 116 |
+
help='目标 bucket id,如 "user/my-bucket"(不要写 hf://)',
|
| 117 |
+
)
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--dest-prefix",
|
| 120 |
+
default="cosmos_hub_mirror",
|
| 121 |
+
help="copy_files 写入 bucket 内的子路径(不要用前导 /)",
|
| 122 |
+
)
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--ensure-bucket",
|
| 125 |
+
action="store_true",
|
| 126 |
+
help="若不存在则 create_bucket(..., exist_ok=True)",
|
| 127 |
+
)
|
| 128 |
+
parser.add_argument(
|
| 129 |
+
"--copy-only",
|
| 130 |
+
action="store_true",
|
| 131 |
+
help="只做 copy_files,不解压",
|
| 132 |
+
)
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--bucket-mount",
|
| 135 |
+
default=None,
|
| 136 |
+
help="Job 内 bucket 挂载点(如 /mnt/cosmos);若设 --extract-tars 则必填",
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--extract-tars",
|
| 140 |
+
action="store_true",
|
| 141 |
+
help="解压 mirror 树下的 tar;输出见 --extract-out-prefix",
|
| 142 |
+
)
|
| 143 |
+
parser.add_argument(
|
| 144 |
+
"--extract-out-prefix",
|
| 145 |
+
default=None,
|
| 146 |
+
metavar="NAME",
|
| 147 |
+
help="解压根目录(bucket 内相对路径,与 dest-prefix 平级)。默认 {dest-prefix}_unpacked",
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--extract-beside-tar",
|
| 151 |
+
action="store_true",
|
| 152 |
+
help="旧行为:在每条 tar 旁建 ``{name}_extracted``(易与 rglob 嵌套 tar 纠缠,一般不推荐)",
|
| 153 |
+
)
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
"--max-tars",
|
| 156 |
+
type=int,
|
| 157 |
+
default=None,
|
| 158 |
+
help="最多处理多少个 tar(烟囱/限流)",
|
| 159 |
+
)
|
| 160 |
+
args = parser.parse_args()
|
| 161 |
+
|
| 162 |
+
if args.bucket_mount:
|
| 163 |
+
_redirect_ephemeral_to_bucket(Path(args.bucket_mount))
|
| 164 |
+
|
| 165 |
+
src = _ensure_trailing_slash_hf_url(args.source)
|
| 166 |
+
dest_prefix = args.dest_prefix.strip().strip("/")
|
| 167 |
+
dest = f"hf://buckets/{args.bucket}/{dest_prefix}/"
|
| 168 |
+
|
| 169 |
+
api = HfApi()
|
| 170 |
+
if args.ensure_bucket:
|
| 171 |
+
create_bucket(args.bucket, exist_ok=True)
|
| 172 |
+
print(f"[ingest] bucket ready: {args.bucket}", flush=True)
|
| 173 |
+
|
| 174 |
+
print(f"[ingest] copy_files\n {src}\n -> {dest}", flush=True)
|
| 175 |
+
api.copy_files(src, dest)
|
| 176 |
+
print("[ingest] copy_files 完成", flush=True)
|
| 177 |
+
|
| 178 |
+
if args.copy_only or not args.extract_tars:
|
| 179 |
+
return
|
| 180 |
+
|
| 181 |
+
if not args.bucket_mount:
|
| 182 |
+
print("[ingest] 错误: --extract-tars 需要 --bucket-mount", file=sys.stderr)
|
| 183 |
+
sys.exit(2)
|
| 184 |
+
|
| 185 |
+
root = Path(args.bucket_mount) / dest_prefix
|
| 186 |
+
out_rel = args.extract_out_prefix
|
| 187 |
+
if out_rel is None:
|
| 188 |
+
out_rel = f"{dest_prefix}_unpacked"
|
| 189 |
+
out_rel = out_rel.strip().strip("/")
|
| 190 |
+
extract_base = Path(args.bucket_mount) / out_rel
|
| 191 |
+
|
| 192 |
+
if not root.is_dir():
|
| 193 |
+
print(f"[ingest] 警告: 镜像路径不存在或尚不可见: {root}", flush=True)
|
| 194 |
+
|
| 195 |
+
patterns = ("*.tar", "*.tar.gz", "*.tar.xz", "*.tgz")
|
| 196 |
+
archives = _collect_archives(root, patterns, exclude_under=extract_base)
|
| 197 |
+
|
| 198 |
+
if args.max_tars is not None:
|
| 199 |
+
archives = archives[: args.max_tars]
|
| 200 |
+
|
| 201 |
+
mode = "beside-tar" if args.extract_beside_tar else f"mirror -> {extract_base}"
|
| 202 |
+
print(f"[ingest] 将解压 {len(archives)} 个归档 under {root}(模式: {mode})", flush=True)
|
| 203 |
+
|
| 204 |
+
for i, tar_path in enumerate(archives):
|
| 205 |
+
if args.extract_beside_tar:
|
| 206 |
+
out_dir = tar_path.parent / f"{tar_path.name}_extracted"
|
| 207 |
+
else:
|
| 208 |
+
rel = tar_path.relative_to(root)
|
| 209 |
+
out_dir = extract_base / rel.parent / _archive_stem(tar_path)
|
| 210 |
+
|
| 211 |
+
if out_dir.exists() and any(out_dir.iterdir()):
|
| 212 |
+
print(f"[ingest] ({i + 1}/{len(archives)}) 跳过(已存在非空) {out_dir}", flush=True)
|
| 213 |
+
continue
|
| 214 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 215 |
+
print(f"[ingest] ({i + 1}/{len(archives)}) {tar_path} -> {out_dir}", flush=True)
|
| 216 |
+
try:
|
| 217 |
+
with tarfile.open(tar_path, mode="r:*") as tf:
|
| 218 |
+
_extract(tf, out_dir)
|
| 219 |
+
except Exception as e:
|
| 220 |
+
print(f"[ingest] 解压失败 {tar_path}: {e}", flush=True)
|
| 221 |
+
raise
|
| 222 |
+
|
| 223 |
+
print("[ingest] 全部完成", flush=True)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def _extract(tf: tarfile.TarFile, out_dir: Path) -> None:
|
| 227 |
+
if sys.version_info >= (3, 12):
|
| 228 |
+
tf.extractall(out_dir, filter="data")
|
| 229 |
+
else:
|
| 230 |
+
tf.extractall(out_dir)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
if __name__ == "__main__":
|
| 234 |
+
main()
|
scripts/push_cpu_ingest_job.py
CHANGED
|
@@ -1,141 +1,148 @@
|
|
| 1 |
-
"""提交 **CPU Basic** Job:把 Hub 上 Cosmos(或其它源)服务端复制到你的 Bucket,并尝试解压 tar。
|
| 2 |
-
|
| 3 |
-
- 计费:见 https://huggingface.co/docs/hub/jobs-pricing(CPU Basic 约 \\$0.01/ 小时量级,以官网为准)。
|
| 4 |
-
- 默认挂载:代码 ``fuzirui/WJAD``、可写 Bucket;超时默认 48h(大仓库复制可能很久)。
|
| 5 |
-
- 须 ``hf auth login``;NVIDIA 数据集须在网页接受条款。
|
| 6 |
-
|
| 7 |
-
用法::
|
| 8 |
-
|
| 9 |
-
python scripts/push_cpu_ingest_job.py --bucket fuzirui/wjad-cosmos-data
|
| 10 |
-
python scripts/push_cpu_ingest_job.py --bucket fuzirui/wjad-cosmos-data --follow
|
| 11 |
-
python scripts/push_cpu_ingest_job.py --bucket fuzirui/x --source 'hf://datasets/foo/bar/'
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
from __future__ import annotations
|
| 15 |
-
|
| 16 |
-
import argparse
|
| 17 |
-
import os
|
| 18 |
-
import sys
|
| 19 |
-
|
| 20 |
-
from huggingface_hub import HfApi, Volume, create_bucket
|
| 21 |
-
|
| 22 |
-
try:
|
| 23 |
-
from huggingface_hub.cli._cli_utils import parse_env_map
|
| 24 |
-
except Exception: # pragma: no cover
|
| 25 |
-
parse_env_map = None
|
| 26 |
-
|
| 27 |
-
DEFAULT_CODE_REPO = "fuzirui/WJAD"
|
| 28 |
-
DEFAULT_BUCKET = "fuzirui/WJAD"
|
| 29 |
-
DEFAULT_SOURCE = "hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/"
|
| 30 |
-
DEFAULT_DEST_PREFIX = "cosmos_hub_mirror"
|
| 31 |
-
DEFAULT_IMAGE = "python:3.12"
|
| 32 |
-
DEFAULT_TIMEOUT = "7d"
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def _secrets_for_job() -> dict | None:
|
| 36 |
-
if parse_env_map is not None:
|
| 37 |
-
try:
|
| 38 |
-
m = parse_env_map(["HF_TOKEN"])
|
| 39 |
-
if m.get("HF_TOKEN"):
|
| 40 |
-
return m
|
| 41 |
-
except Exception:
|
| 42 |
-
pass
|
| 43 |
-
t = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
|
| 44 |
-
return {"HF_TOKEN": t} if t else None
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def main() -> None:
|
| 48 |
-
parser = argparse.ArgumentParser(description="HF Jobs:CPU 拉取 Hub → Bucket + 解压")
|
| 49 |
-
parser.add_argument("--bucket", default=DEFAULT_BUCKET, help="目标 Storage Bucket id(须已创建或加 --ensure-bucket)")
|
| 50 |
-
parser.add_argument("--code-repo", default=DEFAULT_CODE_REPO, help="含 ingest 脚本的 Hub model/space id")
|
| 51 |
-
parser.add_argument("--code-type", default="model", choices=("model", "space", "dataset"))
|
| 52 |
-
parser.add_argument("--source", default=DEFAULT_SOURCE, help="hf:// 源目录")
|
| 53 |
-
parser.add_argument("--dest-prefix", default=DEFAULT_DEST_PREFIX, help="bucket 内子路径")
|
| 54 |
-
parser.add_argument(
|
| 55 |
-
"--skip-create-bucket",
|
| 56 |
-
action="store_true",
|
| 57 |
-
help="不在本机预先 create_bucket(bucket 必须已存在,否则挂载失败)",
|
| 58 |
-
)
|
| 59 |
-
parser.add_argument(
|
| 60 |
-
"--no-extract",
|
| 61 |
-
action="store_true",
|
| 62 |
-
help="只做 copy_files,不解压 tar",
|
| 63 |
-
)
|
| 64 |
-
parser.add_argument(
|
| 65 |
-
"--max-tars",
|
| 66 |
-
type=int,
|
| 67 |
-
default=None,
|
| 68 |
-
help="传给 ingest_hub_to_bucket.py --max-tars",
|
| 69 |
-
)
|
| 70 |
-
parser.add_argument(
|
| 71 |
-
"--extract-out-prefix",
|
| 72 |
-
default=None,
|
| 73 |
-
metavar="NAME",
|
| 74 |
-
help="解压输出子路径(默认 {dest-prefix}_unpacked)",
|
| 75 |
-
)
|
| 76 |
-
parser.add_argument(
|
| 77 |
-
"--extract-beside-tar",
|
| 78 |
-
action="store_true",
|
| 79 |
-
help="旧行为:在每条 tar 旁解压为 _extracted",
|
| 80 |
-
)
|
| 81 |
-
parser.add_argument("--image", default=DEFAULT_IMAGE)
|
| 82 |
-
parser.add_argument("--timeout", default=DEFAULT_TIMEOUT)
|
| 83 |
-
parser.add_argument("--follow", action="store_true")
|
| 84 |
-
parser.add_argument("--no-secrets", action="store_true")
|
| 85 |
-
args = parser.parse_args()
|
| 86 |
-
|
| 87 |
-
bucket_mount = "/mnt/cosmos"
|
| 88 |
-
code_mount = "/workspace"
|
| 89 |
-
|
| 90 |
-
max_tars = ""
|
| 91 |
-
if args.max_tars is not None:
|
| 92 |
-
max_tars = f" --max-tars {args.max_tars}"
|
| 93 |
-
|
| 94 |
-
extract_flag = "" if args.no_extract else " --extract-tars"
|
| 95 |
-
|
| 96 |
-
extract_beside = " --extract-beside-tar" if args.extract_beside_tar else ""
|
| 97 |
-
out_prefix = ""
|
| 98 |
-
if args.extract_out_prefix:
|
| 99 |
-
out_prefix = f" --extract-out-prefix '{args.extract_out_prefix}'"
|
| 100 |
-
|
| 101 |
-
script = f"""set -euo pipefail
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
""
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""提交 **CPU Basic** Job:把 Hub 上 Cosmos(或其它源)服务端复制到你的 Bucket,并尝试解压 tar。
|
| 2 |
+
|
| 3 |
+
- 计费:见 https://huggingface.co/docs/hub/jobs-pricing(CPU Basic 约 \\$0.01/ 小时量级,以官网为准)。
|
| 4 |
+
- 默认挂载:代码 ``fuzirui/WJAD``、可写 Bucket;超时默认 48h(大仓库复制可能很久)。
|
| 5 |
+
- 须 ``hf auth login``;NVIDIA 数据集须在网页接受条款。
|
| 6 |
+
|
| 7 |
+
用法::
|
| 8 |
+
|
| 9 |
+
python scripts/push_cpu_ingest_job.py --bucket fuzirui/wjad-cosmos-data
|
| 10 |
+
python scripts/push_cpu_ingest_job.py --bucket fuzirui/wjad-cosmos-data --follow
|
| 11 |
+
python scripts/push_cpu_ingest_job.py --bucket fuzirui/x --source 'hf://datasets/foo/bar/'
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
|
| 20 |
+
from huggingface_hub import HfApi, Volume, create_bucket
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from huggingface_hub.cli._cli_utils import parse_env_map
|
| 24 |
+
except Exception: # pragma: no cover
|
| 25 |
+
parse_env_map = None
|
| 26 |
+
|
| 27 |
+
DEFAULT_CODE_REPO = "fuzirui/WJAD"
|
| 28 |
+
DEFAULT_BUCKET = "fuzirui/WJAD"
|
| 29 |
+
DEFAULT_SOURCE = "hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/"
|
| 30 |
+
DEFAULT_DEST_PREFIX = "cosmos_hub_mirror"
|
| 31 |
+
DEFAULT_IMAGE = "python:3.12"
|
| 32 |
+
DEFAULT_TIMEOUT = "7d"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _secrets_for_job() -> dict | None:
|
| 36 |
+
if parse_env_map is not None:
|
| 37 |
+
try:
|
| 38 |
+
m = parse_env_map(["HF_TOKEN"])
|
| 39 |
+
if m.get("HF_TOKEN"):
|
| 40 |
+
return m
|
| 41 |
+
except Exception:
|
| 42 |
+
pass
|
| 43 |
+
t = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
|
| 44 |
+
return {"HF_TOKEN": t} if t else None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def main() -> None:
|
| 48 |
+
parser = argparse.ArgumentParser(description="HF Jobs:CPU 拉取 Hub → Bucket + 解压")
|
| 49 |
+
parser.add_argument("--bucket", default=DEFAULT_BUCKET, help="目标 Storage Bucket id(须已创建或加 --ensure-bucket)")
|
| 50 |
+
parser.add_argument("--code-repo", default=DEFAULT_CODE_REPO, help="含 ingest 脚本的 Hub model/space id")
|
| 51 |
+
parser.add_argument("--code-type", default="model", choices=("model", "space", "dataset"))
|
| 52 |
+
parser.add_argument("--source", default=DEFAULT_SOURCE, help="hf:// 源目录")
|
| 53 |
+
parser.add_argument("--dest-prefix", default=DEFAULT_DEST_PREFIX, help="bucket 内子路径")
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--skip-create-bucket",
|
| 56 |
+
action="store_true",
|
| 57 |
+
help="不在本机预先 create_bucket(bucket 必须已存在,否则挂载失败)",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--no-extract",
|
| 61 |
+
action="store_true",
|
| 62 |
+
help="只做 copy_files,不解压 tar",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--max-tars",
|
| 66 |
+
type=int,
|
| 67 |
+
default=None,
|
| 68 |
+
help="传给 ingest_hub_to_bucket.py --max-tars",
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--extract-out-prefix",
|
| 72 |
+
default=None,
|
| 73 |
+
metavar="NAME",
|
| 74 |
+
help="解压输出子路径(默认 {dest-prefix}_unpacked)",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--extract-beside-tar",
|
| 78 |
+
action="store_true",
|
| 79 |
+
help="旧行为:在每条 tar 旁解压为 _extracted",
|
| 80 |
+
)
|
| 81 |
+
parser.add_argument("--image", default=DEFAULT_IMAGE)
|
| 82 |
+
parser.add_argument("--timeout", default=DEFAULT_TIMEOUT)
|
| 83 |
+
parser.add_argument("--follow", action="store_true")
|
| 84 |
+
parser.add_argument("--no-secrets", action="store_true")
|
| 85 |
+
args = parser.parse_args()
|
| 86 |
+
|
| 87 |
+
bucket_mount = "/mnt/cosmos"
|
| 88 |
+
code_mount = "/workspace"
|
| 89 |
+
|
| 90 |
+
max_tars = ""
|
| 91 |
+
if args.max_tars is not None:
|
| 92 |
+
max_tars = f" --max-tars {args.max_tars}"
|
| 93 |
+
|
| 94 |
+
extract_flag = "" if args.no_extract else " --extract-tars"
|
| 95 |
+
|
| 96 |
+
extract_beside = " --extract-beside-tar" if args.extract_beside_tar else ""
|
| 97 |
+
out_prefix = ""
|
| 98 |
+
if args.extract_out_prefix:
|
| 99 |
+
out_prefix = f" --extract-out-prefix '{args.extract_out_prefix}'"
|
| 100 |
+
|
| 101 |
+
script = f"""set -euo pipefail
|
| 102 |
+
Eph="{bucket_mount}/.wjad_ephemeral"
|
| 103 |
+
mkdir -p "$Eph/tmp" "$Eph/hf_home/hub" "$Eph/xdg_cache"
|
| 104 |
+
export TMPDIR="$Eph/tmp"
|
| 105 |
+
export TMP="$TMPDIR" TEMP="$TMPDIR"
|
| 106 |
+
export HF_HOME="$Eph/hf_home"
|
| 107 |
+
export HF_HUB_CACHE="$HF_HOME/hub"
|
| 108 |
+
export XDG_CACHE_HOME="$Eph/xdg_cache"
|
| 109 |
+
pip install --root-user-action=ignore --no-cache-dir 'huggingface_hub>=0.30'
|
| 110 |
+
python {code_mount}/scripts/ingest_hub_to_bucket.py \\
|
| 111 |
+
--bucket '{args.bucket}' \\
|
| 112 |
+
--source '{args.source}' \\
|
| 113 |
+
--dest-prefix '{args.dest_prefix}' \\
|
| 114 |
+
--bucket-mount '{bucket_mount}'{extract_flag}{max_tars}{out_prefix}{extract_beside}
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
secrets = None if args.no_secrets else _secrets_for_job()
|
| 118 |
+
if secrets is None and not args.no_secrets:
|
| 119 |
+
print("[push_cpu_ingest] 警告: 无 HF_TOKEN,gated 数据会失败。", file=sys.stderr)
|
| 120 |
+
|
| 121 |
+
if not args.skip_create_bucket:
|
| 122 |
+
create_bucket(args.bucket, exist_ok=True)
|
| 123 |
+
print(f"[push_cpu_ingest] bucket 已确保存在(或已存在): {args.bucket}")
|
| 124 |
+
|
| 125 |
+
volumes = [
|
| 126 |
+
Volume(type=args.code_type, source=args.code_repo, mount_path=code_mount),
|
| 127 |
+
Volume(type="bucket", source=args.bucket, mount_path=bucket_mount),
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
api = HfApi()
|
| 131 |
+
job = api.run_job(
|
| 132 |
+
image=args.image,
|
| 133 |
+
command=["bash", "-lc", script],
|
| 134 |
+
flavor="cpu-basic",
|
| 135 |
+
volumes=volumes,
|
| 136 |
+
secrets=secrets,
|
| 137 |
+
timeout=args.timeout,
|
| 138 |
+
)
|
| 139 |
+
print(f"[push_cpu_ingest] Job ID: {job.id}")
|
| 140 |
+
print(f"[push_cpu_ingest] URL: {job.url}")
|
| 141 |
+
|
| 142 |
+
if args.follow:
|
| 143 |
+
for line in api.fetch_job_logs(job_id=job.id, namespace=job.owner.name, follow=True):
|
| 144 |
+
print(line, end="" if str(line).endswith("\n") else "\n")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
main()
|
scripts/push_to_jobs.py
CHANGED
|
@@ -1,196 +1,196 @@
|
|
| 1 |
-
"""提交 Hugging Face Jobs 正式训练。
|
| 2 |
-
|
| 3 |
-
- 代码仓库挂载为只读后复制到 ``/tmp/wjad-run`` 再 ``pip install -e .``。
|
| 4 |
-
- **数据**:``CosmosDriveDreamsDataset`` 需要 NVIDIA ``download.py`` 拉下来的目录树
|
| 5 |
-
(``synthetic/single_view/generation/*.mp4`` + ``labels/``)。Hub 上 **datasets 视图挂载**
|
| 6 |
-
不是这棵树,``build_clip_index`` 会得到 0 条样本。
|
| 7 |
-
- **默认**:在 Job 里先执行 ``scripts/download_data.py``,把数据落到可写目录
|
| 8 |
-
``WJAD_DATA_ROOT``(默认 ``/tmp/wjad-cosmos``)——即 **一次性下载**(按 clip 限流可用
|
| 9 |
-
``--download-limit``)。全量约 TB 级,请用大磁盘 Job 或挂 **HF Bucket** 并把
|
| 10 |
-
``WJAD_DATA_ROOT`` 指到挂载路径。
|
| 11 |
-
- **流式**:当前 DataLoader 按视频/帧文件随机访问,未接 ``datasets`` 流式 API;要低改动
|
| 12 |
-
流式需另用 ``IterableDataset`` + shard,属后续工作。
|
| 13 |
-
|
| 14 |
-
用法:
|
| 15 |
-
|
| 16 |
-
python scripts/push_to_jobs.py
|
| 17 |
-
python scripts/push_to_jobs.py --follow
|
| 18 |
-
python scripts/push_to_jobs.py --download-limit 0
|
| 19 |
-
python scripts/push_to_jobs.py --skip-download
|
| 20 |
-
python scripts/push_to_jobs.py --mount-hub-dataset
|
| 21 |
-
"""
|
| 22 |
-
|
| 23 |
-
from __future__ import annotations
|
| 24 |
-
|
| 25 |
-
import argparse
|
| 26 |
-
import sys
|
| 27 |
-
|
| 28 |
-
from huggingface_hub import HfApi, Volume
|
| 29 |
-
|
| 30 |
-
try:
|
| 31 |
-
from huggingface_hub.cli._cli_utils import parse_env_map
|
| 32 |
-
except Exception: # pragma: no cover
|
| 33 |
-
parse_env_map = None
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
DEFAULT_FLAVOR = "a10g-large"
|
| 37 |
-
DEFAULT_IMAGE = "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-runtime"
|
| 38 |
-
DEFAULT_REPO = "fuzirui/WJAD"
|
| 39 |
-
DEFAULT_MOUNT = "/workspace"
|
| 40 |
-
DEFAULT_HUB_DATASET = "nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams"
|
| 41 |
-
DEFAULT_HUB_DATASET_MOUNT = "/data/cosmos"
|
| 42 |
-
DEFAULT_DATA_PREP_DIR = "/tmp/wjad-cosmos"
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def _secrets_for_job() -> dict | None:
|
| 46 |
-
if parse_env_map is not None:
|
| 47 |
-
try:
|
| 48 |
-
m = parse_env_map(["HF_TOKEN"])
|
| 49 |
-
if m.get("HF_TOKEN"):
|
| 50 |
-
return m
|
| 51 |
-
except Exception:
|
| 52 |
-
pass
|
| 53 |
-
import os
|
| 54 |
-
|
| 55 |
-
t = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
|
| 56 |
-
return {"HF_TOKEN": t} if t else None
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def main() -> None:
|
| 60 |
-
parser = argparse.ArgumentParser(description="在 HF Jobs 上启动 WJAD 训练")
|
| 61 |
-
parser.add_argument("--repo", default=DEFAULT_REPO, help="含本仓库代码的 Hub repo id")
|
| 62 |
-
parser.add_argument(
|
| 63 |
-
"--repo-type",
|
| 64 |
-
default="model",
|
| 65 |
-
choices=("model", "space", "dataset"),
|
| 66 |
-
help="仓库类型",
|
| 67 |
-
)
|
| 68 |
-
parser.add_argument("--mount", default=DEFAULT_MOUNT, help="代码在容器内的挂载路径")
|
| 69 |
-
parser.add_argument("--flavor", default=DEFAULT_FLAVOR)
|
| 70 |
-
parser.add_argument("--image", default=DEFAULT_IMAGE)
|
| 71 |
-
parser.add_argument(
|
| 72 |
-
"--follow",
|
| 73 |
-
action="store_true",
|
| 74 |
-
help="跟随 Job 日志直到结束",
|
| 75 |
-
)
|
| 76 |
-
parser.add_argument("--no-secrets", action="store_true")
|
| 77 |
-
parser.add_argument("--timeout", default=None, help="如 168h、7d")
|
| 78 |
-
# —— 数据:默认 NVIDIA 下载到可写目录 ——
|
| 79 |
-
parser.add_argument(
|
| 80 |
-
"--data-prep-dir",
|
| 81 |
-
default=DEFAULT_DATA_PREP_DIR,
|
| 82 |
-
help="下载目标目录(可写)。可被环境变量 WJAD_DATA_ROOT 覆盖",
|
| 83 |
-
)
|
| 84 |
-
parser.add_argument(
|
| 85 |
-
"--download-workers",
|
| 86 |
-
type=int,
|
| 87 |
-
default=8,
|
| 88 |
-
help="download_data.py --workers",
|
| 89 |
-
)
|
| 90 |
-
parser.add_argument(
|
| 91 |
-
"--download-limit",
|
| 92 |
-
type=int,
|
| 93 |
-
default=8,
|
| 94 |
-
metavar="N",
|
| 95 |
-
help="传给 NVIDIA --limit;默认 8 个 clip 控制磁盘。0=不限制(全量,需足够盘或 Bucket)",
|
| 96 |
-
)
|
| 97 |
-
parser.add_argument(
|
| 98 |
-
"--skip-download",
|
| 99 |
-
action="store_true",
|
| 100 |
-
help="不运行 download_data.py(数据已存在于 WJAD_DATA_ROOT 或 Bucket 挂载点)",
|
| 101 |
-
)
|
| 102 |
-
# —— 可选:挂载 Hub dataset(当前 loader 一般不兼容,仅特殊预处理树可用)——
|
| 103 |
-
parser.add_argument(
|
| 104 |
-
"--mount-hub-dataset",
|
| 105 |
-
action="store_true",
|
| 106 |
-
help="额外只读挂载 nvidia/Cosmos dataset 到 --hub-dataset-mount(与自动下载互斥)",
|
| 107 |
-
)
|
| 108 |
-
parser.add_argument("--hub-dataset", default=DEFAULT_HUB_DATASET, metavar="REPO_ID")
|
| 109 |
-
parser.add_argument("--hub-dataset-mount", default=DEFAULT_HUB_DATASET_MOUNT)
|
| 110 |
-
|
| 111 |
-
args = parser.parse_args()
|
| 112 |
-
|
| 113 |
-
code_vol = Volume(type=args.repo_type, source=args.repo, mount_path=args.mount)
|
| 114 |
-
ro_mount = args.mount
|
| 115 |
-
work = "/tmp/wjad-run"
|
| 116 |
-
volumes: list[Volume] = [code_vol]
|
| 117 |
-
|
| 118 |
-
if args.mount_hub_dataset:
|
| 119 |
-
volumes.append(
|
| 120 |
-
Volume(type="dataset", source=args.hub_dataset, mount_path=args.hub_dataset_mount)
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
-
use_auto_download = not args.skip_download and not args.mount_hub_dataset
|
| 124 |
-
data_root_default = (
|
| 125 |
-
args.hub_dataset_mount if args.mount_hub_dataset else args.data_prep_dir
|
| 126 |
-
)
|
| 127 |
-
|
| 128 |
-
limit_tail = ""
|
| 129 |
-
if use_auto_download and args.download_limit > 0:
|
| 130 |
-
limit_tail = f" --limit {args.download_limit}"
|
| 131 |
-
|
| 132 |
-
download_block = ""
|
| 133 |
-
if use_auto_download:
|
| 134 |
-
download_block = f"""
|
| 135 |
-
mkdir -p "$DATA_ROOT"
|
| 136 |
-
python scripts/download_data.py --odir "$DATA_ROOT" \\
|
| 137 |
-
--file_types synthetic,lidar,hdmap --workers {args.download_workers}{limit_tail}
|
| 138 |
-
"""
|
| 139 |
-
|
| 140 |
-
script = f"""set -euo pipefail
|
| 141 |
-
rm -rf {work}
|
| 142 |
-
cp -a {ro_mount} {work}
|
| 143 |
-
cd {work}
|
| 144 |
-
export PIP_ROOT_USER_ACTION=ignore
|
| 145 |
-
pip install --root-user-action=ignore --no-cache-dir -e .
|
| 146 |
-
export PYTHONPATH="{work}/src:${{PYTHONPATH:-}}"
|
| 147 |
-
DATA_ROOT="${{WJAD_DATA_ROOT:-{data_root_default}}}"
|
| 148 |
-
{download_block}
|
| 149 |
-
python -m wjad.train.runner_local \\
|
| 150 |
-
--device cuda \\
|
| 151 |
-
--config configs/default.yaml \\
|
| 152 |
-
--data_root "$DATA_ROOT" \\
|
| 153 |
-
--dinov3_path "${{DINOV3_PATH:-{work}/dinov3-vitb16-pretrain-lvd1689m}}"
|
| 154 |
-
"""
|
| 155 |
-
|
| 156 |
-
secrets = None if args.no_secrets else _secrets_for_job()
|
| 157 |
-
if secrets is None and not args.no_secrets:
|
| 158 |
-
print(
|
| 159 |
-
"[push_to_jobs] 警告: 未解析到 HF_TOKEN,下载/checkpoint 可能失败。请先 hf auth login。",
|
| 160 |
-
file=sys.stderr,
|
| 161 |
-
)
|
| 162 |
-
|
| 163 |
-
if args.mount_hub_dataset:
|
| 164 |
-
print(
|
| 165 |
-
f"[push_to_jobs] 已挂载 Hub dataset(只读){args.hub_dataset} -> {args.hub_dataset_mount};"
|
| 166 |
-
"若仍 0 样本,说明布局不是 synthetic/*/generation + labels/,请不要用 --mount-hub-dataset,"
|
| 167 |
-
"改用默认自动 download。"
|
| 168 |
-
)
|
| 169 |
-
elif use_auto_download:
|
| 170 |
-
lim_msg = f"limit={args.download_limit}" if args.download_limit > 0 else "无 limit(全量)"
|
| 171 |
-
print(
|
| 172 |
-
f"[push_to_jobs] 将下载到 DATA_ROOT={data_root_default}({lim_msg})。"
|
| 173 |
-
"全量请 --download-limit 0 并保证磁盘或 Bucket。"
|
| 174 |
-
)
|
| 175 |
-
else:
|
| 176 |
-
print("[push_to_jobs] 已 --skip-download,请保证 $WJAD_DATA_ROOT 下已有 NVIDIA 布局数据。")
|
| 177 |
-
|
| 178 |
-
api = HfApi()
|
| 179 |
-
job = api.run_job(
|
| 180 |
-
image=args.image,
|
| 181 |
-
command=["bash", "-lc", script],
|
| 182 |
-
flavor=args.flavor,
|
| 183 |
-
volumes=volumes,
|
| 184 |
-
secrets=secrets,
|
| 185 |
-
timeout=args.timeout,
|
| 186 |
-
)
|
| 187 |
-
print(f"[push_to_jobs] Job ID: {job.id}")
|
| 188 |
-
print(f"[push_to_jobs] URL: {job.url}")
|
| 189 |
-
|
| 190 |
-
if args.follow:
|
| 191 |
-
for line in api.fetch_job_logs(job_id=job.id, namespace=job.owner.name, follow=True):
|
| 192 |
-
print(line, end="" if str(line).endswith("\n") else "\n")
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
if __name__ == "__main__":
|
| 196 |
-
main()
|
|
|
|
| 1 |
+
"""提交 Hugging Face Jobs 正式训练。
|
| 2 |
+
|
| 3 |
+
- 代码仓库挂载为只读后复制到 ``/tmp/wjad-run`` 再 ``pip install -e .``。
|
| 4 |
+
- **数据**:``CosmosDriveDreamsDataset`` 需要 NVIDIA ``download.py`` 拉下来的目录树
|
| 5 |
+
(``synthetic/single_view/generation/*.mp4`` + ``labels/``)。Hub 上 **datasets 视图挂载**
|
| 6 |
+
不是这棵树,``build_clip_index`` 会得到 0 条样本。
|
| 7 |
+
- **默认**:在 Job 里先执行 ``scripts/download_data.py``,把数据落到可写目录
|
| 8 |
+
``WJAD_DATA_ROOT``(默认 ``/tmp/wjad-cosmos``)——即 **一次性下载**(按 clip 限流可用
|
| 9 |
+
``--download-limit``)。全量约 TB 级,请用大磁盘 Job 或挂 **HF Bucket** 并把
|
| 10 |
+
``WJAD_DATA_ROOT`` 指到挂载路径。
|
| 11 |
+
- **流式**:当前 DataLoader 按视频/帧文件随机访问,未接 ``datasets`` 流式 API;要低改动
|
| 12 |
+
流式需另用 ``IterableDataset`` + shard,属后续工作。
|
| 13 |
+
|
| 14 |
+
用法:
|
| 15 |
+
|
| 16 |
+
python scripts/push_to_jobs.py
|
| 17 |
+
python scripts/push_to_jobs.py --follow
|
| 18 |
+
python scripts/push_to_jobs.py --download-limit 0
|
| 19 |
+
python scripts/push_to_jobs.py --skip-download
|
| 20 |
+
python scripts/push_to_jobs.py --mount-hub-dataset
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import sys
|
| 27 |
+
|
| 28 |
+
from huggingface_hub import HfApi, Volume
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from huggingface_hub.cli._cli_utils import parse_env_map
|
| 32 |
+
except Exception: # pragma: no cover
|
| 33 |
+
parse_env_map = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
DEFAULT_FLAVOR = "a10g-large"
|
| 37 |
+
DEFAULT_IMAGE = "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-runtime"
|
| 38 |
+
DEFAULT_REPO = "fuzirui/WJAD"
|
| 39 |
+
DEFAULT_MOUNT = "/workspace"
|
| 40 |
+
DEFAULT_HUB_DATASET = "nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams"
|
| 41 |
+
DEFAULT_HUB_DATASET_MOUNT = "/data/cosmos"
|
| 42 |
+
DEFAULT_DATA_PREP_DIR = "/tmp/wjad-cosmos"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _secrets_for_job() -> dict | None:
|
| 46 |
+
if parse_env_map is not None:
|
| 47 |
+
try:
|
| 48 |
+
m = parse_env_map(["HF_TOKEN"])
|
| 49 |
+
if m.get("HF_TOKEN"):
|
| 50 |
+
return m
|
| 51 |
+
except Exception:
|
| 52 |
+
pass
|
| 53 |
+
import os
|
| 54 |
+
|
| 55 |
+
t = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
|
| 56 |
+
return {"HF_TOKEN": t} if t else None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def main() -> None:
|
| 60 |
+
parser = argparse.ArgumentParser(description="在 HF Jobs 上启动 WJAD 训练")
|
| 61 |
+
parser.add_argument("--repo", default=DEFAULT_REPO, help="含本仓库代码的 Hub repo id")
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--repo-type",
|
| 64 |
+
default="model",
|
| 65 |
+
choices=("model", "space", "dataset"),
|
| 66 |
+
help="仓库类型",
|
| 67 |
+
)
|
| 68 |
+
parser.add_argument("--mount", default=DEFAULT_MOUNT, help="代码在容器内的挂载路径")
|
| 69 |
+
parser.add_argument("--flavor", default=DEFAULT_FLAVOR)
|
| 70 |
+
parser.add_argument("--image", default=DEFAULT_IMAGE)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--follow",
|
| 73 |
+
action="store_true",
|
| 74 |
+
help="跟随 Job 日志直到结束",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument("--no-secrets", action="store_true")
|
| 77 |
+
parser.add_argument("--timeout", default=None, help="如 168h、7d")
|
| 78 |
+
# —— 数据:默认 NVIDIA 下载到可写目录 ——
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--data-prep-dir",
|
| 81 |
+
default=DEFAULT_DATA_PREP_DIR,
|
| 82 |
+
help="下载目标目录(可写)。可被环境变量 WJAD_DATA_ROOT 覆盖",
|
| 83 |
+
)
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--download-workers",
|
| 86 |
+
type=int,
|
| 87 |
+
default=8,
|
| 88 |
+
help="download_data.py --workers",
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--download-limit",
|
| 92 |
+
type=int,
|
| 93 |
+
default=8,
|
| 94 |
+
metavar="N",
|
| 95 |
+
help="传给 NVIDIA --limit;默认 8 个 clip 控制磁盘。0=不限制(全量,需足够盘或 Bucket)",
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--skip-download",
|
| 99 |
+
action="store_true",
|
| 100 |
+
help="不运行 download_data.py(数据已存在于 WJAD_DATA_ROOT 或 Bucket 挂载点)",
|
| 101 |
+
)
|
| 102 |
+
# —— 可选:挂载 Hub dataset(当前 loader 一般不兼容,仅特殊预处理树可用)——
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--mount-hub-dataset",
|
| 105 |
+
action="store_true",
|
| 106 |
+
help="额外只读挂载 nvidia/Cosmos dataset 到 --hub-dataset-mount(与自动下载互斥)",
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument("--hub-dataset", default=DEFAULT_HUB_DATASET, metavar="REPO_ID")
|
| 109 |
+
parser.add_argument("--hub-dataset-mount", default=DEFAULT_HUB_DATASET_MOUNT)
|
| 110 |
+
|
| 111 |
+
args = parser.parse_args()
|
| 112 |
+
|
| 113 |
+
code_vol = Volume(type=args.repo_type, source=args.repo, mount_path=args.mount)
|
| 114 |
+
ro_mount = args.mount
|
| 115 |
+
work = "/tmp/wjad-run"
|
| 116 |
+
volumes: list[Volume] = [code_vol]
|
| 117 |
+
|
| 118 |
+
if args.mount_hub_dataset:
|
| 119 |
+
volumes.append(
|
| 120 |
+
Volume(type="dataset", source=args.hub_dataset, mount_path=args.hub_dataset_mount)
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
use_auto_download = not args.skip_download and not args.mount_hub_dataset
|
| 124 |
+
data_root_default = (
|
| 125 |
+
args.hub_dataset_mount if args.mount_hub_dataset else args.data_prep_dir
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
limit_tail = ""
|
| 129 |
+
if use_auto_download and args.download_limit > 0:
|
| 130 |
+
limit_tail = f" --limit {args.download_limit}"
|
| 131 |
+
|
| 132 |
+
download_block = ""
|
| 133 |
+
if use_auto_download:
|
| 134 |
+
download_block = f"""
|
| 135 |
+
mkdir -p "$DATA_ROOT"
|
| 136 |
+
python scripts/download_data.py --odir "$DATA_ROOT" \\
|
| 137 |
+
--file_types synthetic,lidar,hdmap --workers {args.download_workers}{limit_tail}
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
script = f"""set -euo pipefail
|
| 141 |
+
rm -rf {work}
|
| 142 |
+
cp -a {ro_mount} {work}
|
| 143 |
+
cd {work}
|
| 144 |
+
export PIP_ROOT_USER_ACTION=ignore
|
| 145 |
+
pip install --root-user-action=ignore --no-cache-dir -e .
|
| 146 |
+
export PYTHONPATH="{work}/src:${{PYTHONPATH:-}}"
|
| 147 |
+
DATA_ROOT="${{WJAD_DATA_ROOT:-{data_root_default}}}"
|
| 148 |
+
{download_block}
|
| 149 |
+
python -m wjad.train.runner_local \\
|
| 150 |
+
--device cuda \\
|
| 151 |
+
--config configs/default.yaml \\
|
| 152 |
+
--data_root "$DATA_ROOT" \\
|
| 153 |
+
--dinov3_path "${{DINOV3_PATH:-{work}/dinov3-vitb16-pretrain-lvd1689m}}"
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
secrets = None if args.no_secrets else _secrets_for_job()
|
| 157 |
+
if secrets is None and not args.no_secrets:
|
| 158 |
+
print(
|
| 159 |
+
"[push_to_jobs] 警告: 未解析到 HF_TOKEN,下载/checkpoint 可能失败。请先 hf auth login。",
|
| 160 |
+
file=sys.stderr,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
if args.mount_hub_dataset:
|
| 164 |
+
print(
|
| 165 |
+
f"[push_to_jobs] 已挂载 Hub dataset(只读){args.hub_dataset} -> {args.hub_dataset_mount};"
|
| 166 |
+
"若仍 0 样本,说明布局不是 synthetic/*/generation + labels/,请不要用 --mount-hub-dataset,"
|
| 167 |
+
"改用默认自动 download。"
|
| 168 |
+
)
|
| 169 |
+
elif use_auto_download:
|
| 170 |
+
lim_msg = f"limit={args.download_limit}" if args.download_limit > 0 else "无 limit(全量)"
|
| 171 |
+
print(
|
| 172 |
+
f"[push_to_jobs] 将下载到 DATA_ROOT={data_root_default}({lim_msg})。"
|
| 173 |
+
"全量请 --download-limit 0 并保证磁盘或 Bucket。"
|
| 174 |
+
)
|
| 175 |
+
else:
|
| 176 |
+
print("[push_to_jobs] 已 --skip-download,请保证 $WJAD_DATA_ROOT 下已有 NVIDIA 布局数据。")
|
| 177 |
+
|
| 178 |
+
api = HfApi()
|
| 179 |
+
job = api.run_job(
|
| 180 |
+
image=args.image,
|
| 181 |
+
command=["bash", "-lc", script],
|
| 182 |
+
flavor=args.flavor,
|
| 183 |
+
volumes=volumes,
|
| 184 |
+
secrets=secrets,
|
| 185 |
+
timeout=args.timeout,
|
| 186 |
+
)
|
| 187 |
+
print(f"[push_to_jobs] Job ID: {job.id}")
|
| 188 |
+
print(f"[push_to_jobs] URL: {job.url}")
|
| 189 |
+
|
| 190 |
+
if args.follow:
|
| 191 |
+
for line in api.fetch_job_logs(job_id=job.id, namespace=job.owner.name, follow=True):
|
| 192 |
+
print(line, end="" if str(line).endswith("\n") else "\n")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
if __name__ == "__main__":
|
| 196 |
+
main()
|
scripts/push_to_sandbox.py
CHANGED
|
@@ -1,185 +1,185 @@
|
|
| 1 |
-
"""推送代码到 HF Space,做 sandbox 微训练。
|
| 2 |
-
|
| 3 |
-
依据 ``estimate_memory.py`` 的估算:
|
| 4 |
-
- BS=8 + bf16 + PCGrad + GradNorm 需要 ≥34 GB 显存;
|
| 5 |
-
- 默认硬件 **a10g-small**(~24 GB):与 ``smoke_train`` / ``sandbox_real_data`` 的 tiny 设置一致;
|
| 6 |
-
- 要拉满 BS=8 可改用 ``--gpu a10g-large`` 或 A100。
|
| 7 |
-
|
| 8 |
-
本脚本:
|
| 9 |
-
1. ``huggingface_hub.create_repo`` 在 HF 上创建(或复用)一个 Space,
|
| 10 |
-
Space SDK = ``docker``;
|
| 11 |
-
2. 用 ``upload_folder`` 上传当前仓库(排除 ``.venv``、数据集等);
|
| 12 |
-
3. 写入 ``Dockerfile`` + ``app.py``(在 Space 启动时跑微训练)。
|
| 13 |
-
|
| 14 |
-
要求:先在本地 ``hf auth login``。
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
from __future__ import annotations
|
| 18 |
-
|
| 19 |
-
import argparse
|
| 20 |
-
from pathlib import Path
|
| 21 |
-
|
| 22 |
-
from huggingface_hub import HfApi, create_repo
|
| 23 |
-
|
| 24 |
-
ROOT = Path(__file__).resolve().parent.parent
|
| 25 |
-
DOCKERFILE = """FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
|
| 26 |
-
|
| 27 |
-
ENV DEBIAN_FRONTEND=noninteractive
|
| 28 |
-
ENV PYTHONUNBUFFERED=1
|
| 29 |
-
RUN apt-get update && apt-get install -y --no-install-recommends \\
|
| 30 |
-
python3 python3-pip python3-venv ffmpeg libgl1 libglib2.0-0 git \\
|
| 31 |
-
&& rm -rf /var/lib/apt/lists/*
|
| 32 |
-
|
| 33 |
-
# HF Space 默认用户(避免权限问题)
|
| 34 |
-
RUN useradd -m -u 1000 user
|
| 35 |
-
USER user
|
| 36 |
-
ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
|
| 37 |
-
|
| 38 |
-
WORKDIR /app
|
| 39 |
-
COPY --chown=user pyproject.toml /app/
|
| 40 |
-
COPY --chown=user src /app/src
|
| 41 |
-
COPY --chown=user scripts /app/scripts
|
| 42 |
-
COPY --chown=user configs /app/configs
|
| 43 |
-
COPY --chown=user dinov3-vitb16-pretrain-lvd1689m /app/dinov3-vitb16-pretrain-lvd1689m
|
| 44 |
-
COPY --chown=user app.py /app/app.py
|
| 45 |
-
|
| 46 |
-
RUN python3 -m pip install --user --no-cache-dir --upgrade pip \\
|
| 47 |
-
&& python3 -m pip install --user --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cu124 \\
|
| 48 |
-
&& python3 -m pip install --user --no-cache-dir -e .
|
| 49 |
-
|
| 50 |
-
EXPOSE 7860
|
| 51 |
-
CMD ["python3", "app.py"]
|
| 52 |
-
"""
|
| 53 |
-
|
| 54 |
-
APP_PY = '''"""HF Sandbox 入口(docker SDK,监听 7860)。
|
| 55 |
-
|
| 56 |
-
启动后:
|
| 57 |
-
1. 后台进程跑 scripts/smoke_train.py(追加写入 /tmp/wjad.log)
|
| 58 |
-
2. 主进程开 HTTP server on :7860,返回最新日志
|
| 59 |
-
|
| 60 |
-
阶段 A(无需数据):smoke_train 用随机张量验证 GPU 上的 forward/反传/AMP/PCGrad。
|
| 61 |
-
阶段 B(需要数据):把 LAUNCH_CMD 改为 runner_local 的真实训练命令。
|
| 62 |
-
"""
|
| 63 |
-
import os
|
| 64 |
-
import subprocess
|
| 65 |
-
import sys
|
| 66 |
-
import threading
|
| 67 |
-
from http.server import BaseHTTPRequestHandler, HTTPServer
|
| 68 |
-
|
| 69 |
-
LOG_PATH = "/tmp/wjad.log"
|
| 70 |
-
PORT = 7860
|
| 71 |
-
# 当 SANDBOX_MODE=real_data 时跑真实标签 + 占位视频;否则跑随机张量 smoke。
|
| 72 |
-
_MODE = os.environ.get("SANDBOX_MODE", "smoke")
|
| 73 |
-
if _MODE == "real_data":
|
| 74 |
-
LAUNCH_CMD = [sys.executable, "scripts/sandbox_real_data.py"]
|
| 75 |
-
else:
|
| 76 |
-
LAUNCH_CMD = [sys.executable, "scripts/smoke_train.py"]
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def _print_env(f):
|
| 80 |
-
f.write("=" * 72 + "\\n")
|
| 81 |
-
f.write(" WJAD HF Sandbox\\n")
|
| 82 |
-
f.write("=" * 72 + "\\n")
|
| 83 |
-
f.write(f"Python: {sys.version}\\n")
|
| 84 |
-
try:
|
| 85 |
-
import torch
|
| 86 |
-
f.write(f"torch: {torch.__version__} cuda_avail={torch.cuda.is_available()}\\n")
|
| 87 |
-
if torch.cuda.is_available():
|
| 88 |
-
p = torch.cuda.get_device_properties(0)
|
| 89 |
-
f.write(f"device: {p.name} vram={p.total_memory / 1024**3:.2f} GB\\n")
|
| 90 |
-
except Exception as e:
|
| 91 |
-
f.write(f"torch import failed: {e}\\n")
|
| 92 |
-
f.flush()
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def run_training():
|
| 96 |
-
with open(LOG_PATH, "w", buffering=1) as f:
|
| 97 |
-
_print_env(f)
|
| 98 |
-
f.write(f"$ {' '.join(LAUNCH_CMD)}\\n")
|
| 99 |
-
f.flush()
|
| 100 |
-
p = subprocess.Popen(
|
| 101 |
-
LAUNCH_CMD, stdout=f, stderr=subprocess.STDOUT, cwd="/app"
|
| 102 |
-
)
|
| 103 |
-
rc = p.wait()
|
| 104 |
-
f.write(f"\\n[exit code = {rc}]\\n")
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
class Handler(BaseHTTPRequestHandler):
|
| 108 |
-
def do_GET(self):
|
| 109 |
-
try:
|
| 110 |
-
with open(LOG_PATH, "r") as f:
|
| 111 |
-
body = f.read()
|
| 112 |
-
except FileNotFoundError:
|
| 113 |
-
body = "starting..."
|
| 114 |
-
self.send_response(200)
|
| 115 |
-
self.send_header("Content-Type", "text/plain; charset=utf-8")
|
| 116 |
-
self.end_headers()
|
| 117 |
-
self.wfile.write(body.encode("utf-8"))
|
| 118 |
-
|
| 119 |
-
def log_message(self, fmt, *args):
|
| 120 |
-
return
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
if __name__ == "__main__":
|
| 124 |
-
threading.Thread(target=run_training, daemon=True).start()
|
| 125 |
-
HTTPServer(("0.0.0.0", PORT), Handler).serve_forever()
|
| 126 |
-
'''
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
def main() -> None:
|
| 130 |
-
parser = argparse.ArgumentParser()
|
| 131 |
-
parser.add_argument("--repo", required=True, help="HF Space repo, e.g. user/wjad-sandbox")
|
| 132 |
-
parser.add_argument("--gpu", default="a10g-small", help="HF Spaces 硬件,默认 a10g-small(省 GPU 小时)")
|
| 133 |
-
parser.add_argument("--private", action="store_true")
|
| 134 |
-
parser.add_argument(
|
| 135 |
-
"--mode",
|
| 136 |
-
choices=["smoke", "real_data"],
|
| 137 |
-
default="smoke",
|
| 138 |
-
help="smoke=随机张量;real_data=拉真实标签+占位视频跑 trainer",
|
| 139 |
-
)
|
| 140 |
-
args = parser.parse_args()
|
| 141 |
-
|
| 142 |
-
api = HfApi()
|
| 143 |
-
print(f"[push_to_sandbox] 创建 / 复用 Space: {args.repo} (GPU={args.gpu}, mode={args.mode})")
|
| 144 |
-
create_repo(
|
| 145 |
-
args.repo,
|
| 146 |
-
repo_type="space",
|
| 147 |
-
space_sdk="docker",
|
| 148 |
-
space_hardware=args.gpu,
|
| 149 |
-
private=args.private,
|
| 150 |
-
exist_ok=True,
|
| 151 |
-
)
|
| 152 |
-
|
| 153 |
-
# 把 SANDBOX_MODE 写到 Space 变量;HF_TOKEN 需要用户自己在 Space Settings
|
| 154 |
-
# -> Secrets 里加一份能访问 NVIDIA 数据集的 token(real_data 模式必须)。
|
| 155 |
-
api.add_space_variable(repo_id=args.repo, key="SANDBOX_MODE", value=args.mode)
|
| 156 |
-
if args.mode == "real_data":
|
| 157 |
-
print(
|
| 158 |
-
"[push_to_sandbox] 提醒:real_data 模式需要在 Space Settings -> Secrets "
|
| 159 |
-
"里手动添加 HF_TOKEN(必须是能访问 nvidia/PhysicalAI-Autonomous-Vehicle-"
|
| 160 |
-
"Cosmos-Drive-Dreams 的账号 token,否则 download.py 会拒绝访问)。"
|
| 161 |
-
)
|
| 162 |
-
|
| 163 |
-
# 落盘 Dockerfile / app.py
|
| 164 |
-
(ROOT / "Dockerfile").write_text(DOCKERFILE, encoding="utf-8")
|
| 165 |
-
(ROOT / "app.py").write_text(APP_PY, encoding="utf-8")
|
| 166 |
-
|
| 167 |
-
print("[push_to_sandbox] 上传仓库(排除 .venv / data / 缓存)...")
|
| 168 |
-
api.upload_folder(
|
| 169 |
-
folder_path=str(ROOT),
|
| 170 |
-
repo_id=args.repo,
|
| 171 |
-
repo_type="space",
|
| 172 |
-
ignore_patterns=[
|
| 173 |
-
".venv/*",
|
| 174 |
-
"data/*",
|
| 175 |
-
"**/__pycache__/*",
|
| 176 |
-
"*.pyc",
|
| 177 |
-
"agent-tools/*",
|
| 178 |
-
".git/*",
|
| 179 |
-
],
|
| 180 |
-
)
|
| 181 |
-
print(f"[push_to_sandbox] OK -> https://huggingface.co/spaces/{args.repo}")
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
if __name__ == "__main__":
|
| 185 |
-
main()
|
|
|
|
| 1 |
+
"""推送代码到 HF Space,做 sandbox 微训练。
|
| 2 |
+
|
| 3 |
+
依据 ``estimate_memory.py`` 的估算:
|
| 4 |
+
- BS=8 + bf16 + PCGrad + GradNorm 需要 ≥34 GB 显存;
|
| 5 |
+
- 默认硬件 **a10g-small**(~24 GB):与 ``smoke_train`` / ``sandbox_real_data`` 的 tiny 设置一致;
|
| 6 |
+
- 要拉满 BS=8 可改用 ``--gpu a10g-large`` 或 A100。
|
| 7 |
+
|
| 8 |
+
本脚本:
|
| 9 |
+
1. ``huggingface_hub.create_repo`` 在 HF 上创建(或复用)一个 Space,
|
| 10 |
+
Space SDK = ``docker``;
|
| 11 |
+
2. 用 ``upload_folder`` 上传当前仓库(排除 ``.venv``、数据集等);
|
| 12 |
+
3. 写入 ``Dockerfile`` + ``app.py``(在 Space 启动时跑微训练)。
|
| 13 |
+
|
| 14 |
+
要求:先在本地 ``hf auth login``。
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
from huggingface_hub import HfApi, create_repo
|
| 23 |
+
|
| 24 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 25 |
+
DOCKERFILE = """FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
|
| 26 |
+
|
| 27 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 28 |
+
ENV PYTHONUNBUFFERED=1
|
| 29 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \\
|
| 30 |
+
python3 python3-pip python3-venv ffmpeg libgl1 libglib2.0-0 git \\
|
| 31 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 32 |
+
|
| 33 |
+
# HF Space 默认用户(避免权限问题)
|
| 34 |
+
RUN useradd -m -u 1000 user
|
| 35 |
+
USER user
|
| 36 |
+
ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
|
| 37 |
+
|
| 38 |
+
WORKDIR /app
|
| 39 |
+
COPY --chown=user pyproject.toml /app/
|
| 40 |
+
COPY --chown=user src /app/src
|
| 41 |
+
COPY --chown=user scripts /app/scripts
|
| 42 |
+
COPY --chown=user configs /app/configs
|
| 43 |
+
COPY --chown=user dinov3-vitb16-pretrain-lvd1689m /app/dinov3-vitb16-pretrain-lvd1689m
|
| 44 |
+
COPY --chown=user app.py /app/app.py
|
| 45 |
+
|
| 46 |
+
RUN python3 -m pip install --user --no-cache-dir --upgrade pip \\
|
| 47 |
+
&& python3 -m pip install --user --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cu124 \\
|
| 48 |
+
&& python3 -m pip install --user --no-cache-dir -e .
|
| 49 |
+
|
| 50 |
+
EXPOSE 7860
|
| 51 |
+
CMD ["python3", "app.py"]
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
APP_PY = '''"""HF Sandbox 入口(docker SDK,监听 7860)。
|
| 55 |
+
|
| 56 |
+
启动后:
|
| 57 |
+
1. 后台进程跑 scripts/smoke_train.py(追加写入 /tmp/wjad.log)
|
| 58 |
+
2. 主进程开 HTTP server on :7860,返回最新日志
|
| 59 |
+
|
| 60 |
+
阶段 A(无需数据):smoke_train 用随机张量验证 GPU 上的 forward/反传/AMP/PCGrad。
|
| 61 |
+
阶段 B(需要数据):把 LAUNCH_CMD 改为 runner_local 的真实训练命令。
|
| 62 |
+
"""
|
| 63 |
+
import os
|
| 64 |
+
import subprocess
|
| 65 |
+
import sys
|
| 66 |
+
import threading
|
| 67 |
+
from http.server import BaseHTTPRequestHandler, HTTPServer
|
| 68 |
+
|
| 69 |
+
LOG_PATH = "/tmp/wjad.log"
|
| 70 |
+
PORT = 7860
|
| 71 |
+
# 当 SANDBOX_MODE=real_data 时跑真实标签 + 占位视频;否则跑随机张量 smoke。
|
| 72 |
+
_MODE = os.environ.get("SANDBOX_MODE", "smoke")
|
| 73 |
+
if _MODE == "real_data":
|
| 74 |
+
LAUNCH_CMD = [sys.executable, "scripts/sandbox_real_data.py"]
|
| 75 |
+
else:
|
| 76 |
+
LAUNCH_CMD = [sys.executable, "scripts/smoke_train.py"]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _print_env(f):
|
| 80 |
+
f.write("=" * 72 + "\\n")
|
| 81 |
+
f.write(" WJAD HF Sandbox\\n")
|
| 82 |
+
f.write("=" * 72 + "\\n")
|
| 83 |
+
f.write(f"Python: {sys.version}\\n")
|
| 84 |
+
try:
|
| 85 |
+
import torch
|
| 86 |
+
f.write(f"torch: {torch.__version__} cuda_avail={torch.cuda.is_available()}\\n")
|
| 87 |
+
if torch.cuda.is_available():
|
| 88 |
+
p = torch.cuda.get_device_properties(0)
|
| 89 |
+
f.write(f"device: {p.name} vram={p.total_memory / 1024**3:.2f} GB\\n")
|
| 90 |
+
except Exception as e:
|
| 91 |
+
f.write(f"torch import failed: {e}\\n")
|
| 92 |
+
f.flush()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def run_training():
|
| 96 |
+
with open(LOG_PATH, "w", buffering=1) as f:
|
| 97 |
+
_print_env(f)
|
| 98 |
+
f.write(f"$ {' '.join(LAUNCH_CMD)}\\n")
|
| 99 |
+
f.flush()
|
| 100 |
+
p = subprocess.Popen(
|
| 101 |
+
LAUNCH_CMD, stdout=f, stderr=subprocess.STDOUT, cwd="/app"
|
| 102 |
+
)
|
| 103 |
+
rc = p.wait()
|
| 104 |
+
f.write(f"\\n[exit code = {rc}]\\n")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class Handler(BaseHTTPRequestHandler):
|
| 108 |
+
def do_GET(self):
|
| 109 |
+
try:
|
| 110 |
+
with open(LOG_PATH, "r") as f:
|
| 111 |
+
body = f.read()
|
| 112 |
+
except FileNotFoundError:
|
| 113 |
+
body = "starting..."
|
| 114 |
+
self.send_response(200)
|
| 115 |
+
self.send_header("Content-Type", "text/plain; charset=utf-8")
|
| 116 |
+
self.end_headers()
|
| 117 |
+
self.wfile.write(body.encode("utf-8"))
|
| 118 |
+
|
| 119 |
+
def log_message(self, fmt, *args):
|
| 120 |
+
return
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
threading.Thread(target=run_training, daemon=True).start()
|
| 125 |
+
HTTPServer(("0.0.0.0", PORT), Handler).serve_forever()
|
| 126 |
+
'''
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def main() -> None:
|
| 130 |
+
parser = argparse.ArgumentParser()
|
| 131 |
+
parser.add_argument("--repo", required=True, help="HF Space repo, e.g. user/wjad-sandbox")
|
| 132 |
+
parser.add_argument("--gpu", default="a10g-small", help="HF Spaces 硬件,默认 a10g-small(省 GPU 小时)")
|
| 133 |
+
parser.add_argument("--private", action="store_true")
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--mode",
|
| 136 |
+
choices=["smoke", "real_data"],
|
| 137 |
+
default="smoke",
|
| 138 |
+
help="smoke=随机张量;real_data=拉真实标签+占位视频跑 trainer",
|
| 139 |
+
)
|
| 140 |
+
args = parser.parse_args()
|
| 141 |
+
|
| 142 |
+
api = HfApi()
|
| 143 |
+
print(f"[push_to_sandbox] 创建 / 复用 Space: {args.repo} (GPU={args.gpu}, mode={args.mode})")
|
| 144 |
+
create_repo(
|
| 145 |
+
args.repo,
|
| 146 |
+
repo_type="space",
|
| 147 |
+
space_sdk="docker",
|
| 148 |
+
space_hardware=args.gpu,
|
| 149 |
+
private=args.private,
|
| 150 |
+
exist_ok=True,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# 把 SANDBOX_MODE 写到 Space 变量;HF_TOKEN 需要用户自己在 Space Settings
|
| 154 |
+
# -> Secrets 里加一份能访问 NVIDIA 数据集的 token(real_data 模式必须)。
|
| 155 |
+
api.add_space_variable(repo_id=args.repo, key="SANDBOX_MODE", value=args.mode)
|
| 156 |
+
if args.mode == "real_data":
|
| 157 |
+
print(
|
| 158 |
+
"[push_to_sandbox] 提醒:real_data 模式需要在 Space Settings -> Secrets "
|
| 159 |
+
"里手动添加 HF_TOKEN(必须是能访问 nvidia/PhysicalAI-Autonomous-Vehicle-"
|
| 160 |
+
"Cosmos-Drive-Dreams 的账号 token,否则 download.py 会拒绝访问)。"
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# 落盘 Dockerfile / app.py
|
| 164 |
+
(ROOT / "Dockerfile").write_text(DOCKERFILE, encoding="utf-8")
|
| 165 |
+
(ROOT / "app.py").write_text(APP_PY, encoding="utf-8")
|
| 166 |
+
|
| 167 |
+
print("[push_to_sandbox] 上传仓库(排除 .venv / data / 缓存)...")
|
| 168 |
+
api.upload_folder(
|
| 169 |
+
folder_path=str(ROOT),
|
| 170 |
+
repo_id=args.repo,
|
| 171 |
+
repo_type="space",
|
| 172 |
+
ignore_patterns=[
|
| 173 |
+
".venv/*",
|
| 174 |
+
"data/*",
|
| 175 |
+
"**/__pycache__/*",
|
| 176 |
+
"*.pyc",
|
| 177 |
+
"agent-tools/*",
|
| 178 |
+
".git/*",
|
| 179 |
+
],
|
| 180 |
+
)
|
| 181 |
+
print(f"[push_to_sandbox] OK -> https://huggingface.co/spaces/{args.repo}")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
if __name__ == "__main__":
|
| 185 |
+
main()
|
scripts/sandbox_real_data.py
CHANGED
|
@@ -1,236 +1,236 @@
|
|
| 1 |
-
"""Sandbox 真实数据微验证脚本。
|
| 2 |
-
|
| 3 |
-
由于 NVIDIA Cosmos-Drive-Dreams 数据集的 ``cosmos_synthetic`` 是一份切成 17
|
| 4 |
-
个分卷(共 ~700 GB)的 ``split`` 二进制,单独下载某一分卷无法解压出 mp4。
|
| 5 |
-
因此本脚本采用混合方案:
|
| 6 |
-
|
| 7 |
-
1. 用官方 ``download.py --file_types lidar --limit 1`` 拉下 1 个 clip 的
|
| 8 |
-
全部真实标签(所有 common 文件夹 + lidar_raw),约 50-200 MB;
|
| 9 |
-
2. 把每个 ``.tar`` 解压到 ``labels/{clip_id}/{folder}/`` 结构,匹配
|
| 10 |
-
``wjad.data.cosmos_dataset`` 期待的布局;
|
| 11 |
-
3. 用 ``imageio`` 合成一个随机噪声 mp4 占位真实合成视频
|
| 12 |
-
(文件名 ``{clip_id}_{chunk_id}_Sunny.mp4``,121 帧,分辨率 1024×768);
|
| 13 |
-
4. 调用 ``wjad.train.runner_local --tiny --max_steps 4`` 跑 4 步真实标签 +
|
| 14 |
-
伪造视觉的训练。
|
| 15 |
-
|
| 16 |
-
这样能验证:
|
| 17 |
-
- 数据集索引(``build_clip_index``)
|
| 18 |
-
- 标签解析(``all_object_info`` JSON、SE(3) pose、f-theta 内参)
|
| 19 |
-
- LIDAR 加载与遮挡过滤
|
| 20 |
-
- Hungarian 匹配 + DETR loss
|
| 21 |
-
- 端到端 forward / GradNorm / PCGrad / 反传
|
| 22 |
-
|
| 23 |
-
但不会验证 DINOv3 在真实图像上的语义提取(视觉是噪声,不会收敛)。
|
| 24 |
-
"""
|
| 25 |
-
|
| 26 |
-
from __future__ import annotations
|
| 27 |
-
|
| 28 |
-
import os
|
| 29 |
-
import shutil
|
| 30 |
-
import subprocess
|
| 31 |
-
import sys
|
| 32 |
-
import tarfile
|
| 33 |
-
import urllib.request
|
| 34 |
-
from pathlib import Path
|
| 35 |
-
|
| 36 |
-
ROOT = Path(__file__).resolve().parent.parent
|
| 37 |
-
sys.path.insert(0, str(ROOT / "src"))
|
| 38 |
-
|
| 39 |
-
DATA_ROOT = Path(os.environ.get("WJAD_DATA_ROOT", ROOT / "data" / "cosmos"))
|
| 40 |
-
NV_DOWNLOAD_URL = (
|
| 41 |
-
"https://raw.githubusercontent.com/nv-tlabs/Cosmos-Drive-Dreams/main/scripts/download.py"
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def _print_section(title: str) -> None:
|
| 46 |
-
bar = "=" * 60
|
| 47 |
-
print(f"\n{bar}\n{title}\n{bar}", flush=True)
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def step1_download_labels() -> None:
|
| 51 |
-
"""用 NVIDIA 官方脚本下载 1 个 clip 的标签 + lidar。"""
|
| 52 |
-
_print_section("STEP 1 下载真实标签(1 个 clip)")
|
| 53 |
-
DATA_ROOT.mkdir(parents=True, exist_ok=True)
|
| 54 |
-
nv_script = DATA_ROOT / ".nvidia_download.py"
|
| 55 |
-
if not nv_script.exists():
|
| 56 |
-
print(f"[download] 取 NVIDIA download.py -> {nv_script}", flush=True)
|
| 57 |
-
with urllib.request.urlopen(NV_DOWNLOAD_URL) as r, open(nv_script, "wb") as f:
|
| 58 |
-
f.write(r.read())
|
| 59 |
-
# 同时拉 lidar + hdmap:``hdmap`` 类别会触发 9 个 3d_* 文件夹下载,
|
| 60 |
-
# 配合 common 文件夹一起拿,覆盖动态 + 结构化两类标签。
|
| 61 |
-
cmd = [
|
| 62 |
-
sys.executable,
|
| 63 |
-
str(nv_script),
|
| 64 |
-
"--odir", str(DATA_ROOT),
|
| 65 |
-
"--file_types", "lidar,hdmap",
|
| 66 |
-
"--workers", "4",
|
| 67 |
-
"--limit", "1",
|
| 68 |
-
]
|
| 69 |
-
print(f"$ {' '.join(cmd)}", flush=True)
|
| 70 |
-
rc = subprocess.call(cmd)
|
| 71 |
-
if rc != 0:
|
| 72 |
-
sys.exit(f"download.py 失败 rc={rc}")
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def _hoist_single_subdir(out_dir: Path) -> None:
|
| 76 |
-
"""若解压结果仅为「单个子目录、顶层无文件」,把子目录内容抬到 out_dir(常见 tar 布局)。"""
|
| 77 |
-
if not out_dir.is_dir():
|
| 78 |
-
return
|
| 79 |
-
subs = [p for p in out_dir.iterdir() if p.is_dir()]
|
| 80 |
-
files = [p for p in out_dir.iterdir() if p.is_file()]
|
| 81 |
-
if len(subs) == 1 and not files:
|
| 82 |
-
child = subs[0]
|
| 83 |
-
for item in child.iterdir():
|
| 84 |
-
dest = out_dir / item.name
|
| 85 |
-
if dest.exists():
|
| 86 |
-
continue
|
| 87 |
-
shutil.move(str(item), str(dest))
|
| 88 |
-
try:
|
| 89 |
-
child.rmdir()
|
| 90 |
-
except OSError:
|
| 91 |
-
pass
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def step2_reorganize_labels() -> str:
|
| 95 |
-
"""把每个 common 文件夹的 .tar 解压到 ``labels/{clip_id}/{folder}/``。
|
| 96 |
-
|
| 97 |
-
返回挑选出的 ``clip_id``(去掉 ``_{start}_{end}`` 后缀)。
|
| 98 |
-
"""
|
| 99 |
-
_print_section("STEP 2 解压标签到 labels/<clip_id>/<folder> 布局")
|
| 100 |
-
|
| 101 |
-
common_folders = [
|
| 102 |
-
"all_object_info",
|
| 103 |
-
"captions",
|
| 104 |
-
"car_mask_coarse",
|
| 105 |
-
"ftheta_intrinsic",
|
| 106 |
-
"pinhole_intrinsic",
|
| 107 |
-
"pose",
|
| 108 |
-
"vehicle_pose",
|
| 109 |
-
"lidar_raw",
|
| 110 |
-
# HDMap 9 类
|
| 111 |
-
"3d_lanes",
|
| 112 |
-
"3d_lanelines",
|
| 113 |
-
"3d_road_boundaries",
|
| 114 |
-
"3d_wait_lines",
|
| 115 |
-
"3d_crosswalks",
|
| 116 |
-
"3d_road_markings",
|
| 117 |
-
"3d_poles",
|
| 118 |
-
"3d_traffic_lights",
|
| 119 |
-
"3d_traffic_signs",
|
| 120 |
-
]
|
| 121 |
-
|
| 122 |
-
clip_id_full: str | None = None # {clip_id}_{start}_{end}
|
| 123 |
-
clip_id: str | None = None
|
| 124 |
-
|
| 125 |
-
for folder in common_folders:
|
| 126 |
-
src = DATA_ROOT / folder
|
| 127 |
-
if not src.exists():
|
| 128 |
-
print(f" - skip {folder} (not downloaded)", flush=True)
|
| 129 |
-
continue
|
| 130 |
-
tars = sorted(src.glob("*.tar"))
|
| 131 |
-
if not tars:
|
| 132 |
-
print(f" - skip {folder} (no .tar)", flush=True)
|
| 133 |
-
continue
|
| 134 |
-
if clip_id_full is None:
|
| 135 |
-
clip_id_full = tars[0].stem
|
| 136 |
-
clip_id = clip_id_full.rsplit("_", 2)[0]
|
| 137 |
-
print(f" -> chosen clip_id_full = {clip_id_full}", flush=True)
|
| 138 |
-
print(f" -> video / symlink clip_id = {clip_id}", flush=True)
|
| 139 |
-
use_tars = [t for t in tars if t.stem == clip_id_full]
|
| 140 |
-
if not use_tars:
|
| 141 |
-
print(f" - skip {folder}: 无与 {clip_id_full} 同名的 tar(避免解压错 clip)", flush=True)
|
| 142 |
-
continue
|
| 143 |
-
tar_path = use_tars[0]
|
| 144 |
-
# 目标目录
|
| 145 |
-
out_dir = DATA_ROOT / "labels" / clip_id_full / folder
|
| 146 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 147 |
-
with tarfile.open(tar_path, "r") as tf:
|
| 148 |
-
tf.extractall(out_dir)
|
| 149 |
-
_hoist_single_subdir(out_dir)
|
| 150 |
-
# 若仍嵌套一层 modality 名(ftheta_intrinsic/ftheta_intrinsic/...)
|
| 151 |
-
_hoist_single_subdir(out_dir)
|
| 152 |
-
# 列几个样例
|
| 153 |
-
members = sorted(out_dir.rglob("*"))[:3]
|
| 154 |
-
for m in members:
|
| 155 |
-
print(f" {m.relative_to(DATA_ROOT)}", flush=True)
|
| 156 |
-
print(f" - {folder}: {len(list(out_dir.rglob('*')))} files", flush=True)
|
| 157 |
-
|
| 158 |
-
if clip_id_full is None:
|
| 159 |
-
sys.exit("没有下到任何标签 tar,确认 HF_TOKEN 是否能访问 NVIDIA 数据集")
|
| 160 |
-
|
| 161 |
-
# 兼容 cosmos_dataset.py:它从 labels/{clip_id}/ 读,但实际下载用的是
|
| 162 |
-
# {clip_id}_{start}_{end} 作为目录名。这里软链一份名为纯 clip_id 的目录。
|
| 163 |
-
short_dir = DATA_ROOT / "labels" / clip_id # type: ignore[arg-type]
|
| 164 |
-
if not short_dir.exists():
|
| 165 |
-
try:
|
| 166 |
-
short_dir.symlink_to(DATA_ROOT / "labels" / clip_id_full, target_is_directory=True)
|
| 167 |
-
except OSError:
|
| 168 |
-
shutil.copytree(DATA_ROOT / "labels" / clip_id_full, short_dir)
|
| 169 |
-
return clip_id # type: ignore[return-value]
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
def step3_make_fake_video(clip_id: str) -> None:
|
| 173 |
-
"""合成 121 帧随机 mp4 模拟 ``cosmos_synthetic`` 视频。"""
|
| 174 |
-
_print_section("STEP 3 合成占位视频(随机噪声 mp4)")
|
| 175 |
-
import numpy as np
|
| 176 |
-
import cv2
|
| 177 |
-
|
| 178 |
-
syn_dir = DATA_ROOT / "synthetic" / "single_view" / "generation"
|
| 179 |
-
syn_dir.mkdir(parents=True, exist_ok=True)
|
| 180 |
-
out_path = syn_dir / f"{clip_id}_0_Sunny.mp4"
|
| 181 |
-
|
| 182 |
-
H, W, T = 768, 1024, 121 # 顶部裁剪后 384,原始 768
|
| 183 |
-
rng = np.random.default_rng(0)
|
| 184 |
-
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 185 |
-
writer = cv2.VideoWriter(str(out_path), fourcc, 30.0, (W, H))
|
| 186 |
-
if not writer.isOpened():
|
| 187 |
-
sys.exit(f"无法打开 mp4 写入器(缺 codec?): {out_path}")
|
| 188 |
-
for _ in range(T):
|
| 189 |
-
frame = rng.integers(0, 256, size=(H, W, 3), dtype=np.uint8)
|
| 190 |
-
writer.write(frame)
|
| 191 |
-
writer.release()
|
| 192 |
-
print(f" 写入 {out_path} ({out_path.stat().st_size / 1024**2:.1f} MB)", flush=True)
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
def step4_run_trainer(clip_id: str) -> None:
|
| 196 |
-
"""跑 runner_local --tiny --max_steps 4。"""
|
| 197 |
-
_print_section("STEP 4 跑 trainer(真实标签 + 伪造视觉)")
|
| 198 |
-
cmd = [
|
| 199 |
-
sys.executable,
|
| 200 |
-
"-m",
|
| 201 |
-
"wjad.train.runner_local",
|
| 202 |
-
"--config", str(ROOT / "configs" / "default.yaml"),
|
| 203 |
-
"--data_root", str(DATA_ROOT),
|
| 204 |
-
"--dinov3_path", str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
|
| 205 |
-
"--device", "cuda" if _has_cuda() else "cpu",
|
| 206 |
-
"--tiny",
|
| 207 |
-
"--max_steps", "4",
|
| 208 |
-
]
|
| 209 |
-
env = os.environ.copy()
|
| 210 |
-
env["PYTHONPATH"] = str(ROOT / "src") + os.pathsep + env.get("PYTHONPATH", "")
|
| 211 |
-
print(f"$ {' '.join(cmd)}", flush=True)
|
| 212 |
-
rc = subprocess.call(cmd, env=env)
|
| 213 |
-
if rc != 0:
|
| 214 |
-
sys.exit(f"trainer 失败 rc={rc}")
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
def _has_cuda() -> bool:
|
| 218 |
-
try:
|
| 219 |
-
import torch
|
| 220 |
-
return torch.cuda.is_available()
|
| 221 |
-
except Exception:
|
| 222 |
-
return False
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
def main() -> None:
|
| 226 |
-
_print_section("WJAD Sandbox Real-Data Tiny Test")
|
| 227 |
-
print(f"DATA_ROOT = {DATA_ROOT}", flush=True)
|
| 228 |
-
step1_download_labels()
|
| 229 |
-
clip_id = step2_reorganize_labels()
|
| 230 |
-
step3_make_fake_video(clip_id)
|
| 231 |
-
step4_run_trainer(clip_id)
|
| 232 |
-
_print_section("DONE")
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
if __name__ == "__main__":
|
| 236 |
-
main()
|
|
|
|
| 1 |
+
"""Sandbox 真实数据微验证脚本。
|
| 2 |
+
|
| 3 |
+
由于 NVIDIA Cosmos-Drive-Dreams 数据集的 ``cosmos_synthetic`` 是一份切成 17
|
| 4 |
+
个分卷(共 ~700 GB)的 ``split`` 二进制,单独下载某一分卷无法解压出 mp4。
|
| 5 |
+
因此本脚本采用混合方案:
|
| 6 |
+
|
| 7 |
+
1. 用官方 ``download.py --file_types lidar --limit 1`` 拉下 1 个 clip 的
|
| 8 |
+
全部真实标签(所有 common 文件夹 + lidar_raw),约 50-200 MB;
|
| 9 |
+
2. 把每个 ``.tar`` 解压到 ``labels/{clip_id}/{folder}/`` 结构,匹配
|
| 10 |
+
``wjad.data.cosmos_dataset`` 期待的布局;
|
| 11 |
+
3. 用 ``imageio`` 合成一个随机噪声 mp4 占位真实合成视频
|
| 12 |
+
(文件名 ``{clip_id}_{chunk_id}_Sunny.mp4``,121 帧,分辨率 1024×768);
|
| 13 |
+
4. 调用 ``wjad.train.runner_local --tiny --max_steps 4`` 跑 4 步真实标签 +
|
| 14 |
+
伪造视觉的训练。
|
| 15 |
+
|
| 16 |
+
这样能验证:
|
| 17 |
+
- 数据集索引(``build_clip_index``)
|
| 18 |
+
- 标签解析(``all_object_info`` JSON、SE(3) pose、f-theta 内参)
|
| 19 |
+
- LIDAR 加载与遮挡过滤
|
| 20 |
+
- Hungarian 匹配 + DETR loss
|
| 21 |
+
- 端到端 forward / GradNorm / PCGrad / 反传
|
| 22 |
+
|
| 23 |
+
但不会验证 DINOv3 在真实图像上的语义提取(视觉是噪声,不会收敛)。
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import os
|
| 29 |
+
import shutil
|
| 30 |
+
import subprocess
|
| 31 |
+
import sys
|
| 32 |
+
import tarfile
|
| 33 |
+
import urllib.request
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
|
| 36 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 37 |
+
sys.path.insert(0, str(ROOT / "src"))
|
| 38 |
+
|
| 39 |
+
DATA_ROOT = Path(os.environ.get("WJAD_DATA_ROOT", ROOT / "data" / "cosmos"))
|
| 40 |
+
NV_DOWNLOAD_URL = (
|
| 41 |
+
"https://raw.githubusercontent.com/nv-tlabs/Cosmos-Drive-Dreams/main/scripts/download.py"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _print_section(title: str) -> None:
|
| 46 |
+
bar = "=" * 60
|
| 47 |
+
print(f"\n{bar}\n{title}\n{bar}", flush=True)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def step1_download_labels() -> None:
|
| 51 |
+
"""用 NVIDIA 官方脚本下载 1 个 clip 的标签 + lidar。"""
|
| 52 |
+
_print_section("STEP 1 下载真实标签(1 个 clip)")
|
| 53 |
+
DATA_ROOT.mkdir(parents=True, exist_ok=True)
|
| 54 |
+
nv_script = DATA_ROOT / ".nvidia_download.py"
|
| 55 |
+
if not nv_script.exists():
|
| 56 |
+
print(f"[download] 取 NVIDIA download.py -> {nv_script}", flush=True)
|
| 57 |
+
with urllib.request.urlopen(NV_DOWNLOAD_URL) as r, open(nv_script, "wb") as f:
|
| 58 |
+
f.write(r.read())
|
| 59 |
+
# 同时拉 lidar + hdmap:``hdmap`` 类别会触发 9 个 3d_* 文件夹下载,
|
| 60 |
+
# 配合 common 文件夹一起拿,覆盖动态 + 结构化两类标签。
|
| 61 |
+
cmd = [
|
| 62 |
+
sys.executable,
|
| 63 |
+
str(nv_script),
|
| 64 |
+
"--odir", str(DATA_ROOT),
|
| 65 |
+
"--file_types", "lidar,hdmap",
|
| 66 |
+
"--workers", "4",
|
| 67 |
+
"--limit", "1",
|
| 68 |
+
]
|
| 69 |
+
print(f"$ {' '.join(cmd)}", flush=True)
|
| 70 |
+
rc = subprocess.call(cmd)
|
| 71 |
+
if rc != 0:
|
| 72 |
+
sys.exit(f"download.py 失败 rc={rc}")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _hoist_single_subdir(out_dir: Path) -> None:
|
| 76 |
+
"""若解压结果仅为「单个子目录、顶层无文件」,把子目录内容抬到 out_dir(常见 tar 布局)。"""
|
| 77 |
+
if not out_dir.is_dir():
|
| 78 |
+
return
|
| 79 |
+
subs = [p for p in out_dir.iterdir() if p.is_dir()]
|
| 80 |
+
files = [p for p in out_dir.iterdir() if p.is_file()]
|
| 81 |
+
if len(subs) == 1 and not files:
|
| 82 |
+
child = subs[0]
|
| 83 |
+
for item in child.iterdir():
|
| 84 |
+
dest = out_dir / item.name
|
| 85 |
+
if dest.exists():
|
| 86 |
+
continue
|
| 87 |
+
shutil.move(str(item), str(dest))
|
| 88 |
+
try:
|
| 89 |
+
child.rmdir()
|
| 90 |
+
except OSError:
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def step2_reorganize_labels() -> str:
|
| 95 |
+
"""把每个 common 文件夹的 .tar 解压到 ``labels/{clip_id}/{folder}/``。
|
| 96 |
+
|
| 97 |
+
返回挑选出的 ``clip_id``(去掉 ``_{start}_{end}`` 后缀)。
|
| 98 |
+
"""
|
| 99 |
+
_print_section("STEP 2 解压标签到 labels/<clip_id>/<folder> 布局")
|
| 100 |
+
|
| 101 |
+
common_folders = [
|
| 102 |
+
"all_object_info",
|
| 103 |
+
"captions",
|
| 104 |
+
"car_mask_coarse",
|
| 105 |
+
"ftheta_intrinsic",
|
| 106 |
+
"pinhole_intrinsic",
|
| 107 |
+
"pose",
|
| 108 |
+
"vehicle_pose",
|
| 109 |
+
"lidar_raw",
|
| 110 |
+
# HDMap 9 类
|
| 111 |
+
"3d_lanes",
|
| 112 |
+
"3d_lanelines",
|
| 113 |
+
"3d_road_boundaries",
|
| 114 |
+
"3d_wait_lines",
|
| 115 |
+
"3d_crosswalks",
|
| 116 |
+
"3d_road_markings",
|
| 117 |
+
"3d_poles",
|
| 118 |
+
"3d_traffic_lights",
|
| 119 |
+
"3d_traffic_signs",
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
clip_id_full: str | None = None # {clip_id}_{start}_{end}
|
| 123 |
+
clip_id: str | None = None
|
| 124 |
+
|
| 125 |
+
for folder in common_folders:
|
| 126 |
+
src = DATA_ROOT / folder
|
| 127 |
+
if not src.exists():
|
| 128 |
+
print(f" - skip {folder} (not downloaded)", flush=True)
|
| 129 |
+
continue
|
| 130 |
+
tars = sorted(src.glob("*.tar"))
|
| 131 |
+
if not tars:
|
| 132 |
+
print(f" - skip {folder} (no .tar)", flush=True)
|
| 133 |
+
continue
|
| 134 |
+
if clip_id_full is None:
|
| 135 |
+
clip_id_full = tars[0].stem
|
| 136 |
+
clip_id = clip_id_full.rsplit("_", 2)[0]
|
| 137 |
+
print(f" -> chosen clip_id_full = {clip_id_full}", flush=True)
|
| 138 |
+
print(f" -> video / symlink clip_id = {clip_id}", flush=True)
|
| 139 |
+
use_tars = [t for t in tars if t.stem == clip_id_full]
|
| 140 |
+
if not use_tars:
|
| 141 |
+
print(f" - skip {folder}: 无与 {clip_id_full} 同名的 tar(避免解压错 clip)", flush=True)
|
| 142 |
+
continue
|
| 143 |
+
tar_path = use_tars[0]
|
| 144 |
+
# 目标目录
|
| 145 |
+
out_dir = DATA_ROOT / "labels" / clip_id_full / folder
|
| 146 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 147 |
+
with tarfile.open(tar_path, "r") as tf:
|
| 148 |
+
tf.extractall(out_dir)
|
| 149 |
+
_hoist_single_subdir(out_dir)
|
| 150 |
+
# 若仍嵌套一层 modality 名(ftheta_intrinsic/ftheta_intrinsic/...)
|
| 151 |
+
_hoist_single_subdir(out_dir)
|
| 152 |
+
# 列几个样例
|
| 153 |
+
members = sorted(out_dir.rglob("*"))[:3]
|
| 154 |
+
for m in members:
|
| 155 |
+
print(f" {m.relative_to(DATA_ROOT)}", flush=True)
|
| 156 |
+
print(f" - {folder}: {len(list(out_dir.rglob('*')))} files", flush=True)
|
| 157 |
+
|
| 158 |
+
if clip_id_full is None:
|
| 159 |
+
sys.exit("没有下到任何标签 tar,确认 HF_TOKEN 是否能访问 NVIDIA 数据集")
|
| 160 |
+
|
| 161 |
+
# 兼容 cosmos_dataset.py:它从 labels/{clip_id}/ 读,但实际下载用的是
|
| 162 |
+
# {clip_id}_{start}_{end} 作为目录名。这里软链一份名为纯 clip_id 的目录。
|
| 163 |
+
short_dir = DATA_ROOT / "labels" / clip_id # type: ignore[arg-type]
|
| 164 |
+
if not short_dir.exists():
|
| 165 |
+
try:
|
| 166 |
+
short_dir.symlink_to(DATA_ROOT / "labels" / clip_id_full, target_is_directory=True)
|
| 167 |
+
except OSError:
|
| 168 |
+
shutil.copytree(DATA_ROOT / "labels" / clip_id_full, short_dir)
|
| 169 |
+
return clip_id # type: ignore[return-value]
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def step3_make_fake_video(clip_id: str) -> None:
|
| 173 |
+
"""合成 121 帧随机 mp4 模拟 ``cosmos_synthetic`` 视频。"""
|
| 174 |
+
_print_section("STEP 3 合成占位视频(随机噪声 mp4)")
|
| 175 |
+
import numpy as np
|
| 176 |
+
import cv2
|
| 177 |
+
|
| 178 |
+
syn_dir = DATA_ROOT / "synthetic" / "single_view" / "generation"
|
| 179 |
+
syn_dir.mkdir(parents=True, exist_ok=True)
|
| 180 |
+
out_path = syn_dir / f"{clip_id}_0_Sunny.mp4"
|
| 181 |
+
|
| 182 |
+
H, W, T = 768, 1024, 121 # 顶部裁剪后 384,原始 768
|
| 183 |
+
rng = np.random.default_rng(0)
|
| 184 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 185 |
+
writer = cv2.VideoWriter(str(out_path), fourcc, 30.0, (W, H))
|
| 186 |
+
if not writer.isOpened():
|
| 187 |
+
sys.exit(f"无法打开 mp4 写入器(缺 codec?): {out_path}")
|
| 188 |
+
for _ in range(T):
|
| 189 |
+
frame = rng.integers(0, 256, size=(H, W, 3), dtype=np.uint8)
|
| 190 |
+
writer.write(frame)
|
| 191 |
+
writer.release()
|
| 192 |
+
print(f" 写入 {out_path} ({out_path.stat().st_size / 1024**2:.1f} MB)", flush=True)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def step4_run_trainer(clip_id: str) -> None:
|
| 196 |
+
"""跑 runner_local --tiny --max_steps 4。"""
|
| 197 |
+
_print_section("STEP 4 跑 trainer(真实标签 + 伪造视觉)")
|
| 198 |
+
cmd = [
|
| 199 |
+
sys.executable,
|
| 200 |
+
"-m",
|
| 201 |
+
"wjad.train.runner_local",
|
| 202 |
+
"--config", str(ROOT / "configs" / "default.yaml"),
|
| 203 |
+
"--data_root", str(DATA_ROOT),
|
| 204 |
+
"--dinov3_path", str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
|
| 205 |
+
"--device", "cuda" if _has_cuda() else "cpu",
|
| 206 |
+
"--tiny",
|
| 207 |
+
"--max_steps", "4",
|
| 208 |
+
]
|
| 209 |
+
env = os.environ.copy()
|
| 210 |
+
env["PYTHONPATH"] = str(ROOT / "src") + os.pathsep + env.get("PYTHONPATH", "")
|
| 211 |
+
print(f"$ {' '.join(cmd)}", flush=True)
|
| 212 |
+
rc = subprocess.call(cmd, env=env)
|
| 213 |
+
if rc != 0:
|
| 214 |
+
sys.exit(f"trainer 失败 rc={rc}")
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _has_cuda() -> bool:
|
| 218 |
+
try:
|
| 219 |
+
import torch
|
| 220 |
+
return torch.cuda.is_available()
|
| 221 |
+
except Exception:
|
| 222 |
+
return False
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def main() -> None:
|
| 226 |
+
_print_section("WJAD Sandbox Real-Data Tiny Test")
|
| 227 |
+
print(f"DATA_ROOT = {DATA_ROOT}", flush=True)
|
| 228 |
+
step1_download_labels()
|
| 229 |
+
clip_id = step2_reorganize_labels()
|
| 230 |
+
step3_make_fake_video(clip_id)
|
| 231 |
+
step4_run_trainer(clip_id)
|
| 232 |
+
_print_section("DONE")
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
if __name__ == "__main__":
|
| 236 |
+
main()
|
scripts/smoke_test.py
CHANGED
|
@@ -1,78 +1,78 @@
|
|
| 1 |
-
"""本地烟囱测试:用随机张量验证 forward + backward。
|
| 2 |
-
|
| 3 |
-
运行:
|
| 4 |
-
python -m scripts.smoke_test
|
| 5 |
-
或:
|
| 6 |
-
python scripts/smoke_test.py
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
from __future__ import annotations
|
| 10 |
-
|
| 11 |
-
import sys
|
| 12 |
-
from pathlib import Path
|
| 13 |
-
|
| 14 |
-
# 允许直接运行
|
| 15 |
-
ROOT = Path(__file__).resolve().parent.parent
|
| 16 |
-
sys.path.insert(0, str(ROOT / "src"))
|
| 17 |
-
|
| 18 |
-
import torch
|
| 19 |
-
|
| 20 |
-
from wjad.model import E2EAVModel
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def main() -> None:
|
| 24 |
-
torch.manual_seed(0)
|
| 25 |
-
device = "cpu"
|
| 26 |
-
|
| 27 |
-
print("[smoke_test] 构建模型...")
|
| 28 |
-
model = E2EAVModel(
|
| 29 |
-
dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
|
| 30 |
-
# 减小测试规模以适配 CPU
|
| 31 |
-
num_dense_layers=2,
|
| 32 |
-
num_moe_layers=2,
|
| 33 |
-
num_detection_tokens=64,
|
| 34 |
-
num_extra_tokens=32,
|
| 35 |
-
num_classes=22,
|
| 36 |
-
).to(device)
|
| 37 |
-
|
| 38 |
-
# 切到 sparse 验证 Top-3 路径也通
|
| 39 |
-
model.backbone.set_moe_mode("sparse")
|
| 40 |
-
|
| 41 |
-
B, T = 1, 8
|
| 42 |
-
images = torch.randn(B, T, 3, 384, 1024, device=device)
|
| 43 |
-
ego_6d = torch.zeros(B, T, 6, device=device)
|
| 44 |
-
ego_6d[..., 0] = torch.linspace(0, 7, T) # 模拟前进
|
| 45 |
-
intr_vec = torch.tensor([[
|
| 46 |
-
512.0, 192.0, 1024, 384, # cx, cy, w, h
|
| 47 |
-
0.0, 0.5, 0.0, 0.0, 0.0, 0.0, # poly
|
| 48 |
-
1.0, # is_bw_poly(与 Cosmos 11 维一致)
|
| 49 |
-
]], device=device)
|
| 50 |
-
extr_6d = torch.zeros(B, 6, device=device)
|
| 51 |
-
|
| 52 |
-
print("[smoke_test] 前向...")
|
| 53 |
-
out = model(images, ego_6d, intr_vec, extr_6d)
|
| 54 |
-
print(f" detection cls: {tuple(out.detection.cls_logits.shape)}")
|
| 55 |
-
print(f" detection box_mu: {tuple(out.detection.box3d_mu.shape)}")
|
| 56 |
-
print(f" detection traj_mu: {tuple(out.detection.traj_mu.shape)}")
|
| 57 |
-
print(f" control ego_traj_mu: {tuple(out.control.ego_traj_mu.shape)}")
|
| 58 |
-
print(f" control action_mu: {tuple(out.control.action_mu.shape)}")
|
| 59 |
-
print(f" moe_stats per layer: {len(out.backbone_out.moe_stats)}")
|
| 60 |
-
|
| 61 |
-
# 简单 backward:用 cls + box_mu 和 + ego_traj_mu 的简单 loss
|
| 62 |
-
loss = (
|
| 63 |
-
out.detection.cls_logits.float().abs().mean()
|
| 64 |
-
+ out.detection.box3d_mu.float().abs().mean()
|
| 65 |
-
+ out.detection.traj_mu.float().abs().mean()
|
| 66 |
-
+ out.control.ego_traj_mu.float().abs().mean()
|
| 67 |
-
)
|
| 68 |
-
print(f"[smoke_test] loss = {loss.item():.6f}")
|
| 69 |
-
loss.backward()
|
| 70 |
-
grad_norm = sum(
|
| 71 |
-
p.grad.detach().norm().item() for p in model.parameters() if p.grad is not None
|
| 72 |
-
)
|
| 73 |
-
print(f"[smoke_test] grad sum-of-norms = {grad_norm:.4f}")
|
| 74 |
-
print("[smoke_test] OK")
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
if __name__ == "__main__":
|
| 78 |
-
main()
|
|
|
|
| 1 |
+
"""本地烟囱测试:用随机张量验证 forward + backward。
|
| 2 |
+
|
| 3 |
+
运行:
|
| 4 |
+
python -m scripts.smoke_test
|
| 5 |
+
或:
|
| 6 |
+
python scripts/smoke_test.py
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
# 允许直接运行
|
| 15 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 16 |
+
sys.path.insert(0, str(ROOT / "src"))
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from wjad.model import E2EAVModel
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def main() -> None:
|
| 24 |
+
torch.manual_seed(0)
|
| 25 |
+
device = "cpu"
|
| 26 |
+
|
| 27 |
+
print("[smoke_test] 构建模型...")
|
| 28 |
+
model = E2EAVModel(
|
| 29 |
+
dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
|
| 30 |
+
# 减小测试规模以适配 CPU
|
| 31 |
+
num_dense_layers=2,
|
| 32 |
+
num_moe_layers=2,
|
| 33 |
+
num_detection_tokens=64,
|
| 34 |
+
num_extra_tokens=32,
|
| 35 |
+
num_classes=22,
|
| 36 |
+
).to(device)
|
| 37 |
+
|
| 38 |
+
# 切到 sparse 验证 Top-3 路径也通
|
| 39 |
+
model.backbone.set_moe_mode("sparse")
|
| 40 |
+
|
| 41 |
+
B, T = 1, 8
|
| 42 |
+
images = torch.randn(B, T, 3, 384, 1024, device=device)
|
| 43 |
+
ego_6d = torch.zeros(B, T, 6, device=device)
|
| 44 |
+
ego_6d[..., 0] = torch.linspace(0, 7, T) # 模拟前进
|
| 45 |
+
intr_vec = torch.tensor([[
|
| 46 |
+
512.0, 192.0, 1024, 384, # cx, cy, w, h
|
| 47 |
+
0.0, 0.5, 0.0, 0.0, 0.0, 0.0, # poly
|
| 48 |
+
1.0, # is_bw_poly(与 Cosmos 11 维一致)
|
| 49 |
+
]], device=device)
|
| 50 |
+
extr_6d = torch.zeros(B, 6, device=device)
|
| 51 |
+
|
| 52 |
+
print("[smoke_test] 前向...")
|
| 53 |
+
out = model(images, ego_6d, intr_vec, extr_6d)
|
| 54 |
+
print(f" detection cls: {tuple(out.detection.cls_logits.shape)}")
|
| 55 |
+
print(f" detection box_mu: {tuple(out.detection.box3d_mu.shape)}")
|
| 56 |
+
print(f" detection traj_mu: {tuple(out.detection.traj_mu.shape)}")
|
| 57 |
+
print(f" control ego_traj_mu: {tuple(out.control.ego_traj_mu.shape)}")
|
| 58 |
+
print(f" control action_mu: {tuple(out.control.action_mu.shape)}")
|
| 59 |
+
print(f" moe_stats per layer: {len(out.backbone_out.moe_stats)}")
|
| 60 |
+
|
| 61 |
+
# 简单 backward:用 cls + box_mu 和 + ego_traj_mu 的简单 loss
|
| 62 |
+
loss = (
|
| 63 |
+
out.detection.cls_logits.float().abs().mean()
|
| 64 |
+
+ out.detection.box3d_mu.float().abs().mean()
|
| 65 |
+
+ out.detection.traj_mu.float().abs().mean()
|
| 66 |
+
+ out.control.ego_traj_mu.float().abs().mean()
|
| 67 |
+
)
|
| 68 |
+
print(f"[smoke_test] loss = {loss.item():.6f}")
|
| 69 |
+
loss.backward()
|
| 70 |
+
grad_norm = sum(
|
| 71 |
+
p.grad.detach().norm().item() for p in model.parameters() if p.grad is not None
|
| 72 |
+
)
|
| 73 |
+
print(f"[smoke_test] grad sum-of-norms = {grad_norm:.4f}")
|
| 74 |
+
print("[smoke_test] OK")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
main()
|
scripts/smoke_train.py
CHANGED
|
@@ -1,152 +1,152 @@
|
|
| 1 |
-
"""端到端训练循环烟囱测试:构造随机 batch,跑 1-2 步 trainer。
|
| 2 |
-
|
| 3 |
-
不依赖磁盘上的数据集,仅验证 forward/backward/loss/PCGrad/GradNorm 链路。
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
from __future__ import annotations
|
| 7 |
-
|
| 8 |
-
import os
|
| 9 |
-
import sys
|
| 10 |
-
from pathlib import Path
|
| 11 |
-
|
| 12 |
-
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
| 13 |
-
|
| 14 |
-
ROOT = Path(__file__).resolve().parent.parent
|
| 15 |
-
sys.path.insert(0, str(ROOT / "src"))
|
| 16 |
-
|
| 17 |
-
import logging
|
| 18 |
-
|
| 19 |
-
import numpy as np
|
| 20 |
-
import torch
|
| 21 |
-
|
| 22 |
-
from wjad.model import E2EAVModel
|
| 23 |
-
from wjad.train.trainer import Trainer, TrainerConfig
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def _make_dummy_batch(
|
| 27 |
-
B: int = 1,
|
| 28 |
-
T: int = 8,
|
| 29 |
-
H: int = 64,
|
| 30 |
-
W: int = 128,
|
| 31 |
-
num_classes: int = 22,
|
| 32 |
-
num_objects: int = 3,
|
| 33 |
-
) -> dict:
|
| 34 |
-
"""构造极小分辨率的随机 batch(CPU 烟囱测试用)。"""
|
| 35 |
-
images = torch.randn(B, T, 3, H, W)
|
| 36 |
-
ego_6d = torch.zeros(B, T, 6)
|
| 37 |
-
intr_vec = torch.tensor([[
|
| 38 |
-
W / 2, H / 2, W, H,
|
| 39 |
-
0.0, 0.5, 0.0, 0.0, 0.0, 0.0,
|
| 40 |
-
1.0,
|
| 41 |
-
]] * B)
|
| 42 |
-
extr_6d = torch.zeros(B, 6)
|
| 43 |
-
ego_future = torch.zeros(B, 24, 3)
|
| 44 |
-
ego_future_valid = torch.ones(B, 24, dtype=torch.bool)
|
| 45 |
-
|
| 46 |
-
targets = []
|
| 47 |
-
for _ in range(B):
|
| 48 |
-
boxes = torch.zeros(num_objects, 7)
|
| 49 |
-
boxes[:, 3:6] = 2.0
|
| 50 |
-
targets.append({
|
| 51 |
-
"labels": torch.randint(1, num_classes, (num_objects,)),
|
| 52 |
-
"boxes": boxes,
|
| 53 |
-
"is_dynamic": torch.ones(num_objects, dtype=torch.long),
|
| 54 |
-
"future_traj": torch.zeros(num_objects, 24, 3),
|
| 55 |
-
"future_valid": torch.ones(num_objects, 24, dtype=torch.bool),
|
| 56 |
-
})
|
| 57 |
-
|
| 58 |
-
return {
|
| 59 |
-
"images": images,
|
| 60 |
-
"ego_6d": ego_6d,
|
| 61 |
-
"intr_vec": intr_vec,
|
| 62 |
-
"extr_6d": extr_6d,
|
| 63 |
-
"ego_future": ego_future,
|
| 64 |
-
"ego_future_valid": ego_future_valid,
|
| 65 |
-
"targets": targets,
|
| 66 |
-
"meta": [{}] * B,
|
| 67 |
-
}
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def main() -> None:
|
| 71 |
-
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 72 |
-
torch.manual_seed(0)
|
| 73 |
-
|
| 74 |
-
has_cuda = torch.cuda.is_available()
|
| 75 |
-
device = "cuda" if has_cuda else "cpu"
|
| 76 |
-
if has_cuda:
|
| 77 |
-
# GPU 上跑接近真实规模:full 384x1024 + 完整 18 层
|
| 78 |
-
# a10g-small (~22 GB) 上 BS=4 OOM,启用 gradient_checkpointing 后 BS=2 稳定
|
| 79 |
-
H, W = 384, 1024
|
| 80 |
-
B = 2
|
| 81 |
-
num_dense, num_moe = 9, 9
|
| 82 |
-
num_det = 1024
|
| 83 |
-
num_extra = 256
|
| 84 |
-
amp = "bf16"
|
| 85 |
-
n_steps = 4
|
| 86 |
-
use_grad_ckpt = True
|
| 87 |
-
else:
|
| 88 |
-
# CPU 上跑极小规模仅做 sanity
|
| 89 |
-
H, W = 64, 128
|
| 90 |
-
B = 1
|
| 91 |
-
num_dense, num_moe = 2, 2
|
| 92 |
-
num_det = 32
|
| 93 |
-
num_extra = 16
|
| 94 |
-
amp = "fp32"
|
| 95 |
-
n_steps = 4
|
| 96 |
-
use_grad_ckpt = False
|
| 97 |
-
|
| 98 |
-
print(f"[smoke_train] device={device}, H={H} W={W} B={B} amp={amp} grad_ckpt={use_grad_ckpt}")
|
| 99 |
-
model = E2EAVModel(
|
| 100 |
-
dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
|
| 101 |
-
num_dense_layers=num_dense,
|
| 102 |
-
num_moe_layers=num_moe,
|
| 103 |
-
num_detection_tokens=num_det,
|
| 104 |
-
num_control_tokens=24,
|
| 105 |
-
num_ego_tokens=8,
|
| 106 |
-
num_extra_tokens=num_extra,
|
| 107 |
-
num_classes=22,
|
| 108 |
-
image_h=H,
|
| 109 |
-
image_w=W,
|
| 110 |
-
patch_size=16,
|
| 111 |
-
)
|
| 112 |
-
if use_grad_ckpt:
|
| 113 |
-
model.backbone.set_gradient_checkpointing(True)
|
| 114 |
-
# sandbox a10g-small 不做 DINOv3 finetune(显存预算 22GB 不够),冻结即可
|
| 115 |
-
# 验证两阶段路径切换。完整训练交给 H100 Jobs。
|
| 116 |
-
model.dinov3.freeze()
|
| 117 |
-
|
| 118 |
-
cfg = TrainerConfig(
|
| 119 |
-
total_steps=n_steps,
|
| 120 |
-
warmup_steps=1,
|
| 121 |
-
base_lr=1e-4,
|
| 122 |
-
log_interval=1,
|
| 123 |
-
stage1_steps=2, # 跑到 stage2 验证切换路径
|
| 124 |
-
stage1_perturb_start=1,
|
| 125 |
-
enable_gradnorm=True,
|
| 126 |
-
enable_pcgrad=True, # 全程启用 PCGrad
|
| 127 |
-
mixed_precision=amp,
|
| 128 |
-
unfreeze_dinov3_at_stage2=False, # sandbox 显存有限,验证路径即可
|
| 129 |
-
)
|
| 130 |
-
trainer = Trainer(model, cfg, num_classes=22, device=device)
|
| 131 |
-
rng = np.random.default_rng(0)
|
| 132 |
-
|
| 133 |
-
if has_cuda:
|
| 134 |
-
torch.cuda.reset_peak_memory_stats()
|
| 135 |
-
|
| 136 |
-
for step in range(n_steps):
|
| 137 |
-
batch = _make_dummy_batch(B=B, H=H, W=W)
|
| 138 |
-
info = trainer.train_step(batch, rng)
|
| 139 |
-
print(
|
| 140 |
-
f"step={info['step']} stage={info['stage']} total={info['total_loss']:.4f} "
|
| 141 |
-
f"cls={info['L_cls']:.4f} box={info['L_box']:.4f} traj_obj={info['L_traj_obj']:.4f} "
|
| 142 |
-
f"weights={[f'{w:.2f}' for w in info['weights']]}"
|
| 143 |
-
)
|
| 144 |
-
|
| 145 |
-
if has_cuda:
|
| 146 |
-
peak_gb = torch.cuda.max_memory_allocated() / 1024**3
|
| 147 |
-
print(f"[smoke_train] CUDA peak memory = {peak_gb:.2f} GB")
|
| 148 |
-
print("[smoke_train] OK")
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
if __name__ == "__main__":
|
| 152 |
-
main()
|
|
|
|
| 1 |
+
"""端到端训练循环烟囱测试:构造随机 batch,跑 1-2 步 trainer。
|
| 2 |
+
|
| 3 |
+
不依赖磁盘上的数据集,仅验证 forward/backward/loss/PCGrad/GradNorm 链路。
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
| 13 |
+
|
| 14 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 15 |
+
sys.path.insert(0, str(ROOT / "src"))
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from wjad.model import E2EAVModel
|
| 23 |
+
from wjad.train.trainer import Trainer, TrainerConfig
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _make_dummy_batch(
|
| 27 |
+
B: int = 1,
|
| 28 |
+
T: int = 8,
|
| 29 |
+
H: int = 64,
|
| 30 |
+
W: int = 128,
|
| 31 |
+
num_classes: int = 22,
|
| 32 |
+
num_objects: int = 3,
|
| 33 |
+
) -> dict:
|
| 34 |
+
"""构造极小分辨率的随机 batch(CPU 烟囱测试用)。"""
|
| 35 |
+
images = torch.randn(B, T, 3, H, W)
|
| 36 |
+
ego_6d = torch.zeros(B, T, 6)
|
| 37 |
+
intr_vec = torch.tensor([[
|
| 38 |
+
W / 2, H / 2, W, H,
|
| 39 |
+
0.0, 0.5, 0.0, 0.0, 0.0, 0.0,
|
| 40 |
+
1.0,
|
| 41 |
+
]] * B)
|
| 42 |
+
extr_6d = torch.zeros(B, 6)
|
| 43 |
+
ego_future = torch.zeros(B, 24, 3)
|
| 44 |
+
ego_future_valid = torch.ones(B, 24, dtype=torch.bool)
|
| 45 |
+
|
| 46 |
+
targets = []
|
| 47 |
+
for _ in range(B):
|
| 48 |
+
boxes = torch.zeros(num_objects, 7)
|
| 49 |
+
boxes[:, 3:6] = 2.0
|
| 50 |
+
targets.append({
|
| 51 |
+
"labels": torch.randint(1, num_classes, (num_objects,)),
|
| 52 |
+
"boxes": boxes,
|
| 53 |
+
"is_dynamic": torch.ones(num_objects, dtype=torch.long),
|
| 54 |
+
"future_traj": torch.zeros(num_objects, 24, 3),
|
| 55 |
+
"future_valid": torch.ones(num_objects, 24, dtype=torch.bool),
|
| 56 |
+
})
|
| 57 |
+
|
| 58 |
+
return {
|
| 59 |
+
"images": images,
|
| 60 |
+
"ego_6d": ego_6d,
|
| 61 |
+
"intr_vec": intr_vec,
|
| 62 |
+
"extr_6d": extr_6d,
|
| 63 |
+
"ego_future": ego_future,
|
| 64 |
+
"ego_future_valid": ego_future_valid,
|
| 65 |
+
"targets": targets,
|
| 66 |
+
"meta": [{}] * B,
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def main() -> None:
|
| 71 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 72 |
+
torch.manual_seed(0)
|
| 73 |
+
|
| 74 |
+
has_cuda = torch.cuda.is_available()
|
| 75 |
+
device = "cuda" if has_cuda else "cpu"
|
| 76 |
+
if has_cuda:
|
| 77 |
+
# GPU 上跑接近真实规模:full 384x1024 + 完整 18 层
|
| 78 |
+
# a10g-small (~22 GB) 上 BS=4 OOM,启用 gradient_checkpointing 后 BS=2 稳定
|
| 79 |
+
H, W = 384, 1024
|
| 80 |
+
B = 2
|
| 81 |
+
num_dense, num_moe = 9, 9
|
| 82 |
+
num_det = 1024
|
| 83 |
+
num_extra = 256
|
| 84 |
+
amp = "bf16"
|
| 85 |
+
n_steps = 4
|
| 86 |
+
use_grad_ckpt = True
|
| 87 |
+
else:
|
| 88 |
+
# CPU 上跑极小规模仅做 sanity
|
| 89 |
+
H, W = 64, 128
|
| 90 |
+
B = 1
|
| 91 |
+
num_dense, num_moe = 2, 2
|
| 92 |
+
num_det = 32
|
| 93 |
+
num_extra = 16
|
| 94 |
+
amp = "fp32"
|
| 95 |
+
n_steps = 4
|
| 96 |
+
use_grad_ckpt = False
|
| 97 |
+
|
| 98 |
+
print(f"[smoke_train] device={device}, H={H} W={W} B={B} amp={amp} grad_ckpt={use_grad_ckpt}")
|
| 99 |
+
model = E2EAVModel(
|
| 100 |
+
dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
|
| 101 |
+
num_dense_layers=num_dense,
|
| 102 |
+
num_moe_layers=num_moe,
|
| 103 |
+
num_detection_tokens=num_det,
|
| 104 |
+
num_control_tokens=24,
|
| 105 |
+
num_ego_tokens=8,
|
| 106 |
+
num_extra_tokens=num_extra,
|
| 107 |
+
num_classes=22,
|
| 108 |
+
image_h=H,
|
| 109 |
+
image_w=W,
|
| 110 |
+
patch_size=16,
|
| 111 |
+
)
|
| 112 |
+
if use_grad_ckpt:
|
| 113 |
+
model.backbone.set_gradient_checkpointing(True)
|
| 114 |
+
# sandbox a10g-small 不做 DINOv3 finetune(显存预算 22GB 不够),冻结即可
|
| 115 |
+
# 验证两阶段路径切换。完整训练交给 H100 Jobs。
|
| 116 |
+
model.dinov3.freeze()
|
| 117 |
+
|
| 118 |
+
cfg = TrainerConfig(
|
| 119 |
+
total_steps=n_steps,
|
| 120 |
+
warmup_steps=1,
|
| 121 |
+
base_lr=1e-4,
|
| 122 |
+
log_interval=1,
|
| 123 |
+
stage1_steps=2, # 跑到 stage2 验证切换路径
|
| 124 |
+
stage1_perturb_start=1,
|
| 125 |
+
enable_gradnorm=True,
|
| 126 |
+
enable_pcgrad=True, # 全程启用 PCGrad
|
| 127 |
+
mixed_precision=amp,
|
| 128 |
+
unfreeze_dinov3_at_stage2=False, # sandbox 显存有限,验证路径即可
|
| 129 |
+
)
|
| 130 |
+
trainer = Trainer(model, cfg, num_classes=22, device=device)
|
| 131 |
+
rng = np.random.default_rng(0)
|
| 132 |
+
|
| 133 |
+
if has_cuda:
|
| 134 |
+
torch.cuda.reset_peak_memory_stats()
|
| 135 |
+
|
| 136 |
+
for step in range(n_steps):
|
| 137 |
+
batch = _make_dummy_batch(B=B, H=H, W=W)
|
| 138 |
+
info = trainer.train_step(batch, rng)
|
| 139 |
+
print(
|
| 140 |
+
f"step={info['step']} stage={info['stage']} total={info['total_loss']:.4f} "
|
| 141 |
+
f"cls={info['L_cls']:.4f} box={info['L_box']:.4f} traj_obj={info['L_traj_obj']:.4f} "
|
| 142 |
+
f"weights={[f'{w:.2f}' for w in info['weights']]}"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if has_cuda:
|
| 146 |
+
peak_gb = torch.cuda.max_memory_allocated() / 1024**3
|
| 147 |
+
print(f"[smoke_train] CUDA peak memory = {peak_gb:.2f} GB")
|
| 148 |
+
print("[smoke_train] OK")
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
main()
|
scripts/update_deps.py
CHANGED
|
@@ -1,123 +1,123 @@
|
|
| 1 |
-
"""自动把项目依赖升级到 PyPI 最新版。
|
| 2 |
-
|
| 3 |
-
特点:
|
| 4 |
-
- 从 ``pyproject.toml`` 读取 ``project.dependencies`` 与
|
| 5 |
-
``project.optional-dependencies``;
|
| 6 |
-
- 直接调用 ``pip install --upgrade <pkg>`` 把所有第三方依赖升级;
|
| 7 |
-
- 为 ``torch`` / ``torchvision`` / ``torchaudio`` 提供单独的 CUDA index URL
|
| 8 |
-
选项(``--torch-index https://download.pytorch.org/whl/cu124``);
|
| 9 |
-
- 升级后调用 ``pip freeze`` 把锁定版本写入 ``requirements.lock.txt``,便于
|
| 10 |
-
在 HF Sandbox / Jobs 环境中复现。
|
| 11 |
-
|
| 12 |
-
注意:本脚本会修改本地 venv!若需要安全演练,加 ``--dry-run``。
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
from __future__ import annotations
|
| 16 |
-
|
| 17 |
-
import argparse
|
| 18 |
-
import shutil
|
| 19 |
-
import subprocess
|
| 20 |
-
import sys
|
| 21 |
-
from pathlib import Path
|
| 22 |
-
|
| 23 |
-
try:
|
| 24 |
-
import tomllib # py3.11+
|
| 25 |
-
except ImportError: # pragma: no cover
|
| 26 |
-
import tomli as tomllib # type: ignore
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
ROOT = Path(__file__).resolve().parent.parent
|
| 30 |
-
PYPROJECT = ROOT / "pyproject.toml"
|
| 31 |
-
LOCK_FILE = ROOT / "requirements.lock.txt"
|
| 32 |
-
TORCH_PKGS = {"torch", "torchvision", "torchaudio"}
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def parse_pyproject() -> tuple[list[str], list[str]]:
|
| 36 |
-
"""返回 (主依赖, dev 依赖) 的纯包名列表。"""
|
| 37 |
-
data = tomllib.loads(PYPROJECT.read_text(encoding="utf-8"))
|
| 38 |
-
main = [
|
| 39 |
-
_strip_spec(d) for d in data.get("project", {}).get("dependencies", [])
|
| 40 |
-
]
|
| 41 |
-
dev = [
|
| 42 |
-
_strip_spec(d)
|
| 43 |
-
for d in data.get("project", {})
|
| 44 |
-
.get("optional-dependencies", {})
|
| 45 |
-
.get("dev", [])
|
| 46 |
-
]
|
| 47 |
-
return main, dev
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def _strip_spec(req: str) -> str:
|
| 51 |
-
"""去掉版本约束,只留包名。"""
|
| 52 |
-
name = req.split(";")[0] # 去掉 environment marker
|
| 53 |
-
for sym in ("[", ">=", "<=", "==", "~=", ">", "<", "!=", "@"):
|
| 54 |
-
if sym in name:
|
| 55 |
-
name = name.split(sym)[0]
|
| 56 |
-
return name.strip()
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def run(cmd: list[str], dry_run: bool = False) -> int:
|
| 60 |
-
print("$", " ".join(cmd))
|
| 61 |
-
if dry_run:
|
| 62 |
-
return 0
|
| 63 |
-
return subprocess.call(cmd)
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def upgrade(pkgs: list[str], extra_index: str | None, dry_run: bool, with_pre: bool = False) -> None:
|
| 67 |
-
base = [sys.executable, "-m", "pip", "install", "--upgrade"]
|
| 68 |
-
if with_pre:
|
| 69 |
-
base.append("--pre")
|
| 70 |
-
if extra_index:
|
| 71 |
-
base += ["--extra-index-url", extra_index]
|
| 72 |
-
rc = run(base + pkgs, dry_run=dry_run)
|
| 73 |
-
if rc != 0:
|
| 74 |
-
sys.exit(rc)
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
def main() -> None:
|
| 78 |
-
parser = argparse.ArgumentParser()
|
| 79 |
-
parser.add_argument(
|
| 80 |
-
"--torch-index",
|
| 81 |
-
default=None,
|
| 82 |
-
help="PyTorch CUDA wheel 索引(如 https://download.pytorch.org/whl/cu124)",
|
| 83 |
-
)
|
| 84 |
-
parser.add_argument("--no-dev", action="store_true", help="不升级 dev 依赖")
|
| 85 |
-
parser.add_argument("--dry-run", action="store_true", help="只打印命令不执行")
|
| 86 |
-
parser.add_argument("--with-pre", action="store_true", help="允许升级到 pre-release")
|
| 87 |
-
args = parser.parse_args()
|
| 88 |
-
|
| 89 |
-
if not PYPROJECT.exists():
|
| 90 |
-
print(f"[update_deps] 找不到 {PYPROJECT}", file=sys.stderr)
|
| 91 |
-
sys.exit(1)
|
| 92 |
-
|
| 93 |
-
main_deps, dev_deps = parse_pyproject()
|
| 94 |
-
|
| 95 |
-
# 把 torch 系列单独处理(用专用索引)
|
| 96 |
-
torch_deps = [p for p in main_deps if p in TORCH_PKGS]
|
| 97 |
-
other_deps = [p for p in main_deps if p not in TORCH_PKGS]
|
| 98 |
-
|
| 99 |
-
print(f"[update_deps] 升级 pip / setuptools / wheel ...")
|
| 100 |
-
upgrade(["pip", "setuptools", "wheel"], extra_index=None, dry_run=args.dry_run)
|
| 101 |
-
|
| 102 |
-
if torch_deps:
|
| 103 |
-
print(f"[update_deps] 升级 torch 系列 ({torch_deps}) ...")
|
| 104 |
-
upgrade(torch_deps, extra_index=args.torch_index, dry_run=args.dry_run, with_pre=args.with_pre)
|
| 105 |
-
|
| 106 |
-
if other_deps:
|
| 107 |
-
print(f"[update_deps] 升级主依赖 ({len(other_deps)} 个) ...")
|
| 108 |
-
upgrade(other_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre)
|
| 109 |
-
|
| 110 |
-
if dev_deps and not args.no_dev:
|
| 111 |
-
print(f"[update_deps] 升级 dev 依赖 ({len(dev_deps)} 个) ...")
|
| 112 |
-
upgrade(dev_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre)
|
| 113 |
-
|
| 114 |
-
print("[update_deps] 写入锁定文件 ...")
|
| 115 |
-
if not args.dry_run:
|
| 116 |
-
with open(LOCK_FILE, "w", encoding="utf-8") as f:
|
| 117 |
-
subprocess.run([sys.executable, "-m", "pip", "freeze"], stdout=f, check=True)
|
| 118 |
-
print(f"[update_deps] 锁定版本已写入 {LOCK_FILE}")
|
| 119 |
-
print("[update_deps] OK")
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
if __name__ == "__main__":
|
| 123 |
-
main()
|
|
|
|
| 1 |
+
"""自动把项目依赖升级到 PyPI 最新版。
|
| 2 |
+
|
| 3 |
+
特点:
|
| 4 |
+
- 从 ``pyproject.toml`` 读取 ``project.dependencies`` 与
|
| 5 |
+
``project.optional-dependencies``;
|
| 6 |
+
- 直接调用 ``pip install --upgrade <pkg>`` 把所有第三方依赖升级;
|
| 7 |
+
- 为 ``torch`` / ``torchvision`` / ``torchaudio`` 提供单独的 CUDA index URL
|
| 8 |
+
选项(``--torch-index https://download.pytorch.org/whl/cu124``);
|
| 9 |
+
- 升级后调用 ``pip freeze`` 把锁定版本写入 ``requirements.lock.txt``,便于
|
| 10 |
+
在 HF Sandbox / Jobs 环境中复现。
|
| 11 |
+
|
| 12 |
+
注意:本脚本会修改本地 venv!若需要安全演练,加 ``--dry-run``。
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import shutil
|
| 19 |
+
import subprocess
|
| 20 |
+
import sys
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
import tomllib # py3.11+
|
| 25 |
+
except ImportError: # pragma: no cover
|
| 26 |
+
import tomli as tomllib # type: ignore
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 30 |
+
PYPROJECT = ROOT / "pyproject.toml"
|
| 31 |
+
LOCK_FILE = ROOT / "requirements.lock.txt"
|
| 32 |
+
TORCH_PKGS = {"torch", "torchvision", "torchaudio"}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def parse_pyproject() -> tuple[list[str], list[str]]:
|
| 36 |
+
"""返回 (主依赖, dev 依赖) 的纯包名列表。"""
|
| 37 |
+
data = tomllib.loads(PYPROJECT.read_text(encoding="utf-8"))
|
| 38 |
+
main = [
|
| 39 |
+
_strip_spec(d) for d in data.get("project", {}).get("dependencies", [])
|
| 40 |
+
]
|
| 41 |
+
dev = [
|
| 42 |
+
_strip_spec(d)
|
| 43 |
+
for d in data.get("project", {})
|
| 44 |
+
.get("optional-dependencies", {})
|
| 45 |
+
.get("dev", [])
|
| 46 |
+
]
|
| 47 |
+
return main, dev
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _strip_spec(req: str) -> str:
|
| 51 |
+
"""去掉版本约束,只留包名。"""
|
| 52 |
+
name = req.split(";")[0] # 去掉 environment marker
|
| 53 |
+
for sym in ("[", ">=", "<=", "==", "~=", ">", "<", "!=", "@"):
|
| 54 |
+
if sym in name:
|
| 55 |
+
name = name.split(sym)[0]
|
| 56 |
+
return name.strip()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def run(cmd: list[str], dry_run: bool = False) -> int:
|
| 60 |
+
print("$", " ".join(cmd))
|
| 61 |
+
if dry_run:
|
| 62 |
+
return 0
|
| 63 |
+
return subprocess.call(cmd)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def upgrade(pkgs: list[str], extra_index: str | None, dry_run: bool, with_pre: bool = False) -> None:
|
| 67 |
+
base = [sys.executable, "-m", "pip", "install", "--upgrade"]
|
| 68 |
+
if with_pre:
|
| 69 |
+
base.append("--pre")
|
| 70 |
+
if extra_index:
|
| 71 |
+
base += ["--extra-index-url", extra_index]
|
| 72 |
+
rc = run(base + pkgs, dry_run=dry_run)
|
| 73 |
+
if rc != 0:
|
| 74 |
+
sys.exit(rc)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def main() -> None:
|
| 78 |
+
parser = argparse.ArgumentParser()
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--torch-index",
|
| 81 |
+
default=None,
|
| 82 |
+
help="PyTorch CUDA wheel 索引(如 https://download.pytorch.org/whl/cu124)",
|
| 83 |
+
)
|
| 84 |
+
parser.add_argument("--no-dev", action="store_true", help="不升级 dev 依赖")
|
| 85 |
+
parser.add_argument("--dry-run", action="store_true", help="只打印命令不执行")
|
| 86 |
+
parser.add_argument("--with-pre", action="store_true", help="允许升级到 pre-release")
|
| 87 |
+
args = parser.parse_args()
|
| 88 |
+
|
| 89 |
+
if not PYPROJECT.exists():
|
| 90 |
+
print(f"[update_deps] 找不到 {PYPROJECT}", file=sys.stderr)
|
| 91 |
+
sys.exit(1)
|
| 92 |
+
|
| 93 |
+
main_deps, dev_deps = parse_pyproject()
|
| 94 |
+
|
| 95 |
+
# 把 torch 系列单独处理(用专用索引)
|
| 96 |
+
torch_deps = [p for p in main_deps if p in TORCH_PKGS]
|
| 97 |
+
other_deps = [p for p in main_deps if p not in TORCH_PKGS]
|
| 98 |
+
|
| 99 |
+
print(f"[update_deps] 升级 pip / setuptools / wheel ...")
|
| 100 |
+
upgrade(["pip", "setuptools", "wheel"], extra_index=None, dry_run=args.dry_run)
|
| 101 |
+
|
| 102 |
+
if torch_deps:
|
| 103 |
+
print(f"[update_deps] 升级 torch 系列 ({torch_deps}) ...")
|
| 104 |
+
upgrade(torch_deps, extra_index=args.torch_index, dry_run=args.dry_run, with_pre=args.with_pre)
|
| 105 |
+
|
| 106 |
+
if other_deps:
|
| 107 |
+
print(f"[update_deps] 升级主依赖 ({len(other_deps)} 个) ...")
|
| 108 |
+
upgrade(other_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre)
|
| 109 |
+
|
| 110 |
+
if dev_deps and not args.no_dev:
|
| 111 |
+
print(f"[update_deps] 升级 dev 依赖 ({len(dev_deps)} 个) ...")
|
| 112 |
+
upgrade(dev_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre)
|
| 113 |
+
|
| 114 |
+
print("[update_deps] 写入锁定文件 ...")
|
| 115 |
+
if not args.dry_run:
|
| 116 |
+
with open(LOCK_FILE, "w", encoding="utf-8") as f:
|
| 117 |
+
subprocess.run([sys.executable, "-m", "pip", "freeze"], stdout=f, check=True)
|
| 118 |
+
print(f"[update_deps] 锁定版本已写入 {LOCK_FILE}")
|
| 119 |
+
print("[update_deps] OK")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
if __name__ == "__main__":
|
| 123 |
+
main()
|
src/wjad/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
"""WJAD: 端到端自动驾驶模型主包。"""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
__version__ = "0.1.0"
|
|
|
|
| 1 |
+
"""WJAD: 端到端自动驾驶模型主包。"""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
__version__ = "0.1.0"
|
src/wjad/backbone/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
-
"""18 层主干。"""
|
| 2 |
-
|
| 3 |
-
from .backbone import Backbone, BackboneOutput
|
| 4 |
-
from .blocks import DenseBlock, MoEBlockWithAttn
|
| 5 |
-
|
| 6 |
-
__all__ = ["Backbone", "BackboneOutput", "DenseBlock", "MoEBlockWithAttn"]
|
|
|
|
| 1 |
+
"""18 层主干。"""
|
| 2 |
+
|
| 3 |
+
from .backbone import Backbone, BackboneOutput
|
| 4 |
+
from .blocks import DenseBlock, MoEBlockWithAttn
|
| 5 |
+
|
| 6 |
+
__all__ = ["Backbone", "BackboneOutput", "DenseBlock", "MoEBlockWithAttn"]
|
src/wjad/backbone/backbone.py
CHANGED
|
@@ -1,110 +1,110 @@
|
|
| 1 |
-
"""18 层主干:前 9 Dense + 后 9 MoE。"""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
from dataclasses import dataclass, field
|
| 6 |
-
from typing import Optional
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn as nn
|
| 10 |
-
import torch.utils.checkpoint as cp
|
| 11 |
-
|
| 12 |
-
from ..modules.moe import MoEStats
|
| 13 |
-
from .blocks import DenseBlock, MoEBlockWithAttn
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
@dataclass
|
| 17 |
-
class BackboneOutput:
|
| 18 |
-
"""主干输出。"""
|
| 19 |
-
|
| 20 |
-
hidden_states: torch.Tensor # [B, N, D]
|
| 21 |
-
moe_stats: list[MoEStats] = field(default_factory=list)
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class Backbone(nn.Module):
|
| 25 |
-
"""端到端主干。
|
| 26 |
-
|
| 27 |
-
输入序列已包含位置编码(视觉部分 RoPE 在每层内部应用,非视觉部分使用
|
| 28 |
-
可学习 PE 在外部加完)。本模块只负责 18 层堆叠 + 路由统计聚合。
|
| 29 |
-
"""
|
| 30 |
-
|
| 31 |
-
def __init__(
|
| 32 |
-
self,
|
| 33 |
-
dim: int = 768,
|
| 34 |
-
num_heads: int = 12,
|
| 35 |
-
ffn_mult: int = 4,
|
| 36 |
-
num_dense_layers: int = 9,
|
| 37 |
-
num_moe_layers: int = 9,
|
| 38 |
-
num_routed: int = 7,
|
| 39 |
-
num_shared: int = 1,
|
| 40 |
-
topk: int = 3,
|
| 41 |
-
dropout: float = 0.0,
|
| 42 |
-
) -> None:
|
| 43 |
-
super().__init__()
|
| 44 |
-
self.dim = dim
|
| 45 |
-
self.num_heads = num_heads
|
| 46 |
-
self.num_dense_layers = num_dense_layers
|
| 47 |
-
self.num_moe_layers = num_moe_layers
|
| 48 |
-
|
| 49 |
-
self.dense_layers = nn.ModuleList([
|
| 50 |
-
DenseBlock(dim, num_heads, ffn_mult=ffn_mult, dropout=dropout)
|
| 51 |
-
for _ in range(num_dense_layers)
|
| 52 |
-
])
|
| 53 |
-
self.moe_layers = nn.ModuleList([
|
| 54 |
-
MoEBlockWithAttn(
|
| 55 |
-
dim,
|
| 56 |
-
num_heads,
|
| 57 |
-
num_routed=num_routed,
|
| 58 |
-
num_shared=num_shared,
|
| 59 |
-
topk=topk,
|
| 60 |
-
ffn_mult=ffn_mult,
|
| 61 |
-
dropout=dropout,
|
| 62 |
-
)
|
| 63 |
-
for _ in range(num_moe_layers)
|
| 64 |
-
])
|
| 65 |
-
self.final_norm = nn.LayerNorm(dim)
|
| 66 |
-
# 默认关闭;外部通过 ``set_gradient_checkpointing(True)`` 打开以省显存
|
| 67 |
-
self.gradient_checkpointing = False
|
| 68 |
-
|
| 69 |
-
def set_gradient_checkpointing(self, enabled: bool) -> None:
|
| 70 |
-
"""开启/关闭主干各层 gradient checkpointing(约省 2/3 激活显存)。"""
|
| 71 |
-
self.gradient_checkpointing = enabled
|
| 72 |
-
|
| 73 |
-
def set_moe_mode(self, mode: str) -> None:
|
| 74 |
-
"""切换所有 MoE 层模式('dense' / 'sparse')。"""
|
| 75 |
-
for blk in self.moe_layers:
|
| 76 |
-
blk.set_mode(mode)
|
| 77 |
-
|
| 78 |
-
def set_router_temperature(self, t: float) -> None:
|
| 79 |
-
for blk in self.moe_layers:
|
| 80 |
-
blk.set_temperature(t)
|
| 81 |
-
|
| 82 |
-
def forward(
|
| 83 |
-
self,
|
| 84 |
-
x: torch.Tensor,
|
| 85 |
-
rope_cos: Optional[torch.Tensor] = None,
|
| 86 |
-
rope_sin: Optional[torch.Tensor] = None,
|
| 87 |
-
visual_slice: Optional[tuple[int, int]] = None,
|
| 88 |
-
) -> BackboneOutput:
|
| 89 |
-
moe_stats: list[MoEStats] = []
|
| 90 |
-
use_ckpt = self.gradient_checkpointing and self.training
|
| 91 |
-
|
| 92 |
-
for blk in self.dense_layers:
|
| 93 |
-
if use_ckpt:
|
| 94 |
-
x = cp.checkpoint(
|
| 95 |
-
blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False
|
| 96 |
-
)
|
| 97 |
-
else:
|
| 98 |
-
x = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
|
| 99 |
-
|
| 100 |
-
for blk in self.moe_layers:
|
| 101 |
-
if use_ckpt:
|
| 102 |
-
x, stats = cp.checkpoint(
|
| 103 |
-
blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False
|
| 104 |
-
)
|
| 105 |
-
else:
|
| 106 |
-
x, stats = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
|
| 107 |
-
moe_stats.append(stats)
|
| 108 |
-
|
| 109 |
-
x = self.final_norm(x)
|
| 110 |
-
return BackboneOutput(hidden_states=x, moe_stats=moe_stats)
|
|
|
|
| 1 |
+
"""18 层主干:前 9 Dense + 后 9 MoE。"""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.utils.checkpoint as cp
|
| 11 |
+
|
| 12 |
+
from ..modules.moe import MoEStats
|
| 13 |
+
from .blocks import DenseBlock, MoEBlockWithAttn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class BackboneOutput:
|
| 18 |
+
"""主干输出。"""
|
| 19 |
+
|
| 20 |
+
hidden_states: torch.Tensor # [B, N, D]
|
| 21 |
+
moe_stats: list[MoEStats] = field(default_factory=list)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Backbone(nn.Module):
|
| 25 |
+
"""端到端主干。
|
| 26 |
+
|
| 27 |
+
输入序列已包含位置编码(视觉部分 RoPE 在每层内部应用,非视觉部分使用
|
| 28 |
+
可学习 PE 在外部加完)。本模块只负责 18 层堆叠 + 路由统计聚合。
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
dim: int = 768,
|
| 34 |
+
num_heads: int = 12,
|
| 35 |
+
ffn_mult: int = 4,
|
| 36 |
+
num_dense_layers: int = 9,
|
| 37 |
+
num_moe_layers: int = 9,
|
| 38 |
+
num_routed: int = 7,
|
| 39 |
+
num_shared: int = 1,
|
| 40 |
+
topk: int = 3,
|
| 41 |
+
dropout: float = 0.0,
|
| 42 |
+
) -> None:
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.dim = dim
|
| 45 |
+
self.num_heads = num_heads
|
| 46 |
+
self.num_dense_layers = num_dense_layers
|
| 47 |
+
self.num_moe_layers = num_moe_layers
|
| 48 |
+
|
| 49 |
+
self.dense_layers = nn.ModuleList([
|
| 50 |
+
DenseBlock(dim, num_heads, ffn_mult=ffn_mult, dropout=dropout)
|
| 51 |
+
for _ in range(num_dense_layers)
|
| 52 |
+
])
|
| 53 |
+
self.moe_layers = nn.ModuleList([
|
| 54 |
+
MoEBlockWithAttn(
|
| 55 |
+
dim,
|
| 56 |
+
num_heads,
|
| 57 |
+
num_routed=num_routed,
|
| 58 |
+
num_shared=num_shared,
|
| 59 |
+
topk=topk,
|
| 60 |
+
ffn_mult=ffn_mult,
|
| 61 |
+
dropout=dropout,
|
| 62 |
+
)
|
| 63 |
+
for _ in range(num_moe_layers)
|
| 64 |
+
])
|
| 65 |
+
self.final_norm = nn.LayerNorm(dim)
|
| 66 |
+
# 默认关闭;外部通过 ``set_gradient_checkpointing(True)`` 打开以省显存
|
| 67 |
+
self.gradient_checkpointing = False
|
| 68 |
+
|
| 69 |
+
def set_gradient_checkpointing(self, enabled: bool) -> None:
|
| 70 |
+
"""开启/关闭主干各层 gradient checkpointing(约省 2/3 激活显存)。"""
|
| 71 |
+
self.gradient_checkpointing = enabled
|
| 72 |
+
|
| 73 |
+
def set_moe_mode(self, mode: str) -> None:
|
| 74 |
+
"""切换所有 MoE 层模式('dense' / 'sparse')。"""
|
| 75 |
+
for blk in self.moe_layers:
|
| 76 |
+
blk.set_mode(mode)
|
| 77 |
+
|
| 78 |
+
def set_router_temperature(self, t: float) -> None:
|
| 79 |
+
for blk in self.moe_layers:
|
| 80 |
+
blk.set_temperature(t)
|
| 81 |
+
|
| 82 |
+
def forward(
|
| 83 |
+
self,
|
| 84 |
+
x: torch.Tensor,
|
| 85 |
+
rope_cos: Optional[torch.Tensor] = None,
|
| 86 |
+
rope_sin: Optional[torch.Tensor] = None,
|
| 87 |
+
visual_slice: Optional[tuple[int, int]] = None,
|
| 88 |
+
) -> BackboneOutput:
|
| 89 |
+
moe_stats: list[MoEStats] = []
|
| 90 |
+
use_ckpt = self.gradient_checkpointing and self.training
|
| 91 |
+
|
| 92 |
+
for blk in self.dense_layers:
|
| 93 |
+
if use_ckpt:
|
| 94 |
+
x = cp.checkpoint(
|
| 95 |
+
blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
x = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
|
| 99 |
+
|
| 100 |
+
for blk in self.moe_layers:
|
| 101 |
+
if use_ckpt:
|
| 102 |
+
x, stats = cp.checkpoint(
|
| 103 |
+
blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
x, stats = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
|
| 107 |
+
moe_stats.append(stats)
|
| 108 |
+
|
| 109 |
+
x = self.final_norm(x)
|
| 110 |
+
return BackboneOutput(hidden_states=x, moe_stats=moe_stats)
|
src/wjad/backbone/blocks.py
CHANGED
|
@@ -1,79 +1,79 @@
|
|
| 1 |
-
"""主干层 block:Dense(GateSelfAttn + SwiGLU FFN)/ MoE(GateSelfAttn + MoE FFN)。"""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
from typing import Optional
|
| 6 |
-
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn as nn
|
| 9 |
-
|
| 10 |
-
from ..modules.ffn import SwiGLUFFN
|
| 11 |
-
from ..modules.gate_attention import GateSelfAttention
|
| 12 |
-
from ..modules.moe import MoEBlock, MoEStats
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class DenseBlock(nn.Module):
|
| 16 |
-
"""PreNorm GateSelfAttention + PreNorm SwiGLU FFN。"""
|
| 17 |
-
|
| 18 |
-
def __init__(self, dim: int, num_heads: int, ffn_mult: int = 4, dropout: float = 0.0) -> None:
|
| 19 |
-
super().__init__()
|
| 20 |
-
self.norm1 = nn.LayerNorm(dim)
|
| 21 |
-
self.attn = GateSelfAttention(dim, num_heads=num_heads, dropout=dropout)
|
| 22 |
-
self.norm2 = nn.LayerNorm(dim)
|
| 23 |
-
self.ffn = SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout)
|
| 24 |
-
|
| 25 |
-
def forward(
|
| 26 |
-
self,
|
| 27 |
-
x: torch.Tensor,
|
| 28 |
-
rope_cos: Optional[torch.Tensor] = None,
|
| 29 |
-
rope_sin: Optional[torch.Tensor] = None,
|
| 30 |
-
visual_slice: Optional[tuple[int, int]] = None,
|
| 31 |
-
) -> torch.Tensor:
|
| 32 |
-
x = x + self.attn(self.norm1(x), rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
|
| 33 |
-
x = x + self.ffn(self.norm2(x))
|
| 34 |
-
return x
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
class MoEBlockWithAttn(nn.Module):
|
| 38 |
-
"""PreNorm GateSelfAttention + PreNorm MoE FFN。"""
|
| 39 |
-
|
| 40 |
-
def __init__(
|
| 41 |
-
self,
|
| 42 |
-
dim: int,
|
| 43 |
-
num_heads: int,
|
| 44 |
-
num_routed: int = 7,
|
| 45 |
-
num_shared: int = 1,
|
| 46 |
-
topk: int = 3,
|
| 47 |
-
ffn_mult: int = 4,
|
| 48 |
-
dropout: float = 0.0,
|
| 49 |
-
) -> None:
|
| 50 |
-
super().__init__()
|
| 51 |
-
self.norm1 = nn.LayerNorm(dim)
|
| 52 |
-
self.attn = GateSelfAttention(dim, num_heads=num_heads, dropout=dropout)
|
| 53 |
-
self.norm2 = nn.LayerNorm(dim)
|
| 54 |
-
self.moe = MoEBlock(
|
| 55 |
-
dim,
|
| 56 |
-
num_routed=num_routed,
|
| 57 |
-
num_shared=num_shared,
|
| 58 |
-
topk=topk,
|
| 59 |
-
ffn_mult=ffn_mult,
|
| 60 |
-
dropout=dropout,
|
| 61 |
-
)
|
| 62 |
-
|
| 63 |
-
def set_mode(self, mode: str) -> None:
|
| 64 |
-
self.moe.set_mode(mode)
|
| 65 |
-
|
| 66 |
-
def set_temperature(self, t: float) -> None:
|
| 67 |
-
self.moe.set_temperature(t)
|
| 68 |
-
|
| 69 |
-
def forward(
|
| 70 |
-
self,
|
| 71 |
-
x: torch.Tensor,
|
| 72 |
-
rope_cos: Optional[torch.Tensor] = None,
|
| 73 |
-
rope_sin: Optional[torch.Tensor] = None,
|
| 74 |
-
visual_slice: Optional[tuple[int, int]] = None,
|
| 75 |
-
) -> tuple[torch.Tensor, MoEStats]:
|
| 76 |
-
x = x + self.attn(self.norm1(x), rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
|
| 77 |
-
moe_out, stats = self.moe(self.norm2(x))
|
| 78 |
-
x = x + moe_out
|
| 79 |
-
return x, stats
|
|
|
|
| 1 |
+
"""主干层 block:Dense(GateSelfAttn + SwiGLU FFN)/ MoE(GateSelfAttn + MoE FFN)。"""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from ..modules.ffn import SwiGLUFFN
|
| 11 |
+
from ..modules.gate_attention import GateSelfAttention
|
| 12 |
+
from ..modules.moe import MoEBlock, MoEStats
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DenseBlock(nn.Module):
|
| 16 |
+
"""PreNorm GateSelfAttention + PreNorm SwiGLU FFN。"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, dim: int, num_heads: int, ffn_mult: int = 4, dropout: float = 0.0) -> None:
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 21 |
+
self.attn = GateSelfAttention(dim, num_heads=num_heads, dropout=dropout)
|
| 22 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 23 |
+
self.ffn = SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout)
|
| 24 |
+
|
| 25 |
+
def forward(
|
| 26 |
+
self,
|
| 27 |
+
x: torch.Tensor,
|
| 28 |
+
rope_cos: Optional[torch.Tensor] = None,
|
| 29 |
+
rope_sin: Optional[torch.Tensor] = None,
|
| 30 |
+
visual_slice: Optional[tuple[int, int]] = None,
|
| 31 |
+
) -> torch.Tensor:
|
| 32 |
+
x = x + self.attn(self.norm1(x), rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
|
| 33 |
+
x = x + self.ffn(self.norm2(x))
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class MoEBlockWithAttn(nn.Module):
|
| 38 |
+
"""PreNorm GateSelfAttention + PreNorm MoE FFN。"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
dim: int,
|
| 43 |
+
num_heads: int,
|
| 44 |
+
num_routed: int = 7,
|
| 45 |
+
num_shared: int = 1,
|
| 46 |
+
topk: int = 3,
|
| 47 |
+
ffn_mult: int = 4,
|
| 48 |
+
dropout: float = 0.0,
|
| 49 |
+
) -> None:
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 52 |
+
self.attn = GateSelfAttention(dim, num_heads=num_heads, dropout=dropout)
|
| 53 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 54 |
+
self.moe = MoEBlock(
|
| 55 |
+
dim,
|
| 56 |
+
num_routed=num_routed,
|
| 57 |
+
num_shared=num_shared,
|
| 58 |
+
topk=topk,
|
| 59 |
+
ffn_mult=ffn_mult,
|
| 60 |
+
dropout=dropout,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def set_mode(self, mode: str) -> None:
|
| 64 |
+
self.moe.set_mode(mode)
|
| 65 |
+
|
| 66 |
+
def set_temperature(self, t: float) -> None:
|
| 67 |
+
self.moe.set_temperature(t)
|
| 68 |
+
|
| 69 |
+
def forward(
|
| 70 |
+
self,
|
| 71 |
+
x: torch.Tensor,
|
| 72 |
+
rope_cos: Optional[torch.Tensor] = None,
|
| 73 |
+
rope_sin: Optional[torch.Tensor] = None,
|
| 74 |
+
visual_slice: Optional[tuple[int, int]] = None,
|
| 75 |
+
) -> tuple[torch.Tensor, MoEStats]:
|
| 76 |
+
x = x + self.attn(self.norm1(x), rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
|
| 77 |
+
moe_out, stats = self.moe(self.norm2(x))
|
| 78 |
+
x = x + moe_out
|
| 79 |
+
return x, stats
|
src/wjad/calibration/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
"""在线校准模块。"""
|
| 2 |
-
|
| 3 |
-
from .online_calib import OnlineCalibration, CalibrationOutput
|
| 4 |
-
|
| 5 |
-
__all__ = ["OnlineCalibration", "CalibrationOutput"]
|
|
|
|
| 1 |
+
"""在线校准模块。"""
|
| 2 |
+
|
| 3 |
+
from .online_calib import OnlineCalibration, CalibrationOutput
|
| 4 |
+
|
| 5 |
+
__all__ = ["OnlineCalibration", "CalibrationOutput"]
|
src/wjad/calibration/online_calib.py
CHANGED
|
@@ -1,196 +1,196 @@
|
|
| 1 |
-
"""在线校准网络。
|
| 2 |
-
|
| 3 |
-
输入
|
| 4 |
-
- DINOv3 patch 特征 ``[B, T, gh, gw, D_dino]``(用作 K/V 上下文)。
|
| 5 |
-
- 8 帧自车位姿(每帧 6D = 3 平移 + 3 轴角)``[B, 8, 6]``。
|
| 6 |
-
- f-theta 内参 ``[B, intr_dim]``。Cosmos-Drive-Dreams 常见 **11** 维(无 ``linear_cde``);
|
| 7 |
-
README 完整式为 14 维时把 ``intr_dim`` 配成 14 即可,**不做零填充**。
|
| 8 |
-
- 相机外参 6D ``[B, 6]``。
|
| 9 |
-
|
| 10 |
-
流程
|
| 11 |
-
- 上述运动学 / 内外参先 ``symlog`` 归一,再 Linear -> 256,作为额外
|
| 12 |
-
条件 token 与 256 个可学习 query token 拼接。
|
| 13 |
-
- 6 层 = 2 × [1 GateCrossAttn(K,V <- DINOv3 patch) + 2 GateSelfAttn]。
|
| 14 |
-
- 取最后一个 token,MLP -> Tanh -> ``residual_range`` → 输出
|
| 15 |
-
symlog 空间的残差。``corrected = symexp(symlog(raw) + Tanh_residual)``。
|
| 16 |
-
|
| 17 |
-
输出
|
| 18 |
-
- ``ego_residual`` ``[B, 8, 6]``、``intr_residual`` ``[B, intr_dim]``、
|
| 19 |
-
``extr_residual`` ``[B, 6]``,已 symexp 还原到真实空间的 ``corrected_*``。
|
| 20 |
-
"""
|
| 21 |
-
|
| 22 |
-
from __future__ import annotations
|
| 23 |
-
|
| 24 |
-
from dataclasses import dataclass
|
| 25 |
-
|
| 26 |
-
import torch
|
| 27 |
-
import torch.nn as nn
|
| 28 |
-
|
| 29 |
-
from ..modules.gate_attention import GateCrossAttention, GateSelfAttention
|
| 30 |
-
from ..modules.learned_pe import LearnedTokenPE
|
| 31 |
-
from ..modules.normalization import symexp, symlog
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
@dataclass
|
| 35 |
-
class CalibrationOutput:
|
| 36 |
-
"""校准网络输出。残差均在 symlog 空间,``corrected_*`` 已 symexp 还原。"""
|
| 37 |
-
|
| 38 |
-
ego_residual: torch.Tensor # [B, 8, 6]
|
| 39 |
-
intr_residual: torch.Tensor # [B, intr_dim]
|
| 40 |
-
extr_residual: torch.Tensor # [B, 6]
|
| 41 |
-
corrected_ego: torch.Tensor # [B, 8, 6],真实空间
|
| 42 |
-
corrected_intr: torch.Tensor # [B, intr_dim]
|
| 43 |
-
corrected_extr: torch.Tensor # [B, 6]
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class _CalibBlock(nn.Module):
|
| 47 |
-
"""单个校准 block:1 GateCrossAttn + 2 GateSelfAttn,PreNorm。"""
|
| 48 |
-
|
| 49 |
-
def __init__(self, dim: int, dim_kv: int, num_heads: int, num_self: int = 2) -> None:
|
| 50 |
-
super().__init__()
|
| 51 |
-
self.cross_norm = nn.LayerNorm(dim)
|
| 52 |
-
self.cross = GateCrossAttention(dim, dim_kv, num_heads=num_heads)
|
| 53 |
-
self.cross_drop = nn.Dropout(0.0)
|
| 54 |
-
self.self_blocks = nn.ModuleList()
|
| 55 |
-
for _ in range(num_self):
|
| 56 |
-
self.self_blocks.append(
|
| 57 |
-
nn.ModuleDict({
|
| 58 |
-
"norm": nn.LayerNorm(dim),
|
| 59 |
-
"attn": GateSelfAttention(dim, num_heads=num_heads),
|
| 60 |
-
})
|
| 61 |
-
)
|
| 62 |
-
|
| 63 |
-
def forward(self, q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
|
| 64 |
-
# 1) Cross
|
| 65 |
-
q = q + self.cross(self.cross_norm(q), kv)
|
| 66 |
-
# 2) Self ×2
|
| 67 |
-
for blk in self.self_blocks:
|
| 68 |
-
q = q + blk["attn"](blk["norm"](q))
|
| 69 |
-
return q
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
class OnlineCalibration(nn.Module):
|
| 73 |
-
def __init__(
|
| 74 |
-
self,
|
| 75 |
-
dino_dim: int = 768,
|
| 76 |
-
hidden_dim: int = 256,
|
| 77 |
-
num_query_tokens: int = 256,
|
| 78 |
-
num_blocks: int = 2,
|
| 79 |
-
num_self_attn_per_block: int = 2,
|
| 80 |
-
num_heads: int = 8,
|
| 81 |
-
residual_range: float = 0.1,
|
| 82 |
-
ego_dim: int = 6, # 3 平移 + 3 轴角
|
| 83 |
-
intr_dim: int = 11,
|
| 84 |
-
extr_dim: int = 6,
|
| 85 |
-
num_history_frames: int = 8,
|
| 86 |
-
init_zero_output: bool = True,
|
| 87 |
-
) -> None:
|
| 88 |
-
super().__init__()
|
| 89 |
-
self.hidden_dim = hidden_dim
|
| 90 |
-
self.residual_range = residual_range
|
| 91 |
-
self.ego_dim = ego_dim
|
| 92 |
-
self.intr_dim = intr_dim
|
| 93 |
-
self.extr_dim = extr_dim
|
| 94 |
-
self.num_history_frames = num_history_frames
|
| 95 |
-
|
| 96 |
-
# 256 可学习 query token
|
| 97 |
-
self.query_tokens = nn.Parameter(torch.empty(num_query_tokens, hidden_dim))
|
| 98 |
-
nn.init.trunc_normal_(self.query_tokens, std=0.02)
|
| 99 |
-
self.query_pe = LearnedTokenPE(num_query_tokens, hidden_dim)
|
| 100 |
-
|
| 101 |
-
# 条件 token 编码(symlog 空间)
|
| 102 |
-
self.ego_proj = nn.Linear(ego_dim, hidden_dim)
|
| 103 |
-
self.intr_proj = nn.Linear(intr_dim, hidden_dim)
|
| 104 |
-
self.extr_proj = nn.Linear(extr_dim, hidden_dim)
|
| 105 |
-
# 条件 token 也加可学习 PE(与 query 区分)
|
| 106 |
-
num_cond = num_history_frames + 2 # 8 ego + 1 intr + 1 extr
|
| 107 |
-
self.cond_pe = LearnedTokenPE(num_cond, hidden_dim)
|
| 108 |
-
self.num_cond = num_cond
|
| 109 |
-
self.num_query = num_query_tokens
|
| 110 |
-
|
| 111 |
-
# KV 上下文:DINOv3 patch 特征投影到 hidden_dim
|
| 112 |
-
self.kv_proj = nn.Linear(dino_dim, hidden_dim)
|
| 113 |
-
self.kv_norm = nn.LayerNorm(hidden_dim)
|
| 114 |
-
|
| 115 |
-
# 校准 block × num_blocks
|
| 116 |
-
self.blocks = nn.ModuleList([
|
| 117 |
-
_CalibBlock(hidden_dim, hidden_dim, num_heads=num_heads, num_self=num_self_attn_per_block)
|
| 118 |
-
for _ in range(num_blocks)
|
| 119 |
-
])
|
| 120 |
-
|
| 121 |
-
# 输出 MLP:取 last token -> 残差向量
|
| 122 |
-
residual_total = num_history_frames * ego_dim + intr_dim + extr_dim
|
| 123 |
-
self.out_norm = nn.LayerNorm(hidden_dim)
|
| 124 |
-
self.out_mlp = nn.Sequential(
|
| 125 |
-
nn.Linear(hidden_dim, hidden_dim),
|
| 126 |
-
nn.GELU(),
|
| 127 |
-
nn.Linear(hidden_dim, residual_total),
|
| 128 |
-
)
|
| 129 |
-
# 0 初始化最后一层 → Tanh(0) = 0 → 初始残差 = 0
|
| 130 |
-
if init_zero_output:
|
| 131 |
-
nn.init.zeros_(self.out_mlp[-1].weight)
|
| 132 |
-
nn.init.zeros_(self.out_mlp[-1].bias)
|
| 133 |
-
|
| 134 |
-
def forward(
|
| 135 |
-
self,
|
| 136 |
-
dino_feats: torch.Tensor, # [B, T, gh, gw, D_dino]
|
| 137 |
-
ego_raw: torch.Tensor, # [B, 8, 6] 真实空间
|
| 138 |
-
intr_raw: torch.Tensor, # [B, intr_dim],须与构造 ``OnlineCalibration`` 时一致
|
| 139 |
-
extr_raw: torch.Tensor, # [B, 6]
|
| 140 |
-
) -> CalibrationOutput:
|
| 141 |
-
b = dino_feats.shape[0]
|
| 142 |
-
if intr_raw.shape[-1] != self.intr_dim:
|
| 143 |
-
raise ValueError(
|
| 144 |
-
f"intr_raw.shape[-1]={intr_raw.shape[-1]} 与 OnlineCalibration.intr_dim={self.intr_dim} 不一致。"
|
| 145 |
-
f"数据是多少维就设多少维(见 configs calibration.intr_vec_dim),不要填充假参数。"
|
| 146 |
-
)
|
| 147 |
-
# === 上下文 K/V ===
|
| 148 |
-
# 把 [B, T, gh, gw, D] flatten 为 [B, T*gh*gw, D]
|
| 149 |
-
kv = dino_feats.reshape(b, -1, dino_feats.shape[-1])
|
| 150 |
-
kv = self.kv_norm(self.kv_proj(kv))
|
| 151 |
-
|
| 152 |
-
# === 条件 token(symlog 空间)===
|
| 153 |
-
ego_sym = symlog(ego_raw)
|
| 154 |
-
intr_sym = symlog(intr_raw)
|
| 155 |
-
extr_sym = symlog(extr_raw)
|
| 156 |
-
|
| 157 |
-
ego_tok = self.ego_proj(ego_sym) # [B, 8, D]
|
| 158 |
-
intr_tok = self.intr_proj(intr_sym).unsqueeze(1) # [B, 1, D]
|
| 159 |
-
extr_tok = self.extr_proj(extr_sym).unsqueeze(1) # [B, 1, D]
|
| 160 |
-
cond = torch.cat([ego_tok, intr_tok, extr_tok], dim=1) # [B, num_cond, D]
|
| 161 |
-
cond = self.cond_pe(cond)
|
| 162 |
-
|
| 163 |
-
# === 拼接 query token ===
|
| 164 |
-
q = self.query_tokens.unsqueeze(0).expand(b, -1, -1)
|
| 165 |
-
q = self.query_pe(q)
|
| 166 |
-
seq = torch.cat([cond, q], dim=1) # [B, num_cond + num_query, D]
|
| 167 |
-
|
| 168 |
-
# === 6 层 block ===
|
| 169 |
-
for blk in self.blocks:
|
| 170 |
-
seq = blk(seq, kv)
|
| 171 |
-
|
| 172 |
-
# === 取 last token ===
|
| 173 |
-
last = self.out_norm(seq[:, -1, :]) # [B, D]
|
| 174 |
-
residual_flat = self.out_mlp(last)
|
| 175 |
-
# Tanh + 缩放
|
| 176 |
-
residual_flat = torch.tanh(residual_flat) * self.residual_range
|
| 177 |
-
|
| 178 |
-
# 拆分
|
| 179 |
-
n_ego = self.num_history_frames * self.ego_dim
|
| 180 |
-
ego_res = residual_flat[:, :n_ego].view(b, self.num_history_frames, self.ego_dim)
|
| 181 |
-
intr_res = residual_flat[:, n_ego : n_ego + self.intr_dim]
|
| 182 |
-
extr_res = residual_flat[:, n_ego + self.intr_dim :]
|
| 183 |
-
|
| 184 |
-
# symlog 空间叠加 + symexp 还原
|
| 185 |
-
corrected_ego = symexp(symlog(ego_raw) + ego_res)
|
| 186 |
-
corrected_intr = symexp(symlog(intr_raw) + intr_res)
|
| 187 |
-
corrected_extr = symexp(symlog(extr_raw) + extr_res)
|
| 188 |
-
|
| 189 |
-
return CalibrationOutput(
|
| 190 |
-
ego_residual=ego_res,
|
| 191 |
-
intr_residual=intr_res,
|
| 192 |
-
extr_residual=extr_res,
|
| 193 |
-
corrected_ego=corrected_ego,
|
| 194 |
-
corrected_intr=corrected_intr,
|
| 195 |
-
corrected_extr=corrected_extr,
|
| 196 |
-
)
|
|
|
|
| 1 |
+
"""在线校准网络。
|
| 2 |
+
|
| 3 |
+
输入
|
| 4 |
+
- DINOv3 patch 特征 ``[B, T, gh, gw, D_dino]``(用作 K/V 上下文)。
|
| 5 |
+
- 8 帧自车位姿(每帧 6D = 3 平移 + 3 轴角)``[B, 8, 6]``。
|
| 6 |
+
- f-theta 内参 ``[B, intr_dim]``。Cosmos-Drive-Dreams 常见 **11** 维(无 ``linear_cde``);
|
| 7 |
+
README 完整式为 14 维时把 ``intr_dim`` 配成 14 即可,**不做零填充**。
|
| 8 |
+
- 相机外参 6D ``[B, 6]``。
|
| 9 |
+
|
| 10 |
+
流程
|
| 11 |
+
- 上述运动学 / 内外参先 ``symlog`` 归一,再 Linear -> 256,作为额外
|
| 12 |
+
条件 token 与 256 个可学习 query token 拼接。
|
| 13 |
+
- 6 层 = 2 × [1 GateCrossAttn(K,V <- DINOv3 patch) + 2 GateSelfAttn]。
|
| 14 |
+
- 取最后一个 token,MLP -> Tanh -> ``residual_range`` → 输出
|
| 15 |
+
symlog 空间的残差。``corrected = symexp(symlog(raw) + Tanh_residual)``。
|
| 16 |
+
|
| 17 |
+
输出
|
| 18 |
+
- ``ego_residual`` ``[B, 8, 6]``、``intr_residual`` ``[B, intr_dim]``、
|
| 19 |
+
``extr_residual`` ``[B, 6]``,已 symexp 还原到真实空间的 ``corrected_*``。
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
|
| 29 |
+
from ..modules.gate_attention import GateCrossAttention, GateSelfAttention
|
| 30 |
+
from ..modules.learned_pe import LearnedTokenPE
|
| 31 |
+
from ..modules.normalization import symexp, symlog
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class CalibrationOutput:
|
| 36 |
+
"""校准网络输出。残差均在 symlog 空间,``corrected_*`` 已 symexp 还原。"""
|
| 37 |
+
|
| 38 |
+
ego_residual: torch.Tensor # [B, 8, 6]
|
| 39 |
+
intr_residual: torch.Tensor # [B, intr_dim]
|
| 40 |
+
extr_residual: torch.Tensor # [B, 6]
|
| 41 |
+
corrected_ego: torch.Tensor # [B, 8, 6],真实空间
|
| 42 |
+
corrected_intr: torch.Tensor # [B, intr_dim]
|
| 43 |
+
corrected_extr: torch.Tensor # [B, 6]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class _CalibBlock(nn.Module):
|
| 47 |
+
"""单个校准 block:1 GateCrossAttn + 2 GateSelfAttn,PreNorm。"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, dim: int, dim_kv: int, num_heads: int, num_self: int = 2) -> None:
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.cross_norm = nn.LayerNorm(dim)
|
| 52 |
+
self.cross = GateCrossAttention(dim, dim_kv, num_heads=num_heads)
|
| 53 |
+
self.cross_drop = nn.Dropout(0.0)
|
| 54 |
+
self.self_blocks = nn.ModuleList()
|
| 55 |
+
for _ in range(num_self):
|
| 56 |
+
self.self_blocks.append(
|
| 57 |
+
nn.ModuleDict({
|
| 58 |
+
"norm": nn.LayerNorm(dim),
|
| 59 |
+
"attn": GateSelfAttention(dim, num_heads=num_heads),
|
| 60 |
+
})
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def forward(self, q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
|
| 64 |
+
# 1) Cross
|
| 65 |
+
q = q + self.cross(self.cross_norm(q), kv)
|
| 66 |
+
# 2) Self ×2
|
| 67 |
+
for blk in self.self_blocks:
|
| 68 |
+
q = q + blk["attn"](blk["norm"](q))
|
| 69 |
+
return q
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class OnlineCalibration(nn.Module):
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
dino_dim: int = 768,
|
| 76 |
+
hidden_dim: int = 256,
|
| 77 |
+
num_query_tokens: int = 256,
|
| 78 |
+
num_blocks: int = 2,
|
| 79 |
+
num_self_attn_per_block: int = 2,
|
| 80 |
+
num_heads: int = 8,
|
| 81 |
+
residual_range: float = 0.1,
|
| 82 |
+
ego_dim: int = 6, # 3 平移 + 3 轴角
|
| 83 |
+
intr_dim: int = 11,
|
| 84 |
+
extr_dim: int = 6,
|
| 85 |
+
num_history_frames: int = 8,
|
| 86 |
+
init_zero_output: bool = True,
|
| 87 |
+
) -> None:
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.hidden_dim = hidden_dim
|
| 90 |
+
self.residual_range = residual_range
|
| 91 |
+
self.ego_dim = ego_dim
|
| 92 |
+
self.intr_dim = intr_dim
|
| 93 |
+
self.extr_dim = extr_dim
|
| 94 |
+
self.num_history_frames = num_history_frames
|
| 95 |
+
|
| 96 |
+
# 256 可学习 query token
|
| 97 |
+
self.query_tokens = nn.Parameter(torch.empty(num_query_tokens, hidden_dim))
|
| 98 |
+
nn.init.trunc_normal_(self.query_tokens, std=0.02)
|
| 99 |
+
self.query_pe = LearnedTokenPE(num_query_tokens, hidden_dim)
|
| 100 |
+
|
| 101 |
+
# 条件 token 编码(symlog 空间)
|
| 102 |
+
self.ego_proj = nn.Linear(ego_dim, hidden_dim)
|
| 103 |
+
self.intr_proj = nn.Linear(intr_dim, hidden_dim)
|
| 104 |
+
self.extr_proj = nn.Linear(extr_dim, hidden_dim)
|
| 105 |
+
# 条件 token 也加可学习 PE(与 query 区分)
|
| 106 |
+
num_cond = num_history_frames + 2 # 8 ego + 1 intr + 1 extr
|
| 107 |
+
self.cond_pe = LearnedTokenPE(num_cond, hidden_dim)
|
| 108 |
+
self.num_cond = num_cond
|
| 109 |
+
self.num_query = num_query_tokens
|
| 110 |
+
|
| 111 |
+
# KV 上下文:DINOv3 patch 特征投影到 hidden_dim
|
| 112 |
+
self.kv_proj = nn.Linear(dino_dim, hidden_dim)
|
| 113 |
+
self.kv_norm = nn.LayerNorm(hidden_dim)
|
| 114 |
+
|
| 115 |
+
# 校准 block × num_blocks
|
| 116 |
+
self.blocks = nn.ModuleList([
|
| 117 |
+
_CalibBlock(hidden_dim, hidden_dim, num_heads=num_heads, num_self=num_self_attn_per_block)
|
| 118 |
+
for _ in range(num_blocks)
|
| 119 |
+
])
|
| 120 |
+
|
| 121 |
+
# 输出 MLP:取 last token -> 残差向量
|
| 122 |
+
residual_total = num_history_frames * ego_dim + intr_dim + extr_dim
|
| 123 |
+
self.out_norm = nn.LayerNorm(hidden_dim)
|
| 124 |
+
self.out_mlp = nn.Sequential(
|
| 125 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 126 |
+
nn.GELU(),
|
| 127 |
+
nn.Linear(hidden_dim, residual_total),
|
| 128 |
+
)
|
| 129 |
+
# 0 初始化最后一层 → Tanh(0) = 0 → 初始残差 = 0
|
| 130 |
+
if init_zero_output:
|
| 131 |
+
nn.init.zeros_(self.out_mlp[-1].weight)
|
| 132 |
+
nn.init.zeros_(self.out_mlp[-1].bias)
|
| 133 |
+
|
| 134 |
+
def forward(
|
| 135 |
+
self,
|
| 136 |
+
dino_feats: torch.Tensor, # [B, T, gh, gw, D_dino]
|
| 137 |
+
ego_raw: torch.Tensor, # [B, 8, 6] 真实空间
|
| 138 |
+
intr_raw: torch.Tensor, # [B, intr_dim],须与构造 ``OnlineCalibration`` 时一致
|
| 139 |
+
extr_raw: torch.Tensor, # [B, 6]
|
| 140 |
+
) -> CalibrationOutput:
|
| 141 |
+
b = dino_feats.shape[0]
|
| 142 |
+
if intr_raw.shape[-1] != self.intr_dim:
|
| 143 |
+
raise ValueError(
|
| 144 |
+
f"intr_raw.shape[-1]={intr_raw.shape[-1]} 与 OnlineCalibration.intr_dim={self.intr_dim} 不一致。"
|
| 145 |
+
f"数据是多少维就设多少维(见 configs calibration.intr_vec_dim),不要填充假参数。"
|
| 146 |
+
)
|
| 147 |
+
# === 上下文 K/V ===
|
| 148 |
+
# 把 [B, T, gh, gw, D] flatten 为 [B, T*gh*gw, D]
|
| 149 |
+
kv = dino_feats.reshape(b, -1, dino_feats.shape[-1])
|
| 150 |
+
kv = self.kv_norm(self.kv_proj(kv))
|
| 151 |
+
|
| 152 |
+
# === 条件 token(symlog 空间)===
|
| 153 |
+
ego_sym = symlog(ego_raw)
|
| 154 |
+
intr_sym = symlog(intr_raw)
|
| 155 |
+
extr_sym = symlog(extr_raw)
|
| 156 |
+
|
| 157 |
+
ego_tok = self.ego_proj(ego_sym) # [B, 8, D]
|
| 158 |
+
intr_tok = self.intr_proj(intr_sym).unsqueeze(1) # [B, 1, D]
|
| 159 |
+
extr_tok = self.extr_proj(extr_sym).unsqueeze(1) # [B, 1, D]
|
| 160 |
+
cond = torch.cat([ego_tok, intr_tok, extr_tok], dim=1) # [B, num_cond, D]
|
| 161 |
+
cond = self.cond_pe(cond)
|
| 162 |
+
|
| 163 |
+
# === 拼接 query token ===
|
| 164 |
+
q = self.query_tokens.unsqueeze(0).expand(b, -1, -1)
|
| 165 |
+
q = self.query_pe(q)
|
| 166 |
+
seq = torch.cat([cond, q], dim=1) # [B, num_cond + num_query, D]
|
| 167 |
+
|
| 168 |
+
# === 6 层 block ===
|
| 169 |
+
for blk in self.blocks:
|
| 170 |
+
seq = blk(seq, kv)
|
| 171 |
+
|
| 172 |
+
# === 取 last token ===
|
| 173 |
+
last = self.out_norm(seq[:, -1, :]) # [B, D]
|
| 174 |
+
residual_flat = self.out_mlp(last)
|
| 175 |
+
# Tanh + 缩放
|
| 176 |
+
residual_flat = torch.tanh(residual_flat) * self.residual_range
|
| 177 |
+
|
| 178 |
+
# 拆分
|
| 179 |
+
n_ego = self.num_history_frames * self.ego_dim
|
| 180 |
+
ego_res = residual_flat[:, :n_ego].view(b, self.num_history_frames, self.ego_dim)
|
| 181 |
+
intr_res = residual_flat[:, n_ego : n_ego + self.intr_dim]
|
| 182 |
+
extr_res = residual_flat[:, n_ego + self.intr_dim :]
|
| 183 |
+
|
| 184 |
+
# symlog 空间叠加 + symexp 还原
|
| 185 |
+
corrected_ego = symexp(symlog(ego_raw) + ego_res)
|
| 186 |
+
corrected_intr = symexp(symlog(intr_raw) + intr_res)
|
| 187 |
+
corrected_extr = symexp(symlog(extr_raw) + extr_res)
|
| 188 |
+
|
| 189 |
+
return CalibrationOutput(
|
| 190 |
+
ego_residual=ego_res,
|
| 191 |
+
intr_residual=intr_res,
|
| 192 |
+
extr_residual=extr_res,
|
| 193 |
+
corrected_ego=corrected_ego,
|
| 194 |
+
corrected_intr=corrected_intr,
|
| 195 |
+
corrected_extr=corrected_extr,
|
| 196 |
+
)
|
src/wjad/data/__init__.py
CHANGED
|
@@ -1,39 +1,39 @@
|
|
| 1 |
-
"""Cosmos-Drive-Dreams 数据加载与目标构建。"""
|
| 2 |
-
|
| 3 |
-
from .se3 import (
|
| 4 |
-
matrix_to_6d,
|
| 5 |
-
six_d_to_matrix,
|
| 6 |
-
invert_se3,
|
| 7 |
-
rotation_matrix_to_axis_angle,
|
| 8 |
-
axis_angle_to_rotation_matrix,
|
| 9 |
-
)
|
| 10 |
-
from .ftheta_proj import project_points_ftheta
|
| 11 |
-
from .transforms import (
|
| 12 |
-
crop_top_half,
|
| 13 |
-
normalize_image,
|
| 14 |
-
add_gaussian_noise,
|
| 15 |
-
perturb_kinematics,
|
| 16 |
-
)
|
| 17 |
-
from .targets import build_detection_targets, build_ego_future_target, ObjectTrackInfo
|
| 18 |
-
from .hdmap import parse_hdmap_clip, HDMAP_SOURCES
|
| 19 |
-
from .cosmos_dataset import CosmosDriveDreamsDataset, build_clip_index
|
| 20 |
-
|
| 21 |
-
__all__ = [
|
| 22 |
-
"matrix_to_6d",
|
| 23 |
-
"six_d_to_matrix",
|
| 24 |
-
"invert_se3",
|
| 25 |
-
"rotation_matrix_to_axis_angle",
|
| 26 |
-
"axis_angle_to_rotation_matrix",
|
| 27 |
-
"project_points_ftheta",
|
| 28 |
-
"crop_top_half",
|
| 29 |
-
"normalize_image",
|
| 30 |
-
"add_gaussian_noise",
|
| 31 |
-
"perturb_kinematics",
|
| 32 |
-
"build_detection_targets",
|
| 33 |
-
"build_ego_future_target",
|
| 34 |
-
"ObjectTrackInfo",
|
| 35 |
-
"CosmosDriveDreamsDataset",
|
| 36 |
-
"build_clip_index",
|
| 37 |
-
"parse_hdmap_clip",
|
| 38 |
-
"HDMAP_SOURCES",
|
| 39 |
-
]
|
|
|
|
| 1 |
+
"""Cosmos-Drive-Dreams 数据加载与目标构建。"""
|
| 2 |
+
|
| 3 |
+
from .se3 import (
|
| 4 |
+
matrix_to_6d,
|
| 5 |
+
six_d_to_matrix,
|
| 6 |
+
invert_se3,
|
| 7 |
+
rotation_matrix_to_axis_angle,
|
| 8 |
+
axis_angle_to_rotation_matrix,
|
| 9 |
+
)
|
| 10 |
+
from .ftheta_proj import project_points_ftheta
|
| 11 |
+
from .transforms import (
|
| 12 |
+
crop_top_half,
|
| 13 |
+
normalize_image,
|
| 14 |
+
add_gaussian_noise,
|
| 15 |
+
perturb_kinematics,
|
| 16 |
+
)
|
| 17 |
+
from .targets import build_detection_targets, build_ego_future_target, ObjectTrackInfo
|
| 18 |
+
from .hdmap import parse_hdmap_clip, HDMAP_SOURCES
|
| 19 |
+
from .cosmos_dataset import CosmosDriveDreamsDataset, build_clip_index
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
"matrix_to_6d",
|
| 23 |
+
"six_d_to_matrix",
|
| 24 |
+
"invert_se3",
|
| 25 |
+
"rotation_matrix_to_axis_angle",
|
| 26 |
+
"axis_angle_to_rotation_matrix",
|
| 27 |
+
"project_points_ftheta",
|
| 28 |
+
"crop_top_half",
|
| 29 |
+
"normalize_image",
|
| 30 |
+
"add_gaussian_noise",
|
| 31 |
+
"perturb_kinematics",
|
| 32 |
+
"build_detection_targets",
|
| 33 |
+
"build_ego_future_target",
|
| 34 |
+
"ObjectTrackInfo",
|
| 35 |
+
"CosmosDriveDreamsDataset",
|
| 36 |
+
"build_clip_index",
|
| 37 |
+
"parse_hdmap_clip",
|
| 38 |
+
"HDMAP_SOURCES",
|
| 39 |
+
]
|
src/wjad/data/cosmos_dataset.py
CHANGED
|
@@ -1,439 +1,439 @@
|
|
| 1 |
-
"""Cosmos-Drive-Dreams 数据集加载器(真实实现)。
|
| 2 |
-
|
| 3 |
-
期待目录结构(从 NVIDIA 提供的 .tar 解压):
|
| 4 |
-
|
| 5 |
-
data_root/
|
| 6 |
-
synthetic/single_view/
|
| 7 |
-
generation/{clip_id}_{chunk_id}_{weather}.mp4 # 121 帧合成视频
|
| 8 |
-
labels/{clip_id}/
|
| 9 |
-
vehicle_pose/000000.vehicle_pose.npy ... # 30 FPS, FLU
|
| 10 |
-
pose/000000.pose.{camera}.npy # 30 FPS, OpenCV
|
| 11 |
-
ftheta_intrinsic/ftheta_intrinsic.{camera}.npy
|
| 12 |
-
all_object_info/000000.all_object_info.json
|
| 13 |
-
lidar_raw/000000.lidar_raw.npz # 10 FPS
|
| 14 |
-
|
| 15 |
-
每段 clip 提供:
|
| 16 |
-
- 视频按 `_chunk_id` 分块。chunk_id=0 对应 label idx 0..120;chunk_id=1 对应 label idx 121..241。
|
| 17 |
-
- 每个样本:8 帧不重叠窗口 t∈[7, 96],输入 8 帧(t-7..t)+ 未来 24 帧标签。
|
| 18 |
-
"""
|
| 19 |
-
|
| 20 |
-
from __future__ import annotations
|
| 21 |
-
|
| 22 |
-
import json
|
| 23 |
-
from dataclasses import dataclass
|
| 24 |
-
from pathlib import Path
|
| 25 |
-
from typing import Sequence
|
| 26 |
-
|
| 27 |
-
import cv2
|
| 28 |
-
import numpy as np
|
| 29 |
-
import torch
|
| 30 |
-
from torch.utils.data import Dataset
|
| 31 |
-
|
| 32 |
-
from ..modules.normalization import symlog
|
| 33 |
-
from ..modules.rays import FThetaCamera
|
| 34 |
-
from .label_paths import resolve_clip_file
|
| 35 |
-
from .hdmap import parse_hdmap_clip
|
| 36 |
-
from .se3 import matrix_to_6d
|
| 37 |
-
from .targets import (
|
| 38 |
-
ObjectTrackInfo,
|
| 39 |
-
build_detection_targets,
|
| 40 |
-
build_ego_future_target,
|
| 41 |
-
)
|
| 42 |
-
from .transforms import DINOV3_MEAN, DINOV3_STD
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
# 数据集 README 列出的对象类型;动态类用于 is_dynamic + 未来轨迹监督。
|
| 46 |
-
DEFAULT_DYNAMIC_CLASSES = [
|
| 47 |
-
"Automobile",
|
| 48 |
-
"Heavy_truck",
|
| 49 |
-
"Bus",
|
| 50 |
-
"Train_or_tram_car",
|
| 51 |
-
"Trolley_bus",
|
| 52 |
-
"Other_vehicle",
|
| 53 |
-
"Trailer",
|
| 54 |
-
"Person",
|
| 55 |
-
"Stroller",
|
| 56 |
-
"Rider",
|
| 57 |
-
"Animal",
|
| 58 |
-
"Protruding_object",
|
| 59 |
-
]
|
| 60 |
-
|
| 61 |
-
# 结构化场景类(与 ``hdmap.py`` 的 9 个 HDMAP_SOURCES key 一一对应)。
|
| 62 |
-
DEFAULT_STRUCTURED_CLASSES = [
|
| 63 |
-
"lane",
|
| 64 |
-
"laneline",
|
| 65 |
-
"road_boundary",
|
| 66 |
-
"wait_line",
|
| 67 |
-
"crosswalk",
|
| 68 |
-
"road_marking",
|
| 69 |
-
"pole",
|
| 70 |
-
"traffic_light",
|
| 71 |
-
"traffic_sign",
|
| 72 |
-
]
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
@dataclass
|
| 76 |
-
class ClipSample:
|
| 77 |
-
"""clip 索引项。"""
|
| 78 |
-
|
| 79 |
-
clip_id: str
|
| 80 |
-
chunk_id: int
|
| 81 |
-
weather: str
|
| 82 |
-
video_path: Path
|
| 83 |
-
labels_dir: Path
|
| 84 |
-
anchor_t: int # 当前帧(含),范围 [7, 96]
|
| 85 |
-
chunk_offset: int # 当前 chunk 在标签里的起始 idx(0 或 121)
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def build_clip_index(
|
| 89 |
-
data_root: str | Path,
|
| 90 |
-
weathers: Sequence[str] = ("Sunny",),
|
| 91 |
-
chunk_ids: Sequence[int] = (0, 1),
|
| 92 |
-
camera_name: str = "camera_front_wide_120fov",
|
| 93 |
-
stride: int = 8,
|
| 94 |
-
anchor_min: int = 7,
|
| 95 |
-
anchor_max: int = 96,
|
| 96 |
-
max_clips: int | None = None,
|
| 97 |
-
) -> list[ClipSample]:
|
| 98 |
-
"""枚举所有可用 (clip, chunk, weather, anchor_t) 样本。
|
| 99 |
-
|
| 100 |
-
锚点 ``t`` 在 chunk 内为局部索引,对应视频帧 ``t``,对应标签帧
|
| 101 |
-
``chunk_offset + t``(chunk_offset = chunk_id * 121)。
|
| 102 |
-
"""
|
| 103 |
-
root = Path(data_root)
|
| 104 |
-
syn_dir = root / "synthetic" / "single_view" / "generation"
|
| 105 |
-
labels_dir = root / "labels"
|
| 106 |
-
|
| 107 |
-
samples: list[ClipSample] = []
|
| 108 |
-
if not syn_dir.exists():
|
| 109 |
-
return samples
|
| 110 |
-
|
| 111 |
-
for video_path in sorted(syn_dir.glob("*.mp4")):
|
| 112 |
-
# 文件名形如 {clip_id}_{chunk_id}_{weather}.mp4
|
| 113 |
-
# clip_id 可能含下划线(UUID 或 timestamp 形式),所以从右侧解析
|
| 114 |
-
stem = video_path.stem
|
| 115 |
-
parts = stem.rsplit("_", 2)
|
| 116 |
-
if len(parts) != 3:
|
| 117 |
-
continue
|
| 118 |
-
clip_id, chunk_str, weather = parts
|
| 119 |
-
try:
|
| 120 |
-
chunk_id = int(chunk_str)
|
| 121 |
-
except ValueError:
|
| 122 |
-
continue
|
| 123 |
-
if chunk_id not in chunk_ids or weather not in weathers:
|
| 124 |
-
continue
|
| 125 |
-
|
| 126 |
-
clip_label_dir = labels_dir / clip_id
|
| 127 |
-
if not clip_label_dir.exists():
|
| 128 |
-
continue
|
| 129 |
-
|
| 130 |
-
chunk_offset = chunk_id * 121
|
| 131 |
-
for t in range(anchor_min, anchor_max + 1, stride):
|
| 132 |
-
samples.append(
|
| 133 |
-
ClipSample(
|
| 134 |
-
clip_id=clip_id,
|
| 135 |
-
chunk_id=chunk_id,
|
| 136 |
-
weather=weather,
|
| 137 |
-
video_path=video_path,
|
| 138 |
-
labels_dir=clip_label_dir,
|
| 139 |
-
anchor_t=t,
|
| 140 |
-
chunk_offset=chunk_offset,
|
| 141 |
-
)
|
| 142 |
-
)
|
| 143 |
-
if max_clips is not None and len({s.clip_id for s in samples}) >= max_clips:
|
| 144 |
-
break
|
| 145 |
-
|
| 146 |
-
return samples
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
def _load_video_frames(
|
| 150 |
-
video_path: Path,
|
| 151 |
-
frame_indices: Sequence[int],
|
| 152 |
-
target_h: int,
|
| 153 |
-
target_w: int,
|
| 154 |
-
) -> torch.Tensor:
|
| 155 |
-
"""从 .mp4 中读取指定帧序列,调整大小并按 ``[T, 3, H, W]`` 返回 ``float32 in [0, 1]``。"""
|
| 156 |
-
cap = cv2.VideoCapture(str(video_path))
|
| 157 |
-
if not cap.isOpened():
|
| 158 |
-
raise FileNotFoundError(f"无法打开视频: {video_path}")
|
| 159 |
-
frames = []
|
| 160 |
-
for idx in frame_indices:
|
| 161 |
-
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
| 162 |
-
ok, bgr = cap.read()
|
| 163 |
-
if not ok:
|
| 164 |
-
cap.release()
|
| 165 |
-
raise RuntimeError(f"读取帧 {idx} 失败: {video_path}")
|
| 166 |
-
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
| 167 |
-
rgb = cv2.resize(rgb, (target_w, target_h * 2), interpolation=cv2.INTER_AREA)
|
| 168 |
-
# 裁去上半部分(天空)后高度变为 target_h
|
| 169 |
-
rgb = rgb[target_h:, :, :]
|
| 170 |
-
rgb = rgb.astype(np.float32) / 255.0
|
| 171 |
-
frames.append(torch.from_numpy(rgb).permute(2, 0, 1)) # [3, H, W]
|
| 172 |
-
cap.release()
|
| 173 |
-
return torch.stack(frames, dim=0)
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
def _load_npy(path: Path) -> np.ndarray:
|
| 177 |
-
return np.load(path, allow_pickle=False)
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
def _load_object_info(path: Path) -> list[ObjectTrackInfo]:
|
| 181 |
-
"""解析单帧 all_object_info JSON。"""
|
| 182 |
-
if not path.exists():
|
| 183 |
-
return []
|
| 184 |
-
data = json.loads(path.read_text())
|
| 185 |
-
out = []
|
| 186 |
-
for tid, info in data.items():
|
| 187 |
-
T = torch.tensor(info["object_to_world"], dtype=torch.float32)
|
| 188 |
-
lwh = torch.tensor(info["object_lwh"], dtype=torch.float32)
|
| 189 |
-
out.append(
|
| 190 |
-
ObjectTrackInfo(
|
| 191 |
-
tracking_id=tid,
|
| 192 |
-
object_to_world=T,
|
| 193 |
-
lwh=lwh,
|
| 194 |
-
is_moving=bool(info.get("object_is_moving", False)),
|
| 195 |
-
object_type=str(info.get("object_type", "")),
|
| 196 |
-
)
|
| 197 |
-
)
|
| 198 |
-
return out
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
def _load_lidar_self_frame(
|
| 202 |
-
labels_dir: Path,
|
| 203 |
-
label_idx: int,
|
| 204 |
-
vehicle_pose: torch.Tensor,
|
| 205 |
-
max_history: int = 3,
|
| 206 |
-
) -> torch.Tensor | None:
|
| 207 |
-
"""读取与 ``label_idx`` 时间最近的 LIDAR 帧并把 xyz 转到当前 ego self 系。
|
| 208 |
-
|
| 209 |
-
LIDAR 是 10 FPS(每 3 个相机帧 1 个 LIDAR 帧),数据集存储 ``000000``、
|
| 210 |
-
``000003``、``000006`` 等步长 3 的索引。我们向下取整最近的一帧。
|
| 211 |
-
"""
|
| 212 |
-
lidar_idx = (label_idx // 3) * 3
|
| 213 |
-
search_order = [lidar_idx - back * 3 for back in range(max_history + 1) if lidar_idx - back * 3 >= 0]
|
| 214 |
-
p: Path | None = None
|
| 215 |
-
for idx_try in search_order:
|
| 216 |
-
try:
|
| 217 |
-
p = resolve_clip_file(labels_dir, "lidar_raw", f"{idx_try:06d}.lidar_raw.npz")
|
| 218 |
-
break
|
| 219 |
-
except FileNotFoundError:
|
| 220 |
-
continue
|
| 221 |
-
if p is None:
|
| 222 |
-
return None
|
| 223 |
-
arr = np.load(p, allow_pickle=False)
|
| 224 |
-
xyz_lidar = arr["xyz"] # [N, 3] in lidar frame
|
| 225 |
-
lidar_to_world = arr["lidar_to_world"] # [4, 4]
|
| 226 |
-
# 转到 world 后再转 self
|
| 227 |
-
pts_w = (lidar_to_world[:3, :3] @ xyz_lidar.T).T + lidar_to_world[:3, 3]
|
| 228 |
-
inv_pose = torch.linalg.inv(vehicle_pose)
|
| 229 |
-
pts_w_t = torch.from_numpy(pts_w).float()
|
| 230 |
-
pts_self = (inv_pose[:3, :3] @ pts_w_t.T).T + inv_pose[:3, 3]
|
| 231 |
-
return pts_self
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
class CosmosDriveDreamsDataset(Dataset):
|
| 235 |
-
"""端到端样本:8 帧图像 + ego/intr/extr + 检测 + 自车未来 + 对象未来。"""
|
| 236 |
-
|
| 237 |
-
def __init__(
|
| 238 |
-
self,
|
| 239 |
-
data_root: str | Path,
|
| 240 |
-
samples: list[ClipSample] | None = None,
|
| 241 |
-
weathers: Sequence[str] = ("Sunny",),
|
| 242 |
-
camera_name: str = "camera_front_wide_120fov",
|
| 243 |
-
image_h: int = 384,
|
| 244 |
-
image_w: int = 1024,
|
| 245 |
-
num_history: int = 8,
|
| 246 |
-
future_horizon: int = 24,
|
| 247 |
-
max_distance_m: float = 48.0,
|
| 248 |
-
occlusion_tol: float = 0.5,
|
| 249 |
-
dynamic_classes: Sequence[str] = DEFAULT_DYNAMIC_CLASSES,
|
| 250 |
-
structured_classes: Sequence[str] = DEFAULT_STRUCTURED_CLASSES,
|
| 251 |
-
do_normalize: bool = True,
|
| 252 |
-
use_lidar_occlusion: bool = True,
|
| 253 |
-
use_hdmap: bool = True,
|
| 254 |
-
) -> None:
|
| 255 |
-
super().__init__()
|
| 256 |
-
self.data_root = Path(data_root)
|
| 257 |
-
self.samples = samples if samples is not None else build_clip_index(
|
| 258 |
-
data_root, weathers=weathers, camera_name=camera_name
|
| 259 |
-
)
|
| 260 |
-
self.camera_name = camera_name
|
| 261 |
-
self.image_h = image_h
|
| 262 |
-
self.image_w = image_w
|
| 263 |
-
self.num_history = num_history
|
| 264 |
-
self.future_horizon = future_horizon
|
| 265 |
-
self.max_distance_m = max_distance_m
|
| 266 |
-
self.occlusion_tol = occlusion_tol
|
| 267 |
-
self.dynamic_classes = list(dynamic_classes)
|
| 268 |
-
self.structured_classes = list(structured_classes)
|
| 269 |
-
self.do_normalize = do_normalize
|
| 270 |
-
self.use_lidar_occlusion = use_lidar_occlusion
|
| 271 |
-
self.use_hdmap = use_hdmap
|
| 272 |
-
# HDMap 是 per-clip 静态对象,缓存避免每个 anchor_t 都重新解析
|
| 273 |
-
self._hdmap_cache: dict[str, list[ObjectTrackInfo]] = {}
|
| 274 |
-
self._hdmap_cache_max = 32
|
| 275 |
-
|
| 276 |
-
def __len__(self) -> int:
|
| 277 |
-
return len(self.samples)
|
| 278 |
-
|
| 279 |
-
def _load_intrinsic(self, sample: ClipSample) -> torch.Tensor:
|
| 280 |
-
p = resolve_clip_file(
|
| 281 |
-
sample.labels_dir,
|
| 282 |
-
"ftheta_intrinsic",
|
| 283 |
-
f"ftheta_intrinsic.{self.camera_name}.npy",
|
| 284 |
-
)
|
| 285 |
-
return torch.from_numpy(_load_npy(p)).float()
|
| 286 |
-
|
| 287 |
-
def _load_pose_camera(self, sample: ClipSample, label_idx: int) -> torch.Tensor:
|
| 288 |
-
p = resolve_clip_file(
|
| 289 |
-
sample.labels_dir,
|
| 290 |
-
"pose",
|
| 291 |
-
f"{label_idx:06d}.pose.{self.camera_name}.npy",
|
| 292 |
-
)
|
| 293 |
-
return torch.from_numpy(_load_npy(p)).float()
|
| 294 |
-
|
| 295 |
-
def _load_pose_vehicle(self, sample: ClipSample, label_idx: int) -> torch.Tensor:
|
| 296 |
-
p = resolve_clip_file(
|
| 297 |
-
sample.labels_dir,
|
| 298 |
-
"vehicle_pose",
|
| 299 |
-
f"{label_idx:06d}.vehicle_pose.npy",
|
| 300 |
-
)
|
| 301 |
-
return torch.from_numpy(_load_npy(p)).float()
|
| 302 |
-
|
| 303 |
-
def _load_hdmap_static(self, clip_dir: Path) -> list[ObjectTrackInfo]:
|
| 304 |
-
if not self.use_hdmap:
|
| 305 |
-
return []
|
| 306 |
-
key = str(clip_dir)
|
| 307 |
-
cached = self._hdmap_cache.get(key)
|
| 308 |
-
if cached is not None:
|
| 309 |
-
return cached
|
| 310 |
-
objs = parse_hdmap_clip(clip_dir)
|
| 311 |
-
if len(self._hdmap_cache) >= self._hdmap_cache_max:
|
| 312 |
-
self._hdmap_cache.pop(next(iter(self._hdmap_cache)))
|
| 313 |
-
self._hdmap_cache[key] = objs
|
| 314 |
-
return objs
|
| 315 |
-
|
| 316 |
-
def _load_objects(self, sample: ClipSample, label_idx: int) -> list[ObjectTrackInfo]:
|
| 317 |
-
p = resolve_clip_file(
|
| 318 |
-
sample.labels_dir,
|
| 319 |
-
"all_object_info",
|
| 320 |
-
f"{label_idx:06d}.all_object_info.json",
|
| 321 |
-
)
|
| 322 |
-
dynamic = _load_object_info(p)
|
| 323 |
-
# HDMap 是 clip 级静态标签:t 与 t+k 帧都拿同一份(tracking_id 相同),
|
| 324 |
-
# 这样 ``build_detection_targets`` 的未来轨迹分支会自动得到 ~0 残差,
|
| 325 |
-
# 同时由 ``is_dynamic=0`` 在损失里被 mask 掉,不进 trajectory NLL。
|
| 326 |
-
return dynamic + self._load_hdmap_static(sample.labels_dir)
|
| 327 |
-
|
| 328 |
-
def __getitem__(self, idx: int) -> dict:
|
| 329 |
-
s = self.samples[idx]
|
| 330 |
-
# 视频帧索引(chunk 内 0-based)
|
| 331 |
-
t = s.anchor_t
|
| 332 |
-
history_frames = list(range(t - self.num_history + 1, t + 1))
|
| 333 |
-
# 标签索引:chunk_offset + chunk-local idx
|
| 334 |
-
history_label_idx = [s.chunk_offset + f for f in history_frames]
|
| 335 |
-
future_label_idx = [s.chunk_offset + t + 1 + k for k in range(self.future_horizon)]
|
| 336 |
-
|
| 337 |
-
# === 1) 加载图像 ===
|
| 338 |
-
# 注意:videl 已经裁过上半(数据生成时仍 1920x1080 等原始分辨率);
|
| 339 |
-
# 这里在 _load_video_frames 内同时做 resize 与 top-half 裁剪。
|
| 340 |
-
images = _load_video_frames(s.video_path, history_frames, self.image_h, self.image_w)
|
| 341 |
-
# [T, 3, H, W],[0, 1]
|
| 342 |
-
if self.do_normalize:
|
| 343 |
-
images = (images - DINOV3_MEAN) / DINOV3_STD
|
| 344 |
-
|
| 345 |
-
# === 2) 加载内参 / 外参 ===
|
| 346 |
-
intr_vec = self._load_intrinsic(s) # [14]
|
| 347 |
-
|
| 348 |
-
# 当前帧的 cam_to_world 与 vehicle_to_world,得到 cam_to_vehicle
|
| 349 |
-
pose_cam_world = self._load_pose_camera(s, s.chunk_offset + t)
|
| 350 |
-
pose_veh_world = self._load_pose_vehicle(s, s.chunk_offset + t)
|
| 351 |
-
# cam_to_vehicle = inv(vehicle_to_world) @ cam_to_world
|
| 352 |
-
inv_veh = torch.linalg.inv(pose_veh_world)
|
| 353 |
-
cam2veh = inv_veh @ pose_cam_world
|
| 354 |
-
extr_6d = matrix_to_6d(cam2veh) # [6]
|
| 355 |
-
|
| 356 |
-
# === 3) 历史 8 帧 ego pose(vehicle 6D)===
|
| 357 |
-
ego_6d = []
|
| 358 |
-
for li in history_label_idx:
|
| 359 |
-
T_vw = self._load_pose_vehicle(s, li)
|
| 360 |
-
ego_6d.append(matrix_to_6d(T_vw))
|
| 361 |
-
ego_6d = torch.stack(ego_6d, dim=0) # [8, 6]
|
| 362 |
-
|
| 363 |
-
# === 4) 检测 / 未来轨迹标签 ===
|
| 364 |
-
# objs_t / objs_future = 动态 all_object_info ∪ HDMap 静态对象。
|
| 365 |
-
objs_t = self._load_objects(s, s.chunk_offset + t)
|
| 366 |
-
objs_future = [self._load_objects(s, li) for li in future_label_idx]
|
| 367 |
-
veh_pose_future = []
|
| 368 |
-
for li in future_label_idx:
|
| 369 |
-
try:
|
| 370 |
-
veh_pose_future.append(self._load_pose_vehicle(s, li))
|
| 371 |
-
except FileNotFoundError:
|
| 372 |
-
break
|
| 373 |
-
|
| 374 |
-
cam = FThetaCamera.from_vector(intr_vec)
|
| 375 |
-
lidar_self = None
|
| 376 |
-
if self.use_lidar_occlusion:
|
| 377 |
-
try:
|
| 378 |
-
lidar_self = _load_lidar_self_frame(
|
| 379 |
-
s.labels_dir,
|
| 380 |
-
s.chunk_offset + t,
|
| 381 |
-
pose_veh_world,
|
| 382 |
-
)
|
| 383 |
-
except Exception:
|
| 384 |
-
lidar_self = None
|
| 385 |
-
|
| 386 |
-
det_targets = build_detection_targets(
|
| 387 |
-
objects_t=objs_t,
|
| 388 |
-
objects_future=objs_future,
|
| 389 |
-
vehicle_pose_t=pose_veh_world,
|
| 390 |
-
vehicle_pose_future=veh_pose_future,
|
| 391 |
-
cam_intrinsic=cam,
|
| 392 |
-
cam2vehicle=cam2veh,
|
| 393 |
-
image_h=self.image_h,
|
| 394 |
-
image_w=self.image_w,
|
| 395 |
-
max_distance_m=self.max_distance_m,
|
| 396 |
-
occlusion_depth_tolerance=self.occlusion_tol,
|
| 397 |
-
lidar_points_self=lidar_self,
|
| 398 |
-
dynamic_classes=self.dynamic_classes,
|
| 399 |
-
structured_classes=self.structured_classes,
|
| 400 |
-
future_horizon=self.future_horizon,
|
| 401 |
-
)
|
| 402 |
-
|
| 403 |
-
ego_future, ego_future_valid = build_ego_future_target(
|
| 404 |
-
pose_veh_world, veh_pose_future, horizon=self.future_horizon
|
| 405 |
-
)
|
| 406 |
-
|
| 407 |
-
sample_out = {
|
| 408 |
-
"images": images,
|
| 409 |
-
"ego_6d": ego_6d,
|
| 410 |
-
"intr_vec": intr_vec,
|
| 411 |
-
"extr_6d": extr_6d,
|
| 412 |
-
"ego_future": ego_future,
|
| 413 |
-
"ego_future_valid": ego_future_valid,
|
| 414 |
-
"targets": det_targets,
|
| 415 |
-
"meta": {
|
| 416 |
-
"clip_id": s.clip_id,
|
| 417 |
-
"chunk_id": s.chunk_id,
|
| 418 |
-
"weather": s.weather,
|
| 419 |
-
"anchor_t": s.anchor_t,
|
| 420 |
-
},
|
| 421 |
-
}
|
| 422 |
-
return sample_out
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
def collate_samples(batch: list[dict]) -> dict:
|
| 426 |
-
"""自定义 collate:对图像 / ego / intr / extr / ego_future 直接 stack;
|
| 427 |
-
targets 列表保留为 list(便于匈牙利匹配处理变长 N);
|
| 428 |
-
meta 也保留为 list。"""
|
| 429 |
-
out = {
|
| 430 |
-
"images": torch.stack([b["images"] for b in batch], dim=0),
|
| 431 |
-
"ego_6d": torch.stack([b["ego_6d"] for b in batch], dim=0),
|
| 432 |
-
"intr_vec": torch.stack([b["intr_vec"] for b in batch], dim=0),
|
| 433 |
-
"extr_6d": torch.stack([b["extr_6d"] for b in batch], dim=0),
|
| 434 |
-
"ego_future": torch.stack([b["ego_future"] for b in batch], dim=0),
|
| 435 |
-
"ego_future_valid": torch.stack([b["ego_future_valid"] for b in batch], dim=0),
|
| 436 |
-
"targets": [b["targets"] for b in batch],
|
| 437 |
-
"meta": [b["meta"] for b in batch],
|
| 438 |
-
}
|
| 439 |
-
return out
|
|
|
|
| 1 |
+
"""Cosmos-Drive-Dreams 数据集加载器(真实实现)。
|
| 2 |
+
|
| 3 |
+
期待目录结构(从 NVIDIA 提供的 .tar 解压):
|
| 4 |
+
|
| 5 |
+
data_root/
|
| 6 |
+
synthetic/single_view/
|
| 7 |
+
generation/{clip_id}_{chunk_id}_{weather}.mp4 # 121 帧合成视频
|
| 8 |
+
labels/{clip_id}/
|
| 9 |
+
vehicle_pose/000000.vehicle_pose.npy ... # 30 FPS, FLU
|
| 10 |
+
pose/000000.pose.{camera}.npy # 30 FPS, OpenCV
|
| 11 |
+
ftheta_intrinsic/ftheta_intrinsic.{camera}.npy
|
| 12 |
+
all_object_info/000000.all_object_info.json
|
| 13 |
+
lidar_raw/000000.lidar_raw.npz # 10 FPS
|
| 14 |
+
|
| 15 |
+
每段 clip 提供:
|
| 16 |
+
- 视频按 `_chunk_id` 分块。chunk_id=0 对应 label idx 0..120;chunk_id=1 对应 label idx 121..241。
|
| 17 |
+
- 每个样本:8 帧不重叠窗口 t∈[7, 96],输入 8 帧(t-7..t)+ 未来 24 帧标签。
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import json
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import Sequence
|
| 26 |
+
|
| 27 |
+
import cv2
|
| 28 |
+
import numpy as np
|
| 29 |
+
import torch
|
| 30 |
+
from torch.utils.data import Dataset
|
| 31 |
+
|
| 32 |
+
from ..modules.normalization import symlog
|
| 33 |
+
from ..modules.rays import FThetaCamera
|
| 34 |
+
from .label_paths import resolve_clip_file
|
| 35 |
+
from .hdmap import parse_hdmap_clip
|
| 36 |
+
from .se3 import matrix_to_6d
|
| 37 |
+
from .targets import (
|
| 38 |
+
ObjectTrackInfo,
|
| 39 |
+
build_detection_targets,
|
| 40 |
+
build_ego_future_target,
|
| 41 |
+
)
|
| 42 |
+
from .transforms import DINOV3_MEAN, DINOV3_STD
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# 数据集 README 列出的对象类型;动态类用于 is_dynamic + 未来轨迹监督。
|
| 46 |
+
DEFAULT_DYNAMIC_CLASSES = [
|
| 47 |
+
"Automobile",
|
| 48 |
+
"Heavy_truck",
|
| 49 |
+
"Bus",
|
| 50 |
+
"Train_or_tram_car",
|
| 51 |
+
"Trolley_bus",
|
| 52 |
+
"Other_vehicle",
|
| 53 |
+
"Trailer",
|
| 54 |
+
"Person",
|
| 55 |
+
"Stroller",
|
| 56 |
+
"Rider",
|
| 57 |
+
"Animal",
|
| 58 |
+
"Protruding_object",
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
# 结构化场景类(与 ``hdmap.py`` 的 9 个 HDMAP_SOURCES key 一一对应)。
|
| 62 |
+
DEFAULT_STRUCTURED_CLASSES = [
|
| 63 |
+
"lane",
|
| 64 |
+
"laneline",
|
| 65 |
+
"road_boundary",
|
| 66 |
+
"wait_line",
|
| 67 |
+
"crosswalk",
|
| 68 |
+
"road_marking",
|
| 69 |
+
"pole",
|
| 70 |
+
"traffic_light",
|
| 71 |
+
"traffic_sign",
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class ClipSample:
|
| 77 |
+
"""clip 索引项。"""
|
| 78 |
+
|
| 79 |
+
clip_id: str
|
| 80 |
+
chunk_id: int
|
| 81 |
+
weather: str
|
| 82 |
+
video_path: Path
|
| 83 |
+
labels_dir: Path
|
| 84 |
+
anchor_t: int # 当前帧(含),范围 [7, 96]
|
| 85 |
+
chunk_offset: int # 当前 chunk 在标签里的起始 idx(0 或 121)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def build_clip_index(
|
| 89 |
+
data_root: str | Path,
|
| 90 |
+
weathers: Sequence[str] = ("Sunny",),
|
| 91 |
+
chunk_ids: Sequence[int] = (0, 1),
|
| 92 |
+
camera_name: str = "camera_front_wide_120fov",
|
| 93 |
+
stride: int = 8,
|
| 94 |
+
anchor_min: int = 7,
|
| 95 |
+
anchor_max: int = 96,
|
| 96 |
+
max_clips: int | None = None,
|
| 97 |
+
) -> list[ClipSample]:
|
| 98 |
+
"""枚举所有可用 (clip, chunk, weather, anchor_t) 样本。
|
| 99 |
+
|
| 100 |
+
锚点 ``t`` 在 chunk 内为局部索引,对应视频帧 ``t``,对应标签帧
|
| 101 |
+
``chunk_offset + t``(chunk_offset = chunk_id * 121)。
|
| 102 |
+
"""
|
| 103 |
+
root = Path(data_root)
|
| 104 |
+
syn_dir = root / "synthetic" / "single_view" / "generation"
|
| 105 |
+
labels_dir = root / "labels"
|
| 106 |
+
|
| 107 |
+
samples: list[ClipSample] = []
|
| 108 |
+
if not syn_dir.exists():
|
| 109 |
+
return samples
|
| 110 |
+
|
| 111 |
+
for video_path in sorted(syn_dir.glob("*.mp4")):
|
| 112 |
+
# 文件名形如 {clip_id}_{chunk_id}_{weather}.mp4
|
| 113 |
+
# clip_id 可能含下划线(UUID 或 timestamp 形式),所以从右侧解析
|
| 114 |
+
stem = video_path.stem
|
| 115 |
+
parts = stem.rsplit("_", 2)
|
| 116 |
+
if len(parts) != 3:
|
| 117 |
+
continue
|
| 118 |
+
clip_id, chunk_str, weather = parts
|
| 119 |
+
try:
|
| 120 |
+
chunk_id = int(chunk_str)
|
| 121 |
+
except ValueError:
|
| 122 |
+
continue
|
| 123 |
+
if chunk_id not in chunk_ids or weather not in weathers:
|
| 124 |
+
continue
|
| 125 |
+
|
| 126 |
+
clip_label_dir = labels_dir / clip_id
|
| 127 |
+
if not clip_label_dir.exists():
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
chunk_offset = chunk_id * 121
|
| 131 |
+
for t in range(anchor_min, anchor_max + 1, stride):
|
| 132 |
+
samples.append(
|
| 133 |
+
ClipSample(
|
| 134 |
+
clip_id=clip_id,
|
| 135 |
+
chunk_id=chunk_id,
|
| 136 |
+
weather=weather,
|
| 137 |
+
video_path=video_path,
|
| 138 |
+
labels_dir=clip_label_dir,
|
| 139 |
+
anchor_t=t,
|
| 140 |
+
chunk_offset=chunk_offset,
|
| 141 |
+
)
|
| 142 |
+
)
|
| 143 |
+
if max_clips is not None and len({s.clip_id for s in samples}) >= max_clips:
|
| 144 |
+
break
|
| 145 |
+
|
| 146 |
+
return samples
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _load_video_frames(
|
| 150 |
+
video_path: Path,
|
| 151 |
+
frame_indices: Sequence[int],
|
| 152 |
+
target_h: int,
|
| 153 |
+
target_w: int,
|
| 154 |
+
) -> torch.Tensor:
|
| 155 |
+
"""从 .mp4 中读取指定帧序列,调整大小并按 ``[T, 3, H, W]`` 返回 ``float32 in [0, 1]``。"""
|
| 156 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 157 |
+
if not cap.isOpened():
|
| 158 |
+
raise FileNotFoundError(f"无法打开视频: {video_path}")
|
| 159 |
+
frames = []
|
| 160 |
+
for idx in frame_indices:
|
| 161 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
| 162 |
+
ok, bgr = cap.read()
|
| 163 |
+
if not ok:
|
| 164 |
+
cap.release()
|
| 165 |
+
raise RuntimeError(f"读取帧 {idx} 失败: {video_path}")
|
| 166 |
+
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
| 167 |
+
rgb = cv2.resize(rgb, (target_w, target_h * 2), interpolation=cv2.INTER_AREA)
|
| 168 |
+
# 裁去上半部分(天空)后高度变为 target_h
|
| 169 |
+
rgb = rgb[target_h:, :, :]
|
| 170 |
+
rgb = rgb.astype(np.float32) / 255.0
|
| 171 |
+
frames.append(torch.from_numpy(rgb).permute(2, 0, 1)) # [3, H, W]
|
| 172 |
+
cap.release()
|
| 173 |
+
return torch.stack(frames, dim=0)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _load_npy(path: Path) -> np.ndarray:
|
| 177 |
+
return np.load(path, allow_pickle=False)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _load_object_info(path: Path) -> list[ObjectTrackInfo]:
|
| 181 |
+
"""解析单帧 all_object_info JSON。"""
|
| 182 |
+
if not path.exists():
|
| 183 |
+
return []
|
| 184 |
+
data = json.loads(path.read_text())
|
| 185 |
+
out = []
|
| 186 |
+
for tid, info in data.items():
|
| 187 |
+
T = torch.tensor(info["object_to_world"], dtype=torch.float32)
|
| 188 |
+
lwh = torch.tensor(info["object_lwh"], dtype=torch.float32)
|
| 189 |
+
out.append(
|
| 190 |
+
ObjectTrackInfo(
|
| 191 |
+
tracking_id=tid,
|
| 192 |
+
object_to_world=T,
|
| 193 |
+
lwh=lwh,
|
| 194 |
+
is_moving=bool(info.get("object_is_moving", False)),
|
| 195 |
+
object_type=str(info.get("object_type", "")),
|
| 196 |
+
)
|
| 197 |
+
)
|
| 198 |
+
return out
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _load_lidar_self_frame(
|
| 202 |
+
labels_dir: Path,
|
| 203 |
+
label_idx: int,
|
| 204 |
+
vehicle_pose: torch.Tensor,
|
| 205 |
+
max_history: int = 3,
|
| 206 |
+
) -> torch.Tensor | None:
|
| 207 |
+
"""读取与 ``label_idx`` 时间最近的 LIDAR 帧并把 xyz 转到当前 ego self 系。
|
| 208 |
+
|
| 209 |
+
LIDAR 是 10 FPS(每 3 个相机帧 1 个 LIDAR 帧),数据集存储 ``000000``、
|
| 210 |
+
``000003``、``000006`` 等步长 3 的索引。我们向下取整最近的一帧。
|
| 211 |
+
"""
|
| 212 |
+
lidar_idx = (label_idx // 3) * 3
|
| 213 |
+
search_order = [lidar_idx - back * 3 for back in range(max_history + 1) if lidar_idx - back * 3 >= 0]
|
| 214 |
+
p: Path | None = None
|
| 215 |
+
for idx_try in search_order:
|
| 216 |
+
try:
|
| 217 |
+
p = resolve_clip_file(labels_dir, "lidar_raw", f"{idx_try:06d}.lidar_raw.npz")
|
| 218 |
+
break
|
| 219 |
+
except FileNotFoundError:
|
| 220 |
+
continue
|
| 221 |
+
if p is None:
|
| 222 |
+
return None
|
| 223 |
+
arr = np.load(p, allow_pickle=False)
|
| 224 |
+
xyz_lidar = arr["xyz"] # [N, 3] in lidar frame
|
| 225 |
+
lidar_to_world = arr["lidar_to_world"] # [4, 4]
|
| 226 |
+
# 转到 world 后再转 self
|
| 227 |
+
pts_w = (lidar_to_world[:3, :3] @ xyz_lidar.T).T + lidar_to_world[:3, 3]
|
| 228 |
+
inv_pose = torch.linalg.inv(vehicle_pose)
|
| 229 |
+
pts_w_t = torch.from_numpy(pts_w).float()
|
| 230 |
+
pts_self = (inv_pose[:3, :3] @ pts_w_t.T).T + inv_pose[:3, 3]
|
| 231 |
+
return pts_self
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class CosmosDriveDreamsDataset(Dataset):
|
| 235 |
+
"""端到端样本:8 帧图像 + ego/intr/extr + 检测 + 自车未来 + 对象未来。"""
|
| 236 |
+
|
| 237 |
+
def __init__(
|
| 238 |
+
self,
|
| 239 |
+
data_root: str | Path,
|
| 240 |
+
samples: list[ClipSample] | None = None,
|
| 241 |
+
weathers: Sequence[str] = ("Sunny",),
|
| 242 |
+
camera_name: str = "camera_front_wide_120fov",
|
| 243 |
+
image_h: int = 384,
|
| 244 |
+
image_w: int = 1024,
|
| 245 |
+
num_history: int = 8,
|
| 246 |
+
future_horizon: int = 24,
|
| 247 |
+
max_distance_m: float = 48.0,
|
| 248 |
+
occlusion_tol: float = 0.5,
|
| 249 |
+
dynamic_classes: Sequence[str] = DEFAULT_DYNAMIC_CLASSES,
|
| 250 |
+
structured_classes: Sequence[str] = DEFAULT_STRUCTURED_CLASSES,
|
| 251 |
+
do_normalize: bool = True,
|
| 252 |
+
use_lidar_occlusion: bool = True,
|
| 253 |
+
use_hdmap: bool = True,
|
| 254 |
+
) -> None:
|
| 255 |
+
super().__init__()
|
| 256 |
+
self.data_root = Path(data_root)
|
| 257 |
+
self.samples = samples if samples is not None else build_clip_index(
|
| 258 |
+
data_root, weathers=weathers, camera_name=camera_name
|
| 259 |
+
)
|
| 260 |
+
self.camera_name = camera_name
|
| 261 |
+
self.image_h = image_h
|
| 262 |
+
self.image_w = image_w
|
| 263 |
+
self.num_history = num_history
|
| 264 |
+
self.future_horizon = future_horizon
|
| 265 |
+
self.max_distance_m = max_distance_m
|
| 266 |
+
self.occlusion_tol = occlusion_tol
|
| 267 |
+
self.dynamic_classes = list(dynamic_classes)
|
| 268 |
+
self.structured_classes = list(structured_classes)
|
| 269 |
+
self.do_normalize = do_normalize
|
| 270 |
+
self.use_lidar_occlusion = use_lidar_occlusion
|
| 271 |
+
self.use_hdmap = use_hdmap
|
| 272 |
+
# HDMap 是 per-clip 静态对象,缓存避免每个 anchor_t 都重新解析
|
| 273 |
+
self._hdmap_cache: dict[str, list[ObjectTrackInfo]] = {}
|
| 274 |
+
self._hdmap_cache_max = 32
|
| 275 |
+
|
| 276 |
+
def __len__(self) -> int:
|
| 277 |
+
return len(self.samples)
|
| 278 |
+
|
| 279 |
+
def _load_intrinsic(self, sample: ClipSample) -> torch.Tensor:
|
| 280 |
+
p = resolve_clip_file(
|
| 281 |
+
sample.labels_dir,
|
| 282 |
+
"ftheta_intrinsic",
|
| 283 |
+
f"ftheta_intrinsic.{self.camera_name}.npy",
|
| 284 |
+
)
|
| 285 |
+
return torch.from_numpy(_load_npy(p)).float()
|
| 286 |
+
|
| 287 |
+
def _load_pose_camera(self, sample: ClipSample, label_idx: int) -> torch.Tensor:
|
| 288 |
+
p = resolve_clip_file(
|
| 289 |
+
sample.labels_dir,
|
| 290 |
+
"pose",
|
| 291 |
+
f"{label_idx:06d}.pose.{self.camera_name}.npy",
|
| 292 |
+
)
|
| 293 |
+
return torch.from_numpy(_load_npy(p)).float()
|
| 294 |
+
|
| 295 |
+
def _load_pose_vehicle(self, sample: ClipSample, label_idx: int) -> torch.Tensor:
|
| 296 |
+
p = resolve_clip_file(
|
| 297 |
+
sample.labels_dir,
|
| 298 |
+
"vehicle_pose",
|
| 299 |
+
f"{label_idx:06d}.vehicle_pose.npy",
|
| 300 |
+
)
|
| 301 |
+
return torch.from_numpy(_load_npy(p)).float()
|
| 302 |
+
|
| 303 |
+
def _load_hdmap_static(self, clip_dir: Path) -> list[ObjectTrackInfo]:
|
| 304 |
+
if not self.use_hdmap:
|
| 305 |
+
return []
|
| 306 |
+
key = str(clip_dir)
|
| 307 |
+
cached = self._hdmap_cache.get(key)
|
| 308 |
+
if cached is not None:
|
| 309 |
+
return cached
|
| 310 |
+
objs = parse_hdmap_clip(clip_dir)
|
| 311 |
+
if len(self._hdmap_cache) >= self._hdmap_cache_max:
|
| 312 |
+
self._hdmap_cache.pop(next(iter(self._hdmap_cache)))
|
| 313 |
+
self._hdmap_cache[key] = objs
|
| 314 |
+
return objs
|
| 315 |
+
|
| 316 |
+
def _load_objects(self, sample: ClipSample, label_idx: int) -> list[ObjectTrackInfo]:
|
| 317 |
+
p = resolve_clip_file(
|
| 318 |
+
sample.labels_dir,
|
| 319 |
+
"all_object_info",
|
| 320 |
+
f"{label_idx:06d}.all_object_info.json",
|
| 321 |
+
)
|
| 322 |
+
dynamic = _load_object_info(p)
|
| 323 |
+
# HDMap 是 clip 级静态标签:t 与 t+k 帧都拿同一份(tracking_id 相同),
|
| 324 |
+
# 这样 ``build_detection_targets`` 的未来轨迹分支会自动得到 ~0 残差,
|
| 325 |
+
# 同时由 ``is_dynamic=0`` 在损失里被 mask 掉,不进 trajectory NLL。
|
| 326 |
+
return dynamic + self._load_hdmap_static(sample.labels_dir)
|
| 327 |
+
|
| 328 |
+
def __getitem__(self, idx: int) -> dict:
|
| 329 |
+
s = self.samples[idx]
|
| 330 |
+
# 视频帧索引(chunk 内 0-based)
|
| 331 |
+
t = s.anchor_t
|
| 332 |
+
history_frames = list(range(t - self.num_history + 1, t + 1))
|
| 333 |
+
# 标签索引:chunk_offset + chunk-local idx
|
| 334 |
+
history_label_idx = [s.chunk_offset + f for f in history_frames]
|
| 335 |
+
future_label_idx = [s.chunk_offset + t + 1 + k for k in range(self.future_horizon)]
|
| 336 |
+
|
| 337 |
+
# === 1) 加载图像 ===
|
| 338 |
+
# 注意:videl 已经裁过上半(数据生成时仍 1920x1080 等原始分辨率);
|
| 339 |
+
# 这里在 _load_video_frames 内同时做 resize 与 top-half 裁剪。
|
| 340 |
+
images = _load_video_frames(s.video_path, history_frames, self.image_h, self.image_w)
|
| 341 |
+
# [T, 3, H, W],[0, 1]
|
| 342 |
+
if self.do_normalize:
|
| 343 |
+
images = (images - DINOV3_MEAN) / DINOV3_STD
|
| 344 |
+
|
| 345 |
+
# === 2) 加载内参 / 外参 ===
|
| 346 |
+
intr_vec = self._load_intrinsic(s) # [14]
|
| 347 |
+
|
| 348 |
+
# 当前帧的 cam_to_world 与 vehicle_to_world,得到 cam_to_vehicle
|
| 349 |
+
pose_cam_world = self._load_pose_camera(s, s.chunk_offset + t)
|
| 350 |
+
pose_veh_world = self._load_pose_vehicle(s, s.chunk_offset + t)
|
| 351 |
+
# cam_to_vehicle = inv(vehicle_to_world) @ cam_to_world
|
| 352 |
+
inv_veh = torch.linalg.inv(pose_veh_world)
|
| 353 |
+
cam2veh = inv_veh @ pose_cam_world
|
| 354 |
+
extr_6d = matrix_to_6d(cam2veh) # [6]
|
| 355 |
+
|
| 356 |
+
# === 3) 历史 8 帧 ego pose(vehicle 6D)===
|
| 357 |
+
ego_6d = []
|
| 358 |
+
for li in history_label_idx:
|
| 359 |
+
T_vw = self._load_pose_vehicle(s, li)
|
| 360 |
+
ego_6d.append(matrix_to_6d(T_vw))
|
| 361 |
+
ego_6d = torch.stack(ego_6d, dim=0) # [8, 6]
|
| 362 |
+
|
| 363 |
+
# === 4) 检测 / 未来轨迹标签 ===
|
| 364 |
+
# objs_t / objs_future = 动态 all_object_info ∪ HDMap 静态对象。
|
| 365 |
+
objs_t = self._load_objects(s, s.chunk_offset + t)
|
| 366 |
+
objs_future = [self._load_objects(s, li) for li in future_label_idx]
|
| 367 |
+
veh_pose_future = []
|
| 368 |
+
for li in future_label_idx:
|
| 369 |
+
try:
|
| 370 |
+
veh_pose_future.append(self._load_pose_vehicle(s, li))
|
| 371 |
+
except FileNotFoundError:
|
| 372 |
+
break
|
| 373 |
+
|
| 374 |
+
cam = FThetaCamera.from_vector(intr_vec)
|
| 375 |
+
lidar_self = None
|
| 376 |
+
if self.use_lidar_occlusion:
|
| 377 |
+
try:
|
| 378 |
+
lidar_self = _load_lidar_self_frame(
|
| 379 |
+
s.labels_dir,
|
| 380 |
+
s.chunk_offset + t,
|
| 381 |
+
pose_veh_world,
|
| 382 |
+
)
|
| 383 |
+
except Exception:
|
| 384 |
+
lidar_self = None
|
| 385 |
+
|
| 386 |
+
det_targets = build_detection_targets(
|
| 387 |
+
objects_t=objs_t,
|
| 388 |
+
objects_future=objs_future,
|
| 389 |
+
vehicle_pose_t=pose_veh_world,
|
| 390 |
+
vehicle_pose_future=veh_pose_future,
|
| 391 |
+
cam_intrinsic=cam,
|
| 392 |
+
cam2vehicle=cam2veh,
|
| 393 |
+
image_h=self.image_h,
|
| 394 |
+
image_w=self.image_w,
|
| 395 |
+
max_distance_m=self.max_distance_m,
|
| 396 |
+
occlusion_depth_tolerance=self.occlusion_tol,
|
| 397 |
+
lidar_points_self=lidar_self,
|
| 398 |
+
dynamic_classes=self.dynamic_classes,
|
| 399 |
+
structured_classes=self.structured_classes,
|
| 400 |
+
future_horizon=self.future_horizon,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
ego_future, ego_future_valid = build_ego_future_target(
|
| 404 |
+
pose_veh_world, veh_pose_future, horizon=self.future_horizon
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
sample_out = {
|
| 408 |
+
"images": images,
|
| 409 |
+
"ego_6d": ego_6d,
|
| 410 |
+
"intr_vec": intr_vec,
|
| 411 |
+
"extr_6d": extr_6d,
|
| 412 |
+
"ego_future": ego_future,
|
| 413 |
+
"ego_future_valid": ego_future_valid,
|
| 414 |
+
"targets": det_targets,
|
| 415 |
+
"meta": {
|
| 416 |
+
"clip_id": s.clip_id,
|
| 417 |
+
"chunk_id": s.chunk_id,
|
| 418 |
+
"weather": s.weather,
|
| 419 |
+
"anchor_t": s.anchor_t,
|
| 420 |
+
},
|
| 421 |
+
}
|
| 422 |
+
return sample_out
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def collate_samples(batch: list[dict]) -> dict:
|
| 426 |
+
"""自定义 collate:对图像 / ego / intr / extr / ego_future 直接 stack;
|
| 427 |
+
targets 列表保留为 list(便于匈牙利匹配处理变长 N);
|
| 428 |
+
meta 也保留为 list。"""
|
| 429 |
+
out = {
|
| 430 |
+
"images": torch.stack([b["images"] for b in batch], dim=0),
|
| 431 |
+
"ego_6d": torch.stack([b["ego_6d"] for b in batch], dim=0),
|
| 432 |
+
"intr_vec": torch.stack([b["intr_vec"] for b in batch], dim=0),
|
| 433 |
+
"extr_6d": torch.stack([b["extr_6d"] for b in batch], dim=0),
|
| 434 |
+
"ego_future": torch.stack([b["ego_future"] for b in batch], dim=0),
|
| 435 |
+
"ego_future_valid": torch.stack([b["ego_future_valid"] for b in batch], dim=0),
|
| 436 |
+
"targets": [b["targets"] for b in batch],
|
| 437 |
+
"meta": [b["meta"] for b in batch],
|
| 438 |
+
}
|
| 439 |
+
return out
|
src/wjad/data/ftheta_proj.py
CHANGED
|
@@ -1,62 +1,62 @@
|
|
| 1 |
-
"""f-theta 正向投影:3D 点(相机系) -> 像素。
|
| 2 |
-
|
| 3 |
-
仅支持 backward polynomial 形式(与 NVIDIA 工具一致),
|
| 4 |
-
forward 形式可在内部用牛顿迭代反推。
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from __future__ import annotations
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
|
| 11 |
-
from ..modules.rays import FThetaCamera
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def project_points_ftheta(
|
| 15 |
-
points_cam: torch.Tensor, # [..., 3],相机系下 3D 点
|
| 16 |
-
cam: FThetaCamera,
|
| 17 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 18 |
-
"""正向投影:相机系点 -> 像素 ``(u, v)``,并返回深度。
|
| 19 |
-
|
| 20 |
-
返回
|
| 21 |
-
----
|
| 22 |
-
uv : [..., 2]
|
| 23 |
-
depth : [..., 1],沿主光轴(z)方向的深度(如果 z<0 则视为后方,仍计算
|
| 24 |
-
但调用方需用 ``depth > 0`` 做有效性筛选)。
|
| 25 |
-
"""
|
| 26 |
-
x = points_cam[..., 0]
|
| 27 |
-
y = points_cam[..., 1]
|
| 28 |
-
z = points_cam[..., 2]
|
| 29 |
-
norm = torch.sqrt(x * x + y * y + z * z).clamp_min(1e-6)
|
| 30 |
-
cos_theta = z / norm
|
| 31 |
-
cos_theta = cos_theta.clamp(-1.0 + 1e-7, 1.0 - 1e-7)
|
| 32 |
-
theta = torch.acos(cos_theta)
|
| 33 |
-
phi = torch.atan2(y, x)
|
| 34 |
-
|
| 35 |
-
if cam.intr.is_bw_poly:
|
| 36 |
-
# backward poly 是 r_pix -> theta;正向需要反求 theta -> r_pix。
|
| 37 |
-
# 用牛顿迭代:希望 _eval_poly(r) = theta
|
| 38 |
-
r = theta.clone() # 初始猜测
|
| 39 |
-
for _ in range(8):
|
| 40 |
-
f = cam._eval_poly(r) - theta
|
| 41 |
-
df = cam._eval_poly_grad(r).clamp_min(1e-6)
|
| 42 |
-
r = r - f / df
|
| 43 |
-
r_pix = r
|
| 44 |
-
else:
|
| 45 |
-
r_pix = cam._eval_poly(theta)
|
| 46 |
-
|
| 47 |
-
cos_p = torch.cos(phi)
|
| 48 |
-
sin_p = torch.sin(phi)
|
| 49 |
-
du = r_pix * cos_p
|
| 50 |
-
dv = r_pix * sin_p
|
| 51 |
-
# 反线性修正:linear_cde 是仿射 (du,dv) = M (du0,dv0),正投影需要逆
|
| 52 |
-
c = cam.intr.linear_cde[0]
|
| 53 |
-
d = cam.intr.linear_cde[1]
|
| 54 |
-
e = cam.intr.linear_cde[2]
|
| 55 |
-
# 简化:忽略 linear_cde 的修正(与 unproject 中近似一致)
|
| 56 |
-
du0 = du
|
| 57 |
-
dv0 = dv
|
| 58 |
-
u = du0 + cam.intr.cx
|
| 59 |
-
v = dv0 + cam.intr.cy
|
| 60 |
-
uv = torch.stack([u, v], dim=-1)
|
| 61 |
-
depth = z.unsqueeze(-1)
|
| 62 |
-
return uv, depth
|
|
|
|
| 1 |
+
"""f-theta 正向投影:3D 点(相机系) -> 像素。
|
| 2 |
+
|
| 3 |
+
仅支持 backward polynomial 形式(与 NVIDIA 工具一致),
|
| 4 |
+
forward 形式可在内部用牛顿迭代反推。
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from ..modules.rays import FThetaCamera
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def project_points_ftheta(
|
| 15 |
+
points_cam: torch.Tensor, # [..., 3],相机系下 3D 点
|
| 16 |
+
cam: FThetaCamera,
|
| 17 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 18 |
+
"""正向投影:相机系点 -> 像素 ``(u, v)``,并返回深度。
|
| 19 |
+
|
| 20 |
+
返回
|
| 21 |
+
----
|
| 22 |
+
uv : [..., 2]
|
| 23 |
+
depth : [..., 1],沿主光轴(z)方向的深度(如果 z<0 则视为后方,仍计算
|
| 24 |
+
但调用方需用 ``depth > 0`` 做有效性筛选)。
|
| 25 |
+
"""
|
| 26 |
+
x = points_cam[..., 0]
|
| 27 |
+
y = points_cam[..., 1]
|
| 28 |
+
z = points_cam[..., 2]
|
| 29 |
+
norm = torch.sqrt(x * x + y * y + z * z).clamp_min(1e-6)
|
| 30 |
+
cos_theta = z / norm
|
| 31 |
+
cos_theta = cos_theta.clamp(-1.0 + 1e-7, 1.0 - 1e-7)
|
| 32 |
+
theta = torch.acos(cos_theta)
|
| 33 |
+
phi = torch.atan2(y, x)
|
| 34 |
+
|
| 35 |
+
if cam.intr.is_bw_poly:
|
| 36 |
+
# backward poly 是 r_pix -> theta;正向需要反求 theta -> r_pix。
|
| 37 |
+
# 用牛顿迭代:希望 _eval_poly(r) = theta
|
| 38 |
+
r = theta.clone() # 初始猜测
|
| 39 |
+
for _ in range(8):
|
| 40 |
+
f = cam._eval_poly(r) - theta
|
| 41 |
+
df = cam._eval_poly_grad(r).clamp_min(1e-6)
|
| 42 |
+
r = r - f / df
|
| 43 |
+
r_pix = r
|
| 44 |
+
else:
|
| 45 |
+
r_pix = cam._eval_poly(theta)
|
| 46 |
+
|
| 47 |
+
cos_p = torch.cos(phi)
|
| 48 |
+
sin_p = torch.sin(phi)
|
| 49 |
+
du = r_pix * cos_p
|
| 50 |
+
dv = r_pix * sin_p
|
| 51 |
+
# 反线性修正:linear_cde 是仿射 (du,dv) = M (du0,dv0),正投影需要逆
|
| 52 |
+
c = cam.intr.linear_cde[0]
|
| 53 |
+
d = cam.intr.linear_cde[1]
|
| 54 |
+
e = cam.intr.linear_cde[2]
|
| 55 |
+
# 简化:忽略 linear_cde 的修正(与 unproject 中近似一致)
|
| 56 |
+
du0 = du
|
| 57 |
+
dv0 = dv
|
| 58 |
+
u = du0 + cam.intr.cx
|
| 59 |
+
v = dv0 + cam.intr.cy
|
| 60 |
+
uv = torch.stack([u, v], dim=-1)
|
| 61 |
+
depth = z.unsqueeze(-1)
|
| 62 |
+
return uv, depth
|
src/wjad/data/hdmap.py
CHANGED
|
@@ -1,247 +1,247 @@
|
|
| 1 |
-
"""HDMap 3D 标签解析(Cosmos-Drive-Dreams 9 类结构化对象)。
|
| 2 |
-
|
| 3 |
-
输入:clip 标签目录(``labels/{clip_id_full}/``)。
|
| 4 |
-
输出:``list[ObjectTrackInfo]``,每个对象给出 ``object_to_world`` 4x4 + ``lwh``,
|
| 5 |
-
``object_type`` 取自 ``HDMAP_SOURCES`` 的 9 类,``is_moving=False``。
|
| 6 |
-
|
| 7 |
-
形状约定(按 README):
|
| 8 |
-
- 3d_lanes / lanes.json
|
| 9 |
-
labels[i]['labelData']['shape3d']['polylines3d']['polylines'][0/1]['vertices']
|
| 10 |
-
- 3d_lanelines / lanelines.json
|
| 11 |
-
labels[i]['labelData']['shape3d']['polyline3d']['vertices']
|
| 12 |
-
- 3d_road_boundaries / road_boundaries.json 同 polyline3d
|
| 13 |
-
- 3d_wait_lines / wait_lines.json 同 polyline3d
|
| 14 |
-
- 3d_crosswalks / crosswalks.json
|
| 15 |
-
labels[i]['labelData']['shape3d']['surface']['vertices']
|
| 16 |
-
- 3d_road_markings / road_markings.json 同 surface
|
| 17 |
-
- 3d_poles / poles.json 同 polyline3d
|
| 18 |
-
- 3d_traffic_lights / 3d_traffic_lights.json
|
| 19 |
-
labels[i]['labelData']['shape3d']['cuboid3d']['vertices'] # 8 角点
|
| 20 |
-
- 3d_traffic_signs / 3d_traffic_signs.json 同 cuboid3d
|
| 21 |
-
|
| 22 |
-
折线 → 7-DoF box:
|
| 23 |
-
PCA 主方向作 yaw,主/副/竖三向 min-max 作 ``l/w/h``;过长 polyline 按累计
|
| 24 |
-
弧长切成若干 ``segment_len`` 米的小段,每段一个独立 box(车道线一段太长会
|
| 25 |
-
超出 max_distance_m,DETR query 也很难一次拟合一整条 100 m 车道线)。
|
| 26 |
-
"""
|
| 27 |
-
|
| 28 |
-
from __future__ import annotations
|
| 29 |
-
|
| 30 |
-
import json
|
| 31 |
-
from pathlib import Path
|
| 32 |
-
|
| 33 |
-
import numpy as np
|
| 34 |
-
import torch
|
| 35 |
-
|
| 36 |
-
from .targets import ObjectTrackInfo
|
| 37 |
-
from .label_paths import resolve_clip_file
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
# 折线类长度切分阈值(米)
|
| 41 |
-
POLYLINE_SEGMENT_LEN = 10.0
|
| 42 |
-
LANE_SEGMENT_LEN = 15.0 # lanes 是一对左右 polyline,整体粗一点
|
| 43 |
-
MIN_LWH = (0.2, 0.2, 0.05)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
# cls_name -> (folder, json_name, kind)
|
| 47 |
-
HDMAP_SOURCES = {
|
| 48 |
-
"lane": ("3d_lanes", "lanes.json", "lane_pair"),
|
| 49 |
-
"laneline": ("3d_lanelines", "lanelines.json", "polyline"),
|
| 50 |
-
"road_boundary": ("3d_road_boundaries", "road_boundaries.json", "polyline"),
|
| 51 |
-
"wait_line": ("3d_wait_lines", "wait_lines.json", "polyline"),
|
| 52 |
-
"crosswalk": ("3d_crosswalks", "crosswalks.json", "surface"),
|
| 53 |
-
"road_marking": ("3d_road_markings", "road_markings.json", "surface"),
|
| 54 |
-
"pole": ("3d_poles", "poles.json", "polyline_short"),
|
| 55 |
-
# 磁盘文件名为 ``{clip_stem}.traffic_lights.json``(非 README 里的 3d_*.json)
|
| 56 |
-
"traffic_light": ("3d_traffic_lights", "traffic_lights.json", "cuboid"),
|
| 57 |
-
"traffic_sign": ("3d_traffic_signs", "traffic_signs.json", "cuboid"),
|
| 58 |
-
}
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def _load_json_labels(path: Path) -> list:
|
| 62 |
-
"""容错读取:JSON 顶层可能是 ``{labels: ...}`` 或 ``{<filename>: {labels: ...}}``。"""
|
| 63 |
-
if not path.exists():
|
| 64 |
-
return []
|
| 65 |
-
try:
|
| 66 |
-
data = json.loads(path.read_text(encoding="utf-8"))
|
| 67 |
-
except Exception:
|
| 68 |
-
return []
|
| 69 |
-
if isinstance(data, dict):
|
| 70 |
-
if isinstance(data.get("labels"), list):
|
| 71 |
-
return data["labels"]
|
| 72 |
-
for v in data.values():
|
| 73 |
-
if isinstance(v, dict) and isinstance(v.get("labels"), list):
|
| 74 |
-
return v["labels"]
|
| 75 |
-
return []
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
def _verts_to_array(verts) -> np.ndarray:
|
| 79 |
-
"""vertices 兼容 ``list[[x,y,z]]`` 与 ``list[{x,y,z}]`` 两种格式。"""
|
| 80 |
-
if not verts:
|
| 81 |
-
return np.zeros((0, 3), dtype=np.float32)
|
| 82 |
-
out: list[list[float]] = []
|
| 83 |
-
for v in verts:
|
| 84 |
-
if isinstance(v, dict):
|
| 85 |
-
out.append([float(v.get("x", 0.0)), float(v.get("y", 0.0)), float(v.get("z", 0.0))])
|
| 86 |
-
elif isinstance(v, (list, tuple)) and len(v) >= 3:
|
| 87 |
-
out.append([float(v[0]), float(v[1]), float(v[2])])
|
| 88 |
-
return np.array(out, dtype=np.float32) if out else np.zeros((0, 3), dtype=np.float32)
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
def _split_polyline(verts: np.ndarray, seg_len: float) -> list[np.ndarray]:
|
| 92 |
-
"""按累计弧长把折线切成若干段。每段顶点数 >=2。"""
|
| 93 |
-
if verts.shape[0] < 2:
|
| 94 |
-
return []
|
| 95 |
-
edges = np.linalg.norm(np.diff(verts, axis=0), axis=1)
|
| 96 |
-
cum = np.concatenate([[0.0], np.cumsum(edges)])
|
| 97 |
-
total = float(cum[-1])
|
| 98 |
-
if total <= seg_len:
|
| 99 |
-
return [verts]
|
| 100 |
-
n = max(1, int(np.ceil(total / seg_len)))
|
| 101 |
-
bounds = np.linspace(0.0, total, n + 1)
|
| 102 |
-
chunks: list[np.ndarray] = []
|
| 103 |
-
for i in range(n):
|
| 104 |
-
lo, hi = bounds[i], bounds[i + 1]
|
| 105 |
-
mask = (cum >= lo - 1e-6) & (cum <= hi + 1e-6)
|
| 106 |
-
chunk = verts[mask]
|
| 107 |
-
if chunk.shape[0] >= 2:
|
| 108 |
-
chunks.append(chunk)
|
| 109 |
-
return chunks
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
def _vertices_to_box(verts: np.ndarray) -> tuple[np.ndarray, np.ndarray, float] | None:
|
| 113 |
-
"""[N, 3] -> (center, lwh, yaw)。"""
|
| 114 |
-
if verts.shape[0] < 2:
|
| 115 |
-
return None
|
| 116 |
-
center = verts.mean(0)
|
| 117 |
-
centered_xy = verts[:, :2] - center[:2]
|
| 118 |
-
if np.allclose(centered_xy, 0.0):
|
| 119 |
-
yaw = 0.0
|
| 120 |
-
else:
|
| 121 |
-
cov = centered_xy.T @ centered_xy / max(verts.shape[0] - 1, 1)
|
| 122 |
-
_, eigvecs = np.linalg.eigh(cov)
|
| 123 |
-
principal = eigvecs[:, -1]
|
| 124 |
-
yaw = float(np.arctan2(principal[1], principal[0]))
|
| 125 |
-
c, s = float(np.cos(-yaw)), float(np.sin(-yaw))
|
| 126 |
-
rot_xy = centered_xy @ np.array([[c, -s], [s, c]], dtype=np.float32).T
|
| 127 |
-
l = float(rot_xy[:, 0].max() - rot_xy[:, 0].min())
|
| 128 |
-
w = float(rot_xy[:, 1].max() - rot_xy[:, 1].min())
|
| 129 |
-
h = float(verts[:, 2].max() - verts[:, 2].min())
|
| 130 |
-
l = max(l, MIN_LWH[0])
|
| 131 |
-
w = max(w, MIN_LWH[1])
|
| 132 |
-
h = max(h, MIN_LWH[2])
|
| 133 |
-
return center.astype(np.float32), np.array([l, w, h], dtype=np.float32), yaw
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
def _cuboid_to_box(corners: np.ndarray) -> tuple[np.ndarray, np.ndarray, float]:
|
| 137 |
-
"""8 角点 -> (center, lwh, yaw)。用 corner[0]→corner[1] 估计 yaw。"""
|
| 138 |
-
center = corners.mean(0)
|
| 139 |
-
edge = corners[1] - corners[0]
|
| 140 |
-
yaw = float(np.arctan2(edge[1], edge[0]))
|
| 141 |
-
c, s = float(np.cos(-yaw)), float(np.sin(-yaw))
|
| 142 |
-
R = np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32)
|
| 143 |
-
rot = (corners - center) @ R.T
|
| 144 |
-
lwh = (rot.max(0) - rot.min(0)).astype(np.float32)
|
| 145 |
-
lwh = np.maximum(lwh, np.array(MIN_LWH, dtype=np.float32))
|
| 146 |
-
return center.astype(np.float32), lwh, yaw
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
def _build_object(
|
| 150 |
-
center: np.ndarray,
|
| 151 |
-
lwh: np.ndarray,
|
| 152 |
-
yaw: float,
|
| 153 |
-
cls_name: str,
|
| 154 |
-
idx: int,
|
| 155 |
-
sub_idx: int = 0,
|
| 156 |
-
) -> ObjectTrackInfo:
|
| 157 |
-
T = np.eye(4, dtype=np.float32)
|
| 158 |
-
c, s = float(np.cos(yaw)), float(np.sin(yaw))
|
| 159 |
-
T[:3, :3] = np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32)
|
| 160 |
-
T[:3, 3] = center
|
| 161 |
-
return ObjectTrackInfo(
|
| 162 |
-
tracking_id=f"hdmap_{cls_name}_{idx}_{sub_idx}",
|
| 163 |
-
object_to_world=torch.from_numpy(T),
|
| 164 |
-
lwh=torch.from_numpy(lwh),
|
| 165 |
-
is_moving=False,
|
| 166 |
-
object_type=cls_name,
|
| 167 |
-
)
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
def parse_hdmap_clip(
|
| 171 |
-
clip_label_dir: Path,
|
| 172 |
-
segment_len: float = POLYLINE_SEGMENT_LEN,
|
| 173 |
-
lane_segment_len: float = LANE_SEGMENT_LEN,
|
| 174 |
-
) -> list[ObjectTrackInfo]:
|
| 175 |
-
"""解析一个 clip 的 9 类 HDMap,展开为 world-frame ``ObjectTrackInfo`` 列表。"""
|
| 176 |
-
out: list[ObjectTrackInfo] = []
|
| 177 |
-
for cls_name, (subdir, json_name, kind) in HDMAP_SOURCES.items():
|
| 178 |
-
try:
|
| 179 |
-
path = resolve_clip_file(clip_label_dir, subdir, json_name)
|
| 180 |
-
except FileNotFoundError:
|
| 181 |
-
continue
|
| 182 |
-
labels = _load_json_labels(path)
|
| 183 |
-
for i, lbl in enumerate(labels):
|
| 184 |
-
if not isinstance(lbl, dict):
|
| 185 |
-
continue
|
| 186 |
-
shape = lbl.get("labelData", {}).get("shape3d", {})
|
| 187 |
-
if not isinstance(shape, dict):
|
| 188 |
-
continue
|
| 189 |
-
|
| 190 |
-
if kind == "cuboid":
|
| 191 |
-
verts = shape.get("cuboid3d", {}).get("vertices", [])
|
| 192 |
-
arr = _verts_to_array(verts)
|
| 193 |
-
if arr.shape[0] != 8:
|
| 194 |
-
continue
|
| 195 |
-
c, lwh, yaw = _cuboid_to_box(arr)
|
| 196 |
-
out.append(_build_object(c, lwh, yaw, cls_name, i))
|
| 197 |
-
|
| 198 |
-
elif kind == "surface":
|
| 199 |
-
verts = shape.get("surface", {}).get("vertices", [])
|
| 200 |
-
arr = _verts_to_array(verts)
|
| 201 |
-
if arr.shape[0] < 3:
|
| 202 |
-
continue
|
| 203 |
-
box = _vertices_to_box(arr)
|
| 204 |
-
if box is not None:
|
| 205 |
-
out.append(_build_object(*box, cls_name, i))
|
| 206 |
-
|
| 207 |
-
elif kind == "polyline":
|
| 208 |
-
verts = shape.get("polyline3d", {}).get("vertices", [])
|
| 209 |
-
arr = _verts_to_array(verts)
|
| 210 |
-
if arr.shape[0] < 2:
|
| 211 |
-
continue
|
| 212 |
-
for j, chunk in enumerate(_split_polyline(arr, segment_len)):
|
| 213 |
-
box = _vertices_to_box(chunk)
|
| 214 |
-
if box is not None:
|
| 215 |
-
out.append(_build_object(*box, cls_name, i, j))
|
| 216 |
-
|
| 217 |
-
elif kind == "polyline_short":
|
| 218 |
-
# 杆状物体不切分
|
| 219 |
-
verts = shape.get("polyline3d", {}).get("vertices", [])
|
| 220 |
-
arr = _verts_to_array(verts)
|
| 221 |
-
if arr.shape[0] < 2:
|
| 222 |
-
continue
|
| 223 |
-
box = _vertices_to_box(arr)
|
| 224 |
-
if box is not None:
|
| 225 |
-
out.append(_build_object(*box, cls_name, i))
|
| 226 |
-
|
| 227 |
-
elif kind == "lane_pair":
|
| 228 |
-
pl_root = shape.get("polylines3d", {}).get("polylines", [])
|
| 229 |
-
if not isinstance(pl_root, list) or len(pl_root) < 2:
|
| 230 |
-
continue
|
| 231 |
-
left = _verts_to_array(
|
| 232 |
-
pl_root[0].get("vertices", []) if isinstance(pl_root[0], dict) else []
|
| 233 |
-
)
|
| 234 |
-
right = _verts_to_array(
|
| 235 |
-
pl_root[1].get("vertices", []) if isinstance(pl_root[1], dict) else []
|
| 236 |
-
)
|
| 237 |
-
if left.shape[0] == 0 and right.shape[0] == 0:
|
| 238 |
-
continue
|
| 239 |
-
merged = np.concatenate([a for a in (left, right) if a.shape[0]], axis=0)
|
| 240 |
-
if merged.shape[0] < 2:
|
| 241 |
-
continue
|
| 242 |
-
for j, chunk in enumerate(_split_polyline(merged, lane_segment_len)):
|
| 243 |
-
box = _vertices_to_box(chunk)
|
| 244 |
-
if box is not None:
|
| 245 |
-
out.append(_build_object(*box, cls_name, i, j))
|
| 246 |
-
|
| 247 |
-
return out
|
|
|
|
| 1 |
+
"""HDMap 3D 标签解析(Cosmos-Drive-Dreams 9 类结构化对象)。
|
| 2 |
+
|
| 3 |
+
输入:clip 标签目录(``labels/{clip_id_full}/``)。
|
| 4 |
+
输出:``list[ObjectTrackInfo]``,每个对象给出 ``object_to_world`` 4x4 + ``lwh``,
|
| 5 |
+
``object_type`` 取自 ``HDMAP_SOURCES`` 的 9 类,``is_moving=False``。
|
| 6 |
+
|
| 7 |
+
形状约定(按 README):
|
| 8 |
+
- 3d_lanes / lanes.json
|
| 9 |
+
labels[i]['labelData']['shape3d']['polylines3d']['polylines'][0/1]['vertices']
|
| 10 |
+
- 3d_lanelines / lanelines.json
|
| 11 |
+
labels[i]['labelData']['shape3d']['polyline3d']['vertices']
|
| 12 |
+
- 3d_road_boundaries / road_boundaries.json 同 polyline3d
|
| 13 |
+
- 3d_wait_lines / wait_lines.json 同 polyline3d
|
| 14 |
+
- 3d_crosswalks / crosswalks.json
|
| 15 |
+
labels[i]['labelData']['shape3d']['surface']['vertices']
|
| 16 |
+
- 3d_road_markings / road_markings.json 同 surface
|
| 17 |
+
- 3d_poles / poles.json 同 polyline3d
|
| 18 |
+
- 3d_traffic_lights / 3d_traffic_lights.json
|
| 19 |
+
labels[i]['labelData']['shape3d']['cuboid3d']['vertices'] # 8 角点
|
| 20 |
+
- 3d_traffic_signs / 3d_traffic_signs.json 同 cuboid3d
|
| 21 |
+
|
| 22 |
+
折线 → 7-DoF box:
|
| 23 |
+
PCA 主方向作 yaw,主/副/竖三向 min-max 作 ``l/w/h``;过长 polyline 按累计
|
| 24 |
+
弧长切成若干 ``segment_len`` 米的小段,每段一个独立 box(车道线一段太长会
|
| 25 |
+
超出 max_distance_m,DETR query 也很难一次拟合一整条 100 m 车道线)。
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from __future__ import annotations
|
| 29 |
+
|
| 30 |
+
import json
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
|
| 33 |
+
import numpy as np
|
| 34 |
+
import torch
|
| 35 |
+
|
| 36 |
+
from .targets import ObjectTrackInfo
|
| 37 |
+
from .label_paths import resolve_clip_file
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# 折线类长度切分阈值(米)
|
| 41 |
+
POLYLINE_SEGMENT_LEN = 10.0
|
| 42 |
+
LANE_SEGMENT_LEN = 15.0 # lanes 是一对左右 polyline,整体粗一点
|
| 43 |
+
MIN_LWH = (0.2, 0.2, 0.05)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# cls_name -> (folder, json_name, kind)
|
| 47 |
+
HDMAP_SOURCES = {
|
| 48 |
+
"lane": ("3d_lanes", "lanes.json", "lane_pair"),
|
| 49 |
+
"laneline": ("3d_lanelines", "lanelines.json", "polyline"),
|
| 50 |
+
"road_boundary": ("3d_road_boundaries", "road_boundaries.json", "polyline"),
|
| 51 |
+
"wait_line": ("3d_wait_lines", "wait_lines.json", "polyline"),
|
| 52 |
+
"crosswalk": ("3d_crosswalks", "crosswalks.json", "surface"),
|
| 53 |
+
"road_marking": ("3d_road_markings", "road_markings.json", "surface"),
|
| 54 |
+
"pole": ("3d_poles", "poles.json", "polyline_short"),
|
| 55 |
+
# 磁盘文件名为 ``{clip_stem}.traffic_lights.json``(非 README 里的 3d_*.json)
|
| 56 |
+
"traffic_light": ("3d_traffic_lights", "traffic_lights.json", "cuboid"),
|
| 57 |
+
"traffic_sign": ("3d_traffic_signs", "traffic_signs.json", "cuboid"),
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _load_json_labels(path: Path) -> list:
|
| 62 |
+
"""容错读取:JSON 顶层可能是 ``{labels: ...}`` 或 ``{<filename>: {labels: ...}}``。"""
|
| 63 |
+
if not path.exists():
|
| 64 |
+
return []
|
| 65 |
+
try:
|
| 66 |
+
data = json.loads(path.read_text(encoding="utf-8"))
|
| 67 |
+
except Exception:
|
| 68 |
+
return []
|
| 69 |
+
if isinstance(data, dict):
|
| 70 |
+
if isinstance(data.get("labels"), list):
|
| 71 |
+
return data["labels"]
|
| 72 |
+
for v in data.values():
|
| 73 |
+
if isinstance(v, dict) and isinstance(v.get("labels"), list):
|
| 74 |
+
return v["labels"]
|
| 75 |
+
return []
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _verts_to_array(verts) -> np.ndarray:
|
| 79 |
+
"""vertices 兼容 ``list[[x,y,z]]`` 与 ``list[{x,y,z}]`` 两种格式。"""
|
| 80 |
+
if not verts:
|
| 81 |
+
return np.zeros((0, 3), dtype=np.float32)
|
| 82 |
+
out: list[list[float]] = []
|
| 83 |
+
for v in verts:
|
| 84 |
+
if isinstance(v, dict):
|
| 85 |
+
out.append([float(v.get("x", 0.0)), float(v.get("y", 0.0)), float(v.get("z", 0.0))])
|
| 86 |
+
elif isinstance(v, (list, tuple)) and len(v) >= 3:
|
| 87 |
+
out.append([float(v[0]), float(v[1]), float(v[2])])
|
| 88 |
+
return np.array(out, dtype=np.float32) if out else np.zeros((0, 3), dtype=np.float32)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _split_polyline(verts: np.ndarray, seg_len: float) -> list[np.ndarray]:
|
| 92 |
+
"""按累计弧长把折线切成若干段。每段顶点数 >=2。"""
|
| 93 |
+
if verts.shape[0] < 2:
|
| 94 |
+
return []
|
| 95 |
+
edges = np.linalg.norm(np.diff(verts, axis=0), axis=1)
|
| 96 |
+
cum = np.concatenate([[0.0], np.cumsum(edges)])
|
| 97 |
+
total = float(cum[-1])
|
| 98 |
+
if total <= seg_len:
|
| 99 |
+
return [verts]
|
| 100 |
+
n = max(1, int(np.ceil(total / seg_len)))
|
| 101 |
+
bounds = np.linspace(0.0, total, n + 1)
|
| 102 |
+
chunks: list[np.ndarray] = []
|
| 103 |
+
for i in range(n):
|
| 104 |
+
lo, hi = bounds[i], bounds[i + 1]
|
| 105 |
+
mask = (cum >= lo - 1e-6) & (cum <= hi + 1e-6)
|
| 106 |
+
chunk = verts[mask]
|
| 107 |
+
if chunk.shape[0] >= 2:
|
| 108 |
+
chunks.append(chunk)
|
| 109 |
+
return chunks
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _vertices_to_box(verts: np.ndarray) -> tuple[np.ndarray, np.ndarray, float] | None:
|
| 113 |
+
"""[N, 3] -> (center, lwh, yaw)。"""
|
| 114 |
+
if verts.shape[0] < 2:
|
| 115 |
+
return None
|
| 116 |
+
center = verts.mean(0)
|
| 117 |
+
centered_xy = verts[:, :2] - center[:2]
|
| 118 |
+
if np.allclose(centered_xy, 0.0):
|
| 119 |
+
yaw = 0.0
|
| 120 |
+
else:
|
| 121 |
+
cov = centered_xy.T @ centered_xy / max(verts.shape[0] - 1, 1)
|
| 122 |
+
_, eigvecs = np.linalg.eigh(cov)
|
| 123 |
+
principal = eigvecs[:, -1]
|
| 124 |
+
yaw = float(np.arctan2(principal[1], principal[0]))
|
| 125 |
+
c, s = float(np.cos(-yaw)), float(np.sin(-yaw))
|
| 126 |
+
rot_xy = centered_xy @ np.array([[c, -s], [s, c]], dtype=np.float32).T
|
| 127 |
+
l = float(rot_xy[:, 0].max() - rot_xy[:, 0].min())
|
| 128 |
+
w = float(rot_xy[:, 1].max() - rot_xy[:, 1].min())
|
| 129 |
+
h = float(verts[:, 2].max() - verts[:, 2].min())
|
| 130 |
+
l = max(l, MIN_LWH[0])
|
| 131 |
+
w = max(w, MIN_LWH[1])
|
| 132 |
+
h = max(h, MIN_LWH[2])
|
| 133 |
+
return center.astype(np.float32), np.array([l, w, h], dtype=np.float32), yaw
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _cuboid_to_box(corners: np.ndarray) -> tuple[np.ndarray, np.ndarray, float]:
|
| 137 |
+
"""8 角点 -> (center, lwh, yaw)。用 corner[0]→corner[1] 估计 yaw。"""
|
| 138 |
+
center = corners.mean(0)
|
| 139 |
+
edge = corners[1] - corners[0]
|
| 140 |
+
yaw = float(np.arctan2(edge[1], edge[0]))
|
| 141 |
+
c, s = float(np.cos(-yaw)), float(np.sin(-yaw))
|
| 142 |
+
R = np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32)
|
| 143 |
+
rot = (corners - center) @ R.T
|
| 144 |
+
lwh = (rot.max(0) - rot.min(0)).astype(np.float32)
|
| 145 |
+
lwh = np.maximum(lwh, np.array(MIN_LWH, dtype=np.float32))
|
| 146 |
+
return center.astype(np.float32), lwh, yaw
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _build_object(
|
| 150 |
+
center: np.ndarray,
|
| 151 |
+
lwh: np.ndarray,
|
| 152 |
+
yaw: float,
|
| 153 |
+
cls_name: str,
|
| 154 |
+
idx: int,
|
| 155 |
+
sub_idx: int = 0,
|
| 156 |
+
) -> ObjectTrackInfo:
|
| 157 |
+
T = np.eye(4, dtype=np.float32)
|
| 158 |
+
c, s = float(np.cos(yaw)), float(np.sin(yaw))
|
| 159 |
+
T[:3, :3] = np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32)
|
| 160 |
+
T[:3, 3] = center
|
| 161 |
+
return ObjectTrackInfo(
|
| 162 |
+
tracking_id=f"hdmap_{cls_name}_{idx}_{sub_idx}",
|
| 163 |
+
object_to_world=torch.from_numpy(T),
|
| 164 |
+
lwh=torch.from_numpy(lwh),
|
| 165 |
+
is_moving=False,
|
| 166 |
+
object_type=cls_name,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def parse_hdmap_clip(
|
| 171 |
+
clip_label_dir: Path,
|
| 172 |
+
segment_len: float = POLYLINE_SEGMENT_LEN,
|
| 173 |
+
lane_segment_len: float = LANE_SEGMENT_LEN,
|
| 174 |
+
) -> list[ObjectTrackInfo]:
|
| 175 |
+
"""解析一个 clip 的 9 类 HDMap,展开为 world-frame ``ObjectTrackInfo`` 列表。"""
|
| 176 |
+
out: list[ObjectTrackInfo] = []
|
| 177 |
+
for cls_name, (subdir, json_name, kind) in HDMAP_SOURCES.items():
|
| 178 |
+
try:
|
| 179 |
+
path = resolve_clip_file(clip_label_dir, subdir, json_name)
|
| 180 |
+
except FileNotFoundError:
|
| 181 |
+
continue
|
| 182 |
+
labels = _load_json_labels(path)
|
| 183 |
+
for i, lbl in enumerate(labels):
|
| 184 |
+
if not isinstance(lbl, dict):
|
| 185 |
+
continue
|
| 186 |
+
shape = lbl.get("labelData", {}).get("shape3d", {})
|
| 187 |
+
if not isinstance(shape, dict):
|
| 188 |
+
continue
|
| 189 |
+
|
| 190 |
+
if kind == "cuboid":
|
| 191 |
+
verts = shape.get("cuboid3d", {}).get("vertices", [])
|
| 192 |
+
arr = _verts_to_array(verts)
|
| 193 |
+
if arr.shape[0] != 8:
|
| 194 |
+
continue
|
| 195 |
+
c, lwh, yaw = _cuboid_to_box(arr)
|
| 196 |
+
out.append(_build_object(c, lwh, yaw, cls_name, i))
|
| 197 |
+
|
| 198 |
+
elif kind == "surface":
|
| 199 |
+
verts = shape.get("surface", {}).get("vertices", [])
|
| 200 |
+
arr = _verts_to_array(verts)
|
| 201 |
+
if arr.shape[0] < 3:
|
| 202 |
+
continue
|
| 203 |
+
box = _vertices_to_box(arr)
|
| 204 |
+
if box is not None:
|
| 205 |
+
out.append(_build_object(*box, cls_name, i))
|
| 206 |
+
|
| 207 |
+
elif kind == "polyline":
|
| 208 |
+
verts = shape.get("polyline3d", {}).get("vertices", [])
|
| 209 |
+
arr = _verts_to_array(verts)
|
| 210 |
+
if arr.shape[0] < 2:
|
| 211 |
+
continue
|
| 212 |
+
for j, chunk in enumerate(_split_polyline(arr, segment_len)):
|
| 213 |
+
box = _vertices_to_box(chunk)
|
| 214 |
+
if box is not None:
|
| 215 |
+
out.append(_build_object(*box, cls_name, i, j))
|
| 216 |
+
|
| 217 |
+
elif kind == "polyline_short":
|
| 218 |
+
# 杆状物体不切分
|
| 219 |
+
verts = shape.get("polyline3d", {}).get("vertices", [])
|
| 220 |
+
arr = _verts_to_array(verts)
|
| 221 |
+
if arr.shape[0] < 2:
|
| 222 |
+
continue
|
| 223 |
+
box = _vertices_to_box(arr)
|
| 224 |
+
if box is not None:
|
| 225 |
+
out.append(_build_object(*box, cls_name, i))
|
| 226 |
+
|
| 227 |
+
elif kind == "lane_pair":
|
| 228 |
+
pl_root = shape.get("polylines3d", {}).get("polylines", [])
|
| 229 |
+
if not isinstance(pl_root, list) or len(pl_root) < 2:
|
| 230 |
+
continue
|
| 231 |
+
left = _verts_to_array(
|
| 232 |
+
pl_root[0].get("vertices", []) if isinstance(pl_root[0], dict) else []
|
| 233 |
+
)
|
| 234 |
+
right = _verts_to_array(
|
| 235 |
+
pl_root[1].get("vertices", []) if isinstance(pl_root[1], dict) else []
|
| 236 |
+
)
|
| 237 |
+
if left.shape[0] == 0 and right.shape[0] == 0:
|
| 238 |
+
continue
|
| 239 |
+
merged = np.concatenate([a for a in (left, right) if a.shape[0]], axis=0)
|
| 240 |
+
if merged.shape[0] < 2:
|
| 241 |
+
continue
|
| 242 |
+
for j, chunk in enumerate(_split_polyline(merged, lane_segment_len)):
|
| 243 |
+
box = _vertices_to_box(chunk)
|
| 244 |
+
if box is not None:
|
| 245 |
+
out.append(_build_object(*box, cls_name, i, j))
|
| 246 |
+
|
| 247 |
+
return out
|
src/wjad/data/label_paths.py
CHANGED
|
@@ -1,218 +1,218 @@
|
|
| 1 |
-
"""数据集标签目录布局解析。
|
| 2 |
-
|
| 3 |
-
README 中的 keys 是相对每个 modality 的 ``.tar`` 根目录的扁平路径;
|
| 4 |
-
实际解压后常多一层子目录或 clip stem 前缀。解析失败时在 ``FileNotFoundError``
|
| 5 |
-
里附带目录列表,便于与 Hugging Face 数据集页面中的说明对照。
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
from __future__ import annotations
|
| 9 |
-
|
| 10 |
-
from pathlib import Path
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def _norm_name(s: str) -> str:
|
| 14 |
-
return "".join(c for c in s.lower() if c.isalnum())
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def _diagnose_labels(labels_dir: Path, folder: str, max_list: int = 50) -> str:
|
| 18 |
-
"""列出 ``labels/<clip>/<folder>/`` 下文件采样 + clip 根下一级子目录。"""
|
| 19 |
-
lines: list[str] = []
|
| 20 |
-
sub = labels_dir / folder
|
| 21 |
-
if sub.is_dir():
|
| 22 |
-
files = sorted(p for p in sub.rglob("*") if p.is_file())
|
| 23 |
-
lines.append(f"[{folder}/] 下共 {len(files)} 个文件(最多列出 {max_list} 条相对路径):")
|
| 24 |
-
for p in files[:max_list]:
|
| 25 |
-
try:
|
| 26 |
-
rel = p.relative_to(labels_dir).as_posix()
|
| 27 |
-
except ValueError:
|
| 28 |
-
rel = str(p)
|
| 29 |
-
lines.append(f" {rel}")
|
| 30 |
-
if len(files) > max_list:
|
| 31 |
-
lines.append(f" ... 另有 {len(files) - max_list} 个文件未列出")
|
| 32 |
-
else:
|
| 33 |
-
lines.append(f"[{folder}/] 不存在:{sub}")
|
| 34 |
-
try:
|
| 35 |
-
top = sorted(d.name for d in labels_dir.iterdir() if d.is_dir())
|
| 36 |
-
lines.append(f"[labels/<clip>/] 一级子目录:{top}")
|
| 37 |
-
except OSError as e:
|
| 38 |
-
lines.append(f"[labels/<clip>/] 无法列举:{e}")
|
| 39 |
-
return "\n".join(lines)
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def _scan_npy_json_npz(
|
| 43 |
-
labels_dir: Path,
|
| 44 |
-
folder: str,
|
| 45 |
-
fname: str,
|
| 46 |
-
*,
|
| 47 |
-
exts: tuple[str, ...] = (".npy",),
|
| 48 |
-
tokens_norm: list[str],
|
| 49 |
-
name_must_contain: str | None = None,
|
| 50 |
-
) -> list[Path]:
|
| 51 |
-
"""在整棵 labels/<clip>/ 下找候选文件:扩展名 + 归一化名须含各 token。"""
|
| 52 |
-
root_hint = labels_dir / folder
|
| 53 |
-
search_roots = [root_hint] if root_hint.is_dir() else []
|
| 54 |
-
if not search_roots:
|
| 55 |
-
search_roots = [labels_dir]
|
| 56 |
-
hits: list[Path] = []
|
| 57 |
-
for root in search_roots:
|
| 58 |
-
for p in root.rglob("*"):
|
| 59 |
-
if not p.is_file():
|
| 60 |
-
continue
|
| 61 |
-
if not p.suffix.lower() in [e.lower() for e in exts]:
|
| 62 |
-
continue
|
| 63 |
-
if name_must_contain and name_must_contain.lower() not in p.name.lower():
|
| 64 |
-
continue
|
| 65 |
-
pn = _norm_name(p.name)
|
| 66 |
-
if all(tok in pn for tok in tokens_norm if tok):
|
| 67 |
-
hits.append(p)
|
| 68 |
-
if not hits and root_hint.is_dir():
|
| 69 |
-
for p in labels_dir.rglob("*"):
|
| 70 |
-
if not p.is_file() or p.suffix.lower() not in [e.lower() for e in exts]:
|
| 71 |
-
continue
|
| 72 |
-
if name_must_contain and name_must_contain.lower() not in p.name.lower():
|
| 73 |
-
continue
|
| 74 |
-
pn = _norm_name(p.name)
|
| 75 |
-
if all(tok in pn for tok in tokens_norm if tok):
|
| 76 |
-
hits.append(p)
|
| 77 |
-
return hits
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def resolve_clip_file(labels_dir: Path, *parts: str) -> Path:
|
| 81 |
-
"""在 ``labels/<clip_id>/`` 下解析 ``parts`` 组成的相对路径(首个元素为一级子文件夹)。"""
|
| 82 |
-
if not parts:
|
| 83 |
-
raise ValueError("parts 不能为空")
|
| 84 |
-
if not labels_dir.is_dir():
|
| 85 |
-
raise FileNotFoundError(f"clip 标签根目录不存在: {labels_dir}")
|
| 86 |
-
|
| 87 |
-
direct = labels_dir.joinpath(*parts)
|
| 88 |
-
if direct.is_file():
|
| 89 |
-
return direct
|
| 90 |
-
# NVIDIA 磁盘命名:``{clip_stem}.{README_key}``,clip_stem = ``labels/<clip>/``
|
| 91 |
-
# 解析后的目录名(含 ``uuid_t0_t1``);README 里的 key 本身不含此前缀。
|
| 92 |
-
clip_stem = labels_dir.resolve().name
|
| 93 |
-
if len(parts) >= 2:
|
| 94 |
-
folder, fname = parts[0], parts[-1]
|
| 95 |
-
if not fname.startswith(f"{clip_stem}."):
|
| 96 |
-
stemmed = labels_dir / folder / f"{clip_stem}.{fname}"
|
| 97 |
-
if stemmed.is_file():
|
| 98 |
-
return stemmed
|
| 99 |
-
if len(parts) >= 2:
|
| 100 |
-
folder = parts[0]
|
| 101 |
-
rest = parts[1:]
|
| 102 |
-
doubled = (labels_dir / folder / folder).joinpath(*rest)
|
| 103 |
-
if doubled.is_file():
|
| 104 |
-
return doubled
|
| 105 |
-
fname = parts[-1]
|
| 106 |
-
folder = parts[0]
|
| 107 |
-
sub = labels_dir / folder
|
| 108 |
-
if sub.is_dir():
|
| 109 |
-
for p in sub.rglob(fname):
|
| 110 |
-
if p.is_file():
|
| 111 |
-
return p
|
| 112 |
-
|
| 113 |
-
for p in labels_dir.rglob(fname):
|
| 114 |
-
if p.is_file():
|
| 115 |
-
return p
|
| 116 |
-
fl = fname.lower()
|
| 117 |
-
for p in labels_dir.rglob("*"):
|
| 118 |
-
if p.is_file() and p.name.lower() == fl:
|
| 119 |
-
return p
|
| 120 |
-
|
| 121 |
-
# ftheta / pinhole
|
| 122 |
-
if folder in ("ftheta_intrinsic", "pinhole_intrinsic") and fname.endswith(".npy"):
|
| 123 |
-
prefix = folder + "."
|
| 124 |
-
if fname.lower().startswith(prefix.lower()):
|
| 125 |
-
cam = fname[len(prefix) : -len(".npy")]
|
| 126 |
-
cam_n = _norm_name(cam)
|
| 127 |
-
hits = []
|
| 128 |
-
for p in labels_dir.rglob("*.npy"):
|
| 129 |
-
if not p.is_file():
|
| 130 |
-
continue
|
| 131 |
-
pn = _norm_name(p.name)
|
| 132 |
-
if folder == "ftheta_intrinsic":
|
| 133 |
-
if "ftheta" not in pn:
|
| 134 |
-
continue
|
| 135 |
-
else:
|
| 136 |
-
if "pinhole" not in pn:
|
| 137 |
-
continue
|
| 138 |
-
if cam_n and cam_n in pn:
|
| 139 |
-
hits.append(p)
|
| 140 |
-
if len(hits) == 1:
|
| 141 |
-
return hits[0]
|
| 142 |
-
if len(hits) > 1:
|
| 143 |
-
hits.sort(key=lambda x: (len(x.parts), str(x)))
|
| 144 |
-
return hits[0]
|
| 145 |
-
|
| 146 |
-
# pose: ``{idx:06d}.pose.{camera}.npy``
|
| 147 |
-
if folder == "pose" and fname.endswith(".npy"):
|
| 148 |
-
base = fname[: -len(".npy")]
|
| 149 |
-
if ".pose." in base:
|
| 150 |
-
idx_part, _, cam_part = base.partition(".pose.")
|
| 151 |
-
hits = _scan_npy_json_npz(
|
| 152 |
-
labels_dir,
|
| 153 |
-
folder,
|
| 154 |
-
fname,
|
| 155 |
-
exts=(".npy",),
|
| 156 |
-
tokens_norm=[_norm_name(idx_part), _norm_name(cam_part)],
|
| 157 |
-
name_must_contain="pose",
|
| 158 |
-
)
|
| 159 |
-
if len(hits) == 1:
|
| 160 |
-
return hits[0]
|
| 161 |
-
if len(hits) > 1:
|
| 162 |
-
hits.sort(key=lambda x: (len(x.parts), -len(x.name), str(x)))
|
| 163 |
-
return hits[0]
|
| 164 |
-
|
| 165 |
-
# vehicle_pose: ``{idx:06d}.vehicle_pose.npy``
|
| 166 |
-
if folder == "vehicle_pose" and fname.endswith(".npy"):
|
| 167 |
-
idx_part = fname.split(".")[0]
|
| 168 |
-
hits = _scan_npy_json_npz(
|
| 169 |
-
labels_dir,
|
| 170 |
-
folder,
|
| 171 |
-
fname,
|
| 172 |
-
exts=(".npy",),
|
| 173 |
-
tokens_norm=[_norm_name(idx_part), "vehiclepose"],
|
| 174 |
-
name_must_contain="vehicle",
|
| 175 |
-
)
|
| 176 |
-
if len(hits) == 1:
|
| 177 |
-
return hits[0]
|
| 178 |
-
if len(hits) > 1:
|
| 179 |
-
hits.sort(key=lambda x: (len(x.parts), str(x)))
|
| 180 |
-
return hits[0]
|
| 181 |
-
|
| 182 |
-
# all_object_info
|
| 183 |
-
if folder == "all_object_info" and fname.endswith(".json"):
|
| 184 |
-
idx_part = fname.split(".")[0]
|
| 185 |
-
hits = _scan_npy_json_npz(
|
| 186 |
-
labels_dir,
|
| 187 |
-
folder,
|
| 188 |
-
fname,
|
| 189 |
-
exts=(".json",),
|
| 190 |
-
tokens_norm=[_norm_name(idx_part), "allobjectinfo"],
|
| 191 |
-
)
|
| 192 |
-
if len(hits) == 1:
|
| 193 |
-
return hits[0]
|
| 194 |
-
if len(hits) > 1:
|
| 195 |
-
hits.sort(key=lambda x: (len(x.parts), str(x)))
|
| 196 |
-
return hits[0]
|
| 197 |
-
|
| 198 |
-
# lidar_raw
|
| 199 |
-
if folder == "lidar_raw" and fname.endswith(".npz"):
|
| 200 |
-
stem = fname[: -len(".npz")]
|
| 201 |
-
hits = _scan_npy_json_npz(
|
| 202 |
-
labels_dir,
|
| 203 |
-
folder,
|
| 204 |
-
fname,
|
| 205 |
-
exts=(".npz",),
|
| 206 |
-
tokens_norm=[_norm_name(stem), "lidar", "raw"],
|
| 207 |
-
)
|
| 208 |
-
if len(hits) == 1:
|
| 209 |
-
return hits[0]
|
| 210 |
-
if len(hits) > 1:
|
| 211 |
-
hits.sort(key=lambda x: (len(x.parts), str(x)))
|
| 212 |
-
return hits[0]
|
| 213 |
-
|
| 214 |
-
detail = _diagnose_labels(labels_dir, folder)
|
| 215 |
-
raise FileNotFoundError(
|
| 216 |
-
f"在 {labels_dir} 下未找到 {'/'.join(parts)}(已尝试 README 扁平路径、双嵌套、"
|
| 217 |
-
f"rglob、按帧索引+相机的扫描匹配)。\n{detail}"
|
| 218 |
-
)
|
|
|
|
| 1 |
+
"""数据集标签目录布局解析。
|
| 2 |
+
|
| 3 |
+
README 中的 keys 是相对每个 modality 的 ``.tar`` 根目录的扁平路径;
|
| 4 |
+
实际解压后常多一层子目录或 clip stem 前缀。解析失败时在 ``FileNotFoundError``
|
| 5 |
+
里附带目录列表,便于与 Hugging Face 数据集页面中的说明对照。
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _norm_name(s: str) -> str:
|
| 14 |
+
return "".join(c for c in s.lower() if c.isalnum())
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _diagnose_labels(labels_dir: Path, folder: str, max_list: int = 50) -> str:
|
| 18 |
+
"""列出 ``labels/<clip>/<folder>/`` 下文件采样 + clip 根下一级子目录。"""
|
| 19 |
+
lines: list[str] = []
|
| 20 |
+
sub = labels_dir / folder
|
| 21 |
+
if sub.is_dir():
|
| 22 |
+
files = sorted(p for p in sub.rglob("*") if p.is_file())
|
| 23 |
+
lines.append(f"[{folder}/] 下共 {len(files)} 个文件(最多列出 {max_list} 条相对路径):")
|
| 24 |
+
for p in files[:max_list]:
|
| 25 |
+
try:
|
| 26 |
+
rel = p.relative_to(labels_dir).as_posix()
|
| 27 |
+
except ValueError:
|
| 28 |
+
rel = str(p)
|
| 29 |
+
lines.append(f" {rel}")
|
| 30 |
+
if len(files) > max_list:
|
| 31 |
+
lines.append(f" ... 另有 {len(files) - max_list} 个文件未列出")
|
| 32 |
+
else:
|
| 33 |
+
lines.append(f"[{folder}/] 不存在:{sub}")
|
| 34 |
+
try:
|
| 35 |
+
top = sorted(d.name for d in labels_dir.iterdir() if d.is_dir())
|
| 36 |
+
lines.append(f"[labels/<clip>/] 一级子目录:{top}")
|
| 37 |
+
except OSError as e:
|
| 38 |
+
lines.append(f"[labels/<clip>/] 无法列举:{e}")
|
| 39 |
+
return "\n".join(lines)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _scan_npy_json_npz(
|
| 43 |
+
labels_dir: Path,
|
| 44 |
+
folder: str,
|
| 45 |
+
fname: str,
|
| 46 |
+
*,
|
| 47 |
+
exts: tuple[str, ...] = (".npy",),
|
| 48 |
+
tokens_norm: list[str],
|
| 49 |
+
name_must_contain: str | None = None,
|
| 50 |
+
) -> list[Path]:
|
| 51 |
+
"""在整棵 labels/<clip>/ 下找候选文件:扩展名 + 归一化名须含各 token。"""
|
| 52 |
+
root_hint = labels_dir / folder
|
| 53 |
+
search_roots = [root_hint] if root_hint.is_dir() else []
|
| 54 |
+
if not search_roots:
|
| 55 |
+
search_roots = [labels_dir]
|
| 56 |
+
hits: list[Path] = []
|
| 57 |
+
for root in search_roots:
|
| 58 |
+
for p in root.rglob("*"):
|
| 59 |
+
if not p.is_file():
|
| 60 |
+
continue
|
| 61 |
+
if not p.suffix.lower() in [e.lower() for e in exts]:
|
| 62 |
+
continue
|
| 63 |
+
if name_must_contain and name_must_contain.lower() not in p.name.lower():
|
| 64 |
+
continue
|
| 65 |
+
pn = _norm_name(p.name)
|
| 66 |
+
if all(tok in pn for tok in tokens_norm if tok):
|
| 67 |
+
hits.append(p)
|
| 68 |
+
if not hits and root_hint.is_dir():
|
| 69 |
+
for p in labels_dir.rglob("*"):
|
| 70 |
+
if not p.is_file() or p.suffix.lower() not in [e.lower() for e in exts]:
|
| 71 |
+
continue
|
| 72 |
+
if name_must_contain and name_must_contain.lower() not in p.name.lower():
|
| 73 |
+
continue
|
| 74 |
+
pn = _norm_name(p.name)
|
| 75 |
+
if all(tok in pn for tok in tokens_norm if tok):
|
| 76 |
+
hits.append(p)
|
| 77 |
+
return hits
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def resolve_clip_file(labels_dir: Path, *parts: str) -> Path:
|
| 81 |
+
"""在 ``labels/<clip_id>/`` 下解析 ``parts`` 组成的相对路径(首个元素为一级子文件夹)。"""
|
| 82 |
+
if not parts:
|
| 83 |
+
raise ValueError("parts 不能为空")
|
| 84 |
+
if not labels_dir.is_dir():
|
| 85 |
+
raise FileNotFoundError(f"clip 标签根目录不存在: {labels_dir}")
|
| 86 |
+
|
| 87 |
+
direct = labels_dir.joinpath(*parts)
|
| 88 |
+
if direct.is_file():
|
| 89 |
+
return direct
|
| 90 |
+
# NVIDIA 磁盘命名:``{clip_stem}.{README_key}``,clip_stem = ``labels/<clip>/``
|
| 91 |
+
# 解析后的目录名(含 ``uuid_t0_t1``);README 里的 key 本身不含此前缀。
|
| 92 |
+
clip_stem = labels_dir.resolve().name
|
| 93 |
+
if len(parts) >= 2:
|
| 94 |
+
folder, fname = parts[0], parts[-1]
|
| 95 |
+
if not fname.startswith(f"{clip_stem}."):
|
| 96 |
+
stemmed = labels_dir / folder / f"{clip_stem}.{fname}"
|
| 97 |
+
if stemmed.is_file():
|
| 98 |
+
return stemmed
|
| 99 |
+
if len(parts) >= 2:
|
| 100 |
+
folder = parts[0]
|
| 101 |
+
rest = parts[1:]
|
| 102 |
+
doubled = (labels_dir / folder / folder).joinpath(*rest)
|
| 103 |
+
if doubled.is_file():
|
| 104 |
+
return doubled
|
| 105 |
+
fname = parts[-1]
|
| 106 |
+
folder = parts[0]
|
| 107 |
+
sub = labels_dir / folder
|
| 108 |
+
if sub.is_dir():
|
| 109 |
+
for p in sub.rglob(fname):
|
| 110 |
+
if p.is_file():
|
| 111 |
+
return p
|
| 112 |
+
|
| 113 |
+
for p in labels_dir.rglob(fname):
|
| 114 |
+
if p.is_file():
|
| 115 |
+
return p
|
| 116 |
+
fl = fname.lower()
|
| 117 |
+
for p in labels_dir.rglob("*"):
|
| 118 |
+
if p.is_file() and p.name.lower() == fl:
|
| 119 |
+
return p
|
| 120 |
+
|
| 121 |
+
# ftheta / pinhole
|
| 122 |
+
if folder in ("ftheta_intrinsic", "pinhole_intrinsic") and fname.endswith(".npy"):
|
| 123 |
+
prefix = folder + "."
|
| 124 |
+
if fname.lower().startswith(prefix.lower()):
|
| 125 |
+
cam = fname[len(prefix) : -len(".npy")]
|
| 126 |
+
cam_n = _norm_name(cam)
|
| 127 |
+
hits = []
|
| 128 |
+
for p in labels_dir.rglob("*.npy"):
|
| 129 |
+
if not p.is_file():
|
| 130 |
+
continue
|
| 131 |
+
pn = _norm_name(p.name)
|
| 132 |
+
if folder == "ftheta_intrinsic":
|
| 133 |
+
if "ftheta" not in pn:
|
| 134 |
+
continue
|
| 135 |
+
else:
|
| 136 |
+
if "pinhole" not in pn:
|
| 137 |
+
continue
|
| 138 |
+
if cam_n and cam_n in pn:
|
| 139 |
+
hits.append(p)
|
| 140 |
+
if len(hits) == 1:
|
| 141 |
+
return hits[0]
|
| 142 |
+
if len(hits) > 1:
|
| 143 |
+
hits.sort(key=lambda x: (len(x.parts), str(x)))
|
| 144 |
+
return hits[0]
|
| 145 |
+
|
| 146 |
+
# pose: ``{idx:06d}.pose.{camera}.npy``
|
| 147 |
+
if folder == "pose" and fname.endswith(".npy"):
|
| 148 |
+
base = fname[: -len(".npy")]
|
| 149 |
+
if ".pose." in base:
|
| 150 |
+
idx_part, _, cam_part = base.partition(".pose.")
|
| 151 |
+
hits = _scan_npy_json_npz(
|
| 152 |
+
labels_dir,
|
| 153 |
+
folder,
|
| 154 |
+
fname,
|
| 155 |
+
exts=(".npy",),
|
| 156 |
+
tokens_norm=[_norm_name(idx_part), _norm_name(cam_part)],
|
| 157 |
+
name_must_contain="pose",
|
| 158 |
+
)
|
| 159 |
+
if len(hits) == 1:
|
| 160 |
+
return hits[0]
|
| 161 |
+
if len(hits) > 1:
|
| 162 |
+
hits.sort(key=lambda x: (len(x.parts), -len(x.name), str(x)))
|
| 163 |
+
return hits[0]
|
| 164 |
+
|
| 165 |
+
# vehicle_pose: ``{idx:06d}.vehicle_pose.npy``
|
| 166 |
+
if folder == "vehicle_pose" and fname.endswith(".npy"):
|
| 167 |
+
idx_part = fname.split(".")[0]
|
| 168 |
+
hits = _scan_npy_json_npz(
|
| 169 |
+
labels_dir,
|
| 170 |
+
folder,
|
| 171 |
+
fname,
|
| 172 |
+
exts=(".npy",),
|
| 173 |
+
tokens_norm=[_norm_name(idx_part), "vehiclepose"],
|
| 174 |
+
name_must_contain="vehicle",
|
| 175 |
+
)
|
| 176 |
+
if len(hits) == 1:
|
| 177 |
+
return hits[0]
|
| 178 |
+
if len(hits) > 1:
|
| 179 |
+
hits.sort(key=lambda x: (len(x.parts), str(x)))
|
| 180 |
+
return hits[0]
|
| 181 |
+
|
| 182 |
+
# all_object_info
|
| 183 |
+
if folder == "all_object_info" and fname.endswith(".json"):
|
| 184 |
+
idx_part = fname.split(".")[0]
|
| 185 |
+
hits = _scan_npy_json_npz(
|
| 186 |
+
labels_dir,
|
| 187 |
+
folder,
|
| 188 |
+
fname,
|
| 189 |
+
exts=(".json",),
|
| 190 |
+
tokens_norm=[_norm_name(idx_part), "allobjectinfo"],
|
| 191 |
+
)
|
| 192 |
+
if len(hits) == 1:
|
| 193 |
+
return hits[0]
|
| 194 |
+
if len(hits) > 1:
|
| 195 |
+
hits.sort(key=lambda x: (len(x.parts), str(x)))
|
| 196 |
+
return hits[0]
|
| 197 |
+
|
| 198 |
+
# lidar_raw
|
| 199 |
+
if folder == "lidar_raw" and fname.endswith(".npz"):
|
| 200 |
+
stem = fname[: -len(".npz")]
|
| 201 |
+
hits = _scan_npy_json_npz(
|
| 202 |
+
labels_dir,
|
| 203 |
+
folder,
|
| 204 |
+
fname,
|
| 205 |
+
exts=(".npz",),
|
| 206 |
+
tokens_norm=[_norm_name(stem), "lidar", "raw"],
|
| 207 |
+
)
|
| 208 |
+
if len(hits) == 1:
|
| 209 |
+
return hits[0]
|
| 210 |
+
if len(hits) > 1:
|
| 211 |
+
hits.sort(key=lambda x: (len(x.parts), str(x)))
|
| 212 |
+
return hits[0]
|
| 213 |
+
|
| 214 |
+
detail = _diagnose_labels(labels_dir, folder)
|
| 215 |
+
raise FileNotFoundError(
|
| 216 |
+
f"在 {labels_dir} 下未找到 {'/'.join(parts)}(已尝试 README 扁平路径、双嵌套、"
|
| 217 |
+
f"rglob、按帧索引+相机的扫描匹配)。\n{detail}"
|
| 218 |
+
)
|
src/wjad/data/se3.py
CHANGED
|
@@ -1,111 +1,111 @@
|
|
| 1 |
-
"""SE(3) 与 6D 表示之间的转换。
|
| 2 |
-
|
| 3 |
-
约定:6D = ``[tx, ty, tz, rx, ry, rz]``,rotation 为轴角向量(``angle * axis``)。
|
| 4 |
-
平移单位为米;旋转角弧度。
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from __future__ import annotations
|
| 8 |
-
|
| 9 |
-
import numpy as np
|
| 10 |
-
import torch
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def rotation_matrix_to_axis_angle(R: torch.Tensor | np.ndarray) -> torch.Tensor:
|
| 14 |
-
"""3x3 旋转矩阵 -> 轴角向量 ``[3]`` (=angle * axis),支持 batch。
|
| 15 |
-
|
| 16 |
-
使用 Rodrigues 公式数值反求。
|
| 17 |
-
"""
|
| 18 |
-
if isinstance(R, np.ndarray):
|
| 19 |
-
R = torch.from_numpy(R).float()
|
| 20 |
-
if R.dim() == 2:
|
| 21 |
-
R = R.unsqueeze(0)
|
| 22 |
-
single = True
|
| 23 |
-
else:
|
| 24 |
-
single = False
|
| 25 |
-
|
| 26 |
-
trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
|
| 27 |
-
cos_theta = ((trace - 1.0) * 0.5).clamp(-1.0 + 1e-7, 1.0 - 1e-7)
|
| 28 |
-
theta = torch.acos(cos_theta) # [B]
|
| 29 |
-
|
| 30 |
-
# 提取轴向量
|
| 31 |
-
rx = R[..., 2, 1] - R[..., 1, 2]
|
| 32 |
-
ry = R[..., 0, 2] - R[..., 2, 0]
|
| 33 |
-
rz = R[..., 1, 0] - R[..., 0, 1]
|
| 34 |
-
axis = torch.stack([rx, ry, rz], dim=-1)
|
| 35 |
-
sin_theta = torch.sin(theta).clamp_min(1e-7)
|
| 36 |
-
axis = axis / (2.0 * sin_theta).unsqueeze(-1)
|
| 37 |
-
|
| 38 |
-
aa = axis * theta.unsqueeze(-1)
|
| 39 |
-
if single:
|
| 40 |
-
aa = aa.squeeze(0)
|
| 41 |
-
return aa
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def axis_angle_to_rotation_matrix(aa: torch.Tensor) -> torch.Tensor:
|
| 45 |
-
"""轴角向量 ``[..., 3]`` -> 旋转矩阵 ``[..., 3, 3]``(Rodrigues)。"""
|
| 46 |
-
theta = aa.norm(dim=-1, keepdim=True).clamp_min(1e-9) # [..., 1]
|
| 47 |
-
axis = aa / theta
|
| 48 |
-
x, y, z = axis[..., 0], axis[..., 1], axis[..., 2]
|
| 49 |
-
sin_t = torch.sin(theta.squeeze(-1))
|
| 50 |
-
cos_t = torch.cos(theta.squeeze(-1))
|
| 51 |
-
one_c = 1.0 - cos_t
|
| 52 |
-
|
| 53 |
-
R = torch.stack(
|
| 54 |
-
[
|
| 55 |
-
cos_t + x * x * one_c, x * y * one_c - z * sin_t, x * z * one_c + y * sin_t,
|
| 56 |
-
y * x * one_c + z * sin_t, cos_t + y * y * one_c, y * z * one_c - x * sin_t,
|
| 57 |
-
z * x * one_c - y * sin_t, z * y * one_c + x * sin_t, cos_t + z * z * one_c,
|
| 58 |
-
],
|
| 59 |
-
dim=-1,
|
| 60 |
-
).reshape(*aa.shape[:-1], 3, 3)
|
| 61 |
-
return R
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def matrix_to_6d(T: torch.Tensor | np.ndarray) -> torch.Tensor:
|
| 65 |
-
"""4x4 SE(3) -> 6D ``[tx, ty, tz, rx, ry, rz]``。"""
|
| 66 |
-
if isinstance(T, np.ndarray):
|
| 67 |
-
T = torch.from_numpy(T).float()
|
| 68 |
-
if T.dim() == 2:
|
| 69 |
-
T = T.unsqueeze(0)
|
| 70 |
-
single = True
|
| 71 |
-
else:
|
| 72 |
-
single = False
|
| 73 |
-
|
| 74 |
-
R = T[..., :3, :3]
|
| 75 |
-
t = T[..., :3, 3]
|
| 76 |
-
aa = rotation_matrix_to_axis_angle(R)
|
| 77 |
-
six = torch.cat([t, aa], dim=-1)
|
| 78 |
-
if single:
|
| 79 |
-
six = six.squeeze(0)
|
| 80 |
-
return six
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def six_d_to_matrix(six: torch.Tensor) -> torch.Tensor:
|
| 84 |
-
"""6D -> 4x4 SE(3)。"""
|
| 85 |
-
if six.dim() == 1:
|
| 86 |
-
six = six.unsqueeze(0)
|
| 87 |
-
single = True
|
| 88 |
-
else:
|
| 89 |
-
single = False
|
| 90 |
-
t = six[..., :3]
|
| 91 |
-
aa = six[..., 3:]
|
| 92 |
-
R = axis_angle_to_rotation_matrix(aa)
|
| 93 |
-
T = torch.zeros(*six.shape[:-1], 4, 4, dtype=six.dtype, device=six.device)
|
| 94 |
-
T[..., :3, :3] = R
|
| 95 |
-
T[..., :3, 3] = t
|
| 96 |
-
T[..., 3, 3] = 1.0
|
| 97 |
-
if single:
|
| 98 |
-
T = T.squeeze(0)
|
| 99 |
-
return T
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def invert_se3(T: torch.Tensor) -> torch.Tensor:
|
| 103 |
-
"""4x4 SE(3) 逆,``[..., 4, 4]``。"""
|
| 104 |
-
R = T[..., :3, :3]
|
| 105 |
-
t = T[..., :3, 3:4]
|
| 106 |
-
Rt = R.transpose(-2, -1)
|
| 107 |
-
inv = torch.zeros_like(T)
|
| 108 |
-
inv[..., :3, :3] = Rt
|
| 109 |
-
inv[..., :3, 3:4] = -Rt @ t
|
| 110 |
-
inv[..., 3, 3] = 1.0
|
| 111 |
-
return inv
|
|
|
|
| 1 |
+
"""SE(3) 与 6D 表示之间的转换。
|
| 2 |
+
|
| 3 |
+
约定:6D = ``[tx, ty, tz, rx, ry, rz]``,rotation 为轴角向量(``angle * axis``)。
|
| 4 |
+
平移单位为米;旋转角弧度。
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def rotation_matrix_to_axis_angle(R: torch.Tensor | np.ndarray) -> torch.Tensor:
|
| 14 |
+
"""3x3 旋转矩阵 -> 轴角向量 ``[3]`` (=angle * axis),支持 batch。
|
| 15 |
+
|
| 16 |
+
使用 Rodrigues 公式数值反求。
|
| 17 |
+
"""
|
| 18 |
+
if isinstance(R, np.ndarray):
|
| 19 |
+
R = torch.from_numpy(R).float()
|
| 20 |
+
if R.dim() == 2:
|
| 21 |
+
R = R.unsqueeze(0)
|
| 22 |
+
single = True
|
| 23 |
+
else:
|
| 24 |
+
single = False
|
| 25 |
+
|
| 26 |
+
trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
|
| 27 |
+
cos_theta = ((trace - 1.0) * 0.5).clamp(-1.0 + 1e-7, 1.0 - 1e-7)
|
| 28 |
+
theta = torch.acos(cos_theta) # [B]
|
| 29 |
+
|
| 30 |
+
# 提取轴向量
|
| 31 |
+
rx = R[..., 2, 1] - R[..., 1, 2]
|
| 32 |
+
ry = R[..., 0, 2] - R[..., 2, 0]
|
| 33 |
+
rz = R[..., 1, 0] - R[..., 0, 1]
|
| 34 |
+
axis = torch.stack([rx, ry, rz], dim=-1)
|
| 35 |
+
sin_theta = torch.sin(theta).clamp_min(1e-7)
|
| 36 |
+
axis = axis / (2.0 * sin_theta).unsqueeze(-1)
|
| 37 |
+
|
| 38 |
+
aa = axis * theta.unsqueeze(-1)
|
| 39 |
+
if single:
|
| 40 |
+
aa = aa.squeeze(0)
|
| 41 |
+
return aa
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def axis_angle_to_rotation_matrix(aa: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
"""轴角向量 ``[..., 3]`` -> 旋转矩阵 ``[..., 3, 3]``(Rodrigues)。"""
|
| 46 |
+
theta = aa.norm(dim=-1, keepdim=True).clamp_min(1e-9) # [..., 1]
|
| 47 |
+
axis = aa / theta
|
| 48 |
+
x, y, z = axis[..., 0], axis[..., 1], axis[..., 2]
|
| 49 |
+
sin_t = torch.sin(theta.squeeze(-1))
|
| 50 |
+
cos_t = torch.cos(theta.squeeze(-1))
|
| 51 |
+
one_c = 1.0 - cos_t
|
| 52 |
+
|
| 53 |
+
R = torch.stack(
|
| 54 |
+
[
|
| 55 |
+
cos_t + x * x * one_c, x * y * one_c - z * sin_t, x * z * one_c + y * sin_t,
|
| 56 |
+
y * x * one_c + z * sin_t, cos_t + y * y * one_c, y * z * one_c - x * sin_t,
|
| 57 |
+
z * x * one_c - y * sin_t, z * y * one_c + x * sin_t, cos_t + z * z * one_c,
|
| 58 |
+
],
|
| 59 |
+
dim=-1,
|
| 60 |
+
).reshape(*aa.shape[:-1], 3, 3)
|
| 61 |
+
return R
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def matrix_to_6d(T: torch.Tensor | np.ndarray) -> torch.Tensor:
|
| 65 |
+
"""4x4 SE(3) -> 6D ``[tx, ty, tz, rx, ry, rz]``。"""
|
| 66 |
+
if isinstance(T, np.ndarray):
|
| 67 |
+
T = torch.from_numpy(T).float()
|
| 68 |
+
if T.dim() == 2:
|
| 69 |
+
T = T.unsqueeze(0)
|
| 70 |
+
single = True
|
| 71 |
+
else:
|
| 72 |
+
single = False
|
| 73 |
+
|
| 74 |
+
R = T[..., :3, :3]
|
| 75 |
+
t = T[..., :3, 3]
|
| 76 |
+
aa = rotation_matrix_to_axis_angle(R)
|
| 77 |
+
six = torch.cat([t, aa], dim=-1)
|
| 78 |
+
if single:
|
| 79 |
+
six = six.squeeze(0)
|
| 80 |
+
return six
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def six_d_to_matrix(six: torch.Tensor) -> torch.Tensor:
|
| 84 |
+
"""6D -> 4x4 SE(3)。"""
|
| 85 |
+
if six.dim() == 1:
|
| 86 |
+
six = six.unsqueeze(0)
|
| 87 |
+
single = True
|
| 88 |
+
else:
|
| 89 |
+
single = False
|
| 90 |
+
t = six[..., :3]
|
| 91 |
+
aa = six[..., 3:]
|
| 92 |
+
R = axis_angle_to_rotation_matrix(aa)
|
| 93 |
+
T = torch.zeros(*six.shape[:-1], 4, 4, dtype=six.dtype, device=six.device)
|
| 94 |
+
T[..., :3, :3] = R
|
| 95 |
+
T[..., :3, 3] = t
|
| 96 |
+
T[..., 3, 3] = 1.0
|
| 97 |
+
if single:
|
| 98 |
+
T = T.squeeze(0)
|
| 99 |
+
return T
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def invert_se3(T: torch.Tensor) -> torch.Tensor:
|
| 103 |
+
"""4x4 SE(3) 逆,``[..., 4, 4]``。"""
|
| 104 |
+
R = T[..., :3, :3]
|
| 105 |
+
t = T[..., :3, 3:4]
|
| 106 |
+
Rt = R.transpose(-2, -1)
|
| 107 |
+
inv = torch.zeros_like(T)
|
| 108 |
+
inv[..., :3, :3] = Rt
|
| 109 |
+
inv[..., :3, 3:4] = -Rt @ t
|
| 110 |
+
inv[..., 3, 3] = 1.0
|
| 111 |
+
return inv
|
src/wjad/data/targets.py
CHANGED
|
@@ -1,214 +1,214 @@
|
|
| 1 |
-
"""检测 / 自车未来轨迹的目标构建。
|
| 2 |
-
|
| 3 |
-
依据 Cosmos-Drive-Dreams 数据集 README:
|
| 4 |
-
all_object_info JSON 中以 ``tracking_id`` 为 key,存储
|
| 5 |
-
``{object_to_world: 4x4, object_lwh: [l,w,h], object_is_moving: bool, object_type: str}``。
|
| 6 |
-
|
| 7 |
-
构建步骤:
|
| 8 |
-
1. 把每个对象的 ``object_to_world`` 转到 t 时刻自车系:
|
| 9 |
-
object_to_self = inv(vehicle_pose_t) @ object_to_world
|
| 10 |
-
2. 距离 ``≤ max_distance_m`` 过滤;
|
| 11 |
-
3. 投影中心点到当前帧像素,要求落在视锥内;
|
| 12 |
-
4. 用 LIDAR 深度对比做遮挡剔除(粗粒度);
|
| 13 |
-
5. 对动态目标,从 t+1..t+24 帧逐帧获取其 ``object_to_world``,转到 t 自车系,
|
| 14 |
-
提取 (dx, dy, dyaw) 并做 symlog 归一作为未来轨迹 GT;缺帧时 ``valid=0``。
|
| 15 |
-
|
| 16 |
-
为方便与 head 输出对齐,最终输出格式:
|
| 17 |
-
{"labels": [N], "boxes": [N, 7], "is_dynamic": [N],
|
| 18 |
-
"future_traj": [N, 24, 3], "future_valid": [N, 24]}
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
-
from __future__ import annotations
|
| 22 |
-
|
| 23 |
-
from dataclasses import dataclass
|
| 24 |
-
|
| 25 |
-
import numpy as np
|
| 26 |
-
import torch
|
| 27 |
-
|
| 28 |
-
from ..modules.normalization import symlog
|
| 29 |
-
from ..modules.rays import FThetaCamera
|
| 30 |
-
from .ftheta_proj import project_points_ftheta
|
| 31 |
-
from .se3 import invert_se3
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
@dataclass
|
| 35 |
-
class ObjectTrackInfo:
|
| 36 |
-
"""单个对
|
| 37 |
-
|
| 38 |
-
tracking_id: str
|
| 39 |
-
object_to_world: torch.Tensor # [4, 4]
|
| 40 |
-
lwh: torch.Tensor # [3]
|
| 41 |
-
is_moving: bool
|
| 42 |
-
object_type: str
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def _yaw_from_rotation_matrix(R: torch.Tensor) -> torch.Tensor:
|
| 46 |
-
"""从 3x3 旋转矩阵提取自车系下绕 z 轴的 yaw 角。
|
| 47 |
-
|
| 48 |
-
使用 ``atan2(R[1,0], R[0,0])``。
|
| 49 |
-
"""
|
| 50 |
-
return torch.atan2(R[..., 1, 0], R[..., 0, 0])
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def _make_class_index(object_type: str, dynamic_classes: list[str], structured_classes: list[str], background_idx: int = 0) -> tuple[int, int]:
|
| 54 |
-
"""根据 object_type 字符串映射到 (class_index, is_dynamic)。"""
|
| 55 |
-
if object_type in dynamic_classes:
|
| 56 |
-
return dynamic_classes.index(object_type) + 1, 1 # +1 为 background 留 idx 0
|
| 57 |
-
if object_type in structured_classes:
|
| 58 |
-
return len(dynamic_classes) + structured_classes.index(object_type) + 1, 0
|
| 59 |
-
return background_idx, 0 # 未知类型当 background
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def build_detection_targets(
|
| 63 |
-
objects_t: list[ObjectTrackInfo],
|
| 64 |
-
objects_future: list[list[ObjectTrackInfo]], # len = future_horizon,每帧一个对象列表
|
| 65 |
-
vehicle_pose_t: torch.Tensor, # [4, 4],vehicle to world
|
| 66 |
-
vehicle_pose_future: list[torch.Tensor], # 每帧一个 4x4
|
| 67 |
-
cam_intrinsic: FThetaCamera,
|
| 68 |
-
cam2vehicle: torch.Tensor, # [4, 4]
|
| 69 |
-
image_h: int,
|
| 70 |
-
image_w: int,
|
| 71 |
-
max_distance_m: float = 48.0,
|
| 72 |
-
occlusion_depth_tolerance: float = 0.5,
|
| 73 |
-
lidar_points_self: torch.Tensor | None = None, # [P, 3] in self frame,做粗遮挡
|
| 74 |
-
dynamic_classes: list[str] | None = None,
|
| 75 |
-
structured_classes: list[str] | None = None,
|
| 76 |
-
future_horizon: int = 24,
|
| 77 |
-
) -> dict:
|
| 78 |
-
"""构建一个样本的检测+未来轨迹标签。"""
|
| 79 |
-
if dynamic_classes is None:
|
| 80 |
-
dynamic_classes = []
|
| 81 |
-
if structured_classes is None:
|
| 82 |
-
structured_classes = []
|
| 83 |
-
|
| 84 |
-
inv_pose_t = invert_se3(vehicle_pose_t)
|
| 85 |
-
vehicle2cam = invert_se3(cam2vehicle)
|
| 86 |
-
|
| 87 |
-
labels: list[int] = []
|
| 88 |
-
boxes: list[list[float]] = []
|
| 89 |
-
is_dynamic: list[int] = []
|
| 90 |
-
future_traj: list[list[list[float]]] = []
|
| 91 |
-
future_valid: list[list[int]] = []
|
| 92 |
-
|
| 93 |
-
for obj in objects_t:
|
| 94 |
-
T_obj_self = inv_pose_t @ obj.object_to_world # [4,4]
|
| 95 |
-
center_self = T_obj_self[:3, 3]
|
| 96 |
-
|
| 97 |
-
dist = float(center_self.norm().item())
|
| 98 |
-
if dist > max_distance_m:
|
| 99 |
-
continue
|
| 100 |
-
|
| 101 |
-
# 视锥裁剪:把中心投影到相机系再投影到像素
|
| 102 |
-
center_cam = (vehicle2cam @ torch.cat([center_self, torch.ones(1)])[None].T).squeeze(-1)[:3]
|
| 103 |
-
if center_cam[2].item() <= 0:
|
| 104 |
-
continue
|
| 105 |
-
uv, depth = project_points_ftheta(center_cam.unsqueeze(0), cam_intrinsic)
|
| 106 |
-
u, v = uv[0, 0].item(), uv[0, 1].item()
|
| 107 |
-
if not (0 <= u < image_w and 0 <= v < image_h):
|
| 108 |
-
continue
|
| 109 |
-
|
| 110 |
-
# LIDAR 遮挡:找到 LIDAR 中靠近当前射线方向的最近点深度,与对象深度对比
|
| 111 |
-
if lidar_points_self is not None and lidar_points_self.numel() > 0:
|
| 112 |
-
ray = center_self / (center_self.norm() + 1e-6)
|
| 113 |
-
proj = lidar_points_self @ ray # [P]
|
| 114 |
-
# 选取沿射线方向投影距离接近 dist 的点(容差 1m,水平角 5°)
|
| 115 |
-
cosang = (lidar_points_self / (lidar_points_self.norm(dim=-1, keepdim=True) + 1e-6)) @ ray
|
| 116 |
-
mask = (cosang > 0.996) & (proj > 0)
|
| 117 |
-
if mask.any():
|
| 118 |
-
lidar_depth = proj[mask].min().item()
|
| 119 |
-
if lidar_depth + occlusion_depth_tolerance < dist:
|
| 120 |
-
# LIDAR 击中前方更近物体 -> 当前对象被遮挡
|
| 121 |
-
continue
|
| 122 |
-
|
| 123 |
-
# 类别映射
|
| 124 |
-
cls_idx, is_dyn = _make_class_index(obj.object_type, dynamic_classes, structured_classes)
|
| 125 |
-
if cls_idx == 0:
|
| 126 |
-
continue
|
| 127 |
-
labels.append(cls_idx)
|
| 128 |
-
is_dynamic.append(is_dyn)
|
| 129 |
-
|
| 130 |
-
yaw = _yaw_from_rotation_matrix(T_obj_self[:3, :3]).item()
|
| 131 |
-
l, w, h = obj.lwh.tolist()
|
| 132 |
-
# box 坐标 symlog 归一
|
| 133 |
-
x_n, y_n, z_n = (
|
| 134 |
-
float(symlog(center_self[0]).item()),
|
| 135 |
-
float(symlog(center_self[1]).item()),
|
| 136 |
-
float(symlog(center_self[2]).item()),
|
| 137 |
-
)
|
| 138 |
-
l_n = float(symlog(torch.tensor(l)).item())
|
| 139 |
-
w_n = float(symlog(torch.tensor(w)).item())
|
| 140 |
-
h_n = float(symlog(torch.tensor(h)).item())
|
| 141 |
-
boxes.append([x_n, y_n, z_n, l_n, w_n, h_n, yaw])
|
| 142 |
-
|
| 143 |
-
# 未来轨迹:在当前 self 系下用 (dx, dy, dyaw),相对 t 时刻对象自身
|
| 144 |
-
# 先取 t 时刻对象在 self 系下的 (x_t, y_t, yaw_t)
|
| 145 |
-
x0, y0, yaw0 = center_self[0].item(), center_self[1].item(), yaw
|
| 146 |
-
future_3 = []
|
| 147 |
-
future_v = []
|
| 148 |
-
for k in range(future_horizon):
|
| 149 |
-
if k >= len(objects_future) or k >= len(vehicle_pose_future):
|
| 150 |
-
future_3.append([0.0, 0.0, 0.0])
|
| 151 |
-
future_v.append(0)
|
| 152 |
-
continue
|
| 153 |
-
# 找对象在 t+k+1 帧
|
| 154 |
-
future_objs = objects_future[k]
|
| 155 |
-
match = next((o for o in future_objs if o.tracking_id == obj.tracking_id), None)
|
| 156 |
-
if match is None:
|
| 157 |
-
future_3.append([0.0, 0.0, 0.0])
|
| 158 |
-
future_v.append(0)
|
| 159 |
-
continue
|
| 160 |
-
T_obj_self_future = invert_se3(vehicle_pose_t) @ match.object_to_world
|
| 161 |
-
xf = T_obj_self_future[0, 3].item()
|
| 162 |
-
yf = T_obj_self_future[1, 3].item()
|
| 163 |
-
yawf = _yaw_from_rotation_matrix(T_obj_self_future[:3, :3]).item()
|
| 164 |
-
dx = xf - x0
|
| 165 |
-
dy = yf - y0
|
| 166 |
-
dyaw = yawf - yaw0
|
| 167 |
-
# 角度归到 (-pi, pi]
|
| 168 |
-
dyaw = (dyaw + np.pi) % (2 * np.pi) - np.pi
|
| 169 |
-
future_3.append([
|
| 170 |
-
float(symlog(torch.tensor(dx)).item()),
|
| 171 |
-
float(symlog(torch.tensor(dy)).item()),
|
| 172 |
-
float(dyaw),
|
| 173 |
-
])
|
| 174 |
-
future_v.append(1)
|
| 175 |
-
future_traj.append(future_3)
|
| 176 |
-
future_valid.append(future_v)
|
| 177 |
-
|
| 178 |
-
if not labels:
|
| 179 |
-
return {
|
| 180 |
-
"labels": torch.zeros(0, dtype=torch.long),
|
| 181 |
-
"boxes": torch.zeros(0, 7),
|
| 182 |
-
"is_dynamic": torch.zeros(0, dtype=torch.long),
|
| 183 |
-
"future_traj": torch.zeros(0, future_horizon, 3),
|
| 184 |
-
"future_valid": torch.zeros(0, future_horizon, dtype=torch.bool),
|
| 185 |
-
}
|
| 186 |
-
return {
|
| 187 |
-
"labels": torch.tensor(labels, dtype=torch.long),
|
| 188 |
-
"boxes": torch.tensor(boxes, dtype=torch.float32),
|
| 189 |
-
"is_dynamic": torch.tensor(is_dynamic, dtype=torch.long),
|
| 190 |
-
"future_traj": torch.tensor(future_traj, dtype=torch.float32),
|
| 191 |
-
"future_valid": torch.tensor(future_valid, dtype=torch.bool),
|
| 192 |
-
}
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
def build_ego_future_target(
|
| 196 |
-
vehicle_pose_t: torch.Tensor,
|
| 197 |
-
vehicle_pose_future: list[torch.Tensor],
|
| 198 |
-
horizon: int = 24,
|
| 199 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 200 |
-
"""自车未来 24 帧轨迹(在 t 自车系下,``(x, y, yaw)`` 已 symlog 归一)。"""
|
| 201 |
-
inv_t = invert_se3(vehicle_pose_t)
|
| 202 |
-
out = torch.zeros(horizon, 3)
|
| 203 |
-
valid = torch.zeros(horizon, dtype=torch.bool)
|
| 204 |
-
for k in range(horizon):
|
| 205 |
-
if k >= len(vehicle_pose_future):
|
| 206 |
-
break
|
| 207 |
-
rel = inv_t @ vehicle_pose_future[k]
|
| 208 |
-
x, y = rel[0, 3].item(), rel[1, 3].item()
|
| 209 |
-
yaw = _yaw_from_rotation_matrix(rel[:3, :3]).item()
|
| 210 |
-
out[k, 0] = symlog(torch.tensor(x))
|
| 211 |
-
out[k, 1] = symlog(torch.tensor(y))
|
| 212 |
-
out[k, 2] = yaw
|
| 213 |
-
valid[k] = True
|
| 214 |
-
return out, valid
|
|
|
|
| 1 |
+
"""检测 / 自车未来轨迹的目标构建。
|
| 2 |
+
|
| 3 |
+
依据 Cosmos-Drive-Dreams 数据集 README:
|
| 4 |
+
all_object_info JSON 中以 ``tracking_id`` 为 key,存储
|
| 5 |
+
``{object_to_world: 4x4, object_lwh: [l,w,h], object_is_moving: bool, object_type: str}``。
|
| 6 |
+
|
| 7 |
+
构建步骤:
|
| 8 |
+
1. 把每个对象的 ``object_to_world`` 转到 t 时刻自车系:
|
| 9 |
+
object_to_self = inv(vehicle_pose_t) @ object_to_world
|
| 10 |
+
2. 距离 ``≤ max_distance_m`` 过滤;
|
| 11 |
+
3. 投影中心点到当前帧像素,要求落在视锥内;
|
| 12 |
+
4. 用 LIDAR 深度对比做遮挡剔除(粗粒度);
|
| 13 |
+
5. 对动态目标,从 t+1..t+24 帧逐帧获取其 ``object_to_world``,转到 t 自车系,
|
| 14 |
+
提取 (dx, dy, dyaw) 并做 symlog 归一作为未来轨迹 GT;缺帧时 ``valid=0``。
|
| 15 |
+
|
| 16 |
+
为方便与 head 输出对齐,最终输出格式:
|
| 17 |
+
{"labels": [N], "boxes": [N, 7], "is_dynamic": [N],
|
| 18 |
+
"future_traj": [N, 24, 3], "future_valid": [N, 24]}
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
|
| 28 |
+
from ..modules.normalization import symlog
|
| 29 |
+
from ..modules.rays import FThetaCamera
|
| 30 |
+
from .ftheta_proj import project_points_ftheta
|
| 31 |
+
from .se3 import invert_se3
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class ObjectTrackInfo:
|
| 36 |
+
"""单个对���在某帧的简化记录。"""
|
| 37 |
+
|
| 38 |
+
tracking_id: str
|
| 39 |
+
object_to_world: torch.Tensor # [4, 4]
|
| 40 |
+
lwh: torch.Tensor # [3]
|
| 41 |
+
is_moving: bool
|
| 42 |
+
object_type: str
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _yaw_from_rotation_matrix(R: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
"""从 3x3 旋转矩阵提取自车系下绕 z 轴的 yaw 角。
|
| 47 |
+
|
| 48 |
+
使用 ``atan2(R[1,0], R[0,0])``。
|
| 49 |
+
"""
|
| 50 |
+
return torch.atan2(R[..., 1, 0], R[..., 0, 0])
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _make_class_index(object_type: str, dynamic_classes: list[str], structured_classes: list[str], background_idx: int = 0) -> tuple[int, int]:
|
| 54 |
+
"""根据 object_type 字符串映射到 (class_index, is_dynamic)。"""
|
| 55 |
+
if object_type in dynamic_classes:
|
| 56 |
+
return dynamic_classes.index(object_type) + 1, 1 # +1 为 background 留 idx 0
|
| 57 |
+
if object_type in structured_classes:
|
| 58 |
+
return len(dynamic_classes) + structured_classes.index(object_type) + 1, 0
|
| 59 |
+
return background_idx, 0 # 未知类型当 background
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def build_detection_targets(
|
| 63 |
+
objects_t: list[ObjectTrackInfo],
|
| 64 |
+
objects_future: list[list[ObjectTrackInfo]], # len = future_horizon,每帧一个对象列表
|
| 65 |
+
vehicle_pose_t: torch.Tensor, # [4, 4],vehicle to world
|
| 66 |
+
vehicle_pose_future: list[torch.Tensor], # 每帧一个 4x4
|
| 67 |
+
cam_intrinsic: FThetaCamera,
|
| 68 |
+
cam2vehicle: torch.Tensor, # [4, 4]
|
| 69 |
+
image_h: int,
|
| 70 |
+
image_w: int,
|
| 71 |
+
max_distance_m: float = 48.0,
|
| 72 |
+
occlusion_depth_tolerance: float = 0.5,
|
| 73 |
+
lidar_points_self: torch.Tensor | None = None, # [P, 3] in self frame,做粗遮挡
|
| 74 |
+
dynamic_classes: list[str] | None = None,
|
| 75 |
+
structured_classes: list[str] | None = None,
|
| 76 |
+
future_horizon: int = 24,
|
| 77 |
+
) -> dict:
|
| 78 |
+
"""构建一个样本的检测+未来轨迹标签。"""
|
| 79 |
+
if dynamic_classes is None:
|
| 80 |
+
dynamic_classes = []
|
| 81 |
+
if structured_classes is None:
|
| 82 |
+
structured_classes = []
|
| 83 |
+
|
| 84 |
+
inv_pose_t = invert_se3(vehicle_pose_t)
|
| 85 |
+
vehicle2cam = invert_se3(cam2vehicle)
|
| 86 |
+
|
| 87 |
+
labels: list[int] = []
|
| 88 |
+
boxes: list[list[float]] = []
|
| 89 |
+
is_dynamic: list[int] = []
|
| 90 |
+
future_traj: list[list[list[float]]] = []
|
| 91 |
+
future_valid: list[list[int]] = []
|
| 92 |
+
|
| 93 |
+
for obj in objects_t:
|
| 94 |
+
T_obj_self = inv_pose_t @ obj.object_to_world # [4,4]
|
| 95 |
+
center_self = T_obj_self[:3, 3]
|
| 96 |
+
|
| 97 |
+
dist = float(center_self.norm().item())
|
| 98 |
+
if dist > max_distance_m:
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
# 视锥裁剪:把中心投影到相机系再投影到像素
|
| 102 |
+
center_cam = (vehicle2cam @ torch.cat([center_self, torch.ones(1)])[None].T).squeeze(-1)[:3]
|
| 103 |
+
if center_cam[2].item() <= 0:
|
| 104 |
+
continue
|
| 105 |
+
uv, depth = project_points_ftheta(center_cam.unsqueeze(0), cam_intrinsic)
|
| 106 |
+
u, v = uv[0, 0].item(), uv[0, 1].item()
|
| 107 |
+
if not (0 <= u < image_w and 0 <= v < image_h):
|
| 108 |
+
continue
|
| 109 |
+
|
| 110 |
+
# LIDAR 遮挡:找到 LIDAR 中靠近当前射线方向的最近点深度,与对象深度对比
|
| 111 |
+
if lidar_points_self is not None and lidar_points_self.numel() > 0:
|
| 112 |
+
ray = center_self / (center_self.norm() + 1e-6)
|
| 113 |
+
proj = lidar_points_self @ ray # [P]
|
| 114 |
+
# 选取沿射线方向投影距离接近 dist 的点(容差 1m,水平角 5°)
|
| 115 |
+
cosang = (lidar_points_self / (lidar_points_self.norm(dim=-1, keepdim=True) + 1e-6)) @ ray
|
| 116 |
+
mask = (cosang > 0.996) & (proj > 0)
|
| 117 |
+
if mask.any():
|
| 118 |
+
lidar_depth = proj[mask].min().item()
|
| 119 |
+
if lidar_depth + occlusion_depth_tolerance < dist:
|
| 120 |
+
# LIDAR 击中前方更近物体 -> 当前对象被遮挡
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
# 类别映射
|
| 124 |
+
cls_idx, is_dyn = _make_class_index(obj.object_type, dynamic_classes, structured_classes)
|
| 125 |
+
if cls_idx == 0:
|
| 126 |
+
continue
|
| 127 |
+
labels.append(cls_idx)
|
| 128 |
+
is_dynamic.append(is_dyn)
|
| 129 |
+
|
| 130 |
+
yaw = _yaw_from_rotation_matrix(T_obj_self[:3, :3]).item()
|
| 131 |
+
l, w, h = obj.lwh.tolist()
|
| 132 |
+
# box 坐标 symlog 归一
|
| 133 |
+
x_n, y_n, z_n = (
|
| 134 |
+
float(symlog(center_self[0]).item()),
|
| 135 |
+
float(symlog(center_self[1]).item()),
|
| 136 |
+
float(symlog(center_self[2]).item()),
|
| 137 |
+
)
|
| 138 |
+
l_n = float(symlog(torch.tensor(l)).item())
|
| 139 |
+
w_n = float(symlog(torch.tensor(w)).item())
|
| 140 |
+
h_n = float(symlog(torch.tensor(h)).item())
|
| 141 |
+
boxes.append([x_n, y_n, z_n, l_n, w_n, h_n, yaw])
|
| 142 |
+
|
| 143 |
+
# 未来轨迹:在当前 self 系下用 (dx, dy, dyaw),相对 t 时刻对象自身
|
| 144 |
+
# 先取 t 时刻对象在 self 系下的 (x_t, y_t, yaw_t)
|
| 145 |
+
x0, y0, yaw0 = center_self[0].item(), center_self[1].item(), yaw
|
| 146 |
+
future_3 = []
|
| 147 |
+
future_v = []
|
| 148 |
+
for k in range(future_horizon):
|
| 149 |
+
if k >= len(objects_future) or k >= len(vehicle_pose_future):
|
| 150 |
+
future_3.append([0.0, 0.0, 0.0])
|
| 151 |
+
future_v.append(0)
|
| 152 |
+
continue
|
| 153 |
+
# 找对象在 t+k+1 帧
|
| 154 |
+
future_objs = objects_future[k]
|
| 155 |
+
match = next((o for o in future_objs if o.tracking_id == obj.tracking_id), None)
|
| 156 |
+
if match is None:
|
| 157 |
+
future_3.append([0.0, 0.0, 0.0])
|
| 158 |
+
future_v.append(0)
|
| 159 |
+
continue
|
| 160 |
+
T_obj_self_future = invert_se3(vehicle_pose_t) @ match.object_to_world
|
| 161 |
+
xf = T_obj_self_future[0, 3].item()
|
| 162 |
+
yf = T_obj_self_future[1, 3].item()
|
| 163 |
+
yawf = _yaw_from_rotation_matrix(T_obj_self_future[:3, :3]).item()
|
| 164 |
+
dx = xf - x0
|
| 165 |
+
dy = yf - y0
|
| 166 |
+
dyaw = yawf - yaw0
|
| 167 |
+
# 角度归到 (-pi, pi]
|
| 168 |
+
dyaw = (dyaw + np.pi) % (2 * np.pi) - np.pi
|
| 169 |
+
future_3.append([
|
| 170 |
+
float(symlog(torch.tensor(dx)).item()),
|
| 171 |
+
float(symlog(torch.tensor(dy)).item()),
|
| 172 |
+
float(dyaw),
|
| 173 |
+
])
|
| 174 |
+
future_v.append(1)
|
| 175 |
+
future_traj.append(future_3)
|
| 176 |
+
future_valid.append(future_v)
|
| 177 |
+
|
| 178 |
+
if not labels:
|
| 179 |
+
return {
|
| 180 |
+
"labels": torch.zeros(0, dtype=torch.long),
|
| 181 |
+
"boxes": torch.zeros(0, 7),
|
| 182 |
+
"is_dynamic": torch.zeros(0, dtype=torch.long),
|
| 183 |
+
"future_traj": torch.zeros(0, future_horizon, 3),
|
| 184 |
+
"future_valid": torch.zeros(0, future_horizon, dtype=torch.bool),
|
| 185 |
+
}
|
| 186 |
+
return {
|
| 187 |
+
"labels": torch.tensor(labels, dtype=torch.long),
|
| 188 |
+
"boxes": torch.tensor(boxes, dtype=torch.float32),
|
| 189 |
+
"is_dynamic": torch.tensor(is_dynamic, dtype=torch.long),
|
| 190 |
+
"future_traj": torch.tensor(future_traj, dtype=torch.float32),
|
| 191 |
+
"future_valid": torch.tensor(future_valid, dtype=torch.bool),
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def build_ego_future_target(
|
| 196 |
+
vehicle_pose_t: torch.Tensor,
|
| 197 |
+
vehicle_pose_future: list[torch.Tensor],
|
| 198 |
+
horizon: int = 24,
|
| 199 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 200 |
+
"""自车未来 24 帧轨迹(在 t 自车系下,``(x, y, yaw)`` 已 symlog 归一)。"""
|
| 201 |
+
inv_t = invert_se3(vehicle_pose_t)
|
| 202 |
+
out = torch.zeros(horizon, 3)
|
| 203 |
+
valid = torch.zeros(horizon, dtype=torch.bool)
|
| 204 |
+
for k in range(horizon):
|
| 205 |
+
if k >= len(vehicle_pose_future):
|
| 206 |
+
break
|
| 207 |
+
rel = inv_t @ vehicle_pose_future[k]
|
| 208 |
+
x, y = rel[0, 3].item(), rel[1, 3].item()
|
| 209 |
+
yaw = _yaw_from_rotation_matrix(rel[:3, :3]).item()
|
| 210 |
+
out[k, 0] = symlog(torch.tensor(x))
|
| 211 |
+
out[k, 1] = symlog(torch.tensor(y))
|
| 212 |
+
out[k, 2] = yaw
|
| 213 |
+
valid[k] = True
|
| 214 |
+
return out, valid
|
src/wjad/data/transforms.py
CHANGED
|
@@ -1,86 +1,86 @@
|
|
| 1 |
-
"""图像与运动学的数据增广。"""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torch
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
# DINOv3 的 ImageNet 标准化参数
|
| 10 |
-
DINOV3_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
| 11 |
-
DINOV3_STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def crop_top_half(image: torch.Tensor) -> torch.Tensor:
|
| 15 |
-
"""裁去图像上半部分(主要是天空)。
|
| 16 |
-
|
| 17 |
-
输入 ``[3, H, W]`` 或 ``[T, 3, H, W]``;返回相同维度但 H 减半。
|
| 18 |
-
"""
|
| 19 |
-
if image.dim() == 4:
|
| 20 |
-
h = image.shape[2]
|
| 21 |
-
return image[:, :, h // 2 :, :]
|
| 22 |
-
elif image.dim() == 3:
|
| 23 |
-
h = image.shape[1]
|
| 24 |
-
return image[:, h // 2 :, :]
|
| 25 |
-
raise ValueError(f"unsupported image dim: {image.dim()}")
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def normalize_image(image: torch.Tensor, mean: torch.Tensor = DINOV3_MEAN, std: torch.Tensor = DINOV3_STD) -> torch.Tensor:
|
| 29 |
-
"""对 [0, 1] 范围的图像做标准化。支持 ``[3,H,W]``/``[T,3,H,W]``/``[B,T,3,H,W]``。"""
|
| 30 |
-
while mean.dim() < image.dim():
|
| 31 |
-
mean = mean.unsqueeze(0)
|
| 32 |
-
std = std.unsqueeze(0)
|
| 33 |
-
return (image - mean.to(image.device, image.dtype)) / std.to(image.device, image.dtype)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def add_gaussian_noise(image: torch.Tensor, std: float = 0.01) -> torch.Tensor:
|
| 37 |
-
"""高斯噪声增广。``image`` 应已归一化(mean=0,std=1 之后)。"""
|
| 38 |
-
if std <= 0:
|
| 39 |
-
return image
|
| 40 |
-
return image + torch.randn_like(image) * std
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def perturb_kinematics(
|
| 44 |
-
ego_6d: torch.Tensor, # [T, 6]
|
| 45 |
-
intr_vec: torch.Tensor, # [14]
|
| 46 |
-
extr_6d: torch.Tensor, # [6]
|
| 47 |
-
translation_std_m: float,
|
| 48 |
-
rotation_std_deg: float,
|
| 49 |
-
intrinsic_std: float,
|
| 50 |
-
extrinsic_std: float,
|
| 51 |
-
rng: np.random.Generator,
|
| 52 |
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 53 |
-
"""在 Stage1 中期对运动学和内外参添加微小扰动,作为校准训练增广。
|
| 54 |
-
|
| 55 |
-
返回扰动后值与扰动量(GT 残差 = -扰动量,因为校准网络要把扰动反推回去)。
|
| 56 |
-
|
| 57 |
-
返回
|
| 58 |
-
----
|
| 59 |
-
perturbed_ego, perturbed_intr, perturbed_extr,
|
| 60 |
-
gt_residual_concat (在 symlog 空间作为 calibration 监督,可选;
|
| 61 |
-
本文件仅返回扰动后的真实空间值,校准 GT 由 trainer 构造)
|
| 62 |
-
"""
|
| 63 |
-
rot_std_rad = np.deg2rad(rotation_std_deg)
|
| 64 |
-
|
| 65 |
-
# ego 8x6
|
| 66 |
-
delta_ego = np.zeros_like(ego_6d.numpy())
|
| 67 |
-
delta_ego[:, :3] = rng.normal(0.0, translation_std_m, size=(ego_6d.shape[0], 3))
|
| 68 |
-
delta_ego[:, 3:] = rng.normal(0.0, rot_std_rad, size=(ego_6d.shape[0], 3))
|
| 69 |
-
perturbed_ego = ego_6d + torch.from_numpy(delta_ego).to(ego_6d)
|
| 70 |
-
|
| 71 |
-
# intrinsic 14
|
| 72 |
-
delta_intr = rng.normal(0.0, intrinsic_std, size=(intr_vec.shape[0],))
|
| 73 |
-
perturbed_intr = intr_vec + torch.from_numpy(delta_intr).to(intr_vec)
|
| 74 |
-
|
| 75 |
-
# extrinsic 6
|
| 76 |
-
delta_extr = np.zeros_like(extr_6d.numpy())
|
| 77 |
-
delta_extr[:3] = rng.normal(0.0, extrinsic_std, size=(3,))
|
| 78 |
-
delta_extr[3:] = rng.normal(0.0, rot_std_rad, size=(3,))
|
| 79 |
-
perturbed_extr = extr_6d + torch.from_numpy(delta_extr).to(extr_6d)
|
| 80 |
-
|
| 81 |
-
return (
|
| 82 |
-
perturbed_ego,
|
| 83 |
-
perturbed_intr,
|
| 84 |
-
perturbed_extr,
|
| 85 |
-
torch.from_numpy(np.concatenate([delta_ego.flatten(), delta_intr, delta_extr])).to(ego_6d.dtype),
|
| 86 |
-
)
|
|
|
|
| 1 |
+
"""图像与运动学的数据增广。"""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# DINOv3 的 ImageNet 标准化参数
|
| 10 |
+
DINOV3_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
| 11 |
+
DINOV3_STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def crop_top_half(image: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
"""裁去图像上半部分(主要是天空)。
|
| 16 |
+
|
| 17 |
+
输入 ``[3, H, W]`` 或 ``[T, 3, H, W]``;返回相同维度但 H 减半。
|
| 18 |
+
"""
|
| 19 |
+
if image.dim() == 4:
|
| 20 |
+
h = image.shape[2]
|
| 21 |
+
return image[:, :, h // 2 :, :]
|
| 22 |
+
elif image.dim() == 3:
|
| 23 |
+
h = image.shape[1]
|
| 24 |
+
return image[:, h // 2 :, :]
|
| 25 |
+
raise ValueError(f"unsupported image dim: {image.dim()}")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def normalize_image(image: torch.Tensor, mean: torch.Tensor = DINOV3_MEAN, std: torch.Tensor = DINOV3_STD) -> torch.Tensor:
|
| 29 |
+
"""对 [0, 1] 范围的图像做标准化。支持 ``[3,H,W]``/``[T,3,H,W]``/``[B,T,3,H,W]``。"""
|
| 30 |
+
while mean.dim() < image.dim():
|
| 31 |
+
mean = mean.unsqueeze(0)
|
| 32 |
+
std = std.unsqueeze(0)
|
| 33 |
+
return (image - mean.to(image.device, image.dtype)) / std.to(image.device, image.dtype)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def add_gaussian_noise(image: torch.Tensor, std: float = 0.01) -> torch.Tensor:
|
| 37 |
+
"""高斯噪声增广。``image`` 应已归一化(mean=0,std=1 之后)。"""
|
| 38 |
+
if std <= 0:
|
| 39 |
+
return image
|
| 40 |
+
return image + torch.randn_like(image) * std
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def perturb_kinematics(
|
| 44 |
+
ego_6d: torch.Tensor, # [T, 6]
|
| 45 |
+
intr_vec: torch.Tensor, # [14]
|
| 46 |
+
extr_6d: torch.Tensor, # [6]
|
| 47 |
+
translation_std_m: float,
|
| 48 |
+
rotation_std_deg: float,
|
| 49 |
+
intrinsic_std: float,
|
| 50 |
+
extrinsic_std: float,
|
| 51 |
+
rng: np.random.Generator,
|
| 52 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 53 |
+
"""在 Stage1 中期对运动学和内外参添加微小扰动,作为校准训练增广。
|
| 54 |
+
|
| 55 |
+
返回扰动后值与扰动量(GT 残差 = -扰动量,因为校准网络要把扰动反推回去)。
|
| 56 |
+
|
| 57 |
+
返回
|
| 58 |
+
----
|
| 59 |
+
perturbed_ego, perturbed_intr, perturbed_extr,
|
| 60 |
+
gt_residual_concat (在 symlog 空间作为 calibration 监督,可选;
|
| 61 |
+
本文件仅返回扰动后的真实空间值,校准 GT 由 trainer 构造)
|
| 62 |
+
"""
|
| 63 |
+
rot_std_rad = np.deg2rad(rotation_std_deg)
|
| 64 |
+
|
| 65 |
+
# ego 8x6
|
| 66 |
+
delta_ego = np.zeros_like(ego_6d.numpy())
|
| 67 |
+
delta_ego[:, :3] = rng.normal(0.0, translation_std_m, size=(ego_6d.shape[0], 3))
|
| 68 |
+
delta_ego[:, 3:] = rng.normal(0.0, rot_std_rad, size=(ego_6d.shape[0], 3))
|
| 69 |
+
perturbed_ego = ego_6d + torch.from_numpy(delta_ego).to(ego_6d)
|
| 70 |
+
|
| 71 |
+
# intrinsic 14
|
| 72 |
+
delta_intr = rng.normal(0.0, intrinsic_std, size=(intr_vec.shape[0],))
|
| 73 |
+
perturbed_intr = intr_vec + torch.from_numpy(delta_intr).to(intr_vec)
|
| 74 |
+
|
| 75 |
+
# extrinsic 6
|
| 76 |
+
delta_extr = np.zeros_like(extr_6d.numpy())
|
| 77 |
+
delta_extr[:3] = rng.normal(0.0, extrinsic_std, size=(3,))
|
| 78 |
+
delta_extr[3:] = rng.normal(0.0, rot_std_rad, size=(3,))
|
| 79 |
+
perturbed_extr = extr_6d + torch.from_numpy(delta_extr).to(extr_6d)
|
| 80 |
+
|
| 81 |
+
return (
|
| 82 |
+
perturbed_ego,
|
| 83 |
+
perturbed_intr,
|
| 84 |
+
perturbed_extr,
|
| 85 |
+
torch.from_numpy(np.concatenate([delta_ego.flatten(), delta_intr, delta_extr])).to(ego_6d.dtype),
|
| 86 |
+
)
|
src/wjad/encoders/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
"""视觉编码相关:DINOv3 包装、时空压缩。"""
|
| 2 |
-
|
| 3 |
-
from .dinov3_wrapper import DINOv3Wrapper
|
| 4 |
-
|
| 5 |
-
__all__ = ["DINOv3Wrapper"]
|
|
|
|
| 1 |
+
"""视觉编码相关:DINOv3 包装、时空压缩。"""
|
| 2 |
+
|
| 3 |
+
from .dinov3_wrapper import DINOv3Wrapper
|
| 4 |
+
|
| 5 |
+
__all__ = ["DINOv3Wrapper"]
|
src/wjad/encoders/dinov3_wrapper.py
CHANGED
|
@@ -1,104 +1,104 @@
|
|
| 1 |
-
"""DINOv3 ViT-B/16 包装器。
|
| 2 |
-
|
| 3 |
-
- 从本地路径加载(``./dinov3-vitb16-pretrain-lvd1689m``)。
|
| 4 |
-
- 强制使用 ``attn_implementation="sdpa"``。
|
| 5 |
-
- 提供 ``freeze()`` / ``unfreeze()`` 开关。
|
| 6 |
-
- 输入:``[B, T, 3, H, W]``;输出:``[B, T, gh, gw, D]``,其中
|
| 7 |
-
``(gh, gw) = (H/patch, W/patch)``。
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
from __future__ import annotations
|
| 11 |
-
|
| 12 |
-
from pathlib import Path
|
| 13 |
-
|
| 14 |
-
import torch
|
| 15 |
-
import torch.nn as nn
|
| 16 |
-
from transformers import AutoModel
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class DINOv3Wrapper(nn.Module):
|
| 20 |
-
"""加载并包装 DINOv3 ViT-B/16,输出 patch 网格特征。"""
|
| 21 |
-
|
| 22 |
-
def __init__(
|
| 23 |
-
self,
|
| 24 |
-
pretrained_path: str | Path = "./dinov3-vitb16-pretrain-lvd1689m",
|
| 25 |
-
attn_implementation: str = "sdpa",
|
| 26 |
-
freeze: bool = True,
|
| 27 |
-
) -> None:
|
| 28 |
-
super().__init__()
|
| 29 |
-
self.pretrained_path = str(pretrained_path)
|
| 30 |
-
# 加载 HuggingFace transformers 中的 DINOv3 ViT 模型
|
| 31 |
-
self.model = AutoModel.from_pretrained(
|
| 32 |
-
self.pretrained_path,
|
| 33 |
-
attn_implementation=attn_implementation,
|
| 34 |
-
)
|
| 35 |
-
cfg = self.model.config
|
| 36 |
-
self.hidden_size = cfg.hidden_size
|
| 37 |
-
self.patch_size = cfg.patch_size
|
| 38 |
-
self.num_register_tokens = getattr(cfg, "num_register_tokens", 4)
|
| 39 |
-
|
| 40 |
-
if freeze:
|
| 41 |
-
self.freeze()
|
| 42 |
-
self._frozen = freeze
|
| 43 |
-
|
| 44 |
-
def freeze(self) -> None:
|
| 45 |
-
"""冻结所有参数。"""
|
| 46 |
-
for p in self.model.parameters():
|
| 47 |
-
p.requires_grad_(False)
|
| 48 |
-
self.model.eval()
|
| 49 |
-
self._frozen = True
|
| 50 |
-
|
| 51 |
-
def unfreeze(self) -> None:
|
| 52 |
-
"""解冻全部参
|
| 53 |
-
for p in self.model.parameters():
|
| 54 |
-
p.requires_grad_(True)
|
| 55 |
-
self.model.train()
|
| 56 |
-
self._frozen = False
|
| 57 |
-
|
| 58 |
-
@property
|
| 59 |
-
def is_frozen(self) -> bool:
|
| 60 |
-
return self._frozen
|
| 61 |
-
|
| 62 |
-
def train(self, mode: bool = True) -> "DINOv3Wrapper":
|
| 63 |
-
"""覆盖 train():冻结时永远保持 eval 模式(避免 BN/Dropout 漂移)。"""
|
| 64 |
-
super().train(mode)
|
| 65 |
-
if self._frozen:
|
| 66 |
-
self.model.eval()
|
| 67 |
-
return self
|
| 68 |
-
|
| 69 |
-
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
| 70 |
-
"""
|
| 71 |
-
参数
|
| 72 |
-
----
|
| 73 |
-
images : ``[B, T, 3, H, W]``,已按 DINOv3 mean/std 归一化。
|
| 74 |
-
|
| 75 |
-
返回
|
| 76 |
-
----
|
| 77 |
-
feats : ``[B, T, gh, gw, D]``。
|
| 78 |
-
"""
|
| 79 |
-
b, t, c, h, w = images.shape
|
| 80 |
-
# DINOv3 forward 接受 ``pixel_values: [B*T, 3, H, W]``
|
| 81 |
-
flat = images.view(b * t, c, h, w)
|
| 82 |
-
|
| 83 |
-
# 冻结分支无需梯度,节省显存与时间
|
| 84 |
-
if self._frozen:
|
| 85 |
-
with torch.no_grad():
|
| 86 |
-
outputs = self.model(pixel_values=flat)
|
| 87 |
-
else:
|
| 88 |
-
outputs = self.model(pixel_values=flat)
|
| 89 |
-
|
| 90 |
-
last = outputs.last_hidden_state # [B*T, 1 + R + N_patch, D]
|
| 91 |
-
num_prefix = 1 + self.num_register_tokens
|
| 92 |
-
patches = last[:, num_prefix:, :] # [B*T, N_patch, D]
|
| 93 |
-
|
| 94 |
-
gh = h // self.patch_size
|
| 95 |
-
gw = w // self.patch_size
|
| 96 |
-
d = patches.shape[-1]
|
| 97 |
-
# reshape 回网格
|
| 98 |
-
feats = patches.view(b, t, gh, gw, d)
|
| 99 |
-
return feats
|
| 100 |
-
|
| 101 |
-
@torch.no_grad()
|
| 102 |
-
def expected_grid(self, image_h: int, image_w: int) -> tuple[int, int]:
|
| 103 |
-
"""给定输入分辨率,返回 patch 网格大小。"""
|
| 104 |
-
return image_h // self.patch_size, image_w // self.patch_size
|
|
|
|
| 1 |
+
"""DINOv3 ViT-B/16 包装器。
|
| 2 |
+
|
| 3 |
+
- 从本地路径加载(``./dinov3-vitb16-pretrain-lvd1689m``)。
|
| 4 |
+
- 强制使用 ``attn_implementation="sdpa"``。
|
| 5 |
+
- 提供 ``freeze()`` / ``unfreeze()`` 开关。
|
| 6 |
+
- 输入:``[B, T, 3, H, W]``;输出:``[B, T, gh, gw, D]``,其中
|
| 7 |
+
``(gh, gw) = (H/patch, W/patch)``。
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
from transformers import AutoModel
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DINOv3Wrapper(nn.Module):
|
| 20 |
+
"""加载并包装 DINOv3 ViT-B/16,输出 patch 网格特征。"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
pretrained_path: str | Path = "./dinov3-vitb16-pretrain-lvd1689m",
|
| 25 |
+
attn_implementation: str = "sdpa",
|
| 26 |
+
freeze: bool = True,
|
| 27 |
+
) -> None:
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.pretrained_path = str(pretrained_path)
|
| 30 |
+
# 加载 HuggingFace transformers 中的 DINOv3 ViT 模型
|
| 31 |
+
self.model = AutoModel.from_pretrained(
|
| 32 |
+
self.pretrained_path,
|
| 33 |
+
attn_implementation=attn_implementation,
|
| 34 |
+
)
|
| 35 |
+
cfg = self.model.config
|
| 36 |
+
self.hidden_size = cfg.hidden_size
|
| 37 |
+
self.patch_size = cfg.patch_size
|
| 38 |
+
self.num_register_tokens = getattr(cfg, "num_register_tokens", 4)
|
| 39 |
+
|
| 40 |
+
if freeze:
|
| 41 |
+
self.freeze()
|
| 42 |
+
self._frozen = freeze
|
| 43 |
+
|
| 44 |
+
def freeze(self) -> None:
|
| 45 |
+
"""冻结所有参数。"""
|
| 46 |
+
for p in self.model.parameters():
|
| 47 |
+
p.requires_grad_(False)
|
| 48 |
+
self.model.eval()
|
| 49 |
+
self._frozen = True
|
| 50 |
+
|
| 51 |
+
def unfreeze(self) -> None:
|
| 52 |
+
"""解冻全部参��(Stage2 微调)。"""
|
| 53 |
+
for p in self.model.parameters():
|
| 54 |
+
p.requires_grad_(True)
|
| 55 |
+
self.model.train()
|
| 56 |
+
self._frozen = False
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def is_frozen(self) -> bool:
|
| 60 |
+
return self._frozen
|
| 61 |
+
|
| 62 |
+
def train(self, mode: bool = True) -> "DINOv3Wrapper":
|
| 63 |
+
"""覆盖 train():冻结时永远保持 eval 模式(避免 BN/Dropout 漂移)。"""
|
| 64 |
+
super().train(mode)
|
| 65 |
+
if self._frozen:
|
| 66 |
+
self.model.eval()
|
| 67 |
+
return self
|
| 68 |
+
|
| 69 |
+
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
| 70 |
+
"""
|
| 71 |
+
参数
|
| 72 |
+
----
|
| 73 |
+
images : ``[B, T, 3, H, W]``,已按 DINOv3 mean/std 归一化。
|
| 74 |
+
|
| 75 |
+
返回
|
| 76 |
+
----
|
| 77 |
+
feats : ``[B, T, gh, gw, D]``。
|
| 78 |
+
"""
|
| 79 |
+
b, t, c, h, w = images.shape
|
| 80 |
+
# DINOv3 forward 接受 ``pixel_values: [B*T, 3, H, W]``
|
| 81 |
+
flat = images.view(b * t, c, h, w)
|
| 82 |
+
|
| 83 |
+
# 冻结分支无需梯度,节省显存与时间
|
| 84 |
+
if self._frozen:
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
outputs = self.model(pixel_values=flat)
|
| 87 |
+
else:
|
| 88 |
+
outputs = self.model(pixel_values=flat)
|
| 89 |
+
|
| 90 |
+
last = outputs.last_hidden_state # [B*T, 1 + R + N_patch, D]
|
| 91 |
+
num_prefix = 1 + self.num_register_tokens
|
| 92 |
+
patches = last[:, num_prefix:, :] # [B*T, N_patch, D]
|
| 93 |
+
|
| 94 |
+
gh = h // self.patch_size
|
| 95 |
+
gw = w // self.patch_size
|
| 96 |
+
d = patches.shape[-1]
|
| 97 |
+
# reshape 回网格
|
| 98 |
+
feats = patches.view(b, t, gh, gw, d)
|
| 99 |
+
return feats
|
| 100 |
+
|
| 101 |
+
@torch.no_grad()
|
| 102 |
+
def expected_grid(self, image_h: int, image_w: int) -> tuple[int, int]:
|
| 103 |
+
"""给定输入分辨率,返回 patch 网格大小。"""
|
| 104 |
+
return image_h // self.patch_size, image_w // self.patch_size
|
src/wjad/heads/__init__.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
-
"""检测+未来轨迹头 + 控制头。"""
|
| 2 |
-
|
| 3 |
-
from .detection_traj import DetectionTrajHead, DetectionTrajOutput
|
| 4 |
-
from .control import ControlHead, ControlOutput
|
| 5 |
-
|
| 6 |
-
__all__ = [
|
| 7 |
-
"DetectionTrajHead",
|
| 8 |
-
"DetectionTrajOutput",
|
| 9 |
-
"ControlHead",
|
| 10 |
-
"ControlOutput",
|
| 11 |
-
]
|
|
|
|
| 1 |
+
"""检测+未来轨迹头 + 控制头。"""
|
| 2 |
+
|
| 3 |
+
from .detection_traj import DetectionTrajHead, DetectionTrajOutput
|
| 4 |
+
from .control import ControlHead, ControlOutput
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"DetectionTrajHead",
|
| 8 |
+
"DetectionTrajOutput",
|
| 9 |
+
"ControlHead",
|
| 10 |
+
"ControlOutput",
|
| 11 |
+
]
|
src/wjad/heads/control.py
CHANGED
|
@@ -1,100 +1,100 @@
|
|
| 1 |
-
"""自车控制头:24 个控制 token 输出未来轨迹与全局动作。
|
| 2 |
-
|
| 3 |
-
token 切分:
|
| 4 |
-
- 12 个轨迹 token → 经 MLP 上采样到 24 帧自车 ``(x, y, yaw)`` 的 ``μ`` / ``log_sigma``
|
| 5 |
-
- 12 个动作 token → 第 0 个解码 ``(steer, throttle, brake)`` 的 ``μ`` / ``log_sigma``
|
| 6 |
-
(其余作为冗余 / 未来扩展,暂不监督)
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
from __future__ import annotations
|
| 10 |
-
|
| 11 |
-
from dataclasses import dataclass
|
| 12 |
-
|
| 13 |
-
import torch
|
| 14 |
-
import torch.nn as nn
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
@dataclass
|
| 18 |
-
class ControlOutput:
|
| 19 |
-
ego_traj_mu: torch.Tensor # [B, T_future, 3]
|
| 20 |
-
ego_traj_log_sigma: torch.Tensor # [B, T_future, 3]
|
| 21 |
-
action_mu: torch.Tensor # [B, action_dim]
|
| 22 |
-
action_log_sigma: torch.Tensor # [B, action_dim]
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class ControlHead(nn.Module):
|
| 26 |
-
def __init__(
|
| 27 |
-
self,
|
| 28 |
-
in_dim: int = 768,
|
| 29 |
-
hidden_size: int = 384,
|
| 30 |
-
num_traj_tokens: int = 12,
|
| 31 |
-
num_action_tokens: int = 12,
|
| 32 |
-
ego_traj_horizon: int = 24,
|
| 33 |
-
ego_traj_dim: int = 3,
|
| 34 |
-
action_dim: int = 3,
|
| 35 |
-
log_sigma_clamp: tuple[float, float] = (-7.0, 7.0),
|
| 36 |
-
) -> None:
|
| 37 |
-
super().__init__()
|
| 38 |
-
assert num_traj_tokens + num_action_tokens == 24, "控制 token 总数应为 24"
|
| 39 |
-
self.in_dim = in_dim
|
| 40 |
-
self.hidden = hidden_size
|
| 41 |
-
self.num_traj_tokens = num_traj_tokens
|
| 42 |
-
self.num_action_tokens = num_action_tokens
|
| 43 |
-
self.ego_traj_horizon = ego_traj_horizon
|
| 44 |
-
self.ego_traj_dim = ego_traj_dim
|
| 45 |
-
self.action_dim = action_dim
|
| 46 |
-
self.log_sigma_clamp = log_sigma_clamp
|
| 47 |
-
|
| 48 |
-
self.norm = nn.LayerNorm(in_dim)
|
| 49 |
-
# 轨迹分支:把 12 个轨迹 token 拼平 -> MLP -> 24*3 (mu) + 24*3 (logsig)
|
| 50 |
-
self.traj_proj = nn.Sequential(
|
| 51 |
-
nn.Linear(num_traj_tokens * in_dim, hidden_size),
|
| 52 |
-
nn.GELU(),
|
| 53 |
-
nn.Linear(hidden_size, hidden_size),
|
| 54 |
-
nn.GELU(),
|
| 55 |
-
)
|
| 56 |
-
self.traj_mu_head = nn.Linear(hidden_size, ego_traj_horizon * ego_traj_dim)
|
| 57 |
-
self.traj_logsig_head = nn.Linear(hidden_size, ego_traj_horizon * ego_traj_dim)
|
| 58 |
-
|
| 59 |
-
# 动作分支:取第 0 个动作 token
|
| 60 |
-
self.action_proj = nn.Sequential(
|
| 61 |
-
nn.Linear(in_dim, hidden_size),
|
| 62 |
-
nn.GELU(),
|
| 63 |
-
nn.Linear(hidden_size, hidden_size),
|
| 64 |
-
nn.GELU(),
|
| 65 |
-
)
|
| 66 |
-
self.action_mu_head = nn.Linear(hidden_size, action_dim)
|
| 67 |
-
self.action_logsig_head = nn.Linear(hidden_size, action_dim)
|
| 68 |
-
|
| 69 |
-
self._init_heads()
|
| 70 |
-
|
| 71 |
-
def _init_heads(self) -> None:
|
| 72 |
-
for m in [self.traj_mu_head, self.traj_logsig_head, self.action_mu_head, self.action_logsig_head]:
|
| 73 |
-
nn.init.zeros_(m.weight)
|
| 74 |
-
nn.init.zeros_(m.bias)
|
| 75 |
-
|
| 76 |
-
def forward(self, ctrl_tokens: torch.Tensor) -> ControlOutput:
|
| 77 |
-
"""
|
| 78 |
-
ctrl_tokens : ``[B, 24, in_dim]``
|
| 79 |
-
"""
|
| 80 |
-
b, n, d = ctrl_tokens.shape
|
| 81 |
-
assert n == self.num_traj_tokens + self.num_action_tokens
|
| 82 |
-
x = self.norm(ctrl_tokens)
|
| 83 |
-
traj_feats = x[:, : self.num_traj_tokens, :].reshape(b, -1)
|
| 84 |
-
action_feats = x[:, self.num_traj_tokens, :] # 取第一个动作 token
|
| 85 |
-
|
| 86 |
-
traj_h = self.traj_proj(traj_feats)
|
| 87 |
-
traj_mu = self.traj_mu_head(traj_h).view(b, self.ego_traj_horizon, self.ego_traj_dim)
|
| 88 |
-
traj_logsig = self.traj_logsig_head(traj_h).view(b, self.ego_traj_horizon, self.ego_traj_dim)
|
| 89 |
-
traj_logsig = traj_logsig.clamp(*self.log_sigma_clamp)
|
| 90 |
-
|
| 91 |
-
action_h = self.action_proj(action_feats)
|
| 92 |
-
action_mu = self.action_mu_head(action_h)
|
| 93 |
-
action_logsig = self.action_logsig_head(action_h).clamp(*self.log_sigma_clamp)
|
| 94 |
-
|
| 95 |
-
return ControlOutput(
|
| 96 |
-
ego_traj_mu=traj_mu,
|
| 97 |
-
ego_traj_log_sigma=traj_logsig,
|
| 98 |
-
action_mu=action_mu,
|
| 99 |
-
action_log_sigma=action_logsig,
|
| 100 |
-
)
|
|
|
|
| 1 |
+
"""自车控制头:24 个控制 token 输出未来轨迹与全局动作。
|
| 2 |
+
|
| 3 |
+
token 切分:
|
| 4 |
+
- 12 个轨迹 token → 经 MLP 上采样到 24 帧自车 ``(x, y, yaw)`` 的 ``μ`` / ``log_sigma``
|
| 5 |
+
- 12 个动作 token → 第 0 个解码 ``(steer, throttle, brake)`` 的 ``μ`` / ``log_sigma``
|
| 6 |
+
(其余作为冗余 / 未来扩展,暂不监督)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class ControlOutput:
|
| 19 |
+
ego_traj_mu: torch.Tensor # [B, T_future, 3]
|
| 20 |
+
ego_traj_log_sigma: torch.Tensor # [B, T_future, 3]
|
| 21 |
+
action_mu: torch.Tensor # [B, action_dim]
|
| 22 |
+
action_log_sigma: torch.Tensor # [B, action_dim]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ControlHead(nn.Module):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
in_dim: int = 768,
|
| 29 |
+
hidden_size: int = 384,
|
| 30 |
+
num_traj_tokens: int = 12,
|
| 31 |
+
num_action_tokens: int = 12,
|
| 32 |
+
ego_traj_horizon: int = 24,
|
| 33 |
+
ego_traj_dim: int = 3,
|
| 34 |
+
action_dim: int = 3,
|
| 35 |
+
log_sigma_clamp: tuple[float, float] = (-7.0, 7.0),
|
| 36 |
+
) -> None:
|
| 37 |
+
super().__init__()
|
| 38 |
+
assert num_traj_tokens + num_action_tokens == 24, "控制 token 总数应为 24"
|
| 39 |
+
self.in_dim = in_dim
|
| 40 |
+
self.hidden = hidden_size
|
| 41 |
+
self.num_traj_tokens = num_traj_tokens
|
| 42 |
+
self.num_action_tokens = num_action_tokens
|
| 43 |
+
self.ego_traj_horizon = ego_traj_horizon
|
| 44 |
+
self.ego_traj_dim = ego_traj_dim
|
| 45 |
+
self.action_dim = action_dim
|
| 46 |
+
self.log_sigma_clamp = log_sigma_clamp
|
| 47 |
+
|
| 48 |
+
self.norm = nn.LayerNorm(in_dim)
|
| 49 |
+
# 轨迹分支:把 12 个轨迹 token 拼平 -> MLP -> 24*3 (mu) + 24*3 (logsig)
|
| 50 |
+
self.traj_proj = nn.Sequential(
|
| 51 |
+
nn.Linear(num_traj_tokens * in_dim, hidden_size),
|
| 52 |
+
nn.GELU(),
|
| 53 |
+
nn.Linear(hidden_size, hidden_size),
|
| 54 |
+
nn.GELU(),
|
| 55 |
+
)
|
| 56 |
+
self.traj_mu_head = nn.Linear(hidden_size, ego_traj_horizon * ego_traj_dim)
|
| 57 |
+
self.traj_logsig_head = nn.Linear(hidden_size, ego_traj_horizon * ego_traj_dim)
|
| 58 |
+
|
| 59 |
+
# 动作分支:取第 0 个动作 token
|
| 60 |
+
self.action_proj = nn.Sequential(
|
| 61 |
+
nn.Linear(in_dim, hidden_size),
|
| 62 |
+
nn.GELU(),
|
| 63 |
+
nn.Linear(hidden_size, hidden_size),
|
| 64 |
+
nn.GELU(),
|
| 65 |
+
)
|
| 66 |
+
self.action_mu_head = nn.Linear(hidden_size, action_dim)
|
| 67 |
+
self.action_logsig_head = nn.Linear(hidden_size, action_dim)
|
| 68 |
+
|
| 69 |
+
self._init_heads()
|
| 70 |
+
|
| 71 |
+
def _init_heads(self) -> None:
|
| 72 |
+
for m in [self.traj_mu_head, self.traj_logsig_head, self.action_mu_head, self.action_logsig_head]:
|
| 73 |
+
nn.init.zeros_(m.weight)
|
| 74 |
+
nn.init.zeros_(m.bias)
|
| 75 |
+
|
| 76 |
+
def forward(self, ctrl_tokens: torch.Tensor) -> ControlOutput:
|
| 77 |
+
"""
|
| 78 |
+
ctrl_tokens : ``[B, 24, in_dim]``
|
| 79 |
+
"""
|
| 80 |
+
b, n, d = ctrl_tokens.shape
|
| 81 |
+
assert n == self.num_traj_tokens + self.num_action_tokens
|
| 82 |
+
x = self.norm(ctrl_tokens)
|
| 83 |
+
traj_feats = x[:, : self.num_traj_tokens, :].reshape(b, -1)
|
| 84 |
+
action_feats = x[:, self.num_traj_tokens, :] # 取第一个动作 token
|
| 85 |
+
|
| 86 |
+
traj_h = self.traj_proj(traj_feats)
|
| 87 |
+
traj_mu = self.traj_mu_head(traj_h).view(b, self.ego_traj_horizon, self.ego_traj_dim)
|
| 88 |
+
traj_logsig = self.traj_logsig_head(traj_h).view(b, self.ego_traj_horizon, self.ego_traj_dim)
|
| 89 |
+
traj_logsig = traj_logsig.clamp(*self.log_sigma_clamp)
|
| 90 |
+
|
| 91 |
+
action_h = self.action_proj(action_feats)
|
| 92 |
+
action_mu = self.action_mu_head(action_h)
|
| 93 |
+
action_logsig = self.action_logsig_head(action_h).clamp(*self.log_sigma_clamp)
|
| 94 |
+
|
| 95 |
+
return ControlOutput(
|
| 96 |
+
ego_traj_mu=traj_mu,
|
| 97 |
+
ego_traj_log_sigma=traj_logsig,
|
| 98 |
+
action_mu=action_mu,
|
| 99 |
+
action_log_sigma=action_logsig,
|
| 100 |
+
)
|
src/wjad/heads/detection_traj.py
CHANGED
|
@@ -1,106 +1,106 @@
|
|
| 1 |
-
"""统一的检测 + 未来轨迹头。
|
| 2 |
-
|
| 3 |
-
每个检测 query token 输出:
|
| 4 |
-
- ``cls`` : ``[num_classes]`` logits(含 background)
|
| 5 |
-
- ``is_dynamic`` : 二分类 logit(是否为运动类,用于 mask 轨迹分支损失)
|
| 6 |
-
- ``box3d_mu`` / ``box3d_log_sigma`` : ``[7]``(x, y, z, l, w, h, yaw)
|
| 7 |
-
- ``traj_mu`` / ``traj_log_sigma`` : ``[traj_horizon, 3]``(dx, dy, dyaw)
|
| 8 |
-
|
| 9 |
-
匈牙利匹配代价由外部损失模块构造(用 cls focal 代价 + L1(box μ) + GIoU3D 近似)。
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
from __future__ import annotations
|
| 13 |
-
|
| 14 |
-
from dataclasses import dataclass
|
| 15 |
-
|
| 16 |
-
import torch
|
| 17 |
-
import torch.nn as nn
|
| 18 |
-
|
| 19 |
-
from ..modules.ffn import SwiGLUFFN
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
@dataclass
|
| 23 |
-
class DetectionTrajOutput:
|
| 24 |
-
"""检测+未来轨迹头输出。"""
|
| 25 |
-
|
| 26 |
-
cls_logits: torch.Tensor # [B, Q, num_classes]
|
| 27 |
-
is_dynamic_logit: torch.Tensor # [B, Q]
|
| 28 |
-
box3d_mu: torch.Tensor # [B, Q, 7]
|
| 29 |
-
box3d_log_sigma: torch.Tensor # [B, Q, 7]
|
| 30 |
-
traj_mu: torch.Tensor # [B, Q, T_future, 3]
|
| 31 |
-
traj_log_sigma: torch.Tensor # [B, Q, T_future, 3]
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
class DetectionTrajHead(nn.Module):
|
| 35 |
-
def __init__(
|
| 36 |
-
self,
|
| 37 |
-
in_dim: int = 768,
|
| 38 |
-
hidden_size: int = 384,
|
| 39 |
-
num_classes: int = 22,
|
| 40 |
-
box_dim: int = 7,
|
| 41 |
-
traj_horizon: int = 24,
|
| 42 |
-
traj_dim: int = 3,
|
| 43 |
-
log_sigma_clamp: tuple[float, float] = (-7.0, 7.0),
|
| 44 |
-
) -> None:
|
| 45 |
-
super().__init__()
|
| 46 |
-
self.num_classes = num_classes
|
| 47 |
-
self.box_dim = box_dim
|
| 48 |
-
self.traj_horizon = traj_horizon
|
| 49 |
-
self.traj_dim = traj_dim
|
| 50 |
-
self.log_sigma_clamp = log_sigma_clamp
|
| 51 |
-
|
| 52 |
-
# 共享主干 MLP(PreNorm + SwiGLU)
|
| 53 |
-
self.norm = nn.LayerNorm(in_dim)
|
| 54 |
-
self.shared = nn.Sequential(
|
| 55 |
-
nn.Linear(in_dim, hidden_size),
|
| 56 |
-
nn.GELU(),
|
| 57 |
-
nn.Linear(hidden_size, hidden_size),
|
| 58 |
-
nn.GELU(),
|
| 59 |
-
)
|
| 60 |
-
|
| 61 |
-
# 各分支
|
| 62 |
-
self.cls_head = nn.Linear(hidden_size, num_classes)
|
| 63 |
-
self.isdyn_head = nn.Linear(hidden_size, 1)
|
| 64 |
-
self.box_mu_head = nn.Linear(hidden_size, box_dim)
|
| 65 |
-
self.box_logsig_head = nn.Linear(hidden_size, box_dim)
|
| 66 |
-
self.traj_mu_head = nn.Linear(hidden_size, traj_horizon * traj_dim)
|
| 67 |
-
self.traj_logsig_head = nn.Linear(hidden_size, traj_horizon * traj_dim)
|
| 68 |
-
|
| 69 |
-
self._init_heads()
|
| 70 |
-
|
| 71 |
-
def _init_heads(self) -> None:
|
| 72 |
-
# 让 box / traj 输出初始 ≈ 0;log_sigma 初始 ≈ 0 → sigma ≈ 1
|
| 73 |
-
for m in [self.box_mu_head, self.box_logsig_head, self.traj_mu_head, self.traj_logsig_head]:
|
| 74 |
-
nn.init.zeros_(m.weight)
|
| 75 |
-
nn.init.zeros_(m.bias)
|
| 76 |
-
# cls / isdyn 用小初始化即可(避免 background 一开始全选)
|
| 77 |
-
nn.init.normal_(self.cls_head.weight, std=0.01)
|
| 78 |
-
nn.init.zeros_(self.cls_head.bias)
|
| 79 |
-
nn.init.zeros_(self.isdyn_head.weight)
|
| 80 |
-
nn.init.zeros_(self.isdyn_head.bias)
|
| 81 |
-
|
| 82 |
-
def forward(self, det_tokens: torch.Tensor) -> DetectionTrajOutput:
|
| 83 |
-
"""
|
| 84 |
-
det_tokens : ``[B, Q, in_dim]``,主干输出中切出来的检测 token。
|
| 85 |
-
"""
|
| 86 |
-
b, q, _ = det_tokens.shape
|
| 87 |
-
feats = self.shared(self.norm(det_tokens))
|
| 88 |
-
|
| 89 |
-
cls_logits = self.cls_head(feats)
|
| 90 |
-
isdyn_logit = self.isdyn_head(feats).squeeze(-1)
|
| 91 |
-
|
| 92 |
-
box_mu = self.box_mu_head(feats)
|
| 93 |
-
box_logsig = self.box_logsig_head(feats).clamp(*self.log_sigma_clamp)
|
| 94 |
-
|
| 95 |
-
traj_mu = self.traj_mu_head(feats).view(b, q, self.traj_horizon, self.traj_dim)
|
| 96 |
-
traj_logsig = self.traj_logsig_head(feats).view(b, q, self.traj_horizon, self.traj_dim)
|
| 97 |
-
traj_logsig = traj_logsig.clamp(*self.log_sigma_clamp)
|
| 98 |
-
|
| 99 |
-
return DetectionTrajOutput(
|
| 100 |
-
cls_logits=cls_logits,
|
| 101 |
-
is_dynamic_logit=isdyn_logit,
|
| 102 |
-
box3d_mu=box_mu,
|
| 103 |
-
box3d_log_sigma=box_logsig,
|
| 104 |
-
traj_mu=traj_mu,
|
| 105 |
-
traj_log_sigma=traj_logsig,
|
| 106 |
-
)
|
|
|
|
| 1 |
+
"""统一的检测 + 未来轨迹头。
|
| 2 |
+
|
| 3 |
+
每个检测 query token 输出:
|
| 4 |
+
- ``cls`` : ``[num_classes]`` logits(含 background)
|
| 5 |
+
- ``is_dynamic`` : 二分类 logit(是否为运动类,用于 mask 轨迹分支损失)
|
| 6 |
+
- ``box3d_mu`` / ``box3d_log_sigma`` : ``[7]``(x, y, z, l, w, h, yaw)
|
| 7 |
+
- ``traj_mu`` / ``traj_log_sigma`` : ``[traj_horizon, 3]``(dx, dy, dyaw)
|
| 8 |
+
|
| 9 |
+
匈牙利匹配代价由外部损失模块构造(用 cls focal 代价 + L1(box μ) + GIoU3D 近似)。
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
|
| 19 |
+
from ..modules.ffn import SwiGLUFFN
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class DetectionTrajOutput:
|
| 24 |
+
"""检测+未来轨迹头输出。"""
|
| 25 |
+
|
| 26 |
+
cls_logits: torch.Tensor # [B, Q, num_classes]
|
| 27 |
+
is_dynamic_logit: torch.Tensor # [B, Q]
|
| 28 |
+
box3d_mu: torch.Tensor # [B, Q, 7]
|
| 29 |
+
box3d_log_sigma: torch.Tensor # [B, Q, 7]
|
| 30 |
+
traj_mu: torch.Tensor # [B, Q, T_future, 3]
|
| 31 |
+
traj_log_sigma: torch.Tensor # [B, Q, T_future, 3]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class DetectionTrajHead(nn.Module):
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
in_dim: int = 768,
|
| 38 |
+
hidden_size: int = 384,
|
| 39 |
+
num_classes: int = 22,
|
| 40 |
+
box_dim: int = 7,
|
| 41 |
+
traj_horizon: int = 24,
|
| 42 |
+
traj_dim: int = 3,
|
| 43 |
+
log_sigma_clamp: tuple[float, float] = (-7.0, 7.0),
|
| 44 |
+
) -> None:
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.num_classes = num_classes
|
| 47 |
+
self.box_dim = box_dim
|
| 48 |
+
self.traj_horizon = traj_horizon
|
| 49 |
+
self.traj_dim = traj_dim
|
| 50 |
+
self.log_sigma_clamp = log_sigma_clamp
|
| 51 |
+
|
| 52 |
+
# 共享主干 MLP(PreNorm + SwiGLU)
|
| 53 |
+
self.norm = nn.LayerNorm(in_dim)
|
| 54 |
+
self.shared = nn.Sequential(
|
| 55 |
+
nn.Linear(in_dim, hidden_size),
|
| 56 |
+
nn.GELU(),
|
| 57 |
+
nn.Linear(hidden_size, hidden_size),
|
| 58 |
+
nn.GELU(),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# 各分支
|
| 62 |
+
self.cls_head = nn.Linear(hidden_size, num_classes)
|
| 63 |
+
self.isdyn_head = nn.Linear(hidden_size, 1)
|
| 64 |
+
self.box_mu_head = nn.Linear(hidden_size, box_dim)
|
| 65 |
+
self.box_logsig_head = nn.Linear(hidden_size, box_dim)
|
| 66 |
+
self.traj_mu_head = nn.Linear(hidden_size, traj_horizon * traj_dim)
|
| 67 |
+
self.traj_logsig_head = nn.Linear(hidden_size, traj_horizon * traj_dim)
|
| 68 |
+
|
| 69 |
+
self._init_heads()
|
| 70 |
+
|
| 71 |
+
def _init_heads(self) -> None:
|
| 72 |
+
# 让 box / traj 输出初始 ≈ 0;log_sigma 初始 ≈ 0 → sigma ≈ 1
|
| 73 |
+
for m in [self.box_mu_head, self.box_logsig_head, self.traj_mu_head, self.traj_logsig_head]:
|
| 74 |
+
nn.init.zeros_(m.weight)
|
| 75 |
+
nn.init.zeros_(m.bias)
|
| 76 |
+
# cls / isdyn 用小初始化即可(避免 background 一开始全选)
|
| 77 |
+
nn.init.normal_(self.cls_head.weight, std=0.01)
|
| 78 |
+
nn.init.zeros_(self.cls_head.bias)
|
| 79 |
+
nn.init.zeros_(self.isdyn_head.weight)
|
| 80 |
+
nn.init.zeros_(self.isdyn_head.bias)
|
| 81 |
+
|
| 82 |
+
def forward(self, det_tokens: torch.Tensor) -> DetectionTrajOutput:
|
| 83 |
+
"""
|
| 84 |
+
det_tokens : ``[B, Q, in_dim]``,主干输出中切出来的检测 token。
|
| 85 |
+
"""
|
| 86 |
+
b, q, _ = det_tokens.shape
|
| 87 |
+
feats = self.shared(self.norm(det_tokens))
|
| 88 |
+
|
| 89 |
+
cls_logits = self.cls_head(feats)
|
| 90 |
+
isdyn_logit = self.isdyn_head(feats).squeeze(-1)
|
| 91 |
+
|
| 92 |
+
box_mu = self.box_mu_head(feats)
|
| 93 |
+
box_logsig = self.box_logsig_head(feats).clamp(*self.log_sigma_clamp)
|
| 94 |
+
|
| 95 |
+
traj_mu = self.traj_mu_head(feats).view(b, q, self.traj_horizon, self.traj_dim)
|
| 96 |
+
traj_logsig = self.traj_logsig_head(feats).view(b, q, self.traj_horizon, self.traj_dim)
|
| 97 |
+
traj_logsig = traj_logsig.clamp(*self.log_sigma_clamp)
|
| 98 |
+
|
| 99 |
+
return DetectionTrajOutput(
|
| 100 |
+
cls_logits=cls_logits,
|
| 101 |
+
is_dynamic_logit=isdyn_logit,
|
| 102 |
+
box3d_mu=box_mu,
|
| 103 |
+
box3d_log_sigma=box_logsig,
|
| 104 |
+
traj_mu=traj_mu,
|
| 105 |
+
traj_log_sigma=traj_logsig,
|
| 106 |
+
)
|
src/wjad/losses/__init__.py
CHANGED
|
@@ -1,24 +1,24 @@
|
|
| 1 |
-
"""损失函数集合。"""
|
| 2 |
-
|
| 3 |
-
from .nll import gaussian_nll
|
| 4 |
-
from .detection import (
|
| 5 |
-
HungarianMatcher,
|
| 6 |
-
detection_losses,
|
| 7 |
-
DetectionLossOutputs,
|
| 8 |
-
)
|
| 9 |
-
from .trajectory import object_traj_nll
|
| 10 |
-
from .control import ego_traj_nll, action_nll
|
| 11 |
-
from .moe_aux import moe_load_balance_and_boundary
|
| 12 |
-
from .calib_reg import calibration_regularization
|
| 13 |
-
|
| 14 |
-
__all__ = [
|
| 15 |
-
"gaussian_nll",
|
| 16 |
-
"HungarianMatcher",
|
| 17 |
-
"detection_losses",
|
| 18 |
-
"DetectionLossOutputs",
|
| 19 |
-
"object_traj_nll",
|
| 20 |
-
"ego_traj_nll",
|
| 21 |
-
"action_nll",
|
| 22 |
-
"moe_load_balance_and_boundary",
|
| 23 |
-
"calibration_regularization",
|
| 24 |
-
]
|
|
|
|
| 1 |
+
"""损失函数集合。"""
|
| 2 |
+
|
| 3 |
+
from .nll import gaussian_nll
|
| 4 |
+
from .detection import (
|
| 5 |
+
HungarianMatcher,
|
| 6 |
+
detection_losses,
|
| 7 |
+
DetectionLossOutputs,
|
| 8 |
+
)
|
| 9 |
+
from .trajectory import object_traj_nll
|
| 10 |
+
from .control import ego_traj_nll, action_nll
|
| 11 |
+
from .moe_aux import moe_load_balance_and_boundary
|
| 12 |
+
from .calib_reg import calibration_regularization
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"gaussian_nll",
|
| 16 |
+
"HungarianMatcher",
|
| 17 |
+
"detection_losses",
|
| 18 |
+
"DetectionLossOutputs",
|
| 19 |
+
"object_traj_nll",
|
| 20 |
+
"ego_traj_nll",
|
| 21 |
+
"action_nll",
|
| 22 |
+
"moe_load_balance_and_boundary",
|
| 23 |
+
"calibration_regularization",
|
| 24 |
+
]
|
src/wjad/losses/calib_reg.py
CHANGED
|
@@ -1,21 +1,21 @@
|
|
| 1 |
-
"""在线校准残差正则:
|
| 2 |
-
|
| 3 |
-
- L2 残差先验:让早期 / 一般情况下残差接近 0;
|
| 4 |
-
- Tanh 边界正则:``residual^2`` 在 Tanh 上抑制饱和。
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from __future__ import annotations
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def calibration_regularization(
|
| 13 |
-
ego_residual: torch.Tensor,
|
| 14 |
-
intr_residual: torch.Tensor,
|
| 15 |
-
extr_residual: torch.Tensor,
|
| 16 |
-
l2_weight: float = 1.0,
|
| 17 |
-
) -> torch.Tensor:
|
| 18 |
-
e = ego_residual.pow(2).mean()
|
| 19 |
-
i = intr_residual.pow(2).mean()
|
| 20 |
-
x = extr_residual.pow(2).mean()
|
| 21 |
-
return l2_weight * (e + i + x) / 3.0
|
|
|
|
| 1 |
+
"""在线校准残差正则:
|
| 2 |
+
|
| 3 |
+
- L2 残差先验:让早期 / 一般情况下残差接近 0;
|
| 4 |
+
- Tanh 边界正则:``residual^2`` 在 Tanh 上抑制饱和。
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def calibration_regularization(
|
| 13 |
+
ego_residual: torch.Tensor,
|
| 14 |
+
intr_residual: torch.Tensor,
|
| 15 |
+
extr_residual: torch.Tensor,
|
| 16 |
+
l2_weight: float = 1.0,
|
| 17 |
+
) -> torch.Tensor:
|
| 18 |
+
e = ego_residual.pow(2).mean()
|
| 19 |
+
i = intr_residual.pow(2).mean()
|
| 20 |
+
x = extr_residual.pow(2).mean()
|
| 21 |
+
return l2_weight * (e + i + x) / 3.0
|
src/wjad/losses/control.py
CHANGED
|
@@ -1,25 +1,25 @@
|
|
| 1 |
-
"""自车未来轨迹 + 全局动作的 NLL。"""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
-
from .nll import gaussian_nll
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def ego_traj_nll(
|
| 11 |
-
pred_mu: torch.Tensor, # [B, T, 3]
|
| 12 |
-
pred_log_sigma: torch.Tensor, # [B, T, 3]
|
| 13 |
-
target: torch.Tensor, # [B, T, 3] (symlog 空间)
|
| 14 |
-
valid: torch.Tensor | None = None,
|
| 15 |
-
) -> torch.Tensor:
|
| 16 |
-
return gaussian_nll(pred_mu, pred_log_sigma, target, valid_mask=valid)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def action_nll(
|
| 20 |
-
pred_mu: torch.Tensor, # [B, A]
|
| 21 |
-
pred_log_sigma: torch.Tensor, # [B, A]
|
| 22 |
-
target: torch.Tensor, # [B, A]
|
| 23 |
-
valid: torch.Tensor | None = None,
|
| 24 |
-
) -> torch.Tensor:
|
| 25 |
-
return gaussian_nll(pred_mu, pred_log_sigma, target, valid_mask=valid)
|
|
|
|
| 1 |
+
"""自车未来轨迹 + 全局动作的 NLL。"""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from .nll import gaussian_nll
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def ego_traj_nll(
|
| 11 |
+
pred_mu: torch.Tensor, # [B, T, 3]
|
| 12 |
+
pred_log_sigma: torch.Tensor, # [B, T, 3]
|
| 13 |
+
target: torch.Tensor, # [B, T, 3] (symlog 空间)
|
| 14 |
+
valid: torch.Tensor | None = None,
|
| 15 |
+
) -> torch.Tensor:
|
| 16 |
+
return gaussian_nll(pred_mu, pred_log_sigma, target, valid_mask=valid)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def action_nll(
|
| 20 |
+
pred_mu: torch.Tensor, # [B, A]
|
| 21 |
+
pred_log_sigma: torch.Tensor, # [B, A]
|
| 22 |
+
target: torch.Tensor, # [B, A]
|
| 23 |
+
valid: torch.Tensor | None = None,
|
| 24 |
+
) -> torch.Tensor:
|
| 25 |
+
return gaussian_nll(pred_mu, pred_log_sigma, target, valid_mask=valid)
|
src/wjad/losses/detection.py
CHANGED
|
@@ -1,213 +1,213 @@
|
|
| 1 |
-
"""检测损失:匈牙利匹配 + 分类 focal + 3D box NLL + 近似 GIoU3D。
|
| 2 |
-
|
| 3 |
-
GIoU3D 用 BEV 平面 + 高度近似:用 ``(x, y, l, w, yaw)`` 计算 BEV IoU/GIoU,
|
| 4 |
-
``z, h`` 在长度上做线性 IoU;最终 GIoU3D ≈ GIoU_BEV * (h_overlap / h_union)。
|
| 5 |
-
此近似在大多数 AV 公开 benchmark 已被广泛使用(速度快、可微)。
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
from __future__ import annotations
|
| 9 |
-
|
| 10 |
-
from dataclasses import dataclass
|
| 11 |
-
|
| 12 |
-
import torch
|
| 13 |
-
import torch.nn.functional as F
|
| 14 |
-
from scipy.optimize import linear_sum_assignment
|
| 15 |
-
|
| 16 |
-
from .nll import gaussian_nll
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
@dataclass
|
| 20 |
-
class DetectionLossOutputs:
|
| 21 |
-
"""检测损失分项与匹配结果。"""
|
| 22 |
-
|
| 23 |
-
cls_loss: torch.Tensor
|
| 24 |
-
box_nll: torch.Tensor
|
| 25 |
-
giou_loss: torch.Tensor
|
| 26 |
-
isdyn_loss: torch.Tensor
|
| 27 |
-
matched_indices: list[tuple[torch.Tensor, torch.Tensor]]
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def focal_loss(
|
| 31 |
-
logits: torch.Tensor,
|
| 32 |
-
target: torch.Tensor,
|
| 33 |
-
alpha: float = 0.25,
|
| 34 |
-
gamma: float = 2.0,
|
| 35 |
-
reduction: str = "mean",
|
| 36 |
-
) -> torch.Tensor:
|
| 37 |
-
"""多类 focal loss。``logits``: ``[N, C]``, ``target``: ``[N]`` (long)。"""
|
| 38 |
-
log_softmax = F.log_softmax(logits, dim=-1)
|
| 39 |
-
pt = log_softmax.exp().gather(-1, target.unsqueeze(-1)).squeeze(-1)
|
| 40 |
-
ce = -log_softmax.gather(-1, target.unsqueeze(-1)).squeeze(-1)
|
| 41 |
-
focal = alpha * (1 - pt).pow(gamma) * ce
|
| 42 |
-
if reduction == "mean":
|
| 43 |
-
return focal.mean()
|
| 44 |
-
if reduction == "sum":
|
| 45 |
-
return focal.sum()
|
| 46 |
-
return focal
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def _bev_giou(
|
| 50 |
-
box_a: torch.Tensor, # [..., 5] (x,y,l,w,yaw)
|
| 51 |
-
box_b: torch.Tensor, # [..., 5]
|
| 52 |
-
) -> torch.Tensor:
|
| 53 |
-
"""BEV 简化 GIoU:取轴对齐的 (x,y,l,w) 包络(忽略 yaw 旋转),
|
| 54 |
-
可微且足够稳定。如需 SOTA 旋转 IoU 可在后续替换为 ``oriented_iou``。
|
| 55 |
-
"""
|
| 56 |
-
cx_a, cy_a, l_a, w_a = box_a[..., 0], box_a[..., 1], box_a[..., 2], box_a[..., 3]
|
| 57 |
-
cx_b, cy_b, l_b, w_b = box_b[..., 0], box_b[..., 1], box_b[..., 2], box_b[..., 3]
|
| 58 |
-
a_x1, a_y1 = cx_a - l_a / 2, cy_a - w_a / 2
|
| 59 |
-
a_x2, a_y2 = cx_a + l_a / 2, cy_a + w_a / 2
|
| 60 |
-
b_x1, b_y1 = cx_b - l_b / 2, cy_b - w_b / 2
|
| 61 |
-
b_x2, b_y2 = cx_b + l_b / 2, cy_b + w_b / 2
|
| 62 |
-
inter_x1 = torch.max(a_x1, b_x1)
|
| 63 |
-
inter_y1 = torch.max(a_y1, b_y1)
|
| 64 |
-
inter_x2 = torch.min(a_x2, b_x2)
|
| 65 |
-
inter_y2 = torch.min(a_y2, b_y2)
|
| 66 |
-
inter_w = (inter_x2 - inter_x1).clamp_min(0)
|
| 67 |
-
inter_h = (inter_y2 - inter_y1).clamp_min(0)
|
| 68 |
-
inter = inter_w * inter_h
|
| 69 |
-
area_a = (l_a * w_a).clamp_min(0)
|
| 70 |
-
area_b = (l_b * w_b).clamp_min(0)
|
| 71 |
-
union = area_a + area_b - inter + 1e-6
|
| 72 |
-
iou = inter / union
|
| 73 |
-
# GIoU enclosure
|
| 74 |
-
enc_x1 = torch.min(a_x1, b_x1)
|
| 75 |
-
enc_y1 = torch.min(a_y1, b_y1)
|
| 76 |
-
enc_x2 = torch.max(a_x2, b_x2)
|
| 77 |
-
enc_y2 = torch.max(a_y2, b_y2)
|
| 78 |
-
enc_area = ((enc_x2 - enc_x1) * (enc_y2 - enc_y1)).clamp_min(1e-6)
|
| 79 |
-
giou = iou - (enc_area - union) / enc_area
|
| 80 |
-
return giou
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def giou3d_approx(box_a: torch.Tensor, box_b: torch.Tensor) -> torch.Tensor:
|
| 84 |
-
"""3D 近似 GIoU。``box``: ``[..., 7]`` (x,y,z,l,w,h,yaw)。"""
|
| 85 |
-
bev = _bev_giou(box_a[..., [0, 1, 3, 4, 6]], box_b[..., [0, 1, 3, 4, 6]])
|
| 86 |
-
z_a, h_a = box_a[..., 2], box_a[..., 5]
|
| 87 |
-
z_b, h_b = box_b[..., 2], box_b[..., 5]
|
| 88 |
-
a_z1, a_z2 = z_a - h_a / 2, z_a + h_a / 2
|
| 89 |
-
b_z1, b_z2 = z_b - h_b / 2, z_b + h_b / 2
|
| 90 |
-
inter_z = (torch.min(a_z2, b_z2) - torch.max(a_z1, b_z1)).clamp_min(0)
|
| 91 |
-
union_z = (h_a + h_b - inter_z).clamp_min(1e-6)
|
| 92 |
-
z_iou = inter_z / union_z
|
| 93 |
-
return bev * z_iou
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
class HungarianMatcher:
|
| 97 |
-
"""DETR 风格匈牙利匹配(CPU 上 ``scipy.linear_sum_assignment``)。"""
|
| 98 |
-
|
| 99 |
-
def __init__(
|
| 100 |
-
self,
|
| 101 |
-
cls_cost: float = 2.0,
|
| 102 |
-
l1_cost: float = 5.0,
|
| 103 |
-
giou_cost: float = 2.0,
|
| 104 |
-
) -> None:
|
| 105 |
-
self.cls_cost = cls_cost
|
| 106 |
-
self.l1_cost = l1_cost
|
| 107 |
-
self.giou_cost = giou_cost
|
| 108 |
-
|
| 109 |
-
@torch.no_grad()
|
| 110 |
-
def match(
|
| 111 |
-
self,
|
| 112 |
-
cls_logits: torch.Tensor, # [B, Q, C]
|
| 113 |
-
box_mu: torch.Tensor, # [B, Q, 7]
|
| 114 |
-
targets: list[dict], # 每个样本: {"labels": [N_i], "boxes": [N_i, 7]}
|
| 115 |
-
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
| 116 |
-
b, q, c = cls_logits.shape
|
| 117 |
-
out = []
|
| 118 |
-
cls_probs = cls_logits.softmax(-1) # [B, Q, C]
|
| 119 |
-
|
| 120 |
-
for i in range(b):
|
| 121 |
-
tgt_labels = targets[i]["labels"]
|
| 122 |
-
tgt_boxes = targets[i]["boxes"]
|
| 123 |
-
n_tgt = tgt_labels.numel()
|
| 124 |
-
if n_tgt == 0:
|
| 125 |
-
out.append((
|
| 126 |
-
torch.empty(0, dtype=torch.long),
|
| 127 |
-
torch.empty(0, dtype=torch.long),
|
| 128 |
-
))
|
| 129 |
-
continue
|
| 130 |
-
|
| 131 |
-
# cost_cls: 越大越好 → 取负
|
| 132 |
-
cost_cls = -cls_probs[i, :, tgt_labels] # [Q, n_tgt]
|
| 133 |
-
cost_l1 = torch.cdist(box_mu[i], tgt_boxes, p=1) # [Q, n_tgt]
|
| 134 |
-
# giou3d: [Q, n_tgt]
|
| 135 |
-
qa = box_mu[i].unsqueeze(1).expand(-1, n_tgt, -1)
|
| 136 |
-
tb = tgt_boxes.unsqueeze(0).expand(q, -1, -1)
|
| 137 |
-
cost_giou = -giou3d_approx(qa, tb)
|
| 138 |
-
|
| 139 |
-
cost = (
|
| 140 |
-
self.cls_cost * cost_cls
|
| 141 |
-
+ self.l1_cost * cost_l1
|
| 142 |
-
+ self.giou_cost * cost_giou
|
| 143 |
-
)
|
| 144 |
-
cost_np = cost.cpu().numpy()
|
| 145 |
-
row, col = linear_sum_assignment(cost_np)
|
| 146 |
-
out.append((torch.as_tensor(row, dtype=torch.long), torch.as_tensor(col, dtype=torch.long)))
|
| 147 |
-
return out
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
def detection_losses(
|
| 151 |
-
cls_logits: torch.Tensor, # [B, Q, C]
|
| 152 |
-
box_mu: torch.Tensor, # [B, Q, 7]
|
| 153 |
-
box_log_sigma: torch.Tensor, # [B, Q, 7]
|
| 154 |
-
isdyn_logit: torch.Tensor, # [B, Q]
|
| 155 |
-
targets: list[dict], # 每样本: {"labels":..., "boxes":..., "is_dynamic":...}
|
| 156 |
-
matcher: HungarianMatcher,
|
| 157 |
-
num_classes: int,
|
| 158 |
-
background_class: int = 0,
|
| 159 |
-
focal_alpha: float = 0.25,
|
| 160 |
-
focal_gamma: float = 2.0,
|
| 161 |
-
) -> DetectionLossOutputs:
|
| 162 |
-
"""返回 cls/box_nll/giou/isdyn 四个标量 loss + 匹配下标。"""
|
| 163 |
-
indices = matcher.match(cls_logits, box_mu, targets)
|
| 164 |
-
b, q, _ = box_mu.shape
|
| 165 |
-
device = box_mu.device
|
| 166 |
-
|
| 167 |
-
# 构造分类目标:所有 query 默认 background;匹配的填对应 label
|
| 168 |
-
target_classes = torch.full((b, q), background_class, dtype=torch.long, device=device)
|
| 169 |
-
target_isdyn = torch.zeros(b, q, dtype=torch.float32, device=device)
|
| 170 |
-
matched_box_pairs = []
|
| 171 |
-
matched_logsig_pairs = []
|
| 172 |
-
matched_target_boxes = []
|
| 173 |
-
|
| 174 |
-
for i, (rows, cols) in enumerate(indices):
|
| 175 |
-
if rows.numel() == 0:
|
| 176 |
-
continue
|
| 177 |
-
rows = rows.to(device)
|
| 178 |
-
cols = cols.to(device)
|
| 179 |
-
target_classes[i, rows] = targets[i]["labels"][cols].to(device)
|
| 180 |
-
target_isdyn[i, rows] = targets[i]["is_dynamic"][cols].to(device).float()
|
| 181 |
-
matched_box_pairs.append(box_mu[i, rows])
|
| 182 |
-
matched_logsig_pairs.append(box_log_sigma[i, rows])
|
| 183 |
-
matched_target_boxes.append(targets[i]["boxes"][cols].to(device))
|
| 184 |
-
|
| 185 |
-
cls_loss = focal_loss(
|
| 186 |
-
cls_logits.view(b * q, -1),
|
| 187 |
-
target_classes.view(-1),
|
| 188 |
-
alpha=focal_alpha,
|
| 189 |
-
gamma=focal_gamma,
|
| 190 |
-
)
|
| 191 |
-
|
| 192 |
-
if matched_box_pairs:
|
| 193 |
-
pred_box = torch.cat(matched_box_pairs, dim=0)
|
| 194 |
-
pred_logsig = torch.cat(matched_logsig_pairs, dim=0)
|
| 195 |
-
gt_box = torch.cat(matched_target_boxes, dim=0)
|
| 196 |
-
box_nll = gaussian_nll(pred_box, pred_logsig, gt_box)
|
| 197 |
-
giou_v = giou3d_approx(pred_box, gt_box)
|
| 198 |
-
giou_loss = (1.0 - giou_v).mean()
|
| 199 |
-
else:
|
| 200 |
-
box_nll = torch.zeros((), device=device)
|
| 201 |
-
giou_loss = torch.zeros((), device=device)
|
| 202 |
-
|
| 203 |
-
isdyn_loss = F.binary_cross_entropy_with_logits(
|
| 204 |
-
isdyn_logit, target_isdyn, reduction="mean"
|
| 205 |
-
)
|
| 206 |
-
|
| 207 |
-
return DetectionLossOutputs(
|
| 208 |
-
cls_loss=cls_loss,
|
| 209 |
-
box_nll=box_nll,
|
| 210 |
-
giou_loss=giou_loss,
|
| 211 |
-
isdyn_loss=isdyn_loss,
|
| 212 |
-
matched_indices=indices,
|
| 213 |
-
)
|
|
|
|
| 1 |
+
"""检测损失:匈牙利匹配 + 分类 focal + 3D box NLL + 近似 GIoU3D。
|
| 2 |
+
|
| 3 |
+
GIoU3D 用 BEV 平面 + 高度近似:用 ``(x, y, l, w, yaw)`` 计算 BEV IoU/GIoU,
|
| 4 |
+
``z, h`` 在长度上做线性 IoU;最终 GIoU3D ≈ GIoU_BEV * (h_overlap / h_union)。
|
| 5 |
+
此近似在大多数 AV 公开 benchmark 已被广泛使用(速度快、可微)。
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from scipy.optimize import linear_sum_assignment
|
| 15 |
+
|
| 16 |
+
from .nll import gaussian_nll
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class DetectionLossOutputs:
|
| 21 |
+
"""检测损失分项与匹配结果。"""
|
| 22 |
+
|
| 23 |
+
cls_loss: torch.Tensor
|
| 24 |
+
box_nll: torch.Tensor
|
| 25 |
+
giou_loss: torch.Tensor
|
| 26 |
+
isdyn_loss: torch.Tensor
|
| 27 |
+
matched_indices: list[tuple[torch.Tensor, torch.Tensor]]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def focal_loss(
|
| 31 |
+
logits: torch.Tensor,
|
| 32 |
+
target: torch.Tensor,
|
| 33 |
+
alpha: float = 0.25,
|
| 34 |
+
gamma: float = 2.0,
|
| 35 |
+
reduction: str = "mean",
|
| 36 |
+
) -> torch.Tensor:
|
| 37 |
+
"""多类 focal loss。``logits``: ``[N, C]``, ``target``: ``[N]`` (long)。"""
|
| 38 |
+
log_softmax = F.log_softmax(logits, dim=-1)
|
| 39 |
+
pt = log_softmax.exp().gather(-1, target.unsqueeze(-1)).squeeze(-1)
|
| 40 |
+
ce = -log_softmax.gather(-1, target.unsqueeze(-1)).squeeze(-1)
|
| 41 |
+
focal = alpha * (1 - pt).pow(gamma) * ce
|
| 42 |
+
if reduction == "mean":
|
| 43 |
+
return focal.mean()
|
| 44 |
+
if reduction == "sum":
|
| 45 |
+
return focal.sum()
|
| 46 |
+
return focal
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _bev_giou(
|
| 50 |
+
box_a: torch.Tensor, # [..., 5] (x,y,l,w,yaw)
|
| 51 |
+
box_b: torch.Tensor, # [..., 5]
|
| 52 |
+
) -> torch.Tensor:
|
| 53 |
+
"""BEV 简化 GIoU:取轴对齐的 (x,y,l,w) 包络(忽略 yaw 旋转),
|
| 54 |
+
可微且足够稳定。如需 SOTA 旋转 IoU 可在后续替换为 ``oriented_iou``。
|
| 55 |
+
"""
|
| 56 |
+
cx_a, cy_a, l_a, w_a = box_a[..., 0], box_a[..., 1], box_a[..., 2], box_a[..., 3]
|
| 57 |
+
cx_b, cy_b, l_b, w_b = box_b[..., 0], box_b[..., 1], box_b[..., 2], box_b[..., 3]
|
| 58 |
+
a_x1, a_y1 = cx_a - l_a / 2, cy_a - w_a / 2
|
| 59 |
+
a_x2, a_y2 = cx_a + l_a / 2, cy_a + w_a / 2
|
| 60 |
+
b_x1, b_y1 = cx_b - l_b / 2, cy_b - w_b / 2
|
| 61 |
+
b_x2, b_y2 = cx_b + l_b / 2, cy_b + w_b / 2
|
| 62 |
+
inter_x1 = torch.max(a_x1, b_x1)
|
| 63 |
+
inter_y1 = torch.max(a_y1, b_y1)
|
| 64 |
+
inter_x2 = torch.min(a_x2, b_x2)
|
| 65 |
+
inter_y2 = torch.min(a_y2, b_y2)
|
| 66 |
+
inter_w = (inter_x2 - inter_x1).clamp_min(0)
|
| 67 |
+
inter_h = (inter_y2 - inter_y1).clamp_min(0)
|
| 68 |
+
inter = inter_w * inter_h
|
| 69 |
+
area_a = (l_a * w_a).clamp_min(0)
|
| 70 |
+
area_b = (l_b * w_b).clamp_min(0)
|
| 71 |
+
union = area_a + area_b - inter + 1e-6
|
| 72 |
+
iou = inter / union
|
| 73 |
+
# GIoU enclosure
|
| 74 |
+
enc_x1 = torch.min(a_x1, b_x1)
|
| 75 |
+
enc_y1 = torch.min(a_y1, b_y1)
|
| 76 |
+
enc_x2 = torch.max(a_x2, b_x2)
|
| 77 |
+
enc_y2 = torch.max(a_y2, b_y2)
|
| 78 |
+
enc_area = ((enc_x2 - enc_x1) * (enc_y2 - enc_y1)).clamp_min(1e-6)
|
| 79 |
+
giou = iou - (enc_area - union) / enc_area
|
| 80 |
+
return giou
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def giou3d_approx(box_a: torch.Tensor, box_b: torch.Tensor) -> torch.Tensor:
|
| 84 |
+
"""3D 近似 GIoU。``box``: ``[..., 7]`` (x,y,z,l,w,h,yaw)。"""
|
| 85 |
+
bev = _bev_giou(box_a[..., [0, 1, 3, 4, 6]], box_b[..., [0, 1, 3, 4, 6]])
|
| 86 |
+
z_a, h_a = box_a[..., 2], box_a[..., 5]
|
| 87 |
+
z_b, h_b = box_b[..., 2], box_b[..., 5]
|
| 88 |
+
a_z1, a_z2 = z_a - h_a / 2, z_a + h_a / 2
|
| 89 |
+
b_z1, b_z2 = z_b - h_b / 2, z_b + h_b / 2
|
| 90 |
+
inter_z = (torch.min(a_z2, b_z2) - torch.max(a_z1, b_z1)).clamp_min(0)
|
| 91 |
+
union_z = (h_a + h_b - inter_z).clamp_min(1e-6)
|
| 92 |
+
z_iou = inter_z / union_z
|
| 93 |
+
return bev * z_iou
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class HungarianMatcher:
|
| 97 |
+
"""DETR 风格匈牙利匹配(CPU 上 ``scipy.linear_sum_assignment``)。"""
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
cls_cost: float = 2.0,
|
| 102 |
+
l1_cost: float = 5.0,
|
| 103 |
+
giou_cost: float = 2.0,
|
| 104 |
+
) -> None:
|
| 105 |
+
self.cls_cost = cls_cost
|
| 106 |
+
self.l1_cost = l1_cost
|
| 107 |
+
self.giou_cost = giou_cost
|
| 108 |
+
|
| 109 |
+
@torch.no_grad()
|
| 110 |
+
def match(
|
| 111 |
+
self,
|
| 112 |
+
cls_logits: torch.Tensor, # [B, Q, C]
|
| 113 |
+
box_mu: torch.Tensor, # [B, Q, 7]
|
| 114 |
+
targets: list[dict], # 每个样本: {"labels": [N_i], "boxes": [N_i, 7]}
|
| 115 |
+
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
| 116 |
+
b, q, c = cls_logits.shape
|
| 117 |
+
out = []
|
| 118 |
+
cls_probs = cls_logits.softmax(-1) # [B, Q, C]
|
| 119 |
+
|
| 120 |
+
for i in range(b):
|
| 121 |
+
tgt_labels = targets[i]["labels"]
|
| 122 |
+
tgt_boxes = targets[i]["boxes"]
|
| 123 |
+
n_tgt = tgt_labels.numel()
|
| 124 |
+
if n_tgt == 0:
|
| 125 |
+
out.append((
|
| 126 |
+
torch.empty(0, dtype=torch.long),
|
| 127 |
+
torch.empty(0, dtype=torch.long),
|
| 128 |
+
))
|
| 129 |
+
continue
|
| 130 |
+
|
| 131 |
+
# cost_cls: 越大越好 → 取负
|
| 132 |
+
cost_cls = -cls_probs[i, :, tgt_labels] # [Q, n_tgt]
|
| 133 |
+
cost_l1 = torch.cdist(box_mu[i], tgt_boxes, p=1) # [Q, n_tgt]
|
| 134 |
+
# giou3d: [Q, n_tgt]
|
| 135 |
+
qa = box_mu[i].unsqueeze(1).expand(-1, n_tgt, -1)
|
| 136 |
+
tb = tgt_boxes.unsqueeze(0).expand(q, -1, -1)
|
| 137 |
+
cost_giou = -giou3d_approx(qa, tb)
|
| 138 |
+
|
| 139 |
+
cost = (
|
| 140 |
+
self.cls_cost * cost_cls
|
| 141 |
+
+ self.l1_cost * cost_l1
|
| 142 |
+
+ self.giou_cost * cost_giou
|
| 143 |
+
)
|
| 144 |
+
cost_np = cost.cpu().numpy()
|
| 145 |
+
row, col = linear_sum_assignment(cost_np)
|
| 146 |
+
out.append((torch.as_tensor(row, dtype=torch.long), torch.as_tensor(col, dtype=torch.long)))
|
| 147 |
+
return out
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def detection_losses(
|
| 151 |
+
cls_logits: torch.Tensor, # [B, Q, C]
|
| 152 |
+
box_mu: torch.Tensor, # [B, Q, 7]
|
| 153 |
+
box_log_sigma: torch.Tensor, # [B, Q, 7]
|
| 154 |
+
isdyn_logit: torch.Tensor, # [B, Q]
|
| 155 |
+
targets: list[dict], # 每样本: {"labels":..., "boxes":..., "is_dynamic":...}
|
| 156 |
+
matcher: HungarianMatcher,
|
| 157 |
+
num_classes: int,
|
| 158 |
+
background_class: int = 0,
|
| 159 |
+
focal_alpha: float = 0.25,
|
| 160 |
+
focal_gamma: float = 2.0,
|
| 161 |
+
) -> DetectionLossOutputs:
|
| 162 |
+
"""返回 cls/box_nll/giou/isdyn 四个标量 loss + 匹配下标。"""
|
| 163 |
+
indices = matcher.match(cls_logits, box_mu, targets)
|
| 164 |
+
b, q, _ = box_mu.shape
|
| 165 |
+
device = box_mu.device
|
| 166 |
+
|
| 167 |
+
# 构造分类目标:所有 query 默认 background;匹配的填对应 label
|
| 168 |
+
target_classes = torch.full((b, q), background_class, dtype=torch.long, device=device)
|
| 169 |
+
target_isdyn = torch.zeros(b, q, dtype=torch.float32, device=device)
|
| 170 |
+
matched_box_pairs = []
|
| 171 |
+
matched_logsig_pairs = []
|
| 172 |
+
matched_target_boxes = []
|
| 173 |
+
|
| 174 |
+
for i, (rows, cols) in enumerate(indices):
|
| 175 |
+
if rows.numel() == 0:
|
| 176 |
+
continue
|
| 177 |
+
rows = rows.to(device)
|
| 178 |
+
cols = cols.to(device)
|
| 179 |
+
target_classes[i, rows] = targets[i]["labels"][cols].to(device)
|
| 180 |
+
target_isdyn[i, rows] = targets[i]["is_dynamic"][cols].to(device).float()
|
| 181 |
+
matched_box_pairs.append(box_mu[i, rows])
|
| 182 |
+
matched_logsig_pairs.append(box_log_sigma[i, rows])
|
| 183 |
+
matched_target_boxes.append(targets[i]["boxes"][cols].to(device))
|
| 184 |
+
|
| 185 |
+
cls_loss = focal_loss(
|
| 186 |
+
cls_logits.view(b * q, -1),
|
| 187 |
+
target_classes.view(-1),
|
| 188 |
+
alpha=focal_alpha,
|
| 189 |
+
gamma=focal_gamma,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if matched_box_pairs:
|
| 193 |
+
pred_box = torch.cat(matched_box_pairs, dim=0)
|
| 194 |
+
pred_logsig = torch.cat(matched_logsig_pairs, dim=0)
|
| 195 |
+
gt_box = torch.cat(matched_target_boxes, dim=0)
|
| 196 |
+
box_nll = gaussian_nll(pred_box, pred_logsig, gt_box)
|
| 197 |
+
giou_v = giou3d_approx(pred_box, gt_box)
|
| 198 |
+
giou_loss = (1.0 - giou_v).mean()
|
| 199 |
+
else:
|
| 200 |
+
box_nll = torch.zeros((), device=device)
|
| 201 |
+
giou_loss = torch.zeros((), device=device)
|
| 202 |
+
|
| 203 |
+
isdyn_loss = F.binary_cross_entropy_with_logits(
|
| 204 |
+
isdyn_logit, target_isdyn, reduction="mean"
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
return DetectionLossOutputs(
|
| 208 |
+
cls_loss=cls_loss,
|
| 209 |
+
box_nll=box_nll,
|
| 210 |
+
giou_loss=giou_loss,
|
| 211 |
+
isdyn_loss=isdyn_loss,
|
| 212 |
+
matched_indices=indices,
|
| 213 |
+
)
|
src/wjad/losses/moe_aux.py
CHANGED
|
@@ -1,33 +1,33 @@
|
|
| 1 |
-
"""MoE 路由的负载均衡 + 边界正则。
|
| 2 |
-
|
| 3 |
-
- 负载均衡:``var_b(probs)`` 跨样本(batch)应较小,避免某些样本恒选少数专家。
|
| 4 |
-
这里用 ``var(probs.mean(dim=0))`` 作为简单负载方差度量;
|
| 5 |
-
也加入 ``mean(probs).std()`` 跨专家的均匀性度量。
|
| 6 |
-
- 边界正则:``mean(logits ** 2)`` 防止路由 logits 越界,使 sigmoid 不饱和。
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
from __future__ import annotations
|
| 10 |
-
|
| 11 |
-
import torch
|
| 12 |
-
|
| 13 |
-
from ..modules.moe import MoEStats
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def moe_load_balance_and_boundary(
|
| 17 |
-
stats_list: list[MoEStats],
|
| 18 |
-
load_balance_weight: float = 1.0,
|
| 19 |
-
boundary_weight: float = 1.0,
|
| 20 |
-
) -> torch.Tensor:
|
| 21 |
-
if not stats_list:
|
| 22 |
-
return torch.zeros((), device="cpu")
|
| 23 |
-
|
| 24 |
-
device = stats_list[0].logits.device
|
| 25 |
-
total = torch.zeros((), device=device)
|
| 26 |
-
for stats in stats_list:
|
| 27 |
-
# boundary:logits^2 的均值
|
| 28 |
-
boundary = stats.logits.pow(2).mean()
|
| 29 |
-
# load balance:跨样本的专家选择频率方差
|
| 30 |
-
avg_per_expert = stats.probs.mean(dim=0) # [num_routed]
|
| 31 |
-
load = avg_per_expert.std()
|
| 32 |
-
total = total + boundary_weight * boundary + load_balance_weight * load
|
| 33 |
-
return total / max(len(stats_list), 1)
|
|
|
|
| 1 |
+
"""MoE 路由的负载均衡 + 边界正则。
|
| 2 |
+
|
| 3 |
+
- 负载均衡:``var_b(probs)`` 跨样本(batch)应较小,避免某些样本恒选少数专家。
|
| 4 |
+
这里用 ``var(probs.mean(dim=0))`` 作为简单负载方差度量;
|
| 5 |
+
也加入 ``mean(probs).std()`` 跨专家的均匀性度量。
|
| 6 |
+
- 边界正则:``mean(logits ** 2)`` 防止路由 logits 越界,使 sigmoid 不饱和。
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from ..modules.moe import MoEStats
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def moe_load_balance_and_boundary(
|
| 17 |
+
stats_list: list[MoEStats],
|
| 18 |
+
load_balance_weight: float = 1.0,
|
| 19 |
+
boundary_weight: float = 1.0,
|
| 20 |
+
) -> torch.Tensor:
|
| 21 |
+
if not stats_list:
|
| 22 |
+
return torch.zeros((), device="cpu")
|
| 23 |
+
|
| 24 |
+
device = stats_list[0].logits.device
|
| 25 |
+
total = torch.zeros((), device=device)
|
| 26 |
+
for stats in stats_list:
|
| 27 |
+
# boundary:logits^2 的均值
|
| 28 |
+
boundary = stats.logits.pow(2).mean()
|
| 29 |
+
# load balance:跨样本的专家选择频率方差
|
| 30 |
+
avg_per_expert = stats.probs.mean(dim=0) # [num_routed]
|
| 31 |
+
load = avg_per_expert.std()
|
| 32 |
+
total = total + boundary_weight * boundary + load_balance_weight * load
|
| 33 |
+
return total / max(len(stats_list), 1)
|
src/wjad/losses/nll.py
CHANGED
|
@@ -1,47 +1,47 @@
|
|
| 1 |
-
"""高斯 NLL 置信度损失。
|
| 2 |
-
|
| 3 |
-
公式: ``L = 0.5 * ((y - μ) * exp(-log_sigma)) ** 2 + log_sigma + 0.5 * log(2π)``
|
| 4 |
-
为节省常数项,实际实现忽略 ``0.5 * log(2π)``(不影响优化)。
|
| 5 |
-
|
| 6 |
-
支持可选的 ``valid_mask``:在 mask=False 处忽略对应元素。
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
from __future__ import annotations
|
| 10 |
-
|
| 11 |
-
import torch
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def gaussian_nll(
|
| 15 |
-
mu: torch.Tensor,
|
| 16 |
-
log_sigma: torch.Tensor,
|
| 17 |
-
target: torch.Tensor,
|
| 18 |
-
valid_mask: torch.Tensor | None = None,
|
| 19 |
-
reduction: str = "mean",
|
| 20 |
-
) -> torch.Tensor:
|
| 21 |
-
"""高斯负对数似然。
|
| 22 |
-
|
| 23 |
-
所有张量同形状(mask 是其 broadcast 子集,可在最后维省略 features)。
|
| 24 |
-
"""
|
| 25 |
-
diff = target - mu
|
| 26 |
-
inv_sigma = torch.exp(-log_sigma)
|
| 27 |
-
nll = 0.5 * (diff * inv_sigma).pow(2) + log_sigma
|
| 28 |
-
|
| 29 |
-
if valid_mask is not None:
|
| 30 |
-
# broadcast 到 nll 的 shape
|
| 31 |
-
while valid_mask.dim() < nll.dim():
|
| 32 |
-
valid_mask = valid_mask.unsqueeze(-1)
|
| 33 |
-
valid_mask = valid_mask.to(nll.dtype)
|
| 34 |
-
nll = nll * valid_mask
|
| 35 |
-
if reduction == "mean":
|
| 36 |
-
denom = valid_mask.sum().clamp_min(1.0)
|
| 37 |
-
return nll.sum() / denom
|
| 38 |
-
elif reduction == "sum":
|
| 39 |
-
return nll.sum()
|
| 40 |
-
else:
|
| 41 |
-
return nll
|
| 42 |
-
|
| 43 |
-
if reduction == "mean":
|
| 44 |
-
return nll.mean()
|
| 45 |
-
if reduction == "sum":
|
| 46 |
-
return nll.sum()
|
| 47 |
-
return nll
|
|
|
|
| 1 |
+
"""高斯 NLL 置信度损失。
|
| 2 |
+
|
| 3 |
+
公式: ``L = 0.5 * ((y - μ) * exp(-log_sigma)) ** 2 + log_sigma + 0.5 * log(2π)``
|
| 4 |
+
为节省常数项,实际实现忽略 ``0.5 * log(2π)``(不影响优化)。
|
| 5 |
+
|
| 6 |
+
支持可选的 ``valid_mask``:在 mask=False 处忽略对应元素。
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def gaussian_nll(
|
| 15 |
+
mu: torch.Tensor,
|
| 16 |
+
log_sigma: torch.Tensor,
|
| 17 |
+
target: torch.Tensor,
|
| 18 |
+
valid_mask: torch.Tensor | None = None,
|
| 19 |
+
reduction: str = "mean",
|
| 20 |
+
) -> torch.Tensor:
|
| 21 |
+
"""高斯负对数似然。
|
| 22 |
+
|
| 23 |
+
所有张量同形状(mask 是其 broadcast 子集,可在最后维省略 features)。
|
| 24 |
+
"""
|
| 25 |
+
diff = target - mu
|
| 26 |
+
inv_sigma = torch.exp(-log_sigma)
|
| 27 |
+
nll = 0.5 * (diff * inv_sigma).pow(2) + log_sigma
|
| 28 |
+
|
| 29 |
+
if valid_mask is not None:
|
| 30 |
+
# broadcast 到 nll 的 shape
|
| 31 |
+
while valid_mask.dim() < nll.dim():
|
| 32 |
+
valid_mask = valid_mask.unsqueeze(-1)
|
| 33 |
+
valid_mask = valid_mask.to(nll.dtype)
|
| 34 |
+
nll = nll * valid_mask
|
| 35 |
+
if reduction == "mean":
|
| 36 |
+
denom = valid_mask.sum().clamp_min(1.0)
|
| 37 |
+
return nll.sum() / denom
|
| 38 |
+
elif reduction == "sum":
|
| 39 |
+
return nll.sum()
|
| 40 |
+
else:
|
| 41 |
+
return nll
|
| 42 |
+
|
| 43 |
+
if reduction == "mean":
|
| 44 |
+
return nll.mean()
|
| 45 |
+
if reduction == "sum":
|
| 46 |
+
return nll.sum()
|
| 47 |
+
return nll
|
src/wjad/losses/trajectory.py
CHANGED
|
@@ -1,43 +1,43 @@
|
|
| 1 |
-
"""动态目标未来 24 帧轨迹 NLL(仅在匹配到运动类的 query 上启用)。"""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
-
from .nll import gaussian_nll
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def object_traj_nll(
|
| 11 |
-
traj_mu: torch.Tensor, # [B, Q, T, 3]
|
| 12 |
-
traj_log_sigma: torch.Tensor, # [B, Q, T, 3]
|
| 13 |
-
matched_indices: list[tuple[torch.Tensor, torch.Tensor]],
|
| 14 |
-
targets: list[dict], # 每样本 {"future_traj":[N,T,3], "future_valid":[N,T], "is_dynamic":[N]}
|
| 15 |
-
) -> torch.Tensor:
|
| 16 |
-
"""对 ``is_dynamic == True`` 的匹配项求 traj NLL;其余忽略。
|
| 17 |
-
|
| 18 |
-
返回标量 loss。
|
| 19 |
-
"""
|
| 20 |
-
device = traj_mu.device
|
| 21 |
-
b = traj_mu.shape[0]
|
| 22 |
-
losses = []
|
| 23 |
-
for i in range(b):
|
| 24 |
-
rows, cols = matched_indices[i]
|
| 25 |
-
if rows.numel() == 0:
|
| 26 |
-
continue
|
| 27 |
-
rows = rows.to(device)
|
| 28 |
-
cols = cols.to(device)
|
| 29 |
-
is_dyn = targets[i]["is_dynamic"][cols].to(device).bool()
|
| 30 |
-
if not is_dyn.any():
|
| 31 |
-
continue
|
| 32 |
-
sel_rows = rows[is_dyn]
|
| 33 |
-
sel_cols = cols[is_dyn]
|
| 34 |
-
pred_mu = traj_mu[i, sel_rows] # [n, T, 3]
|
| 35 |
-
pred_logsig = traj_log_sigma[i, sel_rows] # [n, T, 3]
|
| 36 |
-
gt_traj = targets[i]["future_traj"][sel_cols].to(device)
|
| 37 |
-
valid = targets[i]["future_valid"][sel_cols].to(device).bool() # [n, T]
|
| 38 |
-
# 在 (T, 3) 维上算 NLL,valid mask 只到 T 维
|
| 39 |
-
nll = gaussian_nll(pred_mu, pred_logsig, gt_traj, valid_mask=valid)
|
| 40 |
-
losses.append(nll)
|
| 41 |
-
if not losses:
|
| 42 |
-
return torch.zeros((), device=device)
|
| 43 |
-
return torch.stack(losses).mean()
|
|
|
|
| 1 |
+
"""动态目标未来 24 帧轨迹 NLL(仅在匹配到运动类的 query 上启用)。"""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from .nll import gaussian_nll
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def object_traj_nll(
|
| 11 |
+
traj_mu: torch.Tensor, # [B, Q, T, 3]
|
| 12 |
+
traj_log_sigma: torch.Tensor, # [B, Q, T, 3]
|
| 13 |
+
matched_indices: list[tuple[torch.Tensor, torch.Tensor]],
|
| 14 |
+
targets: list[dict], # 每样本 {"future_traj":[N,T,3], "future_valid":[N,T], "is_dynamic":[N]}
|
| 15 |
+
) -> torch.Tensor:
|
| 16 |
+
"""对 ``is_dynamic == True`` 的匹配项求 traj NLL;其余忽略。
|
| 17 |
+
|
| 18 |
+
返回标量 loss。
|
| 19 |
+
"""
|
| 20 |
+
device = traj_mu.device
|
| 21 |
+
b = traj_mu.shape[0]
|
| 22 |
+
losses = []
|
| 23 |
+
for i in range(b):
|
| 24 |
+
rows, cols = matched_indices[i]
|
| 25 |
+
if rows.numel() == 0:
|
| 26 |
+
continue
|
| 27 |
+
rows = rows.to(device)
|
| 28 |
+
cols = cols.to(device)
|
| 29 |
+
is_dyn = targets[i]["is_dynamic"][cols].to(device).bool()
|
| 30 |
+
if not is_dyn.any():
|
| 31 |
+
continue
|
| 32 |
+
sel_rows = rows[is_dyn]
|
| 33 |
+
sel_cols = cols[is_dyn]
|
| 34 |
+
pred_mu = traj_mu[i, sel_rows] # [n, T, 3]
|
| 35 |
+
pred_logsig = traj_log_sigma[i, sel_rows] # [n, T, 3]
|
| 36 |
+
gt_traj = targets[i]["future_traj"][sel_cols].to(device)
|
| 37 |
+
valid = targets[i]["future_valid"][sel_cols].to(device).bool() # [n, T]
|
| 38 |
+
# 在 (T, 3) 维上算 NLL,valid mask 只到 T 维
|
| 39 |
+
nll = gaussian_nll(pred_mu, pred_logsig, gt_traj, valid_mask=valid)
|
| 40 |
+
losses.append(nll)
|
| 41 |
+
if not losses:
|
| 42 |
+
return torch.zeros((), device=device)
|
| 43 |
+
return torch.stack(losses).mean()
|
src/wjad/model.py
CHANGED
|
@@ -1,289 +1,289 @@
|
|
| 1 |
-
"""端到端自动驾驶模型 E2EAVModel。
|
| 2 |
-
|
| 3 |
-
forward 流程
|
| 4 |
-
1. ``DINOv3`` 提取 8 帧 patch 特征。
|
| 5 |
-
2. ``OnlineCalibration`` 用原始 ego/intr/extr (symlog) + DINOv3 patch 作 KV,
|
| 6 |
-
输出 symlog 空间残差,叠加并 symexp 还原得到 corrected_*。
|
| 7 |
-
3. 用 corrected_intr / corrected_extr / corrected_ego 计算
|
| 8 |
-
- 每 token 的自车系单位射线(仅用于视觉 token 的 RoPE 第一组头)。
|
| 9 |
-
- 8 个 ego token(symlog 后线性投影)。
|
| 10 |
-
4. 2×2×2 时空压缩 -> 1536 视觉 token。
|
| 11 |
-
5. 拼接 [vision(1536) | ego(8) | det(1024) | ctrl(24) | extra(256)] = 2848 token。
|
| 12 |
-
非视觉切片各自加可学习 PE。
|
| 13 |
-
6. 18 层主干(仅视觉切片应用 3D RoPE)。
|
| 14 |
-
7. 切片送入 ``DetectionTrajHead`` 与 ``ControlHead``。
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
from __future__ import annotations
|
| 18 |
-
|
| 19 |
-
from dataclasses import dataclass
|
| 20 |
-
|
| 21 |
-
import torch
|
| 22 |
-
import torch.nn as nn
|
| 23 |
-
import torch.nn.functional as F
|
| 24 |
-
|
| 25 |
-
from .backbone import Backbone, BackboneOutput
|
| 26 |
-
from .calibration import OnlineCalibration, CalibrationOutput
|
| 27 |
-
from .encoders import DINOv3Wrapper
|
| 28 |
-
from .heads import (
|
| 29 |
-
ControlHead,
|
| 30 |
-
ControlOutput,
|
| 31 |
-
DetectionTrajHead,
|
| 32 |
-
DetectionTrajOutput,
|
| 33 |
-
)
|
| 34 |
-
from .modules.learned_pe import LearnedTokenPE
|
| 35 |
-
from .modules.normalization import symlog
|
| 36 |
-
from .modules.pos_encoding import RoPE3D
|
| 37 |
-
from .modules.rays import compute_ego_rays
|
| 38 |
-
from .modules.temporal_compress import TemporalCompress2x2x2
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
@dataclass
|
| 42 |
-
class E2EOutput:
|
| 43 |
-
"""模型完整输出。"""
|
| 44 |
-
|
| 45 |
-
detection: DetectionTrajOutput
|
| 46 |
-
control: ControlOutput
|
| 47 |
-
backbone_out: BackboneOutput
|
| 48 |
-
calibration: CalibrationOutput
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
class E2EAVModel(nn.Module):
|
| 52 |
-
def __init__(
|
| 53 |
-
self,
|
| 54 |
-
dinov3_path: str = "./dinov3-vitb16-pretrain-lvd1689m",
|
| 55 |
-
backbone_dim: int = 768,
|
| 56 |
-
num_heads: int = 12,
|
| 57 |
-
num_dense_layers: int = 9,
|
| 58 |
-
num_moe_layers: int = 9,
|
| 59 |
-
num_routed_experts: int = 7,
|
| 60 |
-
num_shared_experts: int = 1,
|
| 61 |
-
topk_experts: int = 3,
|
| 62 |
-
ffn_mult: int = 4,
|
| 63 |
-
# token 数量
|
| 64 |
-
num_history_frames: int = 8,
|
| 65 |
-
num_detection_tokens: int = 1024,
|
| 66 |
-
num_control_tokens: int = 24,
|
| 67 |
-
num_ego_tokens: int = 8,
|
| 68 |
-
num_extra_tokens: int = 256,
|
| 69 |
-
# 输入分辨率
|
| 70 |
-
image_h: int = 384,
|
| 71 |
-
image_w: int = 1024,
|
| 72 |
-
patch_size: int = 16,
|
| 73 |
-
# 头超参
|
| 74 |
-
num_classes: int = 22,
|
| 75 |
-
traj_horizon: int = 24,
|
| 76 |
-
det_head_hidden: int = 384,
|
| 77 |
-
ctrl_head_hidden: int = 384,
|
| 78 |
-
# 校准
|
| 79 |
-
calib_dim: int = 256,
|
| 80 |
-
calib_num_query: int = 256,
|
| 81 |
-
calib_num_blocks: int = 2,
|
| 82 |
-
calib_num_self_per_block: int = 2,
|
| 83 |
-
calib_num_heads: int = 8,
|
| 84 |
-
calib_residual_range: float = 0.1,
|
| 85 |
-
calib_intr_dim: int = 11,
|
| 86 |
-
# DINOv3
|
| 87 |
-
freeze_dinov3: bool = True,
|
| 88 |
-
attn_implementation: str = "sdpa",
|
| 89 |
-
) -> None:
|
| 90 |
-
super().__init__()
|
| 91 |
-
self.image_h = image_h
|
| 92 |
-
self.image_w = image_w
|
| 93 |
-
self.patch_size = patch_size
|
| 94 |
-
self.num_history = num_history_frames
|
| 95 |
-
self.num_det = num_detection_tokens
|
| 96 |
-
self.num_ctrl = num_control_tokens
|
| 97 |
-
self.num_ego = num_ego_tokens
|
| 98 |
-
self.num_extra = num_extra_tokens
|
| 99 |
-
|
| 100 |
-
# === 1) DINOv3 ===
|
| 101 |
-
self.dinov3 = DINOv3Wrapper(
|
| 102 |
-
pretrained_path=dinov3_path,
|
| 103 |
-
attn_implementation=attn_implementation,
|
| 104 |
-
freeze=freeze_dinov3,
|
| 105 |
-
)
|
| 106 |
-
dino_dim = self.dinov3.hidden_size
|
| 107 |
-
|
| 108 |
-
# === 2) 在线校准 ===
|
| 109 |
-
self.calib = OnlineCalibration(
|
| 110 |
-
dino_dim=dino_dim,
|
| 111 |
-
hidden_dim=calib_dim,
|
| 112 |
-
num_query_tokens=calib_num_query,
|
| 113 |
-
num_blocks=calib_num_blocks,
|
| 114 |
-
num_self_attn_per_block=calib_num_self_per_block,
|
| 115 |
-
num_heads=calib_num_heads,
|
| 116 |
-
residual_range=calib_residual_range,
|
| 117 |
-
num_history_frames=num_history_frames,
|
| 118 |
-
intr_dim=calib_intr_dim,
|
| 119 |
-
)
|
| 120 |
-
|
| 121 |
-
# === 3) 时空压缩 ===
|
| 122 |
-
self.compress = TemporalCompress2x2x2(dim=dino_dim)
|
| 123 |
-
# patch 网格大小(必须能被 2 整除)
|
| 124 |
-
self.gh = image_h // patch_size
|
| 125 |
-
self.gw = image_w // patch_size
|
| 126 |
-
|
| 127 |
-
# === 4) 各类 token + 可学习 PE ===
|
| 128 |
-
self.ego_proj = nn.Linear(6, backbone_dim) # 6D pose -> backbone dim
|
| 129 |
-
self.det_tokens = nn.Parameter(torch.empty(num_detection_tokens, backbone_dim))
|
| 130 |
-
nn.init.trunc_normal_(self.det_tokens, std=0.02)
|
| 131 |
-
self.ctrl_tokens = nn.Parameter(torch.empty(num_control_tokens, backbone_dim))
|
| 132 |
-
nn.init.trunc_normal_(self.ctrl_tokens, std=0.02)
|
| 133 |
-
self.extra_tokens = nn.Parameter(torch.empty(num_extra_tokens, backbone_dim))
|
| 134 |
-
nn.init.trunc_normal_(self.extra_tokens, std=0.02)
|
| 135 |
-
|
| 136 |
-
self.ego_pe = LearnedTokenPE(num_ego_tokens, backbone_dim)
|
| 137 |
-
self.det_pe = LearnedTokenPE(num_detection_tokens, backbone_dim)
|
| 138 |
-
self.ctrl_pe = LearnedTokenPE(num_control_tokens, backbone_dim)
|
| 139 |
-
self.extra_pe = LearnedTokenPE(num_extra_tokens, backbone_dim)
|
| 140 |
-
|
| 141 |
-
# === 5) RoPE 3D(仅视觉,4 时间帧 × 12 × 32 网格)===
|
| 142 |
-
self.rope = RoPE3D(
|
| 143 |
-
num_heads=num_heads,
|
| 144 |
-
head_dim=backbone_dim // num_heads,
|
| 145 |
-
time_size=num_history_frames // 2,
|
| 146 |
-
height_size=self.gh // 2,
|
| 147 |
-
width_size=self.gw // 2,
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
# === 6) 主干 18 层 ===
|
| 151 |
-
self.backbone = Backbone(
|
| 152 |
-
dim=backbone_dim,
|
| 153 |
-
num_heads=num_heads,
|
| 154 |
-
ffn_mult=ffn_mult,
|
| 155 |
-
num_dense_layers=num_dense_layers,
|
| 156 |
-
num_moe_layers=num_moe_layers,
|
| 157 |
-
num_routed=num_routed_experts,
|
| 158 |
-
num_shared=num_shared_experts,
|
| 159 |
-
topk=topk_experts,
|
| 160 |
-
)
|
| 161 |
-
|
| 162 |
-
# === 7) 头 ===
|
| 163 |
-
self.det_traj_head = DetectionTrajHead(
|
| 164 |
-
in_dim=backbone_dim,
|
| 165 |
-
hidden_size=det_head_hidden,
|
| 166 |
-
num_classes=num_classes,
|
| 167 |
-
traj_horizon=traj_horizon,
|
| 168 |
-
)
|
| 169 |
-
self.ctrl_head = ControlHead(
|
| 170 |
-
in_dim=backbone_dim,
|
| 171 |
-
hidden_size=ctrl_head_hidden,
|
| 172 |
-
num_traj_tokens=12,
|
| 173 |
-
num_action_tokens=num_control_tokens - 12,
|
| 174 |
-
ego_traj_horizon=traj_horizon,
|
| 175 |
-
)
|
| 176 |
-
|
| 177 |
-
# ---------- 工具 ----------
|
| 178 |
-
|
| 179 |
-
@property
|
| 180 |
-
def num_visual_tokens(self) -> int:
|
| 181 |
-
# 2×2×2 压缩后
|
| 182 |
-
return (self.num_history // 2) * (self.gh // 2) * (self.gw // 2)
|
| 183 |
-
|
| 184 |
-
def _build_ego_tokens(self, ego_6d_corrected: torch.Tensor) -> torch.Tensor:
|
| 185 |
-
"""``[B, 8, 6]`` -> symlog -> Linear -> ``[B, 8, D]``。"""
|
| 186 |
-
return self.ego_proj(symlog(ego_6d_corrected))
|
| 187 |
-
|
| 188 |
-
def _build_visual_rays(
|
| 189 |
-
self,
|
| 190 |
-
intr_corrected: torch.Tensor, # [B, calib_intr_dim]
|
| 191 |
-
extr_corrected_se3: torch.Tensor, # [B, 4, 4] cam2vehicle
|
| 192 |
-
compressed_thw: tuple[int, int, int],
|
| 193 |
-
) -> torch.Tensor:
|
| 194 |
-
"""计算压缩后视觉 token 网格的射线方向。
|
| 195 |
-
|
| 196 |
-
在 2×2×2 压缩后,每个视觉 token 对应原 patch 网格的一个 2x2 区域 +
|
| 197 |
-
2 个时间帧。这里取所代表区域的中心像素与"中间时间"的射线作近似,
|
| 198 |
-
所有时间帧取同一个 (h, w) 上的射线(因为相机 pose 在 8 帧间是
|
| 199 |
-
rigid 的相机系;自车运动差异会通过 ego token 传递)。
|
| 200 |
-
"""
|
| 201 |
-
b = intr_corrected.shape[0]
|
| 202 |
-
t_, h_, w_ = compressed_thw
|
| 203 |
-
rays_grid = compute_ego_rays(
|
| 204 |
-
intr_vec=intr_corrected,
|
| 205 |
-
cam2vehicle=extr_corrected_se3,
|
| 206 |
-
height=self.image_h,
|
| 207 |
-
width=self.image_w,
|
| 208 |
-
grid_h=h_,
|
| 209 |
-
grid_w=w_,
|
| 210 |
-
device=intr_corrected.device,
|
| 211 |
-
dtype=intr_corrected.dtype,
|
| 212 |
-
) # [B, h_, w_, 3]
|
| 213 |
-
# 复制到时间维:[B, T_, h_, w_, 3] -> flatten 为 [B, N_v, 3]
|
| 214 |
-
rays = rays_grid.unsqueeze(1).expand(-1, t_, -1, -1, -1).contiguous()
|
| 215 |
-
rays = rays.reshape(b, t_ * h_ * w_, 3)
|
| 216 |
-
return rays
|
| 217 |
-
|
| 218 |
-
# ---------- 前向 ----------
|
| 219 |
-
|
| 220 |
-
def forward(
|
| 221 |
-
self,
|
| 222 |
-
images: torch.Tensor, # [B, T=8, 3, H, W]
|
| 223 |
-
ego_6d_raw: torch.Tensor, # [B, 8, 6]
|
| 224 |
-
intr_raw: torch.Tensor, # [B, calib_intr_dim],须与构造时一致
|
| 225 |
-
extr_6d_raw: torch.Tensor, # [B, 6]
|
| 226 |
-
) -> E2EOutput:
|
| 227 |
-
b, t, _, h, w = images.shape
|
| 228 |
-
assert t == self.num_history, f"history frames mismatch: {t} vs {self.num_history}"
|
| 229 |
-
|
| 230 |
-
# 1) DINOv3 patch tokens [B, T, gh, gw, D_dino]
|
| 231 |
-
dino_feats = self.dinov3(images)
|
| 232 |
-
|
| 233 |
-
# 2) 校准(symlog 空间残差 + symexp 还原)
|
| 234 |
-
calib_out: CalibrationOutput = self.calib(
|
| 235 |
-
dino_feats=dino_feats,
|
| 236 |
-
ego_raw=ego_6d_raw,
|
| 237 |
-
intr_raw=intr_raw,
|
| 238 |
-
extr_raw=extr_6d_raw,
|
| 239 |
-
)
|
| 240 |
-
corrected_ego = calib_out.corrected_ego
|
| 241 |
-
corrected_intr = calib_out.corrected_intr
|
| 242 |
-
corrected_extr_6d = calib_out.corrected_extr
|
| 243 |
-
|
| 244 |
-
# 3) 把 corrected_extr 6D 转成 4x4
|
| 245 |
-
from .data.se3 import six_d_to_matrix
|
| 246 |
-
cam2veh_corrected = six_d_to_matrix(corrected_extr_6d) # [B, 4, 4]
|
| 247 |
-
|
| 248 |
-
# 4) 2x2x2 时空压缩
|
| 249 |
-
compressed, thw = self.compress(dino_feats) # [B, N_v, D]
|
| 250 |
-
n_v = compressed.shape[1]
|
| 251 |
-
|
| 252 |
-
# 5) 视觉射线(用 corrected_intr / corrected_extr)
|
| 253 |
-
rays = self._build_visual_rays(corrected_intr, cam2veh_corrected, thw)
|
| 254 |
-
rope_cos, rope_sin = self.rope.compute_freqs(rays)
|
| 255 |
-
|
| 256 |
-
# 6) 构造非视觉 token
|
| 257 |
-
ego_tok = self._build_ego_tokens(corrected_ego) # [B, 8, D]
|
| 258 |
-
det_tok = self.det_tokens.unsqueeze(0).expand(b, -1, -1)
|
| 259 |
-
ctrl_tok = self.ctrl_tokens.unsqueeze(0).expand(b, -1, -1)
|
| 260 |
-
extra_tok = self.extra_tokens.unsqueeze(0).expand(b, -1, -1)
|
| 261 |
-
|
| 262 |
-
ego_tok = self.ego_pe(ego_tok)
|
| 263 |
-
det_tok = self.det_pe(det_tok)
|
| 264 |
-
ctrl_tok = self.ctrl_pe(ctrl_tok)
|
| 265 |
-
extra_tok = self.extra_pe(extra_tok)
|
| 266 |
-
|
| 267 |
-
# 7) 拼接序列:[vision | ego | det | ctrl | extra]
|
| 268 |
-
seq = torch.cat([compressed, ego_tok, det_tok, ctrl_tok, extra_tok], dim=1)
|
| 269 |
-
visual_slice = (0, n_v)
|
| 270 |
-
|
| 271 |
-
# 8) 主干
|
| 272 |
-
bb_out = self.backbone(seq, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
|
| 273 |
-
|
| 274 |
-
# 9) 切片送入头
|
| 275 |
-
offset_det = n_v + self.num_ego
|
| 276 |
-
offset_ctrl = offset_det + self.num_det
|
| 277 |
-
|
| 278 |
-
det_feats = bb_out.hidden_states[:, offset_det : offset_det + self.num_det]
|
| 279 |
-
ctrl_feats = bb_out.hidden_states[:, offset_ctrl : offset_ctrl + self.num_ctrl]
|
| 280 |
-
|
| 281 |
-
det_out = self.det_traj_head(det_feats)
|
| 282 |
-
ctrl_out = self.ctrl_head(ctrl_feats)
|
| 283 |
-
|
| 284 |
-
return E2EOutput(
|
| 285 |
-
detection=det_out,
|
| 286 |
-
control=ctrl_out,
|
| 287 |
-
backbone_out=bb_out,
|
| 288 |
-
calibration=calib_out,
|
| 289 |
-
)
|
|
|
|
| 1 |
+
"""端到端自动驾驶模型 E2EAVModel。
|
| 2 |
+
|
| 3 |
+
forward 流程
|
| 4 |
+
1. ``DINOv3`` 提取 8 帧 patch 特征。
|
| 5 |
+
2. ``OnlineCalibration`` 用原始 ego/intr/extr (symlog) + DINOv3 patch 作 KV,
|
| 6 |
+
输出 symlog 空间残差,叠加并 symexp 还原得到 corrected_*。
|
| 7 |
+
3. 用 corrected_intr / corrected_extr / corrected_ego 计算
|
| 8 |
+
- 每 token 的自车系单位射线(仅用于视觉 token 的 RoPE 第一组头)。
|
| 9 |
+
- 8 个 ego token(symlog 后线性投影)。
|
| 10 |
+
4. 2×2×2 时空压缩 -> 1536 视觉 token。
|
| 11 |
+
5. 拼接 [vision(1536) | ego(8) | det(1024) | ctrl(24) | extra(256)] = 2848 token。
|
| 12 |
+
非视觉切片各自加可学习 PE。
|
| 13 |
+
6. 18 层主干(仅视觉切片应用 3D RoPE)。
|
| 14 |
+
7. 切片送入 ``DetectionTrajHead`` 与 ``ControlHead``。
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
|
| 25 |
+
from .backbone import Backbone, BackboneOutput
|
| 26 |
+
from .calibration import OnlineCalibration, CalibrationOutput
|
| 27 |
+
from .encoders import DINOv3Wrapper
|
| 28 |
+
from .heads import (
|
| 29 |
+
ControlHead,
|
| 30 |
+
ControlOutput,
|
| 31 |
+
DetectionTrajHead,
|
| 32 |
+
DetectionTrajOutput,
|
| 33 |
+
)
|
| 34 |
+
from .modules.learned_pe import LearnedTokenPE
|
| 35 |
+
from .modules.normalization import symlog
|
| 36 |
+
from .modules.pos_encoding import RoPE3D
|
| 37 |
+
from .modules.rays import compute_ego_rays
|
| 38 |
+
from .modules.temporal_compress import TemporalCompress2x2x2
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class E2EOutput:
|
| 43 |
+
"""模型完整输出。"""
|
| 44 |
+
|
| 45 |
+
detection: DetectionTrajOutput
|
| 46 |
+
control: ControlOutput
|
| 47 |
+
backbone_out: BackboneOutput
|
| 48 |
+
calibration: CalibrationOutput
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class E2EAVModel(nn.Module):
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
dinov3_path: str = "./dinov3-vitb16-pretrain-lvd1689m",
|
| 55 |
+
backbone_dim: int = 768,
|
| 56 |
+
num_heads: int = 12,
|
| 57 |
+
num_dense_layers: int = 9,
|
| 58 |
+
num_moe_layers: int = 9,
|
| 59 |
+
num_routed_experts: int = 7,
|
| 60 |
+
num_shared_experts: int = 1,
|
| 61 |
+
topk_experts: int = 3,
|
| 62 |
+
ffn_mult: int = 4,
|
| 63 |
+
# token 数量
|
| 64 |
+
num_history_frames: int = 8,
|
| 65 |
+
num_detection_tokens: int = 1024,
|
| 66 |
+
num_control_tokens: int = 24,
|
| 67 |
+
num_ego_tokens: int = 8,
|
| 68 |
+
num_extra_tokens: int = 256,
|
| 69 |
+
# 输入分辨率
|
| 70 |
+
image_h: int = 384,
|
| 71 |
+
image_w: int = 1024,
|
| 72 |
+
patch_size: int = 16,
|
| 73 |
+
# 头超参
|
| 74 |
+
num_classes: int = 22,
|
| 75 |
+
traj_horizon: int = 24,
|
| 76 |
+
det_head_hidden: int = 384,
|
| 77 |
+
ctrl_head_hidden: int = 384,
|
| 78 |
+
# 校准
|
| 79 |
+
calib_dim: int = 256,
|
| 80 |
+
calib_num_query: int = 256,
|
| 81 |
+
calib_num_blocks: int = 2,
|
| 82 |
+
calib_num_self_per_block: int = 2,
|
| 83 |
+
calib_num_heads: int = 8,
|
| 84 |
+
calib_residual_range: float = 0.1,
|
| 85 |
+
calib_intr_dim: int = 11,
|
| 86 |
+
# DINOv3
|
| 87 |
+
freeze_dinov3: bool = True,
|
| 88 |
+
attn_implementation: str = "sdpa",
|
| 89 |
+
) -> None:
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.image_h = image_h
|
| 92 |
+
self.image_w = image_w
|
| 93 |
+
self.patch_size = patch_size
|
| 94 |
+
self.num_history = num_history_frames
|
| 95 |
+
self.num_det = num_detection_tokens
|
| 96 |
+
self.num_ctrl = num_control_tokens
|
| 97 |
+
self.num_ego = num_ego_tokens
|
| 98 |
+
self.num_extra = num_extra_tokens
|
| 99 |
+
|
| 100 |
+
# === 1) DINOv3 ===
|
| 101 |
+
self.dinov3 = DINOv3Wrapper(
|
| 102 |
+
pretrained_path=dinov3_path,
|
| 103 |
+
attn_implementation=attn_implementation,
|
| 104 |
+
freeze=freeze_dinov3,
|
| 105 |
+
)
|
| 106 |
+
dino_dim = self.dinov3.hidden_size
|
| 107 |
+
|
| 108 |
+
# === 2) 在线校准 ===
|
| 109 |
+
self.calib = OnlineCalibration(
|
| 110 |
+
dino_dim=dino_dim,
|
| 111 |
+
hidden_dim=calib_dim,
|
| 112 |
+
num_query_tokens=calib_num_query,
|
| 113 |
+
num_blocks=calib_num_blocks,
|
| 114 |
+
num_self_attn_per_block=calib_num_self_per_block,
|
| 115 |
+
num_heads=calib_num_heads,
|
| 116 |
+
residual_range=calib_residual_range,
|
| 117 |
+
num_history_frames=num_history_frames,
|
| 118 |
+
intr_dim=calib_intr_dim,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# === 3) 时空压缩 ===
|
| 122 |
+
self.compress = TemporalCompress2x2x2(dim=dino_dim)
|
| 123 |
+
# patch 网格大小(必须能被 2 整除)
|
| 124 |
+
self.gh = image_h // patch_size
|
| 125 |
+
self.gw = image_w // patch_size
|
| 126 |
+
|
| 127 |
+
# === 4) 各类 token + 可学习 PE ===
|
| 128 |
+
self.ego_proj = nn.Linear(6, backbone_dim) # 6D pose -> backbone dim
|
| 129 |
+
self.det_tokens = nn.Parameter(torch.empty(num_detection_tokens, backbone_dim))
|
| 130 |
+
nn.init.trunc_normal_(self.det_tokens, std=0.02)
|
| 131 |
+
self.ctrl_tokens = nn.Parameter(torch.empty(num_control_tokens, backbone_dim))
|
| 132 |
+
nn.init.trunc_normal_(self.ctrl_tokens, std=0.02)
|
| 133 |
+
self.extra_tokens = nn.Parameter(torch.empty(num_extra_tokens, backbone_dim))
|
| 134 |
+
nn.init.trunc_normal_(self.extra_tokens, std=0.02)
|
| 135 |
+
|
| 136 |
+
self.ego_pe = LearnedTokenPE(num_ego_tokens, backbone_dim)
|
| 137 |
+
self.det_pe = LearnedTokenPE(num_detection_tokens, backbone_dim)
|
| 138 |
+
self.ctrl_pe = LearnedTokenPE(num_control_tokens, backbone_dim)
|
| 139 |
+
self.extra_pe = LearnedTokenPE(num_extra_tokens, backbone_dim)
|
| 140 |
+
|
| 141 |
+
# === 5) RoPE 3D(仅视觉,4 时间帧 × 12 × 32 网格)===
|
| 142 |
+
self.rope = RoPE3D(
|
| 143 |
+
num_heads=num_heads,
|
| 144 |
+
head_dim=backbone_dim // num_heads,
|
| 145 |
+
time_size=num_history_frames // 2,
|
| 146 |
+
height_size=self.gh // 2,
|
| 147 |
+
width_size=self.gw // 2,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# === 6) 主干 18 层 ===
|
| 151 |
+
self.backbone = Backbone(
|
| 152 |
+
dim=backbone_dim,
|
| 153 |
+
num_heads=num_heads,
|
| 154 |
+
ffn_mult=ffn_mult,
|
| 155 |
+
num_dense_layers=num_dense_layers,
|
| 156 |
+
num_moe_layers=num_moe_layers,
|
| 157 |
+
num_routed=num_routed_experts,
|
| 158 |
+
num_shared=num_shared_experts,
|
| 159 |
+
topk=topk_experts,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# === 7) 头 ===
|
| 163 |
+
self.det_traj_head = DetectionTrajHead(
|
| 164 |
+
in_dim=backbone_dim,
|
| 165 |
+
hidden_size=det_head_hidden,
|
| 166 |
+
num_classes=num_classes,
|
| 167 |
+
traj_horizon=traj_horizon,
|
| 168 |
+
)
|
| 169 |
+
self.ctrl_head = ControlHead(
|
| 170 |
+
in_dim=backbone_dim,
|
| 171 |
+
hidden_size=ctrl_head_hidden,
|
| 172 |
+
num_traj_tokens=12,
|
| 173 |
+
num_action_tokens=num_control_tokens - 12,
|
| 174 |
+
ego_traj_horizon=traj_horizon,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# ---------- 工具 ----------
|
| 178 |
+
|
| 179 |
+
@property
|
| 180 |
+
def num_visual_tokens(self) -> int:
|
| 181 |
+
# 2×2×2 压缩后
|
| 182 |
+
return (self.num_history // 2) * (self.gh // 2) * (self.gw // 2)
|
| 183 |
+
|
| 184 |
+
def _build_ego_tokens(self, ego_6d_corrected: torch.Tensor) -> torch.Tensor:
|
| 185 |
+
"""``[B, 8, 6]`` -> symlog -> Linear -> ``[B, 8, D]``。"""
|
| 186 |
+
return self.ego_proj(symlog(ego_6d_corrected))
|
| 187 |
+
|
| 188 |
+
def _build_visual_rays(
|
| 189 |
+
self,
|
| 190 |
+
intr_corrected: torch.Tensor, # [B, calib_intr_dim]
|
| 191 |
+
extr_corrected_se3: torch.Tensor, # [B, 4, 4] cam2vehicle
|
| 192 |
+
compressed_thw: tuple[int, int, int],
|
| 193 |
+
) -> torch.Tensor:
|
| 194 |
+
"""计算压缩后视觉 token 网格的射线方向。
|
| 195 |
+
|
| 196 |
+
在 2×2×2 压缩后,每个视觉 token 对应原 patch 网格的一个 2x2 区域 +
|
| 197 |
+
2 个时间帧。这里取所代表区域的中心像素与"中间时间"的射线作近似,
|
| 198 |
+
所有时间帧取同一个 (h, w) 上的射线(因为相机 pose 在 8 帧间是
|
| 199 |
+
rigid 的相机系;自车运动差异会通过 ego token 传递)。
|
| 200 |
+
"""
|
| 201 |
+
b = intr_corrected.shape[0]
|
| 202 |
+
t_, h_, w_ = compressed_thw
|
| 203 |
+
rays_grid = compute_ego_rays(
|
| 204 |
+
intr_vec=intr_corrected,
|
| 205 |
+
cam2vehicle=extr_corrected_se3,
|
| 206 |
+
height=self.image_h,
|
| 207 |
+
width=self.image_w,
|
| 208 |
+
grid_h=h_,
|
| 209 |
+
grid_w=w_,
|
| 210 |
+
device=intr_corrected.device,
|
| 211 |
+
dtype=intr_corrected.dtype,
|
| 212 |
+
) # [B, h_, w_, 3]
|
| 213 |
+
# 复制到时间维:[B, T_, h_, w_, 3] -> flatten 为 [B, N_v, 3]
|
| 214 |
+
rays = rays_grid.unsqueeze(1).expand(-1, t_, -1, -1, -1).contiguous()
|
| 215 |
+
rays = rays.reshape(b, t_ * h_ * w_, 3)
|
| 216 |
+
return rays
|
| 217 |
+
|
| 218 |
+
# ---------- 前向 ----------
|
| 219 |
+
|
| 220 |
+
def forward(
|
| 221 |
+
self,
|
| 222 |
+
images: torch.Tensor, # [B, T=8, 3, H, W]
|
| 223 |
+
ego_6d_raw: torch.Tensor, # [B, 8, 6]
|
| 224 |
+
intr_raw: torch.Tensor, # [B, calib_intr_dim],须与构造时一致
|
| 225 |
+
extr_6d_raw: torch.Tensor, # [B, 6]
|
| 226 |
+
) -> E2EOutput:
|
| 227 |
+
b, t, _, h, w = images.shape
|
| 228 |
+
assert t == self.num_history, f"history frames mismatch: {t} vs {self.num_history}"
|
| 229 |
+
|
| 230 |
+
# 1) DINOv3 patch tokens [B, T, gh, gw, D_dino]
|
| 231 |
+
dino_feats = self.dinov3(images)
|
| 232 |
+
|
| 233 |
+
# 2) 校准(symlog 空间残差 + symexp 还原)
|
| 234 |
+
calib_out: CalibrationOutput = self.calib(
|
| 235 |
+
dino_feats=dino_feats,
|
| 236 |
+
ego_raw=ego_6d_raw,
|
| 237 |
+
intr_raw=intr_raw,
|
| 238 |
+
extr_raw=extr_6d_raw,
|
| 239 |
+
)
|
| 240 |
+
corrected_ego = calib_out.corrected_ego
|
| 241 |
+
corrected_intr = calib_out.corrected_intr
|
| 242 |
+
corrected_extr_6d = calib_out.corrected_extr
|
| 243 |
+
|
| 244 |
+
# 3) 把 corrected_extr 6D 转成 4x4
|
| 245 |
+
from .data.se3 import six_d_to_matrix
|
| 246 |
+
cam2veh_corrected = six_d_to_matrix(corrected_extr_6d) # [B, 4, 4]
|
| 247 |
+
|
| 248 |
+
# 4) 2x2x2 时空压缩
|
| 249 |
+
compressed, thw = self.compress(dino_feats) # [B, N_v, D]
|
| 250 |
+
n_v = compressed.shape[1]
|
| 251 |
+
|
| 252 |
+
# 5) 视觉射线(用 corrected_intr / corrected_extr)
|
| 253 |
+
rays = self._build_visual_rays(corrected_intr, cam2veh_corrected, thw)
|
| 254 |
+
rope_cos, rope_sin = self.rope.compute_freqs(rays)
|
| 255 |
+
|
| 256 |
+
# 6) 构造非视觉 token
|
| 257 |
+
ego_tok = self._build_ego_tokens(corrected_ego) # [B, 8, D]
|
| 258 |
+
det_tok = self.det_tokens.unsqueeze(0).expand(b, -1, -1)
|
| 259 |
+
ctrl_tok = self.ctrl_tokens.unsqueeze(0).expand(b, -1, -1)
|
| 260 |
+
extra_tok = self.extra_tokens.unsqueeze(0).expand(b, -1, -1)
|
| 261 |
+
|
| 262 |
+
ego_tok = self.ego_pe(ego_tok)
|
| 263 |
+
det_tok = self.det_pe(det_tok)
|
| 264 |
+
ctrl_tok = self.ctrl_pe(ctrl_tok)
|
| 265 |
+
extra_tok = self.extra_pe(extra_tok)
|
| 266 |
+
|
| 267 |
+
# 7) 拼接序列:[vision | ego | det | ctrl | extra]
|
| 268 |
+
seq = torch.cat([compressed, ego_tok, det_tok, ctrl_tok, extra_tok], dim=1)
|
| 269 |
+
visual_slice = (0, n_v)
|
| 270 |
+
|
| 271 |
+
# 8) 主干
|
| 272 |
+
bb_out = self.backbone(seq, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
|
| 273 |
+
|
| 274 |
+
# 9) 切片送入头
|
| 275 |
+
offset_det = n_v + self.num_ego
|
| 276 |
+
offset_ctrl = offset_det + self.num_det
|
| 277 |
+
|
| 278 |
+
det_feats = bb_out.hidden_states[:, offset_det : offset_det + self.num_det]
|
| 279 |
+
ctrl_feats = bb_out.hidden_states[:, offset_ctrl : offset_ctrl + self.num_ctrl]
|
| 280 |
+
|
| 281 |
+
det_out = self.det_traj_head(det_feats)
|
| 282 |
+
ctrl_out = self.ctrl_head(ctrl_feats)
|
| 283 |
+
|
| 284 |
+
return E2EOutput(
|
| 285 |
+
detection=det_out,
|
| 286 |
+
control=ctrl_out,
|
| 287 |
+
backbone_out=bb_out,
|
| 288 |
+
calibration=calib_out,
|
| 289 |
+
)
|
src/wjad/modules/__init__.py
CHANGED
|
@@ -1,28 +1,28 @@
|
|
| 1 |
-
"""公用算子模块集合。"""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
from .ffn import SwiGLUFFN
|
| 6 |
-
from .gate_attention import GateSelfAttention, GateCrossAttention
|
| 7 |
-
from .moe import MoEBlock, PerLayerExperts
|
| 8 |
-
from .normalization import symlog, symexp
|
| 9 |
-
from .pos_encoding import RoPE3D, build_rope_freqs
|
| 10 |
-
from .learned_pe import LearnedTokenPE
|
| 11 |
-
from .rays import FThetaCamera, compute_ego_rays
|
| 12 |
-
from .temporal_compress import TemporalCompress2x2x2
|
| 13 |
-
|
| 14 |
-
__all__ = [
|
| 15 |
-
"SwiGLUFFN",
|
| 16 |
-
"GateSelfAttention",
|
| 17 |
-
"GateCrossAttention",
|
| 18 |
-
"MoEBlock",
|
| 19 |
-
"PerLayerExperts",
|
| 20 |
-
"symlog",
|
| 21 |
-
"symexp",
|
| 22 |
-
"RoPE3D",
|
| 23 |
-
"build_rope_freqs",
|
| 24 |
-
"LearnedTokenPE",
|
| 25 |
-
"FThetaCamera",
|
| 26 |
-
"compute_ego_rays",
|
| 27 |
-
"TemporalCompress2x2x2",
|
| 28 |
-
]
|
|
|
|
| 1 |
+
"""公用算子模块集合。"""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from .ffn import SwiGLUFFN
|
| 6 |
+
from .gate_attention import GateSelfAttention, GateCrossAttention
|
| 7 |
+
from .moe import MoEBlock, PerLayerExperts
|
| 8 |
+
from .normalization import symlog, symexp
|
| 9 |
+
from .pos_encoding import RoPE3D, build_rope_freqs
|
| 10 |
+
from .learned_pe import LearnedTokenPE
|
| 11 |
+
from .rays import FThetaCamera, compute_ego_rays
|
| 12 |
+
from .temporal_compress import TemporalCompress2x2x2
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"SwiGLUFFN",
|
| 16 |
+
"GateSelfAttention",
|
| 17 |
+
"GateCrossAttention",
|
| 18 |
+
"MoEBlock",
|
| 19 |
+
"PerLayerExperts",
|
| 20 |
+
"symlog",
|
| 21 |
+
"symexp",
|
| 22 |
+
"RoPE3D",
|
| 23 |
+
"build_rope_freqs",
|
| 24 |
+
"LearnedTokenPE",
|
| 25 |
+
"FThetaCamera",
|
| 26 |
+
"compute_ego_rays",
|
| 27 |
+
"TemporalCompress2x2x2",
|
| 28 |
+
]
|
src/wjad/modules/ffn.py
CHANGED
|
@@ -1,30 +1,30 @@
|
|
| 1 |
-
"""SwiGLU 前馈网络。
|
| 2 |
-
|
| 3 |
-
实现:D -> Linear(2 * 4D) -> chunk2 -> SiLU(a) * b -> Linear(D)
|
| 4 |
-
即 D -> 4D -> SwiGLU -> 2D -> D,与 Design.md 规定一致。
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from __future__ import annotations
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
import torch.nn.functional as F
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class SwiGLUFFN(nn.Module):
|
| 15 |
-
"""SwiGLU FFN: D->4D->SwiGLU->2D->D。
|
| 16 |
-
|
| 17 |
-
使用 ``F.silu(a) * b`` 与现有 ``swiglu.py`` 中的实现一致。
|
| 18 |
-
"""
|
| 19 |
-
|
| 20 |
-
def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0, bias: bool = True) -> None:
|
| 21 |
-
super().__init__()
|
| 22 |
-
hidden = mult * dim
|
| 23 |
-
self.fc1 = nn.Linear(dim, hidden * 2, bias=bias) # 一次性投影出 a,b
|
| 24 |
-
self.fc2 = nn.Linear(hidden, dim, bias=bias)
|
| 25 |
-
self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
| 26 |
-
|
| 27 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 28 |
-
ab = self.fc1(x)
|
| 29 |
-
a, b = ab.chunk(2, dim=-1)
|
| 30 |
-
return self.drop(self.fc2(F.silu(a) * b))
|
|
|
|
| 1 |
+
"""SwiGLU 前馈网络。
|
| 2 |
+
|
| 3 |
+
实现:D -> Linear(2 * 4D) -> chunk2 -> SiLU(a) * b -> Linear(D)
|
| 4 |
+
即 D -> 4D -> SwiGLU -> 2D -> D,与 Design.md 规定一致。
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SwiGLUFFN(nn.Module):
|
| 15 |
+
"""SwiGLU FFN: D->4D->SwiGLU->2D->D。
|
| 16 |
+
|
| 17 |
+
使用 ``F.silu(a) * b`` 与现有 ``swiglu.py`` 中的实现一致。
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0, bias: bool = True) -> None:
|
| 21 |
+
super().__init__()
|
| 22 |
+
hidden = mult * dim
|
| 23 |
+
self.fc1 = nn.Linear(dim, hidden * 2, bias=bias) # 一次性投影出 a,b
|
| 24 |
+
self.fc2 = nn.Linear(hidden, dim, bias=bias)
|
| 25 |
+
self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
| 26 |
+
|
| 27 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 28 |
+
ab = self.fc1(x)
|
| 29 |
+
a, b = ab.chunk(2, dim=-1)
|
| 30 |
+
return self.drop(self.fc2(F.silu(a) * b))
|
src/wjad/modules/gate_attention.py
CHANGED
|
@@ -1,181 +1,181 @@
|
|
| 1 |
-
"""GateSelfAttention / GateCrossAttention(基于 PyTorch SDPA)。
|
| 2 |
-
|
| 3 |
-
与 Design.md 一致:
|
| 4 |
-
- Q 经 Linear + Sigmoid 生成 D 维门控参数;
|
| 5 |
-
- 注意力得到的多头 V 合并后与门控逐元素相乘,再做 out_proj;
|
| 6 |
-
- 门控网络初始化输出 ≈ 1(bias 设大正值,weight ≈ 0),低 LR 缓慢步进。
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
from __future__ import annotations
|
| 10 |
-
|
| 11 |
-
import math
|
| 12 |
-
from typing import Optional
|
| 13 |
-
|
| 14 |
-
import torch
|
| 15 |
-
import torch.nn as nn
|
| 16 |
-
import torch.nn.functional as F
|
| 17 |
-
|
| 18 |
-
from .pos_encoding import apply_rope
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class _MultiHeadProj(nn.Module):
|
| 22 |
-
"""通用的多头 Q/K/V 投影 + reshape。"""
|
| 23 |
-
|
| 24 |
-
def __init__(
|
| 25 |
-
self,
|
| 26 |
-
dim_q: int,
|
| 27 |
-
dim_kv: int,
|
| 28 |
-
num_heads: int,
|
| 29 |
-
head_dim: int,
|
| 30 |
-
q_bias: bool = True,
|
| 31 |
-
kv_bias: bool = True,
|
| 32 |
-
) -> None:
|
| 33 |
-
super().__init__()
|
| 34 |
-
self.num_heads = num_heads
|
| 35 |
-
self.head_dim = head_dim
|
| 36 |
-
inner = num_heads * head_dim
|
| 37 |
-
self.q_proj = nn.Linear(dim_q, inner, bias=q_bias)
|
| 38 |
-
self.k_proj = nn.Linear(dim_kv, inner, bias=kv_bias)
|
| 39 |
-
self.v_proj = nn.Linear(dim_kv, inner, bias=kv_bias)
|
| 40 |
-
|
| 41 |
-
def project_q(self, x: torch.Tensor) -> torch.Tensor:
|
| 42 |
-
b, n, _ = x.shape
|
| 43 |
-
return self.q_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
|
| 44 |
-
|
| 45 |
-
def project_kv(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 46 |
-
b, n, _ = x.shape
|
| 47 |
-
k = self.k_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
|
| 48 |
-
v = self.v_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
|
| 49 |
-
return k, v
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
class _GateModule(nn.Module):
|
| 53 |
-
"""门控生成器:输入 Q 来源张量,输出 [B,N,D] 门控值,初始 ≈ 1。
|
| 54 |
-
|
| 55 |
-
bias 初始化为 ``init_bias``(默认 5.0 → sigmoid≈0.993),weight 初始化为 0。
|
| 56 |
-
这样初始状态等价于普通注意力,门控随训练缓慢偏离 1。
|
| 57 |
-
"""
|
| 58 |
-
|
| 59 |
-
def __init__(self, dim: int, init_bias: float = 5.0) -> None:
|
| 60 |
-
super().__init__()
|
| 61 |
-
self.proj = nn.Linear(dim, dim)
|
| 62 |
-
nn.init.zeros_(self.proj.weight)
|
| 63 |
-
nn.init.constant_(self.proj.bias, init_bias)
|
| 64 |
-
|
| 65 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 66 |
-
return torch.sigmoid(self.proj(x))
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
class GateSelfAttention(nn.Module):
|
| 70 |
-
"""门控自注意力,使用 PyTorch SDPA。
|
| 71 |
-
|
| 72 |
-
支持仅对视觉 token 应用 3D RoPE:通过 ``visual_slice`` 指定切片。
|
| 73 |
-
"""
|
| 74 |
-
|
| 75 |
-
def __init__(
|
| 76 |
-
self,
|
| 77 |
-
dim: int,
|
| 78 |
-
num_heads: int,
|
| 79 |
-
dropout: float = 0.0,
|
| 80 |
-
gate_init_bias: float = 5.0,
|
| 81 |
-
q_bias: bool = True,
|
| 82 |
-
kv_bias: bool = True,
|
| 83 |
-
) -> None:
|
| 84 |
-
super().__init__()
|
| 85 |
-
assert dim % num_heads == 0, "dim 必须能被 num_heads 整除"
|
| 86 |
-
self.dim = dim
|
| 87 |
-
self.num_heads = num_heads
|
| 88 |
-
self.head_dim = dim // num_heads
|
| 89 |
-
self.scale = 1.0 / math.sqrt(self.head_dim)
|
| 90 |
-
self.dropout_p = dropout
|
| 91 |
-
|
| 92 |
-
self.proj = _MultiHeadProj(dim, dim, num_heads, self.head_dim, q_bias, kv_bias)
|
| 93 |
-
self.gate = _GateModule(dim, init_bias=gate_init_bias)
|
| 94 |
-
self.out_proj = nn.Linear(dim, dim, bias=True)
|
| 95 |
-
|
| 96 |
-
def forward(
|
| 97 |
-
self,
|
| 98 |
-
x: torch.Tensor,
|
| 99 |
-
rope_cos: Optional[torch.Tensor] = None,
|
| 100 |
-
rope_sin: Optional[torch.Tensor] = None,
|
| 101 |
-
visual_slice: Optional[tuple[int, int]] = None,
|
| 102 |
-
) -> torch.Tensor:
|
| 103 |
-
"""
|
| 104 |
-
参数
|
| 105 |
-
----
|
| 106 |
-
x : [B, N, D]
|
| 107 |
-
rope_cos, rope_sin : [B, N_v, H, head_dim/2] 或 None
|
| 108 |
-
visual_slice : (start, end),指定视觉 token 在序列中的范围。
|
| 109 |
-
非视觉 token 切片 Q/K 不做 RoPE。
|
| 110 |
-
"""
|
| 111 |
-
b, n, _ = x.shape
|
| 112 |
-
q = self.proj.project_q(x) # [B, H, N, Dh]
|
| 113 |
-
k, v = self.proj.project_kv(x)
|
| 114 |
-
|
| 115 |
-
# 仅对视觉切片应用 RoPE
|
| 116 |
-
if rope_cos is not None and visual_slice is not None:
|
| 117 |
-
s, e = visual_slice
|
| 118 |
-
q_v = q[:, :, s:e, :]
|
| 119 |
-
k_v = k[:, :, s:e, :]
|
| 120 |
-
q_v, k_v = apply_rope(q_v, k_v, rope_cos, rope_sin)
|
| 121 |
-
q = torch.cat([q[:, :, :s, :], q_v, q[:, :, e:, :]], dim=2)
|
| 122 |
-
k = torch.cat([k[:, :, :s, :], k_v, k[:, :, e:, :]], dim=2)
|
| 123 |
-
|
| 124 |
-
# SDPA:[B, H, N, Dh]
|
| 125 |
-
attn = F.scaled_dot_product_attention(
|
| 126 |
-
q,
|
| 127 |
-
k,
|
| 128 |
-
v,
|
| 129 |
-
attn_mask=None,
|
| 130 |
-
dropout_p=self.dropout_p if self.training else 0.0,
|
| 131 |
-
is_causal=False,
|
| 132 |
-
)
|
| 133 |
-
# 多头合并
|
| 134 |
-
attn = attn.transpose(1, 2).contiguous().view(b, n, self.dim)
|
| 135 |
-
|
| 136 |
-
# 门控 ⊗ 多头合并后的 V,再 out_proj
|
| 137 |
-
gate = self.gate(x) # 用 Q 的源(即 x)生成门控
|
| 138 |
-
out = self.out_proj(attn * gate)
|
| 139 |
-
return out
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
class GateCrossAttention(nn.Module):
|
| 143 |
-
"""门控交叉注意力,Q 来自 query token,K/V 来自 context(如 DINOv3 patch 特征)。"""
|
| 144 |
-
|
| 145 |
-
def __init__(
|
| 146 |
-
self,
|
| 147 |
-
dim_q: int,
|
| 148 |
-
dim_kv: int,
|
| 149 |
-
num_heads: int,
|
| 150 |
-
dropout: float = 0.0,
|
| 151 |
-
gate_init_bias: float = 5.0,
|
| 152 |
-
q_bias: bool = True,
|
| 153 |
-
kv_bias: bool = True,
|
| 154 |
-
) -> None:
|
| 155 |
-
super().__init__()
|
| 156 |
-
assert dim_q % num_heads == 0, "dim_q 必须能被 num_heads 整除"
|
| 157 |
-
self.dim_q = dim_q
|
| 158 |
-
self.num_heads = num_heads
|
| 159 |
-
self.head_dim = dim_q // num_heads
|
| 160 |
-
self.dropout_p = dropout
|
| 161 |
-
|
| 162 |
-
self.proj = _MultiHeadProj(dim_q, dim_kv, num_heads, self.head_dim, q_bias, kv_bias)
|
| 163 |
-
self.gate = _GateModule(dim_q, init_bias=gate_init_bias)
|
| 164 |
-
self.out_proj = nn.Linear(dim_q, dim_q, bias=True)
|
| 165 |
-
|
| 166 |
-
def forward(self, q_in: torch.Tensor, kv_in: torch.Tensor) -> torch.Tensor:
|
| 167 |
-
b, n, _ = q_in.shape
|
| 168 |
-
q = self.proj.project_q(q_in)
|
| 169 |
-
k, v = self.proj.project_kv(kv_in)
|
| 170 |
-
|
| 171 |
-
attn = F.scaled_dot_product_attention(
|
| 172 |
-
q,
|
| 173 |
-
k,
|
| 174 |
-
v,
|
| 175 |
-
attn_mask=None,
|
| 176 |
-
dropout_p=self.dropout_p if self.training else 0.0,
|
| 177 |
-
is_causal=False,
|
| 178 |
-
)
|
| 179 |
-
attn = attn.transpose(1, 2).contiguous().view(b, n, self.dim_q)
|
| 180 |
-
gate = self.gate(q_in)
|
| 181 |
-
return self.out_proj(attn * gate)
|
|
|
|
| 1 |
+
"""GateSelfAttention / GateCrossAttention(基于 PyTorch SDPA)。
|
| 2 |
+
|
| 3 |
+
与 Design.md 一致:
|
| 4 |
+
- Q 经 Linear + Sigmoid 生成 D 维门控参数;
|
| 5 |
+
- 注意力得到的多头 V 合并后与门控逐元素相乘,再做 out_proj;
|
| 6 |
+
- 门控网络初始化输出 ≈ 1(bias 设大正值,weight ≈ 0),低 LR 缓慢步进。
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
from .pos_encoding import apply_rope
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class _MultiHeadProj(nn.Module):
|
| 22 |
+
"""通用的多头 Q/K/V 投影 + reshape。"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
dim_q: int,
|
| 27 |
+
dim_kv: int,
|
| 28 |
+
num_heads: int,
|
| 29 |
+
head_dim: int,
|
| 30 |
+
q_bias: bool = True,
|
| 31 |
+
kv_bias: bool = True,
|
| 32 |
+
) -> None:
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.num_heads = num_heads
|
| 35 |
+
self.head_dim = head_dim
|
| 36 |
+
inner = num_heads * head_dim
|
| 37 |
+
self.q_proj = nn.Linear(dim_q, inner, bias=q_bias)
|
| 38 |
+
self.k_proj = nn.Linear(dim_kv, inner, bias=kv_bias)
|
| 39 |
+
self.v_proj = nn.Linear(dim_kv, inner, bias=kv_bias)
|
| 40 |
+
|
| 41 |
+
def project_q(self, x: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
b, n, _ = x.shape
|
| 43 |
+
return self.q_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
|
| 44 |
+
|
| 45 |
+
def project_kv(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 46 |
+
b, n, _ = x.shape
|
| 47 |
+
k = self.k_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
|
| 48 |
+
v = self.v_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
|
| 49 |
+
return k, v
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class _GateModule(nn.Module):
|
| 53 |
+
"""门控生成器:输入 Q 来源张量,输出 [B,N,D] 门控值,初始 ≈ 1。
|
| 54 |
+
|
| 55 |
+
bias 初始化为 ``init_bias``(默认 5.0 → sigmoid≈0.993),weight 初始化为 0。
|
| 56 |
+
这样初始状态等价于普通注意力,门控随训练缓慢偏离 1。
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, dim: int, init_bias: float = 5.0) -> None:
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.proj = nn.Linear(dim, dim)
|
| 62 |
+
nn.init.zeros_(self.proj.weight)
|
| 63 |
+
nn.init.constant_(self.proj.bias, init_bias)
|
| 64 |
+
|
| 65 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 66 |
+
return torch.sigmoid(self.proj(x))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class GateSelfAttention(nn.Module):
|
| 70 |
+
"""门控自注意力,使用 PyTorch SDPA。
|
| 71 |
+
|
| 72 |
+
支持仅对视觉 token 应用 3D RoPE:通过 ``visual_slice`` 指定切片。
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
dim: int,
|
| 78 |
+
num_heads: int,
|
| 79 |
+
dropout: float = 0.0,
|
| 80 |
+
gate_init_bias: float = 5.0,
|
| 81 |
+
q_bias: bool = True,
|
| 82 |
+
kv_bias: bool = True,
|
| 83 |
+
) -> None:
|
| 84 |
+
super().__init__()
|
| 85 |
+
assert dim % num_heads == 0, "dim 必须能被 num_heads 整除"
|
| 86 |
+
self.dim = dim
|
| 87 |
+
self.num_heads = num_heads
|
| 88 |
+
self.head_dim = dim // num_heads
|
| 89 |
+
self.scale = 1.0 / math.sqrt(self.head_dim)
|
| 90 |
+
self.dropout_p = dropout
|
| 91 |
+
|
| 92 |
+
self.proj = _MultiHeadProj(dim, dim, num_heads, self.head_dim, q_bias, kv_bias)
|
| 93 |
+
self.gate = _GateModule(dim, init_bias=gate_init_bias)
|
| 94 |
+
self.out_proj = nn.Linear(dim, dim, bias=True)
|
| 95 |
+
|
| 96 |
+
def forward(
|
| 97 |
+
self,
|
| 98 |
+
x: torch.Tensor,
|
| 99 |
+
rope_cos: Optional[torch.Tensor] = None,
|
| 100 |
+
rope_sin: Optional[torch.Tensor] = None,
|
| 101 |
+
visual_slice: Optional[tuple[int, int]] = None,
|
| 102 |
+
) -> torch.Tensor:
|
| 103 |
+
"""
|
| 104 |
+
参数
|
| 105 |
+
----
|
| 106 |
+
x : [B, N, D]
|
| 107 |
+
rope_cos, rope_sin : [B, N_v, H, head_dim/2] 或 None
|
| 108 |
+
visual_slice : (start, end),指定视觉 token 在序列中的范围。
|
| 109 |
+
非视觉 token 切片 Q/K 不做 RoPE。
|
| 110 |
+
"""
|
| 111 |
+
b, n, _ = x.shape
|
| 112 |
+
q = self.proj.project_q(x) # [B, H, N, Dh]
|
| 113 |
+
k, v = self.proj.project_kv(x)
|
| 114 |
+
|
| 115 |
+
# 仅对视觉切片应用 RoPE
|
| 116 |
+
if rope_cos is not None and visual_slice is not None:
|
| 117 |
+
s, e = visual_slice
|
| 118 |
+
q_v = q[:, :, s:e, :]
|
| 119 |
+
k_v = k[:, :, s:e, :]
|
| 120 |
+
q_v, k_v = apply_rope(q_v, k_v, rope_cos, rope_sin)
|
| 121 |
+
q = torch.cat([q[:, :, :s, :], q_v, q[:, :, e:, :]], dim=2)
|
| 122 |
+
k = torch.cat([k[:, :, :s, :], k_v, k[:, :, e:, :]], dim=2)
|
| 123 |
+
|
| 124 |
+
# SDPA:[B, H, N, Dh]
|
| 125 |
+
attn = F.scaled_dot_product_attention(
|
| 126 |
+
q,
|
| 127 |
+
k,
|
| 128 |
+
v,
|
| 129 |
+
attn_mask=None,
|
| 130 |
+
dropout_p=self.dropout_p if self.training else 0.0,
|
| 131 |
+
is_causal=False,
|
| 132 |
+
)
|
| 133 |
+
# 多头合并
|
| 134 |
+
attn = attn.transpose(1, 2).contiguous().view(b, n, self.dim)
|
| 135 |
+
|
| 136 |
+
# 门控 ⊗ 多头合并后的 V,再 out_proj
|
| 137 |
+
gate = self.gate(x) # 用 Q 的源(即 x)生成门控
|
| 138 |
+
out = self.out_proj(attn * gate)
|
| 139 |
+
return out
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class GateCrossAttention(nn.Module):
|
| 143 |
+
"""门控交叉注意力,Q 来自 query token,K/V 来自 context(如 DINOv3 patch 特征)。"""
|
| 144 |
+
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
dim_q: int,
|
| 148 |
+
dim_kv: int,
|
| 149 |
+
num_heads: int,
|
| 150 |
+
dropout: float = 0.0,
|
| 151 |
+
gate_init_bias: float = 5.0,
|
| 152 |
+
q_bias: bool = True,
|
| 153 |
+
kv_bias: bool = True,
|
| 154 |
+
) -> None:
|
| 155 |
+
super().__init__()
|
| 156 |
+
assert dim_q % num_heads == 0, "dim_q 必须能被 num_heads 整除"
|
| 157 |
+
self.dim_q = dim_q
|
| 158 |
+
self.num_heads = num_heads
|
| 159 |
+
self.head_dim = dim_q // num_heads
|
| 160 |
+
self.dropout_p = dropout
|
| 161 |
+
|
| 162 |
+
self.proj = _MultiHeadProj(dim_q, dim_kv, num_heads, self.head_dim, q_bias, kv_bias)
|
| 163 |
+
self.gate = _GateModule(dim_q, init_bias=gate_init_bias)
|
| 164 |
+
self.out_proj = nn.Linear(dim_q, dim_q, bias=True)
|
| 165 |
+
|
| 166 |
+
def forward(self, q_in: torch.Tensor, kv_in: torch.Tensor) -> torch.Tensor:
|
| 167 |
+
b, n, _ = q_in.shape
|
| 168 |
+
q = self.proj.project_q(q_in)
|
| 169 |
+
k, v = self.proj.project_kv(kv_in)
|
| 170 |
+
|
| 171 |
+
attn = F.scaled_dot_product_attention(
|
| 172 |
+
q,
|
| 173 |
+
k,
|
| 174 |
+
v,
|
| 175 |
+
attn_mask=None,
|
| 176 |
+
dropout_p=self.dropout_p if self.training else 0.0,
|
| 177 |
+
is_causal=False,
|
| 178 |
+
)
|
| 179 |
+
attn = attn.transpose(1, 2).contiguous().view(b, n, self.dim_q)
|
| 180 |
+
gate = self.gate(q_in)
|
| 181 |
+
return self.out_proj(attn * gate)
|
src/wjad/modules/learned_pe.py
CHANGED
|
@@ -1,24 +1,24 @@
|
|
| 1 |
-
"""非视觉 token 的可学习位置编码。
|
| 2 |
-
|
| 3 |
-
ego(8) / det(1024) / ctrl(24) / extra(256) 各自维护一份独立的
|
| 4 |
-
``[N, D]`` 可学习参数,初始化 ``trunc_normal(std=0.02)``。
|
| 5 |
-
直接加到对应 token 上,不参与 RoPE。
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
from __future__ import annotations
|
| 9 |
-
|
| 10 |
-
import torch
|
| 11 |
-
import torch.nn as nn
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class LearnedTokenPE(nn.Module):
|
| 15 |
-
"""形状为 ``[N, D]`` 的可学习位置编码,前向时按 batch 广播加。"""
|
| 16 |
-
|
| 17 |
-
def __init__(self, num_tokens: int, dim: int, init_std: float = 0.02) -> None:
|
| 18 |
-
super().__init__()
|
| 19 |
-
self.pe = nn.Parameter(torch.empty(num_tokens, dim))
|
| 20 |
-
nn.init.trunc_normal_(self.pe, std=init_std)
|
| 21 |
-
|
| 22 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 23 |
-
# x: [B, N, D]
|
| 24 |
-
return x + self.pe.unsqueeze(0)
|
|
|
|
| 1 |
+
"""非视觉 token 的可学习位置编码。
|
| 2 |
+
|
| 3 |
+
ego(8) / det(1024) / ctrl(24) / extra(256) 各自维护一份独立的
|
| 4 |
+
``[N, D]`` 可学习参数,初始化 ``trunc_normal(std=0.02)``。
|
| 5 |
+
直接加到对应 token 上,不参与 RoPE。
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LearnedTokenPE(nn.Module):
|
| 15 |
+
"""形状为 ``[N, D]`` 的可学习位置编码,前向时按 batch 广播加。"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, num_tokens: int, dim: int, init_std: float = 0.02) -> None:
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.pe = nn.Parameter(torch.empty(num_tokens, dim))
|
| 20 |
+
nn.init.trunc_normal_(self.pe, std=init_std)
|
| 21 |
+
|
| 22 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 23 |
+
# x: [B, N, D]
|
| 24 |
+
return x + self.pe.unsqueeze(0)
|
src/wjad/modules/moe.py
CHANGED
|
@@ -1,129 +1,129 @@
|
|
| 1 |
-
"""每层独立 MoE 块(7 路由 + 1 共享专家,GAP 序列级 Sigmoid Top-3)。
|
| 2 |
-
|
| 3 |
-
设计要点(与 Design.md 对齐):
|
| 4 |
-
- 每层独立 8 个专家库(专家[0] 为共享),不同层之间不共享。
|
| 5 |
-
- 路由:对当前层输入做 ``GAP(序列) -> Linear -> Sigmoid -> Top3 mask``。
|
| 6 |
-
- 共享专家始终激活;路由专家依据 sigmoid 概率加权(Stage1 全激活、Stage2 严格 Top-3)。
|
| 7 |
-
- 输出 = 共享专家(x) + sum_i (probs_i * mask_i) * expert_i(x)。
|
| 8 |
-
- 提供路由 logits / probs / 负载均衡 / 边界正则的辅助统计,外部由
|
| 9 |
-
``losses/moe_aux.py`` 聚合成正则损失。
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
from __future__ import annotations
|
| 13 |
-
|
| 14 |
-
from dataclasses import dataclass
|
| 15 |
-
|
| 16 |
-
import torch
|
| 17 |
-
import torch.nn as nn
|
| 18 |
-
|
| 19 |
-
from .ffn import SwiGLUFFN
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
@dataclass
|
| 23 |
-
class MoEStats:
|
| 24 |
-
"""单层 MoE 输出的辅助统计,用于损失与监控。"""
|
| 25 |
-
|
| 26 |
-
logits: torch.Tensor # [B, num_routed]
|
| 27 |
-
probs: torch.Tensor # [B, num_routed],sigmoid 后的概率
|
| 28 |
-
topk_mask: torch.Tensor # [B, num_routed],0/1
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
class PerLayerExperts(nn.Module):
|
| 32 |
-
"""单层的专家库:1 个共享 + N 个路由,全部为 SwiGLUFFN。"""
|
| 33 |
-
|
| 34 |
-
def __init__(
|
| 35 |
-
self,
|
| 36 |
-
dim: int,
|
| 37 |
-
num_routed: int = 7,
|
| 38 |
-
num_shared: int = 1,
|
| 39 |
-
ffn_mult: int = 4,
|
| 40 |
-
dropout: float = 0.0,
|
| 41 |
-
) -> None:
|
| 42 |
-
super().__init__()
|
| 43 |
-
self.num_routed = num_routed
|
| 44 |
-
self.num_shared = num_shared
|
| 45 |
-
self.shared = nn.ModuleList(
|
| 46 |
-
[SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_shared)]
|
| 47 |
-
)
|
| 48 |
-
self.routed = nn.ModuleList(
|
| 49 |
-
[SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_routed)]
|
| 50 |
-
)
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
class MoEBlock(nn.Module):
|
| 54 |
-
"""带路由的 MoE FFN 块(每层独立专家库)。"""
|
| 55 |
-
|
| 56 |
-
def __init__(
|
| 57 |
-
self,
|
| 58 |
-
dim: int,
|
| 59 |
-
num_routed: int = 7,
|
| 60 |
-
num_shared: int = 1,
|
| 61 |
-
ffn_mult: int = 4,
|
| 62 |
-
topk: int = 3,
|
| 63 |
-
dropout: float = 0.0,
|
| 64 |
-
) -> None:
|
| 65 |
-
super().__init__()
|
| 66 |
-
self.dim = dim
|
| 67 |
-
self.num_routed = num_routed
|
| 68 |
-
self.num_shared = num_shared
|
| 69 |
-
self.topk = topk
|
| 70 |
-
self.experts = PerLayerExperts(dim, num_routed, num_shared, ffn_mult, dropout)
|
| 71 |
-
self.router = nn.Linear(dim, num_routed, bias=True)
|
| 72 |
-
# 路由初始化:bias=0、weight 较小,以使初始概率接近 0.5
|
| 73 |
-
nn.init.normal_(self.router.weight, std=0.02)
|
| 74 |
-
nn.init.zeros_(self.router.bias)
|
| 75 |
-
|
| 76 |
-
# 训练阶段:'dense' 等同于 topk=num_routed;'sparse' 用真实 topk
|
| 77 |
-
self._mode: str = "dense"
|
| 78 |
-
# 路由温度(温度 < 1 => 锐化)
|
| 79 |
-
self.register_buffer("router_temperature", torch.tensor(1.0))
|
| 80 |
-
|
| 81 |
-
def set_mode(self, mode: str) -> None:
|
| 82 |
-
assert mode in ("dense", "sparse"), f"未知模式: {mode}"
|
| 83 |
-
self._mode = mode
|
| 84 |
-
|
| 85 |
-
@property
|
| 86 |
-
def mode(self) -> str:
|
| 87 |
-
return self._mode
|
| 88 |
-
|
| 89 |
-
def set_temperature(self, t: float) -> None:
|
| 90 |
-
self.router_temperature.fill_(float(t))
|
| 91 |
-
|
| 92 |
-
def _route(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 93 |
-
"""计算 logits / probs / topk_mask。x: [B, N, D]。"""
|
| 94 |
-
pooled = x.mean(dim=1) # [B, D]
|
| 95 |
-
logits = self.router(pooled) # [B, num_routed]
|
| 96 |
-
# 温度锐化(温度小 => 概率更尖)
|
| 97 |
-
scaled = logits / self.router_temperature.clamp_min(1e-3)
|
| 98 |
-
probs = torch.sigmoid(scaled)
|
| 99 |
-
|
| 100 |
-
if self._mode == "dense" or self.topk >= self.num_routed:
|
| 101 |
-
mask = torch.ones_like(probs)
|
| 102 |
-
else:
|
| 103 |
-
topk_vals, topk_idx = torch.topk(probs, self.topk, dim=-1)
|
| 104 |
-
mask = torch.zeros_like(probs)
|
| 105 |
-
mask.scatter_(-1, topk_idx, 1.0)
|
| 106 |
-
|
| 107 |
-
return logits, probs, mask
|
| 108 |
-
|
| 109 |
-
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, MoEStats]:
|
| 110 |
-
b, n, d = x.shape
|
| 111 |
-
logits, probs, mask = self._route(x)
|
| 112 |
-
|
| 113 |
-
# 共享专家(恒激活,无门控)
|
| 114 |
-
out = torch.zeros_like(x)
|
| 115 |
-
for sh in self.experts.shared:
|
| 116 |
-
out = out + sh(x)
|
| 117 |
-
|
| 118 |
-
# 路由专家:按 (probs * mask) 在 batch 维加权
|
| 119 |
-
weights = probs * mask # [B, num_routed]
|
| 120 |
-
# 注意:每个样本各自的权重独立。逐专家计算后 batch 级加权,避免 token 级
|
| 121 |
-
# 路由的索引开销;与 Design.md "序列级分配" 一致。
|
| 122 |
-
for i, expert in enumerate(self.experts.routed):
|
| 123 |
-
w_i = weights[:, i].view(b, 1, 1) # [B,1,1]
|
| 124 |
-
# 仅当批内任一样本权重 > 0 时才前向以减少计算
|
| 125 |
-
if torch.any(w_i > 0):
|
| 126 |
-
out = out + w_i * expert(x)
|
| 127 |
-
|
| 128 |
-
stats = MoEStats(logits=logits, probs=probs, topk_mask=mask)
|
| 129 |
-
return out, stats
|
|
|
|
| 1 |
+
"""每层独立 MoE 块(7 路由 + 1 共享专家,GAP 序列级 Sigmoid Top-3)。
|
| 2 |
+
|
| 3 |
+
设计要点(与 Design.md 对齐):
|
| 4 |
+
- 每层独立 8 个专家库(专家[0] 为共享),不同层之间不共享。
|
| 5 |
+
- 路由:对当前层输入做 ``GAP(序列) -> Linear -> Sigmoid -> Top3 mask``。
|
| 6 |
+
- 共享专家始终激活;路由专家依据 sigmoid 概率加权(Stage1 全激活、Stage2 严格 Top-3)。
|
| 7 |
+
- 输出 = 共享专家(x) + sum_i (probs_i * mask_i) * expert_i(x)。
|
| 8 |
+
- 提供路由 logits / probs / 负载均衡 / 边界正则的辅助统计,外部由
|
| 9 |
+
``losses/moe_aux.py`` 聚合成正则损失。
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
|
| 19 |
+
from .ffn import SwiGLUFFN
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class MoEStats:
|
| 24 |
+
"""单层 MoE 输出的辅助统计,用于损失与监控。"""
|
| 25 |
+
|
| 26 |
+
logits: torch.Tensor # [B, num_routed]
|
| 27 |
+
probs: torch.Tensor # [B, num_routed],sigmoid 后的概率
|
| 28 |
+
topk_mask: torch.Tensor # [B, num_routed],0/1
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class PerLayerExperts(nn.Module):
|
| 32 |
+
"""单层的专家库:1 个共享 + N 个路由,全部为 SwiGLUFFN。"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
dim: int,
|
| 37 |
+
num_routed: int = 7,
|
| 38 |
+
num_shared: int = 1,
|
| 39 |
+
ffn_mult: int = 4,
|
| 40 |
+
dropout: float = 0.0,
|
| 41 |
+
) -> None:
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.num_routed = num_routed
|
| 44 |
+
self.num_shared = num_shared
|
| 45 |
+
self.shared = nn.ModuleList(
|
| 46 |
+
[SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_shared)]
|
| 47 |
+
)
|
| 48 |
+
self.routed = nn.ModuleList(
|
| 49 |
+
[SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_routed)]
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class MoEBlock(nn.Module):
|
| 54 |
+
"""带路由的 MoE FFN 块(每层独立专家库)。"""
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
dim: int,
|
| 59 |
+
num_routed: int = 7,
|
| 60 |
+
num_shared: int = 1,
|
| 61 |
+
ffn_mult: int = 4,
|
| 62 |
+
topk: int = 3,
|
| 63 |
+
dropout: float = 0.0,
|
| 64 |
+
) -> None:
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.dim = dim
|
| 67 |
+
self.num_routed = num_routed
|
| 68 |
+
self.num_shared = num_shared
|
| 69 |
+
self.topk = topk
|
| 70 |
+
self.experts = PerLayerExperts(dim, num_routed, num_shared, ffn_mult, dropout)
|
| 71 |
+
self.router = nn.Linear(dim, num_routed, bias=True)
|
| 72 |
+
# 路由初始化:bias=0、weight 较小,以使初始概率接近 0.5
|
| 73 |
+
nn.init.normal_(self.router.weight, std=0.02)
|
| 74 |
+
nn.init.zeros_(self.router.bias)
|
| 75 |
+
|
| 76 |
+
# 训练阶段:'dense' 等同于 topk=num_routed;'sparse' 用真实 topk
|
| 77 |
+
self._mode: str = "dense"
|
| 78 |
+
# 路由温度(温度 < 1 => 锐化)
|
| 79 |
+
self.register_buffer("router_temperature", torch.tensor(1.0))
|
| 80 |
+
|
| 81 |
+
def set_mode(self, mode: str) -> None:
|
| 82 |
+
assert mode in ("dense", "sparse"), f"未知模式: {mode}"
|
| 83 |
+
self._mode = mode
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def mode(self) -> str:
|
| 87 |
+
return self._mode
|
| 88 |
+
|
| 89 |
+
def set_temperature(self, t: float) -> None:
|
| 90 |
+
self.router_temperature.fill_(float(t))
|
| 91 |
+
|
| 92 |
+
def _route(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 93 |
+
"""计算 logits / probs / topk_mask。x: [B, N, D]。"""
|
| 94 |
+
pooled = x.mean(dim=1) # [B, D]
|
| 95 |
+
logits = self.router(pooled) # [B, num_routed]
|
| 96 |
+
# 温度锐化(温度小 => 概率更尖)
|
| 97 |
+
scaled = logits / self.router_temperature.clamp_min(1e-3)
|
| 98 |
+
probs = torch.sigmoid(scaled)
|
| 99 |
+
|
| 100 |
+
if self._mode == "dense" or self.topk >= self.num_routed:
|
| 101 |
+
mask = torch.ones_like(probs)
|
| 102 |
+
else:
|
| 103 |
+
topk_vals, topk_idx = torch.topk(probs, self.topk, dim=-1)
|
| 104 |
+
mask = torch.zeros_like(probs)
|
| 105 |
+
mask.scatter_(-1, topk_idx, 1.0)
|
| 106 |
+
|
| 107 |
+
return logits, probs, mask
|
| 108 |
+
|
| 109 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, MoEStats]:
|
| 110 |
+
b, n, d = x.shape
|
| 111 |
+
logits, probs, mask = self._route(x)
|
| 112 |
+
|
| 113 |
+
# 共享专家(恒激活,无门控)
|
| 114 |
+
out = torch.zeros_like(x)
|
| 115 |
+
for sh in self.experts.shared:
|
| 116 |
+
out = out + sh(x)
|
| 117 |
+
|
| 118 |
+
# 路由专家:按 (probs * mask) 在 batch 维加权
|
| 119 |
+
weights = probs * mask # [B, num_routed]
|
| 120 |
+
# 注意:每个样本各自的权重独立。逐专家计算后 batch 级加权,避免 token 级
|
| 121 |
+
# 路由的索引开销;与 Design.md "序列级分配" 一致。
|
| 122 |
+
for i, expert in enumerate(self.experts.routed):
|
| 123 |
+
w_i = weights[:, i].view(b, 1, 1) # [B,1,1]
|
| 124 |
+
# 仅当批内任一样本权重 > 0 时才前向以减少计算
|
| 125 |
+
if torch.any(w_i > 0):
|
| 126 |
+
out = out + w_i * expert(x)
|
| 127 |
+
|
| 128 |
+
stats = MoEStats(logits=logits, probs=probs, topk_mask=mask)
|
| 129 |
+
return out, stats
|
src/wjad/modules/normalization.py
CHANGED
|
@@ -1,22 +1,22 @@
|
|
| 1 |
-
"""对称对数归一化算子 symlog / symexp。
|
| 2 |
-
|
| 3 |
-
公式:
|
| 4 |
-
symlog(x) = sign(x) * log(|x| + 1)
|
| 5 |
-
symexp(y) = sign(y) * (exp(|y|) - 1)
|
| 6 |
-
|
| 7 |
-
用于运动学 / 坐标 / 内外参的归一化,使大幅值被压缩、保持可逆。
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
from __future__ import annotations
|
| 11 |
-
|
| 12 |
-
import torch
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def symlog(x: torch.Tensor) -> torch.Tensor:
|
| 16 |
-
"""对称对数压缩:sign(x) * log(|x| + 1)。"""
|
| 17 |
-
return torch.sign(x) * torch.log1p(torch.abs(x))
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def symexp(y: torch.Tensor) -> torch.Tensor:
|
| 21 |
-
"""symlog 的逆:sign(y) * (exp(|y|) - 1)。"""
|
| 22 |
-
return torch.sign(y) * torch.expm1(torch.abs(y))
|
|
|
|
| 1 |
+
"""对称对数归一化算子 symlog / symexp。
|
| 2 |
+
|
| 3 |
+
公式:
|
| 4 |
+
symlog(x) = sign(x) * log(|x| + 1)
|
| 5 |
+
symexp(y) = sign(y) * (exp(|y|) - 1)
|
| 6 |
+
|
| 7 |
+
用于运动学 / 坐标 / 内外参的归一化,使大幅值被压缩、保持可逆。
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def symlog(x: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
"""对称对数压缩:sign(x) * log(|x| + 1)。"""
|
| 17 |
+
return torch.sign(x) * torch.log1p(torch.abs(x))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def symexp(y: torch.Tensor) -> torch.Tensor:
|
| 21 |
+
"""symlog 的逆:sign(y) * (exp(|y|) - 1)。"""
|
| 22 |
+
return torch.sign(y) * torch.expm1(torch.abs(y))
|
src/wjad/modules/pos_encoding.py
CHANGED
|
@@ -1,224 +1,224 @@
|
|
| 1 |
-
"""3D RoPE(仅作用于视觉 token)。
|
| 2 |
-
|
| 3 |
-
12 头按 4+4+4 拆为三组:
|
| 4 |
-
- 头 0-3:射线 RoPE,编码自车系下的单位射线方向 ``(dx, dy, dz)``。
|
| 5 |
-
- 头 4-7:H/W/T RoPE,编码归一化的空间-时间索引 ``(h_norm, w_norm, t_norm)``。
|
| 6 |
-
- 头 8-11:零频段 RoPE,cos=1 / sin=0 → 旋转矩阵恒为 I(identity)。
|
| 7 |
-
|
| 8 |
-
为减少分支与显存通信,全部 12 头统一走同一份 RoPE 算子(不写 if/else),
|
| 9 |
-
零频段头自然变为恒等映射。
|
| 10 |
-
|
| 11 |
-
将 ``head_dim=64`` 切成 32 个 (cos, sin) 对(两两一组旋转)。每组头内部再按
|
| 12 |
-
3 个分量(dx,dy,dz 或 h,w,t)平均分配 32/3 ≈ 10 对(最后 2 对补 0 频)。
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
from __future__ import annotations
|
| 16 |
-
|
| 17 |
-
import torch
|
| 18 |
-
import torch.nn as nn
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def _split_head_dim_for_components(half: int, num_components: int) -> list[int]:
|
| 22 |
-
"""把 head_dim/2 个旋转对均匀分给若干个分量;剩余补 0 频。
|
| 23 |
-
|
| 24 |
-
返回每个分量分到的旋转对数,最后一项是 ``half - sum(其它)``。
|
| 25 |
-
若 ``num_components == 0``(零频段头),则返回 ``[0, 0, ..., half]``,最后
|
| 26 |
-
一项视为"零频段"——它的频率会被显式置为 0。
|
| 27 |
-
"""
|
| 28 |
-
if num_components == 0:
|
| 29 |
-
return [0, half]
|
| 30 |
-
base = half // num_components
|
| 31 |
-
splits = [base] * num_components
|
| 32 |
-
splits[-1] += half - base * num_components # 余数全归到最后一个分量
|
| 33 |
-
return splits
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def build_rope_freqs(
|
| 37 |
-
rays: torch.Tensor,
|
| 38 |
-
hwt_grid: torch.Tensor,
|
| 39 |
-
num_heads: int = 12,
|
| 40 |
-
head_dim: int = 64,
|
| 41 |
-
rope_theta: float = 10000.0,
|
| 42 |
-
device: torch.device | None = None,
|
| 43 |
-
dtype: torch.dtype = torch.float32,
|
| 44 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 45 |
-
"""构造 3D RoPE 的 cos / sin 表。
|
| 46 |
-
|
| 47 |
-
参数
|
| 48 |
-
----
|
| 49 |
-
rays : Tensor, shape ``[B, N_v, 3]``
|
| 50 |
-
每个视觉 token 在自车系下的单位射线方向 ``(dx, dy, dz)``。
|
| 51 |
-
hwt_grid : Tensor, shape ``[B, N_v, 3]``
|
| 52 |
-
归一化的空间-时间坐标 ``(h_norm, w_norm, t_norm)`` ∈ [-1, 1]。
|
| 53 |
-
num_heads : int
|
| 54 |
-
总头数(默认 12)。
|
| 55 |
-
head_dim : int
|
| 56 |
-
每头维度(默认 64,必须为偶数)。
|
| 57 |
-
|
| 58 |
-
返回
|
| 59 |
-
----
|
| 60 |
-
cos, sin : Tensor, shape ``[B, N_v, num_heads, head_dim // 2]``
|
| 61 |
-
每个旋转对的 cos / sin 值,已就绪可送入 ``apply_rope``。
|
| 62 |
-
"""
|
| 63 |
-
assert head_dim % 2 == 0, "head_dim 必须为偶数"
|
| 64 |
-
assert num_heads % 3 == 0, "num_heads 需被 3 整除以便 4+4+4 分组"
|
| 65 |
-
|
| 66 |
-
half = head_dim // 2
|
| 67 |
-
heads_per_group = num_heads // 3
|
| 68 |
-
bsz, n_v, _ = rays.shape
|
| 69 |
-
if device is None:
|
| 70 |
-
device = rays.device
|
| 71 |
-
|
| 72 |
-
# === 三组分量值 ===
|
| 73 |
-
# group 0: rays (3 components)
|
| 74 |
-
# group 1: hwt (3 components)
|
| 75 |
-
# group 2: zero (0 components -> 全部 half 视为零频段)
|
| 76 |
-
splits_g0 = _split_head_dim_for_components(half, 3) # 用于 rays
|
| 77 |
-
splits_g1 = _split_head_dim_for_components(half, 3) # 用于 hwt
|
| 78 |
-
splits_g2 = _split_head_dim_for_components(half, 0) # [0, half]
|
| 79 |
-
|
| 80 |
-
# === 频率向量(沿 head_dim 半轴)===
|
| 81 |
-
# 经典 RoPE: theta_i = base ^ (-2i / d)
|
| 82 |
-
# 这里我们对每个分量独立排布频率
|
| 83 |
-
def _freqs(num_pairs: int) -> torch.Tensor:
|
| 84 |
-
# 前 num_pairs 个用 RoPE 频率,剩余补 0
|
| 85 |
-
idx = torch.arange(num_pairs, device=device, dtype=dtype)
|
| 86 |
-
freqs = rope_theta ** (-2.0 * idx / head_dim)
|
| 87 |
-
return freqs # [num_pairs]
|
| 88 |
-
|
| 89 |
-
# 把分量值与频率张量逐头展开为 [B, N_v, num_heads, half]
|
| 90 |
-
angles = torch.zeros(bsz, n_v, num_heads, half, device=device, dtype=dtype)
|
| 91 |
-
|
| 92 |
-
# ---- 第 0 组(4 头):射线 ----
|
| 93 |
-
base_offset = 0
|
| 94 |
-
h0_start = 0
|
| 95 |
-
h0_end = h0_start + heads_per_group
|
| 96 |
-
cursor = 0
|
| 97 |
-
for c in range(3): # dx, dy, dz
|
| 98 |
-
n_pairs = splits_g0[c]
|
| 99 |
-
if n_pairs > 0:
|
| 100 |
-
f = _freqs(n_pairs) # [n_pairs]
|
| 101 |
-
comp_val = rays[..., c : c + 1] # [B, N_v, 1]
|
| 102 |
-
ang = comp_val * f # 广播 -> [B, N_v, n_pairs]
|
| 103 |
-
angles[:, :, h0_start:h0_end, cursor : cursor + n_pairs] = ang.unsqueeze(2)
|
| 104 |
-
cursor += n_pairs
|
| 105 |
-
# 余数(splits_g0 最后一项的"补足"部分由 _split 已并入最后分量),无需置 0
|
| 106 |
-
|
| 107 |
-
# ---- 第 1 组(4 头):HWT ----
|
| 108 |
-
h1_start = heads_per_group
|
| 109 |
-
h1_end = h1_start + heads_per_group
|
| 110 |
-
cursor = 0
|
| 111 |
-
for c in range(3): # h, w, t
|
| 112 |
-
n_pairs = splits_g1[c]
|
| 113 |
-
if n_pairs > 0:
|
| 114 |
-
f = _freqs(n_pairs)
|
| 115 |
-
comp_val = hwt_grid[..., c : c + 1]
|
| 116 |
-
ang = comp_val * f
|
| 117 |
-
angles[:, :, h1_start:h1_end, cursor : cursor + n_pairs] = ang.unsqueeze(2)
|
| 118 |
-
cursor += n_pairs
|
| 119 |
-
|
| 120 |
-
# ---- 第 2 组(4 头):零频段 ----
|
| 121 |
-
# 角度恒为 0 → cos=1, sin=0 → 等价 identity;不需要再赋值(已是零)
|
| 122 |
-
|
| 123 |
-
cos = torch.cos(angles)
|
| 124 |
-
sin = torch.sin(angles)
|
| 125 |
-
return cos, sin
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
def apply_rope(
|
| 129 |
-
q: torch.Tensor,
|
| 130 |
-
k: torch.Tensor,
|
| 131 |
-
cos: torch.Tensor,
|
| 132 |
-
sin: torch.Tensor,
|
| 133 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 134 |
-
"""对 ``q`` ``k`` 的视觉 token 部分应用 3D RoPE。
|
| 135 |
-
|
| 136 |
-
所有 12 头一视同仁地走同一段代码(零频段头 cos=1/sin=0 → identity)。
|
| 137 |
-
|
| 138 |
-
参数
|
| 139 |
-
----
|
| 140 |
-
q, k : Tensor, shape ``[B, H, N_v, head_dim]``
|
| 141 |
-
cos, sin : Tensor, shape ``[B, N_v, H, head_dim // 2]``
|
| 142 |
-
|
| 143 |
-
返回
|
| 144 |
-
----
|
| 145 |
-
旋转后的 q, k,形状不变。
|
| 146 |
-
"""
|
| 147 |
-
# 把 cos/sin 转成 [B, H, N_v, half]
|
| 148 |
-
cos_e = cos.permute(0, 2, 1, 3)
|
| 149 |
-
sin_e = sin.permute(0, 2, 1, 3)
|
| 150 |
-
|
| 151 |
-
# 把 head_dim 维度按 (even, odd) 拆开成 [..., half]
|
| 152 |
-
q_even = q[..., 0::2]
|
| 153 |
-
q_odd = q[..., 1::2]
|
| 154 |
-
k_even = k[..., 0::2]
|
| 155 |
-
k_odd = k[..., 1::2]
|
| 156 |
-
|
| 157 |
-
q_rot_even = q_even * cos_e - q_odd * sin_e
|
| 158 |
-
q_rot_odd = q_even * sin_e + q_odd * cos_e
|
| 159 |
-
k_rot_even = k_even * cos_e - k_odd * sin_e
|
| 160 |
-
k_rot_odd = k_even * sin_e + k_odd * cos_e
|
| 161 |
-
|
| 162 |
-
q_out = torch.empty_like(q)
|
| 163 |
-
k_out = torch.empty_like(k)
|
| 164 |
-
q_out[..., 0::2] = q_rot_even
|
| 165 |
-
q_out[..., 1::2] = q_rot_odd
|
| 166 |
-
k_out[..., 0::2] = k_rot_even
|
| 167 |
-
k_out[..., 1::2] = k_rot_odd
|
| 168 |
-
return q_out, k_out
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
class RoPE3D(nn.Module):
|
| 172 |
-
"""3D RoPE 工具模块:缓存 hwt_grid(视觉 token 网格上不变),动态计算 rays。
|
| 173 |
-
|
| 174 |
-
使用方式:
|
| 175 |
-
rope = RoPE3D(num_heads=12, head_dim=64, T=4, H=12, W=32)
|
| 176 |
-
cos, sin = rope.compute_freqs(rays) # rays: [B, N_v, 3]
|
| 177 |
-
q, k = apply_rope(q_visual_only, k_visual_only, cos, sin)
|
| 178 |
-
"""
|
| 179 |
-
|
| 180 |
-
def __init__(
|
| 181 |
-
self,
|
| 182 |
-
num_heads: int = 12,
|
| 183 |
-
head_dim: int = 64,
|
| 184 |
-
time_size: int = 4,
|
| 185 |
-
height_size: int = 12,
|
| 186 |
-
width_size: int = 32,
|
| 187 |
-
rope_theta: float = 10000.0,
|
| 188 |
-
) -> None:
|
| 189 |
-
super().__init__()
|
| 190 |
-
self.num_heads = num_heads
|
| 191 |
-
self.head_dim = head_dim
|
| 192 |
-
self.rope_theta = rope_theta
|
| 193 |
-
self.T = time_size
|
| 194 |
-
self.H = height_size
|
| 195 |
-
self.W = width_size
|
| 196 |
-
|
| 197 |
-
# 预计算并缓存归一化 H/W/T 网格 [N_v, 3],N_v = T*H*W
|
| 198 |
-
t = torch.linspace(-1.0, 1.0, steps=time_size) if time_size > 1 else torch.zeros(1)
|
| 199 |
-
h = torch.linspace(-1.0, 1.0, steps=height_size) if height_size > 1 else torch.zeros(1)
|
| 200 |
-
w = torch.linspace(-1.0, 1.0, steps=width_size) if width_size > 1 else torch.zeros(1)
|
| 201 |
-
# 顺序:t -> h -> w(与 Conv3D 输出展平顺序一致)
|
| 202 |
-
T_, H_, W_ = torch.meshgrid(t, h, w, indexing="ij")
|
| 203 |
-
hwt = torch.stack([H_, W_, T_], dim=-1).reshape(-1, 3) # [N_v, 3]
|
| 204 |
-
self.register_buffer("hwt_grid", hwt, persistent=False)
|
| 205 |
-
|
| 206 |
-
@property
|
| 207 |
-
def num_visual_tokens(self) -> int:
|
| 208 |
-
return self.T * self.H * self.W
|
| 209 |
-
|
| 210 |
-
def compute_freqs(self, rays: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 211 |
-
"""根据每 token 的射线方向计算 cos/sin。
|
| 212 |
-
|
| 213 |
-
``rays`` shape: ``[B, N_v, 3]``。
|
| 214 |
-
"""
|
| 215 |
-
bsz = rays.shape[0]
|
| 216 |
-
hwt = self.hwt_grid.unsqueeze(0).expand(bsz, -1, -1) # [B, N_v, 3]
|
| 217 |
-
return build_rope_freqs(
|
| 218 |
-
rays=rays,
|
| 219 |
-
hwt_grid=hwt,
|
| 220 |
-
num_heads=self.num_heads,
|
| 221 |
-
head_dim=self.head_dim,
|
| 222 |
-
rope_theta=self.rope_theta,
|
| 223 |
-
dtype=rays.dtype,
|
| 224 |
-
)
|
|
|
|
| 1 |
+
"""3D RoPE(仅作用于视觉 token)。
|
| 2 |
+
|
| 3 |
+
12 头按 4+4+4 拆为三组:
|
| 4 |
+
- 头 0-3:射线 RoPE,编码自车系下的单位射线方向 ``(dx, dy, dz)``。
|
| 5 |
+
- 头 4-7:H/W/T RoPE,编码归一化的空间-时间索引 ``(h_norm, w_norm, t_norm)``。
|
| 6 |
+
- 头 8-11:零频段 RoPE,cos=1 / sin=0 → 旋转矩阵恒为 I(identity)。
|
| 7 |
+
|
| 8 |
+
为减少分支与显存通信,全部 12 头统一走同一份 RoPE 算子(不写 if/else),
|
| 9 |
+
零频段头自然变为恒等映射。
|
| 10 |
+
|
| 11 |
+
将 ``head_dim=64`` 切成 32 个 (cos, sin) 对(两两一组旋转)。每组头内部再按
|
| 12 |
+
3 个分量(dx,dy,dz 或 h,w,t)平均分配 32/3 ≈ 10 对(最后 2 对补 0 频)。
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _split_head_dim_for_components(half: int, num_components: int) -> list[int]:
|
| 22 |
+
"""把 head_dim/2 个旋转对均匀分给若干个分量;剩余补 0 频。
|
| 23 |
+
|
| 24 |
+
返回每个分量分到的旋转对数,最后一项是 ``half - sum(其它)``。
|
| 25 |
+
若 ``num_components == 0``(零频段头),则返回 ``[0, 0, ..., half]``,最后
|
| 26 |
+
一项视为"零频段"——它的频率会被显式置为 0。
|
| 27 |
+
"""
|
| 28 |
+
if num_components == 0:
|
| 29 |
+
return [0, half]
|
| 30 |
+
base = half // num_components
|
| 31 |
+
splits = [base] * num_components
|
| 32 |
+
splits[-1] += half - base * num_components # 余数全归到最后一个分量
|
| 33 |
+
return splits
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def build_rope_freqs(
|
| 37 |
+
rays: torch.Tensor,
|
| 38 |
+
hwt_grid: torch.Tensor,
|
| 39 |
+
num_heads: int = 12,
|
| 40 |
+
head_dim: int = 64,
|
| 41 |
+
rope_theta: float = 10000.0,
|
| 42 |
+
device: torch.device | None = None,
|
| 43 |
+
dtype: torch.dtype = torch.float32,
|
| 44 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 45 |
+
"""构造 3D RoPE 的 cos / sin 表。
|
| 46 |
+
|
| 47 |
+
参数
|
| 48 |
+
----
|
| 49 |
+
rays : Tensor, shape ``[B, N_v, 3]``
|
| 50 |
+
每个视觉 token 在自车系下的单位射线方向 ``(dx, dy, dz)``。
|
| 51 |
+
hwt_grid : Tensor, shape ``[B, N_v, 3]``
|
| 52 |
+
归一化的空间-时间坐标 ``(h_norm, w_norm, t_norm)`` ∈ [-1, 1]。
|
| 53 |
+
num_heads : int
|
| 54 |
+
总头数(默认 12)。
|
| 55 |
+
head_dim : int
|
| 56 |
+
每头维度(默认 64,必须为偶数)。
|
| 57 |
+
|
| 58 |
+
返回
|
| 59 |
+
----
|
| 60 |
+
cos, sin : Tensor, shape ``[B, N_v, num_heads, head_dim // 2]``
|
| 61 |
+
每个旋转对的 cos / sin 值,已就绪可送入 ``apply_rope``。
|
| 62 |
+
"""
|
| 63 |
+
assert head_dim % 2 == 0, "head_dim 必须为偶数"
|
| 64 |
+
assert num_heads % 3 == 0, "num_heads 需被 3 整除以便 4+4+4 分组"
|
| 65 |
+
|
| 66 |
+
half = head_dim // 2
|
| 67 |
+
heads_per_group = num_heads // 3
|
| 68 |
+
bsz, n_v, _ = rays.shape
|
| 69 |
+
if device is None:
|
| 70 |
+
device = rays.device
|
| 71 |
+
|
| 72 |
+
# === 三组分量值 ===
|
| 73 |
+
# group 0: rays (3 components)
|
| 74 |
+
# group 1: hwt (3 components)
|
| 75 |
+
# group 2: zero (0 components -> 全部 half 视为零频段)
|
| 76 |
+
splits_g0 = _split_head_dim_for_components(half, 3) # 用于 rays
|
| 77 |
+
splits_g1 = _split_head_dim_for_components(half, 3) # 用于 hwt
|
| 78 |
+
splits_g2 = _split_head_dim_for_components(half, 0) # [0, half]
|
| 79 |
+
|
| 80 |
+
# === 频率向量(沿 head_dim 半轴)===
|
| 81 |
+
# 经典 RoPE: theta_i = base ^ (-2i / d)
|
| 82 |
+
# 这里我们对每个分量独立排布频率
|
| 83 |
+
def _freqs(num_pairs: int) -> torch.Tensor:
|
| 84 |
+
# 前 num_pairs 个用 RoPE 频率,剩余补 0
|
| 85 |
+
idx = torch.arange(num_pairs, device=device, dtype=dtype)
|
| 86 |
+
freqs = rope_theta ** (-2.0 * idx / head_dim)
|
| 87 |
+
return freqs # [num_pairs]
|
| 88 |
+
|
| 89 |
+
# 把分量值与频率张量逐头展开为 [B, N_v, num_heads, half]
|
| 90 |
+
angles = torch.zeros(bsz, n_v, num_heads, half, device=device, dtype=dtype)
|
| 91 |
+
|
| 92 |
+
# ---- 第 0 组(4 头):射线 ----
|
| 93 |
+
base_offset = 0
|
| 94 |
+
h0_start = 0
|
| 95 |
+
h0_end = h0_start + heads_per_group
|
| 96 |
+
cursor = 0
|
| 97 |
+
for c in range(3): # dx, dy, dz
|
| 98 |
+
n_pairs = splits_g0[c]
|
| 99 |
+
if n_pairs > 0:
|
| 100 |
+
f = _freqs(n_pairs) # [n_pairs]
|
| 101 |
+
comp_val = rays[..., c : c + 1] # [B, N_v, 1]
|
| 102 |
+
ang = comp_val * f # 广播 -> [B, N_v, n_pairs]
|
| 103 |
+
angles[:, :, h0_start:h0_end, cursor : cursor + n_pairs] = ang.unsqueeze(2)
|
| 104 |
+
cursor += n_pairs
|
| 105 |
+
# 余数(splits_g0 最后一项的"补足"部分由 _split 已并入最后分量),无需置 0
|
| 106 |
+
|
| 107 |
+
# ---- 第 1 组(4 头):HWT ----
|
| 108 |
+
h1_start = heads_per_group
|
| 109 |
+
h1_end = h1_start + heads_per_group
|
| 110 |
+
cursor = 0
|
| 111 |
+
for c in range(3): # h, w, t
|
| 112 |
+
n_pairs = splits_g1[c]
|
| 113 |
+
if n_pairs > 0:
|
| 114 |
+
f = _freqs(n_pairs)
|
| 115 |
+
comp_val = hwt_grid[..., c : c + 1]
|
| 116 |
+
ang = comp_val * f
|
| 117 |
+
angles[:, :, h1_start:h1_end, cursor : cursor + n_pairs] = ang.unsqueeze(2)
|
| 118 |
+
cursor += n_pairs
|
| 119 |
+
|
| 120 |
+
# ---- 第 2 组(4 头):零频段 ----
|
| 121 |
+
# 角度恒为 0 → cos=1, sin=0 → 等价 identity;不需要再赋值(已是零)
|
| 122 |
+
|
| 123 |
+
cos = torch.cos(angles)
|
| 124 |
+
sin = torch.sin(angles)
|
| 125 |
+
return cos, sin
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def apply_rope(
|
| 129 |
+
q: torch.Tensor,
|
| 130 |
+
k: torch.Tensor,
|
| 131 |
+
cos: torch.Tensor,
|
| 132 |
+
sin: torch.Tensor,
|
| 133 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 134 |
+
"""对 ``q`` ``k`` 的视觉 token 部分应用 3D RoPE。
|
| 135 |
+
|
| 136 |
+
所有 12 头一视同仁地走同一段代码(零频段头 cos=1/sin=0 → identity)。
|
| 137 |
+
|
| 138 |
+
参数
|
| 139 |
+
----
|
| 140 |
+
q, k : Tensor, shape ``[B, H, N_v, head_dim]``
|
| 141 |
+
cos, sin : Tensor, shape ``[B, N_v, H, head_dim // 2]``
|
| 142 |
+
|
| 143 |
+
返回
|
| 144 |
+
----
|
| 145 |
+
旋转后的 q, k,形状不变。
|
| 146 |
+
"""
|
| 147 |
+
# 把 cos/sin 转成 [B, H, N_v, half]
|
| 148 |
+
cos_e = cos.permute(0, 2, 1, 3)
|
| 149 |
+
sin_e = sin.permute(0, 2, 1, 3)
|
| 150 |
+
|
| 151 |
+
# 把 head_dim 维度按 (even, odd) 拆开成 [..., half]
|
| 152 |
+
q_even = q[..., 0::2]
|
| 153 |
+
q_odd = q[..., 1::2]
|
| 154 |
+
k_even = k[..., 0::2]
|
| 155 |
+
k_odd = k[..., 1::2]
|
| 156 |
+
|
| 157 |
+
q_rot_even = q_even * cos_e - q_odd * sin_e
|
| 158 |
+
q_rot_odd = q_even * sin_e + q_odd * cos_e
|
| 159 |
+
k_rot_even = k_even * cos_e - k_odd * sin_e
|
| 160 |
+
k_rot_odd = k_even * sin_e + k_odd * cos_e
|
| 161 |
+
|
| 162 |
+
q_out = torch.empty_like(q)
|
| 163 |
+
k_out = torch.empty_like(k)
|
| 164 |
+
q_out[..., 0::2] = q_rot_even
|
| 165 |
+
q_out[..., 1::2] = q_rot_odd
|
| 166 |
+
k_out[..., 0::2] = k_rot_even
|
| 167 |
+
k_out[..., 1::2] = k_rot_odd
|
| 168 |
+
return q_out, k_out
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class RoPE3D(nn.Module):
|
| 172 |
+
"""3D RoPE 工具模块:缓存 hwt_grid(视觉 token 网格上不变),动态计算 rays。
|
| 173 |
+
|
| 174 |
+
使用方式:
|
| 175 |
+
rope = RoPE3D(num_heads=12, head_dim=64, T=4, H=12, W=32)
|
| 176 |
+
cos, sin = rope.compute_freqs(rays) # rays: [B, N_v, 3]
|
| 177 |
+
q, k = apply_rope(q_visual_only, k_visual_only, cos, sin)
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
def __init__(
|
| 181 |
+
self,
|
| 182 |
+
num_heads: int = 12,
|
| 183 |
+
head_dim: int = 64,
|
| 184 |
+
time_size: int = 4,
|
| 185 |
+
height_size: int = 12,
|
| 186 |
+
width_size: int = 32,
|
| 187 |
+
rope_theta: float = 10000.0,
|
| 188 |
+
) -> None:
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.num_heads = num_heads
|
| 191 |
+
self.head_dim = head_dim
|
| 192 |
+
self.rope_theta = rope_theta
|
| 193 |
+
self.T = time_size
|
| 194 |
+
self.H = height_size
|
| 195 |
+
self.W = width_size
|
| 196 |
+
|
| 197 |
+
# 预计算并缓存归一化 H/W/T 网格 [N_v, 3],N_v = T*H*W
|
| 198 |
+
t = torch.linspace(-1.0, 1.0, steps=time_size) if time_size > 1 else torch.zeros(1)
|
| 199 |
+
h = torch.linspace(-1.0, 1.0, steps=height_size) if height_size > 1 else torch.zeros(1)
|
| 200 |
+
w = torch.linspace(-1.0, 1.0, steps=width_size) if width_size > 1 else torch.zeros(1)
|
| 201 |
+
# 顺序:t -> h -> w(与 Conv3D 输出展平顺序一致)
|
| 202 |
+
T_, H_, W_ = torch.meshgrid(t, h, w, indexing="ij")
|
| 203 |
+
hwt = torch.stack([H_, W_, T_], dim=-1).reshape(-1, 3) # [N_v, 3]
|
| 204 |
+
self.register_buffer("hwt_grid", hwt, persistent=False)
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def num_visual_tokens(self) -> int:
|
| 208 |
+
return self.T * self.H * self.W
|
| 209 |
+
|
| 210 |
+
def compute_freqs(self, rays: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 211 |
+
"""根据每 token 的射线方向计算 cos/sin。
|
| 212 |
+
|
| 213 |
+
``rays`` shape: ``[B, N_v, 3]``。
|
| 214 |
+
"""
|
| 215 |
+
bsz = rays.shape[0]
|
| 216 |
+
hwt = self.hwt_grid.unsqueeze(0).expand(bsz, -1, -1) # [B, N_v, 3]
|
| 217 |
+
return build_rope_freqs(
|
| 218 |
+
rays=rays,
|
| 219 |
+
hwt_grid=hwt,
|
| 220 |
+
num_heads=self.num_heads,
|
| 221 |
+
head_dim=self.head_dim,
|
| 222 |
+
rope_theta=self.rope_theta,
|
| 223 |
+
dtype=rays.dtype,
|
| 224 |
+
)
|
src/wjad/modules/rays.py
CHANGED
|
@@ -1,182 +1,182 @@
|
|
| 1 |
-
"""f-theta 相机模型 + 射线计算。
|
| 2 |
-
|
| 3 |
-
依据 Cosmos-Drive-Dreams 数据集 README:
|
| 4 |
-
ftheta_intrinsic 存储为 ``[cx, cy, w, h, *poly(6), is_bw_poly, *linear_cde(3)]``。
|
| 5 |
-
|
| 6 |
-
f-theta 相机模型用 6 阶多项式将像素半径 ``r_pix = ||(u-cx, v-cy)||`` 映射到
|
| 7 |
-
入射角 ``theta``(或反向)。``is_bw_poly == True`` 表示多项式是从 ``r_pix`` 反
|
| 8 |
-
求 ``theta`` 的 backward polynomial(pixel -> theta);否则是 forward polynomial
|
| 9 |
-
(theta -> r_pix)。``linear_cde`` 是仿射修正系数 ``[c, d, e]``,用于补偿轻微
|
| 10 |
-
的非旋转对称形变。
|
| 11 |
-
|
| 12 |
-
为了简单与可微,本模块默认假设 backward polynomial(``is_bw_poly=True``,
|
| 13 |
-
即 ``theta = poly(r_pix)``);实际数据通常是这种格式。如需 forward 多项式,
|
| 14 |
-
这里使用牛顿迭代反求。
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
from __future__ import annotations
|
| 18 |
-
|
| 19 |
-
from dataclasses import dataclass
|
| 20 |
-
|
| 21 |
-
import torch
|
| 22 |
-
import torch.nn.functional as F
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
@dataclass
|
| 26 |
-
class FThetaIntrinsic:
|
| 27 |
-
"""f-theta 内参(PyTorch 张量形式)。
|
| 28 |
-
|
| 29 |
-
所有字段均为标量或一维向量;外部使用时通常 broadcast 到 batch。
|
| 30 |
-
"""
|
| 31 |
-
|
| 32 |
-
cx: torch.Tensor # ()
|
| 33 |
-
cy: torch.Tensor # ()
|
| 34 |
-
w: int
|
| 35 |
-
h: int
|
| 36 |
-
poly: torch.Tensor # (6,)
|
| 37 |
-
is_bw_poly: bool
|
| 38 |
-
linear_cde: torch.Tensor # (3,)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
class FThetaCamera:
|
| 42 |
-
"""f-theta 相机:像素 -> 单位射线方向(相机坐标系)。"""
|
| 43 |
-
|
| 44 |
-
def __init__(self, intr: FThetaIntrinsic) -> None:
|
| 45 |
-
self.intr = intr
|
| 46 |
-
|
| 47 |
-
@staticmethod
|
| 48 |
-
def from_vector(vec: torch.Tensor) -> "FThetaCamera":
|
| 49 |
-
"""从 NVIDIA ftheta 向量构造:``[cx, cy, w, h, poly×6, is_bw_poly?, linear_cde×3?]``。
|
| 50 |
-
|
| 51 |
-
官方常见 14 维;部分 clip 仅 11 维(无 ``linear_cde``),此时用 ``(1,0,1)``,
|
| 52 |
-
与 ``unproject`` 里近似一致。
|
| 53 |
-
"""
|
| 54 |
-
v = vec.flatten().float()
|
| 55 |
-
n = int(v.numel())
|
| 56 |
-
if n < 10:
|
| 57 |
-
raise ValueError(f"ftheta intrinsic 维度 {n} < 10(至少需要 cx,cy,w,h + 6 poly)")
|
| 58 |
-
cx = v[0]
|
| 59 |
-
cy = v[1]
|
| 60 |
-
w = int(v[2].item())
|
| 61 |
-
h = int(v[3].item())
|
| 62 |
-
poly = v[4:10].clone()
|
| 63 |
-
if n >= 11:
|
| 64 |
-
is_bw = bool(v[10].item() > 0.5)
|
| 65 |
-
else:
|
| 66 |
-
is_bw = True
|
| 67 |
-
if n >= 14:
|
| 68 |
-
linear_cde = v[11:14].clone()
|
| 69 |
-
else:
|
| 70 |
-
linear_cde = torch.tensor([1.0, 0.0, 1.0], dtype=v.dtype, device=v.device)
|
| 71 |
-
return FThetaCamera(
|
| 72 |
-
FThetaIntrinsic(cx=cx, cy=cy, w=w, h=h, poly=poly, is_bw_poly=is_bw, linear_cde=linear_cde)
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
def _eval_poly(self, r: torch.Tensor) -> torch.Tensor:
|
| 76 |
-
"""用 Horner 法计算 poly(r) = sum_{i=0..5} c_i * r^i。"""
|
| 77 |
-
c = self.intr.poly
|
| 78 |
-
out = torch.zeros_like(r)
|
| 79 |
-
for i in range(c.numel() - 1, -1, -1):
|
| 80 |
-
out = out * r + c[i]
|
| 81 |
-
return out
|
| 82 |
-
|
| 83 |
-
def _eval_poly_grad(self, r: torch.Tensor) -> torch.Tensor:
|
| 84 |
-
"""poly 的导数。"""
|
| 85 |
-
c = self.intr.poly
|
| 86 |
-
n = c.numel()
|
| 87 |
-
out = torch.zeros_like(r)
|
| 88 |
-
for i in range(n - 1, 0, -1):
|
| 89 |
-
out = out * r + c[i] * float(i)
|
| 90 |
-
return out
|
| 91 |
-
|
| 92 |
-
def pixel_to_theta(self, r_pix: torch.Tensor) -> torch.Tensor:
|
| 93 |
-
"""像素半径 -> 入射角 theta(弧度)。"""
|
| 94 |
-
if self.intr.is_bw_poly:
|
| 95 |
-
return self._eval_poly(r_pix)
|
| 96 |
-
# forward: r_pix = poly(theta) -> 牛顿迭代
|
| 97 |
-
theta = r_pix.clone()
|
| 98 |
-
for _ in range(8):
|
| 99 |
-
f = self._eval_poly(theta) - r_pix
|
| 100 |
-
df = self._eval_poly_grad(theta).clamp_min(1e-6)
|
| 101 |
-
theta = theta - f / df
|
| 102 |
-
return theta
|
| 103 |
-
|
| 104 |
-
def unproject(self, uv: torch.Tensor) -> torch.Tensor:
|
| 105 |
-
"""像素坐标 ``[..., 2]`` -> 相机坐标系下的单位方向 ``[..., 3]``。
|
| 106 |
-
|
| 107 |
-
f-theta 反投影:
|
| 108 |
-
(du, dv) = (u - cx, v - cy) (并应用 linear_cde 的微小仿射)
|
| 109 |
-
r_pix = ||(du, dv)||
|
| 110 |
-
theta = poly(r_pix) 或 inv_poly(r_pix)
|
| 111 |
-
phi = atan2(dv, du)
|
| 112 |
-
dir_cam = (sin(theta)*cos(phi), sin(theta)*sin(phi), cos(theta))
|
| 113 |
-
"""
|
| 114 |
-
cx = self.intr.cx
|
| 115 |
-
cy = self.intr.cy
|
| 116 |
-
c, d, e = self.intr.linear_cde[0], self.intr.linear_cde[1], self.intr.linear_cde[2]
|
| 117 |
-
|
| 118 |
-
u = uv[..., 0]
|
| 119 |
-
v = uv[..., 1]
|
| 120 |
-
# 应用线性修正:du' = c*du + d*dv + e*1(NVIDIA 工具中通常是 2x2 仿射,这里做近似)
|
| 121 |
-
du0 = u - cx
|
| 122 |
-
dv0 = v - cy
|
| 123 |
-
du = c * du0 + d * dv0
|
| 124 |
-
dv = e * du0 + dv0 # 简化:保持 dv 不变量、加入 e*du 微调
|
| 125 |
-
r_pix = torch.sqrt(du * du + dv * dv).clamp_min(1e-6)
|
| 126 |
-
theta = self.pixel_to_theta(r_pix)
|
| 127 |
-
|
| 128 |
-
sin_t = torch.sin(theta)
|
| 129 |
-
cos_t = torch.cos(theta)
|
| 130 |
-
cos_p = du / r_pix
|
| 131 |
-
sin_p = dv / r_pix
|
| 132 |
-
x = sin_t * cos_p
|
| 133 |
-
y = sin_t * sin_p
|
| 134 |
-
z = cos_t
|
| 135 |
-
dir_cam = torch.stack([x, y, z], dim=-1)
|
| 136 |
-
return F.normalize(dir_cam, dim=-1)
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
def compute_ego_rays(
|
| 140 |
-
intr_vec: torch.Tensor,
|
| 141 |
-
cam2vehicle: torch.Tensor,
|
| 142 |
-
height: int,
|
| 143 |
-
width: int,
|
| 144 |
-
grid_h: int,
|
| 145 |
-
grid_w: int,
|
| 146 |
-
device: torch.device,
|
| 147 |
-
dtype: torch.dtype = torch.float32,
|
| 148 |
-
) -> torch.Tensor:
|
| 149 |
-
"""对一个 ``grid_h x grid_w`` 的均匀像素网格计算自车系下单位射线方向。
|
| 150 |
-
|
| 151 |
-
参数
|
| 152 |
-
----
|
| 153 |
-
intr_vec : ``[B, 14]`` 或 ``[14]``,f-theta 内参向量。
|
| 154 |
-
cam2vehicle : ``[B, 4, 4]`` 或 ``[4, 4]`` 相机系到自车系的变换。
|
| 155 |
-
height, width : 图像分辨率(像素),用于在 ``[0, w] x [0, h]`` 网格采样。
|
| 156 |
-
grid_h, grid_w : 输出射线网格分辨率(与 patch 网格一致,例如 24x64)。
|
| 157 |
-
|
| 158 |
-
返回
|
| 159 |
-
----
|
| 160 |
-
rays : ``[B, grid_h, grid_w, 3]``,自车系下单位方向。
|
| 161 |
-
"""
|
| 162 |
-
if intr_vec.dim() == 1:
|
| 163 |
-
intr_vec = intr_vec.unsqueeze(0)
|
| 164 |
-
if cam2vehicle.dim() == 2:
|
| 165 |
-
cam2vehicle = cam2vehicle.unsqueeze(0)
|
| 166 |
-
B = intr_vec.shape[0]
|
| 167 |
-
|
| 168 |
-
# 在像素中心采样
|
| 169 |
-
u = (torch.arange(grid_w, device=device, dtype=dtype) + 0.5) * (width / grid_w)
|
| 170 |
-
v = (torch.arange(grid_h, device=device, dtype=dtype) + 0.5) * (height / grid_h)
|
| 171 |
-
vv, uu = torch.meshgrid(v, u, indexing="ij") # [gh, gw]
|
| 172 |
-
uv = torch.stack([uu, vv], dim=-1) # [gh, gw, 2]
|
| 173 |
-
|
| 174 |
-
out = []
|
| 175 |
-
for b in range(B):
|
| 176 |
-
cam = FThetaCamera.from_vector(intr_vec[b].to(dtype))
|
| 177 |
-
dir_cam = cam.unproject(uv) # [gh, gw, 3]
|
| 178 |
-
# 旋到自车系:取 cam2vehicle 的 3x3 旋转部分
|
| 179 |
-
R = cam2vehicle[b, :3, :3].to(dtype)
|
| 180 |
-
dir_veh = dir_cam @ R.T # [gh, gw, 3]
|
| 181 |
-
out.append(F.normalize(dir_veh, dim=-1))
|
| 182 |
-
return torch.stack(out, dim=0)
|
|
|
|
| 1 |
+
"""f-theta 相机模型 + 射线计算。
|
| 2 |
+
|
| 3 |
+
依据 Cosmos-Drive-Dreams 数据集 README:
|
| 4 |
+
ftheta_intrinsic 存储为 ``[cx, cy, w, h, *poly(6), is_bw_poly, *linear_cde(3)]``。
|
| 5 |
+
|
| 6 |
+
f-theta 相机模型用 6 阶多项式将像素半径 ``r_pix = ||(u-cx, v-cy)||`` 映射到
|
| 7 |
+
入射角 ``theta``(或反向)。``is_bw_poly == True`` 表示多项式是从 ``r_pix`` 反
|
| 8 |
+
求 ``theta`` 的 backward polynomial(pixel -> theta);否则是 forward polynomial
|
| 9 |
+
(theta -> r_pix)。``linear_cde`` 是仿射修正系数 ``[c, d, e]``,用于补偿轻微
|
| 10 |
+
的非旋转对称形变。
|
| 11 |
+
|
| 12 |
+
为了简单与可微,本模块默认假设 backward polynomial(``is_bw_poly=True``,
|
| 13 |
+
即 ``theta = poly(r_pix)``);实际数据通常是这种格式。如需 forward 多项式,
|
| 14 |
+
这里使用牛顿迭代反求。
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class FThetaIntrinsic:
|
| 27 |
+
"""f-theta 内参(PyTorch 张量形式)。
|
| 28 |
+
|
| 29 |
+
所有字段均为标量或一维向量;外部使用时通常 broadcast 到 batch。
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
cx: torch.Tensor # ()
|
| 33 |
+
cy: torch.Tensor # ()
|
| 34 |
+
w: int
|
| 35 |
+
h: int
|
| 36 |
+
poly: torch.Tensor # (6,)
|
| 37 |
+
is_bw_poly: bool
|
| 38 |
+
linear_cde: torch.Tensor # (3,)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class FThetaCamera:
|
| 42 |
+
"""f-theta 相机:像素 -> 单位射线方向(相机坐标系)。"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, intr: FThetaIntrinsic) -> None:
|
| 45 |
+
self.intr = intr
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def from_vector(vec: torch.Tensor) -> "FThetaCamera":
|
| 49 |
+
"""从 NVIDIA ftheta 向量构造:``[cx, cy, w, h, poly×6, is_bw_poly?, linear_cde×3?]``。
|
| 50 |
+
|
| 51 |
+
官方常见 14 维;部分 clip 仅 11 维(无 ``linear_cde``),此时用 ``(1,0,1)``,
|
| 52 |
+
与 ``unproject`` 里近似一致。
|
| 53 |
+
"""
|
| 54 |
+
v = vec.flatten().float()
|
| 55 |
+
n = int(v.numel())
|
| 56 |
+
if n < 10:
|
| 57 |
+
raise ValueError(f"ftheta intrinsic 维度 {n} < 10(至少需要 cx,cy,w,h + 6 poly)")
|
| 58 |
+
cx = v[0]
|
| 59 |
+
cy = v[1]
|
| 60 |
+
w = int(v[2].item())
|
| 61 |
+
h = int(v[3].item())
|
| 62 |
+
poly = v[4:10].clone()
|
| 63 |
+
if n >= 11:
|
| 64 |
+
is_bw = bool(v[10].item() > 0.5)
|
| 65 |
+
else:
|
| 66 |
+
is_bw = True
|
| 67 |
+
if n >= 14:
|
| 68 |
+
linear_cde = v[11:14].clone()
|
| 69 |
+
else:
|
| 70 |
+
linear_cde = torch.tensor([1.0, 0.0, 1.0], dtype=v.dtype, device=v.device)
|
| 71 |
+
return FThetaCamera(
|
| 72 |
+
FThetaIntrinsic(cx=cx, cy=cy, w=w, h=h, poly=poly, is_bw_poly=is_bw, linear_cde=linear_cde)
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def _eval_poly(self, r: torch.Tensor) -> torch.Tensor:
|
| 76 |
+
"""用 Horner 法计算 poly(r) = sum_{i=0..5} c_i * r^i。"""
|
| 77 |
+
c = self.intr.poly
|
| 78 |
+
out = torch.zeros_like(r)
|
| 79 |
+
for i in range(c.numel() - 1, -1, -1):
|
| 80 |
+
out = out * r + c[i]
|
| 81 |
+
return out
|
| 82 |
+
|
| 83 |
+
def _eval_poly_grad(self, r: torch.Tensor) -> torch.Tensor:
|
| 84 |
+
"""poly 的导数。"""
|
| 85 |
+
c = self.intr.poly
|
| 86 |
+
n = c.numel()
|
| 87 |
+
out = torch.zeros_like(r)
|
| 88 |
+
for i in range(n - 1, 0, -1):
|
| 89 |
+
out = out * r + c[i] * float(i)
|
| 90 |
+
return out
|
| 91 |
+
|
| 92 |
+
def pixel_to_theta(self, r_pix: torch.Tensor) -> torch.Tensor:
|
| 93 |
+
"""像素半径 -> 入射角 theta(弧度)。"""
|
| 94 |
+
if self.intr.is_bw_poly:
|
| 95 |
+
return self._eval_poly(r_pix)
|
| 96 |
+
# forward: r_pix = poly(theta) -> 牛顿迭代
|
| 97 |
+
theta = r_pix.clone()
|
| 98 |
+
for _ in range(8):
|
| 99 |
+
f = self._eval_poly(theta) - r_pix
|
| 100 |
+
df = self._eval_poly_grad(theta).clamp_min(1e-6)
|
| 101 |
+
theta = theta - f / df
|
| 102 |
+
return theta
|
| 103 |
+
|
| 104 |
+
def unproject(self, uv: torch.Tensor) -> torch.Tensor:
|
| 105 |
+
"""像素坐标 ``[..., 2]`` -> 相机坐标系下的单位方向 ``[..., 3]``。
|
| 106 |
+
|
| 107 |
+
f-theta 反投影:
|
| 108 |
+
(du, dv) = (u - cx, v - cy) (并应用 linear_cde 的微小仿射)
|
| 109 |
+
r_pix = ||(du, dv)||
|
| 110 |
+
theta = poly(r_pix) 或 inv_poly(r_pix)
|
| 111 |
+
phi = atan2(dv, du)
|
| 112 |
+
dir_cam = (sin(theta)*cos(phi), sin(theta)*sin(phi), cos(theta))
|
| 113 |
+
"""
|
| 114 |
+
cx = self.intr.cx
|
| 115 |
+
cy = self.intr.cy
|
| 116 |
+
c, d, e = self.intr.linear_cde[0], self.intr.linear_cde[1], self.intr.linear_cde[2]
|
| 117 |
+
|
| 118 |
+
u = uv[..., 0]
|
| 119 |
+
v = uv[..., 1]
|
| 120 |
+
# 应用线性修正:du' = c*du + d*dv + e*1(NVIDIA 工具中通常是 2x2 仿射,这里做近似)
|
| 121 |
+
du0 = u - cx
|
| 122 |
+
dv0 = v - cy
|
| 123 |
+
du = c * du0 + d * dv0
|
| 124 |
+
dv = e * du0 + dv0 # 简化:保持 dv 不变量、加入 e*du 微调
|
| 125 |
+
r_pix = torch.sqrt(du * du + dv * dv).clamp_min(1e-6)
|
| 126 |
+
theta = self.pixel_to_theta(r_pix)
|
| 127 |
+
|
| 128 |
+
sin_t = torch.sin(theta)
|
| 129 |
+
cos_t = torch.cos(theta)
|
| 130 |
+
cos_p = du / r_pix
|
| 131 |
+
sin_p = dv / r_pix
|
| 132 |
+
x = sin_t * cos_p
|
| 133 |
+
y = sin_t * sin_p
|
| 134 |
+
z = cos_t
|
| 135 |
+
dir_cam = torch.stack([x, y, z], dim=-1)
|
| 136 |
+
return F.normalize(dir_cam, dim=-1)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def compute_ego_rays(
|
| 140 |
+
intr_vec: torch.Tensor,
|
| 141 |
+
cam2vehicle: torch.Tensor,
|
| 142 |
+
height: int,
|
| 143 |
+
width: int,
|
| 144 |
+
grid_h: int,
|
| 145 |
+
grid_w: int,
|
| 146 |
+
device: torch.device,
|
| 147 |
+
dtype: torch.dtype = torch.float32,
|
| 148 |
+
) -> torch.Tensor:
|
| 149 |
+
"""对一个 ``grid_h x grid_w`` 的均匀像素网格计算自车系下单位射线方向。
|
| 150 |
+
|
| 151 |
+
参数
|
| 152 |
+
----
|
| 153 |
+
intr_vec : ``[B, 14]`` 或 ``[14]``,f-theta 内参向量。
|
| 154 |
+
cam2vehicle : ``[B, 4, 4]`` 或 ``[4, 4]`` 相机系到自车系的变换。
|
| 155 |
+
height, width : 图像分辨率(像素),用于在 ``[0, w] x [0, h]`` 网格采样。
|
| 156 |
+
grid_h, grid_w : 输出射线网格分辨率(与 patch 网格一致,例如 24x64)。
|
| 157 |
+
|
| 158 |
+
返回
|
| 159 |
+
----
|
| 160 |
+
rays : ``[B, grid_h, grid_w, 3]``,自车系下单位方向。
|
| 161 |
+
"""
|
| 162 |
+
if intr_vec.dim() == 1:
|
| 163 |
+
intr_vec = intr_vec.unsqueeze(0)
|
| 164 |
+
if cam2vehicle.dim() == 2:
|
| 165 |
+
cam2vehicle = cam2vehicle.unsqueeze(0)
|
| 166 |
+
B = intr_vec.shape[0]
|
| 167 |
+
|
| 168 |
+
# 在像素中心采样
|
| 169 |
+
u = (torch.arange(grid_w, device=device, dtype=dtype) + 0.5) * (width / grid_w)
|
| 170 |
+
v = (torch.arange(grid_h, device=device, dtype=dtype) + 0.5) * (height / grid_h)
|
| 171 |
+
vv, uu = torch.meshgrid(v, u, indexing="ij") # [gh, gw]
|
| 172 |
+
uv = torch.stack([uu, vv], dim=-1) # [gh, gw, 2]
|
| 173 |
+
|
| 174 |
+
out = []
|
| 175 |
+
for b in range(B):
|
| 176 |
+
cam = FThetaCamera.from_vector(intr_vec[b].to(dtype))
|
| 177 |
+
dir_cam = cam.unproject(uv) # [gh, gw, 3]
|
| 178 |
+
# 旋到自车系:取 cam2vehicle 的 3x3 旋转部分
|
| 179 |
+
R = cam2vehicle[b, :3, :3].to(dtype)
|
| 180 |
+
dir_veh = dir_cam @ R.T # [gh, gw, 3]
|
| 181 |
+
out.append(F.normalize(dir_veh, dim=-1))
|
| 182 |
+
return torch.stack(out, dim=0)
|
src/wjad/modules/temporal_compress.py
CHANGED
|
@@ -1,34 +1,34 @@
|
|
| 1 |
-
"""2×2×2 时空压缩卷积。
|
| 2 |
-
|
| 3 |
-
将 8 帧 × 24 × 64 = 8×1536 = 12288 个 patch tokens 压缩为
|
| 4 |
-
4 × 12 × 32 = 1536 个视觉 tokens。维度保持 768。
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from __future__ import annotations
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class TemporalCompress2x2x2(nn.Module):
|
| 14 |
-
"""``Conv3d(D, D, kernel=2, stride=2)`` 配合标准 LayerNorm。"""
|
| 15 |
-
|
| 16 |
-
def __init__(self, dim: int = 768) -> None:
|
| 17 |
-
super().__init__()
|
| 18 |
-
self.dim = dim
|
| 19 |
-
self.conv = nn.Conv3d(dim, dim, kernel_size=2, stride=2, padding=0)
|
| 20 |
-
self.norm = nn.LayerNorm(dim)
|
| 21 |
-
|
| 22 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 23 |
-
"""输入 ``[B, T, H, W, D]``;输出 ``[B, T*H*W//8, D]``。
|
| 24 |
-
|
| 25 |
-
中间排布:
|
| 26 |
-
[B, T, H, W, D] -> [B, D, T, H, W] -> Conv3d -> [B, D, T', H', W']
|
| 27 |
-
-> [B, T'*H'*W', D] -> LayerNorm
|
| 28 |
-
"""
|
| 29 |
-
b, t, h, w, d = x.shape
|
| 30 |
-
x_in = x.permute(0, 4, 1, 2, 3).contiguous() # [B, D, T, H, W]
|
| 31 |
-
y = self.conv(x_in)
|
| 32 |
-
bb, dd, t2, h2, w2 = y.shape
|
| 33 |
-
y = y.permute(0, 2, 3, 4, 1).reshape(bb, t2 * h2 * w2, dd)
|
| 34 |
-
return self.norm(y), (t2, h2, w2)
|
|
|
|
| 1 |
+
"""2×2×2 时空压缩卷积。
|
| 2 |
+
|
| 3 |
+
将 8 帧 × 24 × 64 = 8×1536 = 12288 个 patch tokens 压缩为
|
| 4 |
+
4 × 12 × 32 = 1536 个视觉 tokens。维度保持 768。
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TemporalCompress2x2x2(nn.Module):
|
| 14 |
+
"""``Conv3d(D, D, kernel=2, stride=2)`` 配合标准 LayerNorm。"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, dim: int = 768) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.dim = dim
|
| 19 |
+
self.conv = nn.Conv3d(dim, dim, kernel_size=2, stride=2, padding=0)
|
| 20 |
+
self.norm = nn.LayerNorm(dim)
|
| 21 |
+
|
| 22 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 23 |
+
"""输入 ``[B, T, H, W, D]``;输出 ``[B, T*H*W//8, D]``。
|
| 24 |
+
|
| 25 |
+
中间排布:
|
| 26 |
+
[B, T, H, W, D] -> [B, D, T, H, W] -> Conv3d -> [B, D, T', H', W']
|
| 27 |
+
-> [B, T'*H'*W', D] -> LayerNorm
|
| 28 |
+
"""
|
| 29 |
+
b, t, h, w, d = x.shape
|
| 30 |
+
x_in = x.permute(0, 4, 1, 2, 3).contiguous() # [B, D, T, H, W]
|
| 31 |
+
y = self.conv(x_in)
|
| 32 |
+
bb, dd, t2, h2, w2 = y.shape
|
| 33 |
+
y = y.permute(0, 2, 3, 4, 1).reshape(bb, t2 * h2 * w2, dd)
|
| 34 |
+
return self.norm(y), (t2, h2, w2)
|