""" Key insight: use JAX's hard max directly (subgradient via jnp.max). Phase 1: Adam with moderate temp smooth max for exploration. Phase 2: L-BFGS with HARD max for exact convergence. """ import sys import jax import jax.numpy as jnp import numpy as np from scipy.optimize import minimize as scipy_minimize import optax def compute_c1_numpy(f_values, n_points): dx = 0.5 / n_points f_nn = np.maximum(f_values, 0.0) autoconv = np.convolve(f_nn, f_nn, mode='full') * dx integral_sq = (np.sum(f_nn) * dx) ** 2 if integral_sq < 1e-12: return 1e10 return np.max(autoconv) / integral_sq def run(): N = 4000 dx = 0.5 / N @jax.jit def obj_smooth(params, temp): f = jnp.exp(jnp.clip(params, -8, 4)) padded = jnp.zeros(2 * N) padded = padded.at[:N].set(f) fft_f = jnp.fft.rfft(padded) conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx integral_sq = (jnp.sum(f) * dx) ** 2 smooth_max = jax.nn.logsumexp(temp * conv) / temp return smooth_max / integral_sq @jax.jit def obj_hard(params): f = jnp.exp(jnp.clip(params, -8, 4)) padded = jnp.zeros(2 * N) padded = padded.at[:N].set(f) fft_f = jnp.fft.rfft(padded) conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx integral_sq = (jnp.sum(f) * dx) ** 2 return jnp.max(conv) / integral_sq grad_smooth = jax.jit(jax.grad(obj_smooth)) grad_hard = jax.jit(jax.grad(obj_hard)) best_c1_overall = float('inf') best_f_overall = None for seed in range(5): np.random.seed(seed) # Flat initialization (near constant = best for gradient descent start) init_f = np.ones(N) * 0.5 + 0.02 * np.random.randn(N) params = jnp.array(np.log(np.maximum(init_f, 1e-6))) # Phase 1: Adam with smooth max for exploration adam_steps = 100000 lr_schedule = optax.warmup_cosine_decay_schedule( init_value=0.0, peak_value=0.005, warmup_steps=2000, decay_steps=adam_steps - 2000, end_value=1e-6, ) optimizer = optax.adam(learning_rate=lr_schedule) opt_state = optimizer.init(params) best_c1 = float('inf') best_params = params for step in range(adam_steps): loss, grads = jax.value_and_grad(obj_smooth)(params, 300.0) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) if step % 25000 == 0 or step == adam_steps - 1: hc = float(obj_hard(params)) sys.stdout.write(f"[{seed}] Adam {step:7d} | C1={hc:.10f}\n") sys.stdout.flush() if hc < best_c1: best_c1 = hc best_params = params # Phase 2: L-BFGS with HARD max (exact objective) sys.stdout.write(f"[{seed}] Phase 2: L-BFGS with hard max from C1={best_c1:.10f}\n") sys.stdout.flush() params_np = np.array(best_params, dtype=np.float64) # First do smooth L-BFGS to get close for temp in [1000.0, 10000.0]: def scipy_obj_smooth(p): p_jax = jnp.array(p) val = float(obj_smooth(p_jax, temp)) g = np.array(grad_smooth(p_jax, temp), dtype=np.float64) return val, g result = scipy_minimize( scipy_obj_smooth, params_np, method='L-BFGS-B', jac=True, options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-14, 'maxcor': 100}, ) params_np = result.x # Then hard max L-BFGS def scipy_obj_hard(p): p_jax = jnp.array(p) val = float(obj_hard(p_jax)) g = np.array(grad_hard(p_jax), dtype=np.float64) return val, g result = scipy_minimize( scipy_obj_hard, params_np, method='L-BFGS-B', jac=True, options={'maxiter': 20000, 'ftol': 1e-15, 'gtol': 1e-14, 'maxcor': 100}, ) params_np = result.x f_final = np.exp(np.clip(params_np, -8, 4)) c1_final = compute_c1_numpy(f_final, N) sys.stdout.write(f"[{seed}] Hard L-BFGS: C1={c1_final:.10f}\n") sys.stdout.flush() if c1_final < best_c1_overall: best_c1_overall = c1_final best_f_overall = f_final sys.stdout.write(f"*** GLOBAL BEST: C1={c1_final:.10f}\n") sys.stdout.flush() sys.stdout.write(f"\nFinal C1: {best_c1_overall:.10f}\n") sys.stdout.flush() return best_f_overall, best_c1_overall, best_c1_overall, N