sae-gemma / tests /test_ablation_utils.py
senator1's picture
Sparse-feature audit of induction in Gemma-2-2B (full project)
253d988
"""
tests/test_ablation_utils.py
Tests for bootstrap_ci (pure NumPy β€” no model, no SAE).
"""
import sys
from pathlib import Path
import numpy as np
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
from sae_gemma.ablations import bootstrap_ci
# ── Deterministic seed guarantees (bootstrap uses seed=0 internally) ─────────
def test_all_correct_accuracy_1():
"""All-correct array: mean accuracy == 1.0 and CI == [1.0, 1.0]."""
correct = np.ones(100, dtype=bool)
mean, lo, hi = bootstrap_ci(correct, n_bootstrap=200)
assert mean == pytest.approx(1.0)
assert lo == pytest.approx(1.0)
assert hi == pytest.approx(1.0)
def test_all_wrong_accuracy_0():
"""All-wrong array: mean accuracy == 0.0 and CI == [0.0, 0.0]."""
correct = np.zeros(100, dtype=bool)
mean, lo, hi = bootstrap_ci(correct, n_bootstrap=200)
assert mean == pytest.approx(0.0)
assert lo == pytest.approx(0.0)
assert hi == pytest.approx(0.0)
def test_mixed_mean_close_to_true_proportion():
"""For a 70%-correct array the returned mean should be close to 0.7."""
rng = np.random.default_rng(123)
correct = rng.random(500) < 0.7
mean, lo, hi = bootstrap_ci(correct, n_bootstrap=500)
assert abs(mean - 0.7) < 0.05, f"mean={mean:.3f} too far from 0.7"
def test_ci_contains_mean():
"""CI bounds must satisfy lo <= mean <= hi for any input."""
rng = np.random.default_rng(7)
for _ in range(10):
n = rng.integers(50, 300)
p = rng.uniform(0.1, 0.9)
correct = rng.random(n) < p
mean, lo, hi = bootstrap_ci(correct, n_bootstrap=300)
assert lo <= mean + 1e-9, f"lo={lo} > mean={mean}"
assert hi >= mean - 1e-9, f"hi={hi} < mean={mean}"
def test_ci_lo_leq_hi():
"""The lower CI bound is always <= upper CI bound."""
correct = np.array([True, False, True, True, False])
mean, lo, hi = bootstrap_ci(correct, n_bootstrap=500)
assert lo <= hi
def test_return_types_are_python_float():
"""bootstrap_ci should return plain Python floats, not np.float64."""
correct = np.array([True, True, False])
mean, lo, hi = bootstrap_ci(correct, n_bootstrap=100)
assert isinstance(mean, float)
assert isinstance(lo, float)
assert isinstance(hi, float)
def test_ci_width_shrinks_with_more_data():
"""Wider CI for small n, narrower CI for large n (same true proportion ~0.5)."""
rng = np.random.default_rng(42)
p = 0.5
small = rng.random(30) < p
large = rng.random(3000) < p
_, lo_s, hi_s = bootstrap_ci(small, n_bootstrap=500)
_, lo_l, hi_l = bootstrap_ci(large, n_bootstrap=500)
width_small = hi_s - lo_s
width_large = hi_l - lo_l
assert width_small > width_large, (
f"Expected small-n CI ({width_small:.3f}) wider than large-n CI ({width_large:.3f})"
)
def test_n_bootstrap_affects_reproducibility():
"""bootstrap_ci is reproducible: same input produces same output on repeated calls."""
correct = np.array([True, False, True, True, False, True])
result1 = bootstrap_ci(correct, n_bootstrap=300)
result2 = bootstrap_ci(correct, n_bootstrap=300)
assert result1 == result2, "bootstrap_ci is not reproducible with seed=0"
def test_single_element_all_correct():
"""Edge case: single True element."""
correct = np.array([True])
mean, lo, hi = bootstrap_ci(correct, n_bootstrap=100)
assert mean == pytest.approx(1.0)
assert lo == pytest.approx(1.0)
assert hi == pytest.approx(1.0)
def test_single_element_all_wrong():
"""Edge case: single False element."""
correct = np.array([False])
mean, lo, hi = bootstrap_ci(correct, n_bootstrap=100)
assert mean == pytest.approx(0.0)
assert lo == pytest.approx(0.0)
assert hi == pytest.approx(0.0)
def test_alpha_affects_ci_width():
"""Smaller alpha (wider interval) should produce a wider or equal CI than larger alpha."""
rng = np.random.default_rng(99)
correct = rng.random(200) < 0.6
_, lo_wide, hi_wide = bootstrap_ci(correct, n_bootstrap=500, alpha=0.01) # 99% CI
_, lo_narrow, hi_narrow = bootstrap_ci(correct, n_bootstrap=500, alpha=0.10) # 90% CI
assert (hi_wide - lo_wide) >= (hi_narrow - lo_narrow) - 1e-9