JustinTX's picture
Add files using upload-large-folder tool
2facf1f verified
import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import minimize as scipy_minimize
import optax
def compute_c1_numpy(f_values, n_points):
"""Compute C1 using numpy (for verification)"""
dx = 0.5 / n_points
f_nn = np.maximum(f_values, 0.0)
autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
integral_sq = (np.sum(f_nn) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(autoconv) / integral_sq
def make_objective_jax(N, dx):
"""Create JAX objective function for C1 minimization"""
@jax.jit
def objective(params):
# Use exp parameterization for non-negativity
f = jnp.exp(params)
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N)
conv = conv * dx
integral = jnp.sum(f) * dx
integral_sq = integral ** 2
c1 = jnp.max(conv) / integral_sq
return c1
@jax.jit
def objective_smooth(params, temp):
f = jnp.exp(params)
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N)
conv = conv * dx
integral = jnp.sum(f) * dx
integral_sq = integral ** 2
smooth_max = jax.nn.logsumexp(temp * conv) / temp
c1 = smooth_max / integral_sq
return c1
grad_fn = jax.jit(jax.grad(objective_smooth))
return objective, objective_smooth, grad_fn
def run():
best_c1_overall = float('inf')
best_f_overall = None
best_n_overall = None
for N in [1000, 2000, 3000]:
dx = 0.5 / N
objective, objective_smooth, grad_fn = make_objective_jax(N, dx)
for seed in range(5):
print(f"\n--- N={N}, seed={seed} ---")
np.random.seed(seed)
# Initialize
x = np.linspace(0, 1, N)
if seed == 0:
init = np.exp(-10 * (x - 0.5) ** 2) + 0.1
elif seed == 1:
init = np.ones(N)
elif seed == 2:
init = 0.5 * (1 + np.cos(2 * np.pi * (x - 0.5))) + 0.1
elif seed == 3:
# Step function: higher in middle
init = np.where((x > 0.2) & (x < 0.8), 1.5, 0.5)
else:
init = np.abs(np.random.randn(N)) * 0.3 + 0.2
params = np.log(np.maximum(init, 1e-6))
# Phase 1: Adam optimization with smooth max (JAX)
print("Phase 1: Adam optimization...")
params_jax = jnp.array(params)
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=0.01,
warmup_steps=2000,
decay_steps=48000,
end_value=1e-5,
)
optimizer = optax.adam(learning_rate=lr_schedule)
opt_state = optimizer.init(params_jax)
best_c1_run = float('inf')
best_params_run = params_jax
for step in range(50000):
temp = min(50.0 + step * 150.0 / 50000, 200.0)
loss_val, grads = jax.value_and_grad(objective_smooth)(params_jax, temp)
updates, opt_state = optimizer.update(grads, opt_state, params_jax)
params_jax = optax.apply_updates(params_jax, updates)
if step % 5000 == 0:
hard_c1 = float(objective(params_jax))
print(f" Step {step:5d} | C1(smooth)={float(loss_val):.8f} | C1(hard)={hard_c1:.8f}")
if hard_c1 < best_c1_run:
best_c1_run = hard_c1
best_params_run = params_jax
hard_c1 = float(objective(params_jax))
if hard_c1 < best_c1_run:
best_c1_run = hard_c1
best_params_run = params_jax
# Phase 2: L-BFGS-B refinement with high temperature smooth max
print(f"Phase 2: L-BFGS-B refinement (starting from C1={best_c1_run:.8f})...")
params_np = np.array(best_params_run)
for temp in [500.0, 1000.0, 2000.0]:
def scipy_obj(p):
p_jax = jnp.array(p)
val = float(objective_smooth(p_jax, temp))
g = np.array(grad_fn(p_jax, temp))
return val, g
result = scipy_minimize(
scipy_obj,
params_np,
method='L-BFGS-B',
jac=True,
options={'maxiter': 2000, 'ftol': 1e-15, 'gtol': 1e-10},
)
params_np = result.x
f_opt = np.exp(params_np)
c1 = compute_c1_numpy(f_opt, N)
print(f" temp={temp:.0f}: C1={c1:.10f}")
if c1 < best_c1_run:
best_c1_run = c1
best_params_run = jnp.array(params_np)
# Final evaluation
f_final = np.exp(np.array(best_params_run))
c1_final = compute_c1_numpy(f_final, N)
print(f" Final C1 for this run: {c1_final:.10f}")
if c1_final < best_c1_overall:
best_c1_overall = c1_final
best_f_overall = f_final
best_n_overall = N
print(f"*** GLOBAL BEST: C1 = {c1_final:.10f}")
print(f"\n=== Final best C1: {best_c1_overall:.10f} ===")
return best_f_overall, best_c1_overall, best_c1_overall, best_n_overall