File size: 4,634 Bytes
0cfefd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""端到端训练循环烟囱测试:构造随机 batch,跑 1-2 步 trainer。

不依赖磁盘上的数据集,仅验证 forward/backward/loss/PCGrad/GradNorm 链路。
"""

from __future__ import annotations

import os
import sys
from pathlib import Path

os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "src"))

import logging

import numpy as np
import torch

from wjad.model import E2EAVModel
from wjad.train.trainer import Trainer, TrainerConfig


def _make_dummy_batch(
    B: int = 1,
    T: int = 8,
    H: int = 64,
    W: int = 128,
    num_classes: int = 22,
    num_objects: int = 3,
) -> dict:
    """构造极小分辨率的随机 batch(CPU 烟囱测试用)。"""
    images = torch.randn(B, T, 3, H, W)
    ego_6d = torch.zeros(B, T, 6)
    intr_vec = torch.tensor([[
        W / 2, H / 2, W, H,
        0.0, 0.5, 0.0, 0.0, 0.0, 0.0,
        1.0,
    ]] * B)
    extr_6d = torch.zeros(B, 6)
    ego_future = torch.zeros(B, 24, 3)
    ego_future_valid = torch.ones(B, 24, dtype=torch.bool)

    targets = []
    for _ in range(B):
        boxes = torch.zeros(num_objects, 7)
        boxes[:, 3:6] = 2.0
        targets.append({
            "labels": torch.randint(1, num_classes, (num_objects,)),
            "boxes": boxes,
            "is_dynamic": torch.ones(num_objects, dtype=torch.long),
            "future_traj": torch.zeros(num_objects, 24, 3),
            "future_valid": torch.ones(num_objects, 24, dtype=torch.bool),
        })

    return {
        "images": images,
        "ego_6d": ego_6d,
        "intr_vec": intr_vec,
        "extr_6d": extr_6d,
        "ego_future": ego_future,
        "ego_future_valid": ego_future_valid,
        "targets": targets,
        "meta": [{}] * B,
    }


def main() -> None:
    logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
    torch.manual_seed(0)

    has_cuda = torch.cuda.is_available()
    device = "cuda" if has_cuda else "cpu"
    if has_cuda:
        # GPU 上跑接近真实规模:full 384x1024 + 完整 18 层
        # a10g-small (~22 GB) 上 BS=4 OOM,启用 gradient_checkpointing 后 BS=2 稳定
        H, W = 384, 1024
        B = 2
        num_dense, num_moe = 9, 9
        num_det = 1024
        num_extra = 256
        amp = "bf16"
        n_steps = 4
        use_grad_ckpt = True
    else:
        # CPU 上跑极小规模仅做 sanity
        H, W = 64, 128
        B = 1
        num_dense, num_moe = 2, 2
        num_det = 32
        num_extra = 16
        amp = "fp32"
        n_steps = 4
        use_grad_ckpt = False

    print(f"[smoke_train] device={device}, H={H} W={W} B={B} amp={amp} grad_ckpt={use_grad_ckpt}")
    model = E2EAVModel(
        dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
        num_dense_layers=num_dense,
        num_moe_layers=num_moe,
        num_detection_tokens=num_det,
        num_control_tokens=24,
        num_ego_tokens=8,
        num_extra_tokens=num_extra,
        num_classes=22,
        image_h=H,
        image_w=W,
        patch_size=16,
    )
    if use_grad_ckpt:
        model.backbone.set_gradient_checkpointing(True)
    # sandbox a10g-small 不做 DINOv3 finetune(显存预算 22GB 不够),冻结即可
    # 验证两阶段路径切换。完整训练交给 H100 Jobs。
    model.dinov3.freeze()

    cfg = TrainerConfig(
        total_steps=n_steps,
        warmup_steps=1,
        base_lr=1e-4,
        log_interval=1,
        stage1_steps=2,            # 跑到 stage2 验证切换路径
        stage1_perturb_start=1,
        enable_gradnorm=True,
        enable_pcgrad=True,        # 全程启用 PCGrad
        mixed_precision=amp,
        unfreeze_dinov3_at_stage2=False,  # sandbox 显存有限,验证路径即可
    )
    trainer = Trainer(model, cfg, num_classes=22, device=device)
    rng = np.random.default_rng(0)

    if has_cuda:
        torch.cuda.reset_peak_memory_stats()

    for step in range(n_steps):
        batch = _make_dummy_batch(B=B, H=H, W=W)
        info = trainer.train_step(batch, rng)
        print(
            f"step={info['step']} stage={info['stage']} total={info['total_loss']:.4f} "
            f"cls={info['L_cls']:.4f} box={info['L_box']:.4f} traj_obj={info['L_traj_obj']:.4f} "
            f"weights={[f'{w:.2f}' for w in info['weights']]}"
        )

    if has_cuda:
        peak_gb = torch.cuda.max_memory_allocated() / 1024**3
        print(f"[smoke_train] CUDA peak memory = {peak_gb:.2f} GB")
    print("[smoke_train] OK")


if __name__ == "__main__":
    main()