WJAD / scripts /estimate_memory.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""估算 E2EAVModel 在 BS≥8 训练时的显存/内存需求。
输出
- 各模块参数数量
- 训练显存细分:参数 / 优化器 / 梯度 / 主激活 / 多任务梯度副本 / 缓冲
- 推荐设备(HF Sandbox / Jobs)
- 主机内存与磁盘开销
公式说明(粗略上界)
- 参数 (bf16): 2 B/p;fp32 主副本: 4 B/p
- AdamW 一阶/二阶矩 (fp32): 8 B/p
- 梯度 (fp32): 4 B/p
- bf16 训练总计:参数 2 + 主 4 + AdamW 8 + grad 4 = 18 B/可训练 p
- DINOv3 冻结 Stage1:仅 2 B/p(前向激活按 no_grad 释放,可忽略)
- 主激活:每层约 ``B * N * D * 2 B``(bf16),18 层;MoE 层另加 8 个专家
SwiGLU 中间 ``B * N * 2 * 4D * 2 B`` 的临时项,但 Dense 加权求和后只
需 1 份输出。实际显存按"激活 = 单层峰值 × 层数"近似。
- PCGrad 在共享参数上 N 次 ``autograd.grad``:需要 retain_graph,
每个任务额外保留中间激活的引用,最坏放大 N 倍。这里按 1.5x 估算
(GPU autograd 内部 reuse + checkpointing 后通常远低于 N 倍)。
"""
from __future__ import annotations
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "src"))
from dataclasses import dataclass
from wjad.model import E2EAVModel
@dataclass
class MemoryReport:
bs: int
seq_len: int
dim: int
layers: int
params_total: int
params_trainable_stage1: int
params_trainable_stage2: int
weights_gb_stage1: float
weights_gb_stage2: float
optim_gb_stage1: float
optim_gb_stage2: float
activations_gb: float
pcgrad_overhead_gb: float
total_stage1_gb: float
total_stage2_gb: float
host_ram_gb: float
disk_gb: float
def count_params(model) -> tuple[int, dict[str, int]]:
total = 0
by_module: dict[str, int] = {}
for name, child in model.named_children():
n = sum(p.numel() for p in child.parameters())
by_module[name] = n
total += n
return total, by_module
def estimate(bs: int = 8) -> MemoryReport:
model = E2EAVModel(
dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
# 完整规模
backbone_dim=768,
num_heads=12,
num_dense_layers=9,
num_moe_layers=9,
num_routed_experts=7,
num_shared_experts=1,
topk_experts=3,
ffn_mult=4,
num_history_frames=8,
num_detection_tokens=1024,
num_control_tokens=24,
num_ego_tokens=8,
num_extra_tokens=256,
image_h=384,
image_w=1024,
patch_size=16,
num_classes=22,
traj_horizon=24,
freeze_dinov3=True,
)
total, by_module = count_params(model)
dinov3_n = by_module.get("dinov3", 0)
trainable_stage1 = total - dinov3_n
trainable_stage2 = total
# 序列长度(拼接后总 token 数 + 上下文)
n_visual = (8 // 2) * (24 // 2) * (64 // 2)
seq_len = n_visual + 8 + 1024 + 24 + 256
# === 显存 ===
# 单位:GB(除以 1024**3)
GB = 1024 ** 3
weights_stage1 = (dinov3_n * 2 + trainable_stage1 * 2) / GB # 全部 bf16
weights_stage2 = (total * 2) / GB
optim_stage1 = (trainable_stage1 * (4 + 4 + 4)) / GB # master + m + v
optim_stage2 = (trainable_stage2 * (4 + 4 + 4)) / GB
# 激活:粗略 = bs * seq_len * dim * 2 * (num_layers + 1) * 1.5 (含 attn/FFN 重叠)
base_act = bs * seq_len * 768 * 2 * (18 + 6) * 1.5 # 主干 18 + 校准 6
# MoE FFN 中间 (4D = 3072) 的临时项:每 MoE 层 ≈ bs * seq_len * 3072 * 2 * 8(8 专家)
moe_act = bs * seq_len * 3072 * 2 * 8 * 9
# DINOv3 冻结:no_grad,前向激活在 forward 后立即释放,估 2 GB 峰值
dino_act = 2.0 * GB
activations_gb = (base_act + moe_act + dino_act) / GB
# PCGrad 开销(共享参数上 N 次 autograd.grad):retain_graph 阶段会
# 阻止激活释放,最坏接近 1.5x;这里按 +0.5x 估算
pcgrad_overhead_gb = 0.5 * activations_gb
total_stage1 = weights_stage1 + optim_stage1 + activations_gb + pcgrad_overhead_gb + 2.0
total_stage2 = weights_stage2 + optim_stage2 + activations_gb + pcgrad_overhead_gb + 2.0
# === 主机 RAM ===
# DataLoader prefetch + workers + 模型 CPU 副本 + JSON / LIDAR 解析
host_ram = 8.0 + bs * 0.3 * 4 * 2 # 4 workers, prefetch 2
# === 磁盘 ===
# 全量数据集 ~3TB;只跑 sandbox 时 ~5GB(几个 clip);典型 ~50GB(一个 weather 全部)
disk = 50.0
return MemoryReport(
bs=bs,
seq_len=seq_len,
dim=768,
layers=18,
params_total=total,
params_trainable_stage1=trainable_stage1,
params_trainable_stage2=trainable_stage2,
weights_gb_stage1=weights_stage1,
weights_gb_stage2=weights_stage2,
optim_gb_stage1=optim_stage1,
optim_gb_stage2=optim_stage2,
activations_gb=activations_gb,
pcgrad_overhead_gb=pcgrad_overhead_gb,
total_stage1_gb=total_stage1,
total_stage2_gb=total_stage2,
host_ram_gb=host_ram,
disk_gb=disk,
)
def recommend_device(stage_max_gb: float) -> tuple[str, str]:
"""根据 Stage2 峰值显存推荐 GPU。"""
margin = 1.15 # 留 15% 余量(碎片化、CUDA caching、cuBLAS workspace)
need = stage_max_gb * margin
candidates = [
("T4 16GB", 16),
("L4 24GB", 24),
("A10G 24GB", 24),
("A10G Large 48GB", 48),
("A100 40GB", 40),
("L40S 48GB", 48),
("A100 80GB", 80),
("H100 80GB", 80),
]
fit = [c for c in candidates if c[1] >= need]
if not fit:
return "H200 / 多卡 80GB+", f"需要 ≥{need:.1f} GB(单卡极限)"
return fit[0][0], f"需要 ≥{need:.1f} GB"
def main() -> None:
print("=" * 72)
print(" WJAD 训练显存/内存估算 (bf16 AMP)")
print("=" * 72)
for bs in (1, 2, 4, 8, 16):
r = estimate(bs)
print(f"\n--- BS = {bs} ---")
print(f" 总参数 : {r.params_total / 1e6:8.2f} M")
print(f" 可训练 (S1) : {r.params_trainable_stage1 / 1e6:8.2f} M")
print(f" 可训练 (S2) : {r.params_trainable_stage2 / 1e6:8.2f} M")
print(f" 序列长度 : {r.seq_len}")
print(f" 权重 (S1/S2) : {r.weights_gb_stage1:6.2f} / {r.weights_gb_stage2:6.2f} GB")
print(f" 优化器 (S1/S2): {r.optim_gb_stage1:6.2f} / {r.optim_gb_stage2:6.2f} GB")
print(f" 激活 : {r.activations_gb:6.2f} GB")
print(f" PCGrad 余量 : {r.pcgrad_overhead_gb:6.2f} GB")
print(f" 显存合计 S1 : {r.total_stage1_gb:6.2f} GB")
print(f" 显存合计 S2 : {r.total_stage2_gb:6.2f} GB <- 峰值")
gpu, note = recommend_device(r.total_stage2_gb)
print(f" 推荐 GPU : {gpu} ({note})")
print(f" 主机 RAM : ≥ {r.host_ram_gb:6.2f} GB")
print(f" 磁盘 (典型) : ≈ {r.disk_gb:6.0f} GB")
print()
print("说明:")
print(" - 估算包含 bf16 AMP + AdamW(m,v fp32) + 梯度 fp32 主副本 + PCGrad 开销。")
print(" - 开 ``gradient_checkpointing`` 可把激活降至约 1/3,BS 可成倍提升。")
print(" - 实测请用 ``nvidia-smi`` 或 ``torch.cuda.max_memory_allocated()`` 校准。")
if __name__ == "__main__":
main()