File size: 5,603 Bytes
2facf1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import minimize as scipy_minimize
import optax


def compute_c1_numpy(f_values, n_points):
    """Compute C1 using numpy (for verification)"""
    dx = 0.5 / n_points
    f_nn = np.maximum(f_values, 0.0)
    autoconv = np.convolve(f_nn, f_nn, mode='full') * dx
    integral_sq = (np.sum(f_nn) * dx) ** 2
    if integral_sq < 1e-12:
        return 1e10
    return np.max(autoconv) / integral_sq


def make_objective_jax(N, dx):
    """Create JAX objective function for C1 minimization"""

    @jax.jit
    def objective(params):
        # Use exp parameterization for non-negativity
        f = jnp.exp(params)

        padded = jnp.zeros(2 * N)
        padded = padded.at[:N].set(f)
        fft_f = jnp.fft.rfft(padded)
        conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N)
        conv = conv * dx

        integral = jnp.sum(f) * dx
        integral_sq = integral ** 2

        c1 = jnp.max(conv) / integral_sq
        return c1

    @jax.jit
    def objective_smooth(params, temp):
        f = jnp.exp(params)

        padded = jnp.zeros(2 * N)
        padded = padded.at[:N].set(f)
        fft_f = jnp.fft.rfft(padded)
        conv = jnp.fft.irfft(fft_f * fft_f, n=2 * N)
        conv = conv * dx

        integral = jnp.sum(f) * dx
        integral_sq = integral ** 2

        smooth_max = jax.nn.logsumexp(temp * conv) / temp
        c1 = smooth_max / integral_sq
        return c1

    grad_fn = jax.jit(jax.grad(objective_smooth))

    return objective, objective_smooth, grad_fn


def run():
    best_c1_overall = float('inf')
    best_f_overall = None
    best_n_overall = None

    for N in [1000, 2000, 3000]:
        dx = 0.5 / N
        objective, objective_smooth, grad_fn = make_objective_jax(N, dx)

        for seed in range(5):
            print(f"\n--- N={N}, seed={seed} ---")
            np.random.seed(seed)

            # Initialize
            x = np.linspace(0, 1, N)
            if seed == 0:
                init = np.exp(-10 * (x - 0.5) ** 2) + 0.1
            elif seed == 1:
                init = np.ones(N)
            elif seed == 2:
                init = 0.5 * (1 + np.cos(2 * np.pi * (x - 0.5))) + 0.1
            elif seed == 3:
                # Step function: higher in middle
                init = np.where((x > 0.2) & (x < 0.8), 1.5, 0.5)
            else:
                init = np.abs(np.random.randn(N)) * 0.3 + 0.2

            params = np.log(np.maximum(init, 1e-6))

            # Phase 1: Adam optimization with smooth max (JAX)
            print("Phase 1: Adam optimization...")
            params_jax = jnp.array(params)

            lr_schedule = optax.warmup_cosine_decay_schedule(
                init_value=0.0,
                peak_value=0.01,
                warmup_steps=2000,
                decay_steps=48000,
                end_value=1e-5,
            )
            optimizer = optax.adam(learning_rate=lr_schedule)
            opt_state = optimizer.init(params_jax)

            best_c1_run = float('inf')
            best_params_run = params_jax

            for step in range(50000):
                temp = min(50.0 + step * 150.0 / 50000, 200.0)

                loss_val, grads = jax.value_and_grad(objective_smooth)(params_jax, temp)
                updates, opt_state = optimizer.update(grads, opt_state, params_jax)
                params_jax = optax.apply_updates(params_jax, updates)

                if step % 5000 == 0:
                    hard_c1 = float(objective(params_jax))
                    print(f"  Step {step:5d} | C1(smooth)={float(loss_val):.8f} | C1(hard)={hard_c1:.8f}")
                    if hard_c1 < best_c1_run:
                        best_c1_run = hard_c1
                        best_params_run = params_jax

            hard_c1 = float(objective(params_jax))
            if hard_c1 < best_c1_run:
                best_c1_run = hard_c1
                best_params_run = params_jax

            # Phase 2: L-BFGS-B refinement with high temperature smooth max
            print(f"Phase 2: L-BFGS-B refinement (starting from C1={best_c1_run:.8f})...")
            params_np = np.array(best_params_run)

            for temp in [500.0, 1000.0, 2000.0]:
                def scipy_obj(p):
                    p_jax = jnp.array(p)
                    val = float(objective_smooth(p_jax, temp))
                    g = np.array(grad_fn(p_jax, temp))
                    return val, g

                result = scipy_minimize(
                    scipy_obj,
                    params_np,
                    method='L-BFGS-B',
                    jac=True,
                    options={'maxiter': 2000, 'ftol': 1e-15, 'gtol': 1e-10},
                )
                params_np = result.x
                f_opt = np.exp(params_np)
                c1 = compute_c1_numpy(f_opt, N)
                print(f"  temp={temp:.0f}: C1={c1:.10f}")

                if c1 < best_c1_run:
                    best_c1_run = c1
                    best_params_run = jnp.array(params_np)

            # Final evaluation
            f_final = np.exp(np.array(best_params_run))
            c1_final = compute_c1_numpy(f_final, N)
            print(f"  Final C1 for this run: {c1_final:.10f}")

            if c1_final < best_c1_overall:
                best_c1_overall = c1_final
                best_f_overall = f_final
                best_n_overall = N
                print(f"*** GLOBAL BEST: C1 = {c1_final:.10f}")

    print(f"\n=== Final best C1: {best_c1_overall:.10f} ===")
    return best_f_overall, best_c1_overall, best_c1_overall, best_n_overall