File size: 8,973 Bytes
4256820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
"""
Sequential / Always-Valid Testing.

Reference: Johari, Pekelis, Walsh (2015) "Always Valid Inference:
Bringing Sequential Analysis to A/B Testing." arXiv:1512.04922.

Traditional A/B testing requires a fixed sample size decided *before* the
experiment. If you "peek" at the data and stop early when p < Ξ±, the actual
false-positive rate inflates to well above Ξ±.

The mixture Sequential Probability Ratio Test (mSPRT) solves this by
maintaining a martingale M_t whose expected value under H0 is 1. We reject
when M_t β‰₯ 1/Ξ± β€” valid at *any* stopping time, with E[false positives] ≀ Ξ±.
"""
from __future__ import annotations

import numpy as np
from dataclasses import dataclass, asdict
from typing import List, Optional


@dataclass
class SequentialResult:
    n_obs: int
    e_values: List[float]           # running mSPRT e-value
    rejected_at: Optional[int]      # first time M_t β‰₯ 1/Ξ±  (None if never)
    final_rejected: bool
    alpha: float
    threshold: float                # = 1 / alpha


@dataclass
class PeekingSimulation:
    """Results of simulating Ξ±-inflation from peeking."""
    n_experiments: int
    peek_schedule: List[int]        # sample sizes where we peek
    traditional_fpr: float          # false positive rate with peeking
    msprt_fpr: float                # false positive rate with mSPRT (β‰ˆ Ξ±)
    alpha: float


# ── Core mSPRT computation ─────────────────────────────────────────────────────

def _msprt_e_value(
    x_bar: float,        # sample mean of X_i (each X_i = treatment - control obs)
    n: int,              # number of paired observations seen so far
    sigma: float,        # estimated SD of the per-observation differences
    rho: float,          # mixing prior SD (hyperparameter; use rho=sigma for diffuse prior)
) -> float:
    """
    mSPRT e-value for the normal model at sample size n.

    M_n = sqrt(sigma^2 / (sigma^2 + n*rho^2))
           * exp(n^2 * x_bar^2 * rho^2 / (2*sigma^2*(sigma^2 + n*rho^2)))

    Under H0 (true mean = 0):  E[M_n] = 1  for all n.
    Under H1 (true mean β‰  0):  M_n β†’ ∞  (detection guaranteed).
    """
    sigma2 = sigma ** 2
    rho2   = rho   ** 2
    denom  = sigma2 + n * rho2

    factor1 = np.sqrt(sigma2 / denom)
    s       = n * x_bar          # sum of differences
    factor2 = np.exp(s ** 2 * rho2 / (2 * sigma2 * denom))

    return float(factor1 * factor2)


def sequential_test(
    diffs: List[float],   # per-observation treatment - control differences
    alpha: float = 0.05,
    rho_scale: float = 1.0,   # ρ = rho_scale * sigma
) -> SequentialResult:
    """
    Run mSPRT on a stream of paired differences.

    Parameters
    ----------
    diffs      : list of (X_treatment_i - X_control_i) values
    alpha      : significance level
    rho_scale  : mixing prior width relative to sigma
    """
    obs      = np.array(diffs, dtype=float)
    sigma    = float(obs.std()) if obs.std() > 0 else 1.0
    rho      = rho_scale * sigma
    threshold = 1.0 / alpha

    e_values    = []
    rejected_at = None

    for n in range(1, len(obs) + 1):
        x_bar = obs[:n].mean()
        ev = _msprt_e_value(x_bar, n, sigma, rho)
        e_values.append(round(ev, 4))

        if rejected_at is None and ev >= threshold:
            rejected_at = n

    return SequentialResult(
        n_obs=len(obs),
        e_values=e_values,
        rejected_at=rejected_at,
        final_rejected=rejected_at is not None,
        alpha=alpha,
        threshold=threshold,
    )


# ── Confidence sequence ────────────────────────────────────────────────────────

def confidence_sequence(
    obs: List[float],
    alpha: float = 0.05,
    sigma: float = None,
) -> tuple[List[float], List[float], List[float]]:
    """
    Anytime-valid confidence sequence for the mean.

    The width at sample size n is proportional to sqrt(log(log(n)) / n),
    wider than fixed CIs at small n (honest about uncertainty) but valid
    for ALL stopping rules simultaneously.

    Returns (means, lower_bounds, upper_bounds).
    """
    obs = np.array(obs, dtype=float)
    if sigma is None:
        sigma = float(obs.std()) if obs.std() > 0 else 1.0

    means, lowers, uppers = [], [], []
    for n in range(1, len(obs) + 1):
        xbar = float(obs[:n].mean())
        # Howard et al. (2021) – stitched normal mixture CS
        # Half-width: sigma * sqrt(2*(n+v)/(n^2*v) * log(sqrt((n+v)/v) / alpha))
        # with v = 1 (intrinsic time = 1)
        v = 1.0
        inner = max((n + v) / (n ** 2 * v) * np.log(np.sqrt((n + v) / v) / alpha), 0)
        hw = sigma * np.sqrt(2 * inner)
        means.append(round(xbar, 6))
        lowers.append(round(xbar - hw, 6))
        uppers.append(round(xbar + hw, 6))

    return means, lowers, uppers


# ── Peeking simulation ────────────────────────────────────────────────────────

def simulate_peeking(
    n_total: int = 1000,
    n_experiments: int = 2000,
    peek_fractions: List[float] = None,
    alpha: float = 0.05,
    sigma: float = 1.0,
    seed: int = 42,
) -> PeekingSimulation:
    """
    Simulate the Ξ±-inflation caused by peeking under the null hypothesis.

    Under H0 (no treatment effect), both methods are run:
      1. Traditional: reject if ANY peek gives p < Ξ±
      2. mSPRT:       reject when M_t β‰₯ 1/Ξ± at any peek

    Expected result:
      - Traditional FPR β‰ˆ 2Γ— to 3Γ— the nominal Ξ± (due to multiple comparisons)
      - mSPRT FPR β‰ˆ Ξ± (always-valid)
    """
    if peek_fractions is None:
        peek_fractions = [0.25, 0.50, 0.75, 1.00]

    rng = np.random.default_rng(seed)
    peek_schedule = [int(f * n_total) for f in peek_fractions]
    threshold = 1.0 / alpha

    traditional_rejects = 0
    msprt_rejects = 0

    for _ in range(n_experiments):
        # Generate n_total i.i.d. observations under H0 (true mean = 0)
        obs = rng.normal(0.0, sigma, n_total)
        sigma_hat = float(obs.std()) if obs.std() > 0 else 1.0
        rho = sigma_hat

        trad_rejected = False
        msprt_rejected = False

        for peek_n in peek_schedule:
            if peek_n == 0:
                continue
            window = obs[:peek_n]

            # Traditional: two-sided z-test
            if not trad_rejected:
                z = window.mean() / (sigma_hat / np.sqrt(peek_n))
                p = 2 * (1 - _norm_cdf(abs(z)))
                if p < alpha:
                    trad_rejected = True

            # mSPRT
            if not msprt_rejected:
                ev = _msprt_e_value(window.mean(), peek_n, sigma_hat, rho)
                if ev >= threshold:
                    msprt_rejected = True

        if trad_rejected:
            traditional_rejects += 1
        if msprt_rejected:
            msprt_rejects += 1

    return PeekingSimulation(
        n_experiments=n_experiments,
        peek_schedule=peek_schedule,
        traditional_fpr=round(traditional_rejects / n_experiments, 4),
        msprt_fpr=round(msprt_rejects / n_experiments, 4),
        alpha=alpha,
    )


def simulate_detection_speed(
    true_effect: float,
    n_max: int = 500,
    n_experiments: int = 1000,
    sigma: float = 1.0,
    alpha: float = 0.05,
    seed: int = 42,
) -> dict:
    """
    Simulate how quickly mSPRT detects a true effect vs. a fixed-n test.

    Returns distribution of stopping times for mSPRT (in number of observations).
    """
    rng = np.random.default_rng(seed)
    threshold = 1.0 / alpha
    stopping_times = []
    never_detected = 0

    for _ in range(n_experiments):
        obs = rng.normal(true_effect, sigma, n_max)
        sigma_hat = sigma  # known in this simulation
        rho = sigma_hat
        detected = False

        for n in range(1, n_max + 1):
            ev = _msprt_e_value(obs[:n].mean(), n, sigma_hat, rho)
            if ev >= threshold:
                stopping_times.append(n)
                detected = True
                break

        if not detected:
            never_detected += 1
            stopping_times.append(n_max)

    arr = np.array(stopping_times)
    return {
        "true_effect": true_effect,
        "alpha": alpha,
        "n_max": n_max,
        "n_experiments": n_experiments,
        "power": round(1 - never_detected / n_experiments, 4),
        "median_stopping_time": int(np.median(arr)),
        "p25_stopping_time": int(np.percentile(arr, 25)),
        "p75_stopping_time": int(np.percentile(arr, 75)),
        "stopping_times": arr.tolist(),
    }


def _norm_cdf(x: float) -> float:
    """Standard normal CDF (avoid scipy import in tight loop)."""
    from math import erfc, sqrt
    return 0.5 * erfc(-x / sqrt(2))