juddddd commited on
Commit
5daceea
Β·
verified Β·
1 Parent(s): 1d5bbb5

Upload training/fdra_oscillators.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/fdra_oscillators.py +376 -0
training/fdra_oscillators.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FDRA Oscillator Implementation with Explicit Decay Parameters
3
+
4
+ This implements the core FDRA oscillator dynamics where each oscillator has:
5
+ - A decay parameter λ_i ∈ (0, 1)
6
+ - Half-life Ο„_i = ln(0.5) / ln(Ξ»_i)
7
+
8
+ The key problem this addresses (from Melanie/Tiago's discovery):
9
+ - During training at GPT-2 scale, all Ξ»_i collapse to near 1.0 (very short half-lives)
10
+ - This means oscillators only attend to ~10 tokens instead of full context length
11
+ - The model works for short-context tasks but fails on long-context reasoning
12
+
13
+ Solution: Half-life regularization to maintain diversity across temporal scales.
14
+
15
+ Authors: FDRA Half-Life Regularization Implementation
16
+ Date: 2026-01-22
17
+ """
18
+
19
+ import numpy as np
20
+ from typing import Dict, List, Tuple, Optional, Any
21
+ from dataclasses import dataclass
22
+ import json
23
+ from pathlib import Path
24
+
25
+
26
+ @dataclass
27
+ class OscillatorConfig:
28
+ """Configuration for FDRA oscillator bank."""
29
+ num_oscillators: int = 32 # Number of oscillators
30
+ state_dim: int = 16 # Dimension per oscillator
31
+ sequence_length: int = 4096 # Max sequence length (L)
32
+ tau_min: float = 1.0 # Minimum half-life
33
+ tau_max: float = 4096.0 # Maximum half-life (typically = L)
34
+
35
+ # Initialization
36
+ init_method: str = "log_uniform" # "log_uniform" or "random"
37
+
38
+
39
+ @dataclass
40
+ class OscillatorState:
41
+ """State of an oscillator bank."""
42
+ h: np.ndarray # Hidden states: (num_oscillators, state_dim)
43
+ lambdas: np.ndarray # Decay parameters: (num_oscillators,)
44
+
45
+ def copy(self) -> 'OscillatorState':
46
+ return OscillatorState(
47
+ h=self.h.copy(),
48
+ lambdas=self.lambdas.copy()
49
+ )
50
+
51
+
52
+ class FDRAOscillatorBank:
53
+ """
54
+ FDRA Oscillator Bank with explicit decay parameters.
55
+
56
+ Each oscillator i has:
57
+ h_i(t+1) = Ξ»_i * h_i(t) + u_i(t)
58
+
59
+ Where:
60
+ λ_i ∈ (0, 1) is the decay parameter
61
+ Ο„_i = ln(0.5) / ln(Ξ»_i) is the half-life
62
+
63
+ Half-life interpretation:
64
+ Ο„_i = number of steps for oscillator state to decay to 50%
65
+
66
+ The goal of half-life regularization:
67
+ Maintain log-uniform distribution of Ο„_i across [Ο„_min, Ο„_max]
68
+ This ensures oscillators can attend to both short and long contexts.
69
+ """
70
+
71
+ def __init__(self, config: OscillatorConfig):
72
+ self.config = config
73
+ self.n = config.num_oscillators
74
+ self.d = config.state_dim
75
+ self.L = config.sequence_length
76
+
77
+ # Initialize decay parameters
78
+ self.lambdas = self._init_lambdas()
79
+
80
+ # Initialize hidden states
81
+ self.h = np.zeros((self.n, self.d))
82
+
83
+ # Track history for analysis
84
+ self.history: List[Dict[str, Any]] = []
85
+
86
+ def _init_lambdas(self) -> np.ndarray:
87
+ """
88
+ Initialize decay parameters Ξ»_i.
89
+
90
+ For log-uniform half-lives, we want:
91
+ Ο„_i ~ LogUniform(Ο„_min, Ο„_max)
92
+
93
+ Since Ο„ = ln(0.5) / ln(Ξ»), we have:
94
+ Ξ» = 0.5^(1/Ο„)
95
+
96
+ So for log-uniform Ο„:
97
+ log(Ο„) ~ Uniform(log(Ο„_min), log(Ο„_max))
98
+ Ο„ = exp(log_Ο„)
99
+ Ξ» = 0.5^(1/Ο„)
100
+ """
101
+ if self.config.init_method == "log_uniform":
102
+ # Log-uniform distribution of half-lives
103
+ log_tau_min = np.log(self.config.tau_min)
104
+ log_tau_max = np.log(self.config.tau_max)
105
+
106
+ # Evenly spaced in log space
107
+ log_taus = np.linspace(log_tau_min, log_tau_max, self.n)
108
+ taus = np.exp(log_taus)
109
+
110
+ # Convert half-lives to decay parameters
111
+ # Ξ» = exp(ln(0.5) / Ο„) = 0.5^(1/Ο„)
112
+ lambdas = np.power(0.5, 1.0 / taus)
113
+
114
+ else:
115
+ # Random initialization (not recommended)
116
+ lambdas = np.random.uniform(0.5, 0.99, self.n)
117
+
118
+ return lambdas
119
+
120
+ def get_half_lives(self) -> np.ndarray:
121
+ """
122
+ Compute half-lives from decay parameters.
123
+
124
+ Ο„_i = ln(0.5) / ln(Ξ»_i)
125
+ """
126
+ # Clamp lambdas to avoid log(1) = 0
127
+ safe_lambdas = np.clip(self.lambdas, 1e-10, 1.0 - 1e-10)
128
+ taus = np.log(0.5) / np.log(safe_lambdas)
129
+ return taus
130
+
131
+ def get_log_half_lives(self) -> np.ndarray:
132
+ """Get log of half-lives: z_i = log(Ο„_i)."""
133
+ return np.log(self.get_half_lives())
134
+
135
+ def forward(self, u: np.ndarray) -> np.ndarray:
136
+ """
137
+ One step of oscillator dynamics.
138
+
139
+ h_i(t+1) = Ξ»_i * h_i(t) + u_i(t)
140
+
141
+ Args:
142
+ u: Input signal, shape (num_oscillators, state_dim)
143
+
144
+ Returns:
145
+ Updated hidden states, shape (num_oscillators, state_dim)
146
+ """
147
+ # Broadcast lambdas across state dimensions
148
+ lambdas_broadcast = self.lambdas[:, np.newaxis] # (n, 1)
149
+
150
+ # Apply dynamics
151
+ self.h = lambdas_broadcast * self.h + u
152
+
153
+ return self.h.copy()
154
+
155
+ def reset(self):
156
+ """Reset oscillator states to zero."""
157
+ self.h = np.zeros((self.n, self.d))
158
+
159
+ def get_half_life_statistics(self) -> Dict[str, float]:
160
+ """
161
+ Compute statistics of half-life distribution.
162
+
163
+ Returns:
164
+ Dictionary with mean, std, min, max in log space.
165
+ """
166
+ taus = self.get_half_lives()
167
+ z = np.log(taus)
168
+
169
+ return {
170
+ "tau_min": float(np.min(taus)),
171
+ "tau_max": float(np.max(taus)),
172
+ "tau_mean": float(np.mean(taus)),
173
+ "tau_median": float(np.median(taus)),
174
+ "log_tau_mean": float(np.mean(z)),
175
+ "log_tau_std": float(np.std(z)),
176
+ "log_tau_min": float(np.min(z)),
177
+ "log_tau_max": float(np.max(z)),
178
+ }
179
+
180
+ def get_state(self) -> OscillatorState:
181
+ """Get current oscillator state."""
182
+ return OscillatorState(
183
+ h=self.h.copy(),
184
+ lambdas=self.lambdas.copy()
185
+ )
186
+
187
+ def set_state(self, state: OscillatorState):
188
+ """Set oscillator state."""
189
+ self.h = state.h.copy()
190
+ self.lambdas = state.lambdas.copy()
191
+
192
+
193
+ class FDRAWithOscillators:
194
+ """
195
+ Full FDRA agent with oscillator bank for memory.
196
+
197
+ This extends the basic FDRA agent to use an oscillator bank
198
+ with explicit decay parameters that can be regularized.
199
+
200
+ Architecture:
201
+ Input β†’ [Oscillator Bank] β†’ Slow State β†’ Output
202
+ ↑ ↓
203
+ Fast State ←──────────────
204
+ """
205
+
206
+ def __init__(
207
+ self,
208
+ osc_config: Optional[OscillatorConfig] = None,
209
+ wlc_threshold: float = 1.0
210
+ ):
211
+ self.config = osc_config or OscillatorConfig()
212
+ self.oscillators = FDRAOscillatorBank(self.config)
213
+ self.wlc_threshold = wlc_threshold
214
+
215
+ # Fast state (reactive, for computation)
216
+ self.fast = np.zeros(self.config.state_dim)
217
+
218
+ # Energy tracking
219
+ self.energy = 0.0
220
+
221
+ self.history: List[Dict[str, Any]] = []
222
+
223
+ def get_slow_state(self) -> np.ndarray:
224
+ """
225
+ Aggregate slow state from oscillator bank.
226
+
227
+ The slow state is a weighted sum of oscillator states,
228
+ with weights proportional to half-life.
229
+ """
230
+ taus = self.oscillators.get_half_lives()
231
+ weights = taus / np.sum(taus) # Normalize
232
+
233
+ # Weighted sum across oscillators
234
+ weighted_h = self.oscillators.h * weights[:, np.newaxis]
235
+ slow = np.sum(weighted_h, axis=0) # (state_dim,)
236
+
237
+ return slow
238
+
239
+ def forward_dynamics(self, action: np.ndarray) -> np.ndarray:
240
+ """
241
+ Forward dynamics with oscillator bank.
242
+
243
+ 1. Distribute action across oscillators
244
+ 2. Update oscillator bank
245
+ 3. Compute slow state from oscillators
246
+ 4. Update fast state
247
+ """
248
+ # Distribute action to oscillators (same input, different decays)
249
+ u = np.tile(action, (self.config.num_oscillators, 1)) # (n, d)
250
+
251
+ # Scale by oscillator-specific factors (optional: could learn these)
252
+ scale = 0.1 * np.ones((self.config.num_oscillators, 1))
253
+ u = u * scale
254
+
255
+ # Update oscillators
256
+ self.oscillators.forward(u)
257
+
258
+ # Get slow state from oscillators
259
+ slow = self.get_slow_state()
260
+
261
+ # Update fast state (reactive)
262
+ self.fast = 0.9 * self.fast + action
263
+
264
+ # Energy
265
+ self.energy += np.linalg.norm(action) * 0.1
266
+
267
+ return slow
268
+
269
+ def get_coherence(self) -> float:
270
+ """Coherence between slow and fast states."""
271
+ slow = self.get_slow_state()
272
+ slow_norm = np.linalg.norm(slow)
273
+ fast_norm = np.linalg.norm(self.fast)
274
+
275
+ if slow_norm < 1e-10 or fast_norm < 1e-10:
276
+ return 0.0
277
+
278
+ return float(np.dot(slow, self.fast) / (slow_norm * fast_norm))
279
+
280
+ def step(self, action: np.ndarray) -> Dict[str, Any]:
281
+ """Execute one step and return diagnostics."""
282
+ slow = self.forward_dynamics(action)
283
+ coherence = self.get_coherence()
284
+
285
+ stats = self.oscillators.get_half_life_statistics()
286
+
287
+ result = {
288
+ "slow_norm": float(np.linalg.norm(slow)),
289
+ "fast_norm": float(np.linalg.norm(self.fast)),
290
+ "coherence": coherence,
291
+ "energy": self.energy,
292
+ **stats
293
+ }
294
+
295
+ self.history.append(result)
296
+ return result
297
+
298
+ def reset(self):
299
+ """Reset all state."""
300
+ self.oscillators.reset()
301
+ self.fast = np.zeros(self.config.state_dim)
302
+ self.energy = 0.0
303
+ self.history = []
304
+
305
+
306
+ def demo_oscillators():
307
+ """Demonstrate oscillator bank behavior."""
308
+ print("=" * 60)
309
+ print("FDRA OSCILLATOR BANK DEMONSTRATION")
310
+ print("=" * 60)
311
+
312
+ config = OscillatorConfig(
313
+ num_oscillators=16,
314
+ state_dim=8,
315
+ sequence_length=4096,
316
+ tau_min=1.0,
317
+ tau_max=4096.0
318
+ )
319
+
320
+ bank = FDRAOscillatorBank(config)
321
+
322
+ print("\n1. Initial Half-Life Distribution")
323
+ print("-" * 40)
324
+ stats = bank.get_half_life_statistics()
325
+ print(f" Ο„ range: [{stats['tau_min']:.1f}, {stats['tau_max']:.1f}]")
326
+ print(f" Ο„ mean: {stats['tau_mean']:.1f}")
327
+ print(f" log(Ο„) mean: {stats['log_tau_mean']:.3f}")
328
+ print(f" log(Ο„) std: {stats['log_tau_std']:.3f}")
329
+
330
+ print("\n2. Half-Lives per Oscillator")
331
+ print("-" * 40)
332
+ taus = bank.get_half_lives()
333
+ for i, tau in enumerate(taus):
334
+ bar = "β–ˆ" * int(np.log(tau) * 3)
335
+ print(f" Osc {i:2d}: Ο„ = {tau:7.1f} steps {bar}")
336
+
337
+ print("\n3. Simulating Input Sequence")
338
+ print("-" * 40)
339
+
340
+ # Pulse input at t=0
341
+ u = np.random.randn(config.num_oscillators, config.state_dim)
342
+ bank.forward(u)
343
+ initial_norms = np.linalg.norm(bank.h, axis=1)
344
+
345
+ # Decay for 100 steps with zero input
346
+ decay_steps = [10, 50, 100, 500, 1000]
347
+ zero_input = np.zeros((config.num_oscillators, config.state_dim))
348
+
349
+ step = 0
350
+ for target in decay_steps:
351
+ while step < target:
352
+ bank.forward(zero_input)
353
+ step += 1
354
+
355
+ current_norms = np.linalg.norm(bank.h, axis=1)
356
+ retention = current_norms / (initial_norms + 1e-10)
357
+
358
+ print(f"\n After {step} steps:")
359
+ for i, (tau, ret) in enumerate(zip(taus, retention)):
360
+ if tau < step * 0.5:
361
+ expected = "βœ— (should be < 50%)"
362
+ else:
363
+ expected = "βœ“ (should be > 50%)"
364
+ print(f" Osc {i:2d}: Ο„={tau:7.1f}, retention={ret:.1%} {expected}")
365
+ if i >= 3:
366
+ print(f" ... ({len(taus) - 4} more)")
367
+ break
368
+
369
+ print("\n" + "=" * 60)
370
+ print("OBSERVATION: Oscillators with Ο„ > t retain more than 50% of signal")
371
+ print("This is the desired behavior for long-context modeling.")
372
+ print("=" * 60)
373
+
374
+
375
+ if __name__ == "__main__":
376
+ demo_oscillators()