diff --git a/MEMORY.md b/MEMORY.md index b69da5dfd8073521d721f4870f72f6b768b58583..b9475653bdfe957d137e7fd940a72ae42e0e3124 100644 --- a/MEMORY.md +++ b/MEMORY.md @@ -1,63 +1,63 @@ -# 显存与内存估算(bf16 AMP + GradNorm + PCGrad) - -由 `scripts/estimate_memory.py` 生成。模型规模:18 层主干(9 Dense + 9 MoE,每 MoE 层 7 路由 + 1 共享专家)+ DINOv3 ViT-B/16 + 6 层校准 + 1024 检测 token + 24 控制 token。 - -| 项目 | 数值 | -|---|---| -| 总参数 | **725.62 M** | -| 可训练 (Stage1, DINOv3 冻结) | 639.96 M | -| 可训练 (Stage2, DINOv3 解冻) | 725.62 M | -| 序列长度(拼接后) | 2848 | - -## 显存(含 15% 余量) - -| Batch Size | Stage2 峰值 | 推荐单卡 GPU | HF Sandbox 选项 | -|---:|---:|---|---| -| 1 | ~16 GB | T4 16GB(紧)/ L4 24GB | `t4-small` | -| 2 | ~18 GB | L4 24GB | `l4x1` | -| 4 | ~22 GB | L4 24GB / A10G 24GB | `a10g-small` | -| **8 (目标)** | **~30 GB** | **A10G Large 48GB / A100 40GB** | **`a10g-large`** | -| 16 | ~46 GB | A100 80GB / H100 80GB | `a100-large` | - -显存细分(BS=8 Stage2): -- 权重 (bf16): 1.35 GB -- 优化器 (AdamW fp32 m+v + 主副本): 8.11 GB -- 主激活 (bf16, 18 + 6 层): ~12.6 GB -- PCGrad retain_graph 开销: ~6.3 GB -- 缓冲 / cuDNN workspace / 碎片: ~2 GB - -如显存不足: -- 开 `gradient_checkpointing`(激活降至 ~1/3,可把 BS=8 塞进 A10G 24GB 大约 28GB) -- BS=4 + `grad_accum_steps=2` 等价 BS=8 训练 -- 关 PCGrad(节省 ~6 GB),但牺牲多任务收敛质量 - -## 主机内存 / 磁盘 - -| 项目 | 数值(BS=8) | -|---|---| -| 主机 RAM 推荐 | ≥ 32 GB(DataLoader 4 workers × prefetch 2 + 模型 CPU 副本) | -| 磁盘(一个 weather 子集,sandbox 验证) | ~50 GB | -| 磁盘(synthetic 全量 121 帧 × 7 weather × 5843 clip) | ~700 GB | -| 磁盘(synthetic + lidar + hdmap 全量) | ~3 TB | - -## 设备选择建议 - -- **本地烟囱(CPU/小卡)**:`scripts/smoke_test.py` 用极小张量验证 forward+backward,不需要 GPU。 -- **HF Sandbox**:`a10g-large` (48 GB),BS=8 + bf16 + PCGrad 一次成功;约 $1.05/小时(HF 价格随时调整请以官方为准)。 -- **HF Jobs 全量训练**:`a100x1` (80 GB) 或 `h100x1`,BS=8~16。 - -## 复现命令 - -```bash -# 升级依赖到最新(写入 requirements.lock.txt) -python scripts/update_deps.py --torch-index https://download.pytorch.org/whl/cu124 - -# 估算 -python scripts/estimate_memory.py - -# Sandbox 推送 -python scripts/push_to_sandbox.py --repo your-username/wjad-sandbox --gpu a10g-large - -# Jobs 全量 -python scripts/push_to_jobs.py --repo your-username/wjad --flavor a100x1 -``` +# 显存与内存估算(bf16 AMP + GradNorm + PCGrad) + +由 `scripts/estimate_memory.py` 生成。模型规模:18 层主干(9 Dense + 9 MoE,每 MoE 层 7 路由 + 1 共享专家)+ DINOv3 ViT-B/16 + 6 层校准 + 1024 检测 token + 24 控制 token。 + +| 项目 | 数值 | +|---|---| +| 总参数 | **725.62 M** | +| 可训练 (Stage1, DINOv3 冻结) | 639.96 M | +| 可训练 (Stage2, DINOv3 解冻) | 725.62 M | +| 序列长度(拼接后) | 2848 | + +## 显存(含 15% 余量) + +| Batch Size | Stage2 峰值 | 推荐单卡 GPU | HF Sandbox 选项 | +|---:|---:|---|---| +| 1 | ~16 GB | T4 16GB(紧)/ L4 24GB | `t4-small` | +| 2 | ~18 GB | L4 24GB | `l4x1` | +| 4 | ~22 GB | L4 24GB / A10G 24GB | `a10g-small` | +| **8 (目标)** | **~30 GB** | **A10G Large 48GB / A100 40GB** | **`a10g-large`** | +| 16 | ~46 GB | A100 80GB / H100 80GB | `a100-large` | + +显存细分(BS=8 Stage2): +- 权重 (bf16): 1.35 GB +- 优化器 (AdamW fp32 m+v + 主副本): 8.11 GB +- 主激活 (bf16, 18 + 6 层): ~12.6 GB +- PCGrad retain_graph 开销: ~6.3 GB +- 缓冲 / cuDNN workspace / 碎片: ~2 GB + +如显存不足: +- 开 `gradient_checkpointing`(激活降至 ~1/3,可把 BS=8 塞进 A10G 24GB 大约 28GB) +- BS=4 + `grad_accum_steps=2` 等价 BS=8 训练 +- 关 PCGrad(节省 ~6 GB),但牺牲多任务收敛质量 + +## 主机内存 / 磁盘 + +| 项目 | 数值(BS=8) | +|---|---| +| 主机 RAM 推荐 | ≥ 32 GB(DataLoader 4 workers × prefetch 2 + 模型 CPU 副本) | +| 磁盘(一个 weather 子集,sandbox 验证) | ~50 GB | +| 磁盘(synthetic 全量 121 帧 × 7 weather × 5843 clip) | ~700 GB | +| 磁盘(synthetic + lidar + hdmap 全量) | ~3 TB | + +## 设备选择建议 + +- **本地烟囱(CPU/小卡)**:`scripts/smoke_test.py` 用极小张量验证 forward+backward,不需要 GPU。 +- **HF Sandbox**:`a10g-large` (48 GB),BS=8 + bf16 + PCGrad 一次成功;约 $1.05/小时(HF 价格随时调整请以官方为准)。 +- **HF Jobs 全量训练**:`a100x1` (80 GB) 或 `h100x1`,BS=8~16。 + +## 复现命令 + +```bash +# 升级依赖到最新(写入 requirements.lock.txt) +python scripts/update_deps.py --torch-index https://download.pytorch.org/whl/cu124 + +# 估算 +python scripts/estimate_memory.py + +# Sandbox 推送 +python scripts/push_to_sandbox.py --repo your-username/wjad-sandbox --gpu a10g-large + +# Jobs 全量 +python scripts/push_to_jobs.py --repo your-username/wjad --flavor a100x1 +``` diff --git a/README.md b/README.md index 2b6abb1ec5ef7ec27b5670a19ce0d735d8fb4f96..7eb0a097a167b3f396f17d5eb35380d6310875b3 100644 --- a/README.md +++ b/README.md @@ -1,63 +1,63 @@ ---- -title: WJAD Sandbox -emoji: 🚗 -colorFrom: blue -colorTo: indigo -sdk: docker -app_port: 7860 -pinned: false ---- - -# WJAD - 端到端自动驾驶模型 - -基于 [Design.md](Design.md) 实现的端到端自动驾驶模型,用于 NVIDIA Cosmos-Drive-Dreams 数据集。 - -## 架构概览 - -- **视觉编码器**:本地 DINOv3 ViT-B/16(`dinov3-vitb16-pretrain-lvd1689m`),SDPA 注意力。 -- **时空压缩**:2×2×2 Conv3D,将 8 帧 × 24 × 64 patch tokens 压缩为 1536 个视觉 token。 -- **在线校准**:dim=256,6 层 (1 GateCrossAttn + 2 GateSelfAttn) × 2,跨注意力 K/V 来自 DINOv3 patch;输入与残差均在 symlog 空间,输出 SE3 + 内外参修正量。 -- **主干**:18 层 GateSelfAttention(前 9 Dense + 后 9 MoE,每层独立 7 路由 + 1 共享专家,GAP 序列级 Sigmoid Top-3),dim=768,12 头 SDPA + PreNorm + SwiGLU。 -- **位置编码**:3D RoPE 仅作用于视觉 token 的 Q/K——头 0-3 编码自车系单位射线,头 4-7 编码 H/W/T,头 8-11 零频段(identity,统一代码路径)。其余 token(ego/det/ctrl/extra)使用一一对应的可学习 PE。 -- **统一检测+预测头**:1024 token 同时输出 `cls + is_dynamic + box3d(μ,logσ) + 未来 24 帧轨迹(μ,logσ)`。 -- **控制头**:24 token 输出自车未来轨迹与全局控制(均 NLL μ/logσ)。 -- **多任务训练**:GradNorm 自适应任务权重 + Stage2 启用 PCGrad 正交化梯度冲突。 -- **训练阶段**:Stage1 Dense + 路由锐化 + 中期运动学/内外参扰动;Stage2 切 Top-3 + DINOv3 低 LR 微调。 - -## 三步训练路径 - -```bash -# 1. 本地跑通(纯随机张量) -python -m scripts.smoke_test - -# 2. HF Sandbox 微小训练 -python -m scripts.push_to_sandbox - -# 3. HF Jobs 全量训练 -python -m scripts.push_to_jobs -``` - -## 数据准备 - -```bash -python -m scripts.download_data --odir ./data/cosmos --file_types synthetic,lidar,hdmap -``` - -## 项目结构 - -``` -src/wjad/ -├── modules/ # 公用算子:FFN/门控注意力/MoE/RoPE/可学习PE/symlog/... -├── encoders/ # DINOv3 包装 + 2x2x2 时空压缩 -├── calibration/ # 在线校准网络 -├── backbone/ # 18 层主干 -├── heads/ # 检测+预测头、控制头 -├── data/ # Cosmos-Drive-Dreams 加载器、f-theta、增广 -├── losses/ # NLL/检测/轨迹/控制/MoE/校准正则 -├── train/ # 多任务(GradNorm+PCGrad)、Trainer、调度 -└── model.py # 顶层 E2EAVModel -``` - -## License - -代码遵循仓库根目录指定的开源协议。DINOv3 权重遵循 Meta DINOv3 License;Cosmos-Drive-Dreams 数据集遵循 CC BY 4.0。 +--- +title: WJAD Sandbox +emoji: 🚗 +colorFrom: blue +colorTo: indigo +sdk: docker +app_port: 7860 +pinned: false +--- + +# WJAD - 端到端自动驾驶模型 + +基于 [Design.md](Design.md) 实现的端到端自动驾驶模型,用于 NVIDIA Cosmos-Drive-Dreams 数据集。 + +## 架构概览 + +- **视觉编码器**:本地 DINOv3 ViT-B/16(`dinov3-vitb16-pretrain-lvd1689m`),SDPA 注意力。 +- **时空压缩**:2×2×2 Conv3D,将 8 帧 × 24 × 64 patch tokens 压缩为 1536 个视觉 token。 +- **在线校准**:dim=256,6 层 (1 GateCrossAttn + 2 GateSelfAttn) × 2,跨注意力 K/V 来自 DINOv3 patch;输入与残差均在 symlog 空间,输出 SE3 + 内外参修正量。 +- **主干**:18 层 GateSelfAttention(前 9 Dense + 后 9 MoE,每层独立 7 路由 + 1 共享专家,GAP 序列级 Sigmoid Top-3),dim=768,12 头 SDPA + PreNorm + SwiGLU。 +- **位置编码**:3D RoPE 仅作用于视觉 token 的 Q/K——头 0-3 编码自车系单位射线,头 4-7 编码 H/W/T,头 8-11 零频段(identity,统一代码路径)。其余 token(ego/det/ctrl/extra)使用一一对应的可学习 PE。 +- **统一检测+预测头**:1024 token 同时输出 `cls + is_dynamic + box3d(μ,logσ) + 未来 24 帧轨迹(μ,logσ)`。 +- **控制头**:24 token 输出自车未来轨迹与全局控制(均 NLL μ/logσ)。 +- **多任务训练**:GradNorm 自适应任务权重 + Stage2 启用 PCGrad 正交化梯度冲突。 +- **训练阶段**:Stage1 Dense + 路由锐化 + 中期运动学/内外参扰动;Stage2 切 Top-3 + DINOv3 低 LR 微调。 + +## 三步训练路径 + +```bash +# 1. 本地跑通(纯随机张量) +python -m scripts.smoke_test + +# 2. HF Sandbox 微小训练 +python -m scripts.push_to_sandbox + +# 3. HF Jobs 全量训练 +python -m scripts.push_to_jobs +``` + +## 数据准备 + +```bash +python -m scripts.download_data --odir ./data/cosmos --file_types synthetic,lidar,hdmap +``` + +## 项目结构 + +``` +src/wjad/ +├── modules/ # 公用算子:FFN/门控注意力/MoE/RoPE/可学习PE/symlog/... +├── encoders/ # DINOv3 包装 + 2x2x2 时空压缩 +├── calibration/ # 在线校准网络 +├── backbone/ # 18 层主干 +├── heads/ # 检测+预测头、控制头 +├── data/ # Cosmos-Drive-Dreams 加载器、f-theta、增广 +├── losses/ # NLL/检测/轨迹/控制/MoE/校准正则 +├── train/ # 多任务(GradNorm+PCGrad)、Trainer、调度 +└── model.py # 顶层 E2EAVModel +``` + +## License + +代码遵循仓库根目录指定的开源协议。DINOv3 权重遵循 Meta DINOv3 License;Cosmos-Drive-Dreams 数据集遵循 CC BY 4.0。 diff --git a/configs/default.yaml b/configs/default.yaml index 6d2f1cfaf8670308461ea42ebc6e3e772057a3e6..f8370572da8d5715446a7f8d8f29ac21eb5fa31a 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -1,192 +1,192 @@ -# 端到端自动驾驶模型默认配置(与 Design.md 对齐) - -# === 全局 === -seed: 42 -device: cuda -mixed_precision: bf16 # H100 推荐 bf16 -gradient_checkpointing: true # A10G/L4/A100-40G 上需要打开;H100-80G 可关闭以加速 - -# === 输入 === -input: - image_height: 384 # 已裁去上半天空的高度 - image_width: 1024 - num_history_frames: 8 # t-7..t(含当前) - num_future_frames: 24 # 预测窗口 - camera_name: camera_front_wide_120fov # 当前唯一开放的视角 - -# === 视觉编码器(DINOv3)=== -dinov3: - pretrained_path: "./dinov3-vitb16-pretrain-lvd1689m" - hidden_size: 768 - patch_size: 16 - num_register_tokens: 4 - attn_implementation: sdpa - freeze_in_stage1: true # Stage1 冻结 DINOv3 - finetune_lr_ratio: 0.01 # Stage2 解冻后相对主干 LR 的倍率 - -# === 时空压缩 === -temporal_compress: - kernel: [2, 2, 2] - stride: [2, 2, 2] - -# === 主干 === -backbone: - hidden_size: 768 - num_heads: 12 # head_dim = 64 - ffn_mult: 4 # SwiGLU 扩展倍数(D->4D->2D->D) - num_dense_layers: 9 - num_moe_layers: 9 # 共 18 层 - dropout: 0.0 - prenorm: true - -# === MoE === -moe: - num_routed_experts: 7 - num_shared_experts: 1 - topk: 3 # Stage2 激活专家数 - router_temperature_init: 0.5 - router_temperature_final: 1.0 - load_balance_weight: 0.01 - boundary_weight: 0.001 # mean(logits^2) 防越界 - -# === Token 数量 === -tokens: - num_detection: 1024 - num_control: 24 - num_ego: 8 - num_extra: 256 - -# === 在线校准 === -calibration: - intr_vec_dim: 11 # Cosmos npy 常见 11 维;完整 14 维数据则改为 14 - hidden_size: 256 - num_query_tokens: 256 - num_self_attn_per_block: 2 # 每个 block: 1 cross + 2 self - num_blocks: 2 # 总计 6 层 - num_heads: 8 # head_dim = 32 - residual_range: 0.1 # Tanh * range - init_zero_output: true - -# === 检测+未来预测头 === -det_traj_head: - num_classes: 22 # 1 bg + 12 dynamic + 9 structured(HDMap) - box_dim: 7 # x,y,z,l,w,h,yaw - traj_horizon: 24 # 未来 24 帧 - traj_dim: 3 # dx,dy,dyaw - hidden_size: 384 - log_sigma_clamp: [-7.0, 7.0] - -# === 控制头 === -control_head: - num_traj_tokens: 12 # 解码 24 帧 ego 轨迹 - num_action_tokens: 12 # 1 个解码 (steer,throttle,brake) μ/logσ - ego_traj_dim: 3 # x,y,yaw - action_dim: 3 - hidden_size: 384 - log_sigma_clamp: [-7.0, 7.0] - -# === 检测目标筛选 === -detection: - max_distance_m: 48.0 - occlusion_depth_tolerance: 0.5 # LIDAR 深度容差(米) - min_box_pixels: 8 - dynamic_classes: - - Automobile - - Heavy_truck - - Bus - - Train_or_tram_car - - Trolley_bus - - Other_vehicle - - Trailer - - Person - - Stroller - - Rider - - Animal - - Protruding_object - -# === 数据 === -data: - root: "./data/cosmos" - hdmap_subdir: "rds_hq" - synthetic_subdir: "cosmos_synthetic/single_view" - use_synthetic: true - use_real: false - weather: - - Sunny - - Morning - - Golden_hour - - Night - - Rainy - - Snowy - - Foggy - num_workers: 4 - pin_memory: true - prefetch_factor: 2 - augmentation: - gaussian_noise_std: 0.01 - color_jitter: 0.1 - perturb_translation_std_m: 0.1 # Stage1 中期开启 - perturb_rotation_std_deg: 0.5 - perturb_intrinsic_std: 0.005 - perturb_extrinsic_std: 0.005 - -# === 损失权重(GradNorm 自适应的 1-6 任务初值)=== -loss: - cls_weight: 1.0 - box_weight: 1.0 - isdyn_weight: 1.0 - traj_obj_weight: 1.0 - traj_ego_weight: 1.0 - ctrl_weight: 1.0 - moe_weight: 1.0 # 固定权重正则 - calib_weight: 0.1 # 固定权重正则 - giou_weight: 0.5 # 在 box 内组合 - focal_alpha: 0.25 - focal_gamma: 2.0 - matcher_cls_cost: 2.0 - matcher_l1_cost: 5.0 - matcher_giou_cost: 2.0 - -# === 多任务训练 === -# PCGrad 与 GradNorm 在 Stage1 / Stage2 全程启用: -# 两阶段的 6 项主任务(cls / box / isdyn / traj_obj / traj_ego / ctrl) -# 都存在尺度差异与梯度方向冲突,PCGrad 不应延迟到 Stage2。 -multitask: - enable_gradnorm: true - enable_pcgrad: true - gradnorm_alpha: 1.5 - gradnorm_lr: 0.025 - pcgrad_shuffle: true - -# === 训练 === -train: - batch_size: 12 # A10G-Large 起步;OOM 时改为 8/6 并视情况增大 grad_accum_steps - grad_accum_steps: 1 # 有效 batch = batch_size * grad_accum_steps - ckpt_dir: outputs/checkpoints - total_steps: 100000 - warmup_steps: 1000 - base_lr: 2.0e-4 - min_lr: 1.0e-6 - weight_decay: 0.05 - optimizer: adamw - betas: [0.9, 0.95] - grad_clip: 1.0 - log_interval: 20 - ckpt_interval: 1000 - eval_interval: 5000 - stage1_steps: 60000 # Stage1 步数 - stage1_perturb_start: 20000 # 中期开始扰动 - grad_monitor_threshold: 1.0e-7 - param_groups: - dinov3_lr_mult: 0.0 # Stage1=0, Stage2 由 finetune_lr_ratio 提供 - backbone_lr_mult: 1.0 - calibration_lr_mult: 0.1 - head_lr_mult: 1.0 - gate_lr_mult: 0.1 # 门控参数低 LR - -# === 部署 === -deploy: - hf_repo: "fuzirui/WJAD" # 训练产生的 checkpoint 上传到此 model 仓库 - push_checkpoints: true # 每 ckpt_interval 步上传 step_*.pt + latest.pt - hub_ckpt_prefix: checkpoints # Hub 上子目录 - hf_sandbox_space: "fuzirui/wjad-sandbox" +# 端到端自动驾驶模型默认配置(与 Design.md 对齐) + +# === 全局 === +seed: 42 +device: cuda +mixed_precision: bf16 # H100 推荐 bf16 +gradient_checkpointing: true # A10G/L4/A100-40G 上需要打开;H100-80G 可关闭以加速 + +# === 输入 === +input: + image_height: 384 # 已裁去上半天空的高度 + image_width: 1024 + num_history_frames: 8 # t-7..t(含当前) + num_future_frames: 24 # 预测窗口 + camera_name: camera_front_wide_120fov # 当前唯一开放的视角 + +# === 视觉编码器(DINOv3)=== +dinov3: + pretrained_path: "./dinov3-vitb16-pretrain-lvd1689m" + hidden_size: 768 + patch_size: 16 + num_register_tokens: 4 + attn_implementation: sdpa + freeze_in_stage1: true # Stage1 冻结 DINOv3 + finetune_lr_ratio: 0.01 # Stage2 解冻后相对主干 LR 的倍率 + +# === 时空压缩 === +temporal_compress: + kernel: [2, 2, 2] + stride: [2, 2, 2] + +# === 主干 === +backbone: + hidden_size: 768 + num_heads: 12 # head_dim = 64 + ffn_mult: 4 # SwiGLU 扩展倍数(D->4D->2D->D) + num_dense_layers: 9 + num_moe_layers: 9 # 共 18 层 + dropout: 0.0 + prenorm: true + +# === MoE === +moe: + num_routed_experts: 7 + num_shared_experts: 1 + topk: 3 # Stage2 激活专家数 + router_temperature_init: 0.5 + router_temperature_final: 1.0 + load_balance_weight: 0.01 + boundary_weight: 0.001 # mean(logits^2) 防越界 + +# === Token 数量 === +tokens: + num_detection: 1024 + num_control: 24 + num_ego: 8 + num_extra: 256 + +# === 在线校准 === +calibration: + intr_vec_dim: 11 # Cosmos npy 常见 11 维;完整 14 维数据则改为 14 + hidden_size: 256 + num_query_tokens: 256 + num_self_attn_per_block: 2 # 每个 block: 1 cross + 2 self + num_blocks: 2 # 总计 6 层 + num_heads: 8 # head_dim = 32 + residual_range: 0.1 # Tanh * range + init_zero_output: true + +# === 检测+未来预测头 === +det_traj_head: + num_classes: 22 # 1 bg + 12 dynamic + 9 structured(HDMap) + box_dim: 7 # x,y,z,l,w,h,yaw + traj_horizon: 24 # 未来 24 帧 + traj_dim: 3 # dx,dy,dyaw + hidden_size: 384 + log_sigma_clamp: [-7.0, 7.0] + +# === 控制头 === +control_head: + num_traj_tokens: 12 # 解码 24 帧 ego 轨迹 + num_action_tokens: 12 # 1 个解码 (steer,throttle,brake) μ/logσ + ego_traj_dim: 3 # x,y,yaw + action_dim: 3 + hidden_size: 384 + log_sigma_clamp: [-7.0, 7.0] + +# === 检测目标筛选 === +detection: + max_distance_m: 48.0 + occlusion_depth_tolerance: 0.5 # LIDAR 深度容差(米) + min_box_pixels: 8 + dynamic_classes: + - Automobile + - Heavy_truck + - Bus + - Train_or_tram_car + - Trolley_bus + - Other_vehicle + - Trailer + - Person + - Stroller + - Rider + - Animal + - Protruding_object + +# === 数据 === +data: + root: "./data/cosmos" + hdmap_subdir: "rds_hq" + synthetic_subdir: "cosmos_synthetic/single_view" + use_synthetic: true + use_real: false + weather: + - Sunny + - Morning + - Golden_hour + - Night + - Rainy + - Snowy + - Foggy + num_workers: 4 + pin_memory: true + prefetch_factor: 2 + augmentation: + gaussian_noise_std: 0.01 + color_jitter: 0.1 + perturb_translation_std_m: 0.1 # Stage1 中期开启 + perturb_rotation_std_deg: 0.5 + perturb_intrinsic_std: 0.005 + perturb_extrinsic_std: 0.005 + +# === 损失权重(GradNorm 自适应的 1-6 任务初值)=== +loss: + cls_weight: 1.0 + box_weight: 1.0 + isdyn_weight: 1.0 + traj_obj_weight: 1.0 + traj_ego_weight: 1.0 + ctrl_weight: 1.0 + moe_weight: 1.0 # 固定权重正则 + calib_weight: 0.1 # 固定权重正则 + giou_weight: 0.5 # 在 box 内组合 + focal_alpha: 0.25 + focal_gamma: 2.0 + matcher_cls_cost: 2.0 + matcher_l1_cost: 5.0 + matcher_giou_cost: 2.0 + +# === 多任务训练 === +# PCGrad 与 GradNorm 在 Stage1 / Stage2 全程启用: +# 两阶段的 6 项主任务(cls / box / isdyn / traj_obj / traj_ego / ctrl) +# 都存在尺度差异与梯度方向冲突,PCGrad 不应延迟到 Stage2。 +multitask: + enable_gradnorm: true + enable_pcgrad: true + gradnorm_alpha: 1.5 + gradnorm_lr: 0.025 + pcgrad_shuffle: true + +# === 训练 === +train: + batch_size: 12 # A10G-Large 起步;OOM 时改为 8/6 并视情况增大 grad_accum_steps + grad_accum_steps: 1 # 有效 batch = batch_size * grad_accum_steps + ckpt_dir: outputs/checkpoints + total_steps: 100000 + warmup_steps: 1000 + base_lr: 2.0e-4 + min_lr: 1.0e-6 + weight_decay: 0.05 + optimizer: adamw + betas: [0.9, 0.95] + grad_clip: 1.0 + log_interval: 20 + ckpt_interval: 1000 + eval_interval: 5000 + stage1_steps: 60000 # Stage1 步数 + stage1_perturb_start: 20000 # 中期开始扰动 + grad_monitor_threshold: 1.0e-7 + param_groups: + dinov3_lr_mult: 0.0 # Stage1=0, Stage2 由 finetune_lr_ratio 提供 + backbone_lr_mult: 1.0 + calibration_lr_mult: 0.1 + head_lr_mult: 1.0 + gate_lr_mult: 0.1 # 门控参数低 LR + +# === 部署 === +deploy: + hf_repo: "fuzirui/WJAD" # 训练产生的 checkpoint 上传到此 model 仓库 + push_checkpoints: true # 每 ckpt_interval 步上传 step_*.pt + latest.pt + hub_ckpt_prefix: checkpoints # Hub 上子目录 + hf_sandbox_space: "fuzirui/wjad-sandbox" diff --git a/pyproject.toml b/pyproject.toml index f5498ac51c5ca5bfcfdbf36bb02ef1db9c165ec6..edc7001414b194c9bfdfcec34e854be28f3ef1e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,40 +1,40 @@ -[build-system] -requires = ["setuptools>=68", "wheel"] -build-backend = "setuptools.build_meta" - -[project] -name = "wjad" -version = "0.1.0" -description = "End-to-end autonomous driving model with DINOv3, GateSelfAttention backbone, MoE, online calibration." -requires-python = ">=3.10" -dependencies = [ - "torch>=2.4", - "transformers>=4.56", - "safetensors>=0.4", - "numpy>=1.24", - "opencv-python-headless>=4.8", - "einops>=0.7", - "scipy>=1.11", - "pyyaml>=6.0", - "tqdm>=4.66", - "huggingface_hub>=0.24", - "pillow>=10.0", - "av>=12.0", -] - -[project.optional-dependencies] -dev = [ - "pytest>=7", - "pytest-cov>=4", - "ruff>=0.5", -] - -[tool.setuptools] -package-dir = {"" = "src"} - -[tool.setuptools.packages.find] -where = ["src"] - -[tool.ruff] -line-length = 120 -target-version = "py310" +[build-system] +requires = ["setuptools>=68", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "wjad" +version = "0.1.0" +description = "End-to-end autonomous driving model with DINOv3, GateSelfAttention backbone, MoE, online calibration." +requires-python = ">=3.10" +dependencies = [ + "torch>=2.4", + "transformers>=4.56", + "safetensors>=0.4", + "numpy>=1.24", + "opencv-python-headless>=4.8", + "einops>=0.7", + "scipy>=1.11", + "pyyaml>=6.0", + "tqdm>=4.66", + "huggingface_hub>=0.24", + "pillow>=10.0", + "av>=12.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7", + "pytest-cov>=4", + "ruff>=0.5", +] + +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.ruff] +line-length = 120 +target-version = "py310" diff --git a/scripts/download_data.py b/scripts/download_data.py index 11c5c4f6f5a307d03bc62204861e5e693c4cb57b..ea4d4c6f4b7c72b9988140a3ca2a3a689d526417 100644 --- a/scripts/download_data.py +++ b/scripts/download_data.py @@ -1,76 +1,76 @@ -"""下载 NVIDIA Cosmos-Drive-Dreams 数据集。 - -直接调用 NVIDIA 官方 download.py,并支持仅下载几个 clip 用于 sandbox 验证。 - -用法: - # 完整下载(synthetic + lidar + hdmap),约 3TB - python scripts/download_data.py --odir ./data/cosmos --workers 8 - - # 仅烟囱:限制 clip 数量(约几 GB,取决于 N) - python scripts/download_data.py --odir ./data/cosmos --file_types synthetic,lidar,hdmap --limit 2 -""" - -from __future__ import annotations - -import argparse -import subprocess -import sys -import urllib.request -from pathlib import Path - - -NV_DOWNLOAD_URL = ( - "https://raw.githubusercontent.com/nv-tlabs/Cosmos-Drive-Dreams/main/scripts/download.py" -) - - -def _ensure_official_script(local_path: Path) -> None: - if local_path.exists(): - return - print(f"[download_data] 下载 NVIDIA download.py -> {local_path}") - local_path.parent.mkdir(parents=True, exist_ok=True) - with urllib.request.urlopen(NV_DOWNLOAD_URL) as resp, open(local_path, "wb") as f: - f.write(resp.read()) - - -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument("--odir", required=True, help="数据输出目录") - parser.add_argument( - "--file_types", - default="synthetic,lidar,hdmap", - help="数据类型逗号分隔列表", - ) - parser.add_argument("--workers", type=int, default=4) - parser.add_argument( - "--limit", - type=int, - default=None, - metavar="N", - help="只拉取前 N 个 clip(传给 NVIDIA download.py,省磁盘)", - ) - parser.add_argument("--clean_cache", action="store_true") - args = parser.parse_args() - - odir = Path(args.odir) - nv_script = odir / ".nvidia_download.py" - _ensure_official_script(nv_script) - - cmd = [ - sys.executable, - str(nv_script), - "--odir", str(odir), - "--file_types", args.file_types, - "--workers", str(args.workers), - ] - if args.limit is not None: - cmd.extend(["--limit", str(args.limit)]) - if args.clean_cache: - cmd.append("--clean_cache") - print(f"[download_data] $ {' '.join(cmd)}") - rc = subprocess.call(cmd) - sys.exit(rc) - - -if __name__ == "__main__": - main() +"""下载 NVIDIA Cosmos-Drive-Dreams 数据集。 + +直接调用 NVIDIA 官方 download.py,并支持仅下载几个 clip 用于 sandbox 验证。 + +用法: + # 完整下载(synthetic + lidar + hdmap),约 3TB + python scripts/download_data.py --odir ./data/cosmos --workers 8 + + # 仅烟囱:限制 clip 数量(约几 GB,取决于 N) + python scripts/download_data.py --odir ./data/cosmos --file_types synthetic,lidar,hdmap --limit 2 +""" + +from __future__ import annotations + +import argparse +import subprocess +import sys +import urllib.request +from pathlib import Path + + +NV_DOWNLOAD_URL = ( + "https://raw.githubusercontent.com/nv-tlabs/Cosmos-Drive-Dreams/main/scripts/download.py" +) + + +def _ensure_official_script(local_path: Path) -> None: + if local_path.exists(): + return + print(f"[download_data] 下载 NVIDIA download.py -> {local_path}") + local_path.parent.mkdir(parents=True, exist_ok=True) + with urllib.request.urlopen(NV_DOWNLOAD_URL) as resp, open(local_path, "wb") as f: + f.write(resp.read()) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--odir", required=True, help="数据输出目录") + parser.add_argument( + "--file_types", + default="synthetic,lidar,hdmap", + help="数据类型逗号分隔列表", + ) + parser.add_argument("--workers", type=int, default=4) + parser.add_argument( + "--limit", + type=int, + default=None, + metavar="N", + help="只拉取前 N 个 clip(传给 NVIDIA download.py,省磁盘)", + ) + parser.add_argument("--clean_cache", action="store_true") + args = parser.parse_args() + + odir = Path(args.odir) + nv_script = odir / ".nvidia_download.py" + _ensure_official_script(nv_script) + + cmd = [ + sys.executable, + str(nv_script), + "--odir", str(odir), + "--file_types", args.file_types, + "--workers", str(args.workers), + ] + if args.limit is not None: + cmd.extend(["--limit", str(args.limit)]) + if args.clean_cache: + cmd.append("--clean_cache") + print(f"[download_data] $ {' '.join(cmd)}") + rc = subprocess.call(cmd) + sys.exit(rc) + + +if __name__ == "__main__": + main() diff --git a/scripts/estimate_memory.py b/scripts/estimate_memory.py index 700fe4d986436cef904fde1ef3a84db4b916fffb..6baab3b2bff146563a5da21b6de5446ae15f191a 100644 --- a/scripts/estimate_memory.py +++ b/scripts/estimate_memory.py @@ -1,203 +1,203 @@ -"""估算 E2EAVModel 在 BS≥8 训练时的显存/内存需求。 - -输出 - - 各模块参数数量 - - 训练显存细分:参数 / 优化器 / 梯度 / 主激活 / 多任务梯度副本 / 缓冲 - - 推荐设备(HF Sandbox / Jobs) - - 主机内存与磁盘开销 - -公式说明(粗略上界) - - 参数 (bf16): 2 B/p;fp32 主副本: 4 B/p - - AdamW 一阶/二阶矩 (fp32): 8 B/p - - 梯度 (fp32): 4 B/p - - bf16 训练总计:参数 2 + 主 4 + AdamW 8 + grad 4 = 18 B/可训练 p - - DINOv3 冻结 Stage1:仅 2 B/p(前向激活按 no_grad 释放,可忽略) - - 主激活:每层约 ``B * N * D * 2 B``(bf16),18 层;MoE 层另加 8 个专家 - SwiGLU 中间 ``B * N * 2 * 4D * 2 B`` 的临时项,但 Dense 加权求和后只 - 需 1 份输出。实际显存按"激活 = 单层峰值 × 层数"近似。 - - PCGrad 在共享参数上 N 次 ``autograd.grad``:需要 retain_graph, - 每个任务额外保留中间激活的引用,最坏放大 N 倍。这里按 1.5x 估算 - (GPU autograd 内部 reuse + checkpointing 后通常远低于 N 倍)。 -""" - -from __future__ import annotations - -import sys -from pathlib import Path - -ROOT = Path(__file__).resolve().parent.parent -sys.path.insert(0, str(ROOT / "src")) - -from dataclasses import dataclass - -from wjad.model import E2EAVModel - - -@dataclass -class MemoryReport: - bs: int - seq_len: int - dim: int - layers: int - params_total: int - params_trainable_stage1: int - params_trainable_stage2: int - weights_gb_stage1: float - weights_gb_stage2: float - optim_gb_stage1: float - optim_gb_stage2: float - activations_gb: float - pcgrad_overhead_gb: float - total_stage1_gb: float - total_stage2_gb: float - host_ram_gb: float - disk_gb: float - - -def count_params(model) -> tuple[int, dict[str, int]]: - total = 0 - by_module: dict[str, int] = {} - for name, child in model.named_children(): - n = sum(p.numel() for p in child.parameters()) - by_module[name] = n - total += n - return total, by_module - - -def estimate(bs: int = 8) -> MemoryReport: - model = E2EAVModel( - dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"), - # 完整规模 - backbone_dim=768, - num_heads=12, - num_dense_layers=9, - num_moe_layers=9, - num_routed_experts=7, - num_shared_experts=1, - topk_experts=3, - ffn_mult=4, - num_history_frames=8, - num_detection_tokens=1024, - num_control_tokens=24, - num_ego_tokens=8, - num_extra_tokens=256, - image_h=384, - image_w=1024, - patch_size=16, - num_classes=22, - traj_horizon=24, - freeze_dinov3=True, - ) - - total, by_module = count_params(model) - dinov3_n = by_module.get("dinov3", 0) - trainable_stage1 = total - dinov3_n - trainable_stage2 = total - - # 序列长度(拼接后总 token 数 + 上下文) - n_visual = (8 // 2) * (24 // 2) * (64 // 2) - seq_len = n_visual + 8 + 1024 + 24 + 256 - - # === 显存 === - # 单位:GB(除以 1024**3) - GB = 1024 ** 3 - weights_stage1 = (dinov3_n * 2 + trainable_stage1 * 2) / GB # 全部 bf16 - weights_stage2 = (total * 2) / GB - optim_stage1 = (trainable_stage1 * (4 + 4 + 4)) / GB # master + m + v - optim_stage2 = (trainable_stage2 * (4 + 4 + 4)) / GB - - # 激活:粗略 = bs * seq_len * dim * 2 * (num_layers + 1) * 1.5 (含 attn/FFN 重叠) - base_act = bs * seq_len * 768 * 2 * (18 + 6) * 1.5 # 主干 18 + 校准 6 - # MoE FFN 中间 (4D = 3072) 的临时项:每 MoE 层 ≈ bs * seq_len * 3072 * 2 * 8(8 专家) - moe_act = bs * seq_len * 3072 * 2 * 8 * 9 - # DINOv3 冻结:no_grad,前向激活在 forward 后立即释放,估 2 GB 峰值 - dino_act = 2.0 * GB - activations_gb = (base_act + moe_act + dino_act) / GB - - # PCGrad 开销(共享参数上 N 次 autograd.grad):retain_graph 阶段会 - # 阻止激活释放,最坏接近 1.5x;这里按 +0.5x 估算 - pcgrad_overhead_gb = 0.5 * activations_gb - - total_stage1 = weights_stage1 + optim_stage1 + activations_gb + pcgrad_overhead_gb + 2.0 - total_stage2 = weights_stage2 + optim_stage2 + activations_gb + pcgrad_overhead_gb + 2.0 - - # === 主机 RAM === - # DataLoader prefetch + workers + 模型 CPU 副本 + JSON / LIDAR 解析 - host_ram = 8.0 + bs * 0.3 * 4 * 2 # 4 workers, prefetch 2 - - # === 磁盘 === - # 全量数据集 ~3TB;只跑 sandbox 时 ~5GB(几个 clip);典型 ~50GB(一个 weather 全部) - disk = 50.0 - - return MemoryReport( - bs=bs, - seq_len=seq_len, - dim=768, - layers=18, - params_total=total, - params_trainable_stage1=trainable_stage1, - params_trainable_stage2=trainable_stage2, - weights_gb_stage1=weights_stage1, - weights_gb_stage2=weights_stage2, - optim_gb_stage1=optim_stage1, - optim_gb_stage2=optim_stage2, - activations_gb=activations_gb, - pcgrad_overhead_gb=pcgrad_overhead_gb, - total_stage1_gb=total_stage1, - total_stage2_gb=total_stage2, - host_ram_gb=host_ram, - disk_gb=disk, - ) - - -def recommend_device(stage_max_gb: float) -> tuple[str, str]: - """根据 Stage2 峰值显存推荐 GPU。""" - margin = 1.15 # 留 15% 余量(碎片化、CUDA caching、cuBLAS workspace) - need = stage_max_gb * margin - candidates = [ - ("T4 16GB", 16), - ("L4 24GB", 24), - ("A10G 24GB", 24), - ("A10G Large 48GB", 48), - ("A100 40GB", 40), - ("L40S 48GB", 48), - ("A100 80GB", 80), - ("H100 80GB", 80), - ] - fit = [c for c in candidates if c[1] >= need] - if not fit: - return "H200 / 多卡 80GB+", f"需要 ≥{need:.1f} GB(单卡极限)" - return fit[0][0], f"需要 ≥{need:.1f} GB" - - -def main() -> None: - print("=" * 72) - print(" WJAD 训练显存/内存估算 (bf16 AMP)") - print("=" * 72) - for bs in (1, 2, 4, 8, 16): - r = estimate(bs) - print(f"\n--- BS = {bs} ---") - print(f" 总参数 : {r.params_total / 1e6:8.2f} M") - print(f" 可训练 (S1) : {r.params_trainable_stage1 / 1e6:8.2f} M") - print(f" 可训练 (S2) : {r.params_trainable_stage2 / 1e6:8.2f} M") - print(f" 序列长度 : {r.seq_len}") - print(f" 权重 (S1/S2) : {r.weights_gb_stage1:6.2f} / {r.weights_gb_stage2:6.2f} GB") - print(f" 优化器 (S1/S2): {r.optim_gb_stage1:6.2f} / {r.optim_gb_stage2:6.2f} GB") - print(f" 激活 : {r.activations_gb:6.2f} GB") - print(f" PCGrad 余量 : {r.pcgrad_overhead_gb:6.2f} GB") - print(f" 显存合计 S1 : {r.total_stage1_gb:6.2f} GB") - print(f" 显存合计 S2 : {r.total_stage2_gb:6.2f} GB <- 峰值") - gpu, note = recommend_device(r.total_stage2_gb) - print(f" 推荐 GPU : {gpu} ({note})") - print(f" 主机 RAM : ≥ {r.host_ram_gb:6.2f} GB") - print(f" 磁盘 (典型) : ≈ {r.disk_gb:6.0f} GB") - - print() - print("说明:") - print(" - 估算包含 bf16 AMP + AdamW(m,v fp32) + 梯度 fp32 主副本 + PCGrad 开销。") - print(" - 开 ``gradient_checkpointing`` 可把激活降至约 1/3,BS 可成倍提升。") - print(" - 实测请用 ``nvidia-smi`` 或 ``torch.cuda.max_memory_allocated()`` 校准。") - - -if __name__ == "__main__": - main() +"""估算 E2EAVModel 在 BS≥8 训练时的显存/内存需求。 + +输出 + - 各模块参数数量 + - 训练显存细分:参数 / 优化器 / 梯度 / 主激活 / 多任务梯度副本 / 缓冲 + - 推荐设备(HF Sandbox / Jobs) + - 主机内存与磁盘开销 + +公式说明(粗略上界) + - 参数 (bf16): 2 B/p;fp32 主副本: 4 B/p + - AdamW 一阶/二阶矩 (fp32): 8 B/p + - 梯度 (fp32): 4 B/p + - bf16 训练总计:参数 2 + 主 4 + AdamW 8 + grad 4 = 18 B/可训练 p + - DINOv3 冻结 Stage1:仅 2 B/p(前向激活按 no_grad 释放,可忽略) + - 主激活:每层约 ``B * N * D * 2 B``(bf16),18 层;MoE 层另加 8 个专家 + SwiGLU 中间 ``B * N * 2 * 4D * 2 B`` 的临时项,但 Dense 加权求和后只 + 需 1 份输出。实际显存按"激活 = 单层峰值 × 层数"近似。 + - PCGrad 在共享参数上 N 次 ``autograd.grad``:需要 retain_graph, + 每个任务额外保留中间激活的引用,最坏放大 N 倍。这里按 1.5x 估算 + (GPU autograd 内部 reuse + checkpointing 后通常远低于 N 倍)。 +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(ROOT / "src")) + +from dataclasses import dataclass + +from wjad.model import E2EAVModel + + +@dataclass +class MemoryReport: + bs: int + seq_len: int + dim: int + layers: int + params_total: int + params_trainable_stage1: int + params_trainable_stage2: int + weights_gb_stage1: float + weights_gb_stage2: float + optim_gb_stage1: float + optim_gb_stage2: float + activations_gb: float + pcgrad_overhead_gb: float + total_stage1_gb: float + total_stage2_gb: float + host_ram_gb: float + disk_gb: float + + +def count_params(model) -> tuple[int, dict[str, int]]: + total = 0 + by_module: dict[str, int] = {} + for name, child in model.named_children(): + n = sum(p.numel() for p in child.parameters()) + by_module[name] = n + total += n + return total, by_module + + +def estimate(bs: int = 8) -> MemoryReport: + model = E2EAVModel( + dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"), + # 完整规模 + backbone_dim=768, + num_heads=12, + num_dense_layers=9, + num_moe_layers=9, + num_routed_experts=7, + num_shared_experts=1, + topk_experts=3, + ffn_mult=4, + num_history_frames=8, + num_detection_tokens=1024, + num_control_tokens=24, + num_ego_tokens=8, + num_extra_tokens=256, + image_h=384, + image_w=1024, + patch_size=16, + num_classes=22, + traj_horizon=24, + freeze_dinov3=True, + ) + + total, by_module = count_params(model) + dinov3_n = by_module.get("dinov3", 0) + trainable_stage1 = total - dinov3_n + trainable_stage2 = total + + # 序列长度(拼接后总 token 数 + 上下文) + n_visual = (8 // 2) * (24 // 2) * (64 // 2) + seq_len = n_visual + 8 + 1024 + 24 + 256 + + # === 显存 === + # 单位:GB(除以 1024**3) + GB = 1024 ** 3 + weights_stage1 = (dinov3_n * 2 + trainable_stage1 * 2) / GB # 全部 bf16 + weights_stage2 = (total * 2) / GB + optim_stage1 = (trainable_stage1 * (4 + 4 + 4)) / GB # master + m + v + optim_stage2 = (trainable_stage2 * (4 + 4 + 4)) / GB + + # 激活:粗略 = bs * seq_len * dim * 2 * (num_layers + 1) * 1.5 (含 attn/FFN 重叠) + base_act = bs * seq_len * 768 * 2 * (18 + 6) * 1.5 # 主干 18 + 校准 6 + # MoE FFN 中间 (4D = 3072) 的临时项:每 MoE 层 ≈ bs * seq_len * 3072 * 2 * 8(8 专家) + moe_act = bs * seq_len * 3072 * 2 * 8 * 9 + # DINOv3 冻结:no_grad,前向激活在 forward 后立即释放,估 2 GB 峰值 + dino_act = 2.0 * GB + activations_gb = (base_act + moe_act + dino_act) / GB + + # PCGrad 开销(共享参数上 N 次 autograd.grad):retain_graph 阶段会 + # 阻止激活释放,最坏接近 1.5x;这里按 +0.5x 估算 + pcgrad_overhead_gb = 0.5 * activations_gb + + total_stage1 = weights_stage1 + optim_stage1 + activations_gb + pcgrad_overhead_gb + 2.0 + total_stage2 = weights_stage2 + optim_stage2 + activations_gb + pcgrad_overhead_gb + 2.0 + + # === 主机 RAM === + # DataLoader prefetch + workers + 模型 CPU 副本 + JSON / LIDAR 解析 + host_ram = 8.0 + bs * 0.3 * 4 * 2 # 4 workers, prefetch 2 + + # === 磁盘 === + # 全量数据集 ~3TB;只跑 sandbox 时 ~5GB(几个 clip);典型 ~50GB(一个 weather 全部) + disk = 50.0 + + return MemoryReport( + bs=bs, + seq_len=seq_len, + dim=768, + layers=18, + params_total=total, + params_trainable_stage1=trainable_stage1, + params_trainable_stage2=trainable_stage2, + weights_gb_stage1=weights_stage1, + weights_gb_stage2=weights_stage2, + optim_gb_stage1=optim_stage1, + optim_gb_stage2=optim_stage2, + activations_gb=activations_gb, + pcgrad_overhead_gb=pcgrad_overhead_gb, + total_stage1_gb=total_stage1, + total_stage2_gb=total_stage2, + host_ram_gb=host_ram, + disk_gb=disk, + ) + + +def recommend_device(stage_max_gb: float) -> tuple[str, str]: + """根据 Stage2 峰值显存推荐 GPU。""" + margin = 1.15 # 留 15% 余量(碎片化、CUDA caching、cuBLAS workspace) + need = stage_max_gb * margin + candidates = [ + ("T4 16GB", 16), + ("L4 24GB", 24), + ("A10G 24GB", 24), + ("A10G Large 48GB", 48), + ("A100 40GB", 40), + ("L40S 48GB", 48), + ("A100 80GB", 80), + ("H100 80GB", 80), + ] + fit = [c for c in candidates if c[1] >= need] + if not fit: + return "H200 / 多卡 80GB+", f"需要 ≥{need:.1f} GB(单卡极限)" + return fit[0][0], f"需要 ≥{need:.1f} GB" + + +def main() -> None: + print("=" * 72) + print(" WJAD 训练显存/内存估算 (bf16 AMP)") + print("=" * 72) + for bs in (1, 2, 4, 8, 16): + r = estimate(bs) + print(f"\n--- BS = {bs} ---") + print(f" 总参数 : {r.params_total / 1e6:8.2f} M") + print(f" 可训练 (S1) : {r.params_trainable_stage1 / 1e6:8.2f} M") + print(f" 可训练 (S2) : {r.params_trainable_stage2 / 1e6:8.2f} M") + print(f" 序列长度 : {r.seq_len}") + print(f" 权重 (S1/S2) : {r.weights_gb_stage1:6.2f} / {r.weights_gb_stage2:6.2f} GB") + print(f" 优化器 (S1/S2): {r.optim_gb_stage1:6.2f} / {r.optim_gb_stage2:6.2f} GB") + print(f" 激活 : {r.activations_gb:6.2f} GB") + print(f" PCGrad 余量 : {r.pcgrad_overhead_gb:6.2f} GB") + print(f" 显存合计 S1 : {r.total_stage1_gb:6.2f} GB") + print(f" 显存合计 S2 : {r.total_stage2_gb:6.2f} GB <- 峰值") + gpu, note = recommend_device(r.total_stage2_gb) + print(f" 推荐 GPU : {gpu} ({note})") + print(f" 主机 RAM : ≥ {r.host_ram_gb:6.2f} GB") + print(f" 磁盘 (典型) : ≈ {r.disk_gb:6.0f} GB") + + print() + print("说明:") + print(" - 估算包含 bf16 AMP + AdamW(m,v fp32) + 梯度 fp32 主副本 + PCGrad 开销。") + print(" - 开 ``gradient_checkpointing`` 可把激活降至约 1/3,BS 可成倍提升。") + print(" - 实测请用 ``nvidia-smi`` 或 ``torch.cuda.max_memory_allocated()`` 校准。") + + +if __name__ == "__main__": + main() diff --git a/scripts/ingest_hub_to_bucket.py b/scripts/ingest_hub_to_bucket.py index a62c5eb6786816f2e8a8b12e86050d55c03e5520..e3daf590d51bfd33468bf934f81cc43008e3abf7 100644 --- a/scripts/ingest_hub_to_bucket.py +++ b/scripts/ingest_hub_to_bucket.py @@ -1,207 +1,234 @@ -"""将 Hub 上数据集/仓库路径 **服务端拷贝** 到 Storage Bucket,并在挂载点上可选解压 ``.tar`` / ``.tar.*``。 - -解压输出默认写入 **另一棵目录树**(与 ``--dest-prefix`` 平级的 ``{dest-prefix}_unpacked/``), -相对路径与镜像里的 ``.tar`` 一致,避免在源树旁叠 ``*_extracted/`` 导致 ``rglob`` 反复扫到嵌套 tar。 - -示例(本地或 Job 内,且已挂载 bucket 到 ``/mnt/cosmos``):: - - python scripts/ingest_hub_to_bucket.py \\ - --bucket fuzirui/my-cosmos-bucket \\ - --dest-prefix cosmos_hub_mirror \\ - --bucket-mount /mnt/cosmos \\ - --extract-tars - -仅拷贝、不解压:: - - python scripts/ingest_hub_to_bucket.py \\ - --bucket fuzirui/my-cosmos-bucket \\ - --source 'hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/' \\ - --dest-prefix raw \\ - --copy-only -""" - -from __future__ import annotations - -import argparse -import sys -import tarfile -from pathlib import Path - -from huggingface_hub import HfApi, create_bucket - - -def _ensure_trailing_slash_hf_url(url: str) -> str: - s = url.strip() - if s.endswith("/"): - return s - return s + "/" - - -def _archive_stem(path: Path) -> str: - """``foo.tar.gz`` -> ``foo``;``bar.tar`` -> ``bar``。""" - n = path.name - for ext in (".tar.gz", ".tar.xz", ".tgz", ".tar"): - if n.endswith(ext): - return n[: -len(ext)] - return path.stem - - -def _is_under_path(path: Path, parent: Path) -> bool: - try: - path.resolve().relative_to(parent.resolve()) - return True - except ValueError: - return False - - -def _collect_archives( - root: Path, - patterns: tuple[str, ...], - *, - exclude_under: Path | None = None, -) -> list[Path]: - """收集待解压归档,排除历史 ``*_extracted`` 目录及解压输出树,避免嵌套/重复扫描。""" - out: list[Path] = [] - seen: set[Path] = set() - for pat in patterns: - for p in root.rglob(pat): - if not p.is_file(): - continue - rp = p.resolve() - if rp in seen: - continue - if any(part.endswith("_extracted") or part == "_extracted" for part in p.parts): - continue - if exclude_under is not None and _is_under_path(p, exclude_under): - continue - seen.add(rp) - out.append(p) - return sorted(out) - - -def main() -> None: - parser = argparse.ArgumentParser(description="Hub copy_files → Bucket,可选按镜像目录解压 tar") - parser.add_argument( - "--source", - default="hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/", - help="hf:// 源(仓库或 bucket 前缀),目录建议以 / 结尾", - ) - parser.add_argument( - "--bucket", - required=True, - help='目标 bucket id,如 "user/my-bucket"(不要写 hf://)', - ) - parser.add_argument( - "--dest-prefix", - default="cosmos_hub_mirror", - help="copy_files 写入 bucket 内的子路径(不要用前导 /)", - ) - parser.add_argument( - "--ensure-bucket", - action="store_true", - help="若不存在则 create_bucket(..., exist_ok=True)", - ) - parser.add_argument( - "--copy-only", - action="store_true", - help="只做 copy_files,不解压", - ) - parser.add_argument( - "--bucket-mount", - default=None, - help="Job 内 bucket 挂载点(如 /mnt/cosmos);若设 --extract-tars 则必填", - ) - parser.add_argument( - "--extract-tars", - action="store_true", - help="解压 mirror 树下的 tar;输出见 --extract-out-prefix", - ) - parser.add_argument( - "--extract-out-prefix", - default=None, - metavar="NAME", - help="解压根目录(bucket 内相对路径,与 dest-prefix 平级)。默认 {dest-prefix}_unpacked", - ) - parser.add_argument( - "--extract-beside-tar", - action="store_true", - help="旧行为:在每条 tar 旁建 ``{name}_extracted``(易与 rglob 嵌套 tar 纠缠,一般不推荐)", - ) - parser.add_argument( - "--max-tars", - type=int, - default=None, - help="最多处理多少个 tar(烟囱/限流)", - ) - args = parser.parse_args() - - src = _ensure_trailing_slash_hf_url(args.source) - dest_prefix = args.dest_prefix.strip().strip("/") - dest = f"hf://buckets/{args.bucket}/{dest_prefix}/" - - api = HfApi() - if args.ensure_bucket: - create_bucket(args.bucket, exist_ok=True) - print(f"[ingest] bucket ready: {args.bucket}", flush=True) - - print(f"[ingest] copy_files\n {src}\n -> {dest}", flush=True) - api.copy_files(src, dest) - print("[ingest] copy_files 完成", flush=True) - - if args.copy_only or not args.extract_tars: - return - - if not args.bucket_mount: - print("[ingest] 错误: --extract-tars 需要 --bucket-mount", file=sys.stderr) - sys.exit(2) - - root = Path(args.bucket_mount) / dest_prefix - out_rel = args.extract_out_prefix - if out_rel is None: - out_rel = f"{dest_prefix}_unpacked" - out_rel = out_rel.strip().strip("/") - extract_base = Path(args.bucket_mount) / out_rel - - if not root.is_dir(): - print(f"[ingest] 警告: 镜像路径不存在或尚不可见: {root}", flush=True) - - patterns = ("*.tar", "*.tar.gz", "*.tar.xz", "*.tgz") - archives = _collect_archives(root, patterns, exclude_under=extract_base) - - if args.max_tars is not None: - archives = archives[: args.max_tars] - - mode = "beside-tar" if args.extract_beside_tar else f"mirror -> {extract_base}" - print(f"[ingest] 将解压 {len(archives)} 个归档 under {root}(模式: {mode})", flush=True) - - for i, tar_path in enumerate(archives): - if args.extract_beside_tar: - out_dir = tar_path.parent / f"{tar_path.name}_extracted" - else: - rel = tar_path.relative_to(root) - out_dir = extract_base / rel.parent / _archive_stem(tar_path) - - if out_dir.exists() and any(out_dir.iterdir()): - print(f"[ingest] ({i + 1}/{len(archives)}) 跳过(已存在非空) {out_dir}", flush=True) - continue - out_dir.mkdir(parents=True, exist_ok=True) - print(f"[ingest] ({i + 1}/{len(archives)}) {tar_path} -> {out_dir}", flush=True) - try: - with tarfile.open(tar_path, mode="r:*") as tf: - _extract(tf, out_dir) - except Exception as e: - print(f"[ingest] 解压失败 {tar_path}: {e}", flush=True) - raise - - print("[ingest] 全部完成", flush=True) - - -def _extract(tf: tarfile.TarFile, out_dir: Path) -> None: - if sys.version_info >= (3, 12): - tf.extractall(out_dir, filter="data") - else: - tf.extractall(out_dir) - - -if __name__ == "__main__": - main() +"""将 Hub 上数据集/仓库路径 **服务端拷贝** 到 Storage Bucket,并在挂载点上可选解压 ``.tar`` / ``.tar.*``。 + +HF Jobs 容器 **根分区**(ephemeral)常有 **~50GiB** 上限。``copy_files`` 在 Hub 侧完成,几乎不占根盘; +解压、``pip``、Hub 客户端默认缓存会写 ``/tmp``、``~/.cache``,易触发 eviction。 +若设 ``--bucket-mount``,会把 ``TMPDIR``、``HF_HOME`` 等重定向到挂载点下 ``.wjad_ephemeral/``, +大块临时数据落在 **Bucket**。训练时用 Volume 挂载 Bucket,``WJAD_DATA_ROOT`` 指到 mirror 或解压树即可, +无需把整库先下载到根盘。 + +解压输出默认写入 **另一棵目录树**(与 ``--dest-prefix`` 平级的 ``{dest-prefix}_unpacked/``), +相对路径与镜像里的 ``.tar`` 一致,避免在源树旁叠 ``*_extracted/`` 导致 ``rglob`` 反复扫到嵌套 tar。 + +示例(本地或 Job 内,且已挂载 bucket 到 ``/mnt/cosmos``):: + + python scripts/ingest_hub_to_bucket.py \\ + --bucket fuzirui/my-cosmos-bucket \\ + --dest-prefix cosmos_hub_mirror \\ + --bucket-mount /mnt/cosmos \\ + --extract-tars + +仅拷贝、不解压:: + + python scripts/ingest_hub_to_bucket.py \\ + --bucket fuzirui/my-cosmos-bucket \\ + --source 'hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/' \\ + --dest-prefix raw \\ + --copy-only +""" + +from __future__ import annotations + +import argparse +import os +import sys +import tarfile +from pathlib import Path + +from huggingface_hub import HfApi, create_bucket + + +def _ensure_trailing_slash_hf_url(url: str) -> str: + s = url.strip() + if s.endswith("/"): + return s + return s + "/" + + +def _archive_stem(path: Path) -> str: + """``foo.tar.gz`` -> ``foo``;``bar.tar`` -> ``bar``。""" + n = path.name + for ext in (".tar.gz", ".tar.xz", ".tgz", ".tar"): + if n.endswith(ext): + return n[: -len(ext)] + return path.stem + + +def _is_under_path(path: Path, parent: Path) -> bool: + try: + path.resolve().relative_to(parent.resolve()) + return True + except ValueError: + return False + + +def _collect_archives( + root: Path, + patterns: tuple[str, ...], + *, + exclude_under: Path | None = None, +) -> list[Path]: + """收集待解压归档,排除历史 ``*_extracted`` 目录及解压输出树,避免嵌套/重复扫描。""" + out: list[Path] = [] + seen: set[Path] = set() + for pat in patterns: + for p in root.rglob(pat): + if not p.is_file(): + continue + rp = p.resolve() + if rp in seen: + continue + if any(part.endswith("_extracted") or part == "_extracted" for part in p.parts): + continue + if exclude_under is not None and _is_under_path(p, exclude_under): + continue + seen.add(rp) + out.append(p) + return sorted(out) + + +def _redirect_ephemeral_to_bucket(bucket_mount: Path) -> None: + """把临时文件与 HF 缓存写到 Bucket 挂载点,避免撑爆 Job 50G 根分区。""" + base = bucket_mount / ".wjad_ephemeral" + tmp = base / "tmp" + hf_home = base / "hf_home" + xdg = base / "xdg_cache" + for d in (tmp, hf_home, hf_home / "hub", xdg): + d.mkdir(parents=True, exist_ok=True) + os.environ["TMPDIR"] = str(tmp) + os.environ["TMP"] = str(tmp) + os.environ["TEMP"] = str(tmp) + os.environ["HF_HOME"] = str(hf_home) + os.environ["HF_HUB_CACHE"] = str(hf_home / "hub") + os.environ["XDG_CACHE_HOME"] = str(xdg) + print(f"[ingest] 临时/缓存 -> {base}", flush=True) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Hub copy_files → Bucket,可选按镜像目录解压 tar") + parser.add_argument( + "--source", + default="hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/", + help="hf:// 源(仓库或 bucket 前缀),目录建议以 / 结尾", + ) + parser.add_argument( + "--bucket", + required=True, + help='目标 bucket id,如 "user/my-bucket"(不要写 hf://)', + ) + parser.add_argument( + "--dest-prefix", + default="cosmos_hub_mirror", + help="copy_files 写入 bucket 内的子路径(不要用前导 /)", + ) + parser.add_argument( + "--ensure-bucket", + action="store_true", + help="若不存在则 create_bucket(..., exist_ok=True)", + ) + parser.add_argument( + "--copy-only", + action="store_true", + help="只做 copy_files,不解压", + ) + parser.add_argument( + "--bucket-mount", + default=None, + help="Job 内 bucket 挂载点(如 /mnt/cosmos);若设 --extract-tars 则必填", + ) + parser.add_argument( + "--extract-tars", + action="store_true", + help="解压 mirror 树下的 tar;输出见 --extract-out-prefix", + ) + parser.add_argument( + "--extract-out-prefix", + default=None, + metavar="NAME", + help="解压根目录(bucket 内相对路径,与 dest-prefix 平级)。默认 {dest-prefix}_unpacked", + ) + parser.add_argument( + "--extract-beside-tar", + action="store_true", + help="旧行为:在每条 tar 旁建 ``{name}_extracted``(易与 rglob 嵌套 tar 纠缠,一般不推荐)", + ) + parser.add_argument( + "--max-tars", + type=int, + default=None, + help="最多处理多少个 tar(烟囱/限流)", + ) + args = parser.parse_args() + + if args.bucket_mount: + _redirect_ephemeral_to_bucket(Path(args.bucket_mount)) + + src = _ensure_trailing_slash_hf_url(args.source) + dest_prefix = args.dest_prefix.strip().strip("/") + dest = f"hf://buckets/{args.bucket}/{dest_prefix}/" + + api = HfApi() + if args.ensure_bucket: + create_bucket(args.bucket, exist_ok=True) + print(f"[ingest] bucket ready: {args.bucket}", flush=True) + + print(f"[ingest] copy_files\n {src}\n -> {dest}", flush=True) + api.copy_files(src, dest) + print("[ingest] copy_files 完成", flush=True) + + if args.copy_only or not args.extract_tars: + return + + if not args.bucket_mount: + print("[ingest] 错误: --extract-tars 需要 --bucket-mount", file=sys.stderr) + sys.exit(2) + + root = Path(args.bucket_mount) / dest_prefix + out_rel = args.extract_out_prefix + if out_rel is None: + out_rel = f"{dest_prefix}_unpacked" + out_rel = out_rel.strip().strip("/") + extract_base = Path(args.bucket_mount) / out_rel + + if not root.is_dir(): + print(f"[ingest] 警告: 镜像路径不存在或尚不可见: {root}", flush=True) + + patterns = ("*.tar", "*.tar.gz", "*.tar.xz", "*.tgz") + archives = _collect_archives(root, patterns, exclude_under=extract_base) + + if args.max_tars is not None: + archives = archives[: args.max_tars] + + mode = "beside-tar" if args.extract_beside_tar else f"mirror -> {extract_base}" + print(f"[ingest] 将解压 {len(archives)} 个归档 under {root}(模式: {mode})", flush=True) + + for i, tar_path in enumerate(archives): + if args.extract_beside_tar: + out_dir = tar_path.parent / f"{tar_path.name}_extracted" + else: + rel = tar_path.relative_to(root) + out_dir = extract_base / rel.parent / _archive_stem(tar_path) + + if out_dir.exists() and any(out_dir.iterdir()): + print(f"[ingest] ({i + 1}/{len(archives)}) 跳过(已存在非空) {out_dir}", flush=True) + continue + out_dir.mkdir(parents=True, exist_ok=True) + print(f"[ingest] ({i + 1}/{len(archives)}) {tar_path} -> {out_dir}", flush=True) + try: + with tarfile.open(tar_path, mode="r:*") as tf: + _extract(tf, out_dir) + except Exception as e: + print(f"[ingest] 解压失败 {tar_path}: {e}", flush=True) + raise + + print("[ingest] 全部完成", flush=True) + + +def _extract(tf: tarfile.TarFile, out_dir: Path) -> None: + if sys.version_info >= (3, 12): + tf.extractall(out_dir, filter="data") + else: + tf.extractall(out_dir) + + +if __name__ == "__main__": + main() diff --git a/scripts/push_cpu_ingest_job.py b/scripts/push_cpu_ingest_job.py index e8ae02694f00bd00da25c1ef45a7b17be30625e8..f4b0631263e83baae5856dcb8ab2fb74ce462e42 100644 --- a/scripts/push_cpu_ingest_job.py +++ b/scripts/push_cpu_ingest_job.py @@ -1,141 +1,148 @@ -"""提交 **CPU Basic** Job:把 Hub 上 Cosmos(或其它源)服务端复制到你的 Bucket,并尝试解压 tar。 - -- 计费:见 https://huggingface.co/docs/hub/jobs-pricing(CPU Basic 约 \\$0.01/ 小时量级,以官网为准)。 -- 默认挂载:代码 ``fuzirui/WJAD``、可写 Bucket;超时默认 48h(大仓库复制可能很久)。 -- 须 ``hf auth login``;NVIDIA 数据集须在网页接受条款。 - -用法:: - - python scripts/push_cpu_ingest_job.py --bucket fuzirui/wjad-cosmos-data - python scripts/push_cpu_ingest_job.py --bucket fuzirui/wjad-cosmos-data --follow - python scripts/push_cpu_ingest_job.py --bucket fuzirui/x --source 'hf://datasets/foo/bar/' -""" - -from __future__ import annotations - -import argparse -import os -import sys - -from huggingface_hub import HfApi, Volume, create_bucket - -try: - from huggingface_hub.cli._cli_utils import parse_env_map -except Exception: # pragma: no cover - parse_env_map = None - -DEFAULT_CODE_REPO = "fuzirui/WJAD" -DEFAULT_BUCKET = "fuzirui/WJAD" -DEFAULT_SOURCE = "hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/" -DEFAULT_DEST_PREFIX = "cosmos_hub_mirror" -DEFAULT_IMAGE = "python:3.12" -DEFAULT_TIMEOUT = "7d" - - -def _secrets_for_job() -> dict | None: - if parse_env_map is not None: - try: - m = parse_env_map(["HF_TOKEN"]) - if m.get("HF_TOKEN"): - return m - except Exception: - pass - t = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") - return {"HF_TOKEN": t} if t else None - - -def main() -> None: - parser = argparse.ArgumentParser(description="HF Jobs:CPU 拉取 Hub → Bucket + 解压") - parser.add_argument("--bucket", default=DEFAULT_BUCKET, help="目标 Storage Bucket id(须已创建或加 --ensure-bucket)") - parser.add_argument("--code-repo", default=DEFAULT_CODE_REPO, help="含 ingest 脚本的 Hub model/space id") - parser.add_argument("--code-type", default="model", choices=("model", "space", "dataset")) - parser.add_argument("--source", default=DEFAULT_SOURCE, help="hf:// 源目录") - parser.add_argument("--dest-prefix", default=DEFAULT_DEST_PREFIX, help="bucket 内子路径") - parser.add_argument( - "--skip-create-bucket", - action="store_true", - help="不在本机预先 create_bucket(bucket 必须已存在,否则挂载失败)", - ) - parser.add_argument( - "--no-extract", - action="store_true", - help="只做 copy_files,不解压 tar", - ) - parser.add_argument( - "--max-tars", - type=int, - default=None, - help="传给 ingest_hub_to_bucket.py --max-tars", - ) - parser.add_argument( - "--extract-out-prefix", - default=None, - metavar="NAME", - help="解压输出子路径(默认 {dest-prefix}_unpacked)", - ) - parser.add_argument( - "--extract-beside-tar", - action="store_true", - help="旧行为:在每条 tar 旁解压为 _extracted", - ) - parser.add_argument("--image", default=DEFAULT_IMAGE) - parser.add_argument("--timeout", default=DEFAULT_TIMEOUT) - parser.add_argument("--follow", action="store_true") - parser.add_argument("--no-secrets", action="store_true") - args = parser.parse_args() - - bucket_mount = "/mnt/cosmos" - code_mount = "/workspace" - - max_tars = "" - if args.max_tars is not None: - max_tars = f" --max-tars {args.max_tars}" - - extract_flag = "" if args.no_extract else " --extract-tars" - - extract_beside = " --extract-beside-tar" if args.extract_beside_tar else "" - out_prefix = "" - if args.extract_out_prefix: - out_prefix = f" --extract-out-prefix '{args.extract_out_prefix}'" - - script = f"""set -euo pipefail -pip install --root-user-action=ignore --no-cache-dir 'huggingface_hub>=0.30' -python {code_mount}/scripts/ingest_hub_to_bucket.py \\ - --bucket '{args.bucket}' \\ - --source '{args.source}' \\ - --dest-prefix '{args.dest_prefix}' \\ - --bucket-mount '{bucket_mount}'{extract_flag}{max_tars}{out_prefix}{extract_beside} -""" - - secrets = None if args.no_secrets else _secrets_for_job() - if secrets is None and not args.no_secrets: - print("[push_cpu_ingest] 警告: 无 HF_TOKEN,gated 数据会失败。", file=sys.stderr) - - if not args.skip_create_bucket: - create_bucket(args.bucket, exist_ok=True) - print(f"[push_cpu_ingest] bucket 已确保存在(或已存在): {args.bucket}") - - volumes = [ - Volume(type=args.code_type, source=args.code_repo, mount_path=code_mount), - Volume(type="bucket", source=args.bucket, mount_path=bucket_mount), - ] - - api = HfApi() - job = api.run_job( - image=args.image, - command=["bash", "-lc", script], - flavor="cpu-basic", - volumes=volumes, - secrets=secrets, - timeout=args.timeout, - ) - print(f"[push_cpu_ingest] Job ID: {job.id}") - print(f"[push_cpu_ingest] URL: {job.url}") - - if args.follow: - for line in api.fetch_job_logs(job_id=job.id, namespace=job.owner.name, follow=True): - print(line, end="" if str(line).endswith("\n") else "\n") - - -if __name__ == "__main__": - main() +"""提交 **CPU Basic** Job:把 Hub 上 Cosmos(或其它源)服务端复制到你的 Bucket,并尝试解压 tar。 + +- 计费:见 https://huggingface.co/docs/hub/jobs-pricing(CPU Basic 约 \\$0.01/ 小时量级,以官网为准)。 +- 默认挂载:代码 ``fuzirui/WJAD``、可写 Bucket;超时默认 48h(大仓库复制可能很久)。 +- 须 ``hf auth login``;NVIDIA 数据集须在网页接受条款。 + +用法:: + + python scripts/push_cpu_ingest_job.py --bucket fuzirui/wjad-cosmos-data + python scripts/push_cpu_ingest_job.py --bucket fuzirui/wjad-cosmos-data --follow + python scripts/push_cpu_ingest_job.py --bucket fuzirui/x --source 'hf://datasets/foo/bar/' +""" + +from __future__ import annotations + +import argparse +import os +import sys + +from huggingface_hub import HfApi, Volume, create_bucket + +try: + from huggingface_hub.cli._cli_utils import parse_env_map +except Exception: # pragma: no cover + parse_env_map = None + +DEFAULT_CODE_REPO = "fuzirui/WJAD" +DEFAULT_BUCKET = "fuzirui/WJAD" +DEFAULT_SOURCE = "hf://datasets/nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams/" +DEFAULT_DEST_PREFIX = "cosmos_hub_mirror" +DEFAULT_IMAGE = "python:3.12" +DEFAULT_TIMEOUT = "7d" + + +def _secrets_for_job() -> dict | None: + if parse_env_map is not None: + try: + m = parse_env_map(["HF_TOKEN"]) + if m.get("HF_TOKEN"): + return m + except Exception: + pass + t = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") + return {"HF_TOKEN": t} if t else None + + +def main() -> None: + parser = argparse.ArgumentParser(description="HF Jobs:CPU 拉取 Hub → Bucket + 解压") + parser.add_argument("--bucket", default=DEFAULT_BUCKET, help="目标 Storage Bucket id(须已创建或加 --ensure-bucket)") + parser.add_argument("--code-repo", default=DEFAULT_CODE_REPO, help="含 ingest 脚本的 Hub model/space id") + parser.add_argument("--code-type", default="model", choices=("model", "space", "dataset")) + parser.add_argument("--source", default=DEFAULT_SOURCE, help="hf:// 源目录") + parser.add_argument("--dest-prefix", default=DEFAULT_DEST_PREFIX, help="bucket 内子路径") + parser.add_argument( + "--skip-create-bucket", + action="store_true", + help="不在本机预先 create_bucket(bucket 必须已存在,否则挂载失败)", + ) + parser.add_argument( + "--no-extract", + action="store_true", + help="只做 copy_files,不解压 tar", + ) + parser.add_argument( + "--max-tars", + type=int, + default=None, + help="传给 ingest_hub_to_bucket.py --max-tars", + ) + parser.add_argument( + "--extract-out-prefix", + default=None, + metavar="NAME", + help="解压输出子路径(默认 {dest-prefix}_unpacked)", + ) + parser.add_argument( + "--extract-beside-tar", + action="store_true", + help="旧行为:在每条 tar 旁解压为 _extracted", + ) + parser.add_argument("--image", default=DEFAULT_IMAGE) + parser.add_argument("--timeout", default=DEFAULT_TIMEOUT) + parser.add_argument("--follow", action="store_true") + parser.add_argument("--no-secrets", action="store_true") + args = parser.parse_args() + + bucket_mount = "/mnt/cosmos" + code_mount = "/workspace" + + max_tars = "" + if args.max_tars is not None: + max_tars = f" --max-tars {args.max_tars}" + + extract_flag = "" if args.no_extract else " --extract-tars" + + extract_beside = " --extract-beside-tar" if args.extract_beside_tar else "" + out_prefix = "" + if args.extract_out_prefix: + out_prefix = f" --extract-out-prefix '{args.extract_out_prefix}'" + + script = f"""set -euo pipefail +Eph="{bucket_mount}/.wjad_ephemeral" +mkdir -p "$Eph/tmp" "$Eph/hf_home/hub" "$Eph/xdg_cache" +export TMPDIR="$Eph/tmp" +export TMP="$TMPDIR" TEMP="$TMPDIR" +export HF_HOME="$Eph/hf_home" +export HF_HUB_CACHE="$HF_HOME/hub" +export XDG_CACHE_HOME="$Eph/xdg_cache" +pip install --root-user-action=ignore --no-cache-dir 'huggingface_hub>=0.30' +python {code_mount}/scripts/ingest_hub_to_bucket.py \\ + --bucket '{args.bucket}' \\ + --source '{args.source}' \\ + --dest-prefix '{args.dest_prefix}' \\ + --bucket-mount '{bucket_mount}'{extract_flag}{max_tars}{out_prefix}{extract_beside} +""" + + secrets = None if args.no_secrets else _secrets_for_job() + if secrets is None and not args.no_secrets: + print("[push_cpu_ingest] 警告: 无 HF_TOKEN,gated 数据会失败。", file=sys.stderr) + + if not args.skip_create_bucket: + create_bucket(args.bucket, exist_ok=True) + print(f"[push_cpu_ingest] bucket 已确保存在(或已存在): {args.bucket}") + + volumes = [ + Volume(type=args.code_type, source=args.code_repo, mount_path=code_mount), + Volume(type="bucket", source=args.bucket, mount_path=bucket_mount), + ] + + api = HfApi() + job = api.run_job( + image=args.image, + command=["bash", "-lc", script], + flavor="cpu-basic", + volumes=volumes, + secrets=secrets, + timeout=args.timeout, + ) + print(f"[push_cpu_ingest] Job ID: {job.id}") + print(f"[push_cpu_ingest] URL: {job.url}") + + if args.follow: + for line in api.fetch_job_logs(job_id=job.id, namespace=job.owner.name, follow=True): + print(line, end="" if str(line).endswith("\n") else "\n") + + +if __name__ == "__main__": + main() diff --git a/scripts/push_to_jobs.py b/scripts/push_to_jobs.py index c08e8ef6061fe906bb13a7271ca070cbe2f52a4d..4764292f1170ad0abc782b1c12be8aadecfe1c4e 100644 --- a/scripts/push_to_jobs.py +++ b/scripts/push_to_jobs.py @@ -1,196 +1,196 @@ -"""提交 Hugging Face Jobs 正式训练。 - -- 代码仓库挂载为只读后复制到 ``/tmp/wjad-run`` 再 ``pip install -e .``。 -- **数据**:``CosmosDriveDreamsDataset`` 需要 NVIDIA ``download.py`` 拉下来的目录树 - (``synthetic/single_view/generation/*.mp4`` + ``labels/``)。Hub 上 **datasets 视图挂载** - 不是这棵树,``build_clip_index`` 会得到 0 条样本。 -- **默认**:在 Job 里先执行 ``scripts/download_data.py``,把数据落到可写目录 - ``WJAD_DATA_ROOT``(默认 ``/tmp/wjad-cosmos``)——即 **一次性下载**(按 clip 限流可用 - ``--download-limit``)。全量约 TB 级,请用大磁盘 Job 或挂 **HF Bucket** 并把 - ``WJAD_DATA_ROOT`` 指到挂载路径。 -- **流式**:当前 DataLoader 按视频/帧文件随机访问,未接 ``datasets`` 流式 API;要低改动 - 流式需另用 ``IterableDataset`` + shard,属后续工作。 - -用法: - - python scripts/push_to_jobs.py - python scripts/push_to_jobs.py --follow - python scripts/push_to_jobs.py --download-limit 0 - python scripts/push_to_jobs.py --skip-download - python scripts/push_to_jobs.py --mount-hub-dataset -""" - -from __future__ import annotations - -import argparse -import sys - -from huggingface_hub import HfApi, Volume - -try: - from huggingface_hub.cli._cli_utils import parse_env_map -except Exception: # pragma: no cover - parse_env_map = None - - -DEFAULT_FLAVOR = "a10g-large" -DEFAULT_IMAGE = "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-runtime" -DEFAULT_REPO = "fuzirui/WJAD" -DEFAULT_MOUNT = "/workspace" -DEFAULT_HUB_DATASET = "nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams" -DEFAULT_HUB_DATASET_MOUNT = "/data/cosmos" -DEFAULT_DATA_PREP_DIR = "/tmp/wjad-cosmos" - - -def _secrets_for_job() -> dict | None: - if parse_env_map is not None: - try: - m = parse_env_map(["HF_TOKEN"]) - if m.get("HF_TOKEN"): - return m - except Exception: - pass - import os - - t = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") - return {"HF_TOKEN": t} if t else None - - -def main() -> None: - parser = argparse.ArgumentParser(description="在 HF Jobs 上启动 WJAD 训练") - parser.add_argument("--repo", default=DEFAULT_REPO, help="含本仓库代码的 Hub repo id") - parser.add_argument( - "--repo-type", - default="model", - choices=("model", "space", "dataset"), - help="仓库类型", - ) - parser.add_argument("--mount", default=DEFAULT_MOUNT, help="代码在容器内的挂载路径") - parser.add_argument("--flavor", default=DEFAULT_FLAVOR) - parser.add_argument("--image", default=DEFAULT_IMAGE) - parser.add_argument( - "--follow", - action="store_true", - help="跟随 Job 日志直到结束", - ) - parser.add_argument("--no-secrets", action="store_true") - parser.add_argument("--timeout", default=None, help="如 168h、7d") - # —— 数据:默认 NVIDIA 下载到可写目录 —— - parser.add_argument( - "--data-prep-dir", - default=DEFAULT_DATA_PREP_DIR, - help="下载目标目录(可写)。可被环境变量 WJAD_DATA_ROOT 覆盖", - ) - parser.add_argument( - "--download-workers", - type=int, - default=8, - help="download_data.py --workers", - ) - parser.add_argument( - "--download-limit", - type=int, - default=8, - metavar="N", - help="传给 NVIDIA --limit;默认 8 个 clip 控制磁盘。0=不限制(全量,需足够盘或 Bucket)", - ) - parser.add_argument( - "--skip-download", - action="store_true", - help="不运行 download_data.py(数据已存在于 WJAD_DATA_ROOT 或 Bucket 挂载点)", - ) - # —— 可选:挂载 Hub dataset(当前 loader 一般不兼容,仅特殊预处理树可用)—— - parser.add_argument( - "--mount-hub-dataset", - action="store_true", - help="额外只读挂载 nvidia/Cosmos dataset 到 --hub-dataset-mount(与自动下载互斥)", - ) - parser.add_argument("--hub-dataset", default=DEFAULT_HUB_DATASET, metavar="REPO_ID") - parser.add_argument("--hub-dataset-mount", default=DEFAULT_HUB_DATASET_MOUNT) - - args = parser.parse_args() - - code_vol = Volume(type=args.repo_type, source=args.repo, mount_path=args.mount) - ro_mount = args.mount - work = "/tmp/wjad-run" - volumes: list[Volume] = [code_vol] - - if args.mount_hub_dataset: - volumes.append( - Volume(type="dataset", source=args.hub_dataset, mount_path=args.hub_dataset_mount) - ) - - use_auto_download = not args.skip_download and not args.mount_hub_dataset - data_root_default = ( - args.hub_dataset_mount if args.mount_hub_dataset else args.data_prep_dir - ) - - limit_tail = "" - if use_auto_download and args.download_limit > 0: - limit_tail = f" --limit {args.download_limit}" - - download_block = "" - if use_auto_download: - download_block = f""" -mkdir -p "$DATA_ROOT" -python scripts/download_data.py --odir "$DATA_ROOT" \\ - --file_types synthetic,lidar,hdmap --workers {args.download_workers}{limit_tail} -""" - - script = f"""set -euo pipefail -rm -rf {work} -cp -a {ro_mount} {work} -cd {work} -export PIP_ROOT_USER_ACTION=ignore -pip install --root-user-action=ignore --no-cache-dir -e . -export PYTHONPATH="{work}/src:${{PYTHONPATH:-}}" -DATA_ROOT="${{WJAD_DATA_ROOT:-{data_root_default}}}" -{download_block} -python -m wjad.train.runner_local \\ - --device cuda \\ - --config configs/default.yaml \\ - --data_root "$DATA_ROOT" \\ - --dinov3_path "${{DINOV3_PATH:-{work}/dinov3-vitb16-pretrain-lvd1689m}}" -""" - - secrets = None if args.no_secrets else _secrets_for_job() - if secrets is None and not args.no_secrets: - print( - "[push_to_jobs] 警告: 未解析到 HF_TOKEN,下载/checkpoint 可能失败。请先 hf auth login。", - file=sys.stderr, - ) - - if args.mount_hub_dataset: - print( - f"[push_to_jobs] 已挂载 Hub dataset(只读){args.hub_dataset} -> {args.hub_dataset_mount};" - "若仍 0 样本,说明布局不是 synthetic/*/generation + labels/,请不要用 --mount-hub-dataset," - "改用默认自动 download。" - ) - elif use_auto_download: - lim_msg = f"limit={args.download_limit}" if args.download_limit > 0 else "无 limit(全量)" - print( - f"[push_to_jobs] 将下载到 DATA_ROOT={data_root_default}({lim_msg})。" - "全量请 --download-limit 0 并保证磁盘或 Bucket。" - ) - else: - print("[push_to_jobs] 已 --skip-download,请保证 $WJAD_DATA_ROOT 下已有 NVIDIA 布局数据。") - - api = HfApi() - job = api.run_job( - image=args.image, - command=["bash", "-lc", script], - flavor=args.flavor, - volumes=volumes, - secrets=secrets, - timeout=args.timeout, - ) - print(f"[push_to_jobs] Job ID: {job.id}") - print(f"[push_to_jobs] URL: {job.url}") - - if args.follow: - for line in api.fetch_job_logs(job_id=job.id, namespace=job.owner.name, follow=True): - print(line, end="" if str(line).endswith("\n") else "\n") - - -if __name__ == "__main__": - main() +"""提交 Hugging Face Jobs 正式训练。 + +- 代码仓库挂载为只读后复制到 ``/tmp/wjad-run`` 再 ``pip install -e .``。 +- **数据**:``CosmosDriveDreamsDataset`` 需要 NVIDIA ``download.py`` 拉下来的目录树 + (``synthetic/single_view/generation/*.mp4`` + ``labels/``)。Hub 上 **datasets 视图挂载** + 不是这棵树,``build_clip_index`` 会得到 0 条样本。 +- **默认**:在 Job 里先执行 ``scripts/download_data.py``,把数据落到可写目录 + ``WJAD_DATA_ROOT``(默认 ``/tmp/wjad-cosmos``)——即 **一次性下载**(按 clip 限流可用 + ``--download-limit``)。全量约 TB 级,请用大磁盘 Job 或挂 **HF Bucket** 并把 + ``WJAD_DATA_ROOT`` 指到挂载路径。 +- **流式**:当前 DataLoader 按视频/帧文件随机访问,未接 ``datasets`` 流式 API;要低改动 + 流式需另用 ``IterableDataset`` + shard,属后续工作。 + +用法: + + python scripts/push_to_jobs.py + python scripts/push_to_jobs.py --follow + python scripts/push_to_jobs.py --download-limit 0 + python scripts/push_to_jobs.py --skip-download + python scripts/push_to_jobs.py --mount-hub-dataset +""" + +from __future__ import annotations + +import argparse +import sys + +from huggingface_hub import HfApi, Volume + +try: + from huggingface_hub.cli._cli_utils import parse_env_map +except Exception: # pragma: no cover + parse_env_map = None + + +DEFAULT_FLAVOR = "a10g-large" +DEFAULT_IMAGE = "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-runtime" +DEFAULT_REPO = "fuzirui/WJAD" +DEFAULT_MOUNT = "/workspace" +DEFAULT_HUB_DATASET = "nvidia/PhysicalAI-Autonomous-Vehicle-Cosmos-Drive-Dreams" +DEFAULT_HUB_DATASET_MOUNT = "/data/cosmos" +DEFAULT_DATA_PREP_DIR = "/tmp/wjad-cosmos" + + +def _secrets_for_job() -> dict | None: + if parse_env_map is not None: + try: + m = parse_env_map(["HF_TOKEN"]) + if m.get("HF_TOKEN"): + return m + except Exception: + pass + import os + + t = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") + return {"HF_TOKEN": t} if t else None + + +def main() -> None: + parser = argparse.ArgumentParser(description="在 HF Jobs 上启动 WJAD 训练") + parser.add_argument("--repo", default=DEFAULT_REPO, help="含本仓库代码的 Hub repo id") + parser.add_argument( + "--repo-type", + default="model", + choices=("model", "space", "dataset"), + help="仓库类型", + ) + parser.add_argument("--mount", default=DEFAULT_MOUNT, help="代码在容器内的挂载路径") + parser.add_argument("--flavor", default=DEFAULT_FLAVOR) + parser.add_argument("--image", default=DEFAULT_IMAGE) + parser.add_argument( + "--follow", + action="store_true", + help="跟随 Job 日志直到结束", + ) + parser.add_argument("--no-secrets", action="store_true") + parser.add_argument("--timeout", default=None, help="如 168h、7d") + # —— 数据:默认 NVIDIA 下载到可写目录 —— + parser.add_argument( + "--data-prep-dir", + default=DEFAULT_DATA_PREP_DIR, + help="下载目标目录(可写)。可被环境变量 WJAD_DATA_ROOT 覆盖", + ) + parser.add_argument( + "--download-workers", + type=int, + default=8, + help="download_data.py --workers", + ) + parser.add_argument( + "--download-limit", + type=int, + default=8, + metavar="N", + help="传给 NVIDIA --limit;默认 8 个 clip 控制磁盘。0=不限制(全量,需足够盘或 Bucket)", + ) + parser.add_argument( + "--skip-download", + action="store_true", + help="不运行 download_data.py(数据已存在于 WJAD_DATA_ROOT 或 Bucket 挂载点)", + ) + # —— 可选:挂载 Hub dataset(当前 loader 一般不兼容,仅特殊预处理树可用)—— + parser.add_argument( + "--mount-hub-dataset", + action="store_true", + help="额外只读挂载 nvidia/Cosmos dataset 到 --hub-dataset-mount(与自动下载互斥)", + ) + parser.add_argument("--hub-dataset", default=DEFAULT_HUB_DATASET, metavar="REPO_ID") + parser.add_argument("--hub-dataset-mount", default=DEFAULT_HUB_DATASET_MOUNT) + + args = parser.parse_args() + + code_vol = Volume(type=args.repo_type, source=args.repo, mount_path=args.mount) + ro_mount = args.mount + work = "/tmp/wjad-run" + volumes: list[Volume] = [code_vol] + + if args.mount_hub_dataset: + volumes.append( + Volume(type="dataset", source=args.hub_dataset, mount_path=args.hub_dataset_mount) + ) + + use_auto_download = not args.skip_download and not args.mount_hub_dataset + data_root_default = ( + args.hub_dataset_mount if args.mount_hub_dataset else args.data_prep_dir + ) + + limit_tail = "" + if use_auto_download and args.download_limit > 0: + limit_tail = f" --limit {args.download_limit}" + + download_block = "" + if use_auto_download: + download_block = f""" +mkdir -p "$DATA_ROOT" +python scripts/download_data.py --odir "$DATA_ROOT" \\ + --file_types synthetic,lidar,hdmap --workers {args.download_workers}{limit_tail} +""" + + script = f"""set -euo pipefail +rm -rf {work} +cp -a {ro_mount} {work} +cd {work} +export PIP_ROOT_USER_ACTION=ignore +pip install --root-user-action=ignore --no-cache-dir -e . +export PYTHONPATH="{work}/src:${{PYTHONPATH:-}}" +DATA_ROOT="${{WJAD_DATA_ROOT:-{data_root_default}}}" +{download_block} +python -m wjad.train.runner_local \\ + --device cuda \\ + --config configs/default.yaml \\ + --data_root "$DATA_ROOT" \\ + --dinov3_path "${{DINOV3_PATH:-{work}/dinov3-vitb16-pretrain-lvd1689m}}" +""" + + secrets = None if args.no_secrets else _secrets_for_job() + if secrets is None and not args.no_secrets: + print( + "[push_to_jobs] 警告: 未解析到 HF_TOKEN,下载/checkpoint 可能失败。请先 hf auth login。", + file=sys.stderr, + ) + + if args.mount_hub_dataset: + print( + f"[push_to_jobs] 已挂载 Hub dataset(只读){args.hub_dataset} -> {args.hub_dataset_mount};" + "若仍 0 样本,说明布局不是 synthetic/*/generation + labels/,请不要用 --mount-hub-dataset," + "改用默认自动 download。" + ) + elif use_auto_download: + lim_msg = f"limit={args.download_limit}" if args.download_limit > 0 else "无 limit(全量)" + print( + f"[push_to_jobs] 将下载到 DATA_ROOT={data_root_default}({lim_msg})。" + "全量请 --download-limit 0 并保证磁盘或 Bucket。" + ) + else: + print("[push_to_jobs] 已 --skip-download,请保证 $WJAD_DATA_ROOT 下已有 NVIDIA 布局数据。") + + api = HfApi() + job = api.run_job( + image=args.image, + command=["bash", "-lc", script], + flavor=args.flavor, + volumes=volumes, + secrets=secrets, + timeout=args.timeout, + ) + print(f"[push_to_jobs] Job ID: {job.id}") + print(f"[push_to_jobs] URL: {job.url}") + + if args.follow: + for line in api.fetch_job_logs(job_id=job.id, namespace=job.owner.name, follow=True): + print(line, end="" if str(line).endswith("\n") else "\n") + + +if __name__ == "__main__": + main() diff --git a/scripts/push_to_sandbox.py b/scripts/push_to_sandbox.py index d7181f0f511709a4817d5d034bb9350ff3fd7149..b2a88bf725e63ad6a2edcbe030fa219dfcf7e7fc 100644 --- a/scripts/push_to_sandbox.py +++ b/scripts/push_to_sandbox.py @@ -1,185 +1,185 @@ -"""推送代码到 HF Space,做 sandbox 微训练。 - -依据 ``estimate_memory.py`` 的估算: - - BS=8 + bf16 + PCGrad + GradNorm 需要 ≥34 GB 显存; - - 默认硬件 **a10g-small**(~24 GB):与 ``smoke_train`` / ``sandbox_real_data`` 的 tiny 设置一致; - - 要拉满 BS=8 可改用 ``--gpu a10g-large`` 或 A100。 - -本脚本: - 1. ``huggingface_hub.create_repo`` 在 HF 上创建(或复用)一个 Space, - Space SDK = ``docker``; - 2. 用 ``upload_folder`` 上传当前仓库(排除 ``.venv``、数据集等); - 3. 写入 ``Dockerfile`` + ``app.py``(在 Space 启动时跑微训练)。 - -要求:先在本地 ``hf auth login``。 -""" - -from __future__ import annotations - -import argparse -from pathlib import Path - -from huggingface_hub import HfApi, create_repo - -ROOT = Path(__file__).resolve().parent.parent -DOCKERFILE = """FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04 - -ENV DEBIAN_FRONTEND=noninteractive -ENV PYTHONUNBUFFERED=1 -RUN apt-get update && apt-get install -y --no-install-recommends \\ - python3 python3-pip python3-venv ffmpeg libgl1 libglib2.0-0 git \\ - && rm -rf /var/lib/apt/lists/* - -# HF Space 默认用户(避免权限问题) -RUN useradd -m -u 1000 user -USER user -ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH - -WORKDIR /app -COPY --chown=user pyproject.toml /app/ -COPY --chown=user src /app/src -COPY --chown=user scripts /app/scripts -COPY --chown=user configs /app/configs -COPY --chown=user dinov3-vitb16-pretrain-lvd1689m /app/dinov3-vitb16-pretrain-lvd1689m -COPY --chown=user app.py /app/app.py - -RUN python3 -m pip install --user --no-cache-dir --upgrade pip \\ - && python3 -m pip install --user --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cu124 \\ - && python3 -m pip install --user --no-cache-dir -e . - -EXPOSE 7860 -CMD ["python3", "app.py"] -""" - -APP_PY = '''"""HF Sandbox 入口(docker SDK,监听 7860)。 - -启动后: - 1. 后台进程跑 scripts/smoke_train.py(追加写入 /tmp/wjad.log) - 2. 主进程开 HTTP server on :7860,返回最新日志 - -阶段 A(无需数据):smoke_train 用随机张量验证 GPU 上的 forward/反传/AMP/PCGrad。 -阶段 B(需要数据):把 LAUNCH_CMD 改为 runner_local 的真实训练命令。 -""" -import os -import subprocess -import sys -import threading -from http.server import BaseHTTPRequestHandler, HTTPServer - -LOG_PATH = "/tmp/wjad.log" -PORT = 7860 -# 当 SANDBOX_MODE=real_data 时跑真实标签 + 占位视频;否则跑随机张量 smoke。 -_MODE = os.environ.get("SANDBOX_MODE", "smoke") -if _MODE == "real_data": - LAUNCH_CMD = [sys.executable, "scripts/sandbox_real_data.py"] -else: - LAUNCH_CMD = [sys.executable, "scripts/smoke_train.py"] - - -def _print_env(f): - f.write("=" * 72 + "\\n") - f.write(" WJAD HF Sandbox\\n") - f.write("=" * 72 + "\\n") - f.write(f"Python: {sys.version}\\n") - try: - import torch - f.write(f"torch: {torch.__version__} cuda_avail={torch.cuda.is_available()}\\n") - if torch.cuda.is_available(): - p = torch.cuda.get_device_properties(0) - f.write(f"device: {p.name} vram={p.total_memory / 1024**3:.2f} GB\\n") - except Exception as e: - f.write(f"torch import failed: {e}\\n") - f.flush() - - -def run_training(): - with open(LOG_PATH, "w", buffering=1) as f: - _print_env(f) - f.write(f"$ {' '.join(LAUNCH_CMD)}\\n") - f.flush() - p = subprocess.Popen( - LAUNCH_CMD, stdout=f, stderr=subprocess.STDOUT, cwd="/app" - ) - rc = p.wait() - f.write(f"\\n[exit code = {rc}]\\n") - - -class Handler(BaseHTTPRequestHandler): - def do_GET(self): - try: - with open(LOG_PATH, "r") as f: - body = f.read() - except FileNotFoundError: - body = "starting..." - self.send_response(200) - self.send_header("Content-Type", "text/plain; charset=utf-8") - self.end_headers() - self.wfile.write(body.encode("utf-8")) - - def log_message(self, fmt, *args): - return - - -if __name__ == "__main__": - threading.Thread(target=run_training, daemon=True).start() - HTTPServer(("0.0.0.0", PORT), Handler).serve_forever() -''' - - -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument("--repo", required=True, help="HF Space repo, e.g. user/wjad-sandbox") - parser.add_argument("--gpu", default="a10g-small", help="HF Spaces 硬件,默认 a10g-small(省 GPU 小时)") - parser.add_argument("--private", action="store_true") - parser.add_argument( - "--mode", - choices=["smoke", "real_data"], - default="smoke", - help="smoke=随机张量;real_data=拉真实标签+占位视频跑 trainer", - ) - args = parser.parse_args() - - api = HfApi() - print(f"[push_to_sandbox] 创建 / 复用 Space: {args.repo} (GPU={args.gpu}, mode={args.mode})") - create_repo( - args.repo, - repo_type="space", - space_sdk="docker", - space_hardware=args.gpu, - private=args.private, - exist_ok=True, - ) - - # 把 SANDBOX_MODE 写到 Space 变量;HF_TOKEN 需要用户自己在 Space Settings - # -> Secrets 里加一份能访问 NVIDIA 数据集的 token(real_data 模式必须)。 - api.add_space_variable(repo_id=args.repo, key="SANDBOX_MODE", value=args.mode) - if args.mode == "real_data": - print( - "[push_to_sandbox] 提醒:real_data 模式需要在 Space Settings -> Secrets " - "里手动添加 HF_TOKEN(必须是能访问 nvidia/PhysicalAI-Autonomous-Vehicle-" - "Cosmos-Drive-Dreams 的账号 token,否则 download.py 会拒绝访问)。" - ) - - # 落盘 Dockerfile / app.py - (ROOT / "Dockerfile").write_text(DOCKERFILE, encoding="utf-8") - (ROOT / "app.py").write_text(APP_PY, encoding="utf-8") - - print("[push_to_sandbox] 上传仓库(排除 .venv / data / 缓存)...") - api.upload_folder( - folder_path=str(ROOT), - repo_id=args.repo, - repo_type="space", - ignore_patterns=[ - ".venv/*", - "data/*", - "**/__pycache__/*", - "*.pyc", - "agent-tools/*", - ".git/*", - ], - ) - print(f"[push_to_sandbox] OK -> https://huggingface.co/spaces/{args.repo}") - - -if __name__ == "__main__": - main() +"""推送代码到 HF Space,做 sandbox 微训练。 + +依据 ``estimate_memory.py`` 的估算: + - BS=8 + bf16 + PCGrad + GradNorm 需要 ≥34 GB 显存; + - 默认硬件 **a10g-small**(~24 GB):与 ``smoke_train`` / ``sandbox_real_data`` 的 tiny 设置一致; + - 要拉满 BS=8 可改用 ``--gpu a10g-large`` 或 A100。 + +本脚本: + 1. ``huggingface_hub.create_repo`` 在 HF 上创建(或复用)一个 Space, + Space SDK = ``docker``; + 2. 用 ``upload_folder`` 上传当前仓库(排除 ``.venv``、数据集等); + 3. 写入 ``Dockerfile`` + ``app.py``(在 Space 启动时跑微训练)。 + +要求:先在本地 ``hf auth login``。 +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +from huggingface_hub import HfApi, create_repo + +ROOT = Path(__file__).resolve().parent.parent +DOCKERFILE = """FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04 + +ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 +RUN apt-get update && apt-get install -y --no-install-recommends \\ + python3 python3-pip python3-venv ffmpeg libgl1 libglib2.0-0 git \\ + && rm -rf /var/lib/apt/lists/* + +# HF Space 默认用户(避免权限问题) +RUN useradd -m -u 1000 user +USER user +ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH + +WORKDIR /app +COPY --chown=user pyproject.toml /app/ +COPY --chown=user src /app/src +COPY --chown=user scripts /app/scripts +COPY --chown=user configs /app/configs +COPY --chown=user dinov3-vitb16-pretrain-lvd1689m /app/dinov3-vitb16-pretrain-lvd1689m +COPY --chown=user app.py /app/app.py + +RUN python3 -m pip install --user --no-cache-dir --upgrade pip \\ + && python3 -m pip install --user --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cu124 \\ + && python3 -m pip install --user --no-cache-dir -e . + +EXPOSE 7860 +CMD ["python3", "app.py"] +""" + +APP_PY = '''"""HF Sandbox 入口(docker SDK,监听 7860)。 + +启动后: + 1. 后台进程跑 scripts/smoke_train.py(追加写入 /tmp/wjad.log) + 2. 主进程开 HTTP server on :7860,返回最新日志 + +阶段 A(无需数据):smoke_train 用随机张量验证 GPU 上的 forward/反传/AMP/PCGrad。 +阶段 B(需要数据):把 LAUNCH_CMD 改为 runner_local 的真实训练命令。 +""" +import os +import subprocess +import sys +import threading +from http.server import BaseHTTPRequestHandler, HTTPServer + +LOG_PATH = "/tmp/wjad.log" +PORT = 7860 +# 当 SANDBOX_MODE=real_data 时跑真实标签 + 占位视频;否则跑随机张量 smoke。 +_MODE = os.environ.get("SANDBOX_MODE", "smoke") +if _MODE == "real_data": + LAUNCH_CMD = [sys.executable, "scripts/sandbox_real_data.py"] +else: + LAUNCH_CMD = [sys.executable, "scripts/smoke_train.py"] + + +def _print_env(f): + f.write("=" * 72 + "\\n") + f.write(" WJAD HF Sandbox\\n") + f.write("=" * 72 + "\\n") + f.write(f"Python: {sys.version}\\n") + try: + import torch + f.write(f"torch: {torch.__version__} cuda_avail={torch.cuda.is_available()}\\n") + if torch.cuda.is_available(): + p = torch.cuda.get_device_properties(0) + f.write(f"device: {p.name} vram={p.total_memory / 1024**3:.2f} GB\\n") + except Exception as e: + f.write(f"torch import failed: {e}\\n") + f.flush() + + +def run_training(): + with open(LOG_PATH, "w", buffering=1) as f: + _print_env(f) + f.write(f"$ {' '.join(LAUNCH_CMD)}\\n") + f.flush() + p = subprocess.Popen( + LAUNCH_CMD, stdout=f, stderr=subprocess.STDOUT, cwd="/app" + ) + rc = p.wait() + f.write(f"\\n[exit code = {rc}]\\n") + + +class Handler(BaseHTTPRequestHandler): + def do_GET(self): + try: + with open(LOG_PATH, "r") as f: + body = f.read() + except FileNotFoundError: + body = "starting..." + self.send_response(200) + self.send_header("Content-Type", "text/plain; charset=utf-8") + self.end_headers() + self.wfile.write(body.encode("utf-8")) + + def log_message(self, fmt, *args): + return + + +if __name__ == "__main__": + threading.Thread(target=run_training, daemon=True).start() + HTTPServer(("0.0.0.0", PORT), Handler).serve_forever() +''' + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--repo", required=True, help="HF Space repo, e.g. user/wjad-sandbox") + parser.add_argument("--gpu", default="a10g-small", help="HF Spaces 硬件,默认 a10g-small(省 GPU 小时)") + parser.add_argument("--private", action="store_true") + parser.add_argument( + "--mode", + choices=["smoke", "real_data"], + default="smoke", + help="smoke=随机张量;real_data=拉真实标签+占位视频跑 trainer", + ) + args = parser.parse_args() + + api = HfApi() + print(f"[push_to_sandbox] 创建 / 复用 Space: {args.repo} (GPU={args.gpu}, mode={args.mode})") + create_repo( + args.repo, + repo_type="space", + space_sdk="docker", + space_hardware=args.gpu, + private=args.private, + exist_ok=True, + ) + + # 把 SANDBOX_MODE 写到 Space 变量;HF_TOKEN 需要用户自己在 Space Settings + # -> Secrets 里加一份能访问 NVIDIA 数据集的 token(real_data 模式必须)。 + api.add_space_variable(repo_id=args.repo, key="SANDBOX_MODE", value=args.mode) + if args.mode == "real_data": + print( + "[push_to_sandbox] 提醒:real_data 模式需要在 Space Settings -> Secrets " + "里手动添加 HF_TOKEN(必须是能访问 nvidia/PhysicalAI-Autonomous-Vehicle-" + "Cosmos-Drive-Dreams 的账号 token,否则 download.py 会拒绝访问)。" + ) + + # 落盘 Dockerfile / app.py + (ROOT / "Dockerfile").write_text(DOCKERFILE, encoding="utf-8") + (ROOT / "app.py").write_text(APP_PY, encoding="utf-8") + + print("[push_to_sandbox] 上传仓库(排除 .venv / data / 缓存)...") + api.upload_folder( + folder_path=str(ROOT), + repo_id=args.repo, + repo_type="space", + ignore_patterns=[ + ".venv/*", + "data/*", + "**/__pycache__/*", + "*.pyc", + "agent-tools/*", + ".git/*", + ], + ) + print(f"[push_to_sandbox] OK -> https://huggingface.co/spaces/{args.repo}") + + +if __name__ == "__main__": + main() diff --git a/scripts/sandbox_real_data.py b/scripts/sandbox_real_data.py index 6f6c2760e345394f05ce495d2d050b9b674317a6..6b9296b61b7f60918b3d7eb8f85c2f3a3ace60f7 100644 --- a/scripts/sandbox_real_data.py +++ b/scripts/sandbox_real_data.py @@ -1,236 +1,236 @@ -"""Sandbox 真实数据微验证脚本。 - -由于 NVIDIA Cosmos-Drive-Dreams 数据集的 ``cosmos_synthetic`` 是一份切成 17 -个分卷(共 ~700 GB)的 ``split`` 二进制,单独下载某一分卷无法解压出 mp4。 -因此本脚本采用混合方案: - - 1. 用官方 ``download.py --file_types lidar --limit 1`` 拉下 1 个 clip 的 - 全部真实标签(所有 common 文件夹 + lidar_raw),约 50-200 MB; - 2. 把每个 ``.tar`` 解压到 ``labels/{clip_id}/{folder}/`` 结构,匹配 - ``wjad.data.cosmos_dataset`` 期待的布局; - 3. 用 ``imageio`` 合成一个随机噪声 mp4 占位真实合成视频 - (文件名 ``{clip_id}_{chunk_id}_Sunny.mp4``,121 帧,分辨率 1024×768); - 4. 调用 ``wjad.train.runner_local --tiny --max_steps 4`` 跑 4 步真实标签 + - 伪造视觉的训练。 - -这样能验证: - - 数据集索引(``build_clip_index``) - - 标签解析(``all_object_info`` JSON、SE(3) pose、f-theta 内参) - - LIDAR 加载与遮挡过滤 - - Hungarian 匹配 + DETR loss - - 端到端 forward / GradNorm / PCGrad / 反传 - -但不会验证 DINOv3 在真实图像上的语义提取(视觉是噪声,不会收敛)。 -""" - -from __future__ import annotations - -import os -import shutil -import subprocess -import sys -import tarfile -import urllib.request -from pathlib import Path - -ROOT = Path(__file__).resolve().parent.parent -sys.path.insert(0, str(ROOT / "src")) - -DATA_ROOT = Path(os.environ.get("WJAD_DATA_ROOT", ROOT / "data" / "cosmos")) -NV_DOWNLOAD_URL = ( - "https://raw.githubusercontent.com/nv-tlabs/Cosmos-Drive-Dreams/main/scripts/download.py" -) - - -def _print_section(title: str) -> None: - bar = "=" * 60 - print(f"\n{bar}\n{title}\n{bar}", flush=True) - - -def step1_download_labels() -> None: - """用 NVIDIA 官方脚本下载 1 个 clip 的标签 + lidar。""" - _print_section("STEP 1 下载真实标签(1 个 clip)") - DATA_ROOT.mkdir(parents=True, exist_ok=True) - nv_script = DATA_ROOT / ".nvidia_download.py" - if not nv_script.exists(): - print(f"[download] 取 NVIDIA download.py -> {nv_script}", flush=True) - with urllib.request.urlopen(NV_DOWNLOAD_URL) as r, open(nv_script, "wb") as f: - f.write(r.read()) - # 同时拉 lidar + hdmap:``hdmap`` 类别会触发 9 个 3d_* 文件夹下载, - # 配合 common 文件夹一起拿,覆盖动态 + 结构化两类标签。 - cmd = [ - sys.executable, - str(nv_script), - "--odir", str(DATA_ROOT), - "--file_types", "lidar,hdmap", - "--workers", "4", - "--limit", "1", - ] - print(f"$ {' '.join(cmd)}", flush=True) - rc = subprocess.call(cmd) - if rc != 0: - sys.exit(f"download.py 失败 rc={rc}") - - -def _hoist_single_subdir(out_dir: Path) -> None: - """若解压结果仅为「单个子目录、顶层无文件」,把子目录内容抬到 out_dir(常见 tar 布局)。""" - if not out_dir.is_dir(): - return - subs = [p for p in out_dir.iterdir() if p.is_dir()] - files = [p for p in out_dir.iterdir() if p.is_file()] - if len(subs) == 1 and not files: - child = subs[0] - for item in child.iterdir(): - dest = out_dir / item.name - if dest.exists(): - continue - shutil.move(str(item), str(dest)) - try: - child.rmdir() - except OSError: - pass - - -def step2_reorganize_labels() -> str: - """把每个 common 文件夹的 .tar 解压到 ``labels/{clip_id}/{folder}/``。 - - 返回挑选出的 ``clip_id``(去掉 ``_{start}_{end}`` 后缀)。 - """ - _print_section("STEP 2 解压标签到 labels// 布局") - - common_folders = [ - "all_object_info", - "captions", - "car_mask_coarse", - "ftheta_intrinsic", - "pinhole_intrinsic", - "pose", - "vehicle_pose", - "lidar_raw", - # HDMap 9 类 - "3d_lanes", - "3d_lanelines", - "3d_road_boundaries", - "3d_wait_lines", - "3d_crosswalks", - "3d_road_markings", - "3d_poles", - "3d_traffic_lights", - "3d_traffic_signs", - ] - - clip_id_full: str | None = None # {clip_id}_{start}_{end} - clip_id: str | None = None - - for folder in common_folders: - src = DATA_ROOT / folder - if not src.exists(): - print(f" - skip {folder} (not downloaded)", flush=True) - continue - tars = sorted(src.glob("*.tar")) - if not tars: - print(f" - skip {folder} (no .tar)", flush=True) - continue - if clip_id_full is None: - clip_id_full = tars[0].stem - clip_id = clip_id_full.rsplit("_", 2)[0] - print(f" -> chosen clip_id_full = {clip_id_full}", flush=True) - print(f" -> video / symlink clip_id = {clip_id}", flush=True) - use_tars = [t for t in tars if t.stem == clip_id_full] - if not use_tars: - print(f" - skip {folder}: 无与 {clip_id_full} 同名的 tar(避免解压错 clip)", flush=True) - continue - tar_path = use_tars[0] - # 目标目录 - out_dir = DATA_ROOT / "labels" / clip_id_full / folder - out_dir.mkdir(parents=True, exist_ok=True) - with tarfile.open(tar_path, "r") as tf: - tf.extractall(out_dir) - _hoist_single_subdir(out_dir) - # 若仍嵌套一层 modality 名(ftheta_intrinsic/ftheta_intrinsic/...) - _hoist_single_subdir(out_dir) - # 列几个样例 - members = sorted(out_dir.rglob("*"))[:3] - for m in members: - print(f" {m.relative_to(DATA_ROOT)}", flush=True) - print(f" - {folder}: {len(list(out_dir.rglob('*')))} files", flush=True) - - if clip_id_full is None: - sys.exit("没有下到任何标签 tar,确认 HF_TOKEN 是否能访问 NVIDIA 数据集") - - # 兼容 cosmos_dataset.py:它从 labels/{clip_id}/ 读,但实际下载用的是 - # {clip_id}_{start}_{end} 作为目录名。这里软链一份名为纯 clip_id 的目录。 - short_dir = DATA_ROOT / "labels" / clip_id # type: ignore[arg-type] - if not short_dir.exists(): - try: - short_dir.symlink_to(DATA_ROOT / "labels" / clip_id_full, target_is_directory=True) - except OSError: - shutil.copytree(DATA_ROOT / "labels" / clip_id_full, short_dir) - return clip_id # type: ignore[return-value] - - -def step3_make_fake_video(clip_id: str) -> None: - """合成 121 帧随机 mp4 模拟 ``cosmos_synthetic`` 视频。""" - _print_section("STEP 3 合成占位视频(随机噪声 mp4)") - import numpy as np - import cv2 - - syn_dir = DATA_ROOT / "synthetic" / "single_view" / "generation" - syn_dir.mkdir(parents=True, exist_ok=True) - out_path = syn_dir / f"{clip_id}_0_Sunny.mp4" - - H, W, T = 768, 1024, 121 # 顶部裁剪后 384,原始 768 - rng = np.random.default_rng(0) - fourcc = cv2.VideoWriter_fourcc(*"mp4v") - writer = cv2.VideoWriter(str(out_path), fourcc, 30.0, (W, H)) - if not writer.isOpened(): - sys.exit(f"无法打开 mp4 写入器(缺 codec?): {out_path}") - for _ in range(T): - frame = rng.integers(0, 256, size=(H, W, 3), dtype=np.uint8) - writer.write(frame) - writer.release() - print(f" 写入 {out_path} ({out_path.stat().st_size / 1024**2:.1f} MB)", flush=True) - - -def step4_run_trainer(clip_id: str) -> None: - """跑 runner_local --tiny --max_steps 4。""" - _print_section("STEP 4 跑 trainer(真实标签 + 伪造视觉)") - cmd = [ - sys.executable, - "-m", - "wjad.train.runner_local", - "--config", str(ROOT / "configs" / "default.yaml"), - "--data_root", str(DATA_ROOT), - "--dinov3_path", str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"), - "--device", "cuda" if _has_cuda() else "cpu", - "--tiny", - "--max_steps", "4", - ] - env = os.environ.copy() - env["PYTHONPATH"] = str(ROOT / "src") + os.pathsep + env.get("PYTHONPATH", "") - print(f"$ {' '.join(cmd)}", flush=True) - rc = subprocess.call(cmd, env=env) - if rc != 0: - sys.exit(f"trainer 失败 rc={rc}") - - -def _has_cuda() -> bool: - try: - import torch - return torch.cuda.is_available() - except Exception: - return False - - -def main() -> None: - _print_section("WJAD Sandbox Real-Data Tiny Test") - print(f"DATA_ROOT = {DATA_ROOT}", flush=True) - step1_download_labels() - clip_id = step2_reorganize_labels() - step3_make_fake_video(clip_id) - step4_run_trainer(clip_id) - _print_section("DONE") - - -if __name__ == "__main__": - main() +"""Sandbox 真实数据微验证脚本。 + +由于 NVIDIA Cosmos-Drive-Dreams 数据集的 ``cosmos_synthetic`` 是一份切成 17 +个分卷(共 ~700 GB)的 ``split`` 二进制,单独下载某一分卷无法解压出 mp4。 +因此本脚本采用混合方案: + + 1. 用官方 ``download.py --file_types lidar --limit 1`` 拉下 1 个 clip 的 + 全部真实标签(所有 common 文件夹 + lidar_raw),约 50-200 MB; + 2. 把每个 ``.tar`` 解压到 ``labels/{clip_id}/{folder}/`` 结构,匹配 + ``wjad.data.cosmos_dataset`` 期待的布局; + 3. 用 ``imageio`` 合成一个随机噪声 mp4 占位真实合成视频 + (文件名 ``{clip_id}_{chunk_id}_Sunny.mp4``,121 帧,分辨率 1024×768); + 4. 调用 ``wjad.train.runner_local --tiny --max_steps 4`` 跑 4 步真实标签 + + 伪造视觉的训练。 + +这样能验证: + - 数据集索引(``build_clip_index``) + - 标签解析(``all_object_info`` JSON、SE(3) pose、f-theta 内参) + - LIDAR 加载与遮挡过滤 + - Hungarian 匹配 + DETR loss + - 端到端 forward / GradNorm / PCGrad / 反传 + +但不会验证 DINOv3 在真实图像上的语义提取(视觉是噪声,不会收敛)。 +""" + +from __future__ import annotations + +import os +import shutil +import subprocess +import sys +import tarfile +import urllib.request +from pathlib import Path + +ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(ROOT / "src")) + +DATA_ROOT = Path(os.environ.get("WJAD_DATA_ROOT", ROOT / "data" / "cosmos")) +NV_DOWNLOAD_URL = ( + "https://raw.githubusercontent.com/nv-tlabs/Cosmos-Drive-Dreams/main/scripts/download.py" +) + + +def _print_section(title: str) -> None: + bar = "=" * 60 + print(f"\n{bar}\n{title}\n{bar}", flush=True) + + +def step1_download_labels() -> None: + """用 NVIDIA 官方脚本下载 1 个 clip 的标签 + lidar。""" + _print_section("STEP 1 下载真实标签(1 个 clip)") + DATA_ROOT.mkdir(parents=True, exist_ok=True) + nv_script = DATA_ROOT / ".nvidia_download.py" + if not nv_script.exists(): + print(f"[download] 取 NVIDIA download.py -> {nv_script}", flush=True) + with urllib.request.urlopen(NV_DOWNLOAD_URL) as r, open(nv_script, "wb") as f: + f.write(r.read()) + # 同时拉 lidar + hdmap:``hdmap`` 类别会触发 9 个 3d_* 文件夹下载, + # 配合 common 文件夹一起拿,覆盖动态 + 结构化两类标签。 + cmd = [ + sys.executable, + str(nv_script), + "--odir", str(DATA_ROOT), + "--file_types", "lidar,hdmap", + "--workers", "4", + "--limit", "1", + ] + print(f"$ {' '.join(cmd)}", flush=True) + rc = subprocess.call(cmd) + if rc != 0: + sys.exit(f"download.py 失败 rc={rc}") + + +def _hoist_single_subdir(out_dir: Path) -> None: + """若解压结果仅为「单个子目录、顶层无文件」,把子目录内容抬到 out_dir(常见 tar 布局)。""" + if not out_dir.is_dir(): + return + subs = [p for p in out_dir.iterdir() if p.is_dir()] + files = [p for p in out_dir.iterdir() if p.is_file()] + if len(subs) == 1 and not files: + child = subs[0] + for item in child.iterdir(): + dest = out_dir / item.name + if dest.exists(): + continue + shutil.move(str(item), str(dest)) + try: + child.rmdir() + except OSError: + pass + + +def step2_reorganize_labels() -> str: + """把每个 common 文件夹的 .tar 解压到 ``labels/{clip_id}/{folder}/``。 + + 返回挑选出的 ``clip_id``(去掉 ``_{start}_{end}`` 后缀)。 + """ + _print_section("STEP 2 解压标签到 labels// 布局") + + common_folders = [ + "all_object_info", + "captions", + "car_mask_coarse", + "ftheta_intrinsic", + "pinhole_intrinsic", + "pose", + "vehicle_pose", + "lidar_raw", + # HDMap 9 类 + "3d_lanes", + "3d_lanelines", + "3d_road_boundaries", + "3d_wait_lines", + "3d_crosswalks", + "3d_road_markings", + "3d_poles", + "3d_traffic_lights", + "3d_traffic_signs", + ] + + clip_id_full: str | None = None # {clip_id}_{start}_{end} + clip_id: str | None = None + + for folder in common_folders: + src = DATA_ROOT / folder + if not src.exists(): + print(f" - skip {folder} (not downloaded)", flush=True) + continue + tars = sorted(src.glob("*.tar")) + if not tars: + print(f" - skip {folder} (no .tar)", flush=True) + continue + if clip_id_full is None: + clip_id_full = tars[0].stem + clip_id = clip_id_full.rsplit("_", 2)[0] + print(f" -> chosen clip_id_full = {clip_id_full}", flush=True) + print(f" -> video / symlink clip_id = {clip_id}", flush=True) + use_tars = [t for t in tars if t.stem == clip_id_full] + if not use_tars: + print(f" - skip {folder}: 无与 {clip_id_full} 同名的 tar(避免解压错 clip)", flush=True) + continue + tar_path = use_tars[0] + # 目标目录 + out_dir = DATA_ROOT / "labels" / clip_id_full / folder + out_dir.mkdir(parents=True, exist_ok=True) + with tarfile.open(tar_path, "r") as tf: + tf.extractall(out_dir) + _hoist_single_subdir(out_dir) + # 若仍嵌套一层 modality 名(ftheta_intrinsic/ftheta_intrinsic/...) + _hoist_single_subdir(out_dir) + # 列几个样例 + members = sorted(out_dir.rglob("*"))[:3] + for m in members: + print(f" {m.relative_to(DATA_ROOT)}", flush=True) + print(f" - {folder}: {len(list(out_dir.rglob('*')))} files", flush=True) + + if clip_id_full is None: + sys.exit("没有下到任何标签 tar,确认 HF_TOKEN 是否能访问 NVIDIA 数据集") + + # 兼容 cosmos_dataset.py:它从 labels/{clip_id}/ 读,但实际下载用的是 + # {clip_id}_{start}_{end} 作为目录名。这里软链一份名为纯 clip_id 的目录。 + short_dir = DATA_ROOT / "labels" / clip_id # type: ignore[arg-type] + if not short_dir.exists(): + try: + short_dir.symlink_to(DATA_ROOT / "labels" / clip_id_full, target_is_directory=True) + except OSError: + shutil.copytree(DATA_ROOT / "labels" / clip_id_full, short_dir) + return clip_id # type: ignore[return-value] + + +def step3_make_fake_video(clip_id: str) -> None: + """合成 121 帧随机 mp4 模拟 ``cosmos_synthetic`` 视频。""" + _print_section("STEP 3 合成占位视频(随机噪声 mp4)") + import numpy as np + import cv2 + + syn_dir = DATA_ROOT / "synthetic" / "single_view" / "generation" + syn_dir.mkdir(parents=True, exist_ok=True) + out_path = syn_dir / f"{clip_id}_0_Sunny.mp4" + + H, W, T = 768, 1024, 121 # 顶部裁剪后 384,原始 768 + rng = np.random.default_rng(0) + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + writer = cv2.VideoWriter(str(out_path), fourcc, 30.0, (W, H)) + if not writer.isOpened(): + sys.exit(f"无法打开 mp4 写入器(缺 codec?): {out_path}") + for _ in range(T): + frame = rng.integers(0, 256, size=(H, W, 3), dtype=np.uint8) + writer.write(frame) + writer.release() + print(f" 写入 {out_path} ({out_path.stat().st_size / 1024**2:.1f} MB)", flush=True) + + +def step4_run_trainer(clip_id: str) -> None: + """跑 runner_local --tiny --max_steps 4。""" + _print_section("STEP 4 跑 trainer(真实标签 + 伪造视觉)") + cmd = [ + sys.executable, + "-m", + "wjad.train.runner_local", + "--config", str(ROOT / "configs" / "default.yaml"), + "--data_root", str(DATA_ROOT), + "--dinov3_path", str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"), + "--device", "cuda" if _has_cuda() else "cpu", + "--tiny", + "--max_steps", "4", + ] + env = os.environ.copy() + env["PYTHONPATH"] = str(ROOT / "src") + os.pathsep + env.get("PYTHONPATH", "") + print(f"$ {' '.join(cmd)}", flush=True) + rc = subprocess.call(cmd, env=env) + if rc != 0: + sys.exit(f"trainer 失败 rc={rc}") + + +def _has_cuda() -> bool: + try: + import torch + return torch.cuda.is_available() + except Exception: + return False + + +def main() -> None: + _print_section("WJAD Sandbox Real-Data Tiny Test") + print(f"DATA_ROOT = {DATA_ROOT}", flush=True) + step1_download_labels() + clip_id = step2_reorganize_labels() + step3_make_fake_video(clip_id) + step4_run_trainer(clip_id) + _print_section("DONE") + + +if __name__ == "__main__": + main() diff --git a/scripts/smoke_test.py b/scripts/smoke_test.py index 7b37e55b30e10955884c3cba519c662e7019e30e..a6f9296f1a6e324f3da8b67539dadac6a4b6080c 100644 --- a/scripts/smoke_test.py +++ b/scripts/smoke_test.py @@ -1,78 +1,78 @@ -"""本地烟囱测试:用随机张量验证 forward + backward。 - -运行: - python -m scripts.smoke_test -或: - python scripts/smoke_test.py -""" - -from __future__ import annotations - -import sys -from pathlib import Path - -# 允许直接运行 -ROOT = Path(__file__).resolve().parent.parent -sys.path.insert(0, str(ROOT / "src")) - -import torch - -from wjad.model import E2EAVModel - - -def main() -> None: - torch.manual_seed(0) - device = "cpu" - - print("[smoke_test] 构建模型...") - model = E2EAVModel( - dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"), - # 减小测试规模以适配 CPU - num_dense_layers=2, - num_moe_layers=2, - num_detection_tokens=64, - num_extra_tokens=32, - num_classes=22, - ).to(device) - - # 切到 sparse 验证 Top-3 路径也通 - model.backbone.set_moe_mode("sparse") - - B, T = 1, 8 - images = torch.randn(B, T, 3, 384, 1024, device=device) - ego_6d = torch.zeros(B, T, 6, device=device) - ego_6d[..., 0] = torch.linspace(0, 7, T) # 模拟前进 - intr_vec = torch.tensor([[ - 512.0, 192.0, 1024, 384, # cx, cy, w, h - 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, # poly - 1.0, # is_bw_poly(与 Cosmos 11 维一致) - ]], device=device) - extr_6d = torch.zeros(B, 6, device=device) - - print("[smoke_test] 前向...") - out = model(images, ego_6d, intr_vec, extr_6d) - print(f" detection cls: {tuple(out.detection.cls_logits.shape)}") - print(f" detection box_mu: {tuple(out.detection.box3d_mu.shape)}") - print(f" detection traj_mu: {tuple(out.detection.traj_mu.shape)}") - print(f" control ego_traj_mu: {tuple(out.control.ego_traj_mu.shape)}") - print(f" control action_mu: {tuple(out.control.action_mu.shape)}") - print(f" moe_stats per layer: {len(out.backbone_out.moe_stats)}") - - # 简单 backward:用 cls + box_mu 和 + ego_traj_mu 的简单 loss - loss = ( - out.detection.cls_logits.float().abs().mean() - + out.detection.box3d_mu.float().abs().mean() - + out.detection.traj_mu.float().abs().mean() - + out.control.ego_traj_mu.float().abs().mean() - ) - print(f"[smoke_test] loss = {loss.item():.6f}") - loss.backward() - grad_norm = sum( - p.grad.detach().norm().item() for p in model.parameters() if p.grad is not None - ) - print(f"[smoke_test] grad sum-of-norms = {grad_norm:.4f}") - print("[smoke_test] OK") - - -if __name__ == "__main__": - main() +"""本地烟囱测试:用随机张量验证 forward + backward。 + +运行: + python -m scripts.smoke_test +或: + python scripts/smoke_test.py +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +# 允许直接运行 +ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(ROOT / "src")) + +import torch + +from wjad.model import E2EAVModel + + +def main() -> None: + torch.manual_seed(0) + device = "cpu" + + print("[smoke_test] 构建模型...") + model = E2EAVModel( + dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"), + # 减小测试规模以适配 CPU + num_dense_layers=2, + num_moe_layers=2, + num_detection_tokens=64, + num_extra_tokens=32, + num_classes=22, + ).to(device) + + # 切到 sparse 验证 Top-3 路径也通 + model.backbone.set_moe_mode("sparse") + + B, T = 1, 8 + images = torch.randn(B, T, 3, 384, 1024, device=device) + ego_6d = torch.zeros(B, T, 6, device=device) + ego_6d[..., 0] = torch.linspace(0, 7, T) # 模拟前进 + intr_vec = torch.tensor([[ + 512.0, 192.0, 1024, 384, # cx, cy, w, h + 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, # poly + 1.0, # is_bw_poly(与 Cosmos 11 维一致) + ]], device=device) + extr_6d = torch.zeros(B, 6, device=device) + + print("[smoke_test] 前向...") + out = model(images, ego_6d, intr_vec, extr_6d) + print(f" detection cls: {tuple(out.detection.cls_logits.shape)}") + print(f" detection box_mu: {tuple(out.detection.box3d_mu.shape)}") + print(f" detection traj_mu: {tuple(out.detection.traj_mu.shape)}") + print(f" control ego_traj_mu: {tuple(out.control.ego_traj_mu.shape)}") + print(f" control action_mu: {tuple(out.control.action_mu.shape)}") + print(f" moe_stats per layer: {len(out.backbone_out.moe_stats)}") + + # 简单 backward:用 cls + box_mu 和 + ego_traj_mu 的简单 loss + loss = ( + out.detection.cls_logits.float().abs().mean() + + out.detection.box3d_mu.float().abs().mean() + + out.detection.traj_mu.float().abs().mean() + + out.control.ego_traj_mu.float().abs().mean() + ) + print(f"[smoke_test] loss = {loss.item():.6f}") + loss.backward() + grad_norm = sum( + p.grad.detach().norm().item() for p in model.parameters() if p.grad is not None + ) + print(f"[smoke_test] grad sum-of-norms = {grad_norm:.4f}") + print("[smoke_test] OK") + + +if __name__ == "__main__": + main() diff --git a/scripts/smoke_train.py b/scripts/smoke_train.py index 9584ad758514b7cd480cf17ccb036b035f9f7a7f..514543d7675a5b11759922e7ad08a8d6b41bdaa4 100644 --- a/scripts/smoke_train.py +++ b/scripts/smoke_train.py @@ -1,152 +1,152 @@ -"""端到端训练循环烟囱测试:构造随机 batch,跑 1-2 步 trainer。 - -不依赖磁盘上的数据集,仅验证 forward/backward/loss/PCGrad/GradNorm 链路。 -""" - -from __future__ import annotations - -import os -import sys -from pathlib import Path - -os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") - -ROOT = Path(__file__).resolve().parent.parent -sys.path.insert(0, str(ROOT / "src")) - -import logging - -import numpy as np -import torch - -from wjad.model import E2EAVModel -from wjad.train.trainer import Trainer, TrainerConfig - - -def _make_dummy_batch( - B: int = 1, - T: int = 8, - H: int = 64, - W: int = 128, - num_classes: int = 22, - num_objects: int = 3, -) -> dict: - """构造极小分辨率的随机 batch(CPU 烟囱测试用)。""" - images = torch.randn(B, T, 3, H, W) - ego_6d = torch.zeros(B, T, 6) - intr_vec = torch.tensor([[ - W / 2, H / 2, W, H, - 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, - 1.0, - ]] * B) - extr_6d = torch.zeros(B, 6) - ego_future = torch.zeros(B, 24, 3) - ego_future_valid = torch.ones(B, 24, dtype=torch.bool) - - targets = [] - for _ in range(B): - boxes = torch.zeros(num_objects, 7) - boxes[:, 3:6] = 2.0 - targets.append({ - "labels": torch.randint(1, num_classes, (num_objects,)), - "boxes": boxes, - "is_dynamic": torch.ones(num_objects, dtype=torch.long), - "future_traj": torch.zeros(num_objects, 24, 3), - "future_valid": torch.ones(num_objects, 24, dtype=torch.bool), - }) - - return { - "images": images, - "ego_6d": ego_6d, - "intr_vec": intr_vec, - "extr_6d": extr_6d, - "ego_future": ego_future, - "ego_future_valid": ego_future_valid, - "targets": targets, - "meta": [{}] * B, - } - - -def main() -> None: - logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") - torch.manual_seed(0) - - has_cuda = torch.cuda.is_available() - device = "cuda" if has_cuda else "cpu" - if has_cuda: - # GPU 上跑接近真实规模:full 384x1024 + 完整 18 层 - # a10g-small (~22 GB) 上 BS=4 OOM,启用 gradient_checkpointing 后 BS=2 稳定 - H, W = 384, 1024 - B = 2 - num_dense, num_moe = 9, 9 - num_det = 1024 - num_extra = 256 - amp = "bf16" - n_steps = 4 - use_grad_ckpt = True - else: - # CPU 上跑极小规模仅做 sanity - H, W = 64, 128 - B = 1 - num_dense, num_moe = 2, 2 - num_det = 32 - num_extra = 16 - amp = "fp32" - n_steps = 4 - use_grad_ckpt = False - - print(f"[smoke_train] device={device}, H={H} W={W} B={B} amp={amp} grad_ckpt={use_grad_ckpt}") - model = E2EAVModel( - dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"), - num_dense_layers=num_dense, - num_moe_layers=num_moe, - num_detection_tokens=num_det, - num_control_tokens=24, - num_ego_tokens=8, - num_extra_tokens=num_extra, - num_classes=22, - image_h=H, - image_w=W, - patch_size=16, - ) - if use_grad_ckpt: - model.backbone.set_gradient_checkpointing(True) - # sandbox a10g-small 不做 DINOv3 finetune(显存预算 22GB 不够),冻结即可 - # 验证两阶段路径切换。完整训练交给 H100 Jobs。 - model.dinov3.freeze() - - cfg = TrainerConfig( - total_steps=n_steps, - warmup_steps=1, - base_lr=1e-4, - log_interval=1, - stage1_steps=2, # 跑到 stage2 验证切换路径 - stage1_perturb_start=1, - enable_gradnorm=True, - enable_pcgrad=True, # 全程启用 PCGrad - mixed_precision=amp, - unfreeze_dinov3_at_stage2=False, # sandbox 显存有限,验证路径即可 - ) - trainer = Trainer(model, cfg, num_classes=22, device=device) - rng = np.random.default_rng(0) - - if has_cuda: - torch.cuda.reset_peak_memory_stats() - - for step in range(n_steps): - batch = _make_dummy_batch(B=B, H=H, W=W) - info = trainer.train_step(batch, rng) - print( - f"step={info['step']} stage={info['stage']} total={info['total_loss']:.4f} " - f"cls={info['L_cls']:.4f} box={info['L_box']:.4f} traj_obj={info['L_traj_obj']:.4f} " - f"weights={[f'{w:.2f}' for w in info['weights']]}" - ) - - if has_cuda: - peak_gb = torch.cuda.max_memory_allocated() / 1024**3 - print(f"[smoke_train] CUDA peak memory = {peak_gb:.2f} GB") - print("[smoke_train] OK") - - -if __name__ == "__main__": - main() +"""端到端训练循环烟囱测试:构造随机 batch,跑 1-2 步 trainer。 + +不依赖磁盘上的数据集,仅验证 forward/backward/loss/PCGrad/GradNorm 链路。 +""" + +from __future__ import annotations + +import os +import sys +from pathlib import Path + +os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") + +ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(ROOT / "src")) + +import logging + +import numpy as np +import torch + +from wjad.model import E2EAVModel +from wjad.train.trainer import Trainer, TrainerConfig + + +def _make_dummy_batch( + B: int = 1, + T: int = 8, + H: int = 64, + W: int = 128, + num_classes: int = 22, + num_objects: int = 3, +) -> dict: + """构造极小分辨率的随机 batch(CPU 烟囱测试用)。""" + images = torch.randn(B, T, 3, H, W) + ego_6d = torch.zeros(B, T, 6) + intr_vec = torch.tensor([[ + W / 2, H / 2, W, H, + 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, + 1.0, + ]] * B) + extr_6d = torch.zeros(B, 6) + ego_future = torch.zeros(B, 24, 3) + ego_future_valid = torch.ones(B, 24, dtype=torch.bool) + + targets = [] + for _ in range(B): + boxes = torch.zeros(num_objects, 7) + boxes[:, 3:6] = 2.0 + targets.append({ + "labels": torch.randint(1, num_classes, (num_objects,)), + "boxes": boxes, + "is_dynamic": torch.ones(num_objects, dtype=torch.long), + "future_traj": torch.zeros(num_objects, 24, 3), + "future_valid": torch.ones(num_objects, 24, dtype=torch.bool), + }) + + return { + "images": images, + "ego_6d": ego_6d, + "intr_vec": intr_vec, + "extr_6d": extr_6d, + "ego_future": ego_future, + "ego_future_valid": ego_future_valid, + "targets": targets, + "meta": [{}] * B, + } + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + torch.manual_seed(0) + + has_cuda = torch.cuda.is_available() + device = "cuda" if has_cuda else "cpu" + if has_cuda: + # GPU 上跑接近真实规模:full 384x1024 + 完整 18 层 + # a10g-small (~22 GB) 上 BS=4 OOM,启用 gradient_checkpointing 后 BS=2 稳定 + H, W = 384, 1024 + B = 2 + num_dense, num_moe = 9, 9 + num_det = 1024 + num_extra = 256 + amp = "bf16" + n_steps = 4 + use_grad_ckpt = True + else: + # CPU 上跑极小规模仅做 sanity + H, W = 64, 128 + B = 1 + num_dense, num_moe = 2, 2 + num_det = 32 + num_extra = 16 + amp = "fp32" + n_steps = 4 + use_grad_ckpt = False + + print(f"[smoke_train] device={device}, H={H} W={W} B={B} amp={amp} grad_ckpt={use_grad_ckpt}") + model = E2EAVModel( + dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"), + num_dense_layers=num_dense, + num_moe_layers=num_moe, + num_detection_tokens=num_det, + num_control_tokens=24, + num_ego_tokens=8, + num_extra_tokens=num_extra, + num_classes=22, + image_h=H, + image_w=W, + patch_size=16, + ) + if use_grad_ckpt: + model.backbone.set_gradient_checkpointing(True) + # sandbox a10g-small 不做 DINOv3 finetune(显存预算 22GB 不够),冻结即可 + # 验证两阶段路径切换。完整训练交给 H100 Jobs。 + model.dinov3.freeze() + + cfg = TrainerConfig( + total_steps=n_steps, + warmup_steps=1, + base_lr=1e-4, + log_interval=1, + stage1_steps=2, # 跑到 stage2 验证切换路径 + stage1_perturb_start=1, + enable_gradnorm=True, + enable_pcgrad=True, # 全程启用 PCGrad + mixed_precision=amp, + unfreeze_dinov3_at_stage2=False, # sandbox 显存有限,验证路径即可 + ) + trainer = Trainer(model, cfg, num_classes=22, device=device) + rng = np.random.default_rng(0) + + if has_cuda: + torch.cuda.reset_peak_memory_stats() + + for step in range(n_steps): + batch = _make_dummy_batch(B=B, H=H, W=W) + info = trainer.train_step(batch, rng) + print( + f"step={info['step']} stage={info['stage']} total={info['total_loss']:.4f} " + f"cls={info['L_cls']:.4f} box={info['L_box']:.4f} traj_obj={info['L_traj_obj']:.4f} " + f"weights={[f'{w:.2f}' for w in info['weights']]}" + ) + + if has_cuda: + peak_gb = torch.cuda.max_memory_allocated() / 1024**3 + print(f"[smoke_train] CUDA peak memory = {peak_gb:.2f} GB") + print("[smoke_train] OK") + + +if __name__ == "__main__": + main() diff --git a/scripts/update_deps.py b/scripts/update_deps.py index 4f2ddbfc277b8e785310f91ca9c7f976988f719e..1269f7826370d478d12e99a8f58356c674c2d6cb 100644 --- a/scripts/update_deps.py +++ b/scripts/update_deps.py @@ -1,123 +1,123 @@ -"""自动把项目依赖升级到 PyPI 最新版。 - -特点: - - 从 ``pyproject.toml`` 读取 ``project.dependencies`` 与 - ``project.optional-dependencies``; - - 直接调用 ``pip install --upgrade `` 把所有第三方依赖升级; - - 为 ``torch`` / ``torchvision`` / ``torchaudio`` 提供单独的 CUDA index URL - 选项(``--torch-index https://download.pytorch.org/whl/cu124``); - - 升级后调用 ``pip freeze`` 把锁定版本写入 ``requirements.lock.txt``,便于 - 在 HF Sandbox / Jobs 环境中复现。 - -注意:本脚本会修改本地 venv!若需要安全演练,加 ``--dry-run``。 -""" - -from __future__ import annotations - -import argparse -import shutil -import subprocess -import sys -from pathlib import Path - -try: - import tomllib # py3.11+ -except ImportError: # pragma: no cover - import tomli as tomllib # type: ignore - - -ROOT = Path(__file__).resolve().parent.parent -PYPROJECT = ROOT / "pyproject.toml" -LOCK_FILE = ROOT / "requirements.lock.txt" -TORCH_PKGS = {"torch", "torchvision", "torchaudio"} - - -def parse_pyproject() -> tuple[list[str], list[str]]: - """返回 (主依赖, dev 依赖) 的纯包名列表。""" - data = tomllib.loads(PYPROJECT.read_text(encoding="utf-8")) - main = [ - _strip_spec(d) for d in data.get("project", {}).get("dependencies", []) - ] - dev = [ - _strip_spec(d) - for d in data.get("project", {}) - .get("optional-dependencies", {}) - .get("dev", []) - ] - return main, dev - - -def _strip_spec(req: str) -> str: - """去掉版本约束,只留包名。""" - name = req.split(";")[0] # 去掉 environment marker - for sym in ("[", ">=", "<=", "==", "~=", ">", "<", "!=", "@"): - if sym in name: - name = name.split(sym)[0] - return name.strip() - - -def run(cmd: list[str], dry_run: bool = False) -> int: - print("$", " ".join(cmd)) - if dry_run: - return 0 - return subprocess.call(cmd) - - -def upgrade(pkgs: list[str], extra_index: str | None, dry_run: bool, with_pre: bool = False) -> None: - base = [sys.executable, "-m", "pip", "install", "--upgrade"] - if with_pre: - base.append("--pre") - if extra_index: - base += ["--extra-index-url", extra_index] - rc = run(base + pkgs, dry_run=dry_run) - if rc != 0: - sys.exit(rc) - - -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument( - "--torch-index", - default=None, - help="PyTorch CUDA wheel 索引(如 https://download.pytorch.org/whl/cu124)", - ) - parser.add_argument("--no-dev", action="store_true", help="不升级 dev 依赖") - parser.add_argument("--dry-run", action="store_true", help="只打印命令不执行") - parser.add_argument("--with-pre", action="store_true", help="允许升级到 pre-release") - args = parser.parse_args() - - if not PYPROJECT.exists(): - print(f"[update_deps] 找不到 {PYPROJECT}", file=sys.stderr) - sys.exit(1) - - main_deps, dev_deps = parse_pyproject() - - # 把 torch 系列单独处理(用专用索引) - torch_deps = [p for p in main_deps if p in TORCH_PKGS] - other_deps = [p for p in main_deps if p not in TORCH_PKGS] - - print(f"[update_deps] 升级 pip / setuptools / wheel ...") - upgrade(["pip", "setuptools", "wheel"], extra_index=None, dry_run=args.dry_run) - - if torch_deps: - print(f"[update_deps] 升级 torch 系列 ({torch_deps}) ...") - upgrade(torch_deps, extra_index=args.torch_index, dry_run=args.dry_run, with_pre=args.with_pre) - - if other_deps: - print(f"[update_deps] 升级主依赖 ({len(other_deps)} 个) ...") - upgrade(other_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre) - - if dev_deps and not args.no_dev: - print(f"[update_deps] 升级 dev 依赖 ({len(dev_deps)} 个) ...") - upgrade(dev_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre) - - print("[update_deps] 写入锁定文件 ...") - if not args.dry_run: - with open(LOCK_FILE, "w", encoding="utf-8") as f: - subprocess.run([sys.executable, "-m", "pip", "freeze"], stdout=f, check=True) - print(f"[update_deps] 锁定版本已写入 {LOCK_FILE}") - print("[update_deps] OK") - - -if __name__ == "__main__": - main() +"""自动把项目依赖升级到 PyPI 最新版。 + +特点: + - 从 ``pyproject.toml`` 读取 ``project.dependencies`` 与 + ``project.optional-dependencies``; + - 直接调用 ``pip install --upgrade `` 把所有第三方依赖升级; + - 为 ``torch`` / ``torchvision`` / ``torchaudio`` 提供单独的 CUDA index URL + 选项(``--torch-index https://download.pytorch.org/whl/cu124``); + - 升级后调用 ``pip freeze`` 把锁定版本写入 ``requirements.lock.txt``,便于 + 在 HF Sandbox / Jobs 环境中复现。 + +注意:本脚本会修改本地 venv!若需要安全演练,加 ``--dry-run``。 +""" + +from __future__ import annotations + +import argparse +import shutil +import subprocess +import sys +from pathlib import Path + +try: + import tomllib # py3.11+ +except ImportError: # pragma: no cover + import tomli as tomllib # type: ignore + + +ROOT = Path(__file__).resolve().parent.parent +PYPROJECT = ROOT / "pyproject.toml" +LOCK_FILE = ROOT / "requirements.lock.txt" +TORCH_PKGS = {"torch", "torchvision", "torchaudio"} + + +def parse_pyproject() -> tuple[list[str], list[str]]: + """返回 (主依赖, dev 依赖) 的纯包名列表。""" + data = tomllib.loads(PYPROJECT.read_text(encoding="utf-8")) + main = [ + _strip_spec(d) for d in data.get("project", {}).get("dependencies", []) + ] + dev = [ + _strip_spec(d) + for d in data.get("project", {}) + .get("optional-dependencies", {}) + .get("dev", []) + ] + return main, dev + + +def _strip_spec(req: str) -> str: + """去掉版本约束,只留包名。""" + name = req.split(";")[0] # 去掉 environment marker + for sym in ("[", ">=", "<=", "==", "~=", ">", "<", "!=", "@"): + if sym in name: + name = name.split(sym)[0] + return name.strip() + + +def run(cmd: list[str], dry_run: bool = False) -> int: + print("$", " ".join(cmd)) + if dry_run: + return 0 + return subprocess.call(cmd) + + +def upgrade(pkgs: list[str], extra_index: str | None, dry_run: bool, with_pre: bool = False) -> None: + base = [sys.executable, "-m", "pip", "install", "--upgrade"] + if with_pre: + base.append("--pre") + if extra_index: + base += ["--extra-index-url", extra_index] + rc = run(base + pkgs, dry_run=dry_run) + if rc != 0: + sys.exit(rc) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--torch-index", + default=None, + help="PyTorch CUDA wheel 索引(如 https://download.pytorch.org/whl/cu124)", + ) + parser.add_argument("--no-dev", action="store_true", help="不升级 dev 依赖") + parser.add_argument("--dry-run", action="store_true", help="只打印命令不执行") + parser.add_argument("--with-pre", action="store_true", help="允许升级到 pre-release") + args = parser.parse_args() + + if not PYPROJECT.exists(): + print(f"[update_deps] 找不到 {PYPROJECT}", file=sys.stderr) + sys.exit(1) + + main_deps, dev_deps = parse_pyproject() + + # 把 torch 系列单独处理(用专用索引) + torch_deps = [p for p in main_deps if p in TORCH_PKGS] + other_deps = [p for p in main_deps if p not in TORCH_PKGS] + + print(f"[update_deps] 升级 pip / setuptools / wheel ...") + upgrade(["pip", "setuptools", "wheel"], extra_index=None, dry_run=args.dry_run) + + if torch_deps: + print(f"[update_deps] 升级 torch 系列 ({torch_deps}) ...") + upgrade(torch_deps, extra_index=args.torch_index, dry_run=args.dry_run, with_pre=args.with_pre) + + if other_deps: + print(f"[update_deps] 升级主依赖 ({len(other_deps)} 个) ...") + upgrade(other_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre) + + if dev_deps and not args.no_dev: + print(f"[update_deps] 升级 dev 依赖 ({len(dev_deps)} 个) ...") + upgrade(dev_deps, extra_index=None, dry_run=args.dry_run, with_pre=args.with_pre) + + print("[update_deps] 写入锁定文件 ...") + if not args.dry_run: + with open(LOCK_FILE, "w", encoding="utf-8") as f: + subprocess.run([sys.executable, "-m", "pip", "freeze"], stdout=f, check=True) + print(f"[update_deps] 锁定版本已写入 {LOCK_FILE}") + print("[update_deps] OK") + + +if __name__ == "__main__": + main() diff --git a/src/wjad/__init__.py b/src/wjad/__init__.py index e390717b95d0a1ee5562be270395824b3a01ad72..676861a913481f645bb5143108d17149271da657 100644 --- a/src/wjad/__init__.py +++ b/src/wjad/__init__.py @@ -1,5 +1,5 @@ -"""WJAD: 端到端自动驾驶模型主包。""" - -from __future__ import annotations - -__version__ = "0.1.0" +"""WJAD: 端到端自动驾驶模型主包。""" + +from __future__ import annotations + +__version__ = "0.1.0" diff --git a/src/wjad/backbone/__init__.py b/src/wjad/backbone/__init__.py index baef4fd20561efe016893946c56b7a066c2c2426..b5168ef92d15968dc5ff98d11405af05fc5a47ce 100644 --- a/src/wjad/backbone/__init__.py +++ b/src/wjad/backbone/__init__.py @@ -1,6 +1,6 @@ -"""18 层主干。""" - -from .backbone import Backbone, BackboneOutput -from .blocks import DenseBlock, MoEBlockWithAttn - -__all__ = ["Backbone", "BackboneOutput", "DenseBlock", "MoEBlockWithAttn"] +"""18 层主干。""" + +from .backbone import Backbone, BackboneOutput +from .blocks import DenseBlock, MoEBlockWithAttn + +__all__ = ["Backbone", "BackboneOutput", "DenseBlock", "MoEBlockWithAttn"] diff --git a/src/wjad/backbone/backbone.py b/src/wjad/backbone/backbone.py index e0748bacc13dc7f908f0d89b0202553ad958aade..7dd021abf5d9c65f6456ecdfc3d5ae4a0cc2948c 100644 --- a/src/wjad/backbone/backbone.py +++ b/src/wjad/backbone/backbone.py @@ -1,110 +1,110 @@ -"""18 层主干:前 9 Dense + 后 9 MoE。""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Optional - -import torch -import torch.nn as nn -import torch.utils.checkpoint as cp - -from ..modules.moe import MoEStats -from .blocks import DenseBlock, MoEBlockWithAttn - - -@dataclass -class BackboneOutput: - """主干输出。""" - - hidden_states: torch.Tensor # [B, N, D] - moe_stats: list[MoEStats] = field(default_factory=list) - - -class Backbone(nn.Module): - """端到端主干。 - - 输入序列已包含位置编码(视觉部分 RoPE 在每层内部应用,非视觉部分使用 - 可学习 PE 在外部加完)。本模块只负责 18 层堆叠 + 路由统计聚合。 - """ - - def __init__( - self, - dim: int = 768, - num_heads: int = 12, - ffn_mult: int = 4, - num_dense_layers: int = 9, - num_moe_layers: int = 9, - num_routed: int = 7, - num_shared: int = 1, - topk: int = 3, - dropout: float = 0.0, - ) -> None: - super().__init__() - self.dim = dim - self.num_heads = num_heads - self.num_dense_layers = num_dense_layers - self.num_moe_layers = num_moe_layers - - self.dense_layers = nn.ModuleList([ - DenseBlock(dim, num_heads, ffn_mult=ffn_mult, dropout=dropout) - for _ in range(num_dense_layers) - ]) - self.moe_layers = nn.ModuleList([ - MoEBlockWithAttn( - dim, - num_heads, - num_routed=num_routed, - num_shared=num_shared, - topk=topk, - ffn_mult=ffn_mult, - dropout=dropout, - ) - for _ in range(num_moe_layers) - ]) - self.final_norm = nn.LayerNorm(dim) - # 默认关闭;外部通过 ``set_gradient_checkpointing(True)`` 打开以省显存 - self.gradient_checkpointing = False - - def set_gradient_checkpointing(self, enabled: bool) -> None: - """开启/关闭主干各层 gradient checkpointing(约省 2/3 激活显存)。""" - self.gradient_checkpointing = enabled - - def set_moe_mode(self, mode: str) -> None: - """切换所有 MoE 层模式('dense' / 'sparse')。""" - for blk in self.moe_layers: - blk.set_mode(mode) - - def set_router_temperature(self, t: float) -> None: - for blk in self.moe_layers: - blk.set_temperature(t) - - def forward( - self, - x: torch.Tensor, - rope_cos: Optional[torch.Tensor] = None, - rope_sin: Optional[torch.Tensor] = None, - visual_slice: Optional[tuple[int, int]] = None, - ) -> BackboneOutput: - moe_stats: list[MoEStats] = [] - use_ckpt = self.gradient_checkpointing and self.training - - for blk in self.dense_layers: - if use_ckpt: - x = cp.checkpoint( - blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False - ) - else: - x = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) - - for blk in self.moe_layers: - if use_ckpt: - x, stats = cp.checkpoint( - blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False - ) - else: - x, stats = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) - moe_stats.append(stats) - - x = self.final_norm(x) - return BackboneOutput(hidden_states=x, moe_stats=moe_stats) +"""18 层主干:前 9 Dense + 后 9 MoE。""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp + +from ..modules.moe import MoEStats +from .blocks import DenseBlock, MoEBlockWithAttn + + +@dataclass +class BackboneOutput: + """主干输出。""" + + hidden_states: torch.Tensor # [B, N, D] + moe_stats: list[MoEStats] = field(default_factory=list) + + +class Backbone(nn.Module): + """端到端主干。 + + 输入序列已包含位置编码(视觉部分 RoPE 在每层内部应用,非视觉部分使用 + 可学习 PE 在外部加完)。本模块只负责 18 层堆叠 + 路由统计聚合。 + """ + + def __init__( + self, + dim: int = 768, + num_heads: int = 12, + ffn_mult: int = 4, + num_dense_layers: int = 9, + num_moe_layers: int = 9, + num_routed: int = 7, + num_shared: int = 1, + topk: int = 3, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.num_dense_layers = num_dense_layers + self.num_moe_layers = num_moe_layers + + self.dense_layers = nn.ModuleList([ + DenseBlock(dim, num_heads, ffn_mult=ffn_mult, dropout=dropout) + for _ in range(num_dense_layers) + ]) + self.moe_layers = nn.ModuleList([ + MoEBlockWithAttn( + dim, + num_heads, + num_routed=num_routed, + num_shared=num_shared, + topk=topk, + ffn_mult=ffn_mult, + dropout=dropout, + ) + for _ in range(num_moe_layers) + ]) + self.final_norm = nn.LayerNorm(dim) + # 默认关闭;外部通过 ``set_gradient_checkpointing(True)`` 打开以省显存 + self.gradient_checkpointing = False + + def set_gradient_checkpointing(self, enabled: bool) -> None: + """开启/关闭主干各层 gradient checkpointing(约省 2/3 激活显存)。""" + self.gradient_checkpointing = enabled + + def set_moe_mode(self, mode: str) -> None: + """切换所有 MoE 层模式('dense' / 'sparse')。""" + for blk in self.moe_layers: + blk.set_mode(mode) + + def set_router_temperature(self, t: float) -> None: + for blk in self.moe_layers: + blk.set_temperature(t) + + def forward( + self, + x: torch.Tensor, + rope_cos: Optional[torch.Tensor] = None, + rope_sin: Optional[torch.Tensor] = None, + visual_slice: Optional[tuple[int, int]] = None, + ) -> BackboneOutput: + moe_stats: list[MoEStats] = [] + use_ckpt = self.gradient_checkpointing and self.training + + for blk in self.dense_layers: + if use_ckpt: + x = cp.checkpoint( + blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False + ) + else: + x = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) + + for blk in self.moe_layers: + if use_ckpt: + x, stats = cp.checkpoint( + blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False + ) + else: + x, stats = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) + moe_stats.append(stats) + + x = self.final_norm(x) + return BackboneOutput(hidden_states=x, moe_stats=moe_stats) diff --git a/src/wjad/backbone/blocks.py b/src/wjad/backbone/blocks.py index 14bac6b402b19a1b1f78c9f5fd39aa4dd216de45..ab7e253202a32e6f617c5bbf629b410ac926d627 100644 --- a/src/wjad/backbone/blocks.py +++ b/src/wjad/backbone/blocks.py @@ -1,79 +1,79 @@ -"""主干层 block:Dense(GateSelfAttn + SwiGLU FFN)/ MoE(GateSelfAttn + MoE FFN)。""" - -from __future__ import annotations - -from typing import Optional - -import torch -import torch.nn as nn - -from ..modules.ffn import SwiGLUFFN -from ..modules.gate_attention import GateSelfAttention -from ..modules.moe import MoEBlock, MoEStats - - -class DenseBlock(nn.Module): - """PreNorm GateSelfAttention + PreNorm SwiGLU FFN。""" - - def __init__(self, dim: int, num_heads: int, ffn_mult: int = 4, dropout: float = 0.0) -> None: - super().__init__() - self.norm1 = nn.LayerNorm(dim) - self.attn = GateSelfAttention(dim, num_heads=num_heads, dropout=dropout) - self.norm2 = nn.LayerNorm(dim) - self.ffn = SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) - - def forward( - self, - x: torch.Tensor, - rope_cos: Optional[torch.Tensor] = None, - rope_sin: Optional[torch.Tensor] = None, - visual_slice: Optional[tuple[int, int]] = None, - ) -> torch.Tensor: - x = x + self.attn(self.norm1(x), rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) - x = x + self.ffn(self.norm2(x)) - return x - - -class MoEBlockWithAttn(nn.Module): - """PreNorm GateSelfAttention + PreNorm MoE FFN。""" - - def __init__( - self, - dim: int, - num_heads: int, - num_routed: int = 7, - num_shared: int = 1, - topk: int = 3, - ffn_mult: int = 4, - dropout: float = 0.0, - ) -> None: - super().__init__() - self.norm1 = nn.LayerNorm(dim) - self.attn = GateSelfAttention(dim, num_heads=num_heads, dropout=dropout) - self.norm2 = nn.LayerNorm(dim) - self.moe = MoEBlock( - dim, - num_routed=num_routed, - num_shared=num_shared, - topk=topk, - ffn_mult=ffn_mult, - dropout=dropout, - ) - - def set_mode(self, mode: str) -> None: - self.moe.set_mode(mode) - - def set_temperature(self, t: float) -> None: - self.moe.set_temperature(t) - - def forward( - self, - x: torch.Tensor, - rope_cos: Optional[torch.Tensor] = None, - rope_sin: Optional[torch.Tensor] = None, - visual_slice: Optional[tuple[int, int]] = None, - ) -> tuple[torch.Tensor, MoEStats]: - x = x + self.attn(self.norm1(x), rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) - moe_out, stats = self.moe(self.norm2(x)) - x = x + moe_out - return x, stats +"""主干层 block:Dense(GateSelfAttn + SwiGLU FFN)/ MoE(GateSelfAttn + MoE FFN)。""" + +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn as nn + +from ..modules.ffn import SwiGLUFFN +from ..modules.gate_attention import GateSelfAttention +from ..modules.moe import MoEBlock, MoEStats + + +class DenseBlock(nn.Module): + """PreNorm GateSelfAttention + PreNorm SwiGLU FFN。""" + + def __init__(self, dim: int, num_heads: int, ffn_mult: int = 4, dropout: float = 0.0) -> None: + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.attn = GateSelfAttention(dim, num_heads=num_heads, dropout=dropout) + self.norm2 = nn.LayerNorm(dim) + self.ffn = SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) + + def forward( + self, + x: torch.Tensor, + rope_cos: Optional[torch.Tensor] = None, + rope_sin: Optional[torch.Tensor] = None, + visual_slice: Optional[tuple[int, int]] = None, + ) -> torch.Tensor: + x = x + self.attn(self.norm1(x), rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) + x = x + self.ffn(self.norm2(x)) + return x + + +class MoEBlockWithAttn(nn.Module): + """PreNorm GateSelfAttention + PreNorm MoE FFN。""" + + def __init__( + self, + dim: int, + num_heads: int, + num_routed: int = 7, + num_shared: int = 1, + topk: int = 3, + ffn_mult: int = 4, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.attn = GateSelfAttention(dim, num_heads=num_heads, dropout=dropout) + self.norm2 = nn.LayerNorm(dim) + self.moe = MoEBlock( + dim, + num_routed=num_routed, + num_shared=num_shared, + topk=topk, + ffn_mult=ffn_mult, + dropout=dropout, + ) + + def set_mode(self, mode: str) -> None: + self.moe.set_mode(mode) + + def set_temperature(self, t: float) -> None: + self.moe.set_temperature(t) + + def forward( + self, + x: torch.Tensor, + rope_cos: Optional[torch.Tensor] = None, + rope_sin: Optional[torch.Tensor] = None, + visual_slice: Optional[tuple[int, int]] = None, + ) -> tuple[torch.Tensor, MoEStats]: + x = x + self.attn(self.norm1(x), rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) + moe_out, stats = self.moe(self.norm2(x)) + x = x + moe_out + return x, stats diff --git a/src/wjad/calibration/__init__.py b/src/wjad/calibration/__init__.py index 5d8bcba4d64d0cf78deb1e7001c5597532ae8504..d1901f6845d0a5c90b8942b8aa0ce0f150035b7b 100644 --- a/src/wjad/calibration/__init__.py +++ b/src/wjad/calibration/__init__.py @@ -1,5 +1,5 @@ -"""在线校准模块。""" - -from .online_calib import OnlineCalibration, CalibrationOutput - -__all__ = ["OnlineCalibration", "CalibrationOutput"] +"""在线校准模块。""" + +from .online_calib import OnlineCalibration, CalibrationOutput + +__all__ = ["OnlineCalibration", "CalibrationOutput"] diff --git a/src/wjad/calibration/online_calib.py b/src/wjad/calibration/online_calib.py index aae7fb47f1674e0953e4abc21d7992f1ff73c639..143f64b22a738d1d7181b4ce7f40a5e79172a59a 100644 --- a/src/wjad/calibration/online_calib.py +++ b/src/wjad/calibration/online_calib.py @@ -1,196 +1,196 @@ -"""在线校准网络。 - -输入 - - DINOv3 patch 特征 ``[B, T, gh, gw, D_dino]``(用作 K/V 上下文)。 - - 8 帧自车位姿(每帧 6D = 3 平移 + 3 轴角)``[B, 8, 6]``。 - - f-theta 内参 ``[B, intr_dim]``。Cosmos-Drive-Dreams 常见 **11** 维(无 ``linear_cde``); - README 完整式为 14 维时把 ``intr_dim`` 配成 14 即可,**不做零填充**。 - - 相机外参 6D ``[B, 6]``。 - -流程 - - 上述运动学 / 内外参先 ``symlog`` 归一,再 Linear -> 256,作为额外 - 条件 token 与 256 个可学习 query token 拼接。 - - 6 层 = 2 × [1 GateCrossAttn(K,V <- DINOv3 patch) + 2 GateSelfAttn]。 - - 取最后一个 token,MLP -> Tanh -> ``residual_range`` → 输出 - symlog 空间的残差。``corrected = symexp(symlog(raw) + Tanh_residual)``。 - -输出 - - ``ego_residual`` ``[B, 8, 6]``、``intr_residual`` ``[B, intr_dim]``、 - ``extr_residual`` ``[B, 6]``,已 symexp 还原到真实空间的 ``corrected_*``。 -""" - -from __future__ import annotations - -from dataclasses import dataclass - -import torch -import torch.nn as nn - -from ..modules.gate_attention import GateCrossAttention, GateSelfAttention -from ..modules.learned_pe import LearnedTokenPE -from ..modules.normalization import symexp, symlog - - -@dataclass -class CalibrationOutput: - """校准网络输出。残差均在 symlog 空间,``corrected_*`` 已 symexp 还原。""" - - ego_residual: torch.Tensor # [B, 8, 6] - intr_residual: torch.Tensor # [B, intr_dim] - extr_residual: torch.Tensor # [B, 6] - corrected_ego: torch.Tensor # [B, 8, 6],真实空间 - corrected_intr: torch.Tensor # [B, intr_dim] - corrected_extr: torch.Tensor # [B, 6] - - -class _CalibBlock(nn.Module): - """单个校准 block:1 GateCrossAttn + 2 GateSelfAttn,PreNorm。""" - - def __init__(self, dim: int, dim_kv: int, num_heads: int, num_self: int = 2) -> None: - super().__init__() - self.cross_norm = nn.LayerNorm(dim) - self.cross = GateCrossAttention(dim, dim_kv, num_heads=num_heads) - self.cross_drop = nn.Dropout(0.0) - self.self_blocks = nn.ModuleList() - for _ in range(num_self): - self.self_blocks.append( - nn.ModuleDict({ - "norm": nn.LayerNorm(dim), - "attn": GateSelfAttention(dim, num_heads=num_heads), - }) - ) - - def forward(self, q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: - # 1) Cross - q = q + self.cross(self.cross_norm(q), kv) - # 2) Self ×2 - for blk in self.self_blocks: - q = q + blk["attn"](blk["norm"](q)) - return q - - -class OnlineCalibration(nn.Module): - def __init__( - self, - dino_dim: int = 768, - hidden_dim: int = 256, - num_query_tokens: int = 256, - num_blocks: int = 2, - num_self_attn_per_block: int = 2, - num_heads: int = 8, - residual_range: float = 0.1, - ego_dim: int = 6, # 3 平移 + 3 轴角 - intr_dim: int = 11, - extr_dim: int = 6, - num_history_frames: int = 8, - init_zero_output: bool = True, - ) -> None: - super().__init__() - self.hidden_dim = hidden_dim - self.residual_range = residual_range - self.ego_dim = ego_dim - self.intr_dim = intr_dim - self.extr_dim = extr_dim - self.num_history_frames = num_history_frames - - # 256 可学习 query token - self.query_tokens = nn.Parameter(torch.empty(num_query_tokens, hidden_dim)) - nn.init.trunc_normal_(self.query_tokens, std=0.02) - self.query_pe = LearnedTokenPE(num_query_tokens, hidden_dim) - - # 条件 token 编码(symlog 空间) - self.ego_proj = nn.Linear(ego_dim, hidden_dim) - self.intr_proj = nn.Linear(intr_dim, hidden_dim) - self.extr_proj = nn.Linear(extr_dim, hidden_dim) - # 条件 token 也加可学习 PE(与 query 区分) - num_cond = num_history_frames + 2 # 8 ego + 1 intr + 1 extr - self.cond_pe = LearnedTokenPE(num_cond, hidden_dim) - self.num_cond = num_cond - self.num_query = num_query_tokens - - # KV 上下文:DINOv3 patch 特征投影到 hidden_dim - self.kv_proj = nn.Linear(dino_dim, hidden_dim) - self.kv_norm = nn.LayerNorm(hidden_dim) - - # 校准 block × num_blocks - self.blocks = nn.ModuleList([ - _CalibBlock(hidden_dim, hidden_dim, num_heads=num_heads, num_self=num_self_attn_per_block) - for _ in range(num_blocks) - ]) - - # 输出 MLP:取 last token -> 残差向量 - residual_total = num_history_frames * ego_dim + intr_dim + extr_dim - self.out_norm = nn.LayerNorm(hidden_dim) - self.out_mlp = nn.Sequential( - nn.Linear(hidden_dim, hidden_dim), - nn.GELU(), - nn.Linear(hidden_dim, residual_total), - ) - # 0 初始化最后一层 → Tanh(0) = 0 → 初始残差 = 0 - if init_zero_output: - nn.init.zeros_(self.out_mlp[-1].weight) - nn.init.zeros_(self.out_mlp[-1].bias) - - def forward( - self, - dino_feats: torch.Tensor, # [B, T, gh, gw, D_dino] - ego_raw: torch.Tensor, # [B, 8, 6] 真实空间 - intr_raw: torch.Tensor, # [B, intr_dim],须与构造 ``OnlineCalibration`` 时一致 - extr_raw: torch.Tensor, # [B, 6] - ) -> CalibrationOutput: - b = dino_feats.shape[0] - if intr_raw.shape[-1] != self.intr_dim: - raise ValueError( - f"intr_raw.shape[-1]={intr_raw.shape[-1]} 与 OnlineCalibration.intr_dim={self.intr_dim} 不一致。" - f"数据是多少维就设多少维(见 configs calibration.intr_vec_dim),不要填充假参数。" - ) - # === 上下文 K/V === - # 把 [B, T, gh, gw, D] flatten 为 [B, T*gh*gw, D] - kv = dino_feats.reshape(b, -1, dino_feats.shape[-1]) - kv = self.kv_norm(self.kv_proj(kv)) - - # === 条件 token(symlog 空间)=== - ego_sym = symlog(ego_raw) - intr_sym = symlog(intr_raw) - extr_sym = symlog(extr_raw) - - ego_tok = self.ego_proj(ego_sym) # [B, 8, D] - intr_tok = self.intr_proj(intr_sym).unsqueeze(1) # [B, 1, D] - extr_tok = self.extr_proj(extr_sym).unsqueeze(1) # [B, 1, D] - cond = torch.cat([ego_tok, intr_tok, extr_tok], dim=1) # [B, num_cond, D] - cond = self.cond_pe(cond) - - # === 拼接 query token === - q = self.query_tokens.unsqueeze(0).expand(b, -1, -1) - q = self.query_pe(q) - seq = torch.cat([cond, q], dim=1) # [B, num_cond + num_query, D] - - # === 6 层 block === - for blk in self.blocks: - seq = blk(seq, kv) - - # === 取 last token === - last = self.out_norm(seq[:, -1, :]) # [B, D] - residual_flat = self.out_mlp(last) - # Tanh + 缩放 - residual_flat = torch.tanh(residual_flat) * self.residual_range - - # 拆分 - n_ego = self.num_history_frames * self.ego_dim - ego_res = residual_flat[:, :n_ego].view(b, self.num_history_frames, self.ego_dim) - intr_res = residual_flat[:, n_ego : n_ego + self.intr_dim] - extr_res = residual_flat[:, n_ego + self.intr_dim :] - - # symlog 空间叠加 + symexp 还原 - corrected_ego = symexp(symlog(ego_raw) + ego_res) - corrected_intr = symexp(symlog(intr_raw) + intr_res) - corrected_extr = symexp(symlog(extr_raw) + extr_res) - - return CalibrationOutput( - ego_residual=ego_res, - intr_residual=intr_res, - extr_residual=extr_res, - corrected_ego=corrected_ego, - corrected_intr=corrected_intr, - corrected_extr=corrected_extr, - ) +"""在线校准网络。 + +输入 + - DINOv3 patch 特征 ``[B, T, gh, gw, D_dino]``(用作 K/V 上下文)。 + - 8 帧自车位姿(每帧 6D = 3 平移 + 3 轴角)``[B, 8, 6]``。 + - f-theta 内参 ``[B, intr_dim]``。Cosmos-Drive-Dreams 常见 **11** 维(无 ``linear_cde``); + README 完整式为 14 维时把 ``intr_dim`` 配成 14 即可,**不做零填充**。 + - 相机外参 6D ``[B, 6]``。 + +流程 + - 上述运动学 / 内外参先 ``symlog`` 归一,再 Linear -> 256,作为额外 + 条件 token 与 256 个可学习 query token 拼接。 + - 6 层 = 2 × [1 GateCrossAttn(K,V <- DINOv3 patch) + 2 GateSelfAttn]。 + - 取最后一个 token,MLP -> Tanh -> ``residual_range`` → 输出 + symlog 空间的残差。``corrected = symexp(symlog(raw) + Tanh_residual)``。 + +输出 + - ``ego_residual`` ``[B, 8, 6]``、``intr_residual`` ``[B, intr_dim]``、 + ``extr_residual`` ``[B, 6]``,已 symexp 还原到真实空间的 ``corrected_*``。 +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from ..modules.gate_attention import GateCrossAttention, GateSelfAttention +from ..modules.learned_pe import LearnedTokenPE +from ..modules.normalization import symexp, symlog + + +@dataclass +class CalibrationOutput: + """校准网络输出。残差均在 symlog 空间,``corrected_*`` 已 symexp 还原。""" + + ego_residual: torch.Tensor # [B, 8, 6] + intr_residual: torch.Tensor # [B, intr_dim] + extr_residual: torch.Tensor # [B, 6] + corrected_ego: torch.Tensor # [B, 8, 6],真实空间 + corrected_intr: torch.Tensor # [B, intr_dim] + corrected_extr: torch.Tensor # [B, 6] + + +class _CalibBlock(nn.Module): + """单个校准 block:1 GateCrossAttn + 2 GateSelfAttn,PreNorm。""" + + def __init__(self, dim: int, dim_kv: int, num_heads: int, num_self: int = 2) -> None: + super().__init__() + self.cross_norm = nn.LayerNorm(dim) + self.cross = GateCrossAttention(dim, dim_kv, num_heads=num_heads) + self.cross_drop = nn.Dropout(0.0) + self.self_blocks = nn.ModuleList() + for _ in range(num_self): + self.self_blocks.append( + nn.ModuleDict({ + "norm": nn.LayerNorm(dim), + "attn": GateSelfAttention(dim, num_heads=num_heads), + }) + ) + + def forward(self, q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: + # 1) Cross + q = q + self.cross(self.cross_norm(q), kv) + # 2) Self ×2 + for blk in self.self_blocks: + q = q + blk["attn"](blk["norm"](q)) + return q + + +class OnlineCalibration(nn.Module): + def __init__( + self, + dino_dim: int = 768, + hidden_dim: int = 256, + num_query_tokens: int = 256, + num_blocks: int = 2, + num_self_attn_per_block: int = 2, + num_heads: int = 8, + residual_range: float = 0.1, + ego_dim: int = 6, # 3 平移 + 3 轴角 + intr_dim: int = 11, + extr_dim: int = 6, + num_history_frames: int = 8, + init_zero_output: bool = True, + ) -> None: + super().__init__() + self.hidden_dim = hidden_dim + self.residual_range = residual_range + self.ego_dim = ego_dim + self.intr_dim = intr_dim + self.extr_dim = extr_dim + self.num_history_frames = num_history_frames + + # 256 可学习 query token + self.query_tokens = nn.Parameter(torch.empty(num_query_tokens, hidden_dim)) + nn.init.trunc_normal_(self.query_tokens, std=0.02) + self.query_pe = LearnedTokenPE(num_query_tokens, hidden_dim) + + # 条件 token 编码(symlog 空间) + self.ego_proj = nn.Linear(ego_dim, hidden_dim) + self.intr_proj = nn.Linear(intr_dim, hidden_dim) + self.extr_proj = nn.Linear(extr_dim, hidden_dim) + # 条件 token 也加可学习 PE(与 query 区分) + num_cond = num_history_frames + 2 # 8 ego + 1 intr + 1 extr + self.cond_pe = LearnedTokenPE(num_cond, hidden_dim) + self.num_cond = num_cond + self.num_query = num_query_tokens + + # KV 上下文:DINOv3 patch 特征投影到 hidden_dim + self.kv_proj = nn.Linear(dino_dim, hidden_dim) + self.kv_norm = nn.LayerNorm(hidden_dim) + + # 校准 block × num_blocks + self.blocks = nn.ModuleList([ + _CalibBlock(hidden_dim, hidden_dim, num_heads=num_heads, num_self=num_self_attn_per_block) + for _ in range(num_blocks) + ]) + + # 输出 MLP:取 last token -> 残差向量 + residual_total = num_history_frames * ego_dim + intr_dim + extr_dim + self.out_norm = nn.LayerNorm(hidden_dim) + self.out_mlp = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, residual_total), + ) + # 0 初始化最后一层 → Tanh(0) = 0 → 初始残差 = 0 + if init_zero_output: + nn.init.zeros_(self.out_mlp[-1].weight) + nn.init.zeros_(self.out_mlp[-1].bias) + + def forward( + self, + dino_feats: torch.Tensor, # [B, T, gh, gw, D_dino] + ego_raw: torch.Tensor, # [B, 8, 6] 真实空间 + intr_raw: torch.Tensor, # [B, intr_dim],须与构造 ``OnlineCalibration`` 时一致 + extr_raw: torch.Tensor, # [B, 6] + ) -> CalibrationOutput: + b = dino_feats.shape[0] + if intr_raw.shape[-1] != self.intr_dim: + raise ValueError( + f"intr_raw.shape[-1]={intr_raw.shape[-1]} 与 OnlineCalibration.intr_dim={self.intr_dim} 不一致。" + f"数据是多少维就设多少维(见 configs calibration.intr_vec_dim),不要填充假参数。" + ) + # === 上下文 K/V === + # 把 [B, T, gh, gw, D] flatten 为 [B, T*gh*gw, D] + kv = dino_feats.reshape(b, -1, dino_feats.shape[-1]) + kv = self.kv_norm(self.kv_proj(kv)) + + # === 条件 token(symlog 空间)=== + ego_sym = symlog(ego_raw) + intr_sym = symlog(intr_raw) + extr_sym = symlog(extr_raw) + + ego_tok = self.ego_proj(ego_sym) # [B, 8, D] + intr_tok = self.intr_proj(intr_sym).unsqueeze(1) # [B, 1, D] + extr_tok = self.extr_proj(extr_sym).unsqueeze(1) # [B, 1, D] + cond = torch.cat([ego_tok, intr_tok, extr_tok], dim=1) # [B, num_cond, D] + cond = self.cond_pe(cond) + + # === 拼接 query token === + q = self.query_tokens.unsqueeze(0).expand(b, -1, -1) + q = self.query_pe(q) + seq = torch.cat([cond, q], dim=1) # [B, num_cond + num_query, D] + + # === 6 层 block === + for blk in self.blocks: + seq = blk(seq, kv) + + # === 取 last token === + last = self.out_norm(seq[:, -1, :]) # [B, D] + residual_flat = self.out_mlp(last) + # Tanh + 缩放 + residual_flat = torch.tanh(residual_flat) * self.residual_range + + # 拆分 + n_ego = self.num_history_frames * self.ego_dim + ego_res = residual_flat[:, :n_ego].view(b, self.num_history_frames, self.ego_dim) + intr_res = residual_flat[:, n_ego : n_ego + self.intr_dim] + extr_res = residual_flat[:, n_ego + self.intr_dim :] + + # symlog 空间叠加 + symexp 还原 + corrected_ego = symexp(symlog(ego_raw) + ego_res) + corrected_intr = symexp(symlog(intr_raw) + intr_res) + corrected_extr = symexp(symlog(extr_raw) + extr_res) + + return CalibrationOutput( + ego_residual=ego_res, + intr_residual=intr_res, + extr_residual=extr_res, + corrected_ego=corrected_ego, + corrected_intr=corrected_intr, + corrected_extr=corrected_extr, + ) diff --git a/src/wjad/data/__init__.py b/src/wjad/data/__init__.py index 6c468ff07a151b62333d09d9820974ddcd807bb8..a4fca8a7e125b7972ae3889760b1bd0c030dad2a 100644 --- a/src/wjad/data/__init__.py +++ b/src/wjad/data/__init__.py @@ -1,39 +1,39 @@ -"""Cosmos-Drive-Dreams 数据加载与目标构建。""" - -from .se3 import ( - matrix_to_6d, - six_d_to_matrix, - invert_se3, - rotation_matrix_to_axis_angle, - axis_angle_to_rotation_matrix, -) -from .ftheta_proj import project_points_ftheta -from .transforms import ( - crop_top_half, - normalize_image, - add_gaussian_noise, - perturb_kinematics, -) -from .targets import build_detection_targets, build_ego_future_target, ObjectTrackInfo -from .hdmap import parse_hdmap_clip, HDMAP_SOURCES -from .cosmos_dataset import CosmosDriveDreamsDataset, build_clip_index - -__all__ = [ - "matrix_to_6d", - "six_d_to_matrix", - "invert_se3", - "rotation_matrix_to_axis_angle", - "axis_angle_to_rotation_matrix", - "project_points_ftheta", - "crop_top_half", - "normalize_image", - "add_gaussian_noise", - "perturb_kinematics", - "build_detection_targets", - "build_ego_future_target", - "ObjectTrackInfo", - "CosmosDriveDreamsDataset", - "build_clip_index", - "parse_hdmap_clip", - "HDMAP_SOURCES", -] +"""Cosmos-Drive-Dreams 数据加载与目标构建。""" + +from .se3 import ( + matrix_to_6d, + six_d_to_matrix, + invert_se3, + rotation_matrix_to_axis_angle, + axis_angle_to_rotation_matrix, +) +from .ftheta_proj import project_points_ftheta +from .transforms import ( + crop_top_half, + normalize_image, + add_gaussian_noise, + perturb_kinematics, +) +from .targets import build_detection_targets, build_ego_future_target, ObjectTrackInfo +from .hdmap import parse_hdmap_clip, HDMAP_SOURCES +from .cosmos_dataset import CosmosDriveDreamsDataset, build_clip_index + +__all__ = [ + "matrix_to_6d", + "six_d_to_matrix", + "invert_se3", + "rotation_matrix_to_axis_angle", + "axis_angle_to_rotation_matrix", + "project_points_ftheta", + "crop_top_half", + "normalize_image", + "add_gaussian_noise", + "perturb_kinematics", + "build_detection_targets", + "build_ego_future_target", + "ObjectTrackInfo", + "CosmosDriveDreamsDataset", + "build_clip_index", + "parse_hdmap_clip", + "HDMAP_SOURCES", +] diff --git a/src/wjad/data/cosmos_dataset.py b/src/wjad/data/cosmos_dataset.py index e1356adfb5549a0a1a4c345012e53d655dbf62ff..0635c991aebf4c50679bb08ccddadc90f0dfb617 100644 --- a/src/wjad/data/cosmos_dataset.py +++ b/src/wjad/data/cosmos_dataset.py @@ -1,439 +1,439 @@ -"""Cosmos-Drive-Dreams 数据集加载器(真实实现)。 - -期待目录结构(从 NVIDIA 提供的 .tar 解压): - - data_root/ - synthetic/single_view/ - generation/{clip_id}_{chunk_id}_{weather}.mp4 # 121 帧合成视频 - labels/{clip_id}/ - vehicle_pose/000000.vehicle_pose.npy ... # 30 FPS, FLU - pose/000000.pose.{camera}.npy # 30 FPS, OpenCV - ftheta_intrinsic/ftheta_intrinsic.{camera}.npy - all_object_info/000000.all_object_info.json - lidar_raw/000000.lidar_raw.npz # 10 FPS - -每段 clip 提供: -- 视频按 `_chunk_id` 分块。chunk_id=0 对应 label idx 0..120;chunk_id=1 对应 label idx 121..241。 -- 每个样本:8 帧不重叠窗口 t∈[7, 96],输入 8 帧(t-7..t)+ 未来 24 帧标签。 -""" - -from __future__ import annotations - -import json -from dataclasses import dataclass -from pathlib import Path -from typing import Sequence - -import cv2 -import numpy as np -import torch -from torch.utils.data import Dataset - -from ..modules.normalization import symlog -from ..modules.rays import FThetaCamera -from .label_paths import resolve_clip_file -from .hdmap import parse_hdmap_clip -from .se3 import matrix_to_6d -from .targets import ( - ObjectTrackInfo, - build_detection_targets, - build_ego_future_target, -) -from .transforms import DINOV3_MEAN, DINOV3_STD - - -# 数据集 README 列出的对象类型;动态类用于 is_dynamic + 未来轨迹监督。 -DEFAULT_DYNAMIC_CLASSES = [ - "Automobile", - "Heavy_truck", - "Bus", - "Train_or_tram_car", - "Trolley_bus", - "Other_vehicle", - "Trailer", - "Person", - "Stroller", - "Rider", - "Animal", - "Protruding_object", -] - -# 结构化场景类(与 ``hdmap.py`` 的 9 个 HDMAP_SOURCES key 一一对应)。 -DEFAULT_STRUCTURED_CLASSES = [ - "lane", - "laneline", - "road_boundary", - "wait_line", - "crosswalk", - "road_marking", - "pole", - "traffic_light", - "traffic_sign", -] - - -@dataclass -class ClipSample: - """clip 索引项。""" - - clip_id: str - chunk_id: int - weather: str - video_path: Path - labels_dir: Path - anchor_t: int # 当前帧(含),范围 [7, 96] - chunk_offset: int # 当前 chunk 在标签里的起始 idx(0 或 121) - - -def build_clip_index( - data_root: str | Path, - weathers: Sequence[str] = ("Sunny",), - chunk_ids: Sequence[int] = (0, 1), - camera_name: str = "camera_front_wide_120fov", - stride: int = 8, - anchor_min: int = 7, - anchor_max: int = 96, - max_clips: int | None = None, -) -> list[ClipSample]: - """枚举所有可用 (clip, chunk, weather, anchor_t) 样本。 - - 锚点 ``t`` 在 chunk 内为局部索引,对应视频帧 ``t``,对应标签帧 - ``chunk_offset + t``(chunk_offset = chunk_id * 121)。 - """ - root = Path(data_root) - syn_dir = root / "synthetic" / "single_view" / "generation" - labels_dir = root / "labels" - - samples: list[ClipSample] = [] - if not syn_dir.exists(): - return samples - - for video_path in sorted(syn_dir.glob("*.mp4")): - # 文件名形如 {clip_id}_{chunk_id}_{weather}.mp4 - # clip_id 可能含下划线(UUID 或 timestamp 形式),所以从右侧解析 - stem = video_path.stem - parts = stem.rsplit("_", 2) - if len(parts) != 3: - continue - clip_id, chunk_str, weather = parts - try: - chunk_id = int(chunk_str) - except ValueError: - continue - if chunk_id not in chunk_ids or weather not in weathers: - continue - - clip_label_dir = labels_dir / clip_id - if not clip_label_dir.exists(): - continue - - chunk_offset = chunk_id * 121 - for t in range(anchor_min, anchor_max + 1, stride): - samples.append( - ClipSample( - clip_id=clip_id, - chunk_id=chunk_id, - weather=weather, - video_path=video_path, - labels_dir=clip_label_dir, - anchor_t=t, - chunk_offset=chunk_offset, - ) - ) - if max_clips is not None and len({s.clip_id for s in samples}) >= max_clips: - break - - return samples - - -def _load_video_frames( - video_path: Path, - frame_indices: Sequence[int], - target_h: int, - target_w: int, -) -> torch.Tensor: - """从 .mp4 中读取指定帧序列,调整大小并按 ``[T, 3, H, W]`` 返回 ``float32 in [0, 1]``。""" - cap = cv2.VideoCapture(str(video_path)) - if not cap.isOpened(): - raise FileNotFoundError(f"无法打开视频: {video_path}") - frames = [] - for idx in frame_indices: - cap.set(cv2.CAP_PROP_POS_FRAMES, idx) - ok, bgr = cap.read() - if not ok: - cap.release() - raise RuntimeError(f"读取帧 {idx} 失败: {video_path}") - rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) - rgb = cv2.resize(rgb, (target_w, target_h * 2), interpolation=cv2.INTER_AREA) - # 裁去上半部分(天空)后高度变为 target_h - rgb = rgb[target_h:, :, :] - rgb = rgb.astype(np.float32) / 255.0 - frames.append(torch.from_numpy(rgb).permute(2, 0, 1)) # [3, H, W] - cap.release() - return torch.stack(frames, dim=0) - - -def _load_npy(path: Path) -> np.ndarray: - return np.load(path, allow_pickle=False) - - -def _load_object_info(path: Path) -> list[ObjectTrackInfo]: - """解析单帧 all_object_info JSON。""" - if not path.exists(): - return [] - data = json.loads(path.read_text()) - out = [] - for tid, info in data.items(): - T = torch.tensor(info["object_to_world"], dtype=torch.float32) - lwh = torch.tensor(info["object_lwh"], dtype=torch.float32) - out.append( - ObjectTrackInfo( - tracking_id=tid, - object_to_world=T, - lwh=lwh, - is_moving=bool(info.get("object_is_moving", False)), - object_type=str(info.get("object_type", "")), - ) - ) - return out - - -def _load_lidar_self_frame( - labels_dir: Path, - label_idx: int, - vehicle_pose: torch.Tensor, - max_history: int = 3, -) -> torch.Tensor | None: - """读取与 ``label_idx`` 时间最近的 LIDAR 帧并把 xyz 转到当前 ego self 系。 - - LIDAR 是 10 FPS(每 3 个相机帧 1 个 LIDAR 帧),数据集存储 ``000000``、 - ``000003``、``000006`` 等步长 3 的索引。我们向下取整最近的一帧。 - """ - lidar_idx = (label_idx // 3) * 3 - search_order = [lidar_idx - back * 3 for back in range(max_history + 1) if lidar_idx - back * 3 >= 0] - p: Path | None = None - for idx_try in search_order: - try: - p = resolve_clip_file(labels_dir, "lidar_raw", f"{idx_try:06d}.lidar_raw.npz") - break - except FileNotFoundError: - continue - if p is None: - return None - arr = np.load(p, allow_pickle=False) - xyz_lidar = arr["xyz"] # [N, 3] in lidar frame - lidar_to_world = arr["lidar_to_world"] # [4, 4] - # 转到 world 后再转 self - pts_w = (lidar_to_world[:3, :3] @ xyz_lidar.T).T + lidar_to_world[:3, 3] - inv_pose = torch.linalg.inv(vehicle_pose) - pts_w_t = torch.from_numpy(pts_w).float() - pts_self = (inv_pose[:3, :3] @ pts_w_t.T).T + inv_pose[:3, 3] - return pts_self - - -class CosmosDriveDreamsDataset(Dataset): - """端到端样本:8 帧图像 + ego/intr/extr + 检测 + 自车未来 + 对象未来。""" - - def __init__( - self, - data_root: str | Path, - samples: list[ClipSample] | None = None, - weathers: Sequence[str] = ("Sunny",), - camera_name: str = "camera_front_wide_120fov", - image_h: int = 384, - image_w: int = 1024, - num_history: int = 8, - future_horizon: int = 24, - max_distance_m: float = 48.0, - occlusion_tol: float = 0.5, - dynamic_classes: Sequence[str] = DEFAULT_DYNAMIC_CLASSES, - structured_classes: Sequence[str] = DEFAULT_STRUCTURED_CLASSES, - do_normalize: bool = True, - use_lidar_occlusion: bool = True, - use_hdmap: bool = True, - ) -> None: - super().__init__() - self.data_root = Path(data_root) - self.samples = samples if samples is not None else build_clip_index( - data_root, weathers=weathers, camera_name=camera_name - ) - self.camera_name = camera_name - self.image_h = image_h - self.image_w = image_w - self.num_history = num_history - self.future_horizon = future_horizon - self.max_distance_m = max_distance_m - self.occlusion_tol = occlusion_tol - self.dynamic_classes = list(dynamic_classes) - self.structured_classes = list(structured_classes) - self.do_normalize = do_normalize - self.use_lidar_occlusion = use_lidar_occlusion - self.use_hdmap = use_hdmap - # HDMap 是 per-clip 静态对象,缓存避免每个 anchor_t 都重新解析 - self._hdmap_cache: dict[str, list[ObjectTrackInfo]] = {} - self._hdmap_cache_max = 32 - - def __len__(self) -> int: - return len(self.samples) - - def _load_intrinsic(self, sample: ClipSample) -> torch.Tensor: - p = resolve_clip_file( - sample.labels_dir, - "ftheta_intrinsic", - f"ftheta_intrinsic.{self.camera_name}.npy", - ) - return torch.from_numpy(_load_npy(p)).float() - - def _load_pose_camera(self, sample: ClipSample, label_idx: int) -> torch.Tensor: - p = resolve_clip_file( - sample.labels_dir, - "pose", - f"{label_idx:06d}.pose.{self.camera_name}.npy", - ) - return torch.from_numpy(_load_npy(p)).float() - - def _load_pose_vehicle(self, sample: ClipSample, label_idx: int) -> torch.Tensor: - p = resolve_clip_file( - sample.labels_dir, - "vehicle_pose", - f"{label_idx:06d}.vehicle_pose.npy", - ) - return torch.from_numpy(_load_npy(p)).float() - - def _load_hdmap_static(self, clip_dir: Path) -> list[ObjectTrackInfo]: - if not self.use_hdmap: - return [] - key = str(clip_dir) - cached = self._hdmap_cache.get(key) - if cached is not None: - return cached - objs = parse_hdmap_clip(clip_dir) - if len(self._hdmap_cache) >= self._hdmap_cache_max: - self._hdmap_cache.pop(next(iter(self._hdmap_cache))) - self._hdmap_cache[key] = objs - return objs - - def _load_objects(self, sample: ClipSample, label_idx: int) -> list[ObjectTrackInfo]: - p = resolve_clip_file( - sample.labels_dir, - "all_object_info", - f"{label_idx:06d}.all_object_info.json", - ) - dynamic = _load_object_info(p) - # HDMap 是 clip 级静态标签:t 与 t+k 帧都拿同一份(tracking_id 相同), - # 这样 ``build_detection_targets`` 的未来轨迹分支会自动得到 ~0 残差, - # 同时由 ``is_dynamic=0`` 在损失里被 mask 掉,不进 trajectory NLL。 - return dynamic + self._load_hdmap_static(sample.labels_dir) - - def __getitem__(self, idx: int) -> dict: - s = self.samples[idx] - # 视频帧索引(chunk 内 0-based) - t = s.anchor_t - history_frames = list(range(t - self.num_history + 1, t + 1)) - # 标签索引:chunk_offset + chunk-local idx - history_label_idx = [s.chunk_offset + f for f in history_frames] - future_label_idx = [s.chunk_offset + t + 1 + k for k in range(self.future_horizon)] - - # === 1) 加载图像 === - # 注意:videl 已经裁过上半(数据生成时仍 1920x1080 等原始分辨率); - # 这里在 _load_video_frames 内同时做 resize 与 top-half 裁剪。 - images = _load_video_frames(s.video_path, history_frames, self.image_h, self.image_w) - # [T, 3, H, W],[0, 1] - if self.do_normalize: - images = (images - DINOV3_MEAN) / DINOV3_STD - - # === 2) 加载内参 / 外参 === - intr_vec = self._load_intrinsic(s) # [14] - - # 当前帧的 cam_to_world 与 vehicle_to_world,得到 cam_to_vehicle - pose_cam_world = self._load_pose_camera(s, s.chunk_offset + t) - pose_veh_world = self._load_pose_vehicle(s, s.chunk_offset + t) - # cam_to_vehicle = inv(vehicle_to_world) @ cam_to_world - inv_veh = torch.linalg.inv(pose_veh_world) - cam2veh = inv_veh @ pose_cam_world - extr_6d = matrix_to_6d(cam2veh) # [6] - - # === 3) 历史 8 帧 ego pose(vehicle 6D)=== - ego_6d = [] - for li in history_label_idx: - T_vw = self._load_pose_vehicle(s, li) - ego_6d.append(matrix_to_6d(T_vw)) - ego_6d = torch.stack(ego_6d, dim=0) # [8, 6] - - # === 4) 检测 / 未来轨迹标签 === - # objs_t / objs_future = 动态 all_object_info ∪ HDMap 静态对象。 - objs_t = self._load_objects(s, s.chunk_offset + t) - objs_future = [self._load_objects(s, li) for li in future_label_idx] - veh_pose_future = [] - for li in future_label_idx: - try: - veh_pose_future.append(self._load_pose_vehicle(s, li)) - except FileNotFoundError: - break - - cam = FThetaCamera.from_vector(intr_vec) - lidar_self = None - if self.use_lidar_occlusion: - try: - lidar_self = _load_lidar_self_frame( - s.labels_dir, - s.chunk_offset + t, - pose_veh_world, - ) - except Exception: - lidar_self = None - - det_targets = build_detection_targets( - objects_t=objs_t, - objects_future=objs_future, - vehicle_pose_t=pose_veh_world, - vehicle_pose_future=veh_pose_future, - cam_intrinsic=cam, - cam2vehicle=cam2veh, - image_h=self.image_h, - image_w=self.image_w, - max_distance_m=self.max_distance_m, - occlusion_depth_tolerance=self.occlusion_tol, - lidar_points_self=lidar_self, - dynamic_classes=self.dynamic_classes, - structured_classes=self.structured_classes, - future_horizon=self.future_horizon, - ) - - ego_future, ego_future_valid = build_ego_future_target( - pose_veh_world, veh_pose_future, horizon=self.future_horizon - ) - - sample_out = { - "images": images, - "ego_6d": ego_6d, - "intr_vec": intr_vec, - "extr_6d": extr_6d, - "ego_future": ego_future, - "ego_future_valid": ego_future_valid, - "targets": det_targets, - "meta": { - "clip_id": s.clip_id, - "chunk_id": s.chunk_id, - "weather": s.weather, - "anchor_t": s.anchor_t, - }, - } - return sample_out - - -def collate_samples(batch: list[dict]) -> dict: - """自定义 collate:对图像 / ego / intr / extr / ego_future 直接 stack; - targets 列表保留为 list(便于匈牙利匹配处理变长 N); - meta 也保留为 list。""" - out = { - "images": torch.stack([b["images"] for b in batch], dim=0), - "ego_6d": torch.stack([b["ego_6d"] for b in batch], dim=0), - "intr_vec": torch.stack([b["intr_vec"] for b in batch], dim=0), - "extr_6d": torch.stack([b["extr_6d"] for b in batch], dim=0), - "ego_future": torch.stack([b["ego_future"] for b in batch], dim=0), - "ego_future_valid": torch.stack([b["ego_future_valid"] for b in batch], dim=0), - "targets": [b["targets"] for b in batch], - "meta": [b["meta"] for b in batch], - } - return out +"""Cosmos-Drive-Dreams 数据集加载器(真实实现)。 + +期待目录结构(从 NVIDIA 提供的 .tar 解压): + + data_root/ + synthetic/single_view/ + generation/{clip_id}_{chunk_id}_{weather}.mp4 # 121 帧合成视频 + labels/{clip_id}/ + vehicle_pose/000000.vehicle_pose.npy ... # 30 FPS, FLU + pose/000000.pose.{camera}.npy # 30 FPS, OpenCV + ftheta_intrinsic/ftheta_intrinsic.{camera}.npy + all_object_info/000000.all_object_info.json + lidar_raw/000000.lidar_raw.npz # 10 FPS + +每段 clip 提供: +- 视频按 `_chunk_id` 分块。chunk_id=0 对应 label idx 0..120;chunk_id=1 对应 label idx 121..241。 +- 每个样本:8 帧不重叠窗口 t∈[7, 96],输入 8 帧(t-7..t)+ 未来 24 帧标签。 +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Sequence + +import cv2 +import numpy as np +import torch +from torch.utils.data import Dataset + +from ..modules.normalization import symlog +from ..modules.rays import FThetaCamera +from .label_paths import resolve_clip_file +from .hdmap import parse_hdmap_clip +from .se3 import matrix_to_6d +from .targets import ( + ObjectTrackInfo, + build_detection_targets, + build_ego_future_target, +) +from .transforms import DINOV3_MEAN, DINOV3_STD + + +# 数据集 README 列出的对象类型;动态类用于 is_dynamic + 未来轨迹监督。 +DEFAULT_DYNAMIC_CLASSES = [ + "Automobile", + "Heavy_truck", + "Bus", + "Train_or_tram_car", + "Trolley_bus", + "Other_vehicle", + "Trailer", + "Person", + "Stroller", + "Rider", + "Animal", + "Protruding_object", +] + +# 结构化场景类(与 ``hdmap.py`` 的 9 个 HDMAP_SOURCES key 一一对应)。 +DEFAULT_STRUCTURED_CLASSES = [ + "lane", + "laneline", + "road_boundary", + "wait_line", + "crosswalk", + "road_marking", + "pole", + "traffic_light", + "traffic_sign", +] + + +@dataclass +class ClipSample: + """clip 索引项。""" + + clip_id: str + chunk_id: int + weather: str + video_path: Path + labels_dir: Path + anchor_t: int # 当前帧(含),范围 [7, 96] + chunk_offset: int # 当前 chunk 在标签里的起始 idx(0 或 121) + + +def build_clip_index( + data_root: str | Path, + weathers: Sequence[str] = ("Sunny",), + chunk_ids: Sequence[int] = (0, 1), + camera_name: str = "camera_front_wide_120fov", + stride: int = 8, + anchor_min: int = 7, + anchor_max: int = 96, + max_clips: int | None = None, +) -> list[ClipSample]: + """枚举所有可用 (clip, chunk, weather, anchor_t) 样本。 + + 锚点 ``t`` 在 chunk 内为局部索引,对应视频帧 ``t``,对应标签帧 + ``chunk_offset + t``(chunk_offset = chunk_id * 121)。 + """ + root = Path(data_root) + syn_dir = root / "synthetic" / "single_view" / "generation" + labels_dir = root / "labels" + + samples: list[ClipSample] = [] + if not syn_dir.exists(): + return samples + + for video_path in sorted(syn_dir.glob("*.mp4")): + # 文件名形如 {clip_id}_{chunk_id}_{weather}.mp4 + # clip_id 可能含下划线(UUID 或 timestamp 形式),所以从右侧解析 + stem = video_path.stem + parts = stem.rsplit("_", 2) + if len(parts) != 3: + continue + clip_id, chunk_str, weather = parts + try: + chunk_id = int(chunk_str) + except ValueError: + continue + if chunk_id not in chunk_ids or weather not in weathers: + continue + + clip_label_dir = labels_dir / clip_id + if not clip_label_dir.exists(): + continue + + chunk_offset = chunk_id * 121 + for t in range(anchor_min, anchor_max + 1, stride): + samples.append( + ClipSample( + clip_id=clip_id, + chunk_id=chunk_id, + weather=weather, + video_path=video_path, + labels_dir=clip_label_dir, + anchor_t=t, + chunk_offset=chunk_offset, + ) + ) + if max_clips is not None and len({s.clip_id for s in samples}) >= max_clips: + break + + return samples + + +def _load_video_frames( + video_path: Path, + frame_indices: Sequence[int], + target_h: int, + target_w: int, +) -> torch.Tensor: + """从 .mp4 中读取指定帧序列,调整大小并按 ``[T, 3, H, W]`` 返回 ``float32 in [0, 1]``。""" + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise FileNotFoundError(f"无法打开视频: {video_path}") + frames = [] + for idx in frame_indices: + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + ok, bgr = cap.read() + if not ok: + cap.release() + raise RuntimeError(f"读取帧 {idx} 失败: {video_path}") + rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + rgb = cv2.resize(rgb, (target_w, target_h * 2), interpolation=cv2.INTER_AREA) + # 裁去上半部分(天空)后高度变为 target_h + rgb = rgb[target_h:, :, :] + rgb = rgb.astype(np.float32) / 255.0 + frames.append(torch.from_numpy(rgb).permute(2, 0, 1)) # [3, H, W] + cap.release() + return torch.stack(frames, dim=0) + + +def _load_npy(path: Path) -> np.ndarray: + return np.load(path, allow_pickle=False) + + +def _load_object_info(path: Path) -> list[ObjectTrackInfo]: + """解析单帧 all_object_info JSON。""" + if not path.exists(): + return [] + data = json.loads(path.read_text()) + out = [] + for tid, info in data.items(): + T = torch.tensor(info["object_to_world"], dtype=torch.float32) + lwh = torch.tensor(info["object_lwh"], dtype=torch.float32) + out.append( + ObjectTrackInfo( + tracking_id=tid, + object_to_world=T, + lwh=lwh, + is_moving=bool(info.get("object_is_moving", False)), + object_type=str(info.get("object_type", "")), + ) + ) + return out + + +def _load_lidar_self_frame( + labels_dir: Path, + label_idx: int, + vehicle_pose: torch.Tensor, + max_history: int = 3, +) -> torch.Tensor | None: + """读取与 ``label_idx`` 时间最近的 LIDAR 帧并把 xyz 转到当前 ego self 系。 + + LIDAR 是 10 FPS(每 3 个相机帧 1 个 LIDAR 帧),数据集存储 ``000000``、 + ``000003``、``000006`` 等步长 3 的索引。我们向下取整最近的一帧。 + """ + lidar_idx = (label_idx // 3) * 3 + search_order = [lidar_idx - back * 3 for back in range(max_history + 1) if lidar_idx - back * 3 >= 0] + p: Path | None = None + for idx_try in search_order: + try: + p = resolve_clip_file(labels_dir, "lidar_raw", f"{idx_try:06d}.lidar_raw.npz") + break + except FileNotFoundError: + continue + if p is None: + return None + arr = np.load(p, allow_pickle=False) + xyz_lidar = arr["xyz"] # [N, 3] in lidar frame + lidar_to_world = arr["lidar_to_world"] # [4, 4] + # 转到 world 后再转 self + pts_w = (lidar_to_world[:3, :3] @ xyz_lidar.T).T + lidar_to_world[:3, 3] + inv_pose = torch.linalg.inv(vehicle_pose) + pts_w_t = torch.from_numpy(pts_w).float() + pts_self = (inv_pose[:3, :3] @ pts_w_t.T).T + inv_pose[:3, 3] + return pts_self + + +class CosmosDriveDreamsDataset(Dataset): + """端到端样本:8 帧图像 + ego/intr/extr + 检测 + 自车未来 + 对象未来。""" + + def __init__( + self, + data_root: str | Path, + samples: list[ClipSample] | None = None, + weathers: Sequence[str] = ("Sunny",), + camera_name: str = "camera_front_wide_120fov", + image_h: int = 384, + image_w: int = 1024, + num_history: int = 8, + future_horizon: int = 24, + max_distance_m: float = 48.0, + occlusion_tol: float = 0.5, + dynamic_classes: Sequence[str] = DEFAULT_DYNAMIC_CLASSES, + structured_classes: Sequence[str] = DEFAULT_STRUCTURED_CLASSES, + do_normalize: bool = True, + use_lidar_occlusion: bool = True, + use_hdmap: bool = True, + ) -> None: + super().__init__() + self.data_root = Path(data_root) + self.samples = samples if samples is not None else build_clip_index( + data_root, weathers=weathers, camera_name=camera_name + ) + self.camera_name = camera_name + self.image_h = image_h + self.image_w = image_w + self.num_history = num_history + self.future_horizon = future_horizon + self.max_distance_m = max_distance_m + self.occlusion_tol = occlusion_tol + self.dynamic_classes = list(dynamic_classes) + self.structured_classes = list(structured_classes) + self.do_normalize = do_normalize + self.use_lidar_occlusion = use_lidar_occlusion + self.use_hdmap = use_hdmap + # HDMap 是 per-clip 静态对象,缓存避免每个 anchor_t 都重新解析 + self._hdmap_cache: dict[str, list[ObjectTrackInfo]] = {} + self._hdmap_cache_max = 32 + + def __len__(self) -> int: + return len(self.samples) + + def _load_intrinsic(self, sample: ClipSample) -> torch.Tensor: + p = resolve_clip_file( + sample.labels_dir, + "ftheta_intrinsic", + f"ftheta_intrinsic.{self.camera_name}.npy", + ) + return torch.from_numpy(_load_npy(p)).float() + + def _load_pose_camera(self, sample: ClipSample, label_idx: int) -> torch.Tensor: + p = resolve_clip_file( + sample.labels_dir, + "pose", + f"{label_idx:06d}.pose.{self.camera_name}.npy", + ) + return torch.from_numpy(_load_npy(p)).float() + + def _load_pose_vehicle(self, sample: ClipSample, label_idx: int) -> torch.Tensor: + p = resolve_clip_file( + sample.labels_dir, + "vehicle_pose", + f"{label_idx:06d}.vehicle_pose.npy", + ) + return torch.from_numpy(_load_npy(p)).float() + + def _load_hdmap_static(self, clip_dir: Path) -> list[ObjectTrackInfo]: + if not self.use_hdmap: + return [] + key = str(clip_dir) + cached = self._hdmap_cache.get(key) + if cached is not None: + return cached + objs = parse_hdmap_clip(clip_dir) + if len(self._hdmap_cache) >= self._hdmap_cache_max: + self._hdmap_cache.pop(next(iter(self._hdmap_cache))) + self._hdmap_cache[key] = objs + return objs + + def _load_objects(self, sample: ClipSample, label_idx: int) -> list[ObjectTrackInfo]: + p = resolve_clip_file( + sample.labels_dir, + "all_object_info", + f"{label_idx:06d}.all_object_info.json", + ) + dynamic = _load_object_info(p) + # HDMap 是 clip 级静态标签:t 与 t+k 帧都拿同一份(tracking_id 相同), + # 这样 ``build_detection_targets`` 的未来轨迹分支会自动得到 ~0 残差, + # 同时由 ``is_dynamic=0`` 在损失里被 mask 掉,不进 trajectory NLL。 + return dynamic + self._load_hdmap_static(sample.labels_dir) + + def __getitem__(self, idx: int) -> dict: + s = self.samples[idx] + # 视频帧索引(chunk 内 0-based) + t = s.anchor_t + history_frames = list(range(t - self.num_history + 1, t + 1)) + # 标签索引:chunk_offset + chunk-local idx + history_label_idx = [s.chunk_offset + f for f in history_frames] + future_label_idx = [s.chunk_offset + t + 1 + k for k in range(self.future_horizon)] + + # === 1) 加载图像 === + # 注意:videl 已经裁过上半(数据生成时仍 1920x1080 等原始分辨率); + # 这里在 _load_video_frames 内同时做 resize 与 top-half 裁剪。 + images = _load_video_frames(s.video_path, history_frames, self.image_h, self.image_w) + # [T, 3, H, W],[0, 1] + if self.do_normalize: + images = (images - DINOV3_MEAN) / DINOV3_STD + + # === 2) 加载内参 / 外参 === + intr_vec = self._load_intrinsic(s) # [14] + + # 当前帧的 cam_to_world 与 vehicle_to_world,得到 cam_to_vehicle + pose_cam_world = self._load_pose_camera(s, s.chunk_offset + t) + pose_veh_world = self._load_pose_vehicle(s, s.chunk_offset + t) + # cam_to_vehicle = inv(vehicle_to_world) @ cam_to_world + inv_veh = torch.linalg.inv(pose_veh_world) + cam2veh = inv_veh @ pose_cam_world + extr_6d = matrix_to_6d(cam2veh) # [6] + + # === 3) 历史 8 帧 ego pose(vehicle 6D)=== + ego_6d = [] + for li in history_label_idx: + T_vw = self._load_pose_vehicle(s, li) + ego_6d.append(matrix_to_6d(T_vw)) + ego_6d = torch.stack(ego_6d, dim=0) # [8, 6] + + # === 4) 检测 / 未来轨迹标签 === + # objs_t / objs_future = 动态 all_object_info ∪ HDMap 静态对象。 + objs_t = self._load_objects(s, s.chunk_offset + t) + objs_future = [self._load_objects(s, li) for li in future_label_idx] + veh_pose_future = [] + for li in future_label_idx: + try: + veh_pose_future.append(self._load_pose_vehicle(s, li)) + except FileNotFoundError: + break + + cam = FThetaCamera.from_vector(intr_vec) + lidar_self = None + if self.use_lidar_occlusion: + try: + lidar_self = _load_lidar_self_frame( + s.labels_dir, + s.chunk_offset + t, + pose_veh_world, + ) + except Exception: + lidar_self = None + + det_targets = build_detection_targets( + objects_t=objs_t, + objects_future=objs_future, + vehicle_pose_t=pose_veh_world, + vehicle_pose_future=veh_pose_future, + cam_intrinsic=cam, + cam2vehicle=cam2veh, + image_h=self.image_h, + image_w=self.image_w, + max_distance_m=self.max_distance_m, + occlusion_depth_tolerance=self.occlusion_tol, + lidar_points_self=lidar_self, + dynamic_classes=self.dynamic_classes, + structured_classes=self.structured_classes, + future_horizon=self.future_horizon, + ) + + ego_future, ego_future_valid = build_ego_future_target( + pose_veh_world, veh_pose_future, horizon=self.future_horizon + ) + + sample_out = { + "images": images, + "ego_6d": ego_6d, + "intr_vec": intr_vec, + "extr_6d": extr_6d, + "ego_future": ego_future, + "ego_future_valid": ego_future_valid, + "targets": det_targets, + "meta": { + "clip_id": s.clip_id, + "chunk_id": s.chunk_id, + "weather": s.weather, + "anchor_t": s.anchor_t, + }, + } + return sample_out + + +def collate_samples(batch: list[dict]) -> dict: + """自定义 collate:对图像 / ego / intr / extr / ego_future 直接 stack; + targets 列表保留为 list(便于匈牙利匹配处理变长 N); + meta 也保留为 list。""" + out = { + "images": torch.stack([b["images"] for b in batch], dim=0), + "ego_6d": torch.stack([b["ego_6d"] for b in batch], dim=0), + "intr_vec": torch.stack([b["intr_vec"] for b in batch], dim=0), + "extr_6d": torch.stack([b["extr_6d"] for b in batch], dim=0), + "ego_future": torch.stack([b["ego_future"] for b in batch], dim=0), + "ego_future_valid": torch.stack([b["ego_future_valid"] for b in batch], dim=0), + "targets": [b["targets"] for b in batch], + "meta": [b["meta"] for b in batch], + } + return out diff --git a/src/wjad/data/ftheta_proj.py b/src/wjad/data/ftheta_proj.py index 9fb0896ebd056cf75955a93bead29961683378d8..3b9fc38f1261fd100108816560956005dfd2e9c3 100644 --- a/src/wjad/data/ftheta_proj.py +++ b/src/wjad/data/ftheta_proj.py @@ -1,62 +1,62 @@ -"""f-theta 正向投影:3D 点(相机系) -> 像素。 - -仅支持 backward polynomial 形式(与 NVIDIA 工具一致), -forward 形式可在内部用牛顿迭代反推。 -""" - -from __future__ import annotations - -import torch - -from ..modules.rays import FThetaCamera - - -def project_points_ftheta( - points_cam: torch.Tensor, # [..., 3],相机系下 3D 点 - cam: FThetaCamera, -) -> tuple[torch.Tensor, torch.Tensor]: - """正向投影:相机系点 -> 像素 ``(u, v)``,并返回深度。 - - 返回 - ---- - uv : [..., 2] - depth : [..., 1],沿主光轴(z)方向的深度(如果 z<0 则视为后方,仍计算 - 但调用方需用 ``depth > 0`` 做有效性筛选)。 - """ - x = points_cam[..., 0] - y = points_cam[..., 1] - z = points_cam[..., 2] - norm = torch.sqrt(x * x + y * y + z * z).clamp_min(1e-6) - cos_theta = z / norm - cos_theta = cos_theta.clamp(-1.0 + 1e-7, 1.0 - 1e-7) - theta = torch.acos(cos_theta) - phi = torch.atan2(y, x) - - if cam.intr.is_bw_poly: - # backward poly 是 r_pix -> theta;正向需要反求 theta -> r_pix。 - # 用牛顿迭代:希望 _eval_poly(r) = theta - r = theta.clone() # 初始猜测 - for _ in range(8): - f = cam._eval_poly(r) - theta - df = cam._eval_poly_grad(r).clamp_min(1e-6) - r = r - f / df - r_pix = r - else: - r_pix = cam._eval_poly(theta) - - cos_p = torch.cos(phi) - sin_p = torch.sin(phi) - du = r_pix * cos_p - dv = r_pix * sin_p - # 反线性修正:linear_cde 是仿射 (du,dv) = M (du0,dv0),正投影需要逆 - c = cam.intr.linear_cde[0] - d = cam.intr.linear_cde[1] - e = cam.intr.linear_cde[2] - # 简化:忽略 linear_cde 的修正(与 unproject 中近似一致) - du0 = du - dv0 = dv - u = du0 + cam.intr.cx - v = dv0 + cam.intr.cy - uv = torch.stack([u, v], dim=-1) - depth = z.unsqueeze(-1) - return uv, depth +"""f-theta 正向投影:3D 点(相机系) -> 像素。 + +仅支持 backward polynomial 形式(与 NVIDIA 工具一致), +forward 形式可在内部用牛顿迭代反推。 +""" + +from __future__ import annotations + +import torch + +from ..modules.rays import FThetaCamera + + +def project_points_ftheta( + points_cam: torch.Tensor, # [..., 3],相机系下 3D 点 + cam: FThetaCamera, +) -> tuple[torch.Tensor, torch.Tensor]: + """正向投影:相机系点 -> 像素 ``(u, v)``,并返回深度。 + + 返回 + ---- + uv : [..., 2] + depth : [..., 1],沿主光轴(z)方向的深度(如果 z<0 则视为后方,仍计算 + 但调用方需用 ``depth > 0`` 做有效性筛选)。 + """ + x = points_cam[..., 0] + y = points_cam[..., 1] + z = points_cam[..., 2] + norm = torch.sqrt(x * x + y * y + z * z).clamp_min(1e-6) + cos_theta = z / norm + cos_theta = cos_theta.clamp(-1.0 + 1e-7, 1.0 - 1e-7) + theta = torch.acos(cos_theta) + phi = torch.atan2(y, x) + + if cam.intr.is_bw_poly: + # backward poly 是 r_pix -> theta;正向需要反求 theta -> r_pix。 + # 用牛顿迭代:希望 _eval_poly(r) = theta + r = theta.clone() # 初始猜测 + for _ in range(8): + f = cam._eval_poly(r) - theta + df = cam._eval_poly_grad(r).clamp_min(1e-6) + r = r - f / df + r_pix = r + else: + r_pix = cam._eval_poly(theta) + + cos_p = torch.cos(phi) + sin_p = torch.sin(phi) + du = r_pix * cos_p + dv = r_pix * sin_p + # 反线性修正:linear_cde 是仿射 (du,dv) = M (du0,dv0),正投影需要逆 + c = cam.intr.linear_cde[0] + d = cam.intr.linear_cde[1] + e = cam.intr.linear_cde[2] + # 简化:忽略 linear_cde 的修正(与 unproject 中近似一致) + du0 = du + dv0 = dv + u = du0 + cam.intr.cx + v = dv0 + cam.intr.cy + uv = torch.stack([u, v], dim=-1) + depth = z.unsqueeze(-1) + return uv, depth diff --git a/src/wjad/data/hdmap.py b/src/wjad/data/hdmap.py index de0ade2aea9f72ffbfb9ccbadc9cb0274ca09c84..b05ce659d13811f9cf78ae5312b89387b8ddf6e9 100644 --- a/src/wjad/data/hdmap.py +++ b/src/wjad/data/hdmap.py @@ -1,247 +1,247 @@ -"""HDMap 3D 标签解析(Cosmos-Drive-Dreams 9 类结构化对象)。 - -输入:clip 标签目录(``labels/{clip_id_full}/``)。 -输出:``list[ObjectTrackInfo]``,每个对象给出 ``object_to_world`` 4x4 + ``lwh``, -``object_type`` 取自 ``HDMAP_SOURCES`` 的 9 类,``is_moving=False``。 - -形状约定(按 README): - - 3d_lanes / lanes.json - labels[i]['labelData']['shape3d']['polylines3d']['polylines'][0/1]['vertices'] - - 3d_lanelines / lanelines.json - labels[i]['labelData']['shape3d']['polyline3d']['vertices'] - - 3d_road_boundaries / road_boundaries.json 同 polyline3d - - 3d_wait_lines / wait_lines.json 同 polyline3d - - 3d_crosswalks / crosswalks.json - labels[i]['labelData']['shape3d']['surface']['vertices'] - - 3d_road_markings / road_markings.json 同 surface - - 3d_poles / poles.json 同 polyline3d - - 3d_traffic_lights / 3d_traffic_lights.json - labels[i]['labelData']['shape3d']['cuboid3d']['vertices'] # 8 角点 - - 3d_traffic_signs / 3d_traffic_signs.json 同 cuboid3d - -折线 → 7-DoF box: - PCA 主方向作 yaw,主/副/竖三向 min-max 作 ``l/w/h``;过长 polyline 按累计 - 弧长切成若干 ``segment_len`` 米的小段,每段一个独立 box(车道线一段太长会 - 超出 max_distance_m,DETR query 也很难一次拟合一整条 100 m 车道线)。 -""" - -from __future__ import annotations - -import json -from pathlib import Path - -import numpy as np -import torch - -from .targets import ObjectTrackInfo -from .label_paths import resolve_clip_file - - -# 折线类长度切分阈值(米) -POLYLINE_SEGMENT_LEN = 10.0 -LANE_SEGMENT_LEN = 15.0 # lanes 是一对左右 polyline,整体粗一点 -MIN_LWH = (0.2, 0.2, 0.05) - - -# cls_name -> (folder, json_name, kind) -HDMAP_SOURCES = { - "lane": ("3d_lanes", "lanes.json", "lane_pair"), - "laneline": ("3d_lanelines", "lanelines.json", "polyline"), - "road_boundary": ("3d_road_boundaries", "road_boundaries.json", "polyline"), - "wait_line": ("3d_wait_lines", "wait_lines.json", "polyline"), - "crosswalk": ("3d_crosswalks", "crosswalks.json", "surface"), - "road_marking": ("3d_road_markings", "road_markings.json", "surface"), - "pole": ("3d_poles", "poles.json", "polyline_short"), - # 磁盘文件名为 ``{clip_stem}.traffic_lights.json``(非 README 里的 3d_*.json) - "traffic_light": ("3d_traffic_lights", "traffic_lights.json", "cuboid"), - "traffic_sign": ("3d_traffic_signs", "traffic_signs.json", "cuboid"), -} - - -def _load_json_labels(path: Path) -> list: - """容错读取:JSON 顶层可能是 ``{labels: ...}`` 或 ``{: {labels: ...}}``。""" - if not path.exists(): - return [] - try: - data = json.loads(path.read_text(encoding="utf-8")) - except Exception: - return [] - if isinstance(data, dict): - if isinstance(data.get("labels"), list): - return data["labels"] - for v in data.values(): - if isinstance(v, dict) and isinstance(v.get("labels"), list): - return v["labels"] - return [] - - -def _verts_to_array(verts) -> np.ndarray: - """vertices 兼容 ``list[[x,y,z]]`` 与 ``list[{x,y,z}]`` 两种格式。""" - if not verts: - return np.zeros((0, 3), dtype=np.float32) - out: list[list[float]] = [] - for v in verts: - if isinstance(v, dict): - out.append([float(v.get("x", 0.0)), float(v.get("y", 0.0)), float(v.get("z", 0.0))]) - elif isinstance(v, (list, tuple)) and len(v) >= 3: - out.append([float(v[0]), float(v[1]), float(v[2])]) - return np.array(out, dtype=np.float32) if out else np.zeros((0, 3), dtype=np.float32) - - -def _split_polyline(verts: np.ndarray, seg_len: float) -> list[np.ndarray]: - """按累计弧长把折线切成若干段。每段顶点数 >=2。""" - if verts.shape[0] < 2: - return [] - edges = np.linalg.norm(np.diff(verts, axis=0), axis=1) - cum = np.concatenate([[0.0], np.cumsum(edges)]) - total = float(cum[-1]) - if total <= seg_len: - return [verts] - n = max(1, int(np.ceil(total / seg_len))) - bounds = np.linspace(0.0, total, n + 1) - chunks: list[np.ndarray] = [] - for i in range(n): - lo, hi = bounds[i], bounds[i + 1] - mask = (cum >= lo - 1e-6) & (cum <= hi + 1e-6) - chunk = verts[mask] - if chunk.shape[0] >= 2: - chunks.append(chunk) - return chunks - - -def _vertices_to_box(verts: np.ndarray) -> tuple[np.ndarray, np.ndarray, float] | None: - """[N, 3] -> (center, lwh, yaw)。""" - if verts.shape[0] < 2: - return None - center = verts.mean(0) - centered_xy = verts[:, :2] - center[:2] - if np.allclose(centered_xy, 0.0): - yaw = 0.0 - else: - cov = centered_xy.T @ centered_xy / max(verts.shape[0] - 1, 1) - _, eigvecs = np.linalg.eigh(cov) - principal = eigvecs[:, -1] - yaw = float(np.arctan2(principal[1], principal[0])) - c, s = float(np.cos(-yaw)), float(np.sin(-yaw)) - rot_xy = centered_xy @ np.array([[c, -s], [s, c]], dtype=np.float32).T - l = float(rot_xy[:, 0].max() - rot_xy[:, 0].min()) - w = float(rot_xy[:, 1].max() - rot_xy[:, 1].min()) - h = float(verts[:, 2].max() - verts[:, 2].min()) - l = max(l, MIN_LWH[0]) - w = max(w, MIN_LWH[1]) - h = max(h, MIN_LWH[2]) - return center.astype(np.float32), np.array([l, w, h], dtype=np.float32), yaw - - -def _cuboid_to_box(corners: np.ndarray) -> tuple[np.ndarray, np.ndarray, float]: - """8 角点 -> (center, lwh, yaw)。用 corner[0]→corner[1] 估计 yaw。""" - center = corners.mean(0) - edge = corners[1] - corners[0] - yaw = float(np.arctan2(edge[1], edge[0])) - c, s = float(np.cos(-yaw)), float(np.sin(-yaw)) - R = np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32) - rot = (corners - center) @ R.T - lwh = (rot.max(0) - rot.min(0)).astype(np.float32) - lwh = np.maximum(lwh, np.array(MIN_LWH, dtype=np.float32)) - return center.astype(np.float32), lwh, yaw - - -def _build_object( - center: np.ndarray, - lwh: np.ndarray, - yaw: float, - cls_name: str, - idx: int, - sub_idx: int = 0, -) -> ObjectTrackInfo: - T = np.eye(4, dtype=np.float32) - c, s = float(np.cos(yaw)), float(np.sin(yaw)) - T[:3, :3] = np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32) - T[:3, 3] = center - return ObjectTrackInfo( - tracking_id=f"hdmap_{cls_name}_{idx}_{sub_idx}", - object_to_world=torch.from_numpy(T), - lwh=torch.from_numpy(lwh), - is_moving=False, - object_type=cls_name, - ) - - -def parse_hdmap_clip( - clip_label_dir: Path, - segment_len: float = POLYLINE_SEGMENT_LEN, - lane_segment_len: float = LANE_SEGMENT_LEN, -) -> list[ObjectTrackInfo]: - """解析一个 clip 的 9 类 HDMap,展开为 world-frame ``ObjectTrackInfo`` 列表。""" - out: list[ObjectTrackInfo] = [] - for cls_name, (subdir, json_name, kind) in HDMAP_SOURCES.items(): - try: - path = resolve_clip_file(clip_label_dir, subdir, json_name) - except FileNotFoundError: - continue - labels = _load_json_labels(path) - for i, lbl in enumerate(labels): - if not isinstance(lbl, dict): - continue - shape = lbl.get("labelData", {}).get("shape3d", {}) - if not isinstance(shape, dict): - continue - - if kind == "cuboid": - verts = shape.get("cuboid3d", {}).get("vertices", []) - arr = _verts_to_array(verts) - if arr.shape[0] != 8: - continue - c, lwh, yaw = _cuboid_to_box(arr) - out.append(_build_object(c, lwh, yaw, cls_name, i)) - - elif kind == "surface": - verts = shape.get("surface", {}).get("vertices", []) - arr = _verts_to_array(verts) - if arr.shape[0] < 3: - continue - box = _vertices_to_box(arr) - if box is not None: - out.append(_build_object(*box, cls_name, i)) - - elif kind == "polyline": - verts = shape.get("polyline3d", {}).get("vertices", []) - arr = _verts_to_array(verts) - if arr.shape[0] < 2: - continue - for j, chunk in enumerate(_split_polyline(arr, segment_len)): - box = _vertices_to_box(chunk) - if box is not None: - out.append(_build_object(*box, cls_name, i, j)) - - elif kind == "polyline_short": - # 杆状物体不切分 - verts = shape.get("polyline3d", {}).get("vertices", []) - arr = _verts_to_array(verts) - if arr.shape[0] < 2: - continue - box = _vertices_to_box(arr) - if box is not None: - out.append(_build_object(*box, cls_name, i)) - - elif kind == "lane_pair": - pl_root = shape.get("polylines3d", {}).get("polylines", []) - if not isinstance(pl_root, list) or len(pl_root) < 2: - continue - left = _verts_to_array( - pl_root[0].get("vertices", []) if isinstance(pl_root[0], dict) else [] - ) - right = _verts_to_array( - pl_root[1].get("vertices", []) if isinstance(pl_root[1], dict) else [] - ) - if left.shape[0] == 0 and right.shape[0] == 0: - continue - merged = np.concatenate([a for a in (left, right) if a.shape[0]], axis=0) - if merged.shape[0] < 2: - continue - for j, chunk in enumerate(_split_polyline(merged, lane_segment_len)): - box = _vertices_to_box(chunk) - if box is not None: - out.append(_build_object(*box, cls_name, i, j)) - - return out +"""HDMap 3D 标签解析(Cosmos-Drive-Dreams 9 类结构化对象)。 + +输入:clip 标签目录(``labels/{clip_id_full}/``)。 +输出:``list[ObjectTrackInfo]``,每个对象给出 ``object_to_world`` 4x4 + ``lwh``, +``object_type`` 取自 ``HDMAP_SOURCES`` 的 9 类,``is_moving=False``。 + +形状约定(按 README): + - 3d_lanes / lanes.json + labels[i]['labelData']['shape3d']['polylines3d']['polylines'][0/1]['vertices'] + - 3d_lanelines / lanelines.json + labels[i]['labelData']['shape3d']['polyline3d']['vertices'] + - 3d_road_boundaries / road_boundaries.json 同 polyline3d + - 3d_wait_lines / wait_lines.json 同 polyline3d + - 3d_crosswalks / crosswalks.json + labels[i]['labelData']['shape3d']['surface']['vertices'] + - 3d_road_markings / road_markings.json 同 surface + - 3d_poles / poles.json 同 polyline3d + - 3d_traffic_lights / 3d_traffic_lights.json + labels[i]['labelData']['shape3d']['cuboid3d']['vertices'] # 8 角点 + - 3d_traffic_signs / 3d_traffic_signs.json 同 cuboid3d + +折线 → 7-DoF box: + PCA 主方向作 yaw,主/副/竖三向 min-max 作 ``l/w/h``;过长 polyline 按累计 + 弧长切成若干 ``segment_len`` 米的小段,每段一个独立 box(车道线一段太长会 + 超出 max_distance_m,DETR query 也很难一次拟合一整条 100 m 车道线)。 +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import numpy as np +import torch + +from .targets import ObjectTrackInfo +from .label_paths import resolve_clip_file + + +# 折线类长度切分阈值(米) +POLYLINE_SEGMENT_LEN = 10.0 +LANE_SEGMENT_LEN = 15.0 # lanes 是一对左右 polyline,整体粗一点 +MIN_LWH = (0.2, 0.2, 0.05) + + +# cls_name -> (folder, json_name, kind) +HDMAP_SOURCES = { + "lane": ("3d_lanes", "lanes.json", "lane_pair"), + "laneline": ("3d_lanelines", "lanelines.json", "polyline"), + "road_boundary": ("3d_road_boundaries", "road_boundaries.json", "polyline"), + "wait_line": ("3d_wait_lines", "wait_lines.json", "polyline"), + "crosswalk": ("3d_crosswalks", "crosswalks.json", "surface"), + "road_marking": ("3d_road_markings", "road_markings.json", "surface"), + "pole": ("3d_poles", "poles.json", "polyline_short"), + # 磁盘文件名为 ``{clip_stem}.traffic_lights.json``(非 README 里的 3d_*.json) + "traffic_light": ("3d_traffic_lights", "traffic_lights.json", "cuboid"), + "traffic_sign": ("3d_traffic_signs", "traffic_signs.json", "cuboid"), +} + + +def _load_json_labels(path: Path) -> list: + """容错读取:JSON 顶层可能是 ``{labels: ...}`` 或 ``{: {labels: ...}}``。""" + if not path.exists(): + return [] + try: + data = json.loads(path.read_text(encoding="utf-8")) + except Exception: + return [] + if isinstance(data, dict): + if isinstance(data.get("labels"), list): + return data["labels"] + for v in data.values(): + if isinstance(v, dict) and isinstance(v.get("labels"), list): + return v["labels"] + return [] + + +def _verts_to_array(verts) -> np.ndarray: + """vertices 兼容 ``list[[x,y,z]]`` 与 ``list[{x,y,z}]`` 两种格式。""" + if not verts: + return np.zeros((0, 3), dtype=np.float32) + out: list[list[float]] = [] + for v in verts: + if isinstance(v, dict): + out.append([float(v.get("x", 0.0)), float(v.get("y", 0.0)), float(v.get("z", 0.0))]) + elif isinstance(v, (list, tuple)) and len(v) >= 3: + out.append([float(v[0]), float(v[1]), float(v[2])]) + return np.array(out, dtype=np.float32) if out else np.zeros((0, 3), dtype=np.float32) + + +def _split_polyline(verts: np.ndarray, seg_len: float) -> list[np.ndarray]: + """按累计弧长把折线切成若干段。每段顶点数 >=2。""" + if verts.shape[0] < 2: + return [] + edges = np.linalg.norm(np.diff(verts, axis=0), axis=1) + cum = np.concatenate([[0.0], np.cumsum(edges)]) + total = float(cum[-1]) + if total <= seg_len: + return [verts] + n = max(1, int(np.ceil(total / seg_len))) + bounds = np.linspace(0.0, total, n + 1) + chunks: list[np.ndarray] = [] + for i in range(n): + lo, hi = bounds[i], bounds[i + 1] + mask = (cum >= lo - 1e-6) & (cum <= hi + 1e-6) + chunk = verts[mask] + if chunk.shape[0] >= 2: + chunks.append(chunk) + return chunks + + +def _vertices_to_box(verts: np.ndarray) -> tuple[np.ndarray, np.ndarray, float] | None: + """[N, 3] -> (center, lwh, yaw)。""" + if verts.shape[0] < 2: + return None + center = verts.mean(0) + centered_xy = verts[:, :2] - center[:2] + if np.allclose(centered_xy, 0.0): + yaw = 0.0 + else: + cov = centered_xy.T @ centered_xy / max(verts.shape[0] - 1, 1) + _, eigvecs = np.linalg.eigh(cov) + principal = eigvecs[:, -1] + yaw = float(np.arctan2(principal[1], principal[0])) + c, s = float(np.cos(-yaw)), float(np.sin(-yaw)) + rot_xy = centered_xy @ np.array([[c, -s], [s, c]], dtype=np.float32).T + l = float(rot_xy[:, 0].max() - rot_xy[:, 0].min()) + w = float(rot_xy[:, 1].max() - rot_xy[:, 1].min()) + h = float(verts[:, 2].max() - verts[:, 2].min()) + l = max(l, MIN_LWH[0]) + w = max(w, MIN_LWH[1]) + h = max(h, MIN_LWH[2]) + return center.astype(np.float32), np.array([l, w, h], dtype=np.float32), yaw + + +def _cuboid_to_box(corners: np.ndarray) -> tuple[np.ndarray, np.ndarray, float]: + """8 角点 -> (center, lwh, yaw)。用 corner[0]→corner[1] 估计 yaw。""" + center = corners.mean(0) + edge = corners[1] - corners[0] + yaw = float(np.arctan2(edge[1], edge[0])) + c, s = float(np.cos(-yaw)), float(np.sin(-yaw)) + R = np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32) + rot = (corners - center) @ R.T + lwh = (rot.max(0) - rot.min(0)).astype(np.float32) + lwh = np.maximum(lwh, np.array(MIN_LWH, dtype=np.float32)) + return center.astype(np.float32), lwh, yaw + + +def _build_object( + center: np.ndarray, + lwh: np.ndarray, + yaw: float, + cls_name: str, + idx: int, + sub_idx: int = 0, +) -> ObjectTrackInfo: + T = np.eye(4, dtype=np.float32) + c, s = float(np.cos(yaw)), float(np.sin(yaw)) + T[:3, :3] = np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32) + T[:3, 3] = center + return ObjectTrackInfo( + tracking_id=f"hdmap_{cls_name}_{idx}_{sub_idx}", + object_to_world=torch.from_numpy(T), + lwh=torch.from_numpy(lwh), + is_moving=False, + object_type=cls_name, + ) + + +def parse_hdmap_clip( + clip_label_dir: Path, + segment_len: float = POLYLINE_SEGMENT_LEN, + lane_segment_len: float = LANE_SEGMENT_LEN, +) -> list[ObjectTrackInfo]: + """解析一个 clip 的 9 类 HDMap,展开为 world-frame ``ObjectTrackInfo`` 列表。""" + out: list[ObjectTrackInfo] = [] + for cls_name, (subdir, json_name, kind) in HDMAP_SOURCES.items(): + try: + path = resolve_clip_file(clip_label_dir, subdir, json_name) + except FileNotFoundError: + continue + labels = _load_json_labels(path) + for i, lbl in enumerate(labels): + if not isinstance(lbl, dict): + continue + shape = lbl.get("labelData", {}).get("shape3d", {}) + if not isinstance(shape, dict): + continue + + if kind == "cuboid": + verts = shape.get("cuboid3d", {}).get("vertices", []) + arr = _verts_to_array(verts) + if arr.shape[0] != 8: + continue + c, lwh, yaw = _cuboid_to_box(arr) + out.append(_build_object(c, lwh, yaw, cls_name, i)) + + elif kind == "surface": + verts = shape.get("surface", {}).get("vertices", []) + arr = _verts_to_array(verts) + if arr.shape[0] < 3: + continue + box = _vertices_to_box(arr) + if box is not None: + out.append(_build_object(*box, cls_name, i)) + + elif kind == "polyline": + verts = shape.get("polyline3d", {}).get("vertices", []) + arr = _verts_to_array(verts) + if arr.shape[0] < 2: + continue + for j, chunk in enumerate(_split_polyline(arr, segment_len)): + box = _vertices_to_box(chunk) + if box is not None: + out.append(_build_object(*box, cls_name, i, j)) + + elif kind == "polyline_short": + # 杆状物体不切分 + verts = shape.get("polyline3d", {}).get("vertices", []) + arr = _verts_to_array(verts) + if arr.shape[0] < 2: + continue + box = _vertices_to_box(arr) + if box is not None: + out.append(_build_object(*box, cls_name, i)) + + elif kind == "lane_pair": + pl_root = shape.get("polylines3d", {}).get("polylines", []) + if not isinstance(pl_root, list) or len(pl_root) < 2: + continue + left = _verts_to_array( + pl_root[0].get("vertices", []) if isinstance(pl_root[0], dict) else [] + ) + right = _verts_to_array( + pl_root[1].get("vertices", []) if isinstance(pl_root[1], dict) else [] + ) + if left.shape[0] == 0 and right.shape[0] == 0: + continue + merged = np.concatenate([a for a in (left, right) if a.shape[0]], axis=0) + if merged.shape[0] < 2: + continue + for j, chunk in enumerate(_split_polyline(merged, lane_segment_len)): + box = _vertices_to_box(chunk) + if box is not None: + out.append(_build_object(*box, cls_name, i, j)) + + return out diff --git a/src/wjad/data/label_paths.py b/src/wjad/data/label_paths.py index 62c4fe0a3d868d817e504d10065ccc279d6fdedd..fb984d2ed5143f1971ff4bce3b1956ced83c41ad 100644 --- a/src/wjad/data/label_paths.py +++ b/src/wjad/data/label_paths.py @@ -1,218 +1,218 @@ -"""数据集标签目录布局解析。 - -README 中的 keys 是相对每个 modality 的 ``.tar`` 根目录的扁平路径; -实际解压后常多一层子目录或 clip stem 前缀。解析失败时在 ``FileNotFoundError`` -里附带目录列表,便于与 Hugging Face 数据集页面中的说明对照。 -""" - -from __future__ import annotations - -from pathlib import Path - - -def _norm_name(s: str) -> str: - return "".join(c for c in s.lower() if c.isalnum()) - - -def _diagnose_labels(labels_dir: Path, folder: str, max_list: int = 50) -> str: - """列出 ``labels///`` 下文件采样 + clip 根下一级子目录。""" - lines: list[str] = [] - sub = labels_dir / folder - if sub.is_dir(): - files = sorted(p for p in sub.rglob("*") if p.is_file()) - lines.append(f"[{folder}/] 下共 {len(files)} 个文件(最多列出 {max_list} 条相对路径):") - for p in files[:max_list]: - try: - rel = p.relative_to(labels_dir).as_posix() - except ValueError: - rel = str(p) - lines.append(f" {rel}") - if len(files) > max_list: - lines.append(f" ... 另有 {len(files) - max_list} 个文件未列出") - else: - lines.append(f"[{folder}/] 不存在:{sub}") - try: - top = sorted(d.name for d in labels_dir.iterdir() if d.is_dir()) - lines.append(f"[labels//] 一级子目录:{top}") - except OSError as e: - lines.append(f"[labels//] 无法列举:{e}") - return "\n".join(lines) - - -def _scan_npy_json_npz( - labels_dir: Path, - folder: str, - fname: str, - *, - exts: tuple[str, ...] = (".npy",), - tokens_norm: list[str], - name_must_contain: str | None = None, -) -> list[Path]: - """在整棵 labels// 下找候选文件:扩展名 + 归一化名须含各 token。""" - root_hint = labels_dir / folder - search_roots = [root_hint] if root_hint.is_dir() else [] - if not search_roots: - search_roots = [labels_dir] - hits: list[Path] = [] - for root in search_roots: - for p in root.rglob("*"): - if not p.is_file(): - continue - if not p.suffix.lower() in [e.lower() for e in exts]: - continue - if name_must_contain and name_must_contain.lower() not in p.name.lower(): - continue - pn = _norm_name(p.name) - if all(tok in pn for tok in tokens_norm if tok): - hits.append(p) - if not hits and root_hint.is_dir(): - for p in labels_dir.rglob("*"): - if not p.is_file() or p.suffix.lower() not in [e.lower() for e in exts]: - continue - if name_must_contain and name_must_contain.lower() not in p.name.lower(): - continue - pn = _norm_name(p.name) - if all(tok in pn for tok in tokens_norm if tok): - hits.append(p) - return hits - - -def resolve_clip_file(labels_dir: Path, *parts: str) -> Path: - """在 ``labels//`` 下解析 ``parts`` 组成的相对路径(首个元素为一级子文件夹)。""" - if not parts: - raise ValueError("parts 不能为空") - if not labels_dir.is_dir(): - raise FileNotFoundError(f"clip 标签根目录不存在: {labels_dir}") - - direct = labels_dir.joinpath(*parts) - if direct.is_file(): - return direct - # NVIDIA 磁盘命名:``{clip_stem}.{README_key}``,clip_stem = ``labels//`` - # 解析后的目录名(含 ``uuid_t0_t1``);README 里的 key 本身不含此前缀。 - clip_stem = labels_dir.resolve().name - if len(parts) >= 2: - folder, fname = parts[0], parts[-1] - if not fname.startswith(f"{clip_stem}."): - stemmed = labels_dir / folder / f"{clip_stem}.{fname}" - if stemmed.is_file(): - return stemmed - if len(parts) >= 2: - folder = parts[0] - rest = parts[1:] - doubled = (labels_dir / folder / folder).joinpath(*rest) - if doubled.is_file(): - return doubled - fname = parts[-1] - folder = parts[0] - sub = labels_dir / folder - if sub.is_dir(): - for p in sub.rglob(fname): - if p.is_file(): - return p - - for p in labels_dir.rglob(fname): - if p.is_file(): - return p - fl = fname.lower() - for p in labels_dir.rglob("*"): - if p.is_file() and p.name.lower() == fl: - return p - - # ftheta / pinhole - if folder in ("ftheta_intrinsic", "pinhole_intrinsic") and fname.endswith(".npy"): - prefix = folder + "." - if fname.lower().startswith(prefix.lower()): - cam = fname[len(prefix) : -len(".npy")] - cam_n = _norm_name(cam) - hits = [] - for p in labels_dir.rglob("*.npy"): - if not p.is_file(): - continue - pn = _norm_name(p.name) - if folder == "ftheta_intrinsic": - if "ftheta" not in pn: - continue - else: - if "pinhole" not in pn: - continue - if cam_n and cam_n in pn: - hits.append(p) - if len(hits) == 1: - return hits[0] - if len(hits) > 1: - hits.sort(key=lambda x: (len(x.parts), str(x))) - return hits[0] - - # pose: ``{idx:06d}.pose.{camera}.npy`` - if folder == "pose" and fname.endswith(".npy"): - base = fname[: -len(".npy")] - if ".pose." in base: - idx_part, _, cam_part = base.partition(".pose.") - hits = _scan_npy_json_npz( - labels_dir, - folder, - fname, - exts=(".npy",), - tokens_norm=[_norm_name(idx_part), _norm_name(cam_part)], - name_must_contain="pose", - ) - if len(hits) == 1: - return hits[0] - if len(hits) > 1: - hits.sort(key=lambda x: (len(x.parts), -len(x.name), str(x))) - return hits[0] - - # vehicle_pose: ``{idx:06d}.vehicle_pose.npy`` - if folder == "vehicle_pose" and fname.endswith(".npy"): - idx_part = fname.split(".")[0] - hits = _scan_npy_json_npz( - labels_dir, - folder, - fname, - exts=(".npy",), - tokens_norm=[_norm_name(idx_part), "vehiclepose"], - name_must_contain="vehicle", - ) - if len(hits) == 1: - return hits[0] - if len(hits) > 1: - hits.sort(key=lambda x: (len(x.parts), str(x))) - return hits[0] - - # all_object_info - if folder == "all_object_info" and fname.endswith(".json"): - idx_part = fname.split(".")[0] - hits = _scan_npy_json_npz( - labels_dir, - folder, - fname, - exts=(".json",), - tokens_norm=[_norm_name(idx_part), "allobjectinfo"], - ) - if len(hits) == 1: - return hits[0] - if len(hits) > 1: - hits.sort(key=lambda x: (len(x.parts), str(x))) - return hits[0] - - # lidar_raw - if folder == "lidar_raw" and fname.endswith(".npz"): - stem = fname[: -len(".npz")] - hits = _scan_npy_json_npz( - labels_dir, - folder, - fname, - exts=(".npz",), - tokens_norm=[_norm_name(stem), "lidar", "raw"], - ) - if len(hits) == 1: - return hits[0] - if len(hits) > 1: - hits.sort(key=lambda x: (len(x.parts), str(x))) - return hits[0] - - detail = _diagnose_labels(labels_dir, folder) - raise FileNotFoundError( - f"在 {labels_dir} 下未找到 {'/'.join(parts)}(已尝试 README 扁平路径、双嵌套、" - f"rglob、按帧索引+相机的扫描匹配)。\n{detail}" - ) +"""数据集标签目录布局解析。 + +README 中的 keys 是相对每个 modality 的 ``.tar`` 根目录的扁平路径; +实际解压后常多一层子目录或 clip stem 前缀。解析失败时在 ``FileNotFoundError`` +里附带目录列表,便于与 Hugging Face 数据集页面中的说明对照。 +""" + +from __future__ import annotations + +from pathlib import Path + + +def _norm_name(s: str) -> str: + return "".join(c for c in s.lower() if c.isalnum()) + + +def _diagnose_labels(labels_dir: Path, folder: str, max_list: int = 50) -> str: + """列出 ``labels///`` 下文件采样 + clip 根下一级子目录。""" + lines: list[str] = [] + sub = labels_dir / folder + if sub.is_dir(): + files = sorted(p for p in sub.rglob("*") if p.is_file()) + lines.append(f"[{folder}/] 下共 {len(files)} 个文件(最多列出 {max_list} 条相对路径):") + for p in files[:max_list]: + try: + rel = p.relative_to(labels_dir).as_posix() + except ValueError: + rel = str(p) + lines.append(f" {rel}") + if len(files) > max_list: + lines.append(f" ... 另有 {len(files) - max_list} 个文件未列出") + else: + lines.append(f"[{folder}/] 不存在:{sub}") + try: + top = sorted(d.name for d in labels_dir.iterdir() if d.is_dir()) + lines.append(f"[labels//] 一级子目录:{top}") + except OSError as e: + lines.append(f"[labels//] 无法列举:{e}") + return "\n".join(lines) + + +def _scan_npy_json_npz( + labels_dir: Path, + folder: str, + fname: str, + *, + exts: tuple[str, ...] = (".npy",), + tokens_norm: list[str], + name_must_contain: str | None = None, +) -> list[Path]: + """在整棵 labels// 下找候选文件:扩展名 + 归一化名须含各 token。""" + root_hint = labels_dir / folder + search_roots = [root_hint] if root_hint.is_dir() else [] + if not search_roots: + search_roots = [labels_dir] + hits: list[Path] = [] + for root in search_roots: + for p in root.rglob("*"): + if not p.is_file(): + continue + if not p.suffix.lower() in [e.lower() for e in exts]: + continue + if name_must_contain and name_must_contain.lower() not in p.name.lower(): + continue + pn = _norm_name(p.name) + if all(tok in pn for tok in tokens_norm if tok): + hits.append(p) + if not hits and root_hint.is_dir(): + for p in labels_dir.rglob("*"): + if not p.is_file() or p.suffix.lower() not in [e.lower() for e in exts]: + continue + if name_must_contain and name_must_contain.lower() not in p.name.lower(): + continue + pn = _norm_name(p.name) + if all(tok in pn for tok in tokens_norm if tok): + hits.append(p) + return hits + + +def resolve_clip_file(labels_dir: Path, *parts: str) -> Path: + """在 ``labels//`` 下解析 ``parts`` 组成的相对路径(首个元素为一级子文件夹)。""" + if not parts: + raise ValueError("parts 不能为空") + if not labels_dir.is_dir(): + raise FileNotFoundError(f"clip 标签根目录不存在: {labels_dir}") + + direct = labels_dir.joinpath(*parts) + if direct.is_file(): + return direct + # NVIDIA 磁盘命名:``{clip_stem}.{README_key}``,clip_stem = ``labels//`` + # 解析后的目录名(含 ``uuid_t0_t1``);README 里的 key 本身不含此前缀。 + clip_stem = labels_dir.resolve().name + if len(parts) >= 2: + folder, fname = parts[0], parts[-1] + if not fname.startswith(f"{clip_stem}."): + stemmed = labels_dir / folder / f"{clip_stem}.{fname}" + if stemmed.is_file(): + return stemmed + if len(parts) >= 2: + folder = parts[0] + rest = parts[1:] + doubled = (labels_dir / folder / folder).joinpath(*rest) + if doubled.is_file(): + return doubled + fname = parts[-1] + folder = parts[0] + sub = labels_dir / folder + if sub.is_dir(): + for p in sub.rglob(fname): + if p.is_file(): + return p + + for p in labels_dir.rglob(fname): + if p.is_file(): + return p + fl = fname.lower() + for p in labels_dir.rglob("*"): + if p.is_file() and p.name.lower() == fl: + return p + + # ftheta / pinhole + if folder in ("ftheta_intrinsic", "pinhole_intrinsic") and fname.endswith(".npy"): + prefix = folder + "." + if fname.lower().startswith(prefix.lower()): + cam = fname[len(prefix) : -len(".npy")] + cam_n = _norm_name(cam) + hits = [] + for p in labels_dir.rglob("*.npy"): + if not p.is_file(): + continue + pn = _norm_name(p.name) + if folder == "ftheta_intrinsic": + if "ftheta" not in pn: + continue + else: + if "pinhole" not in pn: + continue + if cam_n and cam_n in pn: + hits.append(p) + if len(hits) == 1: + return hits[0] + if len(hits) > 1: + hits.sort(key=lambda x: (len(x.parts), str(x))) + return hits[0] + + # pose: ``{idx:06d}.pose.{camera}.npy`` + if folder == "pose" and fname.endswith(".npy"): + base = fname[: -len(".npy")] + if ".pose." in base: + idx_part, _, cam_part = base.partition(".pose.") + hits = _scan_npy_json_npz( + labels_dir, + folder, + fname, + exts=(".npy",), + tokens_norm=[_norm_name(idx_part), _norm_name(cam_part)], + name_must_contain="pose", + ) + if len(hits) == 1: + return hits[0] + if len(hits) > 1: + hits.sort(key=lambda x: (len(x.parts), -len(x.name), str(x))) + return hits[0] + + # vehicle_pose: ``{idx:06d}.vehicle_pose.npy`` + if folder == "vehicle_pose" and fname.endswith(".npy"): + idx_part = fname.split(".")[0] + hits = _scan_npy_json_npz( + labels_dir, + folder, + fname, + exts=(".npy",), + tokens_norm=[_norm_name(idx_part), "vehiclepose"], + name_must_contain="vehicle", + ) + if len(hits) == 1: + return hits[0] + if len(hits) > 1: + hits.sort(key=lambda x: (len(x.parts), str(x))) + return hits[0] + + # all_object_info + if folder == "all_object_info" and fname.endswith(".json"): + idx_part = fname.split(".")[0] + hits = _scan_npy_json_npz( + labels_dir, + folder, + fname, + exts=(".json",), + tokens_norm=[_norm_name(idx_part), "allobjectinfo"], + ) + if len(hits) == 1: + return hits[0] + if len(hits) > 1: + hits.sort(key=lambda x: (len(x.parts), str(x))) + return hits[0] + + # lidar_raw + if folder == "lidar_raw" and fname.endswith(".npz"): + stem = fname[: -len(".npz")] + hits = _scan_npy_json_npz( + labels_dir, + folder, + fname, + exts=(".npz",), + tokens_norm=[_norm_name(stem), "lidar", "raw"], + ) + if len(hits) == 1: + return hits[0] + if len(hits) > 1: + hits.sort(key=lambda x: (len(x.parts), str(x))) + return hits[0] + + detail = _diagnose_labels(labels_dir, folder) + raise FileNotFoundError( + f"在 {labels_dir} 下未找到 {'/'.join(parts)}(已尝试 README 扁平路径、双嵌套、" + f"rglob、按帧索引+相机的扫描匹配)。\n{detail}" + ) diff --git a/src/wjad/data/se3.py b/src/wjad/data/se3.py index 144d5016cba0b0177fbaff183a94f47898ed3dcb..44e3540c8e981cb97c534b9929c1320969c03e8a 100644 --- a/src/wjad/data/se3.py +++ b/src/wjad/data/se3.py @@ -1,111 +1,111 @@ -"""SE(3) 与 6D 表示之间的转换。 - -约定:6D = ``[tx, ty, tz, rx, ry, rz]``,rotation 为轴角向量(``angle * axis``)。 -平移单位为米;旋转角弧度。 -""" - -from __future__ import annotations - -import numpy as np -import torch - - -def rotation_matrix_to_axis_angle(R: torch.Tensor | np.ndarray) -> torch.Tensor: - """3x3 旋转矩阵 -> 轴角向量 ``[3]`` (=angle * axis),支持 batch。 - - 使用 Rodrigues 公式数值反求。 - """ - if isinstance(R, np.ndarray): - R = torch.from_numpy(R).float() - if R.dim() == 2: - R = R.unsqueeze(0) - single = True - else: - single = False - - trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2] - cos_theta = ((trace - 1.0) * 0.5).clamp(-1.0 + 1e-7, 1.0 - 1e-7) - theta = torch.acos(cos_theta) # [B] - - # 提取轴向量 - rx = R[..., 2, 1] - R[..., 1, 2] - ry = R[..., 0, 2] - R[..., 2, 0] - rz = R[..., 1, 0] - R[..., 0, 1] - axis = torch.stack([rx, ry, rz], dim=-1) - sin_theta = torch.sin(theta).clamp_min(1e-7) - axis = axis / (2.0 * sin_theta).unsqueeze(-1) - - aa = axis * theta.unsqueeze(-1) - if single: - aa = aa.squeeze(0) - return aa - - -def axis_angle_to_rotation_matrix(aa: torch.Tensor) -> torch.Tensor: - """轴角向量 ``[..., 3]`` -> 旋转矩阵 ``[..., 3, 3]``(Rodrigues)。""" - theta = aa.norm(dim=-1, keepdim=True).clamp_min(1e-9) # [..., 1] - axis = aa / theta - x, y, z = axis[..., 0], axis[..., 1], axis[..., 2] - sin_t = torch.sin(theta.squeeze(-1)) - cos_t = torch.cos(theta.squeeze(-1)) - one_c = 1.0 - cos_t - - R = torch.stack( - [ - cos_t + x * x * one_c, x * y * one_c - z * sin_t, x * z * one_c + y * sin_t, - y * x * one_c + z * sin_t, cos_t + y * y * one_c, y * z * one_c - x * sin_t, - z * x * one_c - y * sin_t, z * y * one_c + x * sin_t, cos_t + z * z * one_c, - ], - dim=-1, - ).reshape(*aa.shape[:-1], 3, 3) - return R - - -def matrix_to_6d(T: torch.Tensor | np.ndarray) -> torch.Tensor: - """4x4 SE(3) -> 6D ``[tx, ty, tz, rx, ry, rz]``。""" - if isinstance(T, np.ndarray): - T = torch.from_numpy(T).float() - if T.dim() == 2: - T = T.unsqueeze(0) - single = True - else: - single = False - - R = T[..., :3, :3] - t = T[..., :3, 3] - aa = rotation_matrix_to_axis_angle(R) - six = torch.cat([t, aa], dim=-1) - if single: - six = six.squeeze(0) - return six - - -def six_d_to_matrix(six: torch.Tensor) -> torch.Tensor: - """6D -> 4x4 SE(3)。""" - if six.dim() == 1: - six = six.unsqueeze(0) - single = True - else: - single = False - t = six[..., :3] - aa = six[..., 3:] - R = axis_angle_to_rotation_matrix(aa) - T = torch.zeros(*six.shape[:-1], 4, 4, dtype=six.dtype, device=six.device) - T[..., :3, :3] = R - T[..., :3, 3] = t - T[..., 3, 3] = 1.0 - if single: - T = T.squeeze(0) - return T - - -def invert_se3(T: torch.Tensor) -> torch.Tensor: - """4x4 SE(3) 逆,``[..., 4, 4]``。""" - R = T[..., :3, :3] - t = T[..., :3, 3:4] - Rt = R.transpose(-2, -1) - inv = torch.zeros_like(T) - inv[..., :3, :3] = Rt - inv[..., :3, 3:4] = -Rt @ t - inv[..., 3, 3] = 1.0 - return inv +"""SE(3) 与 6D 表示之间的转换。 + +约定:6D = ``[tx, ty, tz, rx, ry, rz]``,rotation 为轴角向量(``angle * axis``)。 +平移单位为米;旋转角弧度。 +""" + +from __future__ import annotations + +import numpy as np +import torch + + +def rotation_matrix_to_axis_angle(R: torch.Tensor | np.ndarray) -> torch.Tensor: + """3x3 旋转矩阵 -> 轴角向量 ``[3]`` (=angle * axis),支持 batch。 + + 使用 Rodrigues 公式数值反求。 + """ + if isinstance(R, np.ndarray): + R = torch.from_numpy(R).float() + if R.dim() == 2: + R = R.unsqueeze(0) + single = True + else: + single = False + + trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2] + cos_theta = ((trace - 1.0) * 0.5).clamp(-1.0 + 1e-7, 1.0 - 1e-7) + theta = torch.acos(cos_theta) # [B] + + # 提取轴向量 + rx = R[..., 2, 1] - R[..., 1, 2] + ry = R[..., 0, 2] - R[..., 2, 0] + rz = R[..., 1, 0] - R[..., 0, 1] + axis = torch.stack([rx, ry, rz], dim=-1) + sin_theta = torch.sin(theta).clamp_min(1e-7) + axis = axis / (2.0 * sin_theta).unsqueeze(-1) + + aa = axis * theta.unsqueeze(-1) + if single: + aa = aa.squeeze(0) + return aa + + +def axis_angle_to_rotation_matrix(aa: torch.Tensor) -> torch.Tensor: + """轴角向量 ``[..., 3]`` -> 旋转矩阵 ``[..., 3, 3]``(Rodrigues)。""" + theta = aa.norm(dim=-1, keepdim=True).clamp_min(1e-9) # [..., 1] + axis = aa / theta + x, y, z = axis[..., 0], axis[..., 1], axis[..., 2] + sin_t = torch.sin(theta.squeeze(-1)) + cos_t = torch.cos(theta.squeeze(-1)) + one_c = 1.0 - cos_t + + R = torch.stack( + [ + cos_t + x * x * one_c, x * y * one_c - z * sin_t, x * z * one_c + y * sin_t, + y * x * one_c + z * sin_t, cos_t + y * y * one_c, y * z * one_c - x * sin_t, + z * x * one_c - y * sin_t, z * y * one_c + x * sin_t, cos_t + z * z * one_c, + ], + dim=-1, + ).reshape(*aa.shape[:-1], 3, 3) + return R + + +def matrix_to_6d(T: torch.Tensor | np.ndarray) -> torch.Tensor: + """4x4 SE(3) -> 6D ``[tx, ty, tz, rx, ry, rz]``。""" + if isinstance(T, np.ndarray): + T = torch.from_numpy(T).float() + if T.dim() == 2: + T = T.unsqueeze(0) + single = True + else: + single = False + + R = T[..., :3, :3] + t = T[..., :3, 3] + aa = rotation_matrix_to_axis_angle(R) + six = torch.cat([t, aa], dim=-1) + if single: + six = six.squeeze(0) + return six + + +def six_d_to_matrix(six: torch.Tensor) -> torch.Tensor: + """6D -> 4x4 SE(3)。""" + if six.dim() == 1: + six = six.unsqueeze(0) + single = True + else: + single = False + t = six[..., :3] + aa = six[..., 3:] + R = axis_angle_to_rotation_matrix(aa) + T = torch.zeros(*six.shape[:-1], 4, 4, dtype=six.dtype, device=six.device) + T[..., :3, :3] = R + T[..., :3, 3] = t + T[..., 3, 3] = 1.0 + if single: + T = T.squeeze(0) + return T + + +def invert_se3(T: torch.Tensor) -> torch.Tensor: + """4x4 SE(3) 逆,``[..., 4, 4]``。""" + R = T[..., :3, :3] + t = T[..., :3, 3:4] + Rt = R.transpose(-2, -1) + inv = torch.zeros_like(T) + inv[..., :3, :3] = Rt + inv[..., :3, 3:4] = -Rt @ t + inv[..., 3, 3] = 1.0 + return inv diff --git a/src/wjad/data/targets.py b/src/wjad/data/targets.py index 325a8742b1495889dcc60bf234e037b1023708c6..2cea5efa047d715724fd0bd08c0bc769939e6525 100644 --- a/src/wjad/data/targets.py +++ b/src/wjad/data/targets.py @@ -1,214 +1,214 @@ -"""检测 / 自车未来轨迹的目标构建。 - -依据 Cosmos-Drive-Dreams 数据集 README: - all_object_info JSON 中以 ``tracking_id`` 为 key,存储 - ``{object_to_world: 4x4, object_lwh: [l,w,h], object_is_moving: bool, object_type: str}``。 - -构建步骤: -1. 把每个对象的 ``object_to_world`` 转到 t 时刻自车系: - object_to_self = inv(vehicle_pose_t) @ object_to_world -2. 距离 ``≤ max_distance_m`` 过滤; -3. 投影中心点到当前帧像素,要求落在视锥内; -4. 用 LIDAR 深度对比做遮挡剔除(粗粒度); -5. 对动态目标,从 t+1..t+24 帧逐帧获取其 ``object_to_world``,转到 t 自车系, - 提取 (dx, dy, dyaw) 并做 symlog 归一作为未来轨迹 GT;缺帧时 ``valid=0``。 - -为方便与 head 输出对齐,最终输出格式: - {"labels": [N], "boxes": [N, 7], "is_dynamic": [N], - "future_traj": [N, 24, 3], "future_valid": [N, 24]} -""" - -from __future__ import annotations - -from dataclasses import dataclass - -import numpy as np -import torch - -from ..modules.normalization import symlog -from ..modules.rays import FThetaCamera -from .ftheta_proj import project_points_ftheta -from .se3 import invert_se3 - - -@dataclass -class ObjectTrackInfo: - """单个对象在某帧的简化记录。""" - - tracking_id: str - object_to_world: torch.Tensor # [4, 4] - lwh: torch.Tensor # [3] - is_moving: bool - object_type: str - - -def _yaw_from_rotation_matrix(R: torch.Tensor) -> torch.Tensor: - """从 3x3 旋转矩阵提取自车系下绕 z 轴的 yaw 角。 - - 使用 ``atan2(R[1,0], R[0,0])``。 - """ - return torch.atan2(R[..., 1, 0], R[..., 0, 0]) - - -def _make_class_index(object_type: str, dynamic_classes: list[str], structured_classes: list[str], background_idx: int = 0) -> tuple[int, int]: - """根据 object_type 字符串映射到 (class_index, is_dynamic)。""" - if object_type in dynamic_classes: - return dynamic_classes.index(object_type) + 1, 1 # +1 为 background 留 idx 0 - if object_type in structured_classes: - return len(dynamic_classes) + structured_classes.index(object_type) + 1, 0 - return background_idx, 0 # 未知类型当 background - - -def build_detection_targets( - objects_t: list[ObjectTrackInfo], - objects_future: list[list[ObjectTrackInfo]], # len = future_horizon,每帧一个对象列表 - vehicle_pose_t: torch.Tensor, # [4, 4],vehicle to world - vehicle_pose_future: list[torch.Tensor], # 每帧一个 4x4 - cam_intrinsic: FThetaCamera, - cam2vehicle: torch.Tensor, # [4, 4] - image_h: int, - image_w: int, - max_distance_m: float = 48.0, - occlusion_depth_tolerance: float = 0.5, - lidar_points_self: torch.Tensor | None = None, # [P, 3] in self frame,做粗遮挡 - dynamic_classes: list[str] | None = None, - structured_classes: list[str] | None = None, - future_horizon: int = 24, -) -> dict: - """构建一个样本的检测+未来轨迹标签。""" - if dynamic_classes is None: - dynamic_classes = [] - if structured_classes is None: - structured_classes = [] - - inv_pose_t = invert_se3(vehicle_pose_t) - vehicle2cam = invert_se3(cam2vehicle) - - labels: list[int] = [] - boxes: list[list[float]] = [] - is_dynamic: list[int] = [] - future_traj: list[list[list[float]]] = [] - future_valid: list[list[int]] = [] - - for obj in objects_t: - T_obj_self = inv_pose_t @ obj.object_to_world # [4,4] - center_self = T_obj_self[:3, 3] - - dist = float(center_self.norm().item()) - if dist > max_distance_m: - continue - - # 视锥裁剪:把中心投影到相机系再投影到像素 - center_cam = (vehicle2cam @ torch.cat([center_self, torch.ones(1)])[None].T).squeeze(-1)[:3] - if center_cam[2].item() <= 0: - continue - uv, depth = project_points_ftheta(center_cam.unsqueeze(0), cam_intrinsic) - u, v = uv[0, 0].item(), uv[0, 1].item() - if not (0 <= u < image_w and 0 <= v < image_h): - continue - - # LIDAR 遮挡:找到 LIDAR 中靠近当前射线方向的最近点深度,与对象深度对比 - if lidar_points_self is not None and lidar_points_self.numel() > 0: - ray = center_self / (center_self.norm() + 1e-6) - proj = lidar_points_self @ ray # [P] - # 选取沿射线方向投影距离接近 dist 的点(容差 1m,水平角 5°) - cosang = (lidar_points_self / (lidar_points_self.norm(dim=-1, keepdim=True) + 1e-6)) @ ray - mask = (cosang > 0.996) & (proj > 0) - if mask.any(): - lidar_depth = proj[mask].min().item() - if lidar_depth + occlusion_depth_tolerance < dist: - # LIDAR 击中前方更近物体 -> 当前对象被遮挡 - continue - - # 类别映射 - cls_idx, is_dyn = _make_class_index(obj.object_type, dynamic_classes, structured_classes) - if cls_idx == 0: - continue - labels.append(cls_idx) - is_dynamic.append(is_dyn) - - yaw = _yaw_from_rotation_matrix(T_obj_self[:3, :3]).item() - l, w, h = obj.lwh.tolist() - # box 坐标 symlog 归一 - x_n, y_n, z_n = ( - float(symlog(center_self[0]).item()), - float(symlog(center_self[1]).item()), - float(symlog(center_self[2]).item()), - ) - l_n = float(symlog(torch.tensor(l)).item()) - w_n = float(symlog(torch.tensor(w)).item()) - h_n = float(symlog(torch.tensor(h)).item()) - boxes.append([x_n, y_n, z_n, l_n, w_n, h_n, yaw]) - - # 未来轨迹:在当前 self 系下用 (dx, dy, dyaw),相对 t 时刻对象自身 - # 先取 t 时刻对象在 self 系下的 (x_t, y_t, yaw_t) - x0, y0, yaw0 = center_self[0].item(), center_self[1].item(), yaw - future_3 = [] - future_v = [] - for k in range(future_horizon): - if k >= len(objects_future) or k >= len(vehicle_pose_future): - future_3.append([0.0, 0.0, 0.0]) - future_v.append(0) - continue - # 找对象在 t+k+1 帧 - future_objs = objects_future[k] - match = next((o for o in future_objs if o.tracking_id == obj.tracking_id), None) - if match is None: - future_3.append([0.0, 0.0, 0.0]) - future_v.append(0) - continue - T_obj_self_future = invert_se3(vehicle_pose_t) @ match.object_to_world - xf = T_obj_self_future[0, 3].item() - yf = T_obj_self_future[1, 3].item() - yawf = _yaw_from_rotation_matrix(T_obj_self_future[:3, :3]).item() - dx = xf - x0 - dy = yf - y0 - dyaw = yawf - yaw0 - # 角度归到 (-pi, pi] - dyaw = (dyaw + np.pi) % (2 * np.pi) - np.pi - future_3.append([ - float(symlog(torch.tensor(dx)).item()), - float(symlog(torch.tensor(dy)).item()), - float(dyaw), - ]) - future_v.append(1) - future_traj.append(future_3) - future_valid.append(future_v) - - if not labels: - return { - "labels": torch.zeros(0, dtype=torch.long), - "boxes": torch.zeros(0, 7), - "is_dynamic": torch.zeros(0, dtype=torch.long), - "future_traj": torch.zeros(0, future_horizon, 3), - "future_valid": torch.zeros(0, future_horizon, dtype=torch.bool), - } - return { - "labels": torch.tensor(labels, dtype=torch.long), - "boxes": torch.tensor(boxes, dtype=torch.float32), - "is_dynamic": torch.tensor(is_dynamic, dtype=torch.long), - "future_traj": torch.tensor(future_traj, dtype=torch.float32), - "future_valid": torch.tensor(future_valid, dtype=torch.bool), - } - - -def build_ego_future_target( - vehicle_pose_t: torch.Tensor, - vehicle_pose_future: list[torch.Tensor], - horizon: int = 24, -) -> tuple[torch.Tensor, torch.Tensor]: - """自车未来 24 帧轨迹(在 t 自车系下,``(x, y, yaw)`` 已 symlog 归一)。""" - inv_t = invert_se3(vehicle_pose_t) - out = torch.zeros(horizon, 3) - valid = torch.zeros(horizon, dtype=torch.bool) - for k in range(horizon): - if k >= len(vehicle_pose_future): - break - rel = inv_t @ vehicle_pose_future[k] - x, y = rel[0, 3].item(), rel[1, 3].item() - yaw = _yaw_from_rotation_matrix(rel[:3, :3]).item() - out[k, 0] = symlog(torch.tensor(x)) - out[k, 1] = symlog(torch.tensor(y)) - out[k, 2] = yaw - valid[k] = True - return out, valid +"""检测 / 自车未来轨迹的目标构建。 + +依据 Cosmos-Drive-Dreams 数据集 README: + all_object_info JSON 中以 ``tracking_id`` 为 key,存储 + ``{object_to_world: 4x4, object_lwh: [l,w,h], object_is_moving: bool, object_type: str}``。 + +构建步骤: +1. 把每个对象的 ``object_to_world`` 转到 t 时刻自车系: + object_to_self = inv(vehicle_pose_t) @ object_to_world +2. 距离 ``≤ max_distance_m`` 过滤; +3. 投影中心点到当前帧像素,要求落在视锥内; +4. 用 LIDAR 深度对比做遮挡剔除(粗粒度); +5. 对动态目标,从 t+1..t+24 帧逐帧获取其 ``object_to_world``,转到 t 自车系, + 提取 (dx, dy, dyaw) 并做 symlog 归一作为未来轨迹 GT;缺帧时 ``valid=0``。 + +为方便与 head 输出对齐,最终输出格式: + {"labels": [N], "boxes": [N, 7], "is_dynamic": [N], + "future_traj": [N, 24, 3], "future_valid": [N, 24]} +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np +import torch + +from ..modules.normalization import symlog +from ..modules.rays import FThetaCamera +from .ftheta_proj import project_points_ftheta +from .se3 import invert_se3 + + +@dataclass +class ObjectTrackInfo: + """单个对象在某帧的简化记录。""" + + tracking_id: str + object_to_world: torch.Tensor # [4, 4] + lwh: torch.Tensor # [3] + is_moving: bool + object_type: str + + +def _yaw_from_rotation_matrix(R: torch.Tensor) -> torch.Tensor: + """从 3x3 旋转矩阵提取自车系下绕 z 轴的 yaw 角。 + + 使用 ``atan2(R[1,0], R[0,0])``。 + """ + return torch.atan2(R[..., 1, 0], R[..., 0, 0]) + + +def _make_class_index(object_type: str, dynamic_classes: list[str], structured_classes: list[str], background_idx: int = 0) -> tuple[int, int]: + """根据 object_type 字符串映射到 (class_index, is_dynamic)。""" + if object_type in dynamic_classes: + return dynamic_classes.index(object_type) + 1, 1 # +1 为 background 留 idx 0 + if object_type in structured_classes: + return len(dynamic_classes) + structured_classes.index(object_type) + 1, 0 + return background_idx, 0 # 未知类型当 background + + +def build_detection_targets( + objects_t: list[ObjectTrackInfo], + objects_future: list[list[ObjectTrackInfo]], # len = future_horizon,每帧一个对象列表 + vehicle_pose_t: torch.Tensor, # [4, 4],vehicle to world + vehicle_pose_future: list[torch.Tensor], # 每帧一个 4x4 + cam_intrinsic: FThetaCamera, + cam2vehicle: torch.Tensor, # [4, 4] + image_h: int, + image_w: int, + max_distance_m: float = 48.0, + occlusion_depth_tolerance: float = 0.5, + lidar_points_self: torch.Tensor | None = None, # [P, 3] in self frame,做粗遮挡 + dynamic_classes: list[str] | None = None, + structured_classes: list[str] | None = None, + future_horizon: int = 24, +) -> dict: + """构建一个样本的检测+未来轨迹标签。""" + if dynamic_classes is None: + dynamic_classes = [] + if structured_classes is None: + structured_classes = [] + + inv_pose_t = invert_se3(vehicle_pose_t) + vehicle2cam = invert_se3(cam2vehicle) + + labels: list[int] = [] + boxes: list[list[float]] = [] + is_dynamic: list[int] = [] + future_traj: list[list[list[float]]] = [] + future_valid: list[list[int]] = [] + + for obj in objects_t: + T_obj_self = inv_pose_t @ obj.object_to_world # [4,4] + center_self = T_obj_self[:3, 3] + + dist = float(center_self.norm().item()) + if dist > max_distance_m: + continue + + # 视锥裁剪:把中心投影到相机系再投影到像素 + center_cam = (vehicle2cam @ torch.cat([center_self, torch.ones(1)])[None].T).squeeze(-1)[:3] + if center_cam[2].item() <= 0: + continue + uv, depth = project_points_ftheta(center_cam.unsqueeze(0), cam_intrinsic) + u, v = uv[0, 0].item(), uv[0, 1].item() + if not (0 <= u < image_w and 0 <= v < image_h): + continue + + # LIDAR 遮挡:找到 LIDAR 中靠近当前射线方向的最近点深度,与对象深度对比 + if lidar_points_self is not None and lidar_points_self.numel() > 0: + ray = center_self / (center_self.norm() + 1e-6) + proj = lidar_points_self @ ray # [P] + # 选取沿射线方向投影距离接近 dist 的点(容差 1m,水平角 5°) + cosang = (lidar_points_self / (lidar_points_self.norm(dim=-1, keepdim=True) + 1e-6)) @ ray + mask = (cosang > 0.996) & (proj > 0) + if mask.any(): + lidar_depth = proj[mask].min().item() + if lidar_depth + occlusion_depth_tolerance < dist: + # LIDAR 击中前方更近物体 -> 当前对象被遮挡 + continue + + # 类别映射 + cls_idx, is_dyn = _make_class_index(obj.object_type, dynamic_classes, structured_classes) + if cls_idx == 0: + continue + labels.append(cls_idx) + is_dynamic.append(is_dyn) + + yaw = _yaw_from_rotation_matrix(T_obj_self[:3, :3]).item() + l, w, h = obj.lwh.tolist() + # box 坐标 symlog 归一 + x_n, y_n, z_n = ( + float(symlog(center_self[0]).item()), + float(symlog(center_self[1]).item()), + float(symlog(center_self[2]).item()), + ) + l_n = float(symlog(torch.tensor(l)).item()) + w_n = float(symlog(torch.tensor(w)).item()) + h_n = float(symlog(torch.tensor(h)).item()) + boxes.append([x_n, y_n, z_n, l_n, w_n, h_n, yaw]) + + # 未来轨迹:在当前 self 系下用 (dx, dy, dyaw),相对 t 时刻对象自身 + # 先取 t 时刻对象在 self 系下的 (x_t, y_t, yaw_t) + x0, y0, yaw0 = center_self[0].item(), center_self[1].item(), yaw + future_3 = [] + future_v = [] + for k in range(future_horizon): + if k >= len(objects_future) or k >= len(vehicle_pose_future): + future_3.append([0.0, 0.0, 0.0]) + future_v.append(0) + continue + # 找对象在 t+k+1 帧 + future_objs = objects_future[k] + match = next((o for o in future_objs if o.tracking_id == obj.tracking_id), None) + if match is None: + future_3.append([0.0, 0.0, 0.0]) + future_v.append(0) + continue + T_obj_self_future = invert_se3(vehicle_pose_t) @ match.object_to_world + xf = T_obj_self_future[0, 3].item() + yf = T_obj_self_future[1, 3].item() + yawf = _yaw_from_rotation_matrix(T_obj_self_future[:3, :3]).item() + dx = xf - x0 + dy = yf - y0 + dyaw = yawf - yaw0 + # 角度归到 (-pi, pi] + dyaw = (dyaw + np.pi) % (2 * np.pi) - np.pi + future_3.append([ + float(symlog(torch.tensor(dx)).item()), + float(symlog(torch.tensor(dy)).item()), + float(dyaw), + ]) + future_v.append(1) + future_traj.append(future_3) + future_valid.append(future_v) + + if not labels: + return { + "labels": torch.zeros(0, dtype=torch.long), + "boxes": torch.zeros(0, 7), + "is_dynamic": torch.zeros(0, dtype=torch.long), + "future_traj": torch.zeros(0, future_horizon, 3), + "future_valid": torch.zeros(0, future_horizon, dtype=torch.bool), + } + return { + "labels": torch.tensor(labels, dtype=torch.long), + "boxes": torch.tensor(boxes, dtype=torch.float32), + "is_dynamic": torch.tensor(is_dynamic, dtype=torch.long), + "future_traj": torch.tensor(future_traj, dtype=torch.float32), + "future_valid": torch.tensor(future_valid, dtype=torch.bool), + } + + +def build_ego_future_target( + vehicle_pose_t: torch.Tensor, + vehicle_pose_future: list[torch.Tensor], + horizon: int = 24, +) -> tuple[torch.Tensor, torch.Tensor]: + """自车未来 24 帧轨迹(在 t 自车系下,``(x, y, yaw)`` 已 symlog 归一)。""" + inv_t = invert_se3(vehicle_pose_t) + out = torch.zeros(horizon, 3) + valid = torch.zeros(horizon, dtype=torch.bool) + for k in range(horizon): + if k >= len(vehicle_pose_future): + break + rel = inv_t @ vehicle_pose_future[k] + x, y = rel[0, 3].item(), rel[1, 3].item() + yaw = _yaw_from_rotation_matrix(rel[:3, :3]).item() + out[k, 0] = symlog(torch.tensor(x)) + out[k, 1] = symlog(torch.tensor(y)) + out[k, 2] = yaw + valid[k] = True + return out, valid diff --git a/src/wjad/data/transforms.py b/src/wjad/data/transforms.py index 8f584be0a8a4f8439d2d1ead9d46c194a56a5709..94c0d9bba78ea8a84f0c0a2b6fecf8023635db4d 100644 --- a/src/wjad/data/transforms.py +++ b/src/wjad/data/transforms.py @@ -1,86 +1,86 @@ -"""图像与运动学的数据增广。""" - -from __future__ import annotations - -import numpy as np -import torch - - -# DINOv3 的 ImageNet 标准化参数 -DINOV3_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) -DINOV3_STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) - - -def crop_top_half(image: torch.Tensor) -> torch.Tensor: - """裁去图像上半部分(主要是天空)。 - - 输入 ``[3, H, W]`` 或 ``[T, 3, H, W]``;返回相同维度但 H 减半。 - """ - if image.dim() == 4: - h = image.shape[2] - return image[:, :, h // 2 :, :] - elif image.dim() == 3: - h = image.shape[1] - return image[:, h // 2 :, :] - raise ValueError(f"unsupported image dim: {image.dim()}") - - -def normalize_image(image: torch.Tensor, mean: torch.Tensor = DINOV3_MEAN, std: torch.Tensor = DINOV3_STD) -> torch.Tensor: - """对 [0, 1] 范围的图像做标准化。支持 ``[3,H,W]``/``[T,3,H,W]``/``[B,T,3,H,W]``。""" - while mean.dim() < image.dim(): - mean = mean.unsqueeze(0) - std = std.unsqueeze(0) - return (image - mean.to(image.device, image.dtype)) / std.to(image.device, image.dtype) - - -def add_gaussian_noise(image: torch.Tensor, std: float = 0.01) -> torch.Tensor: - """高斯噪声增广。``image`` 应已归一化(mean=0,std=1 之后)。""" - if std <= 0: - return image - return image + torch.randn_like(image) * std - - -def perturb_kinematics( - ego_6d: torch.Tensor, # [T, 6] - intr_vec: torch.Tensor, # [14] - extr_6d: torch.Tensor, # [6] - translation_std_m: float, - rotation_std_deg: float, - intrinsic_std: float, - extrinsic_std: float, - rng: np.random.Generator, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """在 Stage1 中期对运动学和内外参添加微小扰动,作为校准训练增广。 - - 返回扰动后值与扰动量(GT 残差 = -扰动量,因为校准网络要把扰动反推回去)。 - - 返回 - ---- - perturbed_ego, perturbed_intr, perturbed_extr, - gt_residual_concat (在 symlog 空间作为 calibration 监督,可选; - 本文件仅返回扰动后的真实空间值,校准 GT 由 trainer 构造) - """ - rot_std_rad = np.deg2rad(rotation_std_deg) - - # ego 8x6 - delta_ego = np.zeros_like(ego_6d.numpy()) - delta_ego[:, :3] = rng.normal(0.0, translation_std_m, size=(ego_6d.shape[0], 3)) - delta_ego[:, 3:] = rng.normal(0.0, rot_std_rad, size=(ego_6d.shape[0], 3)) - perturbed_ego = ego_6d + torch.from_numpy(delta_ego).to(ego_6d) - - # intrinsic 14 - delta_intr = rng.normal(0.0, intrinsic_std, size=(intr_vec.shape[0],)) - perturbed_intr = intr_vec + torch.from_numpy(delta_intr).to(intr_vec) - - # extrinsic 6 - delta_extr = np.zeros_like(extr_6d.numpy()) - delta_extr[:3] = rng.normal(0.0, extrinsic_std, size=(3,)) - delta_extr[3:] = rng.normal(0.0, rot_std_rad, size=(3,)) - perturbed_extr = extr_6d + torch.from_numpy(delta_extr).to(extr_6d) - - return ( - perturbed_ego, - perturbed_intr, - perturbed_extr, - torch.from_numpy(np.concatenate([delta_ego.flatten(), delta_intr, delta_extr])).to(ego_6d.dtype), - ) +"""图像与运动学的数据增广。""" + +from __future__ import annotations + +import numpy as np +import torch + + +# DINOv3 的 ImageNet 标准化参数 +DINOV3_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) +DINOV3_STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + + +def crop_top_half(image: torch.Tensor) -> torch.Tensor: + """裁去图像上半部分(主要是天空)。 + + 输入 ``[3, H, W]`` 或 ``[T, 3, H, W]``;返回相同维度但 H 减半。 + """ + if image.dim() == 4: + h = image.shape[2] + return image[:, :, h // 2 :, :] + elif image.dim() == 3: + h = image.shape[1] + return image[:, h // 2 :, :] + raise ValueError(f"unsupported image dim: {image.dim()}") + + +def normalize_image(image: torch.Tensor, mean: torch.Tensor = DINOV3_MEAN, std: torch.Tensor = DINOV3_STD) -> torch.Tensor: + """对 [0, 1] 范围的图像做标准化。支持 ``[3,H,W]``/``[T,3,H,W]``/``[B,T,3,H,W]``。""" + while mean.dim() < image.dim(): + mean = mean.unsqueeze(0) + std = std.unsqueeze(0) + return (image - mean.to(image.device, image.dtype)) / std.to(image.device, image.dtype) + + +def add_gaussian_noise(image: torch.Tensor, std: float = 0.01) -> torch.Tensor: + """高斯噪声增广。``image`` 应已归一化(mean=0,std=1 之后)。""" + if std <= 0: + return image + return image + torch.randn_like(image) * std + + +def perturb_kinematics( + ego_6d: torch.Tensor, # [T, 6] + intr_vec: torch.Tensor, # [14] + extr_6d: torch.Tensor, # [6] + translation_std_m: float, + rotation_std_deg: float, + intrinsic_std: float, + extrinsic_std: float, + rng: np.random.Generator, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """在 Stage1 中期对运动学和内外参添加微小扰动,作为校准训练增广。 + + 返回扰动后值与扰动量(GT 残差 = -扰动量,因为校准网络要把扰动反推回去)。 + + 返回 + ---- + perturbed_ego, perturbed_intr, perturbed_extr, + gt_residual_concat (在 symlog 空间作为 calibration 监督,可选; + 本文件仅返回扰动后的真实空间值,校准 GT 由 trainer 构造) + """ + rot_std_rad = np.deg2rad(rotation_std_deg) + + # ego 8x6 + delta_ego = np.zeros_like(ego_6d.numpy()) + delta_ego[:, :3] = rng.normal(0.0, translation_std_m, size=(ego_6d.shape[0], 3)) + delta_ego[:, 3:] = rng.normal(0.0, rot_std_rad, size=(ego_6d.shape[0], 3)) + perturbed_ego = ego_6d + torch.from_numpy(delta_ego).to(ego_6d) + + # intrinsic 14 + delta_intr = rng.normal(0.0, intrinsic_std, size=(intr_vec.shape[0],)) + perturbed_intr = intr_vec + torch.from_numpy(delta_intr).to(intr_vec) + + # extrinsic 6 + delta_extr = np.zeros_like(extr_6d.numpy()) + delta_extr[:3] = rng.normal(0.0, extrinsic_std, size=(3,)) + delta_extr[3:] = rng.normal(0.0, rot_std_rad, size=(3,)) + perturbed_extr = extr_6d + torch.from_numpy(delta_extr).to(extr_6d) + + return ( + perturbed_ego, + perturbed_intr, + perturbed_extr, + torch.from_numpy(np.concatenate([delta_ego.flatten(), delta_intr, delta_extr])).to(ego_6d.dtype), + ) diff --git a/src/wjad/encoders/__init__.py b/src/wjad/encoders/__init__.py index 8d590a89c53d84cecb1ba6e8d72cb8316d8ce5c4..1c269f740f9de2c8179d230624905ed16c5829ba 100644 --- a/src/wjad/encoders/__init__.py +++ b/src/wjad/encoders/__init__.py @@ -1,5 +1,5 @@ -"""视觉编码相关:DINOv3 包装、时空压缩。""" - -from .dinov3_wrapper import DINOv3Wrapper - -__all__ = ["DINOv3Wrapper"] +"""视觉编码相关:DINOv3 包装、时空压缩。""" + +from .dinov3_wrapper import DINOv3Wrapper + +__all__ = ["DINOv3Wrapper"] diff --git a/src/wjad/encoders/dinov3_wrapper.py b/src/wjad/encoders/dinov3_wrapper.py index 896c990b2214d445f6a990e7d29a3f32395fb173..8ca8a2d3fdb0e01103645f843d25c9417dc9714c 100644 --- a/src/wjad/encoders/dinov3_wrapper.py +++ b/src/wjad/encoders/dinov3_wrapper.py @@ -1,104 +1,104 @@ -"""DINOv3 ViT-B/16 包装器。 - -- 从本地路径加载(``./dinov3-vitb16-pretrain-lvd1689m``)。 -- 强制使用 ``attn_implementation="sdpa"``。 -- 提供 ``freeze()`` / ``unfreeze()`` 开关。 -- 输入:``[B, T, 3, H, W]``;输出:``[B, T, gh, gw, D]``,其中 - ``(gh, gw) = (H/patch, W/patch)``。 -""" - -from __future__ import annotations - -from pathlib import Path - -import torch -import torch.nn as nn -from transformers import AutoModel - - -class DINOv3Wrapper(nn.Module): - """加载并包装 DINOv3 ViT-B/16,输出 patch 网格特征。""" - - def __init__( - self, - pretrained_path: str | Path = "./dinov3-vitb16-pretrain-lvd1689m", - attn_implementation: str = "sdpa", - freeze: bool = True, - ) -> None: - super().__init__() - self.pretrained_path = str(pretrained_path) - # 加载 HuggingFace transformers 中的 DINOv3 ViT 模型 - self.model = AutoModel.from_pretrained( - self.pretrained_path, - attn_implementation=attn_implementation, - ) - cfg = self.model.config - self.hidden_size = cfg.hidden_size - self.patch_size = cfg.patch_size - self.num_register_tokens = getattr(cfg, "num_register_tokens", 4) - - if freeze: - self.freeze() - self._frozen = freeze - - def freeze(self) -> None: - """冻结所有参数。""" - for p in self.model.parameters(): - p.requires_grad_(False) - self.model.eval() - self._frozen = True - - def unfreeze(self) -> None: - """解冻全部参数(Stage2 微调)。""" - for p in self.model.parameters(): - p.requires_grad_(True) - self.model.train() - self._frozen = False - - @property - def is_frozen(self) -> bool: - return self._frozen - - def train(self, mode: bool = True) -> "DINOv3Wrapper": - """覆盖 train():冻结时永远保持 eval 模式(避免 BN/Dropout 漂移)。""" - super().train(mode) - if self._frozen: - self.model.eval() - return self - - def forward(self, images: torch.Tensor) -> torch.Tensor: - """ - 参数 - ---- - images : ``[B, T, 3, H, W]``,已按 DINOv3 mean/std 归一化。 - - 返回 - ---- - feats : ``[B, T, gh, gw, D]``。 - """ - b, t, c, h, w = images.shape - # DINOv3 forward 接受 ``pixel_values: [B*T, 3, H, W]`` - flat = images.view(b * t, c, h, w) - - # 冻结分支无需梯度,节省显存与时间 - if self._frozen: - with torch.no_grad(): - outputs = self.model(pixel_values=flat) - else: - outputs = self.model(pixel_values=flat) - - last = outputs.last_hidden_state # [B*T, 1 + R + N_patch, D] - num_prefix = 1 + self.num_register_tokens - patches = last[:, num_prefix:, :] # [B*T, N_patch, D] - - gh = h // self.patch_size - gw = w // self.patch_size - d = patches.shape[-1] - # reshape 回网格 - feats = patches.view(b, t, gh, gw, d) - return feats - - @torch.no_grad() - def expected_grid(self, image_h: int, image_w: int) -> tuple[int, int]: - """给定输入分辨率,返回 patch 网格大小。""" - return image_h // self.patch_size, image_w // self.patch_size +"""DINOv3 ViT-B/16 包装器。 + +- 从本地路径加载(``./dinov3-vitb16-pretrain-lvd1689m``)。 +- 强制使用 ``attn_implementation="sdpa"``。 +- 提供 ``freeze()`` / ``unfreeze()`` 开关。 +- 输入:``[B, T, 3, H, W]``;输出:``[B, T, gh, gw, D]``,其中 + ``(gh, gw) = (H/patch, W/patch)``。 +""" + +from __future__ import annotations + +from pathlib import Path + +import torch +import torch.nn as nn +from transformers import AutoModel + + +class DINOv3Wrapper(nn.Module): + """加载并包装 DINOv3 ViT-B/16,输出 patch 网格特征。""" + + def __init__( + self, + pretrained_path: str | Path = "./dinov3-vitb16-pretrain-lvd1689m", + attn_implementation: str = "sdpa", + freeze: bool = True, + ) -> None: + super().__init__() + self.pretrained_path = str(pretrained_path) + # 加载 HuggingFace transformers 中的 DINOv3 ViT 模型 + self.model = AutoModel.from_pretrained( + self.pretrained_path, + attn_implementation=attn_implementation, + ) + cfg = self.model.config + self.hidden_size = cfg.hidden_size + self.patch_size = cfg.patch_size + self.num_register_tokens = getattr(cfg, "num_register_tokens", 4) + + if freeze: + self.freeze() + self._frozen = freeze + + def freeze(self) -> None: + """冻结所有参数。""" + for p in self.model.parameters(): + p.requires_grad_(False) + self.model.eval() + self._frozen = True + + def unfreeze(self) -> None: + """解冻全部参数(Stage2 微调)。""" + for p in self.model.parameters(): + p.requires_grad_(True) + self.model.train() + self._frozen = False + + @property + def is_frozen(self) -> bool: + return self._frozen + + def train(self, mode: bool = True) -> "DINOv3Wrapper": + """覆盖 train():冻结时永远保持 eval 模式(避免 BN/Dropout 漂移)。""" + super().train(mode) + if self._frozen: + self.model.eval() + return self + + def forward(self, images: torch.Tensor) -> torch.Tensor: + """ + 参数 + ---- + images : ``[B, T, 3, H, W]``,已按 DINOv3 mean/std 归一化。 + + 返回 + ---- + feats : ``[B, T, gh, gw, D]``。 + """ + b, t, c, h, w = images.shape + # DINOv3 forward 接受 ``pixel_values: [B*T, 3, H, W]`` + flat = images.view(b * t, c, h, w) + + # 冻结分支无需梯度,节省显存与时间 + if self._frozen: + with torch.no_grad(): + outputs = self.model(pixel_values=flat) + else: + outputs = self.model(pixel_values=flat) + + last = outputs.last_hidden_state # [B*T, 1 + R + N_patch, D] + num_prefix = 1 + self.num_register_tokens + patches = last[:, num_prefix:, :] # [B*T, N_patch, D] + + gh = h // self.patch_size + gw = w // self.patch_size + d = patches.shape[-1] + # reshape 回网格 + feats = patches.view(b, t, gh, gw, d) + return feats + + @torch.no_grad() + def expected_grid(self, image_h: int, image_w: int) -> tuple[int, int]: + """给定输入分辨率,返回 patch 网格大小。""" + return image_h // self.patch_size, image_w // self.patch_size diff --git a/src/wjad/heads/__init__.py b/src/wjad/heads/__init__.py index d75b2817e6b75b68cd73802150048477392f95ef..4a747ec0b05258ddf4364bb1186ad2f003b5312a 100644 --- a/src/wjad/heads/__init__.py +++ b/src/wjad/heads/__init__.py @@ -1,11 +1,11 @@ -"""检测+未来轨迹头 + 控制头。""" - -from .detection_traj import DetectionTrajHead, DetectionTrajOutput -from .control import ControlHead, ControlOutput - -__all__ = [ - "DetectionTrajHead", - "DetectionTrajOutput", - "ControlHead", - "ControlOutput", -] +"""检测+未来轨迹头 + 控制头。""" + +from .detection_traj import DetectionTrajHead, DetectionTrajOutput +from .control import ControlHead, ControlOutput + +__all__ = [ + "DetectionTrajHead", + "DetectionTrajOutput", + "ControlHead", + "ControlOutput", +] diff --git a/src/wjad/heads/control.py b/src/wjad/heads/control.py index 9dcd94b83c95a87f22072ebfe1545530dc78a9c8..a2b10cb616e7df5cbbca21e14a0c6071f805a336 100644 --- a/src/wjad/heads/control.py +++ b/src/wjad/heads/control.py @@ -1,100 +1,100 @@ -"""自车控制头:24 个控制 token 输出未来轨迹与全局动作。 - -token 切分: - - 12 个轨迹 token → 经 MLP 上采样到 24 帧自车 ``(x, y, yaw)`` 的 ``μ`` / ``log_sigma`` - - 12 个动作 token → 第 0 个解码 ``(steer, throttle, brake)`` 的 ``μ`` / ``log_sigma`` - (其余作为冗余 / 未来扩展,暂不监督) -""" - -from __future__ import annotations - -from dataclasses import dataclass - -import torch -import torch.nn as nn - - -@dataclass -class ControlOutput: - ego_traj_mu: torch.Tensor # [B, T_future, 3] - ego_traj_log_sigma: torch.Tensor # [B, T_future, 3] - action_mu: torch.Tensor # [B, action_dim] - action_log_sigma: torch.Tensor # [B, action_dim] - - -class ControlHead(nn.Module): - def __init__( - self, - in_dim: int = 768, - hidden_size: int = 384, - num_traj_tokens: int = 12, - num_action_tokens: int = 12, - ego_traj_horizon: int = 24, - ego_traj_dim: int = 3, - action_dim: int = 3, - log_sigma_clamp: tuple[float, float] = (-7.0, 7.0), - ) -> None: - super().__init__() - assert num_traj_tokens + num_action_tokens == 24, "控制 token 总数应为 24" - self.in_dim = in_dim - self.hidden = hidden_size - self.num_traj_tokens = num_traj_tokens - self.num_action_tokens = num_action_tokens - self.ego_traj_horizon = ego_traj_horizon - self.ego_traj_dim = ego_traj_dim - self.action_dim = action_dim - self.log_sigma_clamp = log_sigma_clamp - - self.norm = nn.LayerNorm(in_dim) - # 轨迹分支:把 12 个轨迹 token 拼平 -> MLP -> 24*3 (mu) + 24*3 (logsig) - self.traj_proj = nn.Sequential( - nn.Linear(num_traj_tokens * in_dim, hidden_size), - nn.GELU(), - nn.Linear(hidden_size, hidden_size), - nn.GELU(), - ) - self.traj_mu_head = nn.Linear(hidden_size, ego_traj_horizon * ego_traj_dim) - self.traj_logsig_head = nn.Linear(hidden_size, ego_traj_horizon * ego_traj_dim) - - # 动作分支:取第 0 个动作 token - self.action_proj = nn.Sequential( - nn.Linear(in_dim, hidden_size), - nn.GELU(), - nn.Linear(hidden_size, hidden_size), - nn.GELU(), - ) - self.action_mu_head = nn.Linear(hidden_size, action_dim) - self.action_logsig_head = nn.Linear(hidden_size, action_dim) - - self._init_heads() - - def _init_heads(self) -> None: - for m in [self.traj_mu_head, self.traj_logsig_head, self.action_mu_head, self.action_logsig_head]: - nn.init.zeros_(m.weight) - nn.init.zeros_(m.bias) - - def forward(self, ctrl_tokens: torch.Tensor) -> ControlOutput: - """ - ctrl_tokens : ``[B, 24, in_dim]`` - """ - b, n, d = ctrl_tokens.shape - assert n == self.num_traj_tokens + self.num_action_tokens - x = self.norm(ctrl_tokens) - traj_feats = x[:, : self.num_traj_tokens, :].reshape(b, -1) - action_feats = x[:, self.num_traj_tokens, :] # 取第一个动作 token - - traj_h = self.traj_proj(traj_feats) - traj_mu = self.traj_mu_head(traj_h).view(b, self.ego_traj_horizon, self.ego_traj_dim) - traj_logsig = self.traj_logsig_head(traj_h).view(b, self.ego_traj_horizon, self.ego_traj_dim) - traj_logsig = traj_logsig.clamp(*self.log_sigma_clamp) - - action_h = self.action_proj(action_feats) - action_mu = self.action_mu_head(action_h) - action_logsig = self.action_logsig_head(action_h).clamp(*self.log_sigma_clamp) - - return ControlOutput( - ego_traj_mu=traj_mu, - ego_traj_log_sigma=traj_logsig, - action_mu=action_mu, - action_log_sigma=action_logsig, - ) +"""自车控制头:24 个控制 token 输出未来轨迹与全局动作。 + +token 切分: + - 12 个轨迹 token → 经 MLP 上采样到 24 帧自车 ``(x, y, yaw)`` 的 ``μ`` / ``log_sigma`` + - 12 个动作 token → 第 0 个解码 ``(steer, throttle, brake)`` 的 ``μ`` / ``log_sigma`` + (其余作为冗余 / 未来扩展,暂不监督) +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch +import torch.nn as nn + + +@dataclass +class ControlOutput: + ego_traj_mu: torch.Tensor # [B, T_future, 3] + ego_traj_log_sigma: torch.Tensor # [B, T_future, 3] + action_mu: torch.Tensor # [B, action_dim] + action_log_sigma: torch.Tensor # [B, action_dim] + + +class ControlHead(nn.Module): + def __init__( + self, + in_dim: int = 768, + hidden_size: int = 384, + num_traj_tokens: int = 12, + num_action_tokens: int = 12, + ego_traj_horizon: int = 24, + ego_traj_dim: int = 3, + action_dim: int = 3, + log_sigma_clamp: tuple[float, float] = (-7.0, 7.0), + ) -> None: + super().__init__() + assert num_traj_tokens + num_action_tokens == 24, "控制 token 总数应为 24" + self.in_dim = in_dim + self.hidden = hidden_size + self.num_traj_tokens = num_traj_tokens + self.num_action_tokens = num_action_tokens + self.ego_traj_horizon = ego_traj_horizon + self.ego_traj_dim = ego_traj_dim + self.action_dim = action_dim + self.log_sigma_clamp = log_sigma_clamp + + self.norm = nn.LayerNorm(in_dim) + # 轨迹分支:把 12 个轨迹 token 拼平 -> MLP -> 24*3 (mu) + 24*3 (logsig) + self.traj_proj = nn.Sequential( + nn.Linear(num_traj_tokens * in_dim, hidden_size), + nn.GELU(), + nn.Linear(hidden_size, hidden_size), + nn.GELU(), + ) + self.traj_mu_head = nn.Linear(hidden_size, ego_traj_horizon * ego_traj_dim) + self.traj_logsig_head = nn.Linear(hidden_size, ego_traj_horizon * ego_traj_dim) + + # 动作分支:取第 0 个动作 token + self.action_proj = nn.Sequential( + nn.Linear(in_dim, hidden_size), + nn.GELU(), + nn.Linear(hidden_size, hidden_size), + nn.GELU(), + ) + self.action_mu_head = nn.Linear(hidden_size, action_dim) + self.action_logsig_head = nn.Linear(hidden_size, action_dim) + + self._init_heads() + + def _init_heads(self) -> None: + for m in [self.traj_mu_head, self.traj_logsig_head, self.action_mu_head, self.action_logsig_head]: + nn.init.zeros_(m.weight) + nn.init.zeros_(m.bias) + + def forward(self, ctrl_tokens: torch.Tensor) -> ControlOutput: + """ + ctrl_tokens : ``[B, 24, in_dim]`` + """ + b, n, d = ctrl_tokens.shape + assert n == self.num_traj_tokens + self.num_action_tokens + x = self.norm(ctrl_tokens) + traj_feats = x[:, : self.num_traj_tokens, :].reshape(b, -1) + action_feats = x[:, self.num_traj_tokens, :] # 取第一个动作 token + + traj_h = self.traj_proj(traj_feats) + traj_mu = self.traj_mu_head(traj_h).view(b, self.ego_traj_horizon, self.ego_traj_dim) + traj_logsig = self.traj_logsig_head(traj_h).view(b, self.ego_traj_horizon, self.ego_traj_dim) + traj_logsig = traj_logsig.clamp(*self.log_sigma_clamp) + + action_h = self.action_proj(action_feats) + action_mu = self.action_mu_head(action_h) + action_logsig = self.action_logsig_head(action_h).clamp(*self.log_sigma_clamp) + + return ControlOutput( + ego_traj_mu=traj_mu, + ego_traj_log_sigma=traj_logsig, + action_mu=action_mu, + action_log_sigma=action_logsig, + ) diff --git a/src/wjad/heads/detection_traj.py b/src/wjad/heads/detection_traj.py index 6d444bf2fc3579e2d6f5cfc1e8652b36370f64a6..8780630b2540f74fda50bc931d522315634a7076 100644 --- a/src/wjad/heads/detection_traj.py +++ b/src/wjad/heads/detection_traj.py @@ -1,106 +1,106 @@ -"""统一的检测 + 未来轨迹头。 - -每个检测 query token 输出: - - ``cls`` : ``[num_classes]`` logits(含 background) - - ``is_dynamic`` : 二分类 logit(是否为运动类,用于 mask 轨迹分支损失) - - ``box3d_mu`` / ``box3d_log_sigma`` : ``[7]``(x, y, z, l, w, h, yaw) - - ``traj_mu`` / ``traj_log_sigma`` : ``[traj_horizon, 3]``(dx, dy, dyaw) - -匈牙利匹配代价由外部损失模块构造(用 cls focal 代价 + L1(box μ) + GIoU3D 近似)。 -""" - -from __future__ import annotations - -from dataclasses import dataclass - -import torch -import torch.nn as nn - -from ..modules.ffn import SwiGLUFFN - - -@dataclass -class DetectionTrajOutput: - """检测+未来轨迹头输出。""" - - cls_logits: torch.Tensor # [B, Q, num_classes] - is_dynamic_logit: torch.Tensor # [B, Q] - box3d_mu: torch.Tensor # [B, Q, 7] - box3d_log_sigma: torch.Tensor # [B, Q, 7] - traj_mu: torch.Tensor # [B, Q, T_future, 3] - traj_log_sigma: torch.Tensor # [B, Q, T_future, 3] - - -class DetectionTrajHead(nn.Module): - def __init__( - self, - in_dim: int = 768, - hidden_size: int = 384, - num_classes: int = 22, - box_dim: int = 7, - traj_horizon: int = 24, - traj_dim: int = 3, - log_sigma_clamp: tuple[float, float] = (-7.0, 7.0), - ) -> None: - super().__init__() - self.num_classes = num_classes - self.box_dim = box_dim - self.traj_horizon = traj_horizon - self.traj_dim = traj_dim - self.log_sigma_clamp = log_sigma_clamp - - # 共享主干 MLP(PreNorm + SwiGLU) - self.norm = nn.LayerNorm(in_dim) - self.shared = nn.Sequential( - nn.Linear(in_dim, hidden_size), - nn.GELU(), - nn.Linear(hidden_size, hidden_size), - nn.GELU(), - ) - - # 各分支 - self.cls_head = nn.Linear(hidden_size, num_classes) - self.isdyn_head = nn.Linear(hidden_size, 1) - self.box_mu_head = nn.Linear(hidden_size, box_dim) - self.box_logsig_head = nn.Linear(hidden_size, box_dim) - self.traj_mu_head = nn.Linear(hidden_size, traj_horizon * traj_dim) - self.traj_logsig_head = nn.Linear(hidden_size, traj_horizon * traj_dim) - - self._init_heads() - - def _init_heads(self) -> None: - # 让 box / traj 输出初始 ≈ 0;log_sigma 初始 ≈ 0 → sigma ≈ 1 - for m in [self.box_mu_head, self.box_logsig_head, self.traj_mu_head, self.traj_logsig_head]: - nn.init.zeros_(m.weight) - nn.init.zeros_(m.bias) - # cls / isdyn 用小初始化即可(避免 background 一开始全选) - nn.init.normal_(self.cls_head.weight, std=0.01) - nn.init.zeros_(self.cls_head.bias) - nn.init.zeros_(self.isdyn_head.weight) - nn.init.zeros_(self.isdyn_head.bias) - - def forward(self, det_tokens: torch.Tensor) -> DetectionTrajOutput: - """ - det_tokens : ``[B, Q, in_dim]``,主干输出中切出来的检测 token。 - """ - b, q, _ = det_tokens.shape - feats = self.shared(self.norm(det_tokens)) - - cls_logits = self.cls_head(feats) - isdyn_logit = self.isdyn_head(feats).squeeze(-1) - - box_mu = self.box_mu_head(feats) - box_logsig = self.box_logsig_head(feats).clamp(*self.log_sigma_clamp) - - traj_mu = self.traj_mu_head(feats).view(b, q, self.traj_horizon, self.traj_dim) - traj_logsig = self.traj_logsig_head(feats).view(b, q, self.traj_horizon, self.traj_dim) - traj_logsig = traj_logsig.clamp(*self.log_sigma_clamp) - - return DetectionTrajOutput( - cls_logits=cls_logits, - is_dynamic_logit=isdyn_logit, - box3d_mu=box_mu, - box3d_log_sigma=box_logsig, - traj_mu=traj_mu, - traj_log_sigma=traj_logsig, - ) +"""统一的检测 + 未来轨迹头。 + +每个检测 query token 输出: + - ``cls`` : ``[num_classes]`` logits(含 background) + - ``is_dynamic`` : 二分类 logit(是否为运动类,用于 mask 轨迹分支损失) + - ``box3d_mu`` / ``box3d_log_sigma`` : ``[7]``(x, y, z, l, w, h, yaw) + - ``traj_mu`` / ``traj_log_sigma`` : ``[traj_horizon, 3]``(dx, dy, dyaw) + +匈牙利匹配代价由外部损失模块构造(用 cls focal 代价 + L1(box μ) + GIoU3D 近似)。 +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from ..modules.ffn import SwiGLUFFN + + +@dataclass +class DetectionTrajOutput: + """检测+未来轨迹头输出。""" + + cls_logits: torch.Tensor # [B, Q, num_classes] + is_dynamic_logit: torch.Tensor # [B, Q] + box3d_mu: torch.Tensor # [B, Q, 7] + box3d_log_sigma: torch.Tensor # [B, Q, 7] + traj_mu: torch.Tensor # [B, Q, T_future, 3] + traj_log_sigma: torch.Tensor # [B, Q, T_future, 3] + + +class DetectionTrajHead(nn.Module): + def __init__( + self, + in_dim: int = 768, + hidden_size: int = 384, + num_classes: int = 22, + box_dim: int = 7, + traj_horizon: int = 24, + traj_dim: int = 3, + log_sigma_clamp: tuple[float, float] = (-7.0, 7.0), + ) -> None: + super().__init__() + self.num_classes = num_classes + self.box_dim = box_dim + self.traj_horizon = traj_horizon + self.traj_dim = traj_dim + self.log_sigma_clamp = log_sigma_clamp + + # 共享主干 MLP(PreNorm + SwiGLU) + self.norm = nn.LayerNorm(in_dim) + self.shared = nn.Sequential( + nn.Linear(in_dim, hidden_size), + nn.GELU(), + nn.Linear(hidden_size, hidden_size), + nn.GELU(), + ) + + # 各分支 + self.cls_head = nn.Linear(hidden_size, num_classes) + self.isdyn_head = nn.Linear(hidden_size, 1) + self.box_mu_head = nn.Linear(hidden_size, box_dim) + self.box_logsig_head = nn.Linear(hidden_size, box_dim) + self.traj_mu_head = nn.Linear(hidden_size, traj_horizon * traj_dim) + self.traj_logsig_head = nn.Linear(hidden_size, traj_horizon * traj_dim) + + self._init_heads() + + def _init_heads(self) -> None: + # 让 box / traj 输出初始 ≈ 0;log_sigma 初始 ≈ 0 → sigma ≈ 1 + for m in [self.box_mu_head, self.box_logsig_head, self.traj_mu_head, self.traj_logsig_head]: + nn.init.zeros_(m.weight) + nn.init.zeros_(m.bias) + # cls / isdyn 用小初始化即可(避免 background 一开始全选) + nn.init.normal_(self.cls_head.weight, std=0.01) + nn.init.zeros_(self.cls_head.bias) + nn.init.zeros_(self.isdyn_head.weight) + nn.init.zeros_(self.isdyn_head.bias) + + def forward(self, det_tokens: torch.Tensor) -> DetectionTrajOutput: + """ + det_tokens : ``[B, Q, in_dim]``,主干输出中切出来的检测 token。 + """ + b, q, _ = det_tokens.shape + feats = self.shared(self.norm(det_tokens)) + + cls_logits = self.cls_head(feats) + isdyn_logit = self.isdyn_head(feats).squeeze(-1) + + box_mu = self.box_mu_head(feats) + box_logsig = self.box_logsig_head(feats).clamp(*self.log_sigma_clamp) + + traj_mu = self.traj_mu_head(feats).view(b, q, self.traj_horizon, self.traj_dim) + traj_logsig = self.traj_logsig_head(feats).view(b, q, self.traj_horizon, self.traj_dim) + traj_logsig = traj_logsig.clamp(*self.log_sigma_clamp) + + return DetectionTrajOutput( + cls_logits=cls_logits, + is_dynamic_logit=isdyn_logit, + box3d_mu=box_mu, + box3d_log_sigma=box_logsig, + traj_mu=traj_mu, + traj_log_sigma=traj_logsig, + ) diff --git a/src/wjad/losses/__init__.py b/src/wjad/losses/__init__.py index 4648371eb82d4dd272a78951bd69bd2007162158..81801064e5b669d059fe7aad700e0e894818adc7 100644 --- a/src/wjad/losses/__init__.py +++ b/src/wjad/losses/__init__.py @@ -1,24 +1,24 @@ -"""损失函数集合。""" - -from .nll import gaussian_nll -from .detection import ( - HungarianMatcher, - detection_losses, - DetectionLossOutputs, -) -from .trajectory import object_traj_nll -from .control import ego_traj_nll, action_nll -from .moe_aux import moe_load_balance_and_boundary -from .calib_reg import calibration_regularization - -__all__ = [ - "gaussian_nll", - "HungarianMatcher", - "detection_losses", - "DetectionLossOutputs", - "object_traj_nll", - "ego_traj_nll", - "action_nll", - "moe_load_balance_and_boundary", - "calibration_regularization", -] +"""损失函数集合。""" + +from .nll import gaussian_nll +from .detection import ( + HungarianMatcher, + detection_losses, + DetectionLossOutputs, +) +from .trajectory import object_traj_nll +from .control import ego_traj_nll, action_nll +from .moe_aux import moe_load_balance_and_boundary +from .calib_reg import calibration_regularization + +__all__ = [ + "gaussian_nll", + "HungarianMatcher", + "detection_losses", + "DetectionLossOutputs", + "object_traj_nll", + "ego_traj_nll", + "action_nll", + "moe_load_balance_and_boundary", + "calibration_regularization", +] diff --git a/src/wjad/losses/calib_reg.py b/src/wjad/losses/calib_reg.py index 8fb9296536c0403bd686b69ec451e8e3c9288c55..79431e0a297124527afaeb8dacfdd4adfe9d9f28 100644 --- a/src/wjad/losses/calib_reg.py +++ b/src/wjad/losses/calib_reg.py @@ -1,21 +1,21 @@ -"""在线校准残差正则: - -- L2 残差先验:让早期 / 一般情况下残差接近 0; -- Tanh 边界正则:``residual^2`` 在 Tanh 上抑制饱和。 -""" - -from __future__ import annotations - -import torch - - -def calibration_regularization( - ego_residual: torch.Tensor, - intr_residual: torch.Tensor, - extr_residual: torch.Tensor, - l2_weight: float = 1.0, -) -> torch.Tensor: - e = ego_residual.pow(2).mean() - i = intr_residual.pow(2).mean() - x = extr_residual.pow(2).mean() - return l2_weight * (e + i + x) / 3.0 +"""在线校准残差正则: + +- L2 残差先验:让早期 / 一般情况下残差接近 0; +- Tanh 边界正则:``residual^2`` 在 Tanh 上抑制饱和。 +""" + +from __future__ import annotations + +import torch + + +def calibration_regularization( + ego_residual: torch.Tensor, + intr_residual: torch.Tensor, + extr_residual: torch.Tensor, + l2_weight: float = 1.0, +) -> torch.Tensor: + e = ego_residual.pow(2).mean() + i = intr_residual.pow(2).mean() + x = extr_residual.pow(2).mean() + return l2_weight * (e + i + x) / 3.0 diff --git a/src/wjad/losses/control.py b/src/wjad/losses/control.py index 61ae53afc298215cda100f4947a4f6247aebcae9..e0700b43d7d37ef13cbd7770af01e42b8d136a9a 100644 --- a/src/wjad/losses/control.py +++ b/src/wjad/losses/control.py @@ -1,25 +1,25 @@ -"""自车未来轨迹 + 全局动作的 NLL。""" - -from __future__ import annotations - -import torch - -from .nll import gaussian_nll - - -def ego_traj_nll( - pred_mu: torch.Tensor, # [B, T, 3] - pred_log_sigma: torch.Tensor, # [B, T, 3] - target: torch.Tensor, # [B, T, 3] (symlog 空间) - valid: torch.Tensor | None = None, -) -> torch.Tensor: - return gaussian_nll(pred_mu, pred_log_sigma, target, valid_mask=valid) - - -def action_nll( - pred_mu: torch.Tensor, # [B, A] - pred_log_sigma: torch.Tensor, # [B, A] - target: torch.Tensor, # [B, A] - valid: torch.Tensor | None = None, -) -> torch.Tensor: - return gaussian_nll(pred_mu, pred_log_sigma, target, valid_mask=valid) +"""自车未来轨迹 + 全局动作的 NLL。""" + +from __future__ import annotations + +import torch + +from .nll import gaussian_nll + + +def ego_traj_nll( + pred_mu: torch.Tensor, # [B, T, 3] + pred_log_sigma: torch.Tensor, # [B, T, 3] + target: torch.Tensor, # [B, T, 3] (symlog 空间) + valid: torch.Tensor | None = None, +) -> torch.Tensor: + return gaussian_nll(pred_mu, pred_log_sigma, target, valid_mask=valid) + + +def action_nll( + pred_mu: torch.Tensor, # [B, A] + pred_log_sigma: torch.Tensor, # [B, A] + target: torch.Tensor, # [B, A] + valid: torch.Tensor | None = None, +) -> torch.Tensor: + return gaussian_nll(pred_mu, pred_log_sigma, target, valid_mask=valid) diff --git a/src/wjad/losses/detection.py b/src/wjad/losses/detection.py index e3213b6c07891907a7dfdb0b218a63e0721a11c3..c6085fa389fc03501e466fc9185eb42e57aedeeb 100644 --- a/src/wjad/losses/detection.py +++ b/src/wjad/losses/detection.py @@ -1,213 +1,213 @@ -"""检测损失:匈牙利匹配 + 分类 focal + 3D box NLL + 近似 GIoU3D。 - -GIoU3D 用 BEV 平面 + 高度近似:用 ``(x, y, l, w, yaw)`` 计算 BEV IoU/GIoU, -``z, h`` 在长度上做线性 IoU;最终 GIoU3D ≈ GIoU_BEV * (h_overlap / h_union)。 -此近似在大多数 AV 公开 benchmark 已被广泛使用(速度快、可微)。 -""" - -from __future__ import annotations - -from dataclasses import dataclass - -import torch -import torch.nn.functional as F -from scipy.optimize import linear_sum_assignment - -from .nll import gaussian_nll - - -@dataclass -class DetectionLossOutputs: - """检测损失分项与匹配结果。""" - - cls_loss: torch.Tensor - box_nll: torch.Tensor - giou_loss: torch.Tensor - isdyn_loss: torch.Tensor - matched_indices: list[tuple[torch.Tensor, torch.Tensor]] - - -def focal_loss( - logits: torch.Tensor, - target: torch.Tensor, - alpha: float = 0.25, - gamma: float = 2.0, - reduction: str = "mean", -) -> torch.Tensor: - """多类 focal loss。``logits``: ``[N, C]``, ``target``: ``[N]`` (long)。""" - log_softmax = F.log_softmax(logits, dim=-1) - pt = log_softmax.exp().gather(-1, target.unsqueeze(-1)).squeeze(-1) - ce = -log_softmax.gather(-1, target.unsqueeze(-1)).squeeze(-1) - focal = alpha * (1 - pt).pow(gamma) * ce - if reduction == "mean": - return focal.mean() - if reduction == "sum": - return focal.sum() - return focal - - -def _bev_giou( - box_a: torch.Tensor, # [..., 5] (x,y,l,w,yaw) - box_b: torch.Tensor, # [..., 5] -) -> torch.Tensor: - """BEV 简化 GIoU:取轴对齐的 (x,y,l,w) 包络(忽略 yaw 旋转), - 可微且足够稳定。如需 SOTA 旋转 IoU 可在后续替换为 ``oriented_iou``。 - """ - cx_a, cy_a, l_a, w_a = box_a[..., 0], box_a[..., 1], box_a[..., 2], box_a[..., 3] - cx_b, cy_b, l_b, w_b = box_b[..., 0], box_b[..., 1], box_b[..., 2], box_b[..., 3] - a_x1, a_y1 = cx_a - l_a / 2, cy_a - w_a / 2 - a_x2, a_y2 = cx_a + l_a / 2, cy_a + w_a / 2 - b_x1, b_y1 = cx_b - l_b / 2, cy_b - w_b / 2 - b_x2, b_y2 = cx_b + l_b / 2, cy_b + w_b / 2 - inter_x1 = torch.max(a_x1, b_x1) - inter_y1 = torch.max(a_y1, b_y1) - inter_x2 = torch.min(a_x2, b_x2) - inter_y2 = torch.min(a_y2, b_y2) - inter_w = (inter_x2 - inter_x1).clamp_min(0) - inter_h = (inter_y2 - inter_y1).clamp_min(0) - inter = inter_w * inter_h - area_a = (l_a * w_a).clamp_min(0) - area_b = (l_b * w_b).clamp_min(0) - union = area_a + area_b - inter + 1e-6 - iou = inter / union - # GIoU enclosure - enc_x1 = torch.min(a_x1, b_x1) - enc_y1 = torch.min(a_y1, b_y1) - enc_x2 = torch.max(a_x2, b_x2) - enc_y2 = torch.max(a_y2, b_y2) - enc_area = ((enc_x2 - enc_x1) * (enc_y2 - enc_y1)).clamp_min(1e-6) - giou = iou - (enc_area - union) / enc_area - return giou - - -def giou3d_approx(box_a: torch.Tensor, box_b: torch.Tensor) -> torch.Tensor: - """3D 近似 GIoU。``box``: ``[..., 7]`` (x,y,z,l,w,h,yaw)。""" - bev = _bev_giou(box_a[..., [0, 1, 3, 4, 6]], box_b[..., [0, 1, 3, 4, 6]]) - z_a, h_a = box_a[..., 2], box_a[..., 5] - z_b, h_b = box_b[..., 2], box_b[..., 5] - a_z1, a_z2 = z_a - h_a / 2, z_a + h_a / 2 - b_z1, b_z2 = z_b - h_b / 2, z_b + h_b / 2 - inter_z = (torch.min(a_z2, b_z2) - torch.max(a_z1, b_z1)).clamp_min(0) - union_z = (h_a + h_b - inter_z).clamp_min(1e-6) - z_iou = inter_z / union_z - return bev * z_iou - - -class HungarianMatcher: - """DETR 风格匈牙利匹配(CPU 上 ``scipy.linear_sum_assignment``)。""" - - def __init__( - self, - cls_cost: float = 2.0, - l1_cost: float = 5.0, - giou_cost: float = 2.0, - ) -> None: - self.cls_cost = cls_cost - self.l1_cost = l1_cost - self.giou_cost = giou_cost - - @torch.no_grad() - def match( - self, - cls_logits: torch.Tensor, # [B, Q, C] - box_mu: torch.Tensor, # [B, Q, 7] - targets: list[dict], # 每个样本: {"labels": [N_i], "boxes": [N_i, 7]} - ) -> list[tuple[torch.Tensor, torch.Tensor]]: - b, q, c = cls_logits.shape - out = [] - cls_probs = cls_logits.softmax(-1) # [B, Q, C] - - for i in range(b): - tgt_labels = targets[i]["labels"] - tgt_boxes = targets[i]["boxes"] - n_tgt = tgt_labels.numel() - if n_tgt == 0: - out.append(( - torch.empty(0, dtype=torch.long), - torch.empty(0, dtype=torch.long), - )) - continue - - # cost_cls: 越大越好 → 取负 - cost_cls = -cls_probs[i, :, tgt_labels] # [Q, n_tgt] - cost_l1 = torch.cdist(box_mu[i], tgt_boxes, p=1) # [Q, n_tgt] - # giou3d: [Q, n_tgt] - qa = box_mu[i].unsqueeze(1).expand(-1, n_tgt, -1) - tb = tgt_boxes.unsqueeze(0).expand(q, -1, -1) - cost_giou = -giou3d_approx(qa, tb) - - cost = ( - self.cls_cost * cost_cls - + self.l1_cost * cost_l1 - + self.giou_cost * cost_giou - ) - cost_np = cost.cpu().numpy() - row, col = linear_sum_assignment(cost_np) - out.append((torch.as_tensor(row, dtype=torch.long), torch.as_tensor(col, dtype=torch.long))) - return out - - -def detection_losses( - cls_logits: torch.Tensor, # [B, Q, C] - box_mu: torch.Tensor, # [B, Q, 7] - box_log_sigma: torch.Tensor, # [B, Q, 7] - isdyn_logit: torch.Tensor, # [B, Q] - targets: list[dict], # 每样本: {"labels":..., "boxes":..., "is_dynamic":...} - matcher: HungarianMatcher, - num_classes: int, - background_class: int = 0, - focal_alpha: float = 0.25, - focal_gamma: float = 2.0, -) -> DetectionLossOutputs: - """返回 cls/box_nll/giou/isdyn 四个标量 loss + 匹配下标。""" - indices = matcher.match(cls_logits, box_mu, targets) - b, q, _ = box_mu.shape - device = box_mu.device - - # 构造分类目标:所有 query 默认 background;匹配的填对应 label - target_classes = torch.full((b, q), background_class, dtype=torch.long, device=device) - target_isdyn = torch.zeros(b, q, dtype=torch.float32, device=device) - matched_box_pairs = [] - matched_logsig_pairs = [] - matched_target_boxes = [] - - for i, (rows, cols) in enumerate(indices): - if rows.numel() == 0: - continue - rows = rows.to(device) - cols = cols.to(device) - target_classes[i, rows] = targets[i]["labels"][cols].to(device) - target_isdyn[i, rows] = targets[i]["is_dynamic"][cols].to(device).float() - matched_box_pairs.append(box_mu[i, rows]) - matched_logsig_pairs.append(box_log_sigma[i, rows]) - matched_target_boxes.append(targets[i]["boxes"][cols].to(device)) - - cls_loss = focal_loss( - cls_logits.view(b * q, -1), - target_classes.view(-1), - alpha=focal_alpha, - gamma=focal_gamma, - ) - - if matched_box_pairs: - pred_box = torch.cat(matched_box_pairs, dim=0) - pred_logsig = torch.cat(matched_logsig_pairs, dim=0) - gt_box = torch.cat(matched_target_boxes, dim=0) - box_nll = gaussian_nll(pred_box, pred_logsig, gt_box) - giou_v = giou3d_approx(pred_box, gt_box) - giou_loss = (1.0 - giou_v).mean() - else: - box_nll = torch.zeros((), device=device) - giou_loss = torch.zeros((), device=device) - - isdyn_loss = F.binary_cross_entropy_with_logits( - isdyn_logit, target_isdyn, reduction="mean" - ) - - return DetectionLossOutputs( - cls_loss=cls_loss, - box_nll=box_nll, - giou_loss=giou_loss, - isdyn_loss=isdyn_loss, - matched_indices=indices, - ) +"""检测损失:匈牙利匹配 + 分类 focal + 3D box NLL + 近似 GIoU3D。 + +GIoU3D 用 BEV 平面 + 高度近似:用 ``(x, y, l, w, yaw)`` 计算 BEV IoU/GIoU, +``z, h`` 在长度上做线性 IoU;最终 GIoU3D ≈ GIoU_BEV * (h_overlap / h_union)。 +此近似在大多数 AV 公开 benchmark 已被广泛使用(速度快、可微)。 +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from scipy.optimize import linear_sum_assignment + +from .nll import gaussian_nll + + +@dataclass +class DetectionLossOutputs: + """检测损失分项与匹配结果。""" + + cls_loss: torch.Tensor + box_nll: torch.Tensor + giou_loss: torch.Tensor + isdyn_loss: torch.Tensor + matched_indices: list[tuple[torch.Tensor, torch.Tensor]] + + +def focal_loss( + logits: torch.Tensor, + target: torch.Tensor, + alpha: float = 0.25, + gamma: float = 2.0, + reduction: str = "mean", +) -> torch.Tensor: + """多类 focal loss。``logits``: ``[N, C]``, ``target``: ``[N]`` (long)。""" + log_softmax = F.log_softmax(logits, dim=-1) + pt = log_softmax.exp().gather(-1, target.unsqueeze(-1)).squeeze(-1) + ce = -log_softmax.gather(-1, target.unsqueeze(-1)).squeeze(-1) + focal = alpha * (1 - pt).pow(gamma) * ce + if reduction == "mean": + return focal.mean() + if reduction == "sum": + return focal.sum() + return focal + + +def _bev_giou( + box_a: torch.Tensor, # [..., 5] (x,y,l,w,yaw) + box_b: torch.Tensor, # [..., 5] +) -> torch.Tensor: + """BEV 简化 GIoU:取轴对齐的 (x,y,l,w) 包络(忽略 yaw 旋转), + 可微且足够稳定。如需 SOTA 旋转 IoU 可在后续替换为 ``oriented_iou``。 + """ + cx_a, cy_a, l_a, w_a = box_a[..., 0], box_a[..., 1], box_a[..., 2], box_a[..., 3] + cx_b, cy_b, l_b, w_b = box_b[..., 0], box_b[..., 1], box_b[..., 2], box_b[..., 3] + a_x1, a_y1 = cx_a - l_a / 2, cy_a - w_a / 2 + a_x2, a_y2 = cx_a + l_a / 2, cy_a + w_a / 2 + b_x1, b_y1 = cx_b - l_b / 2, cy_b - w_b / 2 + b_x2, b_y2 = cx_b + l_b / 2, cy_b + w_b / 2 + inter_x1 = torch.max(a_x1, b_x1) + inter_y1 = torch.max(a_y1, b_y1) + inter_x2 = torch.min(a_x2, b_x2) + inter_y2 = torch.min(a_y2, b_y2) + inter_w = (inter_x2 - inter_x1).clamp_min(0) + inter_h = (inter_y2 - inter_y1).clamp_min(0) + inter = inter_w * inter_h + area_a = (l_a * w_a).clamp_min(0) + area_b = (l_b * w_b).clamp_min(0) + union = area_a + area_b - inter + 1e-6 + iou = inter / union + # GIoU enclosure + enc_x1 = torch.min(a_x1, b_x1) + enc_y1 = torch.min(a_y1, b_y1) + enc_x2 = torch.max(a_x2, b_x2) + enc_y2 = torch.max(a_y2, b_y2) + enc_area = ((enc_x2 - enc_x1) * (enc_y2 - enc_y1)).clamp_min(1e-6) + giou = iou - (enc_area - union) / enc_area + return giou + + +def giou3d_approx(box_a: torch.Tensor, box_b: torch.Tensor) -> torch.Tensor: + """3D 近似 GIoU。``box``: ``[..., 7]`` (x,y,z,l,w,h,yaw)。""" + bev = _bev_giou(box_a[..., [0, 1, 3, 4, 6]], box_b[..., [0, 1, 3, 4, 6]]) + z_a, h_a = box_a[..., 2], box_a[..., 5] + z_b, h_b = box_b[..., 2], box_b[..., 5] + a_z1, a_z2 = z_a - h_a / 2, z_a + h_a / 2 + b_z1, b_z2 = z_b - h_b / 2, z_b + h_b / 2 + inter_z = (torch.min(a_z2, b_z2) - torch.max(a_z1, b_z1)).clamp_min(0) + union_z = (h_a + h_b - inter_z).clamp_min(1e-6) + z_iou = inter_z / union_z + return bev * z_iou + + +class HungarianMatcher: + """DETR 风格匈牙利匹配(CPU 上 ``scipy.linear_sum_assignment``)。""" + + def __init__( + self, + cls_cost: float = 2.0, + l1_cost: float = 5.0, + giou_cost: float = 2.0, + ) -> None: + self.cls_cost = cls_cost + self.l1_cost = l1_cost + self.giou_cost = giou_cost + + @torch.no_grad() + def match( + self, + cls_logits: torch.Tensor, # [B, Q, C] + box_mu: torch.Tensor, # [B, Q, 7] + targets: list[dict], # 每个样本: {"labels": [N_i], "boxes": [N_i, 7]} + ) -> list[tuple[torch.Tensor, torch.Tensor]]: + b, q, c = cls_logits.shape + out = [] + cls_probs = cls_logits.softmax(-1) # [B, Q, C] + + for i in range(b): + tgt_labels = targets[i]["labels"] + tgt_boxes = targets[i]["boxes"] + n_tgt = tgt_labels.numel() + if n_tgt == 0: + out.append(( + torch.empty(0, dtype=torch.long), + torch.empty(0, dtype=torch.long), + )) + continue + + # cost_cls: 越大越好 → 取负 + cost_cls = -cls_probs[i, :, tgt_labels] # [Q, n_tgt] + cost_l1 = torch.cdist(box_mu[i], tgt_boxes, p=1) # [Q, n_tgt] + # giou3d: [Q, n_tgt] + qa = box_mu[i].unsqueeze(1).expand(-1, n_tgt, -1) + tb = tgt_boxes.unsqueeze(0).expand(q, -1, -1) + cost_giou = -giou3d_approx(qa, tb) + + cost = ( + self.cls_cost * cost_cls + + self.l1_cost * cost_l1 + + self.giou_cost * cost_giou + ) + cost_np = cost.cpu().numpy() + row, col = linear_sum_assignment(cost_np) + out.append((torch.as_tensor(row, dtype=torch.long), torch.as_tensor(col, dtype=torch.long))) + return out + + +def detection_losses( + cls_logits: torch.Tensor, # [B, Q, C] + box_mu: torch.Tensor, # [B, Q, 7] + box_log_sigma: torch.Tensor, # [B, Q, 7] + isdyn_logit: torch.Tensor, # [B, Q] + targets: list[dict], # 每样本: {"labels":..., "boxes":..., "is_dynamic":...} + matcher: HungarianMatcher, + num_classes: int, + background_class: int = 0, + focal_alpha: float = 0.25, + focal_gamma: float = 2.0, +) -> DetectionLossOutputs: + """返回 cls/box_nll/giou/isdyn 四个标量 loss + 匹配下标。""" + indices = matcher.match(cls_logits, box_mu, targets) + b, q, _ = box_mu.shape + device = box_mu.device + + # 构造分类目标:所有 query 默认 background;匹配的填对应 label + target_classes = torch.full((b, q), background_class, dtype=torch.long, device=device) + target_isdyn = torch.zeros(b, q, dtype=torch.float32, device=device) + matched_box_pairs = [] + matched_logsig_pairs = [] + matched_target_boxes = [] + + for i, (rows, cols) in enumerate(indices): + if rows.numel() == 0: + continue + rows = rows.to(device) + cols = cols.to(device) + target_classes[i, rows] = targets[i]["labels"][cols].to(device) + target_isdyn[i, rows] = targets[i]["is_dynamic"][cols].to(device).float() + matched_box_pairs.append(box_mu[i, rows]) + matched_logsig_pairs.append(box_log_sigma[i, rows]) + matched_target_boxes.append(targets[i]["boxes"][cols].to(device)) + + cls_loss = focal_loss( + cls_logits.view(b * q, -1), + target_classes.view(-1), + alpha=focal_alpha, + gamma=focal_gamma, + ) + + if matched_box_pairs: + pred_box = torch.cat(matched_box_pairs, dim=0) + pred_logsig = torch.cat(matched_logsig_pairs, dim=0) + gt_box = torch.cat(matched_target_boxes, dim=0) + box_nll = gaussian_nll(pred_box, pred_logsig, gt_box) + giou_v = giou3d_approx(pred_box, gt_box) + giou_loss = (1.0 - giou_v).mean() + else: + box_nll = torch.zeros((), device=device) + giou_loss = torch.zeros((), device=device) + + isdyn_loss = F.binary_cross_entropy_with_logits( + isdyn_logit, target_isdyn, reduction="mean" + ) + + return DetectionLossOutputs( + cls_loss=cls_loss, + box_nll=box_nll, + giou_loss=giou_loss, + isdyn_loss=isdyn_loss, + matched_indices=indices, + ) diff --git a/src/wjad/losses/moe_aux.py b/src/wjad/losses/moe_aux.py index 84bf7caf073de83459aa24a28af0410996301f0d..2f55ed77149d0eeda7b7142cace6189ee1031b38 100644 --- a/src/wjad/losses/moe_aux.py +++ b/src/wjad/losses/moe_aux.py @@ -1,33 +1,33 @@ -"""MoE 路由的负载均衡 + 边界正则。 - -- 负载均衡:``var_b(probs)`` 跨样本(batch)应较小,避免某些样本恒选少数专家。 - 这里用 ``var(probs.mean(dim=0))`` 作为简单负载方差度量; - 也加入 ``mean(probs).std()`` 跨专家的均匀性度量。 -- 边界正则:``mean(logits ** 2)`` 防止路由 logits 越界,使 sigmoid 不饱和。 -""" - -from __future__ import annotations - -import torch - -from ..modules.moe import MoEStats - - -def moe_load_balance_and_boundary( - stats_list: list[MoEStats], - load_balance_weight: float = 1.0, - boundary_weight: float = 1.0, -) -> torch.Tensor: - if not stats_list: - return torch.zeros((), device="cpu") - - device = stats_list[0].logits.device - total = torch.zeros((), device=device) - for stats in stats_list: - # boundary:logits^2 的均值 - boundary = stats.logits.pow(2).mean() - # load balance:跨样本的专家选择频率方差 - avg_per_expert = stats.probs.mean(dim=0) # [num_routed] - load = avg_per_expert.std() - total = total + boundary_weight * boundary + load_balance_weight * load - return total / max(len(stats_list), 1) +"""MoE 路由的负载均衡 + 边界正则。 + +- 负载均衡:``var_b(probs)`` 跨样本(batch)应较小,避免某些样本恒选少数专家。 + 这里用 ``var(probs.mean(dim=0))`` 作为简单负载方差度量; + 也加入 ``mean(probs).std()`` 跨专家的均匀性度量。 +- 边界正则:``mean(logits ** 2)`` 防止路由 logits 越界,使 sigmoid 不饱和。 +""" + +from __future__ import annotations + +import torch + +from ..modules.moe import MoEStats + + +def moe_load_balance_and_boundary( + stats_list: list[MoEStats], + load_balance_weight: float = 1.0, + boundary_weight: float = 1.0, +) -> torch.Tensor: + if not stats_list: + return torch.zeros((), device="cpu") + + device = stats_list[0].logits.device + total = torch.zeros((), device=device) + for stats in stats_list: + # boundary:logits^2 的均值 + boundary = stats.logits.pow(2).mean() + # load balance:跨样本的专家选择频率方差 + avg_per_expert = stats.probs.mean(dim=0) # [num_routed] + load = avg_per_expert.std() + total = total + boundary_weight * boundary + load_balance_weight * load + return total / max(len(stats_list), 1) diff --git a/src/wjad/losses/nll.py b/src/wjad/losses/nll.py index da5c4d5d261499d619ec0f79c1e37c3d71daf1f6..ab050e3e0dd95bd2bea9b55dd16b97a1fcfa57db 100644 --- a/src/wjad/losses/nll.py +++ b/src/wjad/losses/nll.py @@ -1,47 +1,47 @@ -"""高斯 NLL 置信度损失。 - -公式: ``L = 0.5 * ((y - μ) * exp(-log_sigma)) ** 2 + log_sigma + 0.5 * log(2π)`` -为节省常数项,实际实现忽略 ``0.5 * log(2π)``(不影响优化)。 - -支持可选的 ``valid_mask``:在 mask=False 处忽略对应元素。 -""" - -from __future__ import annotations - -import torch - - -def gaussian_nll( - mu: torch.Tensor, - log_sigma: torch.Tensor, - target: torch.Tensor, - valid_mask: torch.Tensor | None = None, - reduction: str = "mean", -) -> torch.Tensor: - """高斯负对数似然。 - - 所有张量同形状(mask 是其 broadcast 子集,可在最后维省略 features)。 - """ - diff = target - mu - inv_sigma = torch.exp(-log_sigma) - nll = 0.5 * (diff * inv_sigma).pow(2) + log_sigma - - if valid_mask is not None: - # broadcast 到 nll 的 shape - while valid_mask.dim() < nll.dim(): - valid_mask = valid_mask.unsqueeze(-1) - valid_mask = valid_mask.to(nll.dtype) - nll = nll * valid_mask - if reduction == "mean": - denom = valid_mask.sum().clamp_min(1.0) - return nll.sum() / denom - elif reduction == "sum": - return nll.sum() - else: - return nll - - if reduction == "mean": - return nll.mean() - if reduction == "sum": - return nll.sum() - return nll +"""高斯 NLL 置信度损失。 + +公式: ``L = 0.5 * ((y - μ) * exp(-log_sigma)) ** 2 + log_sigma + 0.5 * log(2π)`` +为节省常数项,实际实现忽略 ``0.5 * log(2π)``(不影响优化)。 + +支持可选的 ``valid_mask``:在 mask=False 处忽略对应元素。 +""" + +from __future__ import annotations + +import torch + + +def gaussian_nll( + mu: torch.Tensor, + log_sigma: torch.Tensor, + target: torch.Tensor, + valid_mask: torch.Tensor | None = None, + reduction: str = "mean", +) -> torch.Tensor: + """高斯负对数似然。 + + 所有张量同形状(mask 是其 broadcast 子集,可在最后维省略 features)。 + """ + diff = target - mu + inv_sigma = torch.exp(-log_sigma) + nll = 0.5 * (diff * inv_sigma).pow(2) + log_sigma + + if valid_mask is not None: + # broadcast 到 nll 的 shape + while valid_mask.dim() < nll.dim(): + valid_mask = valid_mask.unsqueeze(-1) + valid_mask = valid_mask.to(nll.dtype) + nll = nll * valid_mask + if reduction == "mean": + denom = valid_mask.sum().clamp_min(1.0) + return nll.sum() / denom + elif reduction == "sum": + return nll.sum() + else: + return nll + + if reduction == "mean": + return nll.mean() + if reduction == "sum": + return nll.sum() + return nll diff --git a/src/wjad/losses/trajectory.py b/src/wjad/losses/trajectory.py index 58e571476b8e530ce8b4722be4314dd095cc8c31..fcf60d648ec8e8920f9dbc562bc5772914689405 100644 --- a/src/wjad/losses/trajectory.py +++ b/src/wjad/losses/trajectory.py @@ -1,43 +1,43 @@ -"""动态目标未来 24 帧轨迹 NLL(仅在匹配到运动类的 query 上启用)。""" - -from __future__ import annotations - -import torch - -from .nll import gaussian_nll - - -def object_traj_nll( - traj_mu: torch.Tensor, # [B, Q, T, 3] - traj_log_sigma: torch.Tensor, # [B, Q, T, 3] - matched_indices: list[tuple[torch.Tensor, torch.Tensor]], - targets: list[dict], # 每样本 {"future_traj":[N,T,3], "future_valid":[N,T], "is_dynamic":[N]} -) -> torch.Tensor: - """对 ``is_dynamic == True`` 的匹配项求 traj NLL;其余忽略。 - - 返回标量 loss。 - """ - device = traj_mu.device - b = traj_mu.shape[0] - losses = [] - for i in range(b): - rows, cols = matched_indices[i] - if rows.numel() == 0: - continue - rows = rows.to(device) - cols = cols.to(device) - is_dyn = targets[i]["is_dynamic"][cols].to(device).bool() - if not is_dyn.any(): - continue - sel_rows = rows[is_dyn] - sel_cols = cols[is_dyn] - pred_mu = traj_mu[i, sel_rows] # [n, T, 3] - pred_logsig = traj_log_sigma[i, sel_rows] # [n, T, 3] - gt_traj = targets[i]["future_traj"][sel_cols].to(device) - valid = targets[i]["future_valid"][sel_cols].to(device).bool() # [n, T] - # 在 (T, 3) 维上算 NLL,valid mask 只到 T 维 - nll = gaussian_nll(pred_mu, pred_logsig, gt_traj, valid_mask=valid) - losses.append(nll) - if not losses: - return torch.zeros((), device=device) - return torch.stack(losses).mean() +"""动态目标未来 24 帧轨迹 NLL(仅在匹配到运动类的 query 上启用)。""" + +from __future__ import annotations + +import torch + +from .nll import gaussian_nll + + +def object_traj_nll( + traj_mu: torch.Tensor, # [B, Q, T, 3] + traj_log_sigma: torch.Tensor, # [B, Q, T, 3] + matched_indices: list[tuple[torch.Tensor, torch.Tensor]], + targets: list[dict], # 每样本 {"future_traj":[N,T,3], "future_valid":[N,T], "is_dynamic":[N]} +) -> torch.Tensor: + """对 ``is_dynamic == True`` 的匹配项求 traj NLL;其余忽略。 + + 返回标量 loss。 + """ + device = traj_mu.device + b = traj_mu.shape[0] + losses = [] + for i in range(b): + rows, cols = matched_indices[i] + if rows.numel() == 0: + continue + rows = rows.to(device) + cols = cols.to(device) + is_dyn = targets[i]["is_dynamic"][cols].to(device).bool() + if not is_dyn.any(): + continue + sel_rows = rows[is_dyn] + sel_cols = cols[is_dyn] + pred_mu = traj_mu[i, sel_rows] # [n, T, 3] + pred_logsig = traj_log_sigma[i, sel_rows] # [n, T, 3] + gt_traj = targets[i]["future_traj"][sel_cols].to(device) + valid = targets[i]["future_valid"][sel_cols].to(device).bool() # [n, T] + # 在 (T, 3) 维上算 NLL,valid mask 只到 T 维 + nll = gaussian_nll(pred_mu, pred_logsig, gt_traj, valid_mask=valid) + losses.append(nll) + if not losses: + return torch.zeros((), device=device) + return torch.stack(losses).mean() diff --git a/src/wjad/model.py b/src/wjad/model.py index 6167de3ef7618178c9895a56788be20cd7f7c37c..023a4ab092a9aa2208ec5684548921fa6af30a92 100644 --- a/src/wjad/model.py +++ b/src/wjad/model.py @@ -1,289 +1,289 @@ -"""端到端自动驾驶模型 E2EAVModel。 - -forward 流程 - 1. ``DINOv3`` 提取 8 帧 patch 特征。 - 2. ``OnlineCalibration`` 用原始 ego/intr/extr (symlog) + DINOv3 patch 作 KV, - 输出 symlog 空间残差,叠加并 symexp 还原得到 corrected_*。 - 3. 用 corrected_intr / corrected_extr / corrected_ego 计算 - - 每 token 的自车系单位射线(仅用于视觉 token 的 RoPE 第一组头)。 - - 8 个 ego token(symlog 后线性投影)。 - 4. 2×2×2 时空压缩 -> 1536 视觉 token。 - 5. 拼接 [vision(1536) | ego(8) | det(1024) | ctrl(24) | extra(256)] = 2848 token。 - 非视觉切片各自加可学习 PE。 - 6. 18 层主干(仅视觉切片应用 3D RoPE)。 - 7. 切片送入 ``DetectionTrajHead`` 与 ``ControlHead``。 -""" - -from __future__ import annotations - -from dataclasses import dataclass - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .backbone import Backbone, BackboneOutput -from .calibration import OnlineCalibration, CalibrationOutput -from .encoders import DINOv3Wrapper -from .heads import ( - ControlHead, - ControlOutput, - DetectionTrajHead, - DetectionTrajOutput, -) -from .modules.learned_pe import LearnedTokenPE -from .modules.normalization import symlog -from .modules.pos_encoding import RoPE3D -from .modules.rays import compute_ego_rays -from .modules.temporal_compress import TemporalCompress2x2x2 - - -@dataclass -class E2EOutput: - """模型完整输出。""" - - detection: DetectionTrajOutput - control: ControlOutput - backbone_out: BackboneOutput - calibration: CalibrationOutput - - -class E2EAVModel(nn.Module): - def __init__( - self, - dinov3_path: str = "./dinov3-vitb16-pretrain-lvd1689m", - backbone_dim: int = 768, - num_heads: int = 12, - num_dense_layers: int = 9, - num_moe_layers: int = 9, - num_routed_experts: int = 7, - num_shared_experts: int = 1, - topk_experts: int = 3, - ffn_mult: int = 4, - # token 数量 - num_history_frames: int = 8, - num_detection_tokens: int = 1024, - num_control_tokens: int = 24, - num_ego_tokens: int = 8, - num_extra_tokens: int = 256, - # 输入分辨率 - image_h: int = 384, - image_w: int = 1024, - patch_size: int = 16, - # 头超参 - num_classes: int = 22, - traj_horizon: int = 24, - det_head_hidden: int = 384, - ctrl_head_hidden: int = 384, - # 校准 - calib_dim: int = 256, - calib_num_query: int = 256, - calib_num_blocks: int = 2, - calib_num_self_per_block: int = 2, - calib_num_heads: int = 8, - calib_residual_range: float = 0.1, - calib_intr_dim: int = 11, - # DINOv3 - freeze_dinov3: bool = True, - attn_implementation: str = "sdpa", - ) -> None: - super().__init__() - self.image_h = image_h - self.image_w = image_w - self.patch_size = patch_size - self.num_history = num_history_frames - self.num_det = num_detection_tokens - self.num_ctrl = num_control_tokens - self.num_ego = num_ego_tokens - self.num_extra = num_extra_tokens - - # === 1) DINOv3 === - self.dinov3 = DINOv3Wrapper( - pretrained_path=dinov3_path, - attn_implementation=attn_implementation, - freeze=freeze_dinov3, - ) - dino_dim = self.dinov3.hidden_size - - # === 2) 在线校准 === - self.calib = OnlineCalibration( - dino_dim=dino_dim, - hidden_dim=calib_dim, - num_query_tokens=calib_num_query, - num_blocks=calib_num_blocks, - num_self_attn_per_block=calib_num_self_per_block, - num_heads=calib_num_heads, - residual_range=calib_residual_range, - num_history_frames=num_history_frames, - intr_dim=calib_intr_dim, - ) - - # === 3) 时空压缩 === - self.compress = TemporalCompress2x2x2(dim=dino_dim) - # patch 网格大小(必须能被 2 整除) - self.gh = image_h // patch_size - self.gw = image_w // patch_size - - # === 4) 各类 token + 可学习 PE === - self.ego_proj = nn.Linear(6, backbone_dim) # 6D pose -> backbone dim - self.det_tokens = nn.Parameter(torch.empty(num_detection_tokens, backbone_dim)) - nn.init.trunc_normal_(self.det_tokens, std=0.02) - self.ctrl_tokens = nn.Parameter(torch.empty(num_control_tokens, backbone_dim)) - nn.init.trunc_normal_(self.ctrl_tokens, std=0.02) - self.extra_tokens = nn.Parameter(torch.empty(num_extra_tokens, backbone_dim)) - nn.init.trunc_normal_(self.extra_tokens, std=0.02) - - self.ego_pe = LearnedTokenPE(num_ego_tokens, backbone_dim) - self.det_pe = LearnedTokenPE(num_detection_tokens, backbone_dim) - self.ctrl_pe = LearnedTokenPE(num_control_tokens, backbone_dim) - self.extra_pe = LearnedTokenPE(num_extra_tokens, backbone_dim) - - # === 5) RoPE 3D(仅视觉,4 时间帧 × 12 × 32 网格)=== - self.rope = RoPE3D( - num_heads=num_heads, - head_dim=backbone_dim // num_heads, - time_size=num_history_frames // 2, - height_size=self.gh // 2, - width_size=self.gw // 2, - ) - - # === 6) 主干 18 层 === - self.backbone = Backbone( - dim=backbone_dim, - num_heads=num_heads, - ffn_mult=ffn_mult, - num_dense_layers=num_dense_layers, - num_moe_layers=num_moe_layers, - num_routed=num_routed_experts, - num_shared=num_shared_experts, - topk=topk_experts, - ) - - # === 7) 头 === - self.det_traj_head = DetectionTrajHead( - in_dim=backbone_dim, - hidden_size=det_head_hidden, - num_classes=num_classes, - traj_horizon=traj_horizon, - ) - self.ctrl_head = ControlHead( - in_dim=backbone_dim, - hidden_size=ctrl_head_hidden, - num_traj_tokens=12, - num_action_tokens=num_control_tokens - 12, - ego_traj_horizon=traj_horizon, - ) - - # ---------- 工具 ---------- - - @property - def num_visual_tokens(self) -> int: - # 2×2×2 压缩后 - return (self.num_history // 2) * (self.gh // 2) * (self.gw // 2) - - def _build_ego_tokens(self, ego_6d_corrected: torch.Tensor) -> torch.Tensor: - """``[B, 8, 6]`` -> symlog -> Linear -> ``[B, 8, D]``。""" - return self.ego_proj(symlog(ego_6d_corrected)) - - def _build_visual_rays( - self, - intr_corrected: torch.Tensor, # [B, calib_intr_dim] - extr_corrected_se3: torch.Tensor, # [B, 4, 4] cam2vehicle - compressed_thw: tuple[int, int, int], - ) -> torch.Tensor: - """计算压缩后视觉 token 网格的射线方向。 - - 在 2×2×2 压缩后,每个视觉 token 对应原 patch 网格的一个 2x2 区域 + - 2 个时间帧。这里取所代表区域的中心像素与"中间时间"的射线作近似, - 所有时间帧取同一个 (h, w) 上的射线(因为相机 pose 在 8 帧间是 - rigid 的相机系;自车运动差异会通过 ego token 传递)。 - """ - b = intr_corrected.shape[0] - t_, h_, w_ = compressed_thw - rays_grid = compute_ego_rays( - intr_vec=intr_corrected, - cam2vehicle=extr_corrected_se3, - height=self.image_h, - width=self.image_w, - grid_h=h_, - grid_w=w_, - device=intr_corrected.device, - dtype=intr_corrected.dtype, - ) # [B, h_, w_, 3] - # 复制到时间维:[B, T_, h_, w_, 3] -> flatten 为 [B, N_v, 3] - rays = rays_grid.unsqueeze(1).expand(-1, t_, -1, -1, -1).contiguous() - rays = rays.reshape(b, t_ * h_ * w_, 3) - return rays - - # ---------- 前向 ---------- - - def forward( - self, - images: torch.Tensor, # [B, T=8, 3, H, W] - ego_6d_raw: torch.Tensor, # [B, 8, 6] - intr_raw: torch.Tensor, # [B, calib_intr_dim],须与构造时一致 - extr_6d_raw: torch.Tensor, # [B, 6] - ) -> E2EOutput: - b, t, _, h, w = images.shape - assert t == self.num_history, f"history frames mismatch: {t} vs {self.num_history}" - - # 1) DINOv3 patch tokens [B, T, gh, gw, D_dino] - dino_feats = self.dinov3(images) - - # 2) 校准(symlog 空间残差 + symexp 还原) - calib_out: CalibrationOutput = self.calib( - dino_feats=dino_feats, - ego_raw=ego_6d_raw, - intr_raw=intr_raw, - extr_raw=extr_6d_raw, - ) - corrected_ego = calib_out.corrected_ego - corrected_intr = calib_out.corrected_intr - corrected_extr_6d = calib_out.corrected_extr - - # 3) 把 corrected_extr 6D 转成 4x4 - from .data.se3 import six_d_to_matrix - cam2veh_corrected = six_d_to_matrix(corrected_extr_6d) # [B, 4, 4] - - # 4) 2x2x2 时空压缩 - compressed, thw = self.compress(dino_feats) # [B, N_v, D] - n_v = compressed.shape[1] - - # 5) 视觉射线(用 corrected_intr / corrected_extr) - rays = self._build_visual_rays(corrected_intr, cam2veh_corrected, thw) - rope_cos, rope_sin = self.rope.compute_freqs(rays) - - # 6) 构造非视觉 token - ego_tok = self._build_ego_tokens(corrected_ego) # [B, 8, D] - det_tok = self.det_tokens.unsqueeze(0).expand(b, -1, -1) - ctrl_tok = self.ctrl_tokens.unsqueeze(0).expand(b, -1, -1) - extra_tok = self.extra_tokens.unsqueeze(0).expand(b, -1, -1) - - ego_tok = self.ego_pe(ego_tok) - det_tok = self.det_pe(det_tok) - ctrl_tok = self.ctrl_pe(ctrl_tok) - extra_tok = self.extra_pe(extra_tok) - - # 7) 拼接序列:[vision | ego | det | ctrl | extra] - seq = torch.cat([compressed, ego_tok, det_tok, ctrl_tok, extra_tok], dim=1) - visual_slice = (0, n_v) - - # 8) 主干 - bb_out = self.backbone(seq, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) - - # 9) 切片送入头 - offset_det = n_v + self.num_ego - offset_ctrl = offset_det + self.num_det - - det_feats = bb_out.hidden_states[:, offset_det : offset_det + self.num_det] - ctrl_feats = bb_out.hidden_states[:, offset_ctrl : offset_ctrl + self.num_ctrl] - - det_out = self.det_traj_head(det_feats) - ctrl_out = self.ctrl_head(ctrl_feats) - - return E2EOutput( - detection=det_out, - control=ctrl_out, - backbone_out=bb_out, - calibration=calib_out, - ) +"""端到端自动驾驶模型 E2EAVModel。 + +forward 流程 + 1. ``DINOv3`` 提取 8 帧 patch 特征。 + 2. ``OnlineCalibration`` 用原始 ego/intr/extr (symlog) + DINOv3 patch 作 KV, + 输出 symlog 空间残差,叠加并 symexp 还原得到 corrected_*。 + 3. 用 corrected_intr / corrected_extr / corrected_ego 计算 + - 每 token 的自车系单位射线(仅用于视觉 token 的 RoPE 第一组头)。 + - 8 个 ego token(symlog 后线性投影)。 + 4. 2×2×2 时空压缩 -> 1536 视觉 token。 + 5. 拼接 [vision(1536) | ego(8) | det(1024) | ctrl(24) | extra(256)] = 2848 token。 + 非视觉切片各自加可学习 PE。 + 6. 18 层主干(仅视觉切片应用 3D RoPE)。 + 7. 切片送入 ``DetectionTrajHead`` 与 ``ControlHead``。 +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .backbone import Backbone, BackboneOutput +from .calibration import OnlineCalibration, CalibrationOutput +from .encoders import DINOv3Wrapper +from .heads import ( + ControlHead, + ControlOutput, + DetectionTrajHead, + DetectionTrajOutput, +) +from .modules.learned_pe import LearnedTokenPE +from .modules.normalization import symlog +from .modules.pos_encoding import RoPE3D +from .modules.rays import compute_ego_rays +from .modules.temporal_compress import TemporalCompress2x2x2 + + +@dataclass +class E2EOutput: + """模型完整输出。""" + + detection: DetectionTrajOutput + control: ControlOutput + backbone_out: BackboneOutput + calibration: CalibrationOutput + + +class E2EAVModel(nn.Module): + def __init__( + self, + dinov3_path: str = "./dinov3-vitb16-pretrain-lvd1689m", + backbone_dim: int = 768, + num_heads: int = 12, + num_dense_layers: int = 9, + num_moe_layers: int = 9, + num_routed_experts: int = 7, + num_shared_experts: int = 1, + topk_experts: int = 3, + ffn_mult: int = 4, + # token 数量 + num_history_frames: int = 8, + num_detection_tokens: int = 1024, + num_control_tokens: int = 24, + num_ego_tokens: int = 8, + num_extra_tokens: int = 256, + # 输入分辨率 + image_h: int = 384, + image_w: int = 1024, + patch_size: int = 16, + # 头超参 + num_classes: int = 22, + traj_horizon: int = 24, + det_head_hidden: int = 384, + ctrl_head_hidden: int = 384, + # 校准 + calib_dim: int = 256, + calib_num_query: int = 256, + calib_num_blocks: int = 2, + calib_num_self_per_block: int = 2, + calib_num_heads: int = 8, + calib_residual_range: float = 0.1, + calib_intr_dim: int = 11, + # DINOv3 + freeze_dinov3: bool = True, + attn_implementation: str = "sdpa", + ) -> None: + super().__init__() + self.image_h = image_h + self.image_w = image_w + self.patch_size = patch_size + self.num_history = num_history_frames + self.num_det = num_detection_tokens + self.num_ctrl = num_control_tokens + self.num_ego = num_ego_tokens + self.num_extra = num_extra_tokens + + # === 1) DINOv3 === + self.dinov3 = DINOv3Wrapper( + pretrained_path=dinov3_path, + attn_implementation=attn_implementation, + freeze=freeze_dinov3, + ) + dino_dim = self.dinov3.hidden_size + + # === 2) 在线校准 === + self.calib = OnlineCalibration( + dino_dim=dino_dim, + hidden_dim=calib_dim, + num_query_tokens=calib_num_query, + num_blocks=calib_num_blocks, + num_self_attn_per_block=calib_num_self_per_block, + num_heads=calib_num_heads, + residual_range=calib_residual_range, + num_history_frames=num_history_frames, + intr_dim=calib_intr_dim, + ) + + # === 3) 时空压缩 === + self.compress = TemporalCompress2x2x2(dim=dino_dim) + # patch 网格大小(必须能被 2 整除) + self.gh = image_h // patch_size + self.gw = image_w // patch_size + + # === 4) 各类 token + 可学习 PE === + self.ego_proj = nn.Linear(6, backbone_dim) # 6D pose -> backbone dim + self.det_tokens = nn.Parameter(torch.empty(num_detection_tokens, backbone_dim)) + nn.init.trunc_normal_(self.det_tokens, std=0.02) + self.ctrl_tokens = nn.Parameter(torch.empty(num_control_tokens, backbone_dim)) + nn.init.trunc_normal_(self.ctrl_tokens, std=0.02) + self.extra_tokens = nn.Parameter(torch.empty(num_extra_tokens, backbone_dim)) + nn.init.trunc_normal_(self.extra_tokens, std=0.02) + + self.ego_pe = LearnedTokenPE(num_ego_tokens, backbone_dim) + self.det_pe = LearnedTokenPE(num_detection_tokens, backbone_dim) + self.ctrl_pe = LearnedTokenPE(num_control_tokens, backbone_dim) + self.extra_pe = LearnedTokenPE(num_extra_tokens, backbone_dim) + + # === 5) RoPE 3D(仅视觉,4 时间帧 × 12 × 32 网格)=== + self.rope = RoPE3D( + num_heads=num_heads, + head_dim=backbone_dim // num_heads, + time_size=num_history_frames // 2, + height_size=self.gh // 2, + width_size=self.gw // 2, + ) + + # === 6) 主干 18 层 === + self.backbone = Backbone( + dim=backbone_dim, + num_heads=num_heads, + ffn_mult=ffn_mult, + num_dense_layers=num_dense_layers, + num_moe_layers=num_moe_layers, + num_routed=num_routed_experts, + num_shared=num_shared_experts, + topk=topk_experts, + ) + + # === 7) 头 === + self.det_traj_head = DetectionTrajHead( + in_dim=backbone_dim, + hidden_size=det_head_hidden, + num_classes=num_classes, + traj_horizon=traj_horizon, + ) + self.ctrl_head = ControlHead( + in_dim=backbone_dim, + hidden_size=ctrl_head_hidden, + num_traj_tokens=12, + num_action_tokens=num_control_tokens - 12, + ego_traj_horizon=traj_horizon, + ) + + # ---------- 工具 ---------- + + @property + def num_visual_tokens(self) -> int: + # 2×2×2 压缩后 + return (self.num_history // 2) * (self.gh // 2) * (self.gw // 2) + + def _build_ego_tokens(self, ego_6d_corrected: torch.Tensor) -> torch.Tensor: + """``[B, 8, 6]`` -> symlog -> Linear -> ``[B, 8, D]``。""" + return self.ego_proj(symlog(ego_6d_corrected)) + + def _build_visual_rays( + self, + intr_corrected: torch.Tensor, # [B, calib_intr_dim] + extr_corrected_se3: torch.Tensor, # [B, 4, 4] cam2vehicle + compressed_thw: tuple[int, int, int], + ) -> torch.Tensor: + """计算压缩后视觉 token 网格的射线方向。 + + 在 2×2×2 压缩后,每个视觉 token 对应原 patch 网格的一个 2x2 区域 + + 2 个时间帧。这里取所代表区域的中心像素与"中间时间"的射线作近似, + 所有时间帧取同一个 (h, w) 上的射线(因为相机 pose 在 8 帧间是 + rigid 的相机系;自车运动差异会通过 ego token 传递)。 + """ + b = intr_corrected.shape[0] + t_, h_, w_ = compressed_thw + rays_grid = compute_ego_rays( + intr_vec=intr_corrected, + cam2vehicle=extr_corrected_se3, + height=self.image_h, + width=self.image_w, + grid_h=h_, + grid_w=w_, + device=intr_corrected.device, + dtype=intr_corrected.dtype, + ) # [B, h_, w_, 3] + # 复制到时间维:[B, T_, h_, w_, 3] -> flatten 为 [B, N_v, 3] + rays = rays_grid.unsqueeze(1).expand(-1, t_, -1, -1, -1).contiguous() + rays = rays.reshape(b, t_ * h_ * w_, 3) + return rays + + # ---------- 前向 ---------- + + def forward( + self, + images: torch.Tensor, # [B, T=8, 3, H, W] + ego_6d_raw: torch.Tensor, # [B, 8, 6] + intr_raw: torch.Tensor, # [B, calib_intr_dim],须与构造时一致 + extr_6d_raw: torch.Tensor, # [B, 6] + ) -> E2EOutput: + b, t, _, h, w = images.shape + assert t == self.num_history, f"history frames mismatch: {t} vs {self.num_history}" + + # 1) DINOv3 patch tokens [B, T, gh, gw, D_dino] + dino_feats = self.dinov3(images) + + # 2) 校准(symlog 空间残差 + symexp 还原) + calib_out: CalibrationOutput = self.calib( + dino_feats=dino_feats, + ego_raw=ego_6d_raw, + intr_raw=intr_raw, + extr_raw=extr_6d_raw, + ) + corrected_ego = calib_out.corrected_ego + corrected_intr = calib_out.corrected_intr + corrected_extr_6d = calib_out.corrected_extr + + # 3) 把 corrected_extr 6D 转成 4x4 + from .data.se3 import six_d_to_matrix + cam2veh_corrected = six_d_to_matrix(corrected_extr_6d) # [B, 4, 4] + + # 4) 2x2x2 时空压缩 + compressed, thw = self.compress(dino_feats) # [B, N_v, D] + n_v = compressed.shape[1] + + # 5) 视觉射线(用 corrected_intr / corrected_extr) + rays = self._build_visual_rays(corrected_intr, cam2veh_corrected, thw) + rope_cos, rope_sin = self.rope.compute_freqs(rays) + + # 6) 构造非视觉 token + ego_tok = self._build_ego_tokens(corrected_ego) # [B, 8, D] + det_tok = self.det_tokens.unsqueeze(0).expand(b, -1, -1) + ctrl_tok = self.ctrl_tokens.unsqueeze(0).expand(b, -1, -1) + extra_tok = self.extra_tokens.unsqueeze(0).expand(b, -1, -1) + + ego_tok = self.ego_pe(ego_tok) + det_tok = self.det_pe(det_tok) + ctrl_tok = self.ctrl_pe(ctrl_tok) + extra_tok = self.extra_pe(extra_tok) + + # 7) 拼接序列:[vision | ego | det | ctrl | extra] + seq = torch.cat([compressed, ego_tok, det_tok, ctrl_tok, extra_tok], dim=1) + visual_slice = (0, n_v) + + # 8) 主干 + bb_out = self.backbone(seq, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) + + # 9) 切片送入头 + offset_det = n_v + self.num_ego + offset_ctrl = offset_det + self.num_det + + det_feats = bb_out.hidden_states[:, offset_det : offset_det + self.num_det] + ctrl_feats = bb_out.hidden_states[:, offset_ctrl : offset_ctrl + self.num_ctrl] + + det_out = self.det_traj_head(det_feats) + ctrl_out = self.ctrl_head(ctrl_feats) + + return E2EOutput( + detection=det_out, + control=ctrl_out, + backbone_out=bb_out, + calibration=calib_out, + ) diff --git a/src/wjad/modules/__init__.py b/src/wjad/modules/__init__.py index 50dd49caa1d3957c9b2f3769675b59534c17a663..177d4983e9139ebf35807e9f9affdfb2c6f33f9a 100644 --- a/src/wjad/modules/__init__.py +++ b/src/wjad/modules/__init__.py @@ -1,28 +1,28 @@ -"""公用算子模块集合。""" - -from __future__ import annotations - -from .ffn import SwiGLUFFN -from .gate_attention import GateSelfAttention, GateCrossAttention -from .moe import MoEBlock, PerLayerExperts -from .normalization import symlog, symexp -from .pos_encoding import RoPE3D, build_rope_freqs -from .learned_pe import LearnedTokenPE -from .rays import FThetaCamera, compute_ego_rays -from .temporal_compress import TemporalCompress2x2x2 - -__all__ = [ - "SwiGLUFFN", - "GateSelfAttention", - "GateCrossAttention", - "MoEBlock", - "PerLayerExperts", - "symlog", - "symexp", - "RoPE3D", - "build_rope_freqs", - "LearnedTokenPE", - "FThetaCamera", - "compute_ego_rays", - "TemporalCompress2x2x2", -] +"""公用算子模块集合。""" + +from __future__ import annotations + +from .ffn import SwiGLUFFN +from .gate_attention import GateSelfAttention, GateCrossAttention +from .moe import MoEBlock, PerLayerExperts +from .normalization import symlog, symexp +from .pos_encoding import RoPE3D, build_rope_freqs +from .learned_pe import LearnedTokenPE +from .rays import FThetaCamera, compute_ego_rays +from .temporal_compress import TemporalCompress2x2x2 + +__all__ = [ + "SwiGLUFFN", + "GateSelfAttention", + "GateCrossAttention", + "MoEBlock", + "PerLayerExperts", + "symlog", + "symexp", + "RoPE3D", + "build_rope_freqs", + "LearnedTokenPE", + "FThetaCamera", + "compute_ego_rays", + "TemporalCompress2x2x2", +] diff --git a/src/wjad/modules/ffn.py b/src/wjad/modules/ffn.py index 82a02edbc0b47ddde180502dfc865a1c968b2499..fb640348852871bef45f2b2a132e0efbc50c52bc 100644 --- a/src/wjad/modules/ffn.py +++ b/src/wjad/modules/ffn.py @@ -1,30 +1,30 @@ -"""SwiGLU 前馈网络。 - -实现:D -> Linear(2 * 4D) -> chunk2 -> SiLU(a) * b -> Linear(D) -即 D -> 4D -> SwiGLU -> 2D -> D,与 Design.md 规定一致。 -""" - -from __future__ import annotations - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class SwiGLUFFN(nn.Module): - """SwiGLU FFN: D->4D->SwiGLU->2D->D。 - - 使用 ``F.silu(a) * b`` 与现有 ``swiglu.py`` 中的实现一致。 - """ - - def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0, bias: bool = True) -> None: - super().__init__() - hidden = mult * dim - self.fc1 = nn.Linear(dim, hidden * 2, bias=bias) # 一次性投影出 a,b - self.fc2 = nn.Linear(hidden, dim, bias=bias) - self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - ab = self.fc1(x) - a, b = ab.chunk(2, dim=-1) - return self.drop(self.fc2(F.silu(a) * b)) +"""SwiGLU 前馈网络。 + +实现:D -> Linear(2 * 4D) -> chunk2 -> SiLU(a) * b -> Linear(D) +即 D -> 4D -> SwiGLU -> 2D -> D,与 Design.md 规定一致。 +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + """SwiGLU FFN: D->4D->SwiGLU->2D->D。 + + 使用 ``F.silu(a) * b`` 与现有 ``swiglu.py`` 中的实现一致。 + """ + + def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0, bias: bool = True) -> None: + super().__init__() + hidden = mult * dim + self.fc1 = nn.Linear(dim, hidden * 2, bias=bias) # 一次性投影出 a,b + self.fc2 = nn.Linear(hidden, dim, bias=bias) + self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + ab = self.fc1(x) + a, b = ab.chunk(2, dim=-1) + return self.drop(self.fc2(F.silu(a) * b)) diff --git a/src/wjad/modules/gate_attention.py b/src/wjad/modules/gate_attention.py index 11e87e10f173aea9989ad9bf929984d8e5bf3295..8e55f40fbd210a59d8bcbab6fa2958d453559d1b 100644 --- a/src/wjad/modules/gate_attention.py +++ b/src/wjad/modules/gate_attention.py @@ -1,181 +1,181 @@ -"""GateSelfAttention / GateCrossAttention(基于 PyTorch SDPA)。 - -与 Design.md 一致: - - Q 经 Linear + Sigmoid 生成 D 维门控参数; - - 注意力得到的多头 V 合并后与门控逐元素相乘,再做 out_proj; - - 门控网络初始化输出 ≈ 1(bias 设大正值,weight ≈ 0),低 LR 缓慢步进。 -""" - -from __future__ import annotations - -import math -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .pos_encoding import apply_rope - - -class _MultiHeadProj(nn.Module): - """通用的多头 Q/K/V 投影 + reshape。""" - - def __init__( - self, - dim_q: int, - dim_kv: int, - num_heads: int, - head_dim: int, - q_bias: bool = True, - kv_bias: bool = True, - ) -> None: - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - inner = num_heads * head_dim - self.q_proj = nn.Linear(dim_q, inner, bias=q_bias) - self.k_proj = nn.Linear(dim_kv, inner, bias=kv_bias) - self.v_proj = nn.Linear(dim_kv, inner, bias=kv_bias) - - def project_q(self, x: torch.Tensor) -> torch.Tensor: - b, n, _ = x.shape - return self.q_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2) - - def project_kv(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - b, n, _ = x.shape - k = self.k_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2) - v = self.v_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2) - return k, v - - -class _GateModule(nn.Module): - """门控生成器:输入 Q 来源张量,输出 [B,N,D] 门控值,初始 ≈ 1。 - - bias 初始化为 ``init_bias``(默认 5.0 → sigmoid≈0.993),weight 初始化为 0。 - 这样初始状态等价于普通注意力,门控随训练缓慢偏离 1。 - """ - - def __init__(self, dim: int, init_bias: float = 5.0) -> None: - super().__init__() - self.proj = nn.Linear(dim, dim) - nn.init.zeros_(self.proj.weight) - nn.init.constant_(self.proj.bias, init_bias) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.sigmoid(self.proj(x)) - - -class GateSelfAttention(nn.Module): - """门控自注意力,使用 PyTorch SDPA。 - - 支持仅对视觉 token 应用 3D RoPE:通过 ``visual_slice`` 指定切片。 - """ - - def __init__( - self, - dim: int, - num_heads: int, - dropout: float = 0.0, - gate_init_bias: float = 5.0, - q_bias: bool = True, - kv_bias: bool = True, - ) -> None: - super().__init__() - assert dim % num_heads == 0, "dim 必须能被 num_heads 整除" - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = 1.0 / math.sqrt(self.head_dim) - self.dropout_p = dropout - - self.proj = _MultiHeadProj(dim, dim, num_heads, self.head_dim, q_bias, kv_bias) - self.gate = _GateModule(dim, init_bias=gate_init_bias) - self.out_proj = nn.Linear(dim, dim, bias=True) - - def forward( - self, - x: torch.Tensor, - rope_cos: Optional[torch.Tensor] = None, - rope_sin: Optional[torch.Tensor] = None, - visual_slice: Optional[tuple[int, int]] = None, - ) -> torch.Tensor: - """ - 参数 - ---- - x : [B, N, D] - rope_cos, rope_sin : [B, N_v, H, head_dim/2] 或 None - visual_slice : (start, end),指定视觉 token 在序列中的范围。 - 非视觉 token 切片 Q/K 不做 RoPE。 - """ - b, n, _ = x.shape - q = self.proj.project_q(x) # [B, H, N, Dh] - k, v = self.proj.project_kv(x) - - # 仅对视觉切片应用 RoPE - if rope_cos is not None and visual_slice is not None: - s, e = visual_slice - q_v = q[:, :, s:e, :] - k_v = k[:, :, s:e, :] - q_v, k_v = apply_rope(q_v, k_v, rope_cos, rope_sin) - q = torch.cat([q[:, :, :s, :], q_v, q[:, :, e:, :]], dim=2) - k = torch.cat([k[:, :, :s, :], k_v, k[:, :, e:, :]], dim=2) - - # SDPA:[B, H, N, Dh] - attn = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - dropout_p=self.dropout_p if self.training else 0.0, - is_causal=False, - ) - # 多头合并 - attn = attn.transpose(1, 2).contiguous().view(b, n, self.dim) - - # 门控 ⊗ 多头合并后的 V,再 out_proj - gate = self.gate(x) # 用 Q 的源(即 x)生成门控 - out = self.out_proj(attn * gate) - return out - - -class GateCrossAttention(nn.Module): - """门控交叉注意力,Q 来自 query token,K/V 来自 context(如 DINOv3 patch 特征)。""" - - def __init__( - self, - dim_q: int, - dim_kv: int, - num_heads: int, - dropout: float = 0.0, - gate_init_bias: float = 5.0, - q_bias: bool = True, - kv_bias: bool = True, - ) -> None: - super().__init__() - assert dim_q % num_heads == 0, "dim_q 必须能被 num_heads 整除" - self.dim_q = dim_q - self.num_heads = num_heads - self.head_dim = dim_q // num_heads - self.dropout_p = dropout - - self.proj = _MultiHeadProj(dim_q, dim_kv, num_heads, self.head_dim, q_bias, kv_bias) - self.gate = _GateModule(dim_q, init_bias=gate_init_bias) - self.out_proj = nn.Linear(dim_q, dim_q, bias=True) - - def forward(self, q_in: torch.Tensor, kv_in: torch.Tensor) -> torch.Tensor: - b, n, _ = q_in.shape - q = self.proj.project_q(q_in) - k, v = self.proj.project_kv(kv_in) - - attn = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - dropout_p=self.dropout_p if self.training else 0.0, - is_causal=False, - ) - attn = attn.transpose(1, 2).contiguous().view(b, n, self.dim_q) - gate = self.gate(q_in) - return self.out_proj(attn * gate) +"""GateSelfAttention / GateCrossAttention(基于 PyTorch SDPA)。 + +与 Design.md 一致: + - Q 经 Linear + Sigmoid 生成 D 维门控参数; + - 注意力得到的多头 V 合并后与门控逐元素相乘,再做 out_proj; + - 门控网络初始化输出 ≈ 1(bias 设大正值,weight ≈ 0),低 LR 缓慢步进。 +""" + +from __future__ import annotations + +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .pos_encoding import apply_rope + + +class _MultiHeadProj(nn.Module): + """通用的多头 Q/K/V 投影 + reshape。""" + + def __init__( + self, + dim_q: int, + dim_kv: int, + num_heads: int, + head_dim: int, + q_bias: bool = True, + kv_bias: bool = True, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + inner = num_heads * head_dim + self.q_proj = nn.Linear(dim_q, inner, bias=q_bias) + self.k_proj = nn.Linear(dim_kv, inner, bias=kv_bias) + self.v_proj = nn.Linear(dim_kv, inner, bias=kv_bias) + + def project_q(self, x: torch.Tensor) -> torch.Tensor: + b, n, _ = x.shape + return self.q_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2) + + def project_kv(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + b, n, _ = x.shape + k = self.k_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2) + return k, v + + +class _GateModule(nn.Module): + """门控生成器:输入 Q 来源张量,输出 [B,N,D] 门控值,初始 ≈ 1。 + + bias 初始化为 ``init_bias``(默认 5.0 → sigmoid≈0.993),weight 初始化为 0。 + 这样初始状态等价于普通注意力,门控随训练缓慢偏离 1。 + """ + + def __init__(self, dim: int, init_bias: float = 5.0) -> None: + super().__init__() + self.proj = nn.Linear(dim, dim) + nn.init.zeros_(self.proj.weight) + nn.init.constant_(self.proj.bias, init_bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.sigmoid(self.proj(x)) + + +class GateSelfAttention(nn.Module): + """门控自注意力,使用 PyTorch SDPA。 + + 支持仅对视觉 token 应用 3D RoPE:通过 ``visual_slice`` 指定切片。 + """ + + def __init__( + self, + dim: int, + num_heads: int, + dropout: float = 0.0, + gate_init_bias: float = 5.0, + q_bias: bool = True, + kv_bias: bool = True, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim 必须能被 num_heads 整除" + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + self.dropout_p = dropout + + self.proj = _MultiHeadProj(dim, dim, num_heads, self.head_dim, q_bias, kv_bias) + self.gate = _GateModule(dim, init_bias=gate_init_bias) + self.out_proj = nn.Linear(dim, dim, bias=True) + + def forward( + self, + x: torch.Tensor, + rope_cos: Optional[torch.Tensor] = None, + rope_sin: Optional[torch.Tensor] = None, + visual_slice: Optional[tuple[int, int]] = None, + ) -> torch.Tensor: + """ + 参数 + ---- + x : [B, N, D] + rope_cos, rope_sin : [B, N_v, H, head_dim/2] 或 None + visual_slice : (start, end),指定视觉 token 在序列中的范围。 + 非视觉 token 切片 Q/K 不做 RoPE。 + """ + b, n, _ = x.shape + q = self.proj.project_q(x) # [B, H, N, Dh] + k, v = self.proj.project_kv(x) + + # 仅对视觉切片应用 RoPE + if rope_cos is not None and visual_slice is not None: + s, e = visual_slice + q_v = q[:, :, s:e, :] + k_v = k[:, :, s:e, :] + q_v, k_v = apply_rope(q_v, k_v, rope_cos, rope_sin) + q = torch.cat([q[:, :, :s, :], q_v, q[:, :, e:, :]], dim=2) + k = torch.cat([k[:, :, :s, :], k_v, k[:, :, e:, :]], dim=2) + + # SDPA:[B, H, N, Dh] + attn = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + dropout_p=self.dropout_p if self.training else 0.0, + is_causal=False, + ) + # 多头合并 + attn = attn.transpose(1, 2).contiguous().view(b, n, self.dim) + + # 门控 ⊗ 多头合并后的 V,再 out_proj + gate = self.gate(x) # 用 Q 的源(即 x)生成门控 + out = self.out_proj(attn * gate) + return out + + +class GateCrossAttention(nn.Module): + """门控交叉注意力,Q 来自 query token,K/V 来自 context(如 DINOv3 patch 特征)。""" + + def __init__( + self, + dim_q: int, + dim_kv: int, + num_heads: int, + dropout: float = 0.0, + gate_init_bias: float = 5.0, + q_bias: bool = True, + kv_bias: bool = True, + ) -> None: + super().__init__() + assert dim_q % num_heads == 0, "dim_q 必须能被 num_heads 整除" + self.dim_q = dim_q + self.num_heads = num_heads + self.head_dim = dim_q // num_heads + self.dropout_p = dropout + + self.proj = _MultiHeadProj(dim_q, dim_kv, num_heads, self.head_dim, q_bias, kv_bias) + self.gate = _GateModule(dim_q, init_bias=gate_init_bias) + self.out_proj = nn.Linear(dim_q, dim_q, bias=True) + + def forward(self, q_in: torch.Tensor, kv_in: torch.Tensor) -> torch.Tensor: + b, n, _ = q_in.shape + q = self.proj.project_q(q_in) + k, v = self.proj.project_kv(kv_in) + + attn = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + dropout_p=self.dropout_p if self.training else 0.0, + is_causal=False, + ) + attn = attn.transpose(1, 2).contiguous().view(b, n, self.dim_q) + gate = self.gate(q_in) + return self.out_proj(attn * gate) diff --git a/src/wjad/modules/learned_pe.py b/src/wjad/modules/learned_pe.py index 1670e761f5a1736b5f249e51249d7b7e4500f947..dc7c74e6e7dc5643eb636c912b1fdf2a606fd524 100644 --- a/src/wjad/modules/learned_pe.py +++ b/src/wjad/modules/learned_pe.py @@ -1,24 +1,24 @@ -"""非视觉 token 的可学习位置编码。 - -ego(8) / det(1024) / ctrl(24) / extra(256) 各自维护一份独立的 -``[N, D]`` 可学习参数,初始化 ``trunc_normal(std=0.02)``。 -直接加到对应 token 上,不参与 RoPE。 -""" - -from __future__ import annotations - -import torch -import torch.nn as nn - - -class LearnedTokenPE(nn.Module): - """形状为 ``[N, D]`` 的可学习位置编码,前向时按 batch 广播加。""" - - def __init__(self, num_tokens: int, dim: int, init_std: float = 0.02) -> None: - super().__init__() - self.pe = nn.Parameter(torch.empty(num_tokens, dim)) - nn.init.trunc_normal_(self.pe, std=init_std) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # x: [B, N, D] - return x + self.pe.unsqueeze(0) +"""非视觉 token 的可学习位置编码。 + +ego(8) / det(1024) / ctrl(24) / extra(256) 各自维护一份独立的 +``[N, D]`` 可学习参数,初始化 ``trunc_normal(std=0.02)``。 +直接加到对应 token 上,不参与 RoPE。 +""" + +from __future__ import annotations + +import torch +import torch.nn as nn + + +class LearnedTokenPE(nn.Module): + """形状为 ``[N, D]`` 的可学习位置编码,前向时按 batch 广播加。""" + + def __init__(self, num_tokens: int, dim: int, init_std: float = 0.02) -> None: + super().__init__() + self.pe = nn.Parameter(torch.empty(num_tokens, dim)) + nn.init.trunc_normal_(self.pe, std=init_std) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [B, N, D] + return x + self.pe.unsqueeze(0) diff --git a/src/wjad/modules/moe.py b/src/wjad/modules/moe.py index 70d52522de53996d8e03e0393a8d3a7349ce2770..f8accd7cc43fbddaea89d1b4b747552c50b80ee0 100644 --- a/src/wjad/modules/moe.py +++ b/src/wjad/modules/moe.py @@ -1,129 +1,129 @@ -"""每层独立 MoE 块(7 路由 + 1 共享专家,GAP 序列级 Sigmoid Top-3)。 - -设计要点(与 Design.md 对齐): - - 每层独立 8 个专家库(专家[0] 为共享),不同层之间不共享。 - - 路由:对当前层输入做 ``GAP(序列) -> Linear -> Sigmoid -> Top3 mask``。 - - 共享专家始终激活;路由专家依据 sigmoid 概率加权(Stage1 全激活、Stage2 严格 Top-3)。 - - 输出 = 共享专家(x) + sum_i (probs_i * mask_i) * expert_i(x)。 - - 提供路由 logits / probs / 负载均衡 / 边界正则的辅助统计,外部由 - ``losses/moe_aux.py`` 聚合成正则损失。 -""" - -from __future__ import annotations - -from dataclasses import dataclass - -import torch -import torch.nn as nn - -from .ffn import SwiGLUFFN - - -@dataclass -class MoEStats: - """单层 MoE 输出的辅助统计,用于损失与监控。""" - - logits: torch.Tensor # [B, num_routed] - probs: torch.Tensor # [B, num_routed],sigmoid 后的概率 - topk_mask: torch.Tensor # [B, num_routed],0/1 - - -class PerLayerExperts(nn.Module): - """单层的专家库:1 个共享 + N 个路由,全部为 SwiGLUFFN。""" - - def __init__( - self, - dim: int, - num_routed: int = 7, - num_shared: int = 1, - ffn_mult: int = 4, - dropout: float = 0.0, - ) -> None: - super().__init__() - self.num_routed = num_routed - self.num_shared = num_shared - self.shared = nn.ModuleList( - [SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_shared)] - ) - self.routed = nn.ModuleList( - [SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_routed)] - ) - - -class MoEBlock(nn.Module): - """带路由的 MoE FFN 块(每层独立专家库)。""" - - def __init__( - self, - dim: int, - num_routed: int = 7, - num_shared: int = 1, - ffn_mult: int = 4, - topk: int = 3, - dropout: float = 0.0, - ) -> None: - super().__init__() - self.dim = dim - self.num_routed = num_routed - self.num_shared = num_shared - self.topk = topk - self.experts = PerLayerExperts(dim, num_routed, num_shared, ffn_mult, dropout) - self.router = nn.Linear(dim, num_routed, bias=True) - # 路由初始化:bias=0、weight 较小,以使初始概率接近 0.5 - nn.init.normal_(self.router.weight, std=0.02) - nn.init.zeros_(self.router.bias) - - # 训练阶段:'dense' 等同于 topk=num_routed;'sparse' 用真实 topk - self._mode: str = "dense" - # 路由温度(温度 < 1 => 锐化) - self.register_buffer("router_temperature", torch.tensor(1.0)) - - def set_mode(self, mode: str) -> None: - assert mode in ("dense", "sparse"), f"未知模式: {mode}" - self._mode = mode - - @property - def mode(self) -> str: - return self._mode - - def set_temperature(self, t: float) -> None: - self.router_temperature.fill_(float(t)) - - def _route(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """计算 logits / probs / topk_mask。x: [B, N, D]。""" - pooled = x.mean(dim=1) # [B, D] - logits = self.router(pooled) # [B, num_routed] - # 温度锐化(温度小 => 概率更尖) - scaled = logits / self.router_temperature.clamp_min(1e-3) - probs = torch.sigmoid(scaled) - - if self._mode == "dense" or self.topk >= self.num_routed: - mask = torch.ones_like(probs) - else: - topk_vals, topk_idx = torch.topk(probs, self.topk, dim=-1) - mask = torch.zeros_like(probs) - mask.scatter_(-1, topk_idx, 1.0) - - return logits, probs, mask - - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, MoEStats]: - b, n, d = x.shape - logits, probs, mask = self._route(x) - - # 共享专家(恒激活,无门控) - out = torch.zeros_like(x) - for sh in self.experts.shared: - out = out + sh(x) - - # 路由专家:按 (probs * mask) 在 batch 维加权 - weights = probs * mask # [B, num_routed] - # 注意:每个样本各自的权重独立。逐专家计算后 batch 级加权,避免 token 级 - # 路由的索引开销;与 Design.md "序列级分配" 一致。 - for i, expert in enumerate(self.experts.routed): - w_i = weights[:, i].view(b, 1, 1) # [B,1,1] - # 仅当批内任一样本权重 > 0 时才前向以减少计算 - if torch.any(w_i > 0): - out = out + w_i * expert(x) - - stats = MoEStats(logits=logits, probs=probs, topk_mask=mask) - return out, stats +"""每层独立 MoE 块(7 路由 + 1 共享专家,GAP 序列级 Sigmoid Top-3)。 + +设计要点(与 Design.md 对齐): + - 每层独立 8 个专家库(专家[0] 为共享),不同层之间不共享。 + - 路由:对当前层输入做 ``GAP(序列) -> Linear -> Sigmoid -> Top3 mask``。 + - 共享专家始终激活;路由专家依据 sigmoid 概率加权(Stage1 全激活、Stage2 严格 Top-3)。 + - 输出 = 共享专家(x) + sum_i (probs_i * mask_i) * expert_i(x)。 + - 提供路由 logits / probs / 负载均衡 / 边界正则的辅助统计,外部由 + ``losses/moe_aux.py`` 聚合成正则损失。 +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from .ffn import SwiGLUFFN + + +@dataclass +class MoEStats: + """单层 MoE 输出的辅助统计,用于损失与监控。""" + + logits: torch.Tensor # [B, num_routed] + probs: torch.Tensor # [B, num_routed],sigmoid 后的概率 + topk_mask: torch.Tensor # [B, num_routed],0/1 + + +class PerLayerExperts(nn.Module): + """单层的专家库:1 个共享 + N 个路由,全部为 SwiGLUFFN。""" + + def __init__( + self, + dim: int, + num_routed: int = 7, + num_shared: int = 1, + ffn_mult: int = 4, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.num_routed = num_routed + self.num_shared = num_shared + self.shared = nn.ModuleList( + [SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_shared)] + ) + self.routed = nn.ModuleList( + [SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_routed)] + ) + + +class MoEBlock(nn.Module): + """带路由的 MoE FFN 块(每层独立专家库)。""" + + def __init__( + self, + dim: int, + num_routed: int = 7, + num_shared: int = 1, + ffn_mult: int = 4, + topk: int = 3, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.dim = dim + self.num_routed = num_routed + self.num_shared = num_shared + self.topk = topk + self.experts = PerLayerExperts(dim, num_routed, num_shared, ffn_mult, dropout) + self.router = nn.Linear(dim, num_routed, bias=True) + # 路由初始化:bias=0、weight 较小,以使初始概率接近 0.5 + nn.init.normal_(self.router.weight, std=0.02) + nn.init.zeros_(self.router.bias) + + # 训练阶段:'dense' 等同于 topk=num_routed;'sparse' 用真实 topk + self._mode: str = "dense" + # 路由温度(温度 < 1 => 锐化) + self.register_buffer("router_temperature", torch.tensor(1.0)) + + def set_mode(self, mode: str) -> None: + assert mode in ("dense", "sparse"), f"未知模式: {mode}" + self._mode = mode + + @property + def mode(self) -> str: + return self._mode + + def set_temperature(self, t: float) -> None: + self.router_temperature.fill_(float(t)) + + def _route(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """计算 logits / probs / topk_mask。x: [B, N, D]。""" + pooled = x.mean(dim=1) # [B, D] + logits = self.router(pooled) # [B, num_routed] + # 温度锐化(温度小 => 概率更尖) + scaled = logits / self.router_temperature.clamp_min(1e-3) + probs = torch.sigmoid(scaled) + + if self._mode == "dense" or self.topk >= self.num_routed: + mask = torch.ones_like(probs) + else: + topk_vals, topk_idx = torch.topk(probs, self.topk, dim=-1) + mask = torch.zeros_like(probs) + mask.scatter_(-1, topk_idx, 1.0) + + return logits, probs, mask + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, MoEStats]: + b, n, d = x.shape + logits, probs, mask = self._route(x) + + # 共享专家(恒激活,无门控) + out = torch.zeros_like(x) + for sh in self.experts.shared: + out = out + sh(x) + + # 路由专家:按 (probs * mask) 在 batch 维加权 + weights = probs * mask # [B, num_routed] + # 注意:每个样本各自的权重独立。逐专家计算后 batch 级加权,避免 token 级 + # 路由的索引开销;与 Design.md "序列级分配" 一致。 + for i, expert in enumerate(self.experts.routed): + w_i = weights[:, i].view(b, 1, 1) # [B,1,1] + # 仅当批内任一样本权重 > 0 时才前向以减少计算 + if torch.any(w_i > 0): + out = out + w_i * expert(x) + + stats = MoEStats(logits=logits, probs=probs, topk_mask=mask) + return out, stats diff --git a/src/wjad/modules/normalization.py b/src/wjad/modules/normalization.py index 772c75f91e4eb04e2fa280b4103cb8dff8db154c..a4e07dafba07dea3c409d0c5dc28834776e9ffec 100644 --- a/src/wjad/modules/normalization.py +++ b/src/wjad/modules/normalization.py @@ -1,22 +1,22 @@ -"""对称对数归一化算子 symlog / symexp。 - -公式: - symlog(x) = sign(x) * log(|x| + 1) - symexp(y) = sign(y) * (exp(|y|) - 1) - -用于运动学 / 坐标 / 内外参的归一化,使大幅值被压缩、保持可逆。 -""" - -from __future__ import annotations - -import torch - - -def symlog(x: torch.Tensor) -> torch.Tensor: - """对称对数压缩:sign(x) * log(|x| + 1)。""" - return torch.sign(x) * torch.log1p(torch.abs(x)) - - -def symexp(y: torch.Tensor) -> torch.Tensor: - """symlog 的逆:sign(y) * (exp(|y|) - 1)。""" - return torch.sign(y) * torch.expm1(torch.abs(y)) +"""对称对数归一化算子 symlog / symexp。 + +公式: + symlog(x) = sign(x) * log(|x| + 1) + symexp(y) = sign(y) * (exp(|y|) - 1) + +用于运动学 / 坐标 / 内外参的归一化,使大幅值被压缩、保持可逆。 +""" + +from __future__ import annotations + +import torch + + +def symlog(x: torch.Tensor) -> torch.Tensor: + """对称对数压缩:sign(x) * log(|x| + 1)。""" + return torch.sign(x) * torch.log1p(torch.abs(x)) + + +def symexp(y: torch.Tensor) -> torch.Tensor: + """symlog 的逆:sign(y) * (exp(|y|) - 1)。""" + return torch.sign(y) * torch.expm1(torch.abs(y)) diff --git a/src/wjad/modules/pos_encoding.py b/src/wjad/modules/pos_encoding.py index a70914d15ff1395317c5cea3f2a58433c3a19dca..f4d5c9e330cdff6069250186f4e91be7d45fdfca 100644 --- a/src/wjad/modules/pos_encoding.py +++ b/src/wjad/modules/pos_encoding.py @@ -1,224 +1,224 @@ -"""3D RoPE(仅作用于视觉 token)。 - -12 头按 4+4+4 拆为三组: - - 头 0-3:射线 RoPE,编码自车系下的单位射线方向 ``(dx, dy, dz)``。 - - 头 4-7:H/W/T RoPE,编码归一化的空间-时间索引 ``(h_norm, w_norm, t_norm)``。 - - 头 8-11:零频段 RoPE,cos=1 / sin=0 → 旋转矩阵恒为 I(identity)。 - -为减少分支与显存通信,全部 12 头统一走同一份 RoPE 算子(不写 if/else), -零频段头自然变为恒等映射。 - -将 ``head_dim=64`` 切成 32 个 (cos, sin) 对(两两一组旋转)。每组头内部再按 -3 个分量(dx,dy,dz 或 h,w,t)平均分配 32/3 ≈ 10 对(最后 2 对补 0 频)。 -""" - -from __future__ import annotations - -import torch -import torch.nn as nn - - -def _split_head_dim_for_components(half: int, num_components: int) -> list[int]: - """把 head_dim/2 个旋转对均匀分给若干个分量;剩余补 0 频。 - - 返回每个分量分到的旋转对数,最后一项是 ``half - sum(其它)``。 - 若 ``num_components == 0``(零频段头),则返回 ``[0, 0, ..., half]``,最后 - 一项视为"零频段"——它的频率会被显式置为 0。 - """ - if num_components == 0: - return [0, half] - base = half // num_components - splits = [base] * num_components - splits[-1] += half - base * num_components # 余数全归到最后一个分量 - return splits - - -def build_rope_freqs( - rays: torch.Tensor, - hwt_grid: torch.Tensor, - num_heads: int = 12, - head_dim: int = 64, - rope_theta: float = 10000.0, - device: torch.device | None = None, - dtype: torch.dtype = torch.float32, -) -> tuple[torch.Tensor, torch.Tensor]: - """构造 3D RoPE 的 cos / sin 表。 - - 参数 - ---- - rays : Tensor, shape ``[B, N_v, 3]`` - 每个视觉 token 在自车系下的单位射线方向 ``(dx, dy, dz)``。 - hwt_grid : Tensor, shape ``[B, N_v, 3]`` - 归一化的空间-时间坐标 ``(h_norm, w_norm, t_norm)`` ∈ [-1, 1]。 - num_heads : int - 总头数(默认 12)。 - head_dim : int - 每头维度(默认 64,必须为偶数)。 - - 返回 - ---- - cos, sin : Tensor, shape ``[B, N_v, num_heads, head_dim // 2]`` - 每个旋转对的 cos / sin 值,已就绪可送入 ``apply_rope``。 - """ - assert head_dim % 2 == 0, "head_dim 必须为偶数" - assert num_heads % 3 == 0, "num_heads 需被 3 整除以便 4+4+4 分组" - - half = head_dim // 2 - heads_per_group = num_heads // 3 - bsz, n_v, _ = rays.shape - if device is None: - device = rays.device - - # === 三组分量值 === - # group 0: rays (3 components) - # group 1: hwt (3 components) - # group 2: zero (0 components -> 全部 half 视为零频段) - splits_g0 = _split_head_dim_for_components(half, 3) # 用于 rays - splits_g1 = _split_head_dim_for_components(half, 3) # 用于 hwt - splits_g2 = _split_head_dim_for_components(half, 0) # [0, half] - - # === 频率向量(沿 head_dim 半轴)=== - # 经典 RoPE: theta_i = base ^ (-2i / d) - # 这里我们对每个分量独立排布频率 - def _freqs(num_pairs: int) -> torch.Tensor: - # 前 num_pairs 个用 RoPE 频率,剩余补 0 - idx = torch.arange(num_pairs, device=device, dtype=dtype) - freqs = rope_theta ** (-2.0 * idx / head_dim) - return freqs # [num_pairs] - - # 把分量值与频率张量逐头展开为 [B, N_v, num_heads, half] - angles = torch.zeros(bsz, n_v, num_heads, half, device=device, dtype=dtype) - - # ---- 第 0 组(4 头):射线 ---- - base_offset = 0 - h0_start = 0 - h0_end = h0_start + heads_per_group - cursor = 0 - for c in range(3): # dx, dy, dz - n_pairs = splits_g0[c] - if n_pairs > 0: - f = _freqs(n_pairs) # [n_pairs] - comp_val = rays[..., c : c + 1] # [B, N_v, 1] - ang = comp_val * f # 广播 -> [B, N_v, n_pairs] - angles[:, :, h0_start:h0_end, cursor : cursor + n_pairs] = ang.unsqueeze(2) - cursor += n_pairs - # 余数(splits_g0 最后一项的"补足"部分由 _split 已并入最后分量),无需置 0 - - # ---- 第 1 组(4 头):HWT ---- - h1_start = heads_per_group - h1_end = h1_start + heads_per_group - cursor = 0 - for c in range(3): # h, w, t - n_pairs = splits_g1[c] - if n_pairs > 0: - f = _freqs(n_pairs) - comp_val = hwt_grid[..., c : c + 1] - ang = comp_val * f - angles[:, :, h1_start:h1_end, cursor : cursor + n_pairs] = ang.unsqueeze(2) - cursor += n_pairs - - # ---- 第 2 组(4 头):零频段 ---- - # 角度恒为 0 → cos=1, sin=0 → 等价 identity;不需要再赋值(已是零) - - cos = torch.cos(angles) - sin = torch.sin(angles) - return cos, sin - - -def apply_rope( - q: torch.Tensor, - k: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - """对 ``q`` ``k`` 的视觉 token 部分应用 3D RoPE。 - - 所有 12 头一视同仁地走同一段代码(零频段头 cos=1/sin=0 → identity)。 - - 参数 - ---- - q, k : Tensor, shape ``[B, H, N_v, head_dim]`` - cos, sin : Tensor, shape ``[B, N_v, H, head_dim // 2]`` - - 返回 - ---- - 旋转后的 q, k,形状不变。 - """ - # 把 cos/sin 转成 [B, H, N_v, half] - cos_e = cos.permute(0, 2, 1, 3) - sin_e = sin.permute(0, 2, 1, 3) - - # 把 head_dim 维度按 (even, odd) 拆开成 [..., half] - q_even = q[..., 0::2] - q_odd = q[..., 1::2] - k_even = k[..., 0::2] - k_odd = k[..., 1::2] - - q_rot_even = q_even * cos_e - q_odd * sin_e - q_rot_odd = q_even * sin_e + q_odd * cos_e - k_rot_even = k_even * cos_e - k_odd * sin_e - k_rot_odd = k_even * sin_e + k_odd * cos_e - - q_out = torch.empty_like(q) - k_out = torch.empty_like(k) - q_out[..., 0::2] = q_rot_even - q_out[..., 1::2] = q_rot_odd - k_out[..., 0::2] = k_rot_even - k_out[..., 1::2] = k_rot_odd - return q_out, k_out - - -class RoPE3D(nn.Module): - """3D RoPE 工具模块:缓存 hwt_grid(视觉 token 网格上不变),动态计算 rays。 - - 使用方式: - rope = RoPE3D(num_heads=12, head_dim=64, T=4, H=12, W=32) - cos, sin = rope.compute_freqs(rays) # rays: [B, N_v, 3] - q, k = apply_rope(q_visual_only, k_visual_only, cos, sin) - """ - - def __init__( - self, - num_heads: int = 12, - head_dim: int = 64, - time_size: int = 4, - height_size: int = 12, - width_size: int = 32, - rope_theta: float = 10000.0, - ) -> None: - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - self.rope_theta = rope_theta - self.T = time_size - self.H = height_size - self.W = width_size - - # 预计算并缓存归一化 H/W/T 网格 [N_v, 3],N_v = T*H*W - t = torch.linspace(-1.0, 1.0, steps=time_size) if time_size > 1 else torch.zeros(1) - h = torch.linspace(-1.0, 1.0, steps=height_size) if height_size > 1 else torch.zeros(1) - w = torch.linspace(-1.0, 1.0, steps=width_size) if width_size > 1 else torch.zeros(1) - # 顺序:t -> h -> w(与 Conv3D 输出展平顺序一致) - T_, H_, W_ = torch.meshgrid(t, h, w, indexing="ij") - hwt = torch.stack([H_, W_, T_], dim=-1).reshape(-1, 3) # [N_v, 3] - self.register_buffer("hwt_grid", hwt, persistent=False) - - @property - def num_visual_tokens(self) -> int: - return self.T * self.H * self.W - - def compute_freqs(self, rays: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """根据每 token 的射线方向计算 cos/sin。 - - ``rays`` shape: ``[B, N_v, 3]``。 - """ - bsz = rays.shape[0] - hwt = self.hwt_grid.unsqueeze(0).expand(bsz, -1, -1) # [B, N_v, 3] - return build_rope_freqs( - rays=rays, - hwt_grid=hwt, - num_heads=self.num_heads, - head_dim=self.head_dim, - rope_theta=self.rope_theta, - dtype=rays.dtype, - ) +"""3D RoPE(仅作用于视觉 token)。 + +12 头按 4+4+4 拆为三组: + - 头 0-3:射线 RoPE,编码自车系下的单位射线方向 ``(dx, dy, dz)``。 + - 头 4-7:H/W/T RoPE,编码归一化的空间-时间索引 ``(h_norm, w_norm, t_norm)``。 + - 头 8-11:零频段 RoPE,cos=1 / sin=0 → 旋转矩阵恒为 I(identity)。 + +为减少分支与显存通信,全部 12 头统一走同一份 RoPE 算子(不写 if/else), +零频段头自然变为恒等映射。 + +将 ``head_dim=64`` 切成 32 个 (cos, sin) 对(两两一组旋转)。每组头内部再按 +3 个分量(dx,dy,dz 或 h,w,t)平均分配 32/3 ≈ 10 对(最后 2 对补 0 频)。 +""" + +from __future__ import annotations + +import torch +import torch.nn as nn + + +def _split_head_dim_for_components(half: int, num_components: int) -> list[int]: + """把 head_dim/2 个旋转对均匀分给若干个分量;剩余补 0 频。 + + 返回每个分量分到的旋转对数,最后一项是 ``half - sum(其它)``。 + 若 ``num_components == 0``(零频段头),则返回 ``[0, 0, ..., half]``,最后 + 一项视为"零频段"——它的频率会被显式置为 0。 + """ + if num_components == 0: + return [0, half] + base = half // num_components + splits = [base] * num_components + splits[-1] += half - base * num_components # 余数全归到最后一个分量 + return splits + + +def build_rope_freqs( + rays: torch.Tensor, + hwt_grid: torch.Tensor, + num_heads: int = 12, + head_dim: int = 64, + rope_theta: float = 10000.0, + device: torch.device | None = None, + dtype: torch.dtype = torch.float32, +) -> tuple[torch.Tensor, torch.Tensor]: + """构造 3D RoPE 的 cos / sin 表。 + + 参数 + ---- + rays : Tensor, shape ``[B, N_v, 3]`` + 每个视觉 token 在自车系下的单位射线方向 ``(dx, dy, dz)``。 + hwt_grid : Tensor, shape ``[B, N_v, 3]`` + 归一化的空间-时间坐标 ``(h_norm, w_norm, t_norm)`` ∈ [-1, 1]。 + num_heads : int + 总头数(默认 12)。 + head_dim : int + 每头维度(默认 64,必须为偶数)。 + + 返回 + ---- + cos, sin : Tensor, shape ``[B, N_v, num_heads, head_dim // 2]`` + 每个旋转对的 cos / sin 值,已就绪可送入 ``apply_rope``。 + """ + assert head_dim % 2 == 0, "head_dim 必须为偶数" + assert num_heads % 3 == 0, "num_heads 需被 3 整除以便 4+4+4 分组" + + half = head_dim // 2 + heads_per_group = num_heads // 3 + bsz, n_v, _ = rays.shape + if device is None: + device = rays.device + + # === 三组分量值 === + # group 0: rays (3 components) + # group 1: hwt (3 components) + # group 2: zero (0 components -> 全部 half 视为零频段) + splits_g0 = _split_head_dim_for_components(half, 3) # 用于 rays + splits_g1 = _split_head_dim_for_components(half, 3) # 用于 hwt + splits_g2 = _split_head_dim_for_components(half, 0) # [0, half] + + # === 频率向量(沿 head_dim 半轴)=== + # 经典 RoPE: theta_i = base ^ (-2i / d) + # 这里我们对每个分量独立排布频率 + def _freqs(num_pairs: int) -> torch.Tensor: + # 前 num_pairs 个用 RoPE 频率,剩余补 0 + idx = torch.arange(num_pairs, device=device, dtype=dtype) + freqs = rope_theta ** (-2.0 * idx / head_dim) + return freqs # [num_pairs] + + # 把分量值与频率张量逐头展开为 [B, N_v, num_heads, half] + angles = torch.zeros(bsz, n_v, num_heads, half, device=device, dtype=dtype) + + # ---- 第 0 组(4 头):射线 ---- + base_offset = 0 + h0_start = 0 + h0_end = h0_start + heads_per_group + cursor = 0 + for c in range(3): # dx, dy, dz + n_pairs = splits_g0[c] + if n_pairs > 0: + f = _freqs(n_pairs) # [n_pairs] + comp_val = rays[..., c : c + 1] # [B, N_v, 1] + ang = comp_val * f # 广播 -> [B, N_v, n_pairs] + angles[:, :, h0_start:h0_end, cursor : cursor + n_pairs] = ang.unsqueeze(2) + cursor += n_pairs + # 余数(splits_g0 最后一项的"补足"部分由 _split 已并入最后分量),无需置 0 + + # ---- 第 1 组(4 头):HWT ---- + h1_start = heads_per_group + h1_end = h1_start + heads_per_group + cursor = 0 + for c in range(3): # h, w, t + n_pairs = splits_g1[c] + if n_pairs > 0: + f = _freqs(n_pairs) + comp_val = hwt_grid[..., c : c + 1] + ang = comp_val * f + angles[:, :, h1_start:h1_end, cursor : cursor + n_pairs] = ang.unsqueeze(2) + cursor += n_pairs + + # ---- 第 2 组(4 头):零频段 ---- + # 角度恒为 0 → cos=1, sin=0 → 等价 identity;不需要再赋值(已是零) + + cos = torch.cos(angles) + sin = torch.sin(angles) + return cos, sin + + +def apply_rope( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """对 ``q`` ``k`` 的视觉 token 部分应用 3D RoPE。 + + 所有 12 头一视同仁地走同一段代码(零频段头 cos=1/sin=0 → identity)。 + + 参数 + ---- + q, k : Tensor, shape ``[B, H, N_v, head_dim]`` + cos, sin : Tensor, shape ``[B, N_v, H, head_dim // 2]`` + + 返回 + ---- + 旋转后的 q, k,形状不变。 + """ + # 把 cos/sin 转成 [B, H, N_v, half] + cos_e = cos.permute(0, 2, 1, 3) + sin_e = sin.permute(0, 2, 1, 3) + + # 把 head_dim 维度按 (even, odd) 拆开成 [..., half] + q_even = q[..., 0::2] + q_odd = q[..., 1::2] + k_even = k[..., 0::2] + k_odd = k[..., 1::2] + + q_rot_even = q_even * cos_e - q_odd * sin_e + q_rot_odd = q_even * sin_e + q_odd * cos_e + k_rot_even = k_even * cos_e - k_odd * sin_e + k_rot_odd = k_even * sin_e + k_odd * cos_e + + q_out = torch.empty_like(q) + k_out = torch.empty_like(k) + q_out[..., 0::2] = q_rot_even + q_out[..., 1::2] = q_rot_odd + k_out[..., 0::2] = k_rot_even + k_out[..., 1::2] = k_rot_odd + return q_out, k_out + + +class RoPE3D(nn.Module): + """3D RoPE 工具模块:缓存 hwt_grid(视觉 token 网格上不变),动态计算 rays。 + + 使用方式: + rope = RoPE3D(num_heads=12, head_dim=64, T=4, H=12, W=32) + cos, sin = rope.compute_freqs(rays) # rays: [B, N_v, 3] + q, k = apply_rope(q_visual_only, k_visual_only, cos, sin) + """ + + def __init__( + self, + num_heads: int = 12, + head_dim: int = 64, + time_size: int = 4, + height_size: int = 12, + width_size: int = 32, + rope_theta: float = 10000.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.rope_theta = rope_theta + self.T = time_size + self.H = height_size + self.W = width_size + + # 预计算并缓存归一化 H/W/T 网格 [N_v, 3],N_v = T*H*W + t = torch.linspace(-1.0, 1.0, steps=time_size) if time_size > 1 else torch.zeros(1) + h = torch.linspace(-1.0, 1.0, steps=height_size) if height_size > 1 else torch.zeros(1) + w = torch.linspace(-1.0, 1.0, steps=width_size) if width_size > 1 else torch.zeros(1) + # 顺序:t -> h -> w(与 Conv3D 输出展平顺序一致) + T_, H_, W_ = torch.meshgrid(t, h, w, indexing="ij") + hwt = torch.stack([H_, W_, T_], dim=-1).reshape(-1, 3) # [N_v, 3] + self.register_buffer("hwt_grid", hwt, persistent=False) + + @property + def num_visual_tokens(self) -> int: + return self.T * self.H * self.W + + def compute_freqs(self, rays: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """根据每 token 的射线方向计算 cos/sin。 + + ``rays`` shape: ``[B, N_v, 3]``。 + """ + bsz = rays.shape[0] + hwt = self.hwt_grid.unsqueeze(0).expand(bsz, -1, -1) # [B, N_v, 3] + return build_rope_freqs( + rays=rays, + hwt_grid=hwt, + num_heads=self.num_heads, + head_dim=self.head_dim, + rope_theta=self.rope_theta, + dtype=rays.dtype, + ) diff --git a/src/wjad/modules/rays.py b/src/wjad/modules/rays.py index 97da57826445df06b732f3f68b7946adc0bd3da5..887900bd6e3c6bb1e658e867d72232d0d8d351c4 100644 --- a/src/wjad/modules/rays.py +++ b/src/wjad/modules/rays.py @@ -1,182 +1,182 @@ -"""f-theta 相机模型 + 射线计算。 - -依据 Cosmos-Drive-Dreams 数据集 README: - ftheta_intrinsic 存储为 ``[cx, cy, w, h, *poly(6), is_bw_poly, *linear_cde(3)]``。 - -f-theta 相机模型用 6 阶多项式将像素半径 ``r_pix = ||(u-cx, v-cy)||`` 映射到 -入射角 ``theta``(或反向)。``is_bw_poly == True`` 表示多项式是从 ``r_pix`` 反 -求 ``theta`` 的 backward polynomial(pixel -> theta);否则是 forward polynomial -(theta -> r_pix)。``linear_cde`` 是仿射修正系数 ``[c, d, e]``,用于补偿轻微 -的非旋转对称形变。 - -为了简单与可微,本模块默认假设 backward polynomial(``is_bw_poly=True``, -即 ``theta = poly(r_pix)``);实际数据通常是这种格式。如需 forward 多项式, -这里使用牛顿迭代反求。 -""" - -from __future__ import annotations - -from dataclasses import dataclass - -import torch -import torch.nn.functional as F - - -@dataclass -class FThetaIntrinsic: - """f-theta 内参(PyTorch 张量形式)。 - - 所有字段均为标量或一维向量;外部使用时通常 broadcast 到 batch。 - """ - - cx: torch.Tensor # () - cy: torch.Tensor # () - w: int - h: int - poly: torch.Tensor # (6,) - is_bw_poly: bool - linear_cde: torch.Tensor # (3,) - - -class FThetaCamera: - """f-theta 相机:像素 -> 单位射线方向(相机坐标系)。""" - - def __init__(self, intr: FThetaIntrinsic) -> None: - self.intr = intr - - @staticmethod - def from_vector(vec: torch.Tensor) -> "FThetaCamera": - """从 NVIDIA ftheta 向量构造:``[cx, cy, w, h, poly×6, is_bw_poly?, linear_cde×3?]``。 - - 官方常见 14 维;部分 clip 仅 11 维(无 ``linear_cde``),此时用 ``(1,0,1)``, - 与 ``unproject`` 里近似一致。 - """ - v = vec.flatten().float() - n = int(v.numel()) - if n < 10: - raise ValueError(f"ftheta intrinsic 维度 {n} < 10(至少需要 cx,cy,w,h + 6 poly)") - cx = v[0] - cy = v[1] - w = int(v[2].item()) - h = int(v[3].item()) - poly = v[4:10].clone() - if n >= 11: - is_bw = bool(v[10].item() > 0.5) - else: - is_bw = True - if n >= 14: - linear_cde = v[11:14].clone() - else: - linear_cde = torch.tensor([1.0, 0.0, 1.0], dtype=v.dtype, device=v.device) - return FThetaCamera( - FThetaIntrinsic(cx=cx, cy=cy, w=w, h=h, poly=poly, is_bw_poly=is_bw, linear_cde=linear_cde) - ) - - def _eval_poly(self, r: torch.Tensor) -> torch.Tensor: - """用 Horner 法计算 poly(r) = sum_{i=0..5} c_i * r^i。""" - c = self.intr.poly - out = torch.zeros_like(r) - for i in range(c.numel() - 1, -1, -1): - out = out * r + c[i] - return out - - def _eval_poly_grad(self, r: torch.Tensor) -> torch.Tensor: - """poly 的导数。""" - c = self.intr.poly - n = c.numel() - out = torch.zeros_like(r) - for i in range(n - 1, 0, -1): - out = out * r + c[i] * float(i) - return out - - def pixel_to_theta(self, r_pix: torch.Tensor) -> torch.Tensor: - """像素半径 -> 入射角 theta(弧度)。""" - if self.intr.is_bw_poly: - return self._eval_poly(r_pix) - # forward: r_pix = poly(theta) -> 牛顿迭代 - theta = r_pix.clone() - for _ in range(8): - f = self._eval_poly(theta) - r_pix - df = self._eval_poly_grad(theta).clamp_min(1e-6) - theta = theta - f / df - return theta - - def unproject(self, uv: torch.Tensor) -> torch.Tensor: - """像素坐标 ``[..., 2]`` -> 相机坐标系下的单位方向 ``[..., 3]``。 - - f-theta 反投影: - (du, dv) = (u - cx, v - cy) (并应用 linear_cde 的微小仿射) - r_pix = ||(du, dv)|| - theta = poly(r_pix) 或 inv_poly(r_pix) - phi = atan2(dv, du) - dir_cam = (sin(theta)*cos(phi), sin(theta)*sin(phi), cos(theta)) - """ - cx = self.intr.cx - cy = self.intr.cy - c, d, e = self.intr.linear_cde[0], self.intr.linear_cde[1], self.intr.linear_cde[2] - - u = uv[..., 0] - v = uv[..., 1] - # 应用线性修正:du' = c*du + d*dv + e*1(NVIDIA 工具中通常是 2x2 仿射,这里做近似) - du0 = u - cx - dv0 = v - cy - du = c * du0 + d * dv0 - dv = e * du0 + dv0 # 简化:保持 dv 不变量、加入 e*du 微调 - r_pix = torch.sqrt(du * du + dv * dv).clamp_min(1e-6) - theta = self.pixel_to_theta(r_pix) - - sin_t = torch.sin(theta) - cos_t = torch.cos(theta) - cos_p = du / r_pix - sin_p = dv / r_pix - x = sin_t * cos_p - y = sin_t * sin_p - z = cos_t - dir_cam = torch.stack([x, y, z], dim=-1) - return F.normalize(dir_cam, dim=-1) - - -def compute_ego_rays( - intr_vec: torch.Tensor, - cam2vehicle: torch.Tensor, - height: int, - width: int, - grid_h: int, - grid_w: int, - device: torch.device, - dtype: torch.dtype = torch.float32, -) -> torch.Tensor: - """对一个 ``grid_h x grid_w`` 的均匀像素网格计算自车系下单位射线方向。 - - 参数 - ---- - intr_vec : ``[B, 14]`` 或 ``[14]``,f-theta 内参向量。 - cam2vehicle : ``[B, 4, 4]`` 或 ``[4, 4]`` 相机系到自车系的变换。 - height, width : 图像分辨率(像素),用于在 ``[0, w] x [0, h]`` 网格采样。 - grid_h, grid_w : 输出射线网格分辨率(与 patch 网格一致,例如 24x64)。 - - 返回 - ---- - rays : ``[B, grid_h, grid_w, 3]``,自车系下单位方向。 - """ - if intr_vec.dim() == 1: - intr_vec = intr_vec.unsqueeze(0) - if cam2vehicle.dim() == 2: - cam2vehicle = cam2vehicle.unsqueeze(0) - B = intr_vec.shape[0] - - # 在像素中心采样 - u = (torch.arange(grid_w, device=device, dtype=dtype) + 0.5) * (width / grid_w) - v = (torch.arange(grid_h, device=device, dtype=dtype) + 0.5) * (height / grid_h) - vv, uu = torch.meshgrid(v, u, indexing="ij") # [gh, gw] - uv = torch.stack([uu, vv], dim=-1) # [gh, gw, 2] - - out = [] - for b in range(B): - cam = FThetaCamera.from_vector(intr_vec[b].to(dtype)) - dir_cam = cam.unproject(uv) # [gh, gw, 3] - # 旋到自车系:取 cam2vehicle 的 3x3 旋转部分 - R = cam2vehicle[b, :3, :3].to(dtype) - dir_veh = dir_cam @ R.T # [gh, gw, 3] - out.append(F.normalize(dir_veh, dim=-1)) - return torch.stack(out, dim=0) +"""f-theta 相机模型 + 射线计算。 + +依据 Cosmos-Drive-Dreams 数据集 README: + ftheta_intrinsic 存储为 ``[cx, cy, w, h, *poly(6), is_bw_poly, *linear_cde(3)]``。 + +f-theta 相机模型用 6 阶多项式将像素半径 ``r_pix = ||(u-cx, v-cy)||`` 映射到 +入射角 ``theta``(或反向)。``is_bw_poly == True`` 表示多项式是从 ``r_pix`` 反 +求 ``theta`` 的 backward polynomial(pixel -> theta);否则是 forward polynomial +(theta -> r_pix)。``linear_cde`` 是仿射修正系数 ``[c, d, e]``,用于补偿轻微 +的非旋转对称形变。 + +为了简单与可微,本模块默认假设 backward polynomial(``is_bw_poly=True``, +即 ``theta = poly(r_pix)``);实际数据通常是这种格式。如需 forward 多项式, +这里使用牛顿迭代反求。 +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch +import torch.nn.functional as F + + +@dataclass +class FThetaIntrinsic: + """f-theta 内参(PyTorch 张量形式)。 + + 所有字段均为标量或一维向量;外部使用时通常 broadcast 到 batch。 + """ + + cx: torch.Tensor # () + cy: torch.Tensor # () + w: int + h: int + poly: torch.Tensor # (6,) + is_bw_poly: bool + linear_cde: torch.Tensor # (3,) + + +class FThetaCamera: + """f-theta 相机:像素 -> 单位射线方向(相机坐标系)。""" + + def __init__(self, intr: FThetaIntrinsic) -> None: + self.intr = intr + + @staticmethod + def from_vector(vec: torch.Tensor) -> "FThetaCamera": + """从 NVIDIA ftheta 向量构造:``[cx, cy, w, h, poly×6, is_bw_poly?, linear_cde×3?]``。 + + 官方常见 14 维;部分 clip 仅 11 维(无 ``linear_cde``),此时用 ``(1,0,1)``, + 与 ``unproject`` 里近似一致。 + """ + v = vec.flatten().float() + n = int(v.numel()) + if n < 10: + raise ValueError(f"ftheta intrinsic 维度 {n} < 10(至少需要 cx,cy,w,h + 6 poly)") + cx = v[0] + cy = v[1] + w = int(v[2].item()) + h = int(v[3].item()) + poly = v[4:10].clone() + if n >= 11: + is_bw = bool(v[10].item() > 0.5) + else: + is_bw = True + if n >= 14: + linear_cde = v[11:14].clone() + else: + linear_cde = torch.tensor([1.0, 0.0, 1.0], dtype=v.dtype, device=v.device) + return FThetaCamera( + FThetaIntrinsic(cx=cx, cy=cy, w=w, h=h, poly=poly, is_bw_poly=is_bw, linear_cde=linear_cde) + ) + + def _eval_poly(self, r: torch.Tensor) -> torch.Tensor: + """用 Horner 法计算 poly(r) = sum_{i=0..5} c_i * r^i。""" + c = self.intr.poly + out = torch.zeros_like(r) + for i in range(c.numel() - 1, -1, -1): + out = out * r + c[i] + return out + + def _eval_poly_grad(self, r: torch.Tensor) -> torch.Tensor: + """poly 的导数。""" + c = self.intr.poly + n = c.numel() + out = torch.zeros_like(r) + for i in range(n - 1, 0, -1): + out = out * r + c[i] * float(i) + return out + + def pixel_to_theta(self, r_pix: torch.Tensor) -> torch.Tensor: + """像素半径 -> 入射角 theta(弧度)。""" + if self.intr.is_bw_poly: + return self._eval_poly(r_pix) + # forward: r_pix = poly(theta) -> 牛顿迭代 + theta = r_pix.clone() + for _ in range(8): + f = self._eval_poly(theta) - r_pix + df = self._eval_poly_grad(theta).clamp_min(1e-6) + theta = theta - f / df + return theta + + def unproject(self, uv: torch.Tensor) -> torch.Tensor: + """像素坐标 ``[..., 2]`` -> 相机坐标系下的单位方向 ``[..., 3]``。 + + f-theta 反投影: + (du, dv) = (u - cx, v - cy) (并应用 linear_cde 的微小仿射) + r_pix = ||(du, dv)|| + theta = poly(r_pix) 或 inv_poly(r_pix) + phi = atan2(dv, du) + dir_cam = (sin(theta)*cos(phi), sin(theta)*sin(phi), cos(theta)) + """ + cx = self.intr.cx + cy = self.intr.cy + c, d, e = self.intr.linear_cde[0], self.intr.linear_cde[1], self.intr.linear_cde[2] + + u = uv[..., 0] + v = uv[..., 1] + # 应用线性修正:du' = c*du + d*dv + e*1(NVIDIA 工具中通常是 2x2 仿射,这里做近似) + du0 = u - cx + dv0 = v - cy + du = c * du0 + d * dv0 + dv = e * du0 + dv0 # 简化:保持 dv 不变量、加入 e*du 微调 + r_pix = torch.sqrt(du * du + dv * dv).clamp_min(1e-6) + theta = self.pixel_to_theta(r_pix) + + sin_t = torch.sin(theta) + cos_t = torch.cos(theta) + cos_p = du / r_pix + sin_p = dv / r_pix + x = sin_t * cos_p + y = sin_t * sin_p + z = cos_t + dir_cam = torch.stack([x, y, z], dim=-1) + return F.normalize(dir_cam, dim=-1) + + +def compute_ego_rays( + intr_vec: torch.Tensor, + cam2vehicle: torch.Tensor, + height: int, + width: int, + grid_h: int, + grid_w: int, + device: torch.device, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """对一个 ``grid_h x grid_w`` 的均匀像素网格计算自车系下单位射线方向。 + + 参数 + ---- + intr_vec : ``[B, 14]`` 或 ``[14]``,f-theta 内参向量。 + cam2vehicle : ``[B, 4, 4]`` 或 ``[4, 4]`` 相机系到自车系的变换。 + height, width : 图像分辨率(像素),用于在 ``[0, w] x [0, h]`` 网格采样。 + grid_h, grid_w : 输出射线网格分辨率(与 patch 网格一致,例如 24x64)。 + + 返回 + ---- + rays : ``[B, grid_h, grid_w, 3]``,自车系下单位方向。 + """ + if intr_vec.dim() == 1: + intr_vec = intr_vec.unsqueeze(0) + if cam2vehicle.dim() == 2: + cam2vehicle = cam2vehicle.unsqueeze(0) + B = intr_vec.shape[0] + + # 在像素中心采样 + u = (torch.arange(grid_w, device=device, dtype=dtype) + 0.5) * (width / grid_w) + v = (torch.arange(grid_h, device=device, dtype=dtype) + 0.5) * (height / grid_h) + vv, uu = torch.meshgrid(v, u, indexing="ij") # [gh, gw] + uv = torch.stack([uu, vv], dim=-1) # [gh, gw, 2] + + out = [] + for b in range(B): + cam = FThetaCamera.from_vector(intr_vec[b].to(dtype)) + dir_cam = cam.unproject(uv) # [gh, gw, 3] + # 旋到自车系:取 cam2vehicle 的 3x3 旋转部分 + R = cam2vehicle[b, :3, :3].to(dtype) + dir_veh = dir_cam @ R.T # [gh, gw, 3] + out.append(F.normalize(dir_veh, dim=-1)) + return torch.stack(out, dim=0) diff --git a/src/wjad/modules/temporal_compress.py b/src/wjad/modules/temporal_compress.py index cd02b94d0eac1415206f54dc5dcbe7d95da5b751..c85500d00b413d73e166b17cc9ae652d74362018 100644 --- a/src/wjad/modules/temporal_compress.py +++ b/src/wjad/modules/temporal_compress.py @@ -1,34 +1,34 @@ -"""2×2×2 时空压缩卷积。 - -将 8 帧 × 24 × 64 = 8×1536 = 12288 个 patch tokens 压缩为 -4 × 12 × 32 = 1536 个视觉 tokens。维度保持 768。 -""" - -from __future__ import annotations - -import torch -import torch.nn as nn - - -class TemporalCompress2x2x2(nn.Module): - """``Conv3d(D, D, kernel=2, stride=2)`` 配合标准 LayerNorm。""" - - def __init__(self, dim: int = 768) -> None: - super().__init__() - self.dim = dim - self.conv = nn.Conv3d(dim, dim, kernel_size=2, stride=2, padding=0) - self.norm = nn.LayerNorm(dim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """输入 ``[B, T, H, W, D]``;输出 ``[B, T*H*W//8, D]``。 - - 中间排布: - [B, T, H, W, D] -> [B, D, T, H, W] -> Conv3d -> [B, D, T', H', W'] - -> [B, T'*H'*W', D] -> LayerNorm - """ - b, t, h, w, d = x.shape - x_in = x.permute(0, 4, 1, 2, 3).contiguous() # [B, D, T, H, W] - y = self.conv(x_in) - bb, dd, t2, h2, w2 = y.shape - y = y.permute(0, 2, 3, 4, 1).reshape(bb, t2 * h2 * w2, dd) - return self.norm(y), (t2, h2, w2) +"""2×2×2 时空压缩卷积。 + +将 8 帧 × 24 × 64 = 8×1536 = 12288 个 patch tokens 压缩为 +4 × 12 × 32 = 1536 个视觉 tokens。维度保持 768。 +""" + +from __future__ import annotations + +import torch +import torch.nn as nn + + +class TemporalCompress2x2x2(nn.Module): + """``Conv3d(D, D, kernel=2, stride=2)`` 配合标准 LayerNorm。""" + + def __init__(self, dim: int = 768) -> None: + super().__init__() + self.dim = dim + self.conv = nn.Conv3d(dim, dim, kernel_size=2, stride=2, padding=0) + self.norm = nn.LayerNorm(dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """输入 ``[B, T, H, W, D]``;输出 ``[B, T*H*W//8, D]``。 + + 中间排布: + [B, T, H, W, D] -> [B, D, T, H, W] -> Conv3d -> [B, D, T', H', W'] + -> [B, T'*H'*W', D] -> LayerNorm + """ + b, t, h, w, d = x.shape + x_in = x.permute(0, 4, 1, 2, 3).contiguous() # [B, D, T, H, W] + y = self.conv(x_in) + bb, dd, t2, h2, w2 = y.shape + y = y.permute(0, 2, 3, 4, 1).reshape(bb, t2 * h2 * w2, dd) + return self.norm(y), (t2, h2, w2) diff --git a/src/wjad/train/__init__.py b/src/wjad/train/__init__.py index ce88a74467c110c239fc8498c292c565e001c602..6252e2b1166458b5c2548a0d6aed0f12722caafc 100644 --- a/src/wjad/train/__init__.py +++ b/src/wjad/train/__init__.py @@ -1,17 +1,17 @@ -"""训练相关:多任务损失合并、Trainer、调度器。""" - -from .multitask import GradNormBalancer, PCGradCombiner, MultiTaskOptimizer -from .schedule import build_scheduler -from .trainer import Trainer, TrainerConfig, compute_all_losses, MAIN_TASK_KEYS, AUX_TASK_KEYS - -__all__ = [ - "GradNormBalancer", - "PCGradCombiner", - "MultiTaskOptimizer", - "build_scheduler", - "Trainer", - "TrainerConfig", - "compute_all_losses", - "MAIN_TASK_KEYS", - "AUX_TASK_KEYS", -] +"""训练相关:多任务损失合并、Trainer、调度器。""" + +from .multitask import GradNormBalancer, PCGradCombiner, MultiTaskOptimizer +from .schedule import build_scheduler +from .trainer import Trainer, TrainerConfig, compute_all_losses, MAIN_TASK_KEYS, AUX_TASK_KEYS + +__all__ = [ + "GradNormBalancer", + "PCGradCombiner", + "MultiTaskOptimizer", + "build_scheduler", + "Trainer", + "TrainerConfig", + "compute_all_losses", + "MAIN_TASK_KEYS", + "AUX_TASK_KEYS", +] diff --git a/src/wjad/train/multitask.py b/src/wjad/train/multitask.py index 415d79b141e86e306d2a3758d040bcc28e0ab929..432a9abc411da9aa654241890833ac0fd143a9aa 100644 --- a/src/wjad/train/multitask.py +++ b/src/wjad/train/multitask.py @@ -1,284 +1,284 @@ -"""多任务损失合并:GradNorm(自适应权重)+ PCGrad(正交化梯度)。 - -GradNorm(Chen et al. 2018) - 维护任务可学习权重 ``w_i = softplus(raw_w_i)``,按各任务相对训练速度 - 自适应调整。``r_i(t) = (L_i / L_i(0)) / mean_j(L_j / L_j(0))``,目标 - ``G_i = mean_norm * r_i^alpha``;以 ``L_grad = sum_i |‖∇w_i L_i‖ - G_i|`` - 回传更新 ``raw_w_i``。最后把 ``w_i`` 重归一化使 ``sum w = N``。 - -PCGrad(Yu et al. 2020) - 分别对每个任务在 **共享参数** 上做 ``autograd.grad`` 得到 ``g_i``,对 - 每对 (i, j),若 `` < 0``,把 ``g_i`` 投影到 ``g_j`` 的正交补; - 每步随机打乱任务顺序避免偏置;最后把所有调整后的梯度求和写回 - ``param.grad``。任务专属参数(仅自身 loss 影响)不需要 PCGrad,由普通 - backward 路径处理。 -""" - -from __future__ import annotations - -import random -from dataclasses import dataclass -from typing import Sequence - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class GradNormBalancer(nn.Module): - """GradNorm 自适应任务权重。 - - 维护 ``raw_weights``(softplus 参数化),对外暴露归一化权重 ``task_weights``。 - """ - - def __init__( - self, - num_tasks: int, - alpha: float = 1.5, - gradnorm_lr: float = 0.025, - eps: float = 1e-8, - ) -> None: - super().__init__() - self.num_tasks = num_tasks - self.alpha = alpha - self.eps = eps - # raw_weights = 1 → softplus(1) ≈ 1.31,归一化后初始权重均匀。 - self.raw_weights = nn.Parameter(torch.ones(num_tasks)) - self.optimizer = torch.optim.Adam([self.raw_weights], lr=gradnorm_lr) - self.register_buffer("initial_losses", torch.zeros(num_tasks)) - self._initialized = False - - @property - def task_weights(self) -> torch.Tensor: - """重归一化后 sum=N 的权重(保留计算图,可用于反传到 raw_weights)。""" - w = F.softplus(self.raw_weights) + self.eps - return w * (self.num_tasks / w.sum()) - - def initialize(self, losses: torch.Tensor) -> None: - with torch.no_grad(): - self.initial_losses.copy_(losses.detach()) - self._initialized = True - - def step(self, losses: torch.Tensor, shared_param: torch.Tensor) -> None: - """按 GradNorm 规则更新任务权重。 - - 参数 - ---- - losses : ``[N]``,未加权的各任务 loss(保留计算图)。 - shared_param : 用于估计 ``‖∇w_i L_i‖`` 的代理参数(通常是主干末层 weight)。 - """ - if not self._initialized: - self.initialize(losses) - return - N = self.num_tasks - weights = self.task_weights - weighted = weights * losses - # 对每个任务取代理参数的梯度范数 - gnorms = [] - for i in range(N): - (g,) = torch.autograd.grad( - weighted[i], shared_param, retain_graph=True, create_graph=False - ) - gnorms.append(g.detach().norm(p=2)) - gnorms_t = torch.stack(gnorms) - mean_g = gnorms_t.mean() - - with torch.no_grad(): - losses_ratio = losses.detach() / self.initial_losses.clamp_min(self.eps) - rt = losses_ratio / losses_ratio.mean().clamp_min(self.eps) - target = (mean_g.detach() * rt.pow(self.alpha)).detach() - - # 关键:L_grad 仅通过 weights = f(raw_weights) 反传到 raw_weights。 - # 这里用 (gnorms_t.detach() - target).abs() * weights,让 weight 自身 - # 接受梯度(标准 GradNorm 即此实现)。 - # 不过更稳妥的做法是用差值符号驱动:见 Chen 2018 论文。 - L_grad = (gnorms_t.detach() - target).abs().sum() - # gnorms_t.detach() 已 detach;为让 raw_weights 接收到梯度,需要把 - # 权重的“范数贡献”再次接入。常用近似:把 gnorms 重写为 weights * base。 - # 这里采用论文推荐近似:以 weights 为变量、其它项视为常数。 - # 等价 L_grad' = sum_i weights_i * (||∇L_i_unweighted|| - target_i / weights_i) - # 简化:用 weights * (gnorms_unweighted - target/weights) 的 L1 形式。 - # 避免实现复杂,采用 weights 自身的微弱 L2 锚 + GradNorm 主目标。 - anchor = (weights - 1.0).pow(2).sum() * 1e-3 - # weights 越大、对应任务相对慢 -> 增加 weights;反之减少。 - speed_signal = (weights * (gnorms_t.detach() - target)).sum() - loss_for_w = anchor + speed_signal.abs() * 0 # 占位以便 autograd 不报错 - # 实际驱动信号:让 weights 沿 (gnorms - target) 反向更新 - # 用一个简单 surrogate:sum(weights * sign(gnorms_t - target).detach()) - sign = torch.sign(gnorms_t - target).detach() - surrogate = (weights * sign).sum() - full = anchor + surrogate * 1.0 # 倾向减小 weights 当 gnorm > target - - self.optimizer.zero_grad(set_to_none=True) - full.backward(retain_graph=False) - self.optimizer.step() - - -class PCGradCombiner: - """PCGrad:对共享参数的多任务梯度做正交投影。""" - - def __init__(self, shuffle: bool = True) -> None: - self.shuffle = shuffle - - @torch.no_grad() - def project(self, grads_per_task: list[torch.Tensor]) -> list[torch.Tensor]: - """对一组扁平的 task 梯度做 PCGrad 投影;返回投影后的列表。""" - n = len(grads_per_task) - adjusted = [g.clone() for g in grads_per_task] - order_template = list(range(n)) - for i in range(n): - order = order_template.copy() - if self.shuffle: - random.shuffle(order) - for j in order: - if j == i: - continue - gi = adjusted[i] - gj = grads_per_task[j] - dot = torch.dot(gi, gj) - if dot.item() < 0: - denom = gj.dot(gj).clamp_min(1e-12) - adjusted[i] = gi - (dot / denom) * gj - return adjusted - - -@dataclass -class MultiTaskOptimizerConfig: - enable_gradnorm: bool = True - enable_pcgrad: bool = False - gradnorm_alpha: float = 1.5 - gradnorm_lr: float = 0.025 - pcgrad_shuffle: bool = True - - -class MultiTaskOptimizer: - """整合 GradNorm + PCGrad 的多任务训练 helper。 - - 使用流程: - mto = MultiTaskOptimizer(num_tasks, shared_params, proxy, head_params, cfg) - for step in ...: - optimizer.zero_grad(set_to_none=True) - losses_main = torch.stack([...]) # [N], 未加权 - loss_aux = ... # 标量正则 - total, w = mto.backward(losses_main, loss_aux, all_trainable_params) - optimizer.step() - """ - - def __init__( - self, - num_main_tasks: int, - shared_params: list[nn.Parameter], - gradnorm_proxy_param: nn.Parameter, - cfg: MultiTaskOptimizerConfig, - ) -> None: - self.cfg = cfg - self.num_main = num_main_tasks - self.shared_params = list(shared_params) - self.shared_set = set(id(p) for p in self.shared_params) - self.proxy = gradnorm_proxy_param - self.gradnorm = ( - GradNormBalancer(num_main_tasks, alpha=cfg.gradnorm_alpha, gradnorm_lr=cfg.gradnorm_lr) - if cfg.enable_gradnorm - else None - ) - self.pcgrad = PCGradCombiner(shuffle=cfg.pcgrad_shuffle) if cfg.enable_pcgrad else None - - def task_weights(self, losses_main: torch.Tensor) -> torch.Tensor: - """获取(并按需更新)任务权重。返回 detach 版本用于加权 loss。""" - if self.gradnorm is None: - return torch.ones(losses_main.shape[0], device=losses_main.device) - # GradNorm 自身的优化器内部 step - self.gradnorm.step(losses_main, self.proxy) - return self.gradnorm.task_weights.detach() - - def backward( - self, - losses_main: torch.Tensor, # [N],未加权 - loss_aux: torch.Tensor, # 标量 - all_trainable_params: Sequence[nn.Parameter], - ) -> tuple[torch.Tensor, torch.Tensor]: - """完成一次反传 + 梯度合并;返回 (total_unweighted_view, weights)。""" - weights = self.task_weights(losses_main) - weighted_main = weights * losses_main # [N] - - if self.pcgrad is None: - # 常规路径:sum(weighted) + aux 一次反传 - total = weighted_main.sum() + loss_aux - total.backward() - return total.detach(), weights - - # === PCGrad 路径 === - # 1) 共享参数:对每个 task 单独 autograd.grad,正交化后写回 .grad。 - # 任务专属(非共享)参数:用 (sum(weighted_main) + aux).backward() 处理。 - - # 1a) 共享参数的 per-task 梯度 - per_task_flat: list[torch.Tensor] = [] - shapes = [p.shape for p in self.shared_params] - for i in range(self.num_main): - grads = torch.autograd.grad( - weighted_main[i], - self.shared_params, - retain_graph=True, - allow_unused=True, - ) - grads = [ - g if g is not None else torch.zeros_like(p) - for g, p in zip(grads, self.shared_params) - ] - per_task_flat.append(torch.cat([g.reshape(-1) for g in grads], dim=0)) - - adjusted = self.pcgrad.project(per_task_flat) - # 原始 per_task_flat 已不再需要:投影时已读完所有 j 引用。立即释放, - # 降低峰值(N × flat 张量)。 - del per_task_flat - # 原地累加,避免 torch.stack 创建 [N, P] 中间张量(再多一份显存)。 - combined_main_flat = adjusted[0] - for k in range(1, len(adjusted)): - combined_main_flat = combined_main_flat + adjusted[k] - adjusted[k] = None # type: ignore[assignment] - del adjusted - - # 1b) aux loss 对共享参数的梯度 - aux_grads = torch.autograd.grad( - loss_aux, - self.shared_params, - retain_graph=True, - allow_unused=True, - ) - aux_grads = [ - g if g is not None else torch.zeros_like(p) - for g, p in zip(aux_grads, self.shared_params) - ] - aux_flat = torch.cat([g.reshape(-1) for g in aux_grads], dim=0) - shared_flat = combined_main_flat + aux_flat - - # 1c) 写回共享参数 .grad - cursor = 0 - for p, shp in zip(self.shared_params, shapes): - n = int(torch.tensor(shp).prod().item()) - chunk = shared_flat[cursor : cursor + n].view(*shp) - if p.grad is None: - p.grad = chunk.detach().clone() - else: - p.grad = p.grad + chunk.detach() - cursor += n - - # 2) 非共享参数:调用 backward 走标准路径 - non_shared = [p for p in all_trainable_params if id(p) not in self.shared_set] - if non_shared: - total_for_ns = weighted_main.sum() + loss_aux - grads_ns = torch.autograd.grad( - total_for_ns, - non_shared, - retain_graph=False, - allow_unused=True, - ) - for p, g in zip(non_shared, grads_ns): - if g is None: - continue - if p.grad is None: - p.grad = g.detach().clone() - else: - p.grad = p.grad + g.detach() - - return (weighted_main.sum().detach() + loss_aux.detach()), weights +"""多任务损失合并:GradNorm(自适应权重)+ PCGrad(正交化梯度)。 + +GradNorm(Chen et al. 2018) + 维护任务可学习权重 ``w_i = softplus(raw_w_i)``,按各任务相对训练速度 + 自适应调整。``r_i(t) = (L_i / L_i(0)) / mean_j(L_j / L_j(0))``,目标 + ``G_i = mean_norm * r_i^alpha``;以 ``L_grad = sum_i |‖∇w_i L_i‖ - G_i|`` + 回传更新 ``raw_w_i``。最后把 ``w_i`` 重归一化使 ``sum w = N``。 + +PCGrad(Yu et al. 2020) + 分别对每个任务在 **共享参数** 上做 ``autograd.grad`` 得到 ``g_i``,对 + 每对 (i, j),若 `` < 0``,把 ``g_i`` 投影到 ``g_j`` 的正交补; + 每步随机打乱任务顺序避免偏置;最后把所有调整后的梯度求和写回 + ``param.grad``。任务专属参数(仅自身 loss 影响)不需要 PCGrad,由普通 + backward 路径处理。 +""" + +from __future__ import annotations + +import random +from dataclasses import dataclass +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GradNormBalancer(nn.Module): + """GradNorm 自适应任务权重。 + + 维护 ``raw_weights``(softplus 参数化),对外暴露归一化权重 ``task_weights``。 + """ + + def __init__( + self, + num_tasks: int, + alpha: float = 1.5, + gradnorm_lr: float = 0.025, + eps: float = 1e-8, + ) -> None: + super().__init__() + self.num_tasks = num_tasks + self.alpha = alpha + self.eps = eps + # raw_weights = 1 → softplus(1) ≈ 1.31,归一化后初始权重均匀。 + self.raw_weights = nn.Parameter(torch.ones(num_tasks)) + self.optimizer = torch.optim.Adam([self.raw_weights], lr=gradnorm_lr) + self.register_buffer("initial_losses", torch.zeros(num_tasks)) + self._initialized = False + + @property + def task_weights(self) -> torch.Tensor: + """重归一化后 sum=N 的权重(保留计算图,可用于反传到 raw_weights)。""" + w = F.softplus(self.raw_weights) + self.eps + return w * (self.num_tasks / w.sum()) + + def initialize(self, losses: torch.Tensor) -> None: + with torch.no_grad(): + self.initial_losses.copy_(losses.detach()) + self._initialized = True + + def step(self, losses: torch.Tensor, shared_param: torch.Tensor) -> None: + """按 GradNorm 规则更新任务权重。 + + 参数 + ---- + losses : ``[N]``,未加权的各任务 loss(保留计算图)。 + shared_param : 用于估计 ``‖∇w_i L_i‖`` 的代理参数(通常是主干末层 weight)。 + """ + if not self._initialized: + self.initialize(losses) + return + N = self.num_tasks + weights = self.task_weights + weighted = weights * losses + # 对每个任务取代理参数的梯度范数 + gnorms = [] + for i in range(N): + (g,) = torch.autograd.grad( + weighted[i], shared_param, retain_graph=True, create_graph=False + ) + gnorms.append(g.detach().norm(p=2)) + gnorms_t = torch.stack(gnorms) + mean_g = gnorms_t.mean() + + with torch.no_grad(): + losses_ratio = losses.detach() / self.initial_losses.clamp_min(self.eps) + rt = losses_ratio / losses_ratio.mean().clamp_min(self.eps) + target = (mean_g.detach() * rt.pow(self.alpha)).detach() + + # 关键:L_grad 仅通过 weights = f(raw_weights) 反传到 raw_weights。 + # 这里用 (gnorms_t.detach() - target).abs() * weights,让 weight 自身 + # 接受梯度(标准 GradNorm 即此实现)。 + # 不过更稳妥的做法是用差值符号驱动:见 Chen 2018 论文。 + L_grad = (gnorms_t.detach() - target).abs().sum() + # gnorms_t.detach() 已 detach;为让 raw_weights 接收到梯度,需要把 + # 权重的“范数贡献”再次接入。常用近似:把 gnorms 重写为 weights * base。 + # 这里采用论文推荐近似:以 weights 为变量、其它项视为常数。 + # 等价 L_grad' = sum_i weights_i * (||∇L_i_unweighted|| - target_i / weights_i) + # 简化:用 weights * (gnorms_unweighted - target/weights) 的 L1 形式。 + # 避免实现复杂,采用 weights 自身的微弱 L2 锚 + GradNorm 主目标。 + anchor = (weights - 1.0).pow(2).sum() * 1e-3 + # weights 越大、对应任务相对慢 -> 增加 weights;反之减少。 + speed_signal = (weights * (gnorms_t.detach() - target)).sum() + loss_for_w = anchor + speed_signal.abs() * 0 # 占位以便 autograd 不报错 + # 实际驱动信号:让 weights 沿 (gnorms - target) 反向更新 + # 用一个简单 surrogate:sum(weights * sign(gnorms_t - target).detach()) + sign = torch.sign(gnorms_t - target).detach() + surrogate = (weights * sign).sum() + full = anchor + surrogate * 1.0 # 倾向减小 weights 当 gnorm > target + + self.optimizer.zero_grad(set_to_none=True) + full.backward(retain_graph=False) + self.optimizer.step() + + +class PCGradCombiner: + """PCGrad:对共享参数的多任务梯度做正交投影。""" + + def __init__(self, shuffle: bool = True) -> None: + self.shuffle = shuffle + + @torch.no_grad() + def project(self, grads_per_task: list[torch.Tensor]) -> list[torch.Tensor]: + """对一组扁平的 task 梯度做 PCGrad 投影;返回投影后的列表。""" + n = len(grads_per_task) + adjusted = [g.clone() for g in grads_per_task] + order_template = list(range(n)) + for i in range(n): + order = order_template.copy() + if self.shuffle: + random.shuffle(order) + for j in order: + if j == i: + continue + gi = adjusted[i] + gj = grads_per_task[j] + dot = torch.dot(gi, gj) + if dot.item() < 0: + denom = gj.dot(gj).clamp_min(1e-12) + adjusted[i] = gi - (dot / denom) * gj + return adjusted + + +@dataclass +class MultiTaskOptimizerConfig: + enable_gradnorm: bool = True + enable_pcgrad: bool = False + gradnorm_alpha: float = 1.5 + gradnorm_lr: float = 0.025 + pcgrad_shuffle: bool = True + + +class MultiTaskOptimizer: + """整合 GradNorm + PCGrad 的多任务训练 helper。 + + 使用流程: + mto = MultiTaskOptimizer(num_tasks, shared_params, proxy, head_params, cfg) + for step in ...: + optimizer.zero_grad(set_to_none=True) + losses_main = torch.stack([...]) # [N], 未加权 + loss_aux = ... # 标量正则 + total, w = mto.backward(losses_main, loss_aux, all_trainable_params) + optimizer.step() + """ + + def __init__( + self, + num_main_tasks: int, + shared_params: list[nn.Parameter], + gradnorm_proxy_param: nn.Parameter, + cfg: MultiTaskOptimizerConfig, + ) -> None: + self.cfg = cfg + self.num_main = num_main_tasks + self.shared_params = list(shared_params) + self.shared_set = set(id(p) for p in self.shared_params) + self.proxy = gradnorm_proxy_param + self.gradnorm = ( + GradNormBalancer(num_main_tasks, alpha=cfg.gradnorm_alpha, gradnorm_lr=cfg.gradnorm_lr) + if cfg.enable_gradnorm + else None + ) + self.pcgrad = PCGradCombiner(shuffle=cfg.pcgrad_shuffle) if cfg.enable_pcgrad else None + + def task_weights(self, losses_main: torch.Tensor) -> torch.Tensor: + """获取(并按需更新)任务权重。返回 detach 版本用于加权 loss。""" + if self.gradnorm is None: + return torch.ones(losses_main.shape[0], device=losses_main.device) + # GradNorm 自身的优化器内部 step + self.gradnorm.step(losses_main, self.proxy) + return self.gradnorm.task_weights.detach() + + def backward( + self, + losses_main: torch.Tensor, # [N],未加权 + loss_aux: torch.Tensor, # 标量 + all_trainable_params: Sequence[nn.Parameter], + ) -> tuple[torch.Tensor, torch.Tensor]: + """完成一次反传 + 梯度合并;返回 (total_unweighted_view, weights)。""" + weights = self.task_weights(losses_main) + weighted_main = weights * losses_main # [N] + + if self.pcgrad is None: + # 常规路径:sum(weighted) + aux 一次反传 + total = weighted_main.sum() + loss_aux + total.backward() + return total.detach(), weights + + # === PCGrad 路径 === + # 1) 共享参数:对每个 task 单独 autograd.grad,正交化后写回 .grad。 + # 任务专属(非共享)参数:用 (sum(weighted_main) + aux).backward() 处理。 + + # 1a) 共享参数的 per-task 梯度 + per_task_flat: list[torch.Tensor] = [] + shapes = [p.shape for p in self.shared_params] + for i in range(self.num_main): + grads = torch.autograd.grad( + weighted_main[i], + self.shared_params, + retain_graph=True, + allow_unused=True, + ) + grads = [ + g if g is not None else torch.zeros_like(p) + for g, p in zip(grads, self.shared_params) + ] + per_task_flat.append(torch.cat([g.reshape(-1) for g in grads], dim=0)) + + adjusted = self.pcgrad.project(per_task_flat) + # 原始 per_task_flat 已不再需要:投影时已读完所有 j 引用。立即释放, + # 降低峰值(N × flat 张量)。 + del per_task_flat + # 原地累加,避免 torch.stack 创建 [N, P] 中间张量(再多一份显存)。 + combined_main_flat = adjusted[0] + for k in range(1, len(adjusted)): + combined_main_flat = combined_main_flat + adjusted[k] + adjusted[k] = None # type: ignore[assignment] + del adjusted + + # 1b) aux loss 对共享参数的梯度 + aux_grads = torch.autograd.grad( + loss_aux, + self.shared_params, + retain_graph=True, + allow_unused=True, + ) + aux_grads = [ + g if g is not None else torch.zeros_like(p) + for g, p in zip(aux_grads, self.shared_params) + ] + aux_flat = torch.cat([g.reshape(-1) for g in aux_grads], dim=0) + shared_flat = combined_main_flat + aux_flat + + # 1c) 写回共享参数 .grad + cursor = 0 + for p, shp in zip(self.shared_params, shapes): + n = int(torch.tensor(shp).prod().item()) + chunk = shared_flat[cursor : cursor + n].view(*shp) + if p.grad is None: + p.grad = chunk.detach().clone() + else: + p.grad = p.grad + chunk.detach() + cursor += n + + # 2) 非共享参数:调用 backward 走标准路径 + non_shared = [p for p in all_trainable_params if id(p) not in self.shared_set] + if non_shared: + total_for_ns = weighted_main.sum() + loss_aux + grads_ns = torch.autograd.grad( + total_for_ns, + non_shared, + retain_graph=False, + allow_unused=True, + ) + for p, g in zip(non_shared, grads_ns): + if g is None: + continue + if p.grad is None: + p.grad = g.detach().clone() + else: + p.grad = p.grad + g.detach() + + return (weighted_main.sum().detach() + loss_aux.detach()), weights diff --git a/src/wjad/train/runner_local.py b/src/wjad/train/runner_local.py index 9681e384b63befe4746d74365aa3fb3c07204219..ef15569525edc0be343e076e90671c2cd87ac08e 100644 --- a/src/wjad/train/runner_local.py +++ b/src/wjad/train/runner_local.py @@ -1,163 +1,163 @@ -"""本地训练入口。 - -支持两种模式: - - ``--tiny`` : 用 1 个 clip / 极少步数验证训练循环可跑通; - - 否则默认按 configs/default.yaml 训练(需要数据集解压完成)。 -""" - -from __future__ import annotations - -import argparse -import logging -import sys -from pathlib import Path - -import torch -import yaml -from torch.utils.data import DataLoader - -from ..data.cosmos_dataset import CosmosDriveDreamsDataset, build_clip_index, collate_samples -from ..model import E2EAVModel -from .trainer import Trainer, TrainerConfig - - -def _load_config(path: str) -> dict: - with open(path, "r", encoding="utf-8") as f: - return yaml.safe_load(f) - - -def _make_model_from_cfg(cfg: dict, dinov3_path: str) -> E2EAVModel: - """根据配置创建模型。""" - return E2EAVModel( - dinov3_path=dinov3_path, - backbone_dim=cfg["backbone"]["hidden_size"], - num_heads=cfg["backbone"]["num_heads"], - num_dense_layers=cfg["backbone"]["num_dense_layers"], - num_moe_layers=cfg["backbone"]["num_moe_layers"], - num_routed_experts=cfg["moe"]["num_routed_experts"], - num_shared_experts=cfg["moe"]["num_shared_experts"], - topk_experts=cfg["moe"]["topk"], - ffn_mult=cfg["backbone"]["ffn_mult"], - num_history_frames=cfg["input"]["num_history_frames"], - num_detection_tokens=cfg["tokens"]["num_detection"], - num_control_tokens=cfg["tokens"]["num_control"], - num_ego_tokens=cfg["tokens"]["num_ego"], - num_extra_tokens=cfg["tokens"]["num_extra"], - image_h=cfg["input"]["image_height"], - image_w=cfg["input"]["image_width"], - patch_size=cfg["dinov3"]["patch_size"], - num_classes=cfg["det_traj_head"]["num_classes"], - traj_horizon=cfg["det_traj_head"]["traj_horizon"], - det_head_hidden=cfg["det_traj_head"]["hidden_size"], - ctrl_head_hidden=cfg["control_head"]["hidden_size"], - calib_dim=cfg["calibration"]["hidden_size"], - calib_num_query=cfg["calibration"]["num_query_tokens"], - calib_num_blocks=cfg["calibration"]["num_blocks"], - calib_num_self_per_block=cfg["calibration"]["num_self_attn_per_block"], - calib_num_heads=cfg["calibration"]["num_heads"], - calib_residual_range=cfg["calibration"]["residual_range"], - calib_intr_dim=cfg["calibration"]["intr_vec_dim"], - freeze_dinov3=cfg["dinov3"]["freeze_in_stage1"], - attn_implementation=cfg["dinov3"]["attn_implementation"], - ) - - -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument("--config", default=str(Path(__file__).resolve().parents[3] / "configs" / "default.yaml")) - parser.add_argument("--data_root", default=None, help="覆盖 config 中的 data.root") - parser.add_argument("--dinov3_path", default=str(Path(__file__).resolve().parents[3] / "dinov3-vitb16-pretrain-lvd1689m")) - parser.add_argument("--tiny", action="store_true", help="用极少样本验证训练循环") - parser.add_argument("--max_steps", type=int, default=50) - parser.add_argument("--device", default="cpu") - args = parser.parse_args() - - logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") - cfg = _load_config(args.config) - data_root = args.data_root or cfg["data"]["root"] - - samples = build_clip_index( - data_root, - weathers=cfg["data"]["weather"], - camera_name=cfg["input"]["camera_name"], - ) - if args.tiny: - samples = samples[:8] or samples - print(f"[runner_local] 找到 {len(samples)} 个样本于 {data_root}") - if not samples: - print("[runner_local] 没有数据。") - print( - " 说明:loader 需要 NVIDIA download.py 解压后的布局:\n" - " /synthetic/single_view/generation/*.mp4\n" - " /labels//..." - ) - print( - " 在 HF Jobs 上只读挂载 Hub dataset 仓库通常 **不是** 上述树,build_clip_index 会得到 0 条。" - " 请用 Job 默认流程运行 scripts/download_data.py 到可写目录,或本地准备好再 --skip-download。" - ) - print(" 本地/容器内准备:python scripts/download_data.py --odir ./data/cosmos ...") - sys.exit(0) - - dataset = CosmosDriveDreamsDataset( - data_root=data_root, - samples=samples, - camera_name=cfg["input"]["camera_name"], - image_h=cfg["input"]["image_height"], - image_w=cfg["input"]["image_width"], - num_history=cfg["input"]["num_history_frames"], - future_horizon=cfg["input"]["num_future_frames"], - max_distance_m=cfg["detection"]["max_distance_m"], - occlusion_tol=cfg["detection"]["occlusion_depth_tolerance"], - ) - loader = DataLoader( - dataset, - batch_size=cfg["train"]["batch_size"], - shuffle=True, - num_workers=cfg["data"]["num_workers"] if not args.tiny else 0, - collate_fn=collate_samples, - pin_memory=cfg["data"]["pin_memory"], - ) - - model = _make_model_from_cfg(cfg, args.dinov3_path) - if cfg.get("gradient_checkpointing", False): - model.backbone.set_gradient_checkpointing(True) - - tcfg = TrainerConfig( - total_steps=cfg["train"]["total_steps"], - warmup_steps=cfg["train"]["warmup_steps"], - base_lr=cfg["train"]["base_lr"], - min_lr=cfg["train"]["min_lr"], - weight_decay=cfg["train"]["weight_decay"], - grad_clip=cfg["train"]["grad_clip"], - log_interval=cfg["train"]["log_interval"], - ckpt_interval=cfg["train"]["ckpt_interval"], - stage1_steps=cfg["train"]["stage1_steps"], - stage1_perturb_start=cfg["train"]["stage1_perturb_start"], - grad_monitor_threshold=cfg["train"]["grad_monitor_threshold"], - moe_load_balance_weight=cfg["moe"]["load_balance_weight"], - moe_boundary_weight=cfg["moe"]["boundary_weight"], - router_temp_init=cfg["moe"]["router_temperature_init"], - router_temp_final=cfg["moe"]["router_temperature_final"], - loss_giou_weight=cfg["loss"]["giou_weight"], - loss_calib_weight=cfg["loss"]["calib_weight"], - enable_gradnorm=cfg["multitask"]["enable_gradnorm"], - enable_pcgrad=cfg["multitask"]["enable_pcgrad"], - mixed_precision=cfg["mixed_precision"], - grad_accum_steps=cfg["train"]["grad_accum_steps"], - dinov3_lr_mult_stage2=cfg["dinov3"]["finetune_lr_ratio"], - backbone_lr_mult=cfg["train"]["param_groups"]["backbone_lr_mult"], - calibration_lr_mult=cfg["train"]["param_groups"]["calibration_lr_mult"], - head_lr_mult=cfg["train"]["param_groups"]["head_lr_mult"], - gate_lr_mult=cfg["train"]["param_groups"]["gate_lr_mult"], - unfreeze_dinov3_at_stage2=cfg["train"].get("unfreeze_dinov3_at_stage2", True), - ckpt_dir=cfg["train"].get("ckpt_dir", "outputs/checkpoints"), - hub_repo_id=(cfg.get("deploy") or {}).get("hf_repo"), - hub_push_checkpoints=bool((cfg.get("deploy") or {}).get("push_checkpoints", False)), - hub_ckpt_prefix=(cfg.get("deploy") or {}).get("hub_ckpt_prefix", "checkpoints"), - ) - trainer = Trainer(model, tcfg, num_classes=cfg["det_traj_head"]["num_classes"], device=args.device) - trainer.fit(loader, max_steps=args.max_steps) - - -if __name__ == "__main__": - main() +"""本地训练入口。 + +支持两种模式: + - ``--tiny`` : 用 1 个 clip / 极少步数验证训练循环可跑通; + - 否则默认按 configs/default.yaml 训练(需要数据集解压完成)。 +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +import torch +import yaml +from torch.utils.data import DataLoader + +from ..data.cosmos_dataset import CosmosDriveDreamsDataset, build_clip_index, collate_samples +from ..model import E2EAVModel +from .trainer import Trainer, TrainerConfig + + +def _load_config(path: str) -> dict: + with open(path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) + + +def _make_model_from_cfg(cfg: dict, dinov3_path: str) -> E2EAVModel: + """根据配置创建模型。""" + return E2EAVModel( + dinov3_path=dinov3_path, + backbone_dim=cfg["backbone"]["hidden_size"], + num_heads=cfg["backbone"]["num_heads"], + num_dense_layers=cfg["backbone"]["num_dense_layers"], + num_moe_layers=cfg["backbone"]["num_moe_layers"], + num_routed_experts=cfg["moe"]["num_routed_experts"], + num_shared_experts=cfg["moe"]["num_shared_experts"], + topk_experts=cfg["moe"]["topk"], + ffn_mult=cfg["backbone"]["ffn_mult"], + num_history_frames=cfg["input"]["num_history_frames"], + num_detection_tokens=cfg["tokens"]["num_detection"], + num_control_tokens=cfg["tokens"]["num_control"], + num_ego_tokens=cfg["tokens"]["num_ego"], + num_extra_tokens=cfg["tokens"]["num_extra"], + image_h=cfg["input"]["image_height"], + image_w=cfg["input"]["image_width"], + patch_size=cfg["dinov3"]["patch_size"], + num_classes=cfg["det_traj_head"]["num_classes"], + traj_horizon=cfg["det_traj_head"]["traj_horizon"], + det_head_hidden=cfg["det_traj_head"]["hidden_size"], + ctrl_head_hidden=cfg["control_head"]["hidden_size"], + calib_dim=cfg["calibration"]["hidden_size"], + calib_num_query=cfg["calibration"]["num_query_tokens"], + calib_num_blocks=cfg["calibration"]["num_blocks"], + calib_num_self_per_block=cfg["calibration"]["num_self_attn_per_block"], + calib_num_heads=cfg["calibration"]["num_heads"], + calib_residual_range=cfg["calibration"]["residual_range"], + calib_intr_dim=cfg["calibration"]["intr_vec_dim"], + freeze_dinov3=cfg["dinov3"]["freeze_in_stage1"], + attn_implementation=cfg["dinov3"]["attn_implementation"], + ) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--config", default=str(Path(__file__).resolve().parents[3] / "configs" / "default.yaml")) + parser.add_argument("--data_root", default=None, help="覆盖 config 中的 data.root") + parser.add_argument("--dinov3_path", default=str(Path(__file__).resolve().parents[3] / "dinov3-vitb16-pretrain-lvd1689m")) + parser.add_argument("--tiny", action="store_true", help="用极少样本验证训练循环") + parser.add_argument("--max_steps", type=int, default=50) + parser.add_argument("--device", default="cpu") + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + cfg = _load_config(args.config) + data_root = args.data_root or cfg["data"]["root"] + + samples = build_clip_index( + data_root, + weathers=cfg["data"]["weather"], + camera_name=cfg["input"]["camera_name"], + ) + if args.tiny: + samples = samples[:8] or samples + print(f"[runner_local] 找到 {len(samples)} 个样本于 {data_root}") + if not samples: + print("[runner_local] 没有数据。") + print( + " 说明:loader 需要 NVIDIA download.py 解压后的布局:\n" + " /synthetic/single_view/generation/*.mp4\n" + " /labels//..." + ) + print( + " 在 HF Jobs 上只读挂载 Hub dataset 仓库通常 **不是** 上述树,build_clip_index 会得到 0 条。" + " 请用 Job 默认流程运行 scripts/download_data.py 到可写目录,或本地准备好再 --skip-download。" + ) + print(" 本地/容器内准备:python scripts/download_data.py --odir ./data/cosmos ...") + sys.exit(0) + + dataset = CosmosDriveDreamsDataset( + data_root=data_root, + samples=samples, + camera_name=cfg["input"]["camera_name"], + image_h=cfg["input"]["image_height"], + image_w=cfg["input"]["image_width"], + num_history=cfg["input"]["num_history_frames"], + future_horizon=cfg["input"]["num_future_frames"], + max_distance_m=cfg["detection"]["max_distance_m"], + occlusion_tol=cfg["detection"]["occlusion_depth_tolerance"], + ) + loader = DataLoader( + dataset, + batch_size=cfg["train"]["batch_size"], + shuffle=True, + num_workers=cfg["data"]["num_workers"] if not args.tiny else 0, + collate_fn=collate_samples, + pin_memory=cfg["data"]["pin_memory"], + ) + + model = _make_model_from_cfg(cfg, args.dinov3_path) + if cfg.get("gradient_checkpointing", False): + model.backbone.set_gradient_checkpointing(True) + + tcfg = TrainerConfig( + total_steps=cfg["train"]["total_steps"], + warmup_steps=cfg["train"]["warmup_steps"], + base_lr=cfg["train"]["base_lr"], + min_lr=cfg["train"]["min_lr"], + weight_decay=cfg["train"]["weight_decay"], + grad_clip=cfg["train"]["grad_clip"], + log_interval=cfg["train"]["log_interval"], + ckpt_interval=cfg["train"]["ckpt_interval"], + stage1_steps=cfg["train"]["stage1_steps"], + stage1_perturb_start=cfg["train"]["stage1_perturb_start"], + grad_monitor_threshold=cfg["train"]["grad_monitor_threshold"], + moe_load_balance_weight=cfg["moe"]["load_balance_weight"], + moe_boundary_weight=cfg["moe"]["boundary_weight"], + router_temp_init=cfg["moe"]["router_temperature_init"], + router_temp_final=cfg["moe"]["router_temperature_final"], + loss_giou_weight=cfg["loss"]["giou_weight"], + loss_calib_weight=cfg["loss"]["calib_weight"], + enable_gradnorm=cfg["multitask"]["enable_gradnorm"], + enable_pcgrad=cfg["multitask"]["enable_pcgrad"], + mixed_precision=cfg["mixed_precision"], + grad_accum_steps=cfg["train"]["grad_accum_steps"], + dinov3_lr_mult_stage2=cfg["dinov3"]["finetune_lr_ratio"], + backbone_lr_mult=cfg["train"]["param_groups"]["backbone_lr_mult"], + calibration_lr_mult=cfg["train"]["param_groups"]["calibration_lr_mult"], + head_lr_mult=cfg["train"]["param_groups"]["head_lr_mult"], + gate_lr_mult=cfg["train"]["param_groups"]["gate_lr_mult"], + unfreeze_dinov3_at_stage2=cfg["train"].get("unfreeze_dinov3_at_stage2", True), + ckpt_dir=cfg["train"].get("ckpt_dir", "outputs/checkpoints"), + hub_repo_id=(cfg.get("deploy") or {}).get("hf_repo"), + hub_push_checkpoints=bool((cfg.get("deploy") or {}).get("push_checkpoints", False)), + hub_ckpt_prefix=(cfg.get("deploy") or {}).get("hub_ckpt_prefix", "checkpoints"), + ) + trainer = Trainer(model, tcfg, num_classes=cfg["det_traj_head"]["num_classes"], device=args.device) + trainer.fit(loader, max_steps=args.max_steps) + + +if __name__ == "__main__": + main() diff --git a/src/wjad/train/schedule.py b/src/wjad/train/schedule.py index df902f6a29e8a90c05f6590e85a81d89c3f7b3f3..83b1f06978d6f47512a95f0c5539db54928e55af 100644 --- a/src/wjad/train/schedule.py +++ b/src/wjad/train/schedule.py @@ -1,29 +1,29 @@ -"""学习率调度:线性 warmup + 余弦退火。""" - -from __future__ import annotations - -import math - -import torch - - -def build_scheduler( - optimizer: torch.optim.Optimizer, - warmup_steps: int, - total_steps: int, - base_lr: float, - min_lr: float, -) -> torch.optim.lr_scheduler.LambdaLR: - """返回 ``LambdaLR``,其中 lr_factor = current_lr / base_lr。""" - - def lr_lambda(step: int) -> float: - if step < warmup_steps: - return float(step + 1) / max(1, warmup_steps) - # 余弦退火 - progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) - progress = min(max(progress, 0.0), 1.0) - cos = 0.5 * (1.0 + math.cos(math.pi * progress)) - ratio = (min_lr + (base_lr - min_lr) * cos) / max(base_lr, 1e-12) - return float(ratio) - - return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) +"""学习率调度:线性 warmup + 余弦退火。""" + +from __future__ import annotations + +import math + +import torch + + +def build_scheduler( + optimizer: torch.optim.Optimizer, + warmup_steps: int, + total_steps: int, + base_lr: float, + min_lr: float, +) -> torch.optim.lr_scheduler.LambdaLR: + """返回 ``LambdaLR``,其中 lr_factor = current_lr / base_lr。""" + + def lr_lambda(step: int) -> float: + if step < warmup_steps: + return float(step + 1) / max(1, warmup_steps) + # 余弦退火 + progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) + progress = min(max(progress, 0.0), 1.0) + cos = 0.5 * (1.0 + math.cos(math.pi * progress)) + ratio = (min_lr + (base_lr - min_lr) * cos) / max(base_lr, 1e-12) + return float(ratio) + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) diff --git a/src/wjad/train/trainer.py b/src/wjad/train/trainer.py index e8e337d40a65845218d04aa0d1bbd4bcf1d86407..7db12f509db41d52de816b9382c749ba2abcbaf1 100644 --- a/src/wjad/train/trainer.py +++ b/src/wjad/train/trainer.py @@ -1,573 +1,573 @@ -"""两阶段训练器 + 梯度监控。 - -Stage 1 (Dense): - - MoE 全部专家加权(dense 模式); - - 路由温度初始 < 1(锐化),训练中线性升到 1; - - DINOv3 冻结; - - 中期开启运动学/内外参扰动,监督校准网络; - - GradNorm 启用。 - -Stage 2 (Sparse): - - MoE 切 Top-3; - - 路由温度退火完成; - - DINOv3 解冻并采用 1/100 主干 LR; - - GradNorm + PCGrad 同时启用。 -""" - -from __future__ import annotations - -import logging -import math -import os -from dataclasses import dataclass -from pathlib import Path -from typing import Sequence - -import numpy as np -import torch -import torch.nn as nn -from torch.utils.data import DataLoader - -from ..losses import ( - HungarianMatcher, - action_nll, - calibration_regularization, - detection_losses, - ego_traj_nll, - moe_load_balance_and_boundary, - object_traj_nll, -) -from ..model import E2EAVModel, E2EOutput -from .multitask import MultiTaskOptimizer, MultiTaskOptimizerConfig -from .schedule import build_scheduler - -log = logging.getLogger(__name__) - - -class _NullContext: - """空 context manager,用于 AMP 关闭时占位 autocast。""" - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - -@dataclass -class TrainerConfig: - """Trainer 超参数(与 ``configs/default.yaml`` 对齐)。""" - - total_steps: int = 100000 - warmup_steps: int = 1000 - base_lr: float = 2.0e-4 - min_lr: float = 1.0e-6 - weight_decay: float = 0.05 - grad_clip: float = 1.0 - log_interval: int = 20 - ckpt_interval: int = 1000 - stage1_steps: int = 60000 - stage1_perturb_start: int = 20000 - grad_monitor_threshold: float = 1e-7 - # === AMP / 混合精度 === - # "fp32" / "bf16" / "fp16"。默认 bf16(H100/A100 推荐,无需 GradScaler)。 - mixed_precision: str = "bf16" - grad_accum_steps: int = 1 - # MoE - moe_load_balance_weight: float = 0.01 - moe_boundary_weight: float = 0.001 - router_temp_init: float = 0.5 - router_temp_final: float = 1.0 - # 损失初始权重(GradNorm 自适应主任务 1-6) - loss_giou_weight: float = 0.5 - loss_calib_weight: float = 0.1 - # MultiTask(GradNorm + PCGrad 在 Stage1/Stage2 全程启用—— - # 两阶段的 6 项主任务都存在尺度不均与梯度冲突,PCGrad 不应延迟到 Stage2) - enable_gradnorm: bool = True - enable_pcgrad: bool = True - # 参数组 - dinov3_lr_mult_stage2: float = 0.01 - # 显存吃紧的设备(如 a10g-small)上可关闭 Stage2 DINOv3 解冻,保持冻结 - unfreeze_dinov3_at_stage2: bool = True - backbone_lr_mult: float = 1.0 - calibration_lr_mult: float = 0.1 - head_lr_mult: float = 1.0 - gate_lr_mult: float = 0.1 - # checkpoint(本地 + 可选 Hub) - ckpt_dir: str = "outputs/checkpoints" - hub_repo_id: str | None = None - hub_push_checkpoints: bool = False - hub_ckpt_prefix: str = "checkpoints" - - -def _is_gate_param(name: str) -> bool: - return ".gate." in name or name.endswith(".gate_proj.weight") or name.endswith(".gate_proj.bias") - - -def build_param_groups(model: E2EAVModel, base_lr: float, cfg: TrainerConfig, stage: int) -> list[dict]: - """按模块归类参数为不同 LR 组。Stage1 时 DINOv3 lr=0。""" - groups: dict[str, list[nn.Parameter]] = { - "dinov3": [], - "backbone": [], - "calibration": [], - "head": [], - "gate": [], - "other": [], - } - for name, p in model.named_parameters(): - if not p.requires_grad and stage == 1: - continue - if name.startswith("dinov3."): - groups["dinov3"].append(p) - elif name.startswith("backbone."): - if _is_gate_param(name): - groups["gate"].append(p) - else: - groups["backbone"].append(p) - elif name.startswith("calib."): - if _is_gate_param(name): - groups["gate"].append(p) - else: - groups["calibration"].append(p) - elif name.startswith("det_traj_head.") or name.startswith("ctrl_head."): - groups["head"].append(p) - else: - groups["other"].append(p) - - dinov3_lr = base_lr * (cfg.dinov3_lr_mult_stage2 if stage == 2 else 0.0) - return [ - {"params": groups["dinov3"], "lr": dinov3_lr, "name": "dinov3"}, - {"params": groups["backbone"], "lr": base_lr * cfg.backbone_lr_mult, "name": "backbone"}, - {"params": groups["calibration"], "lr": base_lr * cfg.calibration_lr_mult, "name": "calibration"}, - {"params": groups["head"], "lr": base_lr * cfg.head_lr_mult, "name": "head"}, - {"params": groups["gate"], "lr": base_lr * cfg.gate_lr_mult, "name": "gate"}, - {"params": groups["other"], "lr": base_lr, "name": "other"}, - ] - - -def grad_norm_per_module(model: nn.Module, threshold: float) -> dict[str, float]: - """统计各顶层模块的 grad-norm,返回 dict(用于日志/告警)。 - - 跳过: - - 没有任何 ``requires_grad=True`` 参数的模块(如冻结的 DINOv3、纯 - buffer 模块 RoPE); - - 空模块(参数计数为 0)。 - """ - summary: dict[str, float] = {} - for name, child in model.named_children(): - params = list(child.parameters()) - if not params: - continue - if not any(p.requires_grad for p in params): - # 整个模块被冻结 -> 不监控 - continue - total = 0.0 - seen = 0 - for p in params: - if p.grad is not None: - total += float(p.grad.detach().norm().item()) ** 2 - seen += 1 - if seen == 0: - continue - n = math.sqrt(total) - summary[name] = n - if n < threshold: - log.warning("[grad_monitor] %s grad_norm=%.3e < %.3e", name, n, threshold) - if not math.isfinite(n): - log.error("[grad_monitor] %s grad_norm is %s (NaN/Inf)", name, n) - return summary - - -def compute_all_losses( - model_out: E2EOutput, - batch: dict, - matcher: HungarianMatcher, - num_classes: int, - cfg: TrainerConfig, - perturbation_residual: torch.Tensor | None = None, -) -> dict[str, torch.Tensor]: - """计算 8 项损失,返回字典。 - - ``perturbation_residual``:扰动训练时给定的 ground-truth 残差,用于额外 - 监督校准网络;正常训练为 None。 - """ - targets = batch["targets"] - - det_out = model_out.detection - ctrl_out = model_out.control - calib = model_out.calibration - - det_losses = detection_losses( - cls_logits=det_out.cls_logits, - box_mu=det_out.box3d_mu, - box_log_sigma=det_out.box3d_log_sigma, - isdyn_logit=det_out.is_dynamic_logit, - targets=targets, - matcher=matcher, - num_classes=num_classes, - ) - - L_traj_obj = object_traj_nll( - det_out.traj_mu, - det_out.traj_log_sigma, - det_losses.matched_indices, - targets, - ) - - L_traj_ego = ego_traj_nll( - ctrl_out.ego_traj_mu, - ctrl_out.ego_traj_log_sigma, - batch["ego_future"], - valid=batch.get("ego_future_valid"), - ) - - # 全局动作 GT 通常没有;此处用 0 做占位(实际数据集需补齐) - action_target = batch.get("action_target") - if action_target is None: - action_target = torch.zeros_like(ctrl_out.action_mu) - L_ctrl = action_nll( - ctrl_out.action_mu, ctrl_out.action_log_sigma, action_target - ) + L_traj_ego # 控制损失 = action + ego_traj 复用同一项的便利封装;trainer 视情况拆分 - - # MoE / 校准 正则 - L_moe = moe_load_balance_and_boundary( - model_out.backbone_out.moe_stats, - load_balance_weight=cfg.moe_load_balance_weight, - boundary_weight=cfg.moe_boundary_weight, - ) - L_calib_reg = calibration_regularization( - calib.ego_residual, calib.intr_residual, calib.extr_residual, - l2_weight=1.0, - ) - if perturbation_residual is not None: - # 扰动训练:计算校准网络应该预测的 GT 残差与实际残差的 MSE - actual = torch.cat( - [calib.ego_residual.flatten(1), calib.intr_residual, calib.extr_residual], - dim=-1, - ) - L_calib_reg = L_calib_reg + 1.0 * (actual - perturbation_residual).pow(2).mean() - - return { - "L_cls": det_losses.cls_loss, - "L_box": det_losses.box_nll + cfg.loss_giou_weight * det_losses.giou_loss, - "L_isdyn": det_losses.isdyn_loss, - "L_traj_obj": L_traj_obj, - "L_traj_ego": L_traj_ego, - "L_ctrl": L_ctrl, - "L_moe": L_moe, - "L_calib": L_calib_reg, - } - - -MAIN_TASK_KEYS = ["L_cls", "L_box", "L_isdyn", "L_traj_obj", "L_traj_ego", "L_ctrl"] -AUX_TASK_KEYS = ["L_moe", "L_calib"] - - -class Trainer: - """端到端训练器。""" - - def __init__( - self, - model: E2EAVModel, - cfg: TrainerConfig, - num_classes: int = 22, - device: str = "cuda", - ) -> None: - self.model = model.to(device) - self.cfg = cfg - self.num_classes = num_classes - self.device = device - self.matcher = HungarianMatcher() - self.global_step = 0 - self._micro_step = 0 # 用于 grad_accum - self._stage = 1 - self._build_optimizer() - - # === AMP 配置 === - # 仅在 device 为 cuda 时启用 autocast(CPU 上 bf16 也能跑但收益极小)。 - amp_dtype_map = { - "fp32": None, - "bf16": torch.bfloat16, - "fp16": torch.float16, - } - self.amp_dtype = amp_dtype_map[cfg.mixed_precision] - self.amp_enabled = self.amp_dtype is not None and "cuda" in str(device) - # GradScaler 仅 fp16 需要;bf16 数值范围大无需 scaler - self.scaler = ( - torch.amp.GradScaler("cuda") - if (self.amp_enabled and self.amp_dtype == torch.float16) - else None - ) - - # MoE 初始模式 = dense;Stage2 切 sparse - self.model.backbone.set_moe_mode("dense") - self.model.backbone.set_router_temperature(cfg.router_temp_init) - - # ---------- 优化器构建 ---------- - - def _build_optimizer(self) -> None: - cfg = self.cfg - groups = build_param_groups(self.model, cfg.base_lr, cfg, stage=self._stage) - self.optimizer = torch.optim.AdamW(groups, weight_decay=cfg.weight_decay, betas=(0.9, 0.95)) - self.scheduler = build_scheduler( - self.optimizer, - warmup_steps=cfg.warmup_steps, - total_steps=cfg.total_steps, - base_lr=cfg.base_lr, - min_lr=cfg.min_lr, - ) - - # PCGrad 共享参数 = 主干最后的“共享瓶颈”:final_norm + 最后 1 层 MoE block。 - # 不把全部 DINOv3/Calib/Backbone 都纳入,否则 N 个任务 × full-grad 扁平副本会 - # 在 a10g-small 上瞬间 OOM(~600M 参数 × 6 任务 × 2 副本 ≈ 28 GB)。 - # 较前的层仍享受 GradNorm 自适应加权 + 共同求和的标准多任务训练。 - shared: list[nn.Parameter] = [] - last_moe = self.model.backbone.moe_layers[-1] - for p in self.model.backbone.final_norm.parameters(): - if p.requires_grad: - shared.append(p) - for p in last_moe.parameters(): - if p.requires_grad: - shared.append(p) - # GradNorm 代理参数:取主干最后 LayerNorm 的 weight - proxy = self.model.backbone.final_norm.weight - mt_cfg = MultiTaskOptimizerConfig( - enable_gradnorm=cfg.enable_gradnorm, - enable_pcgrad=cfg.enable_pcgrad, - gradnorm_alpha=1.5, - gradnorm_lr=0.025, - pcgrad_shuffle=True, - ) - self.mto = MultiTaskOptimizer( - num_main_tasks=len(MAIN_TASK_KEYS), - shared_params=shared, - gradnorm_proxy_param=proxy, - cfg=mt_cfg, - ) - # GradNormBalancer 是 nn.Module,需要把 raw_weights / initial_losses 缓冲 - # 移到 model 所在 device,否则与 losses (cuda) 设备不匹配。 - if self.mto.gradnorm is not None: - self.mto.gradnorm.to(self.device) - - def _checkpoint_state(self) -> dict: - state: dict = { - "model": self.model.state_dict(), - "optimizer": self.optimizer.state_dict(), - "scheduler": self.scheduler.state_dict(), - "global_step": self.global_step, - "stage": self._stage, - } - if self.scaler is not None: - state["scaler"] = self.scaler.state_dict() - return state - - def save_checkpoint(self, ckpt_path: Path) -> None: - ckpt_path = Path(ckpt_path) - ckpt_path.parent.mkdir(parents=True, exist_ok=True) - torch.save(self._checkpoint_state(), ckpt_path) - - def push_checkpoint_to_hub(self, ckpt_path: Path) -> None: - """将 ``ckpt_path`` 同步上传到 ``hub_repo_id``(需环境变量 ``HF_TOKEN``)。""" - repo = self.cfg.hub_repo_id - if not repo or not self.cfg.hub_push_checkpoints: - return - try: - from huggingface_hub import HfApi, create_repo - except ImportError as e: - log.warning("[Trainer] huggingface_hub 未安装,跳过 Hub 上传: %s", e) - return - prefix = self.cfg.hub_ckpt_prefix.strip().strip("/") - step = self.global_step - rel_step = f"{prefix}/step_{step:08d}.pt" - rel_latest = f"{prefix}/latest.pt" - api = HfApi(token=os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")) - try: - create_repo(repo, repo_type="model", exist_ok=True) - api.upload_file( - path_or_fileobj=str(ckpt_path), - path_in_repo=rel_step, - repo_id=repo, - repo_type="model", - commit_message=f"checkpoint step {step}", - ) - api.upload_file( - path_or_fileobj=str(ckpt_path), - path_in_repo=rel_latest, - repo_id=repo, - repo_type="model", - commit_message=f"latest ckpt step {step}", - ) - log.info("[Trainer] Hub 上传完成 hf.co/%s (%s)", repo, rel_latest) - except Exception: - log.exception("[Trainer] Hub 上传失败(可检查 HF_TOKEN / 配额)") - - # ---------- 阶段切换 ---------- - - def maybe_switch_stage(self) -> None: - cfg = self.cfg - if self._stage == 1 and self.global_step >= cfg.stage1_steps: - log.info("[Trainer] -> Stage 2 (sparse MoE + DINOv3 finetune + PCGrad)") - self._stage = 2 - # 1) MoE 切 sparse - self.model.backbone.set_moe_mode("sparse") - # 2) 路由温度退火完成 - self.model.backbone.set_router_temperature(cfg.router_temp_final) - # 3) DINOv3 解冻(小显存设备可禁用) - if cfg.unfreeze_dinov3_at_stage2: - self.model.dinov3.unfreeze() - # 4) 重建优化器(包含 DINOv3 参数)+ 启用 PCGrad - self._build_optimizer() - - # ---------- 单步 ---------- - - def train_step(self, batch: dict, rng: np.random.Generator) -> dict: - cfg = self.cfg - self.maybe_switch_stage() - - # 移到 device - batch = {k: (v.to(self.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} - # targets 是 list of dict,里面的 tensor 也移到 device - if "targets" in batch and isinstance(batch["targets"], list): - new_targets = [] - for t in batch["targets"]: - new_targets.append({k: (v.to(self.device) if isinstance(v, torch.Tensor) else v) for k, v in t.items()}) - batch["targets"] = new_targets - - # 扰动注入(Stage1 中期开启) - perturb_residual = None - ego_input = batch["ego_6d"] - intr_input = batch["intr_vec"] - extr_input = batch["extr_6d"] - if ( - self._stage == 1 - and self.global_step >= cfg.stage1_perturb_start - and rng.uniform() < 0.5 - ): - from ..data.transforms import perturb_kinematics - ego_input, intr_input, extr_input, delta = perturb_kinematics( - ego_input.cpu().clone(), intr_input.cpu().clone()[0], extr_input.cpu().clone()[0], - translation_std_m=0.1, rotation_std_deg=0.5, - intrinsic_std=0.005, extrinsic_std=0.005, - rng=rng, - ) - ego_input = ego_input.to(self.device) - # intr/extr 是 [B,...] 而 perturb_kinematics 是单样本;这里为简洁仅扰动第 0 个样本 - # 实际生产中应 batched 实现 - intr_input = batch["intr_vec"].clone() - extr_input = batch["extr_6d"].clone() - # GT 残差:在 symlog 空间 = -delta(symlog 是非线性,这里用线性近似) - perturb_residual = -delta.to(self.device).unsqueeze(0).expand(ego_input.shape[0], -1) - - # 前向(AMP autocast 仅包住 forward 与匹配/损失,反传由 PyTorch - # 在 fp32 主梯度下完成;GradNorm/PCGrad 内的 autograd.grad 也在 fp32) - ac_ctx = ( - torch.autocast(device_type="cuda", dtype=self.amp_dtype) - if self.amp_enabled - else _NullContext() - ) - with ac_ctx: - out = self.model( - images=batch["images"], - ego_6d_raw=ego_input, - intr_raw=intr_input, - extr_6d_raw=extr_input, - ) - losses = compute_all_losses( - out, batch, self.matcher, self.num_classes, cfg, - perturbation_residual=perturb_residual, - ) - - # === 把损失提升到 fp32 以保证后续 GradNorm/PCGrad 数值稳定 === - main = torch.stack([losses[k].float() for k in MAIN_TASK_KEYS]) - aux = sum(losses[k].float() for k in AUX_TASK_KEYS) - # 梯度累积:对累积步数取平均 - if cfg.grad_accum_steps > 1: - main = main / cfg.grad_accum_steps - aux = aux / cfg.grad_accum_steps - - # === 反传 === - if self._micro_step == 0: - self.optimizer.zero_grad(set_to_none=True) - - all_params = [p for p in self.model.parameters() if p.requires_grad] - if self.scaler is not None: - # fp16 路径:GradScaler 不直接支持 PCGrad(需手动调度);这里 - # 退化为标准 sum-backward。bf16 推荐路径无此限制。 - total = main.sum() + aux - self.scaler.scale(total).backward() - weights = torch.ones_like(main) - else: - total, weights = self.mto.backward(main, aux, all_params) - - self._micro_step += 1 - do_step = self._micro_step >= cfg.grad_accum_steps - if not do_step: - info_partial = { - "step": self.global_step, - "stage": self._stage, - "total_loss": float(total), - "weights": [float(w) for w in weights], - "grad_norms": {}, - } - for k, v in losses.items(): - info_partial[k] = float(v.detach()) - return info_partial - - # === 梯度裁剪 + 监控 + step === - if self.scaler is not None: - self.scaler.unscale_(self.optimizer) - grad_summary = grad_norm_per_module(self.model, cfg.grad_monitor_threshold) - torch.nn.utils.clip_grad_norm_(all_params, max_norm=cfg.grad_clip) - - if self.scaler is not None: - self.scaler.step(self.optimizer) - self.scaler.update() - else: - self.optimizer.step() - self.scheduler.step() - self._micro_step = 0 - - # 路由温度线性退火 - if self._stage == 1: - ratio = min(1.0, self.global_step / max(1, cfg.stage1_steps)) - t = cfg.router_temp_init + ratio * (cfg.router_temp_final - cfg.router_temp_init) - self.model.backbone.set_router_temperature(t) - - self.global_step += 1 - - info = { - "step": self.global_step, - "stage": self._stage, - "total_loss": float(total), - "weights": [float(w) for w in weights], - "grad_norms": grad_summary, - } - for k, v in losses.items(): - info[k] = float(v.detach()) - return info - - def fit(self, loader: DataLoader, max_steps: int | None = None) -> None: - """简化训练循环。""" - rng = np.random.default_rng(0) - steps = max_steps or self.cfg.total_steps - it = iter(loader) - for _ in range(steps): - try: - batch = next(it) - except StopIteration: - it = iter(loader) - batch = next(it) - info = self.train_step(batch, rng) - step = info["step"] - if self.cfg.ckpt_interval > 0 and step % self.cfg.ckpt_interval == 0: - ckpt_path = Path(self.cfg.ckpt_dir) / f"step_{step:08d}.pt" - self.save_checkpoint(ckpt_path) - self.push_checkpoint_to_hub(ckpt_path) - if step % self.cfg.log_interval == 0: - log.info( - "step=%d stage=%d total=%.4f cls=%.4f box=%.4f isdyn=%.4f traj_obj=%.4f traj_ego=%.4f ctrl=%.4f moe=%.4f calib=%.4f", - info["step"], info["stage"], info["total_loss"], - info["L_cls"], info["L_box"], info["L_isdyn"], - info["L_traj_obj"], info["L_traj_ego"], info["L_ctrl"], - info["L_moe"], info["L_calib"], - ) +"""两阶段训练器 + 梯度监控。 + +Stage 1 (Dense): + - MoE 全部专家加权(dense 模式); + - 路由温度初始 < 1(锐化),训练中线性升到 1; + - DINOv3 冻结; + - 中期开启运动学/内外参扰动,监督校准网络; + - GradNorm 启用。 + +Stage 2 (Sparse): + - MoE 切 Top-3; + - 路由温度退火完成; + - DINOv3 解冻并采用 1/100 主干 LR; + - GradNorm + PCGrad 同时启用。 +""" + +from __future__ import annotations + +import logging +import math +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +from ..losses import ( + HungarianMatcher, + action_nll, + calibration_regularization, + detection_losses, + ego_traj_nll, + moe_load_balance_and_boundary, + object_traj_nll, +) +from ..model import E2EAVModel, E2EOutput +from .multitask import MultiTaskOptimizer, MultiTaskOptimizerConfig +from .schedule import build_scheduler + +log = logging.getLogger(__name__) + + +class _NullContext: + """空 context manager,用于 AMP 关闭时占位 autocast。""" + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +@dataclass +class TrainerConfig: + """Trainer 超参数(与 ``configs/default.yaml`` 对齐)。""" + + total_steps: int = 100000 + warmup_steps: int = 1000 + base_lr: float = 2.0e-4 + min_lr: float = 1.0e-6 + weight_decay: float = 0.05 + grad_clip: float = 1.0 + log_interval: int = 20 + ckpt_interval: int = 1000 + stage1_steps: int = 60000 + stage1_perturb_start: int = 20000 + grad_monitor_threshold: float = 1e-7 + # === AMP / 混合精度 === + # "fp32" / "bf16" / "fp16"。默认 bf16(H100/A100 推荐,无需 GradScaler)。 + mixed_precision: str = "bf16" + grad_accum_steps: int = 1 + # MoE + moe_load_balance_weight: float = 0.01 + moe_boundary_weight: float = 0.001 + router_temp_init: float = 0.5 + router_temp_final: float = 1.0 + # 损失初始权重(GradNorm 自适应主任务 1-6) + loss_giou_weight: float = 0.5 + loss_calib_weight: float = 0.1 + # MultiTask(GradNorm + PCGrad 在 Stage1/Stage2 全程启用—— + # 两阶段的 6 项主任务都存在尺度不均与梯度冲突,PCGrad 不应延迟到 Stage2) + enable_gradnorm: bool = True + enable_pcgrad: bool = True + # 参数组 + dinov3_lr_mult_stage2: float = 0.01 + # 显存吃紧的设备(如 a10g-small)上可关闭 Stage2 DINOv3 解冻,保持冻结 + unfreeze_dinov3_at_stage2: bool = True + backbone_lr_mult: float = 1.0 + calibration_lr_mult: float = 0.1 + head_lr_mult: float = 1.0 + gate_lr_mult: float = 0.1 + # checkpoint(本地 + 可选 Hub) + ckpt_dir: str = "outputs/checkpoints" + hub_repo_id: str | None = None + hub_push_checkpoints: bool = False + hub_ckpt_prefix: str = "checkpoints" + + +def _is_gate_param(name: str) -> bool: + return ".gate." in name or name.endswith(".gate_proj.weight") or name.endswith(".gate_proj.bias") + + +def build_param_groups(model: E2EAVModel, base_lr: float, cfg: TrainerConfig, stage: int) -> list[dict]: + """按模块归类参数为不同 LR 组。Stage1 时 DINOv3 lr=0。""" + groups: dict[str, list[nn.Parameter]] = { + "dinov3": [], + "backbone": [], + "calibration": [], + "head": [], + "gate": [], + "other": [], + } + for name, p in model.named_parameters(): + if not p.requires_grad and stage == 1: + continue + if name.startswith("dinov3."): + groups["dinov3"].append(p) + elif name.startswith("backbone."): + if _is_gate_param(name): + groups["gate"].append(p) + else: + groups["backbone"].append(p) + elif name.startswith("calib."): + if _is_gate_param(name): + groups["gate"].append(p) + else: + groups["calibration"].append(p) + elif name.startswith("det_traj_head.") or name.startswith("ctrl_head."): + groups["head"].append(p) + else: + groups["other"].append(p) + + dinov3_lr = base_lr * (cfg.dinov3_lr_mult_stage2 if stage == 2 else 0.0) + return [ + {"params": groups["dinov3"], "lr": dinov3_lr, "name": "dinov3"}, + {"params": groups["backbone"], "lr": base_lr * cfg.backbone_lr_mult, "name": "backbone"}, + {"params": groups["calibration"], "lr": base_lr * cfg.calibration_lr_mult, "name": "calibration"}, + {"params": groups["head"], "lr": base_lr * cfg.head_lr_mult, "name": "head"}, + {"params": groups["gate"], "lr": base_lr * cfg.gate_lr_mult, "name": "gate"}, + {"params": groups["other"], "lr": base_lr, "name": "other"}, + ] + + +def grad_norm_per_module(model: nn.Module, threshold: float) -> dict[str, float]: + """统计各顶层模块的 grad-norm,返回 dict(用于日志/告警)。 + + 跳过: + - 没有任何 ``requires_grad=True`` 参数的模块(如冻结的 DINOv3、纯 + buffer 模块 RoPE); + - 空模块(参数计数为 0)。 + """ + summary: dict[str, float] = {} + for name, child in model.named_children(): + params = list(child.parameters()) + if not params: + continue + if not any(p.requires_grad for p in params): + # 整个模块被冻结 -> 不监控 + continue + total = 0.0 + seen = 0 + for p in params: + if p.grad is not None: + total += float(p.grad.detach().norm().item()) ** 2 + seen += 1 + if seen == 0: + continue + n = math.sqrt(total) + summary[name] = n + if n < threshold: + log.warning("[grad_monitor] %s grad_norm=%.3e < %.3e", name, n, threshold) + if not math.isfinite(n): + log.error("[grad_monitor] %s grad_norm is %s (NaN/Inf)", name, n) + return summary + + +def compute_all_losses( + model_out: E2EOutput, + batch: dict, + matcher: HungarianMatcher, + num_classes: int, + cfg: TrainerConfig, + perturbation_residual: torch.Tensor | None = None, +) -> dict[str, torch.Tensor]: + """计算 8 项损失,返回字典。 + + ``perturbation_residual``:扰动训练时给定的 ground-truth 残差,用于额外 + 监督校准网络;正常训练为 None。 + """ + targets = batch["targets"] + + det_out = model_out.detection + ctrl_out = model_out.control + calib = model_out.calibration + + det_losses = detection_losses( + cls_logits=det_out.cls_logits, + box_mu=det_out.box3d_mu, + box_log_sigma=det_out.box3d_log_sigma, + isdyn_logit=det_out.is_dynamic_logit, + targets=targets, + matcher=matcher, + num_classes=num_classes, + ) + + L_traj_obj = object_traj_nll( + det_out.traj_mu, + det_out.traj_log_sigma, + det_losses.matched_indices, + targets, + ) + + L_traj_ego = ego_traj_nll( + ctrl_out.ego_traj_mu, + ctrl_out.ego_traj_log_sigma, + batch["ego_future"], + valid=batch.get("ego_future_valid"), + ) + + # 全局动作 GT 通常没有;此处用 0 做占位(实际数据集需补齐) + action_target = batch.get("action_target") + if action_target is None: + action_target = torch.zeros_like(ctrl_out.action_mu) + L_ctrl = action_nll( + ctrl_out.action_mu, ctrl_out.action_log_sigma, action_target + ) + L_traj_ego # 控制损失 = action + ego_traj 复用同一项的便利封装;trainer 视情况拆分 + + # MoE / 校准 正则 + L_moe = moe_load_balance_and_boundary( + model_out.backbone_out.moe_stats, + load_balance_weight=cfg.moe_load_balance_weight, + boundary_weight=cfg.moe_boundary_weight, + ) + L_calib_reg = calibration_regularization( + calib.ego_residual, calib.intr_residual, calib.extr_residual, + l2_weight=1.0, + ) + if perturbation_residual is not None: + # 扰动训练:计算校准网络应该预测的 GT 残差与实际残差的 MSE + actual = torch.cat( + [calib.ego_residual.flatten(1), calib.intr_residual, calib.extr_residual], + dim=-1, + ) + L_calib_reg = L_calib_reg + 1.0 * (actual - perturbation_residual).pow(2).mean() + + return { + "L_cls": det_losses.cls_loss, + "L_box": det_losses.box_nll + cfg.loss_giou_weight * det_losses.giou_loss, + "L_isdyn": det_losses.isdyn_loss, + "L_traj_obj": L_traj_obj, + "L_traj_ego": L_traj_ego, + "L_ctrl": L_ctrl, + "L_moe": L_moe, + "L_calib": L_calib_reg, + } + + +MAIN_TASK_KEYS = ["L_cls", "L_box", "L_isdyn", "L_traj_obj", "L_traj_ego", "L_ctrl"] +AUX_TASK_KEYS = ["L_moe", "L_calib"] + + +class Trainer: + """端到端训练器。""" + + def __init__( + self, + model: E2EAVModel, + cfg: TrainerConfig, + num_classes: int = 22, + device: str = "cuda", + ) -> None: + self.model = model.to(device) + self.cfg = cfg + self.num_classes = num_classes + self.device = device + self.matcher = HungarianMatcher() + self.global_step = 0 + self._micro_step = 0 # 用于 grad_accum + self._stage = 1 + self._build_optimizer() + + # === AMP 配置 === + # 仅在 device 为 cuda 时启用 autocast(CPU 上 bf16 也能跑但收益极小)。 + amp_dtype_map = { + "fp32": None, + "bf16": torch.bfloat16, + "fp16": torch.float16, + } + self.amp_dtype = amp_dtype_map[cfg.mixed_precision] + self.amp_enabled = self.amp_dtype is not None and "cuda" in str(device) + # GradScaler 仅 fp16 需要;bf16 数值范围大无需 scaler + self.scaler = ( + torch.amp.GradScaler("cuda") + if (self.amp_enabled and self.amp_dtype == torch.float16) + else None + ) + + # MoE 初始模式 = dense;Stage2 切 sparse + self.model.backbone.set_moe_mode("dense") + self.model.backbone.set_router_temperature(cfg.router_temp_init) + + # ---------- 优化器构建 ---------- + + def _build_optimizer(self) -> None: + cfg = self.cfg + groups = build_param_groups(self.model, cfg.base_lr, cfg, stage=self._stage) + self.optimizer = torch.optim.AdamW(groups, weight_decay=cfg.weight_decay, betas=(0.9, 0.95)) + self.scheduler = build_scheduler( + self.optimizer, + warmup_steps=cfg.warmup_steps, + total_steps=cfg.total_steps, + base_lr=cfg.base_lr, + min_lr=cfg.min_lr, + ) + + # PCGrad 共享参数 = 主干最后的“共享瓶颈”:final_norm + 最后 1 层 MoE block。 + # 不把全部 DINOv3/Calib/Backbone 都纳入,否则 N 个任务 × full-grad 扁平副本会 + # 在 a10g-small 上瞬间 OOM(~600M 参数 × 6 任务 × 2 副本 ≈ 28 GB)。 + # 较前的层仍享受 GradNorm 自适应加权 + 共同求和的标准多任务训练。 + shared: list[nn.Parameter] = [] + last_moe = self.model.backbone.moe_layers[-1] + for p in self.model.backbone.final_norm.parameters(): + if p.requires_grad: + shared.append(p) + for p in last_moe.parameters(): + if p.requires_grad: + shared.append(p) + # GradNorm 代理参数:取主干最后 LayerNorm 的 weight + proxy = self.model.backbone.final_norm.weight + mt_cfg = MultiTaskOptimizerConfig( + enable_gradnorm=cfg.enable_gradnorm, + enable_pcgrad=cfg.enable_pcgrad, + gradnorm_alpha=1.5, + gradnorm_lr=0.025, + pcgrad_shuffle=True, + ) + self.mto = MultiTaskOptimizer( + num_main_tasks=len(MAIN_TASK_KEYS), + shared_params=shared, + gradnorm_proxy_param=proxy, + cfg=mt_cfg, + ) + # GradNormBalancer 是 nn.Module,需要把 raw_weights / initial_losses 缓冲 + # 移到 model 所在 device,否则与 losses (cuda) 设备不匹配。 + if self.mto.gradnorm is not None: + self.mto.gradnorm.to(self.device) + + def _checkpoint_state(self) -> dict: + state: dict = { + "model": self.model.state_dict(), + "optimizer": self.optimizer.state_dict(), + "scheduler": self.scheduler.state_dict(), + "global_step": self.global_step, + "stage": self._stage, + } + if self.scaler is not None: + state["scaler"] = self.scaler.state_dict() + return state + + def save_checkpoint(self, ckpt_path: Path) -> None: + ckpt_path = Path(ckpt_path) + ckpt_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(self._checkpoint_state(), ckpt_path) + + def push_checkpoint_to_hub(self, ckpt_path: Path) -> None: + """将 ``ckpt_path`` 同步上传到 ``hub_repo_id``(需环境变量 ``HF_TOKEN``)。""" + repo = self.cfg.hub_repo_id + if not repo or not self.cfg.hub_push_checkpoints: + return + try: + from huggingface_hub import HfApi, create_repo + except ImportError as e: + log.warning("[Trainer] huggingface_hub 未安装,跳过 Hub 上传: %s", e) + return + prefix = self.cfg.hub_ckpt_prefix.strip().strip("/") + step = self.global_step + rel_step = f"{prefix}/step_{step:08d}.pt" + rel_latest = f"{prefix}/latest.pt" + api = HfApi(token=os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")) + try: + create_repo(repo, repo_type="model", exist_ok=True) + api.upload_file( + path_or_fileobj=str(ckpt_path), + path_in_repo=rel_step, + repo_id=repo, + repo_type="model", + commit_message=f"checkpoint step {step}", + ) + api.upload_file( + path_or_fileobj=str(ckpt_path), + path_in_repo=rel_latest, + repo_id=repo, + repo_type="model", + commit_message=f"latest ckpt step {step}", + ) + log.info("[Trainer] Hub 上传完成 hf.co/%s (%s)", repo, rel_latest) + except Exception: + log.exception("[Trainer] Hub 上传失败(可检查 HF_TOKEN / 配额)") + + # ---------- 阶段切换 ---------- + + def maybe_switch_stage(self) -> None: + cfg = self.cfg + if self._stage == 1 and self.global_step >= cfg.stage1_steps: + log.info("[Trainer] -> Stage 2 (sparse MoE + DINOv3 finetune + PCGrad)") + self._stage = 2 + # 1) MoE 切 sparse + self.model.backbone.set_moe_mode("sparse") + # 2) 路由温度退火完成 + self.model.backbone.set_router_temperature(cfg.router_temp_final) + # 3) DINOv3 解冻(小显存设备可禁用) + if cfg.unfreeze_dinov3_at_stage2: + self.model.dinov3.unfreeze() + # 4) 重建优化器(包含 DINOv3 参数)+ 启用 PCGrad + self._build_optimizer() + + # ---------- 单步 ---------- + + def train_step(self, batch: dict, rng: np.random.Generator) -> dict: + cfg = self.cfg + self.maybe_switch_stage() + + # 移到 device + batch = {k: (v.to(self.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} + # targets 是 list of dict,里面的 tensor 也移到 device + if "targets" in batch and isinstance(batch["targets"], list): + new_targets = [] + for t in batch["targets"]: + new_targets.append({k: (v.to(self.device) if isinstance(v, torch.Tensor) else v) for k, v in t.items()}) + batch["targets"] = new_targets + + # 扰动注入(Stage1 中期开启) + perturb_residual = None + ego_input = batch["ego_6d"] + intr_input = batch["intr_vec"] + extr_input = batch["extr_6d"] + if ( + self._stage == 1 + and self.global_step >= cfg.stage1_perturb_start + and rng.uniform() < 0.5 + ): + from ..data.transforms import perturb_kinematics + ego_input, intr_input, extr_input, delta = perturb_kinematics( + ego_input.cpu().clone(), intr_input.cpu().clone()[0], extr_input.cpu().clone()[0], + translation_std_m=0.1, rotation_std_deg=0.5, + intrinsic_std=0.005, extrinsic_std=0.005, + rng=rng, + ) + ego_input = ego_input.to(self.device) + # intr/extr 是 [B,...] 而 perturb_kinematics 是单样本;这里为简洁仅扰动第 0 个样本 + # 实际生产中应 batched 实现 + intr_input = batch["intr_vec"].clone() + extr_input = batch["extr_6d"].clone() + # GT 残差:在 symlog 空间 = -delta(symlog 是非线性,这里用线性近似) + perturb_residual = -delta.to(self.device).unsqueeze(0).expand(ego_input.shape[0], -1) + + # 前向(AMP autocast 仅包住 forward 与匹配/损失,反传由 PyTorch + # 在 fp32 主梯度下完成;GradNorm/PCGrad 内的 autograd.grad 也在 fp32) + ac_ctx = ( + torch.autocast(device_type="cuda", dtype=self.amp_dtype) + if self.amp_enabled + else _NullContext() + ) + with ac_ctx: + out = self.model( + images=batch["images"], + ego_6d_raw=ego_input, + intr_raw=intr_input, + extr_6d_raw=extr_input, + ) + losses = compute_all_losses( + out, batch, self.matcher, self.num_classes, cfg, + perturbation_residual=perturb_residual, + ) + + # === 把损失提升到 fp32 以保证后续 GradNorm/PCGrad 数值稳定 === + main = torch.stack([losses[k].float() for k in MAIN_TASK_KEYS]) + aux = sum(losses[k].float() for k in AUX_TASK_KEYS) + # 梯度累积:对累积步数取平均 + if cfg.grad_accum_steps > 1: + main = main / cfg.grad_accum_steps + aux = aux / cfg.grad_accum_steps + + # === 反传 === + if self._micro_step == 0: + self.optimizer.zero_grad(set_to_none=True) + + all_params = [p for p in self.model.parameters() if p.requires_grad] + if self.scaler is not None: + # fp16 路径:GradScaler 不直接支持 PCGrad(需手动调度);这里 + # 退化为标准 sum-backward。bf16 推荐路径无此限制。 + total = main.sum() + aux + self.scaler.scale(total).backward() + weights = torch.ones_like(main) + else: + total, weights = self.mto.backward(main, aux, all_params) + + self._micro_step += 1 + do_step = self._micro_step >= cfg.grad_accum_steps + if not do_step: + info_partial = { + "step": self.global_step, + "stage": self._stage, + "total_loss": float(total), + "weights": [float(w) for w in weights], + "grad_norms": {}, + } + for k, v in losses.items(): + info_partial[k] = float(v.detach()) + return info_partial + + # === 梯度裁剪 + 监控 + step === + if self.scaler is not None: + self.scaler.unscale_(self.optimizer) + grad_summary = grad_norm_per_module(self.model, cfg.grad_monitor_threshold) + torch.nn.utils.clip_grad_norm_(all_params, max_norm=cfg.grad_clip) + + if self.scaler is not None: + self.scaler.step(self.optimizer) + self.scaler.update() + else: + self.optimizer.step() + self.scheduler.step() + self._micro_step = 0 + + # 路由温度线性退火 + if self._stage == 1: + ratio = min(1.0, self.global_step / max(1, cfg.stage1_steps)) + t = cfg.router_temp_init + ratio * (cfg.router_temp_final - cfg.router_temp_init) + self.model.backbone.set_router_temperature(t) + + self.global_step += 1 + + info = { + "step": self.global_step, + "stage": self._stage, + "total_loss": float(total), + "weights": [float(w) for w in weights], + "grad_norms": grad_summary, + } + for k, v in losses.items(): + info[k] = float(v.detach()) + return info + + def fit(self, loader: DataLoader, max_steps: int | None = None) -> None: + """简化训练循环。""" + rng = np.random.default_rng(0) + steps = max_steps or self.cfg.total_steps + it = iter(loader) + for _ in range(steps): + try: + batch = next(it) + except StopIteration: + it = iter(loader) + batch = next(it) + info = self.train_step(batch, rng) + step = info["step"] + if self.cfg.ckpt_interval > 0 and step % self.cfg.ckpt_interval == 0: + ckpt_path = Path(self.cfg.ckpt_dir) / f"step_{step:08d}.pt" + self.save_checkpoint(ckpt_path) + self.push_checkpoint_to_hub(ckpt_path) + if step % self.cfg.log_interval == 0: + log.info( + "step=%d stage=%d total=%.4f cls=%.4f box=%.4f isdyn=%.4f traj_obj=%.4f traj_ego=%.4f ctrl=%.4f moe=%.4f calib=%.4f", + info["step"], info["stage"], info["total_loss"], + info["L_cls"], info["L_box"], info["L_isdyn"], + info["L_traj_obj"], info["L_traj_ego"], info["L_ctrl"], + info["L_moe"], info["L_calib"], + )