File size: 3,648 Bytes
2facf1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Quick test: single run, N=3000, 200k Adam steps, aggressive L-BFGS.
Uses PYTHONUNBUFFERED for output.
"""
import sys
import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import minimize as scipy_minimize
import optax


def compute_c1_numpy(f_values, n_points):
    dx = 0.5 / n_points
    f_nn = np.maximum(f_values, 0.0)
    autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
    integral_sq = (np.sum(f_nn) * dx) ** 2
    if integral_sq < 1e-12:
        return 1e10
    return np.max(autoconv) / integral_sq


def run():
    N = 3000
    dx = 0.5 / N

    @jax.jit
    def obj_smooth(params, temp):
        f = jnp.exp(jnp.clip(params, -8, 4))
        padded = jnp.zeros(2 * N)
        padded = padded.at[:N].set(f)
        fft_f = jnp.fft.rfft(padded)
        conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
        integral_sq = (jnp.sum(f) * dx) ** 2
        smooth_max = jax.nn.logsumexp(temp * conv) / temp
        return smooth_max / integral_sq

    @jax.jit
    def obj_hard(params):
        f = jnp.exp(jnp.clip(params, -8, 4))
        padded = jnp.zeros(2 * N)
        padded = padded.at[:N].set(f)
        fft_f = jnp.fft.rfft(padded)
        conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
        integral_sq = (jnp.sum(f) * dx) ** 2
        return jnp.max(conv) / integral_sq

    grad_fn = jax.jit(jax.grad(obj_smooth))

    best_c1_overall = float('inf')
    best_f_overall = None

    for seed in range(3):
        np.random.seed(seed)
        init_f = np.abs(np.random.randn(N)) * 0.3 + 0.2
        params = jnp.array(np.log(np.maximum(init_f, 1e-6)))

        adam_steps = 200000
        lr_schedule = optax.warmup_cosine_decay_schedule(
            init_value=0.0, peak_value=0.008, warmup_steps=3000,
            decay_steps=adam_steps - 3000, end_value=1e-7,
        )
        optimizer = optax.adam(learning_rate=lr_schedule)
        opt_state = optimizer.init(params)

        best_c1 = float('inf')
        best_params = params

        for step in range(adam_steps):
            loss, grads = jax.value_and_grad(obj_smooth)(params, 300.0)
            updates, opt_state = optimizer.update(grads, opt_state, params)
            params = optax.apply_updates(params, updates)

            if step % 25000 == 0 or step == adam_steps - 1:
                hc = float(obj_hard(params))
                sys.stdout.write(f"[{seed}] Step {step:7d} | C1={hc:.8f}\n")
                sys.stdout.flush()
                if hc < best_c1:
                    best_c1 = hc
                    best_params = params

        # L-BFGS polish
        params_np = np.array(best_params, dtype=np.float64)
        for temp_lbfgs in [500, 1000, 2000, 5000, 10000, 50000]:
            def scipy_obj(p):
                p_jax = jnp.array(p)
                val = float(obj_smooth(p_jax, float(temp_lbfgs)))
                g = np.array(grad_fn(p_jax, float(temp_lbfgs)), dtype=np.float64)
                return val, g

            result = scipy_minimize(
                scipy_obj, params_np, method='L-BFGS-B', jac=True,
                options={'maxiter': 10000, 'ftol': 1e-15, 'gtol': 1e-14, 'maxcor': 50},
            )
            params_np = result.x

        f_final = np.exp(np.clip(params_np, -8, 4))
        c1_final = compute_c1_numpy(f_final, N)
        sys.stdout.write(f"[{seed}] Final: C1={c1_final:.10f}\n")
        sys.stdout.flush()

        if c1_final < best_c1_overall:
            best_c1_overall = c1_final
            best_f_overall = f_final

    sys.stdout.write(f"Best C1: {best_c1_overall:.10f}\n")
    sys.stdout.flush()
    return best_f_overall, best_c1_overall, best_c1_overall, N