|
|
|
|
|
""" |
|
|
Comprehensive WrinkleBrane Test Suite |
|
|
Tests the wave-interference associative memory capabilities. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
sys.path.append(str(Path(__file__).resolve().parent / "src")) |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
import time |
|
|
from wrinklebrane.membrane_bank import MembraneBank |
|
|
from wrinklebrane.codes import hadamard_codes, dct_codes, gaussian_codes, coherence_stats |
|
|
from wrinklebrane.slicer import make_slicer |
|
|
from wrinklebrane.write_ops import store_pairs |
|
|
from wrinklebrane.metrics import psnr, ssim |
|
|
|
|
|
def test_basic_storage_retrieval(): |
|
|
"""Test basic key-value storage and retrieval.""" |
|
|
print("๐งช Testing Basic Storage & Retrieval...") |
|
|
|
|
|
|
|
|
B, L, H, W, K = 1, 32, 16, 16, 8 |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f" Using device: {device}") |
|
|
|
|
|
|
|
|
bank = MembraneBank(L=L, H=H, W=W, device=device) |
|
|
bank.allocate(B) |
|
|
|
|
|
|
|
|
C = hadamard_codes(L, K).to(device) |
|
|
slicer = make_slicer(C) |
|
|
|
|
|
|
|
|
patterns = [] |
|
|
for i in range(K): |
|
|
pattern = torch.zeros(H, W, device=device) |
|
|
|
|
|
if i % 3 == 0: |
|
|
center = (H//2, W//2) |
|
|
radius = 3 + i//3 |
|
|
for y in range(H): |
|
|
for x in range(W): |
|
|
if (x - center[0])**2 + (y - center[1])**2 <= radius**2: |
|
|
pattern[y, x] = 1.0 |
|
|
elif i % 3 == 1: |
|
|
size = 4 + i//3 |
|
|
start = (H - size) // 2 |
|
|
pattern[start:start+size, start:start+size] = 1.0 |
|
|
else: |
|
|
for d in range(min(H, W)): |
|
|
if d + i//3 < H and d + i//3 < W: |
|
|
pattern[d + i//3, d] = 1.0 |
|
|
|
|
|
patterns.append(pattern) |
|
|
|
|
|
|
|
|
keys = torch.arange(K, device=device) |
|
|
values = torch.stack(patterns) |
|
|
alphas = torch.ones(K, device=device) |
|
|
|
|
|
|
|
|
M = store_pairs(bank.read(), C, keys, values, alphas) |
|
|
bank.write(M - bank.read()) |
|
|
|
|
|
|
|
|
readouts = slicer(bank.read()) |
|
|
readouts = readouts.squeeze(0) |
|
|
|
|
|
|
|
|
total_psnr = 0 |
|
|
total_ssim = 0 |
|
|
|
|
|
print(" Fidelity Results:") |
|
|
for i in range(K): |
|
|
original = patterns[i] |
|
|
retrieved = readouts[i] |
|
|
|
|
|
psnr_val = psnr(original.cpu().numpy(), retrieved.cpu().numpy()) |
|
|
ssim_val = ssim(original.cpu().numpy(), retrieved.cpu().numpy()) |
|
|
|
|
|
total_psnr += psnr_val |
|
|
total_ssim += ssim_val |
|
|
|
|
|
print(f" Pattern {i}: PSNR={psnr_val:.2f}dB, SSIM={ssim_val:.4f}") |
|
|
|
|
|
avg_psnr = total_psnr / K |
|
|
avg_ssim = total_ssim / K |
|
|
|
|
|
print(f" Average PSNR: {avg_psnr:.2f}dB") |
|
|
print(f" Average SSIM: {avg_ssim:.4f}") |
|
|
|
|
|
|
|
|
if avg_psnr > 80: |
|
|
print("โ
Basic storage & retrieval: HIGH FIDELITY") |
|
|
return True |
|
|
elif avg_psnr > 40: |
|
|
print("โ ๏ธ Basic storage & retrieval: MEDIUM FIDELITY") |
|
|
return True |
|
|
else: |
|
|
print("โ Basic storage & retrieval: LOW FIDELITY") |
|
|
return False |
|
|
|
|
|
def test_code_comparison(): |
|
|
"""Compare different orthogonal basis types.""" |
|
|
print("\n๐งช Testing Different Code Types...") |
|
|
|
|
|
L, K = 32, 16 |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
code_types = { |
|
|
"Hadamard": hadamard_codes(L, K).to(device), |
|
|
"DCT": dct_codes(L, K).to(device), |
|
|
"Gaussian": gaussian_codes(L, K).to(device) |
|
|
} |
|
|
|
|
|
for name, codes in code_types.items(): |
|
|
stats = coherence_stats(codes) |
|
|
print(f" {name} Codes:") |
|
|
print(f" Max off-diagonal: {stats['max_abs_offdiag']:.6f}") |
|
|
print(f" Mean off-diagonal: {stats['mean_abs_offdiag']:.6f}") |
|
|
|
|
|
|
|
|
G = codes.T @ codes |
|
|
I = torch.eye(K, device=device, dtype=codes.dtype) |
|
|
orthogonality_error = torch.norm(G - I).item() |
|
|
print(f" Orthogonality error: {orthogonality_error:.6f}") |
|
|
|
|
|
def test_capacity_scaling(): |
|
|
"""Test memory capacity with increasing load.""" |
|
|
print("\n๐งช Testing Capacity Scaling...") |
|
|
|
|
|
B, L, H, W = 1, 64, 8, 8 |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
capacities = [4, 8, 16, 32] |
|
|
|
|
|
for K in capacities: |
|
|
print(f" Testing {K} stored patterns...") |
|
|
|
|
|
|
|
|
bank = MembraneBank(L=L, H=H, W=W, device=device) |
|
|
bank.allocate(B) |
|
|
|
|
|
|
|
|
C = hadamard_codes(L, K).to(device) |
|
|
slicer = make_slicer(C) |
|
|
|
|
|
|
|
|
patterns = torch.rand(K, H, W, device=device) |
|
|
keys = torch.arange(K, device=device) |
|
|
alphas = torch.ones(K, device=device) |
|
|
|
|
|
|
|
|
M = store_pairs(bank.read(), C, keys, patterns, alphas) |
|
|
bank.write(M - bank.read()) |
|
|
|
|
|
readouts = slicer(bank.read()).squeeze(0) |
|
|
|
|
|
|
|
|
total_psnr = 0 |
|
|
for i in range(K): |
|
|
psnr_val = psnr(patterns[i].cpu().numpy(), readouts[i].cpu().numpy()) |
|
|
total_psnr += psnr_val |
|
|
|
|
|
avg_psnr = total_psnr / K |
|
|
print(f" Average PSNR: {avg_psnr:.2f}dB") |
|
|
|
|
|
def test_interference_analysis(): |
|
|
"""Test cross-talk between stored patterns.""" |
|
|
print("\n๐งช Testing Interference Analysis...") |
|
|
|
|
|
B, L, H, W, K = 1, 32, 16, 16, 8 |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
bank = MembraneBank(L=L, H=H, W=W, device=device) |
|
|
bank.allocate(B) |
|
|
|
|
|
C = hadamard_codes(L, K).to(device) |
|
|
slicer = make_slicer(C) |
|
|
|
|
|
|
|
|
active_keys = [0, 2, 4] |
|
|
patterns = torch.rand(len(active_keys), H, W, device=device) |
|
|
keys = torch.tensor(active_keys, device=device) |
|
|
alphas = torch.ones(len(active_keys), device=device) |
|
|
|
|
|
|
|
|
M = store_pairs(bank.read(), C, keys, patterns, alphas) |
|
|
bank.write(M - bank.read()) |
|
|
|
|
|
|
|
|
readouts = slicer(bank.read()).squeeze(0) |
|
|
|
|
|
print(" Interference Results:") |
|
|
for i in range(K): |
|
|
if i in active_keys: |
|
|
|
|
|
idx = active_keys.index(i) |
|
|
signal_power = torch.norm(readouts[i]).item() |
|
|
original_power = torch.norm(patterns[idx]).item() |
|
|
print(f" Channel {i} (stored): Signal power {signal_power:.4f} (original {original_power:.4f})") |
|
|
else: |
|
|
|
|
|
interference_power = torch.norm(readouts[i]).item() |
|
|
print(f" Channel {i} (empty): Interference {interference_power:.6f}") |
|
|
|
|
|
def performance_benchmark(): |
|
|
"""Benchmark WrinkleBrane performance.""" |
|
|
print("\nโก Performance Benchmark...") |
|
|
|
|
|
B, L, H, W, K = 4, 128, 32, 32, 64 |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
print(f" Configuration: B={B}, L={L}, H={H}, W={W}, K={K}") |
|
|
print(f" Memory footprint: {B*L*H*W*4/1e6:.1f}MB (membranes)") |
|
|
|
|
|
|
|
|
bank = MembraneBank(L=L, H=H, W=W, device=device) |
|
|
bank.allocate(B) |
|
|
|
|
|
C = hadamard_codes(L, K).to(device) |
|
|
slicer = make_slicer(C) |
|
|
|
|
|
patterns = torch.rand(K, H, W, device=device) |
|
|
keys = torch.arange(K, device=device) |
|
|
alphas = torch.ones(K, device=device) |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
for _ in range(10): |
|
|
M = store_pairs(bank.read(), C, keys, patterns, alphas) |
|
|
bank.write(M - bank.read()) |
|
|
write_time = (time.time() - start_time) / 10 |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
for _ in range(100): |
|
|
readouts = slicer(bank.read()) |
|
|
read_time = (time.time() - start_time) / 100 |
|
|
|
|
|
print(f" Write time: {write_time*1000:.2f}ms ({K/write_time:.0f} patterns/sec)") |
|
|
print(f" Read time: {read_time*1000:.2f}ms ({K*B/read_time:.0f} readouts/sec)") |
|
|
|
|
|
def main(): |
|
|
"""Run comprehensive WrinkleBrane test suite.""" |
|
|
print("๐ WrinkleBrane Comprehensive Test Suite") |
|
|
print("="*50) |
|
|
|
|
|
|
|
|
torch.manual_seed(42) |
|
|
np.random.seed(42) |
|
|
|
|
|
|
|
|
success = True |
|
|
|
|
|
try: |
|
|
success &= test_basic_storage_retrieval() |
|
|
test_code_comparison() |
|
|
test_capacity_scaling() |
|
|
test_interference_analysis() |
|
|
performance_benchmark() |
|
|
|
|
|
print("\n" + "="*50) |
|
|
if success: |
|
|
print("๐ WrinkleBrane: ALL TESTS PASSED") |
|
|
print(" Wave-interference associative memory working correctly!") |
|
|
else: |
|
|
print("โ ๏ธ WrinkleBrane: Some tests showed issues") |
|
|
print(" System functional but may need optimization") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nโ Test suite failed with error: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
return success |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |