JustinTX's picture
Add files using upload-large-folder tool
2facf1f verified
import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import minimize as scipy_minimize
import optax
def compute_c1_numpy(f_values, n_points):
dx = 0.5 / n_points
f_nn = np.maximum(f_values, 0.0)
autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
integral_sq = (np.sum(f_nn) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(autoconv) / integral_sq
def optimize_run(N, seed, adam_steps=100000, verbose=True):
dx = 0.5 / N
@jax.jit
def get_f(params):
return jax.nn.relu(params) # ReLU allows exact zeros
@jax.jit
def compute_conv(f):
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
return conv
@jax.jit
def objective_reg(params, temp, lam):
"""Smooth max + flatness regularization"""
f = get_f(params)
conv = compute_conv(f)
integral = jnp.sum(f) * dx
integral_sq = jnp.maximum(integral, 1e-9) ** 2
# Smooth max of convolution
smooth_max = jax.nn.logsumexp(temp * conv) / temp
# Flatness regularization: penalize variance of autoconvolution
# Only in the region where conv is significant
conv_mean = jnp.sum(conv) / (2 * N)
conv_var = jnp.sum((conv - conv_mean) ** 2) / (2 * N)
c1 = smooth_max / integral_sq
flatness_penalty = lam * conv_var / integral_sq ** 2
return c1 + flatness_penalty
@jax.jit
def objective_hard(params):
f = get_f(params)
conv = compute_conv(f)
integral = jnp.sum(f) * dx
integral_sq = jnp.maximum(integral, 1e-9) ** 2
return jnp.max(conv) / integral_sq
@jax.jit
def objective_smooth_only(params, temp):
f = get_f(params)
conv = compute_conv(f)
integral = jnp.sum(f) * dx
integral_sq = jnp.maximum(integral, 1e-9) ** 2
smooth_max = jax.nn.logsumexp(temp * conv) / temp
return smooth_max / integral_sq
grad_smooth = jax.jit(jax.grad(objective_smooth_only))
grad_reg = jax.jit(jax.grad(objective_reg))
# Initialize with diverse shapes
np.random.seed(seed)
x = np.linspace(0, 1, N)
inits = {
0: np.ones(N) * 0.5 + 0.02 * np.random.randn(N),
1: np.exp(-10 * (x - 0.5) ** 2) + 0.1,
2: np.exp(-5 * (x - 0.3) ** 2) + 0.05, # asymmetric
3: np.exp(-5 * (x - 0.7) ** 2) + 0.05, # asymmetric other way
4: 0.3 + 0.7 * np.sin(np.pi * x) ** 2 + 0.02 * np.random.randn(N),
5: np.where(x < 0.5, 1.0, 0.3) + 0.02 * np.random.randn(N),
6: np.where(x > 0.5, 1.0, 0.3) + 0.02 * np.random.randn(N),
7: np.exp(-30 * (x - 0.5) ** 2) + 0.01, # sharp peak
8: 0.5 * (1 + np.cos(4 * np.pi * x)) + 0.1 + 0.02 * np.random.randn(N),
9: np.abs(np.random.randn(N)) * 0.3 + 0.1,
}
init_f = inits.get(seed % 10, np.ones(N) * 0.5)
init_f = np.maximum(init_f, 0.01)
params = jnp.array(init_f)
# Adam optimization
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=0.005, warmup_steps=2000,
decay_steps=adam_steps - 2000, end_value=1e-6,
)
optimizer = optax.adam(learning_rate=lr_schedule)
opt_state = optimizer.init(params)
best_c1 = float('inf')
best_params = params
for step in range(adam_steps):
progress = min(step / adam_steps, 1.0)
temp = 100.0 + progress * 200.0
# Decrease flatness regularization over time
lam = 0.1 * max(1.0 - progress * 2, 0.0)
loss, grads = jax.value_and_grad(objective_reg)(params, temp, lam)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
if step % 20000 == 0 or step == adam_steps - 1:
hard_c1 = float(objective_hard(params))
if verbose:
print(f" [{seed}] Step {step:6d} | C1={hard_c1:.8f}")
if hard_c1 < best_c1:
best_c1 = hard_c1
best_params = params
# L-BFGS polishing (no regularization)
params_np = np.array(best_params, dtype=np.float64)
for temp_lbfgs in [1000.0, 5000.0, 20000.0]:
def scipy_obj(p):
p_jax = jnp.array(p)
val = float(objective_smooth_only(p_jax, temp_lbfgs))
g = np.array(grad_smooth(p_jax, temp_lbfgs), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_obj, params_np, method='L-BFGS-B', jac=True,
bounds=[(0, None)] * N, # Non-negativity constraint
options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-12},
)
params_np = result.x
f_opt = np.maximum(params_np, 0.0)
c1 = compute_c1_numpy(f_opt, N)
if verbose:
print(f" [{seed}] L-BFGS temp={temp_lbfgs:.0f}: C1={c1:.10f}")
if c1 < best_c1:
best_c1 = c1
best_params = jnp.array(params_np)
f_final = np.maximum(np.array(best_params), 0.0)
return f_final, best_c1
def run():
N = 3000
best_c1 = float('inf')
best_f = None
for seed in range(10):
f, c1 = optimize_run(N, seed, adam_steps=80000)
print(f" Seed {seed}: C1={c1:.10f}")
if c1 < best_c1:
best_c1 = c1
best_f = f
print(f" *** NEW GLOBAL BEST: C1={c1:.10f}")
print(f"\nFinal best C1: {best_c1:.10f}")
return best_f, best_c1, best_c1, N