ai-response-validator / tests /unit /test_drift.py
mbochniak01
Replace HHEM with sentence-level NLI, add claim decomposition and drift detection
ffbf46f
Raw
History Blame Contribute Delete
4.55 kB
"""
Unit tests for drift detection — detect_drift() only.
No model loading, no IO, no telemetry.
"""
import sys
from pathlib import Path
import numpy as np
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "eval"))
from drift import ALPHA, MIN_CURRENT_SAMPLES, MetricDrift, detect_drift
METRICS = ["faithfulness", "answer_relevancy", "pii_leakage", "token_budget", "chain_terminology"]
def _scores(n: int, **col_values: list[float]) -> dict[str, list[float]]:
"""Build a Scores dict with fixed values per column; defaults to 0.9 for others."""
data: dict[str, list[float]] = {}
for metric in METRICS:
data[metric] = col_values.get(metric, [0.9] * n)
return data
class TestDetectDrift:
def test_identical_distributions_no_drift(self) -> None:
rng = np.random.default_rng(42)
scores = rng.uniform(0.5, 1.0, 50).tolist()
ref = _scores(50, faithfulness=scores)
cur = _scores(50, faithfulness=scores)
results = detect_drift(cur, ref)
faith = next(r for r in results if r.metric == "faithfulness")
assert faith.drifted is False
def test_shifted_distribution_detected(self) -> None:
ref = _scores(50, faithfulness=[0.9] * 50)
cur = _scores(50, faithfulness=[0.1] * 50)
results = detect_drift(cur, ref)
faith = next(r for r in results if r.metric == "faithfulness")
assert faith.drifted is True
assert faith.p_value < ALPHA
def test_below_min_samples_excluded(self) -> None:
ref = _scores(50)
cur = _scores(MIN_CURRENT_SAMPLES - 1)
results = detect_drift(cur, ref)
assert results == []
def test_exactly_min_samples_included(self) -> None:
ref = _scores(50)
cur = _scores(MIN_CURRENT_SAMPLES)
results = detect_drift(cur, ref)
assert len(results) == len(METRICS)
def test_ks_statistic_in_range(self) -> None:
ref = _scores(50, faithfulness=[0.9] * 50)
cur = _scores(50, faithfulness=[0.1] * 50)
results = detect_drift(cur, ref)
faith = next(r for r in results if r.metric == "faithfulness")
assert 0.0 <= faith.ks_statistic <= 1.0
def test_means_computed_correctly(self) -> None:
ref = _scores(10, faithfulness=[0.8] * 10)
cur = _scores(10, faithfulness=[0.4] * 10)
results = detect_drift(cur, ref)
faith = next(r for r in results if r.metric == "faithfulness")
assert faith.ref_mean == pytest.approx(0.8, abs=1e-3)
assert faith.cur_mean == pytest.approx(0.4, abs=1e-3)
def test_all_metrics_returned(self) -> None:
ref = _scores(30)
cur = _scores(30)
result_names = {r.metric for r in detect_drift(cur, ref)}
assert result_names == set(METRICS)
def test_result_is_metric_drift_dataclass(self) -> None:
ref = _scores(20)
cur = _scores(20)
for r in detect_drift(cur, ref):
assert isinstance(r, MetricDrift)
assert isinstance(r.drifted, bool)
assert isinstance(r.ks_statistic, float)
assert isinstance(r.p_value, float)
def test_custom_alpha_respected(self) -> None:
rng = np.random.default_rng(0)
ref = _scores(50, faithfulness=rng.uniform(0.7, 1.0, 50).tolist())
cur = _scores(50, faithfulness=rng.uniform(0.4, 0.7, 50).tolist())
strict = detect_drift(cur, ref, alpha=0.001)
lenient = detect_drift(cur, ref, alpha=0.999)
faith_strict = next(r for r in strict if r.metric == "faithfulness")
faith_lenient = next(r for r in lenient if r.metric == "faithfulness")
assert faith_lenient.drifted or not faith_strict.drifted
def test_missing_metric_column_skipped(self) -> None:
ref: dict[str, list[float]] = {"faithfulness": [0.9] * 20}
cur: dict[str, list[float]] = {"faithfulness": [0.4] * 20}
results = detect_drift(cur, ref)
assert all(r.metric == "faithfulness" for r in results)
assert len(results) == 1
def test_empty_reference_skipped(self) -> None:
ref: dict[str, list[float]] = {"faithfulness": []}
cur: dict[str, list[float]] = {"faithfulness": [0.4] * 20}
results = detect_drift(cur, ref)
assert results == []
def test_sample_counts_in_result(self) -> None:
ref = _scores(30)
cur = _scores(10)
results = detect_drift(cur, ref)
for r in results:
assert r.ref_n == 30
assert r.cur_n == 10