han-na's picture
remove query_type
f400c92
"""
evaluation/eval.py
Runs the full pipeline against test_queries.jsonl and computes retrieval metrics.
"""
from __future__ import annotations
import argparse
import csv
import json
import sys
import traceback
from pathlib import Path
from typing import List
# ── Run from current_spring2026/ ──────────────────────────────────────────────
sys.path.insert(0, str(Path(__file__).parent.parent))
from pipeline import run_query, PipelineResult
# ── Metric helpers ────────────────────────────────────────────────────────────
def hit_at_k(retrieved: List[str], ground_truths: List[str], k: int) -> int:
return int(any(gt in retrieved[:k] for gt in ground_truths))
def reciprocal_rank(retrieved: List[str], ground_truths: List[str]) -> float:
for i, ark_id in enumerate(retrieved, start=1):
if ark_id in ground_truths:
return 1.0 / i
return 0.0
def recall_at_k(retrieved: List[str], ground_truths: List[str], k: int) -> float:
if not ground_truths:
return 0.0
hits = sum(1 for gt in ground_truths if gt in retrieved[:k])
return hits / len(ground_truths)
def precision_at_k(retrieved: List[str], ground_truths: List[str], k: int) -> float:
if k == 0:
return 0.0
hits = sum(1 for ark in retrieved[:k] if ark in ground_truths)
return hits / k
# ── Main ──────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--queries",
default=str(Path(__file__).parent.parent / "test_queries.jsonl"),
help="Path to test_queries.jsonl",
)
parser.add_argument(
"--out",
default=str(Path(__file__).parent / "eval_results.csv"),
help="Path to save per-query CSV results",
)
args = parser.parse_args()
K_VALUES = [10, 30, 50]
MAX_K = max(K_VALUES)
queries_path = Path(args.queries)
if not queries_path.exists():
print(f"ERROR: test queries file not found at {queries_path}")
sys.exit(1)
with open(queries_path) as f:
entries = [json.loads(line) for line in f if line.strip()]
print(f"Loaded {len(entries)} queries from {queries_path}")
print(f"Evaluating top-{K_VALUES} retrieved results\n")
rows = []
for i, entry in enumerate(entries):
question = entry["question"]
qtype = entry["question_type"]
ground_truths = [
g["ark_id"].removeprefix("commonwealth:")
for g in entry.get("ground_truths", [])
]
reference_answer = entry.get("answer", "")
print(f"[{i+1:02d}/{len(entries)}] ({qtype}) {question[:70]}...")
try:
result: PipelineResult = run_query(question, top_k=MAX_K)
retrieved_ids = [doc.ark_id for doc in result.documents]
mrr = reciprocal_rank(retrieved_ids, ground_truths)
# Hallucination test: pipeline should return no docs (or say "no results")
if qtype == "hallucination_test":
hallucination_pass = int(
len(retrieved_ids) == 0
or "no relevant" in result.generation.response.lower()
or "not found" in result.generation.response.lower()
)
else:
hallucination_pass = ""
row = {
"question": question,
"question_type": qtype,
"rewritten_query": result.intent.rewritten_query,
"num_ground_truths": len(ground_truths),
"num_retrieved": len(retrieved_ids),
"mrr": round(mrr, 4),
"hallucination_pass": hallucination_pass,
"response_preview": result.generation.response[:150].replace("\n", " "),
"retrieved_ids": "|".join(retrieved_ids),
"ground_truth_ids": "|".join(ground_truths),
"latency_ms": result.latency_ms,
"error": "",
}
for k in K_VALUES:
row[f"hit_at_{k}"] = hit_at_k(retrieved_ids, ground_truths, k)
row[f"recall_at_{k}"] = round(recall_at_k(retrieved_ids, ground_truths, k), 4)
row[f"precision_at_{k}"] = round(precision_at_k(retrieved_ids, ground_truths, k), 4)
status = " " + " ".join(f"hit@{k}={row[f'hit_at_{k}']}" for k in K_VALUES)
status += f" mrr={mrr:.3f}"
if qtype == "hallucination_test":
status += f" hallucination_pass={hallucination_pass}"
print(status)
except Exception as e:
traceback.print_exc()
print(f" ERROR: {e}")
row = {
"question": question,
"question_type": qtype,
"classified_as": "",
"rewritten_query": "",
"num_ground_truths": len(ground_truths),
"num_retrieved": 0,
"mrr": "",
"hallucination_pass": "",
"response_preview": "",
"retrieved_ids": "",
"ground_truth_ids": "|".join(ground_truths),
"latency_ms": "",
"error": str(e),
}
for k in K_VALUES:
row[f"hit_at_{k}"] = ""
row[f"recall_at_{k}"] = ""
row[f"precision_at_{k}"] = ""
rows.append(row)
# ── Save CSV ──────────────────────────────────────────────────────────────
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
fieldnames = list(rows[0].keys())
with open(out_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
print(f"\nPer-query results saved to {out_path}")
# ── Summary by query type ─────────────────────────────────────────────────
print("\n" + "=" * 55)
print("SUMMARY")
print("=" * 55)
summary_rows = []
for qtype in ["metadata", "full_text", "hallucination_test"]:
subset = [r for r in rows if r["question_type"] == qtype and r["mrr"] != ""]
if not subset:
continue
n = len(subset)
avg = lambda key: sum(r[key] for r in subset) / n # noqa: E731
print(f"\n{qtype} (n={n})")
print(f" MRR : {avg('mrr'):.3f}")
summary_row = {
"question_type": qtype,
"n": n,
"mrr": round(avg("mrr"), 4),
"hallucination_pass_rate": "",
}
for k in K_VALUES:
print(f" Hit@{k:<2} : {avg(f'hit_at_{k}'):.3f}")
print(f" Recall@{k:<2} : {avg(f'recall_at_{k}'):.3f}")
print(f" Precision@{k:<2} : {avg(f'precision_at_{k}'):.3f}")
summary_row[f"hit_at_{k}"] = round(avg(f"hit_at_{k}"), 4)
summary_row[f"recall_at_{k}"] = round(avg(f"recall_at_{k}"), 4)
summary_row[f"precision_at_{k}"] = round(avg(f"precision_at_{k}"), 4)
if qtype == "hallucination_test":
hall_subset = [r for r in rows if r["question_type"] == qtype and r["hallucination_pass"] != ""]
if hall_subset:
pass_rate = sum(r["hallucination_pass"] for r in hall_subset) / len(hall_subset)
print(f" Hallucination pass : {pass_rate:.3f}")
summary_row["hallucination_pass_rate"] = round(pass_rate, 4)
summary_rows.append(summary_row)
errors = [r for r in rows if r["error"]]
if errors:
print(f"\nFailed queries: {len(errors)}")
for r in errors:
print(f" - {r['question'][:60]}: {r['error']}")
print()
# ── Save summary CSV ──────────────────────────────────────────────────────
if summary_rows:
summary_path = out_path.with_name(out_path.stem + "_summary.csv")
summary_fieldnames = list(summary_rows[0].keys())
with open(summary_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=summary_fieldnames)
writer.writeheader()
writer.writerows(summary_rows)
print(f"Summary results saved to {summary_path}")
if __name__ == "__main__":
main()