| """ |
| Strategy: CMA-ES on low-dimensional step function parameterization, |
| then upsample and polish with JAX gradient descent. |
| """ |
| import numpy as np |
| from scipy.optimize import minimize as scipy_minimize, differential_evolution |
| import cma |
| import jax |
| import jax.numpy as jnp |
| import optax |
| import sys |
|
|
|
|
| def compute_c1_fast(f_values, dx): |
| """Fast C1 computation using FFT""" |
| n = len(f_values) |
| padded = np.zeros(2 * n) |
| padded[:n] = f_values |
| fft_f = np.fft.rfft(padded) |
| conv = np.fft.irfft(fft_f * fft_f, n=2*n) * dx |
| integral_sq = (np.sum(f_values) * dx) ** 2 |
| if integral_sq < 1e-12: |
| return 1e10 |
| return np.max(conv) / integral_sq |
|
|
|
|
| def compute_c1_numpy(f_values, n_points): |
| dx = 0.5 / n_points |
| f_nn = np.maximum(f_values, 0.0) |
| autoconv = np.convolve(f_nn, f_nn, mode='full') * dx |
| integral_sq = (np.sum(f_nn) * dx) ** 2 |
| if integral_sq < 1e-12: |
| return 1e10 |
| return np.max(autoconv) / integral_sq |
|
|
|
|
| def step_to_fine(heights, N_eval): |
| """Convert step heights to fine function""" |
| n_steps = len(heights) |
| f = np.zeros(N_eval) |
| for i in range(n_steps): |
| start = int(i * N_eval / n_steps) |
| end = int((i + 1) * N_eval / n_steps) |
| f[start:end] = max(heights[i], 0.0) |
| return f |
|
|
|
|
| def optimize_cma(n_steps, N_eval, sigma0=0.5, maxiter=2000, seed=42): |
| dx = 0.5 / N_eval |
|
|
| def objective(heights): |
| f = step_to_fine(np.maximum(heights, 0.0), N_eval) |
| return compute_c1_fast(f, dx) |
|
|
| x0 = np.ones(n_steps) * 0.5 |
| opts = { |
| 'maxiter': maxiter, |
| 'tolfun': 1e-13, |
| 'tolx': 1e-13, |
| 'popsize': max(20, 4 + int(3 * np.log(n_steps))), |
| 'seed': seed, |
| 'verbose': -9, |
| 'bounds': [0.0, 10.0], |
| } |
|
|
| es = cma.CMAEvolutionStrategy(x0, sigma0, opts) |
| while not es.stop(): |
| solutions = es.ask() |
| fitnesses = [objective(s) for s in solutions] |
| es.tell(solutions, fitnesses) |
|
|
| return es.result[0], es.result[1] |
|
|
|
|
| def jax_polish(f_init, N, adam_steps=50000, verbose=True): |
| """Polish a solution using JAX gradient descent + L-BFGS""" |
| dx = 0.5 / N |
|
|
| @jax.jit |
| def obj_smooth(params, temp): |
| f = jnp.exp(jnp.clip(params, -8, 4)) |
| padded = jnp.zeros(2 * N) |
| padded = padded.at[:N].set(f) |
| fft_f = jnp.fft.rfft(padded) |
| conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx |
| integral_sq = (jnp.sum(f) * dx) ** 2 |
| smooth_max = jax.nn.logsumexp(temp * conv) / temp |
| return smooth_max / integral_sq |
|
|
| @jax.jit |
| def obj_hard(params): |
| f = jnp.exp(jnp.clip(params, -8, 4)) |
| padded = jnp.zeros(2 * N) |
| padded = padded.at[:N].set(f) |
| fft_f = jnp.fft.rfft(padded) |
| conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx |
| integral_sq = (jnp.sum(f) * dx) ** 2 |
| return jnp.max(conv) / integral_sq |
|
|
| grad_fn = jax.jit(jax.grad(obj_smooth)) |
|
|
| |
| params = jnp.array(np.log(np.maximum(f_init, 1e-6))) |
|
|
| lr_schedule = optax.warmup_cosine_decay_schedule( |
| init_value=0.0, peak_value=0.003, warmup_steps=1000, |
| decay_steps=adam_steps - 1000, end_value=1e-6, |
| ) |
| optimizer = optax.adam(learning_rate=lr_schedule) |
| opt_state = optimizer.init(params) |
|
|
| best_c1 = float('inf') |
| best_params = params |
|
|
| for step in range(adam_steps): |
| temp = 300.0 |
| loss, grads = jax.value_and_grad(obj_smooth)(params, temp) |
| updates, opt_state = optimizer.update(grads, opt_state, params) |
| params = optax.apply_updates(params, updates) |
|
|
| if step % 10000 == 0 or step == adam_steps - 1: |
| hc = float(obj_hard(params)) |
| if verbose: |
| print(f" Adam step {step}: C1={hc:.8f}") |
| if hc < best_c1: |
| best_c1 = hc |
| best_params = params |
|
|
| |
| params_np = np.array(best_params, dtype=np.float64) |
| for temp in [1000.0, 5000.0, 20000.0]: |
| def scipy_obj(p): |
| p_jax = jnp.array(p) |
| val = float(obj_smooth(p_jax, temp)) |
| g = np.array(grad_fn(p_jax, temp), dtype=np.float64) |
| return val, g |
|
|
| result = scipy_minimize( |
| scipy_obj, params_np, method='L-BFGS-B', jac=True, |
| options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-12}, |
| ) |
| params_np = result.x |
|
|
| f_final = np.exp(np.clip(params_np, -8, 4)) |
| c1_final = compute_c1_numpy(f_final, N) |
| if verbose: |
| print(f" After L-BFGS: C1={c1_final:.10f}") |
|
|
| if c1_final < best_c1: |
| return f_final, c1_final |
| else: |
| f_best = np.exp(np.clip(np.array(best_params), -8, 4)) |
| return f_best, best_c1 |
|
|
|
|
| def run(): |
| best_c1 = float('inf') |
| best_f = None |
| best_n = None |
|
|
| N_coarse = 500 |
|
|
| |
| print("Phase 1: CMA-ES search over step functions") |
| best_heights = None |
| best_cma_c1 = float('inf') |
|
|
| for n_steps in [10, 15, 20, 30, 40, 50, 60, 80, 100]: |
| for seed in [42, 0, 7]: |
| heights, c1 = optimize_cma(n_steps, N_coarse, maxiter=3000, seed=seed) |
| f = step_to_fine(np.maximum(heights, 0.0), N_coarse) |
| c1_v = compute_c1_numpy(f, N_coarse) |
| if c1_v < best_cma_c1: |
| best_cma_c1 = c1_v |
| best_heights = np.maximum(heights, 0.0) |
| print(f" n_steps={n_steps}, seed={seed}: C1={c1_v:.10f} ***", flush=True) |
| else: |
| if n_steps <= 30: |
| print(f" n_steps={n_steps}, seed={seed}: C1={c1_v:.10f}", flush=True) |
|
|
| |
| print(f"\nPhase 2: Polish best CMA result (C1={best_cma_c1:.10f})") |
| N_fine = 3000 |
| f_upsampled = step_to_fine(best_heights, N_fine) |
| f_polished, c1_polished = jax_polish(f_upsampled, N_fine, adam_steps=60000) |
| print(f" Polished: C1={c1_polished:.10f}") |
|
|
| if c1_polished < best_c1: |
| best_c1 = c1_polished |
| best_f = f_polished |
| best_n = N_fine |
|
|
| |
| print(f"\nPhase 3: Direct JAX optimization") |
| N = 3000 |
| np.random.seed(42) |
| init_f = np.ones(N) * 0.5 + 0.02 * np.random.randn(N) |
| f_jax, c1_jax = jax_polish(init_f, N, adam_steps=80000) |
| print(f" Direct JAX: C1={c1_jax:.10f}") |
|
|
| if c1_jax < best_c1: |
| best_c1 = c1_jax |
| best_f = f_jax |
| best_n = N |
|
|
| print(f"\nFinal best C1: {best_c1:.10f}") |
| return best_f, best_c1, best_c1, best_n |
|
|