| |
| """Test suite for Direction 6: Learnable Code Evolution. |
| |
| Tests the learnable codebook system: |
| 1. Backward compat — frozen learnable bank ≡ plain 1D bank |
| 2. Gradient flow — C_raw receives gradients through write+read paths |
| 3. Orthogonality regulariser — penalises non-orthogonal codes |
| 4. Reconstruction training — codes improve fidelity over training steps |
| 5. Hadamard vs learned — compare on structured data distributions |
| 6. Coherence tracking — monitor coherence_stats during training |
| 7. Overload regime — can learned codes extend capacity beyond K=L? |
| 8. Continuous addressing — learnable codes + continuous read |
| |
| Conventions (matching project test style): |
| - torch.manual_seed(42), np.random.seed(42) |
| - Setup → Store → Retrieve → Measure → Report |
| - PSNR thresholds: >100 dB EXCELLENT, >80 dB HIGH FIDELITY, >50 dB GOOD |
| """ |
|
|
| from __future__ import annotations |
|
|
| import sys |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent / "src")) |
|
|
| import math |
| import numpy as np |
| import torch |
| from torch import nn |
|
|
| from wrinklebrane.codes import ( |
| hadamard_codes, |
| dct_codes, |
| gaussian_codes, |
| coherence_stats, |
| normalize_columns, |
| ) |
| from wrinklebrane.membrane_1d import ( |
| MembraneBank1D, |
| store_pairs_1d, |
| Slicer1D, |
| cosine_similarity_matrix, |
| token_retrieval_accuracy, |
| soft_code_weights_1d, |
| ) |
| from wrinklebrane.learnable_codes import ( |
| LearnableCodebook, |
| LearnableMemoryBank1D, |
| orthogonality_loss, |
| reconstruction_loss, |
| train_codebook, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def _make_embeddings(T: int, D: int, seed: int = 42, signed: bool = True) -> torch.Tensor: |
| gen = torch.Generator().manual_seed(seed) |
| if signed: |
| return torch.randn(T, D, generator=gen) * 0.5 |
| return torch.rand(T, D, generator=gen) |
|
|
|
|
| def _psnr(pred: torch.Tensor, target: torch.Tensor) -> float: |
| mse = float((pred.detach() - target.detach()).pow(2).mean()) |
| if mse < 1e-30: |
| return 300.0 |
| dr = float(target.detach().abs().max()) |
| if dr < 1e-10: |
| dr = 1.0 |
| return 10.0 * math.log10(dr ** 2 / mse) |
|
|
|
|
| |
| |
| |
| def test_1_backward_compat(): |
| """Frozen learnable bank should produce identical results to plain 1D.""" |
| print("\n" + "=" * 60) |
| print("TEST 1: Backward Compatibility (frozen learnable ≡ plain 1D)") |
| print("=" * 60) |
|
|
| torch.manual_seed(42) |
| np.random.seed(42) |
|
|
| L, K, D = 64, 8, 128 |
| B = 1 |
| T = K |
|
|
| C = hadamard_codes(L, K) |
| embeddings = _make_embeddings(T, D) |
| keys = torch.arange(T, dtype=torch.long) |
| alphas = torch.ones(T) |
|
|
| |
| M_plain = torch.zeros(B, L, D) |
| M_plain = store_pairs_1d(M_plain, C, keys, embeddings, alphas) |
| slicer = Slicer1D(C, bias=False, relu=False) |
| Y_plain = slicer(M_plain) |
|
|
| |
| bank = LearnableMemoryBank1D(L, K, D, init="hadamard", freeze_codes=True) |
| bank.allocate(B) |
| bank.store(keys, embeddings, alphas) |
| Y_learn = bank.retrieve() |
|
|
| max_diff = float((Y_plain - Y_learn).abs().max()) |
| mean_diff = float((Y_plain - Y_learn).abs().mean()) |
|
|
| print(f"\n Configuration: L={L}, K={K}, D={D}") |
| print(f" Max |Y_plain - Y_learn|: {max_diff:.2e}") |
| print(f" Mean |Y_plain - Y_learn|: {mean_diff:.2e}") |
|
|
| passed = max_diff < 1e-6 |
| print(f"\n {'PASS' if passed else 'FAIL'}: Frozen learnable ≡ plain 1D " |
| f"(max diff = {max_diff:.2e})") |
| return passed |
|
|
|
|
| |
| |
| |
| def test_2_gradient_flow(): |
| """C_raw receives gradients through the shared write+read path.""" |
| print("\n" + "=" * 60) |
| print("TEST 2: Gradient Flow Through Shared Codebook") |
| print("=" * 60) |
|
|
| torch.manual_seed(42) |
|
|
| L, K, D = 16, 4, 32 |
| B = 1 |
| T = K |
|
|
| embeddings = _make_embeddings(T, D) |
| keys = torch.arange(T, dtype=torch.long) |
| alphas = torch.ones(T) |
|
|
| bank = LearnableMemoryBank1D(L, K, D, init="hadamard") |
| bank.allocate(B) |
|
|
| |
| bank.store(keys, embeddings, alphas) |
| Y = bank.retrieve() |
|
|
| |
| target = embeddings.unsqueeze(0).expand(B, -1, -1) |
| loss = reconstruction_loss(Y, target) |
| loss.backward() |
|
|
| print(f"\n Configuration: L={L}, K={K}, D={D}") |
| print(f" Loss: {loss.item():.6f}") |
|
|
| grad = bank.codebook.C_raw.grad |
| if grad is not None: |
| norm = float(grad.norm()) |
| nonzero = int((grad.abs() > 1e-10).sum()) |
| total = grad.numel() |
| print(f" C_raw gradient: norm = {norm:.6f} | nz = {nonzero}/{total}") |
| has_grad = norm > 1e-10 |
| else: |
| print(f" C_raw gradient: NONE") |
| has_grad = False |
|
|
| passed = has_grad |
| print(f"\n {'PASS' if passed else 'FAIL'}: C_raw receives gradients") |
| return passed |
|
|
|
|
| |
| |
| |
| def test_3_orthogonality_regulariser(): |
| """Orthogonality loss is zero for Hadamard, nonzero for random.""" |
| print("\n" + "=" * 60) |
| print("TEST 3: Orthogonality Regulariser") |
| print("=" * 60) |
|
|
| torch.manual_seed(42) |
|
|
| L, K = 64, 8 |
|
|
| configs = [ |
| ("hadamard", "hadamard"), |
| ("dct", "dct"), |
| ("gaussian", "gaussian"), |
| ("random", "random"), |
| ("identity", "identity"), |
| ] |
|
|
| print(f"\n Configuration: L={L}, K={K}") |
| print(f"\n {'Init':>12} {'OrthoLoss':>12} {'MaxCoherence':>14} " |
| f"{'MeanCoherence':>14}") |
| print(f" {'-'*12} {'-'*12} {'-'*14} {'-'*14}") |
|
|
| results = {} |
| for label, init in configs: |
| cb = LearnableCodebook(L, K, init=init) |
| loss = float(cb.ortho_loss()) |
| coh = cb.coherence() |
| results[label] = { |
| "loss": loss, |
| "max_coh": coh["max_abs_offdiag"], |
| "mean_coh": coh["mean_abs_offdiag"], |
| } |
| print(f" {label:>12} {loss:>11.6f} {coh['max_abs_offdiag']:>13.6f} " |
| f"{coh['mean_abs_offdiag']:>13.6f}") |
|
|
| |
| hadamard_low = results["hadamard"]["loss"] < 1e-6 |
| random_higher = results["random"]["loss"] > results["hadamard"]["loss"] |
|
|
| |
| cb = LearnableCodebook(L, K, init="random") |
| loss = cb.ortho_loss() |
| loss.backward() |
| grad_exists = cb.C_raw.grad is not None and float(cb.C_raw.grad.norm()) > 0 |
|
|
| print(f"\n Hadamard ortho_loss ≈ 0: {'YES' if hadamard_low else 'NO'}") |
| print(f" Random > Hadamard: {'YES' if random_higher else 'NO'}") |
| print(f" Loss is differentiable: {'YES' if grad_exists else 'NO'}") |
|
|
| passed = hadamard_low and random_higher and grad_exists |
| print(f"\n {'PASS' if passed else 'FAIL'}: Orthogonality regulariser") |
| return passed |
|
|
|
|
| |
| |
| |
| def test_4_reconstruction_training(): |
| """Training with reconstruction loss + ortho reg improves fidelity.""" |
| print("\n" + "=" * 60) |
| print("TEST 4: Reconstruction Training") |
| print("=" * 60) |
|
|
| torch.manual_seed(42) |
|
|
| L, K, D = 32, 8, 64 |
| B = 1 |
| T = K |
|
|
| |
| embeddings = _make_embeddings(T, D, seed=42) |
| keys = torch.arange(T, dtype=torch.long) |
| alphas = torch.ones(T) |
|
|
| def data_fn(): |
| return keys, embeddings, alphas |
|
|
| |
| bank_before = LearnableMemoryBank1D(L, K, D, init="random", freeze_codes=True) |
| bank_before.allocate(B) |
| bank_before.store(keys, embeddings, alphas) |
| Y_before = bank_before.retrieve() |
| psnr_before = sum( |
| _psnr(Y_before[0, k], embeddings[k]) for k in range(T) |
| ) / T |
|
|
| |
| torch.manual_seed(42) |
| bank = LearnableMemoryBank1D(L, K, D, init="random") |
| history = train_codebook( |
| bank, data_fn, |
| n_steps=200, lr=1e-2, ortho_lambda=0.01, |
| B=B, log_every=20, |
| ) |
|
|
| |
| bank.allocate(B) |
| bank.store(keys, embeddings, alphas) |
| Y_after = bank.retrieve() |
| psnr_after = sum( |
| _psnr(Y_after[0, k], embeddings[k]) for k in range(T) |
| ) / T |
|
|
| print(f"\n Configuration: L={L}, K={K}, D={D}") |
| print(f" Init: random codes | Steps: 200 | lr: 1e-2 | λ_ortho: 0.01") |
|
|
| print(f"\n Training trajectory:") |
| print(f" {'Step':>6} {'Total':>10} {'Recon':>10} {'Ortho':>10} " |
| f"{'MaxCoh':>10} {'MeanCoh':>10}") |
| print(f" {'-'*6} {'-'*10} {'-'*10} {'-'*10} {'-'*10} {'-'*10}") |
| for h in history: |
| print(f" {h['step']:>6} {h['total_loss']:>9.6f} " |
| f"{h['recon_loss']:>9.6f} {h['ortho_loss']:>9.6f} " |
| f"{h['max_coherence']:>9.6f} {h['mean_coherence']:>9.6f}") |
|
|
| print(f"\n PSNR before training: {psnr_before:.1f} dB") |
| print(f" PSNR after training: {psnr_after:.1f} dB") |
| print(f" Improvement: {psnr_after - psnr_before:.1f} dB") |
|
|
| |
| improved = psnr_after > psnr_before + 5.0 |
| loss_decreased = history[-1]["total_loss"] < history[0]["total_loss"] |
|
|
| passed = improved and loss_decreased |
| print(f"\n PSNR improved by >5 dB: {'YES' if improved else 'NO'}") |
| print(f" Loss decreased: {'YES' if loss_decreased else 'NO'}") |
| print(f"\n {'PASS' if passed else 'FAIL'}: Reconstruction training") |
| return passed |
|
|
|
|
| |
| |
| |
| def test_5_hadamard_vs_learned(): |
| """Compare fixed Hadamard vs learned codes on structured data.""" |
| print("\n" + "=" * 60) |
| print("TEST 5: Hadamard vs Learned Codes") |
| print("=" * 60) |
|
|
| torch.manual_seed(42) |
|
|
| L, K, D = 32, 8, 64 |
| B = 1 |
| T = K |
|
|
| |
| gen = torch.Generator().manual_seed(42) |
| base = torch.randn(1, D, generator=gen) * 0.5 |
| embeddings = base + torch.randn(T, D, generator=gen) * 0.1 |
| keys = torch.arange(T, dtype=torch.long) |
| alphas = torch.ones(T) |
|
|
| def data_fn(): |
| return keys, embeddings, alphas |
|
|
| |
| bank_had = LearnableMemoryBank1D(L, K, D, init="hadamard", freeze_codes=True) |
| bank_had.allocate(B) |
| bank_had.store(keys, embeddings, alphas) |
| Y_had = bank_had.retrieve() |
| psnr_had = sum(_psnr(Y_had[0, k], embeddings[k]) for k in range(T)) / T |
| coh_had = bank_had.coherence() |
|
|
| |
| torch.manual_seed(42) |
| bank_learn_h = LearnableMemoryBank1D(L, K, D, init="hadamard") |
| train_codebook( |
| bank_learn_h, data_fn, |
| n_steps=200, lr=1e-2, ortho_lambda=0.01, B=B, log_every=200, |
| ) |
| bank_learn_h.allocate(B) |
| bank_learn_h.store(keys, embeddings, alphas) |
| Y_learn_h = bank_learn_h.retrieve() |
| psnr_learn_h = sum( |
| _psnr(Y_learn_h[0, k], embeddings[k]) for k in range(T) |
| ) / T |
| coh_learn_h = bank_learn_h.coherence() |
|
|
| |
| torch.manual_seed(42) |
| bank_learn_r = LearnableMemoryBank1D(L, K, D, init="random") |
| train_codebook( |
| bank_learn_r, data_fn, |
| n_steps=200, lr=1e-2, ortho_lambda=0.01, B=B, log_every=200, |
| ) |
| bank_learn_r.allocate(B) |
| bank_learn_r.store(keys, embeddings, alphas) |
| Y_learn_r = bank_learn_r.retrieve() |
| psnr_learn_r = sum( |
| _psnr(Y_learn_r[0, k], embeddings[k]) for k in range(T) |
| ) / T |
| coh_learn_r = bank_learn_r.coherence() |
|
|
| print(f"\n Configuration: L={L}, K={K}, D={D}") |
| print(f" Data: correlated embeddings (base + noise)") |
|
|
| print(f"\n {'Config':<28} {'PSNR':>8} {'MaxCoh':>10} {'MeanCoh':>10}") |
| print(f" {'-'*28} {'-'*8} {'-'*10} {'-'*10}") |
| print(f" {'Fixed Hadamard':<28} {psnr_had:>7.1f}dB " |
| f"{coh_had['max_abs_offdiag']:>9.6f} " |
| f"{coh_had['mean_abs_offdiag']:>9.6f}") |
| print(f" {'Learned (Hadamard init)':<28} {psnr_learn_h:>7.1f}dB " |
| f"{coh_learn_h['max_abs_offdiag']:>9.6f} " |
| f"{coh_learn_h['mean_abs_offdiag']:>9.6f}") |
| print(f" {'Learned (random init)':<28} {psnr_learn_r:>7.1f}dB " |
| f"{coh_learn_r['max_abs_offdiag']:>9.6f} " |
| f"{coh_learn_r['mean_abs_offdiag']:>9.6f}") |
|
|
| |
| hadamard_excellent = psnr_had > 100.0 |
| |
| learned_competitive = psnr_learn_r > 50.0 |
|
|
| passed = hadamard_excellent and learned_competitive |
| print(f"\n Fixed Hadamard > 100 dB: {'YES' if hadamard_excellent else 'NO'}") |
| print(f" Learned (random init) > 50 dB: " |
| f"{'YES' if learned_competitive else 'NO'}") |
| print(f"\n {'PASS' if passed else 'FAIL'}: Hadamard vs learned") |
| return passed |
|
|
|
|
| |
| |
| |
| def test_6_coherence_tracking(): |
| """Monitor coherence_stats during training — should stay controlled.""" |
| print("\n" + "=" * 60) |
| print("TEST 6: Coherence Tracking During Training") |
| print("=" * 60) |
|
|
| torch.manual_seed(42) |
|
|
| L, K, D = 32, 8, 64 |
| B = 1 |
| T = K |
|
|
| embeddings = _make_embeddings(T, D, seed=42) |
| keys = torch.arange(T, dtype=torch.long) |
| alphas = torch.ones(T) |
|
|
| def data_fn(): |
| return keys, embeddings, alphas |
|
|
| |
| lambdas = [0.0, 0.01, 0.1, 1.0] |
|
|
| print(f"\n Configuration: L={L}, K={K}, D={D}") |
| print(f" Training 200 steps from random init at various λ_ortho") |
|
|
| print(f"\n {'λ_ortho':>8} {'Final PSNR':>12} {'Final MaxCoh':>14} " |
| f"{'Final MeanCoh':>14} {'Final OrthoL':>14}") |
| print(f" {'-'*8} {'-'*12} {'-'*14} {'-'*14} {'-'*14}") |
|
|
| results = {} |
| for lam in lambdas: |
| torch.manual_seed(42) |
| bank = LearnableMemoryBank1D(L, K, D, init="random") |
| history = train_codebook( |
| bank, data_fn, |
| n_steps=200, lr=1e-2, ortho_lambda=lam, B=B, log_every=200, |
| ) |
|
|
| bank.allocate(B) |
| bank.store(keys, embeddings, alphas) |
| Y = bank.retrieve() |
| psnr = sum(_psnr(Y[0, k], embeddings[k]) for k in range(T)) / T |
| coh = bank.coherence() |
|
|
| results[lam] = { |
| "psnr": psnr, |
| "max_coh": coh["max_abs_offdiag"], |
| "mean_coh": coh["mean_abs_offdiag"], |
| "ortho_loss": history[-1]["ortho_loss"], |
| } |
| print(f" {lam:>8.2f} {psnr:>11.1f}dB " |
| f"{coh['max_abs_offdiag']:>13.6f} " |
| f"{coh['mean_abs_offdiag']:>13.6f} " |
| f"{history[-1]['ortho_loss']:>13.6f}") |
|
|
| |
| coherence_controlled = ( |
| results[1.0]["max_coh"] < results[0.0]["max_coh"] or |
| results[1.0]["max_coh"] < 0.3 |
| ) |
| |
| no_reg_higher_coh = results[0.0]["max_coh"] > results[1.0]["max_coh"] |
|
|
| passed = coherence_controlled |
| print(f"\n λ=1.0 controls coherence: " |
| f"{'YES' if coherence_controlled else 'NO'}") |
| print(f" λ=0 has higher coherence: " |
| f"{'YES' if no_reg_higher_coh else 'NO'}") |
| print(f"\n {'PASS' if passed else 'FAIL'}: Coherence tracking") |
| return passed |
|
|
|
|
| |
| |
| |
| def test_7_overload_regime(): |
| """Can learned codes extend effective capacity beyond K=L?""" |
| print("\n" + "=" * 60) |
| print("TEST 7: Overload Regime (K > L)") |
| print("=" * 60) |
|
|
| torch.manual_seed(42) |
|
|
| L = 16 |
| K = 32 |
| D = 64 |
| B = 1 |
| T = K |
|
|
| embeddings = _make_embeddings(T, D, seed=42) |
| keys = torch.arange(T, dtype=torch.long) |
| alphas = torch.ones(T) |
|
|
| def data_fn(): |
| return keys, embeddings, alphas |
|
|
| |
| bank_had = LearnableMemoryBank1D(L, K, D, init="hadamard", freeze_codes=True) |
| bank_had.allocate(B) |
| bank_had.store(keys, embeddings, alphas) |
| Y_had = bank_had.retrieve() |
| psnr_had = sum(_psnr(Y_had[0, k], embeddings[k]) for k in range(T)) / T |
|
|
| |
| torch.manual_seed(42) |
| bank_learn = LearnableMemoryBank1D(L, K, D, init="hadamard") |
| history = train_codebook( |
| bank_learn, data_fn, |
| n_steps=300, lr=1e-2, ortho_lambda=0.001, B=B, log_every=50, |
| ) |
|
|
| bank_learn.allocate(B) |
| bank_learn.store(keys, embeddings, alphas) |
| Y_learn = bank_learn.retrieve() |
| psnr_learn = sum( |
| _psnr(Y_learn[0, k], embeddings[k]) for k in range(T) |
| ) / T |
|
|
| print(f"\n Configuration: L={L}, K={K} (K/L = {K/L:.1f}×), D={D}") |
|
|
| print(f"\n Training trajectory:") |
| print(f" {'Step':>6} {'Total':>10} {'Recon':>10} {'Ortho':>10}") |
| print(f" {'-'*6} {'-'*10} {'-'*10} {'-'*10}") |
| for h in history: |
| print(f" {h['step']:>6} {h['total_loss']:>9.6f} " |
| f"{h['recon_loss']:>9.6f} {h['ortho_loss']:>9.6f}") |
|
|
| print(f"\n Fixed Hadamard at 2× overload: {psnr_had:.1f} dB") |
| print(f" Learned codes at 2× overload: {psnr_learn:.1f} dB") |
|
|
| |
| learned_better = psnr_learn > psnr_had |
| |
| both_degraded = psnr_had < 50.0 and psnr_learn < 100.0 |
|
|
| passed = learned_better |
| print(f"\n Learned > Fixed at overload: " |
| f"{'YES' if learned_better else 'NO'} " |
| f"({psnr_learn:.1f} vs {psnr_had:.1f} dB)") |
| print(f" Both degraded vs capacity: {'YES' if both_degraded else 'NO'}") |
| print(f"\n {'PASS' if passed else 'FAIL'}: Overload regime") |
| return passed |
|
|
|
|
| |
| |
| |
| def test_8_continuous_addressing(): |
| """Learnable codes work with continuous read path.""" |
| print("\n" + "=" * 60) |
| print("TEST 8: Continuous Addressing + Learnable Codes") |
| print("=" * 60) |
|
|
| torch.manual_seed(42) |
|
|
| L, K, D = 32, 8, 64 |
| D_query = L |
| B = 1 |
| T = K |
|
|
| embeddings = _make_embeddings(T, D, seed=42) |
| keys = torch.arange(T, dtype=torch.long) |
| alphas = torch.ones(T) |
|
|
| bank = LearnableMemoryBank1D(L, K, D, init="hadamard") |
| bank.allocate(B) |
|
|
| |
| bank.store(keys, embeddings, alphas) |
|
|
| |
| C = bank.codebook() |
| queries = C.T |
| projection = C.detach().clone() |
|
|
| temperatures = [1e-6, 0.01, 0.1, 1.0] |
|
|
| print(f"\n Configuration: L={L}, K={K}, D={D}, D_query={D_query}") |
| print(f" Store: discrete | Read: continuous at varying T") |
|
|
| print(f"\n {'Temp':>8} {'Avg PSNR':>10} {'Avg CosSim':>12}") |
| print(f" {'-'*8} {'-'*10} {'-'*12}") |
|
|
| results = {} |
| for temp in temperatures: |
| Y = bank.retrieve_continuous(queries, projection, temperature=temp) |
|
|
| psnrs = [_psnr(Y[0, k], embeddings[k]) for k in range(T)] |
| metrics = token_retrieval_accuracy( |
| Y[0].detach(), embeddings, threshold=0.999 |
| ) |
| avg_p = sum(psnrs) / len(psnrs) |
|
|
| results[temp] = {"psnr": avg_p, "cos": metrics["cosine_sim"]} |
| print(f" {temp:>8.1e} {avg_p:>9.1f}dB {metrics['cosine_sim']:>11.6f}") |
|
|
| |
| bank.allocate(B) |
| bank.store(keys, embeddings, alphas) |
| proj_param = nn.Parameter(C.detach().clone()) |
| Y = bank.retrieve_continuous(queries.detach(), proj_param, temperature=0.01) |
| target = embeddings.unsqueeze(0) |
| loss = reconstruction_loss(Y, target) |
| loss.backward() |
|
|
| code_grad = bank.codebook.C_raw.grad is not None |
| proj_grad = proj_param.grad is not None |
|
|
| print(f"\n Gradient through C_raw: {'YES' if code_grad else 'NO'}") |
| print(f" Gradient through projection: {'YES' if proj_grad else 'NO'}") |
|
|
| low_t_ok = results[1e-6]["psnr"] > 80.0 |
| degrades = results[1.0]["psnr"] < results[1e-6]["psnr"] |
|
|
| passed = low_t_ok and degrades and code_grad |
| print(f"\n Low T PSNR > 80 dB: {'YES' if low_t_ok else 'NO'} " |
| f"({results[1e-6]['psnr']:.1f} dB)") |
| print(f" Degrades at high T: {'YES' if degrades else 'NO'}") |
| print(f"\n {'PASS' if passed else 'FAIL'}: Continuous addressing + learnable codes") |
| return passed |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| print("=" * 60) |
| print(" DIRECTION 6: LEARNABLE CODE EVOLUTION") |
| print(" Learnable Codebook Test Suite") |
| print("=" * 60) |
|
|
| tests = [ |
| ("Backward Compat", test_1_backward_compat), |
| ("Gradient Flow", test_2_gradient_flow), |
| ("Orthogonality Regulariser", test_3_orthogonality_regulariser), |
| ("Reconstruction Training", test_4_reconstruction_training), |
| ("Hadamard vs Learned", test_5_hadamard_vs_learned), |
| ("Coherence Tracking", test_6_coherence_tracking), |
| ("Overload Regime", test_7_overload_regime), |
| ("Continuous Addressing", test_8_continuous_addressing), |
| ] |
|
|
| results = {} |
| for name, test_fn in tests: |
| try: |
| results[name] = test_fn() |
| except Exception as e: |
| print(f"\n ERROR in {name}: {e}") |
| import traceback |
| traceback.print_exc() |
| results[name] = False |
|
|
| print("\n" + "=" * 60) |
| print(" SUMMARY") |
| print("=" * 60) |
| for name, passed in results.items(): |
| status = "PASS" if passed else "FAIL" |
| print(f" [{status}] {name}") |
|
|
| n_pass = sum(1 for p in results.values() if p) |
| n_total = len(results) |
| print(f"\n {n_pass}/{n_total} tests passed") |
|
|
| if n_pass == n_total: |
| print("\n ALL TESTS PASSED") |
| else: |
| print(f"\n {n_total - n_pass} FAILURES") |
|
|
| return n_pass == n_total |
|
|
|
|
| if __name__ == "__main__": |
| success = main() |
| sys.exit(0 if success else 1) |
|
|