Spaces:
Sleeping
Sleeping
| """Retrieval evaluation script for MediaStorm RAG. | |
| Measures Precision@1, Recall@5, MRR, NDCG@5 against curated ground truth queries. | |
| Zero dependencies beyond the project itself. | |
| Usage: | |
| python eval_retrieval.py [--verbose] | |
| """ | |
| import asyncio | |
| import math | |
| import time | |
| # --------------------------------------------------------------------------- | |
| # Metrics | |
| # --------------------------------------------------------------------------- | |
| def precision_at_1(retrieved_ids: list[str], expected_ids: set[str]) -> float: | |
| """1.0 if top result is relevant, 0.0 otherwise.""" | |
| if not retrieved_ids or not expected_ids: | |
| return 0.0 | |
| return 1.0 if retrieved_ids[0] in expected_ids else 0.0 | |
| def recall_at_k(retrieved_ids: list[str], expected_ids: set[str], k: int = 5) -> float: | |
| """Fraction of expected docs found in top-k results.""" | |
| if not expected_ids: | |
| return 1.0 # vacuous truth | |
| found = set(retrieved_ids[:k]) & expected_ids | |
| return len(found) / len(expected_ids) | |
| def mrr(retrieved_ids: list[str], expected_ids: set[str]) -> float: | |
| """Mean Reciprocal Rank — 1/rank of first relevant result.""" | |
| if not expected_ids: | |
| return 0.0 | |
| for i, rid in enumerate(retrieved_ids): | |
| if rid in expected_ids: | |
| return 1.0 / (i + 1) | |
| return 0.0 | |
| def ndcg_at_k(retrieved_ids: list[str], expected_ids: set[str], k: int = 5) -> float: | |
| """Normalized Discounted Cumulative Gain at k.""" | |
| if not expected_ids: | |
| return 0.0 | |
| # DCG | |
| dcg = 0.0 | |
| for i, rid in enumerate(retrieved_ids[:k]): | |
| if rid in expected_ids: | |
| dcg += 1.0 / math.log2(i + 2) # i+2 because log2(1)=0 | |
| # Ideal DCG (all relevant docs at top) | |
| ideal_count = min(len(expected_ids), k) | |
| idcg = sum(1.0 / math.log2(i + 2) for i in range(ideal_count)) | |
| return dcg / idcg if idcg > 0 else 0.0 | |
| # --------------------------------------------------------------------------- | |
| # Ground truth queries | |
| # --------------------------------------------------------------------------- | |
| EVAL_QUERIES = [ | |
| # --- GEOGRAPHIC (5) --- | |
| { | |
| "query": "Stories about the war in Congo", | |
| "expected": {"f6cea0f9", "3d1c98e7", "e1018d23", "2f638ea0"}, | |
| "category": "geographic", | |
| }, | |
| { | |
| "query": "Documentaries set in Afghanistan", | |
| "expected": {"d7982140", "f38a381a", "5e34b1f2", "b3fc2f4f", "dab272b4"}, | |
| "category": "geographic", | |
| }, | |
| { | |
| "query": "Stories about East Africa — Kenya, Ethiopia, Somalia", | |
| "expected": {"44004365", "b37d6691", "d371bdd4", "f9427ae8", "1d972cf2"}, | |
| "category": "geographic", | |
| }, | |
| { | |
| "query": "Stories filmed in Latin America or Mexico", | |
| "expected": {"13e631b1", "6ed4cd6a", "7708bf66", "ebe28ea7", "5fa24a64"}, | |
| "category": "geographic", | |
| }, | |
| { | |
| "query": "Stories about the Middle East conflict — Israel Palestine", | |
| "expected": {"aab689a6"}, | |
| "category": "geographic", | |
| }, | |
| # --- THEMATIC (5) --- | |
| { | |
| "query": "Stories about PTSD and veterans returning from war", | |
| "expected": {"e53b9d54", "b3fc2f4f", "dab272b4", "8f3f7b47", "5936e80e"}, | |
| "category": "thematic", | |
| }, | |
| { | |
| "query": "Climate change and environmental destruction", | |
| "expected": {"7233cf20", "44004365", "d371bdd4", "deb75fcf", "18b8f8d9", "b6f35a10"}, | |
| "category": "thematic", | |
| }, | |
| { | |
| "query": "Child marriage and women's rights", | |
| "expected": {"5e34b1f2", "b37d6691", "9b36adb3", "65c1fa57", "ebe28ea7"}, | |
| "category": "thematic", | |
| }, | |
| { | |
| "query": "Wildlife conservation and endangered species", | |
| "expected": {"f9427ae8", "c6fc31b3", "e1018d23", "39431e99", "b8c0f1a0"}, | |
| "category": "thematic", | |
| }, | |
| { | |
| "query": "Immigration and refugee stories", | |
| "expected": {"49866bbe", "f4fe3cbf", "6ed4cd6a", "ce125ae0", "6e613637"}, | |
| "category": "thematic", | |
| }, | |
| # --- TEMPORAL (4) --- | |
| { | |
| "query": "MediaStorm's earliest stories from 2005-2006", | |
| "expected": {"5936e80e", "88995eea", "0a2d36a5", "13e631b1", "5fa24a64", "214b44a0", "9e5325aa"}, | |
| "category": "temporal", | |
| }, | |
| { | |
| "query": "Recent stories from 2022 to 2025", | |
| "expected": {"b8c0f1a0", "609d8d9f", "64b3132a", "f9427ae8", "051af735"}, | |
| "category": "temporal", | |
| }, | |
| { | |
| "query": "Stories from the 2008 financial crisis era", | |
| "expected": {"5ce6a28d", "b333209b", "732657e6"}, | |
| "category": "temporal", | |
| }, | |
| { | |
| "query": "Stories published around 2010", | |
| "expected": {"44004365", "575bf728", "9b36adb3", "fbd54b9c", "826d329f"}, | |
| "category": "temporal", | |
| }, | |
| # --- PEOPLE (4) --- | |
| { | |
| "query": "Stories about Sebastiao Salgado", | |
| "expected": {"3fa4c5e5"}, | |
| "category": "people", | |
| }, | |
| { | |
| "query": "Stories featuring Don McCullin", | |
| "expected": {"3f5a3517"}, | |
| "category": "people", | |
| }, | |
| { | |
| "query": "Stories about Ai Weiwei", | |
| "expected": {"c346cb01"}, | |
| "category": "people", | |
| }, | |
| { | |
| "query": "Stories about Angelina Jolie and humanitarian work", | |
| "expected": {"fade2a94"}, | |
| "category": "people", | |
| }, | |
| # --- GENRE/FORMAT (4) --- | |
| { | |
| "query": "Photo essays in the archive", | |
| "expected": {"0a2d36a5", "13e631b1", "34d720e4", "732657e6", "d7982140", "c3d52625", "9e5325aa", "e53b9d54"}, | |
| "category": "genre", | |
| }, | |
| { | |
| "query": "Interactive multimedia projects or crisis guides", | |
| "expected": {"1815903a", "6b84f13f", "5be0d7ec", "05208857", "aab689a6"}, | |
| "category": "genre", | |
| }, | |
| { | |
| "query": "Video documentaries about family and aging", | |
| "expected": {"176e4cd9", "7e8268de", "88995eea", "7f8e385f", "4c2f60cf"}, | |
| "category": "genre", | |
| }, | |
| { | |
| "query": "Animated or motion design pieces", | |
| "expected": {"018cbb6a", "85d5056b", "5ae39bb8"}, | |
| "category": "genre", | |
| }, | |
| # --- AWARDS (4) --- | |
| { | |
| "query": "Emmy award winning stories", | |
| "expected": {"49866bbe", "988d3b60", "732657e6", "9b36adb3", "d266b644", "bac708fd", "dc0749e7", "e4cb243e", "e53b9d54"}, | |
| "category": "awards", | |
| }, | |
| { | |
| "query": "World Press Photo winners", | |
| "expected": {"44004365", "176e4cd9", "575bf728", "7cc092f6", "87f894da", "127a7e90"}, | |
| "category": "awards", | |
| }, | |
| { | |
| "query": "Award-winning stories about the Iraq war", | |
| "expected": {"e53b9d54", "5936e80e"}, | |
| "category": "awards", | |
| }, | |
| { | |
| "query": "Stories that won at Webby Awards", | |
| "expected": {"176e4cd9", "5936e80e", "44004365", "575bf728"}, | |
| "category": "awards", | |
| }, | |
| # --- EDGE CASES: should return NO relevant results (4) --- | |
| { | |
| "query": "Quantum computing breakthroughs in 2024", | |
| "expected": set(), | |
| "category": "edge_no_match", | |
| }, | |
| { | |
| "query": "Best Italian pasta recipes from Tuscany", | |
| "expected": set(), | |
| "category": "edge_no_match", | |
| }, | |
| { | |
| "query": "Taylor Swift concert tour dates", | |
| "expected": set(), | |
| "category": "edge_no_match", | |
| }, | |
| { | |
| "query": "Stock market trading strategies and cryptocurrency", | |
| "expected": set(), | |
| "category": "edge_no_match", | |
| }, | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Runner | |
| # --------------------------------------------------------------------------- | |
| def _resolve_uids(queries: list[dict], all_ids: list[str]) -> list[dict]: | |
| """Resolve short UID prefixes to full UUIDs from ChromaDB.""" | |
| prefix_map = {} | |
| for full_id in all_ids: | |
| prefix = full_id.split("-")[0] | |
| prefix_map[prefix] = full_id | |
| resolved = [] | |
| for q in queries: | |
| expected = set() | |
| for uid in q["expected"]: | |
| if uid in prefix_map: | |
| expected.add(prefix_map[uid]) | |
| else: | |
| expected.add(uid) # keep as-is if already full | |
| resolved.append({**q, "expected": expected}) | |
| return resolved | |
| async def run_eval(verbose: bool = False, quiet: bool = False, pipeline: bool = False) -> dict: | |
| """Run evaluation and return aggregate metrics. | |
| Args: | |
| pipeline: If True, run the full pipeline (retriever + Gemini filter). | |
| Requires GEMINI_API_KEY. Measures what the user actually sees. | |
| """ | |
| from mediastorm.config import CHROMADB_PATH, BM25_INDEX_PATH | |
| from mediastorm.vectorize.store import VectorStore | |
| from mediastorm.vectorize.embedder import Embedder | |
| from mediastorm.vectorize.bm25_store import BM25Store | |
| from mediastorm.rag.retriever import HybridRetriever | |
| from mediastorm.rag.router import QueryRouter | |
| store = VectorStore(path=CHROMADB_PATH) | |
| embedder = Embedder() | |
| bm25 = BM25Store(path=BM25_INDEX_PATH) | |
| bm25.load() | |
| router = QueryRouter() | |
| retriever = HybridRetriever( | |
| vector_store=store, | |
| bm25_store=bm25, | |
| embedder=embedder, | |
| router=router, | |
| top_k_final=5, | |
| ) | |
| # Pipeline mode: build link lookup and import generator | |
| link_lookup: dict[str, str] = {} | |
| if pipeline: | |
| from mediastorm.api import _build_link_lookup | |
| from mediastorm.rag.generator import generate_response | |
| link_lookup = await _build_link_lookup() | |
| # Resolve short UIDs to full UUIDs | |
| all_ids = store._stories.get(include=[])["ids"] | |
| queries = _resolve_uids(EVAL_QUERIES, all_ids) | |
| results = [] | |
| category_results: dict[str, list] = {} | |
| if not quiet: | |
| print("=" * 70) | |
| print("MediaStorm RAG — Retrieval Evaluation") | |
| print("=" * 70) | |
| print() | |
| for i, q in enumerate(queries): | |
| query = q["query"] | |
| expected = q["expected"] | |
| category = q["category"] | |
| start = time.time() | |
| retrieval = await retriever.retrieve(query) | |
| duration = time.time() - start | |
| # Pipeline mode: filter through Gemini (same as /api/search) | |
| if pipeline and retrieval.stories: | |
| full_text = await generate_response( | |
| query, retrieval, link_lookup=link_lookup, | |
| ) | |
| retrieval_stories = [ | |
| s for s in retrieval.stories | |
| if link_lookup.get(s["id"], "") and link_lookup[s["id"]] in full_text | |
| ] | |
| else: | |
| retrieval_stories = retrieval.stories | |
| retrieved_ids = [s["id"] for s in retrieval_stories] | |
| if category == "edge_no_match": | |
| success = len(retrieval_stories) == 0 | |
| row = { | |
| "query": query, | |
| "category": category, | |
| "success": success, | |
| "num_returned": len(retrieval_stories), | |
| "expected": [], | |
| "duration": duration, | |
| } | |
| status = "PASS" if success else "FAIL" | |
| if verbose and not quiet: | |
| print(f" [{status}] Q{i+1}: {query}") | |
| if not success: | |
| names = [s.get("metadata", {}).get("name", s["id"]) for s in retrieval_stories] | |
| print(f" Unexpected results: {names}") | |
| else: | |
| p1 = precision_at_1(retrieved_ids, expected) | |
| r5 = recall_at_k(retrieved_ids, expected, k=5) | |
| m = mrr(retrieved_ids, expected) | |
| n5 = ndcg_at_k(retrieved_ids, expected, k=5) | |
| row = { | |
| "query": query, | |
| "category": category, | |
| "precision_at_1": p1, | |
| "recall_at_5": r5, | |
| "mrr": m, | |
| "ndcg_at_5": n5, | |
| "retrieved": retrieved_ids, | |
| "expected": list(expected), | |
| "missed": list(expected - set(retrieved_ids)), | |
| "duration": duration, | |
| } | |
| if verbose and not quiet: | |
| status = "PASS" if r5 > 0 else "MISS" | |
| print(f" [{status}] Q{i+1}: {query}") | |
| print(f" P@1={p1:.0f} R@5={r5:.2f} MRR={m:.2f} NDCG@5={n5:.2f} ({duration:.1f}s)") | |
| if r5 < 1.0: | |
| found = set(retrieved_ids) & expected | |
| missed_v = expected - set(retrieved_ids) | |
| if missed_v: | |
| print(f" Missed: {missed_v}") | |
| results.append(row) | |
| category_results.setdefault(category, []).append(row) | |
| # Split into semantic, filter, and edge queries | |
| _SEMANTIC_CATS = {"geographic", "thematic", "people"} | |
| _FILTER_CATS = {"temporal", "genre", "awards"} | |
| semantic = [r for r in results if r["category"] in _SEMANTIC_CATS] | |
| filter_q = [r for r in results if r["category"] in _FILTER_CATS] | |
| edge = [r for r in results if r["category"] == "edge_no_match"] | |
| scored = [r for r in results if r["category"] != "edge_no_match"] | |
| def _avg(rows, key): | |
| return sum(r[key] for r in rows) / len(rows) if rows else 0.0 | |
| edge_pass = sum(1 for r in edge if r["success"]) | |
| if not quiet: | |
| print() | |
| print("-" * 70) | |
| print("CORE SEMANTIC SEARCH (people, thematic, geographic)") | |
| print("-" * 70) | |
| print(f" Precision@1: {_avg(semantic, 'precision_at_1'):.2f} (target ≥ 0.85)") | |
| print(f" Recall@5: {_avg(semantic, 'recall_at_5'):.2f} (target ≥ 0.90)") | |
| print(f" MRR: {_avg(semantic, 'mrr'):.2f}") | |
| print(f" NDCG@5: {_avg(semantic, 'ndcg_at_5'):.2f}") | |
| print() | |
| print("FILTER QUERIES (temporal, genre, awards)") | |
| print("-" * 70) | |
| print(f" Precision@1: {_avg(filter_q, 'precision_at_1'):.2f}") | |
| print(f" Recall@5: {_avg(filter_q, 'recall_at_5'):.2f}") | |
| print(f" MRR: {_avg(filter_q, 'mrr'):.2f}") | |
| print() | |
| print("EDGE CASES") | |
| print("-" * 70) | |
| print(f" Correctly rejected: {edge_pass}/{len(edge)}") | |
| print() | |
| # Per-category breakdown | |
| print("PER-CATEGORY BREAKDOWN") | |
| print("-" * 70) | |
| for cat, rows in category_results.items(): | |
| if cat == "edge_no_match": | |
| passed = sum(1 for r in rows if r["success"]) | |
| print(f" {cat:20s} {passed}/{len(rows)} rejected") | |
| else: | |
| label = "semantic" if cat in _SEMANTIC_CATS else "filter" | |
| cat_r5 = _avg(rows, "recall_at_5") | |
| cat_p1 = _avg(rows, "precision_at_1") | |
| print(f" {cat:20s} P@1={cat_p1:.2f} R@5={cat_r5:.2f} ({len(rows)} queries) [{label}]") | |
| print("=" * 70) | |
| return { | |
| "semantic_precision_at_1": _avg(semantic, "precision_at_1"), | |
| "semantic_recall_at_5": _avg(semantic, "recall_at_5"), | |
| "semantic_mrr": _avg(semantic, "mrr"), | |
| "semantic_ndcg_at_5": _avg(semantic, "ndcg_at_5"), | |
| "filter_precision_at_1": _avg(filter_q, "precision_at_1"), | |
| "filter_recall_at_5": _avg(filter_q, "recall_at_5"), | |
| "edge_pass_rate": edge_pass / len(edge) if edge else 1.0, | |
| "details": results, | |
| } | |
| if __name__ == "__main__": | |
| import sys | |
| verbose = "--verbose" in sys.argv or "-v" in sys.argv | |
| pipeline = "--pipeline" in sys.argv | |
| asyncio.run(run_eval(verbose=verbose, pipeline=pipeline)) | |