""" Evaluation harness for the rag-psych retrieval + generation pipeline. Loads `eval/test_queries.yaml`, runs each query end-to-end through `retrieve_hybrid` + `generate`, scores it against the hand-labeled expectations, and prints a markdown table to stdout. A full JSON dump (per-query and aggregate) is also written to `eval/results/{ISO timestamp}.json` so we can diff runs over time as the pipeline evolves. Metrics: - source_routing_top1 — did the rank-1 chunk match an expected source? - source_recall_top5 — fraction of top-5 from any expected source - keyword_recall — fraction of expected_keywords that appear at least once in the top-5 chunk_text (case-insensitive) - refusal_correct — for off_topic queries, did the system refuse? (refusal is either retrieve_hybrid → [] or the generator returning the canonical refusal string) - citation_validity — fraction of cited chunk_ids that are in the retrieved set (1.0 means no hallucinated citations) - negation_held — for queries with `negation:`, none of the forbidden patterns appear in top-5 chunk_text - retrieval_ms / generation_ms / total_ms — wall-clock per query Aggregates: - means of all numeric per-query metrics - off_topic refusal rate (must be 100% to pass) - any non-empty `invalid_cited_ids` is a hallucination flag Run: .venv/bin/python eval/run_eval.py """ from __future__ import annotations import json import os import re import sys import time from datetime import datetime, timezone from pathlib import Path from statistics import mean from typing import Any import psycopg import yaml from dotenv import load_dotenv sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from api.generate import REFUSAL_STRING, generate # noqa: E402 from api.hybrid import retrieve_hybrid # noqa: E402 EVAL_DIR = Path(__file__).resolve().parent QUERIES_PATH = EVAL_DIR / "test_queries.yaml" RESULTS_DIR = EVAL_DIR / "results" def main() -> None: load_dotenv() queries = yaml.safe_load(QUERIES_PATH.read_text())["queries"] print(f"loaded {len(queries)} queries from {QUERIES_PATH.name}\n") results: list[dict[str, Any]] = [] with psycopg.connect(os.environ["DATABASE_URL"]) as conn: for q in queries: results.append(_run_one(conn, q)) print(f" {q['id']:30s} done") aggregate = _aggregate(results) report_md = _format_markdown(results, aggregate) print("\n" + report_md) RESULTS_DIR.mkdir(parents=True, exist_ok=True) ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") out_path = RESULTS_DIR / f"{ts}.json" out_path.write_text(json.dumps({ "timestamp": ts, "n_queries": len(results), "aggregate": aggregate, "per_query": results, }, indent=2)) print(f"\nresults saved to {out_path.relative_to(EVAL_DIR.parent)}") def _run_one(conn: psycopg.Connection, q: dict[str, Any]) -> dict[str, Any]: qid = q["id"] query = q["query"] off_topic = bool(q.get("off_topic", False)) expected_sources = set(q.get("expected_sources") or []) expected_keywords = q.get("expected_keywords") or [] negation = q.get("negation") t0 = time.perf_counter() t_r = time.perf_counter() hits = retrieve_hybrid(conn, query, k=5) retrieval_ms = (time.perf_counter() - t_r) * 1000 gen = generate(query, hits) total_ms = (time.perf_counter() - t0) * 1000 sources = [h.hit.source_type for h in hits] blob = " \n ".join((h.hit.chunk_text or "") for h in hits).lower() source_routing_top1 = ( bool(hits) and bool(expected_sources) and sources[0] in expected_sources ) source_recall_top5 = ( sum(1 for s in sources if s in expected_sources) / 5 if expected_sources else None ) keyword_recall = ( sum(1 for kw in expected_keywords if kw.lower() in blob) / len(expected_keywords) if expected_keywords else None ) refusal_correct = ( gen.refused if off_topic else (not gen.refused or len(hits) == 0) ) # Off-topic queries are correct when refused. Non-off-topic queries are # correct when not refused (or when retrieval refused — also acceptable). if off_topic: refusal_correct = gen.refused else: refusal_correct = True # not penalized; we measure quality elsewhere citation_validity = ( 1.0 if not gen.cited_ids else (1 - len(gen.invalid_cited_ids) / len(gen.cited_ids)) ) negation_held: bool | None = None forbidden_hits: list[str] = [] if negation: for pat in negation.get("forbidden_patterns", []): if pat.lower() in blob: forbidden_hits.append(pat) negation_held = (len(forbidden_hits) == 0) return { "id": qid, "query": query, "off_topic": off_topic, "n_hits": len(hits), "sources_top5": sources, "expected_sources": sorted(expected_sources), "source_routing_top1": source_routing_top1, "source_recall_top5": source_recall_top5, "keyword_recall": keyword_recall, "refused": gen.refused, "off_topic_refusal_correct": (gen.refused == off_topic) if off_topic else None, "cited_ids": gen.cited_ids, "invalid_cited_ids": gen.invalid_cited_ids, "citation_validity": citation_validity, "negation_held": negation_held, "negation_forbidden_hits": forbidden_hits, "answer_first_120": (gen.answer or "")[:120], "model": gen.model, "retrieval_ms": round(retrieval_ms, 1), "generation_ms": round(gen.latency_ms, 1), "total_ms": round(total_ms, 1), } def _aggregate(results: list[dict[str, Any]]) -> dict[str, Any]: def mean_of(key: str, predicate=lambda r: True) -> float | None: values = [r[key] for r in results if predicate(r) and r[key] is not None] return round(mean(values), 3) if values else None on_topic = lambda r: not r["off_topic"] off_topic = lambda r: r["off_topic"] has_neg = lambda r: r["negation_held"] is not None off_topic_results = [r for r in results if r["off_topic"]] off_topic_refusal_rate = ( sum(1 for r in off_topic_results if r["refused"]) / len(off_topic_results) if off_topic_results else None ) any_invalid = any(r["invalid_cited_ids"] for r in results) neg_pass_rate = ( sum(1 for r in results if r["negation_held"]) / sum(1 for r in results if has_neg(r)) if any(has_neg(r) for r in results) else None ) return { "n_queries": len(results), "n_on_topic": sum(1 for r in results if not r["off_topic"]), "n_off_topic": len(off_topic_results), "source_routing_top1_rate": round(sum(1 for r in results if r["source_routing_top1"]) / sum(1 for r in results if on_topic(r)), 3) if any(on_topic(r) for r in results) else None, "mean_source_recall_top5": mean_of("source_recall_top5", on_topic), "mean_keyword_recall": mean_of("keyword_recall", on_topic), "mean_citation_validity": mean_of("citation_validity", on_topic), "any_hallucinated_citation": any_invalid, "off_topic_refusal_rate": off_topic_refusal_rate, "negation_pass_rate": neg_pass_rate, "mean_retrieval_ms": mean_of("retrieval_ms"), "mean_generation_ms": mean_of("generation_ms"), "mean_total_ms": mean_of("total_ms"), } def _format_markdown(results: list[dict[str, Any]], agg: dict[str, Any]) -> str: """Render a compact two-table report: per-query rows + aggregate rollup.""" rows = ["| id | sources@5 | route✓ | kw rec | cite✓ | refused | t_total |", "|---|---|---|---|---|---|---|"] for r in results: kw = "—" if r["keyword_recall"] is None else f"{r['keyword_recall']:.0%}" sr = "—" if r["source_recall_top5"] is None else f"{r['source_recall_top5']:.0%}" cite = f"{r['citation_validity']:.0%}" route = "—" if r["off_topic"] else ("✓" if r["source_routing_top1"] else "✗") refused = "yes" if r["refused"] else "no" rows.append(f"| {r['id']} | {','.join(r['sources_top5'])[:24] or '—'} ({sr}) | {route} | {kw} | {cite} | {refused} | {int(r['total_ms'])}ms |") summary = [ "", "### Aggregate", "", f"- queries: {agg['n_queries']} ({agg['n_on_topic']} on-topic, {agg['n_off_topic']} off-topic)", f"- source-routing top-1: **{(agg['source_routing_top1_rate'] or 0) * 100:.0f}%**", f"- mean source-recall@5: **{(agg['mean_source_recall_top5'] or 0) * 100:.0f}%**", f"- mean keyword-recall: **{(agg['mean_keyword_recall'] or 0) * 100:.0f}%**", f"- mean citation-validity: **{(agg['mean_citation_validity'] or 0) * 100:.0f}%** " f"({'hallucinated citations DETECTED' if agg['any_hallucinated_citation'] else 'no hallucinated citations'})", f"- off-topic refusal rate: **{(agg['off_topic_refusal_rate'] or 0) * 100:.0f}%** " f"(target 100%)", f"- negation pass rate: **{(agg['negation_pass_rate'] or 0) * 100:.0f}%** " f"(target 100%)" if agg['negation_pass_rate'] is not None else "- negation: not measured", f"- mean retrieval: {int(agg['mean_retrieval_ms'])} ms · " f"mean generation: {int(agg['mean_generation_ms'])} ms · " f"mean total: {int(agg['mean_total_ms'])} ms", ] return "\n".join(rows + summary) if __name__ == "__main__": main()