File size: 2,692 Bytes
9508d8c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | """Evaluate retrieval metrics: Recall@K, mAP."""
import argparse
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
def compute_recall_at_k(results: list[list[str]], relevant: list[set[str]], k: int) -> float:
"""Recall@K: fraction of relevant items in top-K."""
total = 0.0
for res, rel in zip(results, relevant):
top_k = set(res[:k])
if rel:
hit = len(top_k & rel) / len(rel)
total += min(1.0, hit)
else:
total += 0.0
return total / len(results) if results else 0.0
def compute_ap(results: list[str], relevant: set[str]) -> float:
"""Average Precision for a single query."""
if not relevant:
return 0.0
hits = 0
prec_sum = 0.0
for i, doc in enumerate(results):
if doc in relevant:
hits += 1
prec_sum += hits / (i + 1)
return prec_sum / len(relevant) if relevant else 0.0
def compute_map(results: list[list[str]], relevant: list[set[str]]) -> float:
"""Mean Average Precision."""
return sum(compute_ap(r, rel) for r, rel in zip(results, relevant)) / len(results) if results else 0.0
def main():
parser = argparse.ArgumentParser(description="Evaluate retrieval (Recall@K, mAP)")
parser.add_argument("--results-file", type=Path, help="JSON file with query -> top-K ids")
parser.add_argument("--ground-truth", type=Path, help="JSON file with query -> set of relevant ids")
args = parser.parse_args()
if not args.results_file or not args.ground_truth:
print("Usage: python evaluate.py --results-file results.json --ground-truth gt.json")
print("Format: each JSON is a list of {query_id, result_ids} and {query_id, relevant_ids}")
sys.exit(1)
import json
with open(args.results_file) as f:
results_data = json.load(f)
with open(args.ground_truth) as f:
gt_data = json.load(f)
# Build lookup
gt_map = {r["query_id"]: set(r["relevant_ids"]) for r in gt_data}
results_list = []
relevant_list = []
for r in results_data:
qid = r["query_id"]
if qid in gt_map:
results_list.append(r["result_ids"])
relevant_list.append(gt_map[qid])
r1 = compute_recall_at_k(results_list, relevant_list, 1)
r5 = compute_recall_at_k(results_list, relevant_list, 5)
r10 = compute_recall_at_k(results_list, relevant_list, 10)
mAP = compute_map(results_list, relevant_list)
print(f"Recall@1: {r1:.4f}")
print(f"Recall@5: {r5:.4f}")
print(f"Recall@10: {r10:.4f}")
print(f"mAP: {mAP:.4f}")
if __name__ == "__main__":
main()
|