WJAD / scripts /smoke_test.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""本地烟囱测试:用随机张量验证 forward + backward。
运行:
python -m scripts.smoke_test
或:
python scripts/smoke_test.py
"""
from __future__ import annotations
import sys
from pathlib import Path
# 允许直接运行
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "src"))
import torch
from wjad.model import E2EAVModel
def main() -> None:
torch.manual_seed(0)
device = "cpu"
print("[smoke_test] 构建模型...")
model = E2EAVModel(
dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
# 减小测试规模以适配 CPU
num_dense_layers=2,
num_moe_layers=2,
num_detection_tokens=64,
num_extra_tokens=32,
num_classes=22,
).to(device)
# 切到 sparse 验证 Top-3 路径也通
model.backbone.set_moe_mode("sparse")
B, T = 1, 8
images = torch.randn(B, T, 3, 384, 1024, device=device)
ego_6d = torch.zeros(B, T, 6, device=device)
ego_6d[..., 0] = torch.linspace(0, 7, T) # 模拟前进
intr_vec = torch.tensor([[
512.0, 192.0, 1024, 384, # cx, cy, w, h
0.0, 0.5, 0.0, 0.0, 0.0, 0.0, # poly
1.0, # is_bw_poly(与 Cosmos 11 维一致)
]], device=device)
extr_6d = torch.zeros(B, 6, device=device)
print("[smoke_test] 前向...")
out = model(images, ego_6d, intr_vec, extr_6d)
print(f" detection cls: {tuple(out.detection.cls_logits.shape)}")
print(f" detection box_mu: {tuple(out.detection.box3d_mu.shape)}")
print(f" detection traj_mu: {tuple(out.detection.traj_mu.shape)}")
print(f" control ego_traj_mu: {tuple(out.control.ego_traj_mu.shape)}")
print(f" control action_mu: {tuple(out.control.action_mu.shape)}")
print(f" moe_stats per layer: {len(out.backbone_out.moe_stats)}")
# 简单 backward:用 cls + box_mu 和 + ego_traj_mu 的简单 loss
loss = (
out.detection.cls_logits.float().abs().mean()
+ out.detection.box3d_mu.float().abs().mean()
+ out.detection.traj_mu.float().abs().mean()
+ out.control.ego_traj_mu.float().abs().mean()
)
print(f"[smoke_test] loss = {loss.item():.6f}")
loss.backward()
grad_norm = sum(
p.grad.detach().norm().item() for p in model.parameters() if p.grad is not None
)
print(f"[smoke_test] grad sum-of-norms = {grad_norm:.4f}")
print("[smoke_test] OK")
if __name__ == "__main__":
main()