| # 显存与内存估算(bf16 AMP + GradNorm + PCGrad) |
|
|
| 由 `scripts/estimate_memory.py` 生成。模型规模:18 层主干(9 Dense + 9 MoE,每 MoE 层 7 路由 + 1 共享专家)+ DINOv3 ViT-B/16 + 6 层校准 + 1024 检测 token + 24 控制 token。 |
|
|
| | 项目 | 数值 | |
| |---|---| |
| | 总参数 | **725.62 M** | |
| | 可训练 (Stage1, DINOv3 冻结) | 639.96 M | |
| | 可训练 (Stage2, DINOv3 解冻) | 725.62 M | |
| | 序列长度(拼接后) | 2848 | |
|
|
| ## 显存(含 15% 余量) |
|
|
| | Batch Size | Stage2 峰值 | 推荐单卡 GPU | HF Sandbox 选项 | |
| |---:|---:|---|---| |
| | 1 | ~16 GB | T4 16GB(紧)/ L4 24GB | `t4-small` | |
| | 2 | ~18 GB | L4 24GB | `l4x1` | |
| | 4 | ~22 GB | L4 24GB / A10G 24GB | `a10g-small` | |
| | **8 (目标)** | **~30 GB** | **A10G Large 48GB / A100 40GB** | **`a10g-large`** | |
| | 16 | ~46 GB | A100 80GB / H100 80GB | `a100-large` | |
|
|
| 显存细分(BS=8 Stage2): |
| - 权重 (bf16): 1.35 GB |
| - 优化器 (AdamW fp32 m+v + 主副本): 8.11 GB |
| - 主激活 (bf16, 18 + 6 层): ~12.6 GB |
| - PCGrad retain_graph 开销: ~6.3 GB |
| - 缓冲 / cuDNN workspace / 碎片: ~2 GB |
| |
| 如显存不足: |
| - 开 `gradient_checkpointing`(激活降至 ~1/3,可把 BS=8 塞进 A10G 24GB 大约 28GB) |
| - BS=4 + `grad_accum_steps=2` 等价 BS=8 训练 |
| - 关 PCGrad(节省 ~6 GB),但牺牲多任务收敛质量 |
|
|
| ## 主机内存 / 磁盘 |
|
|
| | 项目 | 数值(BS=8) | |
| |---|---| |
| | 主机 RAM 推荐 | ≥ 32 GB(DataLoader 4 workers × prefetch 2 + 模型 CPU 副本) | |
| | 磁盘(一个 weather 子集,sandbox 验证) | ~50 GB | |
| | 磁盘(synthetic 全量 121 帧 × 7 weather × 5843 clip) | ~700 GB | |
| | 磁盘(synthetic + lidar + hdmap 全量) | ~3 TB | |
|
|
| ## 设备选择建议 |
|
|
| - **本地烟囱(CPU/小卡)**:`scripts/smoke_test.py` 用极小张量验证 forward+backward,不需要 GPU。 |
| - **HF Sandbox**:`a10g-large` (48 GB),BS=8 + bf16 + PCGrad 一次成功;约 $1.05/小时(HF 价格随时调整请以官方为准)。 |
| - **HF Jobs 全量训练**:`a100x1` (80 GB) 或 `h100x1`,BS=8~16。 |
|
|
| ## 复现命令 |
|
|
| ```bash |
| # 升级依赖到最新(写入 requirements.lock.txt) |
| python scripts/update_deps.py --torch-index https://download.pytorch.org/whl/cu124 |
| |
| # 估算 |
| python scripts/estimate_memory.py |
| |
| # Sandbox 推送 |
| python scripts/push_to_sandbox.py --repo your-username/wjad-sandbox --gpu a10g-large |
| |
| # Jobs 全量 |
| python scripts/push_to_jobs.py --repo your-username/wjad --flavor a100x1 |
| ``` |
|
|