Spaces:
Running
Running
| """ | |
| Sequential / Always-Valid Testing. | |
| Reference: Johari, Pekelis, Walsh (2015) "Always Valid Inference: | |
| Bringing Sequential Analysis to A/B Testing." arXiv:1512.04922. | |
| Traditional A/B testing requires a fixed sample size decided *before* the | |
| experiment. If you "peek" at the data and stop early when p < Ξ±, the actual | |
| false-positive rate inflates to well above Ξ±. | |
| The mixture Sequential Probability Ratio Test (mSPRT) solves this by | |
| maintaining a martingale M_t whose expected value under H0 is 1. We reject | |
| when M_t β₯ 1/Ξ± β valid at *any* stopping time, with E[false positives] β€ Ξ±. | |
| """ | |
| from __future__ import annotations | |
| import numpy as np | |
| from dataclasses import dataclass, asdict | |
| from typing import List, Optional | |
| class SequentialResult: | |
| n_obs: int | |
| e_values: List[float] # running mSPRT e-value | |
| rejected_at: Optional[int] # first time M_t β₯ 1/Ξ± (None if never) | |
| final_rejected: bool | |
| alpha: float | |
| threshold: float # = 1 / alpha | |
| class PeekingSimulation: | |
| """Results of simulating Ξ±-inflation from peeking.""" | |
| n_experiments: int | |
| peek_schedule: List[int] # sample sizes where we peek | |
| traditional_fpr: float # false positive rate with peeking | |
| msprt_fpr: float # false positive rate with mSPRT (β Ξ±) | |
| alpha: float | |
| # ββ Core mSPRT computation βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _msprt_e_value( | |
| x_bar: float, # sample mean of X_i (each X_i = treatment - control obs) | |
| n: int, # number of paired observations seen so far | |
| sigma: float, # estimated SD of the per-observation differences | |
| rho: float, # mixing prior SD (hyperparameter; use rho=sigma for diffuse prior) | |
| ) -> float: | |
| """ | |
| mSPRT e-value for the normal model at sample size n. | |
| M_n = sqrt(sigma^2 / (sigma^2 + n*rho^2)) | |
| * exp(n^2 * x_bar^2 * rho^2 / (2*sigma^2*(sigma^2 + n*rho^2))) | |
| Under H0 (true mean = 0): E[M_n] = 1 for all n. | |
| Under H1 (true mean β 0): M_n β β (detection guaranteed). | |
| """ | |
| sigma2 = sigma ** 2 | |
| rho2 = rho ** 2 | |
| denom = sigma2 + n * rho2 | |
| factor1 = np.sqrt(sigma2 / denom) | |
| s = n * x_bar # sum of differences | |
| factor2 = np.exp(s ** 2 * rho2 / (2 * sigma2 * denom)) | |
| return float(factor1 * factor2) | |
| def sequential_test( | |
| diffs: List[float], # per-observation treatment - control differences | |
| alpha: float = 0.05, | |
| rho_scale: float = 1.0, # Ο = rho_scale * sigma | |
| ) -> SequentialResult: | |
| """ | |
| Run mSPRT on a stream of paired differences. | |
| Parameters | |
| ---------- | |
| diffs : list of (X_treatment_i - X_control_i) values | |
| alpha : significance level | |
| rho_scale : mixing prior width relative to sigma | |
| """ | |
| obs = np.array(diffs, dtype=float) | |
| sigma = float(obs.std()) if obs.std() > 0 else 1.0 | |
| rho = rho_scale * sigma | |
| threshold = 1.0 / alpha | |
| e_values = [] | |
| rejected_at = None | |
| for n in range(1, len(obs) + 1): | |
| x_bar = obs[:n].mean() | |
| ev = _msprt_e_value(x_bar, n, sigma, rho) | |
| e_values.append(round(ev, 4)) | |
| if rejected_at is None and ev >= threshold: | |
| rejected_at = n | |
| return SequentialResult( | |
| n_obs=len(obs), | |
| e_values=e_values, | |
| rejected_at=rejected_at, | |
| final_rejected=rejected_at is not None, | |
| alpha=alpha, | |
| threshold=threshold, | |
| ) | |
| # ββ Confidence sequence ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def confidence_sequence( | |
| obs: List[float], | |
| alpha: float = 0.05, | |
| sigma: float = None, | |
| ) -> tuple[List[float], List[float], List[float]]: | |
| """ | |
| Anytime-valid confidence sequence for the mean. | |
| The width at sample size n is proportional to sqrt(log(log(n)) / n), | |
| wider than fixed CIs at small n (honest about uncertainty) but valid | |
| for ALL stopping rules simultaneously. | |
| Returns (means, lower_bounds, upper_bounds). | |
| """ | |
| obs = np.array(obs, dtype=float) | |
| if sigma is None: | |
| sigma = float(obs.std()) if obs.std() > 0 else 1.0 | |
| means, lowers, uppers = [], [], [] | |
| for n in range(1, len(obs) + 1): | |
| xbar = float(obs[:n].mean()) | |
| # Howard et al. (2021) β stitched normal mixture CS | |
| # Half-width: sigma * sqrt(2*(n+v)/(n^2*v) * log(sqrt((n+v)/v) / alpha)) | |
| # with v = 1 (intrinsic time = 1) | |
| v = 1.0 | |
| inner = max((n + v) / (n ** 2 * v) * np.log(np.sqrt((n + v) / v) / alpha), 0) | |
| hw = sigma * np.sqrt(2 * inner) | |
| means.append(round(xbar, 6)) | |
| lowers.append(round(xbar - hw, 6)) | |
| uppers.append(round(xbar + hw, 6)) | |
| return means, lowers, uppers | |
| # ββ Peeking simulation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def simulate_peeking( | |
| n_total: int = 1000, | |
| n_experiments: int = 2000, | |
| peek_fractions: List[float] = None, | |
| alpha: float = 0.05, | |
| sigma: float = 1.0, | |
| seed: int = 42, | |
| ) -> PeekingSimulation: | |
| """ | |
| Simulate the Ξ±-inflation caused by peeking under the null hypothesis. | |
| Under H0 (no treatment effect), both methods are run: | |
| 1. Traditional: reject if ANY peek gives p < Ξ± | |
| 2. mSPRT: reject when M_t β₯ 1/Ξ± at any peek | |
| Expected result: | |
| - Traditional FPR β 2Γ to 3Γ the nominal Ξ± (due to multiple comparisons) | |
| - mSPRT FPR β Ξ± (always-valid) | |
| """ | |
| if peek_fractions is None: | |
| peek_fractions = [0.25, 0.50, 0.75, 1.00] | |
| rng = np.random.default_rng(seed) | |
| peek_schedule = [int(f * n_total) for f in peek_fractions] | |
| threshold = 1.0 / alpha | |
| traditional_rejects = 0 | |
| msprt_rejects = 0 | |
| for _ in range(n_experiments): | |
| # Generate n_total i.i.d. observations under H0 (true mean = 0) | |
| obs = rng.normal(0.0, sigma, n_total) | |
| sigma_hat = float(obs.std()) if obs.std() > 0 else 1.0 | |
| rho = sigma_hat | |
| trad_rejected = False | |
| msprt_rejected = False | |
| for peek_n in peek_schedule: | |
| if peek_n == 0: | |
| continue | |
| window = obs[:peek_n] | |
| # Traditional: two-sided z-test | |
| if not trad_rejected: | |
| z = window.mean() / (sigma_hat / np.sqrt(peek_n)) | |
| p = 2 * (1 - _norm_cdf(abs(z))) | |
| if p < alpha: | |
| trad_rejected = True | |
| # mSPRT | |
| if not msprt_rejected: | |
| ev = _msprt_e_value(window.mean(), peek_n, sigma_hat, rho) | |
| if ev >= threshold: | |
| msprt_rejected = True | |
| if trad_rejected: | |
| traditional_rejects += 1 | |
| if msprt_rejected: | |
| msprt_rejects += 1 | |
| return PeekingSimulation( | |
| n_experiments=n_experiments, | |
| peek_schedule=peek_schedule, | |
| traditional_fpr=round(traditional_rejects / n_experiments, 4), | |
| msprt_fpr=round(msprt_rejects / n_experiments, 4), | |
| alpha=alpha, | |
| ) | |
| def simulate_detection_speed( | |
| true_effect: float, | |
| n_max: int = 500, | |
| n_experiments: int = 1000, | |
| sigma: float = 1.0, | |
| alpha: float = 0.05, | |
| seed: int = 42, | |
| ) -> dict: | |
| """ | |
| Simulate how quickly mSPRT detects a true effect vs. a fixed-n test. | |
| Returns distribution of stopping times for mSPRT (in number of observations). | |
| """ | |
| rng = np.random.default_rng(seed) | |
| threshold = 1.0 / alpha | |
| stopping_times = [] | |
| never_detected = 0 | |
| for _ in range(n_experiments): | |
| obs = rng.normal(true_effect, sigma, n_max) | |
| sigma_hat = sigma # known in this simulation | |
| rho = sigma_hat | |
| detected = False | |
| for n in range(1, n_max + 1): | |
| ev = _msprt_e_value(obs[:n].mean(), n, sigma_hat, rho) | |
| if ev >= threshold: | |
| stopping_times.append(n) | |
| detected = True | |
| break | |
| if not detected: | |
| never_detected += 1 | |
| stopping_times.append(n_max) | |
| arr = np.array(stopping_times) | |
| return { | |
| "true_effect": true_effect, | |
| "alpha": alpha, | |
| "n_max": n_max, | |
| "n_experiments": n_experiments, | |
| "power": round(1 - never_detected / n_experiments, 4), | |
| "median_stopping_time": int(np.median(arr)), | |
| "p25_stopping_time": int(np.percentile(arr, 25)), | |
| "p75_stopping_time": int(np.percentile(arr, 75)), | |
| "stopping_times": arr.tolist(), | |
| } | |
| def _norm_cdf(x: float) -> float: | |
| """Standard normal CDF (avoid scipy import in tight loop).""" | |
| from math import erfc, sqrt | |
| return 0.5 * erfc(-x / sqrt(2)) | |