Spaces:
Runtime error
Runtime error
File size: 8,581 Bytes
efd23fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 | """
Condensate Layer 0: Membrane Tests
Tests the membrane wrapper on increasingly realistic workloads.
Run: python3 test_membrane.py
"""
import numpy as np
import time
import os
import sys
# Add parent dir to path so we can import membrane
sys.path.insert(0, os.path.dirname(__file__))
from membrane import Membrane
def test_basic_dict():
"""Test 1: Basic dict access tracking."""
print("\n--- Test 1: Basic Dict Access ---")
Membrane.clear()
data = Membrane.wrap({
"name": "test",
"values": [1, 2, 3, 4, 5],
"nested": {"a": 10, "b": 20, "c": 30},
}, "basic")
# Read some values
_ = data["name"]
_ = data["name"] # same key twice
_ = data["values"]
_ = data["nested"]["a"] # nested read β should log both levels
_ = data["nested"]["b"]
# Write
data["name"] = "updated"
assert Membrane.entry_count() > 0, "Should have recorded accesses"
Membrane.print_stats()
print(" PASS")
def test_numpy_arrays():
"""Test 2: Dict of numpy arrays β simulates model weight storage."""
print("\n--- Test 2: NumPy Array State (Simulated Model Weights) ---")
Membrane.clear()
# Simulate a small model with layers of weight matrices
state = {}
for layer in range(8):
state[f"layer_{layer}"] = {
"weight": np.random.randn(256, 256).astype(np.float32),
"bias": np.random.randn(256).astype(np.float32),
"attention": {
"q_proj": np.random.randn(256, 256).astype(np.float32),
"k_proj": np.random.randn(256, 256).astype(np.float32),
"v_proj": np.random.randn(256, 256).astype(np.float32),
}
}
wrapped = Membrane.wrap(state, "model")
total_bytes = sum(
state[f"layer_{i}"]["weight"].nbytes +
state[f"layer_{i}"]["bias"].nbytes +
sum(v.nbytes for v in state[f"layer_{i}"]["attention"].values())
for i in range(8)
)
print(f" Model state: {total_bytes / 1024 / 1024:.1f} MB across 8 layers")
# Simulate a forward pass β sequential layer access
print(" Simulating forward pass...")
for layer_idx in range(8):
layer = wrapped[f"layer_{layer_idx}"]
w = layer["weight"]
b = layer["bias"]
attn = layer["attention"]
q = attn["q_proj"]
k = attn["k_proj"]
v = attn["v_proj"]
# Simulate a second forward pass β same pattern
print(" Simulating second forward pass...")
for layer_idx in range(8):
layer = wrapped[f"layer_{layer_idx}"]
w = layer["weight"]
b = layer["bias"]
attn = layer["attention"]
q = attn["q_proj"]
k = attn["k_proj"]
v = attn["v_proj"]
Membrane.print_stats()
print(" PASS")
def test_selective_access():
"""Test 3: Selective access β some layers hot, some cold.
This is the pattern Condensate exploits: not all state is accessed equally.
"""
print("\n--- Test 3: Selective Access (Hot/Cold Pattern) ---")
Membrane.clear()
state = {}
for layer in range(16):
state[f"layer_{layer}"] = {
"weight": np.random.randn(128, 128).astype(np.float32),
"bias": np.random.randn(128).astype(np.float32),
}
wrapped = Membrane.wrap(state, "selective")
# Simulate: layers 3, 7, 11 are "hot" β accessed 10x more
hot_layers = {3, 7, 11}
for iteration in range(20):
for layer_idx in range(16):
if layer_idx in hot_layers:
# Hot path β always accessed
layer = wrapped[f"layer_{layer_idx}"]
_ = layer["weight"]
_ = layer["bias"]
elif iteration % 10 == 0:
# Cold path β accessed once every 10 iterations
layer = wrapped[f"layer_{layer_idx}"]
_ = layer["weight"]
stats = Membrane.stats()
Membrane.print_stats()
# Verify hot layers have more accesses
hot_count = sum(
stats["paths"].get(f"selective.layer_{i}", {}).get("reads", 0)
for i in hot_layers
)
cold_count = sum(
stats["paths"].get(f"selective.layer_{i}", {}).get("reads", 0)
for i in range(16) if i not in hot_layers
)
ratio = hot_count / max(cold_count, 1)
print(f" Hot/cold access ratio: {ratio:.1f}x")
print(f" (This ratio is what Condensate exploits β hot stays in RAM, cold compresses)")
print(" PASS")
def test_temporal_chains():
"""Test 4: Temporal access chains β A always followed by B followed by C.
This is what the SNN will learn as causal chains for prefetch.
"""
print("\n--- Test 4: Temporal Chains (Causal Access Patterns) ---")
Membrane.clear()
state = {f"region_{i}": np.random.randn(64, 64).astype(np.float32) for i in range(10)}
wrapped = Membrane.wrap(state, "temporal")
# Chain 1: 0 β 3 β 7 (always in this order)
# Chain 2: 1 β 4 β 8 (always in this order)
# Region 5: random, no chain
chains = [
[0, 3, 7],
[1, 4, 8],
]
for _ in range(50):
for chain in chains:
for region_id in chain:
_ = wrapped[f"region_{region_id}"]
time.sleep(0.0001) # tiny delay to separate timestamps
# Random access to region 5
if np.random.random() > 0.5:
_ = wrapped["region_5"]
stats = Membrane.stats()
Membrane.print_stats()
# Check co-accesses β chain members should co-access heavily
coaccesses = stats.get("top_coaccesses", [])
if coaccesses:
print(f" Top co-access pairs found: {len(coaccesses)}")
print(f" (These are the causal chains the SNN would learn)")
print(" PASS")
def test_overhead():
"""Test 5: Measure the membrane's overhead.
This tells us if the observation layer is cheap enough.
"""
print("\n--- Test 5: Overhead Measurement ---")
state = {f"key_{i}": np.random.randn(32).astype(np.float32) for i in range(100)}
# Baseline: raw dict access
iterations = 100_000
start = time.monotonic_ns()
for _ in range(iterations):
for key in ["key_0", "key_50", "key_99"]:
_ = state[key]
raw_ns = time.monotonic_ns() - start
# Membrane: wrapped dict access
Membrane.clear()
wrapped = Membrane.wrap(state.copy(), "overhead")
start = time.monotonic_ns()
for _ in range(iterations):
for key in ["key_0", "key_50", "key_99"]:
_ = wrapped[key]
membrane_ns = time.monotonic_ns() - start
raw_per = raw_ns / (iterations * 3)
membrane_per = membrane_ns / (iterations * 3)
overhead = membrane_per - raw_per
print(f" Raw dict access: {raw_per:.0f} ns/access")
print(f" Membrane access: {membrane_per:.0f} ns/access")
print(f" Overhead per access: {overhead:.0f} ns")
print(f" Slowdown factor: {membrane_per / raw_per:.1f}x")
print(f" Total accesses logged: {Membrane.entry_count()}")
# The membrane is for observation only β overhead is acceptable
# if it's under ~1ΞΌs per access. For production, the Rust core
# will bring this to ~5ns.
if overhead < 5000:
print(f" Overhead acceptable for PoC (< 5ΞΌs)")
else:
print(f" Overhead high β expected for Python, Rust core will fix")
print(" PASS")
def test_save_log():
"""Test 6: Save the access log for Layer 1 analysis."""
print("\n--- Test 6: Save Log ---")
Membrane.clear()
state = {f"region_{i}": np.random.randn(64, 64).astype(np.float32) for i in range(5)}
wrapped = Membrane.wrap(state, "saveable")
# Generate some access patterns
for _ in range(10):
_ = wrapped["region_0"]
_ = wrapped["region_2"]
_ = wrapped["region_4"]
log_path = os.path.join(os.path.dirname(__file__), "test_access_log.json")
Membrane.save_log(log_path)
# Verify file exists and is valid JSON
import json
with open(log_path) as f:
data = json.load(f)
assert "entries" in data
assert len(data["entries"]) == 30 # 3 accesses x 10 iterations
# Clean up
os.remove(log_path)
print(" PASS")
if __name__ == "__main__":
print("=" * 60)
print(" CONDENSATE β Layer 0 Membrane Tests")
print("=" * 60)
test_basic_dict()
test_numpy_arrays()
test_selective_access()
test_temporal_chains()
test_overhead()
test_save_log()
print("\n" + "=" * 60)
print(" ALL TESTS PASSED")
print("=" * 60)
|