File size: 4,634 Bytes
0cfefd2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | """端到端训练循环烟囱测试:构造随机 batch,跑 1-2 步 trainer。
不依赖磁盘上的数据集,仅验证 forward/backward/loss/PCGrad/GradNorm 链路。
"""
from __future__ import annotations
import os
import sys
from pathlib import Path
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "src"))
import logging
import numpy as np
import torch
from wjad.model import E2EAVModel
from wjad.train.trainer import Trainer, TrainerConfig
def _make_dummy_batch(
B: int = 1,
T: int = 8,
H: int = 64,
W: int = 128,
num_classes: int = 22,
num_objects: int = 3,
) -> dict:
"""构造极小分辨率的随机 batch(CPU 烟囱测试用)。"""
images = torch.randn(B, T, 3, H, W)
ego_6d = torch.zeros(B, T, 6)
intr_vec = torch.tensor([[
W / 2, H / 2, W, H,
0.0, 0.5, 0.0, 0.0, 0.0, 0.0,
1.0,
]] * B)
extr_6d = torch.zeros(B, 6)
ego_future = torch.zeros(B, 24, 3)
ego_future_valid = torch.ones(B, 24, dtype=torch.bool)
targets = []
for _ in range(B):
boxes = torch.zeros(num_objects, 7)
boxes[:, 3:6] = 2.0
targets.append({
"labels": torch.randint(1, num_classes, (num_objects,)),
"boxes": boxes,
"is_dynamic": torch.ones(num_objects, dtype=torch.long),
"future_traj": torch.zeros(num_objects, 24, 3),
"future_valid": torch.ones(num_objects, 24, dtype=torch.bool),
})
return {
"images": images,
"ego_6d": ego_6d,
"intr_vec": intr_vec,
"extr_6d": extr_6d,
"ego_future": ego_future,
"ego_future_valid": ego_future_valid,
"targets": targets,
"meta": [{}] * B,
}
def main() -> None:
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
torch.manual_seed(0)
has_cuda = torch.cuda.is_available()
device = "cuda" if has_cuda else "cpu"
if has_cuda:
# GPU 上跑接近真实规模:full 384x1024 + 完整 18 层
# a10g-small (~22 GB) 上 BS=4 OOM,启用 gradient_checkpointing 后 BS=2 稳定
H, W = 384, 1024
B = 2
num_dense, num_moe = 9, 9
num_det = 1024
num_extra = 256
amp = "bf16"
n_steps = 4
use_grad_ckpt = True
else:
# CPU 上跑极小规模仅做 sanity
H, W = 64, 128
B = 1
num_dense, num_moe = 2, 2
num_det = 32
num_extra = 16
amp = "fp32"
n_steps = 4
use_grad_ckpt = False
print(f"[smoke_train] device={device}, H={H} W={W} B={B} amp={amp} grad_ckpt={use_grad_ckpt}")
model = E2EAVModel(
dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
num_dense_layers=num_dense,
num_moe_layers=num_moe,
num_detection_tokens=num_det,
num_control_tokens=24,
num_ego_tokens=8,
num_extra_tokens=num_extra,
num_classes=22,
image_h=H,
image_w=W,
patch_size=16,
)
if use_grad_ckpt:
model.backbone.set_gradient_checkpointing(True)
# sandbox a10g-small 不做 DINOv3 finetune(显存预算 22GB 不够),冻结即可
# 验证两阶段路径切换。完整训练交给 H100 Jobs。
model.dinov3.freeze()
cfg = TrainerConfig(
total_steps=n_steps,
warmup_steps=1,
base_lr=1e-4,
log_interval=1,
stage1_steps=2, # 跑到 stage2 验证切换路径
stage1_perturb_start=1,
enable_gradnorm=True,
enable_pcgrad=True, # 全程启用 PCGrad
mixed_precision=amp,
unfreeze_dinov3_at_stage2=False, # sandbox 显存有限,验证路径即可
)
trainer = Trainer(model, cfg, num_classes=22, device=device)
rng = np.random.default_rng(0)
if has_cuda:
torch.cuda.reset_peak_memory_stats()
for step in range(n_steps):
batch = _make_dummy_batch(B=B, H=H, W=W)
info = trainer.train_step(batch, rng)
print(
f"step={info['step']} stage={info['stage']} total={info['total_loss']:.4f} "
f"cls={info['L_cls']:.4f} box={info['L_box']:.4f} traj_obj={info['L_traj_obj']:.4f} "
f"weights={[f'{w:.2f}' for w in info['weights']]}"
)
if has_cuda:
peak_gb = torch.cuda.max_memory_allocated() / 1024**3
print(f"[smoke_train] CUDA peak memory = {peak_gb:.2f} GB")
print("[smoke_train] OK")
if __name__ == "__main__":
main()
|