JustinTX's picture
Add files using upload-large-folder tool
2facf1f verified
"""
Strategy: CMA-ES on low-dimensional step function parameterization,
then upsample and polish with JAX gradient descent.
"""
import numpy as np
from scipy.optimize import minimize as scipy_minimize, differential_evolution
import cma
import jax
import jax.numpy as jnp
import optax
import sys
def compute_c1_fast(f_values, dx):
"""Fast C1 computation using FFT"""
n = len(f_values)
padded = np.zeros(2 * n)
padded[:n] = f_values
fft_f = np.fft.rfft(padded)
conv = np.fft.irfft(fft_f * fft_f, n=2*n) * dx
integral_sq = (np.sum(f_values) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(conv) / integral_sq
def compute_c1_numpy(f_values, n_points):
dx = 0.5 / n_points
f_nn = np.maximum(f_values, 0.0)
autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
integral_sq = (np.sum(f_nn) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(autoconv) / integral_sq
def step_to_fine(heights, N_eval):
"""Convert step heights to fine function"""
n_steps = len(heights)
f = np.zeros(N_eval)
for i in range(n_steps):
start = int(i * N_eval / n_steps)
end = int((i + 1) * N_eval / n_steps)
f[start:end] = max(heights[i], 0.0)
return f
def optimize_cma(n_steps, N_eval, sigma0=0.5, maxiter=2000, seed=42):
dx = 0.5 / N_eval
def objective(heights):
f = step_to_fine(np.maximum(heights, 0.0), N_eval)
return compute_c1_fast(f, dx)
x0 = np.ones(n_steps) * 0.5
opts = {
'maxiter': maxiter,
'tolfun': 1e-13,
'tolx': 1e-13,
'popsize': max(20, 4 + int(3 * np.log(n_steps))),
'seed': seed,
'verbose': -9,
'bounds': [0.0, 10.0],
}
es = cma.CMAEvolutionStrategy(x0, sigma0, opts)
while not es.stop():
solutions = es.ask()
fitnesses = [objective(s) for s in solutions]
es.tell(solutions, fitnesses)
return es.result[0], es.result[1]
def jax_polish(f_init, N, adam_steps=50000, verbose=True):
"""Polish a solution using JAX gradient descent + L-BFGS"""
dx = 0.5 / N
@jax.jit
def obj_smooth(params, temp):
f = jnp.exp(jnp.clip(params, -8, 4))
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
smooth_max = jax.nn.logsumexp(temp * conv) / temp
return smooth_max / integral_sq
@jax.jit
def obj_hard(params):
f = jnp.exp(jnp.clip(params, -8, 4))
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
return jnp.max(conv) / integral_sq
grad_fn = jax.jit(jax.grad(obj_smooth))
# Initialize from the given function
params = jnp.array(np.log(np.maximum(f_init, 1e-6)))
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=0.003, warmup_steps=1000,
decay_steps=adam_steps - 1000, end_value=1e-6,
)
optimizer = optax.adam(learning_rate=lr_schedule)
opt_state = optimizer.init(params)
best_c1 = float('inf')
best_params = params
for step in range(adam_steps):
temp = 300.0
loss, grads = jax.value_and_grad(obj_smooth)(params, temp)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
if step % 10000 == 0 or step == adam_steps - 1:
hc = float(obj_hard(params))
if verbose:
print(f" Adam step {step}: C1={hc:.8f}")
if hc < best_c1:
best_c1 = hc
best_params = params
# L-BFGS polish
params_np = np.array(best_params, dtype=np.float64)
for temp in [1000.0, 5000.0, 20000.0]:
def scipy_obj(p):
p_jax = jnp.array(p)
val = float(obj_smooth(p_jax, temp))
g = np.array(grad_fn(p_jax, temp), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_obj, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-12},
)
params_np = result.x
f_final = np.exp(np.clip(params_np, -8, 4))
c1_final = compute_c1_numpy(f_final, N)
if verbose:
print(f" After L-BFGS: C1={c1_final:.10f}")
if c1_final < best_c1:
return f_final, c1_final
else:
f_best = np.exp(np.clip(np.array(best_params), -8, 4))
return f_best, best_c1
def run():
best_c1 = float('inf')
best_f = None
best_n = None
N_coarse = 500 # For fast CMA-ES evaluation
# Phase 1: CMA-ES search over step functions
print("Phase 1: CMA-ES search over step functions")
best_heights = None
best_cma_c1 = float('inf')
for n_steps in [10, 15, 20, 30, 40, 50, 60, 80, 100]:
for seed in [42, 0, 7]:
heights, c1 = optimize_cma(n_steps, N_coarse, maxiter=3000, seed=seed)
f = step_to_fine(np.maximum(heights, 0.0), N_coarse)
c1_v = compute_c1_numpy(f, N_coarse)
if c1_v < best_cma_c1:
best_cma_c1 = c1_v
best_heights = np.maximum(heights, 0.0)
print(f" n_steps={n_steps}, seed={seed}: C1={c1_v:.10f} ***", flush=True)
else:
if n_steps <= 30: # Only print small configs
print(f" n_steps={n_steps}, seed={seed}: C1={c1_v:.10f}", flush=True)
# Phase 2: Upsample best CMA result and polish with JAX
print(f"\nPhase 2: Polish best CMA result (C1={best_cma_c1:.10f})")
N_fine = 3000
f_upsampled = step_to_fine(best_heights, N_fine)
f_polished, c1_polished = jax_polish(f_upsampled, N_fine, adam_steps=60000)
print(f" Polished: C1={c1_polished:.10f}")
if c1_polished < best_c1:
best_c1 = c1_polished
best_f = f_polished
best_n = N_fine
# Phase 3: Also try direct JAX from scratch (best approach from v4)
print(f"\nPhase 3: Direct JAX optimization")
N = 3000
np.random.seed(42)
init_f = np.ones(N) * 0.5 + 0.02 * np.random.randn(N)
f_jax, c1_jax = jax_polish(init_f, N, adam_steps=80000)
print(f" Direct JAX: C1={c1_jax:.10f}")
if c1_jax < best_c1:
best_c1 = c1_jax
best_f = f_jax
best_n = N
print(f"\nFinal best C1: {best_c1:.10f}")
return best_f, best_c1, best_c1, best_n