| """ |
| Key innovations: |
| 1. Normalize f so integral=1 (removes scale degeneracy) |
| 2. Use Lp norm as smooth max approximation |
| 3. Multi-phase: start with low p, increase gradually |
| 4. Many diverse random restarts |
| """ |
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| from scipy.optimize import minimize as scipy_minimize |
| import optax |
|
|
|
|
| def compute_c1_numpy(f_values, n_points): |
| dx = 0.5 / n_points |
| f_nn = np.maximum(f_values, 0.0) |
| autoconv = np.convolve(f_nn, f_nn, mode='full') * dx |
| integral_sq = (np.sum(f_nn) * dx) ** 2 |
| if integral_sq < 1e-12: |
| return 1e10 |
| return np.max(autoconv) / integral_sq |
|
|
|
|
| def make_fns(N): |
| dx = 0.5 / N |
|
|
| @jax.jit |
| def params_to_f(params): |
| """Convert params to normalized non-negative function""" |
| f = jax.nn.softplus(params) |
| |
| integral = jnp.sum(f) * dx |
| f = f / jnp.maximum(integral, 1e-9) |
| return f |
|
|
| @jax.jit |
| def compute_c1_smooth(params, p): |
| """Lp norm approximation to C1""" |
| f = params_to_f(params) |
| padded = jnp.zeros(2 * N) |
| padded = padded.at[:N].set(f) |
| fft_f = jnp.fft.rfft(padded) |
| conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx |
| |
| |
| |
| conv_pos = jnp.maximum(conv, 0.0) |
| lp = (jnp.mean(conv_pos ** p)) ** (1.0 / p) |
| return lp |
|
|
| @jax.jit |
| def compute_c1_logsumexp(params, temp): |
| """LogSumExp approximation to C1""" |
| f = params_to_f(params) |
| padded = jnp.zeros(2 * N) |
| padded = padded.at[:N].set(f) |
| fft_f = jnp.fft.rfft(padded) |
| conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx |
| return jax.nn.logsumexp(temp * conv) / temp |
|
|
| @jax.jit |
| def compute_c1_hard(params): |
| f = params_to_f(params) |
| padded = jnp.zeros(2 * N) |
| padded = padded.at[:N].set(f) |
| fft_f = jnp.fft.rfft(padded) |
| conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx |
| return jnp.max(conv) |
|
|
| grad_lse = jax.jit(jax.grad(compute_c1_logsumexp)) |
|
|
| return params_to_f, compute_c1_smooth, compute_c1_logsumexp, compute_c1_hard, grad_lse |
|
|
|
|
| def optimize_single(N, seed, steps=80000, verbose=True): |
| dx = 0.5 / N |
| params_to_f, c1_smooth, c1_lse, c1_hard, grad_lse = make_fns(N) |
|
|
| np.random.seed(seed) |
| x = np.linspace(0, 1, N) |
|
|
| |
| init_types = [ |
| lambda: np.ones(N), |
| lambda: np.exp(-10 * (x - 0.5)**2) + 0.1, |
| lambda: np.exp(-5 * (x - 0.3)**2) + 0.05, |
| lambda: np.exp(-5 * (x - 0.7)**2) + 0.05, |
| lambda: 0.5 + 0.5 * np.cos(2*np.pi*x), |
| lambda: np.maximum(1 - 4*np.abs(x - 0.5), 0) + 0.1, |
| lambda: np.where((x > 0.1) & (x < 0.9), 1.0, 0.1), |
| lambda: np.where((x > 0.2) & (x < 0.8), 1.0, 0.1), |
| lambda: np.exp(-50 * (x - 0.5)**2) + 0.01, |
| lambda: np.abs(np.random.randn(N)) * 0.3 + 0.2, |
| lambda: 1 - 0.5 * np.abs(np.sin(3*np.pi*x)), |
| lambda: np.exp(-3*(x - 0.4)**2) + 0.5*np.exp(-3*(x-0.6)**2) + 0.05, |
| ] |
|
|
| init_f = init_types[seed % len(init_types)]() |
| init_f = np.maximum(init_f, 0.01) |
| |
| params = jnp.array(np.log(np.expm1(np.maximum(init_f, 1e-3)))) |
|
|
| |
| lr_schedule = optax.warmup_cosine_decay_schedule( |
| init_value=0.0, peak_value=0.005, warmup_steps=2000, |
| decay_steps=steps - 2000, end_value=1e-6, |
| ) |
| optimizer = optax.adam(learning_rate=lr_schedule) |
| opt_state = optimizer.init(params) |
|
|
| best_c1 = float('inf') |
| best_params = params |
|
|
| for step in range(steps): |
| progress = min(step / steps, 1.0) |
| |
| p = 4.0 + progress * 36.0 |
|
|
| loss, grads = jax.value_and_grad(c1_smooth)(params, p) |
| updates, opt_state = optimizer.update(grads, opt_state, params) |
| params = optax.apply_updates(params, updates) |
|
|
| if step % 20000 == 0 or step == steps - 1: |
| hc = float(c1_hard(params)) |
| if verbose: |
| print(f" [{seed:2d}] Step {step:6d} | C1={hc:.8f} | p={p:.1f}") |
| if hc < best_c1: |
| best_c1 = hc |
| best_params = params |
|
|
| |
| params_np = np.array(best_params, dtype=np.float64) |
| for temp in [500.0, 2000.0, 10000.0]: |
| def scipy_obj(p_arr): |
| p_jax = jnp.array(p_arr) |
| val = float(c1_lse(p_jax, temp)) |
| g = np.array(grad_lse(p_jax, temp), dtype=np.float64) |
| return val, g |
|
|
| result = scipy_minimize( |
| scipy_obj, params_np, method='L-BFGS-B', jac=True, |
| options={'maxiter': 3000, 'ftol': 1e-15, 'gtol': 1e-12}, |
| ) |
| params_np = result.x |
|
|
| |
| f_norm = np.log1p(np.exp(params_np)) |
| f_norm = np.maximum(f_norm, 0.0) |
| c1_final = compute_c1_numpy(f_norm, N) |
|
|
| if verbose: |
| print(f" [{seed:2d}] Final: C1={c1_final:.10f}") |
|
|
| return f_norm, c1_final |
|
|
|
|
| def run(): |
| N = 3000 |
| best_c1 = float('inf') |
| best_f = None |
|
|
| for seed in range(12): |
| f, c1 = optimize_single(N, seed, steps=60000) |
| if c1 < best_c1: |
| best_c1 = c1 |
| best_f = f |
| print(f" *** GLOBAL BEST: C1={c1:.10f} (seed={seed})") |
|
|
| |
| print(f"\nExtended polishing at N=5000...") |
| N2 = 5000 |
| |
| f_up = np.interp(np.linspace(0, 1, N2), np.linspace(0, 1, len(best_f)), best_f) |
| f_up2, c1_up = optimize_single(N2, 99, steps=40000) |
|
|
| |
| dx2 = 0.5 / N2 |
| _, _, c1_lse2, c1_hard2, grad_lse2 = make_fns(N2) |
| params_up = jnp.array(np.log(np.expm1(np.maximum(f_up, 1e-3)))) |
|
|
| params_np = np.array(params_up, dtype=np.float64) |
| for temp in [500.0, 2000.0, 10000.0, 50000.0]: |
| def scipy_obj(p_arr): |
| p_jax = jnp.array(p_arr) |
| val = float(c1_lse2(p_jax, temp)) |
| g = np.array(grad_lse2(p_jax, temp), dtype=np.float64) |
| return val, g |
|
|
| result = scipy_minimize( |
| scipy_obj, params_np, method='L-BFGS-B', jac=True, |
| options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-12}, |
| ) |
| params_np = result.x |
|
|
| f_final = np.log1p(np.exp(params_np)) |
| c1_final = compute_c1_numpy(f_final, N2) |
| print(f"Upsampled polished: C1={c1_final:.10f}") |
|
|
| if c1_final < best_c1: |
| best_c1 = c1_final |
| best_f = f_final |
| N = N2 |
|
|
| print(f"\nFinal best C1: {best_c1:.10f}") |
| return best_f, best_c1, best_c1, len(best_f) |
|
|