JustinTX's picture
Add files using upload-large-folder tool
2facf1f verified
import numpy as np
from scipy.optimize import minimize as scipy_minimize
import jax
import jax.numpy as jnp
import optax
def compute_c1_numpy(f_values, n_points):
dx = 0.5 / n_points
f_nn = np.maximum(f_values, 0.0)
autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
integral_sq = (np.sum(f_nn) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(autoconv) / integral_sq
def fourier_to_function(coeffs, N):
"""Convert Fourier coefficients to function values.
coeffs: [a0, a1, b1, a2, b2, ..., aK, bK]
f(x) = a0 + sum_k (a_k cos(2*pi*k*x) + b_k sin(2*pi*k*x))
Clamped to non-negative.
"""
K = (len(coeffs) - 1) // 2
x = np.linspace(0, 1, N, endpoint=False)
f = np.full(N, coeffs[0])
for k in range(1, K + 1):
f += coeffs[2*k - 1] * np.cos(2 * np.pi * k * x)
f += coeffs[2*k] * np.sin(2 * np.pi * k * x)
return np.maximum(f, 0.0)
def fourier_objective(coeffs, N):
f = fourier_to_function(coeffs, N)
return compute_c1_numpy(f, N)
def run():
import cma
best_c1 = float('inf')
best_f = None
best_n = None
N_eval = 4000
# Try different numbers of Fourier components
for K in [15, 25, 40, 60]:
n_coeffs = 2 * K + 1
print(f"\n=== CMA-ES with K={K} ({n_coeffs} params), N_eval={N_eval} ===")
# Initialize: constant function
x0 = np.zeros(n_coeffs)
x0[0] = 0.5 # a0 = 0.5 (constant part)
sigma0 = 0.3
opts = cma.CMAOptions()
opts['maxiter'] = 3000
opts['tolfun'] = 1e-12
opts['tolx'] = 1e-12
opts['popsize'] = 40
opts['seed'] = 42
opts['verbose'] = -1 # quiet
es = cma.CMAEvolutionStrategy(x0, sigma0, opts)
gen = 0
while not es.stop():
solutions = es.ask()
fitnesses = [fourier_objective(s, N_eval) for s in solutions]
es.tell(solutions, fitnesses)
gen += 1
if gen % 100 == 0:
best_fit = min(fitnesses)
print(f" Gen {gen}: best C1 = {best_fit:.10f}")
result = es.result
best_coeffs = result[0]
f_opt = fourier_to_function(best_coeffs, N_eval)
c1 = compute_c1_numpy(f_opt, N_eval)
print(f" CMA-ES result: C1 = {c1:.10f}")
# L-BFGS polish on the Fourier coefficients
result_lbfgs = scipy_minimize(
lambda c: fourier_objective(c, N_eval),
best_coeffs, method='Nelder-Mead',
options={'maxiter': 50000, 'xatol': 1e-12, 'fatol': 1e-12},
)
f_polish = fourier_to_function(result_lbfgs.x, N_eval)
c1_polish = compute_c1_numpy(f_polish, N_eval)
print(f" After Nelder-Mead: C1 = {c1_polish:.10f}")
if c1_polish < c1:
c1 = c1_polish
f_opt = f_polish
if c1 < best_c1:
best_c1 = c1
best_f = f_opt
best_n = N_eval
print(f" *** NEW GLOBAL BEST: C1 = {c1:.10f}")
# Strategy 2: Direct JAX optimization with perturbation-restart
print("\n=== JAX perturbation-restart ===")
N = 3000
dx = 0.5 / N
@jax.jit
def obj_smooth_jax(params, temp):
f = jnp.exp(jnp.clip(params, -8, 4))
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
smooth_max = jax.nn.logsumexp(temp * conv) / temp
return smooth_max / integral_sq
@jax.jit
def obj_hard_jax(params):
f = jnp.exp(jnp.clip(params, -8, 4))
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
return jnp.max(conv) / integral_sq
grad_fn = jax.jit(jax.grad(obj_smooth_jax))
# Start with flat init
np.random.seed(42)
init_f = np.ones(N) * 0.5 + 0.02 * np.random.randn(N)
params = jnp.array(np.log(np.maximum(init_f, 1e-6)))
# Warm up with Adam
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=0.008, warmup_steps=2000,
decay_steps=78000, end_value=1e-6,
)
optimizer = optax.adam(learning_rate=lr_schedule)
opt_state = optimizer.init(params)
best_c1_jax = float('inf')
best_params_jax = params
for step in range(80000):
temp = 200.0
loss, grads = jax.value_and_grad(obj_smooth_jax)(params, temp)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
if step % 20000 == 0:
hc = float(obj_hard_jax(params))
print(f" Step {step}: C1={hc:.8f}")
if hc < best_c1_jax:
best_c1_jax = hc
best_params_jax = params
hc = float(obj_hard_jax(params))
if hc < best_c1_jax:
best_c1_jax = hc
best_params_jax = params
# Perturbation restarts
for restart in range(5):
np.random.seed(100 + restart)
perturbed = best_params_jax + 0.1 * jax.random.normal(jax.random.PRNGKey(100 + restart), shape=(N,))
lr_sched2 = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=0.003, warmup_steps=1000,
decay_steps=29000, end_value=1e-6,
)
opt2 = optax.adam(learning_rate=lr_sched2)
opt_state2 = opt2.init(perturbed)
for step in range(30000):
loss, grads = jax.value_and_grad(obj_smooth_jax)(perturbed, 300.0)
updates, opt_state2 = opt2.update(grads, opt_state2, perturbed)
perturbed = optax.apply_updates(perturbed, updates)
hc = float(obj_hard_jax(perturbed))
print(f" Restart {restart}: C1={hc:.8f}")
if hc < best_c1_jax:
best_c1_jax = hc
best_params_jax = perturbed
# L-BFGS polish
params_np = np.array(best_params_jax, dtype=np.float64)
for temp in [1000.0, 5000.0, 20000.0]:
def scipy_obj(p):
p_jax = jnp.array(p)
val = float(obj_smooth_jax(p_jax, temp))
g = np.array(grad_fn(p_jax, temp), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_obj, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-12},
)
params_np = result.x
f_jax = np.exp(np.clip(params_np, -8, 4))
c1_jax = compute_c1_numpy(f_jax, N)
print(f" JAX final: C1={c1_jax:.10f}")
if c1_jax < best_c1:
best_c1 = c1_jax
best_f = f_jax
best_n = N
print(f"\nFinal best C1: {best_c1:.10f}")
return best_f, best_c1, best_c1, best_n