Yeroyan's picture
sync v0.1.3
9513cca verified
#!/usr/bin/env python3
"""
Quick Benchmark - Validate retrieval quality with ViDoRe data.
This script:
1. Downloads samples from ViDoRe (with ground truth relevance)
2. Embeds with ColSmol-500M
3. Tests retrieval strategies (exhaustive vs two-stage)
4. Computes METRICS: NDCG@K, MRR@K, Recall@K
5. Compares speed and quality
Usage:
python quick_test.py --samples 100
python quick_test.py --samples 500 --skip-exhaustive # Faster
"""
import sys
import time
import argparse
import logging
from pathlib import Path
from typing import List, Dict, Any
# Add parent directory to Python path (so we can import visual_rag)
# This allows running the script directly without pip install
_script_dir = Path(__file__).parent
_parent_dir = _script_dir.parent
if str(_parent_dir) not in sys.path:
sys.path.insert(0, str(_parent_dir))
import numpy as np
from tqdm import tqdm
# Visual RAG imports (now works without pip install)
from visual_rag.embedding import VisualEmbedder
from visual_rag.embedding.pooling import (
tile_level_mean_pooling,
compute_maxsim_score,
)
# Optional: datasets for ViDoRe
try:
from datasets import load_dataset as hf_load_dataset
HAS_DATASETS = True
except ImportError:
HAS_DATASETS = False
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def load_vidore_sample(num_samples: int = 100) -> List[Dict]:
"""
Load sample from ViDoRe DocVQA with ground truth.
Each sample has a query and its relevant document (1:1 mapping).
This allows computing retrieval metrics.
"""
if not HAS_DATASETS:
logger.error("Install datasets: pip install datasets")
sys.exit(1)
logger.info(f"📥 Loading {num_samples} samples from ViDoRe DocVQA...")
ds = hf_load_dataset("vidore/docvqa_test_subsampled", split="test")
samples = []
for i, example in enumerate(ds):
if i >= num_samples:
break
samples.append({
"id": i,
"doc_id": f"doc_{i}",
"query_id": f"q_{i}",
"image": example.get("image", example.get("page_image")),
"query": example.get("query", example.get("question", "")),
# Ground truth: query i is relevant to doc i
"relevant_doc": f"doc_{i}",
})
logger.info(f"✅ Loaded {len(samples)} samples with ground truth")
return samples
def embed_all(
samples: List[Dict],
model_name: str = "vidore/colSmol-500M",
) -> Dict[str, Any]:
"""Embed all documents and queries."""
logger.info(f"\n🤖 Loading model: {model_name}")
embedder = VisualEmbedder(model_name=model_name)
images = [s["image"] for s in samples]
queries = [s["query"] for s in samples if s["query"]]
# Embed images
logger.info(f"🎨 Embedding {len(images)} documents...")
start_time = time.time()
embeddings, token_infos = embedder.embed_images(
images, batch_size=4, return_token_info=True
)
doc_embed_time = time.time() - start_time
logger.info(f" Time: {doc_embed_time:.2f}s ({doc_embed_time/len(images)*1000:.1f}ms/doc)")
# Process embeddings: extract visual tokens + tile-level pooling
doc_data = {}
for i, (emb, token_info) in enumerate(zip(embeddings, token_infos)):
if hasattr(emb, 'cpu'):
emb = emb.cpu()
emb_np = emb.numpy() if hasattr(emb, 'numpy') else np.array(emb)
# Extract visual tokens only (filter special tokens)
visual_indices = token_info["visual_token_indices"]
visual_emb = emb_np[visual_indices].astype(np.float32)
# Tile-level pooling
n_rows = token_info.get("n_rows", 4)
n_cols = token_info.get("n_cols", 3)
num_tiles = n_rows * n_cols + 1 if n_rows and n_cols else 13
tile_pooled = tile_level_mean_pooling(visual_emb, num_tiles, patches_per_tile=64)
doc_data[f"doc_{i}"] = {
"embedding": visual_emb,
"pooled": tile_pooled,
"num_visual_tokens": len(visual_indices),
"num_tiles": tile_pooled.shape[0],
}
# Embed queries
logger.info(f"🔍 Embedding {len(queries)} queries...")
start_time = time.time()
query_data = {}
for i, query in enumerate(tqdm(queries, desc="Queries")):
q_emb = embedder.embed_query(query)
if hasattr(q_emb, 'cpu'):
q_emb = q_emb.cpu()
q_np = q_emb.numpy() if hasattr(q_emb, 'numpy') else np.array(q_emb)
query_data[f"q_{i}"] = q_np.astype(np.float32)
query_embed_time = time.time() - start_time
return {
"docs": doc_data,
"queries": query_data,
"samples": samples,
"doc_embed_time": doc_embed_time,
"query_embed_time": query_embed_time,
"model": model_name,
}
def search_exhaustive(query_emb: np.ndarray, docs: Dict, top_k: int = 10) -> List[Dict]:
"""Exhaustive MaxSim search over all documents."""
scores = []
for doc_id, doc in docs.items():
score = compute_maxsim_score(query_emb, doc["embedding"])
scores.append({"id": doc_id, "score": score})
scores.sort(key=lambda x: x["score"], reverse=True)
return scores[:top_k]
def search_two_stage(
query_emb: np.ndarray,
docs: Dict,
prefetch_k: int = 20,
top_k: int = 10,
) -> List[Dict]:
"""
Two-stage retrieval with tile-level pooling.
Stage 1: Fast prefetch using tile-pooled vectors
Stage 2: Exact MaxSim reranking on candidates
"""
# Stage 1: Tile-level pooled search
query_pooled = query_emb.mean(axis=0)
query_pooled = query_pooled / (np.linalg.norm(query_pooled) + 1e-8)
stage1_scores = []
for doc_id, doc in docs.items():
doc_pooled = doc["pooled"]
doc_norm = doc_pooled / (np.linalg.norm(doc_pooled, axis=1, keepdims=True) + 1e-8)
tile_sims = np.dot(doc_norm, query_pooled)
score = float(tile_sims.max())
stage1_scores.append({"id": doc_id, "score": score})
stage1_scores.sort(key=lambda x: x["score"], reverse=True)
candidates = stage1_scores[:prefetch_k]
# Stage 2: Exact MaxSim on candidates
reranked = []
for cand in candidates:
doc_id = cand["id"]
score = compute_maxsim_score(query_emb, docs[doc_id]["embedding"])
reranked.append({"id": doc_id, "score": score, "stage1_rank": stage1_scores.index(cand) + 1})
reranked.sort(key=lambda x: x["score"], reverse=True)
return reranked[:top_k]
def compute_metrics(
results: Dict[str, List[Dict]],
samples: List[Dict],
k_values: List[int] = [1, 3, 5, 7, 10],
) -> Dict[str, float]:
"""
Compute retrieval metrics.
Since ViDoRe has 1:1 query-doc mapping (1 relevant doc per query):
- Recall@K (Hit Rate): Is the relevant doc in top-K? (0 or 1)
- Precision@K: (# relevant in top-K) / K
- MRR@K: 1/rank if found in top-K, else 0
- NDCG@K: DCG / IDCG with binary relevance
"""
metrics = {}
# Also track per-query ranks for analysis
all_ranks = []
for k in k_values:
recalls = []
precisions = []
mrrs = []
ndcgs = []
for sample in samples:
query_id = sample["query_id"]
relevant_doc = sample["relevant_doc"]
if query_id not in results:
continue
ranking = results[query_id][:k]
ranked_ids = [r["id"] for r in ranking]
# Find rank of relevant doc (1-indexed, 0 if not found)
rank = 0
for i, doc_id in enumerate(ranked_ids):
if doc_id == relevant_doc:
rank = i + 1
break
# Recall@K (Hit Rate): 1 if found in top-K
found = 1.0 if rank > 0 else 0.0
recalls.append(found)
# Precision@K: (# relevant found) / K
# With 1 relevant doc: 1/K if found, 0 otherwise
precision = found / k
precisions.append(precision)
# MRR@K: 1/rank if found
mrr = 1.0 / rank if rank > 0 else 0.0
mrrs.append(mrr)
# NDCG@K (binary relevance)
# DCG = 1/log2(rank+1) if found, 0 otherwise
# IDCG = 1/log2(2) = 1 (best case: relevant at rank 1)
dcg = 1.0 / np.log2(rank + 1) if rank > 0 else 0.0
idcg = 1.0
ndcg = dcg / idcg
ndcgs.append(ndcg)
# Track actual rank for analysis (only for k=10)
if k == max(k_values):
full_ranking = results[query_id]
full_rank = 0
for i, r in enumerate(full_ranking):
if r["id"] == relevant_doc:
full_rank = i + 1
break
all_ranks.append(full_rank)
metrics[f"Recall@{k}"] = np.mean(recalls)
metrics[f"P@{k}"] = np.mean(precisions)
metrics[f"MRR@{k}"] = np.mean(mrrs)
metrics[f"NDCG@{k}"] = np.mean(ndcgs)
# Add summary stats
if all_ranks:
found_ranks = [r for r in all_ranks if r > 0]
metrics["avg_rank"] = np.mean(found_ranks) if found_ranks else float('inf')
metrics["median_rank"] = np.median(found_ranks) if found_ranks else float('inf')
metrics["not_found"] = sum(1 for r in all_ranks if r == 0)
return metrics
def run_benchmark(
data: Dict,
skip_exhaustive: bool = False,
prefetch_k: int = None,
top_k: int = 10,
) -> Dict[str, Dict]:
"""Run retrieval benchmark with metrics."""
docs = data["docs"]
queries = data["queries"]
samples = data["samples"]
num_docs = len(docs)
# Auto-set prefetch_k to be meaningful (default: 20, or 20% of docs if >100 docs)
if prefetch_k is None:
if num_docs <= 100:
prefetch_k = 20 # Default: prefetch 20, rerank to top-10
else:
prefetch_k = max(20, min(100, int(num_docs * 0.2))) # 20% for larger collections
# Ensure prefetch_k < num_docs for meaningful two-stage comparison
if prefetch_k >= num_docs:
logger.warning(f"⚠️ prefetch_k={prefetch_k} >= num_docs={num_docs}")
logger.warning(f" Two-stage will fetch ALL docs (same as exhaustive)")
logger.warning(f" Use --samples > {prefetch_k * 3} for meaningful comparison")
logger.info(f"📊 Benchmark config: {num_docs} docs, prefetch_k={prefetch_k}, top_k={top_k}")
logger.info(f" (Both methods return top-{top_k} results - realistic retrieval scenario)")
results = {}
# Two-stage retrieval (NOVEL)
logger.info(f"\n🔬 Running Two-Stage retrieval (prefetch top-{prefetch_k}, rerank to top-{top_k})...")
two_stage_results = {}
two_stage_times = []
for sample in tqdm(samples, desc="Two-Stage"):
query_id = sample["query_id"]
query_emb = queries[query_id]
start = time.time()
ranking = search_two_stage(query_emb, docs, prefetch_k=prefetch_k, top_k=top_k)
two_stage_times.append(time.time() - start)
two_stage_results[query_id] = ranking
two_stage_metrics = compute_metrics(two_stage_results, samples)
two_stage_metrics["avg_time_ms"] = np.mean(two_stage_times) * 1000
two_stage_metrics["prefetch_k"] = prefetch_k
two_stage_metrics["top_k"] = top_k
results["two_stage"] = two_stage_metrics
# Exhaustive search (baseline)
if not skip_exhaustive:
logger.info(f"🔬 Running Exhaustive MaxSim (searches ALL {num_docs} docs, returns top-{top_k})...")
exhaustive_results = {}
exhaustive_times = []
for sample in tqdm(samples, desc="Exhaustive"):
query_id = sample["query_id"]
query_emb = queries[query_id]
start = time.time()
ranking = search_exhaustive(query_emb, docs, top_k=top_k)
exhaustive_times.append(time.time() - start)
exhaustive_results[query_id] = ranking
exhaustive_metrics = compute_metrics(exhaustive_results, samples)
exhaustive_metrics["avg_time_ms"] = np.mean(exhaustive_times) * 1000
exhaustive_metrics["top_k"] = top_k
results["exhaustive"] = exhaustive_metrics
return results
def print_results(data: Dict, benchmark_results: Dict, show_precision: bool = False):
"""Print benchmark results."""
print("\n" + "=" * 80)
print("📊 BENCHMARK RESULTS")
print("=" * 80)
num_docs = len(data['docs'])
print(f"\n🤖 Model: {data['model']}")
print(f"📄 Documents: {num_docs}")
print(f"🔍 Queries: {len(data['queries'])}")
# Embedding stats
sample_doc = list(data['docs'].values())[0]
print(f"\n📏 Embedding (after visual token filtering):")
print(f" Visual tokens per doc: {sample_doc['num_visual_tokens']}")
print(f" Tile-pooled vectors: {sample_doc['num_tiles']}")
if "two_stage" in benchmark_results:
prefetch_k = benchmark_results["two_stage"].get("prefetch_k", "?")
print(f" Two-stage prefetch_k: {prefetch_k} (of {num_docs} docs)")
# Method labels - clearer naming
def get_label(method):
if method == "two_stage":
return "Pooled+Rerank" # Tile-pooled prefetch + MaxSim rerank
else:
return "Full MaxSim" # Exhaustive MaxSim on all docs
# Recall / Hit Rate table
print(f"\n🎯 RECALL (Hit Rate) @ K:")
print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
print(f" {'-'*60}")
for method, metrics in benchmark_results.items():
print(f" {get_label(method):<20} "
f"{metrics.get('Recall@1', 0):>8.3f} "
f"{metrics.get('Recall@3', 0):>8.3f} "
f"{metrics.get('Recall@5', 0):>8.3f} "
f"{metrics.get('Recall@7', 0):>8.3f} "
f"{metrics.get('Recall@10', 0):>8.3f}")
# Precision table (optional)
if show_precision:
print(f"\n📐 PRECISION @ K:")
print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
print(f" {'-'*60}")
for method, metrics in benchmark_results.items():
print(f" {get_label(method):<20} "
f"{metrics.get('P@1', 0):>8.3f} "
f"{metrics.get('P@3', 0):>8.3f} "
f"{metrics.get('P@5', 0):>8.3f} "
f"{metrics.get('P@7', 0):>8.3f} "
f"{metrics.get('P@10', 0):>8.3f}")
# NDCG table
print(f"\n📈 NDCG @ K:")
print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
print(f" {'-'*60}")
for method, metrics in benchmark_results.items():
print(f" {get_label(method):<20} "
f"{metrics.get('NDCG@1', 0):>8.3f} "
f"{metrics.get('NDCG@3', 0):>8.3f} "
f"{metrics.get('NDCG@5', 0):>8.3f} "
f"{metrics.get('NDCG@7', 0):>8.3f} "
f"{metrics.get('NDCG@10', 0):>8.3f}")
# MRR table
print(f"\n🔍 MRR @ K:")
print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
print(f" {'-'*60}")
for method, metrics in benchmark_results.items():
print(f" {get_label(method):<20} "
f"{metrics.get('MRR@1', 0):>8.3f} "
f"{metrics.get('MRR@3', 0):>8.3f} "
f"{metrics.get('MRR@5', 0):>8.3f} "
f"{metrics.get('MRR@7', 0):>8.3f} "
f"{metrics.get('MRR@10', 0):>8.3f}")
# Speed comparison
top_k = benchmark_results.get("two_stage", benchmark_results.get("exhaustive", {})).get("top_k", 10)
print(f"\n⏱️ SPEED (both return top-{top_k} results):")
print(f" {'Method':<20} {'Time (ms)':>12} {'Docs searched':>15}")
print(f" {'-'*50}")
for method, metrics in benchmark_results.items():
if method == "two_stage":
searched = metrics.get("prefetch_k", "?")
label = f"{searched} (stage-1)"
else:
searched = num_docs
label = f"{searched} (all)"
print(f" {get_label(method):<20} {metrics.get('avg_time_ms', 0):>12.2f} {label:>15}")
# Comparison summary
if "exhaustive" in benchmark_results and "two_stage" in benchmark_results:
ex = benchmark_results["exhaustive"]
ts = benchmark_results["two_stage"]
print(f"\n💡 POOLED+RERANK vs FULL MAXSIM:")
for k in [1, 5, 10]:
ex_recall = ex.get(f"Recall@{k}", 0)
ts_recall = ts.get(f"Recall@{k}", 0)
if ex_recall > 0:
retention = ts_recall / ex_recall * 100
print(f" • Recall@{k} retention: {retention:.1f}% ({ts_recall:.3f} vs {ex_recall:.3f})")
speedup = ex["avg_time_ms"] / ts["avg_time_ms"] if ts["avg_time_ms"] > 0 else 0
print(f" • Speedup: {speedup:.1f}x")
# Rank stats with explanation
if "avg_rank" in ts:
prefetch_k = ts.get("prefetch_k", "?")
top_k = ts.get("top_k", 10)
not_found = ts.get("not_found", 0)
total = len(data["queries"])
print(f"\n📊 POOLED+RERANK STATISTICS:")
print(f" Stage-1 (pooled prefetch):")
print(f" • Searches top-{prefetch_k} candidates using tile-pooled vectors")
print(f" • {total - not_found}/{total} queries ({100 - not_found/total*100:.1f}%) had relevant doc in prefetch")
print(f" • {not_found}/{total} queries ({not_found/total*100:.1f}%) missed (relevant doc ranked >{prefetch_k})")
print(f" Stage-2 (MaxSim reranking):")
print(f" • Reranks prefetch candidates with exact MaxSim")
print(f" • Returns final top-{top_k} results")
if ts['avg_rank'] < float('inf'):
print(f" • Avg rank of relevant doc (when found): {ts['avg_rank']:.1f}")
print(f" • Median rank: {ts['median_rank']:.1f}")
print(f"\n 💡 The {not_found/total*100:.1f}% miss rate is for stage-1 prefetch.")
print(f" Final Recall@{top_k} shows how many relevant docs ARE in top-{top_k} results.")
print("\n" + "=" * 80)
print("✅ Benchmark complete!")
def main():
parser = argparse.ArgumentParser(
description="Quick benchmark for visual-rag-toolkit",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--samples", type=int, default=100,
help="Number of samples (default: 100)"
)
parser.add_argument(
"--model", type=str, default="vidore/colSmol-500M",
help="Model: vidore/colSmol-500M (default), vidore/colpali-v1.3"
)
parser.add_argument(
"--prefetch-k", type=int, default=None,
help="Stage 1 candidates for two-stage (default: 20 for <=100 docs, auto for larger)"
)
parser.add_argument(
"--skip-exhaustive", action="store_true",
help="Skip exhaustive baseline (faster)"
)
parser.add_argument(
"--show-precision", action="store_true",
help="Show Precision@K metrics (hidden by default)"
)
parser.add_argument(
"--top-k", type=int, default=10,
help="Number of results to return (default: 10, realistic retrieval scenario)"
)
args = parser.parse_args()
print("\n" + "=" * 70)
print("🧪 VISUAL RAG TOOLKIT - RETRIEVAL BENCHMARK")
print("=" * 70)
# Load samples
samples = load_vidore_sample(args.samples)
if not samples:
logger.error("No samples loaded!")
sys.exit(1)
# Embed all
data = embed_all(samples, args.model)
# Run benchmark
benchmark_results = run_benchmark(
data,
skip_exhaustive=args.skip_exhaustive,
prefetch_k=args.prefetch_k,
top_k=args.top_k,
)
# Print results
print_results(data, benchmark_results, show_precision=args.show_precision)
if __name__ == "__main__":
main()