Executor-Tyrant-Framework Claude Opus 4.6 (1M context) commited on
Commit
efd23fa
·
1 Parent(s): 262b9d5

Fix HF Space: CPU-only torch, lazy imports

Browse files

- Use --extra-index-url for CPU-only PyTorch (much smaller)
- Lazy-import torch and torch_membrane inside functions
- Prevents import failures during Gradio startup

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

__pycache__/condenser.cpython-312.pyc ADDED
Binary file (26.1 kB). View file
 
__pycache__/graph_builder.cpython-312.pyc ADDED
Binary file (25.9 kB). View file
 
__pycache__/membrane.cpython-312.pyc ADDED
Binary file (17.5 kB). View file
 
__pycache__/predictor.cpython-312.pyc ADDED
Binary file (16.8 kB). View file
 
app.py CHANGED
@@ -9,16 +9,13 @@ Compares baseline vs condensed inference.
9
  """
10
 
11
  import gradio as gr
12
- import torch
13
  import numpy as np
14
  import time
15
- import json
16
  import os
17
  import sys
18
 
19
  sys.path.insert(0, os.path.dirname(__file__))
20
 
21
- from torch_membrane import TorchMembrane
22
  from graph_builder import GraphBuilder
23
  from predictor import Predictor
24
 
@@ -36,7 +33,9 @@ def load_model():
36
  """Load model and install membrane."""
37
  global MODEL, TOKENIZER, MEMBRANE
38
 
 
39
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
40
 
41
  TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
42
  if TOKENIZER.pad_token is None:
@@ -58,6 +57,8 @@ def train_predictor(num_prompts=5):
58
  """Run several prompts to train the predictor on access patterns."""
59
  global PREDICTOR, GRAPH, MEMBRANE
60
 
 
 
61
  if MODEL is None:
62
  load_model()
63
 
@@ -108,6 +109,8 @@ def run_inference(prompt, max_tokens=30):
108
  """Run inference and show activation map + condensation potential."""
109
  global MEMBRANE, PREDICTOR
110
 
 
 
111
  if MODEL is None:
112
  load_model()
113
  if PREDICTOR is None:
 
9
  """
10
 
11
  import gradio as gr
 
12
  import numpy as np
13
  import time
 
14
  import os
15
  import sys
16
 
17
  sys.path.insert(0, os.path.dirname(__file__))
18
 
 
19
  from graph_builder import GraphBuilder
20
  from predictor import Predictor
21
 
 
33
  """Load model and install membrane."""
34
  global MODEL, TOKENIZER, MEMBRANE
35
 
36
+ import torch
37
  from transformers import AutoModelForCausalLM, AutoTokenizer
38
+ from torch_membrane import TorchMembrane
39
 
40
  TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
41
  if TOKENIZER.pad_token is None:
 
57
  """Run several prompts to train the predictor on access patterns."""
58
  global PREDICTOR, GRAPH, MEMBRANE
59
 
60
+ import torch
61
+
62
  if MODEL is None:
63
  load_model()
64
 
 
109
  """Run inference and show activation map + condensation potential."""
110
  global MEMBRANE, PREDICTOR
111
 
112
+ import torch
113
+
114
  if MODEL is None:
115
  load_model()
116
  if PREDICTOR is None:
inference_graph.json ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  torch
2
  transformers
3
  numpy
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
  torch
3
  transformers
4
  numpy
test_condenser.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Condensate Layer 3: Condenser Tests
3
+
4
+ The moment of truth — does condensation actually save RAM?
5
+
6
+ Run: python3 test_condenser.py
7
+ """
8
+
9
+ import numpy as np
10
+ import time
11
+ import os
12
+ import sys
13
+
14
+ sys.path.insert(0, os.path.dirname(__file__))
15
+ from condenser import Condenser
16
+
17
+
18
+ def test_basic_compression():
19
+ """Test 1: Can we compress and decompress without data loss?"""
20
+ print("\n--- Test 1: Lossless Compression Round-Trip ---")
21
+
22
+ condenser = Condenser(demotion_idle_ms=1)
23
+
24
+ # Register some numpy arrays
25
+ original_data = np.random.randn(256, 256).astype(np.float32)
26
+ condenser.register("test.weights", original_data.copy())
27
+
28
+ region = condenser.regions["test.weights"]
29
+ original_size = region.original_size
30
+
31
+ # Compress to WARM
32
+ saved = region.compress_to_warm()
33
+ assert region.tier == "WARM"
34
+ assert region.hot_data is None
35
+ assert region.warm_data is not None
36
+ print(f" Original: {original_size / 1024:.1f} KB")
37
+ print(f" Compressed: {region.compressed_size / 1024:.1f} KB")
38
+ print(f" Ratio: {original_size / region.compressed_size:.1f}:1")
39
+ print(f" Saved: {saved / 1024:.1f} KB")
40
+
41
+ # Promote back to HOT
42
+ restored = region.promote_to_hot()
43
+ assert region.tier == "HOT"
44
+ assert np.array_equal(restored, original_data), "Data corrupted after round-trip!"
45
+ print(f" Round-trip: LOSSLESS (arrays match exactly)")
46
+
47
+ # Compress to COLD (disk)
48
+ region.compress_to_cold(condenser.cold_dir)
49
+ assert region.tier == "COLD"
50
+ assert region.current_ram_usage == 0
51
+ print(f" Cold (on disk): 0 KB RAM")
52
+
53
+ # Promote from COLD back to HOT
54
+ restored2 = region.promote_to_hot()
55
+ assert region.tier == "HOT"
56
+ assert np.array_equal(restored2, original_data), "Data corrupted after cold round-trip!"
57
+ print(f" Cold round-trip: LOSSLESS")
58
+
59
+ condenser.cleanup()
60
+ print(" PASS")
61
+
62
+
63
+ def test_selective_condensation():
64
+ """Test 2: Hot regions stay hot, cold regions compress.
65
+
66
+ 16 regions, 4 hot, 12 cold. After condensation, only 4 should
67
+ be in RAM at full size.
68
+ """
69
+ print("\n--- Test 2: Selective Condensation ---")
70
+
71
+ # 16 regions × 64KB each = 1MB total
72
+ # Use structured data (sparse + patterns) — like real weights, not pure noise
73
+ state = {}
74
+ for i in range(16):
75
+ arr = np.zeros((128, 64), dtype=np.float32)
76
+ # Sparse: only ~20% nonzero (realistic for many weight matrices)
77
+ mask = np.random.random((128, 64)) < 0.2
78
+ arr[mask] = np.random.randn(mask.sum()).astype(np.float32)
79
+ state[f"block_{i}"] = arr
80
+
81
+ hot_blocks = {0, 1, 2, 3}
82
+
83
+ def workload(wrapped):
84
+ # Hot blocks: accessed every iteration
85
+ for i in hot_blocks:
86
+ _ = wrapped[f"block_{i}"]
87
+
88
+ # Cold blocks: rarely accessed
89
+ if np.random.random() < 0.05:
90
+ idx = np.random.choice(list(range(4, 16)))
91
+ _ = wrapped[f"block_{idx}"]
92
+
93
+ time.sleep(0.001)
94
+
95
+ condenser = Condenser(demotion_idle_ms=10, warmup_iters=15)
96
+ results = condenser.run_benchmark(state, workload, iterations=30,
97
+ name="selective")
98
+ condenser.print_results(results)
99
+
100
+ # Verify tier management is working — cold regions should exist
101
+ last_log = results["promotion_log"][-1] if results["promotion_log"] else {}
102
+ warm_cold = last_log.get("warm", 0) + last_log.get("cold", 0)
103
+ print(f" Condensed regions (WARM+COLD): {warm_cold} of {results['total_regions']}")
104
+ print(f" RAM saved: {results['saved_mb']:.2f} MB ({results['saved_pct']:.1f}%)")
105
+ assert warm_cold >= 8, f"Should condense at least 8 cold regions, got {warm_cold}"
106
+ condenser.cleanup()
107
+ print(" PASS")
108
+
109
+
110
+ def test_inference_workload():
111
+ """Test 3: Simulated AI inference — THE benchmark.
112
+
113
+ 6-layer model with attention + FFN + KV cache.
114
+ Config and unused layers should compress.
115
+ Active layers should stay hot.
116
+ """
117
+ print("\n--- Test 3: AI Inference Workload (The Real Test) ---")
118
+
119
+ state = {}
120
+
121
+ # Model layers (each ~128KB) — sparse structured weights
122
+ for i in range(6):
123
+ for name in ["q", "k", "v"]:
124
+ arr = np.zeros((128, 128), dtype=np.float32)
125
+ mask = np.random.random((128, 128)) < 0.25
126
+ arr[mask] = np.random.randn(mask.sum()).astype(np.float32)
127
+ state[f"layer_{i}_{name}"] = arr
128
+ for name, shape in [("ffn_up", (128, 512)), ("ffn_down", (512, 128))]:
129
+ arr = np.zeros(shape, dtype=np.float32)
130
+ mask = np.random.random(shape) < 0.2
131
+ arr[mask] = np.random.randn(mask.sum()).astype(np.float32)
132
+ state[f"layer_{i}_{name}"] = arr
133
+
134
+ # KV cache — zeros (compresses extremely well)
135
+ for i in range(6):
136
+ state[f"kv_{i}_keys"] = np.zeros((256, 128), dtype=np.float32)
137
+ state[f"kv_{i}_vals"] = np.zeros((256, 128), dtype=np.float32)
138
+
139
+ # Config and metadata (small)
140
+ for i in range(20):
141
+ state[f"meta_{i}"] = np.zeros(32, dtype=np.float32)
142
+
143
+ def workload(wrapped):
144
+ # Token generation: sequential through layers
145
+ for token in range(3):
146
+ for layer_idx in range(6):
147
+ _ = wrapped[f"layer_{layer_idx}_q"]
148
+ _ = wrapped[f"layer_{layer_idx}_k"]
149
+ _ = wrapped[f"layer_{layer_idx}_v"]
150
+ _ = wrapped[f"kv_{layer_idx}_keys"]
151
+ _ = wrapped[f"kv_{layer_idx}_vals"]
152
+ _ = wrapped[f"layer_{layer_idx}_ffn_up"]
153
+ _ = wrapped[f"layer_{layer_idx}_ffn_down"]
154
+ time.sleep(0.0001)
155
+
156
+ # Metadata accessed once per request
157
+ _ = wrapped["meta_0"]
158
+ _ = wrapped["meta_1"]
159
+
160
+ print(f" State: {len(state)} regions, "
161
+ f"{sum(v.nbytes for v in state.values()) / 1024 / 1024:.2f} MB total")
162
+
163
+ condenser = Condenser(demotion_idle_ms=5, warmup_iters=10)
164
+ results = condenser.run_benchmark(state, workload, iterations=20,
165
+ name="inference")
166
+ condenser.print_results(results)
167
+
168
+ print(f"\n *** INFERENCE RESULTS ***")
169
+ print(f" Baseline RAM: {results['baseline_ram_mb']:.2f} MB")
170
+ print(f" Condensed RAM: {results['avg_condensed_ram_mb']:.2f} MB")
171
+ print(f" Saved: {results['saved_mb']:.2f} MB ({results['saved_pct']:.1f}%)")
172
+ print(f" Prediction acc: {results['prediction_accuracy']}%")
173
+
174
+ condenser.cleanup()
175
+ print(" PASS")
176
+
177
+
178
+ def test_large_state():
179
+ """Test 4: Larger state — stress test with meaningful RAM numbers.
180
+
181
+ 64 regions × 256KB = 16 MB total state.
182
+ Only 8 regions hot at any time = 2 MB needed.
183
+ Target: condense ~14 MB.
184
+ """
185
+ print("\n--- Test 4: Large State Stress Test ---")
186
+
187
+ # 64 regions × 256KB each = 16 MB
188
+ # Structured sparse data — compresses well
189
+ state = {}
190
+ for i in range(64):
191
+ arr = np.zeros((256, 128), dtype=np.float32)
192
+ mask = np.random.random((256, 128)) < 0.15
193
+ arr[mask] = np.random.randn(mask.sum()).astype(np.float32)
194
+ state[f"region_{i}"] = arr
195
+
196
+ # 8 hot regions that rotate
197
+ hot_set_a = set(range(0, 8))
198
+ hot_set_b = set(range(32, 40))
199
+
200
+ iteration_count = [0]
201
+
202
+ def workload(wrapped):
203
+ iteration_count[0] += 1
204
+ # Alternate between two hot sets
205
+ hot = hot_set_a if (iteration_count[0] % 20) < 10 else hot_set_b
206
+
207
+ for i in hot:
208
+ _ = wrapped[f"region_{i}"]
209
+
210
+ time.sleep(0.002)
211
+
212
+ total_mb = sum(v.nbytes for v in state.values()) / 1024 / 1024
213
+ print(f" State: {len(state)} regions, {total_mb:.1f} MB total")
214
+ print(f" Only 8 regions hot at any time (2 MB needed)")
215
+
216
+ condenser = Condenser(demotion_idle_ms=15, warmup_iters=15)
217
+ results = condenser.run_benchmark(state, workload, iterations=40,
218
+ name="large")
219
+ condenser.print_results(results)
220
+
221
+ print(f"\n *** LARGE STATE RESULTS ***")
222
+ print(f" Baseline RAM: {results['baseline_ram_mb']:.1f} MB (all in RAM)")
223
+ print(f" Condensed RAM: {results['avg_condensed_ram_mb']:.1f} MB")
224
+ print(f" Saved: {results['saved_mb']:.1f} MB ({results['saved_pct']:.1f}%)")
225
+
226
+ condenser.cleanup()
227
+ print(" PASS")
228
+
229
+
230
+ def test_prediction_value():
231
+ """Test 5: Measure prediction-driven vs reactive promotions.
232
+
233
+ The ratio of predicted vs reactive tells us how much the
234
+ predictor is actually helping vs just reacting to cache misses.
235
+ """
236
+ print("\n--- Test 5: Prediction Value Measurement ---")
237
+
238
+ state = {f"chunk_{i}": np.random.randn(64, 64).astype(np.float32)
239
+ for i in range(20)}
240
+
241
+ # Predictable pattern: 0→1→2→3, then 10→11→12→13
242
+ def workload(wrapped):
243
+ for i in range(4):
244
+ _ = wrapped[f"chunk_{i}"]
245
+ time.sleep(0.001)
246
+ time.sleep(0.005)
247
+ for i in range(10, 14):
248
+ _ = wrapped[f"chunk_{i}"]
249
+ time.sleep(0.001)
250
+ time.sleep(0.005)
251
+
252
+ condenser = Condenser(demotion_idle_ms=8, warmup_iters=15)
253
+ results = condenser.run_benchmark(state, workload, iterations=25,
254
+ name="predval")
255
+ condenser.print_results(results)
256
+
257
+ pred = results["prediction_promotions"]
258
+ react = results["reactive_promotions"]
259
+ total = pred + react
260
+
261
+ if total > 0:
262
+ pred_pct = pred / total * 100
263
+ print(f"\n Promotions: {total} total")
264
+ print(f" Prediction-driven: {pred} ({pred_pct:.0f}%)")
265
+ print(f" Reactive (miss): {react} ({100-pred_pct:.0f}%)")
266
+
267
+ if pred_pct > 50:
268
+ print(f" GOOD — Majority of promotions are prediction-driven")
269
+ else:
270
+ print(f" Prediction helps but reactive still dominates")
271
+ else:
272
+ print(f" No promotions needed (everything stayed HOT)")
273
+
274
+ condenser.cleanup()
275
+ print(" PASS")
276
+
277
+
278
+ if __name__ == "__main__":
279
+ print("=" * 60)
280
+ print(" CONDENSATE — Layer 3 Condenser Tests")
281
+ print(" The Moment of Truth: Does It Actually Save RAM?")
282
+ print("=" * 60)
283
+
284
+ test_basic_compression()
285
+ test_selective_condensation()
286
+ test_inference_workload()
287
+ test_large_state()
288
+ test_prediction_value()
289
+
290
+ print("\n" + "=" * 60)
291
+ print(" ALL TESTS PASSED")
292
+ print("=" * 60)
test_graph_builder.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Condensate Layer 1: Graph Builder Tests
3
+
4
+ Tests the graph builder on access logs from the Membrane.
5
+ Run: python3 test_graph_builder.py
6
+ """
7
+
8
+ import numpy as np
9
+ import time
10
+ import os
11
+ import sys
12
+
13
+ sys.path.insert(0, os.path.dirname(__file__))
14
+ from membrane import Membrane
15
+ from graph_builder import GraphBuilder
16
+
17
+
18
+ def test_sequential_model():
19
+ """Test 1: Sequential layer access — like a transformer forward pass.
20
+ Should discover: each layer is a cluster, layers chain sequentially.
21
+ """
22
+ print("\n--- Test 1: Sequential Model Forward Pass ---")
23
+ Membrane.clear()
24
+
25
+ # 12-layer "model" with attention components
26
+ state = {}
27
+ for layer in range(12):
28
+ state[f"layer_{layer}"] = {
29
+ "weight": np.random.randn(128, 128).astype(np.float32),
30
+ "bias": np.random.randn(128).astype(np.float32),
31
+ "attn_q": np.random.randn(128, 128).astype(np.float32),
32
+ "attn_k": np.random.randn(128, 128).astype(np.float32),
33
+ "attn_v": np.random.randn(128, 128).astype(np.float32),
34
+ }
35
+
36
+ wrapped = Membrane.wrap(state, "model")
37
+
38
+ # Run 5 "forward passes" — sequential layer access
39
+ for pass_num in range(5):
40
+ for layer_idx in range(12):
41
+ layer = wrapped[f"layer_{layer_idx}"]
42
+ _ = layer["weight"]
43
+ _ = layer["bias"]
44
+ _ = layer["attn_q"]
45
+ _ = layer["attn_k"]
46
+ _ = layer["attn_v"]
47
+ time.sleep(0.0002) # small gap between layers
48
+
49
+ # Build graph
50
+ graph = GraphBuilder(causal_window_ns=2_000_000) # 2ms window
51
+ graph.build(Membrane.get_log())
52
+ graph.print_analysis()
53
+
54
+ # Verify clusters found
55
+ assert len(graph.clusters) > 0, "Should find layer clusters"
56
+
57
+ # Verify causal chains found
58
+ chains = graph.get_causal_chains()
59
+ assert len(chains) > 0, "Should find sequential chains"
60
+
61
+ print(" PASS")
62
+
63
+
64
+ def test_hot_cold_pattern():
65
+ """Test 2: Hot/cold access — some regions hammered, others barely touched.
66
+ Should discover: clear temperature separation, cold regions compressible.
67
+ """
68
+ print("\n--- Test 2: Hot/Cold Access Pattern ---")
69
+ Membrane.clear()
70
+
71
+ # 20 regions, 4 of them hot
72
+ state = {f"region_{i}": np.random.randn(64, 64).astype(np.float32)
73
+ for i in range(20)}
74
+ wrapped = Membrane.wrap(state, "hotcold")
75
+
76
+ hot = {2, 7, 13, 18}
77
+
78
+ for _ in range(100):
79
+ for i in range(20):
80
+ if i in hot:
81
+ _ = wrapped[f"region_{i}"] # hot: every iteration
82
+ elif np.random.random() < 0.03:
83
+ _ = wrapped[f"region_{i}"] # cold: 3% chance
84
+
85
+ graph = GraphBuilder()
86
+ graph.build(Membrane.get_log())
87
+ graph.print_analysis()
88
+
89
+ # Verify temperature classification
90
+ hot_nodes = [n for n in graph.nodes.values()
91
+ if getattr(n, '_temp_class', '') == 'HOT']
92
+ cold_nodes = [n for n in graph.nodes.values()
93
+ if getattr(n, '_temp_class', '') == 'COLD']
94
+
95
+ print(f" HOT nodes: {len(hot_nodes)}, COLD nodes: {len(cold_nodes)}")
96
+ assert len(hot_nodes) >= 3, "Should identify hot regions"
97
+ assert len(cold_nodes) >= 1, "Should identify cold regions"
98
+ print(" PASS")
99
+
100
+
101
+ def test_causal_chains():
102
+ """Test 3: Known causal chains — verify the graph discovers them.
103
+ This is the core capability: can we learn prefetch chains?
104
+ """
105
+ print("\n--- Test 3: Causal Chain Discovery ---")
106
+ Membrane.clear()
107
+
108
+ state = {f"r{i}": np.random.randn(32, 32).astype(np.float32)
109
+ for i in range(10)}
110
+ wrapped = Membrane.wrap(state, "causal")
111
+
112
+ # Chain A: r0 → r2 → r5 → r9 (always this order, ~0.5ms apart)
113
+ # Chain B: r1 → r3 → r6 (always this order)
114
+ # Noise: r4, r7, r8 (random, no pattern)
115
+
116
+ for _ in range(80):
117
+ # Chain A
118
+ _ = wrapped["r0"]
119
+ time.sleep(0.0005)
120
+ _ = wrapped["r2"]
121
+ time.sleep(0.0005)
122
+ _ = wrapped["r5"]
123
+ time.sleep(0.0005)
124
+ _ = wrapped["r9"]
125
+ time.sleep(0.001)
126
+
127
+ # Chain B
128
+ _ = wrapped["r1"]
129
+ time.sleep(0.0005)
130
+ _ = wrapped["r3"]
131
+ time.sleep(0.0005)
132
+ _ = wrapped["r6"]
133
+ time.sleep(0.002)
134
+
135
+ # Noise
136
+ if np.random.random() > 0.5:
137
+ _ = wrapped[f"r{np.random.choice([4, 7, 8])}"]
138
+
139
+ graph = GraphBuilder(causal_window_ns=3_000_000) # 3ms window
140
+ graph.build(Membrane.get_log())
141
+ graph.print_analysis()
142
+
143
+ # Check for discovered chains
144
+ chains = graph.get_causal_chains(min_weight=5.0)
145
+ print(f"\n Chains found (weight >= 5): {len(chains)}")
146
+ for chain in chains:
147
+ path_names = [p.split(".")[-1] for p, _ in chain]
148
+ print(f" {' → '.join(path_names)}")
149
+
150
+ # The graph should find chain-like patterns
151
+ # (exact chains depend on timing, but structure should be visible)
152
+ assert len(chains) >= 1, "Should discover at least one causal chain"
153
+ print(" PASS")
154
+
155
+
156
+ def test_cluster_discovery():
157
+ """Test 4: Co-access clusters — groups of regions always used together.
158
+ These become hyperedges: promote/demote the whole group as a unit.
159
+ """
160
+ print("\n--- Test 4: Cluster (Proto-Hyperedge) Discovery ---")
161
+ Membrane.clear()
162
+
163
+ state = {f"item_{i}": np.random.randn(16).astype(np.float32)
164
+ for i in range(15)}
165
+ wrapped = Membrane.wrap(state, "cluster")
166
+
167
+ # Cluster A: items 0, 1, 2 always together
168
+ # Cluster B: items 5, 6, 7, 8 always together
169
+ # Cluster C: items 10, 11 always together
170
+ # Singletons: 3, 4, 9, 12, 13, 14 — accessed independently
171
+
172
+ for _ in range(60):
173
+ # Cluster A — tight access, big gap after
174
+ _ = wrapped["item_0"]
175
+ _ = wrapped["item_1"]
176
+ _ = wrapped["item_2"]
177
+ time.sleep(0.008) # 8ms gap — outside causal window
178
+
179
+ # Cluster B — tight access, big gap after
180
+ _ = wrapped["item_5"]
181
+ _ = wrapped["item_6"]
182
+ _ = wrapped["item_7"]
183
+ _ = wrapped["item_8"]
184
+ time.sleep(0.008)
185
+
186
+ # Cluster C (less frequent)
187
+ if np.random.random() > 0.3:
188
+ _ = wrapped["item_10"]
189
+ _ = wrapped["item_11"]
190
+ time.sleep(0.008)
191
+
192
+ # Random singletons
193
+ idx = np.random.choice([3, 4, 9, 12, 13, 14])
194
+ _ = wrapped[f"item_{idx}"]
195
+ time.sleep(0.008)
196
+
197
+ graph = GraphBuilder(causal_window_ns=3_000_000, cluster_threshold=0.6)
198
+ graph.build(Membrane.get_log())
199
+ graph.print_analysis()
200
+
201
+ # Should find at least 2 clear clusters
202
+ print(f"\n Clusters found: {len(graph.clusters)}")
203
+ assert len(graph.clusters) >= 2, "Should find multiple clusters"
204
+
205
+ # Verify cluster A members are together
206
+ cluster_a_found = False
207
+ for cluster in graph.clusters:
208
+ paths = {m.split(".")[-1] for m in cluster.members}
209
+ if {"item_0", "item_1", "item_2"}.issubset(paths):
210
+ cluster_a_found = True
211
+ break
212
+
213
+ assert cluster_a_found, "Should find cluster A (items 0,1,2)"
214
+ print(" Cluster A (items 0,1,2) found correctly")
215
+ print(" PASS")
216
+
217
+
218
+ def test_real_world_simulation():
219
+ """Test 5: Realistic workload — simulates an AI inference server.
220
+
221
+ Pattern:
222
+ - Model weights accessed sequentially (forward pass)
223
+ - KV cache accessed selectively (attention)
224
+ - Config accessed once at start
225
+ - Buffer reused across requests
226
+ """
227
+ print("\n--- Test 5: Realistic AI Inference Simulation ---")
228
+ Membrane.clear()
229
+
230
+ state = {
231
+ "config": {"max_tokens": 512, "temperature": 0.7, "top_p": 0.9},
232
+ "buffer": {"input_ids": np.zeros(512, dtype=np.int32),
233
+ "logits": np.zeros(32000, dtype=np.float32)},
234
+ }
235
+ # Add model layers
236
+ for i in range(6):
237
+ state[f"layer_{i}"] = {
238
+ "q": np.random.randn(64, 64).astype(np.float32),
239
+ "k": np.random.randn(64, 64).astype(np.float32),
240
+ "v": np.random.randn(64, 64).astype(np.float32),
241
+ "ffn_up": np.random.randn(64, 256).astype(np.float32),
242
+ "ffn_down": np.random.randn(256, 64).astype(np.float32),
243
+ }
244
+ # Add KV cache (per layer, grows with sequence)
245
+ for i in range(6):
246
+ state[f"kv_cache_{i}"] = {
247
+ "keys": np.zeros((512, 64), dtype=np.float32),
248
+ "values": np.zeros((512, 64), dtype=np.float32),
249
+ }
250
+
251
+ wrapped = Membrane.wrap(state, "server")
252
+
253
+ # Simulate 3 requests
254
+ for req in range(3):
255
+ # Config read once per request
256
+ _ = wrapped["config"]["max_tokens"]
257
+ _ = wrapped["config"]["temperature"]
258
+
259
+ # Buffer setup
260
+ _ = wrapped["buffer"]["input_ids"]
261
+
262
+ # Forward pass — 10 "tokens" of autoregressive generation
263
+ for token in range(10):
264
+ for layer_idx in range(6):
265
+ # Attention
266
+ layer = wrapped[f"layer_{layer_idx}"]
267
+ _ = layer["q"]
268
+ _ = layer["k"]
269
+ _ = layer["v"]
270
+
271
+ # KV cache read/write
272
+ cache = wrapped[f"kv_cache_{layer_idx}"]
273
+ _ = cache["keys"]
274
+ _ = cache["values"]
275
+
276
+ # FFN
277
+ _ = layer["ffn_up"]
278
+ _ = layer["ffn_down"]
279
+ time.sleep(0.0001)
280
+
281
+ # Logits at the end of each token
282
+ _ = wrapped["buffer"]["logits"]
283
+
284
+ total_bytes = 0
285
+ for k, v in state.items():
286
+ if isinstance(v, dict):
287
+ for v2 in v.values():
288
+ if isinstance(v2, np.ndarray):
289
+ total_bytes += v2.nbytes
290
+ elif isinstance(v2, dict):
291
+ for v3 in v2.values():
292
+ if isinstance(v3, np.ndarray):
293
+ total_bytes += v3.nbytes
294
+ total_mb = total_bytes / 1024 / 1024
295
+
296
+ print(f" Simulated: 3 requests × 10 tokens × 6 layers")
297
+ print(f" Total state: {total_mb:.1f} MB")
298
+
299
+ graph = GraphBuilder(causal_window_ns=2_000_000)
300
+ graph.build(Membrane.get_log())
301
+ graph.print_analysis()
302
+
303
+ # Save for potential Layer 2 testing
304
+ graph.save(os.path.join(os.path.dirname(__file__), "inference_graph.json"))
305
+
306
+ # Verify key insights
307
+ config_node = graph.nodes.get("server.config.max_tokens")
308
+ layer0_q = graph.nodes.get("server.layer_0.q")
309
+
310
+ if config_node and layer0_q:
311
+ print(f" Config accesses: {config_node.access_count} (read once per request)")
312
+ print(f" Layer 0 Q accesses: {layer0_q.access_count} (every token, every request)")
313
+ ratio = layer0_q.access_count / max(config_node.access_count, 1)
314
+ print(f" Ratio: {ratio:.0f}x — config is compressible, Q is not")
315
+
316
+ print(" PASS")
317
+
318
+
319
+ if __name__ == "__main__":
320
+ print("=" * 60)
321
+ print(" CONDENSATE — Layer 1 Graph Builder Tests")
322
+ print("=" * 60)
323
+
324
+ test_sequential_model()
325
+ test_hot_cold_pattern()
326
+ test_causal_chains()
327
+ test_cluster_discovery()
328
+ test_real_world_simulation()
329
+
330
+ print("\n" + "=" * 60)
331
+ print(" ALL TESTS PASSED")
332
+ print("=" * 60)
test_membrane.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Condensate Layer 0: Membrane Tests
3
+
4
+ Tests the membrane wrapper on increasingly realistic workloads.
5
+ Run: python3 test_membrane.py
6
+ """
7
+
8
+ import numpy as np
9
+ import time
10
+ import os
11
+ import sys
12
+
13
+ # Add parent dir to path so we can import membrane
14
+ sys.path.insert(0, os.path.dirname(__file__))
15
+ from membrane import Membrane
16
+
17
+
18
+ def test_basic_dict():
19
+ """Test 1: Basic dict access tracking."""
20
+ print("\n--- Test 1: Basic Dict Access ---")
21
+ Membrane.clear()
22
+
23
+ data = Membrane.wrap({
24
+ "name": "test",
25
+ "values": [1, 2, 3, 4, 5],
26
+ "nested": {"a": 10, "b": 20, "c": 30},
27
+ }, "basic")
28
+
29
+ # Read some values
30
+ _ = data["name"]
31
+ _ = data["name"] # same key twice
32
+ _ = data["values"]
33
+ _ = data["nested"]["a"] # nested read — should log both levels
34
+ _ = data["nested"]["b"]
35
+
36
+ # Write
37
+ data["name"] = "updated"
38
+
39
+ assert Membrane.entry_count() > 0, "Should have recorded accesses"
40
+ Membrane.print_stats()
41
+ print(" PASS")
42
+
43
+
44
+ def test_numpy_arrays():
45
+ """Test 2: Dict of numpy arrays — simulates model weight storage."""
46
+ print("\n--- Test 2: NumPy Array State (Simulated Model Weights) ---")
47
+ Membrane.clear()
48
+
49
+ # Simulate a small model with layers of weight matrices
50
+ state = {}
51
+ for layer in range(8):
52
+ state[f"layer_{layer}"] = {
53
+ "weight": np.random.randn(256, 256).astype(np.float32),
54
+ "bias": np.random.randn(256).astype(np.float32),
55
+ "attention": {
56
+ "q_proj": np.random.randn(256, 256).astype(np.float32),
57
+ "k_proj": np.random.randn(256, 256).astype(np.float32),
58
+ "v_proj": np.random.randn(256, 256).astype(np.float32),
59
+ }
60
+ }
61
+
62
+ wrapped = Membrane.wrap(state, "model")
63
+
64
+ total_bytes = sum(
65
+ state[f"layer_{i}"]["weight"].nbytes +
66
+ state[f"layer_{i}"]["bias"].nbytes +
67
+ sum(v.nbytes for v in state[f"layer_{i}"]["attention"].values())
68
+ for i in range(8)
69
+ )
70
+ print(f" Model state: {total_bytes / 1024 / 1024:.1f} MB across 8 layers")
71
+
72
+ # Simulate a forward pass — sequential layer access
73
+ print(" Simulating forward pass...")
74
+ for layer_idx in range(8):
75
+ layer = wrapped[f"layer_{layer_idx}"]
76
+ w = layer["weight"]
77
+ b = layer["bias"]
78
+ attn = layer["attention"]
79
+ q = attn["q_proj"]
80
+ k = attn["k_proj"]
81
+ v = attn["v_proj"]
82
+
83
+ # Simulate a second forward pass — same pattern
84
+ print(" Simulating second forward pass...")
85
+ for layer_idx in range(8):
86
+ layer = wrapped[f"layer_{layer_idx}"]
87
+ w = layer["weight"]
88
+ b = layer["bias"]
89
+ attn = layer["attention"]
90
+ q = attn["q_proj"]
91
+ k = attn["k_proj"]
92
+ v = attn["v_proj"]
93
+
94
+ Membrane.print_stats()
95
+ print(" PASS")
96
+
97
+
98
+ def test_selective_access():
99
+ """Test 3: Selective access — some layers hot, some cold.
100
+ This is the pattern Condensate exploits: not all state is accessed equally.
101
+ """
102
+ print("\n--- Test 3: Selective Access (Hot/Cold Pattern) ---")
103
+ Membrane.clear()
104
+
105
+ state = {}
106
+ for layer in range(16):
107
+ state[f"layer_{layer}"] = {
108
+ "weight": np.random.randn(128, 128).astype(np.float32),
109
+ "bias": np.random.randn(128).astype(np.float32),
110
+ }
111
+
112
+ wrapped = Membrane.wrap(state, "selective")
113
+
114
+ # Simulate: layers 3, 7, 11 are "hot" — accessed 10x more
115
+ hot_layers = {3, 7, 11}
116
+ for iteration in range(20):
117
+ for layer_idx in range(16):
118
+ if layer_idx in hot_layers:
119
+ # Hot path — always accessed
120
+ layer = wrapped[f"layer_{layer_idx}"]
121
+ _ = layer["weight"]
122
+ _ = layer["bias"]
123
+ elif iteration % 10 == 0:
124
+ # Cold path — accessed once every 10 iterations
125
+ layer = wrapped[f"layer_{layer_idx}"]
126
+ _ = layer["weight"]
127
+
128
+ stats = Membrane.stats()
129
+ Membrane.print_stats()
130
+
131
+ # Verify hot layers have more accesses
132
+ hot_count = sum(
133
+ stats["paths"].get(f"selective.layer_{i}", {}).get("reads", 0)
134
+ for i in hot_layers
135
+ )
136
+ cold_count = sum(
137
+ stats["paths"].get(f"selective.layer_{i}", {}).get("reads", 0)
138
+ for i in range(16) if i not in hot_layers
139
+ )
140
+ ratio = hot_count / max(cold_count, 1)
141
+ print(f" Hot/cold access ratio: {ratio:.1f}x")
142
+ print(f" (This ratio is what Condensate exploits — hot stays in RAM, cold compresses)")
143
+ print(" PASS")
144
+
145
+
146
+ def test_temporal_chains():
147
+ """Test 4: Temporal access chains — A always followed by B followed by C.
148
+ This is what the SNN will learn as causal chains for prefetch.
149
+ """
150
+ print("\n--- Test 4: Temporal Chains (Causal Access Patterns) ---")
151
+ Membrane.clear()
152
+
153
+ state = {f"region_{i}": np.random.randn(64, 64).astype(np.float32) for i in range(10)}
154
+ wrapped = Membrane.wrap(state, "temporal")
155
+
156
+ # Chain 1: 0 → 3 → 7 (always in this order)
157
+ # Chain 2: 1 → 4 → 8 (always in this order)
158
+ # Region 5: random, no chain
159
+ chains = [
160
+ [0, 3, 7],
161
+ [1, 4, 8],
162
+ ]
163
+
164
+ for _ in range(50):
165
+ for chain in chains:
166
+ for region_id in chain:
167
+ _ = wrapped[f"region_{region_id}"]
168
+ time.sleep(0.0001) # tiny delay to separate timestamps
169
+
170
+ # Random access to region 5
171
+ if np.random.random() > 0.5:
172
+ _ = wrapped["region_5"]
173
+
174
+ stats = Membrane.stats()
175
+ Membrane.print_stats()
176
+
177
+ # Check co-accesses — chain members should co-access heavily
178
+ coaccesses = stats.get("top_coaccesses", [])
179
+ if coaccesses:
180
+ print(f" Top co-access pairs found: {len(coaccesses)}")
181
+ print(f" (These are the causal chains the SNN would learn)")
182
+
183
+ print(" PASS")
184
+
185
+
186
+ def test_overhead():
187
+ """Test 5: Measure the membrane's overhead.
188
+ This tells us if the observation layer is cheap enough.
189
+ """
190
+ print("\n--- Test 5: Overhead Measurement ---")
191
+
192
+ state = {f"key_{i}": np.random.randn(32).astype(np.float32) for i in range(100)}
193
+
194
+ # Baseline: raw dict access
195
+ iterations = 100_000
196
+ start = time.monotonic_ns()
197
+ for _ in range(iterations):
198
+ for key in ["key_0", "key_50", "key_99"]:
199
+ _ = state[key]
200
+ raw_ns = time.monotonic_ns() - start
201
+
202
+ # Membrane: wrapped dict access
203
+ Membrane.clear()
204
+ wrapped = Membrane.wrap(state.copy(), "overhead")
205
+ start = time.monotonic_ns()
206
+ for _ in range(iterations):
207
+ for key in ["key_0", "key_50", "key_99"]:
208
+ _ = wrapped[key]
209
+ membrane_ns = time.monotonic_ns() - start
210
+
211
+ raw_per = raw_ns / (iterations * 3)
212
+ membrane_per = membrane_ns / (iterations * 3)
213
+ overhead = membrane_per - raw_per
214
+
215
+ print(f" Raw dict access: {raw_per:.0f} ns/access")
216
+ print(f" Membrane access: {membrane_per:.0f} ns/access")
217
+ print(f" Overhead per access: {overhead:.0f} ns")
218
+ print(f" Slowdown factor: {membrane_per / raw_per:.1f}x")
219
+ print(f" Total accesses logged: {Membrane.entry_count()}")
220
+
221
+ # The membrane is for observation only — overhead is acceptable
222
+ # if it's under ~1μs per access. For production, the Rust core
223
+ # will bring this to ~5ns.
224
+ if overhead < 5000:
225
+ print(f" Overhead acceptable for PoC (< 5μs)")
226
+ else:
227
+ print(f" Overhead high — expected for Python, Rust core will fix")
228
+
229
+ print(" PASS")
230
+
231
+
232
+ def test_save_log():
233
+ """Test 6: Save the access log for Layer 1 analysis."""
234
+ print("\n--- Test 6: Save Log ---")
235
+ Membrane.clear()
236
+
237
+ state = {f"region_{i}": np.random.randn(64, 64).astype(np.float32) for i in range(5)}
238
+ wrapped = Membrane.wrap(state, "saveable")
239
+
240
+ # Generate some access patterns
241
+ for _ in range(10):
242
+ _ = wrapped["region_0"]
243
+ _ = wrapped["region_2"]
244
+ _ = wrapped["region_4"]
245
+
246
+ log_path = os.path.join(os.path.dirname(__file__), "test_access_log.json")
247
+ Membrane.save_log(log_path)
248
+
249
+ # Verify file exists and is valid JSON
250
+ import json
251
+ with open(log_path) as f:
252
+ data = json.load(f)
253
+ assert "entries" in data
254
+ assert len(data["entries"]) == 30 # 3 accesses x 10 iterations
255
+
256
+ # Clean up
257
+ os.remove(log_path)
258
+ print(" PASS")
259
+
260
+
261
+ if __name__ == "__main__":
262
+ print("=" * 60)
263
+ print(" CONDENSATE — Layer 0 Membrane Tests")
264
+ print("=" * 60)
265
+
266
+ test_basic_dict()
267
+ test_numpy_arrays()
268
+ test_selective_access()
269
+ test_temporal_chains()
270
+ test_overhead()
271
+ test_save_log()
272
+
273
+ print("\n" + "=" * 60)
274
+ print(" ALL TESTS PASSED")
275
+ print("=" * 60)
test_predictor.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Condensate Layer 2: Predictor Tests
3
+
4
+ Tests prediction accuracy on known access patterns.
5
+ The key question: can we predict what's coming before it's requested?
6
+
7
+ Run: python3 test_predictor.py
8
+ """
9
+
10
+ import numpy as np
11
+ import time
12
+ import os
13
+ import sys
14
+
15
+ sys.path.insert(0, os.path.dirname(__file__))
16
+ from membrane import Membrane
17
+ from graph_builder import GraphBuilder
18
+ from predictor import Predictor
19
+
20
+
21
+ def generate_and_learn(name, state, access_fn, train_iters,
22
+ causal_window_ns=3_000_000):
23
+ """Helper: run a workload, build graph, learn predictor.
24
+
25
+ Returns (predictor, graph) after training.
26
+ """
27
+ Membrane.clear()
28
+ wrapped = Membrane.wrap(state, name)
29
+
30
+ for _ in range(train_iters):
31
+ access_fn(wrapped)
32
+
33
+ train_log = Membrane.get_log()
34
+
35
+ graph = GraphBuilder(causal_window_ns=causal_window_ns)
36
+ graph.build(train_log)
37
+
38
+ predictor = Predictor()
39
+ predictor.learn(graph)
40
+
41
+ return predictor, graph, train_log
42
+
43
+
44
+ def test_sequential_prediction():
45
+ """Test 1: Sequential layer access — can we predict the next layer?
46
+
47
+ Pattern: layer_0 → layer_1 → layer_2 → ... → layer_7
48
+ If we see layer_3, we should predict layer_4.
49
+ """
50
+ print("\n--- Test 1: Sequential Layer Prediction ---")
51
+
52
+ state = {f"layer_{i}": {"w": np.random.randn(64, 64).astype(np.float32)}
53
+ for i in range(8)}
54
+
55
+ def access_fn(wrapped):
56
+ for i in range(8):
57
+ layer = wrapped[f"layer_{i}"]
58
+ _ = layer["w"]
59
+ time.sleep(0.0005)
60
+
61
+ # Train on 20 passes
62
+ predictor, graph, train_log = generate_and_learn(
63
+ "seq", state, access_fn, train_iters=20
64
+ )
65
+
66
+ predictor.print_model()
67
+
68
+ # Test on 10 new passes
69
+ Membrane.clear()
70
+ wrapped = Membrane.wrap(
71
+ {k: dict(v) if isinstance(v, dict) else v for k, v in state.items()},
72
+ "seq"
73
+ )
74
+ for _ in range(10):
75
+ access_fn(wrapped)
76
+
77
+ test_log = Membrane.get_log()
78
+ result = predictor.print_score(test_log, verbose=True)
79
+
80
+ assert result["accuracy"] > 50, f"Sequential prediction should be >50%, got {result['accuracy']}%"
81
+ print(f" Accuracy: {result['accuracy']}% — sequential prediction works!")
82
+ print(" PASS")
83
+
84
+
85
+ def test_causal_chain_prediction():
86
+ """Test 2: Known causal chains — A→B→C with consistent timing.
87
+
88
+ The predictor should learn the chain and predict B when A fires,
89
+ and C when B fires. Multi-hop: seeing A should also predict C.
90
+ """
91
+ print("\n--- Test 2: Causal Chain Prediction ---")
92
+
93
+ state = {f"r{i}": np.random.randn(32).astype(np.float32)
94
+ for i in range(8)}
95
+
96
+ def access_fn(wrapped):
97
+ # Chain: r0 → r2 → r5 → r7 (always, ~1ms apart)
98
+ _ = wrapped["r0"]
99
+ time.sleep(0.001)
100
+ _ = wrapped["r2"]
101
+ time.sleep(0.001)
102
+ _ = wrapped["r5"]
103
+ time.sleep(0.001)
104
+ _ = wrapped["r7"]
105
+ time.sleep(0.005)
106
+
107
+ # Noise
108
+ if np.random.random() > 0.7:
109
+ _ = wrapped[f"r{np.random.choice([1, 3, 4, 6])}"]
110
+ time.sleep(0.005)
111
+
112
+ predictor, graph, train_log = generate_and_learn(
113
+ "chain", state, access_fn, train_iters=50
114
+ )
115
+
116
+ predictor.print_model()
117
+
118
+ # Test: when r0 fires, do we predict r2?
119
+ preds = predictor.predict("chain.r0")
120
+ pred_paths = [p.path for p in preds]
121
+ print(f" When r0 fires, predictions: {[p.path.split('.')[-1] for p in preds[:5]]}")
122
+
123
+ r2_predicted = "chain.r2" in pred_paths
124
+ print(f" r2 predicted after r0: {r2_predicted}")
125
+
126
+ # Test: when r2 fires, do we predict r5?
127
+ preds_r2 = predictor.predict("chain.r2")
128
+ pred_paths_r2 = [p.path for p in preds_r2]
129
+ r5_predicted = "chain.r5" in pred_paths_r2
130
+ print(f" r5 predicted after r2: {r5_predicted}")
131
+
132
+ # Score on fresh data
133
+ Membrane.clear()
134
+ wrapped = Membrane.wrap(
135
+ {k: v.copy() if hasattr(v, 'copy') else v for k, v in state.items()},
136
+ "chain"
137
+ )
138
+ for _ in range(20):
139
+ access_fn(wrapped)
140
+
141
+ result = predictor.print_score(Membrane.get_log(), verbose=True)
142
+
143
+ assert r2_predicted, "Should predict r2 after r0"
144
+ print(" PASS")
145
+
146
+
147
+ def test_cluster_prediction():
148
+ """Test 3: Cluster co-activation — if one member fires, predict all.
149
+
150
+ When item_0 fires, we should predict item_1 and item_2 (same cluster).
151
+ """
152
+ print("\n--- Test 3: Cluster Co-Activation Prediction ---")
153
+
154
+ state = {f"item_{i}": np.random.randn(16).astype(np.float32)
155
+ for i in range(10)}
156
+
157
+ def access_fn(wrapped):
158
+ # Cluster A: always together
159
+ _ = wrapped["item_0"]
160
+ _ = wrapped["item_1"]
161
+ _ = wrapped["item_2"]
162
+ time.sleep(0.008)
163
+
164
+ # Cluster B: always together
165
+ _ = wrapped["item_5"]
166
+ _ = wrapped["item_6"]
167
+ _ = wrapped["item_7"]
168
+ time.sleep(0.008)
169
+
170
+ # Random singletons
171
+ _ = wrapped[f"item_{np.random.choice([3, 4, 8, 9])}"]
172
+ time.sleep(0.008)
173
+
174
+ predictor, graph, train_log = generate_and_learn(
175
+ "clust", state, access_fn, train_iters=40,
176
+ causal_window_ns=3_000_000
177
+ )
178
+
179
+ predictor.print_model()
180
+
181
+ # Test: when item_0 fires, predict item_1 and item_2
182
+ preds = predictor.predict("clust.item_0")
183
+ pred_paths = {p.path for p in preds}
184
+ print(f" When item_0 fires: {[p.path.split('.')[-1] for p in preds[:5]]}")
185
+
186
+ item_1_predicted = "clust.item_1" in pred_paths
187
+ item_2_predicted = "clust.item_2" in pred_paths
188
+ print(f" item_1 predicted: {item_1_predicted}")
189
+ print(f" item_2 predicted: {item_2_predicted}")
190
+
191
+ # Score on fresh data
192
+ Membrane.clear()
193
+ wrapped = Membrane.wrap(
194
+ {k: v.copy() for k, v in state.items()}, "clust"
195
+ )
196
+ for _ in range(15):
197
+ access_fn(wrapped)
198
+
199
+ result = predictor.print_score(Membrane.get_log(), verbose=True)
200
+
201
+ assert item_1_predicted and item_2_predicted, "Should predict cluster members"
202
+ print(" PASS")
203
+
204
+
205
+ def test_inference_simulation():
206
+ """Test 4: Realistic inference — train on requests, predict on new ones.
207
+
208
+ This is the demo workload. If prediction accuracy is high here,
209
+ Condensate has legs.
210
+ """
211
+ print("\n--- Test 4: AI Inference Prediction (The Real Test) ---")
212
+
213
+ state = {
214
+ "config": {"temp": 0.7, "max_tok": 512},
215
+ }
216
+ for i in range(6):
217
+ state[f"layer_{i}"] = {
218
+ "q": np.random.randn(64, 64).astype(np.float32),
219
+ "k": np.random.randn(64, 64).astype(np.float32),
220
+ "v": np.random.randn(64, 64).astype(np.float32),
221
+ "ffn": np.random.randn(64, 256).astype(np.float32),
222
+ }
223
+ for i in range(6):
224
+ state[f"kv_{i}"] = {
225
+ "keys": np.zeros((128, 64), dtype=np.float32),
226
+ "vals": np.zeros((128, 64), dtype=np.float32),
227
+ }
228
+
229
+ def access_fn(wrapped):
230
+ # Config once
231
+ _ = wrapped["config"]["temp"]
232
+
233
+ # 5 tokens of autoregressive generation
234
+ for tok in range(5):
235
+ for layer_idx in range(6):
236
+ layer = wrapped[f"layer_{layer_idx}"]
237
+ _ = layer["q"]
238
+ _ = layer["k"]
239
+ _ = layer["v"]
240
+ kv = wrapped[f"kv_{layer_idx}"]
241
+ _ = kv["keys"]
242
+ _ = kv["vals"]
243
+ _ = layer["ffn"]
244
+ time.sleep(0.0001)
245
+
246
+ # TRAIN on 10 requests
247
+ print(" Training on 10 requests...")
248
+ predictor, graph, train_log = generate_and_learn(
249
+ "inf", state, access_fn, train_iters=10,
250
+ causal_window_ns=2_000_000
251
+ )
252
+
253
+ predictor.print_model()
254
+
255
+ # TEST on 5 new requests
256
+ print(" Testing on 5 new requests...")
257
+ Membrane.clear()
258
+
259
+ # Rebuild state for test
260
+ test_state = {}
261
+ test_state["config"] = {"temp": 0.7, "max_tok": 512}
262
+ for i in range(6):
263
+ test_state[f"layer_{i}"] = {
264
+ "q": np.random.randn(64, 64).astype(np.float32),
265
+ "k": np.random.randn(64, 64).astype(np.float32),
266
+ "v": np.random.randn(64, 64).astype(np.float32),
267
+ "ffn": np.random.randn(64, 256).astype(np.float32),
268
+ }
269
+ for i in range(6):
270
+ test_state[f"kv_{i}"] = {
271
+ "keys": np.zeros((128, 64), dtype=np.float32),
272
+ "vals": np.zeros((128, 64), dtype=np.float32),
273
+ }
274
+
275
+ wrapped = Membrane.wrap(test_state, "inf")
276
+ for _ in range(5):
277
+ access_fn(wrapped)
278
+
279
+ test_log = Membrane.get_log()
280
+ result = predictor.print_score(test_log, verbose=True)
281
+
282
+ # The moment of truth
283
+ accuracy = result["accuracy"]
284
+ print(f"\n *** INFERENCE PREDICTION ACCURACY: {accuracy}% ***")
285
+
286
+ if accuracy >= 80:
287
+ print(" EXCELLENT — Condensate can predict inference access patterns!")
288
+ print(" This means: pre-staging works. RAM condensation is viable.")
289
+ elif accuracy >= 60:
290
+ print(" GOOD — Significant prediction capability. Worth pursuing.")
291
+ elif accuracy >= 40:
292
+ print(" MODERATE — Some structure learned. Needs better substrate.")
293
+ else:
294
+ print(" LOW — Pattern too noisy or model too simple. Investigate.")
295
+
296
+ print(" PASS")
297
+
298
+
299
+ def test_prediction_vs_no_prediction():
300
+ """Test 5: Quantify the value — compare predicted vs unpredicted accesses.
301
+
302
+ Simulates what would happen with and without prediction:
303
+ - Without: every cold access = full latency (cache miss)
304
+ - With: predicted accesses = pre-staged (cache hit)
305
+
306
+ Reports the theoretical speedup.
307
+ """
308
+ print("\n--- Test 5: Prediction Value (Theoretical Speedup) ---")
309
+
310
+ state = {}
311
+ for i in range(16):
312
+ state[f"block_{i}"] = np.random.randn(128, 128).astype(np.float32)
313
+
314
+ hot_blocks = {0, 1, 2, 3} # always in RAM
315
+ cold_blocks = set(range(4, 16)) # would need paging
316
+
317
+ def access_fn(wrapped):
318
+ # Hot blocks every iteration
319
+ for i in hot_blocks:
320
+ _ = wrapped[f"block_{i}"]
321
+ time.sleep(0.001)
322
+
323
+ # Cold blocks: predictable pattern
324
+ # Phase A: blocks 4,5,6 together
325
+ _ = wrapped["block_4"]
326
+ _ = wrapped["block_5"]
327
+ _ = wrapped["block_6"]
328
+ time.sleep(0.005)
329
+
330
+ # Phase B: blocks 10,11,12 together
331
+ _ = wrapped["block_10"]
332
+ _ = wrapped["block_11"]
333
+ _ = wrapped["block_12"]
334
+ time.sleep(0.005)
335
+
336
+ # Random cold access (unpredictable)
337
+ _ = wrapped[f"block_{np.random.choice([7, 8, 9, 13, 14, 15])}"]
338
+ time.sleep(0.005)
339
+
340
+ # Train
341
+ predictor, graph, train_log = generate_and_learn(
342
+ "value", state, access_fn, train_iters=30,
343
+ causal_window_ns=3_000_000
344
+ )
345
+
346
+ # Test
347
+ Membrane.clear()
348
+ wrapped = Membrane.wrap(
349
+ {k: v.copy() for k, v in state.items()}, "value"
350
+ )
351
+ for _ in range(10):
352
+ access_fn(wrapped)
353
+
354
+ result = predictor.score(Membrane.get_log())
355
+
356
+ # Simulate latency impact
357
+ hit_rate = result["accuracy"] / 100.0
358
+ cold_access_count = result["predictions_made"]
359
+
360
+ # Latency model (simplified):
361
+ # Cache hit (predicted & pre-staged): ~100ns (RAM-HOT)
362
+ # Cache miss (unpredicted cold): ~100μs (disk page-in)
363
+ # That's a 1000x difference
364
+ hit_latency_ns = 100
365
+ miss_latency_ns = 100_000
366
+
367
+ with_prediction = (cold_access_count * hit_rate * hit_latency_ns +
368
+ cold_access_count * (1 - hit_rate) * miss_latency_ns)
369
+
370
+ without_prediction = cold_access_count * miss_latency_ns
371
+
372
+ speedup = without_prediction / with_prediction if with_prediction > 0 else 1.0
373
+
374
+ print(f"\n Cold accesses in test: {cold_access_count}")
375
+ print(f" Prediction hit rate: {result['accuracy']}%")
376
+ print(f"")
377
+ print(f" Without Condensate:")
378
+ print(f" Every cold access = {miss_latency_ns/1000:.0f}μs (page from disk)")
379
+ print(f" Total latency: {without_prediction/1e6:.1f}ms")
380
+ print(f"")
381
+ print(f" With Condensate:")
382
+ print(f" Predicted hits: {hit_latency_ns}ns (pre-staged in RAM)")
383
+ print(f" Unpredicted misses: {miss_latency_ns/1000:.0f}μs (still cold)")
384
+ print(f" Total latency: {with_prediction/1e6:.1f}ms")
385
+ print(f"")
386
+ print(f" *** THEORETICAL SPEEDUP: {speedup:.1f}x ***")
387
+
388
+ if speedup > 5:
389
+ print(f" Significant — prediction eliminates most cold-access latency")
390
+ elif speedup > 2:
391
+ print(f" Meaningful — prediction cuts cold-access latency substantially")
392
+ else:
393
+ print(f" Marginal — need better prediction or different access patterns")
394
+
395
+ print(" PASS")
396
+
397
+
398
+ if __name__ == "__main__":
399
+ print("=" * 60)
400
+ print(" CONDENSATE — Layer 2 Predictor Tests")
401
+ print("=" * 60)
402
+
403
+ test_sequential_prediction()
404
+ test_causal_chain_prediction()
405
+ test_cluster_prediction()
406
+ test_inference_simulation()
407
+ test_prediction_vs_no_prediction()
408
+
409
+ print("\n" + "=" * 60)
410
+ print(" ALL TESTS PASSED")
411
+ print("=" * 60)