Spaces:
Running
Running
File size: 8,973 Bytes
4256820 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 | """
Sequential / Always-Valid Testing.
Reference: Johari, Pekelis, Walsh (2015) "Always Valid Inference:
Bringing Sequential Analysis to A/B Testing." arXiv:1512.04922.
Traditional A/B testing requires a fixed sample size decided *before* the
experiment. If you "peek" at the data and stop early when p < Ξ±, the actual
false-positive rate inflates to well above Ξ±.
The mixture Sequential Probability Ratio Test (mSPRT) solves this by
maintaining a martingale M_t whose expected value under H0 is 1. We reject
when M_t β₯ 1/Ξ± β valid at *any* stopping time, with E[false positives] β€ Ξ±.
"""
from __future__ import annotations
import numpy as np
from dataclasses import dataclass, asdict
from typing import List, Optional
@dataclass
class SequentialResult:
n_obs: int
e_values: List[float] # running mSPRT e-value
rejected_at: Optional[int] # first time M_t β₯ 1/Ξ± (None if never)
final_rejected: bool
alpha: float
threshold: float # = 1 / alpha
@dataclass
class PeekingSimulation:
"""Results of simulating Ξ±-inflation from peeking."""
n_experiments: int
peek_schedule: List[int] # sample sizes where we peek
traditional_fpr: float # false positive rate with peeking
msprt_fpr: float # false positive rate with mSPRT (β Ξ±)
alpha: float
# ββ Core mSPRT computation βββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _msprt_e_value(
x_bar: float, # sample mean of X_i (each X_i = treatment - control obs)
n: int, # number of paired observations seen so far
sigma: float, # estimated SD of the per-observation differences
rho: float, # mixing prior SD (hyperparameter; use rho=sigma for diffuse prior)
) -> float:
"""
mSPRT e-value for the normal model at sample size n.
M_n = sqrt(sigma^2 / (sigma^2 + n*rho^2))
* exp(n^2 * x_bar^2 * rho^2 / (2*sigma^2*(sigma^2 + n*rho^2)))
Under H0 (true mean = 0): E[M_n] = 1 for all n.
Under H1 (true mean β 0): M_n β β (detection guaranteed).
"""
sigma2 = sigma ** 2
rho2 = rho ** 2
denom = sigma2 + n * rho2
factor1 = np.sqrt(sigma2 / denom)
s = n * x_bar # sum of differences
factor2 = np.exp(s ** 2 * rho2 / (2 * sigma2 * denom))
return float(factor1 * factor2)
def sequential_test(
diffs: List[float], # per-observation treatment - control differences
alpha: float = 0.05,
rho_scale: float = 1.0, # Ο = rho_scale * sigma
) -> SequentialResult:
"""
Run mSPRT on a stream of paired differences.
Parameters
----------
diffs : list of (X_treatment_i - X_control_i) values
alpha : significance level
rho_scale : mixing prior width relative to sigma
"""
obs = np.array(diffs, dtype=float)
sigma = float(obs.std()) if obs.std() > 0 else 1.0
rho = rho_scale * sigma
threshold = 1.0 / alpha
e_values = []
rejected_at = None
for n in range(1, len(obs) + 1):
x_bar = obs[:n].mean()
ev = _msprt_e_value(x_bar, n, sigma, rho)
e_values.append(round(ev, 4))
if rejected_at is None and ev >= threshold:
rejected_at = n
return SequentialResult(
n_obs=len(obs),
e_values=e_values,
rejected_at=rejected_at,
final_rejected=rejected_at is not None,
alpha=alpha,
threshold=threshold,
)
# ββ Confidence sequence ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def confidence_sequence(
obs: List[float],
alpha: float = 0.05,
sigma: float = None,
) -> tuple[List[float], List[float], List[float]]:
"""
Anytime-valid confidence sequence for the mean.
The width at sample size n is proportional to sqrt(log(log(n)) / n),
wider than fixed CIs at small n (honest about uncertainty) but valid
for ALL stopping rules simultaneously.
Returns (means, lower_bounds, upper_bounds).
"""
obs = np.array(obs, dtype=float)
if sigma is None:
sigma = float(obs.std()) if obs.std() > 0 else 1.0
means, lowers, uppers = [], [], []
for n in range(1, len(obs) + 1):
xbar = float(obs[:n].mean())
# Howard et al. (2021) β stitched normal mixture CS
# Half-width: sigma * sqrt(2*(n+v)/(n^2*v) * log(sqrt((n+v)/v) / alpha))
# with v = 1 (intrinsic time = 1)
v = 1.0
inner = max((n + v) / (n ** 2 * v) * np.log(np.sqrt((n + v) / v) / alpha), 0)
hw = sigma * np.sqrt(2 * inner)
means.append(round(xbar, 6))
lowers.append(round(xbar - hw, 6))
uppers.append(round(xbar + hw, 6))
return means, lowers, uppers
# ββ Peeking simulation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def simulate_peeking(
n_total: int = 1000,
n_experiments: int = 2000,
peek_fractions: List[float] = None,
alpha: float = 0.05,
sigma: float = 1.0,
seed: int = 42,
) -> PeekingSimulation:
"""
Simulate the Ξ±-inflation caused by peeking under the null hypothesis.
Under H0 (no treatment effect), both methods are run:
1. Traditional: reject if ANY peek gives p < Ξ±
2. mSPRT: reject when M_t β₯ 1/Ξ± at any peek
Expected result:
- Traditional FPR β 2Γ to 3Γ the nominal Ξ± (due to multiple comparisons)
- mSPRT FPR β Ξ± (always-valid)
"""
if peek_fractions is None:
peek_fractions = [0.25, 0.50, 0.75, 1.00]
rng = np.random.default_rng(seed)
peek_schedule = [int(f * n_total) for f in peek_fractions]
threshold = 1.0 / alpha
traditional_rejects = 0
msprt_rejects = 0
for _ in range(n_experiments):
# Generate n_total i.i.d. observations under H0 (true mean = 0)
obs = rng.normal(0.0, sigma, n_total)
sigma_hat = float(obs.std()) if obs.std() > 0 else 1.0
rho = sigma_hat
trad_rejected = False
msprt_rejected = False
for peek_n in peek_schedule:
if peek_n == 0:
continue
window = obs[:peek_n]
# Traditional: two-sided z-test
if not trad_rejected:
z = window.mean() / (sigma_hat / np.sqrt(peek_n))
p = 2 * (1 - _norm_cdf(abs(z)))
if p < alpha:
trad_rejected = True
# mSPRT
if not msprt_rejected:
ev = _msprt_e_value(window.mean(), peek_n, sigma_hat, rho)
if ev >= threshold:
msprt_rejected = True
if trad_rejected:
traditional_rejects += 1
if msprt_rejected:
msprt_rejects += 1
return PeekingSimulation(
n_experiments=n_experiments,
peek_schedule=peek_schedule,
traditional_fpr=round(traditional_rejects / n_experiments, 4),
msprt_fpr=round(msprt_rejects / n_experiments, 4),
alpha=alpha,
)
def simulate_detection_speed(
true_effect: float,
n_max: int = 500,
n_experiments: int = 1000,
sigma: float = 1.0,
alpha: float = 0.05,
seed: int = 42,
) -> dict:
"""
Simulate how quickly mSPRT detects a true effect vs. a fixed-n test.
Returns distribution of stopping times for mSPRT (in number of observations).
"""
rng = np.random.default_rng(seed)
threshold = 1.0 / alpha
stopping_times = []
never_detected = 0
for _ in range(n_experiments):
obs = rng.normal(true_effect, sigma, n_max)
sigma_hat = sigma # known in this simulation
rho = sigma_hat
detected = False
for n in range(1, n_max + 1):
ev = _msprt_e_value(obs[:n].mean(), n, sigma_hat, rho)
if ev >= threshold:
stopping_times.append(n)
detected = True
break
if not detected:
never_detected += 1
stopping_times.append(n_max)
arr = np.array(stopping_times)
return {
"true_effect": true_effect,
"alpha": alpha,
"n_max": n_max,
"n_experiments": n_experiments,
"power": round(1 - never_detected / n_experiments, 4),
"median_stopping_time": int(np.median(arr)),
"p25_stopping_time": int(np.percentile(arr, 25)),
"p75_stopping_time": int(np.percentile(arr, 75)),
"stopping_times": arr.tolist(),
}
def _norm_cdf(x: float) -> float:
"""Standard normal CDF (avoid scipy import in tight loop)."""
from math import erfc, sqrt
return 0.5 * erfc(-x / sqrt(2))
|