"""端到端训练循环烟囱测试:构造随机 batch,跑 1-2 步 trainer。 不依赖磁盘上的数据集,仅验证 forward/backward/loss/PCGrad/GradNorm 链路。 """ from __future__ import annotations import os import sys from pathlib import Path os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(ROOT / "src")) import logging import numpy as np import torch from wjad.model import E2EAVModel from wjad.train.trainer import Trainer, TrainerConfig def _make_dummy_batch( B: int = 1, T: int = 8, H: int = 64, W: int = 128, num_classes: int = 22, num_objects: int = 3, ) -> dict: """构造极小分辨率的随机 batch(CPU 烟囱测试用)。""" images = torch.randn(B, T, 3, H, W) ego_6d = torch.zeros(B, T, 6) intr_vec = torch.tensor([[ W / 2, H / 2, W, H, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, 1.0, ]] * B) extr_6d = torch.zeros(B, 6) ego_future = torch.zeros(B, 24, 3) ego_future_valid = torch.ones(B, 24, dtype=torch.bool) targets = [] for _ in range(B): boxes = torch.zeros(num_objects, 7) boxes[:, 3:6] = 2.0 targets.append({ "labels": torch.randint(1, num_classes, (num_objects,)), "boxes": boxes, "is_dynamic": torch.ones(num_objects, dtype=torch.long), "future_traj": torch.zeros(num_objects, 24, 3), "future_valid": torch.ones(num_objects, 24, dtype=torch.bool), }) return { "images": images, "ego_6d": ego_6d, "intr_vec": intr_vec, "extr_6d": extr_6d, "ego_future": ego_future, "ego_future_valid": ego_future_valid, "targets": targets, "meta": [{}] * B, } def main() -> None: logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") torch.manual_seed(0) has_cuda = torch.cuda.is_available() device = "cuda" if has_cuda else "cpu" if has_cuda: # GPU 上跑接近真实规模:full 384x1024 + 完整 18 层 # a10g-small (~22 GB) 上 BS=4 OOM,启用 gradient_checkpointing 后 BS=2 稳定 H, W = 384, 1024 B = 2 num_dense, num_moe = 9, 9 num_det = 1024 num_extra = 256 amp = "bf16" n_steps = 4 use_grad_ckpt = True else: # CPU 上跑极小规模仅做 sanity H, W = 64, 128 B = 1 num_dense, num_moe = 2, 2 num_det = 32 num_extra = 16 amp = "fp32" n_steps = 4 use_grad_ckpt = False print(f"[smoke_train] device={device}, H={H} W={W} B={B} amp={amp} grad_ckpt={use_grad_ckpt}") model = E2EAVModel( dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"), num_dense_layers=num_dense, num_moe_layers=num_moe, num_detection_tokens=num_det, num_control_tokens=24, num_ego_tokens=8, num_extra_tokens=num_extra, num_classes=22, image_h=H, image_w=W, patch_size=16, ) if use_grad_ckpt: model.backbone.set_gradient_checkpointing(True) # sandbox a10g-small 不做 DINOv3 finetune(显存预算 22GB 不够),冻结即可 # 验证两阶段路径切换。完整训练交给 H100 Jobs。 model.dinov3.freeze() cfg = TrainerConfig( total_steps=n_steps, warmup_steps=1, base_lr=1e-4, log_interval=1, stage1_steps=2, # 跑到 stage2 验证切换路径 stage1_perturb_start=1, enable_gradnorm=True, enable_pcgrad=True, # 全程启用 PCGrad mixed_precision=amp, unfreeze_dinov3_at_stage2=False, # sandbox 显存有限,验证路径即可 ) trainer = Trainer(model, cfg, num_classes=22, device=device) rng = np.random.default_rng(0) if has_cuda: torch.cuda.reset_peak_memory_stats() for step in range(n_steps): batch = _make_dummy_batch(B=B, H=H, W=W) info = trainer.train_step(batch, rng) print( f"step={info['step']} stage={info['stage']} total={info['total_loss']:.4f} " f"cls={info['L_cls']:.4f} box={info['L_box']:.4f} traj_obj={info['L_traj_obj']:.4f} " f"weights={[f'{w:.2f}' for w in info['weights']]}" ) if has_cuda: peak_gb = torch.cuda.max_memory_allocated() / 1024**3 print(f"[smoke_train] CUDA peak memory = {peak_gb:.2f} GB") print("[smoke_train] OK") if __name__ == "__main__": main()