JustinTX's picture
Add files using upload-large-folder tool
2facf1f verified
"""
Adam with HARD max (no smoothing bias).
JAX provides subgradient for jnp.max.
"""
import sys
import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import minimize as scipy_minimize
import optax
def compute_c1_numpy(f_values, n_points):
dx = 0.5 / n_points
f_nn = np.maximum(f_values, 0.0)
autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
integral_sq = (np.sum(f_nn) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(autoconv) / integral_sq
def run():
N = 4000
dx = 0.5 / N
@jax.jit
def obj_hard(params):
f = jnp.exp(jnp.clip(params, -8, 4))
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
return jnp.max(conv) / integral_sq
grad_hard = jax.jit(jax.grad(obj_hard))
best_c1_overall = float('inf')
best_f_overall = None
for seed in [100, 15, 8, 2, 17, 42, 7, 99, 50, 150, 250, 1000]:
np.random.seed(seed)
init_f = np.ones(N) * 0.5 + 0.02 * np.random.randn(N)
params = jnp.array(np.log(np.maximum(init_f, 1e-6)))
# Phase 1: Adam with HARD max
adam_steps = 150000
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=0.005, warmup_steps=3000,
decay_steps=adam_steps - 3000, end_value=1e-7,
)
optimizer = optax.adam(learning_rate=lr_schedule)
opt_state = optimizer.init(params)
best_c1 = float('inf')
best_params = params
for step in range(adam_steps):
loss, grads = jax.value_and_grad(obj_hard)(params)
# Clip gradients for stability
grads = jnp.clip(grads, -1.0, 1.0)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
if step % 50000 == 0:
hc = float(loss)
sys.stdout.write(f" [{seed:4d}] Step {step:7d} | C1={hc:.10f}\n")
sys.stdout.flush()
if step >= adam_steps - 1000:
hc = float(obj_hard(params))
if hc < best_c1:
best_c1 = hc
best_params = params
# Phase 2: Hard L-BFGS
params_np = np.array(best_params, dtype=np.float64)
for _ in range(5):
def scipy_obj_hard(p):
p_jax = jnp.array(p)
val = float(obj_hard(p_jax))
g = np.array(grad_hard(p_jax), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_obj_hard, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 20000, 'ftol': 1e-16, 'gtol': 1e-15, 'maxcor': 100},
)
params_np = result.x
f_final = np.exp(np.clip(params_np, -8, 4))
c1_final = compute_c1_numpy(f_final, N)
sys.stdout.write(f"Seed {seed:4d}: C1={c1_final:.10f}")
sys.stdout.flush()
if c1_final < best_c1_overall:
best_c1_overall = c1_final
best_f_overall = f_final
sys.stdout.write(" ***")
np.save('/workspace/best_f.npy', f_final)
sys.stdout.write("\n")
sys.stdout.flush()
sys.stdout.write(f"\nFinal C1: {best_c1_overall:.10f}\n")
sys.stdout.flush()
return best_f_overall, best_c1_overall, best_c1_overall, N