mbochniak01
Replace HHEM with sentence-level NLI, add claim decomposition and drift detection
ffbf46f
Raw
History Blame Contribute Delete
6.34 kB
"""
Drift detection: compare live grader score distributions against the golden-dataset baseline.
Answers: has answer quality shifted since the reference was established?
Catches: model updates, KB staleness, query distribution shift, threshold miscalibration.
Statistical test: KS two-sample (same as Evidently DataDriftPreset for numerical columns).
- H0: current and reference are drawn from the same distribution
- H1: distributions differ
- Drifted if p_value < alpha (default 0.05)
Reference: golden-dataset expected_answer scores (known-good baseline).
Current: in-memory telemetry._events from the running API session.
Usage:
cd /Users/praca/ai-response-validator && .venv/bin/python eval/drift.py
"""
import sys
from dataclasses import dataclass
from pathlib import Path
import yaml
from scipy.stats import ks_2samp
sys.path.insert(0, str(Path(__file__).parent.parent / "backend"))
from grader import (
grade_answer_relevancy,
grade_chain_terminology,
grade_faithfulness_decomposed,
grade_pii_leakage,
grade_token_budget,
)
DATASET_PATH = Path(__file__).parent / "golden-dataset.yaml"
KNOWLEDGE_ROOT = Path(__file__).parent.parent / "knowledge"
METRICS = ["faithfulness", "answer_relevancy", "pii_leakage", "token_budget", "chain_terminology"]
ALPHA = 0.05
MIN_CURRENT_SAMPLES = 5
@dataclass(slots=True)
class MetricDrift:
metric: str
ks_statistic: float
p_value: float
drifted: bool
ref_mean: float
cur_mean: float
ref_n: int
cur_n: int
def _load_kb_context(domain: str) -> str:
path = KNOWLEDGE_ROOT / domain / "features.yaml"
data = yaml.safe_load(path.read_text())
chunks = [f"[{doc['title']}]\n{doc['content'].strip()}" for doc in data["documents"]]
return "\n\n".join(chunks)
Scores = dict[str, list[float]]
def build_reference() -> Scores:
"""Score every golden-dataset pair with all graders."""
pairs = yaml.safe_load(DATASET_PATH.read_text())["pairs"]
kb: dict[str, str] = {}
scores: Scores = {m: [] for m in METRICS}
for pair in pairs:
response = pair["expected_answer"].strip()
domain = pair["domain"]
if domain not in kb:
kb[domain] = _load_kb_context(domain)
scores["pii_leakage"].append(grade_pii_leakage(response).score)
scores["token_budget"].append(grade_token_budget(response).score)
scores["answer_relevancy"].append(grade_answer_relevancy(pair["question"], response).score)
scores["faithfulness"].append(grade_faithfulness_decomposed(response, kb[domain]).score)
scores["chain_terminology"].append(grade_chain_terminology(response, pair["client"]).score)
return scores
def build_current() -> Scores:
"""Pull metric scores from the in-memory telemetry buffer."""
import telemetry
with telemetry._lock:
events = list(telemetry._events)
scores: Scores = {m: [] for m in METRICS}
for event in events:
if "metrics" not in event:
continue
if any(event["metrics"].get(m) is None for m in METRICS):
continue
for m in METRICS:
scores[m].append(float(event["metrics"][m]))
return scores
def detect_drift(
current: Scores,
reference: Scores,
alpha: float = ALPHA,
) -> list[MetricDrift]:
"""Run KS two-sample test per metric. Skips metrics with fewer than MIN_CURRENT_SAMPLES."""
results: list[MetricDrift] = []
for metric in METRICS:
ref_col = reference.get(metric, [])
cur_col = current.get(metric, [])
if len(cur_col) < MIN_CURRENT_SAMPLES or len(ref_col) == 0:
continue
import numpy as np
ref_arr = np.array(ref_col, dtype=float)
cur_arr = np.array(cur_col, dtype=float)
stat, pval = ks_2samp(ref_arr, cur_arr)
results.append(MetricDrift(
metric=metric,
ks_statistic=round(float(stat), 4),
p_value=round(float(pval), 4),
drifted=bool(pval < alpha),
ref_mean=round(float(ref_arr.mean()), 4),
cur_mean=round(float(cur_arr.mean()), 4),
ref_n=len(ref_arr),
cur_n=len(cur_arr),
))
return results
def report_drift(results: list[MetricDrift], alpha: float = ALPHA) -> None:
header = (
f"{'metric':<22} {'ks_stat':>7} {'p_value':>7} {'status':>10}"
f" {'ref_mean':>8} {'cur_mean':>8} {'delta':>7}"
)
print(header)
print("-" * len(header))
for r in results:
status = "DRIFT <--" if r.drifted else "ok"
delta = r.cur_mean - r.ref_mean
sign = "+" if delta >= 0 else ""
print(
f"{r.metric:<22} {r.ks_statistic:>7.4f} {r.p_value:>7.4f} {status:>10}"
f" {r.ref_mean:>8.4f} {r.cur_mean:>8.4f} {sign}{delta:>6.4f}"
)
drifted = [r for r in results if r.drifted]
print(f"\nOverall: {len(drifted)}/{len(results)} metrics drifted (alpha={alpha})")
if drifted:
print("\nDrifted metrics:")
for r in drifted:
direction = "degraded" if r.cur_mean < r.ref_mean else "improved"
print(f" {r.metric}: {direction} ({r.ref_mean:.3f}{r.cur_mean:.3f})")
def run() -> None:
print("\nBuilding reference distribution from golden-dataset.yaml...")
reference = build_reference()
ref_n = len(next(iter(reference.values()), []))
print(f"Reference: {ref_n} pairs\n")
current = build_current()
cur_n = len(next(iter(current.values()), []))
if cur_n < MIN_CURRENT_SAMPLES:
import numpy as np
print(
f"Current: {cur_n} telemetry event(s) — need ≥{MIN_CURRENT_SAMPLES} to run KS test.\n"
f"Start the API and run some queries, then re-run drift.py.\n\n"
f"Reference distribution (golden baseline):\n"
)
for m in METRICS:
vals = np.array(reference[m])
print(f" {m:<22} mean={vals.mean():.3f} std={vals.std():.3f} min={vals.min():.3f} max={vals.max():.3f}")
return
print(f"Current: {cur_n} telemetry events\n")
results = detect_drift(current, reference)
if not results:
print("No metrics had enough data for KS test.\n")
return
report_drift(results)
print()
if __name__ == "__main__":
run()