|
|
""" |
|
|
FDRA Oscillator Implementation with Explicit Decay Parameters |
|
|
|
|
|
This implements the core FDRA oscillator dynamics where each oscillator has: |
|
|
- A decay parameter λ_i ∈ (0, 1) |
|
|
- Half-life τ_i = ln(0.5) / ln(λ_i) |
|
|
|
|
|
The key problem this addresses (from Melanie/Tiago's discovery): |
|
|
- During training at GPT-2 scale, all λ_i collapse to near 1.0 (very short half-lives) |
|
|
- This means oscillators only attend to ~10 tokens instead of full context length |
|
|
- The model works for short-context tasks but fails on long-context reasoning |
|
|
|
|
|
Solution: Half-life regularization to maintain diversity across temporal scales. |
|
|
|
|
|
Authors: FDRA Half-Life Regularization Implementation |
|
|
Date: 2026-01-22 |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
from typing import Dict, List, Tuple, Optional, Any |
|
|
from dataclasses import dataclass |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OscillatorConfig: |
|
|
"""Configuration for FDRA oscillator bank.""" |
|
|
num_oscillators: int = 32 |
|
|
state_dim: int = 16 |
|
|
sequence_length: int = 4096 |
|
|
tau_min: float = 1.0 |
|
|
tau_max: float = 4096.0 |
|
|
|
|
|
|
|
|
init_method: str = "log_uniform" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OscillatorState: |
|
|
"""State of an oscillator bank.""" |
|
|
h: np.ndarray |
|
|
lambdas: np.ndarray |
|
|
|
|
|
def copy(self) -> 'OscillatorState': |
|
|
return OscillatorState( |
|
|
h=self.h.copy(), |
|
|
lambdas=self.lambdas.copy() |
|
|
) |
|
|
|
|
|
|
|
|
class FDRAOscillatorBank: |
|
|
""" |
|
|
FDRA Oscillator Bank with explicit decay parameters. |
|
|
|
|
|
Each oscillator i has: |
|
|
h_i(t+1) = λ_i * h_i(t) + u_i(t) |
|
|
|
|
|
Where: |
|
|
λ_i ∈ (0, 1) is the decay parameter |
|
|
τ_i = ln(0.5) / ln(λ_i) is the half-life |
|
|
|
|
|
Half-life interpretation: |
|
|
τ_i = number of steps for oscillator state to decay to 50% |
|
|
|
|
|
The goal of half-life regularization: |
|
|
Maintain log-uniform distribution of τ_i across [τ_min, τ_max] |
|
|
This ensures oscillators can attend to both short and long contexts. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: OscillatorConfig): |
|
|
self.config = config |
|
|
self.n = config.num_oscillators |
|
|
self.d = config.state_dim |
|
|
self.L = config.sequence_length |
|
|
|
|
|
|
|
|
self.lambdas = self._init_lambdas() |
|
|
|
|
|
|
|
|
self.h = np.zeros((self.n, self.d)) |
|
|
|
|
|
|
|
|
self.history: List[Dict[str, Any]] = [] |
|
|
|
|
|
def _init_lambdas(self) -> np.ndarray: |
|
|
""" |
|
|
Initialize decay parameters λ_i. |
|
|
|
|
|
For log-uniform half-lives, we want: |
|
|
τ_i ~ LogUniform(τ_min, τ_max) |
|
|
|
|
|
Since τ = ln(0.5) / ln(λ), we have: |
|
|
λ = 0.5^(1/τ) |
|
|
|
|
|
So for log-uniform τ: |
|
|
log(τ) ~ Uniform(log(τ_min), log(τ_max)) |
|
|
τ = exp(log_τ) |
|
|
λ = 0.5^(1/τ) |
|
|
""" |
|
|
if self.config.init_method == "log_uniform": |
|
|
|
|
|
log_tau_min = np.log(self.config.tau_min) |
|
|
log_tau_max = np.log(self.config.tau_max) |
|
|
|
|
|
|
|
|
log_taus = np.linspace(log_tau_min, log_tau_max, self.n) |
|
|
taus = np.exp(log_taus) |
|
|
|
|
|
|
|
|
|
|
|
lambdas = np.power(0.5, 1.0 / taus) |
|
|
|
|
|
else: |
|
|
|
|
|
lambdas = np.random.uniform(0.5, 0.99, self.n) |
|
|
|
|
|
return lambdas |
|
|
|
|
|
def get_half_lives(self) -> np.ndarray: |
|
|
""" |
|
|
Compute half-lives from decay parameters. |
|
|
|
|
|
τ_i = ln(0.5) / ln(λ_i) |
|
|
""" |
|
|
|
|
|
safe_lambdas = np.clip(self.lambdas, 1e-10, 1.0 - 1e-10) |
|
|
taus = np.log(0.5) / np.log(safe_lambdas) |
|
|
return taus |
|
|
|
|
|
def get_log_half_lives(self) -> np.ndarray: |
|
|
"""Get log of half-lives: z_i = log(τ_i).""" |
|
|
return np.log(self.get_half_lives()) |
|
|
|
|
|
def forward(self, u: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
One step of oscillator dynamics. |
|
|
|
|
|
h_i(t+1) = λ_i * h_i(t) + u_i(t) |
|
|
|
|
|
Args: |
|
|
u: Input signal, shape (num_oscillators, state_dim) |
|
|
|
|
|
Returns: |
|
|
Updated hidden states, shape (num_oscillators, state_dim) |
|
|
""" |
|
|
|
|
|
lambdas_broadcast = self.lambdas[:, np.newaxis] |
|
|
|
|
|
|
|
|
self.h = lambdas_broadcast * self.h + u |
|
|
|
|
|
return self.h.copy() |
|
|
|
|
|
def reset(self): |
|
|
"""Reset oscillator states to zero.""" |
|
|
self.h = np.zeros((self.n, self.d)) |
|
|
|
|
|
def get_half_life_statistics(self) -> Dict[str, float]: |
|
|
""" |
|
|
Compute statistics of half-life distribution. |
|
|
|
|
|
Returns: |
|
|
Dictionary with mean, std, min, max in log space. |
|
|
""" |
|
|
taus = self.get_half_lives() |
|
|
z = np.log(taus) |
|
|
|
|
|
return { |
|
|
"tau_min": float(np.min(taus)), |
|
|
"tau_max": float(np.max(taus)), |
|
|
"tau_mean": float(np.mean(taus)), |
|
|
"tau_median": float(np.median(taus)), |
|
|
"log_tau_mean": float(np.mean(z)), |
|
|
"log_tau_std": float(np.std(z)), |
|
|
"log_tau_min": float(np.min(z)), |
|
|
"log_tau_max": float(np.max(z)), |
|
|
} |
|
|
|
|
|
def get_state(self) -> OscillatorState: |
|
|
"""Get current oscillator state.""" |
|
|
return OscillatorState( |
|
|
h=self.h.copy(), |
|
|
lambdas=self.lambdas.copy() |
|
|
) |
|
|
|
|
|
def set_state(self, state: OscillatorState): |
|
|
"""Set oscillator state.""" |
|
|
self.h = state.h.copy() |
|
|
self.lambdas = state.lambdas.copy() |
|
|
|
|
|
|
|
|
class FDRAWithOscillators: |
|
|
""" |
|
|
Full FDRA agent with oscillator bank for memory. |
|
|
|
|
|
This extends the basic FDRA agent to use an oscillator bank |
|
|
with explicit decay parameters that can be regularized. |
|
|
|
|
|
Architecture: |
|
|
Input → [Oscillator Bank] → Slow State → Output |
|
|
↑ ↓ |
|
|
Fast State ←────────────── |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
osc_config: Optional[OscillatorConfig] = None, |
|
|
wlc_threshold: float = 1.0 |
|
|
): |
|
|
self.config = osc_config or OscillatorConfig() |
|
|
self.oscillators = FDRAOscillatorBank(self.config) |
|
|
self.wlc_threshold = wlc_threshold |
|
|
|
|
|
|
|
|
self.fast = np.zeros(self.config.state_dim) |
|
|
|
|
|
|
|
|
self.energy = 0.0 |
|
|
|
|
|
self.history: List[Dict[str, Any]] = [] |
|
|
|
|
|
def get_slow_state(self) -> np.ndarray: |
|
|
""" |
|
|
Aggregate slow state from oscillator bank. |
|
|
|
|
|
The slow state is a weighted sum of oscillator states, |
|
|
with weights proportional to half-life. |
|
|
""" |
|
|
taus = self.oscillators.get_half_lives() |
|
|
weights = taus / np.sum(taus) |
|
|
|
|
|
|
|
|
weighted_h = self.oscillators.h * weights[:, np.newaxis] |
|
|
slow = np.sum(weighted_h, axis=0) |
|
|
|
|
|
return slow |
|
|
|
|
|
def forward_dynamics(self, action: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Forward dynamics with oscillator bank. |
|
|
|
|
|
1. Distribute action across oscillators |
|
|
2. Update oscillator bank |
|
|
3. Compute slow state from oscillators |
|
|
4. Update fast state |
|
|
""" |
|
|
|
|
|
u = np.tile(action, (self.config.num_oscillators, 1)) |
|
|
|
|
|
|
|
|
scale = 0.1 * np.ones((self.config.num_oscillators, 1)) |
|
|
u = u * scale |
|
|
|
|
|
|
|
|
self.oscillators.forward(u) |
|
|
|
|
|
|
|
|
slow = self.get_slow_state() |
|
|
|
|
|
|
|
|
self.fast = 0.9 * self.fast + action |
|
|
|
|
|
|
|
|
self.energy += np.linalg.norm(action) * 0.1 |
|
|
|
|
|
return slow |
|
|
|
|
|
def get_coherence(self) -> float: |
|
|
"""Coherence between slow and fast states.""" |
|
|
slow = self.get_slow_state() |
|
|
slow_norm = np.linalg.norm(slow) |
|
|
fast_norm = np.linalg.norm(self.fast) |
|
|
|
|
|
if slow_norm < 1e-10 or fast_norm < 1e-10: |
|
|
return 0.0 |
|
|
|
|
|
return float(np.dot(slow, self.fast) / (slow_norm * fast_norm)) |
|
|
|
|
|
def step(self, action: np.ndarray) -> Dict[str, Any]: |
|
|
"""Execute one step and return diagnostics.""" |
|
|
slow = self.forward_dynamics(action) |
|
|
coherence = self.get_coherence() |
|
|
|
|
|
stats = self.oscillators.get_half_life_statistics() |
|
|
|
|
|
result = { |
|
|
"slow_norm": float(np.linalg.norm(slow)), |
|
|
"fast_norm": float(np.linalg.norm(self.fast)), |
|
|
"coherence": coherence, |
|
|
"energy": self.energy, |
|
|
**stats |
|
|
} |
|
|
|
|
|
self.history.append(result) |
|
|
return result |
|
|
|
|
|
def reset(self): |
|
|
"""Reset all state.""" |
|
|
self.oscillators.reset() |
|
|
self.fast = np.zeros(self.config.state_dim) |
|
|
self.energy = 0.0 |
|
|
self.history = [] |
|
|
|
|
|
|
|
|
def demo_oscillators(): |
|
|
"""Demonstrate oscillator bank behavior.""" |
|
|
print("=" * 60) |
|
|
print("FDRA OSCILLATOR BANK DEMONSTRATION") |
|
|
print("=" * 60) |
|
|
|
|
|
config = OscillatorConfig( |
|
|
num_oscillators=16, |
|
|
state_dim=8, |
|
|
sequence_length=4096, |
|
|
tau_min=1.0, |
|
|
tau_max=4096.0 |
|
|
) |
|
|
|
|
|
bank = FDRAOscillatorBank(config) |
|
|
|
|
|
print("\n1. Initial Half-Life Distribution") |
|
|
print("-" * 40) |
|
|
stats = bank.get_half_life_statistics() |
|
|
print(f" τ range: [{stats['tau_min']:.1f}, {stats['tau_max']:.1f}]") |
|
|
print(f" τ mean: {stats['tau_mean']:.1f}") |
|
|
print(f" log(τ) mean: {stats['log_tau_mean']:.3f}") |
|
|
print(f" log(τ) std: {stats['log_tau_std']:.3f}") |
|
|
|
|
|
print("\n2. Half-Lives per Oscillator") |
|
|
print("-" * 40) |
|
|
taus = bank.get_half_lives() |
|
|
for i, tau in enumerate(taus): |
|
|
bar = "█" * int(np.log(tau) * 3) |
|
|
print(f" Osc {i:2d}: τ = {tau:7.1f} steps {bar}") |
|
|
|
|
|
print("\n3. Simulating Input Sequence") |
|
|
print("-" * 40) |
|
|
|
|
|
|
|
|
u = np.random.randn(config.num_oscillators, config.state_dim) |
|
|
bank.forward(u) |
|
|
initial_norms = np.linalg.norm(bank.h, axis=1) |
|
|
|
|
|
|
|
|
decay_steps = [10, 50, 100, 500, 1000] |
|
|
zero_input = np.zeros((config.num_oscillators, config.state_dim)) |
|
|
|
|
|
step = 0 |
|
|
for target in decay_steps: |
|
|
while step < target: |
|
|
bank.forward(zero_input) |
|
|
step += 1 |
|
|
|
|
|
current_norms = np.linalg.norm(bank.h, axis=1) |
|
|
retention = current_norms / (initial_norms + 1e-10) |
|
|
|
|
|
print(f"\n After {step} steps:") |
|
|
for i, (tau, ret) in enumerate(zip(taus, retention)): |
|
|
if tau < step * 0.5: |
|
|
expected = "✗ (should be < 50%)" |
|
|
else: |
|
|
expected = "✓ (should be > 50%)" |
|
|
print(f" Osc {i:2d}: τ={tau:7.1f}, retention={ret:.1%} {expected}") |
|
|
if i >= 3: |
|
|
print(f" ... ({len(taus) - 4} more)") |
|
|
break |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("OBSERVATION: Oscillators with τ > t retain more than 50% of signal") |
|
|
print("This is the desired behavior for long-context modeling.") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo_oscillators() |
|
|
|