| """ |
| Take best approach (v14 seed 2) and push further: |
| 1. Higher N (6000) |
| 2. More Adam steps (200k) |
| 3. Multi-phase: Adam -> smooth L-BFGS -> hard L-BFGS -> perturb -> repeat |
| """ |
| import sys |
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| from scipy.optimize import minimize as scipy_minimize |
| import optax |
|
|
|
|
| def compute_c1_numpy(f_values, n_points): |
| dx = 0.5 / n_points |
| f_nn = np.maximum(f_values, 0.0) |
| autoconv = np.convolve(f_nn, f_nn, mode='full') * dx |
| integral_sq = (np.sum(f_nn) * dx) ** 2 |
| if integral_sq < 1e-12: |
| return 1e10 |
| return np.max(autoconv) / integral_sq |
|
|
|
|
| def make_fns(N): |
| dx = 0.5 / N |
|
|
| @jax.jit |
| def obj_smooth(params, temp): |
| f = jnp.exp(jnp.clip(params, -8, 4)) |
| padded = jnp.zeros(2 * N) |
| padded = padded.at[:N].set(f) |
| fft_f = jnp.fft.rfft(padded) |
| conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx |
| integral_sq = (jnp.sum(f) * dx) ** 2 |
| smooth_max = jax.nn.logsumexp(temp * conv) / temp |
| return smooth_max / integral_sq |
|
|
| @jax.jit |
| def obj_hard(params): |
| f = jnp.exp(jnp.clip(params, -8, 4)) |
| padded = jnp.zeros(2 * N) |
| padded = padded.at[:N].set(f) |
| fft_f = jnp.fft.rfft(padded) |
| conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx |
| integral_sq = (jnp.sum(f) * dx) ** 2 |
| return jnp.max(conv) / integral_sq |
|
|
| grad_smooth = jax.jit(jax.grad(obj_smooth)) |
| grad_hard = jax.jit(jax.grad(obj_hard)) |
|
|
| return obj_smooth, obj_hard, grad_smooth, grad_hard |
|
|
|
|
| def adam_phase(params, N, steps, lr_peak, temp=300.0): |
| obj_smooth, obj_hard, _, _ = make_fns(N) |
|
|
| lr_schedule = optax.warmup_cosine_decay_schedule( |
| init_value=0.0, peak_value=lr_peak, warmup_steps=2000, |
| decay_steps=steps - 2000, end_value=lr_peak * 1e-4, |
| ) |
| optimizer = optax.adam(learning_rate=lr_schedule) |
| opt_state = optimizer.init(params) |
|
|
| best_c1 = float('inf') |
| best_params = params |
|
|
| for step in range(steps): |
| loss, grads = jax.value_and_grad(obj_smooth)(params, temp) |
| updates, opt_state = optimizer.update(grads, opt_state, params) |
| params = optax.apply_updates(params, updates) |
|
|
| if step % 25000 == 0 or step == steps - 1: |
| hc = float(obj_hard(params)) |
| sys.stdout.write(f" Adam {step:7d} | C1={hc:.10f}\n") |
| sys.stdout.flush() |
| if hc < best_c1: |
| best_c1 = hc |
| best_params = params |
|
|
| return best_params, best_c1 |
|
|
|
|
| def lbfgs_phase(params_np, N): |
| obj_smooth, obj_hard, grad_smooth, grad_hard = make_fns(N) |
|
|
| |
| for temp in [1000.0, 10000.0, 100000.0]: |
| def scipy_obj(p): |
| p_jax = jnp.array(p) |
| val = float(obj_smooth(p_jax, temp)) |
| g = np.array(grad_smooth(p_jax, temp), dtype=np.float64) |
| return val, g |
|
|
| result = scipy_minimize( |
| scipy_obj, params_np, method='L-BFGS-B', jac=True, |
| options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-14, 'maxcor': 100}, |
| ) |
| params_np = result.x |
|
|
| |
| def scipy_obj_hard(p): |
| p_jax = jnp.array(p) |
| val = float(obj_hard(p_jax)) |
| g = np.array(grad_hard(p_jax), dtype=np.float64) |
| return val, g |
|
|
| for _ in range(3): |
| result = scipy_minimize( |
| scipy_obj_hard, params_np, method='L-BFGS-B', jac=True, |
| options={'maxiter': 20000, 'ftol': 1e-16, 'gtol': 1e-15, 'maxcor': 100}, |
| ) |
| params_np = result.x |
|
|
| return params_np |
|
|
|
|
| def run(): |
| N = 6000 |
| dx = 0.5 / N |
|
|
| best_c1_overall = float('inf') |
| best_f_overall = None |
|
|
| for seed in [2, 0, 42]: |
| sys.stdout.write(f"\n=== Seed {seed}, N={N} ===\n") |
| sys.stdout.flush() |
|
|
| np.random.seed(seed) |
| init_f = np.ones(N) * 0.5 + 0.02 * np.random.randn(N) |
| params = jnp.array(np.log(np.maximum(init_f, 1e-6))) |
|
|
| |
| params, c1 = adam_phase(params, N, steps=150000, lr_peak=0.005) |
| sys.stdout.write(f" After Adam: C1={c1:.10f}\n") |
| sys.stdout.flush() |
|
|
| |
| params_np = np.array(params, dtype=np.float64) |
| params_np = lbfgs_phase(params_np, N) |
|
|
| f_final = np.exp(np.clip(params_np, -8, 4)) |
| c1_final = compute_c1_numpy(f_final, N) |
| sys.stdout.write(f" After L-BFGS: C1={c1_final:.10f}\n") |
| sys.stdout.flush() |
|
|
| |
| best_params_np = params_np |
| best_c1_seed = c1_final |
|
|
| for pert in range(3): |
| key = jax.random.PRNGKey(seed * 100 + pert) |
| noise = 0.05 * jax.random.normal(key, shape=(N,)) |
| perturbed = jnp.array(best_params_np) + noise |
|
|
| |
| perturbed, c1_p = adam_phase(perturbed, N, steps=30000, lr_peak=0.002) |
|
|
| |
| p_np = np.array(perturbed, dtype=np.float64) |
| p_np = lbfgs_phase(p_np, N) |
|
|
| f_p = np.exp(np.clip(p_np, -8, 4)) |
| c1_p = compute_c1_numpy(f_p, N) |
| sys.stdout.write(f" Perturbation {pert}: C1={c1_p:.10f}\n") |
| sys.stdout.flush() |
|
|
| if c1_p < best_c1_seed: |
| best_c1_seed = c1_p |
| best_params_np = p_np |
| f_final = f_p |
|
|
| if best_c1_seed < best_c1_overall: |
| best_c1_overall = best_c1_seed |
| best_f_overall = f_final |
| sys.stdout.write(f"*** GLOBAL BEST: C1={best_c1_overall:.10f}\n") |
| sys.stdout.flush() |
|
|
| sys.stdout.write(f"\nFinal C1: {best_c1_overall:.10f}\n") |
| sys.stdout.flush() |
| return best_f_overall, best_c1_overall, best_c1_overall, N |
|
|