File size: 7,457 Bytes
0cfefd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""估算 E2EAVModel 在 BS≥8 训练时的显存/内存需求。

输出
  - 各模块参数数量
  - 训练显存细分:参数 / 优化器 / 梯度 / 主激活 / 多任务梯度副本 / 缓冲
  - 推荐设备(HF Sandbox / Jobs)
  - 主机内存与磁盘开销

公式说明(粗略上界)
  - 参数 (bf16): 2 B/p;fp32 主副本: 4 B/p
  - AdamW 一阶/二阶矩 (fp32): 8 B/p
  - 梯度 (fp32): 4 B/p
  - bf16 训练总计:参数 2 + 主 4 + AdamW 8 + grad 4 = 18 B/可训练 p
  - DINOv3 冻结 Stage1:仅 2 B/p(前向激活按 no_grad 释放,可忽略)
  - 主激活:每层约 ``B * N * D * 2 B``(bf16),18 层;MoE 层另加 8 个专家
    SwiGLU 中间 ``B * N * 2 * 4D * 2 B`` 的临时项,但 Dense 加权求和后只
    需 1 份输出。实际显存按"激活 = 单层峰值 × 层数"近似。
  - PCGrad 在共享参数上 N 次 ``autograd.grad``:需要 retain_graph,
    每个任务额外保留中间激活的引用,最坏放大 N 倍。这里按 1.5x 估算
    (GPU autograd 内部 reuse + checkpointing 后通常远低于 N 倍)。
"""

from __future__ import annotations

import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "src"))

from dataclasses import dataclass

from wjad.model import E2EAVModel


@dataclass
class MemoryReport:
    bs: int
    seq_len: int
    dim: int
    layers: int
    params_total: int
    params_trainable_stage1: int
    params_trainable_stage2: int
    weights_gb_stage1: float
    weights_gb_stage2: float
    optim_gb_stage1: float
    optim_gb_stage2: float
    activations_gb: float
    pcgrad_overhead_gb: float
    total_stage1_gb: float
    total_stage2_gb: float
    host_ram_gb: float
    disk_gb: float


def count_params(model) -> tuple[int, dict[str, int]]:
    total = 0
    by_module: dict[str, int] = {}
    for name, child in model.named_children():
        n = sum(p.numel() for p in child.parameters())
        by_module[name] = n
        total += n
    return total, by_module


def estimate(bs: int = 8) -> MemoryReport:
    model = E2EAVModel(
        dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
        # 完整规模
        backbone_dim=768,
        num_heads=12,
        num_dense_layers=9,
        num_moe_layers=9,
        num_routed_experts=7,
        num_shared_experts=1,
        topk_experts=3,
        ffn_mult=4,
        num_history_frames=8,
        num_detection_tokens=1024,
        num_control_tokens=24,
        num_ego_tokens=8,
        num_extra_tokens=256,
        image_h=384,
        image_w=1024,
        patch_size=16,
        num_classes=22,
        traj_horizon=24,
        freeze_dinov3=True,
    )

    total, by_module = count_params(model)
    dinov3_n = by_module.get("dinov3", 0)
    trainable_stage1 = total - dinov3_n
    trainable_stage2 = total

    # 序列长度(拼接后总 token 数 + 上下文)
    n_visual = (8 // 2) * (24 // 2) * (64 // 2)
    seq_len = n_visual + 8 + 1024 + 24 + 256

    # === 显存 ===
    # 单位:GB(除以 1024**3)
    GB = 1024 ** 3
    weights_stage1 = (dinov3_n * 2 + trainable_stage1 * 2) / GB                # 全部 bf16
    weights_stage2 = (total * 2) / GB
    optim_stage1 = (trainable_stage1 * (4 + 4 + 4)) / GB                        # master + m + v
    optim_stage2 = (trainable_stage2 * (4 + 4 + 4)) / GB

    # 激活:粗略 = bs * seq_len * dim * 2 * (num_layers + 1) * 1.5 (含 attn/FFN 重叠)
    base_act = bs * seq_len * 768 * 2 * (18 + 6) * 1.5  # 主干 18 + 校准 6
    # MoE FFN 中间 (4D = 3072) 的临时项:每 MoE 层 ≈ bs * seq_len * 3072 * 2 * 8(8 专家)
    moe_act = bs * seq_len * 3072 * 2 * 8 * 9
    # DINOv3 冻结:no_grad,前向激活在 forward 后立即释放,估 2 GB 峰值
    dino_act = 2.0 * GB
    activations_gb = (base_act + moe_act + dino_act) / GB

    # PCGrad 开销(共享参数上 N 次 autograd.grad):retain_graph 阶段会
    # 阻止激活释放,最坏接近 1.5x;这里按 +0.5x 估算
    pcgrad_overhead_gb = 0.5 * activations_gb

    total_stage1 = weights_stage1 + optim_stage1 + activations_gb + pcgrad_overhead_gb + 2.0
    total_stage2 = weights_stage2 + optim_stage2 + activations_gb + pcgrad_overhead_gb + 2.0

    # === 主机 RAM ===
    # DataLoader prefetch + workers + 模型 CPU 副本 + JSON / LIDAR 解析
    host_ram = 8.0 + bs * 0.3 * 4 * 2  # 4 workers, prefetch 2

    # === 磁盘 ===
    # 全量数据集 ~3TB;只跑 sandbox 时 ~5GB(几个 clip);典型 ~50GB(一个 weather 全部)
    disk = 50.0

    return MemoryReport(
        bs=bs,
        seq_len=seq_len,
        dim=768,
        layers=18,
        params_total=total,
        params_trainable_stage1=trainable_stage1,
        params_trainable_stage2=trainable_stage2,
        weights_gb_stage1=weights_stage1,
        weights_gb_stage2=weights_stage2,
        optim_gb_stage1=optim_stage1,
        optim_gb_stage2=optim_stage2,
        activations_gb=activations_gb,
        pcgrad_overhead_gb=pcgrad_overhead_gb,
        total_stage1_gb=total_stage1,
        total_stage2_gb=total_stage2,
        host_ram_gb=host_ram,
        disk_gb=disk,
    )


def recommend_device(stage_max_gb: float) -> tuple[str, str]:
    """根据 Stage2 峰值显存推荐 GPU。"""
    margin = 1.15  # 留 15% 余量(碎片化、CUDA caching、cuBLAS workspace)
    need = stage_max_gb * margin
    candidates = [
        ("T4 16GB", 16),
        ("L4 24GB", 24),
        ("A10G 24GB", 24),
        ("A10G Large 48GB", 48),
        ("A100 40GB", 40),
        ("L40S 48GB", 48),
        ("A100 80GB", 80),
        ("H100 80GB", 80),
    ]
    fit = [c for c in candidates if c[1] >= need]
    if not fit:
        return "H200 / 多卡 80GB+", f"需要 ≥{need:.1f} GB(单卡极限)"
    return fit[0][0], f"需要 ≥{need:.1f} GB"


def main() -> None:
    print("=" * 72)
    print(" WJAD 训练显存/内存估算 (bf16 AMP)")
    print("=" * 72)
    for bs in (1, 2, 4, 8, 16):
        r = estimate(bs)
        print(f"\n--- BS = {bs} ---")
        print(f"  总参数        : {r.params_total / 1e6:8.2f} M")
        print(f"  可训练 (S1)   : {r.params_trainable_stage1 / 1e6:8.2f} M")
        print(f"  可训练 (S2)   : {r.params_trainable_stage2 / 1e6:8.2f} M")
        print(f"  序列长度      : {r.seq_len}")
        print(f"  权重 (S1/S2)  : {r.weights_gb_stage1:6.2f} / {r.weights_gb_stage2:6.2f} GB")
        print(f"  优化器 (S1/S2): {r.optim_gb_stage1:6.2f} / {r.optim_gb_stage2:6.2f} GB")
        print(f"  激活          : {r.activations_gb:6.2f} GB")
        print(f"  PCGrad 余量   : {r.pcgrad_overhead_gb:6.2f} GB")
        print(f"  显存合计 S1   : {r.total_stage1_gb:6.2f} GB")
        print(f"  显存合计 S2   : {r.total_stage2_gb:6.2f} GB  <- 峰值")
        gpu, note = recommend_device(r.total_stage2_gb)
        print(f"  推荐 GPU      : {gpu}  ({note})")
        print(f"  主机 RAM      : ≥ {r.host_ram_gb:6.2f} GB")
        print(f"  磁盘 (典型)   : ≈ {r.disk_gb:6.0f} GB")

    print()
    print("说明:")
    print("  - 估算包含 bf16 AMP + AdamW(m,v fp32) + 梯度 fp32 主副本 + PCGrad 开销。")
    print("  - 开 ``gradient_checkpointing`` 可把激活降至约 1/3,BS 可成倍提升。")
    print("  - 实测请用 ``nvidia-smi`` 或 ``torch.cuda.max_memory_allocated()`` 校准。")


if __name__ == "__main__":
    main()