JustinTX's picture
Add files using upload-large-folder tool
2facf1f verified
import jax
import jax.numpy as jnp
import optax
import numpy as np
from dataclasses import dataclass
@dataclass
class Hyperparameters:
num_intervals: int = 1000
learning_rate: float = 0.01
end_lr_factor: float = 1e-5
num_steps: int = 80000
warmup_steps: int = 3000
smooth_max_temp: float = 100.0 # temperature for log-sum-exp smooth max
class AutocorrelationOptimizer:
def __init__(self, hypers: Hyperparameters):
self.hypers = hypers
self.domain_width = 0.5
self.dx = self.domain_width / self.hypers.num_intervals
def _objective_fn(self, f_values: jnp.ndarray, temp: float) -> jnp.ndarray:
f_non_negative = jax.nn.softplus(f_values) # smooth non-negativity
integral_f = jnp.sum(f_non_negative) * self.dx
eps = 1e-12
integral_f_safe = jnp.maximum(integral_f, eps)
N = self.hypers.num_intervals
padded_f = jnp.pad(f_non_negative, (0, N))
fft_f = jnp.fft.rfft(padded_f)
fft_conv = fft_f * fft_f
conv_f_f = jnp.fft.irfft(fft_conv, n=2 * N)
scaled_conv_f_f = conv_f_f * self.dx
# Use log-sum-exp smooth max for better gradients
smooth_max = jax.nn.logsumexp(temp * scaled_conv_f_f) / temp
c1_ratio = smooth_max / (integral_f_safe ** 2)
return c1_ratio
def _hard_objective(self, f_values: jnp.ndarray) -> jnp.ndarray:
f_non_negative = jax.nn.softplus(f_values)
integral_f = jnp.sum(f_non_negative) * self.dx
eps = 1e-12
integral_f_safe = jnp.maximum(integral_f, eps)
N = self.hypers.num_intervals
padded_f = jnp.pad(f_non_negative, (0, N))
fft_f = jnp.fft.rfft(padded_f)
fft_conv = fft_f * fft_f
conv_f_f = jnp.fft.irfft(fft_conv, n=2 * N)
scaled_conv_f_f = conv_f_f * self.dx
max_conv = jnp.max(scaled_conv_f_f)
c1_ratio = max_conv / (integral_f_safe ** 2)
return c1_ratio
def train_step(self, f_values, opt_state, temp):
loss, grads = jax.value_and_grad(self._objective_fn)(f_values, temp)
updates, opt_state = self.optimizer.update(grads, opt_state, f_values)
f_values = optax.apply_updates(f_values, updates)
hard_loss = self._hard_objective(f_values)
return f_values, opt_state, loss, hard_loss
def run_optimization(self, seed=42):
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=self.hypers.learning_rate,
warmup_steps=self.hypers.warmup_steps,
decay_steps=self.hypers.num_steps - self.hypers.warmup_steps,
end_value=self.hypers.learning_rate * self.hypers.end_lr_factor,
)
self.optimizer = optax.adam(learning_rate=schedule)
key = jax.random.PRNGKey(seed)
N = self.hypers.num_intervals
# Better initialization: triangle-like shape centered
x = jnp.linspace(0, 1, N)
# Start with a bump function shape
f_values = jnp.exp(-20.0 * (x - 0.5) ** 2)
f_values = f_values + 0.05 * jax.random.uniform(key, (N,))
opt_state = self.optimizer.init(f_values)
train_step_jit = jax.jit(self.train_step)
best_loss = jnp.inf
best_f = f_values
for step in range(self.hypers.num_steps):
# Anneal temperature: start low, increase
progress = min(step / (self.hypers.num_steps * 0.5), 1.0)
temp = 10.0 + progress * (self.hypers.smooth_max_temp - 10.0)
f_values, opt_state, loss, hard_loss = train_step_jit(f_values, opt_state, temp)
if hard_loss < best_loss:
best_loss = hard_loss
best_f = f_values
if step % 5000 == 0 or step == self.hypers.num_steps - 1:
print(f"Step {step:5d} | C1(smooth) ≈ {loss:.8f} | C1(hard) ≈ {hard_loss:.8f} | temp={temp:.1f}")
print(f"Best C1 found: {best_loss:.8f}")
# Convert softplus to actual non-negative values
final_f = jax.nn.softplus(best_f)
final_c1 = float(best_loss)
return final_f, final_c1
def run():
best_c1 = float('inf')
best_result = None
# Try multiple seeds and configs
configs = [
(1000, 0.01, 80000, 42),
(1000, 0.005, 80000, 123),
(1500, 0.008, 60000, 42),
]
for n_intervals, lr, steps, seed in configs:
print(f"\n--- Config: N={n_intervals}, lr={lr}, steps={steps}, seed={seed} ---")
hypers = Hyperparameters(
num_intervals=n_intervals,
learning_rate=lr,
num_steps=steps,
)
optimizer = AutocorrelationOptimizer(hypers)
optimized_f, final_c1 = optimizer.run_optimization(seed=seed)
if final_c1 < best_c1:
best_c1 = final_c1
f_values_np = np.array(optimized_f)
best_result = (f_values_np, best_c1, best_c1, hypers.num_intervals)
print(f"*** New best: C1 = {best_c1:.10f}")
return best_result