"""Retrieval evaluation using annotated ground truth from eval_set.json. Loads the annotation file, computes Hit Rate@k and MRR for each retrieval method separately using the pooled_from data and relevant_chunk_ids, and produces a comparison table. Usage: python scripts/retrieval_eval.py python scripts/retrieval_eval.py --k 5 """ import argparse import json import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from src.config import PROJECT_ROOT from src.evaluation.metrics import mrr, precision_at_k, recall_at_k, ndcg_at_k EVAL_SET_PATH = PROJECT_ROOT / "data" / "eval_set.json" def load_eval_set() -> list[dict]: if not EVAL_SET_PATH.exists(): print(f"Eval set not found: {EVAL_SET_PATH}") print("Run: python scripts/write_questions.py && python scripts/annotate.py") sys.exit(1) with open(EVAL_SET_PATH, encoding="utf-8") as f: data = json.load(f) annotated = [e for e in data if e.get("relevant_chunk_ids")] if not annotated: print("No annotated entries found. Run scripts/annotate.py first.") sys.exit(1) return annotated def evaluate_method( entries: list[dict], method_key: str, k_values: list[int], ) -> dict: """Compute retrieval metrics for one method across all entries.""" all_metrics: dict[str, list[float]] = {} for entry in entries: pooled = entry.get("pooled_from", {}) retrieved_ids = [str(x) for x in pooled.get(method_key, [])] relevant_ids = set(str(x) for x in entry.get("relevant_chunk_ids", [])) if not relevant_ids: continue entry_mrr = mrr(retrieved_ids, relevant_ids) all_metrics.setdefault("mrr", []).append(entry_mrr) for k in k_values: hit = 1.0 if any(rid in relevant_ids for rid in retrieved_ids[:k]) else 0.0 all_metrics.setdefault(f"hit_rate@{k}", []).append(hit) all_metrics.setdefault(f"precision@{k}", []).append( precision_at_k(retrieved_ids, relevant_ids, k) ) all_metrics.setdefault(f"recall@{k}", []).append( recall_at_k(retrieved_ids, relevant_ids, k) ) all_metrics.setdefault(f"ndcg@{k}", []).append( ndcg_at_k(retrieved_ids, relevant_ids, k) ) # Average return { key: sum(vals) / len(vals) if vals else 0.0 for key, vals in all_metrics.items() } def print_table(results: dict[str, dict], k_values: list[int]) -> None: """Print a comparison table.""" methods = list(results.keys()) # Collect metric columns in order columns = ["mrr"] for k in k_values: columns.extend([f"hit_rate@{k}", f"precision@{k}", f"recall@{k}", f"ndcg@{k}"]) # Header header = f"{'Method':<15}" for col in columns: header += f" {col:>12}" print(header) print("-" * len(header)) # Rows for method in methods: row = f"{method:<15}" for col in columns: val = results[method].get(col, 0.0) row += f" {val:>12.3f}" print(row) def main(): parser = argparse.ArgumentParser(description="Retrieval evaluation from annotations") parser.add_argument( "--k", type=int, nargs="+", default=[1, 3, 5, 10], help="k values for metrics (default: 1 3 5 10)", ) args = parser.parse_args() entries = load_eval_set() print(f"\n=== Retrieval Evaluation ({len(entries)} annotated questions) ===\n") method_keys = ["bm25_top10", "vector_top10", "hybrid_top10"] results = {} for mk in method_keys: label = mk.replace("_top10", "") results[label] = evaluate_method(entries, mk, args.k) print_table(results, args.k) print() if __name__ == "__main__": main()