|
|
""" |
|
|
Half-Life Regularizer for FDRA Oscillators |
|
|
|
|
|
This implements the exact mathematical regularizer from the Cursor instructions: |
|
|
|
|
|
## Regularizer 1: Log-Uniform Half-Life Prior (primary) |
|
|
|
|
|
Target distribution: p(τ) ∝ 1/τ for τ ∈ [τ_min, τ_max] |
|
|
This gives equal mass per temporal decade (log scale). |
|
|
|
|
|
Loss: |
|
|
z_i = log(τ_i) |
|
|
μ = mean(z_i) |
|
|
σ² = mean((z_i - μ)²) |
|
|
|
|
|
μ* = (log(τ_min) + log(τ_max)) / 2 |
|
|
σ²* = (log(τ_max) - log(τ_min))² / 12 |
|
|
|
|
|
L_HL = α*(μ - μ*)² + β*(σ² - σ²*)² |
|
|
|
|
|
## Regularizer 2: Long-Tail Survival Constraint (supporting) |
|
|
|
|
|
Ensure existence of long-range oscillators: |
|
|
s_i = σ(k * (τ_i - γ*L)) |
|
|
tail_mass = mean(s_i) |
|
|
L_tail = max(0, ρ - tail_mass)² |
|
|
|
|
|
Where: |
|
|
γ = 0.5 (fraction of full context) |
|
|
ρ = 0.05 (minimum fraction of oscillators) |
|
|
k = 10.0 (sigmoid sharpness) |
|
|
|
|
|
## Regularizer 3: Tau Bounds Constraint (CRITICAL FIX) |
|
|
|
|
|
The moment-matching loss (L_HL) can be satisfied by a pathological bimodal |
|
|
distribution with taus outside [tau_min, tau_max]. This creates oscillators |
|
|
that are either useless (tau << 1) or extreme (tau >> L). |
|
|
|
|
|
L_bounds = mean(relu(tau_min - tau_i)^2) + mean(relu(tau_i - tau_max)^2) |
|
|
|
|
|
## Combined Loss |
|
|
|
|
|
L_total = L_task + λ1 * L_HL + λ2 * L_tail + λ3 * L_bounds |
|
|
|
|
|
Authors: Half-Life Regularization Implementation |
|
|
Date: 2026-01-22 |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
from typing import Dict, Tuple, Optional, Any |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
import json |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class HalfLifeRegularizerConfig: |
|
|
"""Configuration for half-life regularization.""" |
|
|
|
|
|
|
|
|
sequence_length: int = 4096 |
|
|
tau_min: float = 1.0 |
|
|
tau_max: float = 4096.0 |
|
|
|
|
|
|
|
|
alpha: float = 1.0 |
|
|
beta: float = 1.0 |
|
|
|
|
|
|
|
|
gamma: float = 0.5 |
|
|
rho: float = 0.05 |
|
|
k: float = 10.0 |
|
|
|
|
|
|
|
|
lambda1: float = 0.01 |
|
|
lambda2: float = 0.01 |
|
|
|
|
|
|
|
|
lambda3: float = 0.1 |
|
|
bound_sharpness: float = 5.0 |
|
|
|
|
|
|
|
|
class HalfLifeRegularizer: |
|
|
""" |
|
|
Half-Life Regularizer for FDRA Oscillator Banks. |
|
|
|
|
|
Prevents decay parameter collapse by regularizing the half-life |
|
|
distribution toward a log-uniform target. |
|
|
|
|
|
Usage: |
|
|
config = HalfLifeRegularizerConfig() |
|
|
regularizer = HalfLifeRegularizer(config) |
|
|
|
|
|
# During training: |
|
|
lambdas = oscillator_bank.lambdas |
|
|
loss, metrics = regularizer.compute(lambdas) |
|
|
|
|
|
# Add to total loss: |
|
|
total_loss = task_loss + loss |
|
|
|
|
|
# Log metrics: |
|
|
log(metrics) |
|
|
""" |
|
|
|
|
|
def __init__(self, config: HalfLifeRegularizerConfig): |
|
|
self.config = config |
|
|
|
|
|
|
|
|
z_min = np.log(config.tau_min) |
|
|
z_max = np.log(config.tau_max) |
|
|
|
|
|
|
|
|
self.mu_star = (z_min + z_max) / 2.0 |
|
|
|
|
|
|
|
|
self.sigma2_star = (z_max - z_min) ** 2 / 12.0 |
|
|
|
|
|
|
|
|
self.tau_threshold = config.gamma * config.sequence_length |
|
|
|
|
|
def lambdas_to_half_lives(self, lambdas: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Convert decay parameters to half-lives. |
|
|
|
|
|
τ_i = ln(0.5) / ln(λ_i) |
|
|
|
|
|
Args: |
|
|
lambdas: Decay parameters, shape (N,) |
|
|
|
|
|
Returns: |
|
|
taus: Half-lives, shape (N,) |
|
|
""" |
|
|
|
|
|
safe_lambdas = np.clip(lambdas, 1e-10, 1.0 - 1e-10) |
|
|
taus = np.log(0.5) / np.log(safe_lambdas) |
|
|
return taus |
|
|
|
|
|
def compute_log_uniform_loss( |
|
|
self, |
|
|
lambdas: np.ndarray |
|
|
) -> Tuple[float, Dict[str, float]]: |
|
|
""" |
|
|
Compute Log-Uniform Half-Life Prior loss. |
|
|
|
|
|
L_HL = α*(μ - μ*)² + β*(σ² - σ²*)² |
|
|
|
|
|
Args: |
|
|
lambdas: Decay parameters, shape (N,) |
|
|
|
|
|
Returns: |
|
|
loss: Scalar loss value |
|
|
metrics: Dictionary of component metrics |
|
|
""" |
|
|
|
|
|
taus = self.lambdas_to_half_lives(lambdas) |
|
|
z = np.log(taus) |
|
|
|
|
|
|
|
|
mu = np.mean(z) |
|
|
sigma2 = np.var(z) |
|
|
|
|
|
|
|
|
mean_loss = self.config.alpha * (mu - self.mu_star) ** 2 |
|
|
var_loss = self.config.beta * (sigma2 - self.sigma2_star) ** 2 |
|
|
|
|
|
loss = mean_loss + var_loss |
|
|
|
|
|
metrics = { |
|
|
"log_tau_mean": float(mu), |
|
|
"log_tau_var": float(sigma2), |
|
|
"log_tau_target_mean": float(self.mu_star), |
|
|
"log_tau_target_var": float(self.sigma2_star), |
|
|
"mean_deviation": float(abs(mu - self.mu_star)), |
|
|
"var_deviation": float(abs(sigma2 - self.sigma2_star)), |
|
|
"log_uniform_loss": float(loss), |
|
|
} |
|
|
|
|
|
return float(loss), metrics |
|
|
|
|
|
def compute_long_tail_loss( |
|
|
self, |
|
|
lambdas: np.ndarray |
|
|
) -> Tuple[float, Dict[str, float]]: |
|
|
""" |
|
|
Compute Long-Tail Survival Constraint loss. |
|
|
|
|
|
s_i = σ(k * (τ_i - γ*L)) |
|
|
tail_mass = mean(s_i) |
|
|
L_tail = max(0, ρ - tail_mass)² |
|
|
|
|
|
Args: |
|
|
lambdas: Decay parameters, shape (N,) |
|
|
|
|
|
Returns: |
|
|
loss: Scalar loss value |
|
|
metrics: Dictionary of component metrics |
|
|
""" |
|
|
|
|
|
taus = self.lambdas_to_half_lives(lambdas) |
|
|
|
|
|
|
|
|
|
|
|
x = self.config.k * (taus - self.tau_threshold) |
|
|
x = np.clip(x, -500, 500) |
|
|
s = 1.0 / (1.0 + np.exp(-x)) |
|
|
|
|
|
|
|
|
tail_mass = np.mean(s) |
|
|
|
|
|
|
|
|
deficit = max(0, self.config.rho - tail_mass) |
|
|
loss = deficit ** 2 |
|
|
|
|
|
|
|
|
n_long_range = np.sum(taus > self.tau_threshold) |
|
|
frac_long_range = n_long_range / len(taus) |
|
|
|
|
|
metrics = { |
|
|
"tail_mass": float(tail_mass), |
|
|
"tail_target": float(self.config.rho), |
|
|
"tail_deficit": float(deficit), |
|
|
"n_long_range": int(n_long_range), |
|
|
"frac_long_range": float(frac_long_range), |
|
|
"tau_threshold": float(self.tau_threshold), |
|
|
"long_tail_loss": float(loss), |
|
|
} |
|
|
|
|
|
return float(loss), metrics |
|
|
|
|
|
def compute_bounds_loss( |
|
|
self, |
|
|
lambdas: np.ndarray |
|
|
) -> Tuple[float, Dict[str, float]]: |
|
|
""" |
|
|
Compute tau bounds constraint loss. |
|
|
|
|
|
CRITICAL FIX: The moment-matching loss alone can be satisfied by |
|
|
a pathological bimodal distribution with taus outside [tau_min, tau_max]. |
|
|
|
|
|
This loss penalizes taus below tau_min or above tau_max: |
|
|
|
|
|
L_bounds = mean(relu(tau_min - tau_i)^2) + mean(relu(tau_i - tau_max)^2) |
|
|
|
|
|
Uses soft penalty with configurable sharpness. |
|
|
""" |
|
|
taus = self.lambdas_to_half_lives(lambdas) |
|
|
k = self.config.bound_sharpness |
|
|
|
|
|
|
|
|
below_min = np.maximum(0, self.config.tau_min - taus) |
|
|
lower_penalty = np.mean((k * below_min) ** 2) |
|
|
|
|
|
|
|
|
above_max = np.maximum(0, taus - self.config.tau_max) |
|
|
upper_penalty = np.mean((k * above_max) ** 2) |
|
|
|
|
|
loss = lower_penalty + upper_penalty |
|
|
|
|
|
n_below_min = np.sum(taus < self.config.tau_min) |
|
|
n_above_max = np.sum(taus > self.config.tau_max) |
|
|
|
|
|
metrics = { |
|
|
"bounds_loss": float(loss), |
|
|
"lower_bound_penalty": float(lower_penalty), |
|
|
"upper_bound_penalty": float(upper_penalty), |
|
|
"n_below_tau_min": int(n_below_min), |
|
|
"n_above_tau_max": int(n_above_max), |
|
|
"frac_in_bounds": float(1 - (n_below_min + n_above_max) / len(taus)), |
|
|
} |
|
|
|
|
|
return float(loss), metrics |
|
|
|
|
|
def compute(self, lambdas: np.ndarray) -> Tuple[float, Dict[str, Any]]: |
|
|
""" |
|
|
Compute total half-life regularization loss. |
|
|
|
|
|
L_total = λ1 * L_HL + λ2 * L_tail + λ3 * L_bounds |
|
|
|
|
|
CRITICAL: L_bounds prevents the pathological case where moment-matching |
|
|
is satisfied by a bimodal distribution with taus outside [tau_min, tau_max]. |
|
|
|
|
|
Args: |
|
|
lambdas: Decay parameters, shape (N,) |
|
|
|
|
|
Returns: |
|
|
loss: Total regularization loss |
|
|
metrics: All component metrics |
|
|
""" |
|
|
|
|
|
log_uniform_loss, log_uniform_metrics = self.compute_log_uniform_loss(lambdas) |
|
|
long_tail_loss, long_tail_metrics = self.compute_long_tail_loss(lambdas) |
|
|
bounds_loss, bounds_metrics = self.compute_bounds_loss(lambdas) |
|
|
|
|
|
|
|
|
total_loss = ( |
|
|
self.config.lambda1 * log_uniform_loss + |
|
|
self.config.lambda2 * long_tail_loss + |
|
|
self.config.lambda3 * bounds_loss |
|
|
) |
|
|
|
|
|
|
|
|
taus = self.lambdas_to_half_lives(lambdas) |
|
|
|
|
|
metrics = { |
|
|
"total_regularization_loss": float(total_loss), |
|
|
"log_uniform_component": float(self.config.lambda1 * log_uniform_loss), |
|
|
"long_tail_component": float(self.config.lambda2 * long_tail_loss), |
|
|
"bounds_component": float(self.config.lambda3 * bounds_loss), |
|
|
"tau_min": float(np.min(taus)), |
|
|
"tau_max": float(np.max(taus)), |
|
|
"tau_mean": float(np.mean(taus)), |
|
|
"tau_median": float(np.median(taus)), |
|
|
**log_uniform_metrics, |
|
|
**long_tail_metrics, |
|
|
**bounds_metrics, |
|
|
} |
|
|
|
|
|
return float(total_loss), metrics |
|
|
|
|
|
def compute_gradient( |
|
|
self, |
|
|
lambdas: np.ndarray, |
|
|
epsilon: float = 1e-5 |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Compute gradient of regularization loss w.r.t. lambdas. |
|
|
|
|
|
Uses finite differences for simplicity. |
|
|
In a real implementation, this would use autodiff. |
|
|
|
|
|
Args: |
|
|
lambdas: Decay parameters, shape (N,) |
|
|
epsilon: Perturbation size |
|
|
|
|
|
Returns: |
|
|
grad: Gradient, shape (N,) |
|
|
""" |
|
|
grad = np.zeros_like(lambdas) |
|
|
|
|
|
for i in range(len(lambdas)): |
|
|
|
|
|
lambdas_plus = lambdas.copy() |
|
|
lambdas_plus[i] += epsilon |
|
|
loss_plus, _ = self.compute(lambdas_plus) |
|
|
|
|
|
|
|
|
lambdas_minus = lambdas.copy() |
|
|
lambdas_minus[i] -= epsilon |
|
|
loss_minus, _ = self.compute(lambdas_minus) |
|
|
|
|
|
|
|
|
grad[i] = (loss_plus - loss_minus) / (2 * epsilon) |
|
|
|
|
|
return grad |
|
|
|
|
|
def diagnose(self, lambdas: np.ndarray) -> str: |
|
|
""" |
|
|
Generate diagnostic string for current half-life distribution. |
|
|
|
|
|
Args: |
|
|
lambdas: Decay parameters |
|
|
|
|
|
Returns: |
|
|
Diagnostic string |
|
|
""" |
|
|
loss, metrics = self.compute(lambdas) |
|
|
taus = self.lambdas_to_half_lives(lambdas) |
|
|
|
|
|
lines = [ |
|
|
"=" * 60, |
|
|
"HALF-LIFE REGULARIZER DIAGNOSTICS", |
|
|
"=" * 60, |
|
|
"", |
|
|
"Current Distribution:", |
|
|
f" τ range: [{metrics['tau_min']:.1f}, {metrics['tau_max']:.1f}]", |
|
|
f" τ mean: {metrics['tau_mean']:.1f}", |
|
|
f" τ median: {metrics['tau_median']:.1f}", |
|
|
"", |
|
|
"Target Distribution:", |
|
|
f" τ range: [{self.config.tau_min}, {self.config.tau_max}]", |
|
|
f" log(τ) target mean: {self.mu_star:.3f}", |
|
|
f" log(τ) target var: {self.sigma2_star:.3f}", |
|
|
"", |
|
|
"Log-Uniform Prior:", |
|
|
f" log(τ) mean: {metrics['log_tau_mean']:.3f} (target: {metrics['log_tau_target_mean']:.3f})", |
|
|
f" log(τ) var: {metrics['log_tau_var']:.3f} (target: {metrics['log_tau_target_var']:.3f})", |
|
|
f" Mean deviation: {metrics['mean_deviation']:.3f}", |
|
|
f" Var deviation: {metrics['var_deviation']:.3f}", |
|
|
f" Loss: {metrics['log_uniform_loss']:.6f}", |
|
|
"", |
|
|
"Long-Tail Survival:", |
|
|
f" Threshold: τ > {metrics['tau_threshold']:.1f}", |
|
|
f" Long-range count: {metrics['n_long_range']}/{len(lambdas)} ({metrics['frac_long_range']:.1%})", |
|
|
f" Tail mass (soft): {metrics['tail_mass']:.3f} (target: {metrics['tail_target']:.3f})", |
|
|
f" Loss: {metrics['long_tail_loss']:.6f}", |
|
|
"", |
|
|
"Total Regularization Loss:", |
|
|
f" Log-uniform component: {metrics['log_uniform_component']:.6f}", |
|
|
f" Long-tail component: {metrics['long_tail_component']:.6f}", |
|
|
f" Total: {metrics['total_regularization_loss']:.6f}", |
|
|
"", |
|
|
] |
|
|
|
|
|
|
|
|
lines.append("Half-Life Histogram (log scale):") |
|
|
bins = np.logspace(0, np.log10(self.config.tau_max), 11) |
|
|
hist, _ = np.histogram(taus, bins=bins) |
|
|
for i, count in enumerate(hist): |
|
|
bar = "█" * count |
|
|
lines.append(f" [{bins[i]:7.1f}, {bins[i+1]:7.1f}): {count:2d} {bar}") |
|
|
|
|
|
lines.append("") |
|
|
lines.append("=" * 60) |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
def simulate_collapse_and_recovery(): |
|
|
""" |
|
|
Simulate the half-life collapse problem and demonstrate regularization. |
|
|
|
|
|
This shows: |
|
|
1. Initial log-uniform distribution (good) |
|
|
2. Simulated collapse to short half-lives (bad, mimics training at scale) |
|
|
3. Regularization gradient direction (recovery) |
|
|
""" |
|
|
print("=" * 70) |
|
|
print("HALF-LIFE COLLAPSE AND REGULARIZATION DEMONSTRATION") |
|
|
print("=" * 70) |
|
|
|
|
|
config = HalfLifeRegularizerConfig( |
|
|
sequence_length=4096, |
|
|
tau_min=1.0, |
|
|
tau_max=4096.0, |
|
|
lambda1=0.01, |
|
|
lambda2=0.01 |
|
|
) |
|
|
|
|
|
regularizer = HalfLifeRegularizer(config) |
|
|
|
|
|
|
|
|
print("\n1. INITIAL DISTRIBUTION (Log-Uniform)") |
|
|
print("-" * 60) |
|
|
|
|
|
n_oscillators = 32 |
|
|
log_taus_init = np.linspace(np.log(1.0), np.log(4096.0), n_oscillators) |
|
|
taus_init = np.exp(log_taus_init) |
|
|
lambdas_init = np.power(0.5, 1.0 / taus_init) |
|
|
|
|
|
loss_init, metrics_init = regularizer.compute(lambdas_init) |
|
|
print(f" Half-lives: [{metrics_init['tau_min']:.1f}, {metrics_init['tau_max']:.1f}]") |
|
|
print(f" Regularization loss: {loss_init:.6f}") |
|
|
print(f" Long-range oscillators: {metrics_init['n_long_range']}/{n_oscillators}") |
|
|
|
|
|
|
|
|
print("\n2. COLLAPSED DISTRIBUTION (Training at Scale)") |
|
|
print("-" * 60) |
|
|
print(" Simulating what happens during GPT-2 scale training...") |
|
|
|
|
|
|
|
|
taus_collapsed = np.random.uniform(2, 10, n_oscillators) |
|
|
lambdas_collapsed = np.power(0.5, 1.0 / taus_collapsed) |
|
|
|
|
|
loss_collapsed, metrics_collapsed = regularizer.compute(lambdas_collapsed) |
|
|
print(f" Half-lives: [{metrics_collapsed['tau_min']:.1f}, {metrics_collapsed['tau_max']:.1f}]") |
|
|
print(f" Regularization loss: {loss_collapsed:.6f} ({loss_collapsed/loss_init:.0f}x initial)") |
|
|
print(f" Long-range oscillators: {metrics_collapsed['n_long_range']}/{n_oscillators}") |
|
|
|
|
|
|
|
|
print("\n3. REGULARIZATION GRADIENT ANALYSIS") |
|
|
print("-" * 60) |
|
|
|
|
|
grad = regularizer.compute_gradient(lambdas_collapsed) |
|
|
|
|
|
print(" Gradient direction indicates how to adjust λ_i to reduce loss:") |
|
|
print(" (Negative gradient → increase λ → longer half-life)") |
|
|
print() |
|
|
|
|
|
|
|
|
for i in range(min(5, n_oscillators)): |
|
|
tau_i = taus_collapsed[i] |
|
|
grad_i = grad[i] |
|
|
direction = "→ increase τ" if grad_i < 0 else "→ decrease τ" |
|
|
print(f" Osc {i}: τ={tau_i:.1f}, grad={grad_i:+.4f} {direction}") |
|
|
|
|
|
print(f" ... ({n_oscillators - 5} more)") |
|
|
print(f"\n Mean gradient magnitude: {np.mean(np.abs(grad)):.4f}") |
|
|
|
|
|
|
|
|
print("\n4. AFTER REGULARIZATION STEP") |
|
|
print("-" * 60) |
|
|
|
|
|
lr = 1.0 |
|
|
lambdas_reg = lambdas_collapsed - lr * grad |
|
|
lambdas_reg = np.clip(lambdas_reg, 0.01, 0.9999) |
|
|
|
|
|
loss_reg, metrics_reg = regularizer.compute(lambdas_reg) |
|
|
|
|
|
print(f" Half-lives: [{metrics_reg['tau_min']:.1f}, {metrics_reg['tau_max']:.1f}]") |
|
|
print(f" Regularization loss: {loss_reg:.6f} ({loss_reg/loss_collapsed:.1%} of collapsed)") |
|
|
print(f" Long-range oscillators: {metrics_reg['n_long_range']}/{n_oscillators}") |
|
|
|
|
|
|
|
|
print("\n5. SUMMARY") |
|
|
print("-" * 60) |
|
|
print(f""" |
|
|
State | Loss | τ range | Long-range |
|
|
-------------------|-----------|-----------------|------------ |
|
|
Initial (good) | {loss_init:.6f} | [{metrics_init['tau_min']:.1f}, {metrics_init['tau_max']:.1f}] | {metrics_init['n_long_range']}/{n_oscillators} |
|
|
Collapsed (bad) | {loss_collapsed:.6f} | [{metrics_collapsed['tau_min']:.1f}, {metrics_collapsed['tau_max']:.1f}] | {metrics_collapsed['n_long_range']}/{n_oscillators} |
|
|
After 1 reg step | {loss_reg:.6f} | [{metrics_reg['tau_min']:.1f}, {metrics_reg['tau_max']:.1f}] | {metrics_reg['n_long_range']}/{n_oscillators} |
|
|
""") |
|
|
|
|
|
print("=" * 70) |
|
|
print("CONCLUSION:") |
|
|
print(" The regularizer provides gradients that push collapsed half-lives") |
|
|
print(" back toward a log-uniform distribution spanning the full context.") |
|
|
print("=" * 70) |
|
|
|
|
|
return { |
|
|
"initial": {"loss": loss_init, "metrics": metrics_init}, |
|
|
"collapsed": {"loss": loss_collapsed, "metrics": metrics_collapsed}, |
|
|
"regularized": {"loss": loss_reg, "metrics": metrics_reg}, |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
simulate_collapse_and_recovery() |
|
|
|