mbochniak01
Replace HHEM with sentence-level NLI, add claim decomposition and drift detection
ffbf46f | """ | |
| Unit tests for drift detection — detect_drift() only. | |
| No model loading, no IO, no telemetry. | |
| """ | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import pytest | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent / "eval")) | |
| from drift import ALPHA, MIN_CURRENT_SAMPLES, MetricDrift, detect_drift | |
| METRICS = ["faithfulness", "answer_relevancy", "pii_leakage", "token_budget", "chain_terminology"] | |
| def _scores(n: int, **col_values: list[float]) -> dict[str, list[float]]: | |
| """Build a Scores dict with fixed values per column; defaults to 0.9 for others.""" | |
| data: dict[str, list[float]] = {} | |
| for metric in METRICS: | |
| data[metric] = col_values.get(metric, [0.9] * n) | |
| return data | |
| class TestDetectDrift: | |
| def test_identical_distributions_no_drift(self) -> None: | |
| rng = np.random.default_rng(42) | |
| scores = rng.uniform(0.5, 1.0, 50).tolist() | |
| ref = _scores(50, faithfulness=scores) | |
| cur = _scores(50, faithfulness=scores) | |
| results = detect_drift(cur, ref) | |
| faith = next(r for r in results if r.metric == "faithfulness") | |
| assert faith.drifted is False | |
| def test_shifted_distribution_detected(self) -> None: | |
| ref = _scores(50, faithfulness=[0.9] * 50) | |
| cur = _scores(50, faithfulness=[0.1] * 50) | |
| results = detect_drift(cur, ref) | |
| faith = next(r for r in results if r.metric == "faithfulness") | |
| assert faith.drifted is True | |
| assert faith.p_value < ALPHA | |
| def test_below_min_samples_excluded(self) -> None: | |
| ref = _scores(50) | |
| cur = _scores(MIN_CURRENT_SAMPLES - 1) | |
| results = detect_drift(cur, ref) | |
| assert results == [] | |
| def test_exactly_min_samples_included(self) -> None: | |
| ref = _scores(50) | |
| cur = _scores(MIN_CURRENT_SAMPLES) | |
| results = detect_drift(cur, ref) | |
| assert len(results) == len(METRICS) | |
| def test_ks_statistic_in_range(self) -> None: | |
| ref = _scores(50, faithfulness=[0.9] * 50) | |
| cur = _scores(50, faithfulness=[0.1] * 50) | |
| results = detect_drift(cur, ref) | |
| faith = next(r for r in results if r.metric == "faithfulness") | |
| assert 0.0 <= faith.ks_statistic <= 1.0 | |
| def test_means_computed_correctly(self) -> None: | |
| ref = _scores(10, faithfulness=[0.8] * 10) | |
| cur = _scores(10, faithfulness=[0.4] * 10) | |
| results = detect_drift(cur, ref) | |
| faith = next(r for r in results if r.metric == "faithfulness") | |
| assert faith.ref_mean == pytest.approx(0.8, abs=1e-3) | |
| assert faith.cur_mean == pytest.approx(0.4, abs=1e-3) | |
| def test_all_metrics_returned(self) -> None: | |
| ref = _scores(30) | |
| cur = _scores(30) | |
| result_names = {r.metric for r in detect_drift(cur, ref)} | |
| assert result_names == set(METRICS) | |
| def test_result_is_metric_drift_dataclass(self) -> None: | |
| ref = _scores(20) | |
| cur = _scores(20) | |
| for r in detect_drift(cur, ref): | |
| assert isinstance(r, MetricDrift) | |
| assert isinstance(r.drifted, bool) | |
| assert isinstance(r.ks_statistic, float) | |
| assert isinstance(r.p_value, float) | |
| def test_custom_alpha_respected(self) -> None: | |
| rng = np.random.default_rng(0) | |
| ref = _scores(50, faithfulness=rng.uniform(0.7, 1.0, 50).tolist()) | |
| cur = _scores(50, faithfulness=rng.uniform(0.4, 0.7, 50).tolist()) | |
| strict = detect_drift(cur, ref, alpha=0.001) | |
| lenient = detect_drift(cur, ref, alpha=0.999) | |
| faith_strict = next(r for r in strict if r.metric == "faithfulness") | |
| faith_lenient = next(r for r in lenient if r.metric == "faithfulness") | |
| assert faith_lenient.drifted or not faith_strict.drifted | |
| def test_missing_metric_column_skipped(self) -> None: | |
| ref: dict[str, list[float]] = {"faithfulness": [0.9] * 20} | |
| cur: dict[str, list[float]] = {"faithfulness": [0.4] * 20} | |
| results = detect_drift(cur, ref) | |
| assert all(r.metric == "faithfulness" for r in results) | |
| assert len(results) == 1 | |
| def test_empty_reference_skipped(self) -> None: | |
| ref: dict[str, list[float]] = {"faithfulness": []} | |
| cur: dict[str, list[float]] = {"faithfulness": [0.4] * 20} | |
| results = detect_drift(cur, ref) | |
| assert results == [] | |
| def test_sample_counts_in_result(self) -> None: | |
| ref = _scores(30) | |
| cur = _scores(10) | |
| results = detect_drift(cur, ref) | |
| for r in results: | |
| assert r.ref_n == 30 | |
| assert r.cur_n == 10 | |