""" CMA-ES on step function heights with SMALL N_eval for speed. Then upsample and polish. """ import sys import numpy as np import cma import jax import jax.numpy as jnp import optax from scipy.optimize import minimize as scipy_minimize def compute_c1_fast(f, dx): n = len(f) padded = np.zeros(2 * n) padded[:n] = f fft_f = np.fft.rfft(padded) conv = np.fft.irfft(fft_f * fft_f, n=2*n) * dx integral_sq = (np.sum(f) * dx) ** 2 if integral_sq < 1e-12: return 1e10 return np.max(conv) / integral_sq def compute_c1_numpy(f_values, n_points): dx = 0.5 / n_points f_nn = np.maximum(f_values, 0.0) autoconv = np.convolve(f_nn, f_nn, mode='full') * dx integral_sq = (np.sum(f_nn) * dx) ** 2 if integral_sq < 1e-12: return 1e10 return np.max(autoconv) / integral_sq def step_to_fine(heights, N_eval): K = len(heights) f = np.zeros(N_eval) for i in range(K): s = int(i * N_eval / K) e = int((i + 1) * N_eval / K) f[s:e] = max(heights[i], 0.0) return f def run_cma(K, N_eval, seed=42, maxiter=5000, popsize=None): dx = 0.5 / N_eval if popsize is None: popsize = max(50, 4 + int(3 * np.log(K))) def objective(log_heights): heights = np.exp(np.clip(log_heights, -4, 4)) f = step_to_fine(heights, N_eval) return compute_c1_fast(f, dx) x0 = np.zeros(K) # Start with constant function (heights = 1) opts = { 'maxiter': maxiter, 'tolfun': 1e-14, 'tolx': 1e-14, 'popsize': popsize, 'seed': seed, 'verbose': -9, } es = cma.CMAEvolutionStrategy(x0, 1.0, opts) gen = 0 while not es.stop(): solutions = es.ask() fitnesses = [objective(s) for s in solutions] es.tell(solutions, fitnesses) gen += 1 if gen % 500 == 0: bf = es.result[1] sys.stdout.write(f" Gen {gen}: best={bf:.10f}\n") sys.stdout.flush() best = es.result[0] return np.exp(np.clip(best, -4, 4)), es.result[1] def jax_polish(f_init, N, adam_steps=50000): dx = 0.5 / N @jax.jit def obj_smooth(params, temp): f = jnp.exp(jnp.clip(params, -8, 4)) padded = jnp.zeros(2 * N) padded = padded.at[:N].set(f) fft_f = jnp.fft.rfft(padded) conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx integral_sq = (jnp.sum(f) * dx) ** 2 smooth_max = jax.nn.logsumexp(temp * conv) / temp return smooth_max / integral_sq @jax.jit def obj_hard(params): f = jnp.exp(jnp.clip(params, -8, 4)) padded = jnp.zeros(2 * N) padded = padded.at[:N].set(f) fft_f = jnp.fft.rfft(padded) conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx integral_sq = (jnp.sum(f) * dx) ** 2 return jnp.max(conv) / integral_sq grad_smooth = jax.jit(jax.grad(obj_smooth)) grad_hard = jax.jit(jax.grad(obj_hard)) params = jnp.array(np.log(np.maximum(f_init, 1e-6))) lr_schedule = optax.warmup_cosine_decay_schedule( init_value=0.0, peak_value=0.003, warmup_steps=1000, decay_steps=adam_steps - 1000, end_value=1e-7, ) optimizer = optax.adam(learning_rate=lr_schedule) opt_state = optimizer.init(params) best_c1 = float('inf') best_params = params for step in range(adam_steps): loss, grads = jax.value_and_grad(obj_smooth)(params, 300.0) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) if step % 10000 == 0 or step == adam_steps - 1: hc = float(obj_hard(params)) sys.stdout.write(f" Adam {step:6d}: C1={hc:.10f}\n") sys.stdout.flush() if hc < best_c1: best_c1 = hc best_params = params # Hard L-BFGS params_np = np.array(best_params, dtype=np.float64) for temp in [1000.0, 10000.0]: def scipy_obj(p): p_jax = jnp.array(p) val = float(obj_smooth(p_jax, temp)) g = np.array(grad_smooth(p_jax, temp), dtype=np.float64) return val, g result = scipy_minimize( scipy_obj, params_np, method='L-BFGS-B', jac=True, options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-14, 'maxcor': 50}, ) params_np = result.x def scipy_obj_hard(p): p_jax = jnp.array(p) val = float(obj_hard(p_jax)) g = np.array(grad_hard(p_jax), dtype=np.float64) return val, g result = scipy_minimize( scipy_obj_hard, params_np, method='L-BFGS-B', jac=True, options={'maxiter': 10000, 'ftol': 1e-16, 'gtol': 1e-15, 'maxcor': 100}, ) params_np = result.x f_final = np.exp(np.clip(params_np, -8, 4)) c1_final = compute_c1_numpy(f_final, N) sys.stdout.write(f" After L-BFGS: C1={c1_final:.10f}\n") sys.stdout.flush() return f_final, c1_final def run(): best_c1 = float('inf') best_f = None best_n = None N_coarse = 200 # Very fast evaluation for CMA-ES # Phase 1: CMA-ES exploration with step functions sys.stdout.write("Phase 1: CMA-ES on step functions\n") sys.stdout.flush() best_heights = None best_cma_c1 = float('inf') for K in [20, 30, 40, 50, 60, 80, 100]: for seed in [42, 0, 7, 123, 99]: heights, c1 = run_cma(K, N_coarse, seed=seed, maxiter=5000, popsize=80) f = step_to_fine(heights, N_coarse) c1v = compute_c1_fast(f, 0.5 / N_coarse) if c1v < best_cma_c1: best_cma_c1 = c1v best_heights = heights.copy() best_K = K sys.stdout.write(f" K={K} seed={seed}: C1={c1v:.10f} ***\n") elif seed == 42: sys.stdout.write(f" K={K} seed={seed}: C1={c1v:.10f}\n") sys.stdout.flush() sys.stdout.write(f"\nBest CMA-ES: C1={best_cma_c1:.10f} with K={best_K}\n") sys.stdout.flush() # Phase 2: Upsample and polish with JAX N_fine = 4000 f_up = step_to_fine(best_heights, N_fine) sys.stdout.write(f"\nPhase 2: Polish at N={N_fine}\n") sys.stdout.flush() f_polished, c1_polished = jax_polish(f_up, N_fine, adam_steps=60000) sys.stdout.write(f" Polished: C1={c1_polished:.10f}\n") sys.stdout.flush() if c1_polished < best_c1: best_c1 = c1_polished best_f = f_polished best_n = N_fine # Phase 3: Direct JAX with seed 15 (best from v16) sys.stdout.write(f"\nPhase 3: Direct JAX seed 15 at N={N_fine}\n") sys.stdout.flush() np.random.seed(15) init_f = np.ones(N_fine) * 0.5 + 0.02 * np.random.randn(N_fine) f_direct, c1_direct = jax_polish(init_f, N_fine, adam_steps=80000) sys.stdout.write(f" Direct: C1={c1_direct:.10f}\n") sys.stdout.flush() if c1_direct < best_c1: best_c1 = c1_direct best_f = f_direct best_n = N_fine sys.stdout.write(f"\nFinal C1: {best_c1:.10f}\n") sys.stdout.flush() return best_f, best_c1, best_c1, best_n