File size: 9,712 Bytes
dc2b9f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
#!/usr/bin/env python3
"""
Comprehensive WrinkleBrane Test Suite
Tests the wave-interference associative memory capabilities.
"""

import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent / "src"))

import torch
import numpy as np
import time
from wrinklebrane.membrane_bank import MembraneBank  
from wrinklebrane.codes import hadamard_codes, dct_codes, gaussian_codes, coherence_stats
from wrinklebrane.slicer import make_slicer
from wrinklebrane.write_ops import store_pairs
from wrinklebrane.metrics import psnr, ssim

def test_basic_storage_retrieval():
    """Test basic key-value storage and retrieval."""
    print("🧪 Testing Basic Storage & Retrieval...")
    
    # Parameters
    B, L, H, W, K = 1, 32, 16, 16, 8
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"   Using device: {device}")
    
    # Create membrane bank and codes
    bank = MembraneBank(L=L, H=H, W=W, device=device)
    bank.allocate(B)
    
    # Generate Hadamard codes for best orthogonality  
    C = hadamard_codes(L, K).to(device)
    slicer = make_slicer(C)
    
    # Create test patterns - simple geometric shapes
    patterns = []
    for i in range(K):
        pattern = torch.zeros(H, W, device=device)
        # Create distinct patterns: circles, squares, lines
        if i % 3 == 0:  # circles
            center = (H//2, W//2)
            radius = 3 + i//3
            for y in range(H):
                for x in range(W):
                    if (x - center[0])**2 + (y - center[1])**2 <= radius**2:
                        pattern[y, x] = 1.0
        elif i % 3 == 1:  # squares
            size = 4 + i//3
            start = (H - size) // 2
            pattern[start:start+size, start:start+size] = 1.0
        else:  # diagonal lines
            for d in range(min(H, W)):
                if d + i//3 < H and d + i//3 < W:
                    pattern[d + i//3, d] = 1.0
                    
        patterns.append(pattern)
    
    # Store patterns
    keys = torch.arange(K, device=device)
    values = torch.stack(patterns)  # [K, H, W]
    alphas = torch.ones(K, device=device)
    
    # Write to membrane bank
    M = store_pairs(bank.read(), C, keys, values, alphas)
    bank.write(M - bank.read())  # Store the difference
    
    # Read back all patterns
    readouts = slicer(bank.read())  # [B, K, H, W]
    readouts = readouts.squeeze(0)  # [K, H, W]
    
    # Calculate fidelity metrics
    total_psnr = 0
    total_ssim = 0
    
    print("   Fidelity Results:")
    for i in range(K):
        original = patterns[i]
        retrieved = readouts[i]
        
        psnr_val = psnr(original.cpu().numpy(), retrieved.cpu().numpy())
        ssim_val = ssim(original.cpu().numpy(), retrieved.cpu().numpy())
        
        total_psnr += psnr_val
        total_ssim += ssim_val
        
        print(f"     Pattern {i}: PSNR={psnr_val:.2f}dB, SSIM={ssim_val:.4f}")
    
    avg_psnr = total_psnr / K
    avg_ssim = total_ssim / K
    
    print(f"   Average PSNR: {avg_psnr:.2f}dB")
    print(f"   Average SSIM: {avg_ssim:.4f}")
    
    # Success criteria from CLAUDE.md - expect >100dB PSNR
    if avg_psnr > 80:  # High fidelity threshold
        print("✅ Basic storage & retrieval: HIGH FIDELITY")
        return True
    elif avg_psnr > 40:
        print("⚠️  Basic storage & retrieval: MEDIUM FIDELITY")
        return True
    else:
        print("❌ Basic storage & retrieval: LOW FIDELITY")
        return False

def test_code_comparison():
    """Compare different orthogonal basis types."""
    print("\n🧪 Testing Different Code Types...")
    
    L, K = 32, 16
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Test different code types
    code_types = {
        "Hadamard": hadamard_codes(L, K).to(device),
        "DCT": dct_codes(L, K).to(device), 
        "Gaussian": gaussian_codes(L, K).to(device)
    }
    
    for name, codes in code_types.items():
        stats = coherence_stats(codes)
        print(f"   {name} Codes:")
        print(f"     Max off-diagonal: {stats['max_abs_offdiag']:.6f}")
        print(f"     Mean off-diagonal: {stats['mean_abs_offdiag']:.6f}")
        
        # Check orthogonality
        G = codes.T @ codes
        I = torch.eye(K, device=device, dtype=codes.dtype)
        orthogonality_error = torch.norm(G - I).item()
        print(f"     Orthogonality error: {orthogonality_error:.6f}")

def test_capacity_scaling():
    """Test memory capacity with increasing load."""
    print("\n🧪 Testing Capacity Scaling...")
    
    B, L, H, W = 1, 64, 8, 8
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Test different numbers of stored patterns
    capacities = [4, 8, 16, 32]
    
    for K in capacities:
        print(f"   Testing {K} stored patterns...")
        
        # Create membrane bank
        bank = MembraneBank(L=L, H=H, W=W, device=device)
        bank.allocate(B)
        
        # Use Hadamard codes for maximum orthogonality
        C = hadamard_codes(L, K).to(device)
        slicer = make_slicer(C)
        
        # Generate random patterns
        patterns = torch.rand(K, H, W, device=device)
        keys = torch.arange(K, device=device)
        alphas = torch.ones(K, device=device)
        
        # Store and retrieve
        M = store_pairs(bank.read(), C, keys, patterns, alphas)
        bank.write(M - bank.read())
        
        readouts = slicer(bank.read()).squeeze(0)
        
        # Calculate average fidelity
        total_psnr = 0
        for i in range(K):
            psnr_val = psnr(patterns[i].cpu().numpy(), readouts[i].cpu().numpy())
            total_psnr += psnr_val
        
        avg_psnr = total_psnr / K
        print(f"     Average PSNR: {avg_psnr:.2f}dB")

def test_interference_analysis():
    """Test cross-talk between stored patterns.""" 
    print("\n🧪 Testing Interference Analysis...")
    
    B, L, H, W, K = 1, 32, 16, 16, 8
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    bank = MembraneBank(L=L, H=H, W=W, device=device)
    bank.allocate(B)
    
    C = hadamard_codes(L, K).to(device)
    slicer = make_slicer(C)
    
    # Store only a subset of patterns
    active_keys = [0, 2, 4]  # Store patterns 0, 2, 4
    patterns = torch.rand(len(active_keys), H, W, device=device)
    keys = torch.tensor(active_keys, device=device)
    alphas = torch.ones(len(active_keys), device=device)
    
    # Store patterns
    M = store_pairs(bank.read(), C, keys, patterns, alphas)
    bank.write(M - bank.read())
    
    # Read all channels (including unused ones)
    readouts = slicer(bank.read()).squeeze(0)  # [K, H, W]
    
    print("   Interference Results:")
    for i in range(K):
        if i in active_keys:
            # This should have high signal
            idx = active_keys.index(i)
            signal_power = torch.norm(readouts[i]).item()
            original_power = torch.norm(patterns[idx]).item()
            print(f"     Channel {i} (stored): Signal power {signal_power:.4f} (original {original_power:.4f})")
        else:
            # This should have low interference
            interference_power = torch.norm(readouts[i]).item()
            print(f"     Channel {i} (empty):  Interference {interference_power:.6f}")

def performance_benchmark():
    """Benchmark WrinkleBrane performance."""
    print("\n⚡ Performance Benchmark...")
    
    B, L, H, W, K = 4, 128, 32, 32, 64
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print(f"   Configuration: B={B}, L={L}, H={H}, W={W}, K={K}")
    print(f"   Memory footprint: {B*L*H*W*4/1e6:.1f}MB (membranes)")
    
    # Setup
    bank = MembraneBank(L=L, H=H, W=W, device=device) 
    bank.allocate(B)
    
    C = hadamard_codes(L, K).to(device)
    slicer = make_slicer(C)
    
    patterns = torch.rand(K, H, W, device=device)
    keys = torch.arange(K, device=device)
    alphas = torch.ones(K, device=device)
    
    # Benchmark write operation
    start_time = time.time()
    for _ in range(10):
        M = store_pairs(bank.read(), C, keys, patterns, alphas)
        bank.write(M - bank.read())
    write_time = (time.time() - start_time) / 10
    
    # Benchmark read operation  
    start_time = time.time()
    for _ in range(100):
        readouts = slicer(bank.read())
    read_time = (time.time() - start_time) / 100
    
    print(f"   Write time: {write_time*1000:.2f}ms ({K/write_time:.0f} patterns/sec)")
    print(f"   Read time: {read_time*1000:.2f}ms ({K*B/read_time:.0f} readouts/sec)")

def main():
    """Run comprehensive WrinkleBrane test suite."""
    print("🌊 WrinkleBrane Comprehensive Test Suite")
    print("="*50)
    
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Run test suite
    success = True
    
    try:
        success &= test_basic_storage_retrieval()
        test_code_comparison() 
        test_capacity_scaling()
        test_interference_analysis()
        performance_benchmark()
        
        print("\n" + "="*50)
        if success:
            print("🎉 WrinkleBrane: ALL TESTS PASSED")
            print("   Wave-interference associative memory working correctly!")
        else:
            print("⚠️  WrinkleBrane: Some tests showed issues")
            print("   System functional but may need optimization")
            
    except Exception as e:
        print(f"\n❌ Test suite failed with error: {e}")
        import traceback
        traceback.print_exc()
        return False
    
    return success

if __name__ == "__main__":
    main()