WCNegentropy's picture
Upload 510 files
3d7f6c5 verified
#!/usr/bin/env python3
"""Test suite for Direction 6: Learnable Code Evolution.
Tests the learnable codebook system:
1. Backward compat — frozen learnable bank ≡ plain 1D bank
2. Gradient flow — C_raw receives gradients through write+read paths
3. Orthogonality regulariser — penalises non-orthogonal codes
4. Reconstruction training — codes improve fidelity over training steps
5. Hadamard vs learned — compare on structured data distributions
6. Coherence tracking — monitor coherence_stats during training
7. Overload regime — can learned codes extend capacity beyond K=L?
8. Continuous addressing — learnable codes + continuous read
Conventions (matching project test style):
- torch.manual_seed(42), np.random.seed(42)
- Setup → Store → Retrieve → Measure → Report
- PSNR thresholds: >100 dB EXCELLENT, >80 dB HIGH FIDELITY, >50 dB GOOD
"""
from __future__ import annotations
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent / "src"))
import math
import numpy as np
import torch
from torch import nn
from wrinklebrane.codes import (
hadamard_codes,
dct_codes,
gaussian_codes,
coherence_stats,
normalize_columns,
)
from wrinklebrane.membrane_1d import (
MembraneBank1D,
store_pairs_1d,
Slicer1D,
cosine_similarity_matrix,
token_retrieval_accuracy,
soft_code_weights_1d,
)
from wrinklebrane.learnable_codes import (
LearnableCodebook,
LearnableMemoryBank1D,
orthogonality_loss,
reconstruction_loss,
train_codebook,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_embeddings(T: int, D: int, seed: int = 42, signed: bool = True) -> torch.Tensor:
gen = torch.Generator().manual_seed(seed)
if signed:
return torch.randn(T, D, generator=gen) * 0.5
return torch.rand(T, D, generator=gen)
def _psnr(pred: torch.Tensor, target: torch.Tensor) -> float:
mse = float((pred.detach() - target.detach()).pow(2).mean())
if mse < 1e-30:
return 300.0
dr = float(target.detach().abs().max())
if dr < 1e-10:
dr = 1.0
return 10.0 * math.log10(dr ** 2 / mse)
# =====================================================================
# Test 1: Backward Compatibility
# =====================================================================
def test_1_backward_compat():
"""Frozen learnable bank should produce identical results to plain 1D."""
print("\n" + "=" * 60)
print("TEST 1: Backward Compatibility (frozen learnable ≡ plain 1D)")
print("=" * 60)
torch.manual_seed(42)
np.random.seed(42)
L, K, D = 64, 8, 128
B = 1
T = K
C = hadamard_codes(L, K)
embeddings = _make_embeddings(T, D)
keys = torch.arange(T, dtype=torch.long)
alphas = torch.ones(T)
# --- Plain 1D pipeline ---
M_plain = torch.zeros(B, L, D)
M_plain = store_pairs_1d(M_plain, C, keys, embeddings, alphas)
slicer = Slicer1D(C, bias=False, relu=False)
Y_plain = slicer(M_plain)
# --- Frozen learnable bank ---
bank = LearnableMemoryBank1D(L, K, D, init="hadamard", freeze_codes=True)
bank.allocate(B)
bank.store(keys, embeddings, alphas)
Y_learn = bank.retrieve()
max_diff = float((Y_plain - Y_learn).abs().max())
mean_diff = float((Y_plain - Y_learn).abs().mean())
print(f"\n Configuration: L={L}, K={K}, D={D}")
print(f" Max |Y_plain - Y_learn|: {max_diff:.2e}")
print(f" Mean |Y_plain - Y_learn|: {mean_diff:.2e}")
passed = max_diff < 1e-6
print(f"\n {'PASS' if passed else 'FAIL'}: Frozen learnable ≡ plain 1D "
f"(max diff = {max_diff:.2e})")
return passed
# =====================================================================
# Test 2: Gradient Flow
# =====================================================================
def test_2_gradient_flow():
"""C_raw receives gradients through the shared write+read path."""
print("\n" + "=" * 60)
print("TEST 2: Gradient Flow Through Shared Codebook")
print("=" * 60)
torch.manual_seed(42)
L, K, D = 16, 4, 32
B = 1
T = K
embeddings = _make_embeddings(T, D)
keys = torch.arange(T, dtype=torch.long)
alphas = torch.ones(T)
bank = LearnableMemoryBank1D(L, K, D, init="hadamard")
bank.allocate(B)
# Forward
bank.store(keys, embeddings, alphas)
Y = bank.retrieve()
# Loss: reconstruct stored embeddings
target = embeddings.unsqueeze(0).expand(B, -1, -1)
loss = reconstruction_loss(Y, target)
loss.backward()
print(f"\n Configuration: L={L}, K={K}, D={D}")
print(f" Loss: {loss.item():.6f}")
grad = bank.codebook.C_raw.grad
if grad is not None:
norm = float(grad.norm())
nonzero = int((grad.abs() > 1e-10).sum())
total = grad.numel()
print(f" C_raw gradient: norm = {norm:.6f} | nz = {nonzero}/{total}")
has_grad = norm > 1e-10
else:
print(f" C_raw gradient: NONE")
has_grad = False
passed = has_grad
print(f"\n {'PASS' if passed else 'FAIL'}: C_raw receives gradients")
return passed
# =====================================================================
# Test 3: Orthogonality Regulariser
# =====================================================================
def test_3_orthogonality_regulariser():
"""Orthogonality loss is zero for Hadamard, nonzero for random."""
print("\n" + "=" * 60)
print("TEST 3: Orthogonality Regulariser")
print("=" * 60)
torch.manual_seed(42)
L, K = 64, 8
configs = [
("hadamard", "hadamard"),
("dct", "dct"),
("gaussian", "gaussian"),
("random", "random"),
("identity", "identity"),
]
print(f"\n Configuration: L={L}, K={K}")
print(f"\n {'Init':>12} {'OrthoLoss':>12} {'MaxCoherence':>14} "
f"{'MeanCoherence':>14}")
print(f" {'-'*12} {'-'*12} {'-'*14} {'-'*14}")
results = {}
for label, init in configs:
cb = LearnableCodebook(L, K, init=init)
loss = float(cb.ortho_loss())
coh = cb.coherence()
results[label] = {
"loss": loss,
"max_coh": coh["max_abs_offdiag"],
"mean_coh": coh["mean_abs_offdiag"],
}
print(f" {label:>12} {loss:>11.6f} {coh['max_abs_offdiag']:>13.6f} "
f"{coh['mean_abs_offdiag']:>13.6f}")
# Hadamard should be near-zero; random should be significantly nonzero
hadamard_low = results["hadamard"]["loss"] < 1e-6
random_higher = results["random"]["loss"] > results["hadamard"]["loss"]
# Verify loss is differentiable
cb = LearnableCodebook(L, K, init="random")
loss = cb.ortho_loss()
loss.backward()
grad_exists = cb.C_raw.grad is not None and float(cb.C_raw.grad.norm()) > 0
print(f"\n Hadamard ortho_loss ≈ 0: {'YES' if hadamard_low else 'NO'}")
print(f" Random > Hadamard: {'YES' if random_higher else 'NO'}")
print(f" Loss is differentiable: {'YES' if grad_exists else 'NO'}")
passed = hadamard_low and random_higher and grad_exists
print(f"\n {'PASS' if passed else 'FAIL'}: Orthogonality regulariser")
return passed
# =====================================================================
# Test 4: Reconstruction Training
# =====================================================================
def test_4_reconstruction_training():
"""Training with reconstruction loss + ortho reg improves fidelity."""
print("\n" + "=" * 60)
print("TEST 4: Reconstruction Training")
print("=" * 60)
torch.manual_seed(42)
L, K, D = 32, 8, 64
B = 1
T = K
# Fixed data distribution: same embeddings every step
embeddings = _make_embeddings(T, D, seed=42)
keys = torch.arange(T, dtype=torch.long)
alphas = torch.ones(T)
def data_fn():
return keys, embeddings, alphas
# --- Baseline: random init, no training ---
bank_before = LearnableMemoryBank1D(L, K, D, init="random", freeze_codes=True)
bank_before.allocate(B)
bank_before.store(keys, embeddings, alphas)
Y_before = bank_before.retrieve()
psnr_before = sum(
_psnr(Y_before[0, k], embeddings[k]) for k in range(T)
) / T
# --- Train from random init ---
torch.manual_seed(42)
bank = LearnableMemoryBank1D(L, K, D, init="random")
history = train_codebook(
bank, data_fn,
n_steps=200, lr=1e-2, ortho_lambda=0.01,
B=B, log_every=20,
)
# Evaluate after training
bank.allocate(B)
bank.store(keys, embeddings, alphas)
Y_after = bank.retrieve()
psnr_after = sum(
_psnr(Y_after[0, k], embeddings[k]) for k in range(T)
) / T
print(f"\n Configuration: L={L}, K={K}, D={D}")
print(f" Init: random codes | Steps: 200 | lr: 1e-2 | λ_ortho: 0.01")
print(f"\n Training trajectory:")
print(f" {'Step':>6} {'Total':>10} {'Recon':>10} {'Ortho':>10} "
f"{'MaxCoh':>10} {'MeanCoh':>10}")
print(f" {'-'*6} {'-'*10} {'-'*10} {'-'*10} {'-'*10} {'-'*10}")
for h in history:
print(f" {h['step']:>6} {h['total_loss']:>9.6f} "
f"{h['recon_loss']:>9.6f} {h['ortho_loss']:>9.6f} "
f"{h['max_coherence']:>9.6f} {h['mean_coherence']:>9.6f}")
print(f"\n PSNR before training: {psnr_before:.1f} dB")
print(f" PSNR after training: {psnr_after:.1f} dB")
print(f" Improvement: {psnr_after - psnr_before:.1f} dB")
# Training should improve fidelity
improved = psnr_after > psnr_before + 5.0 # At least 5 dB improvement
loss_decreased = history[-1]["total_loss"] < history[0]["total_loss"]
passed = improved and loss_decreased
print(f"\n PSNR improved by >5 dB: {'YES' if improved else 'NO'}")
print(f" Loss decreased: {'YES' if loss_decreased else 'NO'}")
print(f"\n {'PASS' if passed else 'FAIL'}: Reconstruction training")
return passed
# =====================================================================
# Test 5: Hadamard vs Learned
# =====================================================================
def test_5_hadamard_vs_learned():
"""Compare fixed Hadamard vs learned codes on structured data."""
print("\n" + "=" * 60)
print("TEST 5: Hadamard vs Learned Codes")
print("=" * 60)
torch.manual_seed(42)
L, K, D = 32, 8, 64
B = 1
T = K
# Structured data: correlated embeddings (e.g. smooth, low-frequency)
gen = torch.Generator().manual_seed(42)
base = torch.randn(1, D, generator=gen) * 0.5
embeddings = base + torch.randn(T, D, generator=gen) * 0.1
keys = torch.arange(T, dtype=torch.long)
alphas = torch.ones(T)
def data_fn():
return keys, embeddings, alphas
# --- Fixed Hadamard ---
bank_had = LearnableMemoryBank1D(L, K, D, init="hadamard", freeze_codes=True)
bank_had.allocate(B)
bank_had.store(keys, embeddings, alphas)
Y_had = bank_had.retrieve()
psnr_had = sum(_psnr(Y_had[0, k], embeddings[k]) for k in range(T)) / T
coh_had = bank_had.coherence()
# --- Trained from Hadamard init ---
torch.manual_seed(42)
bank_learn_h = LearnableMemoryBank1D(L, K, D, init="hadamard")
train_codebook(
bank_learn_h, data_fn,
n_steps=200, lr=1e-2, ortho_lambda=0.01, B=B, log_every=200,
)
bank_learn_h.allocate(B)
bank_learn_h.store(keys, embeddings, alphas)
Y_learn_h = bank_learn_h.retrieve()
psnr_learn_h = sum(
_psnr(Y_learn_h[0, k], embeddings[k]) for k in range(T)
) / T
coh_learn_h = bank_learn_h.coherence()
# --- Trained from random init ---
torch.manual_seed(42)
bank_learn_r = LearnableMemoryBank1D(L, K, D, init="random")
train_codebook(
bank_learn_r, data_fn,
n_steps=200, lr=1e-2, ortho_lambda=0.01, B=B, log_every=200,
)
bank_learn_r.allocate(B)
bank_learn_r.store(keys, embeddings, alphas)
Y_learn_r = bank_learn_r.retrieve()
psnr_learn_r = sum(
_psnr(Y_learn_r[0, k], embeddings[k]) for k in range(T)
) / T
coh_learn_r = bank_learn_r.coherence()
print(f"\n Configuration: L={L}, K={K}, D={D}")
print(f" Data: correlated embeddings (base + noise)")
print(f"\n {'Config':<28} {'PSNR':>8} {'MaxCoh':>10} {'MeanCoh':>10}")
print(f" {'-'*28} {'-'*8} {'-'*10} {'-'*10}")
print(f" {'Fixed Hadamard':<28} {psnr_had:>7.1f}dB "
f"{coh_had['max_abs_offdiag']:>9.6f} "
f"{coh_had['mean_abs_offdiag']:>9.6f}")
print(f" {'Learned (Hadamard init)':<28} {psnr_learn_h:>7.1f}dB "
f"{coh_learn_h['max_abs_offdiag']:>9.6f} "
f"{coh_learn_h['mean_abs_offdiag']:>9.6f}")
print(f" {'Learned (random init)':<28} {psnr_learn_r:>7.1f}dB "
f"{coh_learn_r['max_abs_offdiag']:>9.6f} "
f"{coh_learn_r['mean_abs_offdiag']:>9.6f}")
# Fixed Hadamard should already be excellent (orthogonal codes)
hadamard_excellent = psnr_had > 100.0
# Learned from random should significantly improve over random baseline
learned_competitive = psnr_learn_r > 50.0
passed = hadamard_excellent and learned_competitive
print(f"\n Fixed Hadamard > 100 dB: {'YES' if hadamard_excellent else 'NO'}")
print(f" Learned (random init) > 50 dB: "
f"{'YES' if learned_competitive else 'NO'}")
print(f"\n {'PASS' if passed else 'FAIL'}: Hadamard vs learned")
return passed
# =====================================================================
# Test 6: Coherence Tracking
# =====================================================================
def test_6_coherence_tracking():
"""Monitor coherence_stats during training — should stay controlled."""
print("\n" + "=" * 60)
print("TEST 6: Coherence Tracking During Training")
print("=" * 60)
torch.manual_seed(42)
L, K, D = 32, 8, 64
B = 1
T = K
embeddings = _make_embeddings(T, D, seed=42)
keys = torch.arange(T, dtype=torch.long)
alphas = torch.ones(T)
def data_fn():
return keys, embeddings, alphas
# Train with different ortho_lambda values
lambdas = [0.0, 0.01, 0.1, 1.0]
print(f"\n Configuration: L={L}, K={K}, D={D}")
print(f" Training 200 steps from random init at various λ_ortho")
print(f"\n {'λ_ortho':>8} {'Final PSNR':>12} {'Final MaxCoh':>14} "
f"{'Final MeanCoh':>14} {'Final OrthoL':>14}")
print(f" {'-'*8} {'-'*12} {'-'*14} {'-'*14} {'-'*14}")
results = {}
for lam in lambdas:
torch.manual_seed(42)
bank = LearnableMemoryBank1D(L, K, D, init="random")
history = train_codebook(
bank, data_fn,
n_steps=200, lr=1e-2, ortho_lambda=lam, B=B, log_every=200,
)
bank.allocate(B)
bank.store(keys, embeddings, alphas)
Y = bank.retrieve()
psnr = sum(_psnr(Y[0, k], embeddings[k]) for k in range(T)) / T
coh = bank.coherence()
results[lam] = {
"psnr": psnr,
"max_coh": coh["max_abs_offdiag"],
"mean_coh": coh["mean_abs_offdiag"],
"ortho_loss": history[-1]["ortho_loss"],
}
print(f" {lam:>8.2f} {psnr:>11.1f}dB "
f"{coh['max_abs_offdiag']:>13.6f} "
f"{coh['mean_abs_offdiag']:>13.6f} "
f"{history[-1]['ortho_loss']:>13.6f}")
# Higher λ should produce lower coherence (more orthogonal)
coherence_controlled = (
results[1.0]["max_coh"] < results[0.0]["max_coh"] or
results[1.0]["max_coh"] < 0.3 # absolute threshold
)
# λ=0 may overfit to data but lose orthogonality
no_reg_higher_coh = results[0.0]["max_coh"] > results[1.0]["max_coh"]
passed = coherence_controlled
print(f"\n λ=1.0 controls coherence: "
f"{'YES' if coherence_controlled else 'NO'}")
print(f" λ=0 has higher coherence: "
f"{'YES' if no_reg_higher_coh else 'NO'}")
print(f"\n {'PASS' if passed else 'FAIL'}: Coherence tracking")
return passed
# =====================================================================
# Test 7: Overload Regime
# =====================================================================
def test_7_overload_regime():
"""Can learned codes extend effective capacity beyond K=L?"""
print("\n" + "=" * 60)
print("TEST 7: Overload Regime (K > L)")
print("=" * 60)
torch.manual_seed(42)
L = 16
K = 32 # 2× overload
D = 64
B = 1
T = K
embeddings = _make_embeddings(T, D, seed=42)
keys = torch.arange(T, dtype=torch.long)
alphas = torch.ones(T)
def data_fn():
return keys, embeddings, alphas
# --- Fixed Hadamard at overload ---
bank_had = LearnableMemoryBank1D(L, K, D, init="hadamard", freeze_codes=True)
bank_had.allocate(B)
bank_had.store(keys, embeddings, alphas)
Y_had = bank_had.retrieve()
psnr_had = sum(_psnr(Y_had[0, k], embeddings[k]) for k in range(T)) / T
# --- Learned codes at overload ---
torch.manual_seed(42)
bank_learn = LearnableMemoryBank1D(L, K, D, init="hadamard")
history = train_codebook(
bank_learn, data_fn,
n_steps=300, lr=1e-2, ortho_lambda=0.001, B=B, log_every=50,
)
bank_learn.allocate(B)
bank_learn.store(keys, embeddings, alphas)
Y_learn = bank_learn.retrieve()
psnr_learn = sum(
_psnr(Y_learn[0, k], embeddings[k]) for k in range(T)
) / T
print(f"\n Configuration: L={L}, K={K} (K/L = {K/L:.1f}×), D={D}")
print(f"\n Training trajectory:")
print(f" {'Step':>6} {'Total':>10} {'Recon':>10} {'Ortho':>10}")
print(f" {'-'*6} {'-'*10} {'-'*10} {'-'*10}")
for h in history:
print(f" {h['step']:>6} {h['total_loss']:>9.6f} "
f"{h['recon_loss']:>9.6f} {h['ortho_loss']:>9.6f}")
print(f"\n Fixed Hadamard at 2× overload: {psnr_had:.1f} dB")
print(f" Learned codes at 2× overload: {psnr_learn:.1f} dB")
# Learned codes should improve over fixed at overload
learned_better = psnr_learn > psnr_had
# Both should be degraded compared to within-capacity (< 140 dB)
both_degraded = psnr_had < 50.0 and psnr_learn < 100.0
passed = learned_better
print(f"\n Learned > Fixed at overload: "
f"{'YES' if learned_better else 'NO'} "
f"({psnr_learn:.1f} vs {psnr_had:.1f} dB)")
print(f" Both degraded vs capacity: {'YES' if both_degraded else 'NO'}")
print(f"\n {'PASS' if passed else 'FAIL'}: Overload regime")
return passed
# =====================================================================
# Test 8: Continuous Addressing Integration
# =====================================================================
def test_8_continuous_addressing():
"""Learnable codes work with continuous read path."""
print("\n" + "=" * 60)
print("TEST 8: Continuous Addressing + Learnable Codes")
print("=" * 60)
torch.manual_seed(42)
L, K, D = 32, 8, 64
D_query = L # match codebook dimension
B = 1
T = K
embeddings = _make_embeddings(T, D, seed=42)
keys = torch.arange(T, dtype=torch.long)
alphas = torch.ones(T)
bank = LearnableMemoryBank1D(L, K, D, init="hadamard")
bank.allocate(B)
# Store discrete (high fidelity)
bank.store(keys, embeddings, alphas)
# Read continuous: use codebook columns as queries
C = bank.codebook() # [L, K]
queries = C.T # [K, L=D_query] — each query is a code column
projection = C.detach().clone() # [L, K] — project into code space
temperatures = [1e-6, 0.01, 0.1, 1.0]
print(f"\n Configuration: L={L}, K={K}, D={D}, D_query={D_query}")
print(f" Store: discrete | Read: continuous at varying T")
print(f"\n {'Temp':>8} {'Avg PSNR':>10} {'Avg CosSim':>12}")
print(f" {'-'*8} {'-'*10} {'-'*12}")
results = {}
for temp in temperatures:
Y = bank.retrieve_continuous(queries, projection, temperature=temp)
psnrs = [_psnr(Y[0, k], embeddings[k]) for k in range(T)]
metrics = token_retrieval_accuracy(
Y[0].detach(), embeddings, threshold=0.999
)
avg_p = sum(psnrs) / len(psnrs)
results[temp] = {"psnr": avg_p, "cos": metrics["cosine_sim"]}
print(f" {temp:>8.1e} {avg_p:>9.1f}dB {metrics['cosine_sim']:>11.6f}")
# Verify gradient flows through learnable codes to continuous read
bank.allocate(B)
bank.store(keys, embeddings, alphas)
proj_param = nn.Parameter(C.detach().clone())
Y = bank.retrieve_continuous(queries.detach(), proj_param, temperature=0.01)
target = embeddings.unsqueeze(0)
loss = reconstruction_loss(Y, target)
loss.backward()
code_grad = bank.codebook.C_raw.grad is not None
proj_grad = proj_param.grad is not None
print(f"\n Gradient through C_raw: {'YES' if code_grad else 'NO'}")
print(f" Gradient through projection: {'YES' if proj_grad else 'NO'}")
low_t_ok = results[1e-6]["psnr"] > 80.0
degrades = results[1.0]["psnr"] < results[1e-6]["psnr"]
passed = low_t_ok and degrades and code_grad
print(f"\n Low T PSNR > 80 dB: {'YES' if low_t_ok else 'NO'} "
f"({results[1e-6]['psnr']:.1f} dB)")
print(f" Degrades at high T: {'YES' if degrades else 'NO'}")
print(f"\n {'PASS' if passed else 'FAIL'}: Continuous addressing + learnable codes")
return passed
# =====================================================================
# Main
# =====================================================================
def main():
print("=" * 60)
print(" DIRECTION 6: LEARNABLE CODE EVOLUTION")
print(" Learnable Codebook Test Suite")
print("=" * 60)
tests = [
("Backward Compat", test_1_backward_compat),
("Gradient Flow", test_2_gradient_flow),
("Orthogonality Regulariser", test_3_orthogonality_regulariser),
("Reconstruction Training", test_4_reconstruction_training),
("Hadamard vs Learned", test_5_hadamard_vs_learned),
("Coherence Tracking", test_6_coherence_tracking),
("Overload Regime", test_7_overload_regime),
("Continuous Addressing", test_8_continuous_addressing),
]
results = {}
for name, test_fn in tests:
try:
results[name] = test_fn()
except Exception as e:
print(f"\n ERROR in {name}: {e}")
import traceback
traceback.print_exc()
results[name] = False
print("\n" + "=" * 60)
print(" SUMMARY")
print("=" * 60)
for name, passed in results.items():
status = "PASS" if passed else "FAIL"
print(f" [{status}] {name}")
n_pass = sum(1 for p in results.values() if p)
n_total = len(results)
print(f"\n {n_pass}/{n_total} tests passed")
if n_pass == n_total:
print("\n ALL TESTS PASSED")
else:
print(f"\n {n_total - n_pass} FAILURES")
return n_pass == n_total
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)