fdra-half-life-regularization / code /fdra_oscillators.py
juddddd's picture
Upload code/fdra_oscillators.py with huggingface_hub
76b087c verified
"""
FDRA Oscillator Implementation with Explicit Decay Parameters
This implements the core FDRA oscillator dynamics where each oscillator has:
- A decay parameter λ_i ∈ (0, 1)
- Half-life τ_i = ln(0.5) / ln(λ_i)
The key problem this addresses (from Melanie/Tiago's discovery):
- During training at GPT-2 scale, all λ_i collapse to near 1.0 (very short half-lives)
- This means oscillators only attend to ~10 tokens instead of full context length
- The model works for short-context tasks but fails on long-context reasoning
Solution: Half-life regularization to maintain diversity across temporal scales.
Authors: FDRA Half-Life Regularization Implementation
Date: 2026-01-22
"""
import numpy as np
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
import json
from pathlib import Path
@dataclass
class OscillatorConfig:
"""Configuration for FDRA oscillator bank."""
num_oscillators: int = 32 # Number of oscillators
state_dim: int = 16 # Dimension per oscillator
sequence_length: int = 4096 # Max sequence length (L)
tau_min: float = 1.0 # Minimum half-life
tau_max: float = 4096.0 # Maximum half-life (typically = L)
# Initialization
init_method: str = "log_uniform" # "log_uniform" or "random"
@dataclass
class OscillatorState:
"""State of an oscillator bank."""
h: np.ndarray # Hidden states: (num_oscillators, state_dim)
lambdas: np.ndarray # Decay parameters: (num_oscillators,)
def copy(self) -> 'OscillatorState':
return OscillatorState(
h=self.h.copy(),
lambdas=self.lambdas.copy()
)
class FDRAOscillatorBank:
"""
FDRA Oscillator Bank with explicit decay parameters.
Each oscillator i has:
h_i(t+1) = λ_i * h_i(t) + u_i(t)
Where:
λ_i ∈ (0, 1) is the decay parameter
τ_i = ln(0.5) / ln(λ_i) is the half-life
Half-life interpretation:
τ_i = number of steps for oscillator state to decay to 50%
The goal of half-life regularization:
Maintain log-uniform distribution of τ_i across [τ_min, τ_max]
This ensures oscillators can attend to both short and long contexts.
"""
def __init__(self, config: OscillatorConfig):
self.config = config
self.n = config.num_oscillators
self.d = config.state_dim
self.L = config.sequence_length
# Initialize decay parameters
self.lambdas = self._init_lambdas()
# Initialize hidden states
self.h = np.zeros((self.n, self.d))
# Track history for analysis
self.history: List[Dict[str, Any]] = []
def _init_lambdas(self) -> np.ndarray:
"""
Initialize decay parameters λ_i.
For log-uniform half-lives, we want:
τ_i ~ LogUniform(τ_min, τ_max)
Since τ = ln(0.5) / ln(λ), we have:
λ = 0.5^(1/τ)
So for log-uniform τ:
log(τ) ~ Uniform(log(τ_min), log(τ_max))
τ = exp(log_τ)
λ = 0.5^(1/τ)
"""
if self.config.init_method == "log_uniform":
# Log-uniform distribution of half-lives
log_tau_min = np.log(self.config.tau_min)
log_tau_max = np.log(self.config.tau_max)
# Evenly spaced in log space
log_taus = np.linspace(log_tau_min, log_tau_max, self.n)
taus = np.exp(log_taus)
# Convert half-lives to decay parameters
# λ = exp(ln(0.5) / τ) = 0.5^(1/τ)
lambdas = np.power(0.5, 1.0 / taus)
else:
# Random initialization (not recommended)
lambdas = np.random.uniform(0.5, 0.99, self.n)
return lambdas
def get_half_lives(self) -> np.ndarray:
"""
Compute half-lives from decay parameters.
τ_i = ln(0.5) / ln(λ_i)
"""
# Clamp lambdas to avoid log(1) = 0
safe_lambdas = np.clip(self.lambdas, 1e-10, 1.0 - 1e-10)
taus = np.log(0.5) / np.log(safe_lambdas)
return taus
def get_log_half_lives(self) -> np.ndarray:
"""Get log of half-lives: z_i = log(τ_i)."""
return np.log(self.get_half_lives())
def forward(self, u: np.ndarray) -> np.ndarray:
"""
One step of oscillator dynamics.
h_i(t+1) = λ_i * h_i(t) + u_i(t)
Args:
u: Input signal, shape (num_oscillators, state_dim)
Returns:
Updated hidden states, shape (num_oscillators, state_dim)
"""
# Broadcast lambdas across state dimensions
lambdas_broadcast = self.lambdas[:, np.newaxis] # (n, 1)
# Apply dynamics
self.h = lambdas_broadcast * self.h + u
return self.h.copy()
def reset(self):
"""Reset oscillator states to zero."""
self.h = np.zeros((self.n, self.d))
def get_half_life_statistics(self) -> Dict[str, float]:
"""
Compute statistics of half-life distribution.
Returns:
Dictionary with mean, std, min, max in log space.
"""
taus = self.get_half_lives()
z = np.log(taus)
return {
"tau_min": float(np.min(taus)),
"tau_max": float(np.max(taus)),
"tau_mean": float(np.mean(taus)),
"tau_median": float(np.median(taus)),
"log_tau_mean": float(np.mean(z)),
"log_tau_std": float(np.std(z)),
"log_tau_min": float(np.min(z)),
"log_tau_max": float(np.max(z)),
}
def get_state(self) -> OscillatorState:
"""Get current oscillator state."""
return OscillatorState(
h=self.h.copy(),
lambdas=self.lambdas.copy()
)
def set_state(self, state: OscillatorState):
"""Set oscillator state."""
self.h = state.h.copy()
self.lambdas = state.lambdas.copy()
class FDRAWithOscillators:
"""
Full FDRA agent with oscillator bank for memory.
This extends the basic FDRA agent to use an oscillator bank
with explicit decay parameters that can be regularized.
Architecture:
Input → [Oscillator Bank] → Slow State → Output
↑ ↓
Fast State ←──────────────
"""
def __init__(
self,
osc_config: Optional[OscillatorConfig] = None,
wlc_threshold: float = 1.0
):
self.config = osc_config or OscillatorConfig()
self.oscillators = FDRAOscillatorBank(self.config)
self.wlc_threshold = wlc_threshold
# Fast state (reactive, for computation)
self.fast = np.zeros(self.config.state_dim)
# Energy tracking
self.energy = 0.0
self.history: List[Dict[str, Any]] = []
def get_slow_state(self) -> np.ndarray:
"""
Aggregate slow state from oscillator bank.
The slow state is a weighted sum of oscillator states,
with weights proportional to half-life.
"""
taus = self.oscillators.get_half_lives()
weights = taus / np.sum(taus) # Normalize
# Weighted sum across oscillators
weighted_h = self.oscillators.h * weights[:, np.newaxis]
slow = np.sum(weighted_h, axis=0) # (state_dim,)
return slow
def forward_dynamics(self, action: np.ndarray) -> np.ndarray:
"""
Forward dynamics with oscillator bank.
1. Distribute action across oscillators
2. Update oscillator bank
3. Compute slow state from oscillators
4. Update fast state
"""
# Distribute action to oscillators (same input, different decays)
u = np.tile(action, (self.config.num_oscillators, 1)) # (n, d)
# Scale by oscillator-specific factors (optional: could learn these)
scale = 0.1 * np.ones((self.config.num_oscillators, 1))
u = u * scale
# Update oscillators
self.oscillators.forward(u)
# Get slow state from oscillators
slow = self.get_slow_state()
# Update fast state (reactive)
self.fast = 0.9 * self.fast + action
# Energy
self.energy += np.linalg.norm(action) * 0.1
return slow
def get_coherence(self) -> float:
"""Coherence between slow and fast states."""
slow = self.get_slow_state()
slow_norm = np.linalg.norm(slow)
fast_norm = np.linalg.norm(self.fast)
if slow_norm < 1e-10 or fast_norm < 1e-10:
return 0.0
return float(np.dot(slow, self.fast) / (slow_norm * fast_norm))
def step(self, action: np.ndarray) -> Dict[str, Any]:
"""Execute one step and return diagnostics."""
slow = self.forward_dynamics(action)
coherence = self.get_coherence()
stats = self.oscillators.get_half_life_statistics()
result = {
"slow_norm": float(np.linalg.norm(slow)),
"fast_norm": float(np.linalg.norm(self.fast)),
"coherence": coherence,
"energy": self.energy,
**stats
}
self.history.append(result)
return result
def reset(self):
"""Reset all state."""
self.oscillators.reset()
self.fast = np.zeros(self.config.state_dim)
self.energy = 0.0
self.history = []
def demo_oscillators():
"""Demonstrate oscillator bank behavior."""
print("=" * 60)
print("FDRA OSCILLATOR BANK DEMONSTRATION")
print("=" * 60)
config = OscillatorConfig(
num_oscillators=16,
state_dim=8,
sequence_length=4096,
tau_min=1.0,
tau_max=4096.0
)
bank = FDRAOscillatorBank(config)
print("\n1. Initial Half-Life Distribution")
print("-" * 40)
stats = bank.get_half_life_statistics()
print(f" τ range: [{stats['tau_min']:.1f}, {stats['tau_max']:.1f}]")
print(f" τ mean: {stats['tau_mean']:.1f}")
print(f" log(τ) mean: {stats['log_tau_mean']:.3f}")
print(f" log(τ) std: {stats['log_tau_std']:.3f}")
print("\n2. Half-Lives per Oscillator")
print("-" * 40)
taus = bank.get_half_lives()
for i, tau in enumerate(taus):
bar = "█" * int(np.log(tau) * 3)
print(f" Osc {i:2d}: τ = {tau:7.1f} steps {bar}")
print("\n3. Simulating Input Sequence")
print("-" * 40)
# Pulse input at t=0
u = np.random.randn(config.num_oscillators, config.state_dim)
bank.forward(u)
initial_norms = np.linalg.norm(bank.h, axis=1)
# Decay for 100 steps with zero input
decay_steps = [10, 50, 100, 500, 1000]
zero_input = np.zeros((config.num_oscillators, config.state_dim))
step = 0
for target in decay_steps:
while step < target:
bank.forward(zero_input)
step += 1
current_norms = np.linalg.norm(bank.h, axis=1)
retention = current_norms / (initial_norms + 1e-10)
print(f"\n After {step} steps:")
for i, (tau, ret) in enumerate(zip(taus, retention)):
if tau < step * 0.5:
expected = "✗ (should be < 50%)"
else:
expected = "✓ (should be > 50%)"
print(f" Osc {i:2d}: τ={tau:7.1f}, retention={ret:.1%} {expected}")
if i >= 3:
print(f" ... ({len(taus) - 4} more)")
break
print("\n" + "=" * 60)
print("OBSERVATION: Oscillators with τ > t retain more than 50% of signal")
print("This is the desired behavior for long-context modeling.")
print("=" * 60)
if __name__ == "__main__":
demo_oscillators()