| """本地烟囱测试:用随机张量验证 forward + backward。 |
| |
| 运行: |
| python -m scripts.smoke_test |
| 或: |
| python scripts/smoke_test.py |
| """ |
|
|
| from __future__ import annotations |
|
|
| import sys |
| from pathlib import Path |
|
|
| |
| ROOT = Path(__file__).resolve().parent.parent |
| sys.path.insert(0, str(ROOT / "src")) |
|
|
| import torch |
|
|
| from wjad.model import E2EAVModel |
|
|
|
|
| def main() -> None: |
| torch.manual_seed(0) |
| device = "cpu" |
|
|
| print("[smoke_test] 构建模型...") |
| model = E2EAVModel( |
| dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"), |
| |
| num_dense_layers=2, |
| num_moe_layers=2, |
| num_detection_tokens=64, |
| num_extra_tokens=32, |
| num_classes=22, |
| ).to(device) |
|
|
| |
| model.backbone.set_moe_mode("sparse") |
|
|
| B, T = 1, 8 |
| images = torch.randn(B, T, 3, 384, 1024, device=device) |
| ego_6d = torch.zeros(B, T, 6, device=device) |
| ego_6d[..., 0] = torch.linspace(0, 7, T) |
| intr_vec = torch.tensor([[ |
| 512.0, 192.0, 1024, 384, |
| 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, |
| 1.0, |
| ]], device=device) |
| extr_6d = torch.zeros(B, 6, device=device) |
|
|
| print("[smoke_test] 前向...") |
| out = model(images, ego_6d, intr_vec, extr_6d) |
| print(f" detection cls: {tuple(out.detection.cls_logits.shape)}") |
| print(f" detection box_mu: {tuple(out.detection.box3d_mu.shape)}") |
| print(f" detection traj_mu: {tuple(out.detection.traj_mu.shape)}") |
| print(f" control ego_traj_mu: {tuple(out.control.ego_traj_mu.shape)}") |
| print(f" control action_mu: {tuple(out.control.action_mu.shape)}") |
| print(f" moe_stats per layer: {len(out.backbone_out.moe_stats)}") |
|
|
| |
| loss = ( |
| out.detection.cls_logits.float().abs().mean() |
| + out.detection.box3d_mu.float().abs().mean() |
| + out.detection.traj_mu.float().abs().mean() |
| + out.control.ego_traj_mu.float().abs().mean() |
| ) |
| print(f"[smoke_test] loss = {loss.item():.6f}") |
| loss.backward() |
| grad_norm = sum( |
| p.grad.detach().norm().item() for p in model.parameters() if p.grad is not None |
| ) |
| print(f"[smoke_test] grad sum-of-norms = {grad_norm:.4f}") |
| print("[smoke_test] OK") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|