| """ |
| FDRA Oscillator Implementation with Explicit Decay Parameters |
| |
| This implements the core FDRA oscillator dynamics where each oscillator has: |
| - A decay parameter λ_i ∈ (0, 1) |
| - Half-life τ_i = ln(0.5) / ln(λ_i) |
| |
| The key problem this addresses (from Melanie/Tiago's discovery): |
| - During training at GPT-2 scale, all λ_i collapse to near 1.0 (very short half-lives) |
| - This means oscillators only attend to ~10 tokens instead of full context length |
| - The model works for short-context tasks but fails on long-context reasoning |
| |
| Solution: Half-life regularization to maintain diversity across temporal scales. |
| |
| Authors: FDRA Half-Life Regularization Implementation |
| Date: 2026-01-22 |
| """ |
|
|
| import numpy as np |
| from typing import Dict, List, Tuple, Optional, Any |
| from dataclasses import dataclass |
| import json |
| from pathlib import Path |
|
|
|
|
| @dataclass |
| class OscillatorConfig: |
| """Configuration for FDRA oscillator bank.""" |
| num_oscillators: int = 32 |
| state_dim: int = 16 |
| sequence_length: int = 4096 |
| tau_min: float = 1.0 |
| tau_max: float = 4096.0 |
| |
| |
| init_method: str = "log_uniform" |
| |
|
|
| @dataclass |
| class OscillatorState: |
| """State of an oscillator bank.""" |
| h: np.ndarray |
| lambdas: np.ndarray |
| |
| def copy(self) -> 'OscillatorState': |
| return OscillatorState( |
| h=self.h.copy(), |
| lambdas=self.lambdas.copy() |
| ) |
|
|
|
|
| class FDRAOscillatorBank: |
| """ |
| FDRA Oscillator Bank with explicit decay parameters. |
| |
| Each oscillator i has: |
| h_i(t+1) = λ_i * h_i(t) + u_i(t) |
| |
| Where: |
| λ_i ∈ (0, 1) is the decay parameter |
| τ_i = ln(0.5) / ln(λ_i) is the half-life |
| |
| Half-life interpretation: |
| τ_i = number of steps for oscillator state to decay to 50% |
| |
| The goal of half-life regularization: |
| Maintain log-uniform distribution of τ_i across [τ_min, τ_max] |
| This ensures oscillators can attend to both short and long contexts. |
| """ |
| |
| def __init__(self, config: OscillatorConfig): |
| self.config = config |
| self.n = config.num_oscillators |
| self.d = config.state_dim |
| self.L = config.sequence_length |
| |
| |
| self.lambdas = self._init_lambdas() |
| |
| |
| self.h = np.zeros((self.n, self.d)) |
| |
| |
| self.history: List[Dict[str, Any]] = [] |
| |
| def _init_lambdas(self) -> np.ndarray: |
| """ |
| Initialize decay parameters λ_i. |
| |
| For log-uniform half-lives, we want: |
| τ_i ~ LogUniform(τ_min, τ_max) |
| |
| Since τ = ln(0.5) / ln(λ), we have: |
| λ = 0.5^(1/τ) |
| |
| So for log-uniform τ: |
| log(τ) ~ Uniform(log(τ_min), log(τ_max)) |
| τ = exp(log_τ) |
| λ = 0.5^(1/τ) |
| """ |
| if self.config.init_method == "log_uniform": |
| |
| log_tau_min = np.log(self.config.tau_min) |
| log_tau_max = np.log(self.config.tau_max) |
| |
| |
| log_taus = np.linspace(log_tau_min, log_tau_max, self.n) |
| taus = np.exp(log_taus) |
| |
| |
| |
| lambdas = np.power(0.5, 1.0 / taus) |
| |
| else: |
| |
| lambdas = np.random.uniform(0.5, 0.99, self.n) |
| |
| return lambdas |
| |
| def get_half_lives(self) -> np.ndarray: |
| """ |
| Compute half-lives from decay parameters. |
| |
| τ_i = ln(0.5) / ln(λ_i) |
| """ |
| |
| safe_lambdas = np.clip(self.lambdas, 1e-10, 1.0 - 1e-10) |
| taus = np.log(0.5) / np.log(safe_lambdas) |
| return taus |
| |
| def get_log_half_lives(self) -> np.ndarray: |
| """Get log of half-lives: z_i = log(τ_i).""" |
| return np.log(self.get_half_lives()) |
| |
| def forward(self, u: np.ndarray) -> np.ndarray: |
| """ |
| One step of oscillator dynamics. |
| |
| h_i(t+1) = λ_i * h_i(t) + u_i(t) |
| |
| Args: |
| u: Input signal, shape (num_oscillators, state_dim) |
| |
| Returns: |
| Updated hidden states, shape (num_oscillators, state_dim) |
| """ |
| |
| lambdas_broadcast = self.lambdas[:, np.newaxis] |
| |
| |
| self.h = lambdas_broadcast * self.h + u |
| |
| return self.h.copy() |
| |
| def reset(self): |
| """Reset oscillator states to zero.""" |
| self.h = np.zeros((self.n, self.d)) |
| |
| def get_half_life_statistics(self) -> Dict[str, float]: |
| """ |
| Compute statistics of half-life distribution. |
| |
| Returns: |
| Dictionary with mean, std, min, max in log space. |
| """ |
| taus = self.get_half_lives() |
| z = np.log(taus) |
| |
| return { |
| "tau_min": float(np.min(taus)), |
| "tau_max": float(np.max(taus)), |
| "tau_mean": float(np.mean(taus)), |
| "tau_median": float(np.median(taus)), |
| "log_tau_mean": float(np.mean(z)), |
| "log_tau_std": float(np.std(z)), |
| "log_tau_min": float(np.min(z)), |
| "log_tau_max": float(np.max(z)), |
| } |
| |
| def get_state(self) -> OscillatorState: |
| """Get current oscillator state.""" |
| return OscillatorState( |
| h=self.h.copy(), |
| lambdas=self.lambdas.copy() |
| ) |
| |
| def set_state(self, state: OscillatorState): |
| """Set oscillator state.""" |
| self.h = state.h.copy() |
| self.lambdas = state.lambdas.copy() |
|
|
|
|
| class FDRAWithOscillators: |
| """ |
| Full FDRA agent with oscillator bank for memory. |
| |
| This extends the basic FDRA agent to use an oscillator bank |
| with explicit decay parameters that can be regularized. |
| |
| Architecture: |
| Input → [Oscillator Bank] → Slow State → Output |
| ↑ ↓ |
| Fast State ←────────────── |
| """ |
| |
| def __init__( |
| self, |
| osc_config: Optional[OscillatorConfig] = None, |
| wlc_threshold: float = 1.0 |
| ): |
| self.config = osc_config or OscillatorConfig() |
| self.oscillators = FDRAOscillatorBank(self.config) |
| self.wlc_threshold = wlc_threshold |
| |
| |
| self.fast = np.zeros(self.config.state_dim) |
| |
| |
| self.energy = 0.0 |
| |
| self.history: List[Dict[str, Any]] = [] |
| |
| def get_slow_state(self) -> np.ndarray: |
| """ |
| Aggregate slow state from oscillator bank. |
| |
| The slow state is a weighted sum of oscillator states, |
| with weights proportional to half-life. |
| """ |
| taus = self.oscillators.get_half_lives() |
| weights = taus / np.sum(taus) |
| |
| |
| weighted_h = self.oscillators.h * weights[:, np.newaxis] |
| slow = np.sum(weighted_h, axis=0) |
| |
| return slow |
| |
| def forward_dynamics(self, action: np.ndarray) -> np.ndarray: |
| """ |
| Forward dynamics with oscillator bank. |
| |
| 1. Distribute action across oscillators |
| 2. Update oscillator bank |
| 3. Compute slow state from oscillators |
| 4. Update fast state |
| """ |
| |
| u = np.tile(action, (self.config.num_oscillators, 1)) |
| |
| |
| scale = 0.1 * np.ones((self.config.num_oscillators, 1)) |
| u = u * scale |
| |
| |
| self.oscillators.forward(u) |
| |
| |
| slow = self.get_slow_state() |
| |
| |
| self.fast = 0.9 * self.fast + action |
| |
| |
| self.energy += np.linalg.norm(action) * 0.1 |
| |
| return slow |
| |
| def get_coherence(self) -> float: |
| """Coherence between slow and fast states.""" |
| slow = self.get_slow_state() |
| slow_norm = np.linalg.norm(slow) |
| fast_norm = np.linalg.norm(self.fast) |
| |
| if slow_norm < 1e-10 or fast_norm < 1e-10: |
| return 0.0 |
| |
| return float(np.dot(slow, self.fast) / (slow_norm * fast_norm)) |
| |
| def step(self, action: np.ndarray) -> Dict[str, Any]: |
| """Execute one step and return diagnostics.""" |
| slow = self.forward_dynamics(action) |
| coherence = self.get_coherence() |
| |
| stats = self.oscillators.get_half_life_statistics() |
| |
| result = { |
| "slow_norm": float(np.linalg.norm(slow)), |
| "fast_norm": float(np.linalg.norm(self.fast)), |
| "coherence": coherence, |
| "energy": self.energy, |
| **stats |
| } |
| |
| self.history.append(result) |
| return result |
| |
| def reset(self): |
| """Reset all state.""" |
| self.oscillators.reset() |
| self.fast = np.zeros(self.config.state_dim) |
| self.energy = 0.0 |
| self.history = [] |
|
|
|
|
| def demo_oscillators(): |
| """Demonstrate oscillator bank behavior.""" |
| print("=" * 60) |
| print("FDRA OSCILLATOR BANK DEMONSTRATION") |
| print("=" * 60) |
| |
| config = OscillatorConfig( |
| num_oscillators=16, |
| state_dim=8, |
| sequence_length=4096, |
| tau_min=1.0, |
| tau_max=4096.0 |
| ) |
| |
| bank = FDRAOscillatorBank(config) |
| |
| print("\n1. Initial Half-Life Distribution") |
| print("-" * 40) |
| stats = bank.get_half_life_statistics() |
| print(f" τ range: [{stats['tau_min']:.1f}, {stats['tau_max']:.1f}]") |
| print(f" τ mean: {stats['tau_mean']:.1f}") |
| print(f" log(τ) mean: {stats['log_tau_mean']:.3f}") |
| print(f" log(τ) std: {stats['log_tau_std']:.3f}") |
| |
| print("\n2. Half-Lives per Oscillator") |
| print("-" * 40) |
| taus = bank.get_half_lives() |
| for i, tau in enumerate(taus): |
| bar = "█" * int(np.log(tau) * 3) |
| print(f" Osc {i:2d}: τ = {tau:7.1f} steps {bar}") |
| |
| print("\n3. Simulating Input Sequence") |
| print("-" * 40) |
| |
| |
| u = np.random.randn(config.num_oscillators, config.state_dim) |
| bank.forward(u) |
| initial_norms = np.linalg.norm(bank.h, axis=1) |
| |
| |
| decay_steps = [10, 50, 100, 500, 1000] |
| zero_input = np.zeros((config.num_oscillators, config.state_dim)) |
| |
| step = 0 |
| for target in decay_steps: |
| while step < target: |
| bank.forward(zero_input) |
| step += 1 |
| |
| current_norms = np.linalg.norm(bank.h, axis=1) |
| retention = current_norms / (initial_norms + 1e-10) |
| |
| print(f"\n After {step} steps:") |
| for i, (tau, ret) in enumerate(zip(taus, retention)): |
| if tau < step * 0.5: |
| expected = "✗ (should be < 50%)" |
| else: |
| expected = "✓ (should be > 50%)" |
| print(f" Osc {i:2d}: τ={tau:7.1f}, retention={ret:.1%} {expected}") |
| if i >= 3: |
| print(f" ... ({len(taus) - 4} more)") |
| break |
| |
| print("\n" + "=" * 60) |
| print("OBSERVATION: Oscillators with τ > t retain more than 50% of signal") |
| print("This is the desired behavior for long-context modeling.") |
| print("=" * 60) |
|
|
|
|
| if __name__ == "__main__": |
| demo_oscillators() |
|
|