Spaces:
Sleeping
Sleeping
sync v0.1.3
Browse files- benchmarks/__init__.py +1 -0
- benchmarks/quick_test.py +566 -0
- visual_rag/__init__.py +11 -7
- visual_rag/cli/__init__.py +0 -2
- visual_rag/cli/main.py +125 -133
- visual_rag/config.py +61 -53
- visual_rag/demo_runner.py +90 -0
- visual_rag/embedding/__init__.py +4 -5
- visual_rag/embedding/pooling.py +53 -51
- visual_rag/embedding/visual_embedder.py +137 -91
- visual_rag/indexing/__init__.py +38 -0
- visual_rag/indexing/cloudinary_uploader.py +46 -51
- visual_rag/indexing/pdf_processor.py +85 -76
- visual_rag/indexing/pipeline.py +170 -125
- visual_rag/indexing/qdrant_indexer.py +162 -143
- visual_rag/preprocessing/__init__.py +0 -2
- visual_rag/preprocessing/crop_empty.py +15 -7
- visual_rag/qdrant_admin.py +29 -12
- visual_rag/retrieval/__init__.py +2 -2
- visual_rag/retrieval/multi_vector.py +64 -64
- visual_rag/retrieval/single_stage.py +17 -18
- visual_rag/retrieval/three_stage.py +1 -2
- visual_rag/retrieval/two_stage.py +73 -94
- visual_rag/visualization/__init__.py +1 -1
- visual_rag/visualization/saliency.py +63 -67
benchmarks/__init__.py
CHANGED
|
@@ -8,3 +8,4 @@ work in Docker/Spaces environments.
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
__all__ = []
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
__all__ = []
|
| 11 |
+
|
benchmarks/quick_test.py
ADDED
|
@@ -0,0 +1,566 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Quick Benchmark - Validate retrieval quality with ViDoRe data.
|
| 4 |
+
|
| 5 |
+
This script:
|
| 6 |
+
1. Downloads samples from ViDoRe (with ground truth relevance)
|
| 7 |
+
2. Embeds with ColSmol-500M
|
| 8 |
+
3. Tests retrieval strategies (exhaustive vs two-stage)
|
| 9 |
+
4. Computes METRICS: NDCG@K, MRR@K, Recall@K
|
| 10 |
+
5. Compares speed and quality
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python quick_test.py --samples 100
|
| 14 |
+
python quick_test.py --samples 500 --skip-exhaustive # Faster
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import sys
|
| 18 |
+
import time
|
| 19 |
+
import argparse
|
| 20 |
+
import logging
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import List, Dict, Any
|
| 23 |
+
|
| 24 |
+
# Add parent directory to Python path (so we can import visual_rag)
|
| 25 |
+
# This allows running the script directly without pip install
|
| 26 |
+
_script_dir = Path(__file__).parent
|
| 27 |
+
_parent_dir = _script_dir.parent
|
| 28 |
+
if str(_parent_dir) not in sys.path:
|
| 29 |
+
sys.path.insert(0, str(_parent_dir))
|
| 30 |
+
|
| 31 |
+
import numpy as np
|
| 32 |
+
from tqdm import tqdm
|
| 33 |
+
|
| 34 |
+
# Visual RAG imports (now works without pip install)
|
| 35 |
+
from visual_rag.embedding import VisualEmbedder
|
| 36 |
+
from visual_rag.embedding.pooling import (
|
| 37 |
+
tile_level_mean_pooling,
|
| 38 |
+
compute_maxsim_score,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Optional: datasets for ViDoRe
|
| 42 |
+
try:
|
| 43 |
+
from datasets import load_dataset as hf_load_dataset
|
| 44 |
+
HAS_DATASETS = True
|
| 45 |
+
except ImportError:
|
| 46 |
+
HAS_DATASETS = False
|
| 47 |
+
|
| 48 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 49 |
+
logger = logging.getLogger(__name__)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_vidore_sample(num_samples: int = 100) -> List[Dict]:
|
| 53 |
+
"""
|
| 54 |
+
Load sample from ViDoRe DocVQA with ground truth.
|
| 55 |
+
|
| 56 |
+
Each sample has a query and its relevant document (1:1 mapping).
|
| 57 |
+
This allows computing retrieval metrics.
|
| 58 |
+
"""
|
| 59 |
+
if not HAS_DATASETS:
|
| 60 |
+
logger.error("Install datasets: pip install datasets")
|
| 61 |
+
sys.exit(1)
|
| 62 |
+
|
| 63 |
+
logger.info(f"📥 Loading {num_samples} samples from ViDoRe DocVQA...")
|
| 64 |
+
|
| 65 |
+
ds = hf_load_dataset("vidore/docvqa_test_subsampled", split="test")
|
| 66 |
+
|
| 67 |
+
samples = []
|
| 68 |
+
for i, example in enumerate(ds):
|
| 69 |
+
if i >= num_samples:
|
| 70 |
+
break
|
| 71 |
+
|
| 72 |
+
samples.append({
|
| 73 |
+
"id": i,
|
| 74 |
+
"doc_id": f"doc_{i}",
|
| 75 |
+
"query_id": f"q_{i}",
|
| 76 |
+
"image": example.get("image", example.get("page_image")),
|
| 77 |
+
"query": example.get("query", example.get("question", "")),
|
| 78 |
+
# Ground truth: query i is relevant to doc i
|
| 79 |
+
"relevant_doc": f"doc_{i}",
|
| 80 |
+
})
|
| 81 |
+
|
| 82 |
+
logger.info(f"✅ Loaded {len(samples)} samples with ground truth")
|
| 83 |
+
return samples
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def embed_all(
|
| 87 |
+
samples: List[Dict],
|
| 88 |
+
model_name: str = "vidore/colSmol-500M",
|
| 89 |
+
) -> Dict[str, Any]:
|
| 90 |
+
"""Embed all documents and queries."""
|
| 91 |
+
logger.info(f"\n🤖 Loading model: {model_name}")
|
| 92 |
+
embedder = VisualEmbedder(model_name=model_name)
|
| 93 |
+
|
| 94 |
+
images = [s["image"] for s in samples]
|
| 95 |
+
queries = [s["query"] for s in samples if s["query"]]
|
| 96 |
+
|
| 97 |
+
# Embed images
|
| 98 |
+
logger.info(f"🎨 Embedding {len(images)} documents...")
|
| 99 |
+
start_time = time.time()
|
| 100 |
+
|
| 101 |
+
embeddings, token_infos = embedder.embed_images(
|
| 102 |
+
images, batch_size=4, return_token_info=True
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
doc_embed_time = time.time() - start_time
|
| 106 |
+
logger.info(f" Time: {doc_embed_time:.2f}s ({doc_embed_time/len(images)*1000:.1f}ms/doc)")
|
| 107 |
+
|
| 108 |
+
# Process embeddings: extract visual tokens + tile-level pooling
|
| 109 |
+
doc_data = {}
|
| 110 |
+
for i, (emb, token_info) in enumerate(zip(embeddings, token_infos)):
|
| 111 |
+
if hasattr(emb, 'cpu'):
|
| 112 |
+
emb = emb.cpu()
|
| 113 |
+
emb_np = emb.numpy() if hasattr(emb, 'numpy') else np.array(emb)
|
| 114 |
+
|
| 115 |
+
# Extract visual tokens only (filter special tokens)
|
| 116 |
+
visual_indices = token_info["visual_token_indices"]
|
| 117 |
+
visual_emb = emb_np[visual_indices].astype(np.float32)
|
| 118 |
+
|
| 119 |
+
# Tile-level pooling
|
| 120 |
+
n_rows = token_info.get("n_rows", 4)
|
| 121 |
+
n_cols = token_info.get("n_cols", 3)
|
| 122 |
+
num_tiles = n_rows * n_cols + 1 if n_rows and n_cols else 13
|
| 123 |
+
|
| 124 |
+
tile_pooled = tile_level_mean_pooling(visual_emb, num_tiles, patches_per_tile=64)
|
| 125 |
+
|
| 126 |
+
doc_data[f"doc_{i}"] = {
|
| 127 |
+
"embedding": visual_emb,
|
| 128 |
+
"pooled": tile_pooled,
|
| 129 |
+
"num_visual_tokens": len(visual_indices),
|
| 130 |
+
"num_tiles": tile_pooled.shape[0],
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
# Embed queries
|
| 134 |
+
logger.info(f"🔍 Embedding {len(queries)} queries...")
|
| 135 |
+
start_time = time.time()
|
| 136 |
+
|
| 137 |
+
query_data = {}
|
| 138 |
+
for i, query in enumerate(tqdm(queries, desc="Queries")):
|
| 139 |
+
q_emb = embedder.embed_query(query)
|
| 140 |
+
if hasattr(q_emb, 'cpu'):
|
| 141 |
+
q_emb = q_emb.cpu()
|
| 142 |
+
q_np = q_emb.numpy() if hasattr(q_emb, 'numpy') else np.array(q_emb)
|
| 143 |
+
query_data[f"q_{i}"] = q_np.astype(np.float32)
|
| 144 |
+
|
| 145 |
+
query_embed_time = time.time() - start_time
|
| 146 |
+
|
| 147 |
+
return {
|
| 148 |
+
"docs": doc_data,
|
| 149 |
+
"queries": query_data,
|
| 150 |
+
"samples": samples,
|
| 151 |
+
"doc_embed_time": doc_embed_time,
|
| 152 |
+
"query_embed_time": query_embed_time,
|
| 153 |
+
"model": model_name,
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def search_exhaustive(query_emb: np.ndarray, docs: Dict, top_k: int = 10) -> List[Dict]:
|
| 158 |
+
"""Exhaustive MaxSim search over all documents."""
|
| 159 |
+
scores = []
|
| 160 |
+
for doc_id, doc in docs.items():
|
| 161 |
+
score = compute_maxsim_score(query_emb, doc["embedding"])
|
| 162 |
+
scores.append({"id": doc_id, "score": score})
|
| 163 |
+
|
| 164 |
+
scores.sort(key=lambda x: x["score"], reverse=True)
|
| 165 |
+
return scores[:top_k]
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def search_two_stage(
|
| 169 |
+
query_emb: np.ndarray,
|
| 170 |
+
docs: Dict,
|
| 171 |
+
prefetch_k: int = 20,
|
| 172 |
+
top_k: int = 10,
|
| 173 |
+
) -> List[Dict]:
|
| 174 |
+
"""
|
| 175 |
+
Two-stage retrieval with tile-level pooling.
|
| 176 |
+
|
| 177 |
+
Stage 1: Fast prefetch using tile-pooled vectors
|
| 178 |
+
Stage 2: Exact MaxSim reranking on candidates
|
| 179 |
+
"""
|
| 180 |
+
# Stage 1: Tile-level pooled search
|
| 181 |
+
query_pooled = query_emb.mean(axis=0)
|
| 182 |
+
query_pooled = query_pooled / (np.linalg.norm(query_pooled) + 1e-8)
|
| 183 |
+
|
| 184 |
+
stage1_scores = []
|
| 185 |
+
for doc_id, doc in docs.items():
|
| 186 |
+
doc_pooled = doc["pooled"]
|
| 187 |
+
doc_norm = doc_pooled / (np.linalg.norm(doc_pooled, axis=1, keepdims=True) + 1e-8)
|
| 188 |
+
tile_sims = np.dot(doc_norm, query_pooled)
|
| 189 |
+
score = float(tile_sims.max())
|
| 190 |
+
stage1_scores.append({"id": doc_id, "score": score})
|
| 191 |
+
|
| 192 |
+
stage1_scores.sort(key=lambda x: x["score"], reverse=True)
|
| 193 |
+
candidates = stage1_scores[:prefetch_k]
|
| 194 |
+
|
| 195 |
+
# Stage 2: Exact MaxSim on candidates
|
| 196 |
+
reranked = []
|
| 197 |
+
for cand in candidates:
|
| 198 |
+
doc_id = cand["id"]
|
| 199 |
+
score = compute_maxsim_score(query_emb, docs[doc_id]["embedding"])
|
| 200 |
+
reranked.append({"id": doc_id, "score": score, "stage1_rank": stage1_scores.index(cand) + 1})
|
| 201 |
+
|
| 202 |
+
reranked.sort(key=lambda x: x["score"], reverse=True)
|
| 203 |
+
return reranked[:top_k]
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def compute_metrics(
|
| 207 |
+
results: Dict[str, List[Dict]],
|
| 208 |
+
samples: List[Dict],
|
| 209 |
+
k_values: List[int] = [1, 3, 5, 7, 10],
|
| 210 |
+
) -> Dict[str, float]:
|
| 211 |
+
"""
|
| 212 |
+
Compute retrieval metrics.
|
| 213 |
+
|
| 214 |
+
Since ViDoRe has 1:1 query-doc mapping (1 relevant doc per query):
|
| 215 |
+
- Recall@K (Hit Rate): Is the relevant doc in top-K? (0 or 1)
|
| 216 |
+
- Precision@K: (# relevant in top-K) / K
|
| 217 |
+
- MRR@K: 1/rank if found in top-K, else 0
|
| 218 |
+
- NDCG@K: DCG / IDCG with binary relevance
|
| 219 |
+
"""
|
| 220 |
+
metrics = {}
|
| 221 |
+
|
| 222 |
+
# Also track per-query ranks for analysis
|
| 223 |
+
all_ranks = []
|
| 224 |
+
|
| 225 |
+
for k in k_values:
|
| 226 |
+
recalls = []
|
| 227 |
+
precisions = []
|
| 228 |
+
mrrs = []
|
| 229 |
+
ndcgs = []
|
| 230 |
+
|
| 231 |
+
for sample in samples:
|
| 232 |
+
query_id = sample["query_id"]
|
| 233 |
+
relevant_doc = sample["relevant_doc"]
|
| 234 |
+
|
| 235 |
+
if query_id not in results:
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
ranking = results[query_id][:k]
|
| 239 |
+
ranked_ids = [r["id"] for r in ranking]
|
| 240 |
+
|
| 241 |
+
# Find rank of relevant doc (1-indexed, 0 if not found)
|
| 242 |
+
rank = 0
|
| 243 |
+
for i, doc_id in enumerate(ranked_ids):
|
| 244 |
+
if doc_id == relevant_doc:
|
| 245 |
+
rank = i + 1
|
| 246 |
+
break
|
| 247 |
+
|
| 248 |
+
# Recall@K (Hit Rate): 1 if found in top-K
|
| 249 |
+
found = 1.0 if rank > 0 else 0.0
|
| 250 |
+
recalls.append(found)
|
| 251 |
+
|
| 252 |
+
# Precision@K: (# relevant found) / K
|
| 253 |
+
# With 1 relevant doc: 1/K if found, 0 otherwise
|
| 254 |
+
precision = found / k
|
| 255 |
+
precisions.append(precision)
|
| 256 |
+
|
| 257 |
+
# MRR@K: 1/rank if found
|
| 258 |
+
mrr = 1.0 / rank if rank > 0 else 0.0
|
| 259 |
+
mrrs.append(mrr)
|
| 260 |
+
|
| 261 |
+
# NDCG@K (binary relevance)
|
| 262 |
+
# DCG = 1/log2(rank+1) if found, 0 otherwise
|
| 263 |
+
# IDCG = 1/log2(2) = 1 (best case: relevant at rank 1)
|
| 264 |
+
dcg = 1.0 / np.log2(rank + 1) if rank > 0 else 0.0
|
| 265 |
+
idcg = 1.0
|
| 266 |
+
ndcg = dcg / idcg
|
| 267 |
+
ndcgs.append(ndcg)
|
| 268 |
+
|
| 269 |
+
# Track actual rank for analysis (only for k=10)
|
| 270 |
+
if k == max(k_values):
|
| 271 |
+
full_ranking = results[query_id]
|
| 272 |
+
full_rank = 0
|
| 273 |
+
for i, r in enumerate(full_ranking):
|
| 274 |
+
if r["id"] == relevant_doc:
|
| 275 |
+
full_rank = i + 1
|
| 276 |
+
break
|
| 277 |
+
all_ranks.append(full_rank)
|
| 278 |
+
|
| 279 |
+
metrics[f"Recall@{k}"] = np.mean(recalls)
|
| 280 |
+
metrics[f"P@{k}"] = np.mean(precisions)
|
| 281 |
+
metrics[f"MRR@{k}"] = np.mean(mrrs)
|
| 282 |
+
metrics[f"NDCG@{k}"] = np.mean(ndcgs)
|
| 283 |
+
|
| 284 |
+
# Add summary stats
|
| 285 |
+
if all_ranks:
|
| 286 |
+
found_ranks = [r for r in all_ranks if r > 0]
|
| 287 |
+
metrics["avg_rank"] = np.mean(found_ranks) if found_ranks else float('inf')
|
| 288 |
+
metrics["median_rank"] = np.median(found_ranks) if found_ranks else float('inf')
|
| 289 |
+
metrics["not_found"] = sum(1 for r in all_ranks if r == 0)
|
| 290 |
+
|
| 291 |
+
return metrics
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def run_benchmark(
|
| 295 |
+
data: Dict,
|
| 296 |
+
skip_exhaustive: bool = False,
|
| 297 |
+
prefetch_k: int = None,
|
| 298 |
+
top_k: int = 10,
|
| 299 |
+
) -> Dict[str, Dict]:
|
| 300 |
+
"""Run retrieval benchmark with metrics."""
|
| 301 |
+
docs = data["docs"]
|
| 302 |
+
queries = data["queries"]
|
| 303 |
+
samples = data["samples"]
|
| 304 |
+
num_docs = len(docs)
|
| 305 |
+
|
| 306 |
+
# Auto-set prefetch_k to be meaningful (default: 20, or 20% of docs if >100 docs)
|
| 307 |
+
if prefetch_k is None:
|
| 308 |
+
if num_docs <= 100:
|
| 309 |
+
prefetch_k = 20 # Default: prefetch 20, rerank to top-10
|
| 310 |
+
else:
|
| 311 |
+
prefetch_k = max(20, min(100, int(num_docs * 0.2))) # 20% for larger collections
|
| 312 |
+
|
| 313 |
+
# Ensure prefetch_k < num_docs for meaningful two-stage comparison
|
| 314 |
+
if prefetch_k >= num_docs:
|
| 315 |
+
logger.warning(f"⚠️ prefetch_k={prefetch_k} >= num_docs={num_docs}")
|
| 316 |
+
logger.warning(f" Two-stage will fetch ALL docs (same as exhaustive)")
|
| 317 |
+
logger.warning(f" Use --samples > {prefetch_k * 3} for meaningful comparison")
|
| 318 |
+
|
| 319 |
+
logger.info(f"📊 Benchmark config: {num_docs} docs, prefetch_k={prefetch_k}, top_k={top_k}")
|
| 320 |
+
logger.info(f" (Both methods return top-{top_k} results - realistic retrieval scenario)")
|
| 321 |
+
|
| 322 |
+
results = {}
|
| 323 |
+
|
| 324 |
+
# Two-stage retrieval (NOVEL)
|
| 325 |
+
logger.info(f"\n🔬 Running Two-Stage retrieval (prefetch top-{prefetch_k}, rerank to top-{top_k})...")
|
| 326 |
+
two_stage_results = {}
|
| 327 |
+
two_stage_times = []
|
| 328 |
+
|
| 329 |
+
for sample in tqdm(samples, desc="Two-Stage"):
|
| 330 |
+
query_id = sample["query_id"]
|
| 331 |
+
query_emb = queries[query_id]
|
| 332 |
+
|
| 333 |
+
start = time.time()
|
| 334 |
+
ranking = search_two_stage(query_emb, docs, prefetch_k=prefetch_k, top_k=top_k)
|
| 335 |
+
two_stage_times.append(time.time() - start)
|
| 336 |
+
|
| 337 |
+
two_stage_results[query_id] = ranking
|
| 338 |
+
|
| 339 |
+
two_stage_metrics = compute_metrics(two_stage_results, samples)
|
| 340 |
+
two_stage_metrics["avg_time_ms"] = np.mean(two_stage_times) * 1000
|
| 341 |
+
two_stage_metrics["prefetch_k"] = prefetch_k
|
| 342 |
+
two_stage_metrics["top_k"] = top_k
|
| 343 |
+
results["two_stage"] = two_stage_metrics
|
| 344 |
+
|
| 345 |
+
# Exhaustive search (baseline)
|
| 346 |
+
if not skip_exhaustive:
|
| 347 |
+
logger.info(f"🔬 Running Exhaustive MaxSim (searches ALL {num_docs} docs, returns top-{top_k})...")
|
| 348 |
+
exhaustive_results = {}
|
| 349 |
+
exhaustive_times = []
|
| 350 |
+
|
| 351 |
+
for sample in tqdm(samples, desc="Exhaustive"):
|
| 352 |
+
query_id = sample["query_id"]
|
| 353 |
+
query_emb = queries[query_id]
|
| 354 |
+
|
| 355 |
+
start = time.time()
|
| 356 |
+
ranking = search_exhaustive(query_emb, docs, top_k=top_k)
|
| 357 |
+
exhaustive_times.append(time.time() - start)
|
| 358 |
+
|
| 359 |
+
exhaustive_results[query_id] = ranking
|
| 360 |
+
|
| 361 |
+
exhaustive_metrics = compute_metrics(exhaustive_results, samples)
|
| 362 |
+
exhaustive_metrics["avg_time_ms"] = np.mean(exhaustive_times) * 1000
|
| 363 |
+
exhaustive_metrics["top_k"] = top_k
|
| 364 |
+
results["exhaustive"] = exhaustive_metrics
|
| 365 |
+
|
| 366 |
+
return results
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def print_results(data: Dict, benchmark_results: Dict, show_precision: bool = False):
|
| 370 |
+
"""Print benchmark results."""
|
| 371 |
+
print("\n" + "=" * 80)
|
| 372 |
+
print("📊 BENCHMARK RESULTS")
|
| 373 |
+
print("=" * 80)
|
| 374 |
+
|
| 375 |
+
num_docs = len(data['docs'])
|
| 376 |
+
print(f"\n🤖 Model: {data['model']}")
|
| 377 |
+
print(f"📄 Documents: {num_docs}")
|
| 378 |
+
print(f"🔍 Queries: {len(data['queries'])}")
|
| 379 |
+
|
| 380 |
+
# Embedding stats
|
| 381 |
+
sample_doc = list(data['docs'].values())[0]
|
| 382 |
+
print(f"\n📏 Embedding (after visual token filtering):")
|
| 383 |
+
print(f" Visual tokens per doc: {sample_doc['num_visual_tokens']}")
|
| 384 |
+
print(f" Tile-pooled vectors: {sample_doc['num_tiles']}")
|
| 385 |
+
|
| 386 |
+
if "two_stage" in benchmark_results:
|
| 387 |
+
prefetch_k = benchmark_results["two_stage"].get("prefetch_k", "?")
|
| 388 |
+
print(f" Two-stage prefetch_k: {prefetch_k} (of {num_docs} docs)")
|
| 389 |
+
|
| 390 |
+
# Method labels - clearer naming
|
| 391 |
+
def get_label(method):
|
| 392 |
+
if method == "two_stage":
|
| 393 |
+
return "Pooled+Rerank" # Tile-pooled prefetch + MaxSim rerank
|
| 394 |
+
else:
|
| 395 |
+
return "Full MaxSim" # Exhaustive MaxSim on all docs
|
| 396 |
+
|
| 397 |
+
# Recall / Hit Rate table
|
| 398 |
+
print(f"\n🎯 RECALL (Hit Rate) @ K:")
|
| 399 |
+
print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
|
| 400 |
+
print(f" {'-'*60}")
|
| 401 |
+
|
| 402 |
+
for method, metrics in benchmark_results.items():
|
| 403 |
+
print(f" {get_label(method):<20} "
|
| 404 |
+
f"{metrics.get('Recall@1', 0):>8.3f} "
|
| 405 |
+
f"{metrics.get('Recall@3', 0):>8.3f} "
|
| 406 |
+
f"{metrics.get('Recall@5', 0):>8.3f} "
|
| 407 |
+
f"{metrics.get('Recall@7', 0):>8.3f} "
|
| 408 |
+
f"{metrics.get('Recall@10', 0):>8.3f}")
|
| 409 |
+
|
| 410 |
+
# Precision table (optional)
|
| 411 |
+
if show_precision:
|
| 412 |
+
print(f"\n📐 PRECISION @ K:")
|
| 413 |
+
print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
|
| 414 |
+
print(f" {'-'*60}")
|
| 415 |
+
|
| 416 |
+
for method, metrics in benchmark_results.items():
|
| 417 |
+
print(f" {get_label(method):<20} "
|
| 418 |
+
f"{metrics.get('P@1', 0):>8.3f} "
|
| 419 |
+
f"{metrics.get('P@3', 0):>8.3f} "
|
| 420 |
+
f"{metrics.get('P@5', 0):>8.3f} "
|
| 421 |
+
f"{metrics.get('P@7', 0):>8.3f} "
|
| 422 |
+
f"{metrics.get('P@10', 0):>8.3f}")
|
| 423 |
+
|
| 424 |
+
# NDCG table
|
| 425 |
+
print(f"\n📈 NDCG @ K:")
|
| 426 |
+
print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
|
| 427 |
+
print(f" {'-'*60}")
|
| 428 |
+
|
| 429 |
+
for method, metrics in benchmark_results.items():
|
| 430 |
+
print(f" {get_label(method):<20} "
|
| 431 |
+
f"{metrics.get('NDCG@1', 0):>8.3f} "
|
| 432 |
+
f"{metrics.get('NDCG@3', 0):>8.3f} "
|
| 433 |
+
f"{metrics.get('NDCG@5', 0):>8.3f} "
|
| 434 |
+
f"{metrics.get('NDCG@7', 0):>8.3f} "
|
| 435 |
+
f"{metrics.get('NDCG@10', 0):>8.3f}")
|
| 436 |
+
|
| 437 |
+
# MRR table
|
| 438 |
+
print(f"\n🔍 MRR @ K:")
|
| 439 |
+
print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
|
| 440 |
+
print(f" {'-'*60}")
|
| 441 |
+
|
| 442 |
+
for method, metrics in benchmark_results.items():
|
| 443 |
+
print(f" {get_label(method):<20} "
|
| 444 |
+
f"{metrics.get('MRR@1', 0):>8.3f} "
|
| 445 |
+
f"{metrics.get('MRR@3', 0):>8.3f} "
|
| 446 |
+
f"{metrics.get('MRR@5', 0):>8.3f} "
|
| 447 |
+
f"{metrics.get('MRR@7', 0):>8.3f} "
|
| 448 |
+
f"{metrics.get('MRR@10', 0):>8.3f}")
|
| 449 |
+
|
| 450 |
+
# Speed comparison
|
| 451 |
+
top_k = benchmark_results.get("two_stage", benchmark_results.get("exhaustive", {})).get("top_k", 10)
|
| 452 |
+
print(f"\n⏱️ SPEED (both return top-{top_k} results):")
|
| 453 |
+
print(f" {'Method':<20} {'Time (ms)':>12} {'Docs searched':>15}")
|
| 454 |
+
print(f" {'-'*50}")
|
| 455 |
+
|
| 456 |
+
for method, metrics in benchmark_results.items():
|
| 457 |
+
if method == "two_stage":
|
| 458 |
+
searched = metrics.get("prefetch_k", "?")
|
| 459 |
+
label = f"{searched} (stage-1)"
|
| 460 |
+
else:
|
| 461 |
+
searched = num_docs
|
| 462 |
+
label = f"{searched} (all)"
|
| 463 |
+
print(f" {get_label(method):<20} {metrics.get('avg_time_ms', 0):>12.2f} {label:>15}")
|
| 464 |
+
|
| 465 |
+
# Comparison summary
|
| 466 |
+
if "exhaustive" in benchmark_results and "two_stage" in benchmark_results:
|
| 467 |
+
ex = benchmark_results["exhaustive"]
|
| 468 |
+
ts = benchmark_results["two_stage"]
|
| 469 |
+
|
| 470 |
+
print(f"\n💡 POOLED+RERANK vs FULL MAXSIM:")
|
| 471 |
+
|
| 472 |
+
for k in [1, 5, 10]:
|
| 473 |
+
ex_recall = ex.get(f"Recall@{k}", 0)
|
| 474 |
+
ts_recall = ts.get(f"Recall@{k}", 0)
|
| 475 |
+
if ex_recall > 0:
|
| 476 |
+
retention = ts_recall / ex_recall * 100
|
| 477 |
+
print(f" • Recall@{k} retention: {retention:.1f}% ({ts_recall:.3f} vs {ex_recall:.3f})")
|
| 478 |
+
|
| 479 |
+
speedup = ex["avg_time_ms"] / ts["avg_time_ms"] if ts["avg_time_ms"] > 0 else 0
|
| 480 |
+
print(f" • Speedup: {speedup:.1f}x")
|
| 481 |
+
|
| 482 |
+
# Rank stats with explanation
|
| 483 |
+
if "avg_rank" in ts:
|
| 484 |
+
prefetch_k = ts.get("prefetch_k", "?")
|
| 485 |
+
top_k = ts.get("top_k", 10)
|
| 486 |
+
not_found = ts.get("not_found", 0)
|
| 487 |
+
total = len(data["queries"])
|
| 488 |
+
|
| 489 |
+
print(f"\n📊 POOLED+RERANK STATISTICS:")
|
| 490 |
+
print(f" Stage-1 (pooled prefetch):")
|
| 491 |
+
print(f" • Searches top-{prefetch_k} candidates using tile-pooled vectors")
|
| 492 |
+
print(f" • {total - not_found}/{total} queries ({100 - not_found/total*100:.1f}%) had relevant doc in prefetch")
|
| 493 |
+
print(f" • {not_found}/{total} queries ({not_found/total*100:.1f}%) missed (relevant doc ranked >{prefetch_k})")
|
| 494 |
+
print(f" Stage-2 (MaxSim reranking):")
|
| 495 |
+
print(f" • Reranks prefetch candidates with exact MaxSim")
|
| 496 |
+
print(f" • Returns final top-{top_k} results")
|
| 497 |
+
if ts['avg_rank'] < float('inf'):
|
| 498 |
+
print(f" • Avg rank of relevant doc (when found): {ts['avg_rank']:.1f}")
|
| 499 |
+
print(f" • Median rank: {ts['median_rank']:.1f}")
|
| 500 |
+
print(f"\n 💡 The {not_found/total*100:.1f}% miss rate is for stage-1 prefetch.")
|
| 501 |
+
print(f" Final Recall@{top_k} shows how many relevant docs ARE in top-{top_k} results.")
|
| 502 |
+
|
| 503 |
+
print("\n" + "=" * 80)
|
| 504 |
+
print("✅ Benchmark complete!")
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def main():
|
| 508 |
+
parser = argparse.ArgumentParser(
|
| 509 |
+
description="Quick benchmark for visual-rag-toolkit",
|
| 510 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 511 |
+
)
|
| 512 |
+
parser.add_argument(
|
| 513 |
+
"--samples", type=int, default=100,
|
| 514 |
+
help="Number of samples (default: 100)"
|
| 515 |
+
)
|
| 516 |
+
parser.add_argument(
|
| 517 |
+
"--model", type=str, default="vidore/colSmol-500M",
|
| 518 |
+
help="Model: vidore/colSmol-500M (default), vidore/colpali-v1.3"
|
| 519 |
+
)
|
| 520 |
+
parser.add_argument(
|
| 521 |
+
"--prefetch-k", type=int, default=None,
|
| 522 |
+
help="Stage 1 candidates for two-stage (default: 20 for <=100 docs, auto for larger)"
|
| 523 |
+
)
|
| 524 |
+
parser.add_argument(
|
| 525 |
+
"--skip-exhaustive", action="store_true",
|
| 526 |
+
help="Skip exhaustive baseline (faster)"
|
| 527 |
+
)
|
| 528 |
+
parser.add_argument(
|
| 529 |
+
"--show-precision", action="store_true",
|
| 530 |
+
help="Show Precision@K metrics (hidden by default)"
|
| 531 |
+
)
|
| 532 |
+
parser.add_argument(
|
| 533 |
+
"--top-k", type=int, default=10,
|
| 534 |
+
help="Number of results to return (default: 10, realistic retrieval scenario)"
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
args = parser.parse_args()
|
| 538 |
+
|
| 539 |
+
print("\n" + "=" * 70)
|
| 540 |
+
print("🧪 VISUAL RAG TOOLKIT - RETRIEVAL BENCHMARK")
|
| 541 |
+
print("=" * 70)
|
| 542 |
+
|
| 543 |
+
# Load samples
|
| 544 |
+
samples = load_vidore_sample(args.samples)
|
| 545 |
+
|
| 546 |
+
if not samples:
|
| 547 |
+
logger.error("No samples loaded!")
|
| 548 |
+
sys.exit(1)
|
| 549 |
+
|
| 550 |
+
# Embed all
|
| 551 |
+
data = embed_all(samples, args.model)
|
| 552 |
+
|
| 553 |
+
# Run benchmark
|
| 554 |
+
benchmark_results = run_benchmark(
|
| 555 |
+
data,
|
| 556 |
+
skip_exhaustive=args.skip_exhaustive,
|
| 557 |
+
prefetch_k=args.prefetch_k,
|
| 558 |
+
top_k=args.top_k,
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
# Print results
|
| 562 |
+
print_results(data, benchmark_results, show_precision=args.show_precision)
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
if __name__ == "__main__":
|
| 566 |
+
main()
|
visual_rag/__init__.py
CHANGED
|
@@ -14,16 +14,16 @@ Components:
|
|
| 14 |
Quick Start:
|
| 15 |
------------
|
| 16 |
>>> from visual_rag import VisualEmbedder, PDFProcessor, TwoStageRetriever
|
| 17 |
-
>>>
|
| 18 |
>>> # Process PDFs
|
| 19 |
>>> processor = PDFProcessor(dpi=140)
|
| 20 |
>>> images, texts = processor.process_pdf("report.pdf")
|
| 21 |
-
>>>
|
| 22 |
>>> # Generate embeddings
|
| 23 |
>>> embedder = VisualEmbedder()
|
| 24 |
>>> embeddings = embedder.embed_images(images)
|
| 25 |
>>> query_emb = embedder.embed_query("What is the budget?")
|
| 26 |
-
>>>
|
| 27 |
>>> # Search with two-stage retrieval
|
| 28 |
>>> retriever = TwoStageRetriever(qdrant_client, "my_collection")
|
| 29 |
>>> results = retriever.search(query_emb, top_k=10)
|
|
@@ -31,7 +31,7 @@ Quick Start:
|
|
| 31 |
Each component works independently - use only what you need.
|
| 32 |
"""
|
| 33 |
|
| 34 |
-
__version__ = "0.1.
|
| 35 |
|
| 36 |
# Import main classes at package level for convenience
|
| 37 |
# These are optional - if dependencies aren't installed, we catch the error
|
|
@@ -71,13 +71,17 @@ try:
|
|
| 71 |
except ImportError:
|
| 72 |
QdrantAdmin = None
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
# Config utilities (always available)
|
| 75 |
-
from visual_rag.config import
|
| 76 |
|
| 77 |
__all__ = [
|
| 78 |
# Version
|
| 79 |
"__version__",
|
| 80 |
-
|
| 81 |
# Main classes
|
| 82 |
"VisualEmbedder",
|
| 83 |
"PDFProcessor",
|
|
@@ -86,7 +90,7 @@ __all__ = [
|
|
| 86 |
"TwoStageRetriever",
|
| 87 |
"MultiVectorRetriever",
|
| 88 |
"QdrantAdmin",
|
| 89 |
-
|
| 90 |
# Config utilities
|
| 91 |
"load_config",
|
| 92 |
"get",
|
|
|
|
| 14 |
Quick Start:
|
| 15 |
------------
|
| 16 |
>>> from visual_rag import VisualEmbedder, PDFProcessor, TwoStageRetriever
|
| 17 |
+
>>>
|
| 18 |
>>> # Process PDFs
|
| 19 |
>>> processor = PDFProcessor(dpi=140)
|
| 20 |
>>> images, texts = processor.process_pdf("report.pdf")
|
| 21 |
+
>>>
|
| 22 |
>>> # Generate embeddings
|
| 23 |
>>> embedder = VisualEmbedder()
|
| 24 |
>>> embeddings = embedder.embed_images(images)
|
| 25 |
>>> query_emb = embedder.embed_query("What is the budget?")
|
| 26 |
+
>>>
|
| 27 |
>>> # Search with two-stage retrieval
|
| 28 |
>>> retriever = TwoStageRetriever(qdrant_client, "my_collection")
|
| 29 |
>>> results = retriever.search(query_emb, top_k=10)
|
|
|
|
| 31 |
Each component works independently - use only what you need.
|
| 32 |
"""
|
| 33 |
|
| 34 |
+
__version__ = "0.1.3"
|
| 35 |
|
| 36 |
# Import main classes at package level for convenience
|
| 37 |
# These are optional - if dependencies aren't installed, we catch the error
|
|
|
|
| 71 |
except ImportError:
|
| 72 |
QdrantAdmin = None
|
| 73 |
|
| 74 |
+
try:
|
| 75 |
+
from visual_rag.demo_runner import demo
|
| 76 |
+
except ImportError:
|
| 77 |
+
demo = None
|
| 78 |
+
|
| 79 |
# Config utilities (always available)
|
| 80 |
+
from visual_rag.config import get, get_section, load_config
|
| 81 |
|
| 82 |
__all__ = [
|
| 83 |
# Version
|
| 84 |
"__version__",
|
|
|
|
| 85 |
# Main classes
|
| 86 |
"VisualEmbedder",
|
| 87 |
"PDFProcessor",
|
|
|
|
| 90 |
"TwoStageRetriever",
|
| 91 |
"MultiVectorRetriever",
|
| 92 |
"QdrantAdmin",
|
| 93 |
+
"demo",
|
| 94 |
# Config utilities
|
| 95 |
"load_config",
|
| 96 |
"get",
|
visual_rag/cli/__init__.py
CHANGED
|
@@ -1,3 +1 @@
|
|
| 1 |
"""CLI entry point for visual-rag-toolkit."""
|
| 2 |
-
|
| 3 |
-
|
|
|
|
| 1 |
"""CLI entry point for visual-rag-toolkit."""
|
|
|
|
|
|
visual_rag/cli/main.py
CHANGED
|
@@ -10,20 +10,19 @@ Provides command-line interface for:
|
|
| 10 |
Usage:
|
| 11 |
# Process PDFs (like process_pdfs_saliency_v2.py)
|
| 12 |
visual-rag process --reports-dir ./pdfs --metadata-file metadata.json
|
| 13 |
-
|
| 14 |
# Search
|
| 15 |
visual-rag search --query "budget allocation" --collection my_docs
|
| 16 |
-
|
| 17 |
# Show collection info
|
| 18 |
visual-rag info --collection my_docs
|
| 19 |
"""
|
| 20 |
|
| 21 |
-
import os
|
| 22 |
-
import sys
|
| 23 |
import argparse
|
| 24 |
import logging
|
|
|
|
|
|
|
| 25 |
from pathlib import Path
|
| 26 |
-
from typing import Optional
|
| 27 |
from urllib.parse import urlparse
|
| 28 |
|
| 29 |
from dotenv import load_dotenv
|
|
@@ -44,38 +43,38 @@ def setup_logging(debug: bool = False):
|
|
| 44 |
def cmd_process(args):
|
| 45 |
"""
|
| 46 |
Process PDFs: convert → embed → upload to Cloudinary → index in Qdrant.
|
| 47 |
-
|
| 48 |
Equivalent to process_pdfs_saliency_v2.py
|
| 49 |
"""
|
| 50 |
-
from visual_rag import
|
| 51 |
from visual_rag.indexing.pipeline import ProcessingPipeline
|
| 52 |
-
|
| 53 |
# Load environment
|
| 54 |
load_dotenv()
|
| 55 |
-
|
| 56 |
# Load config
|
| 57 |
config = {}
|
| 58 |
if args.config and Path(args.config).exists():
|
| 59 |
config = load_config(args.config)
|
| 60 |
-
|
| 61 |
# Get PDFs
|
| 62 |
reports_dir = Path(args.reports_dir)
|
| 63 |
if not reports_dir.exists():
|
| 64 |
logger.error(f"❌ Reports directory not found: {reports_dir}")
|
| 65 |
sys.exit(1)
|
| 66 |
-
|
| 67 |
pdf_paths = sorted(reports_dir.glob("*.pdf")) + sorted(reports_dir.glob("*.PDF"))
|
| 68 |
if not pdf_paths:
|
| 69 |
logger.error(f"❌ No PDF files found in: {reports_dir}")
|
| 70 |
sys.exit(1)
|
| 71 |
-
|
| 72 |
logger.info(f"📁 Found {len(pdf_paths)} PDF files")
|
| 73 |
-
|
| 74 |
# Load metadata mapping
|
| 75 |
metadata_mapping = {}
|
| 76 |
if args.metadata_file:
|
| 77 |
metadata_mapping = ProcessingPipeline.load_metadata_mapping(Path(args.metadata_file))
|
| 78 |
-
|
| 79 |
# Dry run - just show summary
|
| 80 |
if args.dry_run:
|
| 81 |
logger.info("🏃 DRY RUN MODE")
|
|
@@ -83,21 +82,24 @@ def cmd_process(args):
|
|
| 83 |
logger.info(f" Metadata entries: {len(metadata_mapping)}")
|
| 84 |
logger.info(f" Collection: {args.collection}")
|
| 85 |
logger.info(f" Cloudinary: {'ENABLED' if not args.no_cloudinary else 'DISABLED'}")
|
| 86 |
-
|
| 87 |
for pdf in pdf_paths[:10]:
|
| 88 |
has_meta = "✓" if pdf.stem.lower() in metadata_mapping else "✗"
|
| 89 |
logger.info(f" {has_meta} {pdf.name}")
|
| 90 |
if len(pdf_paths) > 10:
|
| 91 |
logger.info(f" ... and {len(pdf_paths) - 10} more")
|
| 92 |
return
|
| 93 |
-
|
| 94 |
# Get settings
|
| 95 |
model_name = args.model or config.get("model", {}).get("name", "vidore/colSmol-500M")
|
| 96 |
-
collection_name = args.collection or config.get("qdrant", {}).get(
|
| 97 |
-
|
|
|
|
|
|
|
| 98 |
torch_dtype = None
|
| 99 |
if args.torch_dtype != "auto":
|
| 100 |
import torch
|
|
|
|
| 101 |
torch_dtype = {
|
| 102 |
"float32": torch.float32,
|
| 103 |
"float16": torch.float16,
|
|
@@ -111,20 +113,22 @@ def cmd_process(args):
|
|
| 111 |
torch_dtype=torch_dtype,
|
| 112 |
processor_speed=str(getattr(args, "processor_speed", "fast")),
|
| 113 |
)
|
| 114 |
-
|
| 115 |
# Initialize Qdrant indexer
|
| 116 |
-
qdrant_url =
|
|
|
|
|
|
|
| 117 |
qdrant_api_key = (
|
| 118 |
os.getenv("SIGIR_QDRANT_KEY")
|
| 119 |
or os.getenv("SIGIR_QDRANT_API_KEY")
|
| 120 |
or os.getenv("DEST_QDRANT_API_KEY")
|
| 121 |
or os.getenv("QDRANT_API_KEY")
|
| 122 |
)
|
| 123 |
-
|
| 124 |
if not qdrant_url:
|
| 125 |
logger.error("❌ QDRANT_URL environment variable not set")
|
| 126 |
sys.exit(1)
|
| 127 |
-
|
| 128 |
logger.info(f"🔌 Connecting to Qdrant: {qdrant_url}")
|
| 129 |
indexer = QdrantIndexer(
|
| 130 |
url=qdrant_url,
|
|
@@ -133,7 +137,7 @@ def cmd_process(args):
|
|
| 133 |
prefer_grpc=args.prefer_grpc,
|
| 134 |
vector_datatype=args.qdrant_vector_dtype,
|
| 135 |
)
|
| 136 |
-
|
| 137 |
# Create collection if needed
|
| 138 |
indexer.create_collection(force_recreate=args.force_recreate)
|
| 139 |
inferred_fields = []
|
|
@@ -166,7 +170,7 @@ def cmd_process(args):
|
|
| 166 |
inferred_fields.append({"field": k, "type": inferred_type})
|
| 167 |
|
| 168 |
indexer.create_payload_indexes(fields=inferred_fields)
|
| 169 |
-
|
| 170 |
# Initialize Cloudinary uploader (optional)
|
| 171 |
cloudinary_uploader = None
|
| 172 |
if not args.no_cloudinary:
|
|
@@ -176,7 +180,7 @@ def cmd_process(args):
|
|
| 176 |
except ValueError as e:
|
| 177 |
logger.warning(f"⚠️ Cloudinary not configured: {e}")
|
| 178 |
logger.warning(" Continuing without Cloudinary uploads")
|
| 179 |
-
|
| 180 |
# Create pipeline
|
| 181 |
pipeline = ProcessingPipeline(
|
| 182 |
embedder=embedder,
|
|
@@ -186,42 +190,44 @@ def cmd_process(args):
|
|
| 186 |
config=config,
|
| 187 |
embedding_strategy=args.strategy,
|
| 188 |
crop_empty=bool(getattr(args, "crop_empty", False)),
|
| 189 |
-
crop_empty_percentage_to_remove=float(
|
|
|
|
|
|
|
| 190 |
crop_empty_remove_page_number=bool(getattr(args, "crop_empty_remove_page_number", False)),
|
| 191 |
)
|
| 192 |
-
|
| 193 |
# Process PDFs
|
| 194 |
total_uploaded = 0
|
| 195 |
total_skipped = 0
|
| 196 |
total_failed = 0
|
| 197 |
-
|
| 198 |
skip_existing = not args.no_skip_existing
|
| 199 |
-
|
| 200 |
for pdf_idx, pdf_path in enumerate(pdf_paths, 1):
|
| 201 |
logger.info(f"\n{'='*60}")
|
| 202 |
logger.info(f"📄 [{pdf_idx}/{len(pdf_paths)}] {pdf_path.name}")
|
| 203 |
logger.info(f"{'='*60}")
|
| 204 |
-
|
| 205 |
result = pipeline.process_pdf(
|
| 206 |
pdf_path,
|
| 207 |
skip_existing=skip_existing,
|
| 208 |
upload_to_cloudinary=(not args.no_cloudinary),
|
| 209 |
upload_to_qdrant=True,
|
| 210 |
)
|
| 211 |
-
|
| 212 |
total_uploaded += result["uploaded"]
|
| 213 |
total_skipped += result["skipped"]
|
| 214 |
total_failed += result["failed"]
|
| 215 |
-
|
| 216 |
# Summary
|
| 217 |
logger.info(f"\n{'='*60}")
|
| 218 |
-
logger.info(
|
| 219 |
logger.info(f"{'='*60}")
|
| 220 |
logger.info(f" Total PDFs: {len(pdf_paths)}")
|
| 221 |
logger.info(f" Uploaded: {total_uploaded}")
|
| 222 |
logger.info(f" Skipped: {total_skipped}")
|
| 223 |
logger.info(f" Failed: {total_failed}")
|
| 224 |
-
|
| 225 |
info = indexer.get_collection_info()
|
| 226 |
if info:
|
| 227 |
logger.info(f" Collection points: {info.get('points_count', 'N/A')}")
|
|
@@ -229,29 +235,34 @@ def cmd_process(args):
|
|
| 229 |
|
| 230 |
def cmd_search(args):
|
| 231 |
"""Search documents."""
|
| 232 |
-
from visual_rag import VisualEmbedder
|
| 233 |
-
from visual_rag.retrieval import TwoStageRetriever, SingleStageRetriever
|
| 234 |
from qdrant_client import QdrantClient
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
| 236 |
load_dotenv()
|
| 237 |
-
|
| 238 |
-
qdrant_url =
|
|
|
|
|
|
|
| 239 |
qdrant_api_key = (
|
| 240 |
os.getenv("SIGIR_QDRANT_KEY")
|
| 241 |
or os.getenv("SIGIR_QDRANT_API_KEY")
|
| 242 |
or os.getenv("DEST_QDRANT_API_KEY")
|
| 243 |
or os.getenv("QDRANT_API_KEY")
|
| 244 |
)
|
| 245 |
-
|
| 246 |
if not qdrant_url:
|
| 247 |
logger.error("❌ QDRANT_URL not set")
|
| 248 |
sys.exit(1)
|
| 249 |
-
|
| 250 |
# Initialize
|
| 251 |
logger.info(f"🤖 Loading model: {args.model}")
|
| 252 |
-
embedder = VisualEmbedder(
|
|
|
|
|
|
|
| 253 |
|
| 254 |
-
logger.info(
|
| 255 |
grpc_port = 6334 if args.prefer_grpc and urlparse(qdrant_url).port == 6333 else None
|
| 256 |
client = QdrantClient(
|
| 257 |
url=qdrant_url,
|
|
@@ -262,11 +273,11 @@ def cmd_search(args):
|
|
| 262 |
)
|
| 263 |
two_stage = TwoStageRetriever(client, args.collection)
|
| 264 |
single_stage = SingleStageRetriever(client, args.collection)
|
| 265 |
-
|
| 266 |
# Embed query
|
| 267 |
logger.info(f"🔍 Query: {args.query}")
|
| 268 |
query_embedding = embedder.embed_query(args.query)
|
| 269 |
-
|
| 270 |
# Build filter
|
| 271 |
filter_obj = None
|
| 272 |
if args.year or args.source or args.district:
|
|
@@ -275,7 +286,7 @@ def cmd_search(args):
|
|
| 275 |
source=args.source,
|
| 276 |
district=args.district,
|
| 277 |
)
|
| 278 |
-
|
| 279 |
# Search
|
| 280 |
query_np = query_embedding.detach().cpu().numpy()
|
| 281 |
if args.strategy == "single_full":
|
|
@@ -307,21 +318,21 @@ def cmd_search(args):
|
|
| 307 |
filter_obj=filter_obj,
|
| 308 |
stage1_mode=args.stage1_mode,
|
| 309 |
)
|
| 310 |
-
|
| 311 |
# Display results
|
| 312 |
logger.info(f"\n📊 Results ({len(results)}):")
|
| 313 |
for i, result in enumerate(results, 1):
|
| 314 |
payload = result.get("payload", {})
|
| 315 |
score = result.get("score_final", result.get("score_stage1", 0))
|
| 316 |
-
|
| 317 |
filename = payload.get("filename", "N/A")
|
| 318 |
page_num = payload.get("page_number", "N/A")
|
| 319 |
year = payload.get("year", "N/A")
|
| 320 |
source = payload.get("source", "N/A")
|
| 321 |
-
|
| 322 |
logger.info(f" {i}. {filename} p.{page_num}")
|
| 323 |
logger.info(f" Score: {score:.4f} | Year: {year} | Source: {source}")
|
| 324 |
-
|
| 325 |
# Text snippet
|
| 326 |
text = payload.get("text", "")
|
| 327 |
if text and args.show_text:
|
|
@@ -332,21 +343,23 @@ def cmd_search(args):
|
|
| 332 |
def cmd_info(args):
|
| 333 |
"""Show collection info."""
|
| 334 |
from qdrant_client import QdrantClient
|
| 335 |
-
|
| 336 |
load_dotenv()
|
| 337 |
-
|
| 338 |
-
qdrant_url =
|
|
|
|
|
|
|
| 339 |
qdrant_api_key = (
|
| 340 |
os.getenv("SIGIR_QDRANT_KEY")
|
| 341 |
or os.getenv("SIGIR_QDRANT_API_KEY")
|
| 342 |
or os.getenv("DEST_QDRANT_API_KEY")
|
| 343 |
or os.getenv("QDRANT_API_KEY")
|
| 344 |
)
|
| 345 |
-
|
| 346 |
if not qdrant_url:
|
| 347 |
logger.error("❌ QDRANT_URL not set")
|
| 348 |
sys.exit(1)
|
| 349 |
-
|
| 350 |
grpc_port = 6334 if args.prefer_grpc and urlparse(qdrant_url).port == 6333 else None
|
| 351 |
client = QdrantClient(
|
| 352 |
url=qdrant_url,
|
|
@@ -355,29 +368,29 @@ def cmd_info(args):
|
|
| 355 |
grpc_port=grpc_port,
|
| 356 |
check_compatibility=False,
|
| 357 |
)
|
| 358 |
-
|
| 359 |
try:
|
| 360 |
info = client.get_collection(args.collection)
|
| 361 |
-
|
| 362 |
status = info.status
|
| 363 |
if hasattr(status, "value"):
|
| 364 |
status = status.value
|
| 365 |
-
|
| 366 |
indexed_count = getattr(info, "indexed_vectors_count", 0) or 0
|
| 367 |
if isinstance(indexed_count, dict):
|
| 368 |
indexed_count = sum(indexed_count.values())
|
| 369 |
-
|
| 370 |
logger.info(f"📊 Collection: {args.collection}")
|
| 371 |
logger.info(f" Status: {status}")
|
| 372 |
logger.info(f" Points: {info.points_count}")
|
| 373 |
logger.info(f" Indexed vectors: {indexed_count}")
|
| 374 |
-
|
| 375 |
# Show vector config
|
| 376 |
if hasattr(info, "config") and hasattr(info.config, "params"):
|
| 377 |
vectors = getattr(info.config.params, "vectors", {})
|
| 378 |
if vectors:
|
| 379 |
logger.info(f" Vectors: {list(vectors.keys())}")
|
| 380 |
-
|
| 381 |
except Exception as e:
|
| 382 |
logger.error(f"❌ Could not get collection info: {e}")
|
| 383 |
sys.exit(1)
|
|
@@ -393,24 +406,24 @@ def main():
|
|
| 393 |
Examples:
|
| 394 |
# Process PDFs (like process_pdfs_saliency_v2.py)
|
| 395 |
visual-rag process --reports-dir ./pdfs --metadata-file metadata.json
|
| 396 |
-
|
| 397 |
# Process without Cloudinary
|
| 398 |
visual-rag process --reports-dir ./pdfs --no-cloudinary
|
| 399 |
-
|
| 400 |
# Search
|
| 401 |
visual-rag search --query "budget allocation" --collection my_docs
|
| 402 |
-
|
| 403 |
# Search with filters
|
| 404 |
visual-rag search --query "budget" --year 2023 --source "Local Government"
|
| 405 |
-
|
| 406 |
# Show collection info
|
| 407 |
visual-rag info --collection my_docs
|
| 408 |
""",
|
| 409 |
)
|
| 410 |
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
| 411 |
-
|
| 412 |
subparsers = parser.add_subparsers(dest="command", help="Command")
|
| 413 |
-
|
| 414 |
# =========================================================================
|
| 415 |
# PROCESS command
|
| 416 |
# =========================================================================
|
|
@@ -420,32 +433,26 @@ Examples:
|
|
| 420 |
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 421 |
)
|
| 422 |
process_parser.add_argument(
|
| 423 |
-
"--reports-dir", type=str, required=True,
|
| 424 |
-
help="Directory containing PDF files"
|
| 425 |
-
)
|
| 426 |
-
process_parser.add_argument(
|
| 427 |
-
"--metadata-file", type=str,
|
| 428 |
-
help="JSON file with filename → metadata mapping (like filename_metadata.json)"
|
| 429 |
-
)
|
| 430 |
-
process_parser.add_argument(
|
| 431 |
-
"--collection", type=str, default="visual_documents",
|
| 432 |
-
help="Qdrant collection name"
|
| 433 |
)
|
| 434 |
process_parser.add_argument(
|
| 435 |
-
"--
|
| 436 |
-
|
|
|
|
| 437 |
)
|
| 438 |
process_parser.add_argument(
|
| 439 |
-
"--
|
| 440 |
-
help="Embedding batch size"
|
| 441 |
)
|
| 442 |
process_parser.add_argument(
|
| 443 |
-
"--
|
| 444 |
-
|
|
|
|
|
|
|
| 445 |
)
|
|
|
|
|
|
|
| 446 |
process_parser.add_argument(
|
| 447 |
-
"--no-cloudinary", action="store_true",
|
| 448 |
-
help="Skip Cloudinary uploads"
|
| 449 |
)
|
| 450 |
process_parser.add_argument(
|
| 451 |
"--crop-empty",
|
|
@@ -464,22 +471,23 @@ Examples:
|
|
| 464 |
help="If set, attempts to crop away the bottom region that contains sparse page numbers (default: off).",
|
| 465 |
)
|
| 466 |
process_parser.add_argument(
|
| 467 |
-
"--no-skip-existing",
|
| 468 |
-
|
|
|
|
| 469 |
)
|
| 470 |
process_parser.add_argument(
|
| 471 |
-
"--force-recreate", action="store_true",
|
| 472 |
-
help="Delete and recreate collection"
|
| 473 |
)
|
| 474 |
process_parser.add_argument(
|
| 475 |
-
"--dry-run", action="store_true",
|
| 476 |
-
help="Show what would be processed without doing it"
|
| 477 |
)
|
| 478 |
process_parser.add_argument(
|
| 479 |
-
"--strategy",
|
|
|
|
|
|
|
| 480 |
choices=["pooling", "standard", "all"],
|
| 481 |
help="Embedding strategy: 'pooling' (NOVEL), 'standard' (BASELINE), "
|
| 482 |
-
|
| 483 |
)
|
| 484 |
process_parser.add_argument(
|
| 485 |
"--torch-dtype",
|
|
@@ -517,7 +525,7 @@ Examples:
|
|
| 517 |
help="Disable gRPC for Qdrant client.",
|
| 518 |
)
|
| 519 |
process_parser.set_defaults(func=cmd_process)
|
| 520 |
-
|
| 521 |
# =========================================================================
|
| 522 |
# SEARCH command
|
| 523 |
# =========================================================================
|
|
@@ -525,17 +533,12 @@ Examples:
|
|
| 525 |
"search",
|
| 526 |
help="Search documents",
|
| 527 |
)
|
|
|
|
| 528 |
search_parser.add_argument(
|
| 529 |
-
"--
|
| 530 |
-
help="Search query"
|
| 531 |
)
|
| 532 |
search_parser.add_argument(
|
| 533 |
-
"--
|
| 534 |
-
help="Qdrant collection name"
|
| 535 |
-
)
|
| 536 |
-
search_parser.add_argument(
|
| 537 |
-
"--model", type=str, default="vidore/colSmol-500M",
|
| 538 |
-
help="Model name"
|
| 539 |
)
|
| 540 |
search_parser.add_argument(
|
| 541 |
"--processor-speed",
|
|
@@ -544,39 +547,29 @@ Examples:
|
|
| 544 |
choices=["fast", "slow", "auto"],
|
| 545 |
help="Processor implementation: fast (default, with fallback to slow), slow, or auto.",
|
| 546 |
)
|
|
|
|
| 547 |
search_parser.add_argument(
|
| 548 |
-
"--
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
search_parser.add_argument(
|
| 552 |
-
"--strategy", type=str, default="single_full",
|
| 553 |
choices=["single_full", "single_tiles", "single_global", "two_stage"],
|
| 554 |
-
help="Search strategy"
|
| 555 |
)
|
| 556 |
search_parser.add_argument(
|
| 557 |
-
"--prefetch-k", type=int, default=200,
|
| 558 |
-
help="Prefetch candidates for two-stage retrieval"
|
| 559 |
)
|
| 560 |
search_parser.add_argument(
|
| 561 |
-
"--stage1-mode",
|
|
|
|
|
|
|
| 562 |
choices=["pooled_query_vs_tiles", "tokens_vs_tiles", "pooled_query_vs_global"],
|
| 563 |
-
help="Stage 1 mode for two-stage retrieval"
|
| 564 |
)
|
|
|
|
|
|
|
|
|
|
| 565 |
search_parser.add_argument(
|
| 566 |
-
"--
|
| 567 |
-
help="Filter by year"
|
| 568 |
-
)
|
| 569 |
-
search_parser.add_argument(
|
| 570 |
-
"--source", type=str,
|
| 571 |
-
help="Filter by source"
|
| 572 |
-
)
|
| 573 |
-
search_parser.add_argument(
|
| 574 |
-
"--district", type=str,
|
| 575 |
-
help="Filter by district"
|
| 576 |
-
)
|
| 577 |
-
search_parser.add_argument(
|
| 578 |
-
"--show-text", action="store_true",
|
| 579 |
-
help="Show text snippets in results"
|
| 580 |
)
|
| 581 |
search_grpc_group = search_parser.add_mutually_exclusive_group()
|
| 582 |
search_grpc_group.add_argument(
|
|
@@ -593,7 +586,7 @@ Examples:
|
|
| 593 |
help="Disable gRPC for Qdrant client.",
|
| 594 |
)
|
| 595 |
search_parser.set_defaults(func=cmd_search)
|
| 596 |
-
|
| 597 |
# =========================================================================
|
| 598 |
# INFO command
|
| 599 |
# =========================================================================
|
|
@@ -602,8 +595,7 @@ Examples:
|
|
| 602 |
help="Show collection info",
|
| 603 |
)
|
| 604 |
info_parser.add_argument(
|
| 605 |
-
"--collection", type=str, default="visual_documents",
|
| 606 |
-
help="Qdrant collection name"
|
| 607 |
)
|
| 608 |
info_grpc_group = info_parser.add_mutually_exclusive_group()
|
| 609 |
info_grpc_group.add_argument(
|
|
@@ -620,16 +612,16 @@ Examples:
|
|
| 620 |
help="Disable gRPC for Qdrant client.",
|
| 621 |
)
|
| 622 |
info_parser.set_defaults(func=cmd_info)
|
| 623 |
-
|
| 624 |
# Parse and execute
|
| 625 |
args = parser.parse_args()
|
| 626 |
-
|
| 627 |
setup_logging(args.debug)
|
| 628 |
-
|
| 629 |
if not args.command:
|
| 630 |
parser.print_help()
|
| 631 |
sys.exit(0)
|
| 632 |
-
|
| 633 |
args.func(args)
|
| 634 |
|
| 635 |
|
|
|
|
| 10 |
Usage:
|
| 11 |
# Process PDFs (like process_pdfs_saliency_v2.py)
|
| 12 |
visual-rag process --reports-dir ./pdfs --metadata-file metadata.json
|
| 13 |
+
|
| 14 |
# Search
|
| 15 |
visual-rag search --query "budget allocation" --collection my_docs
|
| 16 |
+
|
| 17 |
# Show collection info
|
| 18 |
visual-rag info --collection my_docs
|
| 19 |
"""
|
| 20 |
|
|
|
|
|
|
|
| 21 |
import argparse
|
| 22 |
import logging
|
| 23 |
+
import os
|
| 24 |
+
import sys
|
| 25 |
from pathlib import Path
|
|
|
|
| 26 |
from urllib.parse import urlparse
|
| 27 |
|
| 28 |
from dotenv import load_dotenv
|
|
|
|
| 43 |
def cmd_process(args):
|
| 44 |
"""
|
| 45 |
Process PDFs: convert → embed → upload to Cloudinary → index in Qdrant.
|
| 46 |
+
|
| 47 |
Equivalent to process_pdfs_saliency_v2.py
|
| 48 |
"""
|
| 49 |
+
from visual_rag import CloudinaryUploader, QdrantIndexer, VisualEmbedder, load_config
|
| 50 |
from visual_rag.indexing.pipeline import ProcessingPipeline
|
| 51 |
+
|
| 52 |
# Load environment
|
| 53 |
load_dotenv()
|
| 54 |
+
|
| 55 |
# Load config
|
| 56 |
config = {}
|
| 57 |
if args.config and Path(args.config).exists():
|
| 58 |
config = load_config(args.config)
|
| 59 |
+
|
| 60 |
# Get PDFs
|
| 61 |
reports_dir = Path(args.reports_dir)
|
| 62 |
if not reports_dir.exists():
|
| 63 |
logger.error(f"❌ Reports directory not found: {reports_dir}")
|
| 64 |
sys.exit(1)
|
| 65 |
+
|
| 66 |
pdf_paths = sorted(reports_dir.glob("*.pdf")) + sorted(reports_dir.glob("*.PDF"))
|
| 67 |
if not pdf_paths:
|
| 68 |
logger.error(f"❌ No PDF files found in: {reports_dir}")
|
| 69 |
sys.exit(1)
|
| 70 |
+
|
| 71 |
logger.info(f"📁 Found {len(pdf_paths)} PDF files")
|
| 72 |
+
|
| 73 |
# Load metadata mapping
|
| 74 |
metadata_mapping = {}
|
| 75 |
if args.metadata_file:
|
| 76 |
metadata_mapping = ProcessingPipeline.load_metadata_mapping(Path(args.metadata_file))
|
| 77 |
+
|
| 78 |
# Dry run - just show summary
|
| 79 |
if args.dry_run:
|
| 80 |
logger.info("🏃 DRY RUN MODE")
|
|
|
|
| 82 |
logger.info(f" Metadata entries: {len(metadata_mapping)}")
|
| 83 |
logger.info(f" Collection: {args.collection}")
|
| 84 |
logger.info(f" Cloudinary: {'ENABLED' if not args.no_cloudinary else 'DISABLED'}")
|
| 85 |
+
|
| 86 |
for pdf in pdf_paths[:10]:
|
| 87 |
has_meta = "✓" if pdf.stem.lower() in metadata_mapping else "✗"
|
| 88 |
logger.info(f" {has_meta} {pdf.name}")
|
| 89 |
if len(pdf_paths) > 10:
|
| 90 |
logger.info(f" ... and {len(pdf_paths) - 10} more")
|
| 91 |
return
|
| 92 |
+
|
| 93 |
# Get settings
|
| 94 |
model_name = args.model or config.get("model", {}).get("name", "vidore/colSmol-500M")
|
| 95 |
+
collection_name = args.collection or config.get("qdrant", {}).get(
|
| 96 |
+
"collection_name", "visual_documents"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
torch_dtype = None
|
| 100 |
if args.torch_dtype != "auto":
|
| 101 |
import torch
|
| 102 |
+
|
| 103 |
torch_dtype = {
|
| 104 |
"float32": torch.float32,
|
| 105 |
"float16": torch.float16,
|
|
|
|
| 113 |
torch_dtype=torch_dtype,
|
| 114 |
processor_speed=str(getattr(args, "processor_speed", "fast")),
|
| 115 |
)
|
| 116 |
+
|
| 117 |
# Initialize Qdrant indexer
|
| 118 |
+
qdrant_url = (
|
| 119 |
+
os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
|
| 120 |
+
)
|
| 121 |
qdrant_api_key = (
|
| 122 |
os.getenv("SIGIR_QDRANT_KEY")
|
| 123 |
or os.getenv("SIGIR_QDRANT_API_KEY")
|
| 124 |
or os.getenv("DEST_QDRANT_API_KEY")
|
| 125 |
or os.getenv("QDRANT_API_KEY")
|
| 126 |
)
|
| 127 |
+
|
| 128 |
if not qdrant_url:
|
| 129 |
logger.error("❌ QDRANT_URL environment variable not set")
|
| 130 |
sys.exit(1)
|
| 131 |
+
|
| 132 |
logger.info(f"🔌 Connecting to Qdrant: {qdrant_url}")
|
| 133 |
indexer = QdrantIndexer(
|
| 134 |
url=qdrant_url,
|
|
|
|
| 137 |
prefer_grpc=args.prefer_grpc,
|
| 138 |
vector_datatype=args.qdrant_vector_dtype,
|
| 139 |
)
|
| 140 |
+
|
| 141 |
# Create collection if needed
|
| 142 |
indexer.create_collection(force_recreate=args.force_recreate)
|
| 143 |
inferred_fields = []
|
|
|
|
| 170 |
inferred_fields.append({"field": k, "type": inferred_type})
|
| 171 |
|
| 172 |
indexer.create_payload_indexes(fields=inferred_fields)
|
| 173 |
+
|
| 174 |
# Initialize Cloudinary uploader (optional)
|
| 175 |
cloudinary_uploader = None
|
| 176 |
if not args.no_cloudinary:
|
|
|
|
| 180 |
except ValueError as e:
|
| 181 |
logger.warning(f"⚠️ Cloudinary not configured: {e}")
|
| 182 |
logger.warning(" Continuing without Cloudinary uploads")
|
| 183 |
+
|
| 184 |
# Create pipeline
|
| 185 |
pipeline = ProcessingPipeline(
|
| 186 |
embedder=embedder,
|
|
|
|
| 190 |
config=config,
|
| 191 |
embedding_strategy=args.strategy,
|
| 192 |
crop_empty=bool(getattr(args, "crop_empty", False)),
|
| 193 |
+
crop_empty_percentage_to_remove=float(
|
| 194 |
+
getattr(args, "crop_empty_percentage_to_remove", 0.9)
|
| 195 |
+
),
|
| 196 |
crop_empty_remove_page_number=bool(getattr(args, "crop_empty_remove_page_number", False)),
|
| 197 |
)
|
| 198 |
+
|
| 199 |
# Process PDFs
|
| 200 |
total_uploaded = 0
|
| 201 |
total_skipped = 0
|
| 202 |
total_failed = 0
|
| 203 |
+
|
| 204 |
skip_existing = not args.no_skip_existing
|
| 205 |
+
|
| 206 |
for pdf_idx, pdf_path in enumerate(pdf_paths, 1):
|
| 207 |
logger.info(f"\n{'='*60}")
|
| 208 |
logger.info(f"📄 [{pdf_idx}/{len(pdf_paths)}] {pdf_path.name}")
|
| 209 |
logger.info(f"{'='*60}")
|
| 210 |
+
|
| 211 |
result = pipeline.process_pdf(
|
| 212 |
pdf_path,
|
| 213 |
skip_existing=skip_existing,
|
| 214 |
upload_to_cloudinary=(not args.no_cloudinary),
|
| 215 |
upload_to_qdrant=True,
|
| 216 |
)
|
| 217 |
+
|
| 218 |
total_uploaded += result["uploaded"]
|
| 219 |
total_skipped += result["skipped"]
|
| 220 |
total_failed += result["failed"]
|
| 221 |
+
|
| 222 |
# Summary
|
| 223 |
logger.info(f"\n{'='*60}")
|
| 224 |
+
logger.info("📊 SUMMARY")
|
| 225 |
logger.info(f"{'='*60}")
|
| 226 |
logger.info(f" Total PDFs: {len(pdf_paths)}")
|
| 227 |
logger.info(f" Uploaded: {total_uploaded}")
|
| 228 |
logger.info(f" Skipped: {total_skipped}")
|
| 229 |
logger.info(f" Failed: {total_failed}")
|
| 230 |
+
|
| 231 |
info = indexer.get_collection_info()
|
| 232 |
if info:
|
| 233 |
logger.info(f" Collection points: {info.get('points_count', 'N/A')}")
|
|
|
|
| 235 |
|
| 236 |
def cmd_search(args):
|
| 237 |
"""Search documents."""
|
|
|
|
|
|
|
| 238 |
from qdrant_client import QdrantClient
|
| 239 |
+
|
| 240 |
+
from visual_rag import VisualEmbedder
|
| 241 |
+
from visual_rag.retrieval import SingleStageRetriever, TwoStageRetriever
|
| 242 |
+
|
| 243 |
load_dotenv()
|
| 244 |
+
|
| 245 |
+
qdrant_url = (
|
| 246 |
+
os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
|
| 247 |
+
)
|
| 248 |
qdrant_api_key = (
|
| 249 |
os.getenv("SIGIR_QDRANT_KEY")
|
| 250 |
or os.getenv("SIGIR_QDRANT_API_KEY")
|
| 251 |
or os.getenv("DEST_QDRANT_API_KEY")
|
| 252 |
or os.getenv("QDRANT_API_KEY")
|
| 253 |
)
|
| 254 |
+
|
| 255 |
if not qdrant_url:
|
| 256 |
logger.error("❌ QDRANT_URL not set")
|
| 257 |
sys.exit(1)
|
| 258 |
+
|
| 259 |
# Initialize
|
| 260 |
logger.info(f"🤖 Loading model: {args.model}")
|
| 261 |
+
embedder = VisualEmbedder(
|
| 262 |
+
model_name=args.model, processor_speed=str(getattr(args, "processor_speed", "fast"))
|
| 263 |
+
)
|
| 264 |
|
| 265 |
+
logger.info("🔌 Connecting to Qdrant")
|
| 266 |
grpc_port = 6334 if args.prefer_grpc and urlparse(qdrant_url).port == 6333 else None
|
| 267 |
client = QdrantClient(
|
| 268 |
url=qdrant_url,
|
|
|
|
| 273 |
)
|
| 274 |
two_stage = TwoStageRetriever(client, args.collection)
|
| 275 |
single_stage = SingleStageRetriever(client, args.collection)
|
| 276 |
+
|
| 277 |
# Embed query
|
| 278 |
logger.info(f"🔍 Query: {args.query}")
|
| 279 |
query_embedding = embedder.embed_query(args.query)
|
| 280 |
+
|
| 281 |
# Build filter
|
| 282 |
filter_obj = None
|
| 283 |
if args.year or args.source or args.district:
|
|
|
|
| 286 |
source=args.source,
|
| 287 |
district=args.district,
|
| 288 |
)
|
| 289 |
+
|
| 290 |
# Search
|
| 291 |
query_np = query_embedding.detach().cpu().numpy()
|
| 292 |
if args.strategy == "single_full":
|
|
|
|
| 318 |
filter_obj=filter_obj,
|
| 319 |
stage1_mode=args.stage1_mode,
|
| 320 |
)
|
| 321 |
+
|
| 322 |
# Display results
|
| 323 |
logger.info(f"\n📊 Results ({len(results)}):")
|
| 324 |
for i, result in enumerate(results, 1):
|
| 325 |
payload = result.get("payload", {})
|
| 326 |
score = result.get("score_final", result.get("score_stage1", 0))
|
| 327 |
+
|
| 328 |
filename = payload.get("filename", "N/A")
|
| 329 |
page_num = payload.get("page_number", "N/A")
|
| 330 |
year = payload.get("year", "N/A")
|
| 331 |
source = payload.get("source", "N/A")
|
| 332 |
+
|
| 333 |
logger.info(f" {i}. {filename} p.{page_num}")
|
| 334 |
logger.info(f" Score: {score:.4f} | Year: {year} | Source: {source}")
|
| 335 |
+
|
| 336 |
# Text snippet
|
| 337 |
text = payload.get("text", "")
|
| 338 |
if text and args.show_text:
|
|
|
|
| 343 |
def cmd_info(args):
|
| 344 |
"""Show collection info."""
|
| 345 |
from qdrant_client import QdrantClient
|
| 346 |
+
|
| 347 |
load_dotenv()
|
| 348 |
+
|
| 349 |
+
qdrant_url = (
|
| 350 |
+
os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
|
| 351 |
+
)
|
| 352 |
qdrant_api_key = (
|
| 353 |
os.getenv("SIGIR_QDRANT_KEY")
|
| 354 |
or os.getenv("SIGIR_QDRANT_API_KEY")
|
| 355 |
or os.getenv("DEST_QDRANT_API_KEY")
|
| 356 |
or os.getenv("QDRANT_API_KEY")
|
| 357 |
)
|
| 358 |
+
|
| 359 |
if not qdrant_url:
|
| 360 |
logger.error("❌ QDRANT_URL not set")
|
| 361 |
sys.exit(1)
|
| 362 |
+
|
| 363 |
grpc_port = 6334 if args.prefer_grpc and urlparse(qdrant_url).port == 6333 else None
|
| 364 |
client = QdrantClient(
|
| 365 |
url=qdrant_url,
|
|
|
|
| 368 |
grpc_port=grpc_port,
|
| 369 |
check_compatibility=False,
|
| 370 |
)
|
| 371 |
+
|
| 372 |
try:
|
| 373 |
info = client.get_collection(args.collection)
|
| 374 |
+
|
| 375 |
status = info.status
|
| 376 |
if hasattr(status, "value"):
|
| 377 |
status = status.value
|
| 378 |
+
|
| 379 |
indexed_count = getattr(info, "indexed_vectors_count", 0) or 0
|
| 380 |
if isinstance(indexed_count, dict):
|
| 381 |
indexed_count = sum(indexed_count.values())
|
| 382 |
+
|
| 383 |
logger.info(f"📊 Collection: {args.collection}")
|
| 384 |
logger.info(f" Status: {status}")
|
| 385 |
logger.info(f" Points: {info.points_count}")
|
| 386 |
logger.info(f" Indexed vectors: {indexed_count}")
|
| 387 |
+
|
| 388 |
# Show vector config
|
| 389 |
if hasattr(info, "config") and hasattr(info.config, "params"):
|
| 390 |
vectors = getattr(info.config.params, "vectors", {})
|
| 391 |
if vectors:
|
| 392 |
logger.info(f" Vectors: {list(vectors.keys())}")
|
| 393 |
+
|
| 394 |
except Exception as e:
|
| 395 |
logger.error(f"❌ Could not get collection info: {e}")
|
| 396 |
sys.exit(1)
|
|
|
|
| 406 |
Examples:
|
| 407 |
# Process PDFs (like process_pdfs_saliency_v2.py)
|
| 408 |
visual-rag process --reports-dir ./pdfs --metadata-file metadata.json
|
| 409 |
+
|
| 410 |
# Process without Cloudinary
|
| 411 |
visual-rag process --reports-dir ./pdfs --no-cloudinary
|
| 412 |
+
|
| 413 |
# Search
|
| 414 |
visual-rag search --query "budget allocation" --collection my_docs
|
| 415 |
+
|
| 416 |
# Search with filters
|
| 417 |
visual-rag search --query "budget" --year 2023 --source "Local Government"
|
| 418 |
+
|
| 419 |
# Show collection info
|
| 420 |
visual-rag info --collection my_docs
|
| 421 |
""",
|
| 422 |
)
|
| 423 |
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
| 424 |
+
|
| 425 |
subparsers = parser.add_subparsers(dest="command", help="Command")
|
| 426 |
+
|
| 427 |
# =========================================================================
|
| 428 |
# PROCESS command
|
| 429 |
# =========================================================================
|
|
|
|
| 433 |
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 434 |
)
|
| 435 |
process_parser.add_argument(
|
| 436 |
+
"--reports-dir", type=str, required=True, help="Directory containing PDF files"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
)
|
| 438 |
process_parser.add_argument(
|
| 439 |
+
"--metadata-file",
|
| 440 |
+
type=str,
|
| 441 |
+
help="JSON file with filename → metadata mapping (like filename_metadata.json)",
|
| 442 |
)
|
| 443 |
process_parser.add_argument(
|
| 444 |
+
"--collection", type=str, default="visual_documents", help="Qdrant collection name"
|
|
|
|
| 445 |
)
|
| 446 |
process_parser.add_argument(
|
| 447 |
+
"--model",
|
| 448 |
+
type=str,
|
| 449 |
+
default="vidore/colSmol-500M",
|
| 450 |
+
help="Model name (vidore/colSmol-500M, vidore/colpali-v1.3, etc.)",
|
| 451 |
)
|
| 452 |
+
process_parser.add_argument("--batch-size", type=int, default=8, help="Embedding batch size")
|
| 453 |
+
process_parser.add_argument("--config", type=str, help="Path to config.yaml file")
|
| 454 |
process_parser.add_argument(
|
| 455 |
+
"--no-cloudinary", action="store_true", help="Skip Cloudinary uploads"
|
|
|
|
| 456 |
)
|
| 457 |
process_parser.add_argument(
|
| 458 |
"--crop-empty",
|
|
|
|
| 471 |
help="If set, attempts to crop away the bottom region that contains sparse page numbers (default: off).",
|
| 472 |
)
|
| 473 |
process_parser.add_argument(
|
| 474 |
+
"--no-skip-existing",
|
| 475 |
+
action="store_true",
|
| 476 |
+
help="Process all pages even if they exist in Qdrant",
|
| 477 |
)
|
| 478 |
process_parser.add_argument(
|
| 479 |
+
"--force-recreate", action="store_true", help="Delete and recreate collection"
|
|
|
|
| 480 |
)
|
| 481 |
process_parser.add_argument(
|
| 482 |
+
"--dry-run", action="store_true", help="Show what would be processed without doing it"
|
|
|
|
| 483 |
)
|
| 484 |
process_parser.add_argument(
|
| 485 |
+
"--strategy",
|
| 486 |
+
type=str,
|
| 487 |
+
default="pooling",
|
| 488 |
choices=["pooling", "standard", "all"],
|
| 489 |
help="Embedding strategy: 'pooling' (NOVEL), 'standard' (BASELINE), "
|
| 490 |
+
"'all' (embed once, store BOTH for comparison)",
|
| 491 |
)
|
| 492 |
process_parser.add_argument(
|
| 493 |
"--torch-dtype",
|
|
|
|
| 525 |
help="Disable gRPC for Qdrant client.",
|
| 526 |
)
|
| 527 |
process_parser.set_defaults(func=cmd_process)
|
| 528 |
+
|
| 529 |
# =========================================================================
|
| 530 |
# SEARCH command
|
| 531 |
# =========================================================================
|
|
|
|
| 533 |
"search",
|
| 534 |
help="Search documents",
|
| 535 |
)
|
| 536 |
+
search_parser.add_argument("--query", type=str, required=True, help="Search query")
|
| 537 |
search_parser.add_argument(
|
| 538 |
+
"--collection", type=str, default="visual_documents", help="Qdrant collection name"
|
|
|
|
| 539 |
)
|
| 540 |
search_parser.add_argument(
|
| 541 |
+
"--model", type=str, default="vidore/colSmol-500M", help="Model name"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
)
|
| 543 |
search_parser.add_argument(
|
| 544 |
"--processor-speed",
|
|
|
|
| 547 |
choices=["fast", "slow", "auto"],
|
| 548 |
help="Processor implementation: fast (default, with fallback to slow), slow, or auto.",
|
| 549 |
)
|
| 550 |
+
search_parser.add_argument("--top-k", type=int, default=10, help="Number of results")
|
| 551 |
search_parser.add_argument(
|
| 552 |
+
"--strategy",
|
| 553 |
+
type=str,
|
| 554 |
+
default="single_full",
|
|
|
|
|
|
|
| 555 |
choices=["single_full", "single_tiles", "single_global", "two_stage"],
|
| 556 |
+
help="Search strategy",
|
| 557 |
)
|
| 558 |
search_parser.add_argument(
|
| 559 |
+
"--prefetch-k", type=int, default=200, help="Prefetch candidates for two-stage retrieval"
|
|
|
|
| 560 |
)
|
| 561 |
search_parser.add_argument(
|
| 562 |
+
"--stage1-mode",
|
| 563 |
+
type=str,
|
| 564 |
+
default="pooled_query_vs_tiles",
|
| 565 |
choices=["pooled_query_vs_tiles", "tokens_vs_tiles", "pooled_query_vs_global"],
|
| 566 |
+
help="Stage 1 mode for two-stage retrieval",
|
| 567 |
)
|
| 568 |
+
search_parser.add_argument("--year", type=int, help="Filter by year")
|
| 569 |
+
search_parser.add_argument("--source", type=str, help="Filter by source")
|
| 570 |
+
search_parser.add_argument("--district", type=str, help="Filter by district")
|
| 571 |
search_parser.add_argument(
|
| 572 |
+
"--show-text", action="store_true", help="Show text snippets in results"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
)
|
| 574 |
search_grpc_group = search_parser.add_mutually_exclusive_group()
|
| 575 |
search_grpc_group.add_argument(
|
|
|
|
| 586 |
help="Disable gRPC for Qdrant client.",
|
| 587 |
)
|
| 588 |
search_parser.set_defaults(func=cmd_search)
|
| 589 |
+
|
| 590 |
# =========================================================================
|
| 591 |
# INFO command
|
| 592 |
# =========================================================================
|
|
|
|
| 595 |
help="Show collection info",
|
| 596 |
)
|
| 597 |
info_parser.add_argument(
|
| 598 |
+
"--collection", type=str, default="visual_documents", help="Qdrant collection name"
|
|
|
|
| 599 |
)
|
| 600 |
info_grpc_group = info_parser.add_mutually_exclusive_group()
|
| 601 |
info_grpc_group.add_argument(
|
|
|
|
| 612 |
help="Disable gRPC for Qdrant client.",
|
| 613 |
)
|
| 614 |
info_parser.set_defaults(func=cmd_info)
|
| 615 |
+
|
| 616 |
# Parse and execute
|
| 617 |
args = parser.parse_args()
|
| 618 |
+
|
| 619 |
setup_logging(args.debug)
|
| 620 |
+
|
| 621 |
if not args.command:
|
| 622 |
parser.print_help()
|
| 623 |
sys.exit(0)
|
| 624 |
+
|
| 625 |
args.func(args)
|
| 626 |
|
| 627 |
|
visual_rag/config.py
CHANGED
|
@@ -7,57 +7,56 @@ Provides:
|
|
| 7 |
- Convenience getters for common settings
|
| 8 |
"""
|
| 9 |
|
| 10 |
-
import
|
| 11 |
import logging
|
|
|
|
| 12 |
from pathlib import Path
|
| 13 |
-
from typing import Any,
|
| 14 |
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
| 17 |
-
# Global config cache
|
| 18 |
-
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
def _env_qdrant_url() -> Optional[str]:
|
| 22 |
-
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
def _env_qdrant_api_key() -> Optional[str]:
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
or os.getenv("SIGIR_QDRANT_API_KEY")
|
| 29 |
-
or os.getenv("DEST_QDRANT_API_KEY")
|
| 30 |
-
or os.getenv("QDRANT_API_KEY")
|
| 31 |
-
)
|
| 32 |
|
| 33 |
|
| 34 |
def load_config(
|
| 35 |
config_path: Optional[str] = None,
|
| 36 |
force_reload: bool = False,
|
|
|
|
| 37 |
) -> Dict[str, Any]:
|
| 38 |
"""
|
| 39 |
Load configuration from YAML file.
|
| 40 |
-
|
| 41 |
Uses caching to avoid repeated file I/O.
|
| 42 |
Environment variables can override config values.
|
| 43 |
-
|
| 44 |
Args:
|
| 45 |
config_path: Path to config file (auto-detected if None)
|
| 46 |
force_reload: Bypass cache and reload from file
|
| 47 |
-
|
| 48 |
Returns:
|
| 49 |
Configuration dictionary
|
| 50 |
"""
|
| 51 |
-
global
|
| 52 |
-
|
| 53 |
-
#
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
# Find config file
|
| 58 |
if config_path is None:
|
| 59 |
config_path = os.getenv("VISUALRAG_CONFIG")
|
| 60 |
-
|
| 61 |
if config_path is None:
|
| 62 |
# Check common locations
|
| 63 |
search_paths = [
|
|
@@ -65,65 +64,75 @@ def load_config(
|
|
| 65 |
Path.cwd() / "visual_rag.yaml",
|
| 66 |
Path.home() / ".visual_rag" / "config.yaml",
|
| 67 |
]
|
| 68 |
-
|
| 69 |
for path in search_paths:
|
| 70 |
if path.exists():
|
| 71 |
config_path = str(path)
|
| 72 |
break
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
# Load YAML if file exists
|
| 75 |
config = {}
|
| 76 |
if config_path and Path(config_path).exists():
|
| 77 |
try:
|
| 78 |
import yaml
|
| 79 |
-
|
| 80 |
with open(config_path, "r") as f:
|
| 81 |
config = yaml.safe_load(f) or {}
|
| 82 |
-
|
| 83 |
logger.info(f"Loaded config from: {config_path}")
|
| 84 |
except ImportError:
|
| 85 |
logger.warning("PyYAML not installed, using environment variables only")
|
| 86 |
except Exception as e:
|
| 87 |
logger.warning(f"Could not load config file: {e}")
|
| 88 |
-
|
| 89 |
-
#
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
def _apply_env_overrides(config: Dict[str, Any]) -> Dict[str, Any]:
|
| 97 |
"""Apply environment variable overrides."""
|
| 98 |
-
|
| 99 |
env_mappings = {
|
| 100 |
# Qdrant
|
| 101 |
"QDRANT_URL": ["qdrant", "url"],
|
| 102 |
"QDRANT_API_KEY": ["qdrant", "api_key"],
|
| 103 |
"QDRANT_COLLECTION": ["qdrant", "collection"],
|
| 104 |
-
|
| 105 |
# Model
|
| 106 |
"VISUALRAG_MODEL": ["model", "name"],
|
| 107 |
"COLPALI_MODEL_NAME": ["model", "name"], # Alias
|
| 108 |
"EMBEDDING_BATCH_SIZE": ["model", "batch_size"],
|
| 109 |
-
|
| 110 |
# Cloudinary
|
| 111 |
"CLOUDINARY_CLOUD_NAME": ["cloudinary", "cloud_name"],
|
| 112 |
"CLOUDINARY_API_KEY": ["cloudinary", "api_key"],
|
| 113 |
"CLOUDINARY_API_SECRET": ["cloudinary", "api_secret"],
|
| 114 |
-
|
| 115 |
# Processing
|
| 116 |
"PDF_DPI": ["processing", "dpi"],
|
| 117 |
"JPEG_QUALITY": ["processing", "jpeg_quality"],
|
| 118 |
-
|
| 119 |
# Search
|
| 120 |
"SEARCH_STRATEGY": ["search", "strategy"],
|
| 121 |
"PREFETCH_K": ["search", "prefetch_k"],
|
| 122 |
-
|
| 123 |
# Special token handling
|
| 124 |
"VISUALRAG_INCLUDE_SPECIAL_TOKENS": ["embedding", "include_special_tokens"],
|
| 125 |
}
|
| 126 |
-
|
| 127 |
for env_var, path in env_mappings.items():
|
| 128 |
value = os.getenv(env_var)
|
| 129 |
if value is not None:
|
|
@@ -133,50 +142,51 @@ def _apply_env_overrides(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
| 133 |
if key not in current:
|
| 134 |
current[key] = {}
|
| 135 |
current = current[key]
|
| 136 |
-
|
| 137 |
# Convert value to appropriate type
|
| 138 |
final_key = path[-1]
|
| 139 |
if final_key in current:
|
| 140 |
existing_type = type(current[final_key])
|
| 141 |
-
|
|
|
|
| 142 |
value = value.lower() in ("true", "1", "yes", "on")
|
| 143 |
-
elif existing_type
|
| 144 |
value = int(value)
|
| 145 |
-
elif existing_type
|
| 146 |
value = float(value)
|
| 147 |
-
|
| 148 |
current[final_key] = value
|
| 149 |
logger.debug(f"Config override: {'.'.join(path)} = {value}")
|
| 150 |
-
|
| 151 |
return config
|
| 152 |
|
| 153 |
|
| 154 |
def get(key: str, default: Any = None) -> Any:
|
| 155 |
"""
|
| 156 |
Get a configuration value by dot-notation path.
|
| 157 |
-
|
| 158 |
Examples:
|
| 159 |
>>> get("qdrant.url")
|
| 160 |
>>> get("model.name", "vidore/colSmol-500M")
|
| 161 |
>>> get("search.strategy", "multi_vector")
|
| 162 |
"""
|
| 163 |
-
config = load_config()
|
| 164 |
-
|
| 165 |
keys = key.split(".")
|
| 166 |
current = config
|
| 167 |
-
|
| 168 |
for k in keys:
|
| 169 |
if isinstance(current, dict) and k in current:
|
| 170 |
current = current[k]
|
| 171 |
else:
|
| 172 |
return default
|
| 173 |
-
|
| 174 |
return current
|
| 175 |
|
| 176 |
|
| 177 |
-
def get_section(section: str) -> Dict[str, Any]:
|
| 178 |
"""Get an entire configuration section."""
|
| 179 |
-
config = load_config()
|
| 180 |
return config.get(section, {})
|
| 181 |
|
| 182 |
|
|
@@ -215,5 +225,3 @@ def get_search_config() -> Dict[str, Any]:
|
|
| 215 |
"prefetch_k": get("search.prefetch_k", 200),
|
| 216 |
"top_k": get("search.top_k", 10),
|
| 217 |
}
|
| 218 |
-
|
| 219 |
-
|
|
|
|
| 7 |
- Convenience getters for common settings
|
| 8 |
"""
|
| 9 |
|
| 10 |
+
import copy
|
| 11 |
import logging
|
| 12 |
+
import os
|
| 13 |
from pathlib import Path
|
| 14 |
+
from typing import Any, Dict, Optional
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
+
# Global config cache (raw YAML only; env overrides applied on demand)
|
| 19 |
+
_raw_config_cache: Optional[Dict[str, Any]] = None
|
| 20 |
+
_raw_config_cache_path: Optional[str] = None
|
| 21 |
|
| 22 |
|
| 23 |
def _env_qdrant_url() -> Optional[str]:
|
| 24 |
+
"""Get Qdrant URL from environment. Prefers QDRANT_URL."""
|
| 25 |
+
return os.getenv("QDRANT_URL") or os.getenv("SIGIR_QDRANT_URL") # legacy fallback
|
| 26 |
|
| 27 |
|
| 28 |
def _env_qdrant_api_key() -> Optional[str]:
|
| 29 |
+
"""Get Qdrant API key from environment. Prefers QDRANT_API_KEY."""
|
| 30 |
+
return os.getenv("QDRANT_API_KEY") or os.getenv("SIGIR_QDRANT_KEY") # legacy fallback
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
def load_config(
|
| 34 |
config_path: Optional[str] = None,
|
| 35 |
force_reload: bool = False,
|
| 36 |
+
apply_env_overrides: bool = True,
|
| 37 |
) -> Dict[str, Any]:
|
| 38 |
"""
|
| 39 |
Load configuration from YAML file.
|
| 40 |
+
|
| 41 |
Uses caching to avoid repeated file I/O.
|
| 42 |
Environment variables can override config values.
|
| 43 |
+
|
| 44 |
Args:
|
| 45 |
config_path: Path to config file (auto-detected if None)
|
| 46 |
force_reload: Bypass cache and reload from file
|
| 47 |
+
|
| 48 |
Returns:
|
| 49 |
Configuration dictionary
|
| 50 |
"""
|
| 51 |
+
global _raw_config_cache, _raw_config_cache_path
|
| 52 |
+
|
| 53 |
+
# Determine the effective config path (used for caching)
|
| 54 |
+
effective_path: Optional[str] = None
|
| 55 |
+
|
|
|
|
| 56 |
# Find config file
|
| 57 |
if config_path is None:
|
| 58 |
config_path = os.getenv("VISUALRAG_CONFIG")
|
| 59 |
+
|
| 60 |
if config_path is None:
|
| 61 |
# Check common locations
|
| 62 |
search_paths = [
|
|
|
|
| 64 |
Path.cwd() / "visual_rag.yaml",
|
| 65 |
Path.home() / ".visual_rag" / "config.yaml",
|
| 66 |
]
|
| 67 |
+
|
| 68 |
for path in search_paths:
|
| 69 |
if path.exists():
|
| 70 |
config_path = str(path)
|
| 71 |
break
|
| 72 |
+
effective_path = str(config_path) if config_path else None
|
| 73 |
+
|
| 74 |
+
# Return cached raw config if available.
|
| 75 |
+
# - If caller doesn't specify a path (effective_path is None), use whatever was
|
| 76 |
+
# loaded most recently (common pattern in apps).
|
| 77 |
+
# - If a path is specified, only reuse cache when it matches.
|
| 78 |
+
if (
|
| 79 |
+
_raw_config_cache is not None
|
| 80 |
+
and not force_reload
|
| 81 |
+
and (effective_path is None or _raw_config_cache_path == effective_path)
|
| 82 |
+
):
|
| 83 |
+
cfg = copy.deepcopy(_raw_config_cache)
|
| 84 |
+
return _apply_env_overrides(cfg) if apply_env_overrides else cfg
|
| 85 |
+
|
| 86 |
# Load YAML if file exists
|
| 87 |
config = {}
|
| 88 |
if config_path and Path(config_path).exists():
|
| 89 |
try:
|
| 90 |
import yaml
|
| 91 |
+
|
| 92 |
with open(config_path, "r") as f:
|
| 93 |
config = yaml.safe_load(f) or {}
|
| 94 |
+
|
| 95 |
logger.info(f"Loaded config from: {config_path}")
|
| 96 |
except ImportError:
|
| 97 |
logger.warning("PyYAML not installed, using environment variables only")
|
| 98 |
except Exception as e:
|
| 99 |
logger.warning(f"Could not load config file: {e}")
|
| 100 |
+
|
| 101 |
+
# Cache RAW config (no env overrides)
|
| 102 |
+
_raw_config_cache = copy.deepcopy(config)
|
| 103 |
+
_raw_config_cache_path = effective_path
|
| 104 |
+
|
| 105 |
+
# Return resolved or raw depending on caller preference
|
| 106 |
+
cfg = copy.deepcopy(config)
|
| 107 |
+
return _apply_env_overrides(cfg) if apply_env_overrides else cfg
|
| 108 |
|
| 109 |
|
| 110 |
def _apply_env_overrides(config: Dict[str, Any]) -> Dict[str, Any]:
|
| 111 |
"""Apply environment variable overrides."""
|
| 112 |
+
|
| 113 |
env_mappings = {
|
| 114 |
# Qdrant
|
| 115 |
"QDRANT_URL": ["qdrant", "url"],
|
| 116 |
"QDRANT_API_KEY": ["qdrant", "api_key"],
|
| 117 |
"QDRANT_COLLECTION": ["qdrant", "collection"],
|
|
|
|
| 118 |
# Model
|
| 119 |
"VISUALRAG_MODEL": ["model", "name"],
|
| 120 |
"COLPALI_MODEL_NAME": ["model", "name"], # Alias
|
| 121 |
"EMBEDDING_BATCH_SIZE": ["model", "batch_size"],
|
|
|
|
| 122 |
# Cloudinary
|
| 123 |
"CLOUDINARY_CLOUD_NAME": ["cloudinary", "cloud_name"],
|
| 124 |
"CLOUDINARY_API_KEY": ["cloudinary", "api_key"],
|
| 125 |
"CLOUDINARY_API_SECRET": ["cloudinary", "api_secret"],
|
|
|
|
| 126 |
# Processing
|
| 127 |
"PDF_DPI": ["processing", "dpi"],
|
| 128 |
"JPEG_QUALITY": ["processing", "jpeg_quality"],
|
|
|
|
| 129 |
# Search
|
| 130 |
"SEARCH_STRATEGY": ["search", "strategy"],
|
| 131 |
"PREFETCH_K": ["search", "prefetch_k"],
|
|
|
|
| 132 |
# Special token handling
|
| 133 |
"VISUALRAG_INCLUDE_SPECIAL_TOKENS": ["embedding", "include_special_tokens"],
|
| 134 |
}
|
| 135 |
+
|
| 136 |
for env_var, path in env_mappings.items():
|
| 137 |
value = os.getenv(env_var)
|
| 138 |
if value is not None:
|
|
|
|
| 142 |
if key not in current:
|
| 143 |
current[key] = {}
|
| 144 |
current = current[key]
|
| 145 |
+
|
| 146 |
# Convert value to appropriate type
|
| 147 |
final_key = path[-1]
|
| 148 |
if final_key in current:
|
| 149 |
existing_type = type(current[final_key])
|
| 150 |
+
# Use `is` for type comparisons (Ruff E721).
|
| 151 |
+
if existing_type is bool:
|
| 152 |
value = value.lower() in ("true", "1", "yes", "on")
|
| 153 |
+
elif existing_type is int:
|
| 154 |
value = int(value)
|
| 155 |
+
elif existing_type is float:
|
| 156 |
value = float(value)
|
| 157 |
+
|
| 158 |
current[final_key] = value
|
| 159 |
logger.debug(f"Config override: {'.'.join(path)} = {value}")
|
| 160 |
+
|
| 161 |
return config
|
| 162 |
|
| 163 |
|
| 164 |
def get(key: str, default: Any = None) -> Any:
|
| 165 |
"""
|
| 166 |
Get a configuration value by dot-notation path.
|
| 167 |
+
|
| 168 |
Examples:
|
| 169 |
>>> get("qdrant.url")
|
| 170 |
>>> get("model.name", "vidore/colSmol-500M")
|
| 171 |
>>> get("search.strategy", "multi_vector")
|
| 172 |
"""
|
| 173 |
+
config = load_config(apply_env_overrides=True)
|
| 174 |
+
|
| 175 |
keys = key.split(".")
|
| 176 |
current = config
|
| 177 |
+
|
| 178 |
for k in keys:
|
| 179 |
if isinstance(current, dict) and k in current:
|
| 180 |
current = current[k]
|
| 181 |
else:
|
| 182 |
return default
|
| 183 |
+
|
| 184 |
return current
|
| 185 |
|
| 186 |
|
| 187 |
+
def get_section(section: str, *, apply_env_overrides: bool = True) -> Dict[str, Any]:
|
| 188 |
"""Get an entire configuration section."""
|
| 189 |
+
config = load_config(apply_env_overrides=apply_env_overrides)
|
| 190 |
return config.get(section, {})
|
| 191 |
|
| 192 |
|
|
|
|
| 225 |
"prefetch_k": get("search.prefetch_k", 200),
|
| 226 |
"top_k": get("search.top_k", 10),
|
| 227 |
}
|
|
|
|
|
|
visual_rag/demo_runner.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Launch the Streamlit demo from an installed package.
|
| 3 |
+
|
| 4 |
+
Why:
|
| 5 |
+
- After `pip install visual-rag-toolkit`, the repo layout isn't present.
|
| 6 |
+
- We package the `demo/` module and expose `visual_rag.demo()` + `visual-rag-demo`.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import importlib
|
| 13 |
+
import os
|
| 14 |
+
import subprocess
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def demo(
|
| 21 |
+
*,
|
| 22 |
+
host: str = "0.0.0.0",
|
| 23 |
+
port: int = 7860,
|
| 24 |
+
headless: bool = True,
|
| 25 |
+
open_browser: bool = False,
|
| 26 |
+
extra_args: Optional[list[str]] = None,
|
| 27 |
+
) -> int:
|
| 28 |
+
"""
|
| 29 |
+
Launch the Streamlit demo UI.
|
| 30 |
+
|
| 31 |
+
Requirements:
|
| 32 |
+
- `visual-rag-toolkit[ui,qdrant,embedding,pdf]` (or `visual-rag-toolkit[all]`)
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
Streamlit process exit code.
|
| 36 |
+
"""
|
| 37 |
+
try:
|
| 38 |
+
import streamlit # noqa: F401
|
| 39 |
+
except Exception as e: # pragma: no cover
|
| 40 |
+
raise RuntimeError(
|
| 41 |
+
"Streamlit is not installed. Install with:\n"
|
| 42 |
+
' pip install "visual-rag-toolkit[ui,qdrant,embedding,pdf]"'
|
| 43 |
+
) from e
|
| 44 |
+
|
| 45 |
+
# Resolve the installed demo entrypoint path.
|
| 46 |
+
mod = importlib.import_module("demo.app")
|
| 47 |
+
app_path = Path(getattr(mod, "__file__", "")).resolve()
|
| 48 |
+
if not app_path.exists(): # pragma: no cover
|
| 49 |
+
raise RuntimeError("Could not locate installed demo app (demo.app).")
|
| 50 |
+
|
| 51 |
+
# Build a stable Streamlit invocation.
|
| 52 |
+
cmd = [sys.executable, "-m", "streamlit", "run", str(app_path)]
|
| 53 |
+
cmd += ["--server.address", str(host)]
|
| 54 |
+
cmd += ["--server.port", str(int(port))]
|
| 55 |
+
cmd += ["--server.headless", "true" if headless else "false"]
|
| 56 |
+
cmd += ["--browser.gatherUsageStats", "false"]
|
| 57 |
+
cmd += ["--server.runOnSave", "false"]
|
| 58 |
+
cmd += ["--browser.serverAddress", str(host)]
|
| 59 |
+
if not open_browser:
|
| 60 |
+
cmd += ["--browser.serverPort", str(int(port))]
|
| 61 |
+
cmd += ["--browser.open", "false"]
|
| 62 |
+
|
| 63 |
+
if extra_args:
|
| 64 |
+
cmd += list(extra_args)
|
| 65 |
+
|
| 66 |
+
env = os.environ.copy()
|
| 67 |
+
# Make sure the demo doesn't spam internal Streamlit warnings in logs.
|
| 68 |
+
env.setdefault("STREAMLIT_BROWSER_GATHER_USAGE_STATS", "false")
|
| 69 |
+
|
| 70 |
+
return subprocess.call(cmd, env=env)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def main() -> None:
|
| 74 |
+
p = argparse.ArgumentParser(description="Launch the Visual RAG Toolkit Streamlit demo.")
|
| 75 |
+
p.add_argument("--host", default="0.0.0.0")
|
| 76 |
+
p.add_argument("--port", type=int, default=7860)
|
| 77 |
+
p.add_argument(
|
| 78 |
+
"--no-headless", action="store_true", help="Run with a browser window (not headless)."
|
| 79 |
+
)
|
| 80 |
+
p.add_argument("--open", action="store_true", help="Open browser automatically.")
|
| 81 |
+
args, unknown = p.parse_known_args()
|
| 82 |
+
|
| 83 |
+
rc = demo(
|
| 84 |
+
host=args.host,
|
| 85 |
+
port=args.port,
|
| 86 |
+
headless=(not args.no_headless),
|
| 87 |
+
open_browser=bool(args.open),
|
| 88 |
+
extra_args=unknown,
|
| 89 |
+
)
|
| 90 |
+
raise SystemExit(rc)
|
visual_rag/embedding/__init__.py
CHANGED
|
@@ -6,19 +6,18 @@ Provides:
|
|
| 6 |
- Pooling utilities: tile-level, global, MaxSim scoring
|
| 7 |
"""
|
| 8 |
|
| 9 |
-
from visual_rag.embedding.visual_embedder import VisualEmbedder, ColPaliEmbedder
|
| 10 |
from visual_rag.embedding.pooling import (
|
| 11 |
-
tile_level_mean_pooling,
|
| 12 |
-
global_mean_pooling,
|
| 13 |
-
compute_maxsim_score,
|
| 14 |
compute_maxsim_batch,
|
|
|
|
|
|
|
|
|
|
| 15 |
)
|
|
|
|
| 16 |
|
| 17 |
__all__ = [
|
| 18 |
# Main embedder
|
| 19 |
"VisualEmbedder",
|
| 20 |
"ColPaliEmbedder", # Backward compatibility alias
|
| 21 |
-
|
| 22 |
# Pooling functions
|
| 23 |
"tile_level_mean_pooling",
|
| 24 |
"global_mean_pooling",
|
|
|
|
| 6 |
- Pooling utilities: tile-level, global, MaxSim scoring
|
| 7 |
"""
|
| 8 |
|
|
|
|
| 9 |
from visual_rag.embedding.pooling import (
|
|
|
|
|
|
|
|
|
|
| 10 |
compute_maxsim_batch,
|
| 11 |
+
compute_maxsim_score,
|
| 12 |
+
global_mean_pooling,
|
| 13 |
+
tile_level_mean_pooling,
|
| 14 |
)
|
| 15 |
+
from visual_rag.embedding.visual_embedder import ColPaliEmbedder, VisualEmbedder
|
| 16 |
|
| 17 |
__all__ = [
|
| 18 |
# Main embedder
|
| 19 |
"VisualEmbedder",
|
| 20 |
"ColPaliEmbedder", # Backward compatibility alias
|
|
|
|
| 21 |
# Pooling functions
|
| 22 |
"tile_level_mean_pooling",
|
| 23 |
"global_mean_pooling",
|
visual_rag/embedding/pooling.py
CHANGED
|
@@ -7,10 +7,11 @@ Provides:
|
|
| 7 |
- MaxSim scoring for ColBERT-style late interaction
|
| 8 |
"""
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
| 12 |
-
from typing import Union, Optional
|
| 13 |
-
import logging
|
| 14 |
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
|
@@ -39,24 +40,24 @@ def tile_level_mean_pooling(
|
|
| 39 |
) -> np.ndarray:
|
| 40 |
"""
|
| 41 |
Compute tile-level mean pooling for multi-vector embeddings.
|
| 42 |
-
|
| 43 |
Instead of collapsing to 1×dim (global pooling), this preserves spatial
|
| 44 |
structure by computing mean per tile → num_tiles × dim.
|
| 45 |
-
|
| 46 |
This is our NOVEL contribution for scalable visual retrieval:
|
| 47 |
- Faster than full MaxSim (fewer vectors to compare)
|
| 48 |
- More accurate than global pooling (preserves spatial info)
|
| 49 |
- Ideal for two-stage retrieval (prefetch with pooled, rerank with full)
|
| 50 |
-
|
| 51 |
Args:
|
| 52 |
embedding: Visual token embeddings [num_visual_tokens, dim]
|
| 53 |
num_tiles: Number of tiles (including global tile)
|
| 54 |
patches_per_tile: Patches per tile (64 for ColSmol)
|
| 55 |
output_dtype: Output dtype (default: infer from input, fp16→fp16, bf16→fp32)
|
| 56 |
-
|
| 57 |
Returns:
|
| 58 |
Tile-level pooled embeddings [num_tiles, dim]
|
| 59 |
-
|
| 60 |
Example:
|
| 61 |
>>> # Image with 4×3 tiles + 1 global = 13 tiles
|
| 62 |
>>> # Each tile has 64 patches → 832 visual tokens
|
|
@@ -71,31 +72,29 @@ def tile_level_mean_pooling(
|
|
| 71 |
emb_np = embedding.cpu().numpy().astype(np.float32)
|
| 72 |
else:
|
| 73 |
emb_np = np.array(embedding, dtype=np.float32)
|
| 74 |
-
|
| 75 |
num_visual_tokens = emb_np.shape[0]
|
| 76 |
expected_tokens = num_tiles * patches_per_tile
|
| 77 |
-
|
| 78 |
if num_visual_tokens != expected_tokens:
|
| 79 |
-
logger.debug(
|
| 80 |
-
f"Token count mismatch: {num_visual_tokens} vs expected {expected_tokens}"
|
| 81 |
-
)
|
| 82 |
actual_tiles = num_visual_tokens // patches_per_tile
|
| 83 |
if actual_tiles * patches_per_tile != num_visual_tokens:
|
| 84 |
actual_tiles += 1
|
| 85 |
num_tiles = actual_tiles
|
| 86 |
-
|
| 87 |
tile_embeddings = []
|
| 88 |
for tile_idx in range(num_tiles):
|
| 89 |
start_idx = tile_idx * patches_per_tile
|
| 90 |
end_idx = min(start_idx + patches_per_tile, num_visual_tokens)
|
| 91 |
-
|
| 92 |
if start_idx >= num_visual_tokens:
|
| 93 |
break
|
| 94 |
-
|
| 95 |
tile_patches = emb_np[start_idx:end_idx]
|
| 96 |
tile_mean = tile_patches.mean(axis=0)
|
| 97 |
tile_embeddings.append(tile_mean)
|
| 98 |
-
|
| 99 |
return np.array(tile_embeddings, dtype=out_dtype)
|
| 100 |
|
| 101 |
|
|
@@ -116,7 +115,9 @@ def colpali_row_mean_pooling(
|
|
| 116 |
num_tokens, dim = emb_np.shape
|
| 117 |
expected = int(grid_size) * int(grid_size)
|
| 118 |
if num_tokens != expected:
|
| 119 |
-
raise ValueError(
|
|
|
|
|
|
|
| 120 |
|
| 121 |
grid = emb_np.reshape(int(grid_size), int(grid_size), int(dim))
|
| 122 |
pooled = grid.mean(axis=1)
|
|
@@ -157,7 +158,9 @@ def colsmol_experimental_pooling(
|
|
| 157 |
last_tile_start = (int(num_tiles) - 1) * int(patches_per_tile)
|
| 158 |
|
| 159 |
prefix = emb_np[:last_tile_start]
|
| 160 |
-
last_tile = emb_np[
|
|
|
|
|
|
|
| 161 |
|
| 162 |
if prefix.size:
|
| 163 |
prefix_tiles = prefix.reshape(-1, int(patches_per_tile), int(dim))
|
|
@@ -174,7 +177,7 @@ def colpali_experimental_pooling_from_rows(
|
|
| 174 |
) -> np.ndarray:
|
| 175 |
"""
|
| 176 |
Experimental "convolution-style" pooling with window size 3.
|
| 177 |
-
|
| 178 |
For N input rows, produces N + 2 output vectors:
|
| 179 |
- Position 0: row[0] alone (1 row)
|
| 180 |
- Position 1: mean(rows[0:2]) (2 rows)
|
|
@@ -182,7 +185,7 @@ def colpali_experimental_pooling_from_rows(
|
|
| 182 |
- Positions 3 to N-1: sliding window of 3 (rows[i-2:i+1])
|
| 183 |
- Position N: mean(rows[N-2:N]) (last 2 rows)
|
| 184 |
- Position N+1: row[N-1] alone (last row)
|
| 185 |
-
|
| 186 |
For N=32 rows: produces 34 vectors.
|
| 187 |
"""
|
| 188 |
out_dtype = _infer_output_dtype(row_vectors, output_dtype)
|
|
@@ -202,13 +205,16 @@ def colpali_experimental_pooling_from_rows(
|
|
| 202 |
if n == 2:
|
| 203 |
return np.stack([rows[0], rows[:2].mean(axis=0), rows[1]], axis=0).astype(out_dtype)
|
| 204 |
if n == 3:
|
| 205 |
-
return np.stack(
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
out = np.zeros((n + 2, dim), dtype=np.float32)
|
| 214 |
out[0] = rows[0]
|
|
@@ -227,14 +233,14 @@ def global_mean_pooling(
|
|
| 227 |
) -> np.ndarray:
|
| 228 |
"""
|
| 229 |
Compute global mean pooling → single vector.
|
| 230 |
-
|
| 231 |
This is the simplest pooling but loses all spatial information.
|
| 232 |
Use for fastest retrieval when accuracy can be sacrificed.
|
| 233 |
-
|
| 234 |
Args:
|
| 235 |
embedding: Multi-vector embeddings [num_tokens, dim]
|
| 236 |
output_dtype: Output dtype (default: infer from input, fp16→fp16, bf16→fp32)
|
| 237 |
-
|
| 238 |
Returns:
|
| 239 |
Pooled vector [dim]
|
| 240 |
"""
|
|
@@ -246,7 +252,7 @@ def global_mean_pooling(
|
|
| 246 |
emb_np = embedding.cpu().numpy()
|
| 247 |
else:
|
| 248 |
emb_np = np.array(embedding)
|
| 249 |
-
|
| 250 |
return emb_np.mean(axis=0).astype(out_dtype)
|
| 251 |
|
| 252 |
|
|
@@ -257,21 +263,21 @@ def compute_maxsim_score(
|
|
| 257 |
) -> float:
|
| 258 |
"""
|
| 259 |
Compute ColBERT-style MaxSim late interaction score.
|
| 260 |
-
|
| 261 |
For each query token, finds max similarity with any document token,
|
| 262 |
then sums across query tokens.
|
| 263 |
-
|
| 264 |
This is the standard scoring for ColBERT/ColPali:
|
| 265 |
score = Σ_q max_d (sim(q, d))
|
| 266 |
-
|
| 267 |
Args:
|
| 268 |
query_embedding: Query embeddings [num_query_tokens, dim]
|
| 269 |
doc_embedding: Document embeddings [num_doc_tokens, dim]
|
| 270 |
normalize: L2 normalize embeddings before scoring (recommended)
|
| 271 |
-
|
| 272 |
Returns:
|
| 273 |
MaxSim score (higher is better)
|
| 274 |
-
|
| 275 |
Example:
|
| 276 |
>>> query = embedder.embed_query("budget allocation")
|
| 277 |
>>> doc = embeddings[0] # From embed_images
|
|
@@ -282,22 +288,20 @@ def compute_maxsim_score(
|
|
| 282 |
query_norm = query_embedding / (
|
| 283 |
np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8
|
| 284 |
)
|
| 285 |
-
doc_norm = doc_embedding / (
|
| 286 |
-
np.linalg.norm(doc_embedding, axis=1, keepdims=True) + 1e-8
|
| 287 |
-
)
|
| 288 |
else:
|
| 289 |
query_norm = query_embedding
|
| 290 |
doc_norm = doc_embedding
|
| 291 |
-
|
| 292 |
# Compute similarity matrix: [num_query, num_doc]
|
| 293 |
similarity_matrix = np.dot(query_norm, doc_norm.T)
|
| 294 |
-
|
| 295 |
# MaxSim: For each query token, take max similarity with any doc token
|
| 296 |
max_similarities = similarity_matrix.max(axis=1)
|
| 297 |
-
|
| 298 |
# Sum across query tokens
|
| 299 |
score = float(max_similarities.sum())
|
| 300 |
-
|
| 301 |
return score
|
| 302 |
|
| 303 |
|
|
@@ -308,12 +312,12 @@ def compute_maxsim_batch(
|
|
| 308 |
) -> list:
|
| 309 |
"""
|
| 310 |
Compute MaxSim scores for multiple documents efficiently.
|
| 311 |
-
|
| 312 |
Args:
|
| 313 |
query_embedding: Query embeddings [num_query_tokens, dim]
|
| 314 |
doc_embeddings: List of document embeddings
|
| 315 |
normalize: L2 normalize embeddings
|
| 316 |
-
|
| 317 |
Returns:
|
| 318 |
List of MaxSim scores
|
| 319 |
"""
|
|
@@ -324,18 +328,16 @@ def compute_maxsim_batch(
|
|
| 324 |
)
|
| 325 |
else:
|
| 326 |
query_norm = query_embedding
|
| 327 |
-
|
| 328 |
scores = []
|
| 329 |
for doc_emb in doc_embeddings:
|
| 330 |
if normalize:
|
| 331 |
-
doc_norm = doc_emb / (
|
| 332 |
-
np.linalg.norm(doc_emb, axis=1, keepdims=True) + 1e-8
|
| 333 |
-
)
|
| 334 |
else:
|
| 335 |
doc_norm = doc_emb
|
| 336 |
-
|
| 337 |
sim_matrix = np.dot(query_norm, doc_norm.T)
|
| 338 |
max_sims = sim_matrix.max(axis=1)
|
| 339 |
scores.append(float(max_sims.sum()))
|
| 340 |
-
|
| 341 |
return scores
|
|
|
|
| 7 |
- MaxSim scoring for ColBERT-style late interaction
|
| 8 |
"""
|
| 9 |
|
| 10 |
+
import logging
|
| 11 |
+
from typing import Optional, Union
|
| 12 |
+
|
| 13 |
import numpy as np
|
| 14 |
import torch
|
|
|
|
|
|
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
|
|
|
| 40 |
) -> np.ndarray:
|
| 41 |
"""
|
| 42 |
Compute tile-level mean pooling for multi-vector embeddings.
|
| 43 |
+
|
| 44 |
Instead of collapsing to 1×dim (global pooling), this preserves spatial
|
| 45 |
structure by computing mean per tile → num_tiles × dim.
|
| 46 |
+
|
| 47 |
This is our NOVEL contribution for scalable visual retrieval:
|
| 48 |
- Faster than full MaxSim (fewer vectors to compare)
|
| 49 |
- More accurate than global pooling (preserves spatial info)
|
| 50 |
- Ideal for two-stage retrieval (prefetch with pooled, rerank with full)
|
| 51 |
+
|
| 52 |
Args:
|
| 53 |
embedding: Visual token embeddings [num_visual_tokens, dim]
|
| 54 |
num_tiles: Number of tiles (including global tile)
|
| 55 |
patches_per_tile: Patches per tile (64 for ColSmol)
|
| 56 |
output_dtype: Output dtype (default: infer from input, fp16→fp16, bf16→fp32)
|
| 57 |
+
|
| 58 |
Returns:
|
| 59 |
Tile-level pooled embeddings [num_tiles, dim]
|
| 60 |
+
|
| 61 |
Example:
|
| 62 |
>>> # Image with 4×3 tiles + 1 global = 13 tiles
|
| 63 |
>>> # Each tile has 64 patches → 832 visual tokens
|
|
|
|
| 72 |
emb_np = embedding.cpu().numpy().astype(np.float32)
|
| 73 |
else:
|
| 74 |
emb_np = np.array(embedding, dtype=np.float32)
|
| 75 |
+
|
| 76 |
num_visual_tokens = emb_np.shape[0]
|
| 77 |
expected_tokens = num_tiles * patches_per_tile
|
| 78 |
+
|
| 79 |
if num_visual_tokens != expected_tokens:
|
| 80 |
+
logger.debug(f"Token count mismatch: {num_visual_tokens} vs expected {expected_tokens}")
|
|
|
|
|
|
|
| 81 |
actual_tiles = num_visual_tokens // patches_per_tile
|
| 82 |
if actual_tiles * patches_per_tile != num_visual_tokens:
|
| 83 |
actual_tiles += 1
|
| 84 |
num_tiles = actual_tiles
|
| 85 |
+
|
| 86 |
tile_embeddings = []
|
| 87 |
for tile_idx in range(num_tiles):
|
| 88 |
start_idx = tile_idx * patches_per_tile
|
| 89 |
end_idx = min(start_idx + patches_per_tile, num_visual_tokens)
|
| 90 |
+
|
| 91 |
if start_idx >= num_visual_tokens:
|
| 92 |
break
|
| 93 |
+
|
| 94 |
tile_patches = emb_np[start_idx:end_idx]
|
| 95 |
tile_mean = tile_patches.mean(axis=0)
|
| 96 |
tile_embeddings.append(tile_mean)
|
| 97 |
+
|
| 98 |
return np.array(tile_embeddings, dtype=out_dtype)
|
| 99 |
|
| 100 |
|
|
|
|
| 115 |
num_tokens, dim = emb_np.shape
|
| 116 |
expected = int(grid_size) * int(grid_size)
|
| 117 |
if num_tokens != expected:
|
| 118 |
+
raise ValueError(
|
| 119 |
+
f"Expected {expected} visual tokens for grid_size={grid_size}, got {num_tokens}"
|
| 120 |
+
)
|
| 121 |
|
| 122 |
grid = emb_np.reshape(int(grid_size), int(grid_size), int(dim))
|
| 123 |
pooled = grid.mean(axis=1)
|
|
|
|
| 158 |
last_tile_start = (int(num_tiles) - 1) * int(patches_per_tile)
|
| 159 |
|
| 160 |
prefix = emb_np[:last_tile_start]
|
| 161 |
+
last_tile = emb_np[
|
| 162 |
+
last_tile_start : min(last_tile_start + int(patches_per_tile), num_visual_tokens)
|
| 163 |
+
]
|
| 164 |
|
| 165 |
if prefix.size:
|
| 166 |
prefix_tiles = prefix.reshape(-1, int(patches_per_tile), int(dim))
|
|
|
|
| 177 |
) -> np.ndarray:
|
| 178 |
"""
|
| 179 |
Experimental "convolution-style" pooling with window size 3.
|
| 180 |
+
|
| 181 |
For N input rows, produces N + 2 output vectors:
|
| 182 |
- Position 0: row[0] alone (1 row)
|
| 183 |
- Position 1: mean(rows[0:2]) (2 rows)
|
|
|
|
| 185 |
- Positions 3 to N-1: sliding window of 3 (rows[i-2:i+1])
|
| 186 |
- Position N: mean(rows[N-2:N]) (last 2 rows)
|
| 187 |
- Position N+1: row[N-1] alone (last row)
|
| 188 |
+
|
| 189 |
For N=32 rows: produces 34 vectors.
|
| 190 |
"""
|
| 191 |
out_dtype = _infer_output_dtype(row_vectors, output_dtype)
|
|
|
|
| 205 |
if n == 2:
|
| 206 |
return np.stack([rows[0], rows[:2].mean(axis=0), rows[1]], axis=0).astype(out_dtype)
|
| 207 |
if n == 3:
|
| 208 |
+
return np.stack(
|
| 209 |
+
[
|
| 210 |
+
rows[0],
|
| 211 |
+
rows[:2].mean(axis=0),
|
| 212 |
+
rows[:3].mean(axis=0),
|
| 213 |
+
rows[1:3].mean(axis=0),
|
| 214 |
+
rows[2],
|
| 215 |
+
],
|
| 216 |
+
axis=0,
|
| 217 |
+
).astype(out_dtype)
|
| 218 |
|
| 219 |
out = np.zeros((n + 2, dim), dtype=np.float32)
|
| 220 |
out[0] = rows[0]
|
|
|
|
| 233 |
) -> np.ndarray:
|
| 234 |
"""
|
| 235 |
Compute global mean pooling → single vector.
|
| 236 |
+
|
| 237 |
This is the simplest pooling but loses all spatial information.
|
| 238 |
Use for fastest retrieval when accuracy can be sacrificed.
|
| 239 |
+
|
| 240 |
Args:
|
| 241 |
embedding: Multi-vector embeddings [num_tokens, dim]
|
| 242 |
output_dtype: Output dtype (default: infer from input, fp16→fp16, bf16→fp32)
|
| 243 |
+
|
| 244 |
Returns:
|
| 245 |
Pooled vector [dim]
|
| 246 |
"""
|
|
|
|
| 252 |
emb_np = embedding.cpu().numpy()
|
| 253 |
else:
|
| 254 |
emb_np = np.array(embedding)
|
| 255 |
+
|
| 256 |
return emb_np.mean(axis=0).astype(out_dtype)
|
| 257 |
|
| 258 |
|
|
|
|
| 263 |
) -> float:
|
| 264 |
"""
|
| 265 |
Compute ColBERT-style MaxSim late interaction score.
|
| 266 |
+
|
| 267 |
For each query token, finds max similarity with any document token,
|
| 268 |
then sums across query tokens.
|
| 269 |
+
|
| 270 |
This is the standard scoring for ColBERT/ColPali:
|
| 271 |
score = Σ_q max_d (sim(q, d))
|
| 272 |
+
|
| 273 |
Args:
|
| 274 |
query_embedding: Query embeddings [num_query_tokens, dim]
|
| 275 |
doc_embedding: Document embeddings [num_doc_tokens, dim]
|
| 276 |
normalize: L2 normalize embeddings before scoring (recommended)
|
| 277 |
+
|
| 278 |
Returns:
|
| 279 |
MaxSim score (higher is better)
|
| 280 |
+
|
| 281 |
Example:
|
| 282 |
>>> query = embedder.embed_query("budget allocation")
|
| 283 |
>>> doc = embeddings[0] # From embed_images
|
|
|
|
| 288 |
query_norm = query_embedding / (
|
| 289 |
np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8
|
| 290 |
)
|
| 291 |
+
doc_norm = doc_embedding / (np.linalg.norm(doc_embedding, axis=1, keepdims=True) + 1e-8)
|
|
|
|
|
|
|
| 292 |
else:
|
| 293 |
query_norm = query_embedding
|
| 294 |
doc_norm = doc_embedding
|
| 295 |
+
|
| 296 |
# Compute similarity matrix: [num_query, num_doc]
|
| 297 |
similarity_matrix = np.dot(query_norm, doc_norm.T)
|
| 298 |
+
|
| 299 |
# MaxSim: For each query token, take max similarity with any doc token
|
| 300 |
max_similarities = similarity_matrix.max(axis=1)
|
| 301 |
+
|
| 302 |
# Sum across query tokens
|
| 303 |
score = float(max_similarities.sum())
|
| 304 |
+
|
| 305 |
return score
|
| 306 |
|
| 307 |
|
|
|
|
| 312 |
) -> list:
|
| 313 |
"""
|
| 314 |
Compute MaxSim scores for multiple documents efficiently.
|
| 315 |
+
|
| 316 |
Args:
|
| 317 |
query_embedding: Query embeddings [num_query_tokens, dim]
|
| 318 |
doc_embeddings: List of document embeddings
|
| 319 |
normalize: L2 normalize embeddings
|
| 320 |
+
|
| 321 |
Returns:
|
| 322 |
List of MaxSim scores
|
| 323 |
"""
|
|
|
|
| 328 |
)
|
| 329 |
else:
|
| 330 |
query_norm = query_embedding
|
| 331 |
+
|
| 332 |
scores = []
|
| 333 |
for doc_emb in doc_embeddings:
|
| 334 |
if normalize:
|
| 335 |
+
doc_norm = doc_emb / (np.linalg.norm(doc_emb, axis=1, keepdims=True) + 1e-8)
|
|
|
|
|
|
|
| 336 |
else:
|
| 337 |
doc_norm = doc_emb
|
| 338 |
+
|
| 339 |
sim_matrix = np.dot(query_norm, doc_norm.T)
|
| 340 |
max_sims = sim_matrix.max(axis=1)
|
| 341 |
scores.append(float(max_sims.sum()))
|
| 342 |
+
|
| 343 |
return scores
|
visual_rag/embedding/visual_embedder.py
CHANGED
|
@@ -12,12 +12,12 @@ The embedder is BACKEND-AGNOSTIC - configure which model to use via the
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
import gc
|
| 15 |
-
import os
|
| 16 |
import logging
|
| 17 |
-
|
|
|
|
| 18 |
|
| 19 |
-
import torch
|
| 20 |
import numpy as np
|
|
|
|
| 21 |
from PIL import Image
|
| 22 |
from tqdm import tqdm
|
| 23 |
|
|
@@ -27,11 +27,11 @@ logger = logging.getLogger(__name__)
|
|
| 27 |
class VisualEmbedder:
|
| 28 |
"""
|
| 29 |
Visual document embedder supporting multiple backends.
|
| 30 |
-
|
| 31 |
Currently supports:
|
| 32 |
- ColPali family (ColSmol-500M, ColPali, ColQwen2)
|
| 33 |
- More backends can be added
|
| 34 |
-
|
| 35 |
Args:
|
| 36 |
model_name: HuggingFace model name (e.g., "vidore/colSmol-500M")
|
| 37 |
backend: Backend type ("colpali", "auto"). "auto" detects from model_name.
|
|
@@ -39,23 +39,23 @@ class VisualEmbedder:
|
|
| 39 |
torch_dtype: Data type for model weights
|
| 40 |
batch_size: Batch size for image processing
|
| 41 |
filter_special_tokens: Filter special tokens from query embeddings
|
| 42 |
-
|
| 43 |
Example:
|
| 44 |
>>> # Auto-detect backend from model name
|
| 45 |
>>> embedder = VisualEmbedder(model_name="vidore/colSmol-500M")
|
| 46 |
-
>>>
|
| 47 |
>>> # Embed images
|
| 48 |
>>> image_embeddings = embedder.embed_images(images)
|
| 49 |
-
>>>
|
| 50 |
>>> # Embed query
|
| 51 |
>>> query_embedding = embedder.embed_query("What is the budget?")
|
| 52 |
-
>>>
|
| 53 |
>>> # Get token info for saliency maps
|
| 54 |
>>> embeddings, token_infos = embedder.embed_images(
|
| 55 |
... images, return_token_info=True
|
| 56 |
... )
|
| 57 |
"""
|
| 58 |
-
|
| 59 |
# Known model families and their backends
|
| 60 |
MODEL_BACKENDS = {
|
| 61 |
"colsmol": "colpali",
|
|
@@ -63,7 +63,7 @@ class VisualEmbedder:
|
|
| 63 |
"colqwen": "colpali",
|
| 64 |
"colidefics": "colpali",
|
| 65 |
}
|
| 66 |
-
|
| 67 |
def __init__(
|
| 68 |
self,
|
| 69 |
model_name: str = "vidore/colSmol-500M",
|
|
@@ -81,15 +81,15 @@ class VisualEmbedder:
|
|
| 81 |
if processor_speed not in ("fast", "slow", "auto"):
|
| 82 |
raise ValueError("processor_speed must be one of: fast, slow, auto")
|
| 83 |
self.processor_speed = processor_speed
|
| 84 |
-
|
| 85 |
if os.getenv("VISUALRAG_INCLUDE_SPECIAL_TOKENS"):
|
| 86 |
self.filter_special_tokens = False
|
| 87 |
logger.info("Special token filtering disabled via VISUALRAG_INCLUDE_SPECIAL_TOKENS")
|
| 88 |
-
|
| 89 |
if backend == "auto":
|
| 90 |
backend = self._detect_backend(model_name)
|
| 91 |
self.backend = backend
|
| 92 |
-
|
| 93 |
if device is None:
|
| 94 |
if torch.cuda.is_available():
|
| 95 |
device = "cuda"
|
|
@@ -98,53 +98,55 @@ class VisualEmbedder:
|
|
| 98 |
else:
|
| 99 |
device = "cpu"
|
| 100 |
self.device = device
|
| 101 |
-
|
| 102 |
if torch_dtype is None:
|
| 103 |
if device == "cuda":
|
| 104 |
torch_dtype = torch.bfloat16
|
| 105 |
else:
|
| 106 |
torch_dtype = torch.float32
|
| 107 |
self.torch_dtype = torch_dtype
|
| 108 |
-
|
| 109 |
if output_dtype is None:
|
| 110 |
if torch_dtype == torch.float16:
|
| 111 |
output_dtype = np.float16
|
| 112 |
else:
|
| 113 |
output_dtype = np.float32
|
| 114 |
self.output_dtype = output_dtype
|
| 115 |
-
|
| 116 |
self._model = None
|
| 117 |
self._processor = None
|
| 118 |
self._image_token_id = None
|
| 119 |
-
|
| 120 |
-
logger.info(
|
| 121 |
logger.info(f" Model: {model_name}")
|
| 122 |
logger.info(f" Backend: {backend}")
|
| 123 |
-
logger.info(
|
| 124 |
-
|
|
|
|
|
|
|
| 125 |
def _detect_backend(self, model_name: str) -> str:
|
| 126 |
"""Auto-detect backend from model name."""
|
| 127 |
model_lower = model_name.lower()
|
| 128 |
-
|
| 129 |
for key, backend in self.MODEL_BACKENDS.items():
|
| 130 |
if key in model_lower:
|
| 131 |
logger.debug(f"Detected backend '{backend}' from model name")
|
| 132 |
return backend
|
| 133 |
-
|
| 134 |
# Default to colpali for unknown models
|
| 135 |
logger.warning(f"Unknown model '{model_name}', defaulting to 'colpali' backend")
|
| 136 |
return "colpali"
|
| 137 |
-
|
| 138 |
def _load_model(self):
|
| 139 |
"""Lazy load the model when first needed."""
|
| 140 |
if self._model is not None:
|
| 141 |
return
|
| 142 |
-
|
| 143 |
if self.backend == "colpali":
|
| 144 |
self._load_colpali_model()
|
| 145 |
else:
|
| 146 |
raise ValueError(f"Unknown backend: {self.backend}")
|
| 147 |
-
|
| 148 |
def _load_colpali_model(self):
|
| 149 |
"""Load ColPali-family model."""
|
| 150 |
try:
|
|
@@ -162,7 +164,7 @@ class VisualEmbedder:
|
|
| 162 |
"pip install visual-rag-toolkit[embedding] or "
|
| 163 |
"pip install colpali-engine"
|
| 164 |
)
|
| 165 |
-
|
| 166 |
logger.info(f"🤖 Loading ColPali model: {self.model_name}")
|
| 167 |
logger.info(f" Device: {self.device}, dtype: {self.torch_dtype}")
|
| 168 |
|
|
@@ -170,7 +172,7 @@ class VisualEmbedder:
|
|
| 170 |
if self.processor_speed == "auto":
|
| 171 |
return {}
|
| 172 |
return {"use_fast": self.processor_speed == "fast"}
|
| 173 |
-
|
| 174 |
from transformers import AutoConfig
|
| 175 |
|
| 176 |
cfg = AutoConfig.from_pretrained(self.model_name)
|
|
@@ -183,12 +185,16 @@ class VisualEmbedder:
|
|
| 183 |
device_map=self.device,
|
| 184 |
).eval()
|
| 185 |
try:
|
| 186 |
-
self._processor = ColPaliProcessor.from_pretrained(
|
|
|
|
|
|
|
| 187 |
except TypeError:
|
| 188 |
self._processor = ColPaliProcessor.from_pretrained(self.model_name)
|
| 189 |
except Exception:
|
| 190 |
if self.processor_speed == "fast":
|
| 191 |
-
self._processor = ColPaliProcessor.from_pretrained(
|
|
|
|
|
|
|
| 192 |
else:
|
| 193 |
raise
|
| 194 |
self._image_token_id = self._processor.image_token_id
|
|
@@ -202,12 +208,18 @@ class VisualEmbedder:
|
|
| 202 |
device_map=self.device,
|
| 203 |
).eval()
|
| 204 |
try:
|
| 205 |
-
self._processor = ColQwen2Processor.from_pretrained(
|
|
|
|
|
|
|
| 206 |
except TypeError:
|
| 207 |
-
self._processor = ColQwen2Processor.from_pretrained(
|
|
|
|
|
|
|
| 208 |
except Exception:
|
| 209 |
if self.processor_speed == "fast":
|
| 210 |
-
self._processor = ColQwen2Processor.from_pretrained(
|
|
|
|
|
|
|
| 211 |
else:
|
| 212 |
raise
|
| 213 |
self._image_token_id = self._processor.image_token_id
|
|
@@ -231,33 +243,37 @@ class VisualEmbedder:
|
|
| 231 |
attn_implementation=attn_implementation,
|
| 232 |
).eval()
|
| 233 |
try:
|
| 234 |
-
self._processor = ColIdefics3Processor.from_pretrained(
|
|
|
|
|
|
|
| 235 |
except TypeError:
|
| 236 |
self._processor = ColIdefics3Processor.from_pretrained(self.model_name)
|
| 237 |
except Exception:
|
| 238 |
if self.processor_speed == "fast":
|
| 239 |
-
self._processor = ColIdefics3Processor.from_pretrained(
|
|
|
|
|
|
|
| 240 |
else:
|
| 241 |
raise
|
| 242 |
self._image_token_id = self._processor.image_token_id
|
| 243 |
-
|
| 244 |
logger.info("✅ Model loaded successfully")
|
| 245 |
-
|
| 246 |
@property
|
| 247 |
def model(self):
|
| 248 |
self._load_model()
|
| 249 |
return self._model
|
| 250 |
-
|
| 251 |
@property
|
| 252 |
def processor(self):
|
| 253 |
self._load_model()
|
| 254 |
return self._processor
|
| 255 |
-
|
| 256 |
@property
|
| 257 |
def image_token_id(self):
|
| 258 |
self._load_model()
|
| 259 |
return self._image_token_id
|
| 260 |
-
|
| 261 |
def embed_query(
|
| 262 |
self,
|
| 263 |
query_text: str,
|
|
@@ -265,31 +281,31 @@ class VisualEmbedder:
|
|
| 265 |
) -> torch.Tensor:
|
| 266 |
"""
|
| 267 |
Generate embedding for a text query.
|
| 268 |
-
|
| 269 |
By default, filters out special tokens (CLS, SEP, PAD) to keep only
|
| 270 |
meaningful text tokens for better MaxSim matching.
|
| 271 |
-
|
| 272 |
Args:
|
| 273 |
query_text: Natural language query string
|
| 274 |
filter_special_tokens: Override instance-level setting
|
| 275 |
-
|
| 276 |
Returns:
|
| 277 |
Query embedding tensor of shape [num_tokens, embedding_dim]
|
| 278 |
"""
|
| 279 |
should_filter = (
|
| 280 |
-
filter_special_tokens
|
| 281 |
-
if filter_special_tokens is not None
|
| 282 |
else self.filter_special_tokens
|
| 283 |
)
|
| 284 |
-
|
| 285 |
with torch.no_grad():
|
| 286 |
processed = self.processor.process_queries([query_text]).to(self.model.device)
|
| 287 |
embedding = self.model(**processed)
|
| 288 |
-
|
| 289 |
# Remove batch dimension: [1, tokens, dim] -> [tokens, dim]
|
| 290 |
if embedding.dim() == 3:
|
| 291 |
embedding = embedding.squeeze(0)
|
| 292 |
-
|
| 293 |
if should_filter:
|
| 294 |
# Filter special tokens based on attention mask
|
| 295 |
attention_mask = processed.get("attention_mask")
|
|
@@ -297,7 +313,7 @@ class VisualEmbedder:
|
|
| 297 |
# Keep only tokens with attention_mask = 1
|
| 298 |
valid_mask = attention_mask.squeeze(0).bool()
|
| 299 |
embedding = embedding[valid_mask]
|
| 300 |
-
|
| 301 |
# Additionally filter padding tokens if present
|
| 302 |
input_ids = processed.get("input_ids")
|
| 303 |
if input_ids is not None:
|
|
@@ -307,11 +323,11 @@ class VisualEmbedder:
|
|
| 307 |
non_special_mask = input_ids >= 4
|
| 308 |
if non_special_mask.any():
|
| 309 |
embedding = embedding[non_special_mask]
|
| 310 |
-
|
| 311 |
logger.debug(f"Query embedding: {embedding.shape[0]} tokens after filtering")
|
| 312 |
else:
|
| 313 |
logger.debug(f"Query embedding: {embedding.shape[0]} tokens (unfiltered)")
|
| 314 |
-
|
| 315 |
return embedding
|
| 316 |
|
| 317 |
def embed_queries(
|
|
@@ -327,7 +343,9 @@ class VisualEmbedder:
|
|
| 327 |
Returns a list of tensors, each of shape [num_tokens, embedding_dim].
|
| 328 |
"""
|
| 329 |
should_filter = (
|
| 330 |
-
filter_special_tokens
|
|
|
|
|
|
|
| 331 |
)
|
| 332 |
batch_size = batch_size or self.batch_size
|
| 333 |
|
|
@@ -368,7 +386,7 @@ class VisualEmbedder:
|
|
| 368 |
torch.mps.empty_cache()
|
| 369 |
|
| 370 |
return outputs
|
| 371 |
-
|
| 372 |
def embed_images(
|
| 373 |
self,
|
| 374 |
images: List[Image.Image],
|
|
@@ -378,19 +396,19 @@ class VisualEmbedder:
|
|
| 378 |
) -> Union[List[torch.Tensor], Tuple[List[torch.Tensor], List[Dict[str, Any]]]]:
|
| 379 |
"""
|
| 380 |
Generate embeddings for a list of images.
|
| 381 |
-
|
| 382 |
Args:
|
| 383 |
images: List of PIL Images
|
| 384 |
batch_size: Override instance batch size
|
| 385 |
return_token_info: Also return token metadata (for saliency maps)
|
| 386 |
show_progress: Show progress bar
|
| 387 |
-
|
| 388 |
Returns:
|
| 389 |
If return_token_info=False:
|
| 390 |
List of embedding tensors [num_patches, dim]
|
| 391 |
If return_token_info=True:
|
| 392 |
Tuple of (embeddings, token_infos)
|
| 393 |
-
|
| 394 |
Token info contains:
|
| 395 |
- visual_token_indices: Indices of visual tokens in embedding
|
| 396 |
- num_visual_tokens: Count of visual tokens
|
|
@@ -398,54 +416,60 @@ class VisualEmbedder:
|
|
| 398 |
- num_tiles: Total tiles (n_rows × n_cols + 1 global)
|
| 399 |
"""
|
| 400 |
batch_size = batch_size or self.batch_size
|
| 401 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
batch_size = 1
|
| 403 |
-
|
| 404 |
embeddings = []
|
| 405 |
token_infos = [] if return_token_info else None
|
| 406 |
-
|
| 407 |
iterator = range(0, len(images), batch_size)
|
| 408 |
if show_progress:
|
| 409 |
iterator = tqdm(iterator, desc="🎨 Embedding", unit="batch")
|
| 410 |
-
|
| 411 |
for i in iterator:
|
| 412 |
-
batch = images[i:i + batch_size]
|
| 413 |
-
|
| 414 |
with torch.no_grad():
|
| 415 |
processed = self.processor.process_images(batch).to(self.model.device)
|
| 416 |
-
|
| 417 |
# Extract token info before model forward
|
| 418 |
if return_token_info:
|
| 419 |
input_ids = processed["input_ids"]
|
| 420 |
batch_n_rows = processed.get("n_rows")
|
| 421 |
batch_n_cols = processed.get("n_cols")
|
| 422 |
-
|
| 423 |
for j in range(input_ids.shape[0]):
|
| 424 |
# Find visual token indices
|
| 425 |
-
image_token_mask =
|
| 426 |
visual_indices = torch.where(image_token_mask)[0].cpu().numpy().tolist()
|
| 427 |
-
|
| 428 |
n_rows = batch_n_rows[j].item() if batch_n_rows is not None else None
|
| 429 |
n_cols = batch_n_cols[j].item() if batch_n_cols is not None else None
|
| 430 |
-
|
| 431 |
-
token_infos.append(
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
|
|
|
|
|
|
| 439 |
# Generate embeddings
|
| 440 |
batch_embeddings = self.model(**processed)
|
| 441 |
-
|
| 442 |
# Extract per-image embeddings
|
| 443 |
if isinstance(batch_embeddings, torch.Tensor) and batch_embeddings.dim() == 3:
|
| 444 |
for j in range(batch_embeddings.shape[0]):
|
| 445 |
embeddings.append(batch_embeddings[j].cpu())
|
| 446 |
else:
|
| 447 |
embeddings.extend([e.cpu() for e in batch_embeddings])
|
| 448 |
-
|
| 449 |
# Memory cleanup
|
| 450 |
del processed, batch_embeddings
|
| 451 |
gc.collect()
|
|
@@ -453,11 +477,11 @@ class VisualEmbedder:
|
|
| 453 |
torch.cuda.empty_cache()
|
| 454 |
elif torch.backends.mps.is_available():
|
| 455 |
torch.mps.empty_cache()
|
| 456 |
-
|
| 457 |
if return_token_info:
|
| 458 |
return embeddings, token_infos
|
| 459 |
return embeddings
|
| 460 |
-
|
| 461 |
def extract_visual_embedding(
|
| 462 |
self,
|
| 463 |
full_embedding: torch.Tensor,
|
|
@@ -465,18 +489,18 @@ class VisualEmbedder:
|
|
| 465 |
) -> np.ndarray:
|
| 466 |
"""
|
| 467 |
Extract only visual token embeddings from full embedding.
|
| 468 |
-
|
| 469 |
Filters out special tokens, keeping only visual patches for MaxSim.
|
| 470 |
-
|
| 471 |
Args:
|
| 472 |
full_embedding: Full embedding [all_tokens, dim]
|
| 473 |
token_info: Token info dict from embed_images
|
| 474 |
-
|
| 475 |
Returns:
|
| 476 |
Visual embedding array [num_visual_tokens, dim]
|
| 477 |
"""
|
| 478 |
visual_indices = token_info["visual_token_indices"]
|
| 479 |
-
|
| 480 |
if isinstance(full_embedding, torch.Tensor):
|
| 481 |
if full_embedding.dtype == torch.bfloat16:
|
| 482 |
visual_emb = full_embedding[visual_indices].cpu().float().numpy()
|
|
@@ -484,7 +508,7 @@ class VisualEmbedder:
|
|
| 484 |
visual_emb = full_embedding[visual_indices].cpu().numpy()
|
| 485 |
else:
|
| 486 |
visual_emb = np.array(full_embedding)[visual_indices]
|
| 487 |
-
|
| 488 |
return visual_emb.astype(self.output_dtype)
|
| 489 |
|
| 490 |
def mean_pool_visual_embedding(
|
|
@@ -511,17 +535,23 @@ class VisualEmbedder:
|
|
| 511 |
n_rows = (token_info or {}).get("n_rows")
|
| 512 |
n_cols = (token_info or {}).get("n_cols")
|
| 513 |
num_tiles = int(n_rows) * int(n_cols) + 1 if n_rows and n_cols else 13
|
| 514 |
-
return tile_level_mean_pooling(
|
|
|
|
|
|
|
| 515 |
|
| 516 |
num_tokens = int(visual_np.shape[0])
|
| 517 |
grid = int(round(float(num_tokens) ** 0.5))
|
| 518 |
if grid * grid != num_tokens:
|
| 519 |
-
raise ValueError(
|
|
|
|
|
|
|
| 520 |
if int(target_vectors) != int(grid):
|
| 521 |
raise ValueError(
|
| 522 |
f"target_vectors={target_vectors} does not match inferred grid_size={grid} for model={self.model_name}"
|
| 523 |
)
|
| 524 |
-
return colpali_row_mean_pooling(
|
|
|
|
|
|
|
| 525 |
|
| 526 |
def global_pool_from_mean_pool(self, mean_pool: np.ndarray) -> np.ndarray:
|
| 527 |
if mean_pool.size == 0:
|
|
@@ -536,7 +566,10 @@ class VisualEmbedder:
|
|
| 536 |
target_vectors: int = 32,
|
| 537 |
mean_pool: Optional[np.ndarray] = None,
|
| 538 |
) -> np.ndarray:
|
| 539 |
-
from visual_rag.embedding.pooling import
|
|
|
|
|
|
|
|
|
|
| 540 |
|
| 541 |
model_lower = (self.model_name or "").lower()
|
| 542 |
is_colsmol = "colsmol" in model_lower
|
|
@@ -550,7 +583,11 @@ class VisualEmbedder:
|
|
| 550 |
visual_np = np.array(visual_embedding, dtype=np.float32)
|
| 551 |
|
| 552 |
if is_colsmol:
|
| 553 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
num_tiles = int(mean_pool.shape[0])
|
| 555 |
else:
|
| 556 |
num_tiles = (token_info or {}).get("num_tiles")
|
|
@@ -563,14 +600,23 @@ class VisualEmbedder:
|
|
| 563 |
if int(num_tiles) * patches_per_tile != int(num_visual_tokens):
|
| 564 |
num_tiles = int(num_tiles) + 1
|
| 565 |
num_tiles = int(num_tiles)
|
| 566 |
-
return colsmol_experimental_pooling(
|
|
|
|
|
|
|
| 567 |
|
| 568 |
-
rows =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
if int(rows.shape[0]) != int(target_vectors):
|
| 570 |
raise ValueError(
|
| 571 |
f"experimental pooling expects mean_pool to have {target_vectors} rows, got {rows.shape[0]} for model={self.model_name}"
|
| 572 |
)
|
| 573 |
return colpali_experimental_pooling_from_rows(rows, output_dtype=self.output_dtype)
|
| 574 |
|
|
|
|
| 575 |
# Backward compatibility alias
|
| 576 |
ColPaliEmbedder = VisualEmbedder
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
import gc
|
|
|
|
| 15 |
import logging
|
| 16 |
+
import os
|
| 17 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 18 |
|
|
|
|
| 19 |
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
from PIL import Image
|
| 22 |
from tqdm import tqdm
|
| 23 |
|
|
|
|
| 27 |
class VisualEmbedder:
|
| 28 |
"""
|
| 29 |
Visual document embedder supporting multiple backends.
|
| 30 |
+
|
| 31 |
Currently supports:
|
| 32 |
- ColPali family (ColSmol-500M, ColPali, ColQwen2)
|
| 33 |
- More backends can be added
|
| 34 |
+
|
| 35 |
Args:
|
| 36 |
model_name: HuggingFace model name (e.g., "vidore/colSmol-500M")
|
| 37 |
backend: Backend type ("colpali", "auto"). "auto" detects from model_name.
|
|
|
|
| 39 |
torch_dtype: Data type for model weights
|
| 40 |
batch_size: Batch size for image processing
|
| 41 |
filter_special_tokens: Filter special tokens from query embeddings
|
| 42 |
+
|
| 43 |
Example:
|
| 44 |
>>> # Auto-detect backend from model name
|
| 45 |
>>> embedder = VisualEmbedder(model_name="vidore/colSmol-500M")
|
| 46 |
+
>>>
|
| 47 |
>>> # Embed images
|
| 48 |
>>> image_embeddings = embedder.embed_images(images)
|
| 49 |
+
>>>
|
| 50 |
>>> # Embed query
|
| 51 |
>>> query_embedding = embedder.embed_query("What is the budget?")
|
| 52 |
+
>>>
|
| 53 |
>>> # Get token info for saliency maps
|
| 54 |
>>> embeddings, token_infos = embedder.embed_images(
|
| 55 |
... images, return_token_info=True
|
| 56 |
... )
|
| 57 |
"""
|
| 58 |
+
|
| 59 |
# Known model families and their backends
|
| 60 |
MODEL_BACKENDS = {
|
| 61 |
"colsmol": "colpali",
|
|
|
|
| 63 |
"colqwen": "colpali",
|
| 64 |
"colidefics": "colpali",
|
| 65 |
}
|
| 66 |
+
|
| 67 |
def __init__(
|
| 68 |
self,
|
| 69 |
model_name: str = "vidore/colSmol-500M",
|
|
|
|
| 81 |
if processor_speed not in ("fast", "slow", "auto"):
|
| 82 |
raise ValueError("processor_speed must be one of: fast, slow, auto")
|
| 83 |
self.processor_speed = processor_speed
|
| 84 |
+
|
| 85 |
if os.getenv("VISUALRAG_INCLUDE_SPECIAL_TOKENS"):
|
| 86 |
self.filter_special_tokens = False
|
| 87 |
logger.info("Special token filtering disabled via VISUALRAG_INCLUDE_SPECIAL_TOKENS")
|
| 88 |
+
|
| 89 |
if backend == "auto":
|
| 90 |
backend = self._detect_backend(model_name)
|
| 91 |
self.backend = backend
|
| 92 |
+
|
| 93 |
if device is None:
|
| 94 |
if torch.cuda.is_available():
|
| 95 |
device = "cuda"
|
|
|
|
| 98 |
else:
|
| 99 |
device = "cpu"
|
| 100 |
self.device = device
|
| 101 |
+
|
| 102 |
if torch_dtype is None:
|
| 103 |
if device == "cuda":
|
| 104 |
torch_dtype = torch.bfloat16
|
| 105 |
else:
|
| 106 |
torch_dtype = torch.float32
|
| 107 |
self.torch_dtype = torch_dtype
|
| 108 |
+
|
| 109 |
if output_dtype is None:
|
| 110 |
if torch_dtype == torch.float16:
|
| 111 |
output_dtype = np.float16
|
| 112 |
else:
|
| 113 |
output_dtype = np.float32
|
| 114 |
self.output_dtype = output_dtype
|
| 115 |
+
|
| 116 |
self._model = None
|
| 117 |
self._processor = None
|
| 118 |
self._image_token_id = None
|
| 119 |
+
|
| 120 |
+
logger.info("🤖 VisualEmbedder initialized")
|
| 121 |
logger.info(f" Model: {model_name}")
|
| 122 |
logger.info(f" Backend: {backend}")
|
| 123 |
+
logger.info(
|
| 124 |
+
f" Device: {device}, torch_dtype: {torch_dtype}, output_dtype: {output_dtype}"
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
def _detect_backend(self, model_name: str) -> str:
|
| 128 |
"""Auto-detect backend from model name."""
|
| 129 |
model_lower = model_name.lower()
|
| 130 |
+
|
| 131 |
for key, backend in self.MODEL_BACKENDS.items():
|
| 132 |
if key in model_lower:
|
| 133 |
logger.debug(f"Detected backend '{backend}' from model name")
|
| 134 |
return backend
|
| 135 |
+
|
| 136 |
# Default to colpali for unknown models
|
| 137 |
logger.warning(f"Unknown model '{model_name}', defaulting to 'colpali' backend")
|
| 138 |
return "colpali"
|
| 139 |
+
|
| 140 |
def _load_model(self):
|
| 141 |
"""Lazy load the model when first needed."""
|
| 142 |
if self._model is not None:
|
| 143 |
return
|
| 144 |
+
|
| 145 |
if self.backend == "colpali":
|
| 146 |
self._load_colpali_model()
|
| 147 |
else:
|
| 148 |
raise ValueError(f"Unknown backend: {self.backend}")
|
| 149 |
+
|
| 150 |
def _load_colpali_model(self):
|
| 151 |
"""Load ColPali-family model."""
|
| 152 |
try:
|
|
|
|
| 164 |
"pip install visual-rag-toolkit[embedding] or "
|
| 165 |
"pip install colpali-engine"
|
| 166 |
)
|
| 167 |
+
|
| 168 |
logger.info(f"🤖 Loading ColPali model: {self.model_name}")
|
| 169 |
logger.info(f" Device: {self.device}, dtype: {self.torch_dtype}")
|
| 170 |
|
|
|
|
| 172 |
if self.processor_speed == "auto":
|
| 173 |
return {}
|
| 174 |
return {"use_fast": self.processor_speed == "fast"}
|
| 175 |
+
|
| 176 |
from transformers import AutoConfig
|
| 177 |
|
| 178 |
cfg = AutoConfig.from_pretrained(self.model_name)
|
|
|
|
| 185 |
device_map=self.device,
|
| 186 |
).eval()
|
| 187 |
try:
|
| 188 |
+
self._processor = ColPaliProcessor.from_pretrained(
|
| 189 |
+
self.model_name, **_processor_kwargs()
|
| 190 |
+
)
|
| 191 |
except TypeError:
|
| 192 |
self._processor = ColPaliProcessor.from_pretrained(self.model_name)
|
| 193 |
except Exception:
|
| 194 |
if self.processor_speed == "fast":
|
| 195 |
+
self._processor = ColPaliProcessor.from_pretrained(
|
| 196 |
+
self.model_name, use_fast=False
|
| 197 |
+
)
|
| 198 |
else:
|
| 199 |
raise
|
| 200 |
self._image_token_id = self._processor.image_token_id
|
|
|
|
| 208 |
device_map=self.device,
|
| 209 |
).eval()
|
| 210 |
try:
|
| 211 |
+
self._processor = ColQwen2Processor.from_pretrained(
|
| 212 |
+
self.model_name, device_map=self.device, **_processor_kwargs()
|
| 213 |
+
)
|
| 214 |
except TypeError:
|
| 215 |
+
self._processor = ColQwen2Processor.from_pretrained(
|
| 216 |
+
self.model_name, device_map=self.device
|
| 217 |
+
)
|
| 218 |
except Exception:
|
| 219 |
if self.processor_speed == "fast":
|
| 220 |
+
self._processor = ColQwen2Processor.from_pretrained(
|
| 221 |
+
self.model_name, device_map=self.device, use_fast=False
|
| 222 |
+
)
|
| 223 |
else:
|
| 224 |
raise
|
| 225 |
self._image_token_id = self._processor.image_token_id
|
|
|
|
| 243 |
attn_implementation=attn_implementation,
|
| 244 |
).eval()
|
| 245 |
try:
|
| 246 |
+
self._processor = ColIdefics3Processor.from_pretrained(
|
| 247 |
+
self.model_name, **_processor_kwargs()
|
| 248 |
+
)
|
| 249 |
except TypeError:
|
| 250 |
self._processor = ColIdefics3Processor.from_pretrained(self.model_name)
|
| 251 |
except Exception:
|
| 252 |
if self.processor_speed == "fast":
|
| 253 |
+
self._processor = ColIdefics3Processor.from_pretrained(
|
| 254 |
+
self.model_name, use_fast=False
|
| 255 |
+
)
|
| 256 |
else:
|
| 257 |
raise
|
| 258 |
self._image_token_id = self._processor.image_token_id
|
| 259 |
+
|
| 260 |
logger.info("✅ Model loaded successfully")
|
| 261 |
+
|
| 262 |
@property
|
| 263 |
def model(self):
|
| 264 |
self._load_model()
|
| 265 |
return self._model
|
| 266 |
+
|
| 267 |
@property
|
| 268 |
def processor(self):
|
| 269 |
self._load_model()
|
| 270 |
return self._processor
|
| 271 |
+
|
| 272 |
@property
|
| 273 |
def image_token_id(self):
|
| 274 |
self._load_model()
|
| 275 |
return self._image_token_id
|
| 276 |
+
|
| 277 |
def embed_query(
|
| 278 |
self,
|
| 279 |
query_text: str,
|
|
|
|
| 281 |
) -> torch.Tensor:
|
| 282 |
"""
|
| 283 |
Generate embedding for a text query.
|
| 284 |
+
|
| 285 |
By default, filters out special tokens (CLS, SEP, PAD) to keep only
|
| 286 |
meaningful text tokens for better MaxSim matching.
|
| 287 |
+
|
| 288 |
Args:
|
| 289 |
query_text: Natural language query string
|
| 290 |
filter_special_tokens: Override instance-level setting
|
| 291 |
+
|
| 292 |
Returns:
|
| 293 |
Query embedding tensor of shape [num_tokens, embedding_dim]
|
| 294 |
"""
|
| 295 |
should_filter = (
|
| 296 |
+
filter_special_tokens
|
| 297 |
+
if filter_special_tokens is not None
|
| 298 |
else self.filter_special_tokens
|
| 299 |
)
|
| 300 |
+
|
| 301 |
with torch.no_grad():
|
| 302 |
processed = self.processor.process_queries([query_text]).to(self.model.device)
|
| 303 |
embedding = self.model(**processed)
|
| 304 |
+
|
| 305 |
# Remove batch dimension: [1, tokens, dim] -> [tokens, dim]
|
| 306 |
if embedding.dim() == 3:
|
| 307 |
embedding = embedding.squeeze(0)
|
| 308 |
+
|
| 309 |
if should_filter:
|
| 310 |
# Filter special tokens based on attention mask
|
| 311 |
attention_mask = processed.get("attention_mask")
|
|
|
|
| 313 |
# Keep only tokens with attention_mask = 1
|
| 314 |
valid_mask = attention_mask.squeeze(0).bool()
|
| 315 |
embedding = embedding[valid_mask]
|
| 316 |
+
|
| 317 |
# Additionally filter padding tokens if present
|
| 318 |
input_ids = processed.get("input_ids")
|
| 319 |
if input_ids is not None:
|
|
|
|
| 323 |
non_special_mask = input_ids >= 4
|
| 324 |
if non_special_mask.any():
|
| 325 |
embedding = embedding[non_special_mask]
|
| 326 |
+
|
| 327 |
logger.debug(f"Query embedding: {embedding.shape[0]} tokens after filtering")
|
| 328 |
else:
|
| 329 |
logger.debug(f"Query embedding: {embedding.shape[0]} tokens (unfiltered)")
|
| 330 |
+
|
| 331 |
return embedding
|
| 332 |
|
| 333 |
def embed_queries(
|
|
|
|
| 343 |
Returns a list of tensors, each of shape [num_tokens, embedding_dim].
|
| 344 |
"""
|
| 345 |
should_filter = (
|
| 346 |
+
filter_special_tokens
|
| 347 |
+
if filter_special_tokens is not None
|
| 348 |
+
else self.filter_special_tokens
|
| 349 |
)
|
| 350 |
batch_size = batch_size or self.batch_size
|
| 351 |
|
|
|
|
| 386 |
torch.mps.empty_cache()
|
| 387 |
|
| 388 |
return outputs
|
| 389 |
+
|
| 390 |
def embed_images(
|
| 391 |
self,
|
| 392 |
images: List[Image.Image],
|
|
|
|
| 396 |
) -> Union[List[torch.Tensor], Tuple[List[torch.Tensor], List[Dict[str, Any]]]]:
|
| 397 |
"""
|
| 398 |
Generate embeddings for a list of images.
|
| 399 |
+
|
| 400 |
Args:
|
| 401 |
images: List of PIL Images
|
| 402 |
batch_size: Override instance batch size
|
| 403 |
return_token_info: Also return token metadata (for saliency maps)
|
| 404 |
show_progress: Show progress bar
|
| 405 |
+
|
| 406 |
Returns:
|
| 407 |
If return_token_info=False:
|
| 408 |
List of embedding tensors [num_patches, dim]
|
| 409 |
If return_token_info=True:
|
| 410 |
Tuple of (embeddings, token_infos)
|
| 411 |
+
|
| 412 |
Token info contains:
|
| 413 |
- visual_token_indices: Indices of visual tokens in embedding
|
| 414 |
- num_visual_tokens: Count of visual tokens
|
|
|
|
| 416 |
- num_tiles: Total tiles (n_rows × n_cols + 1 global)
|
| 417 |
"""
|
| 418 |
batch_size = batch_size or self.batch_size
|
| 419 |
+
if (
|
| 420 |
+
self.device == "mps"
|
| 421 |
+
and "colpali" in (self.model_name or "").lower()
|
| 422 |
+
and int(batch_size) > 1
|
| 423 |
+
):
|
| 424 |
batch_size = 1
|
| 425 |
+
|
| 426 |
embeddings = []
|
| 427 |
token_infos = [] if return_token_info else None
|
| 428 |
+
|
| 429 |
iterator = range(0, len(images), batch_size)
|
| 430 |
if show_progress:
|
| 431 |
iterator = tqdm(iterator, desc="🎨 Embedding", unit="batch")
|
| 432 |
+
|
| 433 |
for i in iterator:
|
| 434 |
+
batch = images[i : i + batch_size]
|
| 435 |
+
|
| 436 |
with torch.no_grad():
|
| 437 |
processed = self.processor.process_images(batch).to(self.model.device)
|
| 438 |
+
|
| 439 |
# Extract token info before model forward
|
| 440 |
if return_token_info:
|
| 441 |
input_ids = processed["input_ids"]
|
| 442 |
batch_n_rows = processed.get("n_rows")
|
| 443 |
batch_n_cols = processed.get("n_cols")
|
| 444 |
+
|
| 445 |
for j in range(input_ids.shape[0]):
|
| 446 |
# Find visual token indices
|
| 447 |
+
image_token_mask = input_ids[j] == self.image_token_id
|
| 448 |
visual_indices = torch.where(image_token_mask)[0].cpu().numpy().tolist()
|
| 449 |
+
|
| 450 |
n_rows = batch_n_rows[j].item() if batch_n_rows is not None else None
|
| 451 |
n_cols = batch_n_cols[j].item() if batch_n_cols is not None else None
|
| 452 |
+
|
| 453 |
+
token_infos.append(
|
| 454 |
+
{
|
| 455 |
+
"visual_token_indices": visual_indices,
|
| 456 |
+
"num_visual_tokens": len(visual_indices),
|
| 457 |
+
"n_rows": n_rows,
|
| 458 |
+
"n_cols": n_cols,
|
| 459 |
+
"num_tiles": (n_rows * n_cols + 1) if n_rows and n_cols else None,
|
| 460 |
+
}
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
# Generate embeddings
|
| 464 |
batch_embeddings = self.model(**processed)
|
| 465 |
+
|
| 466 |
# Extract per-image embeddings
|
| 467 |
if isinstance(batch_embeddings, torch.Tensor) and batch_embeddings.dim() == 3:
|
| 468 |
for j in range(batch_embeddings.shape[0]):
|
| 469 |
embeddings.append(batch_embeddings[j].cpu())
|
| 470 |
else:
|
| 471 |
embeddings.extend([e.cpu() for e in batch_embeddings])
|
| 472 |
+
|
| 473 |
# Memory cleanup
|
| 474 |
del processed, batch_embeddings
|
| 475 |
gc.collect()
|
|
|
|
| 477 |
torch.cuda.empty_cache()
|
| 478 |
elif torch.backends.mps.is_available():
|
| 479 |
torch.mps.empty_cache()
|
| 480 |
+
|
| 481 |
if return_token_info:
|
| 482 |
return embeddings, token_infos
|
| 483 |
return embeddings
|
| 484 |
+
|
| 485 |
def extract_visual_embedding(
|
| 486 |
self,
|
| 487 |
full_embedding: torch.Tensor,
|
|
|
|
| 489 |
) -> np.ndarray:
|
| 490 |
"""
|
| 491 |
Extract only visual token embeddings from full embedding.
|
| 492 |
+
|
| 493 |
Filters out special tokens, keeping only visual patches for MaxSim.
|
| 494 |
+
|
| 495 |
Args:
|
| 496 |
full_embedding: Full embedding [all_tokens, dim]
|
| 497 |
token_info: Token info dict from embed_images
|
| 498 |
+
|
| 499 |
Returns:
|
| 500 |
Visual embedding array [num_visual_tokens, dim]
|
| 501 |
"""
|
| 502 |
visual_indices = token_info["visual_token_indices"]
|
| 503 |
+
|
| 504 |
if isinstance(full_embedding, torch.Tensor):
|
| 505 |
if full_embedding.dtype == torch.bfloat16:
|
| 506 |
visual_emb = full_embedding[visual_indices].cpu().float().numpy()
|
|
|
|
| 508 |
visual_emb = full_embedding[visual_indices].cpu().numpy()
|
| 509 |
else:
|
| 510 |
visual_emb = np.array(full_embedding)[visual_indices]
|
| 511 |
+
|
| 512 |
return visual_emb.astype(self.output_dtype)
|
| 513 |
|
| 514 |
def mean_pool_visual_embedding(
|
|
|
|
| 535 |
n_rows = (token_info or {}).get("n_rows")
|
| 536 |
n_cols = (token_info or {}).get("n_cols")
|
| 537 |
num_tiles = int(n_rows) * int(n_cols) + 1 if n_rows and n_cols else 13
|
| 538 |
+
return tile_level_mean_pooling(
|
| 539 |
+
visual_np, num_tiles=num_tiles, patches_per_tile=64, output_dtype=self.output_dtype
|
| 540 |
+
)
|
| 541 |
|
| 542 |
num_tokens = int(visual_np.shape[0])
|
| 543 |
grid = int(round(float(num_tokens) ** 0.5))
|
| 544 |
if grid * grid != num_tokens:
|
| 545 |
+
raise ValueError(
|
| 546 |
+
f"Cannot infer square grid from num_visual_tokens={num_tokens} for model={self.model_name}"
|
| 547 |
+
)
|
| 548 |
if int(target_vectors) != int(grid):
|
| 549 |
raise ValueError(
|
| 550 |
f"target_vectors={target_vectors} does not match inferred grid_size={grid} for model={self.model_name}"
|
| 551 |
)
|
| 552 |
+
return colpali_row_mean_pooling(
|
| 553 |
+
visual_np, grid_size=int(target_vectors), output_dtype=self.output_dtype
|
| 554 |
+
)
|
| 555 |
|
| 556 |
def global_pool_from_mean_pool(self, mean_pool: np.ndarray) -> np.ndarray:
|
| 557 |
if mean_pool.size == 0:
|
|
|
|
| 566 |
target_vectors: int = 32,
|
| 567 |
mean_pool: Optional[np.ndarray] = None,
|
| 568 |
) -> np.ndarray:
|
| 569 |
+
from visual_rag.embedding.pooling import (
|
| 570 |
+
colpali_experimental_pooling_from_rows,
|
| 571 |
+
colsmol_experimental_pooling,
|
| 572 |
+
)
|
| 573 |
|
| 574 |
model_lower = (self.model_name or "").lower()
|
| 575 |
is_colsmol = "colsmol" in model_lower
|
|
|
|
| 583 |
visual_np = np.array(visual_embedding, dtype=np.float32)
|
| 584 |
|
| 585 |
if is_colsmol:
|
| 586 |
+
if (
|
| 587 |
+
mean_pool is not None
|
| 588 |
+
and getattr(mean_pool, "shape", None) is not None
|
| 589 |
+
and int(mean_pool.shape[0]) > 0
|
| 590 |
+
):
|
| 591 |
num_tiles = int(mean_pool.shape[0])
|
| 592 |
else:
|
| 593 |
num_tiles = (token_info or {}).get("num_tiles")
|
|
|
|
| 600 |
if int(num_tiles) * patches_per_tile != int(num_visual_tokens):
|
| 601 |
num_tiles = int(num_tiles) + 1
|
| 602 |
num_tiles = int(num_tiles)
|
| 603 |
+
return colsmol_experimental_pooling(
|
| 604 |
+
visual_np, num_tiles=num_tiles, patches_per_tile=64, output_dtype=self.output_dtype
|
| 605 |
+
)
|
| 606 |
|
| 607 |
+
rows = (
|
| 608 |
+
mean_pool
|
| 609 |
+
if mean_pool is not None
|
| 610 |
+
else self.mean_pool_visual_embedding(
|
| 611 |
+
visual_np, token_info, target_vectors=target_vectors
|
| 612 |
+
)
|
| 613 |
+
)
|
| 614 |
if int(rows.shape[0]) != int(target_vectors):
|
| 615 |
raise ValueError(
|
| 616 |
f"experimental pooling expects mean_pool to have {target_vectors} rows, got {rows.shape[0]} for model={self.model_name}"
|
| 617 |
)
|
| 618 |
return colpali_experimental_pooling_from_rows(rows, output_dtype=self.output_dtype)
|
| 619 |
|
| 620 |
+
|
| 621 |
# Backward compatibility alias
|
| 622 |
ColPaliEmbedder = VisualEmbedder
|
visual_rag/indexing/__init__.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Indexing module - PDF processing, embedding storage, and CDN uploads.
|
| 3 |
+
|
| 4 |
+
Components:
|
| 5 |
+
- PDFProcessor: Convert PDFs to images and extract text
|
| 6 |
+
- QdrantIndexer: Upload embeddings to Qdrant vector database
|
| 7 |
+
- CloudinaryUploader: Upload images to Cloudinary CDN
|
| 8 |
+
- ProcessingPipeline: End-to-end PDF → Qdrant pipeline
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
# Lazy imports to avoid failures when optional dependencies aren't installed
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from visual_rag.indexing.pdf_processor import PDFProcessor
|
| 15 |
+
except ImportError:
|
| 16 |
+
PDFProcessor = None
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from visual_rag.indexing.qdrant_indexer import QdrantIndexer
|
| 20 |
+
except ImportError:
|
| 21 |
+
QdrantIndexer = None
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
from visual_rag.indexing.cloudinary_uploader import CloudinaryUploader
|
| 25 |
+
except ImportError:
|
| 26 |
+
CloudinaryUploader = None
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from visual_rag.indexing.pipeline import ProcessingPipeline
|
| 30 |
+
except ImportError:
|
| 31 |
+
ProcessingPipeline = None
|
| 32 |
+
|
| 33 |
+
__all__ = [
|
| 34 |
+
"PDFProcessor",
|
| 35 |
+
"QdrantIndexer",
|
| 36 |
+
"CloudinaryUploader",
|
| 37 |
+
"ProcessingPipeline",
|
| 38 |
+
]
|
visual_rag/indexing/cloudinary_uploader.py
CHANGED
|
@@ -15,14 +15,15 @@ Environment Variables:
|
|
| 15 |
"""
|
| 16 |
|
| 17 |
import io
|
| 18 |
-
import os
|
| 19 |
-
import time
|
| 20 |
-
import signal
|
| 21 |
import logging
|
|
|
|
| 22 |
import platform
|
|
|
|
| 23 |
import threading
|
|
|
|
|
|
|
|
|
|
| 24 |
from typing import Optional
|
| 25 |
-
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
|
| 26 |
|
| 27 |
from PIL import Image
|
| 28 |
|
|
@@ -34,9 +35,9 @@ THREAD_SAFE_MODE = os.getenv("VISUAL_RAG_THREAD_SAFE", "").lower() in ("1", "tru
|
|
| 34 |
class CloudinaryUploader:
|
| 35 |
"""
|
| 36 |
Upload images to Cloudinary CDN.
|
| 37 |
-
|
| 38 |
Works independently - just needs PIL images.
|
| 39 |
-
|
| 40 |
Args:
|
| 41 |
cloud_name: Cloudinary cloud name
|
| 42 |
api_key: Cloudinary API key
|
|
@@ -44,7 +45,7 @@ class CloudinaryUploader:
|
|
| 44 |
folder: Base folder for uploads
|
| 45 |
max_retries: Number of retry attempts
|
| 46 |
timeout_seconds: Timeout per upload
|
| 47 |
-
|
| 48 |
Example:
|
| 49 |
>>> uploader = CloudinaryUploader(
|
| 50 |
... cloud_name="my-cloud",
|
|
@@ -52,11 +53,11 @@ class CloudinaryUploader:
|
|
| 52 |
... api_secret="yyy",
|
| 53 |
... folder="my-project",
|
| 54 |
... )
|
| 55 |
-
>>>
|
| 56 |
>>> url = uploader.upload(image, "doc_page_1")
|
| 57 |
>>> print(url) # https://res.cloudinary.com/.../doc_page_1.jpg
|
| 58 |
"""
|
| 59 |
-
|
| 60 |
def __init__(
|
| 61 |
self,
|
| 62 |
cloud_name: Optional[str] = None,
|
|
@@ -71,19 +72,19 @@ class CloudinaryUploader:
|
|
| 71 |
self.cloud_name = cloud_name or os.getenv("CLOUDINARY_CLOUD_NAME")
|
| 72 |
self.api_key = api_key or os.getenv("CLOUDINARY_API_KEY")
|
| 73 |
self.api_secret = api_secret or os.getenv("CLOUDINARY_API_SECRET")
|
| 74 |
-
|
| 75 |
if not all([self.cloud_name, self.api_key, self.api_secret]):
|
| 76 |
raise ValueError(
|
| 77 |
"Cloudinary credentials required. Set CLOUDINARY_CLOUD_NAME, "
|
| 78 |
"CLOUDINARY_API_KEY, CLOUDINARY_API_SECRET environment variables "
|
| 79 |
"or pass them as arguments."
|
| 80 |
)
|
| 81 |
-
|
| 82 |
self.folder = folder
|
| 83 |
self.max_retries = max_retries
|
| 84 |
self.timeout_seconds = timeout_seconds
|
| 85 |
self.jpeg_quality = jpeg_quality
|
| 86 |
-
|
| 87 |
# Check dependency
|
| 88 |
try:
|
| 89 |
import cloudinary # noqa
|
|
@@ -92,10 +93,10 @@ class CloudinaryUploader:
|
|
| 92 |
"Cloudinary not installed. "
|
| 93 |
"Install with: pip install visual-rag-toolkit[cloudinary]"
|
| 94 |
)
|
| 95 |
-
|
| 96 |
-
logger.info(
|
| 97 |
logger.info(f" Folder: {folder}")
|
| 98 |
-
|
| 99 |
def upload(
|
| 100 |
self,
|
| 101 |
image: Image.Image,
|
|
@@ -104,34 +105,34 @@ class CloudinaryUploader:
|
|
| 104 |
) -> Optional[str]:
|
| 105 |
"""
|
| 106 |
Upload a single image to Cloudinary.
|
| 107 |
-
|
| 108 |
Args:
|
| 109 |
image: PIL Image to upload
|
| 110 |
public_id: Public ID (filename without extension)
|
| 111 |
subfolder: Optional subfolder within base folder
|
| 112 |
-
|
| 113 |
Returns:
|
| 114 |
Secure URL of uploaded image, or None if failed
|
| 115 |
"""
|
| 116 |
import cloudinary
|
| 117 |
import cloudinary.uploader
|
| 118 |
-
|
| 119 |
# Prepare buffer
|
| 120 |
buffer = io.BytesIO()
|
| 121 |
image.save(buffer, format="JPEG", quality=self.jpeg_quality, optimize=True)
|
| 122 |
-
|
| 123 |
# Configure Cloudinary
|
| 124 |
cloudinary.config(
|
| 125 |
cloud_name=self.cloud_name,
|
| 126 |
api_key=self.api_key,
|
| 127 |
api_secret=self.api_secret,
|
| 128 |
)
|
| 129 |
-
|
| 130 |
# Build folder path
|
| 131 |
folder_path = self.folder
|
| 132 |
if subfolder:
|
| 133 |
folder_path = f"{self.folder}/{subfolder}"
|
| 134 |
-
|
| 135 |
def do_upload():
|
| 136 |
buffer.seek(0)
|
| 137 |
result = cloudinary.uploader.upload(
|
|
@@ -143,14 +144,14 @@ class CloudinaryUploader:
|
|
| 143 |
timeout=self.timeout_seconds,
|
| 144 |
)
|
| 145 |
return result["secure_url"]
|
| 146 |
-
|
| 147 |
# Use thread-safe mode for Streamlit/Flask/threaded contexts
|
| 148 |
# Set VISUAL_RAG_THREAD_SAFE=1 to enable
|
| 149 |
if THREAD_SAFE_MODE or threading.current_thread() is not threading.main_thread():
|
| 150 |
return self._upload_with_thread_timeout(do_upload, public_id)
|
| 151 |
else:
|
| 152 |
return self._upload_with_signal_timeout(do_upload, public_id)
|
| 153 |
-
|
| 154 |
def _upload_with_thread_timeout(self, do_upload, public_id: str) -> Optional[str]:
|
| 155 |
"""Thread-safe upload with ThreadPoolExecutor timeout."""
|
| 156 |
for attempt in range(self.max_retries):
|
|
@@ -158,64 +159,60 @@ class CloudinaryUploader:
|
|
| 158 |
with ThreadPoolExecutor(max_workers=1) as executor:
|
| 159 |
future = executor.submit(do_upload)
|
| 160 |
return future.result(timeout=self.timeout_seconds)
|
| 161 |
-
|
| 162 |
except FuturesTimeoutError:
|
| 163 |
logger.warning(
|
| 164 |
f"Upload timeout (attempt {attempt + 1}/{self.max_retries}): {public_id}"
|
| 165 |
)
|
| 166 |
if attempt < self.max_retries - 1:
|
| 167 |
-
time.sleep(2
|
| 168 |
-
|
| 169 |
except Exception as e:
|
| 170 |
-
logger.warning(
|
| 171 |
-
f"Upload failed (attempt {attempt + 1}/{self.max_retries}): {e}"
|
| 172 |
-
)
|
| 173 |
if attempt < self.max_retries - 1:
|
| 174 |
-
time.sleep(2
|
| 175 |
-
|
| 176 |
logger.error(f"❌ Upload failed after {self.max_retries} attempts: {public_id}")
|
| 177 |
return None
|
| 178 |
-
|
| 179 |
def _upload_with_signal_timeout(self, do_upload, public_id: str) -> Optional[str]:
|
| 180 |
"""Signal-based upload timeout (main thread only, Unix/macOS)."""
|
| 181 |
use_timeout = platform.system() != "Windows"
|
| 182 |
-
|
| 183 |
class SignalTimeoutError(Exception):
|
| 184 |
pass
|
| 185 |
-
|
| 186 |
def timeout_handler(signum, frame):
|
| 187 |
raise SignalTimeoutError(f"Upload timed out after {self.timeout_seconds}s")
|
| 188 |
-
|
| 189 |
for attempt in range(self.max_retries):
|
| 190 |
try:
|
| 191 |
if use_timeout:
|
| 192 |
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
| 193 |
signal.alarm(self.timeout_seconds)
|
| 194 |
-
|
| 195 |
try:
|
| 196 |
return do_upload()
|
| 197 |
finally:
|
| 198 |
if use_timeout:
|
| 199 |
signal.alarm(0)
|
| 200 |
signal.signal(signal.SIGALRM, old_handler)
|
| 201 |
-
|
| 202 |
except SignalTimeoutError:
|
| 203 |
logger.warning(
|
| 204 |
f"Upload timeout (attempt {attempt + 1}/{self.max_retries}): {public_id}"
|
| 205 |
)
|
| 206 |
if attempt < self.max_retries - 1:
|
| 207 |
-
time.sleep(2
|
| 208 |
-
|
| 209 |
except Exception as e:
|
| 210 |
-
logger.warning(
|
| 211 |
-
f"Upload failed (attempt {attempt + 1}/{self.max_retries}): {e}"
|
| 212 |
-
)
|
| 213 |
if attempt < self.max_retries - 1:
|
| 214 |
-
time.sleep(2
|
| 215 |
-
|
| 216 |
logger.error(f"❌ Upload failed after {self.max_retries} attempts: {public_id}")
|
| 217 |
return None
|
| 218 |
-
|
| 219 |
def upload_original_and_resized(
|
| 220 |
self,
|
| 221 |
original_image: Image.Image,
|
|
@@ -224,12 +221,12 @@ class CloudinaryUploader:
|
|
| 224 |
) -> tuple:
|
| 225 |
"""
|
| 226 |
Upload both original and resized versions.
|
| 227 |
-
|
| 228 |
Args:
|
| 229 |
original_image: Original PDF page image
|
| 230 |
resized_image: Resized image for ColPali
|
| 231 |
base_public_id: Base public ID (e.g., "doc_page_1")
|
| 232 |
-
|
| 233 |
Returns:
|
| 234 |
Tuple of (original_url, resized_url) - either can be None on failure
|
| 235 |
"""
|
|
@@ -238,13 +235,13 @@ class CloudinaryUploader:
|
|
| 238 |
base_public_id,
|
| 239 |
subfolder="original",
|
| 240 |
)
|
| 241 |
-
|
| 242 |
resized_url = self.upload(
|
| 243 |
resized_image,
|
| 244 |
base_public_id,
|
| 245 |
subfolder="resized",
|
| 246 |
)
|
| 247 |
-
|
| 248 |
return original_url, resized_url
|
| 249 |
|
| 250 |
def upload_original_cropped_and_resized(
|
|
@@ -275,5 +272,3 @@ class CloudinaryUploader:
|
|
| 275 |
)
|
| 276 |
|
| 277 |
return original_url, cropped_url, resized_url
|
| 278 |
-
|
| 279 |
-
|
|
|
|
| 15 |
"""
|
| 16 |
|
| 17 |
import io
|
|
|
|
|
|
|
|
|
|
| 18 |
import logging
|
| 19 |
+
import os
|
| 20 |
import platform
|
| 21 |
+
import signal
|
| 22 |
import threading
|
| 23 |
+
import time
|
| 24 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 25 |
+
from concurrent.futures import TimeoutError as FuturesTimeoutError
|
| 26 |
from typing import Optional
|
|
|
|
| 27 |
|
| 28 |
from PIL import Image
|
| 29 |
|
|
|
|
| 35 |
class CloudinaryUploader:
|
| 36 |
"""
|
| 37 |
Upload images to Cloudinary CDN.
|
| 38 |
+
|
| 39 |
Works independently - just needs PIL images.
|
| 40 |
+
|
| 41 |
Args:
|
| 42 |
cloud_name: Cloudinary cloud name
|
| 43 |
api_key: Cloudinary API key
|
|
|
|
| 45 |
folder: Base folder for uploads
|
| 46 |
max_retries: Number of retry attempts
|
| 47 |
timeout_seconds: Timeout per upload
|
| 48 |
+
|
| 49 |
Example:
|
| 50 |
>>> uploader = CloudinaryUploader(
|
| 51 |
... cloud_name="my-cloud",
|
|
|
|
| 53 |
... api_secret="yyy",
|
| 54 |
... folder="my-project",
|
| 55 |
... )
|
| 56 |
+
>>>
|
| 57 |
>>> url = uploader.upload(image, "doc_page_1")
|
| 58 |
>>> print(url) # https://res.cloudinary.com/.../doc_page_1.jpg
|
| 59 |
"""
|
| 60 |
+
|
| 61 |
def __init__(
|
| 62 |
self,
|
| 63 |
cloud_name: Optional[str] = None,
|
|
|
|
| 72 |
self.cloud_name = cloud_name or os.getenv("CLOUDINARY_CLOUD_NAME")
|
| 73 |
self.api_key = api_key or os.getenv("CLOUDINARY_API_KEY")
|
| 74 |
self.api_secret = api_secret or os.getenv("CLOUDINARY_API_SECRET")
|
| 75 |
+
|
| 76 |
if not all([self.cloud_name, self.api_key, self.api_secret]):
|
| 77 |
raise ValueError(
|
| 78 |
"Cloudinary credentials required. Set CLOUDINARY_CLOUD_NAME, "
|
| 79 |
"CLOUDINARY_API_KEY, CLOUDINARY_API_SECRET environment variables "
|
| 80 |
"or pass them as arguments."
|
| 81 |
)
|
| 82 |
+
|
| 83 |
self.folder = folder
|
| 84 |
self.max_retries = max_retries
|
| 85 |
self.timeout_seconds = timeout_seconds
|
| 86 |
self.jpeg_quality = jpeg_quality
|
| 87 |
+
|
| 88 |
# Check dependency
|
| 89 |
try:
|
| 90 |
import cloudinary # noqa
|
|
|
|
| 93 |
"Cloudinary not installed. "
|
| 94 |
"Install with: pip install visual-rag-toolkit[cloudinary]"
|
| 95 |
)
|
| 96 |
+
|
| 97 |
+
logger.info("☁️ Cloudinary uploader initialized")
|
| 98 |
logger.info(f" Folder: {folder}")
|
| 99 |
+
|
| 100 |
def upload(
|
| 101 |
self,
|
| 102 |
image: Image.Image,
|
|
|
|
| 105 |
) -> Optional[str]:
|
| 106 |
"""
|
| 107 |
Upload a single image to Cloudinary.
|
| 108 |
+
|
| 109 |
Args:
|
| 110 |
image: PIL Image to upload
|
| 111 |
public_id: Public ID (filename without extension)
|
| 112 |
subfolder: Optional subfolder within base folder
|
| 113 |
+
|
| 114 |
Returns:
|
| 115 |
Secure URL of uploaded image, or None if failed
|
| 116 |
"""
|
| 117 |
import cloudinary
|
| 118 |
import cloudinary.uploader
|
| 119 |
+
|
| 120 |
# Prepare buffer
|
| 121 |
buffer = io.BytesIO()
|
| 122 |
image.save(buffer, format="JPEG", quality=self.jpeg_quality, optimize=True)
|
| 123 |
+
|
| 124 |
# Configure Cloudinary
|
| 125 |
cloudinary.config(
|
| 126 |
cloud_name=self.cloud_name,
|
| 127 |
api_key=self.api_key,
|
| 128 |
api_secret=self.api_secret,
|
| 129 |
)
|
| 130 |
+
|
| 131 |
# Build folder path
|
| 132 |
folder_path = self.folder
|
| 133 |
if subfolder:
|
| 134 |
folder_path = f"{self.folder}/{subfolder}"
|
| 135 |
+
|
| 136 |
def do_upload():
|
| 137 |
buffer.seek(0)
|
| 138 |
result = cloudinary.uploader.upload(
|
|
|
|
| 144 |
timeout=self.timeout_seconds,
|
| 145 |
)
|
| 146 |
return result["secure_url"]
|
| 147 |
+
|
| 148 |
# Use thread-safe mode for Streamlit/Flask/threaded contexts
|
| 149 |
# Set VISUAL_RAG_THREAD_SAFE=1 to enable
|
| 150 |
if THREAD_SAFE_MODE or threading.current_thread() is not threading.main_thread():
|
| 151 |
return self._upload_with_thread_timeout(do_upload, public_id)
|
| 152 |
else:
|
| 153 |
return self._upload_with_signal_timeout(do_upload, public_id)
|
| 154 |
+
|
| 155 |
def _upload_with_thread_timeout(self, do_upload, public_id: str) -> Optional[str]:
|
| 156 |
"""Thread-safe upload with ThreadPoolExecutor timeout."""
|
| 157 |
for attempt in range(self.max_retries):
|
|
|
|
| 159 |
with ThreadPoolExecutor(max_workers=1) as executor:
|
| 160 |
future = executor.submit(do_upload)
|
| 161 |
return future.result(timeout=self.timeout_seconds)
|
| 162 |
+
|
| 163 |
except FuturesTimeoutError:
|
| 164 |
logger.warning(
|
| 165 |
f"Upload timeout (attempt {attempt + 1}/{self.max_retries}): {public_id}"
|
| 166 |
)
|
| 167 |
if attempt < self.max_retries - 1:
|
| 168 |
+
time.sleep(2**attempt)
|
| 169 |
+
|
| 170 |
except Exception as e:
|
| 171 |
+
logger.warning(f"Upload failed (attempt {attempt + 1}/{self.max_retries}): {e}")
|
|
|
|
|
|
|
| 172 |
if attempt < self.max_retries - 1:
|
| 173 |
+
time.sleep(2**attempt)
|
| 174 |
+
|
| 175 |
logger.error(f"❌ Upload failed after {self.max_retries} attempts: {public_id}")
|
| 176 |
return None
|
| 177 |
+
|
| 178 |
def _upload_with_signal_timeout(self, do_upload, public_id: str) -> Optional[str]:
|
| 179 |
"""Signal-based upload timeout (main thread only, Unix/macOS)."""
|
| 180 |
use_timeout = platform.system() != "Windows"
|
| 181 |
+
|
| 182 |
class SignalTimeoutError(Exception):
|
| 183 |
pass
|
| 184 |
+
|
| 185 |
def timeout_handler(signum, frame):
|
| 186 |
raise SignalTimeoutError(f"Upload timed out after {self.timeout_seconds}s")
|
| 187 |
+
|
| 188 |
for attempt in range(self.max_retries):
|
| 189 |
try:
|
| 190 |
if use_timeout:
|
| 191 |
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
| 192 |
signal.alarm(self.timeout_seconds)
|
| 193 |
+
|
| 194 |
try:
|
| 195 |
return do_upload()
|
| 196 |
finally:
|
| 197 |
if use_timeout:
|
| 198 |
signal.alarm(0)
|
| 199 |
signal.signal(signal.SIGALRM, old_handler)
|
| 200 |
+
|
| 201 |
except SignalTimeoutError:
|
| 202 |
logger.warning(
|
| 203 |
f"Upload timeout (attempt {attempt + 1}/{self.max_retries}): {public_id}"
|
| 204 |
)
|
| 205 |
if attempt < self.max_retries - 1:
|
| 206 |
+
time.sleep(2**attempt)
|
| 207 |
+
|
| 208 |
except Exception as e:
|
| 209 |
+
logger.warning(f"Upload failed (attempt {attempt + 1}/{self.max_retries}): {e}")
|
|
|
|
|
|
|
| 210 |
if attempt < self.max_retries - 1:
|
| 211 |
+
time.sleep(2**attempt)
|
| 212 |
+
|
| 213 |
logger.error(f"❌ Upload failed after {self.max_retries} attempts: {public_id}")
|
| 214 |
return None
|
| 215 |
+
|
| 216 |
def upload_original_and_resized(
|
| 217 |
self,
|
| 218 |
original_image: Image.Image,
|
|
|
|
| 221 |
) -> tuple:
|
| 222 |
"""
|
| 223 |
Upload both original and resized versions.
|
| 224 |
+
|
| 225 |
Args:
|
| 226 |
original_image: Original PDF page image
|
| 227 |
resized_image: Resized image for ColPali
|
| 228 |
base_public_id: Base public ID (e.g., "doc_page_1")
|
| 229 |
+
|
| 230 |
Returns:
|
| 231 |
Tuple of (original_url, resized_url) - either can be None on failure
|
| 232 |
"""
|
|
|
|
| 235 |
base_public_id,
|
| 236 |
subfolder="original",
|
| 237 |
)
|
| 238 |
+
|
| 239 |
resized_url = self.upload(
|
| 240 |
resized_image,
|
| 241 |
base_public_id,
|
| 242 |
subfolder="resized",
|
| 243 |
)
|
| 244 |
+
|
| 245 |
return original_url, resized_url
|
| 246 |
|
| 247 |
def upload_original_cropped_and_resized(
|
|
|
|
| 272 |
)
|
| 273 |
|
| 274 |
return original_url, cropped_url, resized_url
|
|
|
|
|
|
visual_rag/indexing/pdf_processor.py
CHANGED
|
@@ -11,10 +11,10 @@ Features:
|
|
| 11 |
"""
|
| 12 |
|
| 13 |
import gc
|
| 14 |
-
import re
|
| 15 |
import logging
|
|
|
|
| 16 |
from pathlib import Path
|
| 17 |
-
from typing import
|
| 18 |
|
| 19 |
from PIL import Image
|
| 20 |
|
|
@@ -24,26 +24,26 @@ logger = logging.getLogger(__name__)
|
|
| 24 |
class PDFProcessor:
|
| 25 |
"""
|
| 26 |
Process PDFs into images and text for visual retrieval.
|
| 27 |
-
|
| 28 |
Works independently - no embedding or storage dependencies.
|
| 29 |
-
|
| 30 |
Args:
|
| 31 |
dpi: DPI for image conversion (higher = better quality)
|
| 32 |
output_format: Image format (RGB, L, etc.)
|
| 33 |
page_batch_size: Pages per batch for memory efficiency
|
| 34 |
-
|
| 35 |
Example:
|
| 36 |
>>> processor = PDFProcessor(dpi=140)
|
| 37 |
-
>>>
|
| 38 |
>>> # Convert single PDF
|
| 39 |
>>> images, texts = processor.process_pdf(Path("report.pdf"))
|
| 40 |
-
>>>
|
| 41 |
>>> # Stream large PDFs
|
| 42 |
>>> for images, texts in processor.stream_pdf(Path("large.pdf"), batch_size=10):
|
| 43 |
... # Process each batch
|
| 44 |
... pass
|
| 45 |
"""
|
| 46 |
-
|
| 47 |
def __init__(
|
| 48 |
self,
|
| 49 |
dpi: int = 140,
|
|
@@ -53,17 +53,24 @@ class PDFProcessor:
|
|
| 53 |
self.dpi = dpi
|
| 54 |
self.output_format = output_format
|
| 55 |
self.page_batch_size = page_batch_size
|
| 56 |
-
|
| 57 |
-
#
|
|
|
|
|
|
|
|
|
|
| 58 |
try:
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
except
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
raise ImportError(
|
| 63 |
-
"PDF processing requires pdf2image and pypdf.
|
| 64 |
-
|
| 65 |
)
|
| 66 |
-
|
| 67 |
def process_pdf(
|
| 68 |
self,
|
| 69 |
pdf_path: Path,
|
|
@@ -71,38 +78,39 @@ class PDFProcessor:
|
|
| 71 |
) -> Tuple[List[Image.Image], List[str]]:
|
| 72 |
"""
|
| 73 |
Convert PDF to images and extract text.
|
| 74 |
-
|
| 75 |
Args:
|
| 76 |
pdf_path: Path to PDF file
|
| 77 |
dpi: Override default DPI
|
| 78 |
-
|
| 79 |
Returns:
|
| 80 |
Tuple of (list of images, list of page texts)
|
| 81 |
"""
|
|
|
|
| 82 |
from pdf2image import convert_from_path
|
| 83 |
from pypdf import PdfReader
|
| 84 |
-
|
| 85 |
dpi = dpi or self.dpi
|
| 86 |
pdf_path = Path(pdf_path)
|
| 87 |
-
|
| 88 |
logger.info(f"📄 Processing PDF: {pdf_path.name}")
|
| 89 |
-
|
| 90 |
# Extract text
|
| 91 |
reader = PdfReader(str(pdf_path))
|
| 92 |
total_pages = len(reader.pages)
|
| 93 |
-
|
| 94 |
page_texts = []
|
| 95 |
for page in reader.pages:
|
| 96 |
text = page.extract_text() or ""
|
| 97 |
# Handle surrogate characters
|
| 98 |
text = self._sanitize_text(text)
|
| 99 |
page_texts.append(text)
|
| 100 |
-
|
| 101 |
# Convert to images in batches
|
| 102 |
all_images = []
|
| 103 |
for start_page in range(1, total_pages + 1, self.page_batch_size):
|
| 104 |
end_page = min(start_page + self.page_batch_size - 1, total_pages)
|
| 105 |
-
|
| 106 |
batch_images = convert_from_path(
|
| 107 |
str(pdf_path),
|
| 108 |
dpi=dpi,
|
|
@@ -110,19 +118,19 @@ class PDFProcessor:
|
|
| 110 |
first_page=start_page,
|
| 111 |
last_page=end_page,
|
| 112 |
)
|
| 113 |
-
|
| 114 |
all_images.extend(batch_images)
|
| 115 |
-
|
| 116 |
del batch_images
|
| 117 |
gc.collect()
|
| 118 |
-
|
| 119 |
-
assert len(all_images) == len(
|
| 120 |
-
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
logger.info(f"✅ Processed {len(all_images)} pages")
|
| 124 |
return all_images, page_texts
|
| 125 |
-
|
| 126 |
def stream_pdf(
|
| 127 |
self,
|
| 128 |
pdf_path: Path,
|
|
@@ -131,39 +139,40 @@ class PDFProcessor:
|
|
| 131 |
) -> Generator[Tuple[List[Image.Image], List[str], int], None, None]:
|
| 132 |
"""
|
| 133 |
Stream PDF processing for large files.
|
| 134 |
-
|
| 135 |
Yields batches of (images, texts, start_page) without loading
|
| 136 |
entire PDF into memory.
|
| 137 |
-
|
| 138 |
Args:
|
| 139 |
pdf_path: Path to PDF file
|
| 140 |
batch_size: Pages per batch
|
| 141 |
dpi: Override default DPI
|
| 142 |
-
|
| 143 |
Yields:
|
| 144 |
Tuple of (batch_images, batch_texts, start_page_number)
|
| 145 |
"""
|
|
|
|
| 146 |
from pdf2image import convert_from_path
|
| 147 |
from pypdf import PdfReader
|
| 148 |
-
|
| 149 |
dpi = dpi or self.dpi
|
| 150 |
pdf_path = Path(pdf_path)
|
| 151 |
-
|
| 152 |
reader = PdfReader(str(pdf_path))
|
| 153 |
total_pages = len(reader.pages)
|
| 154 |
-
|
| 155 |
logger.info(f"📄 Streaming PDF: {pdf_path.name} ({total_pages} pages)")
|
| 156 |
-
|
| 157 |
for start_idx in range(0, total_pages, batch_size):
|
| 158 |
end_idx = min(start_idx + batch_size, total_pages)
|
| 159 |
-
|
| 160 |
# Extract text for batch
|
| 161 |
batch_texts = []
|
| 162 |
for page_idx in range(start_idx, end_idx):
|
| 163 |
text = reader.pages[page_idx].extract_text() or ""
|
| 164 |
text = self._sanitize_text(text)
|
| 165 |
batch_texts.append(text)
|
| 166 |
-
|
| 167 |
# Convert images for batch
|
| 168 |
batch_images = convert_from_path(
|
| 169 |
str(pdf_path),
|
|
@@ -172,18 +181,20 @@ class PDFProcessor:
|
|
| 172 |
first_page=start_idx + 1, # 1-indexed
|
| 173 |
last_page=end_idx,
|
| 174 |
)
|
| 175 |
-
|
| 176 |
yield batch_images, batch_texts, start_idx + 1
|
| 177 |
-
|
| 178 |
del batch_images
|
| 179 |
gc.collect()
|
| 180 |
-
|
| 181 |
def get_page_count(self, pdf_path: Path) -> int:
|
| 182 |
"""Get number of pages in PDF without loading images."""
|
|
|
|
| 183 |
from pypdf import PdfReader
|
|
|
|
| 184 |
reader = PdfReader(str(pdf_path))
|
| 185 |
return len(reader.pages)
|
| 186 |
-
|
| 187 |
def resize_for_colpali(
|
| 188 |
self,
|
| 189 |
image: Image.Image,
|
|
@@ -192,19 +203,23 @@ class PDFProcessor:
|
|
| 192 |
) -> Tuple[Image.Image, int, int]:
|
| 193 |
"""
|
| 194 |
Resize image following ColPali/Idefics3 processor logic.
|
| 195 |
-
|
| 196 |
Resizes to fit within tile grid without black padding.
|
| 197 |
-
|
| 198 |
Args:
|
| 199 |
image: PIL Image
|
| 200 |
max_edge: Maximum edge length
|
| 201 |
tile_size: Size of each tile
|
| 202 |
-
|
| 203 |
Returns:
|
| 204 |
Tuple of (resized_image, tile_rows, tile_cols)
|
| 205 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
w, h = image.size
|
| 207 |
-
|
| 208 |
# Step 1: Resize so longest edge = max_edge
|
| 209 |
if w > h:
|
| 210 |
new_w = max_edge
|
|
@@ -212,25 +227,25 @@ class PDFProcessor:
|
|
| 212 |
else:
|
| 213 |
new_h = max_edge
|
| 214 |
new_w = int(w * (max_edge / h))
|
| 215 |
-
|
| 216 |
# Step 2: Calculate tile grid
|
| 217 |
tile_cols = (new_w + tile_size - 1) // tile_size
|
| 218 |
tile_rows = (new_h + tile_size - 1) // tile_size
|
| 219 |
-
|
| 220 |
# Step 3: Calculate exact dimensions for tiles
|
| 221 |
final_w = tile_cols * tile_size
|
| 222 |
final_h = tile_rows * tile_size
|
| 223 |
-
|
| 224 |
# Step 4: Scale to fit within tile grid
|
| 225 |
scale_w = final_w / w
|
| 226 |
scale_h = final_h / h
|
| 227 |
scale = min(scale_w, scale_h)
|
| 228 |
-
|
| 229 |
scaled_w = int(w * scale)
|
| 230 |
scaled_h = int(h * scale)
|
| 231 |
-
|
| 232 |
resized = image.resize((scaled_w, scaled_h), Image.LANCZOS)
|
| 233 |
-
|
| 234 |
# Center on white canvas if needed
|
| 235 |
if scaled_w != final_w or scaled_h != final_h:
|
| 236 |
canvas = Image.new("RGB", (final_w, final_h), (255, 255, 255))
|
|
@@ -238,19 +253,17 @@ class PDFProcessor:
|
|
| 238 |
offset_y = (final_h - scaled_h) // 2
|
| 239 |
canvas.paste(resized, (offset_x, offset_y))
|
| 240 |
resized = canvas
|
| 241 |
-
|
| 242 |
return resized, tile_rows, tile_cols
|
| 243 |
-
|
| 244 |
def _sanitize_text(self, text: str) -> str:
|
| 245 |
"""Remove invalid Unicode characters (surrogates) from text."""
|
| 246 |
if not text:
|
| 247 |
return ""
|
| 248 |
-
|
| 249 |
# Remove surrogate characters (U+D800-U+DFFF)
|
| 250 |
-
return text.encode("utf-8", errors="surrogatepass").decode(
|
| 251 |
-
|
| 252 |
-
)
|
| 253 |
-
|
| 254 |
def extract_metadata_from_filename(
|
| 255 |
self,
|
| 256 |
filename: str,
|
|
@@ -258,47 +271,45 @@ class PDFProcessor:
|
|
| 258 |
) -> Dict[str, Any]:
|
| 259 |
"""
|
| 260 |
Extract metadata from PDF filename.
|
| 261 |
-
|
| 262 |
Uses mapping if provided, otherwise falls back to pattern matching.
|
| 263 |
-
|
| 264 |
Args:
|
| 265 |
filename: PDF filename (with or without .pdf extension)
|
| 266 |
mapping: Optional mapping dict {filename: metadata}
|
| 267 |
-
|
| 268 |
Returns:
|
| 269 |
Metadata dict with year, source, district, etc.
|
| 270 |
"""
|
| 271 |
# Remove extension
|
| 272 |
stem = Path(filename).stem
|
| 273 |
stem_lower = stem.lower().strip()
|
| 274 |
-
|
| 275 |
# Try mapping first
|
| 276 |
if mapping:
|
| 277 |
if stem_lower in mapping:
|
| 278 |
return mapping[stem_lower].copy()
|
| 279 |
-
|
| 280 |
# Try without .pdf
|
| 281 |
stem_no_ext = stem_lower.replace(".pdf", "")
|
| 282 |
if stem_no_ext in mapping:
|
| 283 |
return mapping[stem_no_ext].copy()
|
| 284 |
-
|
| 285 |
# Fallback: pattern matching
|
| 286 |
metadata = {"filename": filename}
|
| 287 |
-
|
| 288 |
# Extract year
|
| 289 |
year_match = re.search(r"(20\d{2})", stem)
|
| 290 |
if year_match:
|
| 291 |
metadata["year"] = int(year_match.group(1))
|
| 292 |
-
|
| 293 |
# Detect source type
|
| 294 |
if "consolidated" in stem_lower or ("annual" in stem_lower and "oag" in stem_lower):
|
| 295 |
metadata["source"] = "Consolidated"
|
| 296 |
elif "dlg" in stem_lower or "district local government" in stem_lower:
|
| 297 |
metadata["source"] = "Local Government"
|
| 298 |
# Try to extract district name
|
| 299 |
-
district_match = re.search(
|
| 300 |
-
r"([a-z]+)\s+(?:dlg|district local government)", stem_lower
|
| 301 |
-
)
|
| 302 |
if district_match:
|
| 303 |
metadata["district"] = district_match.group(1).title()
|
| 304 |
elif "hospital" in stem_lower or "referral" in stem_lower:
|
|
@@ -309,7 +320,5 @@ class PDFProcessor:
|
|
| 309 |
metadata["source"] = "Project"
|
| 310 |
else:
|
| 311 |
metadata["source"] = "Unknown"
|
| 312 |
-
|
| 313 |
-
return metadata
|
| 314 |
-
|
| 315 |
|
|
|
|
|
|
| 11 |
"""
|
| 12 |
|
| 13 |
import gc
|
|
|
|
| 14 |
import logging
|
| 15 |
+
import re
|
| 16 |
from pathlib import Path
|
| 17 |
+
from typing import Any, Dict, Generator, List, Optional, Tuple
|
| 18 |
|
| 19 |
from PIL import Image
|
| 20 |
|
|
|
|
| 24 |
class PDFProcessor:
|
| 25 |
"""
|
| 26 |
Process PDFs into images and text for visual retrieval.
|
| 27 |
+
|
| 28 |
Works independently - no embedding or storage dependencies.
|
| 29 |
+
|
| 30 |
Args:
|
| 31 |
dpi: DPI for image conversion (higher = better quality)
|
| 32 |
output_format: Image format (RGB, L, etc.)
|
| 33 |
page_batch_size: Pages per batch for memory efficiency
|
| 34 |
+
|
| 35 |
Example:
|
| 36 |
>>> processor = PDFProcessor(dpi=140)
|
| 37 |
+
>>>
|
| 38 |
>>> # Convert single PDF
|
| 39 |
>>> images, texts = processor.process_pdf(Path("report.pdf"))
|
| 40 |
+
>>>
|
| 41 |
>>> # Stream large PDFs
|
| 42 |
>>> for images, texts in processor.stream_pdf(Path("large.pdf"), batch_size=10):
|
| 43 |
... # Process each batch
|
| 44 |
... pass
|
| 45 |
"""
|
| 46 |
+
|
| 47 |
def __init__(
|
| 48 |
self,
|
| 49 |
dpi: int = 140,
|
|
|
|
| 53 |
self.dpi = dpi
|
| 54 |
self.output_format = output_format
|
| 55 |
self.page_batch_size = page_batch_size
|
| 56 |
+
|
| 57 |
+
# PDF deps are optional: we only require them when calling PDF-specific methods.
|
| 58 |
+
# This keeps the class usable for helper utilities like `resize_for_colpali()`
|
| 59 |
+
# even in minimal installs.
|
| 60 |
+
self._pdf_deps_available = True
|
| 61 |
try:
|
| 62 |
+
import pdf2image # noqa: F401
|
| 63 |
+
import pypdf # noqa: F401
|
| 64 |
+
except Exception:
|
| 65 |
+
self._pdf_deps_available = False
|
| 66 |
+
|
| 67 |
+
def _require_pdf_deps(self) -> None:
|
| 68 |
+
if not self._pdf_deps_available:
|
| 69 |
raise ImportError(
|
| 70 |
+
"PDF processing requires `pdf2image` and `pypdf`.\n"
|
| 71 |
+
'Install with: pip install "visual-rag-toolkit[pdf]"'
|
| 72 |
)
|
| 73 |
+
|
| 74 |
def process_pdf(
|
| 75 |
self,
|
| 76 |
pdf_path: Path,
|
|
|
|
| 78 |
) -> Tuple[List[Image.Image], List[str]]:
|
| 79 |
"""
|
| 80 |
Convert PDF to images and extract text.
|
| 81 |
+
|
| 82 |
Args:
|
| 83 |
pdf_path: Path to PDF file
|
| 84 |
dpi: Override default DPI
|
| 85 |
+
|
| 86 |
Returns:
|
| 87 |
Tuple of (list of images, list of page texts)
|
| 88 |
"""
|
| 89 |
+
self._require_pdf_deps()
|
| 90 |
from pdf2image import convert_from_path
|
| 91 |
from pypdf import PdfReader
|
| 92 |
+
|
| 93 |
dpi = dpi or self.dpi
|
| 94 |
pdf_path = Path(pdf_path)
|
| 95 |
+
|
| 96 |
logger.info(f"📄 Processing PDF: {pdf_path.name}")
|
| 97 |
+
|
| 98 |
# Extract text
|
| 99 |
reader = PdfReader(str(pdf_path))
|
| 100 |
total_pages = len(reader.pages)
|
| 101 |
+
|
| 102 |
page_texts = []
|
| 103 |
for page in reader.pages:
|
| 104 |
text = page.extract_text() or ""
|
| 105 |
# Handle surrogate characters
|
| 106 |
text = self._sanitize_text(text)
|
| 107 |
page_texts.append(text)
|
| 108 |
+
|
| 109 |
# Convert to images in batches
|
| 110 |
all_images = []
|
| 111 |
for start_page in range(1, total_pages + 1, self.page_batch_size):
|
| 112 |
end_page = min(start_page + self.page_batch_size - 1, total_pages)
|
| 113 |
+
|
| 114 |
batch_images = convert_from_path(
|
| 115 |
str(pdf_path),
|
| 116 |
dpi=dpi,
|
|
|
|
| 118 |
first_page=start_page,
|
| 119 |
last_page=end_page,
|
| 120 |
)
|
| 121 |
+
|
| 122 |
all_images.extend(batch_images)
|
| 123 |
+
|
| 124 |
del batch_images
|
| 125 |
gc.collect()
|
| 126 |
+
|
| 127 |
+
assert len(all_images) == len(
|
| 128 |
+
page_texts
|
| 129 |
+
), f"Mismatch: {len(all_images)} images vs {len(page_texts)} texts"
|
| 130 |
+
|
| 131 |
logger.info(f"✅ Processed {len(all_images)} pages")
|
| 132 |
return all_images, page_texts
|
| 133 |
+
|
| 134 |
def stream_pdf(
|
| 135 |
self,
|
| 136 |
pdf_path: Path,
|
|
|
|
| 139 |
) -> Generator[Tuple[List[Image.Image], List[str], int], None, None]:
|
| 140 |
"""
|
| 141 |
Stream PDF processing for large files.
|
| 142 |
+
|
| 143 |
Yields batches of (images, texts, start_page) without loading
|
| 144 |
entire PDF into memory.
|
| 145 |
+
|
| 146 |
Args:
|
| 147 |
pdf_path: Path to PDF file
|
| 148 |
batch_size: Pages per batch
|
| 149 |
dpi: Override default DPI
|
| 150 |
+
|
| 151 |
Yields:
|
| 152 |
Tuple of (batch_images, batch_texts, start_page_number)
|
| 153 |
"""
|
| 154 |
+
self._require_pdf_deps()
|
| 155 |
from pdf2image import convert_from_path
|
| 156 |
from pypdf import PdfReader
|
| 157 |
+
|
| 158 |
dpi = dpi or self.dpi
|
| 159 |
pdf_path = Path(pdf_path)
|
| 160 |
+
|
| 161 |
reader = PdfReader(str(pdf_path))
|
| 162 |
total_pages = len(reader.pages)
|
| 163 |
+
|
| 164 |
logger.info(f"📄 Streaming PDF: {pdf_path.name} ({total_pages} pages)")
|
| 165 |
+
|
| 166 |
for start_idx in range(0, total_pages, batch_size):
|
| 167 |
end_idx = min(start_idx + batch_size, total_pages)
|
| 168 |
+
|
| 169 |
# Extract text for batch
|
| 170 |
batch_texts = []
|
| 171 |
for page_idx in range(start_idx, end_idx):
|
| 172 |
text = reader.pages[page_idx].extract_text() or ""
|
| 173 |
text = self._sanitize_text(text)
|
| 174 |
batch_texts.append(text)
|
| 175 |
+
|
| 176 |
# Convert images for batch
|
| 177 |
batch_images = convert_from_path(
|
| 178 |
str(pdf_path),
|
|
|
|
| 181 |
first_page=start_idx + 1, # 1-indexed
|
| 182 |
last_page=end_idx,
|
| 183 |
)
|
| 184 |
+
|
| 185 |
yield batch_images, batch_texts, start_idx + 1
|
| 186 |
+
|
| 187 |
del batch_images
|
| 188 |
gc.collect()
|
| 189 |
+
|
| 190 |
def get_page_count(self, pdf_path: Path) -> int:
|
| 191 |
"""Get number of pages in PDF without loading images."""
|
| 192 |
+
self._require_pdf_deps()
|
| 193 |
from pypdf import PdfReader
|
| 194 |
+
|
| 195 |
reader = PdfReader(str(pdf_path))
|
| 196 |
return len(reader.pages)
|
| 197 |
+
|
| 198 |
def resize_for_colpali(
|
| 199 |
self,
|
| 200 |
image: Image.Image,
|
|
|
|
| 203 |
) -> Tuple[Image.Image, int, int]:
|
| 204 |
"""
|
| 205 |
Resize image following ColPali/Idefics3 processor logic.
|
| 206 |
+
|
| 207 |
Resizes to fit within tile grid without black padding.
|
| 208 |
+
|
| 209 |
Args:
|
| 210 |
image: PIL Image
|
| 211 |
max_edge: Maximum edge length
|
| 212 |
tile_size: Size of each tile
|
| 213 |
+
|
| 214 |
Returns:
|
| 215 |
Tuple of (resized_image, tile_rows, tile_cols)
|
| 216 |
"""
|
| 217 |
+
# Ensure consistent mode for downstream processors (and predictable tests)
|
| 218 |
+
if image.mode != "RGB":
|
| 219 |
+
image = image.convert("RGB")
|
| 220 |
+
|
| 221 |
w, h = image.size
|
| 222 |
+
|
| 223 |
# Step 1: Resize so longest edge = max_edge
|
| 224 |
if w > h:
|
| 225 |
new_w = max_edge
|
|
|
|
| 227 |
else:
|
| 228 |
new_h = max_edge
|
| 229 |
new_w = int(w * (max_edge / h))
|
| 230 |
+
|
| 231 |
# Step 2: Calculate tile grid
|
| 232 |
tile_cols = (new_w + tile_size - 1) // tile_size
|
| 233 |
tile_rows = (new_h + tile_size - 1) // tile_size
|
| 234 |
+
|
| 235 |
# Step 3: Calculate exact dimensions for tiles
|
| 236 |
final_w = tile_cols * tile_size
|
| 237 |
final_h = tile_rows * tile_size
|
| 238 |
+
|
| 239 |
# Step 4: Scale to fit within tile grid
|
| 240 |
scale_w = final_w / w
|
| 241 |
scale_h = final_h / h
|
| 242 |
scale = min(scale_w, scale_h)
|
| 243 |
+
|
| 244 |
scaled_w = int(w * scale)
|
| 245 |
scaled_h = int(h * scale)
|
| 246 |
+
|
| 247 |
resized = image.resize((scaled_w, scaled_h), Image.LANCZOS)
|
| 248 |
+
|
| 249 |
# Center on white canvas if needed
|
| 250 |
if scaled_w != final_w or scaled_h != final_h:
|
| 251 |
canvas = Image.new("RGB", (final_w, final_h), (255, 255, 255))
|
|
|
|
| 253 |
offset_y = (final_h - scaled_h) // 2
|
| 254 |
canvas.paste(resized, (offset_x, offset_y))
|
| 255 |
resized = canvas
|
| 256 |
+
|
| 257 |
return resized, tile_rows, tile_cols
|
| 258 |
+
|
| 259 |
def _sanitize_text(self, text: str) -> str:
|
| 260 |
"""Remove invalid Unicode characters (surrogates) from text."""
|
| 261 |
if not text:
|
| 262 |
return ""
|
| 263 |
+
|
| 264 |
# Remove surrogate characters (U+D800-U+DFFF)
|
| 265 |
+
return text.encode("utf-8", errors="surrogatepass").decode("utf-8", errors="ignore")
|
| 266 |
+
|
|
|
|
|
|
|
| 267 |
def extract_metadata_from_filename(
|
| 268 |
self,
|
| 269 |
filename: str,
|
|
|
|
| 271 |
) -> Dict[str, Any]:
|
| 272 |
"""
|
| 273 |
Extract metadata from PDF filename.
|
| 274 |
+
|
| 275 |
Uses mapping if provided, otherwise falls back to pattern matching.
|
| 276 |
+
|
| 277 |
Args:
|
| 278 |
filename: PDF filename (with or without .pdf extension)
|
| 279 |
mapping: Optional mapping dict {filename: metadata}
|
| 280 |
+
|
| 281 |
Returns:
|
| 282 |
Metadata dict with year, source, district, etc.
|
| 283 |
"""
|
| 284 |
# Remove extension
|
| 285 |
stem = Path(filename).stem
|
| 286 |
stem_lower = stem.lower().strip()
|
| 287 |
+
|
| 288 |
# Try mapping first
|
| 289 |
if mapping:
|
| 290 |
if stem_lower in mapping:
|
| 291 |
return mapping[stem_lower].copy()
|
| 292 |
+
|
| 293 |
# Try without .pdf
|
| 294 |
stem_no_ext = stem_lower.replace(".pdf", "")
|
| 295 |
if stem_no_ext in mapping:
|
| 296 |
return mapping[stem_no_ext].copy()
|
| 297 |
+
|
| 298 |
# Fallback: pattern matching
|
| 299 |
metadata = {"filename": filename}
|
| 300 |
+
|
| 301 |
# Extract year
|
| 302 |
year_match = re.search(r"(20\d{2})", stem)
|
| 303 |
if year_match:
|
| 304 |
metadata["year"] = int(year_match.group(1))
|
| 305 |
+
|
| 306 |
# Detect source type
|
| 307 |
if "consolidated" in stem_lower or ("annual" in stem_lower and "oag" in stem_lower):
|
| 308 |
metadata["source"] = "Consolidated"
|
| 309 |
elif "dlg" in stem_lower or "district local government" in stem_lower:
|
| 310 |
metadata["source"] = "Local Government"
|
| 311 |
# Try to extract district name
|
| 312 |
+
district_match = re.search(r"([a-z]+)\s+(?:dlg|district local government)", stem_lower)
|
|
|
|
|
|
|
| 313 |
if district_match:
|
| 314 |
metadata["district"] = district_match.group(1).title()
|
| 315 |
elif "hospital" in stem_lower or "referral" in stem_lower:
|
|
|
|
| 320 |
metadata["source"] = "Project"
|
| 321 |
else:
|
| 322 |
metadata["source"] = "Unknown"
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
+
return metadata
|
visual_rag/indexing/pipeline.py
CHANGED
|
@@ -16,11 +16,10 @@ The metadata stored includes everything needed for saliency visualization:
|
|
| 16 |
"""
|
| 17 |
|
| 18 |
import gc
|
| 19 |
-
import time
|
| 20 |
import hashlib
|
| 21 |
import logging
|
| 22 |
from pathlib import Path
|
| 23 |
-
from typing import
|
| 24 |
|
| 25 |
import numpy as np
|
| 26 |
import torch
|
|
@@ -31,7 +30,7 @@ logger = logging.getLogger(__name__)
|
|
| 31 |
class ProcessingPipeline:
|
| 32 |
"""
|
| 33 |
End-to-end pipeline for PDF processing and indexing.
|
| 34 |
-
|
| 35 |
This pipeline:
|
| 36 |
1. Converts PDFs to images
|
| 37 |
2. Resizes for ColPali processing
|
|
@@ -39,7 +38,7 @@ class ProcessingPipeline:
|
|
| 39 |
4. Computes pooling (strategy-dependent)
|
| 40 |
5. Uploads images to Cloudinary (optional)
|
| 41 |
6. Stores in Qdrant with full saliency metadata
|
| 42 |
-
|
| 43 |
Args:
|
| 44 |
embedder: VisualEmbedder instance
|
| 45 |
indexer: QdrantIndexer instance (optional)
|
|
@@ -52,34 +51,34 @@ class ProcessingPipeline:
|
|
| 52 |
This is our NOVEL contribution - preserves spatial structure while reducing size.
|
| 53 |
- "standard": Push ALL tokens as-is (including special tokens, padding)
|
| 54 |
This is the baseline approach for comparison.
|
| 55 |
-
|
| 56 |
Example:
|
| 57 |
>>> from visual_rag import VisualEmbedder, QdrantIndexer, CloudinaryUploader
|
| 58 |
>>> from visual_rag.indexing.pipeline import ProcessingPipeline
|
| 59 |
-
>>>
|
| 60 |
>>> # Our novel pooling strategy (default)
|
| 61 |
>>> pipeline = ProcessingPipeline(
|
| 62 |
... embedder=VisualEmbedder(),
|
| 63 |
... indexer=QdrantIndexer(url, api_key, "my_collection"),
|
| 64 |
... embedding_strategy="pooling", # Visual tokens only + tile pooling
|
| 65 |
... )
|
| 66 |
-
>>>
|
| 67 |
>>> # Standard baseline (all tokens, no filtering)
|
| 68 |
>>> pipeline_baseline = ProcessingPipeline(
|
| 69 |
... embedder=VisualEmbedder(),
|
| 70 |
... indexer=QdrantIndexer(url, api_key, "my_collection_baseline"),
|
| 71 |
... embedding_strategy="standard", # All tokens as-is
|
| 72 |
... )
|
| 73 |
-
>>>
|
| 74 |
>>> pipeline.process_pdf(Path("report.pdf"))
|
| 75 |
"""
|
| 76 |
-
|
| 77 |
# Valid embedding strategies
|
| 78 |
# - "pooling": Visual tokens only + tile-level pooling (NOVEL)
|
| 79 |
# - "standard": All tokens + global mean (BASELINE)
|
| 80 |
# - "all": Embed once, push BOTH representations (efficient comparison)
|
| 81 |
STRATEGIES = ["pooling", "standard", "all"]
|
| 82 |
-
|
| 83 |
def __init__(
|
| 84 |
self,
|
| 85 |
embedder=None,
|
|
@@ -92,13 +91,15 @@ class ProcessingPipeline:
|
|
| 92 |
crop_empty: bool = False,
|
| 93 |
crop_empty_percentage_to_remove: float = 0.9,
|
| 94 |
crop_empty_remove_page_number: bool = False,
|
|
|
|
|
|
|
| 95 |
):
|
| 96 |
self.embedder = embedder
|
| 97 |
self.indexer = indexer
|
| 98 |
self.cloudinary_uploader = cloudinary_uploader
|
| 99 |
self.metadata_mapping = metadata_mapping or {}
|
| 100 |
self.config = config or {}
|
| 101 |
-
|
| 102 |
# Validate and set embedding strategy
|
| 103 |
if embedding_strategy not in self.STRATEGIES:
|
| 104 |
raise ValueError(
|
|
@@ -110,41 +111,50 @@ class ProcessingPipeline:
|
|
| 110 |
self.crop_empty = bool(crop_empty)
|
| 111 |
self.crop_empty_percentage_to_remove = float(crop_empty_percentage_to_remove)
|
| 112 |
self.crop_empty_remove_page_number = bool(crop_empty_remove_page_number)
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
logger.info(f"📊 Embedding strategy: {embedding_strategy}")
|
| 115 |
if embedding_strategy == "pooling":
|
| 116 |
logger.info(" → Visual tokens only + tile-level mean pooling (NOVEL)")
|
| 117 |
else:
|
| 118 |
logger.info(" → All tokens as-is (BASELINE)")
|
| 119 |
-
|
| 120 |
# Create PDF processor if not provided
|
| 121 |
if pdf_processor is None:
|
| 122 |
from visual_rag.indexing.pdf_processor import PDFProcessor
|
|
|
|
| 123 |
dpi = self.config.get("processing", {}).get("dpi", 140)
|
| 124 |
pdf_processor = PDFProcessor(dpi=dpi)
|
| 125 |
self.pdf_processor = pdf_processor
|
| 126 |
-
|
| 127 |
# Config defaults
|
| 128 |
self.embedding_batch_size = self.config.get("batching", {}).get("embedding_batch_size", 8)
|
| 129 |
self.upload_batch_size = self.config.get("batching", {}).get("upload_batch_size", 8)
|
| 130 |
self.delay_between_uploads = self.config.get("delays", {}).get("between_uploads", 0.5)
|
| 131 |
-
|
| 132 |
def process_pdf(
|
| 133 |
self,
|
| 134 |
pdf_path: Path,
|
| 135 |
skip_existing: bool = True,
|
| 136 |
upload_to_cloudinary: bool = True,
|
| 137 |
upload_to_qdrant: bool = True,
|
|
|
|
|
|
|
| 138 |
) -> Dict[str, Any]:
|
| 139 |
"""
|
| 140 |
Process a single PDF end-to-end.
|
| 141 |
-
|
| 142 |
Args:
|
| 143 |
pdf_path: Path to PDF file
|
| 144 |
skip_existing: Skip pages that already exist in Qdrant
|
| 145 |
upload_to_cloudinary: Upload images to Cloudinary
|
| 146 |
upload_to_qdrant: Upload embeddings to Qdrant
|
| 147 |
-
|
|
|
|
|
|
|
| 148 |
Returns:
|
| 149 |
Dict with processing results:
|
| 150 |
{
|
|
@@ -157,62 +167,73 @@ class ProcessingPipeline:
|
|
| 157 |
}
|
| 158 |
"""
|
| 159 |
pdf_path = Path(pdf_path)
|
| 160 |
-
|
| 161 |
-
|
|
|
|
| 162 |
# Check existing pages
|
| 163 |
existing_ids: Set[str] = set()
|
| 164 |
if skip_existing and self.indexer:
|
| 165 |
-
existing_ids = self.indexer.get_existing_ids(
|
| 166 |
if existing_ids:
|
| 167 |
logger.info(f" Found {len(existing_ids)} existing pages")
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
| 171 |
images, texts = self.pdf_processor.process_pdf(pdf_path)
|
| 172 |
total_pages = len(images)
|
| 173 |
logger.info(f" ✅ Converted {total_pages} pages")
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
| 177 |
if extra_metadata:
|
| 178 |
logger.info(f" 📋 Found extra metadata: {list(extra_metadata.keys())}")
|
| 179 |
-
|
| 180 |
# Process in batches
|
| 181 |
uploaded = 0
|
| 182 |
skipped = 0
|
| 183 |
failed = 0
|
| 184 |
all_pages = []
|
| 185 |
upload_queue = []
|
| 186 |
-
|
| 187 |
for batch_start in range(0, total_pages, self.embedding_batch_size):
|
| 188 |
batch_end = min(batch_start + self.embedding_batch_size, total_pages)
|
| 189 |
batch_images = images[batch_start:batch_end]
|
| 190 |
batch_texts = texts[batch_start:batch_end]
|
| 191 |
-
|
| 192 |
logger.info(f"📦 Processing pages {batch_start + 1}-{batch_end}/{total_pages}")
|
| 193 |
-
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
pages_to_process = []
|
| 196 |
for i, (img, text) in enumerate(zip(batch_images, batch_texts)):
|
| 197 |
page_num = batch_start + i + 1
|
| 198 |
-
chunk_id = self.generate_chunk_id(
|
| 199 |
-
|
| 200 |
if skip_existing and chunk_id in existing_ids:
|
| 201 |
skipped += 1
|
| 202 |
continue
|
| 203 |
-
|
| 204 |
-
pages_to_process.append(
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
| 212 |
if not pages_to_process:
|
| 213 |
logger.info(" All pages in batch exist, skipping...")
|
| 214 |
continue
|
| 215 |
-
|
| 216 |
# Generate embeddings with token info
|
| 217 |
logger.info(f"🤖 Generating embeddings for {len(pages_to_process)} pages...")
|
| 218 |
from visual_rag.preprocessing.crop_empty import CropEmptyConfig, crop_empty
|
|
@@ -226,6 +247,10 @@ class ProcessingPipeline:
|
|
| 226 |
config=CropEmptyConfig(
|
| 227 |
percentage_to_remove=float(self.crop_empty_percentage_to_remove),
|
| 228 |
remove_page_number=bool(self.crop_empty_remove_page_number),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
),
|
| 230 |
)
|
| 231 |
p["embed_image"] = cropped_img
|
|
@@ -235,15 +260,14 @@ class ProcessingPipeline:
|
|
| 235 |
p["embed_image"] = raw_img
|
| 236 |
p["crop_meta"] = None
|
| 237 |
images_to_embed.append(raw_img)
|
| 238 |
-
|
| 239 |
embeddings, token_infos = self.embedder.embed_images(
|
| 240 |
images_to_embed,
|
| 241 |
batch_size=self.embedding_batch_size,
|
| 242 |
return_token_info=True,
|
| 243 |
-
show_progress=
|
| 244 |
)
|
| 245 |
-
|
| 246 |
-
# Process each page
|
| 247 |
for idx, page_info in enumerate(pages_to_process):
|
| 248 |
raw_img = page_info["raw_image"]
|
| 249 |
embed_img = page_info["embed_image"]
|
|
@@ -253,10 +277,19 @@ class ProcessingPipeline:
|
|
| 253 |
text = page_info["text"]
|
| 254 |
embedding = embeddings[idx]
|
| 255 |
token_info = token_infos[idx]
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
try:
|
| 258 |
page_data = self._process_single_page(
|
| 259 |
-
|
|
|
|
| 260 |
page_num=page_num,
|
| 261 |
chunk_id=chunk_id,
|
| 262 |
total_pages=total_pages,
|
|
@@ -269,46 +302,49 @@ class ProcessingPipeline:
|
|
| 269 |
upload_to_cloudinary=upload_to_cloudinary,
|
| 270 |
crop_meta=crop_meta,
|
| 271 |
)
|
| 272 |
-
|
| 273 |
all_pages.append(page_data)
|
| 274 |
-
|
| 275 |
if upload_to_qdrant and self.indexer:
|
| 276 |
upload_queue.append(page_data)
|
| 277 |
-
|
| 278 |
# Upload in batches
|
| 279 |
if len(upload_queue) >= self.upload_batch_size:
|
| 280 |
count = self._upload_batch(upload_queue)
|
| 281 |
uploaded += count
|
| 282 |
upload_queue = []
|
| 283 |
-
|
| 284 |
except Exception as e:
|
| 285 |
logger.error(f" ❌ Failed page {page_num}: {e}")
|
| 286 |
failed += 1
|
| 287 |
-
|
| 288 |
# Memory cleanup
|
| 289 |
gc.collect()
|
| 290 |
if torch.cuda.is_available():
|
| 291 |
torch.cuda.empty_cache()
|
| 292 |
-
|
| 293 |
# Upload remaining pages
|
| 294 |
if upload_queue and upload_to_qdrant and self.indexer:
|
| 295 |
count = self._upload_batch(upload_queue)
|
| 296 |
uploaded += count
|
| 297 |
-
|
| 298 |
-
logger.info(
|
| 299 |
-
|
|
|
|
|
|
|
| 300 |
return {
|
| 301 |
-
"filename":
|
| 302 |
"total_pages": total_pages,
|
| 303 |
"uploaded": uploaded,
|
| 304 |
"skipped": skipped,
|
| 305 |
"failed": failed,
|
| 306 |
"pages": all_pages,
|
| 307 |
}
|
| 308 |
-
|
| 309 |
def _process_single_page(
|
| 310 |
self,
|
| 311 |
-
|
|
|
|
| 312 |
page_num: int,
|
| 313 |
chunk_id: str,
|
| 314 |
total_pages: int,
|
|
@@ -323,17 +359,17 @@ class ProcessingPipeline:
|
|
| 323 |
) -> Dict[str, Any]:
|
| 324 |
"""Process a single page with full metadata for saliency."""
|
| 325 |
from visual_rag.embedding.pooling import global_mean_pooling
|
| 326 |
-
|
| 327 |
# Resize image for ColPali
|
| 328 |
resized_img, tile_rows, tile_cols = self.pdf_processor.resize_for_colpali(embed_img)
|
| 329 |
-
|
| 330 |
# Use processor's tile info if available (more accurate)
|
| 331 |
proc_n_rows = token_info.get("n_rows")
|
| 332 |
proc_n_cols = token_info.get("n_cols")
|
| 333 |
if proc_n_rows and proc_n_cols:
|
| 334 |
tile_rows = proc_n_rows
|
| 335 |
tile_cols = proc_n_cols
|
| 336 |
-
|
| 337 |
# Convert embedding to numpy
|
| 338 |
if isinstance(embedding, torch.Tensor):
|
| 339 |
if embedding.dtype == torch.bfloat16:
|
|
@@ -343,24 +379,30 @@ class ProcessingPipeline:
|
|
| 343 |
else:
|
| 344 |
full_embedding = np.array(embedding)
|
| 345 |
full_embedding = full_embedding.astype(np.float32)
|
| 346 |
-
|
| 347 |
# Token info for metadata
|
| 348 |
visual_indices = token_info["visual_token_indices"]
|
| 349 |
num_visual_tokens = token_info["num_visual_tokens"]
|
| 350 |
-
|
| 351 |
# =========================================================================
|
| 352 |
# STRATEGY: "pooling" (NOVEL) vs "standard" (BASELINE) vs "all" (BOTH)
|
| 353 |
# =========================================================================
|
| 354 |
-
|
| 355 |
# Always compute visual-only embedding (needed for pooling and saliency)
|
| 356 |
visual_embedding = full_embedding[visual_indices]
|
| 357 |
-
|
| 358 |
-
tile_pooled = self.embedder.mean_pool_visual_embedding(
|
|
|
|
|
|
|
| 359 |
experimental_pooled = self.embedder.experimental_pool_visual_embedding(
|
| 360 |
visual_embedding, token_info, target_vectors=32, mean_pool=tile_pooled
|
| 361 |
)
|
| 362 |
global_pooled = global_mean_pooling(full_embedding)
|
| 363 |
-
global_pooling =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
num_tiles = int(tile_pooled.shape[0])
|
| 366 |
patches_per_tile = int(visual_embedding.shape[0] // max(num_tiles, 1)) if num_tiles else 0
|
|
@@ -369,64 +411,70 @@ class ProcessingPipeline:
|
|
| 369 |
else:
|
| 370 |
tile_rows = token_info.get("n_rows") or None
|
| 371 |
tile_cols = token_info.get("n_cols") or None
|
| 372 |
-
|
| 373 |
if self.embedding_strategy == "pooling":
|
| 374 |
# NOVEL APPROACH: Visual tokens only + tile-level pooling
|
| 375 |
embedding_for_initial = visual_embedding
|
| 376 |
embedding_for_pooling = tile_pooled
|
| 377 |
-
global_pooling =
|
| 378 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
elif self.embedding_strategy == "standard":
|
| 380 |
# BASELINE: All tokens + global mean
|
| 381 |
embedding_for_initial = full_embedding
|
| 382 |
embedding_for_pooling = global_pooled.reshape(1, -1)
|
| 383 |
global_pooling = global_pooled
|
| 384 |
-
|
| 385 |
else: # "all" - Push BOTH representations (efficient for comparison)
|
| 386 |
# Embed once, store multiple vector representations
|
| 387 |
# This allows comparing both strategies without re-embedding
|
| 388 |
embedding_for_initial = visual_embedding # Use visual for search
|
| 389 |
-
embedding_for_pooling = tile_pooled
|
| 390 |
-
global_pooling =
|
| 391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
# ALSO store standard representations as additional vectors
|
| 393 |
# These will be added to metadata for optional use
|
| 394 |
pass # Extra vectors handled in return dict below
|
| 395 |
-
|
| 396 |
# Upload to Cloudinary
|
| 397 |
original_url = None
|
| 398 |
cropped_url = None
|
| 399 |
resized_url = None
|
| 400 |
-
|
| 401 |
if upload_to_cloudinary and self.cloudinary_uploader:
|
| 402 |
-
base_filename = f"{
|
| 403 |
if self.crop_empty:
|
| 404 |
-
original_url, cropped_url, resized_url =
|
| 405 |
-
|
|
|
|
|
|
|
| 406 |
)
|
| 407 |
else:
|
| 408 |
original_url, resized_url = self.cloudinary_uploader.upload_original_and_resized(
|
| 409 |
raw_img, resized_img, base_filename
|
| 410 |
)
|
| 411 |
-
|
| 412 |
# Sanitize text
|
| 413 |
safe_text = self._sanitize_text(text[:10000]) if text else ""
|
| 414 |
-
|
| 415 |
-
# Build metadata (everything needed for saliency)
|
| 416 |
metadata = {
|
| 417 |
-
|
| 418 |
-
"filename": pdf_path.name,
|
| 419 |
"page_number": page_num,
|
| 420 |
"total_pages": total_pages,
|
| 421 |
"has_text": bool(text and text.strip()),
|
| 422 |
"text": safe_text,
|
| 423 |
-
|
| 424 |
# Image URLs
|
| 425 |
"page": resized_url or "", # For display
|
| 426 |
"original_url": original_url or "",
|
| 427 |
"cropped_url": cropped_url or "",
|
| 428 |
"resized_url": resized_url or "",
|
| 429 |
-
|
| 430 |
# Dimensions (needed for saliency overlay)
|
| 431 |
"original_width": raw_img.width,
|
| 432 |
"original_height": raw_img.height,
|
|
@@ -434,35 +482,33 @@ class ProcessingPipeline:
|
|
| 434 |
"cropped_height": int(embed_img.height) if self.crop_empty else int(raw_img.height),
|
| 435 |
"resized_width": resized_img.width,
|
| 436 |
"resized_height": resized_img.height,
|
| 437 |
-
|
| 438 |
# Tile structure (needed for saliency)
|
| 439 |
"num_tiles": num_tiles,
|
| 440 |
"tile_rows": tile_rows,
|
| 441 |
"tile_cols": tile_cols,
|
| 442 |
"patches_per_tile": patches_per_tile,
|
| 443 |
-
|
| 444 |
# Token info (needed for saliency)
|
| 445 |
"num_visual_tokens": num_visual_tokens,
|
| 446 |
"visual_token_indices": visual_indices,
|
| 447 |
"total_tokens": len(full_embedding), # Total tokens in raw embedding
|
| 448 |
-
|
| 449 |
# Strategy used (important for paper comparison)
|
| 450 |
"embedding_strategy": self.embedding_strategy,
|
| 451 |
-
|
| 452 |
"model_name": getattr(self.embedder, "model_name", None),
|
| 453 |
-
|
| 454 |
"crop_empty_enabled": bool(self.crop_empty),
|
| 455 |
"crop_empty_crop_box": (crop_meta or {}).get("crop_box"),
|
| 456 |
"crop_empty_remove_page_number": bool(self.crop_empty_remove_page_number),
|
| 457 |
"crop_empty_percentage_to_remove": float(self.crop_empty_percentage_to_remove),
|
| 458 |
-
|
|
|
|
|
|
|
|
|
|
| 459 |
# Extra metadata (year, district, etc.)
|
| 460 |
**extra_metadata,
|
| 461 |
}
|
| 462 |
-
|
| 463 |
result = {
|
| 464 |
"id": chunk_id,
|
| 465 |
-
"visual_embedding": embedding_for_initial,
|
| 466 |
"tile_pooled_embedding": embedding_for_pooling, # "mean_pooling" vector in Qdrant
|
| 467 |
"experimental_pooled_embedding": experimental_pooled, # "experimental_pooling" vector in Qdrant
|
| 468 |
"global_pooled_embedding": global_pooling, # "global_pooling" vector in Qdrant
|
|
@@ -470,70 +516,70 @@ class ProcessingPipeline:
|
|
| 470 |
"image": raw_img,
|
| 471 |
"resized_image": resized_img,
|
| 472 |
}
|
| 473 |
-
|
| 474 |
# For "all" strategy, include BOTH representations for comparison
|
| 475 |
if self.embedding_strategy == "all":
|
| 476 |
result["extra_vectors"] = {
|
| 477 |
# Standard baseline vectors (for comparison)
|
| 478 |
-
"full_embedding": full_embedding,
|
| 479 |
-
"global_pooled": global_pooled,
|
| 480 |
# Pooling vectors (already in main result)
|
| 481 |
-
"visual_embedding": visual_embedding,
|
| 482 |
-
"tile_pooled": tile_pooled,
|
| 483 |
}
|
| 484 |
-
|
| 485 |
return result
|
| 486 |
-
|
| 487 |
def _upload_batch(self, upload_queue: List[Dict[str, Any]]) -> int:
|
| 488 |
"""Upload batch to Qdrant."""
|
| 489 |
if not upload_queue or not self.indexer:
|
| 490 |
return 0
|
| 491 |
-
|
| 492 |
logger.info(f"📤 Uploading batch of {len(upload_queue)} pages...")
|
| 493 |
-
|
| 494 |
count = self.indexer.upload_batch(
|
| 495 |
upload_queue,
|
| 496 |
delay_between_batches=self.delay_between_uploads,
|
| 497 |
)
|
| 498 |
-
|
| 499 |
return count
|
| 500 |
-
|
| 501 |
def _get_extra_metadata(self, filename: str) -> Dict[str, Any]:
|
| 502 |
"""Get extra metadata for a filename."""
|
| 503 |
if not self.metadata_mapping:
|
| 504 |
return {}
|
| 505 |
-
|
| 506 |
# Normalize filename
|
| 507 |
filename_clean = filename.replace(".pdf", "").replace(".PDF", "").strip().lower()
|
| 508 |
-
|
| 509 |
# Try exact match
|
| 510 |
if filename_clean in self.metadata_mapping:
|
| 511 |
return self.metadata_mapping[filename_clean].copy()
|
| 512 |
-
|
| 513 |
# Try fuzzy match
|
| 514 |
from difflib import SequenceMatcher
|
| 515 |
-
|
| 516 |
best_match = None
|
| 517 |
best_score = 0.0
|
| 518 |
-
|
| 519 |
for known_filename, metadata in self.metadata_mapping.items():
|
| 520 |
score = SequenceMatcher(None, filename_clean, known_filename.lower()).ratio()
|
| 521 |
if score > best_score and score > 0.75:
|
| 522 |
best_score = score
|
| 523 |
best_match = metadata
|
| 524 |
-
|
| 525 |
if best_match:
|
| 526 |
logger.debug(f"Fuzzy matched '{filename}' with score {best_score:.2f}")
|
| 527 |
return best_match.copy()
|
| 528 |
-
|
| 529 |
return {}
|
| 530 |
-
|
| 531 |
def _sanitize_text(self, text: str) -> str:
|
| 532 |
"""Remove invalid Unicode characters."""
|
| 533 |
if not text:
|
| 534 |
return ""
|
| 535 |
return text.encode("utf-8", errors="surrogatepass").decode("utf-8", errors="ignore")
|
| 536 |
-
|
| 537 |
@staticmethod
|
| 538 |
def generate_chunk_id(filename: str, page_number: int) -> str:
|
| 539 |
"""Generate deterministic chunk ID."""
|
|
@@ -541,12 +587,12 @@ class ProcessingPipeline:
|
|
| 541 |
hash_obj = hashlib.sha256(content.encode())
|
| 542 |
hex_str = hash_obj.hexdigest()[:32]
|
| 543 |
return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
|
| 544 |
-
|
| 545 |
@staticmethod
|
| 546 |
def load_metadata_mapping(json_path: Path) -> Dict[str, Dict[str, Any]]:
|
| 547 |
"""
|
| 548 |
Load metadata mapping from JSON file.
|
| 549 |
-
|
| 550 |
Expected format:
|
| 551 |
{
|
| 552 |
"filenames": {
|
|
@@ -554,7 +600,7 @@ class ProcessingPipeline:
|
|
| 554 |
...
|
| 555 |
}
|
| 556 |
}
|
| 557 |
-
|
| 558 |
Or simple format:
|
| 559 |
{
|
| 560 |
"Report Name 2023": {"year": 2023, "source": "Local Government", ...},
|
|
@@ -562,22 +608,21 @@ class ProcessingPipeline:
|
|
| 562 |
}
|
| 563 |
"""
|
| 564 |
import json
|
| 565 |
-
|
| 566 |
with open(json_path, "r") as f:
|
| 567 |
data = json.load(f)
|
| 568 |
-
|
| 569 |
# Check if nested under "filenames"
|
| 570 |
if "filenames" in data and isinstance(data["filenames"], dict):
|
| 571 |
mapping = data["filenames"]
|
| 572 |
else:
|
| 573 |
mapping = data
|
| 574 |
-
|
| 575 |
# Normalize keys to lowercase
|
| 576 |
normalized = {}
|
| 577 |
for filename, metadata in mapping.items():
|
| 578 |
key = filename.lower().strip().replace(".pdf", "")
|
| 579 |
normalized[key] = metadata
|
| 580 |
-
|
| 581 |
logger.info(f"📖 Loaded metadata for {len(normalized)} files")
|
| 582 |
return normalized
|
| 583 |
-
|
|
|
|
| 16 |
"""
|
| 17 |
|
| 18 |
import gc
|
|
|
|
| 19 |
import hashlib
|
| 20 |
import logging
|
| 21 |
from pathlib import Path
|
| 22 |
+
from typing import Any, Dict, List, Optional, Set
|
| 23 |
|
| 24 |
import numpy as np
|
| 25 |
import torch
|
|
|
|
| 30 |
class ProcessingPipeline:
|
| 31 |
"""
|
| 32 |
End-to-end pipeline for PDF processing and indexing.
|
| 33 |
+
|
| 34 |
This pipeline:
|
| 35 |
1. Converts PDFs to images
|
| 36 |
2. Resizes for ColPali processing
|
|
|
|
| 38 |
4. Computes pooling (strategy-dependent)
|
| 39 |
5. Uploads images to Cloudinary (optional)
|
| 40 |
6. Stores in Qdrant with full saliency metadata
|
| 41 |
+
|
| 42 |
Args:
|
| 43 |
embedder: VisualEmbedder instance
|
| 44 |
indexer: QdrantIndexer instance (optional)
|
|
|
|
| 51 |
This is our NOVEL contribution - preserves spatial structure while reducing size.
|
| 52 |
- "standard": Push ALL tokens as-is (including special tokens, padding)
|
| 53 |
This is the baseline approach for comparison.
|
| 54 |
+
|
| 55 |
Example:
|
| 56 |
>>> from visual_rag import VisualEmbedder, QdrantIndexer, CloudinaryUploader
|
| 57 |
>>> from visual_rag.indexing.pipeline import ProcessingPipeline
|
| 58 |
+
>>>
|
| 59 |
>>> # Our novel pooling strategy (default)
|
| 60 |
>>> pipeline = ProcessingPipeline(
|
| 61 |
... embedder=VisualEmbedder(),
|
| 62 |
... indexer=QdrantIndexer(url, api_key, "my_collection"),
|
| 63 |
... embedding_strategy="pooling", # Visual tokens only + tile pooling
|
| 64 |
... )
|
| 65 |
+
>>>
|
| 66 |
>>> # Standard baseline (all tokens, no filtering)
|
| 67 |
>>> pipeline_baseline = ProcessingPipeline(
|
| 68 |
... embedder=VisualEmbedder(),
|
| 69 |
... indexer=QdrantIndexer(url, api_key, "my_collection_baseline"),
|
| 70 |
... embedding_strategy="standard", # All tokens as-is
|
| 71 |
... )
|
| 72 |
+
>>>
|
| 73 |
>>> pipeline.process_pdf(Path("report.pdf"))
|
| 74 |
"""
|
| 75 |
+
|
| 76 |
# Valid embedding strategies
|
| 77 |
# - "pooling": Visual tokens only + tile-level pooling (NOVEL)
|
| 78 |
# - "standard": All tokens + global mean (BASELINE)
|
| 79 |
# - "all": Embed once, push BOTH representations (efficient comparison)
|
| 80 |
STRATEGIES = ["pooling", "standard", "all"]
|
| 81 |
+
|
| 82 |
def __init__(
|
| 83 |
self,
|
| 84 |
embedder=None,
|
|
|
|
| 91 |
crop_empty: bool = False,
|
| 92 |
crop_empty_percentage_to_remove: float = 0.9,
|
| 93 |
crop_empty_remove_page_number: bool = False,
|
| 94 |
+
crop_empty_preserve_border_px: int = 1,
|
| 95 |
+
crop_empty_uniform_rowcol_std_threshold: float = 0.0,
|
| 96 |
):
|
| 97 |
self.embedder = embedder
|
| 98 |
self.indexer = indexer
|
| 99 |
self.cloudinary_uploader = cloudinary_uploader
|
| 100 |
self.metadata_mapping = metadata_mapping or {}
|
| 101 |
self.config = config or {}
|
| 102 |
+
|
| 103 |
# Validate and set embedding strategy
|
| 104 |
if embedding_strategy not in self.STRATEGIES:
|
| 105 |
raise ValueError(
|
|
|
|
| 111 |
self.crop_empty = bool(crop_empty)
|
| 112 |
self.crop_empty_percentage_to_remove = float(crop_empty_percentage_to_remove)
|
| 113 |
self.crop_empty_remove_page_number = bool(crop_empty_remove_page_number)
|
| 114 |
+
self.crop_empty_preserve_border_px = int(crop_empty_preserve_border_px)
|
| 115 |
+
self.crop_empty_uniform_rowcol_std_threshold = float(
|
| 116 |
+
crop_empty_uniform_rowcol_std_threshold
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
logger.info(f"📊 Embedding strategy: {embedding_strategy}")
|
| 120 |
if embedding_strategy == "pooling":
|
| 121 |
logger.info(" → Visual tokens only + tile-level mean pooling (NOVEL)")
|
| 122 |
else:
|
| 123 |
logger.info(" → All tokens as-is (BASELINE)")
|
| 124 |
+
|
| 125 |
# Create PDF processor if not provided
|
| 126 |
if pdf_processor is None:
|
| 127 |
from visual_rag.indexing.pdf_processor import PDFProcessor
|
| 128 |
+
|
| 129 |
dpi = self.config.get("processing", {}).get("dpi", 140)
|
| 130 |
pdf_processor = PDFProcessor(dpi=dpi)
|
| 131 |
self.pdf_processor = pdf_processor
|
| 132 |
+
|
| 133 |
# Config defaults
|
| 134 |
self.embedding_batch_size = self.config.get("batching", {}).get("embedding_batch_size", 8)
|
| 135 |
self.upload_batch_size = self.config.get("batching", {}).get("upload_batch_size", 8)
|
| 136 |
self.delay_between_uploads = self.config.get("delays", {}).get("between_uploads", 0.5)
|
| 137 |
+
|
| 138 |
def process_pdf(
|
| 139 |
self,
|
| 140 |
pdf_path: Path,
|
| 141 |
skip_existing: bool = True,
|
| 142 |
upload_to_cloudinary: bool = True,
|
| 143 |
upload_to_qdrant: bool = True,
|
| 144 |
+
original_filename: Optional[str] = None,
|
| 145 |
+
progress_callback: Optional[callable] = None,
|
| 146 |
) -> Dict[str, Any]:
|
| 147 |
"""
|
| 148 |
Process a single PDF end-to-end.
|
| 149 |
+
|
| 150 |
Args:
|
| 151 |
pdf_path: Path to PDF file
|
| 152 |
skip_existing: Skip pages that already exist in Qdrant
|
| 153 |
upload_to_cloudinary: Upload images to Cloudinary
|
| 154 |
upload_to_qdrant: Upload embeddings to Qdrant
|
| 155 |
+
original_filename: Original filename (use this instead of pdf_path.name for temp files)
|
| 156 |
+
progress_callback: Optional callback(stage, current, total, message) for progress updates
|
| 157 |
+
|
| 158 |
Returns:
|
| 159 |
Dict with processing results:
|
| 160 |
{
|
|
|
|
| 167 |
}
|
| 168 |
"""
|
| 169 |
pdf_path = Path(pdf_path)
|
| 170 |
+
filename = original_filename or pdf_path.name
|
| 171 |
+
logger.info(f"📚 Processing PDF: {filename}")
|
| 172 |
+
|
| 173 |
# Check existing pages
|
| 174 |
existing_ids: Set[str] = set()
|
| 175 |
if skip_existing and self.indexer:
|
| 176 |
+
existing_ids = self.indexer.get_existing_ids(filename)
|
| 177 |
if existing_ids:
|
| 178 |
logger.info(f" Found {len(existing_ids)} existing pages")
|
| 179 |
+
|
| 180 |
+
logger.info("🖼️ Converting PDF to images...")
|
| 181 |
+
if progress_callback:
|
| 182 |
+
progress_callback("convert", 0, 0, "Converting PDF to images...")
|
| 183 |
images, texts = self.pdf_processor.process_pdf(pdf_path)
|
| 184 |
total_pages = len(images)
|
| 185 |
logger.info(f" ✅ Converted {total_pages} pages")
|
| 186 |
+
if progress_callback:
|
| 187 |
+
progress_callback("convert", total_pages, total_pages, f"Converted {total_pages} pages")
|
| 188 |
+
|
| 189 |
+
extra_metadata = self._get_extra_metadata(filename)
|
| 190 |
if extra_metadata:
|
| 191 |
logger.info(f" 📋 Found extra metadata: {list(extra_metadata.keys())}")
|
| 192 |
+
|
| 193 |
# Process in batches
|
| 194 |
uploaded = 0
|
| 195 |
skipped = 0
|
| 196 |
failed = 0
|
| 197 |
all_pages = []
|
| 198 |
upload_queue = []
|
| 199 |
+
|
| 200 |
for batch_start in range(0, total_pages, self.embedding_batch_size):
|
| 201 |
batch_end = min(batch_start + self.embedding_batch_size, total_pages)
|
| 202 |
batch_images = images[batch_start:batch_end]
|
| 203 |
batch_texts = texts[batch_start:batch_end]
|
| 204 |
+
|
| 205 |
logger.info(f"📦 Processing pages {batch_start + 1}-{batch_end}/{total_pages}")
|
| 206 |
+
if progress_callback:
|
| 207 |
+
progress_callback(
|
| 208 |
+
"embed",
|
| 209 |
+
batch_start,
|
| 210 |
+
total_pages,
|
| 211 |
+
f"Embedding pages {batch_start + 1}-{batch_end}",
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
pages_to_process = []
|
| 215 |
for i, (img, text) in enumerate(zip(batch_images, batch_texts)):
|
| 216 |
page_num = batch_start + i + 1
|
| 217 |
+
chunk_id = self.generate_chunk_id(filename, page_num)
|
| 218 |
+
|
| 219 |
if skip_existing and chunk_id in existing_ids:
|
| 220 |
skipped += 1
|
| 221 |
continue
|
| 222 |
+
|
| 223 |
+
pages_to_process.append(
|
| 224 |
+
{
|
| 225 |
+
"index": i,
|
| 226 |
+
"page_num": page_num,
|
| 227 |
+
"chunk_id": chunk_id,
|
| 228 |
+
"raw_image": img,
|
| 229 |
+
"text": text,
|
| 230 |
+
}
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
if not pages_to_process:
|
| 234 |
logger.info(" All pages in batch exist, skipping...")
|
| 235 |
continue
|
| 236 |
+
|
| 237 |
# Generate embeddings with token info
|
| 238 |
logger.info(f"🤖 Generating embeddings for {len(pages_to_process)} pages...")
|
| 239 |
from visual_rag.preprocessing.crop_empty import CropEmptyConfig, crop_empty
|
|
|
|
| 247 |
config=CropEmptyConfig(
|
| 248 |
percentage_to_remove=float(self.crop_empty_percentage_to_remove),
|
| 249 |
remove_page_number=bool(self.crop_empty_remove_page_number),
|
| 250 |
+
preserve_border_px=int(self.crop_empty_preserve_border_px),
|
| 251 |
+
uniform_rowcol_std_threshold=float(
|
| 252 |
+
self.crop_empty_uniform_rowcol_std_threshold
|
| 253 |
+
),
|
| 254 |
),
|
| 255 |
)
|
| 256 |
p["embed_image"] = cropped_img
|
|
|
|
| 260 |
p["embed_image"] = raw_img
|
| 261 |
p["crop_meta"] = None
|
| 262 |
images_to_embed.append(raw_img)
|
| 263 |
+
|
| 264 |
embeddings, token_infos = self.embedder.embed_images(
|
| 265 |
images_to_embed,
|
| 266 |
batch_size=self.embedding_batch_size,
|
| 267 |
return_token_info=True,
|
| 268 |
+
show_progress=False,
|
| 269 |
)
|
| 270 |
+
|
|
|
|
| 271 |
for idx, page_info in enumerate(pages_to_process):
|
| 272 |
raw_img = page_info["raw_image"]
|
| 273 |
embed_img = page_info["embed_image"]
|
|
|
|
| 277 |
text = page_info["text"]
|
| 278 |
embedding = embeddings[idx]
|
| 279 |
token_info = token_infos[idx]
|
| 280 |
+
|
| 281 |
+
if progress_callback:
|
| 282 |
+
progress_callback(
|
| 283 |
+
"process",
|
| 284 |
+
page_num,
|
| 285 |
+
total_pages,
|
| 286 |
+
f"Processing page {page_num}/{total_pages}",
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
try:
|
| 290 |
page_data = self._process_single_page(
|
| 291 |
+
filename=filename,
|
| 292 |
+
pdf_stem=pdf_path.stem,
|
| 293 |
page_num=page_num,
|
| 294 |
chunk_id=chunk_id,
|
| 295 |
total_pages=total_pages,
|
|
|
|
| 302 |
upload_to_cloudinary=upload_to_cloudinary,
|
| 303 |
crop_meta=crop_meta,
|
| 304 |
)
|
| 305 |
+
|
| 306 |
all_pages.append(page_data)
|
| 307 |
+
|
| 308 |
if upload_to_qdrant and self.indexer:
|
| 309 |
upload_queue.append(page_data)
|
| 310 |
+
|
| 311 |
# Upload in batches
|
| 312 |
if len(upload_queue) >= self.upload_batch_size:
|
| 313 |
count = self._upload_batch(upload_queue)
|
| 314 |
uploaded += count
|
| 315 |
upload_queue = []
|
| 316 |
+
|
| 317 |
except Exception as e:
|
| 318 |
logger.error(f" ❌ Failed page {page_num}: {e}")
|
| 319 |
failed += 1
|
| 320 |
+
|
| 321 |
# Memory cleanup
|
| 322 |
gc.collect()
|
| 323 |
if torch.cuda.is_available():
|
| 324 |
torch.cuda.empty_cache()
|
| 325 |
+
|
| 326 |
# Upload remaining pages
|
| 327 |
if upload_queue and upload_to_qdrant and self.indexer:
|
| 328 |
count = self._upload_batch(upload_queue)
|
| 329 |
uploaded += count
|
| 330 |
+
|
| 331 |
+
logger.info(
|
| 332 |
+
f"✅ Completed {filename}: {uploaded} uploaded, {skipped} skipped, {failed} failed"
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
return {
|
| 336 |
+
"filename": filename,
|
| 337 |
"total_pages": total_pages,
|
| 338 |
"uploaded": uploaded,
|
| 339 |
"skipped": skipped,
|
| 340 |
"failed": failed,
|
| 341 |
"pages": all_pages,
|
| 342 |
}
|
| 343 |
+
|
| 344 |
def _process_single_page(
|
| 345 |
self,
|
| 346 |
+
filename: str,
|
| 347 |
+
pdf_stem: str,
|
| 348 |
page_num: int,
|
| 349 |
chunk_id: str,
|
| 350 |
total_pages: int,
|
|
|
|
| 359 |
) -> Dict[str, Any]:
|
| 360 |
"""Process a single page with full metadata for saliency."""
|
| 361 |
from visual_rag.embedding.pooling import global_mean_pooling
|
| 362 |
+
|
| 363 |
# Resize image for ColPali
|
| 364 |
resized_img, tile_rows, tile_cols = self.pdf_processor.resize_for_colpali(embed_img)
|
| 365 |
+
|
| 366 |
# Use processor's tile info if available (more accurate)
|
| 367 |
proc_n_rows = token_info.get("n_rows")
|
| 368 |
proc_n_cols = token_info.get("n_cols")
|
| 369 |
if proc_n_rows and proc_n_cols:
|
| 370 |
tile_rows = proc_n_rows
|
| 371 |
tile_cols = proc_n_cols
|
| 372 |
+
|
| 373 |
# Convert embedding to numpy
|
| 374 |
if isinstance(embedding, torch.Tensor):
|
| 375 |
if embedding.dtype == torch.bfloat16:
|
|
|
|
| 379 |
else:
|
| 380 |
full_embedding = np.array(embedding)
|
| 381 |
full_embedding = full_embedding.astype(np.float32)
|
| 382 |
+
|
| 383 |
# Token info for metadata
|
| 384 |
visual_indices = token_info["visual_token_indices"]
|
| 385 |
num_visual_tokens = token_info["num_visual_tokens"]
|
| 386 |
+
|
| 387 |
# =========================================================================
|
| 388 |
# STRATEGY: "pooling" (NOVEL) vs "standard" (BASELINE) vs "all" (BOTH)
|
| 389 |
# =========================================================================
|
| 390 |
+
|
| 391 |
# Always compute visual-only embedding (needed for pooling and saliency)
|
| 392 |
visual_embedding = full_embedding[visual_indices]
|
| 393 |
+
|
| 394 |
+
tile_pooled = self.embedder.mean_pool_visual_embedding(
|
| 395 |
+
visual_embedding, token_info, target_vectors=32
|
| 396 |
+
)
|
| 397 |
experimental_pooled = self.embedder.experimental_pool_visual_embedding(
|
| 398 |
visual_embedding, token_info, target_vectors=32, mean_pool=tile_pooled
|
| 399 |
)
|
| 400 |
global_pooled = global_mean_pooling(full_embedding)
|
| 401 |
+
global_pooling = (
|
| 402 |
+
self.embedder.global_pool_from_mean_pool(tile_pooled)
|
| 403 |
+
if tile_pooled.size
|
| 404 |
+
else global_pooled
|
| 405 |
+
)
|
| 406 |
|
| 407 |
num_tiles = int(tile_pooled.shape[0])
|
| 408 |
patches_per_tile = int(visual_embedding.shape[0] // max(num_tiles, 1)) if num_tiles else 0
|
|
|
|
| 411 |
else:
|
| 412 |
tile_rows = token_info.get("n_rows") or None
|
| 413 |
tile_cols = token_info.get("n_cols") or None
|
| 414 |
+
|
| 415 |
if self.embedding_strategy == "pooling":
|
| 416 |
# NOVEL APPROACH: Visual tokens only + tile-level pooling
|
| 417 |
embedding_for_initial = visual_embedding
|
| 418 |
embedding_for_pooling = tile_pooled
|
| 419 |
+
global_pooling = (
|
| 420 |
+
self.embedder.global_pool_from_mean_pool(tile_pooled)
|
| 421 |
+
if tile_pooled.size
|
| 422 |
+
else global_pooled
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
elif self.embedding_strategy == "standard":
|
| 426 |
# BASELINE: All tokens + global mean
|
| 427 |
embedding_for_initial = full_embedding
|
| 428 |
embedding_for_pooling = global_pooled.reshape(1, -1)
|
| 429 |
global_pooling = global_pooled
|
| 430 |
+
|
| 431 |
else: # "all" - Push BOTH representations (efficient for comparison)
|
| 432 |
# Embed once, store multiple vector representations
|
| 433 |
# This allows comparing both strategies without re-embedding
|
| 434 |
embedding_for_initial = visual_embedding # Use visual for search
|
| 435 |
+
embedding_for_pooling = tile_pooled # Use tile-level for fast prefetch
|
| 436 |
+
global_pooling = (
|
| 437 |
+
self.embedder.global_pool_from_mean_pool(tile_pooled)
|
| 438 |
+
if tile_pooled.size
|
| 439 |
+
else global_pooled
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
# ALSO store standard representations as additional vectors
|
| 443 |
# These will be added to metadata for optional use
|
| 444 |
pass # Extra vectors handled in return dict below
|
| 445 |
+
|
| 446 |
# Upload to Cloudinary
|
| 447 |
original_url = None
|
| 448 |
cropped_url = None
|
| 449 |
resized_url = None
|
| 450 |
+
|
| 451 |
if upload_to_cloudinary and self.cloudinary_uploader:
|
| 452 |
+
base_filename = f"{pdf_stem}_page_{page_num}"
|
| 453 |
if self.crop_empty:
|
| 454 |
+
original_url, cropped_url, resized_url = (
|
| 455 |
+
self.cloudinary_uploader.upload_original_cropped_and_resized(
|
| 456 |
+
raw_img, embed_img, resized_img, base_filename
|
| 457 |
+
)
|
| 458 |
)
|
| 459 |
else:
|
| 460 |
original_url, resized_url = self.cloudinary_uploader.upload_original_and_resized(
|
| 461 |
raw_img, resized_img, base_filename
|
| 462 |
)
|
| 463 |
+
|
| 464 |
# Sanitize text
|
| 465 |
safe_text = self._sanitize_text(text[:10000]) if text else ""
|
| 466 |
+
|
|
|
|
| 467 |
metadata = {
|
| 468 |
+
"filename": filename,
|
|
|
|
| 469 |
"page_number": page_num,
|
| 470 |
"total_pages": total_pages,
|
| 471 |
"has_text": bool(text and text.strip()),
|
| 472 |
"text": safe_text,
|
|
|
|
| 473 |
# Image URLs
|
| 474 |
"page": resized_url or "", # For display
|
| 475 |
"original_url": original_url or "",
|
| 476 |
"cropped_url": cropped_url or "",
|
| 477 |
"resized_url": resized_url or "",
|
|
|
|
| 478 |
# Dimensions (needed for saliency overlay)
|
| 479 |
"original_width": raw_img.width,
|
| 480 |
"original_height": raw_img.height,
|
|
|
|
| 482 |
"cropped_height": int(embed_img.height) if self.crop_empty else int(raw_img.height),
|
| 483 |
"resized_width": resized_img.width,
|
| 484 |
"resized_height": resized_img.height,
|
|
|
|
| 485 |
# Tile structure (needed for saliency)
|
| 486 |
"num_tiles": num_tiles,
|
| 487 |
"tile_rows": tile_rows,
|
| 488 |
"tile_cols": tile_cols,
|
| 489 |
"patches_per_tile": patches_per_tile,
|
|
|
|
| 490 |
# Token info (needed for saliency)
|
| 491 |
"num_visual_tokens": num_visual_tokens,
|
| 492 |
"visual_token_indices": visual_indices,
|
| 493 |
"total_tokens": len(full_embedding), # Total tokens in raw embedding
|
|
|
|
| 494 |
# Strategy used (important for paper comparison)
|
| 495 |
"embedding_strategy": self.embedding_strategy,
|
|
|
|
| 496 |
"model_name": getattr(self.embedder, "model_name", None),
|
|
|
|
| 497 |
"crop_empty_enabled": bool(self.crop_empty),
|
| 498 |
"crop_empty_crop_box": (crop_meta or {}).get("crop_box"),
|
| 499 |
"crop_empty_remove_page_number": bool(self.crop_empty_remove_page_number),
|
| 500 |
"crop_empty_percentage_to_remove": float(self.crop_empty_percentage_to_remove),
|
| 501 |
+
"crop_empty_preserve_border_px": int(self.crop_empty_preserve_border_px),
|
| 502 |
+
"crop_empty_uniform_rowcol_std_threshold": float(
|
| 503 |
+
self.crop_empty_uniform_rowcol_std_threshold
|
| 504 |
+
),
|
| 505 |
# Extra metadata (year, district, etc.)
|
| 506 |
**extra_metadata,
|
| 507 |
}
|
| 508 |
+
|
| 509 |
result = {
|
| 510 |
"id": chunk_id,
|
| 511 |
+
"visual_embedding": embedding_for_initial, # "initial" vector in Qdrant
|
| 512 |
"tile_pooled_embedding": embedding_for_pooling, # "mean_pooling" vector in Qdrant
|
| 513 |
"experimental_pooled_embedding": experimental_pooled, # "experimental_pooling" vector in Qdrant
|
| 514 |
"global_pooled_embedding": global_pooling, # "global_pooling" vector in Qdrant
|
|
|
|
| 516 |
"image": raw_img,
|
| 517 |
"resized_image": resized_img,
|
| 518 |
}
|
| 519 |
+
|
| 520 |
# For "all" strategy, include BOTH representations for comparison
|
| 521 |
if self.embedding_strategy == "all":
|
| 522 |
result["extra_vectors"] = {
|
| 523 |
# Standard baseline vectors (for comparison)
|
| 524 |
+
"full_embedding": full_embedding, # All tokens [total, 128]
|
| 525 |
+
"global_pooled": global_pooled, # Global mean [128]
|
| 526 |
# Pooling vectors (already in main result)
|
| 527 |
+
"visual_embedding": visual_embedding, # Visual only [visual, 128]
|
| 528 |
+
"tile_pooled": tile_pooled, # Tile-level [tiles, 128]
|
| 529 |
}
|
| 530 |
+
|
| 531 |
return result
|
| 532 |
+
|
| 533 |
def _upload_batch(self, upload_queue: List[Dict[str, Any]]) -> int:
|
| 534 |
"""Upload batch to Qdrant."""
|
| 535 |
if not upload_queue or not self.indexer:
|
| 536 |
return 0
|
| 537 |
+
|
| 538 |
logger.info(f"📤 Uploading batch of {len(upload_queue)} pages...")
|
| 539 |
+
|
| 540 |
count = self.indexer.upload_batch(
|
| 541 |
upload_queue,
|
| 542 |
delay_between_batches=self.delay_between_uploads,
|
| 543 |
)
|
| 544 |
+
|
| 545 |
return count
|
| 546 |
+
|
| 547 |
def _get_extra_metadata(self, filename: str) -> Dict[str, Any]:
|
| 548 |
"""Get extra metadata for a filename."""
|
| 549 |
if not self.metadata_mapping:
|
| 550 |
return {}
|
| 551 |
+
|
| 552 |
# Normalize filename
|
| 553 |
filename_clean = filename.replace(".pdf", "").replace(".PDF", "").strip().lower()
|
| 554 |
+
|
| 555 |
# Try exact match
|
| 556 |
if filename_clean in self.metadata_mapping:
|
| 557 |
return self.metadata_mapping[filename_clean].copy()
|
| 558 |
+
|
| 559 |
# Try fuzzy match
|
| 560 |
from difflib import SequenceMatcher
|
| 561 |
+
|
| 562 |
best_match = None
|
| 563 |
best_score = 0.0
|
| 564 |
+
|
| 565 |
for known_filename, metadata in self.metadata_mapping.items():
|
| 566 |
score = SequenceMatcher(None, filename_clean, known_filename.lower()).ratio()
|
| 567 |
if score > best_score and score > 0.75:
|
| 568 |
best_score = score
|
| 569 |
best_match = metadata
|
| 570 |
+
|
| 571 |
if best_match:
|
| 572 |
logger.debug(f"Fuzzy matched '{filename}' with score {best_score:.2f}")
|
| 573 |
return best_match.copy()
|
| 574 |
+
|
| 575 |
return {}
|
| 576 |
+
|
| 577 |
def _sanitize_text(self, text: str) -> str:
|
| 578 |
"""Remove invalid Unicode characters."""
|
| 579 |
if not text:
|
| 580 |
return ""
|
| 581 |
return text.encode("utf-8", errors="surrogatepass").decode("utf-8", errors="ignore")
|
| 582 |
+
|
| 583 |
@staticmethod
|
| 584 |
def generate_chunk_id(filename: str, page_number: int) -> str:
|
| 585 |
"""Generate deterministic chunk ID."""
|
|
|
|
| 587 |
hash_obj = hashlib.sha256(content.encode())
|
| 588 |
hex_str = hash_obj.hexdigest()[:32]
|
| 589 |
return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
|
| 590 |
+
|
| 591 |
@staticmethod
|
| 592 |
def load_metadata_mapping(json_path: Path) -> Dict[str, Dict[str, Any]]:
|
| 593 |
"""
|
| 594 |
Load metadata mapping from JSON file.
|
| 595 |
+
|
| 596 |
Expected format:
|
| 597 |
{
|
| 598 |
"filenames": {
|
|
|
|
| 600 |
...
|
| 601 |
}
|
| 602 |
}
|
| 603 |
+
|
| 604 |
Or simple format:
|
| 605 |
{
|
| 606 |
"Report Name 2023": {"year": 2023, "source": "Local Government", ...},
|
|
|
|
| 608 |
}
|
| 609 |
"""
|
| 610 |
import json
|
| 611 |
+
|
| 612 |
with open(json_path, "r") as f:
|
| 613 |
data = json.load(f)
|
| 614 |
+
|
| 615 |
# Check if nested under "filenames"
|
| 616 |
if "filenames" in data and isinstance(data["filenames"], dict):
|
| 617 |
mapping = data["filenames"]
|
| 618 |
else:
|
| 619 |
mapping = data
|
| 620 |
+
|
| 621 |
# Normalize keys to lowercase
|
| 622 |
normalized = {}
|
| 623 |
for filename, metadata in mapping.items():
|
| 624 |
key = filename.lower().strip().replace(".pdf", "")
|
| 625 |
normalized[key] = metadata
|
| 626 |
+
|
| 627 |
logger.info(f"📖 Loaded metadata for {len(normalized)} files")
|
| 628 |
return normalized
|
|
|
visual_rag/indexing/qdrant_indexer.py
CHANGED
|
@@ -11,43 +11,61 @@ Features:
|
|
| 11 |
- Configurable payload indexes
|
| 12 |
"""
|
| 13 |
|
| 14 |
-
import time
|
| 15 |
import hashlib
|
| 16 |
import logging
|
| 17 |
-
|
|
|
|
| 18 |
from urllib.parse import urlparse
|
|
|
|
| 19 |
import numpy as np
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
logger = logging.getLogger(__name__)
|
| 22 |
|
| 23 |
|
| 24 |
class QdrantIndexer:
|
| 25 |
"""
|
| 26 |
Upload visual embeddings to Qdrant.
|
| 27 |
-
|
| 28 |
Works independently - just needs embeddings and metadata.
|
| 29 |
-
|
| 30 |
Args:
|
| 31 |
url: Qdrant server URL
|
| 32 |
api_key: Qdrant API key
|
| 33 |
collection_name: Name of the collection
|
| 34 |
timeout: Request timeout in seconds
|
| 35 |
prefer_grpc: Use gRPC protocol (faster but may have issues)
|
| 36 |
-
|
| 37 |
Example:
|
| 38 |
>>> indexer = QdrantIndexer(
|
| 39 |
... url="https://your-cluster.qdrant.io:6333",
|
| 40 |
... api_key="your-api-key",
|
| 41 |
... collection_name="my_collection",
|
| 42 |
... )
|
| 43 |
-
>>>
|
| 44 |
>>> # Create collection
|
| 45 |
>>> indexer.create_collection()
|
| 46 |
-
>>>
|
| 47 |
>>> # Upload points
|
| 48 |
>>> indexer.upload_batch(points)
|
| 49 |
"""
|
| 50 |
-
|
| 51 |
def __init__(
|
| 52 |
self,
|
| 53 |
url: str,
|
|
@@ -57,14 +75,12 @@ class QdrantIndexer:
|
|
| 57 |
prefer_grpc: bool = False,
|
| 58 |
vector_datatype: str = "float32",
|
| 59 |
):
|
| 60 |
-
|
| 61 |
-
from qdrant_client import QdrantClient
|
| 62 |
-
except ImportError:
|
| 63 |
raise ImportError(
|
| 64 |
"Qdrant client not installed. "
|
| 65 |
"Install with: pip install visual-rag-toolkit[qdrant]"
|
| 66 |
)
|
| 67 |
-
|
| 68 |
self.collection_name = collection_name
|
| 69 |
self.timeout = timeout
|
| 70 |
if vector_datatype not in ("float32", "float16"):
|
|
@@ -81,7 +97,7 @@ class QdrantIndexer:
|
|
| 81 |
grpc_port = 6334
|
| 82 |
except Exception:
|
| 83 |
grpc_port = None
|
| 84 |
-
|
| 85 |
def _make_client(use_grpc: bool):
|
| 86 |
return QdrantClient(
|
| 87 |
url=url,
|
|
@@ -102,16 +118,16 @@ class QdrantIndexer:
|
|
| 102 |
self.client = _make_client(False)
|
| 103 |
else:
|
| 104 |
raise
|
| 105 |
-
|
| 106 |
logger.info(f"🔌 Connected to Qdrant: {url}")
|
| 107 |
logger.info(f" Collection: {collection_name}")
|
| 108 |
logger.info(f" Vector datatype: {self.vector_datatype}")
|
| 109 |
-
|
| 110 |
def collection_exists(self) -> bool:
|
| 111 |
"""Check if collection exists."""
|
| 112 |
collections = self.client.get_collections().collections
|
| 113 |
return any(c.name == self.collection_name for c in collections)
|
| 114 |
-
|
| 115 |
def create_collection(
|
| 116 |
self,
|
| 117 |
embedding_dim: int = 128,
|
|
@@ -122,32 +138,22 @@ class QdrantIndexer:
|
|
| 122 |
) -> bool:
|
| 123 |
"""
|
| 124 |
Create collection with multi-vector support.
|
| 125 |
-
|
| 126 |
Creates named vectors:
|
| 127 |
- initial: Full multi-vector embeddings (num_patches × dim)
|
| 128 |
- mean_pooling: Tile-level pooled vectors (num_tiles × dim)
|
| 129 |
- experimental_pooling: Experimental multi-vector pooling (varies by model)
|
| 130 |
- global_pooling: Single vector pooled representation (dim)
|
| 131 |
-
|
| 132 |
Args:
|
| 133 |
embedding_dim: Embedding dimension (128 for ColSmol)
|
| 134 |
force_recreate: Delete and recreate if exists
|
| 135 |
enable_quantization: Enable int8 quantization
|
| 136 |
indexing_threshold: Qdrant optimizer indexing threshold (set 0 to always build ANN indexes)
|
| 137 |
-
|
| 138 |
Returns:
|
| 139 |
True if created, False if already existed
|
| 140 |
"""
|
| 141 |
-
from qdrant_client.http import models
|
| 142 |
-
from qdrant_client.http.models import (
|
| 143 |
-
Distance,
|
| 144 |
-
VectorParams,
|
| 145 |
-
OptimizersConfigDiff,
|
| 146 |
-
HnswConfigDiff,
|
| 147 |
-
ScalarQuantizationConfig,
|
| 148 |
-
ScalarType,
|
| 149 |
-
)
|
| 150 |
-
|
| 151 |
if self.collection_exists():
|
| 152 |
if force_recreate:
|
| 153 |
logger.info(f"🗑️ Deleting existing collection: {self.collection_name}")
|
|
@@ -155,120 +161,99 @@ class QdrantIndexer:
|
|
| 155 |
else:
|
| 156 |
logger.info(f"✅ Collection already exists: {self.collection_name}")
|
| 157 |
return False
|
| 158 |
-
|
| 159 |
logger.info(f"📦 Creating collection: {self.collection_name}")
|
| 160 |
-
|
| 161 |
# Multi-vector config for ColBERT-style MaxSim
|
| 162 |
-
multivector_config =
|
| 163 |
-
comparator=
|
| 164 |
)
|
| 165 |
-
|
| 166 |
-
#
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
on_disk=True,
|
| 172 |
)
|
| 173 |
-
|
| 174 |
-
# Optional quantization
|
| 175 |
-
quantization_config = None
|
| 176 |
-
if enable_quantization:
|
| 177 |
-
logger.info(" Quantization: ENABLED (int8)")
|
| 178 |
-
quantization_config = ScalarQuantizationConfig(
|
| 179 |
-
type=ScalarType.INT8,
|
| 180 |
-
quantile=0.99,
|
| 181 |
-
always_ram=True,
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
-
# Vector configs
|
| 185 |
-
datatype = models.Datatype.FLOAT16 if self.vector_datatype == "float16" else models.Datatype.FLOAT32
|
| 186 |
vectors_config = {
|
| 187 |
"initial": VectorParams(
|
| 188 |
size=embedding_dim,
|
| 189 |
distance=Distance.COSINE,
|
| 190 |
on_disk=True,
|
| 191 |
multivector_config=multivector_config,
|
| 192 |
-
hnsw_config=hnsw_config,
|
| 193 |
datatype=datatype,
|
| 194 |
-
quantization_config=quantization_config,
|
| 195 |
),
|
| 196 |
"mean_pooling": VectorParams(
|
| 197 |
size=embedding_dim,
|
| 198 |
distance=Distance.COSINE,
|
| 199 |
-
on_disk=False,
|
| 200 |
multivector_config=multivector_config,
|
| 201 |
-
hnsw_config=hnsw_config,
|
| 202 |
datatype=datatype,
|
| 203 |
-
quantization_config=quantization_config,
|
| 204 |
),
|
| 205 |
"experimental_pooling": VectorParams(
|
| 206 |
size=embedding_dim,
|
| 207 |
distance=Distance.COSINE,
|
| 208 |
-
on_disk=False,
|
| 209 |
multivector_config=multivector_config,
|
| 210 |
-
hnsw_config=hnsw_config,
|
| 211 |
datatype=datatype,
|
| 212 |
-
quantization_config=quantization_config,
|
| 213 |
),
|
| 214 |
"global_pooling": VectorParams(
|
| 215 |
size=embedding_dim,
|
| 216 |
distance=Distance.COSINE,
|
| 217 |
-
on_disk=False,
|
| 218 |
-
hnsw_config=hnsw_config,
|
| 219 |
datatype=datatype,
|
| 220 |
-
quantization_config=quantization_config,
|
| 221 |
),
|
| 222 |
}
|
| 223 |
-
|
| 224 |
-
# Optimizer config for low-RAM clusters
|
| 225 |
-
optimizer_config = OptimizersConfigDiff(
|
| 226 |
-
indexing_threshold=int(indexing_threshold),
|
| 227 |
-
memmap_threshold=0, # Use mmap immediately
|
| 228 |
-
flush_interval_sec=5, # Flush WAL frequently
|
| 229 |
-
)
|
| 230 |
-
|
| 231 |
self.client.create_collection(
|
| 232 |
collection_name=self.collection_name,
|
| 233 |
vectors_config=vectors_config,
|
| 234 |
-
optimizers_config=optimizer_config,
|
| 235 |
-
hnsw_config=hnsw_config,
|
| 236 |
)
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
logger.info(f"✅ Collection created: {self.collection_name}")
|
| 239 |
return True
|
| 240 |
-
|
| 241 |
def create_payload_indexes(
|
| 242 |
self,
|
| 243 |
fields: Optional[List[Dict[str, str]]] = None,
|
| 244 |
):
|
| 245 |
"""
|
| 246 |
Create payload indexes for filtering.
|
| 247 |
-
|
| 248 |
Args:
|
| 249 |
fields: List of {field, type} dicts
|
| 250 |
type can be: integer, keyword, bool, float, text
|
| 251 |
"""
|
| 252 |
-
from qdrant_client.http import models
|
| 253 |
-
|
| 254 |
type_mapping = {
|
| 255 |
-
"integer":
|
| 256 |
-
"keyword":
|
| 257 |
-
"bool":
|
| 258 |
-
"float":
|
| 259 |
-
"text":
|
| 260 |
}
|
| 261 |
-
|
| 262 |
if not fields:
|
| 263 |
return
|
| 264 |
-
|
| 265 |
logger.info("📇 Creating payload indexes...")
|
| 266 |
-
|
| 267 |
for field_config in fields:
|
| 268 |
field_name = field_config["field"]
|
| 269 |
field_type_str = field_config.get("type", "keyword")
|
| 270 |
-
field_type = type_mapping.get(field_type_str,
|
| 271 |
-
|
| 272 |
try:
|
| 273 |
self.client.create_payload_index(
|
| 274 |
collection_name=self.collection_name,
|
|
@@ -278,7 +263,7 @@ class QdrantIndexer:
|
|
| 278 |
logger.info(f" ✅ {field_name} ({field_type_str})")
|
| 279 |
except Exception as e:
|
| 280 |
logger.debug(f" Index {field_name} might already exist: {e}")
|
| 281 |
-
|
| 282 |
def upload_batch(
|
| 283 |
self,
|
| 284 |
points: List[Dict[str, Any]],
|
|
@@ -289,7 +274,7 @@ class QdrantIndexer:
|
|
| 289 |
) -> int:
|
| 290 |
"""
|
| 291 |
Upload a batch of points to Qdrant.
|
| 292 |
-
|
| 293 |
Each point should have:
|
| 294 |
- id: Unique point ID (string or UUID)
|
| 295 |
- visual_embedding: Full embedding [num_patches, dim]
|
|
@@ -297,28 +282,28 @@ class QdrantIndexer:
|
|
| 297 |
- experimental_pooled_embedding: Experimental pooled embedding [*, dim]
|
| 298 |
- global_pooled_embedding: Pooled embedding [dim]
|
| 299 |
- metadata: Payload dict
|
| 300 |
-
|
| 301 |
Args:
|
| 302 |
points: List of point dicts
|
| 303 |
max_retries: Retry attempts on failure
|
| 304 |
delay_between_batches: Delay after upload
|
| 305 |
wait: Wait for operation to complete on Qdrant server
|
| 306 |
stop_event: Optional threading.Event used to cancel uploads early
|
| 307 |
-
|
| 308 |
Returns:
|
| 309 |
Number of successfully uploaded points
|
| 310 |
"""
|
| 311 |
-
from qdrant_client.http import models
|
| 312 |
-
|
| 313 |
if not points:
|
| 314 |
return 0
|
| 315 |
|
| 316 |
def _is_cancelled() -> bool:
|
| 317 |
return stop_event is not None and getattr(stop_event, "is_set", lambda: False)()
|
| 318 |
-
|
| 319 |
def _is_payload_too_large_error(e: Exception) -> bool:
|
| 320 |
msg = str(e)
|
| 321 |
-
if ("JSON payload" in msg and "larger than allowed" in msg) or (
|
|
|
|
|
|
|
| 322 |
return True
|
| 323 |
content = getattr(e, "content", None)
|
| 324 |
if content is not None:
|
|
@@ -329,7 +314,9 @@ class QdrantIndexer:
|
|
| 329 |
text = str(content)
|
| 330 |
except Exception:
|
| 331 |
text = ""
|
| 332 |
-
if ("JSON payload" in text and "larger than allowed" in text) or (
|
|
|
|
|
|
|
| 333 |
return True
|
| 334 |
resp = getattr(e, "response", None)
|
| 335 |
if resp is not None:
|
|
@@ -337,7 +324,9 @@ class QdrantIndexer:
|
|
| 337 |
text = str(getattr(resp, "text", "") or "")
|
| 338 |
except Exception:
|
| 339 |
text = ""
|
| 340 |
-
if ("JSON payload" in text and "larger than allowed" in text) or (
|
|
|
|
|
|
|
| 341 |
return True
|
| 342 |
return False
|
| 343 |
|
|
@@ -346,8 +335,8 @@ class QdrantIndexer:
|
|
| 346 |
return val.tolist()
|
| 347 |
return val
|
| 348 |
|
| 349 |
-
def _build_qdrant_points(batch_points: List[Dict[str, Any]]) -> List[
|
| 350 |
-
qdrant_points: List[
|
| 351 |
for p in batch_points:
|
| 352 |
global_pooled = p.get("global_pooled_embedding")
|
| 353 |
if global_pooled is None:
|
|
@@ -355,15 +344,19 @@ class QdrantIndexer:
|
|
| 355 |
global_pooled = tile_pooled.mean(axis=0)
|
| 356 |
global_pooled = np.array(global_pooled, dtype=np.float32).reshape(-1)
|
| 357 |
|
| 358 |
-
initial = np.array(p["visual_embedding"], dtype=np.float32).astype(
|
| 359 |
-
|
| 360 |
-
|
|
|
|
| 361 |
self._np_vector_dtype, copy=False
|
| 362 |
)
|
|
|
|
|
|
|
|
|
|
| 363 |
global_pooling = global_pooled.astype(self._np_vector_dtype, copy=False)
|
| 364 |
|
| 365 |
qdrant_points.append(
|
| 366 |
-
|
| 367 |
id=p["id"],
|
| 368 |
vector={
|
| 369 |
"initial": _to_list(initial),
|
|
@@ -375,7 +368,7 @@ class QdrantIndexer:
|
|
| 375 |
)
|
| 376 |
)
|
| 377 |
return qdrant_points
|
| 378 |
-
|
| 379 |
# Upload with retry
|
| 380 |
for attempt in range(max_retries):
|
| 381 |
try:
|
|
@@ -421,11 +414,11 @@ class QdrantIndexer:
|
|
| 421 |
if attempt < max_retries - 1:
|
| 422 |
if _is_cancelled():
|
| 423 |
return 0
|
| 424 |
-
time.sleep(2
|
| 425 |
-
|
| 426 |
logger.error(f"❌ Upload failed after {max_retries} attempts")
|
| 427 |
return 0
|
| 428 |
-
|
| 429 |
def check_exists(self, chunk_id: str) -> bool:
|
| 430 |
"""Check if a point already exists."""
|
| 431 |
try:
|
|
@@ -438,50 +431,78 @@ class QdrantIndexer:
|
|
| 438 |
return len(result) > 0
|
| 439 |
except Exception:
|
| 440 |
return False
|
| 441 |
-
|
| 442 |
def get_existing_ids(self, filename: str) -> Set[str]:
|
| 443 |
-
"""Get all point IDs for a specific file.
|
| 444 |
-
|
| 445 |
-
|
|
|
|
|
|
|
| 446 |
existing_ids = set()
|
| 447 |
offset = None
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
return existing_ids
|
| 471 |
-
|
| 472 |
def get_collection_info(self) -> Optional[Dict[str, Any]]:
|
| 473 |
"""Get collection statistics."""
|
| 474 |
try:
|
| 475 |
info = self.client.get_collection(self.collection_name)
|
| 476 |
-
|
| 477 |
status = info.status
|
| 478 |
if hasattr(status, "value"):
|
| 479 |
status = status.value
|
| 480 |
-
|
| 481 |
indexed_count = getattr(info, "indexed_vectors_count", 0) or 0
|
| 482 |
if isinstance(indexed_count, dict):
|
| 483 |
indexed_count = sum(indexed_count.values())
|
| 484 |
-
|
| 485 |
return {
|
| 486 |
"status": str(status),
|
| 487 |
"points_count": getattr(info, "points_count", 0),
|
|
@@ -490,12 +511,12 @@ class QdrantIndexer:
|
|
| 490 |
except Exception as e:
|
| 491 |
logger.warning(f"Could not get collection info: {e}")
|
| 492 |
return None
|
| 493 |
-
|
| 494 |
@staticmethod
|
| 495 |
def generate_point_id(filename: str, page_number: int) -> str:
|
| 496 |
"""
|
| 497 |
Generate deterministic point ID from filename and page.
|
| 498 |
-
|
| 499 |
Returns a valid UUID string.
|
| 500 |
"""
|
| 501 |
content = f"{filename}:page:{page_number}"
|
|
@@ -503,5 +524,3 @@ class QdrantIndexer:
|
|
| 503 |
hex_str = hash_obj.hexdigest()[:32]
|
| 504 |
# Format as UUID
|
| 505 |
return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
|
| 506 |
-
|
| 507 |
-
|
|
|
|
| 11 |
- Configurable payload indexes
|
| 12 |
"""
|
| 13 |
|
|
|
|
| 14 |
import hashlib
|
| 15 |
import logging
|
| 16 |
+
import time
|
| 17 |
+
from typing import Any, Dict, List, Optional, Set
|
| 18 |
from urllib.parse import urlparse
|
| 19 |
+
|
| 20 |
import numpy as np
|
| 21 |
|
| 22 |
+
try:
|
| 23 |
+
from qdrant_client import QdrantClient
|
| 24 |
+
from qdrant_client.http import models as qdrant_models
|
| 25 |
+
from qdrant_client.http.models import Distance, VectorParams
|
| 26 |
+
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
| 27 |
+
|
| 28 |
+
QDRANT_AVAILABLE = True
|
| 29 |
+
except ImportError:
|
| 30 |
+
QDRANT_AVAILABLE = False
|
| 31 |
+
QdrantClient = None
|
| 32 |
+
qdrant_models = None
|
| 33 |
+
Distance = None
|
| 34 |
+
VectorParams = None
|
| 35 |
+
FieldCondition = None
|
| 36 |
+
Filter = None
|
| 37 |
+
MatchValue = None
|
| 38 |
+
|
| 39 |
logger = logging.getLogger(__name__)
|
| 40 |
|
| 41 |
|
| 42 |
class QdrantIndexer:
|
| 43 |
"""
|
| 44 |
Upload visual embeddings to Qdrant.
|
| 45 |
+
|
| 46 |
Works independently - just needs embeddings and metadata.
|
| 47 |
+
|
| 48 |
Args:
|
| 49 |
url: Qdrant server URL
|
| 50 |
api_key: Qdrant API key
|
| 51 |
collection_name: Name of the collection
|
| 52 |
timeout: Request timeout in seconds
|
| 53 |
prefer_grpc: Use gRPC protocol (faster but may have issues)
|
| 54 |
+
|
| 55 |
Example:
|
| 56 |
>>> indexer = QdrantIndexer(
|
| 57 |
... url="https://your-cluster.qdrant.io:6333",
|
| 58 |
... api_key="your-api-key",
|
| 59 |
... collection_name="my_collection",
|
| 60 |
... )
|
| 61 |
+
>>>
|
| 62 |
>>> # Create collection
|
| 63 |
>>> indexer.create_collection()
|
| 64 |
+
>>>
|
| 65 |
>>> # Upload points
|
| 66 |
>>> indexer.upload_batch(points)
|
| 67 |
"""
|
| 68 |
+
|
| 69 |
def __init__(
|
| 70 |
self,
|
| 71 |
url: str,
|
|
|
|
| 75 |
prefer_grpc: bool = False,
|
| 76 |
vector_datatype: str = "float32",
|
| 77 |
):
|
| 78 |
+
if not QDRANT_AVAILABLE:
|
|
|
|
|
|
|
| 79 |
raise ImportError(
|
| 80 |
"Qdrant client not installed. "
|
| 81 |
"Install with: pip install visual-rag-toolkit[qdrant]"
|
| 82 |
)
|
| 83 |
+
|
| 84 |
self.collection_name = collection_name
|
| 85 |
self.timeout = timeout
|
| 86 |
if vector_datatype not in ("float32", "float16"):
|
|
|
|
| 97 |
grpc_port = 6334
|
| 98 |
except Exception:
|
| 99 |
grpc_port = None
|
| 100 |
+
|
| 101 |
def _make_client(use_grpc: bool):
|
| 102 |
return QdrantClient(
|
| 103 |
url=url,
|
|
|
|
| 118 |
self.client = _make_client(False)
|
| 119 |
else:
|
| 120 |
raise
|
| 121 |
+
|
| 122 |
logger.info(f"🔌 Connected to Qdrant: {url}")
|
| 123 |
logger.info(f" Collection: {collection_name}")
|
| 124 |
logger.info(f" Vector datatype: {self.vector_datatype}")
|
| 125 |
+
|
| 126 |
def collection_exists(self) -> bool:
|
| 127 |
"""Check if collection exists."""
|
| 128 |
collections = self.client.get_collections().collections
|
| 129 |
return any(c.name == self.collection_name for c in collections)
|
| 130 |
+
|
| 131 |
def create_collection(
|
| 132 |
self,
|
| 133 |
embedding_dim: int = 128,
|
|
|
|
| 138 |
) -> bool:
|
| 139 |
"""
|
| 140 |
Create collection with multi-vector support.
|
| 141 |
+
|
| 142 |
Creates named vectors:
|
| 143 |
- initial: Full multi-vector embeddings (num_patches × dim)
|
| 144 |
- mean_pooling: Tile-level pooled vectors (num_tiles × dim)
|
| 145 |
- experimental_pooling: Experimental multi-vector pooling (varies by model)
|
| 146 |
- global_pooling: Single vector pooled representation (dim)
|
| 147 |
+
|
| 148 |
Args:
|
| 149 |
embedding_dim: Embedding dimension (128 for ColSmol)
|
| 150 |
force_recreate: Delete and recreate if exists
|
| 151 |
enable_quantization: Enable int8 quantization
|
| 152 |
indexing_threshold: Qdrant optimizer indexing threshold (set 0 to always build ANN indexes)
|
| 153 |
+
|
| 154 |
Returns:
|
| 155 |
True if created, False if already existed
|
| 156 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
if self.collection_exists():
|
| 158 |
if force_recreate:
|
| 159 |
logger.info(f"🗑️ Deleting existing collection: {self.collection_name}")
|
|
|
|
| 161 |
else:
|
| 162 |
logger.info(f"✅ Collection already exists: {self.collection_name}")
|
| 163 |
return False
|
| 164 |
+
|
| 165 |
logger.info(f"📦 Creating collection: {self.collection_name}")
|
| 166 |
+
|
| 167 |
# Multi-vector config for ColBERT-style MaxSim
|
| 168 |
+
multivector_config = qdrant_models.MultiVectorConfig(
|
| 169 |
+
comparator=qdrant_models.MultiVectorComparator.MAX_SIM
|
| 170 |
)
|
| 171 |
+
|
| 172 |
+
# Vector configs - simplified for compatibility
|
| 173 |
+
datatype = (
|
| 174 |
+
qdrant_models.Datatype.FLOAT16
|
| 175 |
+
if self.vector_datatype == "float16"
|
| 176 |
+
else qdrant_models.Datatype.FLOAT32
|
|
|
|
| 177 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
vectors_config = {
|
| 179 |
"initial": VectorParams(
|
| 180 |
size=embedding_dim,
|
| 181 |
distance=Distance.COSINE,
|
| 182 |
on_disk=True,
|
| 183 |
multivector_config=multivector_config,
|
|
|
|
| 184 |
datatype=datatype,
|
|
|
|
| 185 |
),
|
| 186 |
"mean_pooling": VectorParams(
|
| 187 |
size=embedding_dim,
|
| 188 |
distance=Distance.COSINE,
|
| 189 |
+
on_disk=False,
|
| 190 |
multivector_config=multivector_config,
|
|
|
|
| 191 |
datatype=datatype,
|
|
|
|
| 192 |
),
|
| 193 |
"experimental_pooling": VectorParams(
|
| 194 |
size=embedding_dim,
|
| 195 |
distance=Distance.COSINE,
|
| 196 |
+
on_disk=False,
|
| 197 |
multivector_config=multivector_config,
|
|
|
|
| 198 |
datatype=datatype,
|
|
|
|
| 199 |
),
|
| 200 |
"global_pooling": VectorParams(
|
| 201 |
size=embedding_dim,
|
| 202 |
distance=Distance.COSINE,
|
| 203 |
+
on_disk=False,
|
|
|
|
| 204 |
datatype=datatype,
|
|
|
|
| 205 |
),
|
| 206 |
}
|
| 207 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
self.client.create_collection(
|
| 209 |
collection_name=self.collection_name,
|
| 210 |
vectors_config=vectors_config,
|
|
|
|
|
|
|
| 211 |
)
|
| 212 |
+
|
| 213 |
+
# Create required payload index for skip_existing functionality
|
| 214 |
+
# This index is needed for filtering by filename when checking existing docs
|
| 215 |
+
try:
|
| 216 |
+
self.client.create_payload_index(
|
| 217 |
+
collection_name=self.collection_name,
|
| 218 |
+
field_name="filename",
|
| 219 |
+
field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
|
| 220 |
+
)
|
| 221 |
+
logger.info(" 📇 Created payload index: filename")
|
| 222 |
+
except Exception as e:
|
| 223 |
+
logger.warning(f" ⚠️ Could not create filename index: {e}")
|
| 224 |
+
|
| 225 |
logger.info(f"✅ Collection created: {self.collection_name}")
|
| 226 |
return True
|
| 227 |
+
|
| 228 |
def create_payload_indexes(
|
| 229 |
self,
|
| 230 |
fields: Optional[List[Dict[str, str]]] = None,
|
| 231 |
):
|
| 232 |
"""
|
| 233 |
Create payload indexes for filtering.
|
| 234 |
+
|
| 235 |
Args:
|
| 236 |
fields: List of {field, type} dicts
|
| 237 |
type can be: integer, keyword, bool, float, text
|
| 238 |
"""
|
|
|
|
|
|
|
| 239 |
type_mapping = {
|
| 240 |
+
"integer": qdrant_models.PayloadSchemaType.INTEGER,
|
| 241 |
+
"keyword": qdrant_models.PayloadSchemaType.KEYWORD,
|
| 242 |
+
"bool": qdrant_models.PayloadSchemaType.BOOL,
|
| 243 |
+
"float": qdrant_models.PayloadSchemaType.FLOAT,
|
| 244 |
+
"text": qdrant_models.PayloadSchemaType.TEXT,
|
| 245 |
}
|
| 246 |
+
|
| 247 |
if not fields:
|
| 248 |
return
|
| 249 |
+
|
| 250 |
logger.info("📇 Creating payload indexes...")
|
| 251 |
+
|
| 252 |
for field_config in fields:
|
| 253 |
field_name = field_config["field"]
|
| 254 |
field_type_str = field_config.get("type", "keyword")
|
| 255 |
+
field_type = type_mapping.get(field_type_str, qdrant_models.PayloadSchemaType.KEYWORD)
|
| 256 |
+
|
| 257 |
try:
|
| 258 |
self.client.create_payload_index(
|
| 259 |
collection_name=self.collection_name,
|
|
|
|
| 263 |
logger.info(f" ✅ {field_name} ({field_type_str})")
|
| 264 |
except Exception as e:
|
| 265 |
logger.debug(f" Index {field_name} might already exist: {e}")
|
| 266 |
+
|
| 267 |
def upload_batch(
|
| 268 |
self,
|
| 269 |
points: List[Dict[str, Any]],
|
|
|
|
| 274 |
) -> int:
|
| 275 |
"""
|
| 276 |
Upload a batch of points to Qdrant.
|
| 277 |
+
|
| 278 |
Each point should have:
|
| 279 |
- id: Unique point ID (string or UUID)
|
| 280 |
- visual_embedding: Full embedding [num_patches, dim]
|
|
|
|
| 282 |
- experimental_pooled_embedding: Experimental pooled embedding [*, dim]
|
| 283 |
- global_pooled_embedding: Pooled embedding [dim]
|
| 284 |
- metadata: Payload dict
|
| 285 |
+
|
| 286 |
Args:
|
| 287 |
points: List of point dicts
|
| 288 |
max_retries: Retry attempts on failure
|
| 289 |
delay_between_batches: Delay after upload
|
| 290 |
wait: Wait for operation to complete on Qdrant server
|
| 291 |
stop_event: Optional threading.Event used to cancel uploads early
|
| 292 |
+
|
| 293 |
Returns:
|
| 294 |
Number of successfully uploaded points
|
| 295 |
"""
|
|
|
|
|
|
|
| 296 |
if not points:
|
| 297 |
return 0
|
| 298 |
|
| 299 |
def _is_cancelled() -> bool:
|
| 300 |
return stop_event is not None and getattr(stop_event, "is_set", lambda: False)()
|
| 301 |
+
|
| 302 |
def _is_payload_too_large_error(e: Exception) -> bool:
|
| 303 |
msg = str(e)
|
| 304 |
+
if ("JSON payload" in msg and "larger than allowed" in msg) or (
|
| 305 |
+
"Payload error:" in msg and "limit:" in msg
|
| 306 |
+
):
|
| 307 |
return True
|
| 308 |
content = getattr(e, "content", None)
|
| 309 |
if content is not None:
|
|
|
|
| 314 |
text = str(content)
|
| 315 |
except Exception:
|
| 316 |
text = ""
|
| 317 |
+
if ("JSON payload" in text and "larger than allowed" in text) or (
|
| 318 |
+
"Payload error" in text and "limit" in text
|
| 319 |
+
):
|
| 320 |
return True
|
| 321 |
resp = getattr(e, "response", None)
|
| 322 |
if resp is not None:
|
|
|
|
| 324 |
text = str(getattr(resp, "text", "") or "")
|
| 325 |
except Exception:
|
| 326 |
text = ""
|
| 327 |
+
if ("JSON payload" in text and "larger than allowed" in text) or (
|
| 328 |
+
"Payload error" in text and "limit" in text
|
| 329 |
+
):
|
| 330 |
return True
|
| 331 |
return False
|
| 332 |
|
|
|
|
| 335 |
return val.tolist()
|
| 336 |
return val
|
| 337 |
|
| 338 |
+
def _build_qdrant_points(batch_points: List[Dict[str, Any]]) -> List[qdrant_models.PointStruct]:
|
| 339 |
+
qdrant_points: List[qdrant_models.PointStruct] = []
|
| 340 |
for p in batch_points:
|
| 341 |
global_pooled = p.get("global_pooled_embedding")
|
| 342 |
if global_pooled is None:
|
|
|
|
| 344 |
global_pooled = tile_pooled.mean(axis=0)
|
| 345 |
global_pooled = np.array(global_pooled, dtype=np.float32).reshape(-1)
|
| 346 |
|
| 347 |
+
initial = np.array(p["visual_embedding"], dtype=np.float32).astype(
|
| 348 |
+
self._np_vector_dtype, copy=False
|
| 349 |
+
)
|
| 350 |
+
mean_pooling = np.array(p["tile_pooled_embedding"], dtype=np.float32).astype(
|
| 351 |
self._np_vector_dtype, copy=False
|
| 352 |
)
|
| 353 |
+
experimental_pooling = np.array(
|
| 354 |
+
p["experimental_pooled_embedding"], dtype=np.float32
|
| 355 |
+
).astype(self._np_vector_dtype, copy=False)
|
| 356 |
global_pooling = global_pooled.astype(self._np_vector_dtype, copy=False)
|
| 357 |
|
| 358 |
qdrant_points.append(
|
| 359 |
+
qdrant_models.PointStruct(
|
| 360 |
id=p["id"],
|
| 361 |
vector={
|
| 362 |
"initial": _to_list(initial),
|
|
|
|
| 368 |
)
|
| 369 |
)
|
| 370 |
return qdrant_points
|
| 371 |
+
|
| 372 |
# Upload with retry
|
| 373 |
for attempt in range(max_retries):
|
| 374 |
try:
|
|
|
|
| 414 |
if attempt < max_retries - 1:
|
| 415 |
if _is_cancelled():
|
| 416 |
return 0
|
| 417 |
+
time.sleep(2**attempt) # Exponential backoff
|
| 418 |
+
|
| 419 |
logger.error(f"❌ Upload failed after {max_retries} attempts")
|
| 420 |
return 0
|
| 421 |
+
|
| 422 |
def check_exists(self, chunk_id: str) -> bool:
|
| 423 |
"""Check if a point already exists."""
|
| 424 |
try:
|
|
|
|
| 431 |
return len(result) > 0
|
| 432 |
except Exception:
|
| 433 |
return False
|
| 434 |
+
|
| 435 |
def get_existing_ids(self, filename: str) -> Set[str]:
|
| 436 |
+
"""Get all point IDs for a specific file.
|
| 437 |
+
|
| 438 |
+
Requires a payload index on 'filename' field. If the index doesn't exist,
|
| 439 |
+
this method will attempt to create it automatically.
|
| 440 |
+
"""
|
| 441 |
existing_ids = set()
|
| 442 |
offset = None
|
| 443 |
+
|
| 444 |
+
try:
|
| 445 |
+
while True:
|
| 446 |
+
results = self.client.scroll(
|
| 447 |
+
collection_name=self.collection_name,
|
| 448 |
+
scroll_filter=Filter(
|
| 449 |
+
must=[FieldCondition(key="filename", match=MatchValue(value=filename))]
|
| 450 |
+
),
|
| 451 |
+
limit=100,
|
| 452 |
+
offset=offset,
|
| 453 |
+
with_payload=["page_number"],
|
| 454 |
+
with_vectors=False,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
points, next_offset = results
|
| 458 |
+
|
| 459 |
+
for point in points:
|
| 460 |
+
existing_ids.add(str(point.id))
|
| 461 |
+
|
| 462 |
+
if next_offset is None or len(points) == 0:
|
| 463 |
+
break
|
| 464 |
+
offset = next_offset
|
| 465 |
+
|
| 466 |
+
except Exception as e:
|
| 467 |
+
error_msg = str(e).lower()
|
| 468 |
+
if "index required" in error_msg or "index" in error_msg and "filename" in error_msg:
|
| 469 |
+
# Missing payload index - try to create it
|
| 470 |
+
logger.warning(
|
| 471 |
+
"⚠️ Missing 'filename' payload index. Creating it now... "
|
| 472 |
+
"(skip_existing requires this index for filtering)"
|
| 473 |
+
)
|
| 474 |
+
try:
|
| 475 |
+
self.client.create_payload_index(
|
| 476 |
+
collection_name=self.collection_name,
|
| 477 |
+
field_name="filename",
|
| 478 |
+
field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
|
| 479 |
+
)
|
| 480 |
+
logger.info(" ✅ Created 'filename' index. Retrying query...")
|
| 481 |
+
# Retry the query
|
| 482 |
+
return self.get_existing_ids(filename)
|
| 483 |
+
except Exception as idx_err:
|
| 484 |
+
logger.warning(f" ❌ Could not create index: {idx_err}")
|
| 485 |
+
logger.warning(" Returning empty set - all pages will be processed")
|
| 486 |
+
return set()
|
| 487 |
+
else:
|
| 488 |
+
logger.warning(f"⚠️ Error checking existing IDs: {e}")
|
| 489 |
+
return set()
|
| 490 |
+
|
| 491 |
return existing_ids
|
| 492 |
+
|
| 493 |
def get_collection_info(self) -> Optional[Dict[str, Any]]:
|
| 494 |
"""Get collection statistics."""
|
| 495 |
try:
|
| 496 |
info = self.client.get_collection(self.collection_name)
|
| 497 |
+
|
| 498 |
status = info.status
|
| 499 |
if hasattr(status, "value"):
|
| 500 |
status = status.value
|
| 501 |
+
|
| 502 |
indexed_count = getattr(info, "indexed_vectors_count", 0) or 0
|
| 503 |
if isinstance(indexed_count, dict):
|
| 504 |
indexed_count = sum(indexed_count.values())
|
| 505 |
+
|
| 506 |
return {
|
| 507 |
"status": str(status),
|
| 508 |
"points_count": getattr(info, "points_count", 0),
|
|
|
|
| 511 |
except Exception as e:
|
| 512 |
logger.warning(f"Could not get collection info: {e}")
|
| 513 |
return None
|
| 514 |
+
|
| 515 |
@staticmethod
|
| 516 |
def generate_point_id(filename: str, page_number: int) -> str:
|
| 517 |
"""
|
| 518 |
Generate deterministic point ID from filename and page.
|
| 519 |
+
|
| 520 |
Returns a valid UUID string.
|
| 521 |
"""
|
| 522 |
content = f"{filename}:page:{page_number}"
|
|
|
|
| 524 |
hex_str = hash_obj.hexdigest()[:32]
|
| 525 |
# Format as UUID
|
| 526 |
return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
|
|
|
|
|
|
visual_rag/preprocessing/__init__.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
from visual_rag.preprocessing.crop_empty import CropEmptyConfig, crop_empty
|
| 2 |
|
| 3 |
__all__ = ["CropEmptyConfig", "crop_empty"]
|
| 4 |
-
|
| 5 |
-
|
|
|
|
| 1 |
from visual_rag.preprocessing.crop_empty import CropEmptyConfig, crop_empty
|
| 2 |
|
| 3 |
__all__ = ["CropEmptyConfig", "crop_empty"]
|
|
|
|
|
|
visual_rag/preprocessing/crop_empty.py
CHANGED
|
@@ -20,7 +20,9 @@ class CropEmptyConfig:
|
|
| 20 |
uniform_rowcol_std_threshold: float = 0.0
|
| 21 |
|
| 22 |
|
| 23 |
-
def crop_empty(
|
|
|
|
|
|
|
| 24 |
img = image.convert("RGB")
|
| 25 |
arr = np.array(img)
|
| 26 |
intensity = arr.mean(axis=2)
|
|
@@ -31,7 +33,9 @@ def crop_empty(image: Image.Image, *, config: CropEmptyConfig) -> Tuple[Image.Im
|
|
| 31 |
pixels = intensity[i, :] if axis == 0 else intensity[:, i]
|
| 32 |
white = float(np.mean(pixels > config.color_threshold))
|
| 33 |
non_white = 1.0 - white
|
| 34 |
-
if float(config.uniform_rowcol_std_threshold) > 0.0 and float(np.std(pixels)) <= float(
|
|
|
|
|
|
|
| 35 |
continue
|
| 36 |
if (white < config.min_white_fraction) and (non_white > min_content_density_threshold):
|
| 37 |
return int(i)
|
|
@@ -43,7 +47,9 @@ def crop_empty(image: Image.Image, *, config: CropEmptyConfig) -> Tuple[Image.Im
|
|
| 43 |
pixels = intensity[i, :] if axis == 0 else intensity[:, i]
|
| 44 |
white = float(np.mean(pixels > config.color_threshold))
|
| 45 |
non_white = 1.0 - white
|
| 46 |
-
if float(config.uniform_rowcol_std_threshold) > 0.0 and float(np.std(pixels)) <= float(
|
|
|
|
|
|
|
| 47 |
continue
|
| 48 |
if (white < config.min_white_fraction) and (non_white > min_content_density_threshold):
|
| 49 |
return int(i + 1)
|
|
@@ -53,8 +59,12 @@ def crop_empty(image: Image.Image, *, config: CropEmptyConfig) -> Tuple[Image.Im
|
|
| 53 |
left = _find_border_start(1, min_content_density_threshold=float(config.content_density_sides))
|
| 54 |
right = _find_border_end(1, min_content_density_threshold=float(config.content_density_sides))
|
| 55 |
|
| 56 |
-
main_text_end = _find_border_end(
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
bottom = main_text_end if config.remove_page_number else last_content_end
|
| 59 |
|
| 60 |
width, height = img.size
|
|
@@ -108,5 +118,3 @@ def crop_empty(image: Image.Image, *, config: CropEmptyConfig) -> Tuple[Image.Im
|
|
| 108 |
"uniform_rowcol_std_threshold": float(config.uniform_rowcol_std_threshold),
|
| 109 |
},
|
| 110 |
}
|
| 111 |
-
|
| 112 |
-
|
|
|
|
| 20 |
uniform_rowcol_std_threshold: float = 0.0
|
| 21 |
|
| 22 |
|
| 23 |
+
def crop_empty(
|
| 24 |
+
image: Image.Image, *, config: CropEmptyConfig
|
| 25 |
+
) -> Tuple[Image.Image, Dict[str, Any]]:
|
| 26 |
img = image.convert("RGB")
|
| 27 |
arr = np.array(img)
|
| 28 |
intensity = arr.mean(axis=2)
|
|
|
|
| 33 |
pixels = intensity[i, :] if axis == 0 else intensity[:, i]
|
| 34 |
white = float(np.mean(pixels > config.color_threshold))
|
| 35 |
non_white = 1.0 - white
|
| 36 |
+
if float(config.uniform_rowcol_std_threshold) > 0.0 and float(np.std(pixels)) <= float(
|
| 37 |
+
config.uniform_rowcol_std_threshold
|
| 38 |
+
):
|
| 39 |
continue
|
| 40 |
if (white < config.min_white_fraction) and (non_white > min_content_density_threshold):
|
| 41 |
return int(i)
|
|
|
|
| 47 |
pixels = intensity[i, :] if axis == 0 else intensity[:, i]
|
| 48 |
white = float(np.mean(pixels > config.color_threshold))
|
| 49 |
non_white = 1.0 - white
|
| 50 |
+
if float(config.uniform_rowcol_std_threshold) > 0.0 and float(np.std(pixels)) <= float(
|
| 51 |
+
config.uniform_rowcol_std_threshold
|
| 52 |
+
):
|
| 53 |
continue
|
| 54 |
if (white < config.min_white_fraction) and (non_white > min_content_density_threshold):
|
| 55 |
return int(i + 1)
|
|
|
|
| 59 |
left = _find_border_start(1, min_content_density_threshold=float(config.content_density_sides))
|
| 60 |
right = _find_border_end(1, min_content_density_threshold=float(config.content_density_sides))
|
| 61 |
|
| 62 |
+
main_text_end = _find_border_end(
|
| 63 |
+
0, min_content_density_threshold=float(config.content_density_main_text)
|
| 64 |
+
)
|
| 65 |
+
last_content_end = _find_border_end(
|
| 66 |
+
0, min_content_density_threshold=float(config.content_density_any)
|
| 67 |
+
)
|
| 68 |
bottom = main_text_end if config.remove_page_number else last_content_end
|
| 69 |
|
| 70 |
width, height = img.size
|
|
|
|
| 118 |
"uniform_rowcol_std_threshold": float(config.uniform_rowcol_std_threshold),
|
| 119 |
},
|
| 120 |
}
|
|
|
|
|
|
visual_rag/qdrant_admin.py
CHANGED
|
@@ -33,9 +33,16 @@ def _resolve_qdrant_connection(
|
|
| 33 |
import os
|
| 34 |
|
| 35 |
_maybe_load_dotenv()
|
| 36 |
-
resolved_url =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
if not resolved_url:
|
| 38 |
-
raise ValueError(
|
|
|
|
|
|
|
| 39 |
resolved_key = (
|
| 40 |
api_key
|
| 41 |
or os.getenv("SIGIR_QDRANT_KEY")
|
|
@@ -105,7 +112,11 @@ class QdrantAdmin:
|
|
| 105 |
from qdrant_client.http import models as m
|
| 106 |
|
| 107 |
hnsw_diff = m.HnswConfigDiff(**hnsw_config) if isinstance(hnsw_config, dict) else None
|
| 108 |
-
params_diff =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
if hnsw_diff is None and params_diff is None:
|
| 110 |
raise ValueError("No changes provided (pass hnsw_config and/or collection_params).")
|
| 111 |
return bool(
|
|
@@ -143,7 +154,9 @@ class QdrantAdmin:
|
|
| 143 |
|
| 144 |
missing = [str(k) for k in (vectors or {}).keys() if existing and str(k) not in existing]
|
| 145 |
if missing:
|
| 146 |
-
raise ValueError(
|
|
|
|
|
|
|
| 147 |
|
| 148 |
ok = True
|
| 149 |
for name, cfg in (vectors or {}).items():
|
|
@@ -158,13 +171,16 @@ class QdrantAdmin:
|
|
| 158 |
)
|
| 159 |
}
|
| 160 |
|
| 161 |
-
ok =
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
| 166 |
)
|
| 167 |
-
|
|
|
|
| 168 |
|
| 169 |
return ok
|
| 170 |
|
|
@@ -192,7 +208,9 @@ class QdrantAdmin:
|
|
| 192 |
vectors[str(vname)] = {"on_disk": True, "hnsw_config": {"on_disk": True}}
|
| 193 |
|
| 194 |
if vectors:
|
| 195 |
-
self.modify_collection_vector_config(
|
|
|
|
|
|
|
| 196 |
|
| 197 |
self.modify_collection_config(
|
| 198 |
collection_name=collection_name,
|
|
@@ -202,4 +220,3 @@ class QdrantAdmin:
|
|
| 202 |
)
|
| 203 |
|
| 204 |
return self.get_collection_info(collection_name=collection_name)
|
| 205 |
-
|
|
|
|
| 33 |
import os
|
| 34 |
|
| 35 |
_maybe_load_dotenv()
|
| 36 |
+
resolved_url = (
|
| 37 |
+
url
|
| 38 |
+
or os.getenv("SIGIR_QDRANT_URL")
|
| 39 |
+
or os.getenv("DEST_QDRANT_URL")
|
| 40 |
+
or os.getenv("QDRANT_URL")
|
| 41 |
+
)
|
| 42 |
if not resolved_url:
|
| 43 |
+
raise ValueError(
|
| 44 |
+
"Qdrant URL not set (pass url= or set SIGIR_QDRANT_URL/DEST_QDRANT_URL/QDRANT_URL)."
|
| 45 |
+
)
|
| 46 |
resolved_key = (
|
| 47 |
api_key
|
| 48 |
or os.getenv("SIGIR_QDRANT_KEY")
|
|
|
|
| 112 |
from qdrant_client.http import models as m
|
| 113 |
|
| 114 |
hnsw_diff = m.HnswConfigDiff(**hnsw_config) if isinstance(hnsw_config, dict) else None
|
| 115 |
+
params_diff = (
|
| 116 |
+
m.CollectionParamsDiff(**collection_params)
|
| 117 |
+
if isinstance(collection_params, dict)
|
| 118 |
+
else None
|
| 119 |
+
)
|
| 120 |
if hnsw_diff is None and params_diff is None:
|
| 121 |
raise ValueError("No changes provided (pass hnsw_config and/or collection_params).")
|
| 122 |
return bool(
|
|
|
|
| 154 |
|
| 155 |
missing = [str(k) for k in (vectors or {}).keys() if existing and str(k) not in existing]
|
| 156 |
if missing:
|
| 157 |
+
raise ValueError(
|
| 158 |
+
f"Vectors do not exist in collection '{collection_name}': {missing}. Existing: {sorted(existing)}"
|
| 159 |
+
)
|
| 160 |
|
| 161 |
ok = True
|
| 162 |
for name, cfg in (vectors or {}).items():
|
|
|
|
| 171 |
)
|
| 172 |
}
|
| 173 |
|
| 174 |
+
ok = (
|
| 175 |
+
bool(
|
| 176 |
+
self.client.update_collection(
|
| 177 |
+
collection_name=collection_name,
|
| 178 |
+
vectors_config=vectors_diff,
|
| 179 |
+
timeout=int(timeout) if timeout is not None else None,
|
| 180 |
+
)
|
| 181 |
)
|
| 182 |
+
and ok
|
| 183 |
+
)
|
| 184 |
|
| 185 |
return ok
|
| 186 |
|
|
|
|
| 208 |
vectors[str(vname)] = {"on_disk": True, "hnsw_config": {"on_disk": True}}
|
| 209 |
|
| 210 |
if vectors:
|
| 211 |
+
self.modify_collection_vector_config(
|
| 212 |
+
collection_name=collection_name, vectors=vectors, timeout=timeout
|
| 213 |
+
)
|
| 214 |
|
| 215 |
self.modify_collection_config(
|
| 216 |
collection_name=collection_name,
|
|
|
|
| 220 |
)
|
| 221 |
|
| 222 |
return self.get_collection_info(collection_name=collection_name)
|
|
|
visual_rag/retrieval/__init__.py
CHANGED
|
@@ -6,10 +6,10 @@ Components:
|
|
| 6 |
- SingleStageRetriever: Direct multi-vector or pooled search
|
| 7 |
"""
|
| 8 |
|
| 9 |
-
from visual_rag.retrieval.two_stage import TwoStageRetriever
|
| 10 |
-
from visual_rag.retrieval.single_stage import SingleStageRetriever
|
| 11 |
from visual_rag.retrieval.multi_vector import MultiVectorRetriever
|
|
|
|
| 12 |
from visual_rag.retrieval.three_stage import ThreeStageRetriever
|
|
|
|
| 13 |
|
| 14 |
__all__ = [
|
| 15 |
"TwoStageRetriever",
|
|
|
|
| 6 |
- SingleStageRetriever: Direct multi-vector or pooled search
|
| 7 |
"""
|
| 8 |
|
|
|
|
|
|
|
| 9 |
from visual_rag.retrieval.multi_vector import MultiVectorRetriever
|
| 10 |
+
from visual_rag.retrieval.single_stage import SingleStageRetriever
|
| 11 |
from visual_rag.retrieval.three_stage import ThreeStageRetriever
|
| 12 |
+
from visual_rag.retrieval.two_stage import TwoStageRetriever
|
| 13 |
|
| 14 |
__all__ = [
|
| 15 |
"TwoStageRetriever",
|
visual_rag/retrieval/multi_vector.py
CHANGED
|
@@ -2,18 +2,35 @@ import os
|
|
| 2 |
from typing import Any, Dict, List, Optional
|
| 3 |
from urllib.parse import urlparse
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from visual_rag.embedding.visual_embedder import VisualEmbedder
|
| 6 |
from visual_rag.retrieval.single_stage import SingleStageRetriever
|
| 7 |
-
from visual_rag.retrieval.two_stage import TwoStageRetriever
|
| 8 |
from visual_rag.retrieval.three_stage import ThreeStageRetriever
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class MultiVectorRetriever:
|
| 12 |
@staticmethod
|
| 13 |
def _maybe_load_dotenv() -> None:
|
| 14 |
-
|
| 15 |
-
from dotenv import load_dotenv
|
| 16 |
-
except ImportError:
|
| 17 |
return
|
| 18 |
if os.path.exists(".env"):
|
| 19 |
load_dotenv(".env")
|
|
@@ -33,83 +50,84 @@ class MultiVectorRetriever:
|
|
| 33 |
):
|
| 34 |
if qdrant_client is None:
|
| 35 |
self._maybe_load_dotenv()
|
| 36 |
-
|
| 37 |
-
from qdrant_client import QdrantClient
|
| 38 |
-
except ImportError as e:
|
| 39 |
raise ImportError(
|
| 40 |
"Qdrant client not installed. Install with: pip install visual-rag-toolkit[qdrant]"
|
| 41 |
-
)
|
| 42 |
|
| 43 |
qdrant_url = (
|
| 44 |
qdrant_url
|
| 45 |
-
or os.getenv("SIGIR_QDRANT_URL")
|
| 46 |
-
or os.getenv("DEST_QDRANT_URL")
|
| 47 |
or os.getenv("QDRANT_URL")
|
|
|
|
| 48 |
)
|
| 49 |
if not qdrant_url:
|
| 50 |
raise ValueError(
|
| 51 |
-
"QDRANT_URL is required (pass qdrant_url or set env var).
|
| 52 |
-
"You can also set DEST_QDRANT_URL to override."
|
| 53 |
)
|
| 54 |
|
| 55 |
qdrant_api_key = (
|
| 56 |
qdrant_api_key
|
| 57 |
-
or os.getenv("SIGIR_QDRANT_KEY")
|
| 58 |
-
or os.getenv("SIGIR_QDRANT_API_KEY")
|
| 59 |
-
or os.getenv("DEST_QDRANT_API_KEY")
|
| 60 |
or os.getenv("QDRANT_API_KEY")
|
|
|
|
| 61 |
)
|
| 62 |
|
| 63 |
grpc_port = None
|
| 64 |
if prefer_grpc:
|
| 65 |
try:
|
| 66 |
-
|
|
|
|
|
|
|
| 67 |
grpc_port = 6334
|
| 68 |
except Exception:
|
| 69 |
-
|
|
|
|
| 70 |
def _make_client(use_grpc: bool):
|
| 71 |
return QdrantClient(
|
| 72 |
url=qdrant_url,
|
| 73 |
api_key=qdrant_api_key,
|
|
|
|
| 74 |
prefer_grpc=bool(use_grpc),
|
| 75 |
grpc_port=grpc_port,
|
| 76 |
-
timeout=int(request_timeout),
|
| 77 |
check_compatibility=False,
|
| 78 |
)
|
| 79 |
|
| 80 |
-
|
| 81 |
if prefer_grpc:
|
| 82 |
try:
|
| 83 |
-
_ =
|
| 84 |
except Exception as e:
|
| 85 |
msg = str(e)
|
| 86 |
if "StatusCode.PERMISSION_DENIED" in msg or "http2 header with status: 403" in msg:
|
| 87 |
-
|
| 88 |
else:
|
| 89 |
raise
|
|
|
|
| 90 |
|
| 91 |
self.client = qdrant_client
|
| 92 |
self.collection_name = collection_name
|
|
|
|
| 93 |
self.embedder = embedder or VisualEmbedder(model_name=model_name)
|
| 94 |
|
| 95 |
self._two_stage = TwoStageRetriever(
|
| 96 |
-
|
| 97 |
-
collection_name=
|
| 98 |
-
request_timeout=
|
| 99 |
-
max_retries=
|
| 100 |
-
retry_sleep=
|
| 101 |
)
|
| 102 |
self._three_stage = ThreeStageRetriever(
|
| 103 |
-
|
| 104 |
-
collection_name=
|
| 105 |
-
request_timeout=
|
| 106 |
-
max_retries=
|
| 107 |
-
retry_sleep=
|
| 108 |
)
|
| 109 |
self._single_stage = SingleStageRetriever(
|
| 110 |
-
|
| 111 |
-
collection_name=
|
| 112 |
-
request_timeout=
|
|
|
|
|
|
|
| 113 |
)
|
| 114 |
|
| 115 |
def build_filter(
|
|
@@ -139,14 +157,10 @@ class MultiVectorRetriever:
|
|
| 139 |
return_embeddings: bool = False,
|
| 140 |
) -> List[Dict[str, Any]]:
|
| 141 |
q = self.embedder.embed_query(query)
|
| 142 |
-
|
| 143 |
-
import torch
|
| 144 |
-
except ImportError:
|
| 145 |
-
torch = None
|
| 146 |
-
if torch is not None and isinstance(q, torch.Tensor):
|
| 147 |
query_embedding = q.detach().cpu().numpy()
|
| 148 |
else:
|
| 149 |
-
query_embedding =
|
| 150 |
|
| 151 |
return self.search_embedded(
|
| 152 |
query_embedding=query_embedding,
|
|
@@ -175,27 +189,17 @@ class MultiVectorRetriever:
|
|
| 175 |
return self._single_stage.search(
|
| 176 |
query_embedding=query_embedding,
|
| 177 |
top_k=top_k,
|
| 178 |
-
strategy="multi_vector",
|
| 179 |
-
filter_obj=filter_obj,
|
| 180 |
-
)
|
| 181 |
-
|
| 182 |
-
if mode == "single_tiles":
|
| 183 |
-
return self._single_stage.search(
|
| 184 |
-
query_embedding=query_embedding,
|
| 185 |
-
top_k=top_k,
|
| 186 |
-
strategy="tiles_maxsim",
|
| 187 |
filter_obj=filter_obj,
|
|
|
|
| 188 |
)
|
| 189 |
-
|
| 190 |
-
if mode == "single_global":
|
| 191 |
return self._single_stage.search(
|
| 192 |
query_embedding=query_embedding,
|
| 193 |
top_k=top_k,
|
| 194 |
-
strategy="pooled_global",
|
| 195 |
filter_obj=filter_obj,
|
|
|
|
| 196 |
)
|
| 197 |
-
|
| 198 |
-
if mode == "two_stage":
|
| 199 |
return self._two_stage.search_server_side(
|
| 200 |
query_embedding=query_embedding,
|
| 201 |
top_k=top_k,
|
|
@@ -203,18 +207,14 @@ class MultiVectorRetriever:
|
|
| 203 |
filter_obj=filter_obj,
|
| 204 |
stage1_mode=stage1_mode,
|
| 205 |
)
|
| 206 |
-
|
| 207 |
-
if mode == "three_stage":
|
| 208 |
-
s1 = int(stage1_k) if stage1_k is not None else 1000
|
| 209 |
-
s2 = int(stage2_k) if stage2_k is not None else 300
|
| 210 |
return self._three_stage.search_server_side(
|
| 211 |
query_embedding=query_embedding,
|
| 212 |
top_k=top_k,
|
| 213 |
-
stage1_k=
|
| 214 |
-
stage2_k=
|
| 215 |
filter_obj=filter_obj,
|
|
|
|
| 216 |
)
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
|
|
|
| 2 |
from typing import Any, Dict, List, Optional
|
| 3 |
from urllib.parse import urlparse
|
| 4 |
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
|
| 11 |
+
DOTENV_AVAILABLE = True
|
| 12 |
+
except ImportError:
|
| 13 |
+
DOTENV_AVAILABLE = False
|
| 14 |
+
load_dotenv = None
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from qdrant_client import QdrantClient
|
| 18 |
+
|
| 19 |
+
QDRANT_AVAILABLE = True
|
| 20 |
+
except ImportError:
|
| 21 |
+
QDRANT_AVAILABLE = False
|
| 22 |
+
QdrantClient = None
|
| 23 |
+
|
| 24 |
from visual_rag.embedding.visual_embedder import VisualEmbedder
|
| 25 |
from visual_rag.retrieval.single_stage import SingleStageRetriever
|
|
|
|
| 26 |
from visual_rag.retrieval.three_stage import ThreeStageRetriever
|
| 27 |
+
from visual_rag.retrieval.two_stage import TwoStageRetriever
|
| 28 |
|
| 29 |
|
| 30 |
class MultiVectorRetriever:
|
| 31 |
@staticmethod
|
| 32 |
def _maybe_load_dotenv() -> None:
|
| 33 |
+
if not DOTENV_AVAILABLE:
|
|
|
|
|
|
|
| 34 |
return
|
| 35 |
if os.path.exists(".env"):
|
| 36 |
load_dotenv(".env")
|
|
|
|
| 50 |
):
|
| 51 |
if qdrant_client is None:
|
| 52 |
self._maybe_load_dotenv()
|
| 53 |
+
if not QDRANT_AVAILABLE:
|
|
|
|
|
|
|
| 54 |
raise ImportError(
|
| 55 |
"Qdrant client not installed. Install with: pip install visual-rag-toolkit[qdrant]"
|
| 56 |
+
)
|
| 57 |
|
| 58 |
qdrant_url = (
|
| 59 |
qdrant_url
|
|
|
|
|
|
|
| 60 |
or os.getenv("QDRANT_URL")
|
| 61 |
+
or os.getenv("SIGIR_QDRANT_URL") # legacy
|
| 62 |
)
|
| 63 |
if not qdrant_url:
|
| 64 |
raise ValueError(
|
| 65 |
+
"QDRANT_URL is required (pass qdrant_url or set env var)."
|
|
|
|
| 66 |
)
|
| 67 |
|
| 68 |
qdrant_api_key = (
|
| 69 |
qdrant_api_key
|
|
|
|
|
|
|
|
|
|
| 70 |
or os.getenv("QDRANT_API_KEY")
|
| 71 |
+
or os.getenv("SIGIR_QDRANT_KEY") # legacy
|
| 72 |
)
|
| 73 |
|
| 74 |
grpc_port = None
|
| 75 |
if prefer_grpc:
|
| 76 |
try:
|
| 77 |
+
parsed = urlparse(qdrant_url)
|
| 78 |
+
port = parsed.port
|
| 79 |
+
if port == 6333:
|
| 80 |
grpc_port = 6334
|
| 81 |
except Exception:
|
| 82 |
+
pass
|
| 83 |
+
|
| 84 |
def _make_client(use_grpc: bool):
|
| 85 |
return QdrantClient(
|
| 86 |
url=qdrant_url,
|
| 87 |
api_key=qdrant_api_key,
|
| 88 |
+
timeout=request_timeout,
|
| 89 |
prefer_grpc=bool(use_grpc),
|
| 90 |
grpc_port=grpc_port,
|
|
|
|
| 91 |
check_compatibility=False,
|
| 92 |
)
|
| 93 |
|
| 94 |
+
client = _make_client(prefer_grpc)
|
| 95 |
if prefer_grpc:
|
| 96 |
try:
|
| 97 |
+
_ = client.get_collections()
|
| 98 |
except Exception as e:
|
| 99 |
msg = str(e)
|
| 100 |
if "StatusCode.PERMISSION_DENIED" in msg or "http2 header with status: 403" in msg:
|
| 101 |
+
client = _make_client(False)
|
| 102 |
else:
|
| 103 |
raise
|
| 104 |
+
qdrant_client = client
|
| 105 |
|
| 106 |
self.client = qdrant_client
|
| 107 |
self.collection_name = collection_name
|
| 108 |
+
|
| 109 |
self.embedder = embedder or VisualEmbedder(model_name=model_name)
|
| 110 |
|
| 111 |
self._two_stage = TwoStageRetriever(
|
| 112 |
+
qdrant_client=qdrant_client,
|
| 113 |
+
collection_name=collection_name,
|
| 114 |
+
request_timeout=request_timeout,
|
| 115 |
+
max_retries=max_retries,
|
| 116 |
+
retry_sleep=retry_sleep,
|
| 117 |
)
|
| 118 |
self._three_stage = ThreeStageRetriever(
|
| 119 |
+
qdrant_client=qdrant_client,
|
| 120 |
+
collection_name=collection_name,
|
| 121 |
+
request_timeout=request_timeout,
|
| 122 |
+
max_retries=max_retries,
|
| 123 |
+
retry_sleep=retry_sleep,
|
| 124 |
)
|
| 125 |
self._single_stage = SingleStageRetriever(
|
| 126 |
+
qdrant_client=qdrant_client,
|
| 127 |
+
collection_name=collection_name,
|
| 128 |
+
request_timeout=request_timeout,
|
| 129 |
+
max_retries=max_retries,
|
| 130 |
+
retry_sleep=retry_sleep,
|
| 131 |
)
|
| 132 |
|
| 133 |
def build_filter(
|
|
|
|
| 157 |
return_embeddings: bool = False,
|
| 158 |
) -> List[Dict[str, Any]]:
|
| 159 |
q = self.embedder.embed_query(query)
|
| 160 |
+
if isinstance(q, torch.Tensor):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
query_embedding = q.detach().cpu().numpy()
|
| 162 |
else:
|
| 163 |
+
query_embedding = np.asarray(q)
|
| 164 |
|
| 165 |
return self.search_embedded(
|
| 166 |
query_embedding=query_embedding,
|
|
|
|
| 189 |
return self._single_stage.search(
|
| 190 |
query_embedding=query_embedding,
|
| 191 |
top_k=top_k,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
filter_obj=filter_obj,
|
| 193 |
+
using="initial",
|
| 194 |
)
|
| 195 |
+
elif mode == "single_pooled":
|
|
|
|
| 196 |
return self._single_stage.search(
|
| 197 |
query_embedding=query_embedding,
|
| 198 |
top_k=top_k,
|
|
|
|
| 199 |
filter_obj=filter_obj,
|
| 200 |
+
using="mean_pooling",
|
| 201 |
)
|
| 202 |
+
elif mode == "two_stage":
|
|
|
|
| 203 |
return self._two_stage.search_server_side(
|
| 204 |
query_embedding=query_embedding,
|
| 205 |
top_k=top_k,
|
|
|
|
| 207 |
filter_obj=filter_obj,
|
| 208 |
stage1_mode=stage1_mode,
|
| 209 |
)
|
| 210 |
+
elif mode == "three_stage":
|
|
|
|
|
|
|
|
|
|
| 211 |
return self._three_stage.search_server_side(
|
| 212 |
query_embedding=query_embedding,
|
| 213 |
top_k=top_k,
|
| 214 |
+
stage1_k=stage1_k,
|
| 215 |
+
stage2_k=stage2_k,
|
| 216 |
filter_obj=filter_obj,
|
| 217 |
+
stage1_mode=stage1_mode,
|
| 218 |
)
|
| 219 |
+
else:
|
| 220 |
+
raise ValueError(f"Unknown mode: {mode}")
|
|
|
|
|
|
visual_rag/retrieval/single_stage.py
CHANGED
|
@@ -9,7 +9,8 @@ Use when:
|
|
| 9 |
"""
|
| 10 |
|
| 11 |
import logging
|
| 12 |
-
from typing import
|
|
|
|
| 13 |
import numpy as np
|
| 14 |
import torch
|
| 15 |
|
|
@@ -19,22 +20,22 @@ logger = logging.getLogger(__name__)
|
|
| 19 |
class SingleStageRetriever:
|
| 20 |
"""
|
| 21 |
Single-stage visual document retrieval using native Qdrant search.
|
| 22 |
-
|
| 23 |
Supports strategies:
|
| 24 |
- multi_vector: Native MaxSim on full embeddings (using="initial")
|
| 25 |
- tiles_maxsim: Native MaxSim between query tokens and tile vectors (using="mean_pooling")
|
| 26 |
- pooled_tile: Pooled query vs tile vectors (using="mean_pooling")
|
| 27 |
- pooled_global: Pooled query vs global pooled doc vector (using="global_pooling")
|
| 28 |
-
|
| 29 |
Args:
|
| 30 |
qdrant_client: Connected Qdrant client
|
| 31 |
collection_name: Name of the Qdrant collection
|
| 32 |
-
|
| 33 |
Example:
|
| 34 |
>>> retriever = SingleStageRetriever(client, "my_collection")
|
| 35 |
>>> results = retriever.search(query, top_k=10)
|
| 36 |
"""
|
| 37 |
-
|
| 38 |
def __init__(
|
| 39 |
self,
|
| 40 |
qdrant_client,
|
|
@@ -44,7 +45,7 @@ class SingleStageRetriever:
|
|
| 44 |
self.client = qdrant_client
|
| 45 |
self.collection_name = collection_name
|
| 46 |
self.request_timeout = int(request_timeout)
|
| 47 |
-
|
| 48 |
def search(
|
| 49 |
self,
|
| 50 |
query_embedding: Union[torch.Tensor, np.ndarray],
|
|
@@ -54,47 +55,47 @@ class SingleStageRetriever:
|
|
| 54 |
) -> List[Dict[str, Any]]:
|
| 55 |
"""
|
| 56 |
Single-stage search with configurable strategy.
|
| 57 |
-
|
| 58 |
Args:
|
| 59 |
query_embedding: Query embeddings [num_tokens, dim]
|
| 60 |
top_k: Number of results
|
| 61 |
strategy: "multi_vector", "tiles_maxsim", "pooled_tile", or "pooled_global"
|
| 62 |
filter_obj: Qdrant filter
|
| 63 |
-
|
| 64 |
Returns:
|
| 65 |
List of results with scores and metadata
|
| 66 |
"""
|
| 67 |
query_np = self._to_numpy(query_embedding)
|
| 68 |
-
|
| 69 |
if strategy == "multi_vector":
|
| 70 |
# Native multi-vector MaxSim
|
| 71 |
vector_name = "initial"
|
| 72 |
query_vector = query_np.tolist()
|
| 73 |
logger.debug(f"🎯 Multi-vector search on '{vector_name}'")
|
| 74 |
-
|
| 75 |
elif strategy == "tiles_maxsim":
|
| 76 |
# Native multi-vector MaxSim against tile vectors
|
| 77 |
vector_name = "mean_pooling"
|
| 78 |
query_vector = query_np.tolist()
|
| 79 |
logger.debug(f"🎯 Tile MaxSim search on '{vector_name}'")
|
| 80 |
-
|
| 81 |
elif strategy == "pooled_tile":
|
| 82 |
# Tile-level pooled
|
| 83 |
vector_name = "mean_pooling"
|
| 84 |
query_pooled = query_np.mean(axis=0)
|
| 85 |
query_vector = query_pooled.tolist()
|
| 86 |
logger.debug(f"🔍 Tile-pooled search on '{vector_name}'")
|
| 87 |
-
|
| 88 |
elif strategy == "pooled_global":
|
| 89 |
# Global pooled vector (single vector)
|
| 90 |
vector_name = "global_pooling"
|
| 91 |
query_pooled = query_np.mean(axis=0)
|
| 92 |
query_vector = query_pooled.tolist()
|
| 93 |
logger.debug(f"🔍 Global-pooled search on '{vector_name}'")
|
| 94 |
-
|
| 95 |
else:
|
| 96 |
raise ValueError(f"Unknown strategy: {strategy}")
|
| 97 |
-
|
| 98 |
results = self.client.query_points(
|
| 99 |
collection_name=self.collection_name,
|
| 100 |
query=query_vector,
|
|
@@ -105,7 +106,7 @@ class SingleStageRetriever:
|
|
| 105 |
with_vectors=False,
|
| 106 |
timeout=self.request_timeout,
|
| 107 |
).points
|
| 108 |
-
|
| 109 |
return [
|
| 110 |
{
|
| 111 |
"id": r.id,
|
|
@@ -115,7 +116,7 @@ class SingleStageRetriever:
|
|
| 115 |
}
|
| 116 |
for r in results
|
| 117 |
]
|
| 118 |
-
|
| 119 |
def _to_numpy(self, embedding: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
|
| 120 |
"""Convert embedding to numpy array."""
|
| 121 |
if isinstance(embedding, torch.Tensor):
|
|
@@ -123,5 +124,3 @@ class SingleStageRetriever:
|
|
| 123 |
return embedding.cpu().float().numpy()
|
| 124 |
return embedding.cpu().numpy()
|
| 125 |
return np.array(embedding, dtype=np.float32)
|
| 126 |
-
|
| 127 |
-
|
|
|
|
| 9 |
"""
|
| 10 |
|
| 11 |
import logging
|
| 12 |
+
from typing import Any, Dict, List, Union
|
| 13 |
+
|
| 14 |
import numpy as np
|
| 15 |
import torch
|
| 16 |
|
|
|
|
| 20 |
class SingleStageRetriever:
|
| 21 |
"""
|
| 22 |
Single-stage visual document retrieval using native Qdrant search.
|
| 23 |
+
|
| 24 |
Supports strategies:
|
| 25 |
- multi_vector: Native MaxSim on full embeddings (using="initial")
|
| 26 |
- tiles_maxsim: Native MaxSim between query tokens and tile vectors (using="mean_pooling")
|
| 27 |
- pooled_tile: Pooled query vs tile vectors (using="mean_pooling")
|
| 28 |
- pooled_global: Pooled query vs global pooled doc vector (using="global_pooling")
|
| 29 |
+
|
| 30 |
Args:
|
| 31 |
qdrant_client: Connected Qdrant client
|
| 32 |
collection_name: Name of the Qdrant collection
|
| 33 |
+
|
| 34 |
Example:
|
| 35 |
>>> retriever = SingleStageRetriever(client, "my_collection")
|
| 36 |
>>> results = retriever.search(query, top_k=10)
|
| 37 |
"""
|
| 38 |
+
|
| 39 |
def __init__(
|
| 40 |
self,
|
| 41 |
qdrant_client,
|
|
|
|
| 45 |
self.client = qdrant_client
|
| 46 |
self.collection_name = collection_name
|
| 47 |
self.request_timeout = int(request_timeout)
|
| 48 |
+
|
| 49 |
def search(
|
| 50 |
self,
|
| 51 |
query_embedding: Union[torch.Tensor, np.ndarray],
|
|
|
|
| 55 |
) -> List[Dict[str, Any]]:
|
| 56 |
"""
|
| 57 |
Single-stage search with configurable strategy.
|
| 58 |
+
|
| 59 |
Args:
|
| 60 |
query_embedding: Query embeddings [num_tokens, dim]
|
| 61 |
top_k: Number of results
|
| 62 |
strategy: "multi_vector", "tiles_maxsim", "pooled_tile", or "pooled_global"
|
| 63 |
filter_obj: Qdrant filter
|
| 64 |
+
|
| 65 |
Returns:
|
| 66 |
List of results with scores and metadata
|
| 67 |
"""
|
| 68 |
query_np = self._to_numpy(query_embedding)
|
| 69 |
+
|
| 70 |
if strategy == "multi_vector":
|
| 71 |
# Native multi-vector MaxSim
|
| 72 |
vector_name = "initial"
|
| 73 |
query_vector = query_np.tolist()
|
| 74 |
logger.debug(f"🎯 Multi-vector search on '{vector_name}'")
|
| 75 |
+
|
| 76 |
elif strategy == "tiles_maxsim":
|
| 77 |
# Native multi-vector MaxSim against tile vectors
|
| 78 |
vector_name = "mean_pooling"
|
| 79 |
query_vector = query_np.tolist()
|
| 80 |
logger.debug(f"🎯 Tile MaxSim search on '{vector_name}'")
|
| 81 |
+
|
| 82 |
elif strategy == "pooled_tile":
|
| 83 |
# Tile-level pooled
|
| 84 |
vector_name = "mean_pooling"
|
| 85 |
query_pooled = query_np.mean(axis=0)
|
| 86 |
query_vector = query_pooled.tolist()
|
| 87 |
logger.debug(f"🔍 Tile-pooled search on '{vector_name}'")
|
| 88 |
+
|
| 89 |
elif strategy == "pooled_global":
|
| 90 |
# Global pooled vector (single vector)
|
| 91 |
vector_name = "global_pooling"
|
| 92 |
query_pooled = query_np.mean(axis=0)
|
| 93 |
query_vector = query_pooled.tolist()
|
| 94 |
logger.debug(f"🔍 Global-pooled search on '{vector_name}'")
|
| 95 |
+
|
| 96 |
else:
|
| 97 |
raise ValueError(f"Unknown strategy: {strategy}")
|
| 98 |
+
|
| 99 |
results = self.client.query_points(
|
| 100 |
collection_name=self.collection_name,
|
| 101 |
query=query_vector,
|
|
|
|
| 106 |
with_vectors=False,
|
| 107 |
timeout=self.request_timeout,
|
| 108 |
).points
|
| 109 |
+
|
| 110 |
return [
|
| 111 |
{
|
| 112 |
"id": r.id,
|
|
|
|
| 116 |
}
|
| 117 |
for r in results
|
| 118 |
]
|
| 119 |
+
|
| 120 |
def _to_numpy(self, embedding: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
|
| 121 |
"""Convert embedding to numpy array."""
|
| 122 |
if isinstance(embedding, torch.Tensor):
|
|
|
|
| 124 |
return embedding.cpu().float().numpy()
|
| 125 |
return embedding.cpu().numpy()
|
| 126 |
return np.array(embedding, dtype=np.float32)
|
|
|
|
|
|
visual_rag/retrieval/three_stage.py
CHANGED
|
@@ -43,7 +43,7 @@ class ThreeStageRetriever:
|
|
| 43 |
last_err = e
|
| 44 |
if attempt >= self.max_retries - 1:
|
| 45 |
break
|
| 46 |
-
time.sleep(self.retry_sleep * (2
|
| 47 |
if last_err is not None:
|
| 48 |
raise last_err
|
| 49 |
|
|
@@ -171,4 +171,3 @@ class ThreeStageRetriever:
|
|
| 171 |
}
|
| 172 |
)
|
| 173 |
return out
|
| 174 |
-
|
|
|
|
| 43 |
last_err = e
|
| 44 |
if attempt >= self.max_retries - 1:
|
| 45 |
break
|
| 46 |
+
time.sleep(self.retry_sleep * (2**attempt))
|
| 47 |
if last_err is not None:
|
| 48 |
raise last_err
|
| 49 |
|
|
|
|
| 171 |
}
|
| 172 |
)
|
| 173 |
return out
|
|
|
visual_rag/retrieval/two_stage.py
CHANGED
|
@@ -17,47 +17,54 @@ Research Context:
|
|
| 17 |
"""
|
| 18 |
|
| 19 |
import logging
|
| 20 |
-
|
|
|
|
|
|
|
| 21 |
import numpy as np
|
| 22 |
import torch
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
logger = logging.getLogger(__name__)
|
| 25 |
|
| 26 |
|
| 27 |
class TwoStageRetriever:
|
| 28 |
"""
|
| 29 |
Two-stage visual document retrieval with pooling and reranking.
|
| 30 |
-
|
| 31 |
Stage 1 (Prefetch):
|
| 32 |
Uses tile-level mean-pooled vectors for fast HNSW search.
|
| 33 |
Retrieves prefetch_k candidates (e.g., 100-500).
|
| 34 |
-
|
| 35 |
Stage 2 (Rerank):
|
| 36 |
Fetches full multi-vector embeddings for candidates.
|
| 37 |
Computes exact MaxSim scores for precise ranking.
|
| 38 |
Returns top_k results (e.g., 10).
|
| 39 |
-
|
| 40 |
Args:
|
| 41 |
qdrant_client: Connected Qdrant client
|
| 42 |
collection_name: Name of the Qdrant collection
|
| 43 |
full_vector_name: Name of full multi-vector field (default: "initial")
|
| 44 |
pooled_vector_name: Name of pooled vector field (default: "mean_pooling")
|
| 45 |
-
|
| 46 |
Example:
|
| 47 |
>>> retriever = TwoStageRetriever(client, "my_collection")
|
| 48 |
-
>>>
|
| 49 |
>>> # Two-stage search: prefetch 200, return top 10
|
| 50 |
>>> results = retriever.search(
|
| 51 |
... query_embedding=query,
|
| 52 |
... top_k=10,
|
| 53 |
... prefetch_k=200,
|
| 54 |
... )
|
| 55 |
-
>>>
|
| 56 |
>>> # Compare latency:
|
| 57 |
>>> # Full MaxSim (1000 docs): ~500ms
|
| 58 |
>>> # Two-stage (200→10): ~50ms
|
| 59 |
"""
|
| 60 |
-
|
| 61 |
def __init__(
|
| 62 |
self,
|
| 63 |
qdrant_client,
|
|
@@ -81,8 +88,6 @@ class TwoStageRetriever:
|
|
| 81 |
self.retry_sleep = float(retry_sleep)
|
| 82 |
|
| 83 |
def _retry_call(self, fn):
|
| 84 |
-
import time
|
| 85 |
-
|
| 86 |
last_err = None
|
| 87 |
for attempt in range(self.max_retries):
|
| 88 |
try:
|
|
@@ -91,7 +96,7 @@ class TwoStageRetriever:
|
|
| 91 |
last_err = e
|
| 92 |
if attempt >= self.max_retries - 1:
|
| 93 |
break
|
| 94 |
-
time.sleep(self.retry_sleep * (2
|
| 95 |
if last_err is not None:
|
| 96 |
raise last_err
|
| 97 |
|
|
@@ -105,27 +110,25 @@ class TwoStageRetriever:
|
|
| 105 |
) -> List[Dict[str, Any]]:
|
| 106 |
"""
|
| 107 |
Two-stage retrieval using Qdrant's native prefetch (all server-side).
|
| 108 |
-
|
| 109 |
This is MUCH faster than search() because it avoids network transfer
|
| 110 |
of large multi-vector embeddings. All computation happens in Qdrant.
|
| 111 |
-
|
| 112 |
Args:
|
| 113 |
query_embedding: Query embeddings [num_tokens, dim]
|
| 114 |
top_k: Final number of results
|
| 115 |
prefetch_k: Candidates for stage 1 (default: 10x top_k)
|
| 116 |
filter_obj: Qdrant filter
|
| 117 |
stage1_mode: How to do stage 1 prefetch
|
| 118 |
-
|
| 119 |
Returns:
|
| 120 |
List of results with scores
|
| 121 |
"""
|
| 122 |
-
from qdrant_client.http import models
|
| 123 |
-
|
| 124 |
query_np = self._to_numpy(query_embedding)
|
| 125 |
-
|
| 126 |
if prefetch_k is None:
|
| 127 |
prefetch_k = max(100, top_k * 10)
|
| 128 |
-
|
| 129 |
if stage1_mode == "pooled_query_vs_tiles":
|
| 130 |
prefetch_query = query_np.mean(axis=0).tolist()
|
| 131 |
prefetch_using = self.pooled_vector_name
|
|
@@ -143,9 +146,9 @@ class TwoStageRetriever:
|
|
| 143 |
prefetch_using = self.global_vector_name
|
| 144 |
else:
|
| 145 |
raise ValueError(f"Unknown stage1_mode: {stage1_mode}")
|
| 146 |
-
|
| 147 |
rerank_query = query_np.tolist()
|
| 148 |
-
|
| 149 |
def _do_query():
|
| 150 |
return self.client.query_points(
|
| 151 |
collection_name=self.collection_name,
|
|
@@ -154,9 +157,9 @@ class TwoStageRetriever:
|
|
| 154 |
limit=top_k,
|
| 155 |
query_filter=filter_obj,
|
| 156 |
with_payload=True,
|
| 157 |
-
search_params=
|
| 158 |
prefetch=[
|
| 159 |
-
|
| 160 |
query=prefetch_query,
|
| 161 |
using=prefetch_using,
|
| 162 |
limit=prefetch_k,
|
|
@@ -164,9 +167,9 @@ class TwoStageRetriever:
|
|
| 164 |
],
|
| 165 |
timeout=self.request_timeout,
|
| 166 |
).points
|
| 167 |
-
|
| 168 |
results = self._retry_call(_do_query)
|
| 169 |
-
|
| 170 |
return [
|
| 171 |
{
|
| 172 |
"id": r.id,
|
|
@@ -177,7 +180,7 @@ class TwoStageRetriever:
|
|
| 177 |
}
|
| 178 |
for r in results
|
| 179 |
]
|
| 180 |
-
|
| 181 |
def search(
|
| 182 |
self,
|
| 183 |
query_embedding: Union[torch.Tensor, np.ndarray],
|
|
@@ -190,7 +193,7 @@ class TwoStageRetriever:
|
|
| 190 |
) -> List[Dict[str, Any]]:
|
| 191 |
"""
|
| 192 |
Two-stage retrieval: prefetch with pooling, rerank with MaxSim.
|
| 193 |
-
|
| 194 |
Args:
|
| 195 |
query_embedding: Query embeddings [num_tokens, dim]
|
| 196 |
top_k: Final number of results to return
|
|
@@ -202,7 +205,7 @@ class TwoStageRetriever:
|
|
| 202 |
- "pooled_query_vs_tiles": pool query to 1×dim and search tile vectors (using="mean_pooling")
|
| 203 |
- "tokens_vs_tiles": search tile vectors with full query tokens (using="mean_pooling")
|
| 204 |
- "pooled_query_vs_global": pool query to 1×dim and search global pooled doc vectors (using="global_pooling")
|
| 205 |
-
|
| 206 |
Returns:
|
| 207 |
List of results with scores and metadata:
|
| 208 |
[
|
|
@@ -218,11 +221,11 @@ class TwoStageRetriever:
|
|
| 218 |
"""
|
| 219 |
# Convert to numpy
|
| 220 |
query_np = self._to_numpy(query_embedding)
|
| 221 |
-
|
| 222 |
# Auto-set prefetch_k
|
| 223 |
if prefetch_k is None:
|
| 224 |
prefetch_k = max(100, top_k * 10)
|
| 225 |
-
|
| 226 |
# Stage 1: Prefetch with pooled vectors
|
| 227 |
logger.info(f"🔍 Stage 1: Prefetching {prefetch_k} candidates ({stage1_mode})")
|
| 228 |
candidates = self._stage1_prefetch(
|
|
@@ -231,16 +234,16 @@ class TwoStageRetriever:
|
|
| 231 |
filter_obj=filter_obj,
|
| 232 |
stage1_mode=stage1_mode,
|
| 233 |
)
|
| 234 |
-
|
| 235 |
if not candidates:
|
| 236 |
logger.warning("No candidates found in stage 1")
|
| 237 |
return []
|
| 238 |
-
|
| 239 |
logger.info(f"✅ Stage 1: Retrieved {len(candidates)} candidates")
|
| 240 |
-
|
| 241 |
# Stage 2: Rerank with full embeddings
|
| 242 |
if use_reranking and len(candidates) > top_k:
|
| 243 |
-
logger.info(
|
| 244 |
results = self._stage2_rerank(
|
| 245 |
query_np=query_np,
|
| 246 |
candidates=candidates,
|
|
@@ -254,9 +257,9 @@ class TwoStageRetriever:
|
|
| 254 |
for r in results:
|
| 255 |
r["score_final"] = r["score_stage1"]
|
| 256 |
logger.info(f"⏭️ Skipping reranking, returning top {len(results)}")
|
| 257 |
-
|
| 258 |
return results
|
| 259 |
-
|
| 260 |
def search_single_stage(
|
| 261 |
self,
|
| 262 |
query_embedding: Union[torch.Tensor, np.ndarray],
|
|
@@ -266,18 +269,18 @@ class TwoStageRetriever:
|
|
| 266 |
) -> List[Dict[str, Any]]:
|
| 267 |
"""
|
| 268 |
Single-stage search (either pooled or full multi-vector).
|
| 269 |
-
|
| 270 |
Args:
|
| 271 |
query_embedding: Query embeddings
|
| 272 |
top_k: Number of results
|
| 273 |
filter_obj: Qdrant filter
|
| 274 |
use_pooling: Use pooled vectors (faster) or full (more accurate)
|
| 275 |
-
|
| 276 |
Returns:
|
| 277 |
List of results
|
| 278 |
"""
|
| 279 |
query_np = self._to_numpy(query_embedding)
|
| 280 |
-
|
| 281 |
if use_pooling:
|
| 282 |
# Pool query and search pooled vectors
|
| 283 |
query_pooled = query_np.mean(axis=0)
|
|
@@ -289,7 +292,7 @@ class TwoStageRetriever:
|
|
| 289 |
vector_name = self.full_vector_name
|
| 290 |
query_vector = query_np.tolist()
|
| 291 |
logger.info(f"🎯 Multi-vector search: {vector_name}")
|
| 292 |
-
|
| 293 |
results = self.client.query_points(
|
| 294 |
collection_name=self.collection_name,
|
| 295 |
query=query_vector,
|
|
@@ -300,7 +303,7 @@ class TwoStageRetriever:
|
|
| 300 |
with_vectors=False,
|
| 301 |
timeout=120,
|
| 302 |
).points
|
| 303 |
-
|
| 304 |
return [
|
| 305 |
{
|
| 306 |
"id": r.id,
|
|
@@ -310,7 +313,7 @@ class TwoStageRetriever:
|
|
| 310 |
}
|
| 311 |
for r in results
|
| 312 |
]
|
| 313 |
-
|
| 314 |
def _stage1_prefetch(
|
| 315 |
self,
|
| 316 |
query_np: np.ndarray,
|
|
@@ -330,7 +333,7 @@ class TwoStageRetriever:
|
|
| 330 |
vector_name = self.global_vector_name
|
| 331 |
else:
|
| 332 |
raise ValueError(f"Unknown stage1_mode: {stage1_mode}")
|
| 333 |
-
|
| 334 |
def _do_query():
|
| 335 |
return self.client.query_points(
|
| 336 |
collection_name=self.collection_name,
|
|
@@ -344,7 +347,7 @@ class TwoStageRetriever:
|
|
| 344 |
).points
|
| 345 |
|
| 346 |
results = self._retry_call(_do_query)
|
| 347 |
-
|
| 348 |
return [
|
| 349 |
{
|
| 350 |
"id": r.id,
|
|
@@ -353,7 +356,7 @@ class TwoStageRetriever:
|
|
| 353 |
}
|
| 354 |
for r in results
|
| 355 |
]
|
| 356 |
-
|
| 357 |
def _stage2_rerank(
|
| 358 |
self,
|
| 359 |
query_np: np.ndarray,
|
|
@@ -362,11 +365,9 @@ class TwoStageRetriever:
|
|
| 362 |
return_embeddings: bool = False,
|
| 363 |
) -> List[Dict[str, Any]]:
|
| 364 |
"""Stage 2: Rerank with full multi-vector MaxSim scoring."""
|
| 365 |
-
from visual_rag.embedding.pooling import compute_maxsim_score
|
| 366 |
-
|
| 367 |
# Fetch full embeddings for candidates
|
| 368 |
candidate_ids = [c["id"] for c in candidates]
|
| 369 |
-
|
| 370 |
# Retrieve points with vectors
|
| 371 |
def _do_retrieve():
|
| 372 |
return self.client.retrieve(
|
|
@@ -378,7 +379,7 @@ class TwoStageRetriever:
|
|
| 378 |
)
|
| 379 |
|
| 380 |
points = self._retry_call(_do_retrieve)
|
| 381 |
-
|
| 382 |
# Build ID to embedding map
|
| 383 |
id_to_embedding = {}
|
| 384 |
for point in points:
|
|
@@ -386,13 +387,13 @@ class TwoStageRetriever:
|
|
| 386 |
id_to_embedding[point.id] = np.array(
|
| 387 |
point.vector[self.full_vector_name], dtype=np.float32
|
| 388 |
)
|
| 389 |
-
|
| 390 |
# Compute MaxSim scores
|
| 391 |
reranked = []
|
| 392 |
for candidate in candidates:
|
| 393 |
point_id = candidate["id"]
|
| 394 |
doc_embedding = id_to_embedding.get(point_id)
|
| 395 |
-
|
| 396 |
if doc_embedding is None:
|
| 397 |
# Fallback to stage 1 score
|
| 398 |
candidate["score_stage2"] = candidate["score_stage1"]
|
|
@@ -402,17 +403,17 @@ class TwoStageRetriever:
|
|
| 402 |
maxsim_score = compute_maxsim_score(query_np, doc_embedding)
|
| 403 |
candidate["score_stage2"] = maxsim_score
|
| 404 |
candidate["score_final"] = maxsim_score
|
| 405 |
-
|
| 406 |
if return_embeddings:
|
| 407 |
candidate["embedding"] = doc_embedding
|
| 408 |
-
|
| 409 |
reranked.append(candidate)
|
| 410 |
-
|
| 411 |
# Sort by final score (descending)
|
| 412 |
reranked.sort(key=lambda x: x["score_final"], reverse=True)
|
| 413 |
-
|
| 414 |
return reranked[:top_k]
|
| 415 |
-
|
| 416 |
def _to_numpy(self, embedding: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
|
| 417 |
"""Convert embedding to numpy array."""
|
| 418 |
if isinstance(embedding, torch.Tensor):
|
|
@@ -420,7 +421,7 @@ class TwoStageRetriever:
|
|
| 420 |
return embedding.cpu().float().numpy()
|
| 421 |
return embedding.cpu().numpy()
|
| 422 |
return np.array(embedding, dtype=np.float32)
|
| 423 |
-
|
| 424 |
def build_filter(
|
| 425 |
self,
|
| 426 |
year: Optional[Any] = None,
|
|
@@ -431,60 +432,38 @@ class TwoStageRetriever:
|
|
| 431 |
):
|
| 432 |
"""
|
| 433 |
Build Qdrant filter from parameters.
|
| 434 |
-
|
| 435 |
Supports single values or lists (using MatchAny).
|
| 436 |
"""
|
| 437 |
-
from qdrant_client.models import Filter, FieldCondition, MatchValue, MatchAny
|
| 438 |
-
|
| 439 |
conditions = []
|
| 440 |
-
|
| 441 |
if year is not None:
|
| 442 |
if isinstance(year, list):
|
| 443 |
year_values = [int(y) if isinstance(y, str) else y for y in year]
|
| 444 |
-
conditions.append(
|
| 445 |
-
FieldCondition(key="year", match=MatchAny(any=year_values))
|
| 446 |
-
)
|
| 447 |
else:
|
| 448 |
year_value = int(year) if isinstance(year, str) else year
|
| 449 |
-
conditions.append(
|
| 450 |
-
|
| 451 |
-
)
|
| 452 |
-
|
| 453 |
if source is not None:
|
| 454 |
if isinstance(source, list):
|
| 455 |
-
conditions.append(
|
| 456 |
-
FieldCondition(key="source", match=MatchAny(any=source))
|
| 457 |
-
)
|
| 458 |
else:
|
| 459 |
-
conditions.append(
|
| 460 |
-
|
| 461 |
-
)
|
| 462 |
-
|
| 463 |
if district is not None:
|
| 464 |
if isinstance(district, list):
|
| 465 |
-
conditions.append(
|
| 466 |
-
FieldCondition(key="district", match=MatchAny(any=district))
|
| 467 |
-
)
|
| 468 |
else:
|
| 469 |
-
conditions.append(
|
| 470 |
-
|
| 471 |
-
)
|
| 472 |
-
|
| 473 |
if filename is not None:
|
| 474 |
if isinstance(filename, list):
|
| 475 |
-
conditions.append(
|
| 476 |
-
FieldCondition(key="filename", match=MatchAny(any=filename))
|
| 477 |
-
)
|
| 478 |
else:
|
| 479 |
-
conditions.append(
|
| 480 |
-
FieldCondition(key="filename", match=MatchValue(value=filename))
|
| 481 |
-
)
|
| 482 |
-
|
| 483 |
-
if has_text is not None:
|
| 484 |
-
conditions.append(
|
| 485 |
-
FieldCondition(key="has_text", match=MatchValue(value=has_text))
|
| 486 |
-
)
|
| 487 |
-
|
| 488 |
-
return Filter(must=conditions) if conditions else None
|
| 489 |
|
|
|
|
|
|
|
| 490 |
|
|
|
|
|
|
| 17 |
"""
|
| 18 |
|
| 19 |
import logging
|
| 20 |
+
import time
|
| 21 |
+
from typing import Any, Dict, List, Optional, Union
|
| 22 |
+
|
| 23 |
import numpy as np
|
| 24 |
import torch
|
| 25 |
|
| 26 |
+
from qdrant_client.http import models as qdrant_models
|
| 27 |
+
from qdrant_client.models import FieldCondition, Filter, MatchAny, MatchValue
|
| 28 |
+
|
| 29 |
+
from visual_rag.embedding.pooling import compute_maxsim_score
|
| 30 |
+
|
| 31 |
logger = logging.getLogger(__name__)
|
| 32 |
|
| 33 |
|
| 34 |
class TwoStageRetriever:
|
| 35 |
"""
|
| 36 |
Two-stage visual document retrieval with pooling and reranking.
|
| 37 |
+
|
| 38 |
Stage 1 (Prefetch):
|
| 39 |
Uses tile-level mean-pooled vectors for fast HNSW search.
|
| 40 |
Retrieves prefetch_k candidates (e.g., 100-500).
|
| 41 |
+
|
| 42 |
Stage 2 (Rerank):
|
| 43 |
Fetches full multi-vector embeddings for candidates.
|
| 44 |
Computes exact MaxSim scores for precise ranking.
|
| 45 |
Returns top_k results (e.g., 10).
|
| 46 |
+
|
| 47 |
Args:
|
| 48 |
qdrant_client: Connected Qdrant client
|
| 49 |
collection_name: Name of the Qdrant collection
|
| 50 |
full_vector_name: Name of full multi-vector field (default: "initial")
|
| 51 |
pooled_vector_name: Name of pooled vector field (default: "mean_pooling")
|
| 52 |
+
|
| 53 |
Example:
|
| 54 |
>>> retriever = TwoStageRetriever(client, "my_collection")
|
| 55 |
+
>>>
|
| 56 |
>>> # Two-stage search: prefetch 200, return top 10
|
| 57 |
>>> results = retriever.search(
|
| 58 |
... query_embedding=query,
|
| 59 |
... top_k=10,
|
| 60 |
... prefetch_k=200,
|
| 61 |
... )
|
| 62 |
+
>>>
|
| 63 |
>>> # Compare latency:
|
| 64 |
>>> # Full MaxSim (1000 docs): ~500ms
|
| 65 |
>>> # Two-stage (200→10): ~50ms
|
| 66 |
"""
|
| 67 |
+
|
| 68 |
def __init__(
|
| 69 |
self,
|
| 70 |
qdrant_client,
|
|
|
|
| 88 |
self.retry_sleep = float(retry_sleep)
|
| 89 |
|
| 90 |
def _retry_call(self, fn):
|
|
|
|
|
|
|
| 91 |
last_err = None
|
| 92 |
for attempt in range(self.max_retries):
|
| 93 |
try:
|
|
|
|
| 96 |
last_err = e
|
| 97 |
if attempt >= self.max_retries - 1:
|
| 98 |
break
|
| 99 |
+
time.sleep(self.retry_sleep * (2**attempt))
|
| 100 |
if last_err is not None:
|
| 101 |
raise last_err
|
| 102 |
|
|
|
|
| 110 |
) -> List[Dict[str, Any]]:
|
| 111 |
"""
|
| 112 |
Two-stage retrieval using Qdrant's native prefetch (all server-side).
|
| 113 |
+
|
| 114 |
This is MUCH faster than search() because it avoids network transfer
|
| 115 |
of large multi-vector embeddings. All computation happens in Qdrant.
|
| 116 |
+
|
| 117 |
Args:
|
| 118 |
query_embedding: Query embeddings [num_tokens, dim]
|
| 119 |
top_k: Final number of results
|
| 120 |
prefetch_k: Candidates for stage 1 (default: 10x top_k)
|
| 121 |
filter_obj: Qdrant filter
|
| 122 |
stage1_mode: How to do stage 1 prefetch
|
| 123 |
+
|
| 124 |
Returns:
|
| 125 |
List of results with scores
|
| 126 |
"""
|
|
|
|
|
|
|
| 127 |
query_np = self._to_numpy(query_embedding)
|
| 128 |
+
|
| 129 |
if prefetch_k is None:
|
| 130 |
prefetch_k = max(100, top_k * 10)
|
| 131 |
+
|
| 132 |
if stage1_mode == "pooled_query_vs_tiles":
|
| 133 |
prefetch_query = query_np.mean(axis=0).tolist()
|
| 134 |
prefetch_using = self.pooled_vector_name
|
|
|
|
| 146 |
prefetch_using = self.global_vector_name
|
| 147 |
else:
|
| 148 |
raise ValueError(f"Unknown stage1_mode: {stage1_mode}")
|
| 149 |
+
|
| 150 |
rerank_query = query_np.tolist()
|
| 151 |
+
|
| 152 |
def _do_query():
|
| 153 |
return self.client.query_points(
|
| 154 |
collection_name=self.collection_name,
|
|
|
|
| 157 |
limit=top_k,
|
| 158 |
query_filter=filter_obj,
|
| 159 |
with_payload=True,
|
| 160 |
+
search_params=qdrant_models.SearchParams(exact=True),
|
| 161 |
prefetch=[
|
| 162 |
+
qdrant_models.Prefetch(
|
| 163 |
query=prefetch_query,
|
| 164 |
using=prefetch_using,
|
| 165 |
limit=prefetch_k,
|
|
|
|
| 167 |
],
|
| 168 |
timeout=self.request_timeout,
|
| 169 |
).points
|
| 170 |
+
|
| 171 |
results = self._retry_call(_do_query)
|
| 172 |
+
|
| 173 |
return [
|
| 174 |
{
|
| 175 |
"id": r.id,
|
|
|
|
| 180 |
}
|
| 181 |
for r in results
|
| 182 |
]
|
| 183 |
+
|
| 184 |
def search(
|
| 185 |
self,
|
| 186 |
query_embedding: Union[torch.Tensor, np.ndarray],
|
|
|
|
| 193 |
) -> List[Dict[str, Any]]:
|
| 194 |
"""
|
| 195 |
Two-stage retrieval: prefetch with pooling, rerank with MaxSim.
|
| 196 |
+
|
| 197 |
Args:
|
| 198 |
query_embedding: Query embeddings [num_tokens, dim]
|
| 199 |
top_k: Final number of results to return
|
|
|
|
| 205 |
- "pooled_query_vs_tiles": pool query to 1×dim and search tile vectors (using="mean_pooling")
|
| 206 |
- "tokens_vs_tiles": search tile vectors with full query tokens (using="mean_pooling")
|
| 207 |
- "pooled_query_vs_global": pool query to 1×dim and search global pooled doc vectors (using="global_pooling")
|
| 208 |
+
|
| 209 |
Returns:
|
| 210 |
List of results with scores and metadata:
|
| 211 |
[
|
|
|
|
| 221 |
"""
|
| 222 |
# Convert to numpy
|
| 223 |
query_np = self._to_numpy(query_embedding)
|
| 224 |
+
|
| 225 |
# Auto-set prefetch_k
|
| 226 |
if prefetch_k is None:
|
| 227 |
prefetch_k = max(100, top_k * 10)
|
| 228 |
+
|
| 229 |
# Stage 1: Prefetch with pooled vectors
|
| 230 |
logger.info(f"🔍 Stage 1: Prefetching {prefetch_k} candidates ({stage1_mode})")
|
| 231 |
candidates = self._stage1_prefetch(
|
|
|
|
| 234 |
filter_obj=filter_obj,
|
| 235 |
stage1_mode=stage1_mode,
|
| 236 |
)
|
| 237 |
+
|
| 238 |
if not candidates:
|
| 239 |
logger.warning("No candidates found in stage 1")
|
| 240 |
return []
|
| 241 |
+
|
| 242 |
logger.info(f"✅ Stage 1: Retrieved {len(candidates)} candidates")
|
| 243 |
+
|
| 244 |
# Stage 2: Rerank with full embeddings
|
| 245 |
if use_reranking and len(candidates) > top_k:
|
| 246 |
+
logger.info("🎯 Stage 2: Reranking with MaxSim...")
|
| 247 |
results = self._stage2_rerank(
|
| 248 |
query_np=query_np,
|
| 249 |
candidates=candidates,
|
|
|
|
| 257 |
for r in results:
|
| 258 |
r["score_final"] = r["score_stage1"]
|
| 259 |
logger.info(f"⏭️ Skipping reranking, returning top {len(results)}")
|
| 260 |
+
|
| 261 |
return results
|
| 262 |
+
|
| 263 |
def search_single_stage(
|
| 264 |
self,
|
| 265 |
query_embedding: Union[torch.Tensor, np.ndarray],
|
|
|
|
| 269 |
) -> List[Dict[str, Any]]:
|
| 270 |
"""
|
| 271 |
Single-stage search (either pooled or full multi-vector).
|
| 272 |
+
|
| 273 |
Args:
|
| 274 |
query_embedding: Query embeddings
|
| 275 |
top_k: Number of results
|
| 276 |
filter_obj: Qdrant filter
|
| 277 |
use_pooling: Use pooled vectors (faster) or full (more accurate)
|
| 278 |
+
|
| 279 |
Returns:
|
| 280 |
List of results
|
| 281 |
"""
|
| 282 |
query_np = self._to_numpy(query_embedding)
|
| 283 |
+
|
| 284 |
if use_pooling:
|
| 285 |
# Pool query and search pooled vectors
|
| 286 |
query_pooled = query_np.mean(axis=0)
|
|
|
|
| 292 |
vector_name = self.full_vector_name
|
| 293 |
query_vector = query_np.tolist()
|
| 294 |
logger.info(f"🎯 Multi-vector search: {vector_name}")
|
| 295 |
+
|
| 296 |
results = self.client.query_points(
|
| 297 |
collection_name=self.collection_name,
|
| 298 |
query=query_vector,
|
|
|
|
| 303 |
with_vectors=False,
|
| 304 |
timeout=120,
|
| 305 |
).points
|
| 306 |
+
|
| 307 |
return [
|
| 308 |
{
|
| 309 |
"id": r.id,
|
|
|
|
| 313 |
}
|
| 314 |
for r in results
|
| 315 |
]
|
| 316 |
+
|
| 317 |
def _stage1_prefetch(
|
| 318 |
self,
|
| 319 |
query_np: np.ndarray,
|
|
|
|
| 333 |
vector_name = self.global_vector_name
|
| 334 |
else:
|
| 335 |
raise ValueError(f"Unknown stage1_mode: {stage1_mode}")
|
| 336 |
+
|
| 337 |
def _do_query():
|
| 338 |
return self.client.query_points(
|
| 339 |
collection_name=self.collection_name,
|
|
|
|
| 347 |
).points
|
| 348 |
|
| 349 |
results = self._retry_call(_do_query)
|
| 350 |
+
|
| 351 |
return [
|
| 352 |
{
|
| 353 |
"id": r.id,
|
|
|
|
| 356 |
}
|
| 357 |
for r in results
|
| 358 |
]
|
| 359 |
+
|
| 360 |
def _stage2_rerank(
|
| 361 |
self,
|
| 362 |
query_np: np.ndarray,
|
|
|
|
| 365 |
return_embeddings: bool = False,
|
| 366 |
) -> List[Dict[str, Any]]:
|
| 367 |
"""Stage 2: Rerank with full multi-vector MaxSim scoring."""
|
|
|
|
|
|
|
| 368 |
# Fetch full embeddings for candidates
|
| 369 |
candidate_ids = [c["id"] for c in candidates]
|
| 370 |
+
|
| 371 |
# Retrieve points with vectors
|
| 372 |
def _do_retrieve():
|
| 373 |
return self.client.retrieve(
|
|
|
|
| 379 |
)
|
| 380 |
|
| 381 |
points = self._retry_call(_do_retrieve)
|
| 382 |
+
|
| 383 |
# Build ID to embedding map
|
| 384 |
id_to_embedding = {}
|
| 385 |
for point in points:
|
|
|
|
| 387 |
id_to_embedding[point.id] = np.array(
|
| 388 |
point.vector[self.full_vector_name], dtype=np.float32
|
| 389 |
)
|
| 390 |
+
|
| 391 |
# Compute MaxSim scores
|
| 392 |
reranked = []
|
| 393 |
for candidate in candidates:
|
| 394 |
point_id = candidate["id"]
|
| 395 |
doc_embedding = id_to_embedding.get(point_id)
|
| 396 |
+
|
| 397 |
if doc_embedding is None:
|
| 398 |
# Fallback to stage 1 score
|
| 399 |
candidate["score_stage2"] = candidate["score_stage1"]
|
|
|
|
| 403 |
maxsim_score = compute_maxsim_score(query_np, doc_embedding)
|
| 404 |
candidate["score_stage2"] = maxsim_score
|
| 405 |
candidate["score_final"] = maxsim_score
|
| 406 |
+
|
| 407 |
if return_embeddings:
|
| 408 |
candidate["embedding"] = doc_embedding
|
| 409 |
+
|
| 410 |
reranked.append(candidate)
|
| 411 |
+
|
| 412 |
# Sort by final score (descending)
|
| 413 |
reranked.sort(key=lambda x: x["score_final"], reverse=True)
|
| 414 |
+
|
| 415 |
return reranked[:top_k]
|
| 416 |
+
|
| 417 |
def _to_numpy(self, embedding: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
|
| 418 |
"""Convert embedding to numpy array."""
|
| 419 |
if isinstance(embedding, torch.Tensor):
|
|
|
|
| 421 |
return embedding.cpu().float().numpy()
|
| 422 |
return embedding.cpu().numpy()
|
| 423 |
return np.array(embedding, dtype=np.float32)
|
| 424 |
+
|
| 425 |
def build_filter(
|
| 426 |
self,
|
| 427 |
year: Optional[Any] = None,
|
|
|
|
| 432 |
):
|
| 433 |
"""
|
| 434 |
Build Qdrant filter from parameters.
|
| 435 |
+
|
| 436 |
Supports single values or lists (using MatchAny).
|
| 437 |
"""
|
|
|
|
|
|
|
| 438 |
conditions = []
|
| 439 |
+
|
| 440 |
if year is not None:
|
| 441 |
if isinstance(year, list):
|
| 442 |
year_values = [int(y) if isinstance(y, str) else y for y in year]
|
| 443 |
+
conditions.append(FieldCondition(key="year", match=MatchAny(any=year_values)))
|
|
|
|
|
|
|
| 444 |
else:
|
| 445 |
year_value = int(year) if isinstance(year, str) else year
|
| 446 |
+
conditions.append(FieldCondition(key="year", match=MatchValue(value=year_value)))
|
| 447 |
+
|
|
|
|
|
|
|
| 448 |
if source is not None:
|
| 449 |
if isinstance(source, list):
|
| 450 |
+
conditions.append(FieldCondition(key="source", match=MatchAny(any=source)))
|
|
|
|
|
|
|
| 451 |
else:
|
| 452 |
+
conditions.append(FieldCondition(key="source", match=MatchValue(value=source)))
|
| 453 |
+
|
|
|
|
|
|
|
| 454 |
if district is not None:
|
| 455 |
if isinstance(district, list):
|
| 456 |
+
conditions.append(FieldCondition(key="district", match=MatchAny(any=district)))
|
|
|
|
|
|
|
| 457 |
else:
|
| 458 |
+
conditions.append(FieldCondition(key="district", match=MatchValue(value=district)))
|
| 459 |
+
|
|
|
|
|
|
|
| 460 |
if filename is not None:
|
| 461 |
if isinstance(filename, list):
|
| 462 |
+
conditions.append(FieldCondition(key="filename", match=MatchAny(any=filename)))
|
|
|
|
|
|
|
| 463 |
else:
|
| 464 |
+
conditions.append(FieldCondition(key="filename", match=MatchValue(value=filename)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
|
| 466 |
+
if has_text is not None:
|
| 467 |
+
conditions.append(FieldCondition(key="has_text", match=MatchValue(value=has_text)))
|
| 468 |
|
| 469 |
+
return Filter(must=conditions) if conditions else None
|
visual_rag/visualization/__init__.py
CHANGED
|
@@ -7,8 +7,8 @@ This module provides:
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
from visual_rag.visualization.saliency import (
|
| 10 |
-
generate_saliency_map,
|
| 11 |
create_saliency_overlay,
|
|
|
|
| 12 |
visualize_search_results,
|
| 13 |
)
|
| 14 |
|
|
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
from visual_rag.visualization.saliency import (
|
|
|
|
| 10 |
create_saliency_overlay,
|
| 11 |
+
generate_saliency_map,
|
| 12 |
visualize_search_results,
|
| 13 |
)
|
| 14 |
|
visual_rag/visualization/saliency.py
CHANGED
|
@@ -5,10 +5,11 @@ Generates attention/saliency maps to visualize which parts of documents
|
|
| 5 |
are most relevant to a query.
|
| 6 |
"""
|
| 7 |
|
| 8 |
-
import numpy as np
|
| 9 |
-
from PIL import Image, ImageDraw, ImageFont
|
| 10 |
-
from typing import List, Dict, Any, Optional, Tuple, Union
|
| 11 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
|
@@ -24,9 +25,9 @@ def generate_saliency_map(
|
|
| 24 |
) -> Tuple[Image.Image, np.ndarray]:
|
| 25 |
"""
|
| 26 |
Generate saliency map showing which parts of the image match the query.
|
| 27 |
-
|
| 28 |
Computes patch-level relevance scores and overlays them on the image.
|
| 29 |
-
|
| 30 |
Args:
|
| 31 |
query_embedding: Query embeddings [num_query_tokens, dim]
|
| 32 |
doc_embedding: Document visual embeddings [num_visual_tokens, dim]
|
|
@@ -35,10 +36,10 @@ def generate_saliency_map(
|
|
| 35 |
colormap: Matplotlib colormap name (Reds, viridis, jet, etc.)
|
| 36 |
alpha: Overlay transparency (0-1)
|
| 37 |
threshold_percentile: Only highlight patches above this percentile
|
| 38 |
-
|
| 39 |
Returns:
|
| 40 |
Tuple of (annotated_image, patch_scores)
|
| 41 |
-
|
| 42 |
Example:
|
| 43 |
>>> query = embedder.embed_query("budget allocation")
|
| 44 |
>>> doc = visual_embedding # From embed_images
|
|
@@ -51,57 +52,57 @@ def generate_saliency_map(
|
|
| 51 |
>>> annotated.save("saliency.png")
|
| 52 |
"""
|
| 53 |
# Ensure numpy arrays
|
| 54 |
-
if hasattr(query_embedding,
|
| 55 |
query_np = query_embedding.numpy()
|
| 56 |
-
elif hasattr(query_embedding,
|
| 57 |
query_np = query_embedding.cpu().numpy()
|
| 58 |
else:
|
| 59 |
query_np = np.array(query_embedding, dtype=np.float32)
|
| 60 |
-
|
| 61 |
-
if hasattr(doc_embedding,
|
| 62 |
doc_np = doc_embedding.numpy()
|
| 63 |
-
elif hasattr(doc_embedding,
|
| 64 |
doc_np = doc_embedding.cpu().numpy()
|
| 65 |
else:
|
| 66 |
doc_np = np.array(doc_embedding, dtype=np.float32)
|
| 67 |
-
|
| 68 |
# Normalize embeddings
|
| 69 |
query_norm = query_np / (np.linalg.norm(query_np, axis=1, keepdims=True) + 1e-8)
|
| 70 |
doc_norm = doc_np / (np.linalg.norm(doc_np, axis=1, keepdims=True) + 1e-8)
|
| 71 |
-
|
| 72 |
# Compute similarity matrix: [num_query, num_doc]
|
| 73 |
similarity_matrix = np.dot(query_norm, doc_norm.T)
|
| 74 |
-
|
| 75 |
# Get max similarity per document patch (best match from any query token)
|
| 76 |
patch_scores = similarity_matrix.max(axis=0)
|
| 77 |
-
|
| 78 |
# Normalize to [0, 1]
|
| 79 |
score_min, score_max = patch_scores.min(), patch_scores.max()
|
| 80 |
if score_max - score_min > 1e-8:
|
| 81 |
patch_scores_norm = (patch_scores - score_min) / (score_max - score_min)
|
| 82 |
else:
|
| 83 |
patch_scores_norm = np.zeros_like(patch_scores)
|
| 84 |
-
|
| 85 |
# Determine grid dimensions
|
| 86 |
if token_info and token_info.get("n_rows") and token_info.get("n_cols"):
|
| 87 |
n_rows = token_info["n_rows"]
|
| 88 |
n_cols = token_info["n_cols"]
|
| 89 |
num_tiles = n_rows * n_cols + 1 # +1 for global tile
|
| 90 |
patches_per_tile = 64 # ColSmol standard
|
| 91 |
-
|
| 92 |
# Reshape to tile grid (excluding global tile)
|
| 93 |
try:
|
| 94 |
# Skip global tile patches at the end
|
| 95 |
tile_patches = num_tiles * patches_per_tile
|
| 96 |
if len(patch_scores_norm) >= tile_patches:
|
| 97 |
-
grid_patches = patch_scores_norm[:n_rows * n_cols * patches_per_tile]
|
| 98 |
else:
|
| 99 |
grid_patches = patch_scores_norm
|
| 100 |
-
|
| 101 |
# Reshape: [tiles * patches_per_tile] -> [tiles, patches_per_tile]
|
| 102 |
# Then mean per tile
|
| 103 |
num_grid_tiles = n_rows * n_cols
|
| 104 |
-
grid_patches = grid_patches[:num_grid_tiles * patches_per_tile]
|
| 105 |
tile_scores = grid_patches.reshape(num_grid_tiles, patches_per_tile).mean(axis=1)
|
| 106 |
tile_scores = tile_scores.reshape(n_rows, n_cols)
|
| 107 |
except Exception as e:
|
|
@@ -110,7 +111,7 @@ def generate_saliency_map(
|
|
| 110 |
else:
|
| 111 |
tile_scores = None
|
| 112 |
n_rows = n_cols = None
|
| 113 |
-
|
| 114 |
# Create overlay
|
| 115 |
annotated = create_saliency_overlay(
|
| 116 |
image=image,
|
|
@@ -121,7 +122,7 @@ def generate_saliency_map(
|
|
| 121 |
grid_rows=n_rows,
|
| 122 |
grid_cols=n_cols,
|
| 123 |
)
|
| 124 |
-
|
| 125 |
return annotated, patch_scores
|
| 126 |
|
| 127 |
|
|
@@ -136,7 +137,7 @@ def create_saliency_overlay(
|
|
| 136 |
) -> Image.Image:
|
| 137 |
"""
|
| 138 |
Create colored overlay on image based on scores.
|
| 139 |
-
|
| 140 |
Args:
|
| 141 |
image: Base PIL Image
|
| 142 |
scores: Score array - 1D [num_patches] or 2D [rows, cols]
|
|
@@ -144,7 +145,7 @@ def create_saliency_overlay(
|
|
| 144 |
alpha: Overlay transparency
|
| 145 |
threshold_percentile: Only color patches above this percentile
|
| 146 |
grid_rows, grid_cols: Grid dimensions (auto-detected if not provided)
|
| 147 |
-
|
| 148 |
Returns:
|
| 149 |
Annotated PIL Image
|
| 150 |
"""
|
|
@@ -153,10 +154,10 @@ def create_saliency_overlay(
|
|
| 153 |
except ImportError:
|
| 154 |
logger.warning("matplotlib not installed, returning original image")
|
| 155 |
return image
|
| 156 |
-
|
| 157 |
img_array = np.array(image)
|
| 158 |
h, w = img_array.shape[:2]
|
| 159 |
-
|
| 160 |
# Handle 2D scores (tile grid)
|
| 161 |
if scores.ndim == 2:
|
| 162 |
rows, cols = scores.shape
|
|
@@ -171,58 +172,58 @@ def create_saliency_overlay(
|
|
| 171 |
aspect = w / h
|
| 172 |
cols = int(np.sqrt(num_patches * aspect))
|
| 173 |
rows = max(1, num_patches // cols)
|
| 174 |
-
scores = scores[:rows * cols].reshape(rows, cols)
|
| 175 |
else:
|
| 176 |
# Auto-estimate grid
|
| 177 |
num_patches = len(scores) if scores.ndim == 1 else scores.size
|
| 178 |
aspect = w / h
|
| 179 |
cols = max(1, int(np.sqrt(num_patches * aspect)))
|
| 180 |
rows = max(1, num_patches // cols)
|
| 181 |
-
|
| 182 |
if rows * cols > len(scores) if scores.ndim == 1 else scores.size:
|
| 183 |
cols = max(1, cols - 1)
|
| 184 |
-
|
| 185 |
if scores.ndim == 1:
|
| 186 |
-
scores = scores[:rows * cols].reshape(rows, cols)
|
| 187 |
-
|
| 188 |
# Get colormap
|
| 189 |
cmap = plt.cm.get_cmap(colormap)
|
| 190 |
-
|
| 191 |
# Calculate threshold
|
| 192 |
threshold = np.percentile(scores, threshold_percentile)
|
| 193 |
-
|
| 194 |
# Calculate cell dimensions
|
| 195 |
cell_h = h // rows
|
| 196 |
cell_w = w // cols
|
| 197 |
-
|
| 198 |
# Create RGBA overlay
|
| 199 |
overlay = np.zeros((h, w, 4), dtype=np.uint8)
|
| 200 |
-
|
| 201 |
for i in range(rows):
|
| 202 |
for j in range(cols):
|
| 203 |
score = scores[i, j]
|
| 204 |
-
|
| 205 |
if score >= threshold:
|
| 206 |
y1 = i * cell_h
|
| 207 |
y2 = min((i + 1) * cell_h, h)
|
| 208 |
x1 = j * cell_w
|
| 209 |
x2 = min((j + 1) * cell_w, w)
|
| 210 |
-
|
| 211 |
# Normalize score for coloring (above threshold)
|
| 212 |
norm_score = (score - threshold) / (1.0 - threshold + 1e-8)
|
| 213 |
norm_score = min(1.0, max(0.0, norm_score))
|
| 214 |
-
|
| 215 |
# Get color
|
| 216 |
color = cmap(norm_score)[:3]
|
| 217 |
color_uint8 = (np.array(color) * 255).astype(np.uint8)
|
| 218 |
-
|
| 219 |
overlay[y1:y2, x1:x2, :3] = color_uint8
|
| 220 |
overlay[y1:y2, x1:x2, 3] = int(alpha * 255 * norm_score)
|
| 221 |
-
|
| 222 |
# Blend with original
|
| 223 |
overlay_img = Image.fromarray(overlay, "RGBA")
|
| 224 |
result = Image.alpha_composite(image.convert("RGBA"), overlay_img)
|
| 225 |
-
|
| 226 |
return result.convert("RGB")
|
| 227 |
|
| 228 |
|
|
@@ -237,7 +238,7 @@ def visualize_search_results(
|
|
| 237 |
) -> Optional[Image.Image]:
|
| 238 |
"""
|
| 239 |
Visualize search results as a grid of images with scores.
|
| 240 |
-
|
| 241 |
Args:
|
| 242 |
query: Original query text
|
| 243 |
results: List of search results with 'payload' containing 'page' (image URL/base64)
|
|
@@ -246,7 +247,7 @@ def visualize_search_results(
|
|
| 246 |
output_path: Path to save visualization (optional)
|
| 247 |
max_results: Maximum results to show
|
| 248 |
show_saliency: Generate saliency overlays (requires query_embedding & embeddings)
|
| 249 |
-
|
| 250 |
Returns:
|
| 251 |
Combined visualization image if successful
|
| 252 |
"""
|
|
@@ -255,32 +256,32 @@ def visualize_search_results(
|
|
| 255 |
except ImportError:
|
| 256 |
logger.error("matplotlib required for visualization")
|
| 257 |
return None
|
| 258 |
-
|
| 259 |
results = results[:max_results]
|
| 260 |
n = len(results)
|
| 261 |
-
|
| 262 |
if n == 0:
|
| 263 |
logger.warning("No results to visualize")
|
| 264 |
return None
|
| 265 |
-
|
| 266 |
fig, axes = plt.subplots(1, n, figsize=(4 * n, 4))
|
| 267 |
if n == 1:
|
| 268 |
axes = [axes]
|
| 269 |
-
|
| 270 |
for idx, (result, ax) in enumerate(zip(results, axes)):
|
| 271 |
payload = result.get("payload", {})
|
| 272 |
score = result.get("score_final", result.get("score_stage1", 0))
|
| 273 |
-
|
| 274 |
# Try to load image from payload
|
| 275 |
page_data = payload.get("page", "")
|
| 276 |
image = None
|
| 277 |
-
|
| 278 |
if page_data.startswith("data:image"):
|
| 279 |
# Base64 encoded
|
| 280 |
try:
|
| 281 |
import base64
|
| 282 |
from io import BytesIO
|
| 283 |
-
|
| 284 |
b64_data = page_data.split(",")[1]
|
| 285 |
image = Image.open(BytesIO(base64.b64decode(b64_data)))
|
| 286 |
except Exception as e:
|
|
@@ -290,50 +291,45 @@ def visualize_search_results(
|
|
| 290 |
try:
|
| 291 |
import urllib.request
|
| 292 |
from io import BytesIO
|
| 293 |
-
|
| 294 |
with urllib.request.urlopen(page_data, timeout=5) as response:
|
| 295 |
image = Image.open(BytesIO(response.read()))
|
| 296 |
except Exception as e:
|
| 297 |
logger.debug(f"Could not fetch image URL: {e}")
|
| 298 |
-
|
| 299 |
if image:
|
| 300 |
ax.imshow(image)
|
| 301 |
else:
|
| 302 |
# Show placeholder
|
| 303 |
-
ax.text(
|
| 304 |
-
|
| 305 |
-
ha="center", va="center",
|
| 306 |
-
fontsize=12, color="gray"
|
| 307 |
-
)
|
| 308 |
-
|
| 309 |
# Add title
|
| 310 |
title = f"Rank {idx + 1}\nScore: {score:.3f}"
|
| 311 |
if payload.get("filename"):
|
| 312 |
title += f"\n{payload['filename'][:30]}"
|
| 313 |
if payload.get("page_number") is not None:
|
| 314 |
title += f" p.{payload['page_number'] + 1}"
|
| 315 |
-
|
| 316 |
ax.set_title(title, fontsize=9)
|
| 317 |
ax.axis("off")
|
| 318 |
-
|
| 319 |
# Add query as suptitle
|
| 320 |
query_display = query[:80] + "..." if len(query) > 80 else query
|
| 321 |
plt.suptitle(f"Query: {query_display}", fontsize=11, fontweight="bold")
|
| 322 |
plt.tight_layout()
|
| 323 |
-
|
| 324 |
if output_path:
|
| 325 |
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
| 326 |
logger.info(f"💾 Saved visualization to: {output_path}")
|
| 327 |
-
|
| 328 |
# Convert to PIL Image for return
|
| 329 |
from io import BytesIO
|
|
|
|
| 330 |
buf = BytesIO()
|
| 331 |
plt.savefig(buf, format="png", dpi=100, bbox_inches="tight")
|
| 332 |
buf.seek(0)
|
| 333 |
result_image = Image.open(buf)
|
| 334 |
-
|
| 335 |
-
plt.close()
|
| 336 |
-
|
| 337 |
-
return result_image
|
| 338 |
|
|
|
|
| 339 |
|
|
|
|
|
|
| 5 |
are most relevant to a query.
|
| 6 |
"""
|
| 7 |
|
|
|
|
|
|
|
|
|
|
| 8 |
import logging
|
| 9 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from PIL import Image
|
| 13 |
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
|
|
|
| 25 |
) -> Tuple[Image.Image, np.ndarray]:
|
| 26 |
"""
|
| 27 |
Generate saliency map showing which parts of the image match the query.
|
| 28 |
+
|
| 29 |
Computes patch-level relevance scores and overlays them on the image.
|
| 30 |
+
|
| 31 |
Args:
|
| 32 |
query_embedding: Query embeddings [num_query_tokens, dim]
|
| 33 |
doc_embedding: Document visual embeddings [num_visual_tokens, dim]
|
|
|
|
| 36 |
colormap: Matplotlib colormap name (Reds, viridis, jet, etc.)
|
| 37 |
alpha: Overlay transparency (0-1)
|
| 38 |
threshold_percentile: Only highlight patches above this percentile
|
| 39 |
+
|
| 40 |
Returns:
|
| 41 |
Tuple of (annotated_image, patch_scores)
|
| 42 |
+
|
| 43 |
Example:
|
| 44 |
>>> query = embedder.embed_query("budget allocation")
|
| 45 |
>>> doc = visual_embedding # From embed_images
|
|
|
|
| 52 |
>>> annotated.save("saliency.png")
|
| 53 |
"""
|
| 54 |
# Ensure numpy arrays
|
| 55 |
+
if hasattr(query_embedding, "numpy"):
|
| 56 |
query_np = query_embedding.numpy()
|
| 57 |
+
elif hasattr(query_embedding, "cpu"):
|
| 58 |
query_np = query_embedding.cpu().numpy()
|
| 59 |
else:
|
| 60 |
query_np = np.array(query_embedding, dtype=np.float32)
|
| 61 |
+
|
| 62 |
+
if hasattr(doc_embedding, "numpy"):
|
| 63 |
doc_np = doc_embedding.numpy()
|
| 64 |
+
elif hasattr(doc_embedding, "cpu"):
|
| 65 |
doc_np = doc_embedding.cpu().numpy()
|
| 66 |
else:
|
| 67 |
doc_np = np.array(doc_embedding, dtype=np.float32)
|
| 68 |
+
|
| 69 |
# Normalize embeddings
|
| 70 |
query_norm = query_np / (np.linalg.norm(query_np, axis=1, keepdims=True) + 1e-8)
|
| 71 |
doc_norm = doc_np / (np.linalg.norm(doc_np, axis=1, keepdims=True) + 1e-8)
|
| 72 |
+
|
| 73 |
# Compute similarity matrix: [num_query, num_doc]
|
| 74 |
similarity_matrix = np.dot(query_norm, doc_norm.T)
|
| 75 |
+
|
| 76 |
# Get max similarity per document patch (best match from any query token)
|
| 77 |
patch_scores = similarity_matrix.max(axis=0)
|
| 78 |
+
|
| 79 |
# Normalize to [0, 1]
|
| 80 |
score_min, score_max = patch_scores.min(), patch_scores.max()
|
| 81 |
if score_max - score_min > 1e-8:
|
| 82 |
patch_scores_norm = (patch_scores - score_min) / (score_max - score_min)
|
| 83 |
else:
|
| 84 |
patch_scores_norm = np.zeros_like(patch_scores)
|
| 85 |
+
|
| 86 |
# Determine grid dimensions
|
| 87 |
if token_info and token_info.get("n_rows") and token_info.get("n_cols"):
|
| 88 |
n_rows = token_info["n_rows"]
|
| 89 |
n_cols = token_info["n_cols"]
|
| 90 |
num_tiles = n_rows * n_cols + 1 # +1 for global tile
|
| 91 |
patches_per_tile = 64 # ColSmol standard
|
| 92 |
+
|
| 93 |
# Reshape to tile grid (excluding global tile)
|
| 94 |
try:
|
| 95 |
# Skip global tile patches at the end
|
| 96 |
tile_patches = num_tiles * patches_per_tile
|
| 97 |
if len(patch_scores_norm) >= tile_patches:
|
| 98 |
+
grid_patches = patch_scores_norm[: n_rows * n_cols * patches_per_tile]
|
| 99 |
else:
|
| 100 |
grid_patches = patch_scores_norm
|
| 101 |
+
|
| 102 |
# Reshape: [tiles * patches_per_tile] -> [tiles, patches_per_tile]
|
| 103 |
# Then mean per tile
|
| 104 |
num_grid_tiles = n_rows * n_cols
|
| 105 |
+
grid_patches = grid_patches[: num_grid_tiles * patches_per_tile]
|
| 106 |
tile_scores = grid_patches.reshape(num_grid_tiles, patches_per_tile).mean(axis=1)
|
| 107 |
tile_scores = tile_scores.reshape(n_rows, n_cols)
|
| 108 |
except Exception as e:
|
|
|
|
| 111 |
else:
|
| 112 |
tile_scores = None
|
| 113 |
n_rows = n_cols = None
|
| 114 |
+
|
| 115 |
# Create overlay
|
| 116 |
annotated = create_saliency_overlay(
|
| 117 |
image=image,
|
|
|
|
| 122 |
grid_rows=n_rows,
|
| 123 |
grid_cols=n_cols,
|
| 124 |
)
|
| 125 |
+
|
| 126 |
return annotated, patch_scores
|
| 127 |
|
| 128 |
|
|
|
|
| 137 |
) -> Image.Image:
|
| 138 |
"""
|
| 139 |
Create colored overlay on image based on scores.
|
| 140 |
+
|
| 141 |
Args:
|
| 142 |
image: Base PIL Image
|
| 143 |
scores: Score array - 1D [num_patches] or 2D [rows, cols]
|
|
|
|
| 145 |
alpha: Overlay transparency
|
| 146 |
threshold_percentile: Only color patches above this percentile
|
| 147 |
grid_rows, grid_cols: Grid dimensions (auto-detected if not provided)
|
| 148 |
+
|
| 149 |
Returns:
|
| 150 |
Annotated PIL Image
|
| 151 |
"""
|
|
|
|
| 154 |
except ImportError:
|
| 155 |
logger.warning("matplotlib not installed, returning original image")
|
| 156 |
return image
|
| 157 |
+
|
| 158 |
img_array = np.array(image)
|
| 159 |
h, w = img_array.shape[:2]
|
| 160 |
+
|
| 161 |
# Handle 2D scores (tile grid)
|
| 162 |
if scores.ndim == 2:
|
| 163 |
rows, cols = scores.shape
|
|
|
|
| 172 |
aspect = w / h
|
| 173 |
cols = int(np.sqrt(num_patches * aspect))
|
| 174 |
rows = max(1, num_patches // cols)
|
| 175 |
+
scores = scores[: rows * cols].reshape(rows, cols)
|
| 176 |
else:
|
| 177 |
# Auto-estimate grid
|
| 178 |
num_patches = len(scores) if scores.ndim == 1 else scores.size
|
| 179 |
aspect = w / h
|
| 180 |
cols = max(1, int(np.sqrt(num_patches * aspect)))
|
| 181 |
rows = max(1, num_patches // cols)
|
| 182 |
+
|
| 183 |
if rows * cols > len(scores) if scores.ndim == 1 else scores.size:
|
| 184 |
cols = max(1, cols - 1)
|
| 185 |
+
|
| 186 |
if scores.ndim == 1:
|
| 187 |
+
scores = scores[: rows * cols].reshape(rows, cols)
|
| 188 |
+
|
| 189 |
# Get colormap
|
| 190 |
cmap = plt.cm.get_cmap(colormap)
|
| 191 |
+
|
| 192 |
# Calculate threshold
|
| 193 |
threshold = np.percentile(scores, threshold_percentile)
|
| 194 |
+
|
| 195 |
# Calculate cell dimensions
|
| 196 |
cell_h = h // rows
|
| 197 |
cell_w = w // cols
|
| 198 |
+
|
| 199 |
# Create RGBA overlay
|
| 200 |
overlay = np.zeros((h, w, 4), dtype=np.uint8)
|
| 201 |
+
|
| 202 |
for i in range(rows):
|
| 203 |
for j in range(cols):
|
| 204 |
score = scores[i, j]
|
| 205 |
+
|
| 206 |
if score >= threshold:
|
| 207 |
y1 = i * cell_h
|
| 208 |
y2 = min((i + 1) * cell_h, h)
|
| 209 |
x1 = j * cell_w
|
| 210 |
x2 = min((j + 1) * cell_w, w)
|
| 211 |
+
|
| 212 |
# Normalize score for coloring (above threshold)
|
| 213 |
norm_score = (score - threshold) / (1.0 - threshold + 1e-8)
|
| 214 |
norm_score = min(1.0, max(0.0, norm_score))
|
| 215 |
+
|
| 216 |
# Get color
|
| 217 |
color = cmap(norm_score)[:3]
|
| 218 |
color_uint8 = (np.array(color) * 255).astype(np.uint8)
|
| 219 |
+
|
| 220 |
overlay[y1:y2, x1:x2, :3] = color_uint8
|
| 221 |
overlay[y1:y2, x1:x2, 3] = int(alpha * 255 * norm_score)
|
| 222 |
+
|
| 223 |
# Blend with original
|
| 224 |
overlay_img = Image.fromarray(overlay, "RGBA")
|
| 225 |
result = Image.alpha_composite(image.convert("RGBA"), overlay_img)
|
| 226 |
+
|
| 227 |
return result.convert("RGB")
|
| 228 |
|
| 229 |
|
|
|
|
| 238 |
) -> Optional[Image.Image]:
|
| 239 |
"""
|
| 240 |
Visualize search results as a grid of images with scores.
|
| 241 |
+
|
| 242 |
Args:
|
| 243 |
query: Original query text
|
| 244 |
results: List of search results with 'payload' containing 'page' (image URL/base64)
|
|
|
|
| 247 |
output_path: Path to save visualization (optional)
|
| 248 |
max_results: Maximum results to show
|
| 249 |
show_saliency: Generate saliency overlays (requires query_embedding & embeddings)
|
| 250 |
+
|
| 251 |
Returns:
|
| 252 |
Combined visualization image if successful
|
| 253 |
"""
|
|
|
|
| 256 |
except ImportError:
|
| 257 |
logger.error("matplotlib required for visualization")
|
| 258 |
return None
|
| 259 |
+
|
| 260 |
results = results[:max_results]
|
| 261 |
n = len(results)
|
| 262 |
+
|
| 263 |
if n == 0:
|
| 264 |
logger.warning("No results to visualize")
|
| 265 |
return None
|
| 266 |
+
|
| 267 |
fig, axes = plt.subplots(1, n, figsize=(4 * n, 4))
|
| 268 |
if n == 1:
|
| 269 |
axes = [axes]
|
| 270 |
+
|
| 271 |
for idx, (result, ax) in enumerate(zip(results, axes)):
|
| 272 |
payload = result.get("payload", {})
|
| 273 |
score = result.get("score_final", result.get("score_stage1", 0))
|
| 274 |
+
|
| 275 |
# Try to load image from payload
|
| 276 |
page_data = payload.get("page", "")
|
| 277 |
image = None
|
| 278 |
+
|
| 279 |
if page_data.startswith("data:image"):
|
| 280 |
# Base64 encoded
|
| 281 |
try:
|
| 282 |
import base64
|
| 283 |
from io import BytesIO
|
| 284 |
+
|
| 285 |
b64_data = page_data.split(",")[1]
|
| 286 |
image = Image.open(BytesIO(base64.b64decode(b64_data)))
|
| 287 |
except Exception as e:
|
|
|
|
| 291 |
try:
|
| 292 |
import urllib.request
|
| 293 |
from io import BytesIO
|
| 294 |
+
|
| 295 |
with urllib.request.urlopen(page_data, timeout=5) as response:
|
| 296 |
image = Image.open(BytesIO(response.read()))
|
| 297 |
except Exception as e:
|
| 298 |
logger.debug(f"Could not fetch image URL: {e}")
|
| 299 |
+
|
| 300 |
if image:
|
| 301 |
ax.imshow(image)
|
| 302 |
else:
|
| 303 |
# Show placeholder
|
| 304 |
+
ax.text(0.5, 0.5, "No image", ha="center", va="center", fontsize=12, color="gray")
|
| 305 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
# Add title
|
| 307 |
title = f"Rank {idx + 1}\nScore: {score:.3f}"
|
| 308 |
if payload.get("filename"):
|
| 309 |
title += f"\n{payload['filename'][:30]}"
|
| 310 |
if payload.get("page_number") is not None:
|
| 311 |
title += f" p.{payload['page_number'] + 1}"
|
| 312 |
+
|
| 313 |
ax.set_title(title, fontsize=9)
|
| 314 |
ax.axis("off")
|
| 315 |
+
|
| 316 |
# Add query as suptitle
|
| 317 |
query_display = query[:80] + "..." if len(query) > 80 else query
|
| 318 |
plt.suptitle(f"Query: {query_display}", fontsize=11, fontweight="bold")
|
| 319 |
plt.tight_layout()
|
| 320 |
+
|
| 321 |
if output_path:
|
| 322 |
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
| 323 |
logger.info(f"💾 Saved visualization to: {output_path}")
|
| 324 |
+
|
| 325 |
# Convert to PIL Image for return
|
| 326 |
from io import BytesIO
|
| 327 |
+
|
| 328 |
buf = BytesIO()
|
| 329 |
plt.savefig(buf, format="png", dpi=100, bbox_inches="tight")
|
| 330 |
buf.seek(0)
|
| 331 |
result_image = Image.open(buf)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
|
| 333 |
+
plt.close()
|
| 334 |
|
| 335 |
+
return result_image
|