juddddd commited on
Commit
196fec7
·
verified ·
1 Parent(s): b2058e2

Upload ablation/routing_ablation_experiment.py with huggingface_hub

Browse files
ablation/routing_ablation_experiment.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Routing Ablation Experiment: Readout Neutralization + Structured Interference
3
+
4
+ PURPOSE: Determine if routing genuinely improves global identity retention,
5
+ or if the prior result was due to readout alignment and noise-only interference.
6
+
7
+ PART A - READOUT NEUTRALIZATION:
8
+ Test with three readout modes to decouple write/read channels:
9
+ - uniform: mean(h_i) across all oscillators
10
+ - slow_only: mean(h_i) for tau_i >= threshold
11
+ - tau_weighted: sum(tau_i * h_i) / sum(tau_i) [original]
12
+
13
+ PART B - STRUCTURED INTERFERENCE:
14
+ Replace Gaussian noise with low-rank correlated interference.
15
+
16
+ DECISION RULE:
17
+ - If C/D dominate B under uniform readout → routing is real
18
+ - If C/D only dominate under tau_weighted → readout alignment artifact
19
+
20
+ Authors: Routing Ablation
21
+ Date: 2026-01-22
22
+ """
23
+
24
+ import numpy as np
25
+ import json
26
+ import hashlib
27
+ from typing import Dict, List, Tuple, Optional, Any
28
+ from dataclasses import dataclass
29
+ from pathlib import Path
30
+ from datetime import datetime
31
+ import sys
32
+
33
+ sys.path.insert(0, str(Path(__file__).parent.parent))
34
+
35
+ from training.fdra_oscillators import FDRAOscillatorBank, OscillatorConfig
36
+
37
+
38
+ def compute_checkpoint_hash(lambdas: np.ndarray) -> str:
39
+ return hashlib.sha256(lambdas.tobytes()).hexdigest()[:16]
40
+
41
+
42
+ @dataclass
43
+ class ParameterSnapshot:
44
+ lambdas: np.ndarray
45
+ checkpoint_hash: str
46
+ half_life_stats: Dict[str, Any]
47
+ per_oscillator_taus: List[float]
48
+ condition_name: str
49
+
50
+ @classmethod
51
+ def from_lambdas(cls, lambdas: np.ndarray, condition_name: str) -> 'ParameterSnapshot':
52
+ safe_lambdas = np.clip(lambdas, 1e-10, 1 - 1e-10)
53
+ taus = np.log(0.5) / np.log(safe_lambdas)
54
+
55
+ stats = {
56
+ "tau_min": float(np.min(taus)),
57
+ "tau_max": float(np.max(taus)),
58
+ "tau_mean": float(np.mean(taus)),
59
+ "frac_tau_ge_2048": float(np.mean(taus >= 2048)),
60
+ "n_long_range": int(np.sum(taus >= 2048)),
61
+ }
62
+
63
+ return cls(
64
+ lambdas=lambdas.copy(),
65
+ checkpoint_hash=compute_checkpoint_hash(lambdas),
66
+ half_life_stats=stats,
67
+ per_oscillator_taus=taus.tolist(),
68
+ condition_name=condition_name
69
+ )
70
+
71
+ def to_dict(self) -> Dict[str, Any]:
72
+ return {
73
+ "condition_name": self.condition_name,
74
+ "checkpoint_hash": self.checkpoint_hash,
75
+ "half_life_stats": self.half_life_stats,
76
+ }
77
+
78
+
79
+ def sample_tau_collapsed(n: int, seed: int = 42) -> np.ndarray:
80
+ rng = np.random.default_rng(seed)
81
+ taus = rng.uniform(2, 10, n)
82
+ return np.power(0.5, 1.0 / taus)
83
+
84
+
85
+ def sample_tau_anchored_tail(n: int, L: int = 4096, p_tail: float = 0.25, seed: int = 42) -> np.ndarray:
86
+ rng = np.random.default_rng(seed)
87
+
88
+ n_tail = int(n * p_tail)
89
+ n_non_tail = n - n_tail
90
+
91
+ tail_min, tail_max = 0.75 * L, 1.25 * L
92
+ log_taus_tail = rng.uniform(np.log(tail_min), np.log(tail_max), n_tail)
93
+ taus_tail = np.exp(log_taus_tail)
94
+
95
+ log_taus_non_tail = rng.uniform(np.log(1), np.log(512), n_non_tail)
96
+ taus_non_tail = np.exp(log_taus_non_tail)
97
+
98
+ taus = np.concatenate([taus_tail, taus_non_tail])
99
+ return np.power(0.5, 1.0 / taus)
100
+
101
+
102
+ class IdentityEncoderWithReadoutModes:
103
+ """
104
+ Identity encoder with configurable routing AND readout strategies.
105
+
106
+ Routing modes (write):
107
+ - "uniform": Equal weight to all oscillators
108
+ - "tau_weighted": Weight ∝ τ
109
+ - "tau_gated": Only write to τ > threshold
110
+
111
+ Readout modes (read):
112
+ - "uniform": mean(h_i) - equal weight
113
+ - "slow_only": mean(h_i for τ_i >= threshold)
114
+ - "tau_weighted": sum(τ_i * h_i) / sum(τ_i)
115
+ """
116
+
117
+ def __init__(self, dim: int = 16, routing_mode: str = "uniform", readout_mode: str = "tau_weighted"):
118
+ self.dim = dim
119
+ self.routing_mode = routing_mode
120
+ self.readout_mode = readout_mode
121
+ self.tau_threshold = 2048 # For slow_only and tau_gated
122
+
123
+ self.patterns = {
124
+ "decision_rule": self._make_pattern(0),
125
+ "normative_constraint": self._make_pattern(1),
126
+ "self_continuity": self._make_pattern(2),
127
+ }
128
+
129
+ def _make_pattern(self, idx: int) -> np.ndarray:
130
+ pattern = np.zeros(self.dim)
131
+ start = (idx * self.dim // 3) % self.dim
132
+ for i in range(self.dim // 3):
133
+ pattern[(start + i) % self.dim] = 1.0 / np.sqrt(self.dim // 3)
134
+ return pattern
135
+
136
+ def _compute_routing_weights(self, taus: np.ndarray, L: int = 4096) -> np.ndarray:
137
+ if self.routing_mode == "uniform":
138
+ return np.ones(len(taus)) / len(taus)
139
+ elif self.routing_mode == "tau_weighted":
140
+ return taus / np.sum(taus)
141
+ elif self.routing_mode == "tau_gated":
142
+ threshold = L / 4
143
+ mask = (taus > threshold).astype(float)
144
+ if np.sum(mask) == 0:
145
+ return np.ones(len(taus)) / len(taus)
146
+ return mask / np.sum(mask)
147
+ else:
148
+ raise ValueError(f"Unknown routing mode: {self.routing_mode}")
149
+
150
+ def encode(self, bank: FDRAOscillatorBank, strength: float = 1.0):
151
+ taus = bank.get_half_lives()
152
+ weights = self._compute_routing_weights(taus, bank.L)
153
+
154
+ for name, pattern in self.patterns.items():
155
+ u = np.outer(weights, pattern) * strength * len(taus)
156
+ for _ in range(10):
157
+ bank.forward(u)
158
+
159
+ def measure_identity(self, bank: FDRAOscillatorBank) -> Dict[str, float]:
160
+ """Measure identity with CONFIGURABLE readout mode."""
161
+ taus = bank.get_half_lives()
162
+
163
+ # Compute readout based on mode
164
+ if self.readout_mode == "uniform":
165
+ # Equal weight to all oscillators
166
+ slow = np.mean(bank.h, axis=0)
167
+
168
+ elif self.readout_mode == "slow_only":
169
+ # Only oscillators with τ >= threshold
170
+ mask = taus >= self.tau_threshold
171
+ if np.sum(mask) == 0:
172
+ return {name: 0.0 for name in self.patterns}
173
+ slow = np.mean(bank.h[mask], axis=0)
174
+
175
+ elif self.readout_mode == "tau_weighted":
176
+ # Original τ-weighted readout
177
+ weights = taus / np.sum(taus)
178
+ weighted_h = bank.h * weights[:, np.newaxis]
179
+ slow = np.sum(weighted_h, axis=0)
180
+
181
+ else:
182
+ raise ValueError(f"Unknown readout mode: {self.readout_mode}")
183
+
184
+ slow_norm = np.linalg.norm(slow)
185
+ if slow_norm < 1e-10:
186
+ return {name: 0.0 for name in self.patterns}
187
+
188
+ alignments = {}
189
+ for name, pattern in self.patterns.items():
190
+ alignment = np.dot(slow, pattern) / slow_norm
191
+ alignments[name] = max(0, float(alignment))
192
+
193
+ return alignments
194
+
195
+
196
+ class StructuredInterference:
197
+ """
198
+ Generate structured (non-Gaussian) interference.
199
+
200
+ Options:
201
+ - "gaussian": Original i.i.d. Gaussian noise
202
+ - "low_rank": Low-rank correlated interference (A @ v(t))
203
+ - "repeating": Repeating pattern interference
204
+ """
205
+
206
+ def __init__(self, n: int, d: int, mode: str = "gaussian", seed: int = 42):
207
+ self.n = n
208
+ self.d = d
209
+ self.mode = mode
210
+ self.rng = np.random.default_rng(seed)
211
+
212
+ if mode == "low_rank":
213
+ # Create low-rank projection matrix (rank 4)
214
+ self.rank = 4
215
+ self.A = self.rng.standard_normal((n * d, self.rank)) / np.sqrt(self.rank)
216
+ self.v_state = self.rng.standard_normal(self.rank) # AR(1) state
217
+ self.ar_coef = 0.9 # Autocorrelation
218
+
219
+ elif mode == "repeating":
220
+ # Create repeating pattern
221
+ self.period = 32
222
+ self.patterns = [self.rng.standard_normal((n, d)) * 0.5 for _ in range(self.period)]
223
+ self.t = 0
224
+
225
+ def generate(self) -> np.ndarray:
226
+ if self.mode == "gaussian":
227
+ return self.rng.standard_normal((self.n, self.d)) * 0.5
228
+
229
+ elif self.mode == "low_rank":
230
+ # AR(1) process for v(t)
231
+ self.v_state = self.ar_coef * self.v_state + np.sqrt(1 - self.ar_coef**2) * self.rng.standard_normal(self.rank)
232
+ # Project to full space
233
+ flat = (self.A @ self.v_state).reshape(self.n, self.d)
234
+ return flat * 0.5
235
+
236
+ elif self.mode == "repeating":
237
+ pattern = self.patterns[self.t % self.period]
238
+ self.t += 1
239
+ return pattern
240
+
241
+ else:
242
+ raise ValueError(f"Unknown interference mode: {self.mode}")
243
+
244
+
245
+ class RoutingAblationExperiment:
246
+ """
247
+ Ablation experiment for routing with:
248
+ - Multiple readout modes
249
+ - Multiple interference types
250
+ """
251
+
252
+ def __init__(
253
+ self,
254
+ num_oscillators: int = 32,
255
+ state_dim: int = 16,
256
+ sequence_length: int = 4096
257
+ ):
258
+ self.n = num_oscillators
259
+ self.d = state_dim
260
+ self.L = sequence_length
261
+
262
+ self.osc_config = OscillatorConfig(
263
+ num_oscillators=num_oscillators,
264
+ state_dim=state_dim,
265
+ sequence_length=sequence_length
266
+ )
267
+
268
+ self.k_values = [0, 64, 128, 256, 512, 1024, 2048, 4096]
269
+ self.output_dir = Path("outputs/routing_ablation")
270
+ self.output_dir.mkdir(parents=True, exist_ok=True)
271
+
272
+ def run_identity_trial(
273
+ self,
274
+ snapshot: ParameterSnapshot,
275
+ encoder: IdentityEncoderWithReadoutModes,
276
+ interference: StructuredInterference,
277
+ k: int,
278
+ seed: int
279
+ ) -> Dict[str, Any]:
280
+
281
+ bank = FDRAOscillatorBank(self.osc_config)
282
+ bank.lambdas = snapshot.lambdas.copy()
283
+ bank.reset()
284
+
285
+ # Encode
286
+ encoder.encode(bank, strength=1.0)
287
+
288
+ # Measure pre
289
+ pre_identity = encoder.measure_identity(bank)
290
+ pre_score = np.mean(list(pre_identity.values()))
291
+
292
+ if pre_score < 0.2:
293
+ return {
294
+ "k": k, "seed": seed,
295
+ "pre_score": float(pre_score),
296
+ "post_score": 0.0,
297
+ "retention": 0.0,
298
+ "identity_preserved": False,
299
+ "encoding_failed": True
300
+ }
301
+
302
+ # Interference
303
+ interference.rng = np.random.default_rng(seed) # Reset for reproducibility
304
+ for _ in range(k):
305
+ noise = interference.generate()
306
+ bank.forward(noise)
307
+
308
+ # Measure post
309
+ post_identity = encoder.measure_identity(bank)
310
+ post_score = np.mean(list(post_identity.values()))
311
+ retention = post_score / pre_score if pre_score > 0 else 0.0
312
+
313
+ return {
314
+ "k": k, "seed": seed,
315
+ "pre_score": float(pre_score),
316
+ "post_score": float(post_score),
317
+ "retention": float(retention),
318
+ "identity_preserved": retention >= 0.5,
319
+ "encoding_failed": False
320
+ }
321
+
322
+ def run_sweep(
323
+ self,
324
+ snapshot: ParameterSnapshot,
325
+ encoder: IdentityEncoderWithReadoutModes,
326
+ interference_mode: str,
327
+ condition_name: str,
328
+ seeds: List[int],
329
+ n_trials: int = 8
330
+ ) -> Dict[str, Any]:
331
+
332
+ print(f"\n {condition_name}")
333
+ print(f" Routing: {encoder.routing_mode}, Readout: {encoder.readout_mode}, Interference: {interference_mode}")
334
+
335
+ all_trials = []
336
+ preservation_curve = []
337
+
338
+ for k in self.k_values:
339
+ k_trials = []
340
+ for seed in seeds:
341
+ for t in range(n_trials):
342
+ interference = StructuredInterference(self.n, self.d, interference_mode, seed * 1000 + t)
343
+ trial = self.run_identity_trial(snapshot, encoder, interference, k, seed * 1000 + t)
344
+ k_trials.append(trial)
345
+ all_trials.append(trial)
346
+
347
+ preserved_rate = np.mean([1 if t["identity_preserved"] else 0 for t in k_trials])
348
+ mean_retention = np.mean([t["retention"] for t in k_trials])
349
+
350
+ preservation_curve.append({
351
+ "k": k,
352
+ "preserved_rate": float(preserved_rate),
353
+ "mean_retention": float(mean_retention)
354
+ })
355
+
356
+ # Basin widths
357
+ bw80 = max([p["k"] for p in preservation_curve if p["preserved_rate"] >= 0.8], default=0)
358
+ bw50 = max([p["k"] for p in preservation_curve if p["preserved_rate"] >= 0.5], default=0)
359
+
360
+ # Print summary
361
+ print(f" Basin width (80%): {bw80}, (50%): {bw50}")
362
+
363
+ return {
364
+ "condition_name": condition_name,
365
+ "routing_mode": encoder.routing_mode,
366
+ "readout_mode": encoder.readout_mode,
367
+ "interference_mode": interference_mode,
368
+ "analysis": {
369
+ "preservation_curve": preservation_curve,
370
+ "basin_width_80": bw80,
371
+ "basin_width_50": bw50
372
+ }
373
+ }
374
+
375
+ def run_part_a(self, seeds: List[int] = [42, 137, 256], n_trials: int = 8) -> Dict[str, Any]:
376
+ """Part A: Readout Neutralization - 4 conditions × 3 readout modes"""
377
+
378
+ print("=" * 70)
379
+ print("PART A: READOUT NEUTRALIZATION")
380
+ print("=" * 70)
381
+ print("\nQuestion: Does routing advantage hold under different readout modes?")
382
+ print()
383
+
384
+ # Create snapshots
385
+ collapsed = ParameterSnapshot.from_lambdas(sample_tau_collapsed(self.n), "collapsed")
386
+ anchored = ParameterSnapshot.from_lambdas(sample_tau_anchored_tail(self.n, self.L), "anchored_tail")
387
+
388
+ # Define conditions
389
+ routing_conditions = [
390
+ ("A", collapsed, "uniform"),
391
+ ("B", anchored, "uniform"),
392
+ ("C", anchored, "tau_weighted"),
393
+ ("D", anchored, "tau_gated"),
394
+ ]
395
+
396
+ readout_modes = ["uniform", "slow_only", "tau_weighted"]
397
+
398
+ results = {}
399
+
400
+ for readout_mode in readout_modes:
401
+ print(f"\n--- Readout mode: {readout_mode} ---")
402
+ results[readout_mode] = {}
403
+
404
+ for cond_name, snapshot, routing_mode in routing_conditions:
405
+ encoder = IdentityEncoderWithReadoutModes(self.d, routing_mode, readout_mode)
406
+ result = self.run_sweep(
407
+ snapshot, encoder, "gaussian",
408
+ f"{cond_name}) {snapshot.condition_name} + {routing_mode}",
409
+ seeds, n_trials
410
+ )
411
+ results[readout_mode][cond_name] = result
412
+
413
+ # Generate 3×4 table
414
+ print("\n" + "=" * 70)
415
+ print("BASIN WIDTH TABLE (80% threshold)")
416
+ print("=" * 70)
417
+ print(f"\n{'Readout':<15} | {'A':>6} | {'B':>6} | {'C':>6} | {'D':>6}")
418
+ print("-" * 50)
419
+
420
+ for readout_mode in readout_modes:
421
+ row = f"{readout_mode:<15} |"
422
+ for cond in ["A", "B", "C", "D"]:
423
+ bw = results[readout_mode][cond]["analysis"]["basin_width_80"]
424
+ row += f" {bw:>5} |"
425
+ print(row)
426
+
427
+ print("\n" + "=" * 70)
428
+ print("BASIN WIDTH TABLE (50% threshold)")
429
+ print("=" * 70)
430
+ print(f"\n{'Readout':<15} | {'A':>6} | {'B':>6} | {'C':>6} | {'D':>6}")
431
+ print("-" * 50)
432
+
433
+ for readout_mode in readout_modes:
434
+ row = f"{readout_mode:<15} |"
435
+ for cond in ["A", "B", "C", "D"]:
436
+ bw = results[readout_mode][cond]["analysis"]["basin_width_50"]
437
+ row += f" {bw:>5} |"
438
+ print(row)
439
+
440
+ # Decision
441
+ print("\n" + "=" * 70)
442
+ print("DECISION (Part A)")
443
+ print("=" * 70)
444
+
445
+ # Check if C/D dominate B under uniform readout
446
+ bw_B_uniform = results["uniform"]["B"]["analysis"]["basin_width_50"]
447
+ bw_C_uniform = results["uniform"]["C"]["analysis"]["basin_width_50"]
448
+ bw_D_uniform = results["uniform"]["D"]["analysis"]["basin_width_50"]
449
+
450
+ if bw_C_uniform > bw_B_uniform * 1.5 or bw_D_uniform > bw_B_uniform * 1.5:
451
+ verdict = "ROUTING_GENUINE"
452
+ explanation = (
453
+ f"C/D dominate B even under UNIFORM readout:\n"
454
+ f" B (uniform readout): {bw_B_uniform}\n"
455
+ f" C (uniform readout): {bw_C_uniform}\n"
456
+ f" D (uniform readout): {bw_D_uniform}\n"
457
+ f"→ Routing genuinely improves global retention, not just aligned readout."
458
+ )
459
+ else:
460
+ verdict = "READOUT_ARTIFACT"
461
+ explanation = (
462
+ f"C/D advantage disappears under UNIFORM readout:\n"
463
+ f" B (uniform readout): {bw_B_uniform}\n"
464
+ f" C (uniform readout): {bw_C_uniform}\n"
465
+ f" D (uniform readout): {bw_D_uniform}\n"
466
+ f"→ Prior result was partially readout alignment artifact.\n"
467
+ f"→ Identity concentrates in slow modes but leaks elsewhere."
468
+ )
469
+
470
+ print(f"\n Verdict: {verdict}")
471
+ print(f"\n {explanation}")
472
+
473
+ return {
474
+ "results": results,
475
+ "verdict": verdict,
476
+ "explanation": explanation
477
+ }
478
+
479
+ def run_part_b(self, seeds: List[int] = [42, 137, 256], n_trials: int = 8) -> Dict[str, Any]:
480
+ """Part B: Structured Interference - B vs C with different interference types"""
481
+
482
+ print("\n" + "=" * 70)
483
+ print("PART B: STRUCTURED INTERFERENCE")
484
+ print("=" * 70)
485
+ print("\nQuestion: Does routing hold against structured (non-Gaussian) interference?")
486
+ print()
487
+
488
+ anchored = ParameterSnapshot.from_lambdas(sample_tau_anchored_tail(self.n, self.L), "anchored_tail")
489
+
490
+ interference_modes = ["gaussian", "low_rank", "repeating"]
491
+
492
+ results = {}
493
+
494
+ for interference_mode in interference_modes:
495
+ print(f"\n--- Interference mode: {interference_mode} ---")
496
+ results[interference_mode] = {}
497
+
498
+ # Only B and C, uniform readout
499
+ for cond_name, routing_mode in [("B", "uniform"), ("C", "tau_weighted")]:
500
+ encoder = IdentityEncoderWithReadoutModes(self.d, routing_mode, "uniform")
501
+ result = self.run_sweep(
502
+ anchored, encoder, interference_mode,
503
+ f"{cond_name}) anchored + {routing_mode}",
504
+ seeds, n_trials
505
+ )
506
+ results[interference_mode][cond_name] = result
507
+
508
+ # Table
509
+ print("\n" + "=" * 70)
510
+ print("STRUCTURED INTERFERENCE TABLE (uniform readout, 50% threshold)")
511
+ print("=" * 70)
512
+ print(f"\n{'Interference':<15} | {'B':>6} | {'C':>6} | {'Delta':>8}")
513
+ print("-" * 45)
514
+
515
+ for interference_mode in interference_modes:
516
+ bw_B = results[interference_mode]["B"]["analysis"]["basin_width_50"]
517
+ bw_C = results[interference_mode]["C"]["analysis"]["basin_width_50"]
518
+ delta = bw_C - bw_B
519
+ print(f"{interference_mode:<15} | {bw_B:>5} | {bw_C:>5} | {delta:>+7}")
520
+
521
+ # Decision
522
+ print("\n" + "=" * 70)
523
+ print("DECISION (Part B)")
524
+ print("=" * 70)
525
+
526
+ # Check if routing holds across interference types
527
+ holds_count = 0
528
+ for interference_mode in interference_modes:
529
+ bw_B = results[interference_mode]["B"]["analysis"]["basin_width_50"]
530
+ bw_C = results[interference_mode]["C"]["analysis"]["basin_width_50"]
531
+ if bw_C > bw_B:
532
+ holds_count += 1
533
+
534
+ if holds_count == len(interference_modes):
535
+ verdict = "ROUTING_ROBUST"
536
+ explanation = "Routing advantage holds across ALL interference types."
537
+ elif holds_count > 0:
538
+ verdict = "ROUTING_PARTIAL"
539
+ explanation = f"Routing advantage holds for {holds_count}/{len(interference_modes)} interference types."
540
+ else:
541
+ verdict = "ROUTING_FRAGILE"
542
+ explanation = "Routing advantage disappears under structured interference."
543
+
544
+ print(f"\n Verdict: {verdict}")
545
+ print(f"\n {explanation}")
546
+
547
+ return {
548
+ "results": results,
549
+ "verdict": verdict,
550
+ "explanation": explanation
551
+ }
552
+
553
+ def run_full_ablation(self) -> Dict[str, Any]:
554
+ """Run complete ablation experiment."""
555
+
556
+ print("=" * 70)
557
+ print("ROUTING ABLATION EXPERIMENT")
558
+ print("=" * 70)
559
+ print("\nThis experiment tests if the routing breakthrough is genuine or artifact.")
560
+ print("=" * 70)
561
+
562
+ seeds = [42, 137, 256]
563
+ n_trials = 8
564
+
565
+ part_a = self.run_part_a(seeds, n_trials)
566
+ part_b = self.run_part_b(seeds, n_trials)
567
+
568
+ # Overall verdict
569
+ print("\n" + "=" * 70)
570
+ print("OVERALL VERDICT")
571
+ print("=" * 70)
572
+
573
+ if part_a["verdict"] == "ROUTING_GENUINE" and part_b["verdict"] == "ROUTING_ROBUST":
574
+ overall = "ROUTING_CONFIRMED"
575
+ overall_explanation = (
576
+ "Routing is GENUINE and ROBUST:\n"
577
+ " - Holds under uniform readout (not alignment artifact)\n"
578
+ " - Holds under structured interference (not noise-specific)\n"
579
+ "→ Ready to integrate into FDRA training."
580
+ )
581
+ elif part_a["verdict"] == "ROUTING_GENUINE":
582
+ overall = "ROUTING_CONFIRMED_PARTIAL"
583
+ overall_explanation = (
584
+ "Routing is GENUINE but may be noise-specific:\n"
585
+ " - Holds under uniform readout ✓\n"
586
+ " - May degrade under structured interference\n"
587
+ "→ Worth integrating but monitor under real conditions."
588
+ )
589
+ elif part_b["verdict"] in ["ROUTING_ROBUST", "ROUTING_PARTIAL"]:
590
+ overall = "ROUTING_LIMITED"
591
+ overall_explanation = (
592
+ "Routing helps but has readout alignment component:\n"
593
+ " - Advantage reduced under uniform readout\n"
594
+ " - Still helps under some interference types\n"
595
+ "→ Need auxiliary loss or architectural enforcement."
596
+ )
597
+ else:
598
+ overall = "ROUTING_ARTIFACT"
599
+ overall_explanation = (
600
+ "Prior routing result was largely artifact:\n"
601
+ " - Disappears under uniform readout\n"
602
+ " - Fragile to structured interference\n"
603
+ "→ Need fundamentally different approach."
604
+ )
605
+
606
+ print(f"\n {overall}")
607
+ print(f"\n {overall_explanation}")
608
+ print("=" * 70)
609
+
610
+ # Save results
611
+ full_results = {
612
+ "timestamp": datetime.now().isoformat(),
613
+ "experiment": "routing_ablation",
614
+ "part_a": part_a,
615
+ "part_b": part_b,
616
+ "overall": {
617
+ "verdict": overall,
618
+ "explanation": overall_explanation
619
+ }
620
+ }
621
+
622
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
623
+ with open(self.output_dir / f"routing_ablation_{ts}.json", "w") as f:
624
+ json.dump(full_results, f, indent=2, default=str)
625
+
626
+ # Generate report
627
+ report = self._generate_report(full_results)
628
+ with open(self.output_dir / f"ABLATION_REPORT_{ts}.md", "w") as f:
629
+ f.write(report)
630
+
631
+ print(f"\nResults saved to: {self.output_dir}/")
632
+
633
+ return full_results
634
+
635
+ def _generate_report(self, results: Dict[str, Any]) -> str:
636
+ report = f"""# Routing Ablation Experiment
637
+
638
+ **Date:** {results['timestamp']}
639
+
640
+ ## Purpose
641
+
642
+ Test if the routing breakthrough is genuine or an artifact of:
643
+ 1. Readout alignment (τ-weighted write + τ-weighted read)
644
+ 2. Noise-only interference (Gaussian vs structured)
645
+
646
+ ## Part A: Readout Neutralization
647
+
648
+ ### Question
649
+ Does C/D dominate B under UNIFORM readout (not just τ-weighted)?
650
+
651
+ ### Results
652
+
653
+ **Basin Width Table (50% threshold)**
654
+
655
+ | Readout | A | B | C | D |
656
+ |---------|---|---|---|---|
657
+ """
658
+ for readout_mode in ["uniform", "slow_only", "tau_weighted"]:
659
+ row = f"| {readout_mode} |"
660
+ for cond in ["A", "B", "C", "D"]:
661
+ bw = results["part_a"]["results"][readout_mode][cond]["analysis"]["basin_width_50"]
662
+ row += f" {bw} |"
663
+ report += row + "\n"
664
+
665
+ report += f"""
666
+ ### Verdict: {results['part_a']['verdict']}
667
+
668
+ {results['part_a']['explanation']}
669
+
670
+ ## Part B: Structured Interference
671
+
672
+ ### Question
673
+ Does routing hold against low-rank correlated and repeating interference?
674
+
675
+ ### Results
676
+
677
+ **Basin Width Table (uniform readout, 50% threshold)**
678
+
679
+ | Interference | B | C | Delta |
680
+ |--------------|---|---|-------|
681
+ """
682
+ for interference_mode in ["gaussian", "low_rank", "repeating"]:
683
+ bw_B = results["part_b"]["results"][interference_mode]["B"]["analysis"]["basin_width_50"]
684
+ bw_C = results["part_b"]["results"][interference_mode]["C"]["analysis"]["basin_width_50"]
685
+ delta = bw_C - bw_B
686
+ report += f"| {interference_mode} | {bw_B} | {bw_C} | {delta:+d} |\n"
687
+
688
+ report += f"""
689
+ ### Verdict: {results['part_b']['verdict']}
690
+
691
+ {results['part_b']['explanation']}
692
+
693
+ ## Overall Verdict
694
+
695
+ **{results['overall']['verdict']}**
696
+
697
+ {results['overall']['explanation']}
698
+
699
+ ---
700
+
701
+ *Report generated by routing_ablation_experiment.py*
702
+ """
703
+ return report
704
+
705
+
706
+ def run_ablation():
707
+ experiment = RoutingAblationExperiment(
708
+ num_oscillators=32,
709
+ state_dim=16,
710
+ sequence_length=4096
711
+ )
712
+ return experiment.run_full_ablation()
713
+
714
+
715
+ if __name__ == "__main__":
716
+ run_ablation()