File size: 5,567 Bytes
2facf1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import minimize as scipy_minimize
import optax


def compute_c1_numpy(f_values, n_points):
    dx = 0.5 / n_points
    f_nn = np.maximum(f_values, 0.0)
    autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
    integral_sq = (np.sum(f_nn) * dx) ** 2
    if integral_sq < 1e-12:
        return 1e10
    return np.max(autoconv) / integral_sq


def optimize_run(N, seed, adam_steps=100000, verbose=True):
    dx = 0.5 / N

    @jax.jit
    def get_f(params):
        return jax.nn.relu(params)  # ReLU allows exact zeros

    @jax.jit
    def compute_conv(f):
        padded = jnp.zeros(2 * N)
        padded = padded.at[:N].set(f)
        fft_f = jnp.fft.rfft(padded)
        conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N) * dx
        return conv

    @jax.jit
    def objective_reg(params, temp, lam):
        """Smooth max + flatness regularization"""
        f = get_f(params)
        conv = compute_conv(f)
        integral = jnp.sum(f) * dx
        integral_sq = jnp.maximum(integral, 1e-9) ** 2

        # Smooth max of convolution
        smooth_max = jax.nn.logsumexp(temp * conv) / temp

        # Flatness regularization: penalize variance of autoconvolution
        # Only in the region where conv is significant
        conv_mean = jnp.sum(conv) / (2 * N)
        conv_var = jnp.sum((conv - conv_mean) ** 2) / (2 * N)

        c1 = smooth_max / integral_sq
        flatness_penalty = lam * conv_var / integral_sq ** 2

        return c1 + flatness_penalty

    @jax.jit
    def objective_hard(params):
        f = get_f(params)
        conv = compute_conv(f)
        integral = jnp.sum(f) * dx
        integral_sq = jnp.maximum(integral, 1e-9) ** 2
        return jnp.max(conv) / integral_sq

    @jax.jit
    def objective_smooth_only(params, temp):
        f = get_f(params)
        conv = compute_conv(f)
        integral = jnp.sum(f) * dx
        integral_sq = jnp.maximum(integral, 1e-9) ** 2
        smooth_max = jax.nn.logsumexp(temp * conv) / temp
        return smooth_max / integral_sq

    grad_smooth = jax.jit(jax.grad(objective_smooth_only))
    grad_reg = jax.jit(jax.grad(objective_reg))

    # Initialize with diverse shapes
    np.random.seed(seed)
    x = np.linspace(0, 1, N)

    inits = {
        0: np.ones(N) * 0.5 + 0.02 * np.random.randn(N),
        1: np.exp(-10 * (x - 0.5) ** 2) + 0.1,
        2: np.exp(-5 * (x - 0.3) ** 2) + 0.05,  # asymmetric
        3: np.exp(-5 * (x - 0.7) ** 2) + 0.05,  # asymmetric other way
        4: 0.3 + 0.7 * np.sin(np.pi * x) ** 2 + 0.02 * np.random.randn(N),
        5: np.where(x < 0.5, 1.0, 0.3) + 0.02 * np.random.randn(N),
        6: np.where(x > 0.5, 1.0, 0.3) + 0.02 * np.random.randn(N),
        7: np.exp(-30 * (x - 0.5) ** 2) + 0.01,  # sharp peak
        8: 0.5 * (1 + np.cos(4 * np.pi * x)) + 0.1 + 0.02 * np.random.randn(N),
        9: np.abs(np.random.randn(N)) * 0.3 + 0.1,
    }
    init_f = inits.get(seed % 10, np.ones(N) * 0.5)
    init_f = np.maximum(init_f, 0.01)
    params = jnp.array(init_f)

    # Adam optimization
    lr_schedule = optax.warmup_cosine_decay_schedule(
        init_value=0.0, peak_value=0.005, warmup_steps=2000,
        decay_steps=adam_steps - 2000, end_value=1e-6,
    )
    optimizer = optax.adam(learning_rate=lr_schedule)
    opt_state = optimizer.init(params)

    best_c1 = float('inf')
    best_params = params

    for step in range(adam_steps):
        progress = min(step / adam_steps, 1.0)
        temp = 100.0 + progress * 200.0
        # Decrease flatness regularization over time
        lam = 0.1 * max(1.0 - progress * 2, 0.0)

        loss, grads = jax.value_and_grad(objective_reg)(params, temp, lam)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)

        if step % 20000 == 0 or step == adam_steps - 1:
            hard_c1 = float(objective_hard(params))
            if verbose:
                print(f"  [{seed}] Step {step:6d} | C1={hard_c1:.8f}")
            if hard_c1 < best_c1:
                best_c1 = hard_c1
                best_params = params

    # L-BFGS polishing (no regularization)
    params_np = np.array(best_params, dtype=np.float64)
    for temp_lbfgs in [1000.0, 5000.0, 20000.0]:
        def scipy_obj(p):
            p_jax = jnp.array(p)
            val = float(objective_smooth_only(p_jax, temp_lbfgs))
            g = np.array(grad_smooth(p_jax, temp_lbfgs), dtype=np.float64)
            return val, g

        result = scipy_minimize(
            scipy_obj, params_np, method='L-BFGS-B', jac=True,
            bounds=[(0, None)] * N,  # Non-negativity constraint
            options={'maxiter': 5000, 'ftol': 1e-15, 'gtol': 1e-12},
        )
        params_np = result.x
        f_opt = np.maximum(params_np, 0.0)
        c1 = compute_c1_numpy(f_opt, N)
        if verbose:
            print(f"  [{seed}] L-BFGS temp={temp_lbfgs:.0f}: C1={c1:.10f}")
        if c1 < best_c1:
            best_c1 = c1
            best_params = jnp.array(params_np)

    f_final = np.maximum(np.array(best_params), 0.0)
    return f_final, best_c1


def run():
    N = 3000
    best_c1 = float('inf')
    best_f = None

    for seed in range(10):
        f, c1 = optimize_run(N, seed, adam_steps=80000)
        print(f"  Seed {seed}: C1={c1:.10f}")
        if c1 < best_c1:
            best_c1 = c1
            best_f = f
            print(f"  *** NEW GLOBAL BEST: C1={c1:.10f}")

    print(f"\nFinal best C1: {best_c1:.10f}")
    return best_f, best_c1, best_c1, N