File size: 5,603 Bytes
2facf1f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import minimize as scipy_minimize
import optax
def compute_c1_numpy(f_values, n_points):
"""Compute C1 using numpy (for verification)"""
dx = 0.5 / n_points
f_nn = np.maximum(f_values, 0.0)
autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
integral_sq = (np.sum(f_nn) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(autoconv) / integral_sq
def make_objective_jax(N, dx):
"""Create JAX objective function for C1 minimization"""
@jax.jit
def objective(params):
# Use exp parameterization for non-negativity
f = jnp.exp(params)
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N)
conv = conv * dx
integral = jnp.sum(f) * dx
integral_sq = integral ** 2
c1 = jnp.max(conv) / integral_sq
return c1
@jax.jit
def objective_smooth(params, temp):
f = jnp.exp(params)
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N)
conv = conv * dx
integral = jnp.sum(f) * dx
integral_sq = integral ** 2
smooth_max = jax.nn.logsumexp(temp * conv) / temp
c1 = smooth_max / integral_sq
return c1
grad_fn = jax.jit(jax.grad(objective_smooth))
return objective, objective_smooth, grad_fn
def run():
best_c1_overall = float('inf')
best_f_overall = None
best_n_overall = None
for N in [1000, 2000, 3000]:
dx = 0.5 / N
objective, objective_smooth, grad_fn = make_objective_jax(N, dx)
for seed in range(5):
print(f"\n--- N={N}, seed={seed} ---")
np.random.seed(seed)
# Initialize
x = np.linspace(0, 1, N)
if seed == 0:
init = np.exp(-10 * (x - 0.5) ** 2) + 0.1
elif seed == 1:
init = np.ones(N)
elif seed == 2:
init = 0.5 * (1 + np.cos(2 * np.pi * (x - 0.5))) + 0.1
elif seed == 3:
# Step function: higher in middle
init = np.where((x > 0.2) & (x < 0.8), 1.5, 0.5)
else:
init = np.abs(np.random.randn(N)) * 0.3 + 0.2
params = np.log(np.maximum(init, 1e-6))
# Phase 1: Adam optimization with smooth max (JAX)
print("Phase 1: Adam optimization...")
params_jax = jnp.array(params)
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=0.01,
warmup_steps=2000,
decay_steps=48000,
end_value=1e-5,
)
optimizer = optax.adam(learning_rate=lr_schedule)
opt_state = optimizer.init(params_jax)
best_c1_run = float('inf')
best_params_run = params_jax
for step in range(50000):
temp = min(50.0 + step * 150.0 / 50000, 200.0)
loss_val, grads = jax.value_and_grad(objective_smooth)(params_jax, temp)
updates, opt_state = optimizer.update(grads, opt_state, params_jax)
params_jax = optax.apply_updates(params_jax, updates)
if step % 5000 == 0:
hard_c1 = float(objective(params_jax))
print(f" Step {step:5d} | C1(smooth)={float(loss_val):.8f} | C1(hard)={hard_c1:.8f}")
if hard_c1 < best_c1_run:
best_c1_run = hard_c1
best_params_run = params_jax
hard_c1 = float(objective(params_jax))
if hard_c1 < best_c1_run:
best_c1_run = hard_c1
best_params_run = params_jax
# Phase 2: L-BFGS-B refinement with high temperature smooth max
print(f"Phase 2: L-BFGS-B refinement (starting from C1={best_c1_run:.8f})...")
params_np = np.array(best_params_run)
for temp in [500.0, 1000.0, 2000.0]:
def scipy_obj(p):
p_jax = jnp.array(p)
val = float(objective_smooth(p_jax, temp))
g = np.array(grad_fn(p_jax, temp))
return val, g
result = scipy_minimize(
scipy_obj,
params_np,
method='L-BFGS-B',
jac=True,
options={'maxiter': 2000, 'ftol': 1e-15, 'gtol': 1e-10},
)
params_np = result.x
f_opt = np.exp(params_np)
c1 = compute_c1_numpy(f_opt, N)
print(f" temp={temp:.0f}: C1={c1:.10f}")
if c1 < best_c1_run:
best_c1_run = c1
best_params_run = jnp.array(params_np)
# Final evaluation
f_final = np.exp(np.array(best_params_run))
c1_final = compute_c1_numpy(f_final, N)
print(f" Final C1 for this run: {c1_final:.10f}")
if c1_final < best_c1_overall:
best_c1_overall = c1_final
best_f_overall = f_final
best_n_overall = N
print(f"*** GLOBAL BEST: C1 = {c1_final:.10f}")
print(f"\n=== Final best C1: {best_c1_overall:.10f} ===")
return best_f_overall, best_c1_overall, best_c1_overall, best_n_overall
|