| import numpy as np |
| from scipy.optimize import minimize as scipy_minimize |
| import jax |
| import jax.numpy as jnp |
| import optax |
|
|
|
|
| def compute_c1_numpy(f_values, n_points): |
| dx = 0.5 / n_points |
| f_nn = np.maximum(f_values, 0.0) |
| autoconv = np.convolve(f_nn, f_nn, mode='full') * dx |
| integral_sq = (np.sum(f_nn) * dx) ** 2 |
| if integral_sq < 1e-12: |
| return 1e10 |
| return np.max(autoconv) / integral_sq |
|
|
|
|
| def fourier_to_function(coeffs, N): |
| """Convert Fourier coefficients to function values. |
| coeffs: [a0, a1, b1, a2, b2, ..., aK, bK] |
| f(x) = a0 + sum_k (a_k cos(2*pi*k*x) + b_k sin(2*pi*k*x)) |
| Clamped to non-negative. |
| """ |
| K = (len(coeffs) - 1) // 2 |
| x = np.linspace(0, 1, N, endpoint=False) |
| f = np.full(N, coeffs[0]) |
| for k in range(1, K + 1): |
| f += coeffs[2*k - 1] * np.cos(2 * np.pi * k * x) |
| f += coeffs[2*k] * np.sin(2 * np.pi * k * x) |
| return np.maximum(f, 0.0) |
|
|
|
|
| def fourier_objective(coeffs, N): |
| f = fourier_to_function(coeffs, N) |
| return compute_c1_numpy(f, N) |
|
|
|
|
| def run(): |
| import cma |
|
|
| best_c1 = float('inf') |
| best_f = None |
| best_n = None |
| N_eval = 4000 |
|
|
| |
| for K in [15, 25, 40, 60]: |
| n_coeffs = 2 * K + 1 |
| print(f"\n=== CMA-ES with K={K} ({n_coeffs} params), N_eval={N_eval} ===") |
|
|
| |
| x0 = np.zeros(n_coeffs) |
| x0[0] = 0.5 |
|
|
| sigma0 = 0.3 |
|
|
| opts = cma.CMAOptions() |
| opts['maxiter'] = 3000 |
| opts['tolfun'] = 1e-12 |
| opts['tolx'] = 1e-12 |
| opts['popsize'] = 40 |
| opts['seed'] = 42 |
| opts['verbose'] = -1 |
|
|
| es = cma.CMAEvolutionStrategy(x0, sigma0, opts) |
|
|
| gen = 0 |
| while not es.stop(): |
| solutions = es.ask() |
| fitnesses = [fourier_objective(s, N_eval) for s in solutions] |
| es.tell(solutions, fitnesses) |
| gen += 1 |
| if gen % 100 == 0: |
| best_fit = min(fitnesses) |
| print(f" Gen {gen}: best C1 = {best_fit:.10f}") |
|
|
| result = es.result |
| best_coeffs = result[0] |
| f_opt = fourier_to_function(best_coeffs, N_eval) |
| c1 = compute_c1_numpy(f_opt, N_eval) |
| print(f" CMA-ES result: C1 = {c1:.10f}") |
|
|
| |
| result_lbfgs = scipy_minimize( |
| lambda c: fourier_objective(c, N_eval), |
| best_coeffs, method='Nelder-Mead', |
| options={'maxiter': 50000, 'xatol': 1e-12, 'fatol': 1e-12}, |
| ) |
| f_polish = fourier_to_function(result_lbfgs.x, N_eval) |
| c1_polish = compute_c1_numpy(f_polish, N_eval) |
| print(f" After Nelder-Mead: C1 = {c1_polish:.10f}") |
|
|
| if c1_polish < c1: |
| c1 = c1_polish |
| f_opt = f_polish |
|
|
| if c1 < best_c1: |
| best_c1 = c1 |
| best_f = f_opt |
| best_n = N_eval |
| print(f" *** NEW GLOBAL BEST: C1 = {c1:.10f}") |
|
|
| |
| print("\n=== JAX perturbation-restart ===") |
| N = 3000 |
| dx = 0.5 / N |
|
|
| @jax.jit |
| def obj_smooth_jax(params, temp): |
| f = jnp.exp(jnp.clip(params, -8, 4)) |
| padded = jnp.zeros(2 * N) |
| padded = padded.at[:N].set(f) |
| fft_f = jnp.fft.rfft(padded) |
| conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx |
| integral_sq = (jnp.sum(f) * dx) ** 2 |
| smooth_max = jax.nn.logsumexp(temp * conv) / temp |
| return smooth_max / integral_sq |
|
|
| @jax.jit |
| def obj_hard_jax(params): |
| f = jnp.exp(jnp.clip(params, -8, 4)) |
| padded = jnp.zeros(2 * N) |
| padded = padded.at[:N].set(f) |
| fft_f = jnp.fft.rfft(padded) |
| conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx |
| integral_sq = (jnp.sum(f) * dx) ** 2 |
| return jnp.max(conv) / integral_sq |
|
|
| grad_fn = jax.jit(jax.grad(obj_smooth_jax)) |
|
|
| |
| np.random.seed(42) |
| init_f = np.ones(N) * 0.5 + 0.02 * np.random.randn(N) |
| params = jnp.array(np.log(np.maximum(init_f, 1e-6))) |
|
|
| |
| lr_schedule = optax.warmup_cosine_decay_schedule( |
| init_value=0.0, peak_value=0.008, warmup_steps=2000, |
| decay_steps=78000, end_value=1e-6, |
| ) |
| optimizer = optax.adam(learning_rate=lr_schedule) |
| opt_state = optimizer.init(params) |
|
|
| best_c1_jax = float('inf') |
| best_params_jax = params |
|
|
| for step in range(80000): |
| temp = 200.0 |
| loss, grads = jax.value_and_grad(obj_smooth_jax)(params, temp) |
| updates, opt_state = optimizer.update(grads, opt_state, params) |
| params = optax.apply_updates(params, updates) |
|
|
| if step % 20000 == 0: |
| hc = float(obj_hard_jax(params)) |
| print(f" Step {step}: C1={hc:.8f}") |
| if hc < best_c1_jax: |
| best_c1_jax = hc |
| best_params_jax = params |
|
|
| hc = float(obj_hard_jax(params)) |
| if hc < best_c1_jax: |
| best_c1_jax = hc |
| best_params_jax = params |
|
|
| |
| for restart in range(5): |
| np.random.seed(100 + restart) |
| perturbed = best_params_jax + 0.1 * jax.random.normal(jax.random.PRNGKey(100 + restart), shape=(N,)) |
|
|
| lr_sched2 = optax.warmup_cosine_decay_schedule( |
| init_value=0.0, peak_value=0.003, warmup_steps=1000, |
| decay_steps=29000, end_value=1e-6, |
| ) |
| opt2 = optax.adam(learning_rate=lr_sched2) |
| opt_state2 = opt2.init(perturbed) |
|
|
| for step in range(30000): |
| loss, grads = jax.value_and_grad(obj_smooth_jax)(perturbed, 300.0) |
| updates, opt_state2 = opt2.update(grads, opt_state2, perturbed) |
| perturbed = optax.apply_updates(perturbed, updates) |
|
|
| hc = float(obj_hard_jax(perturbed)) |
| print(f" Restart {restart}: C1={hc:.8f}") |
| if hc < best_c1_jax: |
| best_c1_jax = hc |
| best_params_jax = perturbed |
|
|
| |
| params_np = np.array(best_params_jax, dtype=np.float64) |
| for temp in [1000.0, 5000.0, 20000.0]: |
| def scipy_obj(p): |
| p_jax = jnp.array(p) |
| val = float(obj_smooth_jax(p_jax, temp)) |
| g = np.array(grad_fn(p_jax, temp), dtype=np.float64) |
| return val, g |
|
|
| result = scipy_minimize( |
| scipy_obj, params_np, method='L-BFGS-B', jac=True, |
| options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-12}, |
| ) |
| params_np = result.x |
|
|
| f_jax = np.exp(np.clip(params_np, -8, 4)) |
| c1_jax = compute_c1_numpy(f_jax, N) |
| print(f" JAX final: C1={c1_jax:.10f}") |
|
|
| if c1_jax < best_c1: |
| best_c1 = c1_jax |
| best_f = f_jax |
| best_n = N |
|
|
| print(f"\nFinal best C1: {best_c1:.10f}") |
| return best_f, best_c1, best_c1, best_n |
|
|