import numpy as np from scipy.optimize import minimize as scipy_minimize import jax import jax.numpy as jnp import optax def compute_c1_numpy(f_values, n_points): dx = 0.5 / n_points f_nn = np.maximum(f_values, 0.0) autoconv = np.convolve(f_nn, f_nn, mode='full') * dx integral_sq = (np.sum(f_nn) * dx) ** 2 if integral_sq < 1e-12: return 1e10 return np.max(autoconv) / integral_sq def fourier_to_function(coeffs, N): """Convert Fourier coefficients to function values. coeffs: [a0, a1, b1, a2, b2, ..., aK, bK] f(x) = a0 + sum_k (a_k cos(2*pi*k*x) + b_k sin(2*pi*k*x)) Clamped to non-negative. """ K = (len(coeffs) - 1) // 2 x = np.linspace(0, 1, N, endpoint=False) f = np.full(N, coeffs[0]) for k in range(1, K + 1): f += coeffs[2*k - 1] * np.cos(2 * np.pi * k * x) f += coeffs[2*k] * np.sin(2 * np.pi * k * x) return np.maximum(f, 0.0) def fourier_objective(coeffs, N): f = fourier_to_function(coeffs, N) return compute_c1_numpy(f, N) def run(): import cma best_c1 = float('inf') best_f = None best_n = None N_eval = 4000 # Try different numbers of Fourier components for K in [15, 25, 40, 60]: n_coeffs = 2 * K + 1 print(f"\n=== CMA-ES with K={K} ({n_coeffs} params), N_eval={N_eval} ===") # Initialize: constant function x0 = np.zeros(n_coeffs) x0[0] = 0.5 # a0 = 0.5 (constant part) sigma0 = 0.3 opts = cma.CMAOptions() opts['maxiter'] = 3000 opts['tolfun'] = 1e-12 opts['tolx'] = 1e-12 opts['popsize'] = 40 opts['seed'] = 42 opts['verbose'] = -1 # quiet es = cma.CMAEvolutionStrategy(x0, sigma0, opts) gen = 0 while not es.stop(): solutions = es.ask() fitnesses = [fourier_objective(s, N_eval) for s in solutions] es.tell(solutions, fitnesses) gen += 1 if gen % 100 == 0: best_fit = min(fitnesses) print(f" Gen {gen}: best C1 = {best_fit:.10f}") result = es.result best_coeffs = result[0] f_opt = fourier_to_function(best_coeffs, N_eval) c1 = compute_c1_numpy(f_opt, N_eval) print(f" CMA-ES result: C1 = {c1:.10f}") # L-BFGS polish on the Fourier coefficients result_lbfgs = scipy_minimize( lambda c: fourier_objective(c, N_eval), best_coeffs, method='Nelder-Mead', options={'maxiter': 50000, 'xatol': 1e-12, 'fatol': 1e-12}, ) f_polish = fourier_to_function(result_lbfgs.x, N_eval) c1_polish = compute_c1_numpy(f_polish, N_eval) print(f" After Nelder-Mead: C1 = {c1_polish:.10f}") if c1_polish < c1: c1 = c1_polish f_opt = f_polish if c1 < best_c1: best_c1 = c1 best_f = f_opt best_n = N_eval print(f" *** NEW GLOBAL BEST: C1 = {c1:.10f}") # Strategy 2: Direct JAX optimization with perturbation-restart print("\n=== JAX perturbation-restart ===") N = 3000 dx = 0.5 / N @jax.jit def obj_smooth_jax(params, temp): f = jnp.exp(jnp.clip(params, -8, 4)) padded = jnp.zeros(2 * N) padded = padded.at[:N].set(f) fft_f = jnp.fft.rfft(padded) conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx integral_sq = (jnp.sum(f) * dx) ** 2 smooth_max = jax.nn.logsumexp(temp * conv) / temp return smooth_max / integral_sq @jax.jit def obj_hard_jax(params): f = jnp.exp(jnp.clip(params, -8, 4)) padded = jnp.zeros(2 * N) padded = padded.at[:N].set(f) fft_f = jnp.fft.rfft(padded) conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx integral_sq = (jnp.sum(f) * dx) ** 2 return jnp.max(conv) / integral_sq grad_fn = jax.jit(jax.grad(obj_smooth_jax)) # Start with flat init np.random.seed(42) init_f = np.ones(N) * 0.5 + 0.02 * np.random.randn(N) params = jnp.array(np.log(np.maximum(init_f, 1e-6))) # Warm up with Adam lr_schedule = optax.warmup_cosine_decay_schedule( init_value=0.0, peak_value=0.008, warmup_steps=2000, decay_steps=78000, end_value=1e-6, ) optimizer = optax.adam(learning_rate=lr_schedule) opt_state = optimizer.init(params) best_c1_jax = float('inf') best_params_jax = params for step in range(80000): temp = 200.0 loss, grads = jax.value_and_grad(obj_smooth_jax)(params, temp) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) if step % 20000 == 0: hc = float(obj_hard_jax(params)) print(f" Step {step}: C1={hc:.8f}") if hc < best_c1_jax: best_c1_jax = hc best_params_jax = params hc = float(obj_hard_jax(params)) if hc < best_c1_jax: best_c1_jax = hc best_params_jax = params # Perturbation restarts for restart in range(5): np.random.seed(100 + restart) perturbed = best_params_jax + 0.1 * jax.random.normal(jax.random.PRNGKey(100 + restart), shape=(N,)) lr_sched2 = optax.warmup_cosine_decay_schedule( init_value=0.0, peak_value=0.003, warmup_steps=1000, decay_steps=29000, end_value=1e-6, ) opt2 = optax.adam(learning_rate=lr_sched2) opt_state2 = opt2.init(perturbed) for step in range(30000): loss, grads = jax.value_and_grad(obj_smooth_jax)(perturbed, 300.0) updates, opt_state2 = opt2.update(grads, opt_state2, perturbed) perturbed = optax.apply_updates(perturbed, updates) hc = float(obj_hard_jax(perturbed)) print(f" Restart {restart}: C1={hc:.8f}") if hc < best_c1_jax: best_c1_jax = hc best_params_jax = perturbed # L-BFGS polish params_np = np.array(best_params_jax, dtype=np.float64) for temp in [1000.0, 5000.0, 20000.0]: def scipy_obj(p): p_jax = jnp.array(p) val = float(obj_smooth_jax(p_jax, temp)) g = np.array(grad_fn(p_jax, temp), dtype=np.float64) return val, g result = scipy_minimize( scipy_obj, params_np, method='L-BFGS-B', jac=True, options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-12}, ) params_np = result.x f_jax = np.exp(np.clip(params_np, -8, 4)) c1_jax = compute_c1_numpy(f_jax, N) print(f" JAX final: C1={c1_jax:.10f}") if c1_jax < best_c1: best_c1 = c1_jax best_f = f_jax best_n = N print(f"\nFinal best C1: {best_c1:.10f}") return best_f, best_c1, best_c1, best_n