File size: 2,692 Bytes
9508d8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Evaluate retrieval metrics: Recall@K, mAP."""

import argparse
import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))


def compute_recall_at_k(results: list[list[str]], relevant: list[set[str]], k: int) -> float:
    """Recall@K: fraction of relevant items in top-K."""
    total = 0.0
    for res, rel in zip(results, relevant):
        top_k = set(res[:k])
        if rel:
            hit = len(top_k & rel) / len(rel)
            total += min(1.0, hit)
        else:
            total += 0.0
    return total / len(results) if results else 0.0


def compute_ap(results: list[str], relevant: set[str]) -> float:
    """Average Precision for a single query."""
    if not relevant:
        return 0.0
    hits = 0
    prec_sum = 0.0
    for i, doc in enumerate(results):
        if doc in relevant:
            hits += 1
            prec_sum += hits / (i + 1)
    return prec_sum / len(relevant) if relevant else 0.0


def compute_map(results: list[list[str]], relevant: list[set[str]]) -> float:
    """Mean Average Precision."""
    return sum(compute_ap(r, rel) for r, rel in zip(results, relevant)) / len(results) if results else 0.0


def main():
    parser = argparse.ArgumentParser(description="Evaluate retrieval (Recall@K, mAP)")
    parser.add_argument("--results-file", type=Path, help="JSON file with query -> top-K ids")
    parser.add_argument("--ground-truth", type=Path, help="JSON file with query -> set of relevant ids")
    args = parser.parse_args()

    if not args.results_file or not args.ground_truth:
        print("Usage: python evaluate.py --results-file results.json --ground-truth gt.json")
        print("Format: each JSON is a list of {query_id, result_ids} and {query_id, relevant_ids}")
        sys.exit(1)

    import json
    with open(args.results_file) as f:
        results_data = json.load(f)
    with open(args.ground_truth) as f:
        gt_data = json.load(f)

    # Build lookup
    gt_map = {r["query_id"]: set(r["relevant_ids"]) for r in gt_data}
    results_list = []
    relevant_list = []
    for r in results_data:
        qid = r["query_id"]
        if qid in gt_map:
            results_list.append(r["result_ids"])
            relevant_list.append(gt_map[qid])

    r1 = compute_recall_at_k(results_list, relevant_list, 1)
    r5 = compute_recall_at_k(results_list, relevant_list, 5)
    r10 = compute_recall_at_k(results_list, relevant_list, 10)
    mAP = compute_map(results_list, relevant_list)

    print(f"Recall@1:  {r1:.4f}")
    print(f"Recall@5:  {r5:.4f}")
    print(f"Recall@10: {r10:.4f}")
    print(f"mAP:       {mAP:.4f}")


if __name__ == "__main__":
    main()