""" Best strategy: Seed 80 at N=4000, upsample to N=8000, then 50+ perturbation rounds with smart noise schedule. """ import sys import jax import jax.numpy as jnp import numpy as np from scipy.optimize import minimize as scipy_minimize import optax def compute_c1_numpy(f_values, n_points): dx = 0.5 / n_points f_nn = np.maximum(f_values, 0.0) autoconv = np.convolve(f_nn, f_nn, mode='full') * dx integral_sq = (np.sum(f_nn) * dx) ** 2 if integral_sq < 1e-12: return 1e10 return np.max(autoconv) / integral_sq def make_fns(N): dx = 0.5 / N @jax.jit def obj_smooth(params, temp): f = jnp.exp(jnp.clip(params, -8, 4)) padded = jnp.zeros(2 * N) padded = padded.at[:N].set(f) fft_f = jnp.fft.rfft(padded) conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx integral_sq = (jnp.sum(f) * dx) ** 2 smooth_max = jax.nn.logsumexp(temp * conv) / temp return smooth_max / integral_sq @jax.jit def obj_hard(params): f = jnp.exp(jnp.clip(params, -8, 4)) padded = jnp.zeros(2 * N) padded = padded.at[:N].set(f) fft_f = jnp.fft.rfft(padded) conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx integral_sq = (jnp.sum(f) * dx) ** 2 return jnp.max(conv) / integral_sq grad_smooth = jax.jit(jax.grad(obj_smooth)) grad_hard = jax.jit(jax.grad(obj_hard)) return obj_smooth, obj_hard, grad_smooth, grad_hard def optimize(N, init_params_np, adam_steps=40000, lr=0.003, temp=300.0): obj_smooth, obj_hard, grad_smooth, grad_hard = make_fns(N) params = jnp.array(init_params_np) lr_schedule = optax.warmup_cosine_decay_schedule( init_value=0.0, peak_value=lr, warmup_steps=min(1000, adam_steps//5), decay_steps=adam_steps - min(1000, adam_steps//5), end_value=lr * 1e-4, ) optimizer = optax.adam(learning_rate=lr_schedule) opt_state = optimizer.init(params) best_c1 = float('inf') best_params = params for step in range(adam_steps): loss, grads = jax.value_and_grad(obj_smooth)(params, temp) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) if step >= adam_steps - 2000: hc = float(obj_hard(params)) if hc < best_c1: best_c1 = hc best_params = params params_np = np.array(best_params, dtype=np.float64) for t in [1000.0, 10000.0]: def scipy_obj(p): p_jax = jnp.array(p) val = float(obj_smooth(p_jax, t)) g = np.array(grad_smooth(p_jax, t), dtype=np.float64) return val, g result = scipy_minimize( scipy_obj, params_np, method='L-BFGS-B', jac=True, options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-14, 'maxcor': 100}, ) params_np = result.x for _ in range(3): def scipy_hard(p): p_jax = jnp.array(p) val = float(obj_hard(p_jax)) g = np.array(grad_hard(p_jax), dtype=np.float64) return val, g result = scipy_minimize( scipy_hard, params_np, method='L-BFGS-B', jac=True, options={'maxiter': 20000, 'ftol': 1e-16, 'gtol': 1e-15, 'maxcor': 100}, ) params_np = result.x f = np.exp(np.clip(params_np, -8, 4)) c1 = compute_c1_numpy(f, N) return params_np, f, c1 def run(): # Phase 1: Seed 80 at N=4000 N1 = 4000 np.random.seed(80) init_f = np.ones(N1) * 0.5 + 0.02 * np.random.randn(N1) init_params = np.log(np.maximum(init_f, 1e-6)) params, f, c1 = optimize(N1, init_params, adam_steps=80000, lr=0.005) sys.stdout.write(f"Seed 80 N=4000: C1={c1:.10f}\n") sys.stdout.flush() # Phase 2: Upsample to N=8000 N2 = 8000 old_f = np.exp(np.clip(params, -8, 4)) new_f = np.interp(np.linspace(0, 1, N2), np.linspace(0, 1, N1), old_f) new_params = np.log(np.maximum(new_f, 1e-6)) params, f, c1 = optimize(N2, new_params, adam_steps=40000, lr=0.002) sys.stdout.write(f"Upsample N=8000: C1={c1:.10f}\n") sys.stdout.flush() best_params = params best_f = f best_c1 = c1 stale_count = 0 # Phase 3: Many perturbation restarts for i in range(80): if stale_count >= 15: break # Stop if no improvement for 15 rounds key = jax.random.PRNGKey(i * 31 + 11) # Vary noise scale - occasionally try larger perturbations if i % 10 == 9: noise_scale = 0.15 # occasional large perturbation elif i % 5 == 4: noise_scale = 0.08 else: noise_scale = 0.02 + 0.01 * (i % 3) noise = noise_scale * jax.random.normal(key, shape=(N2,)) perturbed = best_params + np.array(noise) steps = 15000 if noise_scale < 0.1 else 25000 p, f_p, c1_p = optimize(N2, perturbed, adam_steps=steps, lr=0.001) improved = c1_p < best_c1 if improved: best_c1 = c1_p best_params = p best_f = f_p stale_count = 0 else: stale_count += 1 if i % 5 == 0 or improved: sys.stdout.write(f" P{i:2d} (s={noise_scale:.3f}): C1={c1_p:.10f}") if improved: sys.stdout.write(" ***") sys.stdout.write(f" [best={best_c1:.10f}]\n") sys.stdout.flush() sys.stdout.write(f"\nFinal C1: {best_c1:.10f}\n") sys.stdout.flush() return best_f, best_c1, best_c1, N2