File size: 4,299 Bytes
2facf1f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import minimize as scipy_minimize
import optax
def compute_c1_numpy(f_values, n_points):
dx = 0.5 / n_points
f_nn = np.maximum(f_values, 0.0)
autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
integral_sq = (np.sum(f_nn) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(autoconv) / integral_sq
def optimize_single(N, seed, adam_steps=80000, verbose=True):
dx = 0.5 / N
@jax.jit
def objective_smooth(params, temp):
f = jnp.exp(jnp.clip(params, -10, 5))
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
smooth_max = jax.nn.logsumexp(temp * conv) / temp
return smooth_max / integral_sq
@jax.jit
def objective_hard(params):
f = jnp.exp(jnp.clip(params, -10, 5))
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
return jnp.max(conv) / integral_sq
grad_smooth = jax.jit(jax.grad(objective_smooth))
# Initialize
np.random.seed(seed)
x = np.linspace(0, 1, N)
# Use best-known type of initialization: broad bump
init = np.ones(N) * 0.5 + 0.05 * np.random.randn(N)
params = jnp.array(np.log(np.maximum(init, 1e-6)))
# Phase 1: Adam with increasing temperature
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=0.008, warmup_steps=2000,
decay_steps=adam_steps - 2000, end_value=1e-6,
)
optimizer = optax.adam(learning_rate=lr_schedule)
opt_state = optimizer.init(params)
best_c1 = float('inf')
best_params = params
for step in range(adam_steps):
# Temperature annealing: start moderate, end high
progress = min(step / (adam_steps * 0.7), 1.0)
temp = 20.0 + progress * 280.0 # 20 -> 300
loss, grads = jax.value_and_grad(objective_smooth)(params, temp)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
if step % 10000 == 0 or step == adam_steps - 1:
hard_c1 = float(objective_hard(params))
if verbose:
print(f" Step {step:6d} | C1={hard_c1:.8f} | temp={temp:.0f}")
if hard_c1 < best_c1:
best_c1 = hard_c1
best_params = params
# Phase 2: L-BFGS-B polishing with very high temperature
if verbose:
print(f" Phase 2: L-BFGS polishing from C1={best_c1:.8f}")
params_np = np.array(best_params)
for temp in [500.0, 1000.0, 5000.0]:
def scipy_obj(p):
p_jax = jnp.array(p)
val = float(objective_smooth(p_jax, temp))
g = np.array(grad_smooth(p_jax, temp), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_obj, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 3000, 'ftol': 1e-15, 'gtol': 1e-12},
)
params_np = result.x
f_opt = np.exp(np.clip(params_np, -10, 5))
c1 = compute_c1_numpy(f_opt, N)
if verbose:
print(f" temp={temp:.0f}: C1={c1:.10f}")
if c1 < best_c1:
best_c1 = c1
best_params = jnp.array(params_np)
f_final = np.exp(np.clip(np.array(best_params), -10, 5))
c1_final = compute_c1_numpy(f_final, N)
return f_final, c1_final
def run():
best_c1 = float('inf')
best_f = None
best_n = None
configs = [
(2000, 0, 100000),
(2000, 1, 100000),
(3000, 0, 80000),
]
for N, seed, steps in configs:
print(f"\n=== N={N}, seed={seed}, steps={steps} ===")
f, c1 = optimize_single(N, seed, adam_steps=steps)
print(f" Result: C1={c1:.10f}")
if c1 < best_c1:
best_c1 = c1
best_f = f
best_n = N
print(f" *** NEW GLOBAL BEST: C1={c1:.10f}")
print(f"\nFinal best C1: {best_c1:.10f}")
return best_f, best_c1, best_c1, best_n
|