JustinTX's picture
Add files using upload-large-folder tool
2facf1f verified
"""
Optimize step function heights using JAX.
K parameters (step heights) instead of N parameters.
Lower-dimensional search = better global exploration.
"""
import sys
import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import minimize as scipy_minimize, differential_evolution
import optax
def compute_c1_numpy(f_values, n_points):
dx = 0.5 / n_points
f_nn = np.maximum(f_values, 0.0)
autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
integral_sq = (np.sum(f_nn) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(autoconv) / integral_sq
def make_step_fns(K, N):
"""Create functions for K-step function optimization at N resolution."""
dx = 0.5 / N
# Precompute indices for each step
step_starts = jnp.array([int(i * N / K) for i in range(K)])
step_ends = jnp.array([int((i + 1) * N / K) for i in range(K)])
# Create a mapping matrix: f[j] = heights[i] if j belongs to step i
mapping = jnp.zeros((K, N))
for i in range(K):
s = int(i * N / K)
e = int((i + 1) * N / K)
mapping = mapping.at[i, s:e].set(1.0)
@jax.jit
def heights_to_f(log_heights):
heights = jnp.exp(jnp.clip(log_heights, -4, 4))
f = jnp.dot(heights, mapping)
return f
@jax.jit
def obj_smooth(log_heights, temp):
f = heights_to_f(log_heights)
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
smooth_max = jax.nn.logsumexp(temp * conv) / temp
return smooth_max / integral_sq
@jax.jit
def obj_hard(log_heights):
f = heights_to_f(log_heights)
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
return jnp.max(conv) / integral_sq
grad_smooth = jax.jit(jax.grad(obj_smooth))
grad_hard = jax.jit(jax.grad(obj_hard))
return heights_to_f, obj_smooth, obj_hard, grad_smooth, grad_hard
def optimize_step(K, N, seed=42, adam_steps=100000):
heights_to_f, obj_smooth, obj_hard, grad_smooth, grad_hard = make_step_fns(K, N)
np.random.seed(seed)
# Start near constant function
log_heights = jnp.array(np.random.randn(K) * 0.1)
# Adam optimization
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=0.01, warmup_steps=1000,
decay_steps=adam_steps - 1000, end_value=1e-6,
)
optimizer = optax.adam(learning_rate=lr_schedule)
opt_state = optimizer.init(log_heights)
best_c1 = float('inf')
best_params = log_heights
for step in range(adam_steps):
temp = 300.0
loss, grads = jax.value_and_grad(obj_smooth)(log_heights, temp)
updates, opt_state = optimizer.update(grads, opt_state, log_heights)
log_heights = optax.apply_updates(log_heights, updates)
if step % 25000 == 0 or step == adam_steps - 1:
hc = float(obj_hard(log_heights))
if hc < best_c1:
best_c1 = hc
best_params = log_heights
# L-BFGS polish
params_np = np.array(best_params, dtype=np.float64)
for temp in [1000.0, 10000.0]:
def scipy_obj(p):
p_jax = jnp.array(p)
val = float(obj_smooth(p_jax, temp))
g = np.array(grad_smooth(p_jax, temp), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_obj, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 10000, 'ftol': 1e-15, 'gtol': 1e-14},
)
params_np = result.x
# Hard max L-BFGS
def scipy_obj_hard(p):
p_jax = jnp.array(p)
val = float(obj_hard(p_jax))
g = np.array(grad_hard(p_jax), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_obj_hard, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 20000, 'ftol': 1e-16, 'gtol': 1e-15},
)
params_np = result.x
f = np.array(heights_to_f(jnp.array(params_np)))
c1 = compute_c1_numpy(f, N)
return f, c1, params_np
def run():
N = 4000
best_c1 = float('inf')
best_f = None
# Try many different numbers of steps
for K in [20, 30, 40, 50, 60, 80, 100, 150, 200, 300, 500, 800, 1000, 2000]:
best_k_c1 = float('inf')
for seed in range(5):
f, c1, params = optimize_step(K, N, seed=seed, adam_steps=50000)
if c1 < best_k_c1:
best_k_c1 = c1
if c1 < best_c1:
best_c1 = c1
best_f = f
sys.stdout.write(f"K={K:4d} seed={seed}: C1={c1:.10f} ***\n")
else:
if seed == 0 and c1 < best_k_c1 + 0.001:
sys.stdout.write(f"K={K:4d} seed={seed}: C1={c1:.10f}\n")
sys.stdout.flush()
# Also try full resolution (K=N) with best seeds
K_full = N
sys.stdout.write(f"\nFull resolution K={K_full}...\n")
sys.stdout.flush()
for seed in [15, 8, 2]: # best seeds from v16
f, c1, _ = optimize_step(K_full, N, seed=seed, adam_steps=80000)
sys.stdout.write(f"K={K_full} seed={seed}: C1={c1:.10f}\n")
sys.stdout.flush()
if c1 < best_c1:
best_c1 = c1
best_f = f
sys.stdout.write(f"*** NEW GLOBAL BEST\n")
sys.stdout.flush()
sys.stdout.write(f"\nFinal C1: {best_c1:.10f}\n")
sys.stdout.flush()
return best_f, best_c1, best_c1, N