| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| from scipy.optimize import minimize as scipy_minimize |
| import optax |
|
|
|
|
| def compute_c1_numpy(f_values, n_points): |
| """Compute C1 using numpy (for verification)""" |
| dx = 0.5 / n_points |
| f_nn = np.maximum(f_values, 0.0) |
| autoconv = np.convolve(f_nn, f_nn, mode='full') * dx |
| integral_sq = (np.sum(f_nn) * dx) ** 2 |
| if integral_sq < 1e-12: |
| return 1e10 |
| return np.max(autoconv) / integral_sq |
|
|
|
|
| def make_objective_jax(N, dx): |
| """Create JAX objective function for C1 minimization""" |
|
|
| @jax.jit |
| def objective(params): |
| |
| f = jnp.exp(params) |
|
|
| padded = jnp.zeros(2 * N) |
| padded = padded.at[:N].set(f) |
| fft_f = jnp.fft.rfft(padded) |
| conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) |
| conv = conv * dx |
|
|
| integral = jnp.sum(f) * dx |
| integral_sq = integral ** 2 |
|
|
| c1 = jnp.max(conv) / integral_sq |
| return c1 |
|
|
| @jax.jit |
| def objective_smooth(params, temp): |
| f = jnp.exp(params) |
|
|
| padded = jnp.zeros(2 * N) |
| padded = padded.at[:N].set(f) |
| fft_f = jnp.fft.rfft(padded) |
| conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) |
| conv = conv * dx |
|
|
| integral = jnp.sum(f) * dx |
| integral_sq = integral ** 2 |
|
|
| smooth_max = jax.nn.logsumexp(temp * conv) / temp |
| c1 = smooth_max / integral_sq |
| return c1 |
|
|
| grad_fn = jax.jit(jax.grad(objective_smooth)) |
|
|
| return objective, objective_smooth, grad_fn |
|
|
|
|
| def run(): |
| best_c1_overall = float('inf') |
| best_f_overall = None |
| best_n_overall = None |
|
|
| for N in [1000, 2000, 3000]: |
| dx = 0.5 / N |
| objective, objective_smooth, grad_fn = make_objective_jax(N, dx) |
|
|
| for seed in range(5): |
| print(f"\n--- N={N}, seed={seed} ---") |
| np.random.seed(seed) |
|
|
| |
| x = np.linspace(0, 1, N) |
| if seed == 0: |
| init = np.exp(-10 * (x - 0.5) ** 2) + 0.1 |
| elif seed == 1: |
| init = np.ones(N) |
| elif seed == 2: |
| init = 0.5 * (1 + np.cos(2 * np.pi * (x - 0.5))) + 0.1 |
| elif seed == 3: |
| |
| init = np.where((x > 0.2) & (x < 0.8), 1.5, 0.5) |
| else: |
| init = np.abs(np.random.randn(N)) * 0.3 + 0.2 |
|
|
| params = np.log(np.maximum(init, 1e-6)) |
|
|
| |
| print("Phase 1: Adam optimization...") |
| params_jax = jnp.array(params) |
|
|
| lr_schedule = optax.warmup_cosine_decay_schedule( |
| init_value=0.0, |
| peak_value=0.01, |
| warmup_steps=2000, |
| decay_steps=48000, |
| end_value=1e-5, |
| ) |
| optimizer = optax.adam(learning_rate=lr_schedule) |
| opt_state = optimizer.init(params_jax) |
|
|
| best_c1_run = float('inf') |
| best_params_run = params_jax |
|
|
| for step in range(50000): |
| temp = min(50.0 + step * 150.0 / 50000, 200.0) |
|
|
| loss_val, grads = jax.value_and_grad(objective_smooth)(params_jax, temp) |
| updates, opt_state = optimizer.update(grads, opt_state, params_jax) |
| params_jax = optax.apply_updates(params_jax, updates) |
|
|
| if step % 5000 == 0: |
| hard_c1 = float(objective(params_jax)) |
| print(f" Step {step:5d} | C1(smooth)={float(loss_val):.8f} | C1(hard)={hard_c1:.8f}") |
| if hard_c1 < best_c1_run: |
| best_c1_run = hard_c1 |
| best_params_run = params_jax |
|
|
| hard_c1 = float(objective(params_jax)) |
| if hard_c1 < best_c1_run: |
| best_c1_run = hard_c1 |
| best_params_run = params_jax |
|
|
| |
| print(f"Phase 2: L-BFGS-B refinement (starting from C1={best_c1_run:.8f})...") |
| params_np = np.array(best_params_run) |
|
|
| for temp in [500.0, 1000.0, 2000.0]: |
| def scipy_obj(p): |
| p_jax = jnp.array(p) |
| val = float(objective_smooth(p_jax, temp)) |
| g = np.array(grad_fn(p_jax, temp)) |
| return val, g |
|
|
| result = scipy_minimize( |
| scipy_obj, |
| params_np, |
| method='L-BFGS-B', |
| jac=True, |
| options={'maxiter': 2000, 'ftol': 1e-15, 'gtol': 1e-10}, |
| ) |
| params_np = result.x |
| f_opt = np.exp(params_np) |
| c1 = compute_c1_numpy(f_opt, N) |
| print(f" temp={temp:.0f}: C1={c1:.10f}") |
|
|
| if c1 < best_c1_run: |
| best_c1_run = c1 |
| best_params_run = jnp.array(params_np) |
|
|
| |
| f_final = np.exp(np.array(best_params_run)) |
| c1_final = compute_c1_numpy(f_final, N) |
| print(f" Final C1 for this run: {c1_final:.10f}") |
|
|
| if c1_final < best_c1_overall: |
| best_c1_overall = c1_final |
| best_f_overall = f_final |
| best_n_overall = N |
| print(f"*** GLOBAL BEST: C1 = {c1_final:.10f}") |
|
|
| print(f"\n=== Final best C1: {best_c1_overall:.10f} ===") |
| return best_f_overall, best_c1_overall, best_c1_overall, best_n_overall |
|
|