| """ |
| L-BFGS-B with JAX analytical gradients. |
| Many random restarts. Fast convergence. |
| """ |
| import numpy as np |
| from scipy.optimize import minimize |
| import jax |
| import jax.numpy as jnp |
|
|
|
|
| def compute_c1_numpy(f_values, n_points): |
| dx = 0.5 / n_points |
| f_nn = np.maximum(f_values, 0.0) |
| autoconv = np.convolve(f_nn, f_nn, mode='full') * dx |
| integral_sq = (np.sum(f_nn) * dx) ** 2 |
| if integral_sq < 1e-12: |
| return 1e10 |
| return np.max(autoconv) / integral_sq |
|
|
|
|
| def make_objective(N, temp): |
| dx = 0.5 / N |
|
|
| @jax.jit |
| def obj_and_grad(params): |
| f = jnp.exp(jnp.clip(params, -8, 4)) |
| padded = jnp.zeros(2 * N) |
| padded = padded.at[:N].set(f) |
| fft_f = jnp.fft.rfft(padded) |
| conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx |
| integral_sq = (jnp.sum(f) * dx) ** 2 |
| smooth_max = jax.nn.logsumexp(temp * conv) / temp |
| return smooth_max / integral_sq |
|
|
| value_and_grad = jax.jit(jax.value_and_grad(obj_and_grad)) |
|
|
| def scipy_wrapper(params_np): |
| p = jnp.array(params_np) |
| v, g = value_and_grad(p) |
| return float(v), np.array(g, dtype=np.float64) |
|
|
| return scipy_wrapper |
|
|
|
|
| def run(): |
| best_c1 = float('inf') |
| best_f = None |
| best_n = None |
|
|
| for N in [1000, 2000, 4000]: |
| dx = 0.5 / N |
| print(f"\n=== N={N} ===") |
|
|
| |
| inits = [] |
| x = np.linspace(0, 1, N) |
|
|
| for seed in range(30): |
| np.random.seed(seed) |
| if seed == 0: |
| f = np.ones(N) |
| elif seed == 1: |
| f = 0.5 + 0.5 * np.cos(2*np.pi*x) |
| elif seed == 2: |
| f = np.exp(-10*(x-0.5)**2) + 0.1 |
| elif seed == 3: |
| f = np.exp(-5*(x-0.3)**2) + 0.05 |
| elif seed == 4: |
| f = np.exp(-5*(x-0.7)**2) + 0.05 |
| elif seed == 5: |
| f = 0.2 + 0.8 * np.cos(np.pi*x)**2 |
| elif seed == 6: |
| f = 0.5 + 0.3*np.cos(2*np.pi*x) + 0.1*np.cos(4*np.pi*x) |
| elif seed == 7: |
| f = np.maximum(1 - 4*np.abs(x-0.5), 0) + 0.1 |
| elif seed == 8: |
| f = 0.3 + 0.7*x |
| elif seed == 9: |
| f = 0.3 + 0.7*(1-x) |
| elif seed == 10: |
| f = 1 + 0.5*np.cos(6*np.pi*x) |
| elif seed == 11: |
| f = np.where((x > 0.15) & (x < 0.85), 1.0, 0.3) |
| elif seed == 12: |
| f = 0.5 + 0.3*np.cos(2*np.pi*x) + 0.15*np.cos(4*np.pi*x) + 0.05*np.cos(6*np.pi*x) |
| else: |
| f = np.abs(np.random.randn(N)) * 0.3 + 0.2 |
|
|
| inits.append(np.maximum(f, 0.01)) |
|
|
| |
| for init_idx, init_f in enumerate(inits): |
| params = np.log(np.maximum(init_f, 1e-6)) |
|
|
| |
| for temp in [50.0, 200.0, 1000.0, 5000.0, 20000.0]: |
| obj_fn = make_objective(N, temp) |
| result = minimize( |
| obj_fn, params, method='L-BFGS-B', jac=True, |
| options={'maxiter': 3000, 'ftol': 1e-15, 'gtol': 1e-12}, |
| ) |
| params = result.x |
|
|
| f_opt = np.exp(np.clip(params, -8, 4)) |
| c1 = compute_c1_numpy(f_opt, N) |
|
|
| if c1 < best_c1: |
| best_c1 = c1 |
| best_f = f_opt |
| best_n = N |
| print(f" init={init_idx}: C1={c1:.10f} ***", flush=True) |
| elif init_idx < 13: |
| print(f" init={init_idx}: C1={c1:.10f}", flush=True) |
|
|
| print(f"\nFinal best C1: {best_c1:.10f}") |
| return best_f, best_c1, best_c1, best_n |
|
|