""" Adam with HARD max (no smoothing bias). JAX provides subgradient for jnp.max. """ import sys import jax import jax.numpy as jnp import numpy as np from scipy.optimize import minimize as scipy_minimize import optax def compute_c1_numpy(f_values, n_points): dx = 0.5 / n_points f_nn = np.maximum(f_values, 0.0) autoconv = np.convolve(f_nn, f_nn, mode='full') * dx integral_sq = (np.sum(f_nn) * dx) ** 2 if integral_sq < 1e-12: return 1e10 return np.max(autoconv) / integral_sq def run(): N = 4000 dx = 0.5 / N @jax.jit def obj_hard(params): f = jnp.exp(jnp.clip(params, -8, 4)) padded = jnp.zeros(2 * N) padded = padded.at[:N].set(f) fft_f = jnp.fft.rfft(padded) conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx integral_sq = (jnp.sum(f) * dx) ** 2 return jnp.max(conv) / integral_sq grad_hard = jax.jit(jax.grad(obj_hard)) best_c1_overall = float('inf') best_f_overall = None for seed in [100, 15, 8, 2, 17, 42, 7, 99, 50, 150, 250, 1000]: np.random.seed(seed) init_f = np.ones(N) * 0.5 + 0.02 * np.random.randn(N) params = jnp.array(np.log(np.maximum(init_f, 1e-6))) # Phase 1: Adam with HARD max adam_steps = 150000 lr_schedule = optax.warmup_cosine_decay_schedule( init_value=0.0, peak_value=0.005, warmup_steps=3000, decay_steps=adam_steps - 3000, end_value=1e-7, ) optimizer = optax.adam(learning_rate=lr_schedule) opt_state = optimizer.init(params) best_c1 = float('inf') best_params = params for step in range(adam_steps): loss, grads = jax.value_and_grad(obj_hard)(params) # Clip gradients for stability grads = jnp.clip(grads, -1.0, 1.0) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) if step % 50000 == 0: hc = float(loss) sys.stdout.write(f" [{seed:4d}] Step {step:7d} | C1={hc:.10f}\n") sys.stdout.flush() if step >= adam_steps - 1000: hc = float(obj_hard(params)) if hc < best_c1: best_c1 = hc best_params = params # Phase 2: Hard L-BFGS params_np = np.array(best_params, dtype=np.float64) for _ in range(5): def scipy_obj_hard(p): p_jax = jnp.array(p) val = float(obj_hard(p_jax)) g = np.array(grad_hard(p_jax), dtype=np.float64) return val, g result = scipy_minimize( scipy_obj_hard, params_np, method='L-BFGS-B', jac=True, options={'maxiter': 20000, 'ftol': 1e-16, 'gtol': 1e-15, 'maxcor': 100}, ) params_np = result.x f_final = np.exp(np.clip(params_np, -8, 4)) c1_final = compute_c1_numpy(f_final, N) sys.stdout.write(f"Seed {seed:4d}: C1={c1_final:.10f}") sys.stdout.flush() if c1_final < best_c1_overall: best_c1_overall = c1_final best_f_overall = f_final sys.stdout.write(" ***") np.save('/workspace/best_f.npy', f_final) sys.stdout.write("\n") sys.stdout.flush() sys.stdout.write(f"\nFinal C1: {best_c1_overall:.10f}\n") sys.stdout.flush() return best_f_overall, best_c1_overall, best_c1_overall, N