JustinTX's picture
Add files using upload-large-folder tool
2facf1f verified
"""
L-BFGS-B with JAX analytical gradients.
Many random restarts. Fast convergence.
"""
import numpy as np
from scipy.optimize import minimize
import jax
import jax.numpy as jnp
def compute_c1_numpy(f_values, n_points):
dx = 0.5 / n_points
f_nn = np.maximum(f_values, 0.0)
autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
integral_sq = (np.sum(f_nn) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(autoconv) / integral_sq
def make_objective(N, temp):
dx = 0.5 / N
@jax.jit
def obj_and_grad(params):
f = jnp.exp(jnp.clip(params, -8, 4))
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
smooth_max = jax.nn.logsumexp(temp * conv) / temp
return smooth_max / integral_sq
value_and_grad = jax.jit(jax.value_and_grad(obj_and_grad))
def scipy_wrapper(params_np):
p = jnp.array(params_np)
v, g = value_and_grad(p)
return float(v), np.array(g, dtype=np.float64)
return scipy_wrapper
def run():
best_c1 = float('inf')
best_f = None
best_n = None
for N in [1000, 2000, 4000]:
dx = 0.5 / N
print(f"\n=== N={N} ===")
# Generate many initializations
inits = []
x = np.linspace(0, 1, N)
for seed in range(30):
np.random.seed(seed)
if seed == 0:
f = np.ones(N)
elif seed == 1:
f = 0.5 + 0.5 * np.cos(2*np.pi*x) # U-shape
elif seed == 2:
f = np.exp(-10*(x-0.5)**2) + 0.1 # Gaussian
elif seed == 3:
f = np.exp(-5*(x-0.3)**2) + 0.05
elif seed == 4:
f = np.exp(-5*(x-0.7)**2) + 0.05
elif seed == 5:
f = 0.2 + 0.8 * np.cos(np.pi*x)**2 # centered bump
elif seed == 6:
f = 0.5 + 0.3*np.cos(2*np.pi*x) + 0.1*np.cos(4*np.pi*x)
elif seed == 7:
f = np.maximum(1 - 4*np.abs(x-0.5), 0) + 0.1 # triangle
elif seed == 8:
f = 0.3 + 0.7*x # linear increasing
elif seed == 9:
f = 0.3 + 0.7*(1-x) # linear decreasing
elif seed == 10:
f = 1 + 0.5*np.cos(6*np.pi*x) # wavy
elif seed == 11:
f = np.where((x > 0.15) & (x < 0.85), 1.0, 0.3) # wide box
elif seed == 12:
f = 0.5 + 0.3*np.cos(2*np.pi*x) + 0.15*np.cos(4*np.pi*x) + 0.05*np.cos(6*np.pi*x)
else:
f = np.abs(np.random.randn(N)) * 0.3 + 0.2
inits.append(np.maximum(f, 0.01))
# Run L-BFGS-B with increasing temperature
for init_idx, init_f in enumerate(inits):
params = np.log(np.maximum(init_f, 1e-6))
# Progressive temperature increase
for temp in [50.0, 200.0, 1000.0, 5000.0, 20000.0]:
obj_fn = make_objective(N, temp)
result = minimize(
obj_fn, params, method='L-BFGS-B', jac=True,
options={'maxiter': 3000, 'ftol': 1e-15, 'gtol': 1e-12},
)
params = result.x
f_opt = np.exp(np.clip(params, -8, 4))
c1 = compute_c1_numpy(f_opt, N)
if c1 < best_c1:
best_c1 = c1
best_f = f_opt
best_n = N
print(f" init={init_idx}: C1={c1:.10f} ***", flush=True)
elif init_idx < 13:
print(f" init={init_idx}: C1={c1:.10f}", flush=True)
print(f"\nFinal best C1: {best_c1:.10f}")
return best_f, best_c1, best_c1, best_n