AbstractPhil commited on
Commit
744d24d
·
verified ·
1 Parent(s): 0f7b996

Create tests.py

Browse files
Files changed (1) hide show
  1. tests.py +92 -0
tests.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Flow ensemble — smoke test + diagnostics.
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import sys, time
7
+
8
+ sys.path.insert(0, '.')
9
+ from flows import (
10
+ QuaternionFlow, QuaternionLiteFlow, VelocityFlow,
11
+ MagnitudeFlow, OrbitalFlow, AlignmentFlow, FlowEnsemble,
12
+ )
13
+
14
+ dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ B, n, k, d = 32, 128, 64, 256
16
+
17
+ anchors = torch.randn(B, k, d, device=dev)
18
+ queries = torch.randn(B, n, d, device=dev)
19
+
20
+ print("=" * 68)
21
+ print(" Flow Ensemble — Smoke Test")
22
+ print("=" * 68)
23
+ print(f" B={B} n={n} k={k} d={d} device={dev}")
24
+
25
+ # Test each flow independently
26
+ flows_cfg = [
27
+ ('QuaternionFlow', lambda: QuaternionFlow(d, k, n_heads=4)),
28
+ ('QuaternionLiteFlow', lambda: QuaternionLiteFlow(d, k)),
29
+ ('VelocityFlow', lambda: VelocityFlow(d, k)),
30
+ ('MagnitudeFlow', lambda: MagnitudeFlow(d, k)),
31
+ ('OrbitalFlow', lambda: OrbitalFlow(d, k)),
32
+ ('AlignmentFlow', lambda: AlignmentFlow(d, k)),
33
+ ]
34
+
35
+ print(f"\n {'Flow':<22} {'Params':>8} {'Out shape':>14} {'Fwd (ms)':>10} {'Conf μ':>8}")
36
+ print(f" {'─'*22} {'─'*8} {'─'*14} {'─'*10} {'─'*8}")
37
+
38
+ live_flows = []
39
+ for name, ctor in flows_cfg:
40
+ try:
41
+ flow = ctor().to(dev)
42
+ params = sum(p.numel() for p in flow.parameters())
43
+
44
+ # Warmup
45
+ for _ in range(3):
46
+ flow(anchors, queries)
47
+ if dev.type == 'cuda':
48
+ torch.cuda.synchronize()
49
+
50
+ # Time
51
+ t0 = time.perf_counter()
52
+ N_runs = 50
53
+ for _ in range(N_runs):
54
+ pred, conf = flow(anchors, queries)
55
+ if dev.type == 'cuda':
56
+ torch.cuda.synchronize()
57
+ elapsed = (time.perf_counter() - t0) / N_runs * 1000
58
+
59
+ print(f" {name:<22} {params:>8,} {str(tuple(pred.shape)):>14} {elapsed:>9.2f} {conf.mean().item():>8.3f}")
60
+ live_flows.append(flow)
61
+ except Exception as e:
62
+ print(f" {name:<22} FAILED: {str(e)[:40]}")
63
+
64
+ # Test ensemble
65
+ print(f"\n Ensemble tests:")
66
+ for fusion in ['weighted', 'gated', 'residual']:
67
+ try:
68
+ ens = FlowEnsemble(live_flows, d, fusion=fusion).to(dev)
69
+ params = sum(p.numel() for p in ens.parameters())
70
+ out = ens(anchors, queries)
71
+ print(f" {fusion:<12} params={params:>10,} out={tuple(out.shape)} norm={out.norm(dim=-1).mean():.3f}")
72
+
73
+ # Diagnostics
74
+ diag = ens.flow_diagnostics(anchors, queries)
75
+ for fname, stats in diag.items():
76
+ print(f" {fname:<18} conf={stats['confidence_mean']:.3f}±{stats['confidence_std']:.3f} "
77
+ f"residual={stats['residual_norm']:.3f} temp={stats['temperature']:.3f}")
78
+ except Exception as e:
79
+ print(f" {fusion:<12} FAILED: {str(e)[:50]}")
80
+
81
+ # Gradient flow test
82
+ print(f"\n Gradient flow:")
83
+ ens = FlowEnsemble(live_flows, d, fusion='weighted').to(dev)
84
+ out = ens(anchors, queries)
85
+ loss = out.sum()
86
+ loss.backward()
87
+ for flow in ens.flows:
88
+ grads = [p.grad is not None for p in flow.parameters()]
89
+ pct = sum(grads) / max(len(grads), 1) * 100
90
+ print(f" {flow.name:<18} {pct:.0f}% params have gradients")
91
+
92
+ print("=" * 68)