mediastorm / eval_retrieval.py
remdms's picture
feat(eval): add --pipeline flag for full pipeline eval (retriever + Gemini filter)
69e1201
"""Retrieval evaluation script for MediaStorm RAG.
Measures Precision@1, Recall@5, MRR, NDCG@5 against curated ground truth queries.
Zero dependencies beyond the project itself.
Usage:
python eval_retrieval.py [--verbose]
"""
import asyncio
import math
import time
# ---------------------------------------------------------------------------
# Metrics
# ---------------------------------------------------------------------------
def precision_at_1(retrieved_ids: list[str], expected_ids: set[str]) -> float:
"""1.0 if top result is relevant, 0.0 otherwise."""
if not retrieved_ids or not expected_ids:
return 0.0
return 1.0 if retrieved_ids[0] in expected_ids else 0.0
def recall_at_k(retrieved_ids: list[str], expected_ids: set[str], k: int = 5) -> float:
"""Fraction of expected docs found in top-k results."""
if not expected_ids:
return 1.0 # vacuous truth
found = set(retrieved_ids[:k]) & expected_ids
return len(found) / len(expected_ids)
def mrr(retrieved_ids: list[str], expected_ids: set[str]) -> float:
"""Mean Reciprocal Rank — 1/rank of first relevant result."""
if not expected_ids:
return 0.0
for i, rid in enumerate(retrieved_ids):
if rid in expected_ids:
return 1.0 / (i + 1)
return 0.0
def ndcg_at_k(retrieved_ids: list[str], expected_ids: set[str], k: int = 5) -> float:
"""Normalized Discounted Cumulative Gain at k."""
if not expected_ids:
return 0.0
# DCG
dcg = 0.0
for i, rid in enumerate(retrieved_ids[:k]):
if rid in expected_ids:
dcg += 1.0 / math.log2(i + 2) # i+2 because log2(1)=0
# Ideal DCG (all relevant docs at top)
ideal_count = min(len(expected_ids), k)
idcg = sum(1.0 / math.log2(i + 2) for i in range(ideal_count))
return dcg / idcg if idcg > 0 else 0.0
# ---------------------------------------------------------------------------
# Ground truth queries
# ---------------------------------------------------------------------------
EVAL_QUERIES = [
# --- GEOGRAPHIC (5) ---
{
"query": "Stories about the war in Congo",
"expected": {"f6cea0f9", "3d1c98e7", "e1018d23", "2f638ea0"},
"category": "geographic",
},
{
"query": "Documentaries set in Afghanistan",
"expected": {"d7982140", "f38a381a", "5e34b1f2", "b3fc2f4f", "dab272b4"},
"category": "geographic",
},
{
"query": "Stories about East Africa — Kenya, Ethiopia, Somalia",
"expected": {"44004365", "b37d6691", "d371bdd4", "f9427ae8", "1d972cf2"},
"category": "geographic",
},
{
"query": "Stories filmed in Latin America or Mexico",
"expected": {"13e631b1", "6ed4cd6a", "7708bf66", "ebe28ea7", "5fa24a64"},
"category": "geographic",
},
{
"query": "Stories about the Middle East conflict — Israel Palestine",
"expected": {"aab689a6"},
"category": "geographic",
},
# --- THEMATIC (5) ---
{
"query": "Stories about PTSD and veterans returning from war",
"expected": {"e53b9d54", "b3fc2f4f", "dab272b4", "8f3f7b47", "5936e80e"},
"category": "thematic",
},
{
"query": "Climate change and environmental destruction",
"expected": {"7233cf20", "44004365", "d371bdd4", "deb75fcf", "18b8f8d9", "b6f35a10"},
"category": "thematic",
},
{
"query": "Child marriage and women's rights",
"expected": {"5e34b1f2", "b37d6691", "9b36adb3", "65c1fa57", "ebe28ea7"},
"category": "thematic",
},
{
"query": "Wildlife conservation and endangered species",
"expected": {"f9427ae8", "c6fc31b3", "e1018d23", "39431e99", "b8c0f1a0"},
"category": "thematic",
},
{
"query": "Immigration and refugee stories",
"expected": {"49866bbe", "f4fe3cbf", "6ed4cd6a", "ce125ae0", "6e613637"},
"category": "thematic",
},
# --- TEMPORAL (4) ---
{
"query": "MediaStorm's earliest stories from 2005-2006",
"expected": {"5936e80e", "88995eea", "0a2d36a5", "13e631b1", "5fa24a64", "214b44a0", "9e5325aa"},
"category": "temporal",
},
{
"query": "Recent stories from 2022 to 2025",
"expected": {"b8c0f1a0", "609d8d9f", "64b3132a", "f9427ae8", "051af735"},
"category": "temporal",
},
{
"query": "Stories from the 2008 financial crisis era",
"expected": {"5ce6a28d", "b333209b", "732657e6"},
"category": "temporal",
},
{
"query": "Stories published around 2010",
"expected": {"44004365", "575bf728", "9b36adb3", "fbd54b9c", "826d329f"},
"category": "temporal",
},
# --- PEOPLE (4) ---
{
"query": "Stories about Sebastiao Salgado",
"expected": {"3fa4c5e5"},
"category": "people",
},
{
"query": "Stories featuring Don McCullin",
"expected": {"3f5a3517"},
"category": "people",
},
{
"query": "Stories about Ai Weiwei",
"expected": {"c346cb01"},
"category": "people",
},
{
"query": "Stories about Angelina Jolie and humanitarian work",
"expected": {"fade2a94"},
"category": "people",
},
# --- GENRE/FORMAT (4) ---
{
"query": "Photo essays in the archive",
"expected": {"0a2d36a5", "13e631b1", "34d720e4", "732657e6", "d7982140", "c3d52625", "9e5325aa", "e53b9d54"},
"category": "genre",
},
{
"query": "Interactive multimedia projects or crisis guides",
"expected": {"1815903a", "6b84f13f", "5be0d7ec", "05208857", "aab689a6"},
"category": "genre",
},
{
"query": "Video documentaries about family and aging",
"expected": {"176e4cd9", "7e8268de", "88995eea", "7f8e385f", "4c2f60cf"},
"category": "genre",
},
{
"query": "Animated or motion design pieces",
"expected": {"018cbb6a", "85d5056b", "5ae39bb8"},
"category": "genre",
},
# --- AWARDS (4) ---
{
"query": "Emmy award winning stories",
"expected": {"49866bbe", "988d3b60", "732657e6", "9b36adb3", "d266b644", "bac708fd", "dc0749e7", "e4cb243e", "e53b9d54"},
"category": "awards",
},
{
"query": "World Press Photo winners",
"expected": {"44004365", "176e4cd9", "575bf728", "7cc092f6", "87f894da", "127a7e90"},
"category": "awards",
},
{
"query": "Award-winning stories about the Iraq war",
"expected": {"e53b9d54", "5936e80e"},
"category": "awards",
},
{
"query": "Stories that won at Webby Awards",
"expected": {"176e4cd9", "5936e80e", "44004365", "575bf728"},
"category": "awards",
},
# --- EDGE CASES: should return NO relevant results (4) ---
{
"query": "Quantum computing breakthroughs in 2024",
"expected": set(),
"category": "edge_no_match",
},
{
"query": "Best Italian pasta recipes from Tuscany",
"expected": set(),
"category": "edge_no_match",
},
{
"query": "Taylor Swift concert tour dates",
"expected": set(),
"category": "edge_no_match",
},
{
"query": "Stock market trading strategies and cryptocurrency",
"expected": set(),
"category": "edge_no_match",
},
]
# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------
def _resolve_uids(queries: list[dict], all_ids: list[str]) -> list[dict]:
"""Resolve short UID prefixes to full UUIDs from ChromaDB."""
prefix_map = {}
for full_id in all_ids:
prefix = full_id.split("-")[0]
prefix_map[prefix] = full_id
resolved = []
for q in queries:
expected = set()
for uid in q["expected"]:
if uid in prefix_map:
expected.add(prefix_map[uid])
else:
expected.add(uid) # keep as-is if already full
resolved.append({**q, "expected": expected})
return resolved
async def run_eval(verbose: bool = False, quiet: bool = False, pipeline: bool = False) -> dict:
"""Run evaluation and return aggregate metrics.
Args:
pipeline: If True, run the full pipeline (retriever + Gemini filter).
Requires GEMINI_API_KEY. Measures what the user actually sees.
"""
from mediastorm.config import CHROMADB_PATH, BM25_INDEX_PATH
from mediastorm.vectorize.store import VectorStore
from mediastorm.vectorize.embedder import Embedder
from mediastorm.vectorize.bm25_store import BM25Store
from mediastorm.rag.retriever import HybridRetriever
from mediastorm.rag.router import QueryRouter
store = VectorStore(path=CHROMADB_PATH)
embedder = Embedder()
bm25 = BM25Store(path=BM25_INDEX_PATH)
bm25.load()
router = QueryRouter()
retriever = HybridRetriever(
vector_store=store,
bm25_store=bm25,
embedder=embedder,
router=router,
top_k_final=5,
)
# Pipeline mode: build link lookup and import generator
link_lookup: dict[str, str] = {}
if pipeline:
from mediastorm.api import _build_link_lookup
from mediastorm.rag.generator import generate_response
link_lookup = await _build_link_lookup()
# Resolve short UIDs to full UUIDs
all_ids = store._stories.get(include=[])["ids"]
queries = _resolve_uids(EVAL_QUERIES, all_ids)
results = []
category_results: dict[str, list] = {}
if not quiet:
print("=" * 70)
print("MediaStorm RAG — Retrieval Evaluation")
print("=" * 70)
print()
for i, q in enumerate(queries):
query = q["query"]
expected = q["expected"]
category = q["category"]
start = time.time()
retrieval = await retriever.retrieve(query)
duration = time.time() - start
# Pipeline mode: filter through Gemini (same as /api/search)
if pipeline and retrieval.stories:
full_text = await generate_response(
query, retrieval, link_lookup=link_lookup,
)
retrieval_stories = [
s for s in retrieval.stories
if link_lookup.get(s["id"], "") and link_lookup[s["id"]] in full_text
]
else:
retrieval_stories = retrieval.stories
retrieved_ids = [s["id"] for s in retrieval_stories]
if category == "edge_no_match":
success = len(retrieval_stories) == 0
row = {
"query": query,
"category": category,
"success": success,
"num_returned": len(retrieval_stories),
"expected": [],
"duration": duration,
}
status = "PASS" if success else "FAIL"
if verbose and not quiet:
print(f" [{status}] Q{i+1}: {query}")
if not success:
names = [s.get("metadata", {}).get("name", s["id"]) for s in retrieval_stories]
print(f" Unexpected results: {names}")
else:
p1 = precision_at_1(retrieved_ids, expected)
r5 = recall_at_k(retrieved_ids, expected, k=5)
m = mrr(retrieved_ids, expected)
n5 = ndcg_at_k(retrieved_ids, expected, k=5)
row = {
"query": query,
"category": category,
"precision_at_1": p1,
"recall_at_5": r5,
"mrr": m,
"ndcg_at_5": n5,
"retrieved": retrieved_ids,
"expected": list(expected),
"missed": list(expected - set(retrieved_ids)),
"duration": duration,
}
if verbose and not quiet:
status = "PASS" if r5 > 0 else "MISS"
print(f" [{status}] Q{i+1}: {query}")
print(f" P@1={p1:.0f} R@5={r5:.2f} MRR={m:.2f} NDCG@5={n5:.2f} ({duration:.1f}s)")
if r5 < 1.0:
found = set(retrieved_ids) & expected
missed_v = expected - set(retrieved_ids)
if missed_v:
print(f" Missed: {missed_v}")
results.append(row)
category_results.setdefault(category, []).append(row)
# Split into semantic, filter, and edge queries
_SEMANTIC_CATS = {"geographic", "thematic", "people"}
_FILTER_CATS = {"temporal", "genre", "awards"}
semantic = [r for r in results if r["category"] in _SEMANTIC_CATS]
filter_q = [r for r in results if r["category"] in _FILTER_CATS]
edge = [r for r in results if r["category"] == "edge_no_match"]
scored = [r for r in results if r["category"] != "edge_no_match"]
def _avg(rows, key):
return sum(r[key] for r in rows) / len(rows) if rows else 0.0
edge_pass = sum(1 for r in edge if r["success"])
if not quiet:
print()
print("-" * 70)
print("CORE SEMANTIC SEARCH (people, thematic, geographic)")
print("-" * 70)
print(f" Precision@1: {_avg(semantic, 'precision_at_1'):.2f} (target ≥ 0.85)")
print(f" Recall@5: {_avg(semantic, 'recall_at_5'):.2f} (target ≥ 0.90)")
print(f" MRR: {_avg(semantic, 'mrr'):.2f}")
print(f" NDCG@5: {_avg(semantic, 'ndcg_at_5'):.2f}")
print()
print("FILTER QUERIES (temporal, genre, awards)")
print("-" * 70)
print(f" Precision@1: {_avg(filter_q, 'precision_at_1'):.2f}")
print(f" Recall@5: {_avg(filter_q, 'recall_at_5'):.2f}")
print(f" MRR: {_avg(filter_q, 'mrr'):.2f}")
print()
print("EDGE CASES")
print("-" * 70)
print(f" Correctly rejected: {edge_pass}/{len(edge)}")
print()
# Per-category breakdown
print("PER-CATEGORY BREAKDOWN")
print("-" * 70)
for cat, rows in category_results.items():
if cat == "edge_no_match":
passed = sum(1 for r in rows if r["success"])
print(f" {cat:20s} {passed}/{len(rows)} rejected")
else:
label = "semantic" if cat in _SEMANTIC_CATS else "filter"
cat_r5 = _avg(rows, "recall_at_5")
cat_p1 = _avg(rows, "precision_at_1")
print(f" {cat:20s} P@1={cat_p1:.2f} R@5={cat_r5:.2f} ({len(rows)} queries) [{label}]")
print("=" * 70)
return {
"semantic_precision_at_1": _avg(semantic, "precision_at_1"),
"semantic_recall_at_5": _avg(semantic, "recall_at_5"),
"semantic_mrr": _avg(semantic, "mrr"),
"semantic_ndcg_at_5": _avg(semantic, "ndcg_at_5"),
"filter_precision_at_1": _avg(filter_q, "precision_at_1"),
"filter_recall_at_5": _avg(filter_q, "recall_at_5"),
"edge_pass_rate": edge_pass / len(edge) if edge else 1.0,
"details": results,
}
if __name__ == "__main__":
import sys
verbose = "--verbose" in sys.argv or "-v" in sys.argv
pipeline = "--pipeline" in sys.argv
asyncio.run(run_eval(verbose=verbose, pipeline=pipeline))