JustinTX's picture
Add files using upload-large-folder tool
2facf1f verified
"""
Try squared parameterization (f = params^2) + high-temp smooth max.
Also try cosine restart schedule for better exploration.
"""
import sys
import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import minimize as scipy_minimize
import optax
def compute_c1_numpy(f_values, n_points):
dx = 0.5 / n_points
f_nn = np.maximum(f_values, 0.0)
autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
integral_sq = (np.sum(f_nn) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(autoconv) / integral_sq
def run():
N = 4000
dx = 0.5 / N
# Squared parameterization
@jax.jit
def obj_smooth_sq(params, temp):
f = params ** 2
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
smooth_max = jax.nn.logsumexp(temp * conv) / temp
return smooth_max / integral_sq
@jax.jit
def obj_hard_sq(params):
f = params ** 2
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
return jnp.max(conv) / integral_sq
# Exp parameterization (best so far)
@jax.jit
def obj_smooth_exp(params, temp):
f = jnp.exp(jnp.clip(params, -8, 4))
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
smooth_max = jax.nn.logsumexp(temp * conv) / temp
return smooth_max / integral_sq
@jax.jit
def obj_hard_exp(params):
f = jnp.exp(jnp.clip(params, -8, 4))
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
return jnp.max(conv) / integral_sq
grad_smooth_exp = jax.jit(jax.grad(obj_smooth_exp))
grad_hard_exp = jax.jit(jax.grad(obj_hard_exp))
grad_smooth_sq = jax.jit(jax.grad(obj_smooth_sq))
grad_hard_sq = jax.jit(jax.grad(obj_hard_sq))
best_c1_overall = float('inf')
best_f_overall = None
# Strategy 1: Exp parameterization with high temp (1000)
for seed in [100, 15, 8, 50, 150]:
np.random.seed(seed)
init_f = np.ones(N) * 0.5 + 0.02 * np.random.randn(N)
params = jnp.array(np.log(np.maximum(init_f, 1e-6)))
adam_steps = 150000
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=0.005, warmup_steps=2000,
decay_steps=adam_steps - 2000, end_value=1e-7,
)
optimizer = optax.adam(learning_rate=lr_schedule)
opt_state = optimizer.init(params)
best_c1 = float('inf')
best_params = params
for step in range(adam_steps):
# Use HIGH temp (1000) to reduce bias
loss, grads = jax.value_and_grad(obj_smooth_exp)(params, 1000.0)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
if step >= adam_steps - 1000:
hc = float(obj_hard_exp(params))
if hc < best_c1:
best_c1 = hc
best_params = params
# L-BFGS polish
params_np = np.array(best_params, dtype=np.float64)
for temp in [5000.0, 20000.0]:
def scipy_obj(p):
p_jax = jnp.array(p)
val = float(obj_smooth_exp(p_jax, temp))
g = np.array(grad_smooth_exp(p_jax, temp), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_obj, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-14, 'maxcor': 100},
)
params_np = result.x
# Hard L-BFGS
for _ in range(3):
def scipy_hard(p):
p_jax = jnp.array(p)
val = float(obj_hard_exp(p_jax))
g = np.array(grad_hard_exp(p_jax), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_hard, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 20000, 'ftol': 1e-16, 'gtol': 1e-15, 'maxcor': 100},
)
params_np = result.x
f_final = np.exp(np.clip(params_np, -8, 4))
c1_final = compute_c1_numpy(f_final, N)
sys.stdout.write(f"Exp temp=1000 seed={seed:4d}: C1={c1_final:.10f}")
if c1_final < best_c1_overall:
best_c1_overall = c1_final
best_f_overall = f_final
sys.stdout.write(" ***")
sys.stdout.write("\n")
sys.stdout.flush()
# Strategy 2: Squared parameterization
for seed in [100, 15, 8]:
np.random.seed(seed)
params = jnp.array(np.sqrt(np.ones(N) * 0.5 + 0.02 * np.abs(np.random.randn(N))))
adam_steps = 150000
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=0.005, warmup_steps=2000,
decay_steps=adam_steps - 2000, end_value=1e-7,
)
optimizer = optax.adam(learning_rate=lr_schedule)
opt_state = optimizer.init(params)
best_c1 = float('inf')
best_params = params
for step in range(adam_steps):
loss, grads = jax.value_and_grad(obj_smooth_sq)(params, 300.0)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
if step >= adam_steps - 1000:
hc = float(obj_hard_sq(params))
if hc < best_c1:
best_c1 = hc
best_params = params
# L-BFGS
params_np = np.array(best_params, dtype=np.float64)
for temp in [1000.0, 5000.0]:
def scipy_obj(p):
p_jax = jnp.array(p)
val = float(obj_smooth_sq(p_jax, temp))
g = np.array(grad_smooth_sq(p_jax, temp), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_obj, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-14},
)
params_np = result.x
for _ in range(3):
def scipy_hard(p):
p_jax = jnp.array(p)
val = float(obj_hard_sq(p_jax))
g = np.array(grad_hard_sq(p_jax), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_hard, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 20000, 'ftol': 1e-16, 'gtol': 1e-15},
)
params_np = result.x
f_final = np.array(jnp.array(params_np) ** 2)
c1_final = compute_c1_numpy(f_final, N)
sys.stdout.write(f"Sq temp=300 seed={seed:4d}: C1={c1_final:.10f}")
if c1_final < best_c1_overall:
best_c1_overall = c1_final
best_f_overall = f_final
sys.stdout.write(" ***")
sys.stdout.write("\n")
sys.stdout.flush()
sys.stdout.write(f"\nFinal C1: {best_c1_overall:.10f}\n")
sys.stdout.flush()
return best_f_overall, best_c1_overall, best_c1_overall, N