fdra-half-life-regularization / code /half_life_regularizer.py
juddddd's picture
Upload code/half_life_regularizer.py with huggingface_hub
4e48d38 verified
"""
Half-Life Regularizer for FDRA Oscillators
This implements the exact mathematical regularizer from the Cursor instructions:
## Regularizer 1: Log-Uniform Half-Life Prior (primary)
Target distribution: p(τ) ∝ 1/τ for τ ∈ [τ_min, τ_max]
This gives equal mass per temporal decade (log scale).
Loss:
z_i = log(τ_i)
μ = mean(z_i)
σ² = mean((z_i - μ)²)
μ* = (log(τ_min) + log(τ_max)) / 2
σ²* = (log(τ_max) - log(τ_min))² / 12
L_HL = α*(μ - μ*)² + β*(σ² - σ²*)²
## Regularizer 2: Long-Tail Survival Constraint (supporting)
Ensure existence of long-range oscillators:
s_i = σ(k * (τ_i - γ*L))
tail_mass = mean(s_i)
L_tail = max(0, ρ - tail_mass)²
Where:
γ = 0.5 (fraction of full context)
ρ = 0.05 (minimum fraction of oscillators)
k = 10.0 (sigmoid sharpness)
## Regularizer 3: Tau Bounds Constraint (CRITICAL FIX)
The moment-matching loss (L_HL) can be satisfied by a pathological bimodal
distribution with taus outside [tau_min, tau_max]. This creates oscillators
that are either useless (tau << 1) or extreme (tau >> L).
L_bounds = mean(relu(tau_min - tau_i)^2) + mean(relu(tau_i - tau_max)^2)
## Combined Loss
L_total = L_task + λ1 * L_HL + λ2 * L_tail + λ3 * L_bounds
Authors: Half-Life Regularization Implementation
Date: 2026-01-22
"""
import numpy as np
from typing import Dict, Tuple, Optional, Any
from dataclasses import dataclass
from pathlib import Path
import json
from datetime import datetime
@dataclass
class HalfLifeRegularizerConfig:
"""Configuration for half-life regularization."""
# Task parameters
sequence_length: int = 4096 # L - max sequence length
tau_min: float = 1.0 # Minimum target half-life
tau_max: float = 4096.0 # Maximum target half-life (= L)
# Log-Uniform Prior coefficients
alpha: float = 1.0 # Weight for mean constraint
beta: float = 1.0 # Weight for variance constraint
# Long-Tail Survival coefficients
gamma: float = 0.5 # Fraction of full context for long-range
rho: float = 0.05 # Minimum fraction of long-range oscillators
k: float = 10.0 # Sigmoid sharpness
# Overall loss weights
lambda1: float = 0.01 # Weight for L_HL in total loss
lambda2: float = 0.01 # Weight for L_tail in total loss
# NEW: Tau bound constraint (prevents pathological distributions)
lambda3: float = 0.1 # Weight for L_bounds
bound_sharpness: float = 5.0 # Sharpness of soft bound penalties
class HalfLifeRegularizer:
"""
Half-Life Regularizer for FDRA Oscillator Banks.
Prevents decay parameter collapse by regularizing the half-life
distribution toward a log-uniform target.
Usage:
config = HalfLifeRegularizerConfig()
regularizer = HalfLifeRegularizer(config)
# During training:
lambdas = oscillator_bank.lambdas
loss, metrics = regularizer.compute(lambdas)
# Add to total loss:
total_loss = task_loss + loss
# Log metrics:
log(metrics)
"""
def __init__(self, config: HalfLifeRegularizerConfig):
self.config = config
# Pre-compute target statistics
z_min = np.log(config.tau_min)
z_max = np.log(config.tau_max)
# Target mean in log space (center of [z_min, z_max])
self.mu_star = (z_min + z_max) / 2.0
# Target variance in log space (variance of uniform on [z_min, z_max])
self.sigma2_star = (z_max - z_min) ** 2 / 12.0
# Long-range threshold
self.tau_threshold = config.gamma * config.sequence_length
def lambdas_to_half_lives(self, lambdas: np.ndarray) -> np.ndarray:
"""
Convert decay parameters to half-lives.
τ_i = ln(0.5) / ln(λ_i)
Args:
lambdas: Decay parameters, shape (N,)
Returns:
taus: Half-lives, shape (N,)
"""
# Clamp to avoid numerical issues
safe_lambdas = np.clip(lambdas, 1e-10, 1.0 - 1e-10)
taus = np.log(0.5) / np.log(safe_lambdas)
return taus
def compute_log_uniform_loss(
self,
lambdas: np.ndarray
) -> Tuple[float, Dict[str, float]]:
"""
Compute Log-Uniform Half-Life Prior loss.
L_HL = α*(μ - μ*)² + β*(σ² - σ²*)²
Args:
lambdas: Decay parameters, shape (N,)
Returns:
loss: Scalar loss value
metrics: Dictionary of component metrics
"""
# Compute half-lives and log half-lives
taus = self.lambdas_to_half_lives(lambdas)
z = np.log(taus)
# Current statistics
mu = np.mean(z)
sigma2 = np.var(z)
# Compute loss components
mean_loss = self.config.alpha * (mu - self.mu_star) ** 2
var_loss = self.config.beta * (sigma2 - self.sigma2_star) ** 2
loss = mean_loss + var_loss
metrics = {
"log_tau_mean": float(mu),
"log_tau_var": float(sigma2),
"log_tau_target_mean": float(self.mu_star),
"log_tau_target_var": float(self.sigma2_star),
"mean_deviation": float(abs(mu - self.mu_star)),
"var_deviation": float(abs(sigma2 - self.sigma2_star)),
"log_uniform_loss": float(loss),
}
return float(loss), metrics
def compute_long_tail_loss(
self,
lambdas: np.ndarray
) -> Tuple[float, Dict[str, float]]:
"""
Compute Long-Tail Survival Constraint loss.
s_i = σ(k * (τ_i - γ*L))
tail_mass = mean(s_i)
L_tail = max(0, ρ - tail_mass)²
Args:
lambdas: Decay parameters, shape (N,)
Returns:
loss: Scalar loss value
metrics: Dictionary of component metrics
"""
# Compute half-lives
taus = self.lambdas_to_half_lives(lambdas)
# Sigmoid for soft thresholding (with numerical stability)
# s_i ≈ 1 if τ_i > threshold, ≈ 0 otherwise
x = self.config.k * (taus - self.tau_threshold)
x = np.clip(x, -500, 500) # Prevent overflow
s = 1.0 / (1.0 + np.exp(-x))
# Fraction of oscillators in long-tail regime
tail_mass = np.mean(s)
# Loss: penalize if tail_mass < rho
deficit = max(0, self.config.rho - tail_mass)
loss = deficit ** 2
# Count actual long-range oscillators (hard threshold)
n_long_range = np.sum(taus > self.tau_threshold)
frac_long_range = n_long_range / len(taus)
metrics = {
"tail_mass": float(tail_mass),
"tail_target": float(self.config.rho),
"tail_deficit": float(deficit),
"n_long_range": int(n_long_range),
"frac_long_range": float(frac_long_range),
"tau_threshold": float(self.tau_threshold),
"long_tail_loss": float(loss),
}
return float(loss), metrics
def compute_bounds_loss(
self,
lambdas: np.ndarray
) -> Tuple[float, Dict[str, float]]:
"""
Compute tau bounds constraint loss.
CRITICAL FIX: The moment-matching loss alone can be satisfied by
a pathological bimodal distribution with taus outside [tau_min, tau_max].
This loss penalizes taus below tau_min or above tau_max:
L_bounds = mean(relu(tau_min - tau_i)^2) + mean(relu(tau_i - tau_max)^2)
Uses soft penalty with configurable sharpness.
"""
taus = self.lambdas_to_half_lives(lambdas)
k = self.config.bound_sharpness
# Soft lower bound: penalize tau < tau_min
below_min = np.maximum(0, self.config.tau_min - taus)
lower_penalty = np.mean((k * below_min) ** 2)
# Soft upper bound: penalize tau > tau_max
above_max = np.maximum(0, taus - self.config.tau_max)
upper_penalty = np.mean((k * above_max) ** 2)
loss = lower_penalty + upper_penalty
n_below_min = np.sum(taus < self.config.tau_min)
n_above_max = np.sum(taus > self.config.tau_max)
metrics = {
"bounds_loss": float(loss),
"lower_bound_penalty": float(lower_penalty),
"upper_bound_penalty": float(upper_penalty),
"n_below_tau_min": int(n_below_min),
"n_above_tau_max": int(n_above_max),
"frac_in_bounds": float(1 - (n_below_min + n_above_max) / len(taus)),
}
return float(loss), metrics
def compute(self, lambdas: np.ndarray) -> Tuple[float, Dict[str, Any]]:
"""
Compute total half-life regularization loss.
L_total = λ1 * L_HL + λ2 * L_tail + λ3 * L_bounds
CRITICAL: L_bounds prevents the pathological case where moment-matching
is satisfied by a bimodal distribution with taus outside [tau_min, tau_max].
Args:
lambdas: Decay parameters, shape (N,)
Returns:
loss: Total regularization loss
metrics: All component metrics
"""
# Compute component losses
log_uniform_loss, log_uniform_metrics = self.compute_log_uniform_loss(lambdas)
long_tail_loss, long_tail_metrics = self.compute_long_tail_loss(lambdas)
bounds_loss, bounds_metrics = self.compute_bounds_loss(lambdas)
# Weighted combination (bounds loss is CRITICAL)
total_loss = (
self.config.lambda1 * log_uniform_loss +
self.config.lambda2 * long_tail_loss +
self.config.lambda3 * bounds_loss
)
# Compute half-life distribution for logging
taus = self.lambdas_to_half_lives(lambdas)
metrics = {
"total_regularization_loss": float(total_loss),
"log_uniform_component": float(self.config.lambda1 * log_uniform_loss),
"long_tail_component": float(self.config.lambda2 * long_tail_loss),
"bounds_component": float(self.config.lambda3 * bounds_loss),
"tau_min": float(np.min(taus)),
"tau_max": float(np.max(taus)),
"tau_mean": float(np.mean(taus)),
"tau_median": float(np.median(taus)),
**log_uniform_metrics,
**long_tail_metrics,
**bounds_metrics,
}
return float(total_loss), metrics
def compute_gradient(
self,
lambdas: np.ndarray,
epsilon: float = 1e-5
) -> np.ndarray:
"""
Compute gradient of regularization loss w.r.t. lambdas.
Uses finite differences for simplicity.
In a real implementation, this would use autodiff.
Args:
lambdas: Decay parameters, shape (N,)
epsilon: Perturbation size
Returns:
grad: Gradient, shape (N,)
"""
grad = np.zeros_like(lambdas)
for i in range(len(lambdas)):
# Positive perturbation
lambdas_plus = lambdas.copy()
lambdas_plus[i] += epsilon
loss_plus, _ = self.compute(lambdas_plus)
# Negative perturbation
lambdas_minus = lambdas.copy()
lambdas_minus[i] -= epsilon
loss_minus, _ = self.compute(lambdas_minus)
# Central difference
grad[i] = (loss_plus - loss_minus) / (2 * epsilon)
return grad
def diagnose(self, lambdas: np.ndarray) -> str:
"""
Generate diagnostic string for current half-life distribution.
Args:
lambdas: Decay parameters
Returns:
Diagnostic string
"""
loss, metrics = self.compute(lambdas)
taus = self.lambdas_to_half_lives(lambdas)
lines = [
"=" * 60,
"HALF-LIFE REGULARIZER DIAGNOSTICS",
"=" * 60,
"",
"Current Distribution:",
f" τ range: [{metrics['tau_min']:.1f}, {metrics['tau_max']:.1f}]",
f" τ mean: {metrics['tau_mean']:.1f}",
f" τ median: {metrics['tau_median']:.1f}",
"",
"Target Distribution:",
f" τ range: [{self.config.tau_min}, {self.config.tau_max}]",
f" log(τ) target mean: {self.mu_star:.3f}",
f" log(τ) target var: {self.sigma2_star:.3f}",
"",
"Log-Uniform Prior:",
f" log(τ) mean: {metrics['log_tau_mean']:.3f} (target: {metrics['log_tau_target_mean']:.3f})",
f" log(τ) var: {metrics['log_tau_var']:.3f} (target: {metrics['log_tau_target_var']:.3f})",
f" Mean deviation: {metrics['mean_deviation']:.3f}",
f" Var deviation: {metrics['var_deviation']:.3f}",
f" Loss: {metrics['log_uniform_loss']:.6f}",
"",
"Long-Tail Survival:",
f" Threshold: τ > {metrics['tau_threshold']:.1f}",
f" Long-range count: {metrics['n_long_range']}/{len(lambdas)} ({metrics['frac_long_range']:.1%})",
f" Tail mass (soft): {metrics['tail_mass']:.3f} (target: {metrics['tail_target']:.3f})",
f" Loss: {metrics['long_tail_loss']:.6f}",
"",
"Total Regularization Loss:",
f" Log-uniform component: {metrics['log_uniform_component']:.6f}",
f" Long-tail component: {metrics['long_tail_component']:.6f}",
f" Total: {metrics['total_regularization_loss']:.6f}",
"",
]
# Add half-life histogram
lines.append("Half-Life Histogram (log scale):")
bins = np.logspace(0, np.log10(self.config.tau_max), 11)
hist, _ = np.histogram(taus, bins=bins)
for i, count in enumerate(hist):
bar = "█" * count
lines.append(f" [{bins[i]:7.1f}, {bins[i+1]:7.1f}): {count:2d} {bar}")
lines.append("")
lines.append("=" * 60)
return "\n".join(lines)
def simulate_collapse_and_recovery():
"""
Simulate the half-life collapse problem and demonstrate regularization.
This shows:
1. Initial log-uniform distribution (good)
2. Simulated collapse to short half-lives (bad, mimics training at scale)
3. Regularization gradient direction (recovery)
"""
print("=" * 70)
print("HALF-LIFE COLLAPSE AND REGULARIZATION DEMONSTRATION")
print("=" * 70)
config = HalfLifeRegularizerConfig(
sequence_length=4096,
tau_min=1.0,
tau_max=4096.0,
lambda1=0.01,
lambda2=0.01
)
regularizer = HalfLifeRegularizer(config)
# --- Initial Distribution (good) ---
print("\n1. INITIAL DISTRIBUTION (Log-Uniform)")
print("-" * 60)
n_oscillators = 32
log_taus_init = np.linspace(np.log(1.0), np.log(4096.0), n_oscillators)
taus_init = np.exp(log_taus_init)
lambdas_init = np.power(0.5, 1.0 / taus_init)
loss_init, metrics_init = regularizer.compute(lambdas_init)
print(f" Half-lives: [{metrics_init['tau_min']:.1f}, {metrics_init['tau_max']:.1f}]")
print(f" Regularization loss: {loss_init:.6f}")
print(f" Long-range oscillators: {metrics_init['n_long_range']}/{n_oscillators}")
# --- Collapsed Distribution (bad) ---
print("\n2. COLLAPSED DISTRIBUTION (Training at Scale)")
print("-" * 60)
print(" Simulating what happens during GPT-2 scale training...")
# All half-lives collapse to < 10 steps
taus_collapsed = np.random.uniform(2, 10, n_oscillators)
lambdas_collapsed = np.power(0.5, 1.0 / taus_collapsed)
loss_collapsed, metrics_collapsed = regularizer.compute(lambdas_collapsed)
print(f" Half-lives: [{metrics_collapsed['tau_min']:.1f}, {metrics_collapsed['tau_max']:.1f}]")
print(f" Regularization loss: {loss_collapsed:.6f} ({loss_collapsed/loss_init:.0f}x initial)")
print(f" Long-range oscillators: {metrics_collapsed['n_long_range']}/{n_oscillators}")
# --- Regularization Gradient ---
print("\n3. REGULARIZATION GRADIENT ANALYSIS")
print("-" * 60)
grad = regularizer.compute_gradient(lambdas_collapsed)
print(" Gradient direction indicates how to adjust λ_i to reduce loss:")
print(" (Negative gradient → increase λ → longer half-life)")
print()
# Show gradient for first few oscillators
for i in range(min(5, n_oscillators)):
tau_i = taus_collapsed[i]
grad_i = grad[i]
direction = "→ increase τ" if grad_i < 0 else "→ decrease τ"
print(f" Osc {i}: τ={tau_i:.1f}, grad={grad_i:+.4f} {direction}")
print(f" ... ({n_oscillators - 5} more)")
print(f"\n Mean gradient magnitude: {np.mean(np.abs(grad)):.4f}")
# --- After One Regularization Step ---
print("\n4. AFTER REGULARIZATION STEP")
print("-" * 60)
lr = 1.0 # Learning rate
lambdas_reg = lambdas_collapsed - lr * grad
lambdas_reg = np.clip(lambdas_reg, 0.01, 0.9999) # Keep valid
loss_reg, metrics_reg = regularizer.compute(lambdas_reg)
print(f" Half-lives: [{metrics_reg['tau_min']:.1f}, {metrics_reg['tau_max']:.1f}]")
print(f" Regularization loss: {loss_reg:.6f} ({loss_reg/loss_collapsed:.1%} of collapsed)")
print(f" Long-range oscillators: {metrics_reg['n_long_range']}/{n_oscillators}")
# --- Summary ---
print("\n5. SUMMARY")
print("-" * 60)
print(f"""
State | Loss | τ range | Long-range
-------------------|-----------|-----------------|------------
Initial (good) | {loss_init:.6f} | [{metrics_init['tau_min']:.1f}, {metrics_init['tau_max']:.1f}] | {metrics_init['n_long_range']}/{n_oscillators}
Collapsed (bad) | {loss_collapsed:.6f} | [{metrics_collapsed['tau_min']:.1f}, {metrics_collapsed['tau_max']:.1f}] | {metrics_collapsed['n_long_range']}/{n_oscillators}
After 1 reg step | {loss_reg:.6f} | [{metrics_reg['tau_min']:.1f}, {metrics_reg['tau_max']:.1f}] | {metrics_reg['n_long_range']}/{n_oscillators}
""")
print("=" * 70)
print("CONCLUSION:")
print(" The regularizer provides gradients that push collapsed half-lives")
print(" back toward a log-uniform distribution spanning the full context.")
print("=" * 70)
return {
"initial": {"loss": loss_init, "metrics": metrics_init},
"collapsed": {"loss": loss_collapsed, "metrics": metrics_collapsed},
"regularized": {"loss": loss_reg, "metrics": metrics_reg},
}
if __name__ == "__main__":
simulate_collapse_and_recovery()