juddddd commited on
Commit
bbbfe5d
·
verified ·
1 Parent(s): 88dae09

Upload training_validation/fdra_oscillators_with_routing.py with huggingface_hub

Browse files
training_validation/fdra_oscillators_with_routing.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FDRA Oscillator Implementation with Explicit Decay Parameters
3
+
4
+ This implements the core FDRA oscillator dynamics where each oscillator has:
5
+ - A decay parameter λ_i ∈ (0, 1)
6
+ - Half-life τ_i = ln(0.5) / ln(λ_i)
7
+
8
+ The key problem this addresses (from Melanie/Tiago's discovery):
9
+ - During training at GPT-2 scale, all λ_i collapse to near 1.0 (very short half-lives)
10
+ - This means oscillators only attend to ~10 tokens instead of full context length
11
+ - The model works for short-context tasks but fails on long-context reasoning
12
+
13
+ Solution: Half-life regularization to maintain diversity across temporal scales.
14
+
15
+ Authors: FDRA Half-Life Regularization Implementation
16
+ Date: 2026-01-22
17
+ """
18
+
19
+ import numpy as np
20
+ from typing import Dict, List, Tuple, Optional, Any
21
+ from dataclasses import dataclass
22
+ import json
23
+ from pathlib import Path
24
+
25
+
26
+ @dataclass
27
+ class OscillatorConfig:
28
+ """Configuration for FDRA oscillator bank."""
29
+ num_oscillators: int = 32 # Number of oscillators
30
+ state_dim: int = 16 # Dimension per oscillator
31
+ sequence_length: int = 4096 # Max sequence length (L)
32
+ tau_min: float = 1.0 # Minimum half-life
33
+ tau_max: float = 4096.0 # Maximum half-life (typically = L)
34
+
35
+ # Initialization
36
+ init_method: str = "log_uniform" # "log_uniform" or "random"
37
+
38
+
39
+ @dataclass
40
+ class OscillatorState:
41
+ """State of an oscillator bank."""
42
+ h: np.ndarray # Hidden states: (num_oscillators, state_dim)
43
+ lambdas: np.ndarray # Decay parameters: (num_oscillators,)
44
+
45
+ def copy(self) -> 'OscillatorState':
46
+ return OscillatorState(
47
+ h=self.h.copy(),
48
+ lambdas=self.lambdas.copy()
49
+ )
50
+
51
+
52
+ class FDRAOscillatorBank:
53
+ """
54
+ FDRA Oscillator Bank with explicit decay parameters.
55
+
56
+ Each oscillator i has:
57
+ h_i(t+1) = λ_i * h_i(t) + u_i(t)
58
+
59
+ Where:
60
+ λ_i ∈ (0, 1) is the decay parameter
61
+ τ_i = ln(0.5) / ln(λ_i) is the half-life
62
+
63
+ Half-life interpretation:
64
+ τ_i = number of steps for oscillator state to decay to 50%
65
+
66
+ The goal of half-life regularization:
67
+ Maintain log-uniform distribution of τ_i across [τ_min, τ_max]
68
+ This ensures oscillators can attend to both short and long contexts.
69
+ """
70
+
71
+ def __init__(self, config: OscillatorConfig):
72
+ self.config = config
73
+ self.n = config.num_oscillators
74
+ self.d = config.state_dim
75
+ self.L = config.sequence_length
76
+
77
+ # Initialize decay parameters
78
+ self.lambdas = self._init_lambdas()
79
+
80
+ # Initialize hidden states
81
+ self.h = np.zeros((self.n, self.d))
82
+
83
+ # Track history for analysis
84
+ self.history: List[Dict[str, Any]] = []
85
+
86
+ def _init_lambdas(self) -> np.ndarray:
87
+ """
88
+ Initialize decay parameters λ_i.
89
+
90
+ For log-uniform half-lives, we want:
91
+ τ_i ~ LogUniform(τ_min, τ_max)
92
+
93
+ Since τ = ln(0.5) / ln(λ), we have:
94
+ λ = 0.5^(1/τ)
95
+
96
+ So for log-uniform τ:
97
+ log(τ) ~ Uniform(log(τ_min), log(τ_max))
98
+ τ = exp(log_τ)
99
+ λ = 0.5^(1/τ)
100
+ """
101
+ if self.config.init_method == "log_uniform":
102
+ # Log-uniform distribution of half-lives
103
+ log_tau_min = np.log(self.config.tau_min)
104
+ log_tau_max = np.log(self.config.tau_max)
105
+
106
+ # Evenly spaced in log space
107
+ log_taus = np.linspace(log_tau_min, log_tau_max, self.n)
108
+ taus = np.exp(log_taus)
109
+
110
+ # Convert half-lives to decay parameters
111
+ # λ = exp(ln(0.5) / τ) = 0.5^(1/τ)
112
+ lambdas = np.power(0.5, 1.0 / taus)
113
+
114
+ else:
115
+ # Random initialization (not recommended)
116
+ lambdas = np.random.uniform(0.5, 0.99, self.n)
117
+
118
+ return lambdas
119
+
120
+ def get_half_lives(self) -> np.ndarray:
121
+ """
122
+ Compute half-lives from decay parameters.
123
+
124
+ τ_i = ln(0.5) / ln(λ_i)
125
+ """
126
+ # Clamp lambdas to avoid log(1) = 0
127
+ safe_lambdas = np.clip(self.lambdas, 1e-10, 1.0 - 1e-10)
128
+ taus = np.log(0.5) / np.log(safe_lambdas)
129
+ return taus
130
+
131
+ def get_log_half_lives(self) -> np.ndarray:
132
+ """Get log of half-lives: z_i = log(τ_i)."""
133
+ return np.log(self.get_half_lives())
134
+
135
+ def forward(self, u: np.ndarray) -> np.ndarray:
136
+ """
137
+ One step of oscillator dynamics.
138
+
139
+ h_i(t+1) = λ_i * h_i(t) + u_i(t)
140
+
141
+ Args:
142
+ u: Input signal, shape (num_oscillators, state_dim)
143
+
144
+ Returns:
145
+ Updated hidden states, shape (num_oscillators, state_dim)
146
+ """
147
+ # Broadcast lambdas across state dimensions
148
+ lambdas_broadcast = self.lambdas[:, np.newaxis] # (n, 1)
149
+
150
+ # Apply dynamics
151
+ self.h = lambdas_broadcast * self.h + u
152
+
153
+ return self.h.copy()
154
+
155
+ def reset(self):
156
+ """Reset oscillator states to zero."""
157
+ self.h = np.zeros((self.n, self.d))
158
+
159
+ def get_half_life_statistics(self) -> Dict[str, float]:
160
+ """
161
+ Compute statistics of half-life distribution.
162
+
163
+ Returns:
164
+ Dictionary with mean, std, min, max in log space.
165
+ """
166
+ taus = self.get_half_lives()
167
+ z = np.log(taus)
168
+
169
+ return {
170
+ "tau_min": float(np.min(taus)),
171
+ "tau_max": float(np.max(taus)),
172
+ "tau_mean": float(np.mean(taus)),
173
+ "tau_median": float(np.median(taus)),
174
+ "log_tau_mean": float(np.mean(z)),
175
+ "log_tau_std": float(np.std(z)),
176
+ "log_tau_min": float(np.min(z)),
177
+ "log_tau_max": float(np.max(z)),
178
+ }
179
+
180
+ def get_state(self) -> OscillatorState:
181
+ """Get current oscillator state."""
182
+ return OscillatorState(
183
+ h=self.h.copy(),
184
+ lambdas=self.lambdas.copy()
185
+ )
186
+
187
+ def set_state(self, state: OscillatorState):
188
+ """Set oscillator state."""
189
+ self.h = state.h.copy()
190
+ self.lambdas = state.lambdas.copy()
191
+
192
+
193
+ class FDRAWithOscillators:
194
+ """
195
+ Full FDRA agent with oscillator bank for memory.
196
+
197
+ This extends the basic FDRA agent to use an oscillator bank
198
+ with explicit decay parameters that can be regularized.
199
+
200
+ Architecture:
201
+ Input → [Oscillator Bank] → Slow State → Output
202
+ ↑ ↓
203
+ Fast State ←──────────────
204
+
205
+ Routing Modes (validated in routing ablation):
206
+ - "uniform": Equal weight to all oscillators (baseline)
207
+ - "tau_weighted": Weight ∝ τ (soft routing to slow modes)
208
+ - "tau_gated": Only write to τ > threshold oscillators
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ osc_config: Optional[OscillatorConfig] = None,
214
+ wlc_threshold: float = 1.0,
215
+ routing_mode: str = "uniform" # "uniform", "tau_weighted", or "tau_gated"
216
+ ):
217
+ self.config = osc_config or OscillatorConfig()
218
+ self.oscillators = FDRAOscillatorBank(self.config)
219
+ self.wlc_threshold = wlc_threshold
220
+ self.routing_mode = routing_mode
221
+
222
+ # Routing config
223
+ self.routing_min = 0.25 # Minimum routing weight
224
+ self.routing_max = 4.0 # Maximum routing weight
225
+ self.gating_threshold = 0.25 # Fraction of L for gating threshold
226
+
227
+ # Fast state (reactive, for computation)
228
+ self.fast = np.zeros(self.config.state_dim)
229
+
230
+ # Energy tracking
231
+ self.energy = 0.0
232
+
233
+ self.history: List[Dict[str, Any]] = []
234
+
235
+ def _compute_routing_weights(self) -> np.ndarray:
236
+ """
237
+ Compute routing weights based on routing mode.
238
+
239
+ Returns:
240
+ Routing weights, shape (num_oscillators,)
241
+ """
242
+ taus = self.oscillators.get_half_lives()
243
+
244
+ if self.routing_mode == "uniform":
245
+ # Equal weight to all oscillators
246
+ return np.ones(self.config.num_oscillators)
247
+
248
+ elif self.routing_mode == "tau_weighted":
249
+ # Weight ∝ τ, normalized by mean
250
+ weights = taus / np.mean(taus)
251
+ # Clamp for stability
252
+ weights = np.clip(weights, self.routing_min, self.routing_max)
253
+ return weights
254
+
255
+ elif self.routing_mode == "tau_gated":
256
+ # Hard gating: only oscillators with τ > threshold
257
+ threshold = self.gating_threshold * self.config.sequence_length
258
+ mask = (taus > threshold).astype(float)
259
+ if np.sum(mask) == 0:
260
+ # Fallback to uniform if no oscillators pass
261
+ return np.ones(self.config.num_oscillators)
262
+ # Normalize so total weight is same as uniform
263
+ return mask * (self.config.num_oscillators / np.sum(mask))
264
+
265
+ else:
266
+ raise ValueError(f"Unknown routing mode: {self.routing_mode}")
267
+
268
+ def get_slow_state(self) -> np.ndarray:
269
+ """
270
+ Aggregate slow state from oscillator bank.
271
+
272
+ The slow state is a weighted sum of oscillator states,
273
+ with weights proportional to half-life.
274
+ """
275
+ taus = self.oscillators.get_half_lives()
276
+ weights = taus / np.sum(taus) # Normalize
277
+
278
+ # Weighted sum across oscillators
279
+ weighted_h = self.oscillators.h * weights[:, np.newaxis]
280
+ slow = np.sum(weighted_h, axis=0) # (state_dim,)
281
+
282
+ return slow
283
+
284
+ def forward_dynamics(self, action: np.ndarray) -> np.ndarray:
285
+ """
286
+ Forward dynamics with oscillator bank.
287
+
288
+ 1. Compute routing weights based on mode
289
+ 2. Distribute action across oscillators (weighted by routing)
290
+ 3. Update oscillator bank
291
+ 4. Compute slow state from oscillators
292
+ 5. Update fast state
293
+ """
294
+ # Compute routing weights (the key change for τ-routing)
295
+ routing_weights = self._compute_routing_weights() # (n,)
296
+
297
+ # Distribute action to oscillators WITH ROUTING WEIGHTS
298
+ u = np.tile(action, (self.config.num_oscillators, 1)) # (n, d)
299
+
300
+ # Apply routing weights (scale each oscillator's input by its weight)
301
+ u = u * routing_weights[:, np.newaxis] # (n, d)
302
+
303
+ # Scale by base factor
304
+ u = u * 0.1
305
+
306
+ # Update oscillators
307
+ self.oscillators.forward(u)
308
+
309
+ # Get slow state from oscillators
310
+ slow = self.get_slow_state()
311
+
312
+ # Update fast state (reactive)
313
+ self.fast = 0.9 * self.fast + action
314
+
315
+ # Energy
316
+ self.energy += np.linalg.norm(action) * 0.1
317
+
318
+ return slow
319
+
320
+ def get_coherence(self) -> float:
321
+ """Coherence between slow and fast states."""
322
+ slow = self.get_slow_state()
323
+ slow_norm = np.linalg.norm(slow)
324
+ fast_norm = np.linalg.norm(self.fast)
325
+
326
+ if slow_norm < 1e-10 or fast_norm < 1e-10:
327
+ return 0.0
328
+
329
+ return float(np.dot(slow, self.fast) / (slow_norm * fast_norm))
330
+
331
+ def step(self, action: np.ndarray) -> Dict[str, Any]:
332
+ """Execute one step and return diagnostics."""
333
+ slow = self.forward_dynamics(action)
334
+ coherence = self.get_coherence()
335
+
336
+ stats = self.oscillators.get_half_life_statistics()
337
+
338
+ result = {
339
+ "slow_norm": float(np.linalg.norm(slow)),
340
+ "fast_norm": float(np.linalg.norm(self.fast)),
341
+ "coherence": coherence,
342
+ "energy": self.energy,
343
+ **stats
344
+ }
345
+
346
+ self.history.append(result)
347
+ return result
348
+
349
+ def reset(self):
350
+ """Reset all state."""
351
+ self.oscillators.reset()
352
+ self.fast = np.zeros(self.config.state_dim)
353
+ self.energy = 0.0
354
+ self.history = []
355
+
356
+
357
+ def demo_oscillators():
358
+ """Demonstrate oscillator bank behavior."""
359
+ print("=" * 60)
360
+ print("FDRA OSCILLATOR BANK DEMONSTRATION")
361
+ print("=" * 60)
362
+
363
+ config = OscillatorConfig(
364
+ num_oscillators=16,
365
+ state_dim=8,
366
+ sequence_length=4096,
367
+ tau_min=1.0,
368
+ tau_max=4096.0
369
+ )
370
+
371
+ bank = FDRAOscillatorBank(config)
372
+
373
+ print("\n1. Initial Half-Life Distribution")
374
+ print("-" * 40)
375
+ stats = bank.get_half_life_statistics()
376
+ print(f" τ range: [{stats['tau_min']:.1f}, {stats['tau_max']:.1f}]")
377
+ print(f" τ mean: {stats['tau_mean']:.1f}")
378
+ print(f" log(τ) mean: {stats['log_tau_mean']:.3f}")
379
+ print(f" log(τ) std: {stats['log_tau_std']:.3f}")
380
+
381
+ print("\n2. Half-Lives per Oscillator")
382
+ print("-" * 40)
383
+ taus = bank.get_half_lives()
384
+ for i, tau in enumerate(taus):
385
+ bar = "█" * int(np.log(tau) * 3)
386
+ print(f" Osc {i:2d}: τ = {tau:7.1f} steps {bar}")
387
+
388
+ print("\n3. Simulating Input Sequence")
389
+ print("-" * 40)
390
+
391
+ # Pulse input at t=0
392
+ u = np.random.randn(config.num_oscillators, config.state_dim)
393
+ bank.forward(u)
394
+ initial_norms = np.linalg.norm(bank.h, axis=1)
395
+
396
+ # Decay for 100 steps with zero input
397
+ decay_steps = [10, 50, 100, 500, 1000]
398
+ zero_input = np.zeros((config.num_oscillators, config.state_dim))
399
+
400
+ step = 0
401
+ for target in decay_steps:
402
+ while step < target:
403
+ bank.forward(zero_input)
404
+ step += 1
405
+
406
+ current_norms = np.linalg.norm(bank.h, axis=1)
407
+ retention = current_norms / (initial_norms + 1e-10)
408
+
409
+ print(f"\n After {step} steps:")
410
+ for i, (tau, ret) in enumerate(zip(taus, retention)):
411
+ if tau < step * 0.5:
412
+ expected = "✗ (should be < 50%)"
413
+ else:
414
+ expected = "✓ (should be > 50%)"
415
+ print(f" Osc {i:2d}: τ={tau:7.1f}, retention={ret:.1%} {expected}")
416
+ if i >= 3:
417
+ print(f" ... ({len(taus) - 4} more)")
418
+ break
419
+
420
+ print("\n" + "=" * 60)
421
+ print("OBSERVATION: Oscillators with τ > t retain more than 50% of signal")
422
+ print("This is the desired behavior for long-context modeling.")
423
+ print("=" * 60)
424
+
425
+
426
+ if __name__ == "__main__":
427
+ demo_oscillators()