""" FDRA Oscillator Implementation with Explicit Decay Parameters This implements the core FDRA oscillator dynamics where each oscillator has: - A decay parameter λ_i ∈ (0, 1) - Half-life τ_i = ln(0.5) / ln(λ_i) The key problem this addresses (from Melanie/Tiago's discovery): - During training at GPT-2 scale, all λ_i collapse to near 1.0 (very short half-lives) - This means oscillators only attend to ~10 tokens instead of full context length - The model works for short-context tasks but fails on long-context reasoning Solution: Half-life regularization to maintain diversity across temporal scales. Authors: FDRA Half-Life Regularization Implementation Date: 2026-01-22 """ import numpy as np from typing import Dict, List, Tuple, Optional, Any from dataclasses import dataclass import json from pathlib import Path @dataclass class OscillatorConfig: """Configuration for FDRA oscillator bank.""" num_oscillators: int = 32 # Number of oscillators state_dim: int = 16 # Dimension per oscillator sequence_length: int = 4096 # Max sequence length (L) tau_min: float = 1.0 # Minimum half-life tau_max: float = 4096.0 # Maximum half-life (typically = L) # Initialization init_method: str = "log_uniform" # "log_uniform" or "random" @dataclass class OscillatorState: """State of an oscillator bank.""" h: np.ndarray # Hidden states: (num_oscillators, state_dim) lambdas: np.ndarray # Decay parameters: (num_oscillators,) def copy(self) -> 'OscillatorState': return OscillatorState( h=self.h.copy(), lambdas=self.lambdas.copy() ) class FDRAOscillatorBank: """ FDRA Oscillator Bank with explicit decay parameters. Each oscillator i has: h_i(t+1) = λ_i * h_i(t) + u_i(t) Where: λ_i ∈ (0, 1) is the decay parameter τ_i = ln(0.5) / ln(λ_i) is the half-life Half-life interpretation: τ_i = number of steps for oscillator state to decay to 50% The goal of half-life regularization: Maintain log-uniform distribution of τ_i across [τ_min, τ_max] This ensures oscillators can attend to both short and long contexts. """ def __init__(self, config: OscillatorConfig): self.config = config self.n = config.num_oscillators self.d = config.state_dim self.L = config.sequence_length # Initialize decay parameters self.lambdas = self._init_lambdas() # Initialize hidden states self.h = np.zeros((self.n, self.d)) # Track history for analysis self.history: List[Dict[str, Any]] = [] def _init_lambdas(self) -> np.ndarray: """ Initialize decay parameters λ_i. For log-uniform half-lives, we want: τ_i ~ LogUniform(τ_min, τ_max) Since τ = ln(0.5) / ln(λ), we have: λ = 0.5^(1/τ) So for log-uniform τ: log(τ) ~ Uniform(log(τ_min), log(τ_max)) τ = exp(log_τ) λ = 0.5^(1/τ) """ if self.config.init_method == "log_uniform": # Log-uniform distribution of half-lives log_tau_min = np.log(self.config.tau_min) log_tau_max = np.log(self.config.tau_max) # Evenly spaced in log space log_taus = np.linspace(log_tau_min, log_tau_max, self.n) taus = np.exp(log_taus) # Convert half-lives to decay parameters # λ = exp(ln(0.5) / τ) = 0.5^(1/τ) lambdas = np.power(0.5, 1.0 / taus) else: # Random initialization (not recommended) lambdas = np.random.uniform(0.5, 0.99, self.n) return lambdas def get_half_lives(self) -> np.ndarray: """ Compute half-lives from decay parameters. τ_i = ln(0.5) / ln(λ_i) """ # Clamp lambdas to avoid log(1) = 0 safe_lambdas = np.clip(self.lambdas, 1e-10, 1.0 - 1e-10) taus = np.log(0.5) / np.log(safe_lambdas) return taus def get_log_half_lives(self) -> np.ndarray: """Get log of half-lives: z_i = log(τ_i).""" return np.log(self.get_half_lives()) def forward(self, u: np.ndarray) -> np.ndarray: """ One step of oscillator dynamics. h_i(t+1) = λ_i * h_i(t) + u_i(t) Args: u: Input signal, shape (num_oscillators, state_dim) Returns: Updated hidden states, shape (num_oscillators, state_dim) """ # Broadcast lambdas across state dimensions lambdas_broadcast = self.lambdas[:, np.newaxis] # (n, 1) # Apply dynamics self.h = lambdas_broadcast * self.h + u return self.h.copy() def reset(self): """Reset oscillator states to zero.""" self.h = np.zeros((self.n, self.d)) def get_half_life_statistics(self) -> Dict[str, float]: """ Compute statistics of half-life distribution. Returns: Dictionary with mean, std, min, max in log space. """ taus = self.get_half_lives() z = np.log(taus) return { "tau_min": float(np.min(taus)), "tau_max": float(np.max(taus)), "tau_mean": float(np.mean(taus)), "tau_median": float(np.median(taus)), "log_tau_mean": float(np.mean(z)), "log_tau_std": float(np.std(z)), "log_tau_min": float(np.min(z)), "log_tau_max": float(np.max(z)), } def get_state(self) -> OscillatorState: """Get current oscillator state.""" return OscillatorState( h=self.h.copy(), lambdas=self.lambdas.copy() ) def set_state(self, state: OscillatorState): """Set oscillator state.""" self.h = state.h.copy() self.lambdas = state.lambdas.copy() class FDRAWithOscillators: """ Full FDRA agent with oscillator bank for memory. This extends the basic FDRA agent to use an oscillator bank with explicit decay parameters that can be regularized. Architecture: Input → [Oscillator Bank] → Slow State → Output ↑ ↓ Fast State ←────────────── """ def __init__( self, osc_config: Optional[OscillatorConfig] = None, wlc_threshold: float = 1.0 ): self.config = osc_config or OscillatorConfig() self.oscillators = FDRAOscillatorBank(self.config) self.wlc_threshold = wlc_threshold # Fast state (reactive, for computation) self.fast = np.zeros(self.config.state_dim) # Energy tracking self.energy = 0.0 self.history: List[Dict[str, Any]] = [] def get_slow_state(self) -> np.ndarray: """ Aggregate slow state from oscillator bank. The slow state is a weighted sum of oscillator states, with weights proportional to half-life. """ taus = self.oscillators.get_half_lives() weights = taus / np.sum(taus) # Normalize # Weighted sum across oscillators weighted_h = self.oscillators.h * weights[:, np.newaxis] slow = np.sum(weighted_h, axis=0) # (state_dim,) return slow def forward_dynamics(self, action: np.ndarray) -> np.ndarray: """ Forward dynamics with oscillator bank. 1. Distribute action across oscillators 2. Update oscillator bank 3. Compute slow state from oscillators 4. Update fast state """ # Distribute action to oscillators (same input, different decays) u = np.tile(action, (self.config.num_oscillators, 1)) # (n, d) # Scale by oscillator-specific factors (optional: could learn these) scale = 0.1 * np.ones((self.config.num_oscillators, 1)) u = u * scale # Update oscillators self.oscillators.forward(u) # Get slow state from oscillators slow = self.get_slow_state() # Update fast state (reactive) self.fast = 0.9 * self.fast + action # Energy self.energy += np.linalg.norm(action) * 0.1 return slow def get_coherence(self) -> float: """Coherence between slow and fast states.""" slow = self.get_slow_state() slow_norm = np.linalg.norm(slow) fast_norm = np.linalg.norm(self.fast) if slow_norm < 1e-10 or fast_norm < 1e-10: return 0.0 return float(np.dot(slow, self.fast) / (slow_norm * fast_norm)) def step(self, action: np.ndarray) -> Dict[str, Any]: """Execute one step and return diagnostics.""" slow = self.forward_dynamics(action) coherence = self.get_coherence() stats = self.oscillators.get_half_life_statistics() result = { "slow_norm": float(np.linalg.norm(slow)), "fast_norm": float(np.linalg.norm(self.fast)), "coherence": coherence, "energy": self.energy, **stats } self.history.append(result) return result def reset(self): """Reset all state.""" self.oscillators.reset() self.fast = np.zeros(self.config.state_dim) self.energy = 0.0 self.history = [] def demo_oscillators(): """Demonstrate oscillator bank behavior.""" print("=" * 60) print("FDRA OSCILLATOR BANK DEMONSTRATION") print("=" * 60) config = OscillatorConfig( num_oscillators=16, state_dim=8, sequence_length=4096, tau_min=1.0, tau_max=4096.0 ) bank = FDRAOscillatorBank(config) print("\n1. Initial Half-Life Distribution") print("-" * 40) stats = bank.get_half_life_statistics() print(f" τ range: [{stats['tau_min']:.1f}, {stats['tau_max']:.1f}]") print(f" τ mean: {stats['tau_mean']:.1f}") print(f" log(τ) mean: {stats['log_tau_mean']:.3f}") print(f" log(τ) std: {stats['log_tau_std']:.3f}") print("\n2. Half-Lives per Oscillator") print("-" * 40) taus = bank.get_half_lives() for i, tau in enumerate(taus): bar = "█" * int(np.log(tau) * 3) print(f" Osc {i:2d}: τ = {tau:7.1f} steps {bar}") print("\n3. Simulating Input Sequence") print("-" * 40) # Pulse input at t=0 u = np.random.randn(config.num_oscillators, config.state_dim) bank.forward(u) initial_norms = np.linalg.norm(bank.h, axis=1) # Decay for 100 steps with zero input decay_steps = [10, 50, 100, 500, 1000] zero_input = np.zeros((config.num_oscillators, config.state_dim)) step = 0 for target in decay_steps: while step < target: bank.forward(zero_input) step += 1 current_norms = np.linalg.norm(bank.h, axis=1) retention = current_norms / (initial_norms + 1e-10) print(f"\n After {step} steps:") for i, (tau, ret) in enumerate(zip(taus, retention)): if tau < step * 0.5: expected = "✗ (should be < 50%)" else: expected = "✓ (should be > 50%)" print(f" Osc {i:2d}: τ={tau:7.1f}, retention={ret:.1%} {expected}") if i >= 3: print(f" ... ({len(taus) - 4} more)") break print("\n" + "=" * 60) print("OBSERVATION: Oscillators with τ > t retain more than 50% of signal") print("This is the desired behavior for long-context modeling.") print("=" * 60) if __name__ == "__main__": demo_oscillators()