""" Flow Ensemble — Expanded Test Suite. Assumes geolip-core is installed (Colab with repo loaded). Tests: smoke, linalg integration, multi-scale, ensemble fusion, gradient health, ablation, compile compatibility, memory. """ import torch import torch.nn as nn import torch.nn.functional as F import sys, time, gc # ── Verify geolip_core.linalg is available ── try: import geolip_core.linalg as LA HAS_GEOLIP_LINALG = True print(f"geolip_core.linalg: available") LA.backend.status() except ImportError: import torch.linalg as LA HAS_GEOLIP_LINALG = False print("geolip_core.linalg: NOT available, using torch.linalg fallback") dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def sync(): if dev.type == 'cuda': torch.cuda.synchronize() def time_fn(fn, warmup=5, runs=50): for _ in range(warmup): fn() sync() t0 = time.perf_counter() for _ in range(runs): fn() sync() return (time.perf_counter() - t0) / runs * 1000 def fmt(ms): if ms < 1: return f"{ms*1000:.0f}us" return f"{ms:.2f}ms" def make_data(B, n, k, d): anchors = F.normalize(torch.randn(B, k, d, device=dev), dim=-1) queries = F.normalize(torch.randn(B, n, d, device=dev), dim=-1) return anchors, queries # ═══════════════════════════════════════════════════════════════════ print("=" * 72) print(" Flow Ensemble — Expanded Test Suite") print("=" * 72) print(f" device={dev} geolip_core.linalg={HAS_GEOLIP_LINALG}") if dev.type == 'cuda': print(f" GPU: {torch.cuda.get_device_name()}") print() # ═══════════════════════════════════════════════════════════════════ # 1. SMOKE TEST — all flows, all shapes # ═══════════════════════════════════════════════════════════════════ print(f"{'='*72}\n 1. SMOKE TEST\n{'='*72}") B, n, k, d = 16, 64, 32, 128 anchors, queries = make_data(B, n, k, d) flows_cfg = [ ('QuaternionFlow', lambda d,k: QuaternionFlow(d, k, n_heads=4)), ('QuaternionLiteFlow', lambda d,k: QuaternionLiteFlow(d, k)), ('VelocityFlow', lambda d,k: VelocityFlow(d, k)), ('MagnitudeFlow', lambda d,k: MagnitudeFlow(d, k)), ('OrbitalFlow', lambda d,k: OrbitalFlow(d, k)), ('AlignmentFlow', lambda d,k: AlignmentFlow(d, k)), ] print(f"\n {'Flow':<22} {'Params':>8} {'Shape':>14} {'Time':>10} {'Conf':>8} {'Res norm':>10}") print(f" {'─'*22} {'─'*8} {'─'*14} {'─'*10} {'─'*8} {'─'*10}") live_flows = [] flow_ctors = [] for name, ctor in flows_cfg: try: flow = ctor(d, k).to(dev) params = sum(p.numel() for p in flow.parameters()) pred, conf = flow(anchors, queries) ms = time_fn(lambda: flow(anchors, queries)) res = (pred - queries).norm(dim=-1).mean().item() shape_str = str(tuple(pred.shape)) print(f" {name:<22} {params:>8,} {shape_str:>14} {fmt(ms):>10} {conf.mean().item():>8.3f} {res:>10.3f}") live_flows.append(flow) flow_ctors.append((name, ctor)) except Exception as e: print(f" {name:<22} FAILED: {str(e)[:50]}") # ═══════════════════════════════════════════════════════════════════ # 2. LINALG INTEGRATION # ═══════════════════════════════════════════════════════════════════ print(f"\n{'='*72}\n 2. LINALG INTEGRATION\n{'='*72}") if HAS_GEOLIP_LINALG: print(f"\n Testing eigh dispatch in MagnitudeFlow and OrbitalFlow...") for FlowCls in [MagnitudeFlow, OrbitalFlow]: flow = FlowCls(d, k).to(dev) pred, conf = flow(anchors, queries) ok = torch.isfinite(pred).all().item() and torch.isfinite(conf).all().item() print(f" {flow.name:<18} finite={ok} conf={conf.mean():.3f}") oflow = OrbitalFlow(d, k).to(dev) a_geom = oflow.anchor_proj(anchors) G = torch.bmm(a_geom.transpose(-2, -1), a_geom) vals, vecs = LA.eigh(G) print(f"\n Gram eigenspectrum: shape={tuple(vals.shape)} " f"range=[{vals.min().item():.4f}, {vals.max().item():.4f}]") print(f" Eigenvector orth err: {(torch.bmm(vecs.mT, vecs) - torch.eye(oflow.geom_dim, device=dev)).abs().max().item():.2e}") else: print(" Skipped — geolip_core.linalg not available") # ═══════════════════════════════════════════════════════════════════ # 3. MULTI-SCALE # ═══════════════════════════════════════════════════════════════════ print(f"\n{'='*72}\n 3. MULTI-SCALE\n{'='*72}") configs = [ (4, 16, 8, 64, 'tiny'), (16, 64, 32, 128, 'small'), (32, 128, 64, 256, 'medium'), (64, 256, 128, 256, 'large'), (8, 512, 256, 512, 'wide'), ] print(f"\n OrbitalFlow across scales:") print(f" {'Config':<10} {'B':>4} {'n':>5} {'k':>5} {'d':>5} {'Time':>10} {'OK':>4}") print(f" {'─'*10} {'─'*4} {'─'*5} {'─'*5} {'─'*5} {'─'*10} {'─'*4}") for B_, n_, k_, d_, label in configs: try: of = OrbitalFlow(d_, k_).to(dev) a, q = make_data(B_, n_, k_, d_) pred, conf = of(a, q) ms = time_fn(lambda: of(a, q), warmup=3, runs=20) ok = torch.isfinite(pred).all().item() print(f" {label:<10} {B_:>4} {n_:>5} {k_:>5} {d_:>5} {fmt(ms):>10} {'OK' if ok else 'NO':>4}") del of, a, q except Exception as e: print(f" {label:<10} {B_:>4} {n_:>5} {k_:>5} {d_:>5} FAILED: {str(e)[:30]}") # ═══════════════════════════════════════════════════════════════════ # 4. ENSEMBLE FUSION MODES # ═══════════════════════════════════════════════════════════════════ print(f"\n{'='*72}\n 4. ENSEMBLE FUSION\n{'='*72}") B, n, k, d = 16, 64, 32, 128 anchors, queries = make_data(B, n, k, d) for fusion in ['weighted', 'gated', 'residual']: ens = FlowEnsemble(live_flows, d, fusion=fusion).to(dev) out = ens(anchors, queries) ms = time_fn(lambda: ens(anchors, queries), warmup=3, runs=20) preds = [flow(anchors, queries)[0] for flow in ens.flows] cos_sims = [] for i in range(len(preds)): for j in range(i+1, len(preds)): cs = F.cosine_similarity(preds[i].flatten(1), preds[j].flatten(1), dim=-1).mean().item() cos_sims.append(cs) avg_sim = sum(cos_sims) / max(len(cos_sims), 1) print(f"\n {fusion}: time={fmt(ms)} norm={out.norm(dim=-1).mean():.3f} diversity={1-avg_sim:.3f}") diag = ens.flow_diagnostics(anchors, queries) for fname, stats in diag.items(): print(f" {fname:<18} conf={stats['confidence_mean']:.3f}±{stats['confidence_std']:.3f} " f"res={stats['residual_norm']:.3f}") del ens # ═══════════════════════════════════════════════════════════════════ # 5. GRADIENT HEALTH # ═══════════════════════════════════════════════════════════════════ print(f"\n{'='*72}\n 5. GRADIENT HEALTH\n{'='*72}") B, n, k, d = 16, 64, 32, 128 anchors, queries = make_data(B, n, k, d) losses = { 'mse': (lambda o,q: (o - q).pow(2).mean()), 'cosine': (lambda o,q: (1 - F.cosine_similarity(o, q, dim=-1)).mean()), 'norm': (lambda o,q: o.norm(dim=-1).mean()), } print(f"\n {'Flow':<18} {'Loss':<10} {'Grad norm':>12} {'Status':>8}") print(f" {'─'*18} {'─'*10} {'─'*12} {'─'*8}") for loss_name, loss_fn in losses.items(): # Fresh flows for each loss — avoids in-place grad corruption across losses try: test_flows_grad = [ctor(d, k).to(dev) for _, ctor in flow_ctors] ens_g = FlowEnsemble(test_flows_grad, d, fusion='residual').to(dev) ens_g.zero_grad() anchors_g = anchors.detach().clone().requires_grad_(True) queries_g = queries.detach().clone().requires_grad_(True) out = ens_g(anchors_g, queries_g) loss = loss_fn(out, queries_g.detach()) loss.backward() for flow in ens_g.flows: grads = [p.grad for p in flow.parameters() if p.grad is not None] if grads: gn = torch.cat([g.flatten() for g in grads]).norm().item() status = "OK" if 1e-8 < gn < 1e4 else "WARN" print(f" {flow.name:<18} {loss_name:<10} {gn:>12.2e} {status:>8}") else: print(f" {flow.name:<18} {loss_name:<10} {'no grads':>12} {'WARN':>8}") del ens_g, test_flows_grad except RuntimeError as e: if 'inplace' in str(e).lower() or 'in-place' in str(e).lower() or 'modified by' in str(e): print(f" {'*':>18} {loss_name:<10} {'IN-PLACE ERR':>12} {'NOTE':>8}") print(f" FL eigh deflation uses indexed assignment — needs .clone() fix") else: print(f" {'*':>18} {loss_name:<10} {'ERROR':>12}") print(f" {str(e)[:60]}") # ═══════════════════════════════════════════════════════════════════ # 6. ABLATION — solo vs pairs vs full ensemble # ═══════════════════════════════════════════════════════════════════ print(f"\n{'='*72}\n 6. ABLATION (100 training steps, rotation target)\n{'='*72}") B, n, k, d = 32, 128, 64, 256 anchors, queries = make_data(B, n, k, d) R = torch.linalg.qr(torch.randn(d, d, device=dev)).Q.unsqueeze(0) target = torch.bmm(queries, R.expand(B, -1, -1)) def eval_quality(model, anchors, queries, target, steps=100, lr=1e-3): opt = torch.optim.Adam(model.parameters(), lr=lr) for _ in range(steps): opt.zero_grad() pred = model(anchors, queries) if isinstance(model, FlowEnsemble) else model(anchors, queries)[0] loss = (pred - target).pow(2).mean() loss.backward() opt.step() with torch.no_grad(): pred = model(anchors, queries) if isinstance(model, FlowEnsemble) else model(anchors, queries)[0] return (pred - target).pow(2).mean().item() print(f"\n {'Configuration':<35} {'MSE':>10} {'Params':>10}") print(f" {'─'*35} {'─'*10} {'─'*10}") for name, ctor in flow_ctors: try: flow = ctor(d, k).to(dev) params = sum(p.numel() for p in flow.parameters()) mse = eval_quality(flow, anchors, queries, target) print(f" {name:<35} {mse:>10.4f} {params:>10,}") del flow except Exception as e: print(f" {name:<35} FAILED: {str(e)[:30]}") pairs = [ ('Quat + Orbital', [0, 4]), ('Velocity + Magnitude', [2, 3]), ('Orbital + Alignment', [4, 5]), ('Velocity + Orbital', [2, 4]), ] for pair_name, indices in pairs: try: pair_flows = [flow_ctors[i][1](d, k).to(dev) for i in indices if i < len(flow_ctors)] if len(pair_flows) >= 2: ens = FlowEnsemble(pair_flows, d, fusion='weighted').to(dev) params = sum(p.numel() for p in ens.parameters()) mse = eval_quality(ens, anchors, queries, target) print(f" {pair_name:<35} {mse:>10.4f} {params:>10,}") del ens, pair_flows except Exception as e: print(f" {pair_name:<35} FAILED: {str(e)[:30]}") for fusion in ['weighted', 'residual']: try: all_flows = [ctor(d, k).to(dev) for _, ctor in flow_ctors] ens = FlowEnsemble(all_flows, d, fusion=fusion).to(dev) params = sum(p.numel() for p in ens.parameters()) mse = eval_quality(ens, anchors, queries, target) print(f" {'Full (' + fusion + ')':<35} {mse:>10.4f} {params:>10,}") del ens, all_flows except Exception as e: print(f" {'Full (' + fusion + ')':<35} FAILED: {str(e)[:30]}") # ═══════════════════════════════════════════════════════════════════ # 7. COMPILE COMPATIBILITY # ═══════════════════════════════════════════════════════════════════ print(f"\n{'='*72}\n 7. COMPILE COMPATIBILITY\n{'='*72}") B, n, k, d = 8, 32, 16, 64 anchors, queries = make_data(B, n, k, d) print(f"\n {'Flow':<22} {'fullgraph':>12} {'Raw':>10} {'Compiled':>12}") print(f" {'─'*22} {'─'*12} {'─'*10} {'─'*12}") for name, ctor in flow_ctors: try: flow = ctor(d, k).to(dev) t_raw = time_fn(lambda: flow(anchors, queries), warmup=3, runs=30) try: compiled = torch.compile(flow, fullgraph=True) compiled(anchors, queries); sync() t_comp = time_fn(lambda: compiled(anchors, queries), warmup=3, runs=30) status = "OK" except Exception as e: t_comp = -1 status = str(e)[:12] t_str = fmt(t_comp) if t_comp > 0 else "N/A" print(f" {name:<22} {status:>12} {fmt(t_raw):>10} {t_str:>12}") del flow except Exception as e: print(f" {name:<22} FAILED: {str(e)[:40]}") # ═══════════════════════════════════════════════════════════════════ # 8. MEMORY # ═══════════════════════════════════════════════════════════════════ if dev.type == 'cuda': print(f"\n{'='*72}\n 8. MEMORY (B=32, n=128, k=64, d=256)\n{'='*72}") B, n, k, d = 32, 128, 64, 256 anchors, queries = make_data(B, n, k, d) print(f"\n {'Flow':<22} {'Peak MB':>10}") print(f" {'─'*22} {'─'*10}") for name, ctor in flow_ctors: try: flow = ctor(d, k).to(dev) torch.cuda.empty_cache(); gc.collect() torch.cuda.reset_peak_memory_stats() base = torch.cuda.memory_allocated() pred, conf = flow(anchors, queries); sync() peak = (torch.cuda.max_memory_allocated() - base) / 1024**2 print(f" {name:<22} {peak:>9.1f}") del flow, pred, conf except Exception as e: print(f" {name:<22} FAILED: {str(e)[:30]}") try: all_flows = [ctor(d, k).to(dev) for _, ctor in flow_ctors] ens = FlowEnsemble(all_flows, d, fusion='weighted').to(dev) torch.cuda.empty_cache(); gc.collect() torch.cuda.reset_peak_memory_stats() base = torch.cuda.memory_allocated() out = ens(anchors, queries); sync() peak = (torch.cuda.max_memory_allocated() - base) / 1024**2 print(f" {'Full ensemble':<22} {peak:>9.1f}") del ens, all_flows except Exception as e: print(f" {'Full ensemble':<22} FAILED: {str(e)[:30]}") print(f"\n{'='*72}") print(f" Done.") print(f"{'='*72}")