#!/usr/bin/env python3 """ Evaluate retrieval quality against a small benchmark file. Benchmark format (JSON): { "queries": [ { "name": "egfr_lung", "query": "EGFR lung cancer", "modality": "auto", "top_k": 20, "relevant_ids": ["..."], "relevance_by_id": {"...": 1.0, "...": 2.0} } ], "k": 10 } Notes: - `relevant_ids` is used for Recall@k and MRR@k. - `relevance_by_id` is used for nDCG@k. - You can provide either or both. """ from __future__ import annotations import argparse import json from pathlib import Path from typing import Any, Dict, List, Set import requests from bioflow.evaluation.metrics import mrr_at_k, ndcg_at_k, recall_at_k def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--benchmark", required=True, help="Path to benchmark JSON") ap.add_argument("--base-url", default="http://localhost:8000") args = ap.parse_args() bench = json.loads(Path(args.benchmark).read_text(encoding="utf-8")) k = int(bench.get("k", 10)) queries = bench.get("queries", []) if not queries: raise SystemExit("Benchmark has no queries") recalls: List[float] = [] mrrs: List[float] = [] ndcgs: List[float] = [] for q in queries: query = q["query"] modality = q.get("modality", "auto") top_k = int(q.get("top_k", max(k, 20))) r = requests.post( f"{args.base_url}/api/search", json={"query": query, "modality": modality, "top_k": top_k, "use_mmr": False}, timeout=60, ) r.raise_for_status() data = r.json() ranked_ids = [str(item.get("id")) for item in data.get("results", []) if item.get("id") is not None] relevant_ids = set(map(str, q.get("relevant_ids", []))) relevance_by_id = {str(k): float(v) for k, v in (q.get("relevance_by_id", {}) or {}).items()} if relevant_ids: recalls.append(recall_at_k(relevant_ids, ranked_ids, k)) mrrs.append(mrr_at_k(relevant_ids, ranked_ids, k)) if relevance_by_id: ndcgs.append(ndcg_at_k(relevance_by_id, ranked_ids, k)) print(f"- {q.get('name', query[:30])}: got={len(ranked_ids)} recall@{k}={recalls[-1] if relevant_ids else 'n/a'} mrr@{k}={mrrs[-1] if relevant_ids else 'n/a'} ndcg@{k}={ndcgs[-1] if relevance_by_id else 'n/a'}") def _avg(xs: List[float]) -> float: return sum(xs) / float(len(xs)) if xs else 0.0 print("=" * 60) print(f"Aggregate (@{k})") if recalls: print(f"Recall: {_avg(recalls):.4f}") print(f"MRR: {_avg(mrrs):.4f}") if ndcgs: print(f"nDCG: {_avg(ndcgs):.4f}") if not (recalls or ndcgs): print("No relevance labels provided; nothing to score.") return 0 if __name__ == "__main__": raise SystemExit(main())