mediastorm / benchmark.py
remdms's picture
feat: add benchmark.py with full performance baseline
e8e3c0f
"""Performance benchmark for MediaStorm RAG.
Measures quality (via eval_retrieval), latency, memory, and size.
Saves reports to data/benchmarks/ for comparison across optimizations.
Usage:
python benchmark.py # full benchmark (quality + performance)
python benchmark.py --quick # performance only (skip Gemini eval queries)
"""
import asyncio
import os
import resource
import statistics
import subprocess
import sys
import time
import tracemalloc
from datetime import datetime
from pathlib import Path
def _dir_size_mb(path: str | Path) -> float:
"""Total size of a directory in MB."""
path = Path(path)
if not path.exists():
return 0.0
total = sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
return total / (1024 * 1024)
def _file_size_mb(path: str | Path) -> float:
path = Path(path)
return path.stat().st_size / (1024 * 1024) if path.exists() else 0.0
# ---------------------------------------------------------------------------
# Size benchmark
# ---------------------------------------------------------------------------
def bench_size() -> dict:
"""Measure disk footprint of key components."""
venv_size = _dir_size_mb(".venv")
model_size = _dir_size_mb("models")
chromadb_size = _dir_size_mb("data/chromadb")
bm25_size = _file_size_mb("data/bm25_index.pkl")
return {
"venv_mb": round(venv_size, 1),
"model_mb": round(model_size, 1),
"chromadb_mb": round(chromadb_size, 1),
"bm25_mb": round(bm25_size, 2),
"total_data_mb": round(model_size + chromadb_size + bm25_size, 1),
}
# ---------------------------------------------------------------------------
# Cold start benchmark
# ---------------------------------------------------------------------------
def bench_cold_start() -> dict:
"""Measure import + initialization time."""
t0 = time.perf_counter()
from mediastorm.vectorize.embedder import Embedder
t_import = time.perf_counter() - t0
t0 = time.perf_counter()
embedder = Embedder()
t_embedder = time.perf_counter() - t0
t0 = time.perf_counter()
from mediastorm.vectorize.store import VectorStore
from mediastorm.config import CHROMADB_PATH
store = VectorStore(path=CHROMADB_PATH)
t_chromadb = time.perf_counter() - t0
t0 = time.perf_counter()
from mediastorm.vectorize.bm25_store import BM25Store
from mediastorm.config import BM25_INDEX_PATH
bm25 = BM25Store(path=BM25_INDEX_PATH)
bm25.load()
t_bm25 = time.perf_counter() - t0
# Warmup embedding
t0 = time.perf_counter()
embedder.embed_texts(["warmup"])
t_warmup = time.perf_counter() - t0
total = t_import + t_embedder + t_chromadb + t_bm25 + t_warmup
return {
"import_s": round(t_import, 3),
"embedder_init_s": round(t_embedder, 3),
"chromadb_init_s": round(t_chromadb, 3),
"bm25_load_s": round(t_bm25, 3),
"warmup_s": round(t_warmup, 3),
"total_cold_start_s": round(total, 3),
}
# ---------------------------------------------------------------------------
# Latency benchmark
# ---------------------------------------------------------------------------
async def bench_latency() -> dict:
"""Measure query latency over representative queries."""
from mediastorm.config import CHROMADB_PATH, BM25_INDEX_PATH
from mediastorm.vectorize.store import VectorStore
from mediastorm.vectorize.embedder import Embedder
from mediastorm.vectorize.bm25_store import BM25Store
from mediastorm.rag.retriever import HybridRetriever
from mediastorm.rag.router import QueryRouter
store = VectorStore(path=CHROMADB_PATH)
embedder = Embedder()
bm25 = BM25Store(path=BM25_INDEX_PATH)
bm25.load()
router = QueryRouter()
retriever = HybridRetriever(
vector_store=store, bm25_store=bm25,
embedder=embedder, router=router, top_k_final=5,
)
queries = [
"Stories about the war in Congo",
"Climate change and environmental destruction",
"Emmy award winning stories",
"Stories about Sebastiao Salgado",
"MediaStorm's earliest stories from 2005-2006",
"Photo essays in the archive",
"Stories filmed in Latin America or Mexico",
"Wildlife conservation and endangered species",
"Stories about PTSD and veterans",
"Interactive crisis guides",
]
# Warmup
await retriever.retrieve("warmup query")
durations_ms = []
for q in queries:
t0 = time.perf_counter()
await retriever.retrieve(q)
dur = (time.perf_counter() - t0) * 1000
durations_ms.append(dur)
durations_ms.sort()
return {
"queries": len(queries),
"mean_ms": round(statistics.mean(durations_ms), 1),
"median_ms": round(statistics.median(durations_ms), 1),
"p95_ms": round(durations_ms[int(len(durations_ms) * 0.95)], 1),
"min_ms": round(durations_ms[0], 1),
"max_ms": round(durations_ms[-1], 1),
}
# ---------------------------------------------------------------------------
# Memory benchmark
# ---------------------------------------------------------------------------
def bench_memory() -> dict:
"""Measure peak memory usage."""
ru = resource.getrusage(resource.RUSAGE_SELF)
# macOS reports in bytes, Linux in KB
peak_mb = ru.ru_maxrss / (1024 * 1024) if sys.platform == "darwin" else ru.ru_maxrss / 1024
return {
"peak_rss_mb": round(peak_mb, 1),
}
# ---------------------------------------------------------------------------
# Quality benchmark (via eval_retrieval)
# ---------------------------------------------------------------------------
async def bench_quality() -> dict:
"""Run retrieval eval and extract aggregate metrics."""
from eval_retrieval import run_eval
results = await run_eval(verbose=True)
return {
"semantic_p1": round(results["semantic_precision_at_1"], 3),
"semantic_r5": round(results["semantic_recall_at_5"], 3),
"semantic_mrr": round(results["semantic_mrr"], 3),
"semantic_ndcg5": round(results["semantic_ndcg_at_5"], 3),
"filter_p1": round(results["filter_precision_at_1"], 3),
"filter_r5": round(results["filter_recall_at_5"], 3),
"edge_pass_rate": round(results["edge_pass_rate"], 3),
}
# ---------------------------------------------------------------------------
# Report
# ---------------------------------------------------------------------------
def _format_report(
size: dict,
cold_start: dict,
latency: dict,
memory: dict,
quality: dict | None,
) -> str:
now = datetime.now().strftime("%Y-%m-%d %H:%M")
lines = [
f"# Benchmark Report — {now}",
"",
"## Size",
"",
f"| Component | Size |",
f"|---|---|",
f"| venv | {size['venv_mb']} MB |",
f"| ONNX model | {size['model_mb']} MB |",
f"| ChromaDB data | {size['chromadb_mb']} MB |",
f"| BM25 index | {size['bm25_mb']} MB |",
f"| **Total data** | **{size['total_data_mb']} MB** |",
"",
"## Cold Start",
"",
f"| Step | Time |",
f"|---|---|",
f"| Import | {cold_start['import_s']}s |",
f"| Embedder init | {cold_start['embedder_init_s']}s |",
f"| ChromaDB init | {cold_start['chromadb_init_s']}s |",
f"| BM25 load | {cold_start['bm25_load_s']}s |",
f"| Warmup embed | {cold_start['warmup_s']}s |",
f"| **Total** | **{cold_start['total_cold_start_s']}s** |",
"",
"## Latency ({} queries)".format(latency["queries"]),
"",
f"| Metric | Value |",
f"|---|---|",
f"| Mean | {latency['mean_ms']} ms |",
f"| Median (p50) | {latency['median_ms']} ms |",
f"| p95 | {latency['p95_ms']} ms |",
f"| Min | {latency['min_ms']} ms |",
f"| Max | {latency['max_ms']} ms |",
"",
"## Memory",
"",
f"| Metric | Value |",
f"|---|---|",
f"| Peak RSS | {memory['peak_rss_mb']} MB |",
]
if quality:
lines += [
"",
"## Quality (retrieval eval — 30 queries)",
"",
f"| Metric | Semantic | Filter |",
f"|---|---|---|",
f"| Precision@1 | {quality['semantic_p1']} | {quality['filter_p1']} |",
f"| Recall@5 | {quality['semantic_r5']} | {quality['filter_r5']} |",
f"| MRR | {quality['semantic_mrr']} | — |",
f"| NDCG@5 | {quality['semantic_ndcg5']} | — |",
f"| Edge rejection | {quality['edge_pass_rate']} | — |",
]
return "\n".join(lines) + "\n"
async def main():
quick = "--quick" in sys.argv
print("=" * 60)
print("MediaStorm RAG — Benchmark")
print("=" * 60)
print()
tracemalloc.start()
# Size
print("[1/5] Measuring size...")
size = bench_size()
# Cold start
print("[2/5] Measuring cold start...")
cold_start = bench_cold_start()
# Latency
print("[3/5] Measuring latency (10 queries)...")
latency = await bench_latency()
# Memory
print("[4/5] Measuring memory...")
memory = bench_memory()
# Quality
quality = None
if not quick:
print("[5/5] Running retrieval eval (30 queries via Gemini)...")
quality = await bench_quality()
else:
print("[5/5] Skipped (--quick mode)")
tracemalloc.stop()
# Report
report = _format_report(size, cold_start, latency, memory, quality)
print()
print(report)
# Save
out_dir = Path("data/benchmarks")
out_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y-%m-%d-%H%M")
out_path = out_dir / f"{ts}.md"
out_path.write_text(report)
print(f"Saved to {out_path}")
if __name__ == "__main__":
asyncio.run(main())