JustinTX's picture
Add files using upload-large-folder tool
2facf1f verified
"""
Key innovations:
1. Normalize f so integral=1 (removes scale degeneracy)
2. Use Lp norm as smooth max approximation
3. Multi-phase: start with low p, increase gradually
4. Many diverse random restarts
"""
import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import minimize as scipy_minimize
import optax
def compute_c1_numpy(f_values, n_points):
dx = 0.5 / n_points
f_nn = np.maximum(f_values, 0.0)
autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
integral_sq = (np.sum(f_nn) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(autoconv) / integral_sq
def make_fns(N):
dx = 0.5 / N
@jax.jit
def params_to_f(params):
"""Convert params to normalized non-negative function"""
f = jax.nn.softplus(params)
# Normalize so integral = 1
integral = jnp.sum(f) * dx
f = f / jnp.maximum(integral, 1e-9)
return f
@jax.jit
def compute_c1_smooth(params, p):
"""Lp norm approximation to C1"""
f = params_to_f(params)
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
# C1 = max(conv) since integral = 1
# Approximate with Lp norm
# (sum(conv^p) / (2N))^(1/p) -> max as p->inf
conv_pos = jnp.maximum(conv, 0.0)
lp = (jnp.mean(conv_pos ** p)) ** (1.0 / p)
return lp
@jax.jit
def compute_c1_logsumexp(params, temp):
"""LogSumExp approximation to C1"""
f = params_to_f(params)
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
return jax.nn.logsumexp(temp * conv) / temp
@jax.jit
def compute_c1_hard(params):
f = params_to_f(params)
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
return jnp.max(conv)
grad_lse = jax.jit(jax.grad(compute_c1_logsumexp))
return params_to_f, compute_c1_smooth, compute_c1_logsumexp, compute_c1_hard, grad_lse
def optimize_single(N, seed, steps=80000, verbose=True):
dx = 0.5 / N
params_to_f, c1_smooth, c1_lse, c1_hard, grad_lse = make_fns(N)
np.random.seed(seed)
x = np.linspace(0, 1, N)
# Diverse initializations
init_types = [
lambda: np.ones(N), # constant
lambda: np.exp(-10 * (x - 0.5)**2) + 0.1, # centered Gaussian
lambda: np.exp(-5 * (x - 0.3)**2) + 0.05, # left Gaussian
lambda: np.exp(-5 * (x - 0.7)**2) + 0.05, # right Gaussian
lambda: 0.5 + 0.5 * np.cos(2*np.pi*x), # cosine
lambda: np.maximum(1 - 4*np.abs(x - 0.5), 0) + 0.1, # triangle
lambda: np.where((x > 0.1) & (x < 0.9), 1.0, 0.1), # wide box
lambda: np.where((x > 0.2) & (x < 0.8), 1.0, 0.1), # medium box
lambda: np.exp(-50 * (x - 0.5)**2) + 0.01, # narrow Gaussian
lambda: np.abs(np.random.randn(N)) * 0.3 + 0.2, # random
lambda: 1 - 0.5 * np.abs(np.sin(3*np.pi*x)), # wavy
lambda: np.exp(-3*(x - 0.4)**2) + 0.5*np.exp(-3*(x-0.6)**2) + 0.05, # bimodal
]
init_f = init_types[seed % len(init_types)]()
init_f = np.maximum(init_f, 0.01)
# Inverse softplus
params = jnp.array(np.log(np.expm1(np.maximum(init_f, 1e-3))))
# Adam optimization with Lp norm, increasing p
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=0.005, warmup_steps=2000,
decay_steps=steps - 2000, end_value=1e-6,
)
optimizer = optax.adam(learning_rate=lr_schedule)
opt_state = optimizer.init(params)
best_c1 = float('inf')
best_params = params
for step in range(steps):
progress = min(step / steps, 1.0)
# Start with p=4, increase to p=40
p = 4.0 + progress * 36.0
loss, grads = jax.value_and_grad(c1_smooth)(params, p)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
if step % 20000 == 0 or step == steps - 1:
hc = float(c1_hard(params))
if verbose:
print(f" [{seed:2d}] Step {step:6d} | C1={hc:.8f} | p={p:.1f}")
if hc < best_c1:
best_c1 = hc
best_params = params
# L-BFGS polish with logsumexp
params_np = np.array(best_params, dtype=np.float64)
for temp in [500.0, 2000.0, 10000.0]:
def scipy_obj(p_arr):
p_jax = jnp.array(p_arr)
val = float(c1_lse(p_jax, temp))
g = np.array(grad_lse(p_jax, temp), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_obj, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 3000, 'ftol': 1e-15, 'gtol': 1e-12},
)
params_np = result.x
# Get final function values (unnormalized for evaluator)
f_norm = np.log1p(np.exp(params_np)) # softplus
f_norm = np.maximum(f_norm, 0.0)
c1_final = compute_c1_numpy(f_norm, N)
if verbose:
print(f" [{seed:2d}] Final: C1={c1_final:.10f}")
return f_norm, c1_final
def run():
N = 3000
best_c1 = float('inf')
best_f = None
for seed in range(12):
f, c1 = optimize_single(N, seed, steps=60000)
if c1 < best_c1:
best_c1 = c1
best_f = f
print(f" *** GLOBAL BEST: C1={c1:.10f} (seed={seed})")
# Take best and do extended polishing at higher N
print(f"\nExtended polishing at N=5000...")
N2 = 5000
# Upsample
f_up = np.interp(np.linspace(0, 1, N2), np.linspace(0, 1, len(best_f)), best_f)
f_up2, c1_up = optimize_single(N2, 99, steps=40000)
# Also polish the upsampled version
dx2 = 0.5 / N2
_, _, c1_lse2, c1_hard2, grad_lse2 = make_fns(N2)
params_up = jnp.array(np.log(np.expm1(np.maximum(f_up, 1e-3))))
params_np = np.array(params_up, dtype=np.float64)
for temp in [500.0, 2000.0, 10000.0, 50000.0]:
def scipy_obj(p_arr):
p_jax = jnp.array(p_arr)
val = float(c1_lse2(p_jax, temp))
g = np.array(grad_lse2(p_jax, temp), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_obj, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-12},
)
params_np = result.x
f_final = np.log1p(np.exp(params_np))
c1_final = compute_c1_numpy(f_final, N2)
print(f"Upsampled polished: C1={c1_final:.10f}")
if c1_final < best_c1:
best_c1 = c1_final
best_f = f_final
N = N2
print(f"\nFinal best C1: {best_c1:.10f}")
return best_f, best_c1, best_c1, len(best_f)