JustinTX's picture
Add files using upload-large-folder tool
2facf1f verified
import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import minimize as scipy_minimize
import optax
def compute_c1_numpy(f_values, n_points):
dx = 0.5 / n_points
f_nn = np.maximum(f_values, 0.0)
autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
integral_sq = (np.sum(f_nn) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(autoconv) / integral_sq
def optimize_cascade(verbose=True):
"""Optimize with cascading resolution: start coarse, upsample, refine."""
best_params_np = None
for stage, (N, adam_steps, lr_peak) in enumerate([
(1000, 60000, 0.01),
(2000, 60000, 0.005),
(4000, 80000, 0.003),
]):
dx = 0.5 / N
if verbose:
print(f"\n=== Stage {stage}: N={N}, steps={adam_steps} ===")
@jax.jit
def objective_smooth(params, temp):
f = jax.nn.softplus(params) # smooth non-negativity, allows zeros
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
smooth_max = jax.nn.logsumexp(temp * conv) / temp
return smooth_max / integral_sq
@jax.jit
def objective_hard(params):
f = jax.nn.softplus(params)
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
return jnp.max(conv) / integral_sq
grad_smooth = jax.jit(jax.grad(objective_smooth))
# Initialize
if best_params_np is None:
np.random.seed(42)
init_f = np.ones(N) * 0.5
init_f += 0.02 * np.random.randn(N)
# Inverse softplus
params = jnp.array(np.log(np.expm1(np.maximum(init_f, 1e-4))))
else:
# Upsample from previous stage
old_f = np.log1p(np.exp(best_params_np)) # softplus
new_f = np.interp(np.linspace(0, 1, N), np.linspace(0, 1, len(old_f)), old_f)
params = jnp.array(np.log(np.expm1(np.maximum(new_f, 1e-4))))
# Adam optimization
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=lr_peak, warmup_steps=2000,
decay_steps=adam_steps - 2000, end_value=lr_peak * 1e-4,
)
optimizer = optax.adam(learning_rate=lr_schedule)
opt_state = optimizer.init(params)
best_c1 = float('inf')
best_params = params
for step in range(adam_steps):
# Fixed high temperature
temp = 200.0
loss, grads = jax.value_and_grad(objective_smooth)(params, temp)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
if step % 10000 == 0 or step == adam_steps - 1:
hard_c1 = float(objective_hard(params))
if verbose:
print(f" Step {step:6d} | C1={hard_c1:.8f}")
if hard_c1 < best_c1:
best_c1 = hard_c1
best_params = params
# L-BFGS polishing
if verbose:
print(f" L-BFGS polishing from C1={best_c1:.8f}")
params_np = np.array(best_params, dtype=np.float64)
for temp in [500.0, 2000.0, 10000.0]:
def scipy_obj(p):
p_jax = jnp.array(p)
val = float(objective_smooth(p_jax, temp))
g = np.array(grad_smooth(p_jax, temp), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_obj, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-12},
)
params_np = result.x
f_opt = np.log1p(np.exp(params_np))
c1 = compute_c1_numpy(f_opt, N)
if verbose:
print(f" temp={temp:.0f}: C1={c1:.10f}")
if c1 < best_c1:
best_c1 = c1
best_params = jnp.array(params_np)
best_params_np = np.array(best_params)
if verbose:
print(f" Stage {stage} best: C1={best_c1:.10f}")
# Final result
f_final = np.log1p(np.exp(best_params_np))
c1_final = compute_c1_numpy(f_final, N)
return f_final, c1_final, N
def run():
f, c1, N = optimize_cascade(verbose=True)
print(f"\nFinal C1: {c1:.10f}")
return f, c1, c1, N