File size: 5,567 Bytes
2facf1f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import minimize as scipy_minimize
import optax
def compute_c1_numpy(f_values, n_points):
dx = 0.5 / n_points
f_nn = np.maximum(f_values, 0.0)
autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
integral_sq = (np.sum(f_nn) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(autoconv) / integral_sq
def optimize_run(N, seed, adam_steps=100000, verbose=True):
dx = 0.5 / N
@jax.jit
def get_f(params):
return jax.nn.relu(params) # ReLU allows exact zeros
@jax.jit
def compute_conv(f):
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
return conv
@jax.jit
def objective_reg(params, temp, lam):
"""Smooth max + flatness regularization"""
f = get_f(params)
conv = compute_conv(f)
integral = jnp.sum(f) * dx
integral_sq = jnp.maximum(integral, 1e-9) ** 2
# Smooth max of convolution
smooth_max = jax.nn.logsumexp(temp * conv) / temp
# Flatness regularization: penalize variance of autoconvolution
# Only in the region where conv is significant
conv_mean = jnp.sum(conv) / (2 * N)
conv_var = jnp.sum((conv - conv_mean) ** 2) / (2 * N)
c1 = smooth_max / integral_sq
flatness_penalty = lam * conv_var / integral_sq ** 2
return c1 + flatness_penalty
@jax.jit
def objective_hard(params):
f = get_f(params)
conv = compute_conv(f)
integral = jnp.sum(f) * dx
integral_sq = jnp.maximum(integral, 1e-9) ** 2
return jnp.max(conv) / integral_sq
@jax.jit
def objective_smooth_only(params, temp):
f = get_f(params)
conv = compute_conv(f)
integral = jnp.sum(f) * dx
integral_sq = jnp.maximum(integral, 1e-9) ** 2
smooth_max = jax.nn.logsumexp(temp * conv) / temp
return smooth_max / integral_sq
grad_smooth = jax.jit(jax.grad(objective_smooth_only))
grad_reg = jax.jit(jax.grad(objective_reg))
# Initialize with diverse shapes
np.random.seed(seed)
x = np.linspace(0, 1, N)
inits = {
0: np.ones(N) * 0.5 + 0.02 * np.random.randn(N),
1: np.exp(-10 * (x - 0.5) ** 2) + 0.1,
2: np.exp(-5 * (x - 0.3) ** 2) + 0.05, # asymmetric
3: np.exp(-5 * (x - 0.7) ** 2) + 0.05, # asymmetric other way
4: 0.3 + 0.7 * np.sin(np.pi * x) ** 2 + 0.02 * np.random.randn(N),
5: np.where(x < 0.5, 1.0, 0.3) + 0.02 * np.random.randn(N),
6: np.where(x > 0.5, 1.0, 0.3) + 0.02 * np.random.randn(N),
7: np.exp(-30 * (x - 0.5) ** 2) + 0.01, # sharp peak
8: 0.5 * (1 + np.cos(4 * np.pi * x)) + 0.1 + 0.02 * np.random.randn(N),
9: np.abs(np.random.randn(N)) * 0.3 + 0.1,
}
init_f = inits.get(seed % 10, np.ones(N) * 0.5)
init_f = np.maximum(init_f, 0.01)
params = jnp.array(init_f)
# Adam optimization
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=0.005, warmup_steps=2000,
decay_steps=adam_steps - 2000, end_value=1e-6,
)
optimizer = optax.adam(learning_rate=lr_schedule)
opt_state = optimizer.init(params)
best_c1 = float('inf')
best_params = params
for step in range(adam_steps):
progress = min(step / adam_steps, 1.0)
temp = 100.0 + progress * 200.0
# Decrease flatness regularization over time
lam = 0.1 * max(1.0 - progress * 2, 0.0)
loss, grads = jax.value_and_grad(objective_reg)(params, temp, lam)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
if step % 20000 == 0 or step == adam_steps - 1:
hard_c1 = float(objective_hard(params))
if verbose:
print(f" [{seed}] Step {step:6d} | C1={hard_c1:.8f}")
if hard_c1 < best_c1:
best_c1 = hard_c1
best_params = params
# L-BFGS polishing (no regularization)
params_np = np.array(best_params, dtype=np.float64)
for temp_lbfgs in [1000.0, 5000.0, 20000.0]:
def scipy_obj(p):
p_jax = jnp.array(p)
val = float(objective_smooth_only(p_jax, temp_lbfgs))
g = np.array(grad_smooth(p_jax, temp_lbfgs), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_obj, params_np, method='L-BFGS-B', jac=True,
bounds=[(0, None)] * N, # Non-negativity constraint
options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-12},
)
params_np = result.x
f_opt = np.maximum(params_np, 0.0)
c1 = compute_c1_numpy(f_opt, N)
if verbose:
print(f" [{seed}] L-BFGS temp={temp_lbfgs:.0f}: C1={c1:.10f}")
if c1 < best_c1:
best_c1 = c1
best_params = jnp.array(params_np)
f_final = np.maximum(np.array(best_params), 0.0)
return f_final, best_c1
def run():
N = 3000
best_c1 = float('inf')
best_f = None
for seed in range(10):
f, c1 = optimize_run(N, seed, adam_steps=80000)
print(f" Seed {seed}: C1={c1:.10f}")
if c1 < best_c1:
best_c1 = c1
best_f = f
print(f" *** NEW GLOBAL BEST: C1={c1:.10f}")
print(f"\nFinal best C1: {best_c1:.10f}")
return best_f, best_c1, best_c1, N
|