File size: 5,572 Bytes
2facf1f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | """
Best strategy: Seed 80 at N=4000, upsample to N=8000,
then 50+ perturbation rounds with smart noise schedule.
"""
import sys
import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import minimize as scipy_minimize
import optax
def compute_c1_numpy(f_values, n_points):
dx = 0.5 / n_points
f_nn = np.maximum(f_values, 0.0)
autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
integral_sq = (np.sum(f_nn) * dx) ** 2
if integral_sq < 1e-12:
return 1e10
return np.max(autoconv) / integral_sq
def make_fns(N):
dx = 0.5 / N
@jax.jit
def obj_smooth(params, temp):
f = jnp.exp(jnp.clip(params, -8, 4))
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
smooth_max = jax.nn.logsumexp(temp * conv) / temp
return smooth_max / integral_sq
@jax.jit
def obj_hard(params):
f = jnp.exp(jnp.clip(params, -8, 4))
padded = jnp.zeros(2 * N)
padded = padded.at[:N].set(f)
fft_f = jnp.fft.rfft(padded)
conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
integral_sq = (jnp.sum(f) * dx) ** 2
return jnp.max(conv) / integral_sq
grad_smooth = jax.jit(jax.grad(obj_smooth))
grad_hard = jax.jit(jax.grad(obj_hard))
return obj_smooth, obj_hard, grad_smooth, grad_hard
def optimize(N, init_params_np, adam_steps=40000, lr=0.003, temp=300.0):
obj_smooth, obj_hard, grad_smooth, grad_hard = make_fns(N)
params = jnp.array(init_params_np)
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=lr, warmup_steps=min(1000, adam_steps//5),
decay_steps=adam_steps - min(1000, adam_steps//5), end_value=lr * 1e-4,
)
optimizer = optax.adam(learning_rate=lr_schedule)
opt_state = optimizer.init(params)
best_c1 = float('inf')
best_params = params
for step in range(adam_steps):
loss, grads = jax.value_and_grad(obj_smooth)(params, temp)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
if step >= adam_steps - 2000:
hc = float(obj_hard(params))
if hc < best_c1:
best_c1 = hc
best_params = params
params_np = np.array(best_params, dtype=np.float64)
for t in [1000.0, 10000.0]:
def scipy_obj(p):
p_jax = jnp.array(p)
val = float(obj_smooth(p_jax, t))
g = np.array(grad_smooth(p_jax, t), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_obj, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-14, 'maxcor': 100},
)
params_np = result.x
for _ in range(3):
def scipy_hard(p):
p_jax = jnp.array(p)
val = float(obj_hard(p_jax))
g = np.array(grad_hard(p_jax), dtype=np.float64)
return val, g
result = scipy_minimize(
scipy_hard, params_np, method='L-BFGS-B', jac=True,
options={'maxiter': 20000, 'ftol': 1e-16, 'gtol': 1e-15, 'maxcor': 100},
)
params_np = result.x
f = np.exp(np.clip(params_np, -8, 4))
c1 = compute_c1_numpy(f, N)
return params_np, f, c1
def run():
# Phase 1: Seed 80 at N=4000
N1 = 4000
np.random.seed(80)
init_f = np.ones(N1) * 0.5 + 0.02 * np.random.randn(N1)
init_params = np.log(np.maximum(init_f, 1e-6))
params, f, c1 = optimize(N1, init_params, adam_steps=80000, lr=0.005)
sys.stdout.write(f"Seed 80 N=4000: C1={c1:.10f}\n")
sys.stdout.flush()
# Phase 2: Upsample to N=8000
N2 = 8000
old_f = np.exp(np.clip(params, -8, 4))
new_f = np.interp(np.linspace(0, 1, N2), np.linspace(0, 1, N1), old_f)
new_params = np.log(np.maximum(new_f, 1e-6))
params, f, c1 = optimize(N2, new_params, adam_steps=40000, lr=0.002)
sys.stdout.write(f"Upsample N=8000: C1={c1:.10f}\n")
sys.stdout.flush()
best_params = params
best_f = f
best_c1 = c1
stale_count = 0
# Phase 3: Many perturbation restarts
for i in range(80):
if stale_count >= 15:
break # Stop if no improvement for 15 rounds
key = jax.random.PRNGKey(i * 31 + 11)
# Vary noise scale - occasionally try larger perturbations
if i % 10 == 9:
noise_scale = 0.15 # occasional large perturbation
elif i % 5 == 4:
noise_scale = 0.08
else:
noise_scale = 0.02 + 0.01 * (i % 3)
noise = noise_scale * jax.random.normal(key, shape=(N2,))
perturbed = best_params + np.array(noise)
steps = 15000 if noise_scale < 0.1 else 25000
p, f_p, c1_p = optimize(N2, perturbed, adam_steps=steps, lr=0.001)
improved = c1_p < best_c1
if improved:
best_c1 = c1_p
best_params = p
best_f = f_p
stale_count = 0
else:
stale_count += 1
if i % 5 == 0 or improved:
sys.stdout.write(f" P{i:2d} (s={noise_scale:.3f}): C1={c1_p:.10f}")
if improved:
sys.stdout.write(" ***")
sys.stdout.write(f" [best={best_c1:.10f}]\n")
sys.stdout.flush()
sys.stdout.write(f"\nFinal C1: {best_c1:.10f}\n")
sys.stdout.flush()
return best_f, best_c1, best_c1, N2
|