| """ |
| tests/test_ablation_utils.py |
| Tests for bootstrap_ci (pure NumPy β no model, no SAE). |
| """ |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| import pytest |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) |
|
|
| from sae_gemma.ablations import bootstrap_ci |
|
|
|
|
| |
|
|
|
|
| def test_all_correct_accuracy_1(): |
| """All-correct array: mean accuracy == 1.0 and CI == [1.0, 1.0].""" |
| correct = np.ones(100, dtype=bool) |
| mean, lo, hi = bootstrap_ci(correct, n_bootstrap=200) |
| assert mean == pytest.approx(1.0) |
| assert lo == pytest.approx(1.0) |
| assert hi == pytest.approx(1.0) |
|
|
|
|
| def test_all_wrong_accuracy_0(): |
| """All-wrong array: mean accuracy == 0.0 and CI == [0.0, 0.0].""" |
| correct = np.zeros(100, dtype=bool) |
| mean, lo, hi = bootstrap_ci(correct, n_bootstrap=200) |
| assert mean == pytest.approx(0.0) |
| assert lo == pytest.approx(0.0) |
| assert hi == pytest.approx(0.0) |
|
|
|
|
| def test_mixed_mean_close_to_true_proportion(): |
| """For a 70%-correct array the returned mean should be close to 0.7.""" |
| rng = np.random.default_rng(123) |
| correct = rng.random(500) < 0.7 |
| mean, lo, hi = bootstrap_ci(correct, n_bootstrap=500) |
| assert abs(mean - 0.7) < 0.05, f"mean={mean:.3f} too far from 0.7" |
|
|
|
|
| def test_ci_contains_mean(): |
| """CI bounds must satisfy lo <= mean <= hi for any input.""" |
| rng = np.random.default_rng(7) |
| for _ in range(10): |
| n = rng.integers(50, 300) |
| p = rng.uniform(0.1, 0.9) |
| correct = rng.random(n) < p |
| mean, lo, hi = bootstrap_ci(correct, n_bootstrap=300) |
| assert lo <= mean + 1e-9, f"lo={lo} > mean={mean}" |
| assert hi >= mean - 1e-9, f"hi={hi} < mean={mean}" |
|
|
|
|
| def test_ci_lo_leq_hi(): |
| """The lower CI bound is always <= upper CI bound.""" |
| correct = np.array([True, False, True, True, False]) |
| mean, lo, hi = bootstrap_ci(correct, n_bootstrap=500) |
| assert lo <= hi |
|
|
|
|
| def test_return_types_are_python_float(): |
| """bootstrap_ci should return plain Python floats, not np.float64.""" |
| correct = np.array([True, True, False]) |
| mean, lo, hi = bootstrap_ci(correct, n_bootstrap=100) |
| assert isinstance(mean, float) |
| assert isinstance(lo, float) |
| assert isinstance(hi, float) |
|
|
|
|
| def test_ci_width_shrinks_with_more_data(): |
| """Wider CI for small n, narrower CI for large n (same true proportion ~0.5).""" |
| rng = np.random.default_rng(42) |
| p = 0.5 |
| small = rng.random(30) < p |
| large = rng.random(3000) < p |
|
|
| _, lo_s, hi_s = bootstrap_ci(small, n_bootstrap=500) |
| _, lo_l, hi_l = bootstrap_ci(large, n_bootstrap=500) |
|
|
| width_small = hi_s - lo_s |
| width_large = hi_l - lo_l |
| assert width_small > width_large, ( |
| f"Expected small-n CI ({width_small:.3f}) wider than large-n CI ({width_large:.3f})" |
| ) |
|
|
|
|
| def test_n_bootstrap_affects_reproducibility(): |
| """bootstrap_ci is reproducible: same input produces same output on repeated calls.""" |
| correct = np.array([True, False, True, True, False, True]) |
| result1 = bootstrap_ci(correct, n_bootstrap=300) |
| result2 = bootstrap_ci(correct, n_bootstrap=300) |
| assert result1 == result2, "bootstrap_ci is not reproducible with seed=0" |
|
|
|
|
| def test_single_element_all_correct(): |
| """Edge case: single True element.""" |
| correct = np.array([True]) |
| mean, lo, hi = bootstrap_ci(correct, n_bootstrap=100) |
| assert mean == pytest.approx(1.0) |
| assert lo == pytest.approx(1.0) |
| assert hi == pytest.approx(1.0) |
|
|
|
|
| def test_single_element_all_wrong(): |
| """Edge case: single False element.""" |
| correct = np.array([False]) |
| mean, lo, hi = bootstrap_ci(correct, n_bootstrap=100) |
| assert mean == pytest.approx(0.0) |
| assert lo == pytest.approx(0.0) |
| assert hi == pytest.approx(0.0) |
|
|
|
|
| def test_alpha_affects_ci_width(): |
| """Smaller alpha (wider interval) should produce a wider or equal CI than larger alpha.""" |
| rng = np.random.default_rng(99) |
| correct = rng.random(200) < 0.6 |
| _, lo_wide, hi_wide = bootstrap_ci(correct, n_bootstrap=500, alpha=0.01) |
| _, lo_narrow, hi_narrow = bootstrap_ci(correct, n_bootstrap=500, alpha=0.10) |
| assert (hi_wide - lo_wide) >= (hi_narrow - lo_narrow) - 1e-9 |
|
|