JustinTX's picture
Add files using upload-large-folder tool
2facf1f verified
"""
Key insight from analysis: optimal function has U-shape (higher at edges).
Strategy: Basin hopping + L-BFGS with diverse initializations, then JAX polish.
"""
import numpy as np
from scipy.optimize import minimize, basinhopping
import jax
import jax.numpy as jnp
import optax
def compute_c1_numpy(f_values, n_points):
dx = 0.5 / n_points
f_nn = np.maximum(f_values, 0.0)
autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
integral_sq = (np.sum(f_nn) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(autoconv) / integral_sq
def compute_c1_fft(f_values, dx):
n = len(f_values)
padded = np.zeros(2 * n)
padded[:n] = f_values
fft_f = np.fft.rfft(padded)
conv = np.fft.irfft(fft_f * fft_f, n=2*n) * dx
integral_sq = (np.sum(f_values) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(conv) / integral_sq
def fourier_to_f(coeffs, N):
"""Convert Fourier coefficients to non-negative function.
coeffs = [a0, a1, b1, a2, b2, ..., aK, bK]
"""
K = (len(coeffs) - 1) // 2
x = np.linspace(0, 1, N, endpoint=False)
f = np.full(N, coeffs[0])
for k in range(1, K + 1):
f += coeffs[2*k-1] * np.cos(2*np.pi*k*x)
f += coeffs[2*k] * np.sin(2*np.pi*k*x)
return np.maximum(f, 0.0)
def optimize_fourier_basin(K, N, n_restarts=20, seed=42):
"""Basin hopping on Fourier coefficients"""
dx = 0.5 / N
n_coeffs = 2 * K + 1
def objective(coeffs):
f = fourier_to_f(coeffs, N)
return compute_c1_fft(f, dx)
best_c1 = float('inf')
best_coeffs = None
for restart in range(n_restarts):
np.random.seed(seed + restart * 7)
# Initialize with different shapes
x0 = np.zeros(n_coeffs)
x0[0] = 1.0 # base level
if restart % 5 == 0:
pass # constant
elif restart % 5 == 1:
x0[1] = 0.3 # cos(2*pi*x) = U-shape
elif restart % 5 == 2:
x0[1] = -0.3 # inverted U-shape
elif restart % 5 == 3:
x0[1] = 0.2
x0[3] = 0.1 # U-shape + higher harmonics
elif restart % 5 == 4:
x0[:] = np.random.randn(n_coeffs) * 0.2
x0[0] = 1.0
# Basin hopping
minimizer_kwargs = {
'method': 'L-BFGS-B',
'options': {'maxiter': 500, 'ftol': 1e-12},
}
result = basinhopping(
objective, x0,
minimizer_kwargs=minimizer_kwargs,
niter=200,
T=0.05,
stepsize=0.3,
seed=seed + restart,
)
c1 = result.fun
if c1 < best_c1:
best_c1 = c1
best_coeffs = result.x
print(f" restart={restart}: C1={c1:.10f} ***", flush=True)
return best_coeffs, best_c1
def jax_polish(f_init, N, adam_steps=60000):
"""Polish with JAX gradient descent + L-BFGS"""
dx = 0.5 / N
@jax.jit
def obj_smooth(params, temp):
f = jnp.exp(jnp.clip(params, -8, 4))
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
smooth_max = jax.nn.logsumexp(temp * conv) / temp
return smooth_max / integral_sq
@jax.jit
def obj_hard(params):
f = jnp.exp(jnp.clip(params, -8, 4))
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
return jnp.max(conv) / integral_sq
grad_fn = jax.jit(jax.grad(obj_smooth))
params = jnp.array(np.log(np.maximum(f_init, 1e-6)))
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=0.003, warmup_steps=1000,
decay_steps=adam_steps - 1000, end_value=1e-7,
)
optimizer = optax.adam(learning_rate=lr_schedule)
opt_state = optimizer.init(params)
best_c1 = float('inf')
best_params = params
for step in range(adam_steps):
temp = 300.0
loss, grads = jax.value_and_grad(obj_smooth)(params, temp)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
if step % 15000 == 0 or step == adam_steps - 1:
hc = float(obj_hard(params))
print(f" Step {step}: C1={hc:.8f}")
if hc < best_c1:
best_c1 = hc
best_params = params
# L-BFGS polish
params_np = np.array(best_params, dtype=np.float64)
for temp in [1000.0, 5000.0, 20000.0]:
def scipy_obj(p):
p_jax = jnp.array(p)
val = float(obj_smooth(p_jax, temp))
g = np.array(grad_fn(p_jax, temp), dtype=np.float64)
return val, g
result = minimize(
scipy_obj, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-12},
)
params_np = result.x
f_final = np.exp(np.clip(params_np, -8, 4))
c1_final = compute_c1_numpy(f_final, N)
print(f" After L-BFGS: C1={c1_final:.10f}")
return f_final, min(c1_final, best_c1)
def run():
best_c1 = float('inf')
best_f = None
best_n = None
# Phase 1: Basin hopping on Fourier coefficients at moderate resolution
N_coarse = 500
for K in [10, 20, 30, 50]:
print(f"\nFourier basin hopping K={K}, N={N_coarse}")
coeffs, c1 = optimize_fourier_basin(K, N_coarse, n_restarts=15)
f = fourier_to_f(coeffs, N_coarse)
c1_verify = compute_c1_numpy(f, N_coarse)
print(f" Best: C1={c1_verify:.10f}")
if c1_verify < best_c1:
best_c1 = c1_verify
# Upsample to fine grid
N_fine = 3000
best_f = fourier_to_f(coeffs, N_fine)
best_n = N_fine
best_coeffs = coeffs
print(f" *** GLOBAL BEST: C1={c1_verify:.10f}")
# Phase 2: Polish best result with JAX
if best_f is not None:
print(f"\nPolishing best result (C1={best_c1:.10f})...")
N_fine = 3000
f_polished, c1_polished = jax_polish(best_f, N_fine, adam_steps=60000)
print(f" Polished: C1={c1_polished:.10f}")
if c1_polished < best_c1:
best_c1 = c1_polished
best_f = f_polished
best_n = N_fine
# Phase 3: Direct JAX with U-shaped init
print(f"\nDirect JAX with U-shaped initialization...")
N = 3000
x = np.linspace(0, 1, N)
# U-shape: higher at edges, lower in middle
init_f = 0.5 + 0.5 * np.cos(2 * np.pi * x) # U-shape
init_f = np.maximum(init_f, 0.01)
f_jax, c1_jax = jax_polish(init_f, N, adam_steps=60000)
print(f" U-shape JAX: C1={c1_jax:.10f}")
if c1_jax < best_c1:
best_c1 = c1_jax
best_f = f_jax
best_n = N
print(f"\nFinal best C1: {best_c1:.10f}")
return best_f, best_c1, best_c1, best_n