Spaces:
Running
Running
| """ | |
| evaluation/eval.py | |
| Runs the full pipeline against test_queries.jsonl and computes retrieval metrics. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import json | |
| import sys | |
| import traceback | |
| from pathlib import Path | |
| from typing import List | |
| # ββ Run from current_spring2026/ ββββββββββββββββββββββββββββββββββββββββββββββ | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from pipeline import run_query, PipelineResult | |
| # ββ Metric helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def hit_at_k(retrieved: List[str], ground_truths: List[str], k: int) -> int: | |
| return int(any(gt in retrieved[:k] for gt in ground_truths)) | |
| def reciprocal_rank(retrieved: List[str], ground_truths: List[str]) -> float: | |
| for i, ark_id in enumerate(retrieved, start=1): | |
| if ark_id in ground_truths: | |
| return 1.0 / i | |
| return 0.0 | |
| def recall_at_k(retrieved: List[str], ground_truths: List[str], k: int) -> float: | |
| if not ground_truths: | |
| return 0.0 | |
| hits = sum(1 for gt in ground_truths if gt in retrieved[:k]) | |
| return hits / len(ground_truths) | |
| def precision_at_k(retrieved: List[str], ground_truths: List[str], k: int) -> float: | |
| if k == 0: | |
| return 0.0 | |
| hits = sum(1 for ark in retrieved[:k] if ark in ground_truths) | |
| return hits / k | |
| # ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--queries", | |
| default=str(Path(__file__).parent.parent / "test_queries.jsonl"), | |
| help="Path to test_queries.jsonl", | |
| ) | |
| parser.add_argument( | |
| "--out", | |
| default=str(Path(__file__).parent / "eval_results.csv"), | |
| help="Path to save per-query CSV results", | |
| ) | |
| args = parser.parse_args() | |
| K_VALUES = [10, 30, 50] | |
| MAX_K = max(K_VALUES) | |
| queries_path = Path(args.queries) | |
| if not queries_path.exists(): | |
| print(f"ERROR: test queries file not found at {queries_path}") | |
| sys.exit(1) | |
| with open(queries_path) as f: | |
| entries = [json.loads(line) for line in f if line.strip()] | |
| print(f"Loaded {len(entries)} queries from {queries_path}") | |
| print(f"Evaluating top-{K_VALUES} retrieved results\n") | |
| rows = [] | |
| for i, entry in enumerate(entries): | |
| question = entry["question"] | |
| qtype = entry["question_type"] | |
| ground_truths = [ | |
| g["ark_id"].removeprefix("commonwealth:") | |
| for g in entry.get("ground_truths", []) | |
| ] | |
| reference_answer = entry.get("answer", "") | |
| print(f"[{i+1:02d}/{len(entries)}] ({qtype}) {question[:70]}...") | |
| try: | |
| result: PipelineResult = run_query(question, top_k=MAX_K) | |
| retrieved_ids = [doc.ark_id for doc in result.documents] | |
| mrr = reciprocal_rank(retrieved_ids, ground_truths) | |
| # Hallucination test: pipeline should return no docs (or say "no results") | |
| if qtype == "hallucination_test": | |
| hallucination_pass = int( | |
| len(retrieved_ids) == 0 | |
| or "no relevant" in result.generation.response.lower() | |
| or "not found" in result.generation.response.lower() | |
| ) | |
| else: | |
| hallucination_pass = "" | |
| row = { | |
| "question": question, | |
| "question_type": qtype, | |
| "rewritten_query": result.intent.rewritten_query, | |
| "num_ground_truths": len(ground_truths), | |
| "num_retrieved": len(retrieved_ids), | |
| "mrr": round(mrr, 4), | |
| "hallucination_pass": hallucination_pass, | |
| "response_preview": result.generation.response[:150].replace("\n", " "), | |
| "retrieved_ids": "|".join(retrieved_ids), | |
| "ground_truth_ids": "|".join(ground_truths), | |
| "latency_ms": result.latency_ms, | |
| "error": "", | |
| } | |
| for k in K_VALUES: | |
| row[f"hit_at_{k}"] = hit_at_k(retrieved_ids, ground_truths, k) | |
| row[f"recall_at_{k}"] = round(recall_at_k(retrieved_ids, ground_truths, k), 4) | |
| row[f"precision_at_{k}"] = round(precision_at_k(retrieved_ids, ground_truths, k), 4) | |
| status = " " + " ".join(f"hit@{k}={row[f'hit_at_{k}']}" for k in K_VALUES) | |
| status += f" mrr={mrr:.3f}" | |
| if qtype == "hallucination_test": | |
| status += f" hallucination_pass={hallucination_pass}" | |
| print(status) | |
| except Exception as e: | |
| traceback.print_exc() | |
| print(f" ERROR: {e}") | |
| row = { | |
| "question": question, | |
| "question_type": qtype, | |
| "classified_as": "", | |
| "rewritten_query": "", | |
| "num_ground_truths": len(ground_truths), | |
| "num_retrieved": 0, | |
| "mrr": "", | |
| "hallucination_pass": "", | |
| "response_preview": "", | |
| "retrieved_ids": "", | |
| "ground_truth_ids": "|".join(ground_truths), | |
| "latency_ms": "", | |
| "error": str(e), | |
| } | |
| for k in K_VALUES: | |
| row[f"hit_at_{k}"] = "" | |
| row[f"recall_at_{k}"] = "" | |
| row[f"precision_at_{k}"] = "" | |
| rows.append(row) | |
| # ββ Save CSV ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| out_path = Path(args.out) | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| fieldnames = list(rows[0].keys()) | |
| with open(out_path, "w", newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter(f, fieldnames=fieldnames) | |
| writer.writeheader() | |
| writer.writerows(rows) | |
| print(f"\nPer-query results saved to {out_path}") | |
| # ββ Summary by query type βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n" + "=" * 55) | |
| print("SUMMARY") | |
| print("=" * 55) | |
| summary_rows = [] | |
| for qtype in ["metadata", "full_text", "hallucination_test"]: | |
| subset = [r for r in rows if r["question_type"] == qtype and r["mrr"] != ""] | |
| if not subset: | |
| continue | |
| n = len(subset) | |
| avg = lambda key: sum(r[key] for r in subset) / n # noqa: E731 | |
| print(f"\n{qtype} (n={n})") | |
| print(f" MRR : {avg('mrr'):.3f}") | |
| summary_row = { | |
| "question_type": qtype, | |
| "n": n, | |
| "mrr": round(avg("mrr"), 4), | |
| "hallucination_pass_rate": "", | |
| } | |
| for k in K_VALUES: | |
| print(f" Hit@{k:<2} : {avg(f'hit_at_{k}'):.3f}") | |
| print(f" Recall@{k:<2} : {avg(f'recall_at_{k}'):.3f}") | |
| print(f" Precision@{k:<2} : {avg(f'precision_at_{k}'):.3f}") | |
| summary_row[f"hit_at_{k}"] = round(avg(f"hit_at_{k}"), 4) | |
| summary_row[f"recall_at_{k}"] = round(avg(f"recall_at_{k}"), 4) | |
| summary_row[f"precision_at_{k}"] = round(avg(f"precision_at_{k}"), 4) | |
| if qtype == "hallucination_test": | |
| hall_subset = [r for r in rows if r["question_type"] == qtype and r["hallucination_pass"] != ""] | |
| if hall_subset: | |
| pass_rate = sum(r["hallucination_pass"] for r in hall_subset) / len(hall_subset) | |
| print(f" Hallucination pass : {pass_rate:.3f}") | |
| summary_row["hallucination_pass_rate"] = round(pass_rate, 4) | |
| summary_rows.append(summary_row) | |
| errors = [r for r in rows if r["error"]] | |
| if errors: | |
| print(f"\nFailed queries: {len(errors)}") | |
| for r in errors: | |
| print(f" - {r['question'][:60]}: {r['error']}") | |
| print() | |
| # ββ Save summary CSV ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if summary_rows: | |
| summary_path = out_path.with_name(out_path.stem + "_summary.csv") | |
| summary_fieldnames = list(summary_rows[0].keys()) | |
| with open(summary_path, "w", newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter(f, fieldnames=summary_fieldnames) | |
| writer.writeheader() | |
| writer.writerows(summary_rows) | |
| print(f"Summary results saved to {summary_path}") | |
| if __name__ == "__main__": | |
| main() | |