File size: 7,457 Bytes
0cfefd2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | """估算 E2EAVModel 在 BS≥8 训练时的显存/内存需求。
输出
- 各模块参数数量
- 训练显存细分:参数 / 优化器 / 梯度 / 主激活 / 多任务梯度副本 / 缓冲
- 推荐设备(HF Sandbox / Jobs)
- 主机内存与磁盘开销
公式说明(粗略上界)
- 参数 (bf16): 2 B/p;fp32 主副本: 4 B/p
- AdamW 一阶/二阶矩 (fp32): 8 B/p
- 梯度 (fp32): 4 B/p
- bf16 训练总计:参数 2 + 主 4 + AdamW 8 + grad 4 = 18 B/可训练 p
- DINOv3 冻结 Stage1:仅 2 B/p(前向激活按 no_grad 释放,可忽略)
- 主激活:每层约 ``B * N * D * 2 B``(bf16),18 层;MoE 层另加 8 个专家
SwiGLU 中间 ``B * N * 2 * 4D * 2 B`` 的临时项,但 Dense 加权求和后只
需 1 份输出。实际显存按"激活 = 单层峰值 × 层数"近似。
- PCGrad 在共享参数上 N 次 ``autograd.grad``:需要 retain_graph,
每个任务额外保留中间激活的引用,最坏放大 N 倍。这里按 1.5x 估算
(GPU autograd 内部 reuse + checkpointing 后通常远低于 N 倍)。
"""
from __future__ import annotations
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "src"))
from dataclasses import dataclass
from wjad.model import E2EAVModel
@dataclass
class MemoryReport:
bs: int
seq_len: int
dim: int
layers: int
params_total: int
params_trainable_stage1: int
params_trainable_stage2: int
weights_gb_stage1: float
weights_gb_stage2: float
optim_gb_stage1: float
optim_gb_stage2: float
activations_gb: float
pcgrad_overhead_gb: float
total_stage1_gb: float
total_stage2_gb: float
host_ram_gb: float
disk_gb: float
def count_params(model) -> tuple[int, dict[str, int]]:
total = 0
by_module: dict[str, int] = {}
for name, child in model.named_children():
n = sum(p.numel() for p in child.parameters())
by_module[name] = n
total += n
return total, by_module
def estimate(bs: int = 8) -> MemoryReport:
model = E2EAVModel(
dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
# 完整规模
backbone_dim=768,
num_heads=12,
num_dense_layers=9,
num_moe_layers=9,
num_routed_experts=7,
num_shared_experts=1,
topk_experts=3,
ffn_mult=4,
num_history_frames=8,
num_detection_tokens=1024,
num_control_tokens=24,
num_ego_tokens=8,
num_extra_tokens=256,
image_h=384,
image_w=1024,
patch_size=16,
num_classes=22,
traj_horizon=24,
freeze_dinov3=True,
)
total, by_module = count_params(model)
dinov3_n = by_module.get("dinov3", 0)
trainable_stage1 = total - dinov3_n
trainable_stage2 = total
# 序列长度(拼接后总 token 数 + 上下文)
n_visual = (8 // 2) * (24 // 2) * (64 // 2)
seq_len = n_visual + 8 + 1024 + 24 + 256
# === 显存 ===
# 单位:GB(除以 1024**3)
GB = 1024 ** 3
weights_stage1 = (dinov3_n * 2 + trainable_stage1 * 2) / GB # 全部 bf16
weights_stage2 = (total * 2) / GB
optim_stage1 = (trainable_stage1 * (4 + 4 + 4)) / GB # master + m + v
optim_stage2 = (trainable_stage2 * (4 + 4 + 4)) / GB
# 激活:粗略 = bs * seq_len * dim * 2 * (num_layers + 1) * 1.5 (含 attn/FFN 重叠)
base_act = bs * seq_len * 768 * 2 * (18 + 6) * 1.5 # 主干 18 + 校准 6
# MoE FFN 中间 (4D = 3072) 的临时项:每 MoE 层 ≈ bs * seq_len * 3072 * 2 * 8(8 专家)
moe_act = bs * seq_len * 3072 * 2 * 8 * 9
# DINOv3 冻结:no_grad,前向激活在 forward 后立即释放,估 2 GB 峰值
dino_act = 2.0 * GB
activations_gb = (base_act + moe_act + dino_act) / GB
# PCGrad 开销(共享参数上 N 次 autograd.grad):retain_graph 阶段会
# 阻止激活释放,最坏接近 1.5x;这里按 +0.5x 估算
pcgrad_overhead_gb = 0.5 * activations_gb
total_stage1 = weights_stage1 + optim_stage1 + activations_gb + pcgrad_overhead_gb + 2.0
total_stage2 = weights_stage2 + optim_stage2 + activations_gb + pcgrad_overhead_gb + 2.0
# === 主机 RAM ===
# DataLoader prefetch + workers + 模型 CPU 副本 + JSON / LIDAR 解析
host_ram = 8.0 + bs * 0.3 * 4 * 2 # 4 workers, prefetch 2
# === 磁盘 ===
# 全量数据集 ~3TB;只跑 sandbox 时 ~5GB(几个 clip);典型 ~50GB(一个 weather 全部)
disk = 50.0
return MemoryReport(
bs=bs,
seq_len=seq_len,
dim=768,
layers=18,
params_total=total,
params_trainable_stage1=trainable_stage1,
params_trainable_stage2=trainable_stage2,
weights_gb_stage1=weights_stage1,
weights_gb_stage2=weights_stage2,
optim_gb_stage1=optim_stage1,
optim_gb_stage2=optim_stage2,
activations_gb=activations_gb,
pcgrad_overhead_gb=pcgrad_overhead_gb,
total_stage1_gb=total_stage1,
total_stage2_gb=total_stage2,
host_ram_gb=host_ram,
disk_gb=disk,
)
def recommend_device(stage_max_gb: float) -> tuple[str, str]:
"""根据 Stage2 峰值显存推荐 GPU。"""
margin = 1.15 # 留 15% 余量(碎片化、CUDA caching、cuBLAS workspace)
need = stage_max_gb * margin
candidates = [
("T4 16GB", 16),
("L4 24GB", 24),
("A10G 24GB", 24),
("A10G Large 48GB", 48),
("A100 40GB", 40),
("L40S 48GB", 48),
("A100 80GB", 80),
("H100 80GB", 80),
]
fit = [c for c in candidates if c[1] >= need]
if not fit:
return "H200 / 多卡 80GB+", f"需要 ≥{need:.1f} GB(单卡极限)"
return fit[0][0], f"需要 ≥{need:.1f} GB"
def main() -> None:
print("=" * 72)
print(" WJAD 训练显存/内存估算 (bf16 AMP)")
print("=" * 72)
for bs in (1, 2, 4, 8, 16):
r = estimate(bs)
print(f"\n--- BS = {bs} ---")
print(f" 总参数 : {r.params_total / 1e6:8.2f} M")
print(f" 可训练 (S1) : {r.params_trainable_stage1 / 1e6:8.2f} M")
print(f" 可训练 (S2) : {r.params_trainable_stage2 / 1e6:8.2f} M")
print(f" 序列长度 : {r.seq_len}")
print(f" 权重 (S1/S2) : {r.weights_gb_stage1:6.2f} / {r.weights_gb_stage2:6.2f} GB")
print(f" 优化器 (S1/S2): {r.optim_gb_stage1:6.2f} / {r.optim_gb_stage2:6.2f} GB")
print(f" 激活 : {r.activations_gb:6.2f} GB")
print(f" PCGrad 余量 : {r.pcgrad_overhead_gb:6.2f} GB")
print(f" 显存合计 S1 : {r.total_stage1_gb:6.2f} GB")
print(f" 显存合计 S2 : {r.total_stage2_gb:6.2f} GB <- 峰值")
gpu, note = recommend_device(r.total_stage2_gb)
print(f" 推荐 GPU : {gpu} ({note})")
print(f" 主机 RAM : ≥ {r.host_ram_gb:6.2f} GB")
print(f" 磁盘 (典型) : ≈ {r.disk_gb:6.0f} GB")
print()
print("说明:")
print(" - 估算包含 bf16 AMP + AdamW(m,v fp32) + 梯度 fp32 主副本 + PCGrad 开销。")
print(" - 开 ``gradient_checkpointing`` 可把激活降至约 1/3,BS 可成倍提升。")
print(" - 实测请用 ``nvidia-smi`` 或 ``torch.cuda.max_memory_allocated()`` 校准。")
if __name__ == "__main__":
main()
|