| """ |
| Condensate Layer 1: Graph Builder Tests |
| |
| Tests the graph builder on access logs from the Membrane. |
| Run: python3 test_graph_builder.py |
| """ |
|
|
| import numpy as np |
| import time |
| import os |
| import sys |
|
|
| sys.path.insert(0, os.path.dirname(__file__)) |
| from membrane import Membrane |
| from graph_builder import GraphBuilder |
|
|
|
|
| def test_sequential_model(): |
| """Test 1: Sequential layer access β like a transformer forward pass. |
| Should discover: each layer is a cluster, layers chain sequentially. |
| """ |
| print("\n--- Test 1: Sequential Model Forward Pass ---") |
| Membrane.clear() |
|
|
| |
| state = {} |
| for layer in range(12): |
| state[f"layer_{layer}"] = { |
| "weight": np.random.randn(128, 128).astype(np.float32), |
| "bias": np.random.randn(128).astype(np.float32), |
| "attn_q": np.random.randn(128, 128).astype(np.float32), |
| "attn_k": np.random.randn(128, 128).astype(np.float32), |
| "attn_v": np.random.randn(128, 128).astype(np.float32), |
| } |
|
|
| wrapped = Membrane.wrap(state, "model") |
|
|
| |
| for pass_num in range(5): |
| for layer_idx in range(12): |
| layer = wrapped[f"layer_{layer_idx}"] |
| _ = layer["weight"] |
| _ = layer["bias"] |
| _ = layer["attn_q"] |
| _ = layer["attn_k"] |
| _ = layer["attn_v"] |
| time.sleep(0.0002) |
|
|
| |
| graph = GraphBuilder(causal_window_ns=2_000_000) |
| graph.build(Membrane.get_log()) |
| graph.print_analysis() |
|
|
| |
| assert len(graph.clusters) > 0, "Should find layer clusters" |
|
|
| |
| chains = graph.get_causal_chains() |
| assert len(chains) > 0, "Should find sequential chains" |
|
|
| print(" PASS") |
|
|
|
|
| def test_hot_cold_pattern(): |
| """Test 2: Hot/cold access β some regions hammered, others barely touched. |
| Should discover: clear temperature separation, cold regions compressible. |
| """ |
| print("\n--- Test 2: Hot/Cold Access Pattern ---") |
| Membrane.clear() |
|
|
| |
| state = {f"region_{i}": np.random.randn(64, 64).astype(np.float32) |
| for i in range(20)} |
| wrapped = Membrane.wrap(state, "hotcold") |
|
|
| hot = {2, 7, 13, 18} |
|
|
| for _ in range(100): |
| for i in range(20): |
| if i in hot: |
| _ = wrapped[f"region_{i}"] |
| elif np.random.random() < 0.03: |
| _ = wrapped[f"region_{i}"] |
|
|
| graph = GraphBuilder() |
| graph.build(Membrane.get_log()) |
| graph.print_analysis() |
|
|
| |
| hot_nodes = [n for n in graph.nodes.values() |
| if getattr(n, '_temp_class', '') == 'HOT'] |
| cold_nodes = [n for n in graph.nodes.values() |
| if getattr(n, '_temp_class', '') == 'COLD'] |
|
|
| print(f" HOT nodes: {len(hot_nodes)}, COLD nodes: {len(cold_nodes)}") |
| assert len(hot_nodes) >= 3, "Should identify hot regions" |
| assert len(cold_nodes) >= 1, "Should identify cold regions" |
| print(" PASS") |
|
|
|
|
| def test_causal_chains(): |
| """Test 3: Known causal chains β verify the graph discovers them. |
| This is the core capability: can we learn prefetch chains? |
| """ |
| print("\n--- Test 3: Causal Chain Discovery ---") |
| Membrane.clear() |
|
|
| state = {f"r{i}": np.random.randn(32, 32).astype(np.float32) |
| for i in range(10)} |
| wrapped = Membrane.wrap(state, "causal") |
|
|
| |
| |
| |
|
|
| for _ in range(80): |
| |
| _ = wrapped["r0"] |
| time.sleep(0.0005) |
| _ = wrapped["r2"] |
| time.sleep(0.0005) |
| _ = wrapped["r5"] |
| time.sleep(0.0005) |
| _ = wrapped["r9"] |
| time.sleep(0.001) |
|
|
| |
| _ = wrapped["r1"] |
| time.sleep(0.0005) |
| _ = wrapped["r3"] |
| time.sleep(0.0005) |
| _ = wrapped["r6"] |
| time.sleep(0.002) |
|
|
| |
| if np.random.random() > 0.5: |
| _ = wrapped[f"r{np.random.choice([4, 7, 8])}"] |
|
|
| graph = GraphBuilder(causal_window_ns=3_000_000) |
| graph.build(Membrane.get_log()) |
| graph.print_analysis() |
|
|
| |
| chains = graph.get_causal_chains(min_weight=5.0) |
| print(f"\n Chains found (weight >= 5): {len(chains)}") |
| for chain in chains: |
| path_names = [p.split(".")[-1] for p, _ in chain] |
| print(f" {' β '.join(path_names)}") |
|
|
| |
| |
| assert len(chains) >= 1, "Should discover at least one causal chain" |
| print(" PASS") |
|
|
|
|
| def test_cluster_discovery(): |
| """Test 4: Co-access clusters β groups of regions always used together. |
| These become hyperedges: promote/demote the whole group as a unit. |
| """ |
| print("\n--- Test 4: Cluster (Proto-Hyperedge) Discovery ---") |
| Membrane.clear() |
|
|
| state = {f"item_{i}": np.random.randn(16).astype(np.float32) |
| for i in range(15)} |
| wrapped = Membrane.wrap(state, "cluster") |
|
|
| |
| |
| |
| |
|
|
| for _ in range(60): |
| |
| _ = wrapped["item_0"] |
| _ = wrapped["item_1"] |
| _ = wrapped["item_2"] |
| time.sleep(0.008) |
|
|
| |
| _ = wrapped["item_5"] |
| _ = wrapped["item_6"] |
| _ = wrapped["item_7"] |
| _ = wrapped["item_8"] |
| time.sleep(0.008) |
|
|
| |
| if np.random.random() > 0.3: |
| _ = wrapped["item_10"] |
| _ = wrapped["item_11"] |
| time.sleep(0.008) |
|
|
| |
| idx = np.random.choice([3, 4, 9, 12, 13, 14]) |
| _ = wrapped[f"item_{idx}"] |
| time.sleep(0.008) |
|
|
| graph = GraphBuilder(causal_window_ns=3_000_000, cluster_threshold=0.6) |
| graph.build(Membrane.get_log()) |
| graph.print_analysis() |
|
|
| |
| print(f"\n Clusters found: {len(graph.clusters)}") |
| assert len(graph.clusters) >= 2, "Should find multiple clusters" |
|
|
| |
| cluster_a_found = False |
| for cluster in graph.clusters: |
| paths = {m.split(".")[-1] for m in cluster.members} |
| if {"item_0", "item_1", "item_2"}.issubset(paths): |
| cluster_a_found = True |
| break |
|
|
| assert cluster_a_found, "Should find cluster A (items 0,1,2)" |
| print(" Cluster A (items 0,1,2) found correctly") |
| print(" PASS") |
|
|
|
|
| def test_real_world_simulation(): |
| """Test 5: Realistic workload β simulates an AI inference server. |
| |
| Pattern: |
| - Model weights accessed sequentially (forward pass) |
| - KV cache accessed selectively (attention) |
| - Config accessed once at start |
| - Buffer reused across requests |
| """ |
| print("\n--- Test 5: Realistic AI Inference Simulation ---") |
| Membrane.clear() |
|
|
| state = { |
| "config": {"max_tokens": 512, "temperature": 0.7, "top_p": 0.9}, |
| "buffer": {"input_ids": np.zeros(512, dtype=np.int32), |
| "logits": np.zeros(32000, dtype=np.float32)}, |
| } |
| |
| for i in range(6): |
| state[f"layer_{i}"] = { |
| "q": np.random.randn(64, 64).astype(np.float32), |
| "k": np.random.randn(64, 64).astype(np.float32), |
| "v": np.random.randn(64, 64).astype(np.float32), |
| "ffn_up": np.random.randn(64, 256).astype(np.float32), |
| "ffn_down": np.random.randn(256, 64).astype(np.float32), |
| } |
| |
| for i in range(6): |
| state[f"kv_cache_{i}"] = { |
| "keys": np.zeros((512, 64), dtype=np.float32), |
| "values": np.zeros((512, 64), dtype=np.float32), |
| } |
|
|
| wrapped = Membrane.wrap(state, "server") |
|
|
| |
| for req in range(3): |
| |
| _ = wrapped["config"]["max_tokens"] |
| _ = wrapped["config"]["temperature"] |
|
|
| |
| _ = wrapped["buffer"]["input_ids"] |
|
|
| |
| for token in range(10): |
| for layer_idx in range(6): |
| |
| layer = wrapped[f"layer_{layer_idx}"] |
| _ = layer["q"] |
| _ = layer["k"] |
| _ = layer["v"] |
|
|
| |
| cache = wrapped[f"kv_cache_{layer_idx}"] |
| _ = cache["keys"] |
| _ = cache["values"] |
|
|
| |
| _ = layer["ffn_up"] |
| _ = layer["ffn_down"] |
| time.sleep(0.0001) |
|
|
| |
| _ = wrapped["buffer"]["logits"] |
|
|
| total_bytes = 0 |
| for k, v in state.items(): |
| if isinstance(v, dict): |
| for v2 in v.values(): |
| if isinstance(v2, np.ndarray): |
| total_bytes += v2.nbytes |
| elif isinstance(v2, dict): |
| for v3 in v2.values(): |
| if isinstance(v3, np.ndarray): |
| total_bytes += v3.nbytes |
| total_mb = total_bytes / 1024 / 1024 |
|
|
| print(f" Simulated: 3 requests Γ 10 tokens Γ 6 layers") |
| print(f" Total state: {total_mb:.1f} MB") |
|
|
| graph = GraphBuilder(causal_window_ns=2_000_000) |
| graph.build(Membrane.get_log()) |
| graph.print_analysis() |
|
|
| |
| graph.save(os.path.join(os.path.dirname(__file__), "inference_graph.json")) |
|
|
| |
| config_node = graph.nodes.get("server.config.max_tokens") |
| layer0_q = graph.nodes.get("server.layer_0.q") |
|
|
| if config_node and layer0_q: |
| print(f" Config accesses: {config_node.access_count} (read once per request)") |
| print(f" Layer 0 Q accesses: {layer0_q.access_count} (every token, every request)") |
| ratio = layer0_q.access_count / max(config_node.access_count, 1) |
| print(f" Ratio: {ratio:.0f}x β config is compressible, Q is not") |
|
|
| print(" PASS") |
|
|
|
|
| if __name__ == "__main__": |
| print("=" * 60) |
| print(" CONDENSATE β Layer 1 Graph Builder Tests") |
| print("=" * 60) |
|
|
| test_sequential_model() |
| test_hot_cold_pattern() |
| test_causal_chains() |
| test_cluster_discovery() |
| test_real_world_simulation() |
|
|
| print("\n" + "=" * 60) |
| print(" ALL TESTS PASSED") |
| print("=" * 60) |
|
|