JustinTX's picture
Add files using upload-large-folder tool
2facf1f verified
"""
CMA-ES on step function heights with SMALL N_eval for speed.
Then upsample and polish.
"""
import sys
import numpy as np
import cma
import jax
import jax.numpy as jnp
import optax
from scipy.optimize import minimize as scipy_minimize
def compute_c1_fast(f, dx):
n = len(f)
padded = np.zeros(2 * n)
padded[:n] = f
fft_f = np.fft.rfft(padded)
conv = np.fft.irfft(fft_f * fft_f, n=2*n) * dx
integral_sq = (np.sum(f) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(conv) / integral_sq
def compute_c1_numpy(f_values, n_points):
dx = 0.5 / n_points
f_nn = np.maximum(f_values, 0.0)
autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
integral_sq = (np.sum(f_nn) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(autoconv) / integral_sq
def step_to_fine(heights, N_eval):
K = len(heights)
f = np.zeros(N_eval)
for i in range(K):
s = int(i * N_eval / K)
e = int((i + 1) * N_eval / K)
f[s:e] = max(heights[i], 0.0)
return f
def run_cma(K, N_eval, seed=42, maxiter=5000, popsize=None):
dx = 0.5 / N_eval
if popsize is None:
popsize = max(50, 4 + int(3 * np.log(K)))
def objective(log_heights):
heights = np.exp(np.clip(log_heights, -4, 4))
f = step_to_fine(heights, N_eval)
return compute_c1_fast(f, dx)
x0 = np.zeros(K) # Start with constant function (heights = 1)
opts = {
'maxiter': maxiter,
'tolfun': 1e-14,
'tolx': 1e-14,
'popsize': popsize,
'seed': seed,
'verbose': -9,
}
es = cma.CMAEvolutionStrategy(x0, 1.0, opts)
gen = 0
while not es.stop():
solutions = es.ask()
fitnesses = [objective(s) for s in solutions]
es.tell(solutions, fitnesses)
gen += 1
if gen % 500 == 0:
bf = es.result[1]
sys.stdout.write(f" Gen {gen}: best={bf:.10f}\n")
sys.stdout.flush()
best = es.result[0]
return np.exp(np.clip(best, -4, 4)), es.result[1]
def jax_polish(f_init, N, adam_steps=50000):
dx = 0.5 / N
@jax.jit
def obj_smooth(params, temp):
f = jnp.exp(jnp.clip(params, -8, 4))
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
smooth_max = jax.nn.logsumexp(temp * conv) / temp
return smooth_max / integral_sq
@jax.jit
def obj_hard(params):
f = jnp.exp(jnp.clip(params, -8, 4))
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
return jnp.max(conv) / integral_sq
grad_smooth = jax.jit(jax.grad(obj_smooth))
grad_hard = jax.jit(jax.grad(obj_hard))
params = jnp.array(np.log(np.maximum(f_init, 1e-6)))
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=0.003, warmup_steps=1000,
decay_steps=adam_steps - 1000, end_value=1e-7,
)
optimizer = optax.adam(learning_rate=lr_schedule)
opt_state = optimizer.init(params)
best_c1 = float('inf')
best_params = params
for step in range(adam_steps):
loss, grads = jax.value_and_grad(obj_smooth)(params, 300.0)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
if step % 10000 == 0 or step == adam_steps - 1:
hc = float(obj_hard(params))
sys.stdout.write(f" Adam {step:6d}: C1={hc:.10f}\n")
sys.stdout.flush()
if hc < best_c1:
best_c1 = hc
best_params = params
# Hard L-BFGS
params_np = np.array(best_params, dtype=np.float64)
for temp in [1000.0, 10000.0]:
def scipy_obj(p):
p_jax = jnp.array(p)
val = float(obj_smooth(p_jax, temp))
g = np.array(grad_smooth(p_jax, temp), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_obj, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-14, 'maxcor': 50},
)
params_np = result.x
def scipy_obj_hard(p):
p_jax = jnp.array(p)
val = float(obj_hard(p_jax))
g = np.array(grad_hard(p_jax), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_obj_hard, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 10000, 'ftol': 1e-16, 'gtol': 1e-15, 'maxcor': 100},
)
params_np = result.x
f_final = np.exp(np.clip(params_np, -8, 4))
c1_final = compute_c1_numpy(f_final, N)
sys.stdout.write(f" After L-BFGS: C1={c1_final:.10f}\n")
sys.stdout.flush()
return f_final, c1_final
def run():
best_c1 = float('inf')
best_f = None
best_n = None
N_coarse = 200 # Very fast evaluation for CMA-ES
# Phase 1: CMA-ES exploration with step functions
sys.stdout.write("Phase 1: CMA-ES on step functions\n")
sys.stdout.flush()
best_heights = None
best_cma_c1 = float('inf')
for K in [20, 30, 40, 50, 60, 80, 100]:
for seed in [42, 0, 7, 123, 99]:
heights, c1 = run_cma(K, N_coarse, seed=seed, maxiter=5000, popsize=80)
f = step_to_fine(heights, N_coarse)
c1v = compute_c1_fast(f, 0.5 / N_coarse)
if c1v < best_cma_c1:
best_cma_c1 = c1v
best_heights = heights.copy()
best_K = K
sys.stdout.write(f" K={K} seed={seed}: C1={c1v:.10f} ***\n")
elif seed == 42:
sys.stdout.write(f" K={K} seed={seed}: C1={c1v:.10f}\n")
sys.stdout.flush()
sys.stdout.write(f"\nBest CMA-ES: C1={best_cma_c1:.10f} with K={best_K}\n")
sys.stdout.flush()
# Phase 2: Upsample and polish with JAX
N_fine = 4000
f_up = step_to_fine(best_heights, N_fine)
sys.stdout.write(f"\nPhase 2: Polish at N={N_fine}\n")
sys.stdout.flush()
f_polished, c1_polished = jax_polish(f_up, N_fine, adam_steps=60000)
sys.stdout.write(f" Polished: C1={c1_polished:.10f}\n")
sys.stdout.flush()
if c1_polished < best_c1:
best_c1 = c1_polished
best_f = f_polished
best_n = N_fine
# Phase 3: Direct JAX with seed 15 (best from v16)
sys.stdout.write(f"\nPhase 3: Direct JAX seed 15 at N={N_fine}\n")
sys.stdout.flush()
np.random.seed(15)
init_f = np.ones(N_fine) * 0.5 + 0.02 * np.random.randn(N_fine)
f_direct, c1_direct = jax_polish(init_f, N_fine, adam_steps=80000)
sys.stdout.write(f" Direct: C1={c1_direct:.10f}\n")
sys.stdout.flush()
if c1_direct < best_c1:
best_c1 = c1_direct
best_f = f_direct
best_n = N_fine
sys.stdout.write(f"\nFinal C1: {best_c1:.10f}\n")
sys.stdout.flush()
return best_f, best_c1, best_c1, best_n