JustinTX's picture
Add files using upload-large-folder tool
2facf1f verified
"""
Extended Adam training (200k steps) + aggressive L-BFGS polish.
Multiple seeds, focus on N=3000.
"""
import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import minimize as scipy_minimize
import optax
def compute_c1_numpy(f_values, n_points):
dx = 0.5 / n_points
f_nn = np.maximum(f_values, 0.0)
autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
integral_sq = (np.sum(f_nn) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(autoconv) / integral_sq
def optimize_single(N, seed, adam_steps=200000):
dx = 0.5 / N
@jax.jit
def obj_smooth(params, temp):
f = jnp.exp(jnp.clip(params, -8, 4))
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
smooth_max = jax.nn.logsumexp(temp * conv) / temp
return smooth_max / integral_sq
@jax.jit
def obj_hard(params):
f = jnp.exp(jnp.clip(params, -8, 4))
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
return jnp.max(conv) / integral_sq
grad_fn = jax.jit(jax.grad(obj_smooth))
# Initialize
np.random.seed(seed)
init_f = np.abs(np.random.randn(N)) * 0.3 + 0.2
params = jnp.array(np.log(np.maximum(init_f, 1e-6)))
# Extended Adam with cosine schedule
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=0.008, warmup_steps=3000,
decay_steps=adam_steps - 3000, end_value=1e-7,
)
optimizer = optax.adam(learning_rate=lr_schedule)
opt_state = optimizer.init(params)
best_c1 = float('inf')
best_params = params
temp = 300.0
for step in range(adam_steps):
loss, grads = jax.value_and_grad(obj_smooth)(params, temp)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
if step % 50000 == 0 or step == adam_steps - 1:
hc = float(obj_hard(params))
print(f" [{seed}] Step {step:7d} | C1={hc:.8f}")
if hc < best_c1:
best_c1 = hc
best_params = params
# Aggressive L-BFGS polish
params_np = np.array(best_params, dtype=np.float64)
for temp_lbfgs in [500, 1000, 2000, 5000, 10000, 50000]:
def scipy_obj(p):
p_jax = jnp.array(p)
val = float(obj_smooth(p_jax, float(temp_lbfgs)))
g = np.array(grad_fn(p_jax, float(temp_lbfgs)), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_obj, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 10000, 'ftol': 1e-15, 'gtol': 1e-14, 'maxcor': 50},
)
params_np = result.x
f_final = np.exp(np.clip(params_np, -8, 4))
c1_final = compute_c1_numpy(f_final, N)
print(f" [{seed}] After L-BFGS: C1={c1_final:.10f}")
return f_final, c1_final, params_np
def run():
N = 3000
best_c1 = float('inf')
best_f = None
for seed in range(5):
f, c1, params = optimize_single(N, seed, adam_steps=150000)
if c1 < best_c1:
best_c1 = c1
best_f = f
print(f" *** GLOBAL BEST: C1={c1:.10f} (seed={seed})")
# Also try N=5000 with best seed's params upsampled
print(f"\nUpsampling to N=5000...")
N2 = 5000
dx2 = 0.5 / N2
f_up = np.interp(np.linspace(0, 1, N2), np.linspace(0, 1, N), best_f)
@jax.jit
def obj_smooth_5k(params, temp):
f = jnp.exp(jnp.clip(params, -8, 4))
padded = jnp.zeros(2 * N2)
padded = padded.at[:N2].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N2) * dx2
integral_sq = (jnp.sum(f) * dx2) ** 2
smooth_max = jax.nn.logsumexp(temp * conv) / temp
return smooth_max / integral_sq
grad_5k = jax.jit(jax.grad(obj_smooth_5k))
params_np = np.log(np.maximum(f_up, 1e-6))
for temp in [1000, 5000, 20000, 100000]:
def scipy_obj(p):
p_jax = jnp.array(p)
val = float(obj_smooth_5k(p_jax, float(temp)))
g = np.array(grad_5k(p_jax, float(temp)), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_obj, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 10000, 'ftol': 1e-15, 'gtol': 1e-14, 'maxcor': 50},
)
params_np = result.x
f_5k = np.exp(np.clip(params_np, -8, 4))
c1_5k = compute_c1_numpy(f_5k, N2)
print(f" temp={temp}: C1={c1_5k:.10f}")
if c1_5k < best_c1:
best_c1 = c1_5k
best_f = f_5k
N = N2
print(f"\nFinal best C1: {best_c1:.10f}")
return best_f, best_c1, best_c1, len(best_f)