| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| from scipy.optimize import minimize as scipy_minimize |
| import optax |
|
|
|
|
| def compute_c1_numpy(f_values, n_points): |
| dx = 0.5 / n_points |
| f_nn = np.maximum(f_values, 0.0) |
| autoconv = np.convolve(f_nn, f_nn, mode='full') * dx |
| integral_sq = (np.sum(f_nn) * dx) ** 2 |
| if integral_sq < 1e-12: |
| return 1e10 |
| return np.max(autoconv) / integral_sq |
|
|
|
|
| def optimize_run(N, seed, adam_steps=100000, verbose=True): |
| dx = 0.5 / N |
|
|
| @jax.jit |
| def get_f(params): |
| return jax.nn.relu(params) |
|
|
| @jax.jit |
| def compute_conv(f): |
| padded = jnp.zeros(2 * N) |
| padded = padded.at[:N].set(f) |
| fft_f = jnp.fft.rfft(padded) |
| conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx |
| return conv |
|
|
| @jax.jit |
| def objective_reg(params, temp, lam): |
| """Smooth max + flatness regularization""" |
| f = get_f(params) |
| conv = compute_conv(f) |
| integral = jnp.sum(f) * dx |
| integral_sq = jnp.maximum(integral, 1e-9) ** 2 |
|
|
| |
| smooth_max = jax.nn.logsumexp(temp * conv) / temp |
|
|
| |
| |
| conv_mean = jnp.sum(conv) / (2 * N) |
| conv_var = jnp.sum((conv - conv_mean) ** 2) / (2 * N) |
|
|
| c1 = smooth_max / integral_sq |
| flatness_penalty = lam * conv_var / integral_sq ** 2 |
|
|
| return c1 + flatness_penalty |
|
|
| @jax.jit |
| def objective_hard(params): |
| f = get_f(params) |
| conv = compute_conv(f) |
| integral = jnp.sum(f) * dx |
| integral_sq = jnp.maximum(integral, 1e-9) ** 2 |
| return jnp.max(conv) / integral_sq |
|
|
| @jax.jit |
| def objective_smooth_only(params, temp): |
| f = get_f(params) |
| conv = compute_conv(f) |
| integral = jnp.sum(f) * dx |
| integral_sq = jnp.maximum(integral, 1e-9) ** 2 |
| smooth_max = jax.nn.logsumexp(temp * conv) / temp |
| return smooth_max / integral_sq |
|
|
| grad_smooth = jax.jit(jax.grad(objective_smooth_only)) |
| grad_reg = jax.jit(jax.grad(objective_reg)) |
|
|
| |
| np.random.seed(seed) |
| x = np.linspace(0, 1, N) |
|
|
| inits = { |
| 0: np.ones(N) * 0.5 + 0.02 * np.random.randn(N), |
| 1: np.exp(-10 * (x - 0.5) ** 2) + 0.1, |
| 2: np.exp(-5 * (x - 0.3) ** 2) + 0.05, |
| 3: np.exp(-5 * (x - 0.7) ** 2) + 0.05, |
| 4: 0.3 + 0.7 * np.sin(np.pi * x) ** 2 + 0.02 * np.random.randn(N), |
| 5: np.where(x < 0.5, 1.0, 0.3) + 0.02 * np.random.randn(N), |
| 6: np.where(x > 0.5, 1.0, 0.3) + 0.02 * np.random.randn(N), |
| 7: np.exp(-30 * (x - 0.5) ** 2) + 0.01, |
| 8: 0.5 * (1 + np.cos(4 * np.pi * x)) + 0.1 + 0.02 * np.random.randn(N), |
| 9: np.abs(np.random.randn(N)) * 0.3 + 0.1, |
| } |
| init_f = inits.get(seed % 10, np.ones(N) * 0.5) |
| init_f = np.maximum(init_f, 0.01) |
| params = jnp.array(init_f) |
|
|
| |
| lr_schedule = optax.warmup_cosine_decay_schedule( |
| init_value=0.0, peak_value=0.005, warmup_steps=2000, |
| decay_steps=adam_steps - 2000, end_value=1e-6, |
| ) |
| optimizer = optax.adam(learning_rate=lr_schedule) |
| opt_state = optimizer.init(params) |
|
|
| best_c1 = float('inf') |
| best_params = params |
|
|
| for step in range(adam_steps): |
| progress = min(step / adam_steps, 1.0) |
| temp = 100.0 + progress * 200.0 |
| |
| lam = 0.1 * max(1.0 - progress * 2, 0.0) |
|
|
| loss, grads = jax.value_and_grad(objective_reg)(params, temp, lam) |
| updates, opt_state = optimizer.update(grads, opt_state, params) |
| params = optax.apply_updates(params, updates) |
|
|
| if step % 20000 == 0 or step == adam_steps - 1: |
| hard_c1 = float(objective_hard(params)) |
| if verbose: |
| print(f" [{seed}] Step {step:6d} | C1={hard_c1:.8f}") |
| if hard_c1 < best_c1: |
| best_c1 = hard_c1 |
| best_params = params |
|
|
| |
| params_np = np.array(best_params, dtype=np.float64) |
| for temp_lbfgs in [1000.0, 5000.0, 20000.0]: |
| def scipy_obj(p): |
| p_jax = jnp.array(p) |
| val = float(objective_smooth_only(p_jax, temp_lbfgs)) |
| g = np.array(grad_smooth(p_jax, temp_lbfgs), dtype=np.float64) |
| return val, g |
|
|
| result = scipy_minimize( |
| scipy_obj, params_np, method='L-BFGS-B', jac=True, |
| bounds=[(0, None)] * N, |
| options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-12}, |
| ) |
| params_np = result.x |
| f_opt = np.maximum(params_np, 0.0) |
| c1 = compute_c1_numpy(f_opt, N) |
| if verbose: |
| print(f" [{seed}] L-BFGS temp={temp_lbfgs:.0f}: C1={c1:.10f}") |
| if c1 < best_c1: |
| best_c1 = c1 |
| best_params = jnp.array(params_np) |
|
|
| f_final = np.maximum(np.array(best_params), 0.0) |
| return f_final, best_c1 |
|
|
|
|
| def run(): |
| N = 3000 |
| best_c1 = float('inf') |
| best_f = None |
|
|
| for seed in range(10): |
| f, c1 = optimize_run(N, seed, adam_steps=80000) |
| print(f" Seed {seed}: C1={c1:.10f}") |
| if c1 < best_c1: |
| best_c1 = c1 |
| best_f = f |
| print(f" *** NEW GLOBAL BEST: C1={c1:.10f}") |
|
|
| print(f"\nFinal best C1: {best_c1:.10f}") |
| return best_f, best_c1, best_c1, N |
|
|