| """ |
| Optimize step function heights using JAX. |
| K parameters (step heights) instead of N parameters. |
| Lower-dimensional search = better global exploration. |
| """ |
| import sys |
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| from scipy.optimize import minimize as scipy_minimize, differential_evolution |
| import optax |
|
|
|
|
| def compute_c1_numpy(f_values, n_points): |
| dx = 0.5 / n_points |
| f_nn = np.maximum(f_values, 0.0) |
| autoconv = np.convolve(f_nn, f_nn, mode='full') * dx |
| integral_sq = (np.sum(f_nn) * dx) ** 2 |
| if integral_sq < 1e-12: |
| return 1e10 |
| return np.max(autoconv) / integral_sq |
|
|
|
|
| def make_step_fns(K, N): |
| """Create functions for K-step function optimization at N resolution.""" |
| dx = 0.5 / N |
| |
| step_starts = jnp.array([int(i * N / K) for i in range(K)]) |
| step_ends = jnp.array([int((i + 1) * N / K) for i in range(K)]) |
|
|
| |
| mapping = jnp.zeros((K, N)) |
| for i in range(K): |
| s = int(i * N / K) |
| e = int((i + 1) * N / K) |
| mapping = mapping.at[i, s:e].set(1.0) |
|
|
| @jax.jit |
| def heights_to_f(log_heights): |
| heights = jnp.exp(jnp.clip(log_heights, -4, 4)) |
| f = jnp.dot(heights, mapping) |
| return f |
|
|
| @jax.jit |
| def obj_smooth(log_heights, temp): |
| f = heights_to_f(log_heights) |
| padded = jnp.zeros(2 * N) |
| padded = padded.at[:N].set(f) |
| fft_f = jnp.fft.rfft(padded) |
| conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx |
| integral_sq = (jnp.sum(f) * dx) ** 2 |
| smooth_max = jax.nn.logsumexp(temp * conv) / temp |
| return smooth_max / integral_sq |
|
|
| @jax.jit |
| def obj_hard(log_heights): |
| f = heights_to_f(log_heights) |
| padded = jnp.zeros(2 * N) |
| padded = padded.at[:N].set(f) |
| fft_f = jnp.fft.rfft(padded) |
| conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx |
| integral_sq = (jnp.sum(f) * dx) ** 2 |
| return jnp.max(conv) / integral_sq |
|
|
| grad_smooth = jax.jit(jax.grad(obj_smooth)) |
| grad_hard = jax.jit(jax.grad(obj_hard)) |
|
|
| return heights_to_f, obj_smooth, obj_hard, grad_smooth, grad_hard |
|
|
|
|
| def optimize_step(K, N, seed=42, adam_steps=100000): |
| heights_to_f, obj_smooth, obj_hard, grad_smooth, grad_hard = make_step_fns(K, N) |
|
|
| np.random.seed(seed) |
| |
| log_heights = jnp.array(np.random.randn(K) * 0.1) |
|
|
| |
| lr_schedule = optax.warmup_cosine_decay_schedule( |
| init_value=0.0, peak_value=0.01, warmup_steps=1000, |
| decay_steps=adam_steps - 1000, end_value=1e-6, |
| ) |
| optimizer = optax.adam(learning_rate=lr_schedule) |
| opt_state = optimizer.init(log_heights) |
|
|
| best_c1 = float('inf') |
| best_params = log_heights |
|
|
| for step in range(adam_steps): |
| temp = 300.0 |
| loss, grads = jax.value_and_grad(obj_smooth)(log_heights, temp) |
| updates, opt_state = optimizer.update(grads, opt_state, log_heights) |
| log_heights = optax.apply_updates(log_heights, updates) |
|
|
| if step % 25000 == 0 or step == adam_steps - 1: |
| hc = float(obj_hard(log_heights)) |
| if hc < best_c1: |
| best_c1 = hc |
| best_params = log_heights |
|
|
| |
| params_np = np.array(best_params, dtype=np.float64) |
|
|
| for temp in [1000.0, 10000.0]: |
| def scipy_obj(p): |
| p_jax = jnp.array(p) |
| val = float(obj_smooth(p_jax, temp)) |
| g = np.array(grad_smooth(p_jax, temp), dtype=np.float64) |
| return val, g |
| result = scipy_minimize( |
| scipy_obj, params_np, method='L-BFGS-B', jac=True, |
| options={'maxiter': 10000, 'ftol': 1e-15, 'gtol': 1e-14}, |
| ) |
| params_np = result.x |
|
|
| |
| def scipy_obj_hard(p): |
| p_jax = jnp.array(p) |
| val = float(obj_hard(p_jax)) |
| g = np.array(grad_hard(p_jax), dtype=np.float64) |
| return val, g |
| result = scipy_minimize( |
| scipy_obj_hard, params_np, method='L-BFGS-B', jac=True, |
| options={'maxiter': 20000, 'ftol': 1e-16, 'gtol': 1e-15}, |
| ) |
| params_np = result.x |
|
|
| f = np.array(heights_to_f(jnp.array(params_np))) |
| c1 = compute_c1_numpy(f, N) |
| return f, c1, params_np |
|
|
|
|
| def run(): |
| N = 4000 |
| best_c1 = float('inf') |
| best_f = None |
|
|
| |
| for K in [20, 30, 40, 50, 60, 80, 100, 150, 200, 300, 500, 800, 1000, 2000]: |
| best_k_c1 = float('inf') |
| for seed in range(5): |
| f, c1, params = optimize_step(K, N, seed=seed, adam_steps=50000) |
| if c1 < best_k_c1: |
| best_k_c1 = c1 |
|
|
| if c1 < best_c1: |
| best_c1 = c1 |
| best_f = f |
| sys.stdout.write(f"K={K:4d} seed={seed}: C1={c1:.10f} ***\n") |
| else: |
| if seed == 0 and c1 < best_k_c1 + 0.001: |
| sys.stdout.write(f"K={K:4d} seed={seed}: C1={c1:.10f}\n") |
| sys.stdout.flush() |
|
|
| |
| K_full = N |
| sys.stdout.write(f"\nFull resolution K={K_full}...\n") |
| sys.stdout.flush() |
| for seed in [15, 8, 2]: |
| f, c1, _ = optimize_step(K_full, N, seed=seed, adam_steps=80000) |
| sys.stdout.write(f"K={K_full} seed={seed}: C1={c1:.10f}\n") |
| sys.stdout.flush() |
| if c1 < best_c1: |
| best_c1 = c1 |
| best_f = f |
| sys.stdout.write(f"*** NEW GLOBAL BEST\n") |
| sys.stdout.flush() |
|
|
| sys.stdout.write(f"\nFinal C1: {best_c1:.10f}\n") |
| sys.stdout.flush() |
| return best_f, best_c1, best_c1, N |
|
|