File size: 3,851 Bytes
65dfa4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""Retrieval evaluation using annotated ground truth from eval_set.json.

Loads the annotation file, computes Hit Rate@k and MRR for each retrieval
method separately using the pooled_from data and relevant_chunk_ids, and
produces a comparison table.

Usage:
    python scripts/retrieval_eval.py
    python scripts/retrieval_eval.py --k 5
"""

import argparse
import json
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.config import PROJECT_ROOT
from src.evaluation.metrics import mrr, precision_at_k, recall_at_k, ndcg_at_k

EVAL_SET_PATH = PROJECT_ROOT / "data" / "eval_set.json"


def load_eval_set() -> list[dict]:
    if not EVAL_SET_PATH.exists():
        print(f"Eval set not found: {EVAL_SET_PATH}")
        print("Run: python scripts/write_questions.py && python scripts/annotate.py")
        sys.exit(1)

    with open(EVAL_SET_PATH, encoding="utf-8") as f:
        data = json.load(f)

    annotated = [e for e in data if e.get("relevant_chunk_ids")]
    if not annotated:
        print("No annotated entries found. Run scripts/annotate.py first.")
        sys.exit(1)

    return annotated


def evaluate_method(
    entries: list[dict],
    method_key: str,
    k_values: list[int],
) -> dict:
    """Compute retrieval metrics for one method across all entries."""
    all_metrics: dict[str, list[float]] = {}

    for entry in entries:
        pooled = entry.get("pooled_from", {})
        retrieved_ids = [str(x) for x in pooled.get(method_key, [])]
        relevant_ids = set(str(x) for x in entry.get("relevant_chunk_ids", []))

        if not relevant_ids:
            continue

        entry_mrr = mrr(retrieved_ids, relevant_ids)
        all_metrics.setdefault("mrr", []).append(entry_mrr)

        for k in k_values:
            hit = 1.0 if any(rid in relevant_ids for rid in retrieved_ids[:k]) else 0.0
            all_metrics.setdefault(f"hit_rate@{k}", []).append(hit)
            all_metrics.setdefault(f"precision@{k}", []).append(
                precision_at_k(retrieved_ids, relevant_ids, k)
            )
            all_metrics.setdefault(f"recall@{k}", []).append(
                recall_at_k(retrieved_ids, relevant_ids, k)
            )
            all_metrics.setdefault(f"ndcg@{k}", []).append(
                ndcg_at_k(retrieved_ids, relevant_ids, k)
            )

    # Average
    return {
        key: sum(vals) / len(vals) if vals else 0.0
        for key, vals in all_metrics.items()
    }


def print_table(results: dict[str, dict], k_values: list[int]) -> None:
    """Print a comparison table."""
    methods = list(results.keys())

    # Collect metric columns in order
    columns = ["mrr"]
    for k in k_values:
        columns.extend([f"hit_rate@{k}", f"precision@{k}", f"recall@{k}", f"ndcg@{k}"])

    # Header
    header = f"{'Method':<15}"
    for col in columns:
        header += f"  {col:>12}"
    print(header)
    print("-" * len(header))

    # Rows
    for method in methods:
        row = f"{method:<15}"
        for col in columns:
            val = results[method].get(col, 0.0)
            row += f"  {val:>12.3f}"
        print(row)


def main():
    parser = argparse.ArgumentParser(description="Retrieval evaluation from annotations")
    parser.add_argument(
        "--k", type=int, nargs="+", default=[1, 3, 5, 10],
        help="k values for metrics (default: 1 3 5 10)",
    )
    args = parser.parse_args()

    entries = load_eval_set()
    print(f"\n=== Retrieval Evaluation ({len(entries)} annotated questions) ===\n")

    method_keys = ["bm25_top10", "vector_top10", "hybrid_top10"]
    results = {}
    for mk in method_keys:
        label = mk.replace("_top10", "")
        results[label] = evaluate_method(entries, mk, args.k)

    print_table(results, args.k)
    print()


if __name__ == "__main__":
    main()