JustinTX's picture
Add files using upload-large-folder tool
2facf1f verified
# EVOLVE-BLOCK-START
import jax
import jax.numpy as jnp
import optax
import numpy as np
from dataclasses import dataclass
@dataclass
class Hyperparameters:
"""Hyperparameters for the optimization process."""
num_intervals: int = 600
learning_rate: float = 0.005
end_lr_factor: float = 1e-4
num_steps: int = 40000
warmup_steps: int = 2000
class AutocorrelationOptimizer:
"""
Optimizes a discretized function to find the minimal C1 constant.
"""
def __init__(self, hypers: Hyperparameters):
self.hypers = hypers
self.domain_width = 0.5
self.dx = self.domain_width / self.hypers.num_intervals
def _objective_fn(self, f_values: jnp.ndarray) -> jnp.ndarray:
"""
Computes the objective function, which is the C1 ratio.
We minimize this ratio to find a tight upper bound.
"""
f_non_negative = jax.nn.relu(f_values)
integral_f = jnp.sum(f_non_negative) * self.dx
eps = 1e-9
integral_f_safe = jnp.maximum(integral_f, eps)
N = self.hypers.num_intervals
padded_f = jnp.pad(f_non_negative, (0, N))
fft_f = jnp.fft.fft(padded_f)
fft_conv = fft_f * fft_f
conv_f_f = jnp.fft.ifft(fft_conv).real
# Scale by dx.
scaled_conv_f_f = conv_f_f * self.dx
max_conv = jnp.max(scaled_conv_f_f)
c1_ratio = max_conv / (integral_f_safe**2)
# Return the value to be MINIMIZED.
return c1_ratio
def train_step(self, f_values: jnp.ndarray, opt_state: optax.OptState) -> tuple:
"""Performs a single training step."""
loss, grads = jax.value_and_grad(self._objective_fn)(f_values)
updates, opt_state = self.optimizer.update(grads, opt_state, f_values)
f_values = optax.apply_updates(f_values, updates)
return f_values, opt_state, loss
def run_optimization(self):
"""Sets up and runs the full optimization process."""
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=self.hypers.learning_rate,
warmup_steps=self.hypers.warmup_steps,
decay_steps=self.hypers.num_steps - self.hypers.warmup_steps,
end_value=self.hypers.learning_rate * self.hypers.end_lr_factor,
)
self.optimizer = optax.adam(learning_rate=schedule)
key = jax.random.PRNGKey(42)
N = self.hypers.num_intervals
f_values = jnp.zeros((N,))
start_idx, end_idx = N // 4, 3 * N // 4
f_values = f_values.at[start_idx:end_idx].set(1.0)
f_values += 0.05 * jax.random.uniform(key, (N,))
opt_state = self.optimizer.init(f_values)
print(
f"Number of intervals (N): {self.hypers.num_intervals}, Steps: {self.hypers.num_steps}"
)
train_step_jit = jax.jit(self.train_step)
loss = jnp.inf # Initialize loss
for step in range(self.hypers.num_steps):
f_values, opt_state, loss = train_step_jit(f_values, opt_state)
if step % 2000 == 0 or step == self.hypers.num_steps - 1:
# CORRECTED PRINTING: Show the positive loss value directly.
print(f"Step {step:5d} | C1 ≈ {loss:.8f}")
print(f"Final C1 found: {loss:.8f}")
return jax.nn.relu(f_values), loss
def run():
"""Entry point for running the optimization and returning results."""
hypers = Hyperparameters()
optimizer = AutocorrelationOptimizer(hypers)
optimized_f, final_loss_val = optimizer.run_optimization()
final_c1 = float(final_loss_val)
f_values_np = np.array(optimized_f)
return f_values_np, final_c1, final_loss_val, hypers.num_intervals
# EVOLVE-BLOCK-END