File size: 9,712 Bytes
dc2b9f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 |
#!/usr/bin/env python3
"""
Comprehensive WrinkleBrane Test Suite
Tests the wave-interference associative memory capabilities.
"""
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent / "src"))
import torch
import numpy as np
import time
from wrinklebrane.membrane_bank import MembraneBank
from wrinklebrane.codes import hadamard_codes, dct_codes, gaussian_codes, coherence_stats
from wrinklebrane.slicer import make_slicer
from wrinklebrane.write_ops import store_pairs
from wrinklebrane.metrics import psnr, ssim
def test_basic_storage_retrieval():
"""Test basic key-value storage and retrieval."""
print("🧪 Testing Basic Storage & Retrieval...")
# Parameters
B, L, H, W, K = 1, 32, 16, 16, 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" Using device: {device}")
# Create membrane bank and codes
bank = MembraneBank(L=L, H=H, W=W, device=device)
bank.allocate(B)
# Generate Hadamard codes for best orthogonality
C = hadamard_codes(L, K).to(device)
slicer = make_slicer(C)
# Create test patterns - simple geometric shapes
patterns = []
for i in range(K):
pattern = torch.zeros(H, W, device=device)
# Create distinct patterns: circles, squares, lines
if i % 3 == 0: # circles
center = (H//2, W//2)
radius = 3 + i//3
for y in range(H):
for x in range(W):
if (x - center[0])**2 + (y - center[1])**2 <= radius**2:
pattern[y, x] = 1.0
elif i % 3 == 1: # squares
size = 4 + i//3
start = (H - size) // 2
pattern[start:start+size, start:start+size] = 1.0
else: # diagonal lines
for d in range(min(H, W)):
if d + i//3 < H and d + i//3 < W:
pattern[d + i//3, d] = 1.0
patterns.append(pattern)
# Store patterns
keys = torch.arange(K, device=device)
values = torch.stack(patterns) # [K, H, W]
alphas = torch.ones(K, device=device)
# Write to membrane bank
M = store_pairs(bank.read(), C, keys, values, alphas)
bank.write(M - bank.read()) # Store the difference
# Read back all patterns
readouts = slicer(bank.read()) # [B, K, H, W]
readouts = readouts.squeeze(0) # [K, H, W]
# Calculate fidelity metrics
total_psnr = 0
total_ssim = 0
print(" Fidelity Results:")
for i in range(K):
original = patterns[i]
retrieved = readouts[i]
psnr_val = psnr(original.cpu().numpy(), retrieved.cpu().numpy())
ssim_val = ssim(original.cpu().numpy(), retrieved.cpu().numpy())
total_psnr += psnr_val
total_ssim += ssim_val
print(f" Pattern {i}: PSNR={psnr_val:.2f}dB, SSIM={ssim_val:.4f}")
avg_psnr = total_psnr / K
avg_ssim = total_ssim / K
print(f" Average PSNR: {avg_psnr:.2f}dB")
print(f" Average SSIM: {avg_ssim:.4f}")
# Success criteria from CLAUDE.md - expect >100dB PSNR
if avg_psnr > 80: # High fidelity threshold
print("✅ Basic storage & retrieval: HIGH FIDELITY")
return True
elif avg_psnr > 40:
print("⚠️ Basic storage & retrieval: MEDIUM FIDELITY")
return True
else:
print("❌ Basic storage & retrieval: LOW FIDELITY")
return False
def test_code_comparison():
"""Compare different orthogonal basis types."""
print("\n🧪 Testing Different Code Types...")
L, K = 32, 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Test different code types
code_types = {
"Hadamard": hadamard_codes(L, K).to(device),
"DCT": dct_codes(L, K).to(device),
"Gaussian": gaussian_codes(L, K).to(device)
}
for name, codes in code_types.items():
stats = coherence_stats(codes)
print(f" {name} Codes:")
print(f" Max off-diagonal: {stats['max_abs_offdiag']:.6f}")
print(f" Mean off-diagonal: {stats['mean_abs_offdiag']:.6f}")
# Check orthogonality
G = codes.T @ codes
I = torch.eye(K, device=device, dtype=codes.dtype)
orthogonality_error = torch.norm(G - I).item()
print(f" Orthogonality error: {orthogonality_error:.6f}")
def test_capacity_scaling():
"""Test memory capacity with increasing load."""
print("\n🧪 Testing Capacity Scaling...")
B, L, H, W = 1, 64, 8, 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Test different numbers of stored patterns
capacities = [4, 8, 16, 32]
for K in capacities:
print(f" Testing {K} stored patterns...")
# Create membrane bank
bank = MembraneBank(L=L, H=H, W=W, device=device)
bank.allocate(B)
# Use Hadamard codes for maximum orthogonality
C = hadamard_codes(L, K).to(device)
slicer = make_slicer(C)
# Generate random patterns
patterns = torch.rand(K, H, W, device=device)
keys = torch.arange(K, device=device)
alphas = torch.ones(K, device=device)
# Store and retrieve
M = store_pairs(bank.read(), C, keys, patterns, alphas)
bank.write(M - bank.read())
readouts = slicer(bank.read()).squeeze(0)
# Calculate average fidelity
total_psnr = 0
for i in range(K):
psnr_val = psnr(patterns[i].cpu().numpy(), readouts[i].cpu().numpy())
total_psnr += psnr_val
avg_psnr = total_psnr / K
print(f" Average PSNR: {avg_psnr:.2f}dB")
def test_interference_analysis():
"""Test cross-talk between stored patterns."""
print("\n🧪 Testing Interference Analysis...")
B, L, H, W, K = 1, 32, 16, 16, 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bank = MembraneBank(L=L, H=H, W=W, device=device)
bank.allocate(B)
C = hadamard_codes(L, K).to(device)
slicer = make_slicer(C)
# Store only a subset of patterns
active_keys = [0, 2, 4] # Store patterns 0, 2, 4
patterns = torch.rand(len(active_keys), H, W, device=device)
keys = torch.tensor(active_keys, device=device)
alphas = torch.ones(len(active_keys), device=device)
# Store patterns
M = store_pairs(bank.read(), C, keys, patterns, alphas)
bank.write(M - bank.read())
# Read all channels (including unused ones)
readouts = slicer(bank.read()).squeeze(0) # [K, H, W]
print(" Interference Results:")
for i in range(K):
if i in active_keys:
# This should have high signal
idx = active_keys.index(i)
signal_power = torch.norm(readouts[i]).item()
original_power = torch.norm(patterns[idx]).item()
print(f" Channel {i} (stored): Signal power {signal_power:.4f} (original {original_power:.4f})")
else:
# This should have low interference
interference_power = torch.norm(readouts[i]).item()
print(f" Channel {i} (empty): Interference {interference_power:.6f}")
def performance_benchmark():
"""Benchmark WrinkleBrane performance."""
print("\n⚡ Performance Benchmark...")
B, L, H, W, K = 4, 128, 32, 32, 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" Configuration: B={B}, L={L}, H={H}, W={W}, K={K}")
print(f" Memory footprint: {B*L*H*W*4/1e6:.1f}MB (membranes)")
# Setup
bank = MembraneBank(L=L, H=H, W=W, device=device)
bank.allocate(B)
C = hadamard_codes(L, K).to(device)
slicer = make_slicer(C)
patterns = torch.rand(K, H, W, device=device)
keys = torch.arange(K, device=device)
alphas = torch.ones(K, device=device)
# Benchmark write operation
start_time = time.time()
for _ in range(10):
M = store_pairs(bank.read(), C, keys, patterns, alphas)
bank.write(M - bank.read())
write_time = (time.time() - start_time) / 10
# Benchmark read operation
start_time = time.time()
for _ in range(100):
readouts = slicer(bank.read())
read_time = (time.time() - start_time) / 100
print(f" Write time: {write_time*1000:.2f}ms ({K/write_time:.0f} patterns/sec)")
print(f" Read time: {read_time*1000:.2f}ms ({K*B/read_time:.0f} readouts/sec)")
def main():
"""Run comprehensive WrinkleBrane test suite."""
print("🌊 WrinkleBrane Comprehensive Test Suite")
print("="*50)
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
# Run test suite
success = True
try:
success &= test_basic_storage_retrieval()
test_code_comparison()
test_capacity_scaling()
test_interference_analysis()
performance_benchmark()
print("\n" + "="*50)
if success:
print("🎉 WrinkleBrane: ALL TESTS PASSED")
print(" Wave-interference associative memory working correctly!")
else:
print("⚠️ WrinkleBrane: Some tests showed issues")
print(" System functional but may need optimization")
except Exception as e:
print(f"\n❌ Test suite failed with error: {e}")
import traceback
traceback.print_exc()
return False
return success
if __name__ == "__main__":
main() |