""" Half-Life Regularizer for FDRA Oscillators This implements the exact mathematical regularizer from the Cursor instructions: ## Regularizer 1: Log-Uniform Half-Life Prior (primary) Target distribution: p(τ) ∝ 1/τ for τ ∈ [τ_min, τ_max] This gives equal mass per temporal decade (log scale). Loss: z_i = log(τ_i) μ = mean(z_i) σ² = mean((z_i - μ)²) μ* = (log(τ_min) + log(τ_max)) / 2 σ²* = (log(τ_max) - log(τ_min))² / 12 L_HL = α*(μ - μ*)² + β*(σ² - σ²*)² ## Regularizer 2: Long-Tail Survival Constraint (supporting) Ensure existence of long-range oscillators: s_i = σ(k * (τ_i - γ*L)) tail_mass = mean(s_i) L_tail = max(0, ρ - tail_mass)² Where: γ = 0.5 (fraction of full context) ρ = 0.05 (minimum fraction of oscillators) k = 10.0 (sigmoid sharpness) ## Regularizer 3: Tau Bounds Constraint (CRITICAL FIX) The moment-matching loss (L_HL) can be satisfied by a pathological bimodal distribution with taus outside [tau_min, tau_max]. This creates oscillators that are either useless (tau << 1) or extreme (tau >> L). L_bounds = mean(relu(tau_min - tau_i)^2) + mean(relu(tau_i - tau_max)^2) ## Combined Loss L_total = L_task + λ1 * L_HL + λ2 * L_tail + λ3 * L_bounds Authors: Half-Life Regularization Implementation Date: 2026-01-22 """ import numpy as np from typing import Dict, Tuple, Optional, Any from dataclasses import dataclass from pathlib import Path import json from datetime import datetime @dataclass class HalfLifeRegularizerConfig: """Configuration for half-life regularization.""" # Task parameters sequence_length: int = 4096 # L - max sequence length tau_min: float = 1.0 # Minimum target half-life tau_max: float = 4096.0 # Maximum target half-life (= L) # Log-Uniform Prior coefficients alpha: float = 1.0 # Weight for mean constraint beta: float = 1.0 # Weight for variance constraint # Long-Tail Survival coefficients gamma: float = 0.5 # Fraction of full context for long-range rho: float = 0.05 # Minimum fraction of long-range oscillators k: float = 10.0 # Sigmoid sharpness # Overall loss weights lambda1: float = 0.01 # Weight for L_HL in total loss lambda2: float = 0.01 # Weight for L_tail in total loss # NEW: Tau bound constraint (prevents pathological distributions) lambda3: float = 0.1 # Weight for L_bounds bound_sharpness: float = 5.0 # Sharpness of soft bound penalties class HalfLifeRegularizer: """ Half-Life Regularizer for FDRA Oscillator Banks. Prevents decay parameter collapse by regularizing the half-life distribution toward a log-uniform target. Usage: config = HalfLifeRegularizerConfig() regularizer = HalfLifeRegularizer(config) # During training: lambdas = oscillator_bank.lambdas loss, metrics = regularizer.compute(lambdas) # Add to total loss: total_loss = task_loss + loss # Log metrics: log(metrics) """ def __init__(self, config: HalfLifeRegularizerConfig): self.config = config # Pre-compute target statistics z_min = np.log(config.tau_min) z_max = np.log(config.tau_max) # Target mean in log space (center of [z_min, z_max]) self.mu_star = (z_min + z_max) / 2.0 # Target variance in log space (variance of uniform on [z_min, z_max]) self.sigma2_star = (z_max - z_min) ** 2 / 12.0 # Long-range threshold self.tau_threshold = config.gamma * config.sequence_length def lambdas_to_half_lives(self, lambdas: np.ndarray) -> np.ndarray: """ Convert decay parameters to half-lives. τ_i = ln(0.5) / ln(λ_i) Args: lambdas: Decay parameters, shape (N,) Returns: taus: Half-lives, shape (N,) """ # Clamp to avoid numerical issues safe_lambdas = np.clip(lambdas, 1e-10, 1.0 - 1e-10) taus = np.log(0.5) / np.log(safe_lambdas) return taus def compute_log_uniform_loss( self, lambdas: np.ndarray ) -> Tuple[float, Dict[str, float]]: """ Compute Log-Uniform Half-Life Prior loss. L_HL = α*(μ - μ*)² + β*(σ² - σ²*)² Args: lambdas: Decay parameters, shape (N,) Returns: loss: Scalar loss value metrics: Dictionary of component metrics """ # Compute half-lives and log half-lives taus = self.lambdas_to_half_lives(lambdas) z = np.log(taus) # Current statistics mu = np.mean(z) sigma2 = np.var(z) # Compute loss components mean_loss = self.config.alpha * (mu - self.mu_star) ** 2 var_loss = self.config.beta * (sigma2 - self.sigma2_star) ** 2 loss = mean_loss + var_loss metrics = { "log_tau_mean": float(mu), "log_tau_var": float(sigma2), "log_tau_target_mean": float(self.mu_star), "log_tau_target_var": float(self.sigma2_star), "mean_deviation": float(abs(mu - self.mu_star)), "var_deviation": float(abs(sigma2 - self.sigma2_star)), "log_uniform_loss": float(loss), } return float(loss), metrics def compute_long_tail_loss( self, lambdas: np.ndarray ) -> Tuple[float, Dict[str, float]]: """ Compute Long-Tail Survival Constraint loss. s_i = σ(k * (τ_i - γ*L)) tail_mass = mean(s_i) L_tail = max(0, ρ - tail_mass)² Args: lambdas: Decay parameters, shape (N,) Returns: loss: Scalar loss value metrics: Dictionary of component metrics """ # Compute half-lives taus = self.lambdas_to_half_lives(lambdas) # Sigmoid for soft thresholding (with numerical stability) # s_i ≈ 1 if τ_i > threshold, ≈ 0 otherwise x = self.config.k * (taus - self.tau_threshold) x = np.clip(x, -500, 500) # Prevent overflow s = 1.0 / (1.0 + np.exp(-x)) # Fraction of oscillators in long-tail regime tail_mass = np.mean(s) # Loss: penalize if tail_mass < rho deficit = max(0, self.config.rho - tail_mass) loss = deficit ** 2 # Count actual long-range oscillators (hard threshold) n_long_range = np.sum(taus > self.tau_threshold) frac_long_range = n_long_range / len(taus) metrics = { "tail_mass": float(tail_mass), "tail_target": float(self.config.rho), "tail_deficit": float(deficit), "n_long_range": int(n_long_range), "frac_long_range": float(frac_long_range), "tau_threshold": float(self.tau_threshold), "long_tail_loss": float(loss), } return float(loss), metrics def compute_bounds_loss( self, lambdas: np.ndarray ) -> Tuple[float, Dict[str, float]]: """ Compute tau bounds constraint loss. CRITICAL FIX: The moment-matching loss alone can be satisfied by a pathological bimodal distribution with taus outside [tau_min, tau_max]. This loss penalizes taus below tau_min or above tau_max: L_bounds = mean(relu(tau_min - tau_i)^2) + mean(relu(tau_i - tau_max)^2) Uses soft penalty with configurable sharpness. """ taus = self.lambdas_to_half_lives(lambdas) k = self.config.bound_sharpness # Soft lower bound: penalize tau < tau_min below_min = np.maximum(0, self.config.tau_min - taus) lower_penalty = np.mean((k * below_min) ** 2) # Soft upper bound: penalize tau > tau_max above_max = np.maximum(0, taus - self.config.tau_max) upper_penalty = np.mean((k * above_max) ** 2) loss = lower_penalty + upper_penalty n_below_min = np.sum(taus < self.config.tau_min) n_above_max = np.sum(taus > self.config.tau_max) metrics = { "bounds_loss": float(loss), "lower_bound_penalty": float(lower_penalty), "upper_bound_penalty": float(upper_penalty), "n_below_tau_min": int(n_below_min), "n_above_tau_max": int(n_above_max), "frac_in_bounds": float(1 - (n_below_min + n_above_max) / len(taus)), } return float(loss), metrics def compute(self, lambdas: np.ndarray) -> Tuple[float, Dict[str, Any]]: """ Compute total half-life regularization loss. L_total = λ1 * L_HL + λ2 * L_tail + λ3 * L_bounds CRITICAL: L_bounds prevents the pathological case where moment-matching is satisfied by a bimodal distribution with taus outside [tau_min, tau_max]. Args: lambdas: Decay parameters, shape (N,) Returns: loss: Total regularization loss metrics: All component metrics """ # Compute component losses log_uniform_loss, log_uniform_metrics = self.compute_log_uniform_loss(lambdas) long_tail_loss, long_tail_metrics = self.compute_long_tail_loss(lambdas) bounds_loss, bounds_metrics = self.compute_bounds_loss(lambdas) # Weighted combination (bounds loss is CRITICAL) total_loss = ( self.config.lambda1 * log_uniform_loss + self.config.lambda2 * long_tail_loss + self.config.lambda3 * bounds_loss ) # Compute half-life distribution for logging taus = self.lambdas_to_half_lives(lambdas) metrics = { "total_regularization_loss": float(total_loss), "log_uniform_component": float(self.config.lambda1 * log_uniform_loss), "long_tail_component": float(self.config.lambda2 * long_tail_loss), "bounds_component": float(self.config.lambda3 * bounds_loss), "tau_min": float(np.min(taus)), "tau_max": float(np.max(taus)), "tau_mean": float(np.mean(taus)), "tau_median": float(np.median(taus)), **log_uniform_metrics, **long_tail_metrics, **bounds_metrics, } return float(total_loss), metrics def compute_gradient( self, lambdas: np.ndarray, epsilon: float = 1e-5 ) -> np.ndarray: """ Compute gradient of regularization loss w.r.t. lambdas. Uses finite differences for simplicity. In a real implementation, this would use autodiff. Args: lambdas: Decay parameters, shape (N,) epsilon: Perturbation size Returns: grad: Gradient, shape (N,) """ grad = np.zeros_like(lambdas) for i in range(len(lambdas)): # Positive perturbation lambdas_plus = lambdas.copy() lambdas_plus[i] += epsilon loss_plus, _ = self.compute(lambdas_plus) # Negative perturbation lambdas_minus = lambdas.copy() lambdas_minus[i] -= epsilon loss_minus, _ = self.compute(lambdas_minus) # Central difference grad[i] = (loss_plus - loss_minus) / (2 * epsilon) return grad def diagnose(self, lambdas: np.ndarray) -> str: """ Generate diagnostic string for current half-life distribution. Args: lambdas: Decay parameters Returns: Diagnostic string """ loss, metrics = self.compute(lambdas) taus = self.lambdas_to_half_lives(lambdas) lines = [ "=" * 60, "HALF-LIFE REGULARIZER DIAGNOSTICS", "=" * 60, "", "Current Distribution:", f" τ range: [{metrics['tau_min']:.1f}, {metrics['tau_max']:.1f}]", f" τ mean: {metrics['tau_mean']:.1f}", f" τ median: {metrics['tau_median']:.1f}", "", "Target Distribution:", f" τ range: [{self.config.tau_min}, {self.config.tau_max}]", f" log(τ) target mean: {self.mu_star:.3f}", f" log(τ) target var: {self.sigma2_star:.3f}", "", "Log-Uniform Prior:", f" log(τ) mean: {metrics['log_tau_mean']:.3f} (target: {metrics['log_tau_target_mean']:.3f})", f" log(τ) var: {metrics['log_tau_var']:.3f} (target: {metrics['log_tau_target_var']:.3f})", f" Mean deviation: {metrics['mean_deviation']:.3f}", f" Var deviation: {metrics['var_deviation']:.3f}", f" Loss: {metrics['log_uniform_loss']:.6f}", "", "Long-Tail Survival:", f" Threshold: τ > {metrics['tau_threshold']:.1f}", f" Long-range count: {metrics['n_long_range']}/{len(lambdas)} ({metrics['frac_long_range']:.1%})", f" Tail mass (soft): {metrics['tail_mass']:.3f} (target: {metrics['tail_target']:.3f})", f" Loss: {metrics['long_tail_loss']:.6f}", "", "Total Regularization Loss:", f" Log-uniform component: {metrics['log_uniform_component']:.6f}", f" Long-tail component: {metrics['long_tail_component']:.6f}", f" Total: {metrics['total_regularization_loss']:.6f}", "", ] # Add half-life histogram lines.append("Half-Life Histogram (log scale):") bins = np.logspace(0, np.log10(self.config.tau_max), 11) hist, _ = np.histogram(taus, bins=bins) for i, count in enumerate(hist): bar = "█" * count lines.append(f" [{bins[i]:7.1f}, {bins[i+1]:7.1f}): {count:2d} {bar}") lines.append("") lines.append("=" * 60) return "\n".join(lines) def simulate_collapse_and_recovery(): """ Simulate the half-life collapse problem and demonstrate regularization. This shows: 1. Initial log-uniform distribution (good) 2. Simulated collapse to short half-lives (bad, mimics training at scale) 3. Regularization gradient direction (recovery) """ print("=" * 70) print("HALF-LIFE COLLAPSE AND REGULARIZATION DEMONSTRATION") print("=" * 70) config = HalfLifeRegularizerConfig( sequence_length=4096, tau_min=1.0, tau_max=4096.0, lambda1=0.01, lambda2=0.01 ) regularizer = HalfLifeRegularizer(config) # --- Initial Distribution (good) --- print("\n1. INITIAL DISTRIBUTION (Log-Uniform)") print("-" * 60) n_oscillators = 32 log_taus_init = np.linspace(np.log(1.0), np.log(4096.0), n_oscillators) taus_init = np.exp(log_taus_init) lambdas_init = np.power(0.5, 1.0 / taus_init) loss_init, metrics_init = regularizer.compute(lambdas_init) print(f" Half-lives: [{metrics_init['tau_min']:.1f}, {metrics_init['tau_max']:.1f}]") print(f" Regularization loss: {loss_init:.6f}") print(f" Long-range oscillators: {metrics_init['n_long_range']}/{n_oscillators}") # --- Collapsed Distribution (bad) --- print("\n2. COLLAPSED DISTRIBUTION (Training at Scale)") print("-" * 60) print(" Simulating what happens during GPT-2 scale training...") # All half-lives collapse to < 10 steps taus_collapsed = np.random.uniform(2, 10, n_oscillators) lambdas_collapsed = np.power(0.5, 1.0 / taus_collapsed) loss_collapsed, metrics_collapsed = regularizer.compute(lambdas_collapsed) print(f" Half-lives: [{metrics_collapsed['tau_min']:.1f}, {metrics_collapsed['tau_max']:.1f}]") print(f" Regularization loss: {loss_collapsed:.6f} ({loss_collapsed/loss_init:.0f}x initial)") print(f" Long-range oscillators: {metrics_collapsed['n_long_range']}/{n_oscillators}") # --- Regularization Gradient --- print("\n3. REGULARIZATION GRADIENT ANALYSIS") print("-" * 60) grad = regularizer.compute_gradient(lambdas_collapsed) print(" Gradient direction indicates how to adjust λ_i to reduce loss:") print(" (Negative gradient → increase λ → longer half-life)") print() # Show gradient for first few oscillators for i in range(min(5, n_oscillators)): tau_i = taus_collapsed[i] grad_i = grad[i] direction = "→ increase τ" if grad_i < 0 else "→ decrease τ" print(f" Osc {i}: τ={tau_i:.1f}, grad={grad_i:+.4f} {direction}") print(f" ... ({n_oscillators - 5} more)") print(f"\n Mean gradient magnitude: {np.mean(np.abs(grad)):.4f}") # --- After One Regularization Step --- print("\n4. AFTER REGULARIZATION STEP") print("-" * 60) lr = 1.0 # Learning rate lambdas_reg = lambdas_collapsed - lr * grad lambdas_reg = np.clip(lambdas_reg, 0.01, 0.9999) # Keep valid loss_reg, metrics_reg = regularizer.compute(lambdas_reg) print(f" Half-lives: [{metrics_reg['tau_min']:.1f}, {metrics_reg['tau_max']:.1f}]") print(f" Regularization loss: {loss_reg:.6f} ({loss_reg/loss_collapsed:.1%} of collapsed)") print(f" Long-range oscillators: {metrics_reg['n_long_range']}/{n_oscillators}") # --- Summary --- print("\n5. SUMMARY") print("-" * 60) print(f""" State | Loss | τ range | Long-range -------------------|-----------|-----------------|------------ Initial (good) | {loss_init:.6f} | [{metrics_init['tau_min']:.1f}, {metrics_init['tau_max']:.1f}] | {metrics_init['n_long_range']}/{n_oscillators} Collapsed (bad) | {loss_collapsed:.6f} | [{metrics_collapsed['tau_min']:.1f}, {metrics_collapsed['tau_max']:.1f}] | {metrics_collapsed['n_long_range']}/{n_oscillators} After 1 reg step | {loss_reg:.6f} | [{metrics_reg['tau_min']:.1f}, {metrics_reg['tau_max']:.1f}] | {metrics_reg['n_long_range']}/{n_oscillators} """) print("=" * 70) print("CONCLUSION:") print(" The regularizer provides gradients that push collapsed half-lives") print(" back toward a log-uniform distribution spanning the full context.") print("=" * 70) return { "initial": {"loss": loss_init, "metrics": metrics_init}, "collapsed": {"loss": loss_collapsed, "metrics": metrics_collapsed}, "regularized": {"loss": loss_reg, "metrics": metrics_reg}, } if __name__ == "__main__": simulate_collapse_and_recovery()