File size: 3,344 Bytes
b0e88cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | # EVOLVE-BLOCK-START
import jax
import jax.numpy as jnp
import optax
import numpy as np
from dataclasses import dataclass
import tqdm
@dataclass
class Hyperparameters:
num_intervals: int = 200
learning_rate: float = 0.005
num_steps: int = 20000
penalty_strength: float = 1000000.0
class ErdosOptimizer:
"""
Finds a step function h that minimizes the maximum overlap integral.
"""
def __init__(self, hypers: Hyperparameters):
self.hypers = hypers
self.domain_width = 2.0
self.dx = self.domain_width / self.hypers.num_intervals
def _objective_fn(self, latent_h_values: jnp.ndarray) -> jnp.ndarray:
"""
The loss function includes the objective and a penalty for the constraint.
"""
# Enforce h(x) in [0, 1] via sigmoid (hard constraint)
h = jax.nn.sigmoid(latent_h_values)
# Calculate the primary objective (max correlation)
j = 1.0 - h
N = self.hypers.num_intervals
h_padded = jnp.pad(h, (0, N))
j_padded = jnp.pad(j, (0, N))
corr_fft = jnp.fft.fft(h_padded) * jnp.conj(jnp.fft.fft(j_padded))
correlation = jnp.fft.ifft(corr_fft).real
scaled_correlation = correlation * self.dx
objective_loss = jnp.max(scaled_correlation)
# Calculate the penalty for the integral constraint
integral_h = jnp.sum(h) * self.dx
constraint_loss = (integral_h - 1.0) ** 2
# Combine the objective with the penalty
total_loss = objective_loss + self.hypers.penalty_strength * constraint_loss
return total_loss
def run_optimization(self):
optimizer = optax.adam(self.hypers.learning_rate)
key = jax.random.PRNGKey(42)
latent_h_values = jax.random.normal(key, (self.hypers.num_intervals,))
opt_state = optimizer.init(latent_h_values)
@jax.jit
def train_step(latent_h_values, opt_state):
loss, grads = jax.value_and_grad(self._objective_fn)(latent_h_values)
updates, opt_state = optimizer.update(grads, opt_state)
latent_h_values = optax.apply_updates(latent_h_values, updates)
return latent_h_values, opt_state, loss
print(f"Optimizing a step function with {self.hypers.num_intervals} intervals...")
for step in tqdm.tqdm(range(self.hypers.num_steps), desc="Optimizing"):
latent_h_values, opt_state, loss = train_step(latent_h_values, opt_state)
# Final h is just the sigmoid of the latent values
final_h = jax.nn.sigmoid(latent_h_values)
# Re-calculate final objective loss without the penalty for the report
j = 1.0 - final_h
N = self.hypers.num_intervals
h_padded = jnp.pad(final_h, (0, N))
j_padded = jnp.pad(j, (0, N))
corr_fft = jnp.fft.fft(h_padded) * jnp.conj(jnp.fft.fft(j_padded))
correlation = jnp.fft.ifft(corr_fft).real
c5_bound = jnp.max(correlation * self.dx)
print(f"Optimization complete. Final C5 upper bound: {c5_bound:.8f}")
return np.array(final_h), float(c5_bound)
def run():
hypers = Hyperparameters()
optimizer = ErdosOptimizer(hypers)
final_h_values, c5_bound = optimizer.run_optimization()
return final_h_values, c5_bound, hypers.num_intervals
# EVOLVE-BLOCK-END
|