import jax import jax.numpy as jnp import numpy as np from scipy.optimize import minimize as scipy_minimize import optax def compute_c1_numpy(f_values, n_points): dx = 0.5 / n_points f_nn = np.maximum(f_values, 0.0) autoconv = np.convolve(f_nn, f_nn, mode='full') * dx integral_sq = (np.sum(f_nn) * dx) ** 2 if integral_sq < 1e-12: return 1e10 return np.max(autoconv) / integral_sq def make_fns(N): dx = 0.5 / N @jax.jit def get_f(params): return jnp.exp(jnp.clip(params, -8, 4)) @jax.jit def objective_smooth(params, temp): f = get_f(params) padded = jnp.zeros(2 * N) padded = padded.at[:N].set(f) fft_f = jnp.fft.rfft(padded) conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx integral_sq = (jnp.sum(f) * dx) ** 2 smooth_max = jax.nn.logsumexp(temp * conv) / temp return smooth_max / integral_sq @jax.jit def objective_hard(params): f = get_f(params) padded = jnp.zeros(2 * N) padded = padded.at[:N].set(f) fft_f = jnp.fft.rfft(padded) conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx integral_sq = (jnp.sum(f) * dx) ** 2 return jnp.max(conv) / integral_sq grad_smooth = jax.jit(jax.grad(objective_smooth)) return get_f, objective_smooth, objective_hard, grad_smooth def optimize(N, adam_steps, lr_peak, seed=42, init_params=None, verbose=True): dx = 0.5 / N get_f, objective_smooth, objective_hard, grad_smooth = make_fns(N) if init_params is not None: # Upsample old_f = np.exp(np.clip(init_params, -8, 4)) new_f = np.interp(np.linspace(0, 1, N), np.linspace(0, 1, len(init_params)), old_f) params = jnp.array(np.log(np.maximum(new_f, 1e-6))) else: np.random.seed(seed) init_f = np.ones(N) * 0.5 + 0.02 * np.random.randn(N) params = jnp.array(np.log(np.maximum(init_f, 1e-6))) # Adam optimization with fixed moderate temperature lr_schedule = optax.warmup_cosine_decay_schedule( init_value=0.0, peak_value=lr_peak, warmup_steps=2000, decay_steps=adam_steps - 2000, end_value=lr_peak * 1e-5, ) optimizer = optax.adam(learning_rate=lr_schedule) opt_state = optimizer.init(params) best_c1 = float('inf') best_params = params temp = 200.0 for step in range(adam_steps): loss, grads = jax.value_and_grad(objective_smooth)(params, temp) # Clip gradients grads = jnp.clip(grads, -1.0, 1.0) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) if step % 10000 == 0 or step == adam_steps - 1: hard_c1 = float(objective_hard(params)) if verbose: print(f" Step {step:6d} | C1={hard_c1:.8f}") if hard_c1 < best_c1: best_c1 = hard_c1 best_params = params # L-BFGS polishing if verbose: print(f" L-BFGS polishing from C1={best_c1:.8f}") params_np = np.array(best_params, dtype=np.float64) for temp_lbfgs in [500.0, 2000.0, 5000.0, 20000.0]: def scipy_obj(p): p_jax = jnp.array(p) val = float(objective_smooth(p_jax, temp_lbfgs)) g = np.array(grad_smooth(p_jax, temp_lbfgs), dtype=np.float64) return val, g result = scipy_minimize( scipy_obj, params_np, method='L-BFGS-B', jac=True, options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-12}, ) params_np = result.x f_opt = np.exp(np.clip(params_np, -8, 4)) c1 = compute_c1_numpy(f_opt, N) if verbose: print(f" temp={temp_lbfgs:.0f}: C1={c1:.10f}") if c1 < best_c1: best_c1 = c1 best_params = jnp.array(params_np) return np.array(best_params), best_c1 def run(): best_c1 = float('inf') best_f = None best_n = None # Stage 1: coarse print("=== Stage 1: N=1500 ===") params, c1 = optimize(1500, 80000, 0.008, seed=42) print(f" Stage 1 C1: {c1:.10f}") # Stage 2: medium print("\n=== Stage 2: N=3000 ===") params, c1 = optimize(3000, 80000, 0.004, init_params=params) print(f" Stage 2 C1: {c1:.10f}") # Stage 3: fine print("\n=== Stage 3: N=5000 ===") params, c1 = optimize(5000, 60000, 0.002, init_params=params) print(f" Stage 3 C1: {c1:.10f}") N = 5000 f_final = np.exp(np.clip(params, -8, 4)) c1_final = compute_c1_numpy(f_final, N) if c1_final < best_c1: best_c1 = c1_final best_f = f_final best_n = N # Also try direct N=4000 from scratch with different seed print("\n=== Direct: N=4000 seed=0 ===") params2, c1_2 = optimize(4000, 100000, 0.005, seed=0) print(f" Direct C1: {c1_2:.10f}") if c1_2 < best_c1: N2 = 4000 f2 = np.exp(np.clip(params2, -8, 4)) best_c1 = c1_2 best_f = f2 best_n = N2 print(f"\nFinal best C1: {best_c1:.10f}") return best_f, best_c1, best_c1, best_n