| """端到端训练循环烟囱测试:构造随机 batch,跑 1-2 步 trainer。 |
| |
| 不依赖磁盘上的数据集,仅验证 forward/backward/loss/PCGrad/GradNorm 链路。 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import sys |
| from pathlib import Path |
|
|
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
|
|
| ROOT = Path(__file__).resolve().parent.parent |
| sys.path.insert(0, str(ROOT / "src")) |
|
|
| import logging |
|
|
| import numpy as np |
| import torch |
|
|
| from wjad.model import E2EAVModel |
| from wjad.train.trainer import Trainer, TrainerConfig |
|
|
|
|
| def _make_dummy_batch( |
| B: int = 1, |
| T: int = 8, |
| H: int = 64, |
| W: int = 128, |
| num_classes: int = 22, |
| num_objects: int = 3, |
| ) -> dict: |
| """构造极小分辨率的随机 batch(CPU 烟囱测试用)。""" |
| images = torch.randn(B, T, 3, H, W) |
| ego_6d = torch.zeros(B, T, 6) |
| intr_vec = torch.tensor([[ |
| W / 2, H / 2, W, H, |
| 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, |
| 1.0, |
| ]] * B) |
| extr_6d = torch.zeros(B, 6) |
| ego_future = torch.zeros(B, 24, 3) |
| ego_future_valid = torch.ones(B, 24, dtype=torch.bool) |
|
|
| targets = [] |
| for _ in range(B): |
| boxes = torch.zeros(num_objects, 7) |
| boxes[:, 3:6] = 2.0 |
| targets.append({ |
| "labels": torch.randint(1, num_classes, (num_objects,)), |
| "boxes": boxes, |
| "is_dynamic": torch.ones(num_objects, dtype=torch.long), |
| "future_traj": torch.zeros(num_objects, 24, 3), |
| "future_valid": torch.ones(num_objects, 24, dtype=torch.bool), |
| }) |
|
|
| return { |
| "images": images, |
| "ego_6d": ego_6d, |
| "intr_vec": intr_vec, |
| "extr_6d": extr_6d, |
| "ego_future": ego_future, |
| "ego_future_valid": ego_future_valid, |
| "targets": targets, |
| "meta": [{}] * B, |
| } |
|
|
|
|
| def main() -> None: |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| torch.manual_seed(0) |
|
|
| has_cuda = torch.cuda.is_available() |
| device = "cuda" if has_cuda else "cpu" |
| if has_cuda: |
| |
| |
| H, W = 384, 1024 |
| B = 2 |
| num_dense, num_moe = 9, 9 |
| num_det = 1024 |
| num_extra = 256 |
| amp = "bf16" |
| n_steps = 4 |
| use_grad_ckpt = True |
| else: |
| |
| H, W = 64, 128 |
| B = 1 |
| num_dense, num_moe = 2, 2 |
| num_det = 32 |
| num_extra = 16 |
| amp = "fp32" |
| n_steps = 4 |
| use_grad_ckpt = False |
|
|
| print(f"[smoke_train] device={device}, H={H} W={W} B={B} amp={amp} grad_ckpt={use_grad_ckpt}") |
| model = E2EAVModel( |
| dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"), |
| num_dense_layers=num_dense, |
| num_moe_layers=num_moe, |
| num_detection_tokens=num_det, |
| num_control_tokens=24, |
| num_ego_tokens=8, |
| num_extra_tokens=num_extra, |
| num_classes=22, |
| image_h=H, |
| image_w=W, |
| patch_size=16, |
| ) |
| if use_grad_ckpt: |
| model.backbone.set_gradient_checkpointing(True) |
| |
| |
| model.dinov3.freeze() |
|
|
| cfg = TrainerConfig( |
| total_steps=n_steps, |
| warmup_steps=1, |
| base_lr=1e-4, |
| log_interval=1, |
| stage1_steps=2, |
| stage1_perturb_start=1, |
| enable_gradnorm=True, |
| enable_pcgrad=True, |
| mixed_precision=amp, |
| unfreeze_dinov3_at_stage2=False, |
| ) |
| trainer = Trainer(model, cfg, num_classes=22, device=device) |
| rng = np.random.default_rng(0) |
|
|
| if has_cuda: |
| torch.cuda.reset_peak_memory_stats() |
|
|
| for step in range(n_steps): |
| batch = _make_dummy_batch(B=B, H=H, W=W) |
| info = trainer.train_step(batch, rng) |
| print( |
| f"step={info['step']} stage={info['stage']} total={info['total_loss']:.4f} " |
| f"cls={info['L_cls']:.4f} box={info['L_box']:.4f} traj_obj={info['L_traj_obj']:.4f} " |
| f"weights={[f'{w:.2f}' for w in info['weights']]}" |
| ) |
|
|
| if has_cuda: |
| peak_gb = torch.cuda.max_memory_allocated() / 1024**3 |
| print(f"[smoke_train] CUDA peak memory = {peak_gb:.2f} GB") |
| print("[smoke_train] OK") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|