#!/usr/bin/env python3 """Test suite for Direction 6: Learnable Code Evolution. Tests the learnable codebook system: 1. Backward compat — frozen learnable bank ≡ plain 1D bank 2. Gradient flow — C_raw receives gradients through write+read paths 3. Orthogonality regulariser — penalises non-orthogonal codes 4. Reconstruction training — codes improve fidelity over training steps 5. Hadamard vs learned — compare on structured data distributions 6. Coherence tracking — monitor coherence_stats during training 7. Overload regime — can learned codes extend capacity beyond K=L? 8. Continuous addressing — learnable codes + continuous read Conventions (matching project test style): - torch.manual_seed(42), np.random.seed(42) - Setup → Store → Retrieve → Measure → Report - PSNR thresholds: >100 dB EXCELLENT, >80 dB HIGH FIDELITY, >50 dB GOOD """ from __future__ import annotations import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent / "src")) import math import numpy as np import torch from torch import nn from wrinklebrane.codes import ( hadamard_codes, dct_codes, gaussian_codes, coherence_stats, normalize_columns, ) from wrinklebrane.membrane_1d import ( MembraneBank1D, store_pairs_1d, Slicer1D, cosine_similarity_matrix, token_retrieval_accuracy, soft_code_weights_1d, ) from wrinklebrane.learnable_codes import ( LearnableCodebook, LearnableMemoryBank1D, orthogonality_loss, reconstruction_loss, train_codebook, ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_embeddings(T: int, D: int, seed: int = 42, signed: bool = True) -> torch.Tensor: gen = torch.Generator().manual_seed(seed) if signed: return torch.randn(T, D, generator=gen) * 0.5 return torch.rand(T, D, generator=gen) def _psnr(pred: torch.Tensor, target: torch.Tensor) -> float: mse = float((pred.detach() - target.detach()).pow(2).mean()) if mse < 1e-30: return 300.0 dr = float(target.detach().abs().max()) if dr < 1e-10: dr = 1.0 return 10.0 * math.log10(dr ** 2 / mse) # ===================================================================== # Test 1: Backward Compatibility # ===================================================================== def test_1_backward_compat(): """Frozen learnable bank should produce identical results to plain 1D.""" print("\n" + "=" * 60) print("TEST 1: Backward Compatibility (frozen learnable ≡ plain 1D)") print("=" * 60) torch.manual_seed(42) np.random.seed(42) L, K, D = 64, 8, 128 B = 1 T = K C = hadamard_codes(L, K) embeddings = _make_embeddings(T, D) keys = torch.arange(T, dtype=torch.long) alphas = torch.ones(T) # --- Plain 1D pipeline --- M_plain = torch.zeros(B, L, D) M_plain = store_pairs_1d(M_plain, C, keys, embeddings, alphas) slicer = Slicer1D(C, bias=False, relu=False) Y_plain = slicer(M_plain) # --- Frozen learnable bank --- bank = LearnableMemoryBank1D(L, K, D, init="hadamard", freeze_codes=True) bank.allocate(B) bank.store(keys, embeddings, alphas) Y_learn = bank.retrieve() max_diff = float((Y_plain - Y_learn).abs().max()) mean_diff = float((Y_plain - Y_learn).abs().mean()) print(f"\n Configuration: L={L}, K={K}, D={D}") print(f" Max |Y_plain - Y_learn|: {max_diff:.2e}") print(f" Mean |Y_plain - Y_learn|: {mean_diff:.2e}") passed = max_diff < 1e-6 print(f"\n {'PASS' if passed else 'FAIL'}: Frozen learnable ≡ plain 1D " f"(max diff = {max_diff:.2e})") return passed # ===================================================================== # Test 2: Gradient Flow # ===================================================================== def test_2_gradient_flow(): """C_raw receives gradients through the shared write+read path.""" print("\n" + "=" * 60) print("TEST 2: Gradient Flow Through Shared Codebook") print("=" * 60) torch.manual_seed(42) L, K, D = 16, 4, 32 B = 1 T = K embeddings = _make_embeddings(T, D) keys = torch.arange(T, dtype=torch.long) alphas = torch.ones(T) bank = LearnableMemoryBank1D(L, K, D, init="hadamard") bank.allocate(B) # Forward bank.store(keys, embeddings, alphas) Y = bank.retrieve() # Loss: reconstruct stored embeddings target = embeddings.unsqueeze(0).expand(B, -1, -1) loss = reconstruction_loss(Y, target) loss.backward() print(f"\n Configuration: L={L}, K={K}, D={D}") print(f" Loss: {loss.item():.6f}") grad = bank.codebook.C_raw.grad if grad is not None: norm = float(grad.norm()) nonzero = int((grad.abs() > 1e-10).sum()) total = grad.numel() print(f" C_raw gradient: norm = {norm:.6f} | nz = {nonzero}/{total}") has_grad = norm > 1e-10 else: print(f" C_raw gradient: NONE") has_grad = False passed = has_grad print(f"\n {'PASS' if passed else 'FAIL'}: C_raw receives gradients") return passed # ===================================================================== # Test 3: Orthogonality Regulariser # ===================================================================== def test_3_orthogonality_regulariser(): """Orthogonality loss is zero for Hadamard, nonzero for random.""" print("\n" + "=" * 60) print("TEST 3: Orthogonality Regulariser") print("=" * 60) torch.manual_seed(42) L, K = 64, 8 configs = [ ("hadamard", "hadamard"), ("dct", "dct"), ("gaussian", "gaussian"), ("random", "random"), ("identity", "identity"), ] print(f"\n Configuration: L={L}, K={K}") print(f"\n {'Init':>12} {'OrthoLoss':>12} {'MaxCoherence':>14} " f"{'MeanCoherence':>14}") print(f" {'-'*12} {'-'*12} {'-'*14} {'-'*14}") results = {} for label, init in configs: cb = LearnableCodebook(L, K, init=init) loss = float(cb.ortho_loss()) coh = cb.coherence() results[label] = { "loss": loss, "max_coh": coh["max_abs_offdiag"], "mean_coh": coh["mean_abs_offdiag"], } print(f" {label:>12} {loss:>11.6f} {coh['max_abs_offdiag']:>13.6f} " f"{coh['mean_abs_offdiag']:>13.6f}") # Hadamard should be near-zero; random should be significantly nonzero hadamard_low = results["hadamard"]["loss"] < 1e-6 random_higher = results["random"]["loss"] > results["hadamard"]["loss"] # Verify loss is differentiable cb = LearnableCodebook(L, K, init="random") loss = cb.ortho_loss() loss.backward() grad_exists = cb.C_raw.grad is not None and float(cb.C_raw.grad.norm()) > 0 print(f"\n Hadamard ortho_loss ≈ 0: {'YES' if hadamard_low else 'NO'}") print(f" Random > Hadamard: {'YES' if random_higher else 'NO'}") print(f" Loss is differentiable: {'YES' if grad_exists else 'NO'}") passed = hadamard_low and random_higher and grad_exists print(f"\n {'PASS' if passed else 'FAIL'}: Orthogonality regulariser") return passed # ===================================================================== # Test 4: Reconstruction Training # ===================================================================== def test_4_reconstruction_training(): """Training with reconstruction loss + ortho reg improves fidelity.""" print("\n" + "=" * 60) print("TEST 4: Reconstruction Training") print("=" * 60) torch.manual_seed(42) L, K, D = 32, 8, 64 B = 1 T = K # Fixed data distribution: same embeddings every step embeddings = _make_embeddings(T, D, seed=42) keys = torch.arange(T, dtype=torch.long) alphas = torch.ones(T) def data_fn(): return keys, embeddings, alphas # --- Baseline: random init, no training --- bank_before = LearnableMemoryBank1D(L, K, D, init="random", freeze_codes=True) bank_before.allocate(B) bank_before.store(keys, embeddings, alphas) Y_before = bank_before.retrieve() psnr_before = sum( _psnr(Y_before[0, k], embeddings[k]) for k in range(T) ) / T # --- Train from random init --- torch.manual_seed(42) bank = LearnableMemoryBank1D(L, K, D, init="random") history = train_codebook( bank, data_fn, n_steps=200, lr=1e-2, ortho_lambda=0.01, B=B, log_every=20, ) # Evaluate after training bank.allocate(B) bank.store(keys, embeddings, alphas) Y_after = bank.retrieve() psnr_after = sum( _psnr(Y_after[0, k], embeddings[k]) for k in range(T) ) / T print(f"\n Configuration: L={L}, K={K}, D={D}") print(f" Init: random codes | Steps: 200 | lr: 1e-2 | λ_ortho: 0.01") print(f"\n Training trajectory:") print(f" {'Step':>6} {'Total':>10} {'Recon':>10} {'Ortho':>10} " f"{'MaxCoh':>10} {'MeanCoh':>10}") print(f" {'-'*6} {'-'*10} {'-'*10} {'-'*10} {'-'*10} {'-'*10}") for h in history: print(f" {h['step']:>6} {h['total_loss']:>9.6f} " f"{h['recon_loss']:>9.6f} {h['ortho_loss']:>9.6f} " f"{h['max_coherence']:>9.6f} {h['mean_coherence']:>9.6f}") print(f"\n PSNR before training: {psnr_before:.1f} dB") print(f" PSNR after training: {psnr_after:.1f} dB") print(f" Improvement: {psnr_after - psnr_before:.1f} dB") # Training should improve fidelity improved = psnr_after > psnr_before + 5.0 # At least 5 dB improvement loss_decreased = history[-1]["total_loss"] < history[0]["total_loss"] passed = improved and loss_decreased print(f"\n PSNR improved by >5 dB: {'YES' if improved else 'NO'}") print(f" Loss decreased: {'YES' if loss_decreased else 'NO'}") print(f"\n {'PASS' if passed else 'FAIL'}: Reconstruction training") return passed # ===================================================================== # Test 5: Hadamard vs Learned # ===================================================================== def test_5_hadamard_vs_learned(): """Compare fixed Hadamard vs learned codes on structured data.""" print("\n" + "=" * 60) print("TEST 5: Hadamard vs Learned Codes") print("=" * 60) torch.manual_seed(42) L, K, D = 32, 8, 64 B = 1 T = K # Structured data: correlated embeddings (e.g. smooth, low-frequency) gen = torch.Generator().manual_seed(42) base = torch.randn(1, D, generator=gen) * 0.5 embeddings = base + torch.randn(T, D, generator=gen) * 0.1 keys = torch.arange(T, dtype=torch.long) alphas = torch.ones(T) def data_fn(): return keys, embeddings, alphas # --- Fixed Hadamard --- bank_had = LearnableMemoryBank1D(L, K, D, init="hadamard", freeze_codes=True) bank_had.allocate(B) bank_had.store(keys, embeddings, alphas) Y_had = bank_had.retrieve() psnr_had = sum(_psnr(Y_had[0, k], embeddings[k]) for k in range(T)) / T coh_had = bank_had.coherence() # --- Trained from Hadamard init --- torch.manual_seed(42) bank_learn_h = LearnableMemoryBank1D(L, K, D, init="hadamard") train_codebook( bank_learn_h, data_fn, n_steps=200, lr=1e-2, ortho_lambda=0.01, B=B, log_every=200, ) bank_learn_h.allocate(B) bank_learn_h.store(keys, embeddings, alphas) Y_learn_h = bank_learn_h.retrieve() psnr_learn_h = sum( _psnr(Y_learn_h[0, k], embeddings[k]) for k in range(T) ) / T coh_learn_h = bank_learn_h.coherence() # --- Trained from random init --- torch.manual_seed(42) bank_learn_r = LearnableMemoryBank1D(L, K, D, init="random") train_codebook( bank_learn_r, data_fn, n_steps=200, lr=1e-2, ortho_lambda=0.01, B=B, log_every=200, ) bank_learn_r.allocate(B) bank_learn_r.store(keys, embeddings, alphas) Y_learn_r = bank_learn_r.retrieve() psnr_learn_r = sum( _psnr(Y_learn_r[0, k], embeddings[k]) for k in range(T) ) / T coh_learn_r = bank_learn_r.coherence() print(f"\n Configuration: L={L}, K={K}, D={D}") print(f" Data: correlated embeddings (base + noise)") print(f"\n {'Config':<28} {'PSNR':>8} {'MaxCoh':>10} {'MeanCoh':>10}") print(f" {'-'*28} {'-'*8} {'-'*10} {'-'*10}") print(f" {'Fixed Hadamard':<28} {psnr_had:>7.1f}dB " f"{coh_had['max_abs_offdiag']:>9.6f} " f"{coh_had['mean_abs_offdiag']:>9.6f}") print(f" {'Learned (Hadamard init)':<28} {psnr_learn_h:>7.1f}dB " f"{coh_learn_h['max_abs_offdiag']:>9.6f} " f"{coh_learn_h['mean_abs_offdiag']:>9.6f}") print(f" {'Learned (random init)':<28} {psnr_learn_r:>7.1f}dB " f"{coh_learn_r['max_abs_offdiag']:>9.6f} " f"{coh_learn_r['mean_abs_offdiag']:>9.6f}") # Fixed Hadamard should already be excellent (orthogonal codes) hadamard_excellent = psnr_had > 100.0 # Learned from random should significantly improve over random baseline learned_competitive = psnr_learn_r > 50.0 passed = hadamard_excellent and learned_competitive print(f"\n Fixed Hadamard > 100 dB: {'YES' if hadamard_excellent else 'NO'}") print(f" Learned (random init) > 50 dB: " f"{'YES' if learned_competitive else 'NO'}") print(f"\n {'PASS' if passed else 'FAIL'}: Hadamard vs learned") return passed # ===================================================================== # Test 6: Coherence Tracking # ===================================================================== def test_6_coherence_tracking(): """Monitor coherence_stats during training — should stay controlled.""" print("\n" + "=" * 60) print("TEST 6: Coherence Tracking During Training") print("=" * 60) torch.manual_seed(42) L, K, D = 32, 8, 64 B = 1 T = K embeddings = _make_embeddings(T, D, seed=42) keys = torch.arange(T, dtype=torch.long) alphas = torch.ones(T) def data_fn(): return keys, embeddings, alphas # Train with different ortho_lambda values lambdas = [0.0, 0.01, 0.1, 1.0] print(f"\n Configuration: L={L}, K={K}, D={D}") print(f" Training 200 steps from random init at various λ_ortho") print(f"\n {'λ_ortho':>8} {'Final PSNR':>12} {'Final MaxCoh':>14} " f"{'Final MeanCoh':>14} {'Final OrthoL':>14}") print(f" {'-'*8} {'-'*12} {'-'*14} {'-'*14} {'-'*14}") results = {} for lam in lambdas: torch.manual_seed(42) bank = LearnableMemoryBank1D(L, K, D, init="random") history = train_codebook( bank, data_fn, n_steps=200, lr=1e-2, ortho_lambda=lam, B=B, log_every=200, ) bank.allocate(B) bank.store(keys, embeddings, alphas) Y = bank.retrieve() psnr = sum(_psnr(Y[0, k], embeddings[k]) for k in range(T)) / T coh = bank.coherence() results[lam] = { "psnr": psnr, "max_coh": coh["max_abs_offdiag"], "mean_coh": coh["mean_abs_offdiag"], "ortho_loss": history[-1]["ortho_loss"], } print(f" {lam:>8.2f} {psnr:>11.1f}dB " f"{coh['max_abs_offdiag']:>13.6f} " f"{coh['mean_abs_offdiag']:>13.6f} " f"{history[-1]['ortho_loss']:>13.6f}") # Higher λ should produce lower coherence (more orthogonal) coherence_controlled = ( results[1.0]["max_coh"] < results[0.0]["max_coh"] or results[1.0]["max_coh"] < 0.3 # absolute threshold ) # λ=0 may overfit to data but lose orthogonality no_reg_higher_coh = results[0.0]["max_coh"] > results[1.0]["max_coh"] passed = coherence_controlled print(f"\n λ=1.0 controls coherence: " f"{'YES' if coherence_controlled else 'NO'}") print(f" λ=0 has higher coherence: " f"{'YES' if no_reg_higher_coh else 'NO'}") print(f"\n {'PASS' if passed else 'FAIL'}: Coherence tracking") return passed # ===================================================================== # Test 7: Overload Regime # ===================================================================== def test_7_overload_regime(): """Can learned codes extend effective capacity beyond K=L?""" print("\n" + "=" * 60) print("TEST 7: Overload Regime (K > L)") print("=" * 60) torch.manual_seed(42) L = 16 K = 32 # 2× overload D = 64 B = 1 T = K embeddings = _make_embeddings(T, D, seed=42) keys = torch.arange(T, dtype=torch.long) alphas = torch.ones(T) def data_fn(): return keys, embeddings, alphas # --- Fixed Hadamard at overload --- bank_had = LearnableMemoryBank1D(L, K, D, init="hadamard", freeze_codes=True) bank_had.allocate(B) bank_had.store(keys, embeddings, alphas) Y_had = bank_had.retrieve() psnr_had = sum(_psnr(Y_had[0, k], embeddings[k]) for k in range(T)) / T # --- Learned codes at overload --- torch.manual_seed(42) bank_learn = LearnableMemoryBank1D(L, K, D, init="hadamard") history = train_codebook( bank_learn, data_fn, n_steps=300, lr=1e-2, ortho_lambda=0.001, B=B, log_every=50, ) bank_learn.allocate(B) bank_learn.store(keys, embeddings, alphas) Y_learn = bank_learn.retrieve() psnr_learn = sum( _psnr(Y_learn[0, k], embeddings[k]) for k in range(T) ) / T print(f"\n Configuration: L={L}, K={K} (K/L = {K/L:.1f}×), D={D}") print(f"\n Training trajectory:") print(f" {'Step':>6} {'Total':>10} {'Recon':>10} {'Ortho':>10}") print(f" {'-'*6} {'-'*10} {'-'*10} {'-'*10}") for h in history: print(f" {h['step']:>6} {h['total_loss']:>9.6f} " f"{h['recon_loss']:>9.6f} {h['ortho_loss']:>9.6f}") print(f"\n Fixed Hadamard at 2× overload: {psnr_had:.1f} dB") print(f" Learned codes at 2× overload: {psnr_learn:.1f} dB") # Learned codes should improve over fixed at overload learned_better = psnr_learn > psnr_had # Both should be degraded compared to within-capacity (< 140 dB) both_degraded = psnr_had < 50.0 and psnr_learn < 100.0 passed = learned_better print(f"\n Learned > Fixed at overload: " f"{'YES' if learned_better else 'NO'} " f"({psnr_learn:.1f} vs {psnr_had:.1f} dB)") print(f" Both degraded vs capacity: {'YES' if both_degraded else 'NO'}") print(f"\n {'PASS' if passed else 'FAIL'}: Overload regime") return passed # ===================================================================== # Test 8: Continuous Addressing Integration # ===================================================================== def test_8_continuous_addressing(): """Learnable codes work with continuous read path.""" print("\n" + "=" * 60) print("TEST 8: Continuous Addressing + Learnable Codes") print("=" * 60) torch.manual_seed(42) L, K, D = 32, 8, 64 D_query = L # match codebook dimension B = 1 T = K embeddings = _make_embeddings(T, D, seed=42) keys = torch.arange(T, dtype=torch.long) alphas = torch.ones(T) bank = LearnableMemoryBank1D(L, K, D, init="hadamard") bank.allocate(B) # Store discrete (high fidelity) bank.store(keys, embeddings, alphas) # Read continuous: use codebook columns as queries C = bank.codebook() # [L, K] queries = C.T # [K, L=D_query] — each query is a code column projection = C.detach().clone() # [L, K] — project into code space temperatures = [1e-6, 0.01, 0.1, 1.0] print(f"\n Configuration: L={L}, K={K}, D={D}, D_query={D_query}") print(f" Store: discrete | Read: continuous at varying T") print(f"\n {'Temp':>8} {'Avg PSNR':>10} {'Avg CosSim':>12}") print(f" {'-'*8} {'-'*10} {'-'*12}") results = {} for temp in temperatures: Y = bank.retrieve_continuous(queries, projection, temperature=temp) psnrs = [_psnr(Y[0, k], embeddings[k]) for k in range(T)] metrics = token_retrieval_accuracy( Y[0].detach(), embeddings, threshold=0.999 ) avg_p = sum(psnrs) / len(psnrs) results[temp] = {"psnr": avg_p, "cos": metrics["cosine_sim"]} print(f" {temp:>8.1e} {avg_p:>9.1f}dB {metrics['cosine_sim']:>11.6f}") # Verify gradient flows through learnable codes to continuous read bank.allocate(B) bank.store(keys, embeddings, alphas) proj_param = nn.Parameter(C.detach().clone()) Y = bank.retrieve_continuous(queries.detach(), proj_param, temperature=0.01) target = embeddings.unsqueeze(0) loss = reconstruction_loss(Y, target) loss.backward() code_grad = bank.codebook.C_raw.grad is not None proj_grad = proj_param.grad is not None print(f"\n Gradient through C_raw: {'YES' if code_grad else 'NO'}") print(f" Gradient through projection: {'YES' if proj_grad else 'NO'}") low_t_ok = results[1e-6]["psnr"] > 80.0 degrades = results[1.0]["psnr"] < results[1e-6]["psnr"] passed = low_t_ok and degrades and code_grad print(f"\n Low T PSNR > 80 dB: {'YES' if low_t_ok else 'NO'} " f"({results[1e-6]['psnr']:.1f} dB)") print(f" Degrades at high T: {'YES' if degrades else 'NO'}") print(f"\n {'PASS' if passed else 'FAIL'}: Continuous addressing + learnable codes") return passed # ===================================================================== # Main # ===================================================================== def main(): print("=" * 60) print(" DIRECTION 6: LEARNABLE CODE EVOLUTION") print(" Learnable Codebook Test Suite") print("=" * 60) tests = [ ("Backward Compat", test_1_backward_compat), ("Gradient Flow", test_2_gradient_flow), ("Orthogonality Regulariser", test_3_orthogonality_regulariser), ("Reconstruction Training", test_4_reconstruction_training), ("Hadamard vs Learned", test_5_hadamard_vs_learned), ("Coherence Tracking", test_6_coherence_tracking), ("Overload Regime", test_7_overload_regime), ("Continuous Addressing", test_8_continuous_addressing), ] results = {} for name, test_fn in tests: try: results[name] = test_fn() except Exception as e: print(f"\n ERROR in {name}: {e}") import traceback traceback.print_exc() results[name] = False print("\n" + "=" * 60) print(" SUMMARY") print("=" * 60) for name, passed in results.items(): status = "PASS" if passed else "FAIL" print(f" [{status}] {name}") n_pass = sum(1 for p in results.values() if p) n_total = len(results) print(f"\n {n_pass}/{n_total} tests passed") if n_pass == n_total: print("\n ALL TESTS PASSED") else: print(f"\n {n_total - n_pass} FAILURES") return n_pass == n_total if __name__ == "__main__": success = main() sys.exit(0 if success else 1)