""" Key insight from analysis: optimal function has U-shape (higher at edges). Strategy: Basin hopping + L-BFGS with diverse initializations, then JAX polish. """ import numpy as np from scipy.optimize import minimize, basinhopping import jax import jax.numpy as jnp import optax def compute_c1_numpy(f_values, n_points): dx = 0.5 / n_points f_nn = np.maximum(f_values, 0.0) autoconv = np.convolve(f_nn, f_nn, mode='full') * dx integral_sq = (np.sum(f_nn) * dx) ** 2 if integral_sq < 1e-12: return 1e10 return np.max(autoconv) / integral_sq def compute_c1_fft(f_values, dx): n = len(f_values) padded = np.zeros(2 * n) padded[:n] = f_values fft_f = np.fft.rfft(padded) conv = np.fft.irfft(fft_f * fft_f, n=2*n) * dx integral_sq = (np.sum(f_values) * dx) ** 2 if integral_sq < 1e-12: return 1e10 return np.max(conv) / integral_sq def fourier_to_f(coeffs, N): """Convert Fourier coefficients to non-negative function. coeffs = [a0, a1, b1, a2, b2, ..., aK, bK] """ K = (len(coeffs) - 1) // 2 x = np.linspace(0, 1, N, endpoint=False) f = np.full(N, coeffs[0]) for k in range(1, K + 1): f += coeffs[2*k-1] * np.cos(2*np.pi*k*x) f += coeffs[2*k] * np.sin(2*np.pi*k*x) return np.maximum(f, 0.0) def optimize_fourier_basin(K, N, n_restarts=20, seed=42): """Basin hopping on Fourier coefficients""" dx = 0.5 / N n_coeffs = 2 * K + 1 def objective(coeffs): f = fourier_to_f(coeffs, N) return compute_c1_fft(f, dx) best_c1 = float('inf') best_coeffs = None for restart in range(n_restarts): np.random.seed(seed + restart * 7) # Initialize with different shapes x0 = np.zeros(n_coeffs) x0[0] = 1.0 # base level if restart % 5 == 0: pass # constant elif restart % 5 == 1: x0[1] = 0.3 # cos(2*pi*x) = U-shape elif restart % 5 == 2: x0[1] = -0.3 # inverted U-shape elif restart % 5 == 3: x0[1] = 0.2 x0[3] = 0.1 # U-shape + higher harmonics elif restart % 5 == 4: x0[:] = np.random.randn(n_coeffs) * 0.2 x0[0] = 1.0 # Basin hopping minimizer_kwargs = { 'method': 'L-BFGS-B', 'options': {'maxiter': 500, 'ftol': 1e-12}, } result = basinhopping( objective, x0, minimizer_kwargs=minimizer_kwargs, niter=200, T=0.05, stepsize=0.3, seed=seed + restart, ) c1 = result.fun if c1 < best_c1: best_c1 = c1 best_coeffs = result.x print(f" restart={restart}: C1={c1:.10f} ***", flush=True) return best_coeffs, best_c1 def jax_polish(f_init, N, adam_steps=60000): """Polish with JAX gradient descent + L-BFGS""" dx = 0.5 / N @jax.jit def obj_smooth(params, temp): f = jnp.exp(jnp.clip(params, -8, 4)) padded = jnp.zeros(2 * N) padded = padded.at[:N].set(f) fft_f = jnp.fft.rfft(padded) conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx integral_sq = (jnp.sum(f) * dx) ** 2 smooth_max = jax.nn.logsumexp(temp * conv) / temp return smooth_max / integral_sq @jax.jit def obj_hard(params): f = jnp.exp(jnp.clip(params, -8, 4)) padded = jnp.zeros(2 * N) padded = padded.at[:N].set(f) fft_f = jnp.fft.rfft(padded) conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx integral_sq = (jnp.sum(f) * dx) ** 2 return jnp.max(conv) / integral_sq grad_fn = jax.jit(jax.grad(obj_smooth)) params = jnp.array(np.log(np.maximum(f_init, 1e-6))) lr_schedule = optax.warmup_cosine_decay_schedule( init_value=0.0, peak_value=0.003, warmup_steps=1000, decay_steps=adam_steps - 1000, end_value=1e-7, ) optimizer = optax.adam(learning_rate=lr_schedule) opt_state = optimizer.init(params) best_c1 = float('inf') best_params = params for step in range(adam_steps): temp = 300.0 loss, grads = jax.value_and_grad(obj_smooth)(params, temp) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) if step % 15000 == 0 or step == adam_steps - 1: hc = float(obj_hard(params)) print(f" Step {step}: C1={hc:.8f}") if hc < best_c1: best_c1 = hc best_params = params # L-BFGS polish params_np = np.array(best_params, dtype=np.float64) for temp in [1000.0, 5000.0, 20000.0]: def scipy_obj(p): p_jax = jnp.array(p) val = float(obj_smooth(p_jax, temp)) g = np.array(grad_fn(p_jax, temp), dtype=np.float64) return val, g result = minimize( scipy_obj, params_np, method='L-BFGS-B', jac=True, options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-12}, ) params_np = result.x f_final = np.exp(np.clip(params_np, -8, 4)) c1_final = compute_c1_numpy(f_final, N) print(f" After L-BFGS: C1={c1_final:.10f}") return f_final, min(c1_final, best_c1) def run(): best_c1 = float('inf') best_f = None best_n = None # Phase 1: Basin hopping on Fourier coefficients at moderate resolution N_coarse = 500 for K in [10, 20, 30, 50]: print(f"\nFourier basin hopping K={K}, N={N_coarse}") coeffs, c1 = optimize_fourier_basin(K, N_coarse, n_restarts=15) f = fourier_to_f(coeffs, N_coarse) c1_verify = compute_c1_numpy(f, N_coarse) print(f" Best: C1={c1_verify:.10f}") if c1_verify < best_c1: best_c1 = c1_verify # Upsample to fine grid N_fine = 3000 best_f = fourier_to_f(coeffs, N_fine) best_n = N_fine best_coeffs = coeffs print(f" *** GLOBAL BEST: C1={c1_verify:.10f}") # Phase 2: Polish best result with JAX if best_f is not None: print(f"\nPolishing best result (C1={best_c1:.10f})...") N_fine = 3000 f_polished, c1_polished = jax_polish(best_f, N_fine, adam_steps=60000) print(f" Polished: C1={c1_polished:.10f}") if c1_polished < best_c1: best_c1 = c1_polished best_f = f_polished best_n = N_fine # Phase 3: Direct JAX with U-shaped init print(f"\nDirect JAX with U-shaped initialization...") N = 3000 x = np.linspace(0, 1, N) # U-shape: higher at edges, lower in middle init_f = 0.5 + 0.5 * np.cos(2 * np.pi * x) # U-shape init_f = np.maximum(init_f, 0.01) f_jax, c1_jax = jax_polish(init_f, N, adam_steps=60000) print(f" U-shape JAX: C1={c1_jax:.10f}") if c1_jax < best_c1: best_c1 = c1_jax best_f = f_jax best_n = N print(f"\nFinal best C1: {best_c1:.10f}") return best_f, best_c1, best_c1, best_n