File size: 4,299 Bytes
2facf1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import minimize as scipy_minimize
import optax


def compute_c1_numpy(f_values, n_points):
    dx = 0.5 / n_points
    f_nn = np.maximum(f_values, 0.0)
    autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
    integral_sq = (np.sum(f_nn) * dx) ** 2
    if integral_sq < 1e-12:
        return 1e10
    return np.max(autoconv) / integral_sq


def optimize_single(N, seed, adam_steps=80000, verbose=True):
    dx = 0.5 / N

    @jax.jit
    def objective_smooth(params, temp):
        f = jnp.exp(jnp.clip(params, -10, 5))
        padded = jnp.zeros(2 * N)
        padded = padded.at[:N].set(f)
        fft_f = jnp.fft.rfft(padded)
        conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
        integral_sq = (jnp.sum(f) * dx) ** 2
        smooth_max = jax.nn.logsumexp(temp * conv) / temp
        return smooth_max / integral_sq

    @jax.jit
    def objective_hard(params):
        f = jnp.exp(jnp.clip(params, -10, 5))
        padded = jnp.zeros(2 * N)
        padded = padded.at[:N].set(f)
        fft_f = jnp.fft.rfft(padded)
        conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
        integral_sq = (jnp.sum(f) * dx) ** 2
        return jnp.max(conv) / integral_sq

    grad_smooth = jax.jit(jax.grad(objective_smooth))

    # Initialize
    np.random.seed(seed)
    x = np.linspace(0, 1, N)
    # Use best-known type of initialization: broad bump
    init = np.ones(N) * 0.5 + 0.05 * np.random.randn(N)
    params = jnp.array(np.log(np.maximum(init, 1e-6)))

    # Phase 1: Adam with increasing temperature
    lr_schedule = optax.warmup_cosine_decay_schedule(
        init_value=0.0, peak_value=0.008, warmup_steps=2000,
        decay_steps=adam_steps - 2000, end_value=1e-6,
    )
    optimizer = optax.adam(learning_rate=lr_schedule)
    opt_state = optimizer.init(params)

    best_c1 = float('inf')
    best_params = params

    for step in range(adam_steps):
        # Temperature annealing: start moderate, end high
        progress = min(step / (adam_steps * 0.7), 1.0)
        temp = 20.0 + progress * 280.0  # 20 -> 300

        loss, grads = jax.value_and_grad(objective_smooth)(params, temp)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)

        if step % 10000 == 0 or step == adam_steps - 1:
            hard_c1 = float(objective_hard(params))
            if verbose:
                print(f"  Step {step:6d} | C1={hard_c1:.8f} | temp={temp:.0f}")
            if hard_c1 < best_c1:
                best_c1 = hard_c1
                best_params = params

    # Phase 2: L-BFGS-B polishing with very high temperature
    if verbose:
        print(f"  Phase 2: L-BFGS polishing from C1={best_c1:.8f}")

    params_np = np.array(best_params)
    for temp in [500.0, 1000.0, 5000.0]:
        def scipy_obj(p):
            p_jax = jnp.array(p)
            val = float(objective_smooth(p_jax, temp))
            g = np.array(grad_smooth(p_jax, temp), dtype=np.float64)
            return val, g

        result = scipy_minimize(
            scipy_obj, params_np, method='L-BFGS-B', jac=True,
            options={'maxiter': 3000, 'ftol': 1e-15, 'gtol': 1e-12},
        )
        params_np = result.x
        f_opt = np.exp(np.clip(params_np, -10, 5))
        c1 = compute_c1_numpy(f_opt, N)
        if verbose:
            print(f"    temp={temp:.0f}: C1={c1:.10f}")
        if c1 < best_c1:
            best_c1 = c1
            best_params = jnp.array(params_np)

    f_final = np.exp(np.clip(np.array(best_params), -10, 5))
    c1_final = compute_c1_numpy(f_final, N)
    return f_final, c1_final


def run():
    best_c1 = float('inf')
    best_f = None
    best_n = None

    configs = [
        (2000, 0, 100000),
        (2000, 1, 100000),
        (3000, 0, 80000),
    ]

    for N, seed, steps in configs:
        print(f"\n=== N={N}, seed={seed}, steps={steps} ===")
        f, c1 = optimize_single(N, seed, adam_steps=steps)
        print(f"  Result: C1={c1:.10f}")
        if c1 < best_c1:
            best_c1 = c1
            best_f = f
            best_n = N
            print(f"  *** NEW GLOBAL BEST: C1={c1:.10f}")

    print(f"\nFinal best C1: {best_c1:.10f}")
    return best_f, best_c1, best_c1, best_n