WrinkleBrane / comprehensive_test.py
WCNegentropy's picture
๐Ÿ“š Updated with scientifically rigorous documentation
dc2b9f3 verified
#!/usr/bin/env python3
"""
Comprehensive WrinkleBrane Test Suite
Tests the wave-interference associative memory capabilities.
"""
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent / "src"))
import torch
import numpy as np
import time
from wrinklebrane.membrane_bank import MembraneBank
from wrinklebrane.codes import hadamard_codes, dct_codes, gaussian_codes, coherence_stats
from wrinklebrane.slicer import make_slicer
from wrinklebrane.write_ops import store_pairs
from wrinklebrane.metrics import psnr, ssim
def test_basic_storage_retrieval():
"""Test basic key-value storage and retrieval."""
print("๐Ÿงช Testing Basic Storage & Retrieval...")
# Parameters
B, L, H, W, K = 1, 32, 16, 16, 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" Using device: {device}")
# Create membrane bank and codes
bank = MembraneBank(L=L, H=H, W=W, device=device)
bank.allocate(B)
# Generate Hadamard codes for best orthogonality
C = hadamard_codes(L, K).to(device)
slicer = make_slicer(C)
# Create test patterns - simple geometric shapes
patterns = []
for i in range(K):
pattern = torch.zeros(H, W, device=device)
# Create distinct patterns: circles, squares, lines
if i % 3 == 0: # circles
center = (H//2, W//2)
radius = 3 + i//3
for y in range(H):
for x in range(W):
if (x - center[0])**2 + (y - center[1])**2 <= radius**2:
pattern[y, x] = 1.0
elif i % 3 == 1: # squares
size = 4 + i//3
start = (H - size) // 2
pattern[start:start+size, start:start+size] = 1.0
else: # diagonal lines
for d in range(min(H, W)):
if d + i//3 < H and d + i//3 < W:
pattern[d + i//3, d] = 1.0
patterns.append(pattern)
# Store patterns
keys = torch.arange(K, device=device)
values = torch.stack(patterns) # [K, H, W]
alphas = torch.ones(K, device=device)
# Write to membrane bank
M = store_pairs(bank.read(), C, keys, values, alphas)
bank.write(M - bank.read()) # Store the difference
# Read back all patterns
readouts = slicer(bank.read()) # [B, K, H, W]
readouts = readouts.squeeze(0) # [K, H, W]
# Calculate fidelity metrics
total_psnr = 0
total_ssim = 0
print(" Fidelity Results:")
for i in range(K):
original = patterns[i]
retrieved = readouts[i]
psnr_val = psnr(original.cpu().numpy(), retrieved.cpu().numpy())
ssim_val = ssim(original.cpu().numpy(), retrieved.cpu().numpy())
total_psnr += psnr_val
total_ssim += ssim_val
print(f" Pattern {i}: PSNR={psnr_val:.2f}dB, SSIM={ssim_val:.4f}")
avg_psnr = total_psnr / K
avg_ssim = total_ssim / K
print(f" Average PSNR: {avg_psnr:.2f}dB")
print(f" Average SSIM: {avg_ssim:.4f}")
# Success criteria from CLAUDE.md - expect >100dB PSNR
if avg_psnr > 80: # High fidelity threshold
print("โœ… Basic storage & retrieval: HIGH FIDELITY")
return True
elif avg_psnr > 40:
print("โš ๏ธ Basic storage & retrieval: MEDIUM FIDELITY")
return True
else:
print("โŒ Basic storage & retrieval: LOW FIDELITY")
return False
def test_code_comparison():
"""Compare different orthogonal basis types."""
print("\n๐Ÿงช Testing Different Code Types...")
L, K = 32, 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Test different code types
code_types = {
"Hadamard": hadamard_codes(L, K).to(device),
"DCT": dct_codes(L, K).to(device),
"Gaussian": gaussian_codes(L, K).to(device)
}
for name, codes in code_types.items():
stats = coherence_stats(codes)
print(f" {name} Codes:")
print(f" Max off-diagonal: {stats['max_abs_offdiag']:.6f}")
print(f" Mean off-diagonal: {stats['mean_abs_offdiag']:.6f}")
# Check orthogonality
G = codes.T @ codes
I = torch.eye(K, device=device, dtype=codes.dtype)
orthogonality_error = torch.norm(G - I).item()
print(f" Orthogonality error: {orthogonality_error:.6f}")
def test_capacity_scaling():
"""Test memory capacity with increasing load."""
print("\n๐Ÿงช Testing Capacity Scaling...")
B, L, H, W = 1, 64, 8, 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Test different numbers of stored patterns
capacities = [4, 8, 16, 32]
for K in capacities:
print(f" Testing {K} stored patterns...")
# Create membrane bank
bank = MembraneBank(L=L, H=H, W=W, device=device)
bank.allocate(B)
# Use Hadamard codes for maximum orthogonality
C = hadamard_codes(L, K).to(device)
slicer = make_slicer(C)
# Generate random patterns
patterns = torch.rand(K, H, W, device=device)
keys = torch.arange(K, device=device)
alphas = torch.ones(K, device=device)
# Store and retrieve
M = store_pairs(bank.read(), C, keys, patterns, alphas)
bank.write(M - bank.read())
readouts = slicer(bank.read()).squeeze(0)
# Calculate average fidelity
total_psnr = 0
for i in range(K):
psnr_val = psnr(patterns[i].cpu().numpy(), readouts[i].cpu().numpy())
total_psnr += psnr_val
avg_psnr = total_psnr / K
print(f" Average PSNR: {avg_psnr:.2f}dB")
def test_interference_analysis():
"""Test cross-talk between stored patterns."""
print("\n๐Ÿงช Testing Interference Analysis...")
B, L, H, W, K = 1, 32, 16, 16, 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bank = MembraneBank(L=L, H=H, W=W, device=device)
bank.allocate(B)
C = hadamard_codes(L, K).to(device)
slicer = make_slicer(C)
# Store only a subset of patterns
active_keys = [0, 2, 4] # Store patterns 0, 2, 4
patterns = torch.rand(len(active_keys), H, W, device=device)
keys = torch.tensor(active_keys, device=device)
alphas = torch.ones(len(active_keys), device=device)
# Store patterns
M = store_pairs(bank.read(), C, keys, patterns, alphas)
bank.write(M - bank.read())
# Read all channels (including unused ones)
readouts = slicer(bank.read()).squeeze(0) # [K, H, W]
print(" Interference Results:")
for i in range(K):
if i in active_keys:
# This should have high signal
idx = active_keys.index(i)
signal_power = torch.norm(readouts[i]).item()
original_power = torch.norm(patterns[idx]).item()
print(f" Channel {i} (stored): Signal power {signal_power:.4f} (original {original_power:.4f})")
else:
# This should have low interference
interference_power = torch.norm(readouts[i]).item()
print(f" Channel {i} (empty): Interference {interference_power:.6f}")
def performance_benchmark():
"""Benchmark WrinkleBrane performance."""
print("\nโšก Performance Benchmark...")
B, L, H, W, K = 4, 128, 32, 32, 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" Configuration: B={B}, L={L}, H={H}, W={W}, K={K}")
print(f" Memory footprint: {B*L*H*W*4/1e6:.1f}MB (membranes)")
# Setup
bank = MembraneBank(L=L, H=H, W=W, device=device)
bank.allocate(B)
C = hadamard_codes(L, K).to(device)
slicer = make_slicer(C)
patterns = torch.rand(K, H, W, device=device)
keys = torch.arange(K, device=device)
alphas = torch.ones(K, device=device)
# Benchmark write operation
start_time = time.time()
for _ in range(10):
M = store_pairs(bank.read(), C, keys, patterns, alphas)
bank.write(M - bank.read())
write_time = (time.time() - start_time) / 10
# Benchmark read operation
start_time = time.time()
for _ in range(100):
readouts = slicer(bank.read())
read_time = (time.time() - start_time) / 100
print(f" Write time: {write_time*1000:.2f}ms ({K/write_time:.0f} patterns/sec)")
print(f" Read time: {read_time*1000:.2f}ms ({K*B/read_time:.0f} readouts/sec)")
def main():
"""Run comprehensive WrinkleBrane test suite."""
print("๐ŸŒŠ WrinkleBrane Comprehensive Test Suite")
print("="*50)
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
# Run test suite
success = True
try:
success &= test_basic_storage_retrieval()
test_code_comparison()
test_capacity_scaling()
test_interference_analysis()
performance_benchmark()
print("\n" + "="*50)
if success:
print("๐ŸŽ‰ WrinkleBrane: ALL TESTS PASSED")
print(" Wave-interference associative memory working correctly!")
else:
print("โš ๏ธ WrinkleBrane: Some tests showed issues")
print(" System functional but may need optimization")
except Exception as e:
print(f"\nโŒ Test suite failed with error: {e}")
import traceback
traceback.print_exc()
return False
return success
if __name__ == "__main__":
main()