| import jax |
| import jax.numpy as jnp |
| import optax |
| import numpy as np |
| from dataclasses import dataclass |
|
|
|
|
| @dataclass |
| class Hyperparameters: |
| num_intervals: int = 1000 |
| learning_rate: float = 0.01 |
| end_lr_factor: float = 1e-5 |
| num_steps: int = 80000 |
| warmup_steps: int = 3000 |
| smooth_max_temp: float = 100.0 |
|
|
|
|
| class AutocorrelationOptimizer: |
| def __init__(self, hypers: Hyperparameters): |
| self.hypers = hypers |
| self.domain_width = 0.5 |
| self.dx = self.domain_width / self.hypers.num_intervals |
|
|
| def _objective_fn(self, f_values: jnp.ndarray, temp: float) -> jnp.ndarray: |
| f_non_negative = jax.nn.softplus(f_values) |
| integral_f = jnp.sum(f_non_negative) * self.dx |
| eps = 1e-12 |
| integral_f_safe = jnp.maximum(integral_f, eps) |
|
|
| N = self.hypers.num_intervals |
| padded_f = jnp.pad(f_non_negative, (0, N)) |
| fft_f = jnp.fft.rfft(padded_f) |
| fft_conv = fft_f * fft_f |
| conv_f_f = jnp.fft.irfft(fft_conv, n=2 * N) |
|
|
| scaled_conv_f_f = conv_f_f * self.dx |
|
|
| |
| smooth_max = jax.nn.logsumexp(temp * scaled_conv_f_f) / temp |
|
|
| c1_ratio = smooth_max / (integral_f_safe ** 2) |
| return c1_ratio |
|
|
| def _hard_objective(self, f_values: jnp.ndarray) -> jnp.ndarray: |
| f_non_negative = jax.nn.softplus(f_values) |
| integral_f = jnp.sum(f_non_negative) * self.dx |
| eps = 1e-12 |
| integral_f_safe = jnp.maximum(integral_f, eps) |
|
|
| N = self.hypers.num_intervals |
| padded_f = jnp.pad(f_non_negative, (0, N)) |
| fft_f = jnp.fft.rfft(padded_f) |
| fft_conv = fft_f * fft_f |
| conv_f_f = jnp.fft.irfft(fft_conv, n=2 * N) |
| scaled_conv_f_f = conv_f_f * self.dx |
| max_conv = jnp.max(scaled_conv_f_f) |
| c1_ratio = max_conv / (integral_f_safe ** 2) |
| return c1_ratio |
|
|
| def train_step(self, f_values, opt_state, temp): |
| loss, grads = jax.value_and_grad(self._objective_fn)(f_values, temp) |
| updates, opt_state = self.optimizer.update(grads, opt_state, f_values) |
| f_values = optax.apply_updates(f_values, updates) |
| hard_loss = self._hard_objective(f_values) |
| return f_values, opt_state, loss, hard_loss |
|
|
| def run_optimization(self, seed=42): |
| schedule = optax.warmup_cosine_decay_schedule( |
| init_value=0.0, |
| peak_value=self.hypers.learning_rate, |
| warmup_steps=self.hypers.warmup_steps, |
| decay_steps=self.hypers.num_steps - self.hypers.warmup_steps, |
| end_value=self.hypers.learning_rate * self.hypers.end_lr_factor, |
| ) |
| self.optimizer = optax.adam(learning_rate=schedule) |
|
|
| key = jax.random.PRNGKey(seed) |
| N = self.hypers.num_intervals |
|
|
| |
| x = jnp.linspace(0, 1, N) |
| |
| f_values = jnp.exp(-20.0 * (x - 0.5) ** 2) |
| f_values = f_values + 0.05 * jax.random.uniform(key, (N,)) |
|
|
| opt_state = self.optimizer.init(f_values) |
|
|
| train_step_jit = jax.jit(self.train_step) |
|
|
| best_loss = jnp.inf |
| best_f = f_values |
|
|
| for step in range(self.hypers.num_steps): |
| |
| progress = min(step / (self.hypers.num_steps * 0.5), 1.0) |
| temp = 10.0 + progress * (self.hypers.smooth_max_temp - 10.0) |
|
|
| f_values, opt_state, loss, hard_loss = train_step_jit(f_values, opt_state, temp) |
|
|
| if hard_loss < best_loss: |
| best_loss = hard_loss |
| best_f = f_values |
|
|
| if step % 5000 == 0 or step == self.hypers.num_steps - 1: |
| print(f"Step {step:5d} | C1(smooth) ≈ {loss:.8f} | C1(hard) ≈ {hard_loss:.8f} | temp={temp:.1f}") |
|
|
| print(f"Best C1 found: {best_loss:.8f}") |
|
|
| |
| final_f = jax.nn.softplus(best_f) |
| final_c1 = float(best_loss) |
| return final_f, final_c1 |
|
|
|
|
| def run(): |
| best_c1 = float('inf') |
| best_result = None |
|
|
| |
| configs = [ |
| (1000, 0.01, 80000, 42), |
| (1000, 0.005, 80000, 123), |
| (1500, 0.008, 60000, 42), |
| ] |
|
|
| for n_intervals, lr, steps, seed in configs: |
| print(f"\n--- Config: N={n_intervals}, lr={lr}, steps={steps}, seed={seed} ---") |
| hypers = Hyperparameters( |
| num_intervals=n_intervals, |
| learning_rate=lr, |
| num_steps=steps, |
| ) |
| optimizer = AutocorrelationOptimizer(hypers) |
| optimized_f, final_c1 = optimizer.run_optimization(seed=seed) |
|
|
| if final_c1 < best_c1: |
| best_c1 = final_c1 |
| f_values_np = np.array(optimized_f) |
| best_result = (f_values_np, best_c1, best_c1, hypers.num_intervals) |
| print(f"*** New best: C1 = {best_c1:.10f}") |
|
|
| return best_result |
|
|