File size: 2,510 Bytes
0cfefd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""本地烟囱测试:用随机张量验证 forward + backward。

运行:
    python -m scripts.smoke_test
或:
    python scripts/smoke_test.py
"""

from __future__ import annotations

import sys
from pathlib import Path

# 允许直接运行
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "src"))

import torch

from wjad.model import E2EAVModel


def main() -> None:
    torch.manual_seed(0)
    device = "cpu"

    print("[smoke_test] 构建模型...")
    model = E2EAVModel(
        dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
        # 减小测试规模以适配 CPU
        num_dense_layers=2,
        num_moe_layers=2,
        num_detection_tokens=64,
        num_extra_tokens=32,
        num_classes=22,
    ).to(device)

    # 切到 sparse 验证 Top-3 路径也通
    model.backbone.set_moe_mode("sparse")

    B, T = 1, 8
    images = torch.randn(B, T, 3, 384, 1024, device=device)
    ego_6d = torch.zeros(B, T, 6, device=device)
    ego_6d[..., 0] = torch.linspace(0, 7, T)  # 模拟前进
    intr_vec = torch.tensor([[
        512.0, 192.0, 1024, 384,            # cx, cy, w, h
        0.0, 0.5, 0.0, 0.0, 0.0, 0.0,        # poly
        1.0,                                  # is_bw_poly(与 Cosmos 11 维一致)
    ]], device=device)
    extr_6d = torch.zeros(B, 6, device=device)

    print("[smoke_test] 前向...")
    out = model(images, ego_6d, intr_vec, extr_6d)
    print(f"  detection cls: {tuple(out.detection.cls_logits.shape)}")
    print(f"  detection box_mu: {tuple(out.detection.box3d_mu.shape)}")
    print(f"  detection traj_mu: {tuple(out.detection.traj_mu.shape)}")
    print(f"  control ego_traj_mu: {tuple(out.control.ego_traj_mu.shape)}")
    print(f"  control action_mu: {tuple(out.control.action_mu.shape)}")
    print(f"  moe_stats per layer: {len(out.backbone_out.moe_stats)}")

    # 简单 backward:用 cls + box_mu 和 + ego_traj_mu 的简单 loss
    loss = (
        out.detection.cls_logits.float().abs().mean()
        + out.detection.box3d_mu.float().abs().mean()
        + out.detection.traj_mu.float().abs().mean()
        + out.control.ego_traj_mu.float().abs().mean()
    )
    print(f"[smoke_test] loss = {loss.item():.6f}")
    loss.backward()
    grad_norm = sum(
        p.grad.detach().norm().item() for p in model.parameters() if p.grad is not None
    )
    print(f"[smoke_test] grad sum-of-norms = {grad_norm:.4f}")
    print("[smoke_test] OK")


if __name__ == "__main__":
    main()