Create tests.py
Browse files
tests.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Flow ensemble — smoke test + diagnostics.
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import sys, time
|
| 7 |
+
|
| 8 |
+
sys.path.insert(0, '.')
|
| 9 |
+
from flows import (
|
| 10 |
+
QuaternionFlow, QuaternionLiteFlow, VelocityFlow,
|
| 11 |
+
MagnitudeFlow, OrbitalFlow, AlignmentFlow, FlowEnsemble,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 15 |
+
B, n, k, d = 32, 128, 64, 256
|
| 16 |
+
|
| 17 |
+
anchors = torch.randn(B, k, d, device=dev)
|
| 18 |
+
queries = torch.randn(B, n, d, device=dev)
|
| 19 |
+
|
| 20 |
+
print("=" * 68)
|
| 21 |
+
print(" Flow Ensemble — Smoke Test")
|
| 22 |
+
print("=" * 68)
|
| 23 |
+
print(f" B={B} n={n} k={k} d={d} device={dev}")
|
| 24 |
+
|
| 25 |
+
# Test each flow independently
|
| 26 |
+
flows_cfg = [
|
| 27 |
+
('QuaternionFlow', lambda: QuaternionFlow(d, k, n_heads=4)),
|
| 28 |
+
('QuaternionLiteFlow', lambda: QuaternionLiteFlow(d, k)),
|
| 29 |
+
('VelocityFlow', lambda: VelocityFlow(d, k)),
|
| 30 |
+
('MagnitudeFlow', lambda: MagnitudeFlow(d, k)),
|
| 31 |
+
('OrbitalFlow', lambda: OrbitalFlow(d, k)),
|
| 32 |
+
('AlignmentFlow', lambda: AlignmentFlow(d, k)),
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
print(f"\n {'Flow':<22} {'Params':>8} {'Out shape':>14} {'Fwd (ms)':>10} {'Conf μ':>8}")
|
| 36 |
+
print(f" {'─'*22} {'─'*8} {'─'*14} {'─'*10} {'─'*8}")
|
| 37 |
+
|
| 38 |
+
live_flows = []
|
| 39 |
+
for name, ctor in flows_cfg:
|
| 40 |
+
try:
|
| 41 |
+
flow = ctor().to(dev)
|
| 42 |
+
params = sum(p.numel() for p in flow.parameters())
|
| 43 |
+
|
| 44 |
+
# Warmup
|
| 45 |
+
for _ in range(3):
|
| 46 |
+
flow(anchors, queries)
|
| 47 |
+
if dev.type == 'cuda':
|
| 48 |
+
torch.cuda.synchronize()
|
| 49 |
+
|
| 50 |
+
# Time
|
| 51 |
+
t0 = time.perf_counter()
|
| 52 |
+
N_runs = 50
|
| 53 |
+
for _ in range(N_runs):
|
| 54 |
+
pred, conf = flow(anchors, queries)
|
| 55 |
+
if dev.type == 'cuda':
|
| 56 |
+
torch.cuda.synchronize()
|
| 57 |
+
elapsed = (time.perf_counter() - t0) / N_runs * 1000
|
| 58 |
+
|
| 59 |
+
print(f" {name:<22} {params:>8,} {str(tuple(pred.shape)):>14} {elapsed:>9.2f} {conf.mean().item():>8.3f}")
|
| 60 |
+
live_flows.append(flow)
|
| 61 |
+
except Exception as e:
|
| 62 |
+
print(f" {name:<22} FAILED: {str(e)[:40]}")
|
| 63 |
+
|
| 64 |
+
# Test ensemble
|
| 65 |
+
print(f"\n Ensemble tests:")
|
| 66 |
+
for fusion in ['weighted', 'gated', 'residual']:
|
| 67 |
+
try:
|
| 68 |
+
ens = FlowEnsemble(live_flows, d, fusion=fusion).to(dev)
|
| 69 |
+
params = sum(p.numel() for p in ens.parameters())
|
| 70 |
+
out = ens(anchors, queries)
|
| 71 |
+
print(f" {fusion:<12} params={params:>10,} out={tuple(out.shape)} norm={out.norm(dim=-1).mean():.3f}")
|
| 72 |
+
|
| 73 |
+
# Diagnostics
|
| 74 |
+
diag = ens.flow_diagnostics(anchors, queries)
|
| 75 |
+
for fname, stats in diag.items():
|
| 76 |
+
print(f" {fname:<18} conf={stats['confidence_mean']:.3f}±{stats['confidence_std']:.3f} "
|
| 77 |
+
f"residual={stats['residual_norm']:.3f} temp={stats['temperature']:.3f}")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f" {fusion:<12} FAILED: {str(e)[:50]}")
|
| 80 |
+
|
| 81 |
+
# Gradient flow test
|
| 82 |
+
print(f"\n Gradient flow:")
|
| 83 |
+
ens = FlowEnsemble(live_flows, d, fusion='weighted').to(dev)
|
| 84 |
+
out = ens(anchors, queries)
|
| 85 |
+
loss = out.sum()
|
| 86 |
+
loss.backward()
|
| 87 |
+
for flow in ens.flows:
|
| 88 |
+
grads = [p.grad is not None for p in flow.parameters()]
|
| 89 |
+
pct = sum(grads) / max(len(grads), 1) * 100
|
| 90 |
+
print(f" {flow.name:<18} {pct:.0f}% params have gradients")
|
| 91 |
+
|
| 92 |
+
print("=" * 68)
|