visual-rag-toolkit / demo /evaluation.py
Yeroyan's picture
refactor evaluation
d9f2c00 verified
"""Evaluation runner with UI updates."""
import hashlib
import json
import logging
import time
import traceback
from datetime import datetime
from typing import Any, Dict, List, Optional
import numpy as np
import streamlit as st
import torch
from qdrant_client.models import FieldCondition, Filter, MatchValue
from visual_rag import VisualEmbedder
from visual_rag.retrieval import MultiVectorRetriever
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
from demo.qdrant_utils import get_qdrant_credentials
TORCH_DTYPE_MAP = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
def _stable_uuid(text: str) -> str:
"""Generate a stable UUID from text (same as benchmark script)."""
hex_str = hashlib.sha256(text.encode("utf-8")).hexdigest()[:32]
return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
def _union_point_id(*, dataset_name: str, source_doc_id: str, union_namespace: Optional[str]) -> str:
"""Generate union point ID (same as benchmark script)."""
ns = f"{union_namespace}::{dataset_name}" if union_namespace else dataset_name
return _stable_uuid(f"{ns}::{source_doc_id}")
def _remap_qrels_to_union_ids(
qrels: Dict[str, Dict[str, int]],
corpus: List[Any],
dataset_name: str,
collection_name: str,
) -> Dict[str, Dict[str, int]]:
"""Remap qrels doc_ids from original format to union_doc_id format (matching benchmark)."""
id_map: Dict[str, str] = {}
for doc in corpus:
source_doc_id = str((doc.payload or {}).get("source_doc_id") or doc.doc_id)
id_map[str(doc.doc_id)] = _union_point_id(
dataset_name=dataset_name,
source_doc_id=source_doc_id,
union_namespace=collection_name,
)
remapped: Dict[str, Dict[str, int]] = {}
for qid, rels in qrels.items():
out_rels: Dict[str, int] = {}
for did, score in rels.items():
mapped = id_map.get(str(did))
if mapped:
out_rels[mapped] = int(score)
if out_rels:
remapped[qid] = out_rels
return remapped
def get_doc_id_from_result(r: Dict[str, Any], use_original: bool = True) -> str:
"""Extract document ID from search result.
Args:
r: Search result dict with 'id' and 'payload'
use_original: If True, prefer original doc_id for matching with qrels.
If False, prefer union_doc_id (Qdrant point ID).
"""
payload = r.get("payload", {})
if use_original:
doc_id = (
payload.get("doc_id")
or payload.get("source_doc_id")
or payload.get("corpus-id")
or payload.get("union_doc_id")
or str(r.get("id", ""))
)
else:
doc_id = (
payload.get("union_doc_id")
or str(r.get("id", ""))
or payload.get("doc_id")
)
return str(doc_id)
def run_evaluation_with_ui(config: Dict[str, Any]):
st.divider()
print("=" * 60)
print("[EVAL] Starting evaluation via UI")
print("=" * 60)
url, api_key = get_qdrant_credentials()
if not url:
st.error("QDRANT_URL not configured")
return
datasets = config.get("datasets", [])
collection = config["collection"]
model = config.get("model", "vidore/colpali-v1.3")
mode = config.get("mode", "single_full")
top_k = config.get("top_k", 100)
prefetch_k = config.get("prefetch_k", 256)
stage1_mode = config.get("stage1_mode", "tokens_vs_tiles")
stage1_k = config.get("stage1_k", 1000)
stage2_k = config.get("stage2_k", 300)
prefer_grpc = config.get("prefer_grpc", True)
torch_dtype = config.get("torch_dtype", "float16")
evaluation_scope = config.get("evaluation_scope", "union")
print(f"[EVAL] ═══════════════════════════════════════════════════")
print(f"[EVAL] Collection: {collection}")
print(f"[EVAL] Model: {model}")
print(f"[EVAL] Mode: {mode}, Scope: {evaluation_scope}")
print(f"[EVAL] Datasets: {datasets}")
print(f"[EVAL] Query embedding dtype: {torch_dtype} (vectors already indexed)")
print(f"[EVAL] ═══════════════════════════════════════════════════")
phase1_container = st.container()
phase2_container = st.container()
phase3_container = st.container()
results_container = st.container()
try:
with phase1_container:
st.markdown("##### πŸ€– Phase 1: Loading Model")
model_status = st.empty()
model_status.info(f"Loading `{model.split('/')[-1]}`...")
print(f"[EVAL] Loading embedder: {model}")
torch_dtype_obj = TORCH_DTYPE_MAP.get(torch_dtype, torch.float16)
qdrant_dtype = config.get("qdrant_vector_dtype", "float16")
output_dtype_obj = np.float16 if qdrant_dtype == "float16" else np.float32
embedder = VisualEmbedder(
model_name=model,
torch_dtype=torch_dtype_obj,
output_dtype=output_dtype_obj,
)
embedder._load_model()
print(f"[EVAL] Embedder loaded (torch_dtype={torch_dtype}, output_dtype={qdrant_dtype})")
model_status.success(f"βœ… Model `{model.split('/')[-1]}` loaded")
retriever_status = st.empty()
retriever_status.info(f"Connecting to collection `{collection}`...")
print(f"[EVAL] Connecting to Qdrant collection: {collection}")
retriever = MultiVectorRetriever(
collection_name=collection,
model_name=model,
qdrant_url=url,
qdrant_api_key=api_key,
prefer_grpc=prefer_grpc,
embedder=embedder,
)
print(f"[EVAL] Connected to Qdrant")
retriever_status.success(f"βœ… Connected to `{collection}`")
with phase2_container:
st.markdown("##### πŸ“š Phase 2: Loading Datasets")
dataset_data = {}
total_queries = 0
max_queries_per_ds = config.get("max_queries")
for ds_name in datasets:
ds_status = st.empty()
ds_short = ds_name.split("/")[-1]
ds_status.info(f"Loading `{ds_short}`...")
print(f"[EVAL] Loading dataset: {ds_name}")
corpus, queries, qrels = load_vidore_beir_dataset(ds_name)
print(f"[EVAL] Remapping qrels to union_doc_id format for collection={collection}")
remapped_qrels = _remap_qrels_to_union_ids(qrels, corpus, ds_name, collection)
print(f"[EVAL] Remapped {len(qrels)} -> {len(remapped_qrels)} queries with valid rels")
if evaluation_scope == "per_dataset" and max_queries_per_ds:
queries = queries[:max_queries_per_ds]
dataset_data[ds_name] = {
"queries": queries,
"qrels": remapped_qrels,
"num_docs": len(corpus),
}
total_queries += len(queries)
print(f"[EVAL] Loaded {ds_name}: {len(corpus)} docs, {len(queries)} queries")
ds_status.success(f"βœ… `{ds_short}`: {len(corpus)} docs, {len(queries)} queries")
if evaluation_scope == "union" and max_queries_per_ds and max_queries_per_ds < total_queries:
total_queries = max_queries_per_ds
print(f"[EVAL] Will limit to {total_queries} total queries (union mode)")
embed_status = st.empty()
embed_status.info(f"Embedding queries...")
with phase3_container:
st.markdown("##### 🎯 Phase 3: Running Evaluation")
metrics_collectors = {
"ndcg@5": [], "ndcg@10": [],
"recall@5": [], "recall@10": [],
"mrr@5": [], "mrr@10": [],
}
latencies = []
log_lines = []
metrics_by_dataset = {}
if evaluation_scope == "per_dataset":
overall_progress = st.progress(0.0)
datasets_done = 0
for ds_name, ds_info in dataset_data.items():
ds_short = ds_name.split("/")[-1]
st.markdown(f"**Evaluating `{ds_short}`**")
queries = ds_info["queries"]
qrels = ds_info["qrels"]
if not queries:
continue
print(f"[EVAL] Embedding {len(queries)} queries for {ds_short}...")
query_texts = [q.text for q in queries]
query_embeddings = embedder.embed_queries(query_texts, show_progress=False)
print(f"[EVAL] Queries embedded for {ds_short}")
ds_filter = Filter(
must=[FieldCondition(key="dataset", match=MatchValue(value=ds_name))]
)
print(f"[EVAL] Using filter: dataset={ds_name}")
progress_bar = st.progress(0.0)
eval_status = st.empty()
log_area = st.empty()
ds_metrics = {"ndcg@5": [], "ndcg@10": [], "recall@5": [], "recall@10": [], "mrr@5": [], "mrr@10": []}
ds_latencies = []
ds_log_lines = []
eval_status.info(f"Evaluating {len(queries)} queries...")
print(f"[EVAL] Starting per-dataset evaluation: {ds_short}, {len(queries)} queries")
for i, (q, qemb) in enumerate(zip(queries, query_embeddings)):
start = time.time()
if isinstance(qemb, torch.Tensor):
qemb_np = qemb.detach().cpu().numpy()
else:
qemb_np = qemb.numpy() if hasattr(qemb, 'numpy') else np.array(qemb)
results = retriever.search_embedded(
query_embedding=qemb_np,
top_k=max(100, top_k),
mode=mode,
prefetch_k=prefetch_k,
stage1_mode=stage1_mode,
stage1_k=stage1_k,
stage2_k=stage2_k,
filter_obj=ds_filter,
)
ds_latencies.append((time.time() - start) * 1000)
latencies.append(ds_latencies[-1])
ranking = [str(r["id"]) for r in results]
rels = qrels.get(q.query_id, {})
if i == 0:
print(f"[EVAL] First query for {ds_short} - query_id: {q.query_id}")
print(f"[EVAL] Top 3 retrieved doc_ids: {ranking[:3]}")
print(f"[EVAL] Expected doc_ids (qrels): {list(rels.keys())[:3]}")
print(f"[EVAL] qrels has {len(qrels)} queries, this query in qrels: {q.query_id in qrels}")
if results:
r0 = results[0]
print(f"[EVAL] Sample result payload keys: {list(r0.get('payload', {}).keys())}")
p = r0.get("payload", {})
print(f"[EVAL] Sample payload doc_id={p.get('doc_id')}, union_doc_id={p.get('union_doc_id')}, source_doc_id={p.get('source_doc_id')}")
has_match = any(rid in rels for rid in ranking[:10])
print(f"[EVAL] Any match in top-10? {has_match}")
for k_name, k_val in [("ndcg@5", 5), ("ndcg@10", 10)]:
ds_metrics[k_name].append(ndcg_at_k(ranking, rels, k=k_val))
for k_name, k_val in [("recall@5", 5), ("recall@10", 10)]:
ds_metrics[k_name].append(recall_at_k(ranking, rels, k=k_val))
for k_name, k_val in [("mrr@5", 5), ("mrr@10", 10)]:
ds_metrics[k_name].append(mrr_at_k(ranking, rels, k=k_val))
progress = (i + 1) / len(queries)
progress_bar.progress(progress)
eval_status.info(f"🎯 {i+1}/{len(queries)} ({int(progress*100)}%) β€” latency: {np.mean(ds_latencies):.0f}ms")
log_interval = max(5, len(queries) // 10)
if (i + 1) % log_interval == 0 and i > 0:
cur_ndcg = np.mean(ds_metrics["ndcg@10"])
cur_lat = np.mean(ds_latencies[1:]) if len(ds_latencies) > 1 else ds_latencies[0]
ds_log_lines.append(f"[{i+1}/{len(queries)}] NDCG@10={cur_ndcg:.4f}, lat={cur_lat:.0f}ms")
log_area.code("\n".join(ds_log_lines[-5:]), language="text")
print(f"[EVAL] {ds_short} {i+1}/{len(queries)}: NDCG@10={cur_ndcg:.4f}, lat={cur_lat:.0f}ms")
progress_bar.progress(1.0)
ds_final = {k: float(np.mean(v)) for k, v in ds_metrics.items()}
ds_final["avg_latency_ms"] = float(np.mean(ds_latencies))
ds_final["num_queries"] = len(queries)
metrics_by_dataset[ds_name] = ds_final
for k, v in ds_metrics.items():
metrics_collectors[k].extend(v)
eval_status.success(f"βœ… `{ds_short}`: NDCG@10={ds_final['ndcg@10']:.4f}, latency={ds_final['avg_latency_ms']:.0f}ms")
print(f"[EVAL] {ds_short} DONE: NDCG@10={ds_final['ndcg@10']:.4f}")
datasets_done += 1
overall_progress.progress(datasets_done / len(datasets))
overall_progress.progress(1.0)
embed_status.success(f"βœ… All queries embedded")
total_queries = sum(d["num_queries"] for d in metrics_by_dataset.values())
else:
all_queries = []
all_qrels = {}
for ds_name, ds_info in dataset_data.items():
all_queries.extend(ds_info["queries"])
for qid, rels in ds_info["qrels"].items():
all_qrels[qid] = rels
sample_qrel_keys = list(all_qrels.keys())[:3]
sample_doc_ids = []
for qid in sample_qrel_keys:
sample_doc_ids.extend(list(all_qrels[qid].keys())[:2])
print(f"[EVAL] all_qrels built: {len(all_qrels)} queries")
print(f"[EVAL] Sample qrel query_ids: {sample_qrel_keys}")
print(f"[EVAL] Sample qrel doc_ids: {sample_doc_ids[:5]}")
max_q = config.get("max_queries")
if max_q and max_q < len(all_queries):
all_queries = all_queries[:max_q]
total_queries = len(all_queries)
print(f"[EVAL] Embedding {total_queries} queries (union mode)...")
query_texts = [q.text for q in all_queries]
query_embeddings = embedder.embed_queries(query_texts, show_progress=False)
print(f"[EVAL] Queries embedded")
embed_status.success(f"βœ… {total_queries} queries embedded")
progress_bar = st.progress(0.0)
eval_status = st.empty()
log_area = st.empty()
eval_status.info(f"Evaluating {total_queries} queries in `{mode}` mode...")
print(f"[EVAL] Starting union evaluation: {total_queries} queries, mode={mode}")
for i, (q, qemb) in enumerate(zip(all_queries, query_embeddings)):
start = time.time()
if isinstance(qemb, torch.Tensor):
qemb_np = qemb.detach().cpu().numpy()
else:
qemb_np = qemb.numpy() if hasattr(qemb, 'numpy') else np.array(qemb)
results = retriever.search_embedded(
query_embedding=qemb_np,
top_k=max(100, top_k),
mode=mode,
prefetch_k=prefetch_k,
stage1_mode=stage1_mode,
stage1_k=stage1_k,
stage2_k=stage2_k,
)
latencies.append((time.time() - start) * 1000)
ranking = [str(r["id"]) for r in results]
rels = all_qrels.get(q.query_id, {})
if i == 0:
print(f"[EVAL] First query - query_id: {q.query_id}")
print(f"[EVAL] Top 3 retrieved doc_ids: {ranking[:3]}")
print(f"[EVAL] Expected doc_ids (qrels): {list(rels.keys())[:3]}")
print(f"[EVAL] all_qrels has {len(all_qrels)} queries, this query in qrels: {q.query_id in all_qrels}")
if results:
r0 = results[0]
print(f"[EVAL] Sample result payload keys: {list(r0.get('payload', {}).keys())}")
p = r0.get("payload", {})
print(f"[EVAL] Sample payload doc_id={p.get('doc_id')}, union_doc_id={p.get('union_doc_id')}, source_doc_id={p.get('source_doc_id')}")
has_match = any(rid in rels for rid in ranking[:10])
print(f"[EVAL] Any match in top-10? {has_match}")
metrics_collectors["ndcg@5"].append(ndcg_at_k(ranking, rels, k=5))
metrics_collectors["ndcg@10"].append(ndcg_at_k(ranking, rels, k=10))
metrics_collectors["recall@5"].append(recall_at_k(ranking, rels, k=5))
metrics_collectors["recall@10"].append(recall_at_k(ranking, rels, k=10))
metrics_collectors["mrr@5"].append(mrr_at_k(ranking, rels, k=5))
metrics_collectors["mrr@10"].append(mrr_at_k(ranking, rels, k=10))
progress = (i + 1) / total_queries
progress_bar.progress(progress)
eval_status.info(f"🎯 {i+1}/{total_queries} ({int(progress*100)}%) β€” latency: {np.mean(latencies):.0f}ms")
log_interval = max(10, total_queries // 10)
if (i + 1) % log_interval == 0 and i > 0:
cur_ndcg = np.mean(metrics_collectors["ndcg@10"])
cur_lat = np.mean(latencies[1:]) if len(latencies) > 1 else latencies[0]
log_lines.append(f"[{i+1}/{total_queries}] NDCG@10={cur_ndcg:.4f}, lat={cur_lat:.0f}ms")
log_area.code("\n".join(log_lines[-10:]), language="text")
print(f"[EVAL] Progress {i+1}/{total_queries}: NDCG@10={cur_ndcg:.4f}, lat={cur_lat:.0f}ms")
progress_bar.progress(1.0)
eval_status.success(f"βœ… Evaluation complete! ({total_queries} queries)")
with results_container:
st.markdown("##### πŸ“Š Results")
p95_latency = float(np.percentile(latencies, 95))
eval_time_s = sum(latencies) / 1000
qps = total_queries / eval_time_s if eval_time_s > 0 else 0
final_metrics = {
"ndcg@5": float(np.mean(metrics_collectors["ndcg@5"])),
"ndcg@10": float(np.mean(metrics_collectors["ndcg@10"])),
"recall@5": float(np.mean(metrics_collectors["recall@5"])),
"recall@10": float(np.mean(metrics_collectors["recall@10"])),
"mrr@5": float(np.mean(metrics_collectors["mrr@5"])),
"mrr@10": float(np.mean(metrics_collectors["mrr@10"])),
"avg_latency_ms": float(np.mean(latencies)),
"p95_latency_ms": p95_latency,
"qps": qps,
"eval_time_s": eval_time_s,
"num_queries": total_queries,
}
print("=" * 60)
print("[EVAL] FINAL RESULTS:")
print(f"[EVAL] NDCG@5: {final_metrics['ndcg@5']:.4f}")
print(f"[EVAL] NDCG@10: {final_metrics['ndcg@10']:.4f}")
print(f"[EVAL] Recall@5: {final_metrics['recall@5']:.4f}")
print(f"[EVAL] Recall@10: {final_metrics['recall@10']:.4f}")
print(f"[EVAL] MRR@5: {final_metrics['mrr@5']:.4f}")
print(f"[EVAL] MRR@10: {final_metrics['mrr@10']:.4f}")
print(f"[EVAL] Avg Latency: {final_metrics['avg_latency_ms']:.1f}ms")
print(f"[EVAL] P95 Latency: {final_metrics['p95_latency_ms']:.1f}ms")
print(f"[EVAL] QPS: {final_metrics['qps']:.2f}")
print(f"[EVAL] Queries: {final_metrics['num_queries']}")
print("=" * 60)
st.markdown("**Retrieval Metrics**")
c1, c2, c3 = st.columns(3)
with c1:
st.metric("NDCG@5", f"{final_metrics['ndcg@5']:.4f}")
st.metric("NDCG@10", f"{final_metrics['ndcg@10']:.4f}")
with c2:
st.metric("Recall@5", f"{final_metrics['recall@5']:.4f}")
st.metric("Recall@10", f"{final_metrics['recall@10']:.4f}")
with c3:
st.metric("MRR@5", f"{final_metrics['mrr@5']:.4f}")
st.metric("MRR@10", f"{final_metrics['mrr@10']:.4f}")
st.markdown("**Performance**")
c4, c5, c6, c7 = st.columns(4)
c4.metric("Avg Latency", f"{final_metrics['avg_latency_ms']:.0f}ms")
c5.metric("P95 Latency", f"{final_metrics['p95_latency_ms']:.0f}ms")
c6.metric("QPS", f"{final_metrics['qps']:.2f}")
c7.metric("Eval Time", f"{final_metrics['eval_time_s']:.1f}s")
with st.expander("πŸ“‹ Full Results JSON"):
st.json(final_metrics)
detailed_report = {
"generated_at": datetime.now().isoformat(),
"config": {
"collection": collection,
"model": model,
"datasets": datasets,
"mode": mode,
"top_k": top_k,
"evaluation_scope": config.get("evaluation_scope", "union"),
"prefer_grpc": prefer_grpc,
"torch_dtype": torch_dtype,
"max_queries": config.get("max_queries"),
},
"retrieval_metrics": {
"ndcg@5": final_metrics["ndcg@5"],
"ndcg@10": final_metrics["ndcg@10"],
"recall@5": final_metrics["recall@5"],
"recall@10": final_metrics["recall@10"],
"mrr@5": final_metrics["mrr@5"],
"mrr@10": final_metrics["mrr@10"],
},
"performance": {
"avg_latency_ms": final_metrics["avg_latency_ms"],
"p95_latency_ms": final_metrics["p95_latency_ms"],
"qps": final_metrics["qps"],
"eval_time_s": final_metrics["eval_time_s"],
"num_queries": final_metrics["num_queries"],
},
}
if mode == "two_stage":
detailed_report["config"]["stage1_mode"] = stage1_mode
detailed_report["config"]["prefetch_k"] = prefetch_k
elif mode == "three_stage":
detailed_report["config"]["stage1_k"] = stage1_k
detailed_report["config"]["stage2_k"] = stage2_k
if evaluation_scope == "per_dataset" and metrics_by_dataset:
detailed_report["metrics_by_dataset"] = metrics_by_dataset
st.markdown("**Per-Dataset Results**")
for ds_name, ds_metrics in metrics_by_dataset.items():
ds_short = ds_name.split("/")[-1]
with st.expander(f"πŸ“ {ds_short}"):
dc1, dc2, dc3, dc4 = st.columns(4)
dc1.metric("NDCG@10", f"{ds_metrics['ndcg@10']:.4f}")
dc2.metric("Recall@10", f"{ds_metrics['recall@10']:.4f}")
dc3.metric("MRR@10", f"{ds_metrics['mrr@10']:.4f}")
dc4.metric("Latency", f"{ds_metrics['avg_latency_ms']:.0f}ms")
report_json = json.dumps(detailed_report, indent=2)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"eval_report__{collection}__{mode}__{timestamp}.json"
st.download_button(
label="πŸ“₯ Download Detailed Report",
data=report_json,
file_name=filename,
mime="application/json",
use_container_width=True,
)
st.session_state["last_eval_metrics"] = final_metrics
except Exception as e:
error_msg = str(e)
if "not configured in this collection" in error_msg:
vector_name = error_msg.split("name ")[-1].split(" is")[0] if "name " in error_msg else "unknown"
st.error(f"❌ **Collection Mismatch**: Vector `{vector_name}` not found in collection `{collection}`")
st.warning(f"""
**The selected mode `{mode}` requires vectors that don't exist in this collection.**
**Suggestions:**
- Try `single_full` mode (works with basic collections)
- Use a collection indexed with the Visual RAG Toolkit
- Check that the collection has the required vector types for `{mode}` mode
""")
else:
st.error(f"❌ Error: {e}")
with st.expander("πŸ” Full Error Details"):
st.code(traceback.format_exc(), language="text")