Yeroyan commited on
Commit
9513cca
·
verified ·
1 Parent(s): d9f2c00

sync v0.1.3

Browse files
benchmarks/__init__.py CHANGED
@@ -8,3 +8,4 @@ work in Docker/Spaces environments.
8
  """
9
 
10
  __all__ = []
 
 
8
  """
9
 
10
  __all__ = []
11
+
benchmarks/quick_test.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Quick Benchmark - Validate retrieval quality with ViDoRe data.
4
+
5
+ This script:
6
+ 1. Downloads samples from ViDoRe (with ground truth relevance)
7
+ 2. Embeds with ColSmol-500M
8
+ 3. Tests retrieval strategies (exhaustive vs two-stage)
9
+ 4. Computes METRICS: NDCG@K, MRR@K, Recall@K
10
+ 5. Compares speed and quality
11
+
12
+ Usage:
13
+ python quick_test.py --samples 100
14
+ python quick_test.py --samples 500 --skip-exhaustive # Faster
15
+ """
16
+
17
+ import sys
18
+ import time
19
+ import argparse
20
+ import logging
21
+ from pathlib import Path
22
+ from typing import List, Dict, Any
23
+
24
+ # Add parent directory to Python path (so we can import visual_rag)
25
+ # This allows running the script directly without pip install
26
+ _script_dir = Path(__file__).parent
27
+ _parent_dir = _script_dir.parent
28
+ if str(_parent_dir) not in sys.path:
29
+ sys.path.insert(0, str(_parent_dir))
30
+
31
+ import numpy as np
32
+ from tqdm import tqdm
33
+
34
+ # Visual RAG imports (now works without pip install)
35
+ from visual_rag.embedding import VisualEmbedder
36
+ from visual_rag.embedding.pooling import (
37
+ tile_level_mean_pooling,
38
+ compute_maxsim_score,
39
+ )
40
+
41
+ # Optional: datasets for ViDoRe
42
+ try:
43
+ from datasets import load_dataset as hf_load_dataset
44
+ HAS_DATASETS = True
45
+ except ImportError:
46
+ HAS_DATASETS = False
47
+
48
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
49
+ logger = logging.getLogger(__name__)
50
+
51
+
52
+ def load_vidore_sample(num_samples: int = 100) -> List[Dict]:
53
+ """
54
+ Load sample from ViDoRe DocVQA with ground truth.
55
+
56
+ Each sample has a query and its relevant document (1:1 mapping).
57
+ This allows computing retrieval metrics.
58
+ """
59
+ if not HAS_DATASETS:
60
+ logger.error("Install datasets: pip install datasets")
61
+ sys.exit(1)
62
+
63
+ logger.info(f"📥 Loading {num_samples} samples from ViDoRe DocVQA...")
64
+
65
+ ds = hf_load_dataset("vidore/docvqa_test_subsampled", split="test")
66
+
67
+ samples = []
68
+ for i, example in enumerate(ds):
69
+ if i >= num_samples:
70
+ break
71
+
72
+ samples.append({
73
+ "id": i,
74
+ "doc_id": f"doc_{i}",
75
+ "query_id": f"q_{i}",
76
+ "image": example.get("image", example.get("page_image")),
77
+ "query": example.get("query", example.get("question", "")),
78
+ # Ground truth: query i is relevant to doc i
79
+ "relevant_doc": f"doc_{i}",
80
+ })
81
+
82
+ logger.info(f"✅ Loaded {len(samples)} samples with ground truth")
83
+ return samples
84
+
85
+
86
+ def embed_all(
87
+ samples: List[Dict],
88
+ model_name: str = "vidore/colSmol-500M",
89
+ ) -> Dict[str, Any]:
90
+ """Embed all documents and queries."""
91
+ logger.info(f"\n🤖 Loading model: {model_name}")
92
+ embedder = VisualEmbedder(model_name=model_name)
93
+
94
+ images = [s["image"] for s in samples]
95
+ queries = [s["query"] for s in samples if s["query"]]
96
+
97
+ # Embed images
98
+ logger.info(f"🎨 Embedding {len(images)} documents...")
99
+ start_time = time.time()
100
+
101
+ embeddings, token_infos = embedder.embed_images(
102
+ images, batch_size=4, return_token_info=True
103
+ )
104
+
105
+ doc_embed_time = time.time() - start_time
106
+ logger.info(f" Time: {doc_embed_time:.2f}s ({doc_embed_time/len(images)*1000:.1f}ms/doc)")
107
+
108
+ # Process embeddings: extract visual tokens + tile-level pooling
109
+ doc_data = {}
110
+ for i, (emb, token_info) in enumerate(zip(embeddings, token_infos)):
111
+ if hasattr(emb, 'cpu'):
112
+ emb = emb.cpu()
113
+ emb_np = emb.numpy() if hasattr(emb, 'numpy') else np.array(emb)
114
+
115
+ # Extract visual tokens only (filter special tokens)
116
+ visual_indices = token_info["visual_token_indices"]
117
+ visual_emb = emb_np[visual_indices].astype(np.float32)
118
+
119
+ # Tile-level pooling
120
+ n_rows = token_info.get("n_rows", 4)
121
+ n_cols = token_info.get("n_cols", 3)
122
+ num_tiles = n_rows * n_cols + 1 if n_rows and n_cols else 13
123
+
124
+ tile_pooled = tile_level_mean_pooling(visual_emb, num_tiles, patches_per_tile=64)
125
+
126
+ doc_data[f"doc_{i}"] = {
127
+ "embedding": visual_emb,
128
+ "pooled": tile_pooled,
129
+ "num_visual_tokens": len(visual_indices),
130
+ "num_tiles": tile_pooled.shape[0],
131
+ }
132
+
133
+ # Embed queries
134
+ logger.info(f"🔍 Embedding {len(queries)} queries...")
135
+ start_time = time.time()
136
+
137
+ query_data = {}
138
+ for i, query in enumerate(tqdm(queries, desc="Queries")):
139
+ q_emb = embedder.embed_query(query)
140
+ if hasattr(q_emb, 'cpu'):
141
+ q_emb = q_emb.cpu()
142
+ q_np = q_emb.numpy() if hasattr(q_emb, 'numpy') else np.array(q_emb)
143
+ query_data[f"q_{i}"] = q_np.astype(np.float32)
144
+
145
+ query_embed_time = time.time() - start_time
146
+
147
+ return {
148
+ "docs": doc_data,
149
+ "queries": query_data,
150
+ "samples": samples,
151
+ "doc_embed_time": doc_embed_time,
152
+ "query_embed_time": query_embed_time,
153
+ "model": model_name,
154
+ }
155
+
156
+
157
+ def search_exhaustive(query_emb: np.ndarray, docs: Dict, top_k: int = 10) -> List[Dict]:
158
+ """Exhaustive MaxSim search over all documents."""
159
+ scores = []
160
+ for doc_id, doc in docs.items():
161
+ score = compute_maxsim_score(query_emb, doc["embedding"])
162
+ scores.append({"id": doc_id, "score": score})
163
+
164
+ scores.sort(key=lambda x: x["score"], reverse=True)
165
+ return scores[:top_k]
166
+
167
+
168
+ def search_two_stage(
169
+ query_emb: np.ndarray,
170
+ docs: Dict,
171
+ prefetch_k: int = 20,
172
+ top_k: int = 10,
173
+ ) -> List[Dict]:
174
+ """
175
+ Two-stage retrieval with tile-level pooling.
176
+
177
+ Stage 1: Fast prefetch using tile-pooled vectors
178
+ Stage 2: Exact MaxSim reranking on candidates
179
+ """
180
+ # Stage 1: Tile-level pooled search
181
+ query_pooled = query_emb.mean(axis=0)
182
+ query_pooled = query_pooled / (np.linalg.norm(query_pooled) + 1e-8)
183
+
184
+ stage1_scores = []
185
+ for doc_id, doc in docs.items():
186
+ doc_pooled = doc["pooled"]
187
+ doc_norm = doc_pooled / (np.linalg.norm(doc_pooled, axis=1, keepdims=True) + 1e-8)
188
+ tile_sims = np.dot(doc_norm, query_pooled)
189
+ score = float(tile_sims.max())
190
+ stage1_scores.append({"id": doc_id, "score": score})
191
+
192
+ stage1_scores.sort(key=lambda x: x["score"], reverse=True)
193
+ candidates = stage1_scores[:prefetch_k]
194
+
195
+ # Stage 2: Exact MaxSim on candidates
196
+ reranked = []
197
+ for cand in candidates:
198
+ doc_id = cand["id"]
199
+ score = compute_maxsim_score(query_emb, docs[doc_id]["embedding"])
200
+ reranked.append({"id": doc_id, "score": score, "stage1_rank": stage1_scores.index(cand) + 1})
201
+
202
+ reranked.sort(key=lambda x: x["score"], reverse=True)
203
+ return reranked[:top_k]
204
+
205
+
206
+ def compute_metrics(
207
+ results: Dict[str, List[Dict]],
208
+ samples: List[Dict],
209
+ k_values: List[int] = [1, 3, 5, 7, 10],
210
+ ) -> Dict[str, float]:
211
+ """
212
+ Compute retrieval metrics.
213
+
214
+ Since ViDoRe has 1:1 query-doc mapping (1 relevant doc per query):
215
+ - Recall@K (Hit Rate): Is the relevant doc in top-K? (0 or 1)
216
+ - Precision@K: (# relevant in top-K) / K
217
+ - MRR@K: 1/rank if found in top-K, else 0
218
+ - NDCG@K: DCG / IDCG with binary relevance
219
+ """
220
+ metrics = {}
221
+
222
+ # Also track per-query ranks for analysis
223
+ all_ranks = []
224
+
225
+ for k in k_values:
226
+ recalls = []
227
+ precisions = []
228
+ mrrs = []
229
+ ndcgs = []
230
+
231
+ for sample in samples:
232
+ query_id = sample["query_id"]
233
+ relevant_doc = sample["relevant_doc"]
234
+
235
+ if query_id not in results:
236
+ continue
237
+
238
+ ranking = results[query_id][:k]
239
+ ranked_ids = [r["id"] for r in ranking]
240
+
241
+ # Find rank of relevant doc (1-indexed, 0 if not found)
242
+ rank = 0
243
+ for i, doc_id in enumerate(ranked_ids):
244
+ if doc_id == relevant_doc:
245
+ rank = i + 1
246
+ break
247
+
248
+ # Recall@K (Hit Rate): 1 if found in top-K
249
+ found = 1.0 if rank > 0 else 0.0
250
+ recalls.append(found)
251
+
252
+ # Precision@K: (# relevant found) / K
253
+ # With 1 relevant doc: 1/K if found, 0 otherwise
254
+ precision = found / k
255
+ precisions.append(precision)
256
+
257
+ # MRR@K: 1/rank if found
258
+ mrr = 1.0 / rank if rank > 0 else 0.0
259
+ mrrs.append(mrr)
260
+
261
+ # NDCG@K (binary relevance)
262
+ # DCG = 1/log2(rank+1) if found, 0 otherwise
263
+ # IDCG = 1/log2(2) = 1 (best case: relevant at rank 1)
264
+ dcg = 1.0 / np.log2(rank + 1) if rank > 0 else 0.0
265
+ idcg = 1.0
266
+ ndcg = dcg / idcg
267
+ ndcgs.append(ndcg)
268
+
269
+ # Track actual rank for analysis (only for k=10)
270
+ if k == max(k_values):
271
+ full_ranking = results[query_id]
272
+ full_rank = 0
273
+ for i, r in enumerate(full_ranking):
274
+ if r["id"] == relevant_doc:
275
+ full_rank = i + 1
276
+ break
277
+ all_ranks.append(full_rank)
278
+
279
+ metrics[f"Recall@{k}"] = np.mean(recalls)
280
+ metrics[f"P@{k}"] = np.mean(precisions)
281
+ metrics[f"MRR@{k}"] = np.mean(mrrs)
282
+ metrics[f"NDCG@{k}"] = np.mean(ndcgs)
283
+
284
+ # Add summary stats
285
+ if all_ranks:
286
+ found_ranks = [r for r in all_ranks if r > 0]
287
+ metrics["avg_rank"] = np.mean(found_ranks) if found_ranks else float('inf')
288
+ metrics["median_rank"] = np.median(found_ranks) if found_ranks else float('inf')
289
+ metrics["not_found"] = sum(1 for r in all_ranks if r == 0)
290
+
291
+ return metrics
292
+
293
+
294
+ def run_benchmark(
295
+ data: Dict,
296
+ skip_exhaustive: bool = False,
297
+ prefetch_k: int = None,
298
+ top_k: int = 10,
299
+ ) -> Dict[str, Dict]:
300
+ """Run retrieval benchmark with metrics."""
301
+ docs = data["docs"]
302
+ queries = data["queries"]
303
+ samples = data["samples"]
304
+ num_docs = len(docs)
305
+
306
+ # Auto-set prefetch_k to be meaningful (default: 20, or 20% of docs if >100 docs)
307
+ if prefetch_k is None:
308
+ if num_docs <= 100:
309
+ prefetch_k = 20 # Default: prefetch 20, rerank to top-10
310
+ else:
311
+ prefetch_k = max(20, min(100, int(num_docs * 0.2))) # 20% for larger collections
312
+
313
+ # Ensure prefetch_k < num_docs for meaningful two-stage comparison
314
+ if prefetch_k >= num_docs:
315
+ logger.warning(f"⚠️ prefetch_k={prefetch_k} >= num_docs={num_docs}")
316
+ logger.warning(f" Two-stage will fetch ALL docs (same as exhaustive)")
317
+ logger.warning(f" Use --samples > {prefetch_k * 3} for meaningful comparison")
318
+
319
+ logger.info(f"📊 Benchmark config: {num_docs} docs, prefetch_k={prefetch_k}, top_k={top_k}")
320
+ logger.info(f" (Both methods return top-{top_k} results - realistic retrieval scenario)")
321
+
322
+ results = {}
323
+
324
+ # Two-stage retrieval (NOVEL)
325
+ logger.info(f"\n🔬 Running Two-Stage retrieval (prefetch top-{prefetch_k}, rerank to top-{top_k})...")
326
+ two_stage_results = {}
327
+ two_stage_times = []
328
+
329
+ for sample in tqdm(samples, desc="Two-Stage"):
330
+ query_id = sample["query_id"]
331
+ query_emb = queries[query_id]
332
+
333
+ start = time.time()
334
+ ranking = search_two_stage(query_emb, docs, prefetch_k=prefetch_k, top_k=top_k)
335
+ two_stage_times.append(time.time() - start)
336
+
337
+ two_stage_results[query_id] = ranking
338
+
339
+ two_stage_metrics = compute_metrics(two_stage_results, samples)
340
+ two_stage_metrics["avg_time_ms"] = np.mean(two_stage_times) * 1000
341
+ two_stage_metrics["prefetch_k"] = prefetch_k
342
+ two_stage_metrics["top_k"] = top_k
343
+ results["two_stage"] = two_stage_metrics
344
+
345
+ # Exhaustive search (baseline)
346
+ if not skip_exhaustive:
347
+ logger.info(f"🔬 Running Exhaustive MaxSim (searches ALL {num_docs} docs, returns top-{top_k})...")
348
+ exhaustive_results = {}
349
+ exhaustive_times = []
350
+
351
+ for sample in tqdm(samples, desc="Exhaustive"):
352
+ query_id = sample["query_id"]
353
+ query_emb = queries[query_id]
354
+
355
+ start = time.time()
356
+ ranking = search_exhaustive(query_emb, docs, top_k=top_k)
357
+ exhaustive_times.append(time.time() - start)
358
+
359
+ exhaustive_results[query_id] = ranking
360
+
361
+ exhaustive_metrics = compute_metrics(exhaustive_results, samples)
362
+ exhaustive_metrics["avg_time_ms"] = np.mean(exhaustive_times) * 1000
363
+ exhaustive_metrics["top_k"] = top_k
364
+ results["exhaustive"] = exhaustive_metrics
365
+
366
+ return results
367
+
368
+
369
+ def print_results(data: Dict, benchmark_results: Dict, show_precision: bool = False):
370
+ """Print benchmark results."""
371
+ print("\n" + "=" * 80)
372
+ print("📊 BENCHMARK RESULTS")
373
+ print("=" * 80)
374
+
375
+ num_docs = len(data['docs'])
376
+ print(f"\n🤖 Model: {data['model']}")
377
+ print(f"📄 Documents: {num_docs}")
378
+ print(f"🔍 Queries: {len(data['queries'])}")
379
+
380
+ # Embedding stats
381
+ sample_doc = list(data['docs'].values())[0]
382
+ print(f"\n📏 Embedding (after visual token filtering):")
383
+ print(f" Visual tokens per doc: {sample_doc['num_visual_tokens']}")
384
+ print(f" Tile-pooled vectors: {sample_doc['num_tiles']}")
385
+
386
+ if "two_stage" in benchmark_results:
387
+ prefetch_k = benchmark_results["two_stage"].get("prefetch_k", "?")
388
+ print(f" Two-stage prefetch_k: {prefetch_k} (of {num_docs} docs)")
389
+
390
+ # Method labels - clearer naming
391
+ def get_label(method):
392
+ if method == "two_stage":
393
+ return "Pooled+Rerank" # Tile-pooled prefetch + MaxSim rerank
394
+ else:
395
+ return "Full MaxSim" # Exhaustive MaxSim on all docs
396
+
397
+ # Recall / Hit Rate table
398
+ print(f"\n🎯 RECALL (Hit Rate) @ K:")
399
+ print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
400
+ print(f" {'-'*60}")
401
+
402
+ for method, metrics in benchmark_results.items():
403
+ print(f" {get_label(method):<20} "
404
+ f"{metrics.get('Recall@1', 0):>8.3f} "
405
+ f"{metrics.get('Recall@3', 0):>8.3f} "
406
+ f"{metrics.get('Recall@5', 0):>8.3f} "
407
+ f"{metrics.get('Recall@7', 0):>8.3f} "
408
+ f"{metrics.get('Recall@10', 0):>8.3f}")
409
+
410
+ # Precision table (optional)
411
+ if show_precision:
412
+ print(f"\n📐 PRECISION @ K:")
413
+ print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
414
+ print(f" {'-'*60}")
415
+
416
+ for method, metrics in benchmark_results.items():
417
+ print(f" {get_label(method):<20} "
418
+ f"{metrics.get('P@1', 0):>8.3f} "
419
+ f"{metrics.get('P@3', 0):>8.3f} "
420
+ f"{metrics.get('P@5', 0):>8.3f} "
421
+ f"{metrics.get('P@7', 0):>8.3f} "
422
+ f"{metrics.get('P@10', 0):>8.3f}")
423
+
424
+ # NDCG table
425
+ print(f"\n📈 NDCG @ K:")
426
+ print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
427
+ print(f" {'-'*60}")
428
+
429
+ for method, metrics in benchmark_results.items():
430
+ print(f" {get_label(method):<20} "
431
+ f"{metrics.get('NDCG@1', 0):>8.3f} "
432
+ f"{metrics.get('NDCG@3', 0):>8.3f} "
433
+ f"{metrics.get('NDCG@5', 0):>8.3f} "
434
+ f"{metrics.get('NDCG@7', 0):>8.3f} "
435
+ f"{metrics.get('NDCG@10', 0):>8.3f}")
436
+
437
+ # MRR table
438
+ print(f"\n🔍 MRR @ K:")
439
+ print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
440
+ print(f" {'-'*60}")
441
+
442
+ for method, metrics in benchmark_results.items():
443
+ print(f" {get_label(method):<20} "
444
+ f"{metrics.get('MRR@1', 0):>8.3f} "
445
+ f"{metrics.get('MRR@3', 0):>8.3f} "
446
+ f"{metrics.get('MRR@5', 0):>8.3f} "
447
+ f"{metrics.get('MRR@7', 0):>8.3f} "
448
+ f"{metrics.get('MRR@10', 0):>8.3f}")
449
+
450
+ # Speed comparison
451
+ top_k = benchmark_results.get("two_stage", benchmark_results.get("exhaustive", {})).get("top_k", 10)
452
+ print(f"\n⏱️ SPEED (both return top-{top_k} results):")
453
+ print(f" {'Method':<20} {'Time (ms)':>12} {'Docs searched':>15}")
454
+ print(f" {'-'*50}")
455
+
456
+ for method, metrics in benchmark_results.items():
457
+ if method == "two_stage":
458
+ searched = metrics.get("prefetch_k", "?")
459
+ label = f"{searched} (stage-1)"
460
+ else:
461
+ searched = num_docs
462
+ label = f"{searched} (all)"
463
+ print(f" {get_label(method):<20} {metrics.get('avg_time_ms', 0):>12.2f} {label:>15}")
464
+
465
+ # Comparison summary
466
+ if "exhaustive" in benchmark_results and "two_stage" in benchmark_results:
467
+ ex = benchmark_results["exhaustive"]
468
+ ts = benchmark_results["two_stage"]
469
+
470
+ print(f"\n💡 POOLED+RERANK vs FULL MAXSIM:")
471
+
472
+ for k in [1, 5, 10]:
473
+ ex_recall = ex.get(f"Recall@{k}", 0)
474
+ ts_recall = ts.get(f"Recall@{k}", 0)
475
+ if ex_recall > 0:
476
+ retention = ts_recall / ex_recall * 100
477
+ print(f" • Recall@{k} retention: {retention:.1f}% ({ts_recall:.3f} vs {ex_recall:.3f})")
478
+
479
+ speedup = ex["avg_time_ms"] / ts["avg_time_ms"] if ts["avg_time_ms"] > 0 else 0
480
+ print(f" • Speedup: {speedup:.1f}x")
481
+
482
+ # Rank stats with explanation
483
+ if "avg_rank" in ts:
484
+ prefetch_k = ts.get("prefetch_k", "?")
485
+ top_k = ts.get("top_k", 10)
486
+ not_found = ts.get("not_found", 0)
487
+ total = len(data["queries"])
488
+
489
+ print(f"\n📊 POOLED+RERANK STATISTICS:")
490
+ print(f" Stage-1 (pooled prefetch):")
491
+ print(f" • Searches top-{prefetch_k} candidates using tile-pooled vectors")
492
+ print(f" • {total - not_found}/{total} queries ({100 - not_found/total*100:.1f}%) had relevant doc in prefetch")
493
+ print(f" • {not_found}/{total} queries ({not_found/total*100:.1f}%) missed (relevant doc ranked >{prefetch_k})")
494
+ print(f" Stage-2 (MaxSim reranking):")
495
+ print(f" • Reranks prefetch candidates with exact MaxSim")
496
+ print(f" • Returns final top-{top_k} results")
497
+ if ts['avg_rank'] < float('inf'):
498
+ print(f" • Avg rank of relevant doc (when found): {ts['avg_rank']:.1f}")
499
+ print(f" • Median rank: {ts['median_rank']:.1f}")
500
+ print(f"\n 💡 The {not_found/total*100:.1f}% miss rate is for stage-1 prefetch.")
501
+ print(f" Final Recall@{top_k} shows how many relevant docs ARE in top-{top_k} results.")
502
+
503
+ print("\n" + "=" * 80)
504
+ print("✅ Benchmark complete!")
505
+
506
+
507
+ def main():
508
+ parser = argparse.ArgumentParser(
509
+ description="Quick benchmark for visual-rag-toolkit",
510
+ formatter_class=argparse.RawDescriptionHelpFormatter,
511
+ )
512
+ parser.add_argument(
513
+ "--samples", type=int, default=100,
514
+ help="Number of samples (default: 100)"
515
+ )
516
+ parser.add_argument(
517
+ "--model", type=str, default="vidore/colSmol-500M",
518
+ help="Model: vidore/colSmol-500M (default), vidore/colpali-v1.3"
519
+ )
520
+ parser.add_argument(
521
+ "--prefetch-k", type=int, default=None,
522
+ help="Stage 1 candidates for two-stage (default: 20 for <=100 docs, auto for larger)"
523
+ )
524
+ parser.add_argument(
525
+ "--skip-exhaustive", action="store_true",
526
+ help="Skip exhaustive baseline (faster)"
527
+ )
528
+ parser.add_argument(
529
+ "--show-precision", action="store_true",
530
+ help="Show Precision@K metrics (hidden by default)"
531
+ )
532
+ parser.add_argument(
533
+ "--top-k", type=int, default=10,
534
+ help="Number of results to return (default: 10, realistic retrieval scenario)"
535
+ )
536
+
537
+ args = parser.parse_args()
538
+
539
+ print("\n" + "=" * 70)
540
+ print("🧪 VISUAL RAG TOOLKIT - RETRIEVAL BENCHMARK")
541
+ print("=" * 70)
542
+
543
+ # Load samples
544
+ samples = load_vidore_sample(args.samples)
545
+
546
+ if not samples:
547
+ logger.error("No samples loaded!")
548
+ sys.exit(1)
549
+
550
+ # Embed all
551
+ data = embed_all(samples, args.model)
552
+
553
+ # Run benchmark
554
+ benchmark_results = run_benchmark(
555
+ data,
556
+ skip_exhaustive=args.skip_exhaustive,
557
+ prefetch_k=args.prefetch_k,
558
+ top_k=args.top_k,
559
+ )
560
+
561
+ # Print results
562
+ print_results(data, benchmark_results, show_precision=args.show_precision)
563
+
564
+
565
+ if __name__ == "__main__":
566
+ main()
visual_rag/__init__.py CHANGED
@@ -14,16 +14,16 @@ Components:
14
  Quick Start:
15
  ------------
16
  >>> from visual_rag import VisualEmbedder, PDFProcessor, TwoStageRetriever
17
- >>>
18
  >>> # Process PDFs
19
  >>> processor = PDFProcessor(dpi=140)
20
  >>> images, texts = processor.process_pdf("report.pdf")
21
- >>>
22
  >>> # Generate embeddings
23
  >>> embedder = VisualEmbedder()
24
  >>> embeddings = embedder.embed_images(images)
25
  >>> query_emb = embedder.embed_query("What is the budget?")
26
- >>>
27
  >>> # Search with two-stage retrieval
28
  >>> retriever = TwoStageRetriever(qdrant_client, "my_collection")
29
  >>> results = retriever.search(query_emb, top_k=10)
@@ -31,7 +31,7 @@ Quick Start:
31
  Each component works independently - use only what you need.
32
  """
33
 
34
- __version__ = "0.1.0"
35
 
36
  # Import main classes at package level for convenience
37
  # These are optional - if dependencies aren't installed, we catch the error
@@ -71,13 +71,17 @@ try:
71
  except ImportError:
72
  QdrantAdmin = None
73
 
 
 
 
 
 
74
  # Config utilities (always available)
75
- from visual_rag.config import load_config, get, get_section
76
 
77
  __all__ = [
78
  # Version
79
  "__version__",
80
-
81
  # Main classes
82
  "VisualEmbedder",
83
  "PDFProcessor",
@@ -86,7 +90,7 @@ __all__ = [
86
  "TwoStageRetriever",
87
  "MultiVectorRetriever",
88
  "QdrantAdmin",
89
-
90
  # Config utilities
91
  "load_config",
92
  "get",
 
14
  Quick Start:
15
  ------------
16
  >>> from visual_rag import VisualEmbedder, PDFProcessor, TwoStageRetriever
17
+ >>>
18
  >>> # Process PDFs
19
  >>> processor = PDFProcessor(dpi=140)
20
  >>> images, texts = processor.process_pdf("report.pdf")
21
+ >>>
22
  >>> # Generate embeddings
23
  >>> embedder = VisualEmbedder()
24
  >>> embeddings = embedder.embed_images(images)
25
  >>> query_emb = embedder.embed_query("What is the budget?")
26
+ >>>
27
  >>> # Search with two-stage retrieval
28
  >>> retriever = TwoStageRetriever(qdrant_client, "my_collection")
29
  >>> results = retriever.search(query_emb, top_k=10)
 
31
  Each component works independently - use only what you need.
32
  """
33
 
34
+ __version__ = "0.1.3"
35
 
36
  # Import main classes at package level for convenience
37
  # These are optional - if dependencies aren't installed, we catch the error
 
71
  except ImportError:
72
  QdrantAdmin = None
73
 
74
+ try:
75
+ from visual_rag.demo_runner import demo
76
+ except ImportError:
77
+ demo = None
78
+
79
  # Config utilities (always available)
80
+ from visual_rag.config import get, get_section, load_config
81
 
82
  __all__ = [
83
  # Version
84
  "__version__",
 
85
  # Main classes
86
  "VisualEmbedder",
87
  "PDFProcessor",
 
90
  "TwoStageRetriever",
91
  "MultiVectorRetriever",
92
  "QdrantAdmin",
93
+ "demo",
94
  # Config utilities
95
  "load_config",
96
  "get",
visual_rag/cli/__init__.py CHANGED
@@ -1,3 +1 @@
1
  """CLI entry point for visual-rag-toolkit."""
2
-
3
-
 
1
  """CLI entry point for visual-rag-toolkit."""
 
 
visual_rag/cli/main.py CHANGED
@@ -10,20 +10,19 @@ Provides command-line interface for:
10
  Usage:
11
  # Process PDFs (like process_pdfs_saliency_v2.py)
12
  visual-rag process --reports-dir ./pdfs --metadata-file metadata.json
13
-
14
  # Search
15
  visual-rag search --query "budget allocation" --collection my_docs
16
-
17
  # Show collection info
18
  visual-rag info --collection my_docs
19
  """
20
 
21
- import os
22
- import sys
23
  import argparse
24
  import logging
 
 
25
  from pathlib import Path
26
- from typing import Optional
27
  from urllib.parse import urlparse
28
 
29
  from dotenv import load_dotenv
@@ -44,38 +43,38 @@ def setup_logging(debug: bool = False):
44
  def cmd_process(args):
45
  """
46
  Process PDFs: convert → embed → upload to Cloudinary → index in Qdrant.
47
-
48
  Equivalent to process_pdfs_saliency_v2.py
49
  """
50
- from visual_rag import VisualEmbedder, QdrantIndexer, CloudinaryUploader, load_config
51
  from visual_rag.indexing.pipeline import ProcessingPipeline
52
-
53
  # Load environment
54
  load_dotenv()
55
-
56
  # Load config
57
  config = {}
58
  if args.config and Path(args.config).exists():
59
  config = load_config(args.config)
60
-
61
  # Get PDFs
62
  reports_dir = Path(args.reports_dir)
63
  if not reports_dir.exists():
64
  logger.error(f"❌ Reports directory not found: {reports_dir}")
65
  sys.exit(1)
66
-
67
  pdf_paths = sorted(reports_dir.glob("*.pdf")) + sorted(reports_dir.glob("*.PDF"))
68
  if not pdf_paths:
69
  logger.error(f"❌ No PDF files found in: {reports_dir}")
70
  sys.exit(1)
71
-
72
  logger.info(f"📁 Found {len(pdf_paths)} PDF files")
73
-
74
  # Load metadata mapping
75
  metadata_mapping = {}
76
  if args.metadata_file:
77
  metadata_mapping = ProcessingPipeline.load_metadata_mapping(Path(args.metadata_file))
78
-
79
  # Dry run - just show summary
80
  if args.dry_run:
81
  logger.info("🏃 DRY RUN MODE")
@@ -83,21 +82,24 @@ def cmd_process(args):
83
  logger.info(f" Metadata entries: {len(metadata_mapping)}")
84
  logger.info(f" Collection: {args.collection}")
85
  logger.info(f" Cloudinary: {'ENABLED' if not args.no_cloudinary else 'DISABLED'}")
86
-
87
  for pdf in pdf_paths[:10]:
88
  has_meta = "✓" if pdf.stem.lower() in metadata_mapping else "✗"
89
  logger.info(f" {has_meta} {pdf.name}")
90
  if len(pdf_paths) > 10:
91
  logger.info(f" ... and {len(pdf_paths) - 10} more")
92
  return
93
-
94
  # Get settings
95
  model_name = args.model or config.get("model", {}).get("name", "vidore/colSmol-500M")
96
- collection_name = args.collection or config.get("qdrant", {}).get("collection_name", "visual_documents")
97
-
 
 
98
  torch_dtype = None
99
  if args.torch_dtype != "auto":
100
  import torch
 
101
  torch_dtype = {
102
  "float32": torch.float32,
103
  "float16": torch.float16,
@@ -111,20 +113,22 @@ def cmd_process(args):
111
  torch_dtype=torch_dtype,
112
  processor_speed=str(getattr(args, "processor_speed", "fast")),
113
  )
114
-
115
  # Initialize Qdrant indexer
116
- qdrant_url = os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
 
 
117
  qdrant_api_key = (
118
  os.getenv("SIGIR_QDRANT_KEY")
119
  or os.getenv("SIGIR_QDRANT_API_KEY")
120
  or os.getenv("DEST_QDRANT_API_KEY")
121
  or os.getenv("QDRANT_API_KEY")
122
  )
123
-
124
  if not qdrant_url:
125
  logger.error("❌ QDRANT_URL environment variable not set")
126
  sys.exit(1)
127
-
128
  logger.info(f"🔌 Connecting to Qdrant: {qdrant_url}")
129
  indexer = QdrantIndexer(
130
  url=qdrant_url,
@@ -133,7 +137,7 @@ def cmd_process(args):
133
  prefer_grpc=args.prefer_grpc,
134
  vector_datatype=args.qdrant_vector_dtype,
135
  )
136
-
137
  # Create collection if needed
138
  indexer.create_collection(force_recreate=args.force_recreate)
139
  inferred_fields = []
@@ -166,7 +170,7 @@ def cmd_process(args):
166
  inferred_fields.append({"field": k, "type": inferred_type})
167
 
168
  indexer.create_payload_indexes(fields=inferred_fields)
169
-
170
  # Initialize Cloudinary uploader (optional)
171
  cloudinary_uploader = None
172
  if not args.no_cloudinary:
@@ -176,7 +180,7 @@ def cmd_process(args):
176
  except ValueError as e:
177
  logger.warning(f"⚠️ Cloudinary not configured: {e}")
178
  logger.warning(" Continuing without Cloudinary uploads")
179
-
180
  # Create pipeline
181
  pipeline = ProcessingPipeline(
182
  embedder=embedder,
@@ -186,42 +190,44 @@ def cmd_process(args):
186
  config=config,
187
  embedding_strategy=args.strategy,
188
  crop_empty=bool(getattr(args, "crop_empty", False)),
189
- crop_empty_percentage_to_remove=float(getattr(args, "crop_empty_percentage_to_remove", 0.9)),
 
 
190
  crop_empty_remove_page_number=bool(getattr(args, "crop_empty_remove_page_number", False)),
191
  )
192
-
193
  # Process PDFs
194
  total_uploaded = 0
195
  total_skipped = 0
196
  total_failed = 0
197
-
198
  skip_existing = not args.no_skip_existing
199
-
200
  for pdf_idx, pdf_path in enumerate(pdf_paths, 1):
201
  logger.info(f"\n{'='*60}")
202
  logger.info(f"📄 [{pdf_idx}/{len(pdf_paths)}] {pdf_path.name}")
203
  logger.info(f"{'='*60}")
204
-
205
  result = pipeline.process_pdf(
206
  pdf_path,
207
  skip_existing=skip_existing,
208
  upload_to_cloudinary=(not args.no_cloudinary),
209
  upload_to_qdrant=True,
210
  )
211
-
212
  total_uploaded += result["uploaded"]
213
  total_skipped += result["skipped"]
214
  total_failed += result["failed"]
215
-
216
  # Summary
217
  logger.info(f"\n{'='*60}")
218
- logger.info(f"📊 SUMMARY")
219
  logger.info(f"{'='*60}")
220
  logger.info(f" Total PDFs: {len(pdf_paths)}")
221
  logger.info(f" Uploaded: {total_uploaded}")
222
  logger.info(f" Skipped: {total_skipped}")
223
  logger.info(f" Failed: {total_failed}")
224
-
225
  info = indexer.get_collection_info()
226
  if info:
227
  logger.info(f" Collection points: {info.get('points_count', 'N/A')}")
@@ -229,29 +235,34 @@ def cmd_process(args):
229
 
230
  def cmd_search(args):
231
  """Search documents."""
232
- from visual_rag import VisualEmbedder
233
- from visual_rag.retrieval import TwoStageRetriever, SingleStageRetriever
234
  from qdrant_client import QdrantClient
235
-
 
 
 
236
  load_dotenv()
237
-
238
- qdrant_url = os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
 
 
239
  qdrant_api_key = (
240
  os.getenv("SIGIR_QDRANT_KEY")
241
  or os.getenv("SIGIR_QDRANT_API_KEY")
242
  or os.getenv("DEST_QDRANT_API_KEY")
243
  or os.getenv("QDRANT_API_KEY")
244
  )
245
-
246
  if not qdrant_url:
247
  logger.error("❌ QDRANT_URL not set")
248
  sys.exit(1)
249
-
250
  # Initialize
251
  logger.info(f"🤖 Loading model: {args.model}")
252
- embedder = VisualEmbedder(model_name=args.model, processor_speed=str(getattr(args, "processor_speed", "fast")))
 
 
253
 
254
- logger.info(f"🔌 Connecting to Qdrant")
255
  grpc_port = 6334 if args.prefer_grpc and urlparse(qdrant_url).port == 6333 else None
256
  client = QdrantClient(
257
  url=qdrant_url,
@@ -262,11 +273,11 @@ def cmd_search(args):
262
  )
263
  two_stage = TwoStageRetriever(client, args.collection)
264
  single_stage = SingleStageRetriever(client, args.collection)
265
-
266
  # Embed query
267
  logger.info(f"🔍 Query: {args.query}")
268
  query_embedding = embedder.embed_query(args.query)
269
-
270
  # Build filter
271
  filter_obj = None
272
  if args.year or args.source or args.district:
@@ -275,7 +286,7 @@ def cmd_search(args):
275
  source=args.source,
276
  district=args.district,
277
  )
278
-
279
  # Search
280
  query_np = query_embedding.detach().cpu().numpy()
281
  if args.strategy == "single_full":
@@ -307,21 +318,21 @@ def cmd_search(args):
307
  filter_obj=filter_obj,
308
  stage1_mode=args.stage1_mode,
309
  )
310
-
311
  # Display results
312
  logger.info(f"\n📊 Results ({len(results)}):")
313
  for i, result in enumerate(results, 1):
314
  payload = result.get("payload", {})
315
  score = result.get("score_final", result.get("score_stage1", 0))
316
-
317
  filename = payload.get("filename", "N/A")
318
  page_num = payload.get("page_number", "N/A")
319
  year = payload.get("year", "N/A")
320
  source = payload.get("source", "N/A")
321
-
322
  logger.info(f" {i}. {filename} p.{page_num}")
323
  logger.info(f" Score: {score:.4f} | Year: {year} | Source: {source}")
324
-
325
  # Text snippet
326
  text = payload.get("text", "")
327
  if text and args.show_text:
@@ -332,21 +343,23 @@ def cmd_search(args):
332
  def cmd_info(args):
333
  """Show collection info."""
334
  from qdrant_client import QdrantClient
335
-
336
  load_dotenv()
337
-
338
- qdrant_url = os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
 
 
339
  qdrant_api_key = (
340
  os.getenv("SIGIR_QDRANT_KEY")
341
  or os.getenv("SIGIR_QDRANT_API_KEY")
342
  or os.getenv("DEST_QDRANT_API_KEY")
343
  or os.getenv("QDRANT_API_KEY")
344
  )
345
-
346
  if not qdrant_url:
347
  logger.error("❌ QDRANT_URL not set")
348
  sys.exit(1)
349
-
350
  grpc_port = 6334 if args.prefer_grpc and urlparse(qdrant_url).port == 6333 else None
351
  client = QdrantClient(
352
  url=qdrant_url,
@@ -355,29 +368,29 @@ def cmd_info(args):
355
  grpc_port=grpc_port,
356
  check_compatibility=False,
357
  )
358
-
359
  try:
360
  info = client.get_collection(args.collection)
361
-
362
  status = info.status
363
  if hasattr(status, "value"):
364
  status = status.value
365
-
366
  indexed_count = getattr(info, "indexed_vectors_count", 0) or 0
367
  if isinstance(indexed_count, dict):
368
  indexed_count = sum(indexed_count.values())
369
-
370
  logger.info(f"📊 Collection: {args.collection}")
371
  logger.info(f" Status: {status}")
372
  logger.info(f" Points: {info.points_count}")
373
  logger.info(f" Indexed vectors: {indexed_count}")
374
-
375
  # Show vector config
376
  if hasattr(info, "config") and hasattr(info.config, "params"):
377
  vectors = getattr(info.config.params, "vectors", {})
378
  if vectors:
379
  logger.info(f" Vectors: {list(vectors.keys())}")
380
-
381
  except Exception as e:
382
  logger.error(f"❌ Could not get collection info: {e}")
383
  sys.exit(1)
@@ -393,24 +406,24 @@ def main():
393
  Examples:
394
  # Process PDFs (like process_pdfs_saliency_v2.py)
395
  visual-rag process --reports-dir ./pdfs --metadata-file metadata.json
396
-
397
  # Process without Cloudinary
398
  visual-rag process --reports-dir ./pdfs --no-cloudinary
399
-
400
  # Search
401
  visual-rag search --query "budget allocation" --collection my_docs
402
-
403
  # Search with filters
404
  visual-rag search --query "budget" --year 2023 --source "Local Government"
405
-
406
  # Show collection info
407
  visual-rag info --collection my_docs
408
  """,
409
  )
410
  parser.add_argument("--debug", action="store_true", help="Enable debug logging")
411
-
412
  subparsers = parser.add_subparsers(dest="command", help="Command")
413
-
414
  # =========================================================================
415
  # PROCESS command
416
  # =========================================================================
@@ -420,32 +433,26 @@ Examples:
420
  formatter_class=argparse.RawDescriptionHelpFormatter,
421
  )
422
  process_parser.add_argument(
423
- "--reports-dir", type=str, required=True,
424
- help="Directory containing PDF files"
425
- )
426
- process_parser.add_argument(
427
- "--metadata-file", type=str,
428
- help="JSON file with filename → metadata mapping (like filename_metadata.json)"
429
- )
430
- process_parser.add_argument(
431
- "--collection", type=str, default="visual_documents",
432
- help="Qdrant collection name"
433
  )
434
  process_parser.add_argument(
435
- "--model", type=str, default="vidore/colSmol-500M",
436
- help="Model name (vidore/colSmol-500M, vidore/colpali-v1.3, etc.)"
 
437
  )
438
  process_parser.add_argument(
439
- "--batch-size", type=int, default=8,
440
- help="Embedding batch size"
441
  )
442
  process_parser.add_argument(
443
- "--config", type=str,
444
- help="Path to config.yaml file"
 
 
445
  )
 
 
446
  process_parser.add_argument(
447
- "--no-cloudinary", action="store_true",
448
- help="Skip Cloudinary uploads"
449
  )
450
  process_parser.add_argument(
451
  "--crop-empty",
@@ -464,22 +471,23 @@ Examples:
464
  help="If set, attempts to crop away the bottom region that contains sparse page numbers (default: off).",
465
  )
466
  process_parser.add_argument(
467
- "--no-skip-existing", action="store_true",
468
- help="Process all pages even if they exist in Qdrant"
 
469
  )
470
  process_parser.add_argument(
471
- "--force-recreate", action="store_true",
472
- help="Delete and recreate collection"
473
  )
474
  process_parser.add_argument(
475
- "--dry-run", action="store_true",
476
- help="Show what would be processed without doing it"
477
  )
478
  process_parser.add_argument(
479
- "--strategy", type=str, default="pooling",
 
 
480
  choices=["pooling", "standard", "all"],
481
  help="Embedding strategy: 'pooling' (NOVEL), 'standard' (BASELINE), "
482
- "'all' (embed once, store BOTH for comparison)"
483
  )
484
  process_parser.add_argument(
485
  "--torch-dtype",
@@ -517,7 +525,7 @@ Examples:
517
  help="Disable gRPC for Qdrant client.",
518
  )
519
  process_parser.set_defaults(func=cmd_process)
520
-
521
  # =========================================================================
522
  # SEARCH command
523
  # =========================================================================
@@ -525,17 +533,12 @@ Examples:
525
  "search",
526
  help="Search documents",
527
  )
 
528
  search_parser.add_argument(
529
- "--query", type=str, required=True,
530
- help="Search query"
531
  )
532
  search_parser.add_argument(
533
- "--collection", type=str, default="visual_documents",
534
- help="Qdrant collection name"
535
- )
536
- search_parser.add_argument(
537
- "--model", type=str, default="vidore/colSmol-500M",
538
- help="Model name"
539
  )
540
  search_parser.add_argument(
541
  "--processor-speed",
@@ -544,39 +547,29 @@ Examples:
544
  choices=["fast", "slow", "auto"],
545
  help="Processor implementation: fast (default, with fallback to slow), slow, or auto.",
546
  )
 
547
  search_parser.add_argument(
548
- "--top-k", type=int, default=10,
549
- help="Number of results"
550
- )
551
- search_parser.add_argument(
552
- "--strategy", type=str, default="single_full",
553
  choices=["single_full", "single_tiles", "single_global", "two_stage"],
554
- help="Search strategy"
555
  )
556
  search_parser.add_argument(
557
- "--prefetch-k", type=int, default=200,
558
- help="Prefetch candidates for two-stage retrieval"
559
  )
560
  search_parser.add_argument(
561
- "--stage1-mode", type=str, default="pooled_query_vs_tiles",
 
 
562
  choices=["pooled_query_vs_tiles", "tokens_vs_tiles", "pooled_query_vs_global"],
563
- help="Stage 1 mode for two-stage retrieval"
564
  )
 
 
 
565
  search_parser.add_argument(
566
- "--year", type=int,
567
- help="Filter by year"
568
- )
569
- search_parser.add_argument(
570
- "--source", type=str,
571
- help="Filter by source"
572
- )
573
- search_parser.add_argument(
574
- "--district", type=str,
575
- help="Filter by district"
576
- )
577
- search_parser.add_argument(
578
- "--show-text", action="store_true",
579
- help="Show text snippets in results"
580
  )
581
  search_grpc_group = search_parser.add_mutually_exclusive_group()
582
  search_grpc_group.add_argument(
@@ -593,7 +586,7 @@ Examples:
593
  help="Disable gRPC for Qdrant client.",
594
  )
595
  search_parser.set_defaults(func=cmd_search)
596
-
597
  # =========================================================================
598
  # INFO command
599
  # =========================================================================
@@ -602,8 +595,7 @@ Examples:
602
  help="Show collection info",
603
  )
604
  info_parser.add_argument(
605
- "--collection", type=str, default="visual_documents",
606
- help="Qdrant collection name"
607
  )
608
  info_grpc_group = info_parser.add_mutually_exclusive_group()
609
  info_grpc_group.add_argument(
@@ -620,16 +612,16 @@ Examples:
620
  help="Disable gRPC for Qdrant client.",
621
  )
622
  info_parser.set_defaults(func=cmd_info)
623
-
624
  # Parse and execute
625
  args = parser.parse_args()
626
-
627
  setup_logging(args.debug)
628
-
629
  if not args.command:
630
  parser.print_help()
631
  sys.exit(0)
632
-
633
  args.func(args)
634
 
635
 
 
10
  Usage:
11
  # Process PDFs (like process_pdfs_saliency_v2.py)
12
  visual-rag process --reports-dir ./pdfs --metadata-file metadata.json
13
+
14
  # Search
15
  visual-rag search --query "budget allocation" --collection my_docs
16
+
17
  # Show collection info
18
  visual-rag info --collection my_docs
19
  """
20
 
 
 
21
  import argparse
22
  import logging
23
+ import os
24
+ import sys
25
  from pathlib import Path
 
26
  from urllib.parse import urlparse
27
 
28
  from dotenv import load_dotenv
 
43
  def cmd_process(args):
44
  """
45
  Process PDFs: convert → embed → upload to Cloudinary → index in Qdrant.
46
+
47
  Equivalent to process_pdfs_saliency_v2.py
48
  """
49
+ from visual_rag import CloudinaryUploader, QdrantIndexer, VisualEmbedder, load_config
50
  from visual_rag.indexing.pipeline import ProcessingPipeline
51
+
52
  # Load environment
53
  load_dotenv()
54
+
55
  # Load config
56
  config = {}
57
  if args.config and Path(args.config).exists():
58
  config = load_config(args.config)
59
+
60
  # Get PDFs
61
  reports_dir = Path(args.reports_dir)
62
  if not reports_dir.exists():
63
  logger.error(f"❌ Reports directory not found: {reports_dir}")
64
  sys.exit(1)
65
+
66
  pdf_paths = sorted(reports_dir.glob("*.pdf")) + sorted(reports_dir.glob("*.PDF"))
67
  if not pdf_paths:
68
  logger.error(f"❌ No PDF files found in: {reports_dir}")
69
  sys.exit(1)
70
+
71
  logger.info(f"📁 Found {len(pdf_paths)} PDF files")
72
+
73
  # Load metadata mapping
74
  metadata_mapping = {}
75
  if args.metadata_file:
76
  metadata_mapping = ProcessingPipeline.load_metadata_mapping(Path(args.metadata_file))
77
+
78
  # Dry run - just show summary
79
  if args.dry_run:
80
  logger.info("🏃 DRY RUN MODE")
 
82
  logger.info(f" Metadata entries: {len(metadata_mapping)}")
83
  logger.info(f" Collection: {args.collection}")
84
  logger.info(f" Cloudinary: {'ENABLED' if not args.no_cloudinary else 'DISABLED'}")
85
+
86
  for pdf in pdf_paths[:10]:
87
  has_meta = "✓" if pdf.stem.lower() in metadata_mapping else "✗"
88
  logger.info(f" {has_meta} {pdf.name}")
89
  if len(pdf_paths) > 10:
90
  logger.info(f" ... and {len(pdf_paths) - 10} more")
91
  return
92
+
93
  # Get settings
94
  model_name = args.model or config.get("model", {}).get("name", "vidore/colSmol-500M")
95
+ collection_name = args.collection or config.get("qdrant", {}).get(
96
+ "collection_name", "visual_documents"
97
+ )
98
+
99
  torch_dtype = None
100
  if args.torch_dtype != "auto":
101
  import torch
102
+
103
  torch_dtype = {
104
  "float32": torch.float32,
105
  "float16": torch.float16,
 
113
  torch_dtype=torch_dtype,
114
  processor_speed=str(getattr(args, "processor_speed", "fast")),
115
  )
116
+
117
  # Initialize Qdrant indexer
118
+ qdrant_url = (
119
+ os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
120
+ )
121
  qdrant_api_key = (
122
  os.getenv("SIGIR_QDRANT_KEY")
123
  or os.getenv("SIGIR_QDRANT_API_KEY")
124
  or os.getenv("DEST_QDRANT_API_KEY")
125
  or os.getenv("QDRANT_API_KEY")
126
  )
127
+
128
  if not qdrant_url:
129
  logger.error("❌ QDRANT_URL environment variable not set")
130
  sys.exit(1)
131
+
132
  logger.info(f"🔌 Connecting to Qdrant: {qdrant_url}")
133
  indexer = QdrantIndexer(
134
  url=qdrant_url,
 
137
  prefer_grpc=args.prefer_grpc,
138
  vector_datatype=args.qdrant_vector_dtype,
139
  )
140
+
141
  # Create collection if needed
142
  indexer.create_collection(force_recreate=args.force_recreate)
143
  inferred_fields = []
 
170
  inferred_fields.append({"field": k, "type": inferred_type})
171
 
172
  indexer.create_payload_indexes(fields=inferred_fields)
173
+
174
  # Initialize Cloudinary uploader (optional)
175
  cloudinary_uploader = None
176
  if not args.no_cloudinary:
 
180
  except ValueError as e:
181
  logger.warning(f"⚠️ Cloudinary not configured: {e}")
182
  logger.warning(" Continuing without Cloudinary uploads")
183
+
184
  # Create pipeline
185
  pipeline = ProcessingPipeline(
186
  embedder=embedder,
 
190
  config=config,
191
  embedding_strategy=args.strategy,
192
  crop_empty=bool(getattr(args, "crop_empty", False)),
193
+ crop_empty_percentage_to_remove=float(
194
+ getattr(args, "crop_empty_percentage_to_remove", 0.9)
195
+ ),
196
  crop_empty_remove_page_number=bool(getattr(args, "crop_empty_remove_page_number", False)),
197
  )
198
+
199
  # Process PDFs
200
  total_uploaded = 0
201
  total_skipped = 0
202
  total_failed = 0
203
+
204
  skip_existing = not args.no_skip_existing
205
+
206
  for pdf_idx, pdf_path in enumerate(pdf_paths, 1):
207
  logger.info(f"\n{'='*60}")
208
  logger.info(f"📄 [{pdf_idx}/{len(pdf_paths)}] {pdf_path.name}")
209
  logger.info(f"{'='*60}")
210
+
211
  result = pipeline.process_pdf(
212
  pdf_path,
213
  skip_existing=skip_existing,
214
  upload_to_cloudinary=(not args.no_cloudinary),
215
  upload_to_qdrant=True,
216
  )
217
+
218
  total_uploaded += result["uploaded"]
219
  total_skipped += result["skipped"]
220
  total_failed += result["failed"]
221
+
222
  # Summary
223
  logger.info(f"\n{'='*60}")
224
+ logger.info("📊 SUMMARY")
225
  logger.info(f"{'='*60}")
226
  logger.info(f" Total PDFs: {len(pdf_paths)}")
227
  logger.info(f" Uploaded: {total_uploaded}")
228
  logger.info(f" Skipped: {total_skipped}")
229
  logger.info(f" Failed: {total_failed}")
230
+
231
  info = indexer.get_collection_info()
232
  if info:
233
  logger.info(f" Collection points: {info.get('points_count', 'N/A')}")
 
235
 
236
  def cmd_search(args):
237
  """Search documents."""
 
 
238
  from qdrant_client import QdrantClient
239
+
240
+ from visual_rag import VisualEmbedder
241
+ from visual_rag.retrieval import SingleStageRetriever, TwoStageRetriever
242
+
243
  load_dotenv()
244
+
245
+ qdrant_url = (
246
+ os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
247
+ )
248
  qdrant_api_key = (
249
  os.getenv("SIGIR_QDRANT_KEY")
250
  or os.getenv("SIGIR_QDRANT_API_KEY")
251
  or os.getenv("DEST_QDRANT_API_KEY")
252
  or os.getenv("QDRANT_API_KEY")
253
  )
254
+
255
  if not qdrant_url:
256
  logger.error("❌ QDRANT_URL not set")
257
  sys.exit(1)
258
+
259
  # Initialize
260
  logger.info(f"🤖 Loading model: {args.model}")
261
+ embedder = VisualEmbedder(
262
+ model_name=args.model, processor_speed=str(getattr(args, "processor_speed", "fast"))
263
+ )
264
 
265
+ logger.info("🔌 Connecting to Qdrant")
266
  grpc_port = 6334 if args.prefer_grpc and urlparse(qdrant_url).port == 6333 else None
267
  client = QdrantClient(
268
  url=qdrant_url,
 
273
  )
274
  two_stage = TwoStageRetriever(client, args.collection)
275
  single_stage = SingleStageRetriever(client, args.collection)
276
+
277
  # Embed query
278
  logger.info(f"🔍 Query: {args.query}")
279
  query_embedding = embedder.embed_query(args.query)
280
+
281
  # Build filter
282
  filter_obj = None
283
  if args.year or args.source or args.district:
 
286
  source=args.source,
287
  district=args.district,
288
  )
289
+
290
  # Search
291
  query_np = query_embedding.detach().cpu().numpy()
292
  if args.strategy == "single_full":
 
318
  filter_obj=filter_obj,
319
  stage1_mode=args.stage1_mode,
320
  )
321
+
322
  # Display results
323
  logger.info(f"\n📊 Results ({len(results)}):")
324
  for i, result in enumerate(results, 1):
325
  payload = result.get("payload", {})
326
  score = result.get("score_final", result.get("score_stage1", 0))
327
+
328
  filename = payload.get("filename", "N/A")
329
  page_num = payload.get("page_number", "N/A")
330
  year = payload.get("year", "N/A")
331
  source = payload.get("source", "N/A")
332
+
333
  logger.info(f" {i}. {filename} p.{page_num}")
334
  logger.info(f" Score: {score:.4f} | Year: {year} | Source: {source}")
335
+
336
  # Text snippet
337
  text = payload.get("text", "")
338
  if text and args.show_text:
 
343
  def cmd_info(args):
344
  """Show collection info."""
345
  from qdrant_client import QdrantClient
346
+
347
  load_dotenv()
348
+
349
+ qdrant_url = (
350
+ os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
351
+ )
352
  qdrant_api_key = (
353
  os.getenv("SIGIR_QDRANT_KEY")
354
  or os.getenv("SIGIR_QDRANT_API_KEY")
355
  or os.getenv("DEST_QDRANT_API_KEY")
356
  or os.getenv("QDRANT_API_KEY")
357
  )
358
+
359
  if not qdrant_url:
360
  logger.error("❌ QDRANT_URL not set")
361
  sys.exit(1)
362
+
363
  grpc_port = 6334 if args.prefer_grpc and urlparse(qdrant_url).port == 6333 else None
364
  client = QdrantClient(
365
  url=qdrant_url,
 
368
  grpc_port=grpc_port,
369
  check_compatibility=False,
370
  )
371
+
372
  try:
373
  info = client.get_collection(args.collection)
374
+
375
  status = info.status
376
  if hasattr(status, "value"):
377
  status = status.value
378
+
379
  indexed_count = getattr(info, "indexed_vectors_count", 0) or 0
380
  if isinstance(indexed_count, dict):
381
  indexed_count = sum(indexed_count.values())
382
+
383
  logger.info(f"📊 Collection: {args.collection}")
384
  logger.info(f" Status: {status}")
385
  logger.info(f" Points: {info.points_count}")
386
  logger.info(f" Indexed vectors: {indexed_count}")
387
+
388
  # Show vector config
389
  if hasattr(info, "config") and hasattr(info.config, "params"):
390
  vectors = getattr(info.config.params, "vectors", {})
391
  if vectors:
392
  logger.info(f" Vectors: {list(vectors.keys())}")
393
+
394
  except Exception as e:
395
  logger.error(f"❌ Could not get collection info: {e}")
396
  sys.exit(1)
 
406
  Examples:
407
  # Process PDFs (like process_pdfs_saliency_v2.py)
408
  visual-rag process --reports-dir ./pdfs --metadata-file metadata.json
409
+
410
  # Process without Cloudinary
411
  visual-rag process --reports-dir ./pdfs --no-cloudinary
412
+
413
  # Search
414
  visual-rag search --query "budget allocation" --collection my_docs
415
+
416
  # Search with filters
417
  visual-rag search --query "budget" --year 2023 --source "Local Government"
418
+
419
  # Show collection info
420
  visual-rag info --collection my_docs
421
  """,
422
  )
423
  parser.add_argument("--debug", action="store_true", help="Enable debug logging")
424
+
425
  subparsers = parser.add_subparsers(dest="command", help="Command")
426
+
427
  # =========================================================================
428
  # PROCESS command
429
  # =========================================================================
 
433
  formatter_class=argparse.RawDescriptionHelpFormatter,
434
  )
435
  process_parser.add_argument(
436
+ "--reports-dir", type=str, required=True, help="Directory containing PDF files"
 
 
 
 
 
 
 
 
 
437
  )
438
  process_parser.add_argument(
439
+ "--metadata-file",
440
+ type=str,
441
+ help="JSON file with filename → metadata mapping (like filename_metadata.json)",
442
  )
443
  process_parser.add_argument(
444
+ "--collection", type=str, default="visual_documents", help="Qdrant collection name"
 
445
  )
446
  process_parser.add_argument(
447
+ "--model",
448
+ type=str,
449
+ default="vidore/colSmol-500M",
450
+ help="Model name (vidore/colSmol-500M, vidore/colpali-v1.3, etc.)",
451
  )
452
+ process_parser.add_argument("--batch-size", type=int, default=8, help="Embedding batch size")
453
+ process_parser.add_argument("--config", type=str, help="Path to config.yaml file")
454
  process_parser.add_argument(
455
+ "--no-cloudinary", action="store_true", help="Skip Cloudinary uploads"
 
456
  )
457
  process_parser.add_argument(
458
  "--crop-empty",
 
471
  help="If set, attempts to crop away the bottom region that contains sparse page numbers (default: off).",
472
  )
473
  process_parser.add_argument(
474
+ "--no-skip-existing",
475
+ action="store_true",
476
+ help="Process all pages even if they exist in Qdrant",
477
  )
478
  process_parser.add_argument(
479
+ "--force-recreate", action="store_true", help="Delete and recreate collection"
 
480
  )
481
  process_parser.add_argument(
482
+ "--dry-run", action="store_true", help="Show what would be processed without doing it"
 
483
  )
484
  process_parser.add_argument(
485
+ "--strategy",
486
+ type=str,
487
+ default="pooling",
488
  choices=["pooling", "standard", "all"],
489
  help="Embedding strategy: 'pooling' (NOVEL), 'standard' (BASELINE), "
490
+ "'all' (embed once, store BOTH for comparison)",
491
  )
492
  process_parser.add_argument(
493
  "--torch-dtype",
 
525
  help="Disable gRPC for Qdrant client.",
526
  )
527
  process_parser.set_defaults(func=cmd_process)
528
+
529
  # =========================================================================
530
  # SEARCH command
531
  # =========================================================================
 
533
  "search",
534
  help="Search documents",
535
  )
536
+ search_parser.add_argument("--query", type=str, required=True, help="Search query")
537
  search_parser.add_argument(
538
+ "--collection", type=str, default="visual_documents", help="Qdrant collection name"
 
539
  )
540
  search_parser.add_argument(
541
+ "--model", type=str, default="vidore/colSmol-500M", help="Model name"
 
 
 
 
 
542
  )
543
  search_parser.add_argument(
544
  "--processor-speed",
 
547
  choices=["fast", "slow", "auto"],
548
  help="Processor implementation: fast (default, with fallback to slow), slow, or auto.",
549
  )
550
+ search_parser.add_argument("--top-k", type=int, default=10, help="Number of results")
551
  search_parser.add_argument(
552
+ "--strategy",
553
+ type=str,
554
+ default="single_full",
 
 
555
  choices=["single_full", "single_tiles", "single_global", "two_stage"],
556
+ help="Search strategy",
557
  )
558
  search_parser.add_argument(
559
+ "--prefetch-k", type=int, default=200, help="Prefetch candidates for two-stage retrieval"
 
560
  )
561
  search_parser.add_argument(
562
+ "--stage1-mode",
563
+ type=str,
564
+ default="pooled_query_vs_tiles",
565
  choices=["pooled_query_vs_tiles", "tokens_vs_tiles", "pooled_query_vs_global"],
566
+ help="Stage 1 mode for two-stage retrieval",
567
  )
568
+ search_parser.add_argument("--year", type=int, help="Filter by year")
569
+ search_parser.add_argument("--source", type=str, help="Filter by source")
570
+ search_parser.add_argument("--district", type=str, help="Filter by district")
571
  search_parser.add_argument(
572
+ "--show-text", action="store_true", help="Show text snippets in results"
 
 
 
 
 
 
 
 
 
 
 
 
 
573
  )
574
  search_grpc_group = search_parser.add_mutually_exclusive_group()
575
  search_grpc_group.add_argument(
 
586
  help="Disable gRPC for Qdrant client.",
587
  )
588
  search_parser.set_defaults(func=cmd_search)
589
+
590
  # =========================================================================
591
  # INFO command
592
  # =========================================================================
 
595
  help="Show collection info",
596
  )
597
  info_parser.add_argument(
598
+ "--collection", type=str, default="visual_documents", help="Qdrant collection name"
 
599
  )
600
  info_grpc_group = info_parser.add_mutually_exclusive_group()
601
  info_grpc_group.add_argument(
 
612
  help="Disable gRPC for Qdrant client.",
613
  )
614
  info_parser.set_defaults(func=cmd_info)
615
+
616
  # Parse and execute
617
  args = parser.parse_args()
618
+
619
  setup_logging(args.debug)
620
+
621
  if not args.command:
622
  parser.print_help()
623
  sys.exit(0)
624
+
625
  args.func(args)
626
 
627
 
visual_rag/config.py CHANGED
@@ -7,57 +7,56 @@ Provides:
7
  - Convenience getters for common settings
8
  """
9
 
10
- import os
11
  import logging
 
12
  from pathlib import Path
13
- from typing import Any, Optional, Dict
14
 
15
  logger = logging.getLogger(__name__)
16
 
17
- # Global config cache
18
- _config_cache: Optional[Dict[str, Any]] = None
 
19
 
20
 
21
  def _env_qdrant_url() -> Optional[str]:
22
- return os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
 
23
 
24
 
25
  def _env_qdrant_api_key() -> Optional[str]:
26
- return (
27
- os.getenv("SIGIR_QDRANT_KEY")
28
- or os.getenv("SIGIR_QDRANT_API_KEY")
29
- or os.getenv("DEST_QDRANT_API_KEY")
30
- or os.getenv("QDRANT_API_KEY")
31
- )
32
 
33
 
34
  def load_config(
35
  config_path: Optional[str] = None,
36
  force_reload: bool = False,
 
37
  ) -> Dict[str, Any]:
38
  """
39
  Load configuration from YAML file.
40
-
41
  Uses caching to avoid repeated file I/O.
42
  Environment variables can override config values.
43
-
44
  Args:
45
  config_path: Path to config file (auto-detected if None)
46
  force_reload: Bypass cache and reload from file
47
-
48
  Returns:
49
  Configuration dictionary
50
  """
51
- global _config_cache
52
-
53
- # Return cached config if available
54
- if _config_cache is not None and not force_reload:
55
- return _config_cache
56
-
57
  # Find config file
58
  if config_path is None:
59
  config_path = os.getenv("VISUALRAG_CONFIG")
60
-
61
  if config_path is None:
62
  # Check common locations
63
  search_paths = [
@@ -65,65 +64,75 @@ def load_config(
65
  Path.cwd() / "visual_rag.yaml",
66
  Path.home() / ".visual_rag" / "config.yaml",
67
  ]
68
-
69
  for path in search_paths:
70
  if path.exists():
71
  config_path = str(path)
72
  break
73
-
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  # Load YAML if file exists
75
  config = {}
76
  if config_path and Path(config_path).exists():
77
  try:
78
  import yaml
79
-
80
  with open(config_path, "r") as f:
81
  config = yaml.safe_load(f) or {}
82
-
83
  logger.info(f"Loaded config from: {config_path}")
84
  except ImportError:
85
  logger.warning("PyYAML not installed, using environment variables only")
86
  except Exception as e:
87
  logger.warning(f"Could not load config file: {e}")
88
-
89
- # Apply environment variable overrides
90
- config = _apply_env_overrides(config)
91
-
92
- _config_cache = config
93
- return config
 
 
94
 
95
 
96
  def _apply_env_overrides(config: Dict[str, Any]) -> Dict[str, Any]:
97
  """Apply environment variable overrides."""
98
-
99
  env_mappings = {
100
  # Qdrant
101
  "QDRANT_URL": ["qdrant", "url"],
102
  "QDRANT_API_KEY": ["qdrant", "api_key"],
103
  "QDRANT_COLLECTION": ["qdrant", "collection"],
104
-
105
  # Model
106
  "VISUALRAG_MODEL": ["model", "name"],
107
  "COLPALI_MODEL_NAME": ["model", "name"], # Alias
108
  "EMBEDDING_BATCH_SIZE": ["model", "batch_size"],
109
-
110
  # Cloudinary
111
  "CLOUDINARY_CLOUD_NAME": ["cloudinary", "cloud_name"],
112
  "CLOUDINARY_API_KEY": ["cloudinary", "api_key"],
113
  "CLOUDINARY_API_SECRET": ["cloudinary", "api_secret"],
114
-
115
  # Processing
116
  "PDF_DPI": ["processing", "dpi"],
117
  "JPEG_QUALITY": ["processing", "jpeg_quality"],
118
-
119
  # Search
120
  "SEARCH_STRATEGY": ["search", "strategy"],
121
  "PREFETCH_K": ["search", "prefetch_k"],
122
-
123
  # Special token handling
124
  "VISUALRAG_INCLUDE_SPECIAL_TOKENS": ["embedding", "include_special_tokens"],
125
  }
126
-
127
  for env_var, path in env_mappings.items():
128
  value = os.getenv(env_var)
129
  if value is not None:
@@ -133,50 +142,51 @@ def _apply_env_overrides(config: Dict[str, Any]) -> Dict[str, Any]:
133
  if key not in current:
134
  current[key] = {}
135
  current = current[key]
136
-
137
  # Convert value to appropriate type
138
  final_key = path[-1]
139
  if final_key in current:
140
  existing_type = type(current[final_key])
141
- if existing_type == bool:
 
142
  value = value.lower() in ("true", "1", "yes", "on")
143
- elif existing_type == int:
144
  value = int(value)
145
- elif existing_type == float:
146
  value = float(value)
147
-
148
  current[final_key] = value
149
  logger.debug(f"Config override: {'.'.join(path)} = {value}")
150
-
151
  return config
152
 
153
 
154
  def get(key: str, default: Any = None) -> Any:
155
  """
156
  Get a configuration value by dot-notation path.
157
-
158
  Examples:
159
  >>> get("qdrant.url")
160
  >>> get("model.name", "vidore/colSmol-500M")
161
  >>> get("search.strategy", "multi_vector")
162
  """
163
- config = load_config()
164
-
165
  keys = key.split(".")
166
  current = config
167
-
168
  for k in keys:
169
  if isinstance(current, dict) and k in current:
170
  current = current[k]
171
  else:
172
  return default
173
-
174
  return current
175
 
176
 
177
- def get_section(section: str) -> Dict[str, Any]:
178
  """Get an entire configuration section."""
179
- config = load_config()
180
  return config.get(section, {})
181
 
182
 
@@ -215,5 +225,3 @@ def get_search_config() -> Dict[str, Any]:
215
  "prefetch_k": get("search.prefetch_k", 200),
216
  "top_k": get("search.top_k", 10),
217
  }
218
-
219
-
 
7
  - Convenience getters for common settings
8
  """
9
 
10
+ import copy
11
  import logging
12
+ import os
13
  from pathlib import Path
14
+ from typing import Any, Dict, Optional
15
 
16
  logger = logging.getLogger(__name__)
17
 
18
+ # Global config cache (raw YAML only; env overrides applied on demand)
19
+ _raw_config_cache: Optional[Dict[str, Any]] = None
20
+ _raw_config_cache_path: Optional[str] = None
21
 
22
 
23
  def _env_qdrant_url() -> Optional[str]:
24
+ """Get Qdrant URL from environment. Prefers QDRANT_URL."""
25
+ return os.getenv("QDRANT_URL") or os.getenv("SIGIR_QDRANT_URL") # legacy fallback
26
 
27
 
28
  def _env_qdrant_api_key() -> Optional[str]:
29
+ """Get Qdrant API key from environment. Prefers QDRANT_API_KEY."""
30
+ return os.getenv("QDRANT_API_KEY") or os.getenv("SIGIR_QDRANT_KEY") # legacy fallback
 
 
 
 
31
 
32
 
33
  def load_config(
34
  config_path: Optional[str] = None,
35
  force_reload: bool = False,
36
+ apply_env_overrides: bool = True,
37
  ) -> Dict[str, Any]:
38
  """
39
  Load configuration from YAML file.
40
+
41
  Uses caching to avoid repeated file I/O.
42
  Environment variables can override config values.
43
+
44
  Args:
45
  config_path: Path to config file (auto-detected if None)
46
  force_reload: Bypass cache and reload from file
47
+
48
  Returns:
49
  Configuration dictionary
50
  """
51
+ global _raw_config_cache, _raw_config_cache_path
52
+
53
+ # Determine the effective config path (used for caching)
54
+ effective_path: Optional[str] = None
55
+
 
56
  # Find config file
57
  if config_path is None:
58
  config_path = os.getenv("VISUALRAG_CONFIG")
59
+
60
  if config_path is None:
61
  # Check common locations
62
  search_paths = [
 
64
  Path.cwd() / "visual_rag.yaml",
65
  Path.home() / ".visual_rag" / "config.yaml",
66
  ]
67
+
68
  for path in search_paths:
69
  if path.exists():
70
  config_path = str(path)
71
  break
72
+ effective_path = str(config_path) if config_path else None
73
+
74
+ # Return cached raw config if available.
75
+ # - If caller doesn't specify a path (effective_path is None), use whatever was
76
+ # loaded most recently (common pattern in apps).
77
+ # - If a path is specified, only reuse cache when it matches.
78
+ if (
79
+ _raw_config_cache is not None
80
+ and not force_reload
81
+ and (effective_path is None or _raw_config_cache_path == effective_path)
82
+ ):
83
+ cfg = copy.deepcopy(_raw_config_cache)
84
+ return _apply_env_overrides(cfg) if apply_env_overrides else cfg
85
+
86
  # Load YAML if file exists
87
  config = {}
88
  if config_path and Path(config_path).exists():
89
  try:
90
  import yaml
91
+
92
  with open(config_path, "r") as f:
93
  config = yaml.safe_load(f) or {}
94
+
95
  logger.info(f"Loaded config from: {config_path}")
96
  except ImportError:
97
  logger.warning("PyYAML not installed, using environment variables only")
98
  except Exception as e:
99
  logger.warning(f"Could not load config file: {e}")
100
+
101
+ # Cache RAW config (no env overrides)
102
+ _raw_config_cache = copy.deepcopy(config)
103
+ _raw_config_cache_path = effective_path
104
+
105
+ # Return resolved or raw depending on caller preference
106
+ cfg = copy.deepcopy(config)
107
+ return _apply_env_overrides(cfg) if apply_env_overrides else cfg
108
 
109
 
110
  def _apply_env_overrides(config: Dict[str, Any]) -> Dict[str, Any]:
111
  """Apply environment variable overrides."""
112
+
113
  env_mappings = {
114
  # Qdrant
115
  "QDRANT_URL": ["qdrant", "url"],
116
  "QDRANT_API_KEY": ["qdrant", "api_key"],
117
  "QDRANT_COLLECTION": ["qdrant", "collection"],
 
118
  # Model
119
  "VISUALRAG_MODEL": ["model", "name"],
120
  "COLPALI_MODEL_NAME": ["model", "name"], # Alias
121
  "EMBEDDING_BATCH_SIZE": ["model", "batch_size"],
 
122
  # Cloudinary
123
  "CLOUDINARY_CLOUD_NAME": ["cloudinary", "cloud_name"],
124
  "CLOUDINARY_API_KEY": ["cloudinary", "api_key"],
125
  "CLOUDINARY_API_SECRET": ["cloudinary", "api_secret"],
 
126
  # Processing
127
  "PDF_DPI": ["processing", "dpi"],
128
  "JPEG_QUALITY": ["processing", "jpeg_quality"],
 
129
  # Search
130
  "SEARCH_STRATEGY": ["search", "strategy"],
131
  "PREFETCH_K": ["search", "prefetch_k"],
 
132
  # Special token handling
133
  "VISUALRAG_INCLUDE_SPECIAL_TOKENS": ["embedding", "include_special_tokens"],
134
  }
135
+
136
  for env_var, path in env_mappings.items():
137
  value = os.getenv(env_var)
138
  if value is not None:
 
142
  if key not in current:
143
  current[key] = {}
144
  current = current[key]
145
+
146
  # Convert value to appropriate type
147
  final_key = path[-1]
148
  if final_key in current:
149
  existing_type = type(current[final_key])
150
+ # Use `is` for type comparisons (Ruff E721).
151
+ if existing_type is bool:
152
  value = value.lower() in ("true", "1", "yes", "on")
153
+ elif existing_type is int:
154
  value = int(value)
155
+ elif existing_type is float:
156
  value = float(value)
157
+
158
  current[final_key] = value
159
  logger.debug(f"Config override: {'.'.join(path)} = {value}")
160
+
161
  return config
162
 
163
 
164
  def get(key: str, default: Any = None) -> Any:
165
  """
166
  Get a configuration value by dot-notation path.
167
+
168
  Examples:
169
  >>> get("qdrant.url")
170
  >>> get("model.name", "vidore/colSmol-500M")
171
  >>> get("search.strategy", "multi_vector")
172
  """
173
+ config = load_config(apply_env_overrides=True)
174
+
175
  keys = key.split(".")
176
  current = config
177
+
178
  for k in keys:
179
  if isinstance(current, dict) and k in current:
180
  current = current[k]
181
  else:
182
  return default
183
+
184
  return current
185
 
186
 
187
+ def get_section(section: str, *, apply_env_overrides: bool = True) -> Dict[str, Any]:
188
  """Get an entire configuration section."""
189
+ config = load_config(apply_env_overrides=apply_env_overrides)
190
  return config.get(section, {})
191
 
192
 
 
225
  "prefetch_k": get("search.prefetch_k", 200),
226
  "top_k": get("search.top_k", 10),
227
  }
 
 
visual_rag/demo_runner.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Launch the Streamlit demo from an installed package.
3
+
4
+ Why:
5
+ - After `pip install visual-rag-toolkit`, the repo layout isn't present.
6
+ - We package the `demo/` module and expose `visual_rag.demo()` + `visual-rag-demo`.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import importlib
13
+ import os
14
+ import subprocess
15
+ import sys
16
+ from pathlib import Path
17
+ from typing import Optional
18
+
19
+
20
+ def demo(
21
+ *,
22
+ host: str = "0.0.0.0",
23
+ port: int = 7860,
24
+ headless: bool = True,
25
+ open_browser: bool = False,
26
+ extra_args: Optional[list[str]] = None,
27
+ ) -> int:
28
+ """
29
+ Launch the Streamlit demo UI.
30
+
31
+ Requirements:
32
+ - `visual-rag-toolkit[ui,qdrant,embedding,pdf]` (or `visual-rag-toolkit[all]`)
33
+
34
+ Returns:
35
+ Streamlit process exit code.
36
+ """
37
+ try:
38
+ import streamlit # noqa: F401
39
+ except Exception as e: # pragma: no cover
40
+ raise RuntimeError(
41
+ "Streamlit is not installed. Install with:\n"
42
+ ' pip install "visual-rag-toolkit[ui,qdrant,embedding,pdf]"'
43
+ ) from e
44
+
45
+ # Resolve the installed demo entrypoint path.
46
+ mod = importlib.import_module("demo.app")
47
+ app_path = Path(getattr(mod, "__file__", "")).resolve()
48
+ if not app_path.exists(): # pragma: no cover
49
+ raise RuntimeError("Could not locate installed demo app (demo.app).")
50
+
51
+ # Build a stable Streamlit invocation.
52
+ cmd = [sys.executable, "-m", "streamlit", "run", str(app_path)]
53
+ cmd += ["--server.address", str(host)]
54
+ cmd += ["--server.port", str(int(port))]
55
+ cmd += ["--server.headless", "true" if headless else "false"]
56
+ cmd += ["--browser.gatherUsageStats", "false"]
57
+ cmd += ["--server.runOnSave", "false"]
58
+ cmd += ["--browser.serverAddress", str(host)]
59
+ if not open_browser:
60
+ cmd += ["--browser.serverPort", str(int(port))]
61
+ cmd += ["--browser.open", "false"]
62
+
63
+ if extra_args:
64
+ cmd += list(extra_args)
65
+
66
+ env = os.environ.copy()
67
+ # Make sure the demo doesn't spam internal Streamlit warnings in logs.
68
+ env.setdefault("STREAMLIT_BROWSER_GATHER_USAGE_STATS", "false")
69
+
70
+ return subprocess.call(cmd, env=env)
71
+
72
+
73
+ def main() -> None:
74
+ p = argparse.ArgumentParser(description="Launch the Visual RAG Toolkit Streamlit demo.")
75
+ p.add_argument("--host", default="0.0.0.0")
76
+ p.add_argument("--port", type=int, default=7860)
77
+ p.add_argument(
78
+ "--no-headless", action="store_true", help="Run with a browser window (not headless)."
79
+ )
80
+ p.add_argument("--open", action="store_true", help="Open browser automatically.")
81
+ args, unknown = p.parse_known_args()
82
+
83
+ rc = demo(
84
+ host=args.host,
85
+ port=args.port,
86
+ headless=(not args.no_headless),
87
+ open_browser=bool(args.open),
88
+ extra_args=unknown,
89
+ )
90
+ raise SystemExit(rc)
visual_rag/embedding/__init__.py CHANGED
@@ -6,19 +6,18 @@ Provides:
6
  - Pooling utilities: tile-level, global, MaxSim scoring
7
  """
8
 
9
- from visual_rag.embedding.visual_embedder import VisualEmbedder, ColPaliEmbedder
10
  from visual_rag.embedding.pooling import (
11
- tile_level_mean_pooling,
12
- global_mean_pooling,
13
- compute_maxsim_score,
14
  compute_maxsim_batch,
 
 
 
15
  )
 
16
 
17
  __all__ = [
18
  # Main embedder
19
  "VisualEmbedder",
20
  "ColPaliEmbedder", # Backward compatibility alias
21
-
22
  # Pooling functions
23
  "tile_level_mean_pooling",
24
  "global_mean_pooling",
 
6
  - Pooling utilities: tile-level, global, MaxSim scoring
7
  """
8
 
 
9
  from visual_rag.embedding.pooling import (
 
 
 
10
  compute_maxsim_batch,
11
+ compute_maxsim_score,
12
+ global_mean_pooling,
13
+ tile_level_mean_pooling,
14
  )
15
+ from visual_rag.embedding.visual_embedder import ColPaliEmbedder, VisualEmbedder
16
 
17
  __all__ = [
18
  # Main embedder
19
  "VisualEmbedder",
20
  "ColPaliEmbedder", # Backward compatibility alias
 
21
  # Pooling functions
22
  "tile_level_mean_pooling",
23
  "global_mean_pooling",
visual_rag/embedding/pooling.py CHANGED
@@ -7,10 +7,11 @@ Provides:
7
  - MaxSim scoring for ColBERT-style late interaction
8
  """
9
 
 
 
 
10
  import numpy as np
11
  import torch
12
- from typing import Union, Optional
13
- import logging
14
 
15
  logger = logging.getLogger(__name__)
16
 
@@ -39,24 +40,24 @@ def tile_level_mean_pooling(
39
  ) -> np.ndarray:
40
  """
41
  Compute tile-level mean pooling for multi-vector embeddings.
42
-
43
  Instead of collapsing to 1×dim (global pooling), this preserves spatial
44
  structure by computing mean per tile → num_tiles × dim.
45
-
46
  This is our NOVEL contribution for scalable visual retrieval:
47
  - Faster than full MaxSim (fewer vectors to compare)
48
  - More accurate than global pooling (preserves spatial info)
49
  - Ideal for two-stage retrieval (prefetch with pooled, rerank with full)
50
-
51
  Args:
52
  embedding: Visual token embeddings [num_visual_tokens, dim]
53
  num_tiles: Number of tiles (including global tile)
54
  patches_per_tile: Patches per tile (64 for ColSmol)
55
  output_dtype: Output dtype (default: infer from input, fp16→fp16, bf16→fp32)
56
-
57
  Returns:
58
  Tile-level pooled embeddings [num_tiles, dim]
59
-
60
  Example:
61
  >>> # Image with 4×3 tiles + 1 global = 13 tiles
62
  >>> # Each tile has 64 patches → 832 visual tokens
@@ -71,31 +72,29 @@ def tile_level_mean_pooling(
71
  emb_np = embedding.cpu().numpy().astype(np.float32)
72
  else:
73
  emb_np = np.array(embedding, dtype=np.float32)
74
-
75
  num_visual_tokens = emb_np.shape[0]
76
  expected_tokens = num_tiles * patches_per_tile
77
-
78
  if num_visual_tokens != expected_tokens:
79
- logger.debug(
80
- f"Token count mismatch: {num_visual_tokens} vs expected {expected_tokens}"
81
- )
82
  actual_tiles = num_visual_tokens // patches_per_tile
83
  if actual_tiles * patches_per_tile != num_visual_tokens:
84
  actual_tiles += 1
85
  num_tiles = actual_tiles
86
-
87
  tile_embeddings = []
88
  for tile_idx in range(num_tiles):
89
  start_idx = tile_idx * patches_per_tile
90
  end_idx = min(start_idx + patches_per_tile, num_visual_tokens)
91
-
92
  if start_idx >= num_visual_tokens:
93
  break
94
-
95
  tile_patches = emb_np[start_idx:end_idx]
96
  tile_mean = tile_patches.mean(axis=0)
97
  tile_embeddings.append(tile_mean)
98
-
99
  return np.array(tile_embeddings, dtype=out_dtype)
100
 
101
 
@@ -116,7 +115,9 @@ def colpali_row_mean_pooling(
116
  num_tokens, dim = emb_np.shape
117
  expected = int(grid_size) * int(grid_size)
118
  if num_tokens != expected:
119
- raise ValueError(f"Expected {expected} visual tokens for grid_size={grid_size}, got {num_tokens}")
 
 
120
 
121
  grid = emb_np.reshape(int(grid_size), int(grid_size), int(dim))
122
  pooled = grid.mean(axis=1)
@@ -157,7 +158,9 @@ def colsmol_experimental_pooling(
157
  last_tile_start = (int(num_tiles) - 1) * int(patches_per_tile)
158
 
159
  prefix = emb_np[:last_tile_start]
160
- last_tile = emb_np[last_tile_start : min(last_tile_start + int(patches_per_tile), num_visual_tokens)]
 
 
161
 
162
  if prefix.size:
163
  prefix_tiles = prefix.reshape(-1, int(patches_per_tile), int(dim))
@@ -174,7 +177,7 @@ def colpali_experimental_pooling_from_rows(
174
  ) -> np.ndarray:
175
  """
176
  Experimental "convolution-style" pooling with window size 3.
177
-
178
  For N input rows, produces N + 2 output vectors:
179
  - Position 0: row[0] alone (1 row)
180
  - Position 1: mean(rows[0:2]) (2 rows)
@@ -182,7 +185,7 @@ def colpali_experimental_pooling_from_rows(
182
  - Positions 3 to N-1: sliding window of 3 (rows[i-2:i+1])
183
  - Position N: mean(rows[N-2:N]) (last 2 rows)
184
  - Position N+1: row[N-1] alone (last row)
185
-
186
  For N=32 rows: produces 34 vectors.
187
  """
188
  out_dtype = _infer_output_dtype(row_vectors, output_dtype)
@@ -202,13 +205,16 @@ def colpali_experimental_pooling_from_rows(
202
  if n == 2:
203
  return np.stack([rows[0], rows[:2].mean(axis=0), rows[1]], axis=0).astype(out_dtype)
204
  if n == 3:
205
- return np.stack([
206
- rows[0],
207
- rows[:2].mean(axis=0),
208
- rows[:3].mean(axis=0),
209
- rows[1:3].mean(axis=0),
210
- rows[2],
211
- ], axis=0).astype(out_dtype)
 
 
 
212
 
213
  out = np.zeros((n + 2, dim), dtype=np.float32)
214
  out[0] = rows[0]
@@ -227,14 +233,14 @@ def global_mean_pooling(
227
  ) -> np.ndarray:
228
  """
229
  Compute global mean pooling → single vector.
230
-
231
  This is the simplest pooling but loses all spatial information.
232
  Use for fastest retrieval when accuracy can be sacrificed.
233
-
234
  Args:
235
  embedding: Multi-vector embeddings [num_tokens, dim]
236
  output_dtype: Output dtype (default: infer from input, fp16→fp16, bf16→fp32)
237
-
238
  Returns:
239
  Pooled vector [dim]
240
  """
@@ -246,7 +252,7 @@ def global_mean_pooling(
246
  emb_np = embedding.cpu().numpy()
247
  else:
248
  emb_np = np.array(embedding)
249
-
250
  return emb_np.mean(axis=0).astype(out_dtype)
251
 
252
 
@@ -257,21 +263,21 @@ def compute_maxsim_score(
257
  ) -> float:
258
  """
259
  Compute ColBERT-style MaxSim late interaction score.
260
-
261
  For each query token, finds max similarity with any document token,
262
  then sums across query tokens.
263
-
264
  This is the standard scoring for ColBERT/ColPali:
265
  score = Σ_q max_d (sim(q, d))
266
-
267
  Args:
268
  query_embedding: Query embeddings [num_query_tokens, dim]
269
  doc_embedding: Document embeddings [num_doc_tokens, dim]
270
  normalize: L2 normalize embeddings before scoring (recommended)
271
-
272
  Returns:
273
  MaxSim score (higher is better)
274
-
275
  Example:
276
  >>> query = embedder.embed_query("budget allocation")
277
  >>> doc = embeddings[0] # From embed_images
@@ -282,22 +288,20 @@ def compute_maxsim_score(
282
  query_norm = query_embedding / (
283
  np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8
284
  )
285
- doc_norm = doc_embedding / (
286
- np.linalg.norm(doc_embedding, axis=1, keepdims=True) + 1e-8
287
- )
288
  else:
289
  query_norm = query_embedding
290
  doc_norm = doc_embedding
291
-
292
  # Compute similarity matrix: [num_query, num_doc]
293
  similarity_matrix = np.dot(query_norm, doc_norm.T)
294
-
295
  # MaxSim: For each query token, take max similarity with any doc token
296
  max_similarities = similarity_matrix.max(axis=1)
297
-
298
  # Sum across query tokens
299
  score = float(max_similarities.sum())
300
-
301
  return score
302
 
303
 
@@ -308,12 +312,12 @@ def compute_maxsim_batch(
308
  ) -> list:
309
  """
310
  Compute MaxSim scores for multiple documents efficiently.
311
-
312
  Args:
313
  query_embedding: Query embeddings [num_query_tokens, dim]
314
  doc_embeddings: List of document embeddings
315
  normalize: L2 normalize embeddings
316
-
317
  Returns:
318
  List of MaxSim scores
319
  """
@@ -324,18 +328,16 @@ def compute_maxsim_batch(
324
  )
325
  else:
326
  query_norm = query_embedding
327
-
328
  scores = []
329
  for doc_emb in doc_embeddings:
330
  if normalize:
331
- doc_norm = doc_emb / (
332
- np.linalg.norm(doc_emb, axis=1, keepdims=True) + 1e-8
333
- )
334
  else:
335
  doc_norm = doc_emb
336
-
337
  sim_matrix = np.dot(query_norm, doc_norm.T)
338
  max_sims = sim_matrix.max(axis=1)
339
  scores.append(float(max_sims.sum()))
340
-
341
  return scores
 
7
  - MaxSim scoring for ColBERT-style late interaction
8
  """
9
 
10
+ import logging
11
+ from typing import Optional, Union
12
+
13
  import numpy as np
14
  import torch
 
 
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
40
  ) -> np.ndarray:
41
  """
42
  Compute tile-level mean pooling for multi-vector embeddings.
43
+
44
  Instead of collapsing to 1×dim (global pooling), this preserves spatial
45
  structure by computing mean per tile → num_tiles × dim.
46
+
47
  This is our NOVEL contribution for scalable visual retrieval:
48
  - Faster than full MaxSim (fewer vectors to compare)
49
  - More accurate than global pooling (preserves spatial info)
50
  - Ideal for two-stage retrieval (prefetch with pooled, rerank with full)
51
+
52
  Args:
53
  embedding: Visual token embeddings [num_visual_tokens, dim]
54
  num_tiles: Number of tiles (including global tile)
55
  patches_per_tile: Patches per tile (64 for ColSmol)
56
  output_dtype: Output dtype (default: infer from input, fp16→fp16, bf16→fp32)
57
+
58
  Returns:
59
  Tile-level pooled embeddings [num_tiles, dim]
60
+
61
  Example:
62
  >>> # Image with 4×3 tiles + 1 global = 13 tiles
63
  >>> # Each tile has 64 patches → 832 visual tokens
 
72
  emb_np = embedding.cpu().numpy().astype(np.float32)
73
  else:
74
  emb_np = np.array(embedding, dtype=np.float32)
75
+
76
  num_visual_tokens = emb_np.shape[0]
77
  expected_tokens = num_tiles * patches_per_tile
78
+
79
  if num_visual_tokens != expected_tokens:
80
+ logger.debug(f"Token count mismatch: {num_visual_tokens} vs expected {expected_tokens}")
 
 
81
  actual_tiles = num_visual_tokens // patches_per_tile
82
  if actual_tiles * patches_per_tile != num_visual_tokens:
83
  actual_tiles += 1
84
  num_tiles = actual_tiles
85
+
86
  tile_embeddings = []
87
  for tile_idx in range(num_tiles):
88
  start_idx = tile_idx * patches_per_tile
89
  end_idx = min(start_idx + patches_per_tile, num_visual_tokens)
90
+
91
  if start_idx >= num_visual_tokens:
92
  break
93
+
94
  tile_patches = emb_np[start_idx:end_idx]
95
  tile_mean = tile_patches.mean(axis=0)
96
  tile_embeddings.append(tile_mean)
97
+
98
  return np.array(tile_embeddings, dtype=out_dtype)
99
 
100
 
 
115
  num_tokens, dim = emb_np.shape
116
  expected = int(grid_size) * int(grid_size)
117
  if num_tokens != expected:
118
+ raise ValueError(
119
+ f"Expected {expected} visual tokens for grid_size={grid_size}, got {num_tokens}"
120
+ )
121
 
122
  grid = emb_np.reshape(int(grid_size), int(grid_size), int(dim))
123
  pooled = grid.mean(axis=1)
 
158
  last_tile_start = (int(num_tiles) - 1) * int(patches_per_tile)
159
 
160
  prefix = emb_np[:last_tile_start]
161
+ last_tile = emb_np[
162
+ last_tile_start : min(last_tile_start + int(patches_per_tile), num_visual_tokens)
163
+ ]
164
 
165
  if prefix.size:
166
  prefix_tiles = prefix.reshape(-1, int(patches_per_tile), int(dim))
 
177
  ) -> np.ndarray:
178
  """
179
  Experimental "convolution-style" pooling with window size 3.
180
+
181
  For N input rows, produces N + 2 output vectors:
182
  - Position 0: row[0] alone (1 row)
183
  - Position 1: mean(rows[0:2]) (2 rows)
 
185
  - Positions 3 to N-1: sliding window of 3 (rows[i-2:i+1])
186
  - Position N: mean(rows[N-2:N]) (last 2 rows)
187
  - Position N+1: row[N-1] alone (last row)
188
+
189
  For N=32 rows: produces 34 vectors.
190
  """
191
  out_dtype = _infer_output_dtype(row_vectors, output_dtype)
 
205
  if n == 2:
206
  return np.stack([rows[0], rows[:2].mean(axis=0), rows[1]], axis=0).astype(out_dtype)
207
  if n == 3:
208
+ return np.stack(
209
+ [
210
+ rows[0],
211
+ rows[:2].mean(axis=0),
212
+ rows[:3].mean(axis=0),
213
+ rows[1:3].mean(axis=0),
214
+ rows[2],
215
+ ],
216
+ axis=0,
217
+ ).astype(out_dtype)
218
 
219
  out = np.zeros((n + 2, dim), dtype=np.float32)
220
  out[0] = rows[0]
 
233
  ) -> np.ndarray:
234
  """
235
  Compute global mean pooling → single vector.
236
+
237
  This is the simplest pooling but loses all spatial information.
238
  Use for fastest retrieval when accuracy can be sacrificed.
239
+
240
  Args:
241
  embedding: Multi-vector embeddings [num_tokens, dim]
242
  output_dtype: Output dtype (default: infer from input, fp16→fp16, bf16→fp32)
243
+
244
  Returns:
245
  Pooled vector [dim]
246
  """
 
252
  emb_np = embedding.cpu().numpy()
253
  else:
254
  emb_np = np.array(embedding)
255
+
256
  return emb_np.mean(axis=0).astype(out_dtype)
257
 
258
 
 
263
  ) -> float:
264
  """
265
  Compute ColBERT-style MaxSim late interaction score.
266
+
267
  For each query token, finds max similarity with any document token,
268
  then sums across query tokens.
269
+
270
  This is the standard scoring for ColBERT/ColPali:
271
  score = Σ_q max_d (sim(q, d))
272
+
273
  Args:
274
  query_embedding: Query embeddings [num_query_tokens, dim]
275
  doc_embedding: Document embeddings [num_doc_tokens, dim]
276
  normalize: L2 normalize embeddings before scoring (recommended)
277
+
278
  Returns:
279
  MaxSim score (higher is better)
280
+
281
  Example:
282
  >>> query = embedder.embed_query("budget allocation")
283
  >>> doc = embeddings[0] # From embed_images
 
288
  query_norm = query_embedding / (
289
  np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8
290
  )
291
+ doc_norm = doc_embedding / (np.linalg.norm(doc_embedding, axis=1, keepdims=True) + 1e-8)
 
 
292
  else:
293
  query_norm = query_embedding
294
  doc_norm = doc_embedding
295
+
296
  # Compute similarity matrix: [num_query, num_doc]
297
  similarity_matrix = np.dot(query_norm, doc_norm.T)
298
+
299
  # MaxSim: For each query token, take max similarity with any doc token
300
  max_similarities = similarity_matrix.max(axis=1)
301
+
302
  # Sum across query tokens
303
  score = float(max_similarities.sum())
304
+
305
  return score
306
 
307
 
 
312
  ) -> list:
313
  """
314
  Compute MaxSim scores for multiple documents efficiently.
315
+
316
  Args:
317
  query_embedding: Query embeddings [num_query_tokens, dim]
318
  doc_embeddings: List of document embeddings
319
  normalize: L2 normalize embeddings
320
+
321
  Returns:
322
  List of MaxSim scores
323
  """
 
328
  )
329
  else:
330
  query_norm = query_embedding
331
+
332
  scores = []
333
  for doc_emb in doc_embeddings:
334
  if normalize:
335
+ doc_norm = doc_emb / (np.linalg.norm(doc_emb, axis=1, keepdims=True) + 1e-8)
 
 
336
  else:
337
  doc_norm = doc_emb
338
+
339
  sim_matrix = np.dot(query_norm, doc_norm.T)
340
  max_sims = sim_matrix.max(axis=1)
341
  scores.append(float(max_sims.sum()))
342
+
343
  return scores
visual_rag/embedding/visual_embedder.py CHANGED
@@ -12,12 +12,12 @@ The embedder is BACKEND-AGNOSTIC - configure which model to use via the
12
  """
13
 
14
  import gc
15
- import os
16
  import logging
17
- from typing import List, Dict, Any, Optional, Tuple, Union
 
18
 
19
- import torch
20
  import numpy as np
 
21
  from PIL import Image
22
  from tqdm import tqdm
23
 
@@ -27,11 +27,11 @@ logger = logging.getLogger(__name__)
27
  class VisualEmbedder:
28
  """
29
  Visual document embedder supporting multiple backends.
30
-
31
  Currently supports:
32
  - ColPali family (ColSmol-500M, ColPali, ColQwen2)
33
  - More backends can be added
34
-
35
  Args:
36
  model_name: HuggingFace model name (e.g., "vidore/colSmol-500M")
37
  backend: Backend type ("colpali", "auto"). "auto" detects from model_name.
@@ -39,23 +39,23 @@ class VisualEmbedder:
39
  torch_dtype: Data type for model weights
40
  batch_size: Batch size for image processing
41
  filter_special_tokens: Filter special tokens from query embeddings
42
-
43
  Example:
44
  >>> # Auto-detect backend from model name
45
  >>> embedder = VisualEmbedder(model_name="vidore/colSmol-500M")
46
- >>>
47
  >>> # Embed images
48
  >>> image_embeddings = embedder.embed_images(images)
49
- >>>
50
  >>> # Embed query
51
  >>> query_embedding = embedder.embed_query("What is the budget?")
52
- >>>
53
  >>> # Get token info for saliency maps
54
  >>> embeddings, token_infos = embedder.embed_images(
55
  ... images, return_token_info=True
56
  ... )
57
  """
58
-
59
  # Known model families and their backends
60
  MODEL_BACKENDS = {
61
  "colsmol": "colpali",
@@ -63,7 +63,7 @@ class VisualEmbedder:
63
  "colqwen": "colpali",
64
  "colidefics": "colpali",
65
  }
66
-
67
  def __init__(
68
  self,
69
  model_name: str = "vidore/colSmol-500M",
@@ -81,15 +81,15 @@ class VisualEmbedder:
81
  if processor_speed not in ("fast", "slow", "auto"):
82
  raise ValueError("processor_speed must be one of: fast, slow, auto")
83
  self.processor_speed = processor_speed
84
-
85
  if os.getenv("VISUALRAG_INCLUDE_SPECIAL_TOKENS"):
86
  self.filter_special_tokens = False
87
  logger.info("Special token filtering disabled via VISUALRAG_INCLUDE_SPECIAL_TOKENS")
88
-
89
  if backend == "auto":
90
  backend = self._detect_backend(model_name)
91
  self.backend = backend
92
-
93
  if device is None:
94
  if torch.cuda.is_available():
95
  device = "cuda"
@@ -98,53 +98,55 @@ class VisualEmbedder:
98
  else:
99
  device = "cpu"
100
  self.device = device
101
-
102
  if torch_dtype is None:
103
  if device == "cuda":
104
  torch_dtype = torch.bfloat16
105
  else:
106
  torch_dtype = torch.float32
107
  self.torch_dtype = torch_dtype
108
-
109
  if output_dtype is None:
110
  if torch_dtype == torch.float16:
111
  output_dtype = np.float16
112
  else:
113
  output_dtype = np.float32
114
  self.output_dtype = output_dtype
115
-
116
  self._model = None
117
  self._processor = None
118
  self._image_token_id = None
119
-
120
- logger.info(f"🤖 VisualEmbedder initialized")
121
  logger.info(f" Model: {model_name}")
122
  logger.info(f" Backend: {backend}")
123
- logger.info(f" Device: {device}, torch_dtype: {torch_dtype}, output_dtype: {output_dtype}")
124
-
 
 
125
  def _detect_backend(self, model_name: str) -> str:
126
  """Auto-detect backend from model name."""
127
  model_lower = model_name.lower()
128
-
129
  for key, backend in self.MODEL_BACKENDS.items():
130
  if key in model_lower:
131
  logger.debug(f"Detected backend '{backend}' from model name")
132
  return backend
133
-
134
  # Default to colpali for unknown models
135
  logger.warning(f"Unknown model '{model_name}', defaulting to 'colpali' backend")
136
  return "colpali"
137
-
138
  def _load_model(self):
139
  """Lazy load the model when first needed."""
140
  if self._model is not None:
141
  return
142
-
143
  if self.backend == "colpali":
144
  self._load_colpali_model()
145
  else:
146
  raise ValueError(f"Unknown backend: {self.backend}")
147
-
148
  def _load_colpali_model(self):
149
  """Load ColPali-family model."""
150
  try:
@@ -162,7 +164,7 @@ class VisualEmbedder:
162
  "pip install visual-rag-toolkit[embedding] or "
163
  "pip install colpali-engine"
164
  )
165
-
166
  logger.info(f"🤖 Loading ColPali model: {self.model_name}")
167
  logger.info(f" Device: {self.device}, dtype: {self.torch_dtype}")
168
 
@@ -170,7 +172,7 @@ class VisualEmbedder:
170
  if self.processor_speed == "auto":
171
  return {}
172
  return {"use_fast": self.processor_speed == "fast"}
173
-
174
  from transformers import AutoConfig
175
 
176
  cfg = AutoConfig.from_pretrained(self.model_name)
@@ -183,12 +185,16 @@ class VisualEmbedder:
183
  device_map=self.device,
184
  ).eval()
185
  try:
186
- self._processor = ColPaliProcessor.from_pretrained(self.model_name, **_processor_kwargs())
 
 
187
  except TypeError:
188
  self._processor = ColPaliProcessor.from_pretrained(self.model_name)
189
  except Exception:
190
  if self.processor_speed == "fast":
191
- self._processor = ColPaliProcessor.from_pretrained(self.model_name, use_fast=False)
 
 
192
  else:
193
  raise
194
  self._image_token_id = self._processor.image_token_id
@@ -202,12 +208,18 @@ class VisualEmbedder:
202
  device_map=self.device,
203
  ).eval()
204
  try:
205
- self._processor = ColQwen2Processor.from_pretrained(self.model_name, device_map=self.device, **_processor_kwargs())
 
 
206
  except TypeError:
207
- self._processor = ColQwen2Processor.from_pretrained(self.model_name, device_map=self.device)
 
 
208
  except Exception:
209
  if self.processor_speed == "fast":
210
- self._processor = ColQwen2Processor.from_pretrained(self.model_name, device_map=self.device, use_fast=False)
 
 
211
  else:
212
  raise
213
  self._image_token_id = self._processor.image_token_id
@@ -231,33 +243,37 @@ class VisualEmbedder:
231
  attn_implementation=attn_implementation,
232
  ).eval()
233
  try:
234
- self._processor = ColIdefics3Processor.from_pretrained(self.model_name, **_processor_kwargs())
 
 
235
  except TypeError:
236
  self._processor = ColIdefics3Processor.from_pretrained(self.model_name)
237
  except Exception:
238
  if self.processor_speed == "fast":
239
- self._processor = ColIdefics3Processor.from_pretrained(self.model_name, use_fast=False)
 
 
240
  else:
241
  raise
242
  self._image_token_id = self._processor.image_token_id
243
-
244
  logger.info("✅ Model loaded successfully")
245
-
246
  @property
247
  def model(self):
248
  self._load_model()
249
  return self._model
250
-
251
  @property
252
  def processor(self):
253
  self._load_model()
254
  return self._processor
255
-
256
  @property
257
  def image_token_id(self):
258
  self._load_model()
259
  return self._image_token_id
260
-
261
  def embed_query(
262
  self,
263
  query_text: str,
@@ -265,31 +281,31 @@ class VisualEmbedder:
265
  ) -> torch.Tensor:
266
  """
267
  Generate embedding for a text query.
268
-
269
  By default, filters out special tokens (CLS, SEP, PAD) to keep only
270
  meaningful text tokens for better MaxSim matching.
271
-
272
  Args:
273
  query_text: Natural language query string
274
  filter_special_tokens: Override instance-level setting
275
-
276
  Returns:
277
  Query embedding tensor of shape [num_tokens, embedding_dim]
278
  """
279
  should_filter = (
280
- filter_special_tokens
281
- if filter_special_tokens is not None
282
  else self.filter_special_tokens
283
  )
284
-
285
  with torch.no_grad():
286
  processed = self.processor.process_queries([query_text]).to(self.model.device)
287
  embedding = self.model(**processed)
288
-
289
  # Remove batch dimension: [1, tokens, dim] -> [tokens, dim]
290
  if embedding.dim() == 3:
291
  embedding = embedding.squeeze(0)
292
-
293
  if should_filter:
294
  # Filter special tokens based on attention mask
295
  attention_mask = processed.get("attention_mask")
@@ -297,7 +313,7 @@ class VisualEmbedder:
297
  # Keep only tokens with attention_mask = 1
298
  valid_mask = attention_mask.squeeze(0).bool()
299
  embedding = embedding[valid_mask]
300
-
301
  # Additionally filter padding tokens if present
302
  input_ids = processed.get("input_ids")
303
  if input_ids is not None:
@@ -307,11 +323,11 @@ class VisualEmbedder:
307
  non_special_mask = input_ids >= 4
308
  if non_special_mask.any():
309
  embedding = embedding[non_special_mask]
310
-
311
  logger.debug(f"Query embedding: {embedding.shape[0]} tokens after filtering")
312
  else:
313
  logger.debug(f"Query embedding: {embedding.shape[0]} tokens (unfiltered)")
314
-
315
  return embedding
316
 
317
  def embed_queries(
@@ -327,7 +343,9 @@ class VisualEmbedder:
327
  Returns a list of tensors, each of shape [num_tokens, embedding_dim].
328
  """
329
  should_filter = (
330
- filter_special_tokens if filter_special_tokens is not None else self.filter_special_tokens
 
 
331
  )
332
  batch_size = batch_size or self.batch_size
333
 
@@ -368,7 +386,7 @@ class VisualEmbedder:
368
  torch.mps.empty_cache()
369
 
370
  return outputs
371
-
372
  def embed_images(
373
  self,
374
  images: List[Image.Image],
@@ -378,19 +396,19 @@ class VisualEmbedder:
378
  ) -> Union[List[torch.Tensor], Tuple[List[torch.Tensor], List[Dict[str, Any]]]]:
379
  """
380
  Generate embeddings for a list of images.
381
-
382
  Args:
383
  images: List of PIL Images
384
  batch_size: Override instance batch size
385
  return_token_info: Also return token metadata (for saliency maps)
386
  show_progress: Show progress bar
387
-
388
  Returns:
389
  If return_token_info=False:
390
  List of embedding tensors [num_patches, dim]
391
  If return_token_info=True:
392
  Tuple of (embeddings, token_infos)
393
-
394
  Token info contains:
395
  - visual_token_indices: Indices of visual tokens in embedding
396
  - num_visual_tokens: Count of visual tokens
@@ -398,54 +416,60 @@ class VisualEmbedder:
398
  - num_tiles: Total tiles (n_rows × n_cols + 1 global)
399
  """
400
  batch_size = batch_size or self.batch_size
401
- if self.device == "mps" and "colpali" in (self.model_name or "").lower() and int(batch_size) > 1:
 
 
 
 
402
  batch_size = 1
403
-
404
  embeddings = []
405
  token_infos = [] if return_token_info else None
406
-
407
  iterator = range(0, len(images), batch_size)
408
  if show_progress:
409
  iterator = tqdm(iterator, desc="🎨 Embedding", unit="batch")
410
-
411
  for i in iterator:
412
- batch = images[i:i + batch_size]
413
-
414
  with torch.no_grad():
415
  processed = self.processor.process_images(batch).to(self.model.device)
416
-
417
  # Extract token info before model forward
418
  if return_token_info:
419
  input_ids = processed["input_ids"]
420
  batch_n_rows = processed.get("n_rows")
421
  batch_n_cols = processed.get("n_cols")
422
-
423
  for j in range(input_ids.shape[0]):
424
  # Find visual token indices
425
- image_token_mask = (input_ids[j] == self.image_token_id)
426
  visual_indices = torch.where(image_token_mask)[0].cpu().numpy().tolist()
427
-
428
  n_rows = batch_n_rows[j].item() if batch_n_rows is not None else None
429
  n_cols = batch_n_cols[j].item() if batch_n_cols is not None else None
430
-
431
- token_infos.append({
432
- "visual_token_indices": visual_indices,
433
- "num_visual_tokens": len(visual_indices),
434
- "n_rows": n_rows,
435
- "n_cols": n_cols,
436
- "num_tiles": (n_rows * n_cols + 1) if n_rows and n_cols else None,
437
- })
438
-
 
 
439
  # Generate embeddings
440
  batch_embeddings = self.model(**processed)
441
-
442
  # Extract per-image embeddings
443
  if isinstance(batch_embeddings, torch.Tensor) and batch_embeddings.dim() == 3:
444
  for j in range(batch_embeddings.shape[0]):
445
  embeddings.append(batch_embeddings[j].cpu())
446
  else:
447
  embeddings.extend([e.cpu() for e in batch_embeddings])
448
-
449
  # Memory cleanup
450
  del processed, batch_embeddings
451
  gc.collect()
@@ -453,11 +477,11 @@ class VisualEmbedder:
453
  torch.cuda.empty_cache()
454
  elif torch.backends.mps.is_available():
455
  torch.mps.empty_cache()
456
-
457
  if return_token_info:
458
  return embeddings, token_infos
459
  return embeddings
460
-
461
  def extract_visual_embedding(
462
  self,
463
  full_embedding: torch.Tensor,
@@ -465,18 +489,18 @@ class VisualEmbedder:
465
  ) -> np.ndarray:
466
  """
467
  Extract only visual token embeddings from full embedding.
468
-
469
  Filters out special tokens, keeping only visual patches for MaxSim.
470
-
471
  Args:
472
  full_embedding: Full embedding [all_tokens, dim]
473
  token_info: Token info dict from embed_images
474
-
475
  Returns:
476
  Visual embedding array [num_visual_tokens, dim]
477
  """
478
  visual_indices = token_info["visual_token_indices"]
479
-
480
  if isinstance(full_embedding, torch.Tensor):
481
  if full_embedding.dtype == torch.bfloat16:
482
  visual_emb = full_embedding[visual_indices].cpu().float().numpy()
@@ -484,7 +508,7 @@ class VisualEmbedder:
484
  visual_emb = full_embedding[visual_indices].cpu().numpy()
485
  else:
486
  visual_emb = np.array(full_embedding)[visual_indices]
487
-
488
  return visual_emb.astype(self.output_dtype)
489
 
490
  def mean_pool_visual_embedding(
@@ -511,17 +535,23 @@ class VisualEmbedder:
511
  n_rows = (token_info or {}).get("n_rows")
512
  n_cols = (token_info or {}).get("n_cols")
513
  num_tiles = int(n_rows) * int(n_cols) + 1 if n_rows and n_cols else 13
514
- return tile_level_mean_pooling(visual_np, num_tiles=num_tiles, patches_per_tile=64, output_dtype=self.output_dtype)
 
 
515
 
516
  num_tokens = int(visual_np.shape[0])
517
  grid = int(round(float(num_tokens) ** 0.5))
518
  if grid * grid != num_tokens:
519
- raise ValueError(f"Cannot infer square grid from num_visual_tokens={num_tokens} for model={self.model_name}")
 
 
520
  if int(target_vectors) != int(grid):
521
  raise ValueError(
522
  f"target_vectors={target_vectors} does not match inferred grid_size={grid} for model={self.model_name}"
523
  )
524
- return colpali_row_mean_pooling(visual_np, grid_size=int(target_vectors), output_dtype=self.output_dtype)
 
 
525
 
526
  def global_pool_from_mean_pool(self, mean_pool: np.ndarray) -> np.ndarray:
527
  if mean_pool.size == 0:
@@ -536,7 +566,10 @@ class VisualEmbedder:
536
  target_vectors: int = 32,
537
  mean_pool: Optional[np.ndarray] = None,
538
  ) -> np.ndarray:
539
- from visual_rag.embedding.pooling import colpali_experimental_pooling_from_rows, colsmol_experimental_pooling
 
 
 
540
 
541
  model_lower = (self.model_name or "").lower()
542
  is_colsmol = "colsmol" in model_lower
@@ -550,7 +583,11 @@ class VisualEmbedder:
550
  visual_np = np.array(visual_embedding, dtype=np.float32)
551
 
552
  if is_colsmol:
553
- if mean_pool is not None and getattr(mean_pool, "shape", None) is not None and int(mean_pool.shape[0]) > 0:
 
 
 
 
554
  num_tiles = int(mean_pool.shape[0])
555
  else:
556
  num_tiles = (token_info or {}).get("num_tiles")
@@ -563,14 +600,23 @@ class VisualEmbedder:
563
  if int(num_tiles) * patches_per_tile != int(num_visual_tokens):
564
  num_tiles = int(num_tiles) + 1
565
  num_tiles = int(num_tiles)
566
- return colsmol_experimental_pooling(visual_np, num_tiles=num_tiles, patches_per_tile=64, output_dtype=self.output_dtype)
 
 
567
 
568
- rows = mean_pool if mean_pool is not None else self.mean_pool_visual_embedding(visual_np, token_info, target_vectors=target_vectors)
 
 
 
 
 
 
569
  if int(rows.shape[0]) != int(target_vectors):
570
  raise ValueError(
571
  f"experimental pooling expects mean_pool to have {target_vectors} rows, got {rows.shape[0]} for model={self.model_name}"
572
  )
573
  return colpali_experimental_pooling_from_rows(rows, output_dtype=self.output_dtype)
574
 
 
575
  # Backward compatibility alias
576
  ColPaliEmbedder = VisualEmbedder
 
12
  """
13
 
14
  import gc
 
15
  import logging
16
+ import os
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
 
 
19
  import numpy as np
20
+ import torch
21
  from PIL import Image
22
  from tqdm import tqdm
23
 
 
27
  class VisualEmbedder:
28
  """
29
  Visual document embedder supporting multiple backends.
30
+
31
  Currently supports:
32
  - ColPali family (ColSmol-500M, ColPali, ColQwen2)
33
  - More backends can be added
34
+
35
  Args:
36
  model_name: HuggingFace model name (e.g., "vidore/colSmol-500M")
37
  backend: Backend type ("colpali", "auto"). "auto" detects from model_name.
 
39
  torch_dtype: Data type for model weights
40
  batch_size: Batch size for image processing
41
  filter_special_tokens: Filter special tokens from query embeddings
42
+
43
  Example:
44
  >>> # Auto-detect backend from model name
45
  >>> embedder = VisualEmbedder(model_name="vidore/colSmol-500M")
46
+ >>>
47
  >>> # Embed images
48
  >>> image_embeddings = embedder.embed_images(images)
49
+ >>>
50
  >>> # Embed query
51
  >>> query_embedding = embedder.embed_query("What is the budget?")
52
+ >>>
53
  >>> # Get token info for saliency maps
54
  >>> embeddings, token_infos = embedder.embed_images(
55
  ... images, return_token_info=True
56
  ... )
57
  """
58
+
59
  # Known model families and their backends
60
  MODEL_BACKENDS = {
61
  "colsmol": "colpali",
 
63
  "colqwen": "colpali",
64
  "colidefics": "colpali",
65
  }
66
+
67
  def __init__(
68
  self,
69
  model_name: str = "vidore/colSmol-500M",
 
81
  if processor_speed not in ("fast", "slow", "auto"):
82
  raise ValueError("processor_speed must be one of: fast, slow, auto")
83
  self.processor_speed = processor_speed
84
+
85
  if os.getenv("VISUALRAG_INCLUDE_SPECIAL_TOKENS"):
86
  self.filter_special_tokens = False
87
  logger.info("Special token filtering disabled via VISUALRAG_INCLUDE_SPECIAL_TOKENS")
88
+
89
  if backend == "auto":
90
  backend = self._detect_backend(model_name)
91
  self.backend = backend
92
+
93
  if device is None:
94
  if torch.cuda.is_available():
95
  device = "cuda"
 
98
  else:
99
  device = "cpu"
100
  self.device = device
101
+
102
  if torch_dtype is None:
103
  if device == "cuda":
104
  torch_dtype = torch.bfloat16
105
  else:
106
  torch_dtype = torch.float32
107
  self.torch_dtype = torch_dtype
108
+
109
  if output_dtype is None:
110
  if torch_dtype == torch.float16:
111
  output_dtype = np.float16
112
  else:
113
  output_dtype = np.float32
114
  self.output_dtype = output_dtype
115
+
116
  self._model = None
117
  self._processor = None
118
  self._image_token_id = None
119
+
120
+ logger.info("🤖 VisualEmbedder initialized")
121
  logger.info(f" Model: {model_name}")
122
  logger.info(f" Backend: {backend}")
123
+ logger.info(
124
+ f" Device: {device}, torch_dtype: {torch_dtype}, output_dtype: {output_dtype}"
125
+ )
126
+
127
  def _detect_backend(self, model_name: str) -> str:
128
  """Auto-detect backend from model name."""
129
  model_lower = model_name.lower()
130
+
131
  for key, backend in self.MODEL_BACKENDS.items():
132
  if key in model_lower:
133
  logger.debug(f"Detected backend '{backend}' from model name")
134
  return backend
135
+
136
  # Default to colpali for unknown models
137
  logger.warning(f"Unknown model '{model_name}', defaulting to 'colpali' backend")
138
  return "colpali"
139
+
140
  def _load_model(self):
141
  """Lazy load the model when first needed."""
142
  if self._model is not None:
143
  return
144
+
145
  if self.backend == "colpali":
146
  self._load_colpali_model()
147
  else:
148
  raise ValueError(f"Unknown backend: {self.backend}")
149
+
150
  def _load_colpali_model(self):
151
  """Load ColPali-family model."""
152
  try:
 
164
  "pip install visual-rag-toolkit[embedding] or "
165
  "pip install colpali-engine"
166
  )
167
+
168
  logger.info(f"🤖 Loading ColPali model: {self.model_name}")
169
  logger.info(f" Device: {self.device}, dtype: {self.torch_dtype}")
170
 
 
172
  if self.processor_speed == "auto":
173
  return {}
174
  return {"use_fast": self.processor_speed == "fast"}
175
+
176
  from transformers import AutoConfig
177
 
178
  cfg = AutoConfig.from_pretrained(self.model_name)
 
185
  device_map=self.device,
186
  ).eval()
187
  try:
188
+ self._processor = ColPaliProcessor.from_pretrained(
189
+ self.model_name, **_processor_kwargs()
190
+ )
191
  except TypeError:
192
  self._processor = ColPaliProcessor.from_pretrained(self.model_name)
193
  except Exception:
194
  if self.processor_speed == "fast":
195
+ self._processor = ColPaliProcessor.from_pretrained(
196
+ self.model_name, use_fast=False
197
+ )
198
  else:
199
  raise
200
  self._image_token_id = self._processor.image_token_id
 
208
  device_map=self.device,
209
  ).eval()
210
  try:
211
+ self._processor = ColQwen2Processor.from_pretrained(
212
+ self.model_name, device_map=self.device, **_processor_kwargs()
213
+ )
214
  except TypeError:
215
+ self._processor = ColQwen2Processor.from_pretrained(
216
+ self.model_name, device_map=self.device
217
+ )
218
  except Exception:
219
  if self.processor_speed == "fast":
220
+ self._processor = ColQwen2Processor.from_pretrained(
221
+ self.model_name, device_map=self.device, use_fast=False
222
+ )
223
  else:
224
  raise
225
  self._image_token_id = self._processor.image_token_id
 
243
  attn_implementation=attn_implementation,
244
  ).eval()
245
  try:
246
+ self._processor = ColIdefics3Processor.from_pretrained(
247
+ self.model_name, **_processor_kwargs()
248
+ )
249
  except TypeError:
250
  self._processor = ColIdefics3Processor.from_pretrained(self.model_name)
251
  except Exception:
252
  if self.processor_speed == "fast":
253
+ self._processor = ColIdefics3Processor.from_pretrained(
254
+ self.model_name, use_fast=False
255
+ )
256
  else:
257
  raise
258
  self._image_token_id = self._processor.image_token_id
259
+
260
  logger.info("✅ Model loaded successfully")
261
+
262
  @property
263
  def model(self):
264
  self._load_model()
265
  return self._model
266
+
267
  @property
268
  def processor(self):
269
  self._load_model()
270
  return self._processor
271
+
272
  @property
273
  def image_token_id(self):
274
  self._load_model()
275
  return self._image_token_id
276
+
277
  def embed_query(
278
  self,
279
  query_text: str,
 
281
  ) -> torch.Tensor:
282
  """
283
  Generate embedding for a text query.
284
+
285
  By default, filters out special tokens (CLS, SEP, PAD) to keep only
286
  meaningful text tokens for better MaxSim matching.
287
+
288
  Args:
289
  query_text: Natural language query string
290
  filter_special_tokens: Override instance-level setting
291
+
292
  Returns:
293
  Query embedding tensor of shape [num_tokens, embedding_dim]
294
  """
295
  should_filter = (
296
+ filter_special_tokens
297
+ if filter_special_tokens is not None
298
  else self.filter_special_tokens
299
  )
300
+
301
  with torch.no_grad():
302
  processed = self.processor.process_queries([query_text]).to(self.model.device)
303
  embedding = self.model(**processed)
304
+
305
  # Remove batch dimension: [1, tokens, dim] -> [tokens, dim]
306
  if embedding.dim() == 3:
307
  embedding = embedding.squeeze(0)
308
+
309
  if should_filter:
310
  # Filter special tokens based on attention mask
311
  attention_mask = processed.get("attention_mask")
 
313
  # Keep only tokens with attention_mask = 1
314
  valid_mask = attention_mask.squeeze(0).bool()
315
  embedding = embedding[valid_mask]
316
+
317
  # Additionally filter padding tokens if present
318
  input_ids = processed.get("input_ids")
319
  if input_ids is not None:
 
323
  non_special_mask = input_ids >= 4
324
  if non_special_mask.any():
325
  embedding = embedding[non_special_mask]
326
+
327
  logger.debug(f"Query embedding: {embedding.shape[0]} tokens after filtering")
328
  else:
329
  logger.debug(f"Query embedding: {embedding.shape[0]} tokens (unfiltered)")
330
+
331
  return embedding
332
 
333
  def embed_queries(
 
343
  Returns a list of tensors, each of shape [num_tokens, embedding_dim].
344
  """
345
  should_filter = (
346
+ filter_special_tokens
347
+ if filter_special_tokens is not None
348
+ else self.filter_special_tokens
349
  )
350
  batch_size = batch_size or self.batch_size
351
 
 
386
  torch.mps.empty_cache()
387
 
388
  return outputs
389
+
390
  def embed_images(
391
  self,
392
  images: List[Image.Image],
 
396
  ) -> Union[List[torch.Tensor], Tuple[List[torch.Tensor], List[Dict[str, Any]]]]:
397
  """
398
  Generate embeddings for a list of images.
399
+
400
  Args:
401
  images: List of PIL Images
402
  batch_size: Override instance batch size
403
  return_token_info: Also return token metadata (for saliency maps)
404
  show_progress: Show progress bar
405
+
406
  Returns:
407
  If return_token_info=False:
408
  List of embedding tensors [num_patches, dim]
409
  If return_token_info=True:
410
  Tuple of (embeddings, token_infos)
411
+
412
  Token info contains:
413
  - visual_token_indices: Indices of visual tokens in embedding
414
  - num_visual_tokens: Count of visual tokens
 
416
  - num_tiles: Total tiles (n_rows × n_cols + 1 global)
417
  """
418
  batch_size = batch_size or self.batch_size
419
+ if (
420
+ self.device == "mps"
421
+ and "colpali" in (self.model_name or "").lower()
422
+ and int(batch_size) > 1
423
+ ):
424
  batch_size = 1
425
+
426
  embeddings = []
427
  token_infos = [] if return_token_info else None
428
+
429
  iterator = range(0, len(images), batch_size)
430
  if show_progress:
431
  iterator = tqdm(iterator, desc="🎨 Embedding", unit="batch")
432
+
433
  for i in iterator:
434
+ batch = images[i : i + batch_size]
435
+
436
  with torch.no_grad():
437
  processed = self.processor.process_images(batch).to(self.model.device)
438
+
439
  # Extract token info before model forward
440
  if return_token_info:
441
  input_ids = processed["input_ids"]
442
  batch_n_rows = processed.get("n_rows")
443
  batch_n_cols = processed.get("n_cols")
444
+
445
  for j in range(input_ids.shape[0]):
446
  # Find visual token indices
447
+ image_token_mask = input_ids[j] == self.image_token_id
448
  visual_indices = torch.where(image_token_mask)[0].cpu().numpy().tolist()
449
+
450
  n_rows = batch_n_rows[j].item() if batch_n_rows is not None else None
451
  n_cols = batch_n_cols[j].item() if batch_n_cols is not None else None
452
+
453
+ token_infos.append(
454
+ {
455
+ "visual_token_indices": visual_indices,
456
+ "num_visual_tokens": len(visual_indices),
457
+ "n_rows": n_rows,
458
+ "n_cols": n_cols,
459
+ "num_tiles": (n_rows * n_cols + 1) if n_rows and n_cols else None,
460
+ }
461
+ )
462
+
463
  # Generate embeddings
464
  batch_embeddings = self.model(**processed)
465
+
466
  # Extract per-image embeddings
467
  if isinstance(batch_embeddings, torch.Tensor) and batch_embeddings.dim() == 3:
468
  for j in range(batch_embeddings.shape[0]):
469
  embeddings.append(batch_embeddings[j].cpu())
470
  else:
471
  embeddings.extend([e.cpu() for e in batch_embeddings])
472
+
473
  # Memory cleanup
474
  del processed, batch_embeddings
475
  gc.collect()
 
477
  torch.cuda.empty_cache()
478
  elif torch.backends.mps.is_available():
479
  torch.mps.empty_cache()
480
+
481
  if return_token_info:
482
  return embeddings, token_infos
483
  return embeddings
484
+
485
  def extract_visual_embedding(
486
  self,
487
  full_embedding: torch.Tensor,
 
489
  ) -> np.ndarray:
490
  """
491
  Extract only visual token embeddings from full embedding.
492
+
493
  Filters out special tokens, keeping only visual patches for MaxSim.
494
+
495
  Args:
496
  full_embedding: Full embedding [all_tokens, dim]
497
  token_info: Token info dict from embed_images
498
+
499
  Returns:
500
  Visual embedding array [num_visual_tokens, dim]
501
  """
502
  visual_indices = token_info["visual_token_indices"]
503
+
504
  if isinstance(full_embedding, torch.Tensor):
505
  if full_embedding.dtype == torch.bfloat16:
506
  visual_emb = full_embedding[visual_indices].cpu().float().numpy()
 
508
  visual_emb = full_embedding[visual_indices].cpu().numpy()
509
  else:
510
  visual_emb = np.array(full_embedding)[visual_indices]
511
+
512
  return visual_emb.astype(self.output_dtype)
513
 
514
  def mean_pool_visual_embedding(
 
535
  n_rows = (token_info or {}).get("n_rows")
536
  n_cols = (token_info or {}).get("n_cols")
537
  num_tiles = int(n_rows) * int(n_cols) + 1 if n_rows and n_cols else 13
538
+ return tile_level_mean_pooling(
539
+ visual_np, num_tiles=num_tiles, patches_per_tile=64, output_dtype=self.output_dtype
540
+ )
541
 
542
  num_tokens = int(visual_np.shape[0])
543
  grid = int(round(float(num_tokens) ** 0.5))
544
  if grid * grid != num_tokens:
545
+ raise ValueError(
546
+ f"Cannot infer square grid from num_visual_tokens={num_tokens} for model={self.model_name}"
547
+ )
548
  if int(target_vectors) != int(grid):
549
  raise ValueError(
550
  f"target_vectors={target_vectors} does not match inferred grid_size={grid} for model={self.model_name}"
551
  )
552
+ return colpali_row_mean_pooling(
553
+ visual_np, grid_size=int(target_vectors), output_dtype=self.output_dtype
554
+ )
555
 
556
  def global_pool_from_mean_pool(self, mean_pool: np.ndarray) -> np.ndarray:
557
  if mean_pool.size == 0:
 
566
  target_vectors: int = 32,
567
  mean_pool: Optional[np.ndarray] = None,
568
  ) -> np.ndarray:
569
+ from visual_rag.embedding.pooling import (
570
+ colpali_experimental_pooling_from_rows,
571
+ colsmol_experimental_pooling,
572
+ )
573
 
574
  model_lower = (self.model_name or "").lower()
575
  is_colsmol = "colsmol" in model_lower
 
583
  visual_np = np.array(visual_embedding, dtype=np.float32)
584
 
585
  if is_colsmol:
586
+ if (
587
+ mean_pool is not None
588
+ and getattr(mean_pool, "shape", None) is not None
589
+ and int(mean_pool.shape[0]) > 0
590
+ ):
591
  num_tiles = int(mean_pool.shape[0])
592
  else:
593
  num_tiles = (token_info or {}).get("num_tiles")
 
600
  if int(num_tiles) * patches_per_tile != int(num_visual_tokens):
601
  num_tiles = int(num_tiles) + 1
602
  num_tiles = int(num_tiles)
603
+ return colsmol_experimental_pooling(
604
+ visual_np, num_tiles=num_tiles, patches_per_tile=64, output_dtype=self.output_dtype
605
+ )
606
 
607
+ rows = (
608
+ mean_pool
609
+ if mean_pool is not None
610
+ else self.mean_pool_visual_embedding(
611
+ visual_np, token_info, target_vectors=target_vectors
612
+ )
613
+ )
614
  if int(rows.shape[0]) != int(target_vectors):
615
  raise ValueError(
616
  f"experimental pooling expects mean_pool to have {target_vectors} rows, got {rows.shape[0]} for model={self.model_name}"
617
  )
618
  return colpali_experimental_pooling_from_rows(rows, output_dtype=self.output_dtype)
619
 
620
+
621
  # Backward compatibility alias
622
  ColPaliEmbedder = VisualEmbedder
visual_rag/indexing/__init__.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Indexing module - PDF processing, embedding storage, and CDN uploads.
3
+
4
+ Components:
5
+ - PDFProcessor: Convert PDFs to images and extract text
6
+ - QdrantIndexer: Upload embeddings to Qdrant vector database
7
+ - CloudinaryUploader: Upload images to Cloudinary CDN
8
+ - ProcessingPipeline: End-to-end PDF → Qdrant pipeline
9
+ """
10
+
11
+ # Lazy imports to avoid failures when optional dependencies aren't installed
12
+
13
+ try:
14
+ from visual_rag.indexing.pdf_processor import PDFProcessor
15
+ except ImportError:
16
+ PDFProcessor = None
17
+
18
+ try:
19
+ from visual_rag.indexing.qdrant_indexer import QdrantIndexer
20
+ except ImportError:
21
+ QdrantIndexer = None
22
+
23
+ try:
24
+ from visual_rag.indexing.cloudinary_uploader import CloudinaryUploader
25
+ except ImportError:
26
+ CloudinaryUploader = None
27
+
28
+ try:
29
+ from visual_rag.indexing.pipeline import ProcessingPipeline
30
+ except ImportError:
31
+ ProcessingPipeline = None
32
+
33
+ __all__ = [
34
+ "PDFProcessor",
35
+ "QdrantIndexer",
36
+ "CloudinaryUploader",
37
+ "ProcessingPipeline",
38
+ ]
visual_rag/indexing/cloudinary_uploader.py CHANGED
@@ -15,14 +15,15 @@ Environment Variables:
15
  """
16
 
17
  import io
18
- import os
19
- import time
20
- import signal
21
  import logging
 
22
  import platform
 
23
  import threading
 
 
 
24
  from typing import Optional
25
- from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
26
 
27
  from PIL import Image
28
 
@@ -34,9 +35,9 @@ THREAD_SAFE_MODE = os.getenv("VISUAL_RAG_THREAD_SAFE", "").lower() in ("1", "tru
34
  class CloudinaryUploader:
35
  """
36
  Upload images to Cloudinary CDN.
37
-
38
  Works independently - just needs PIL images.
39
-
40
  Args:
41
  cloud_name: Cloudinary cloud name
42
  api_key: Cloudinary API key
@@ -44,7 +45,7 @@ class CloudinaryUploader:
44
  folder: Base folder for uploads
45
  max_retries: Number of retry attempts
46
  timeout_seconds: Timeout per upload
47
-
48
  Example:
49
  >>> uploader = CloudinaryUploader(
50
  ... cloud_name="my-cloud",
@@ -52,11 +53,11 @@ class CloudinaryUploader:
52
  ... api_secret="yyy",
53
  ... folder="my-project",
54
  ... )
55
- >>>
56
  >>> url = uploader.upload(image, "doc_page_1")
57
  >>> print(url) # https://res.cloudinary.com/.../doc_page_1.jpg
58
  """
59
-
60
  def __init__(
61
  self,
62
  cloud_name: Optional[str] = None,
@@ -71,19 +72,19 @@ class CloudinaryUploader:
71
  self.cloud_name = cloud_name or os.getenv("CLOUDINARY_CLOUD_NAME")
72
  self.api_key = api_key or os.getenv("CLOUDINARY_API_KEY")
73
  self.api_secret = api_secret or os.getenv("CLOUDINARY_API_SECRET")
74
-
75
  if not all([self.cloud_name, self.api_key, self.api_secret]):
76
  raise ValueError(
77
  "Cloudinary credentials required. Set CLOUDINARY_CLOUD_NAME, "
78
  "CLOUDINARY_API_KEY, CLOUDINARY_API_SECRET environment variables "
79
  "or pass them as arguments."
80
  )
81
-
82
  self.folder = folder
83
  self.max_retries = max_retries
84
  self.timeout_seconds = timeout_seconds
85
  self.jpeg_quality = jpeg_quality
86
-
87
  # Check dependency
88
  try:
89
  import cloudinary # noqa
@@ -92,10 +93,10 @@ class CloudinaryUploader:
92
  "Cloudinary not installed. "
93
  "Install with: pip install visual-rag-toolkit[cloudinary]"
94
  )
95
-
96
- logger.info(f"☁️ Cloudinary uploader initialized")
97
  logger.info(f" Folder: {folder}")
98
-
99
  def upload(
100
  self,
101
  image: Image.Image,
@@ -104,34 +105,34 @@ class CloudinaryUploader:
104
  ) -> Optional[str]:
105
  """
106
  Upload a single image to Cloudinary.
107
-
108
  Args:
109
  image: PIL Image to upload
110
  public_id: Public ID (filename without extension)
111
  subfolder: Optional subfolder within base folder
112
-
113
  Returns:
114
  Secure URL of uploaded image, or None if failed
115
  """
116
  import cloudinary
117
  import cloudinary.uploader
118
-
119
  # Prepare buffer
120
  buffer = io.BytesIO()
121
  image.save(buffer, format="JPEG", quality=self.jpeg_quality, optimize=True)
122
-
123
  # Configure Cloudinary
124
  cloudinary.config(
125
  cloud_name=self.cloud_name,
126
  api_key=self.api_key,
127
  api_secret=self.api_secret,
128
  )
129
-
130
  # Build folder path
131
  folder_path = self.folder
132
  if subfolder:
133
  folder_path = f"{self.folder}/{subfolder}"
134
-
135
  def do_upload():
136
  buffer.seek(0)
137
  result = cloudinary.uploader.upload(
@@ -143,14 +144,14 @@ class CloudinaryUploader:
143
  timeout=self.timeout_seconds,
144
  )
145
  return result["secure_url"]
146
-
147
  # Use thread-safe mode for Streamlit/Flask/threaded contexts
148
  # Set VISUAL_RAG_THREAD_SAFE=1 to enable
149
  if THREAD_SAFE_MODE or threading.current_thread() is not threading.main_thread():
150
  return self._upload_with_thread_timeout(do_upload, public_id)
151
  else:
152
  return self._upload_with_signal_timeout(do_upload, public_id)
153
-
154
  def _upload_with_thread_timeout(self, do_upload, public_id: str) -> Optional[str]:
155
  """Thread-safe upload with ThreadPoolExecutor timeout."""
156
  for attempt in range(self.max_retries):
@@ -158,64 +159,60 @@ class CloudinaryUploader:
158
  with ThreadPoolExecutor(max_workers=1) as executor:
159
  future = executor.submit(do_upload)
160
  return future.result(timeout=self.timeout_seconds)
161
-
162
  except FuturesTimeoutError:
163
  logger.warning(
164
  f"Upload timeout (attempt {attempt + 1}/{self.max_retries}): {public_id}"
165
  )
166
  if attempt < self.max_retries - 1:
167
- time.sleep(2 ** attempt)
168
-
169
  except Exception as e:
170
- logger.warning(
171
- f"Upload failed (attempt {attempt + 1}/{self.max_retries}): {e}"
172
- )
173
  if attempt < self.max_retries - 1:
174
- time.sleep(2 ** attempt)
175
-
176
  logger.error(f"❌ Upload failed after {self.max_retries} attempts: {public_id}")
177
  return None
178
-
179
  def _upload_with_signal_timeout(self, do_upload, public_id: str) -> Optional[str]:
180
  """Signal-based upload timeout (main thread only, Unix/macOS)."""
181
  use_timeout = platform.system() != "Windows"
182
-
183
  class SignalTimeoutError(Exception):
184
  pass
185
-
186
  def timeout_handler(signum, frame):
187
  raise SignalTimeoutError(f"Upload timed out after {self.timeout_seconds}s")
188
-
189
  for attempt in range(self.max_retries):
190
  try:
191
  if use_timeout:
192
  old_handler = signal.signal(signal.SIGALRM, timeout_handler)
193
  signal.alarm(self.timeout_seconds)
194
-
195
  try:
196
  return do_upload()
197
  finally:
198
  if use_timeout:
199
  signal.alarm(0)
200
  signal.signal(signal.SIGALRM, old_handler)
201
-
202
  except SignalTimeoutError:
203
  logger.warning(
204
  f"Upload timeout (attempt {attempt + 1}/{self.max_retries}): {public_id}"
205
  )
206
  if attempt < self.max_retries - 1:
207
- time.sleep(2 ** attempt)
208
-
209
  except Exception as e:
210
- logger.warning(
211
- f"Upload failed (attempt {attempt + 1}/{self.max_retries}): {e}"
212
- )
213
  if attempt < self.max_retries - 1:
214
- time.sleep(2 ** attempt)
215
-
216
  logger.error(f"❌ Upload failed after {self.max_retries} attempts: {public_id}")
217
  return None
218
-
219
  def upload_original_and_resized(
220
  self,
221
  original_image: Image.Image,
@@ -224,12 +221,12 @@ class CloudinaryUploader:
224
  ) -> tuple:
225
  """
226
  Upload both original and resized versions.
227
-
228
  Args:
229
  original_image: Original PDF page image
230
  resized_image: Resized image for ColPali
231
  base_public_id: Base public ID (e.g., "doc_page_1")
232
-
233
  Returns:
234
  Tuple of (original_url, resized_url) - either can be None on failure
235
  """
@@ -238,13 +235,13 @@ class CloudinaryUploader:
238
  base_public_id,
239
  subfolder="original",
240
  )
241
-
242
  resized_url = self.upload(
243
  resized_image,
244
  base_public_id,
245
  subfolder="resized",
246
  )
247
-
248
  return original_url, resized_url
249
 
250
  def upload_original_cropped_and_resized(
@@ -275,5 +272,3 @@ class CloudinaryUploader:
275
  )
276
 
277
  return original_url, cropped_url, resized_url
278
-
279
-
 
15
  """
16
 
17
  import io
 
 
 
18
  import logging
19
+ import os
20
  import platform
21
+ import signal
22
  import threading
23
+ import time
24
+ from concurrent.futures import ThreadPoolExecutor
25
+ from concurrent.futures import TimeoutError as FuturesTimeoutError
26
  from typing import Optional
 
27
 
28
  from PIL import Image
29
 
 
35
  class CloudinaryUploader:
36
  """
37
  Upload images to Cloudinary CDN.
38
+
39
  Works independently - just needs PIL images.
40
+
41
  Args:
42
  cloud_name: Cloudinary cloud name
43
  api_key: Cloudinary API key
 
45
  folder: Base folder for uploads
46
  max_retries: Number of retry attempts
47
  timeout_seconds: Timeout per upload
48
+
49
  Example:
50
  >>> uploader = CloudinaryUploader(
51
  ... cloud_name="my-cloud",
 
53
  ... api_secret="yyy",
54
  ... folder="my-project",
55
  ... )
56
+ >>>
57
  >>> url = uploader.upload(image, "doc_page_1")
58
  >>> print(url) # https://res.cloudinary.com/.../doc_page_1.jpg
59
  """
60
+
61
  def __init__(
62
  self,
63
  cloud_name: Optional[str] = None,
 
72
  self.cloud_name = cloud_name or os.getenv("CLOUDINARY_CLOUD_NAME")
73
  self.api_key = api_key or os.getenv("CLOUDINARY_API_KEY")
74
  self.api_secret = api_secret or os.getenv("CLOUDINARY_API_SECRET")
75
+
76
  if not all([self.cloud_name, self.api_key, self.api_secret]):
77
  raise ValueError(
78
  "Cloudinary credentials required. Set CLOUDINARY_CLOUD_NAME, "
79
  "CLOUDINARY_API_KEY, CLOUDINARY_API_SECRET environment variables "
80
  "or pass them as arguments."
81
  )
82
+
83
  self.folder = folder
84
  self.max_retries = max_retries
85
  self.timeout_seconds = timeout_seconds
86
  self.jpeg_quality = jpeg_quality
87
+
88
  # Check dependency
89
  try:
90
  import cloudinary # noqa
 
93
  "Cloudinary not installed. "
94
  "Install with: pip install visual-rag-toolkit[cloudinary]"
95
  )
96
+
97
+ logger.info("☁️ Cloudinary uploader initialized")
98
  logger.info(f" Folder: {folder}")
99
+
100
  def upload(
101
  self,
102
  image: Image.Image,
 
105
  ) -> Optional[str]:
106
  """
107
  Upload a single image to Cloudinary.
108
+
109
  Args:
110
  image: PIL Image to upload
111
  public_id: Public ID (filename without extension)
112
  subfolder: Optional subfolder within base folder
113
+
114
  Returns:
115
  Secure URL of uploaded image, or None if failed
116
  """
117
  import cloudinary
118
  import cloudinary.uploader
119
+
120
  # Prepare buffer
121
  buffer = io.BytesIO()
122
  image.save(buffer, format="JPEG", quality=self.jpeg_quality, optimize=True)
123
+
124
  # Configure Cloudinary
125
  cloudinary.config(
126
  cloud_name=self.cloud_name,
127
  api_key=self.api_key,
128
  api_secret=self.api_secret,
129
  )
130
+
131
  # Build folder path
132
  folder_path = self.folder
133
  if subfolder:
134
  folder_path = f"{self.folder}/{subfolder}"
135
+
136
  def do_upload():
137
  buffer.seek(0)
138
  result = cloudinary.uploader.upload(
 
144
  timeout=self.timeout_seconds,
145
  )
146
  return result["secure_url"]
147
+
148
  # Use thread-safe mode for Streamlit/Flask/threaded contexts
149
  # Set VISUAL_RAG_THREAD_SAFE=1 to enable
150
  if THREAD_SAFE_MODE or threading.current_thread() is not threading.main_thread():
151
  return self._upload_with_thread_timeout(do_upload, public_id)
152
  else:
153
  return self._upload_with_signal_timeout(do_upload, public_id)
154
+
155
  def _upload_with_thread_timeout(self, do_upload, public_id: str) -> Optional[str]:
156
  """Thread-safe upload with ThreadPoolExecutor timeout."""
157
  for attempt in range(self.max_retries):
 
159
  with ThreadPoolExecutor(max_workers=1) as executor:
160
  future = executor.submit(do_upload)
161
  return future.result(timeout=self.timeout_seconds)
162
+
163
  except FuturesTimeoutError:
164
  logger.warning(
165
  f"Upload timeout (attempt {attempt + 1}/{self.max_retries}): {public_id}"
166
  )
167
  if attempt < self.max_retries - 1:
168
+ time.sleep(2**attempt)
169
+
170
  except Exception as e:
171
+ logger.warning(f"Upload failed (attempt {attempt + 1}/{self.max_retries}): {e}")
 
 
172
  if attempt < self.max_retries - 1:
173
+ time.sleep(2**attempt)
174
+
175
  logger.error(f"❌ Upload failed after {self.max_retries} attempts: {public_id}")
176
  return None
177
+
178
  def _upload_with_signal_timeout(self, do_upload, public_id: str) -> Optional[str]:
179
  """Signal-based upload timeout (main thread only, Unix/macOS)."""
180
  use_timeout = platform.system() != "Windows"
181
+
182
  class SignalTimeoutError(Exception):
183
  pass
184
+
185
  def timeout_handler(signum, frame):
186
  raise SignalTimeoutError(f"Upload timed out after {self.timeout_seconds}s")
187
+
188
  for attempt in range(self.max_retries):
189
  try:
190
  if use_timeout:
191
  old_handler = signal.signal(signal.SIGALRM, timeout_handler)
192
  signal.alarm(self.timeout_seconds)
193
+
194
  try:
195
  return do_upload()
196
  finally:
197
  if use_timeout:
198
  signal.alarm(0)
199
  signal.signal(signal.SIGALRM, old_handler)
200
+
201
  except SignalTimeoutError:
202
  logger.warning(
203
  f"Upload timeout (attempt {attempt + 1}/{self.max_retries}): {public_id}"
204
  )
205
  if attempt < self.max_retries - 1:
206
+ time.sleep(2**attempt)
207
+
208
  except Exception as e:
209
+ logger.warning(f"Upload failed (attempt {attempt + 1}/{self.max_retries}): {e}")
 
 
210
  if attempt < self.max_retries - 1:
211
+ time.sleep(2**attempt)
212
+
213
  logger.error(f"❌ Upload failed after {self.max_retries} attempts: {public_id}")
214
  return None
215
+
216
  def upload_original_and_resized(
217
  self,
218
  original_image: Image.Image,
 
221
  ) -> tuple:
222
  """
223
  Upload both original and resized versions.
224
+
225
  Args:
226
  original_image: Original PDF page image
227
  resized_image: Resized image for ColPali
228
  base_public_id: Base public ID (e.g., "doc_page_1")
229
+
230
  Returns:
231
  Tuple of (original_url, resized_url) - either can be None on failure
232
  """
 
235
  base_public_id,
236
  subfolder="original",
237
  )
238
+
239
  resized_url = self.upload(
240
  resized_image,
241
  base_public_id,
242
  subfolder="resized",
243
  )
244
+
245
  return original_url, resized_url
246
 
247
  def upload_original_cropped_and_resized(
 
272
  )
273
 
274
  return original_url, cropped_url, resized_url
 
 
visual_rag/indexing/pdf_processor.py CHANGED
@@ -11,10 +11,10 @@ Features:
11
  """
12
 
13
  import gc
14
- import re
15
  import logging
 
16
  from pathlib import Path
17
- from typing import List, Dict, Any, Optional, Tuple, Generator
18
 
19
  from PIL import Image
20
 
@@ -24,26 +24,26 @@ logger = logging.getLogger(__name__)
24
  class PDFProcessor:
25
  """
26
  Process PDFs into images and text for visual retrieval.
27
-
28
  Works independently - no embedding or storage dependencies.
29
-
30
  Args:
31
  dpi: DPI for image conversion (higher = better quality)
32
  output_format: Image format (RGB, L, etc.)
33
  page_batch_size: Pages per batch for memory efficiency
34
-
35
  Example:
36
  >>> processor = PDFProcessor(dpi=140)
37
- >>>
38
  >>> # Convert single PDF
39
  >>> images, texts = processor.process_pdf(Path("report.pdf"))
40
- >>>
41
  >>> # Stream large PDFs
42
  >>> for images, texts in processor.stream_pdf(Path("large.pdf"), batch_size=10):
43
  ... # Process each batch
44
  ... pass
45
  """
46
-
47
  def __init__(
48
  self,
49
  dpi: int = 140,
@@ -53,17 +53,24 @@ class PDFProcessor:
53
  self.dpi = dpi
54
  self.output_format = output_format
55
  self.page_batch_size = page_batch_size
56
-
57
- # Check dependencies
 
 
 
58
  try:
59
- from pdf2image import convert_from_path # noqa
60
- from pypdf import PdfReader # noqa
61
- except ImportError:
 
 
 
 
62
  raise ImportError(
63
- "PDF processing requires pdf2image and pypdf. "
64
- "Install with: pip install visual-rag-toolkit[pdf]"
65
  )
66
-
67
  def process_pdf(
68
  self,
69
  pdf_path: Path,
@@ -71,38 +78,39 @@ class PDFProcessor:
71
  ) -> Tuple[List[Image.Image], List[str]]:
72
  """
73
  Convert PDF to images and extract text.
74
-
75
  Args:
76
  pdf_path: Path to PDF file
77
  dpi: Override default DPI
78
-
79
  Returns:
80
  Tuple of (list of images, list of page texts)
81
  """
 
82
  from pdf2image import convert_from_path
83
  from pypdf import PdfReader
84
-
85
  dpi = dpi or self.dpi
86
  pdf_path = Path(pdf_path)
87
-
88
  logger.info(f"📄 Processing PDF: {pdf_path.name}")
89
-
90
  # Extract text
91
  reader = PdfReader(str(pdf_path))
92
  total_pages = len(reader.pages)
93
-
94
  page_texts = []
95
  for page in reader.pages:
96
  text = page.extract_text() or ""
97
  # Handle surrogate characters
98
  text = self._sanitize_text(text)
99
  page_texts.append(text)
100
-
101
  # Convert to images in batches
102
  all_images = []
103
  for start_page in range(1, total_pages + 1, self.page_batch_size):
104
  end_page = min(start_page + self.page_batch_size - 1, total_pages)
105
-
106
  batch_images = convert_from_path(
107
  str(pdf_path),
108
  dpi=dpi,
@@ -110,19 +118,19 @@ class PDFProcessor:
110
  first_page=start_page,
111
  last_page=end_page,
112
  )
113
-
114
  all_images.extend(batch_images)
115
-
116
  del batch_images
117
  gc.collect()
118
-
119
- assert len(all_images) == len(page_texts), (
120
- f"Mismatch: {len(all_images)} images vs {len(page_texts)} texts"
121
- )
122
-
123
  logger.info(f"✅ Processed {len(all_images)} pages")
124
  return all_images, page_texts
125
-
126
  def stream_pdf(
127
  self,
128
  pdf_path: Path,
@@ -131,39 +139,40 @@ class PDFProcessor:
131
  ) -> Generator[Tuple[List[Image.Image], List[str], int], None, None]:
132
  """
133
  Stream PDF processing for large files.
134
-
135
  Yields batches of (images, texts, start_page) without loading
136
  entire PDF into memory.
137
-
138
  Args:
139
  pdf_path: Path to PDF file
140
  batch_size: Pages per batch
141
  dpi: Override default DPI
142
-
143
  Yields:
144
  Tuple of (batch_images, batch_texts, start_page_number)
145
  """
 
146
  from pdf2image import convert_from_path
147
  from pypdf import PdfReader
148
-
149
  dpi = dpi or self.dpi
150
  pdf_path = Path(pdf_path)
151
-
152
  reader = PdfReader(str(pdf_path))
153
  total_pages = len(reader.pages)
154
-
155
  logger.info(f"📄 Streaming PDF: {pdf_path.name} ({total_pages} pages)")
156
-
157
  for start_idx in range(0, total_pages, batch_size):
158
  end_idx = min(start_idx + batch_size, total_pages)
159
-
160
  # Extract text for batch
161
  batch_texts = []
162
  for page_idx in range(start_idx, end_idx):
163
  text = reader.pages[page_idx].extract_text() or ""
164
  text = self._sanitize_text(text)
165
  batch_texts.append(text)
166
-
167
  # Convert images for batch
168
  batch_images = convert_from_path(
169
  str(pdf_path),
@@ -172,18 +181,20 @@ class PDFProcessor:
172
  first_page=start_idx + 1, # 1-indexed
173
  last_page=end_idx,
174
  )
175
-
176
  yield batch_images, batch_texts, start_idx + 1
177
-
178
  del batch_images
179
  gc.collect()
180
-
181
  def get_page_count(self, pdf_path: Path) -> int:
182
  """Get number of pages in PDF without loading images."""
 
183
  from pypdf import PdfReader
 
184
  reader = PdfReader(str(pdf_path))
185
  return len(reader.pages)
186
-
187
  def resize_for_colpali(
188
  self,
189
  image: Image.Image,
@@ -192,19 +203,23 @@ class PDFProcessor:
192
  ) -> Tuple[Image.Image, int, int]:
193
  """
194
  Resize image following ColPali/Idefics3 processor logic.
195
-
196
  Resizes to fit within tile grid without black padding.
197
-
198
  Args:
199
  image: PIL Image
200
  max_edge: Maximum edge length
201
  tile_size: Size of each tile
202
-
203
  Returns:
204
  Tuple of (resized_image, tile_rows, tile_cols)
205
  """
 
 
 
 
206
  w, h = image.size
207
-
208
  # Step 1: Resize so longest edge = max_edge
209
  if w > h:
210
  new_w = max_edge
@@ -212,25 +227,25 @@ class PDFProcessor:
212
  else:
213
  new_h = max_edge
214
  new_w = int(w * (max_edge / h))
215
-
216
  # Step 2: Calculate tile grid
217
  tile_cols = (new_w + tile_size - 1) // tile_size
218
  tile_rows = (new_h + tile_size - 1) // tile_size
219
-
220
  # Step 3: Calculate exact dimensions for tiles
221
  final_w = tile_cols * tile_size
222
  final_h = tile_rows * tile_size
223
-
224
  # Step 4: Scale to fit within tile grid
225
  scale_w = final_w / w
226
  scale_h = final_h / h
227
  scale = min(scale_w, scale_h)
228
-
229
  scaled_w = int(w * scale)
230
  scaled_h = int(h * scale)
231
-
232
  resized = image.resize((scaled_w, scaled_h), Image.LANCZOS)
233
-
234
  # Center on white canvas if needed
235
  if scaled_w != final_w or scaled_h != final_h:
236
  canvas = Image.new("RGB", (final_w, final_h), (255, 255, 255))
@@ -238,19 +253,17 @@ class PDFProcessor:
238
  offset_y = (final_h - scaled_h) // 2
239
  canvas.paste(resized, (offset_x, offset_y))
240
  resized = canvas
241
-
242
  return resized, tile_rows, tile_cols
243
-
244
  def _sanitize_text(self, text: str) -> str:
245
  """Remove invalid Unicode characters (surrogates) from text."""
246
  if not text:
247
  return ""
248
-
249
  # Remove surrogate characters (U+D800-U+DFFF)
250
- return text.encode("utf-8", errors="surrogatepass").decode(
251
- "utf-8", errors="ignore"
252
- )
253
-
254
  def extract_metadata_from_filename(
255
  self,
256
  filename: str,
@@ -258,47 +271,45 @@ class PDFProcessor:
258
  ) -> Dict[str, Any]:
259
  """
260
  Extract metadata from PDF filename.
261
-
262
  Uses mapping if provided, otherwise falls back to pattern matching.
263
-
264
  Args:
265
  filename: PDF filename (with or without .pdf extension)
266
  mapping: Optional mapping dict {filename: metadata}
267
-
268
  Returns:
269
  Metadata dict with year, source, district, etc.
270
  """
271
  # Remove extension
272
  stem = Path(filename).stem
273
  stem_lower = stem.lower().strip()
274
-
275
  # Try mapping first
276
  if mapping:
277
  if stem_lower in mapping:
278
  return mapping[stem_lower].copy()
279
-
280
  # Try without .pdf
281
  stem_no_ext = stem_lower.replace(".pdf", "")
282
  if stem_no_ext in mapping:
283
  return mapping[stem_no_ext].copy()
284
-
285
  # Fallback: pattern matching
286
  metadata = {"filename": filename}
287
-
288
  # Extract year
289
  year_match = re.search(r"(20\d{2})", stem)
290
  if year_match:
291
  metadata["year"] = int(year_match.group(1))
292
-
293
  # Detect source type
294
  if "consolidated" in stem_lower or ("annual" in stem_lower and "oag" in stem_lower):
295
  metadata["source"] = "Consolidated"
296
  elif "dlg" in stem_lower or "district local government" in stem_lower:
297
  metadata["source"] = "Local Government"
298
  # Try to extract district name
299
- district_match = re.search(
300
- r"([a-z]+)\s+(?:dlg|district local government)", stem_lower
301
- )
302
  if district_match:
303
  metadata["district"] = district_match.group(1).title()
304
  elif "hospital" in stem_lower or "referral" in stem_lower:
@@ -309,7 +320,5 @@ class PDFProcessor:
309
  metadata["source"] = "Project"
310
  else:
311
  metadata["source"] = "Unknown"
312
-
313
- return metadata
314
-
315
 
 
 
11
  """
12
 
13
  import gc
 
14
  import logging
15
+ import re
16
  from pathlib import Path
17
+ from typing import Any, Dict, Generator, List, Optional, Tuple
18
 
19
  from PIL import Image
20
 
 
24
  class PDFProcessor:
25
  """
26
  Process PDFs into images and text for visual retrieval.
27
+
28
  Works independently - no embedding or storage dependencies.
29
+
30
  Args:
31
  dpi: DPI for image conversion (higher = better quality)
32
  output_format: Image format (RGB, L, etc.)
33
  page_batch_size: Pages per batch for memory efficiency
34
+
35
  Example:
36
  >>> processor = PDFProcessor(dpi=140)
37
+ >>>
38
  >>> # Convert single PDF
39
  >>> images, texts = processor.process_pdf(Path("report.pdf"))
40
+ >>>
41
  >>> # Stream large PDFs
42
  >>> for images, texts in processor.stream_pdf(Path("large.pdf"), batch_size=10):
43
  ... # Process each batch
44
  ... pass
45
  """
46
+
47
  def __init__(
48
  self,
49
  dpi: int = 140,
 
53
  self.dpi = dpi
54
  self.output_format = output_format
55
  self.page_batch_size = page_batch_size
56
+
57
+ # PDF deps are optional: we only require them when calling PDF-specific methods.
58
+ # This keeps the class usable for helper utilities like `resize_for_colpali()`
59
+ # even in minimal installs.
60
+ self._pdf_deps_available = True
61
  try:
62
+ import pdf2image # noqa: F401
63
+ import pypdf # noqa: F401
64
+ except Exception:
65
+ self._pdf_deps_available = False
66
+
67
+ def _require_pdf_deps(self) -> None:
68
+ if not self._pdf_deps_available:
69
  raise ImportError(
70
+ "PDF processing requires `pdf2image` and `pypdf`.\n"
71
+ 'Install with: pip install "visual-rag-toolkit[pdf]"'
72
  )
73
+
74
  def process_pdf(
75
  self,
76
  pdf_path: Path,
 
78
  ) -> Tuple[List[Image.Image], List[str]]:
79
  """
80
  Convert PDF to images and extract text.
81
+
82
  Args:
83
  pdf_path: Path to PDF file
84
  dpi: Override default DPI
85
+
86
  Returns:
87
  Tuple of (list of images, list of page texts)
88
  """
89
+ self._require_pdf_deps()
90
  from pdf2image import convert_from_path
91
  from pypdf import PdfReader
92
+
93
  dpi = dpi or self.dpi
94
  pdf_path = Path(pdf_path)
95
+
96
  logger.info(f"📄 Processing PDF: {pdf_path.name}")
97
+
98
  # Extract text
99
  reader = PdfReader(str(pdf_path))
100
  total_pages = len(reader.pages)
101
+
102
  page_texts = []
103
  for page in reader.pages:
104
  text = page.extract_text() or ""
105
  # Handle surrogate characters
106
  text = self._sanitize_text(text)
107
  page_texts.append(text)
108
+
109
  # Convert to images in batches
110
  all_images = []
111
  for start_page in range(1, total_pages + 1, self.page_batch_size):
112
  end_page = min(start_page + self.page_batch_size - 1, total_pages)
113
+
114
  batch_images = convert_from_path(
115
  str(pdf_path),
116
  dpi=dpi,
 
118
  first_page=start_page,
119
  last_page=end_page,
120
  )
121
+
122
  all_images.extend(batch_images)
123
+
124
  del batch_images
125
  gc.collect()
126
+
127
+ assert len(all_images) == len(
128
+ page_texts
129
+ ), f"Mismatch: {len(all_images)} images vs {len(page_texts)} texts"
130
+
131
  logger.info(f"✅ Processed {len(all_images)} pages")
132
  return all_images, page_texts
133
+
134
  def stream_pdf(
135
  self,
136
  pdf_path: Path,
 
139
  ) -> Generator[Tuple[List[Image.Image], List[str], int], None, None]:
140
  """
141
  Stream PDF processing for large files.
142
+
143
  Yields batches of (images, texts, start_page) without loading
144
  entire PDF into memory.
145
+
146
  Args:
147
  pdf_path: Path to PDF file
148
  batch_size: Pages per batch
149
  dpi: Override default DPI
150
+
151
  Yields:
152
  Tuple of (batch_images, batch_texts, start_page_number)
153
  """
154
+ self._require_pdf_deps()
155
  from pdf2image import convert_from_path
156
  from pypdf import PdfReader
157
+
158
  dpi = dpi or self.dpi
159
  pdf_path = Path(pdf_path)
160
+
161
  reader = PdfReader(str(pdf_path))
162
  total_pages = len(reader.pages)
163
+
164
  logger.info(f"📄 Streaming PDF: {pdf_path.name} ({total_pages} pages)")
165
+
166
  for start_idx in range(0, total_pages, batch_size):
167
  end_idx = min(start_idx + batch_size, total_pages)
168
+
169
  # Extract text for batch
170
  batch_texts = []
171
  for page_idx in range(start_idx, end_idx):
172
  text = reader.pages[page_idx].extract_text() or ""
173
  text = self._sanitize_text(text)
174
  batch_texts.append(text)
175
+
176
  # Convert images for batch
177
  batch_images = convert_from_path(
178
  str(pdf_path),
 
181
  first_page=start_idx + 1, # 1-indexed
182
  last_page=end_idx,
183
  )
184
+
185
  yield batch_images, batch_texts, start_idx + 1
186
+
187
  del batch_images
188
  gc.collect()
189
+
190
  def get_page_count(self, pdf_path: Path) -> int:
191
  """Get number of pages in PDF without loading images."""
192
+ self._require_pdf_deps()
193
  from pypdf import PdfReader
194
+
195
  reader = PdfReader(str(pdf_path))
196
  return len(reader.pages)
197
+
198
  def resize_for_colpali(
199
  self,
200
  image: Image.Image,
 
203
  ) -> Tuple[Image.Image, int, int]:
204
  """
205
  Resize image following ColPali/Idefics3 processor logic.
206
+
207
  Resizes to fit within tile grid without black padding.
208
+
209
  Args:
210
  image: PIL Image
211
  max_edge: Maximum edge length
212
  tile_size: Size of each tile
213
+
214
  Returns:
215
  Tuple of (resized_image, tile_rows, tile_cols)
216
  """
217
+ # Ensure consistent mode for downstream processors (and predictable tests)
218
+ if image.mode != "RGB":
219
+ image = image.convert("RGB")
220
+
221
  w, h = image.size
222
+
223
  # Step 1: Resize so longest edge = max_edge
224
  if w > h:
225
  new_w = max_edge
 
227
  else:
228
  new_h = max_edge
229
  new_w = int(w * (max_edge / h))
230
+
231
  # Step 2: Calculate tile grid
232
  tile_cols = (new_w + tile_size - 1) // tile_size
233
  tile_rows = (new_h + tile_size - 1) // tile_size
234
+
235
  # Step 3: Calculate exact dimensions for tiles
236
  final_w = tile_cols * tile_size
237
  final_h = tile_rows * tile_size
238
+
239
  # Step 4: Scale to fit within tile grid
240
  scale_w = final_w / w
241
  scale_h = final_h / h
242
  scale = min(scale_w, scale_h)
243
+
244
  scaled_w = int(w * scale)
245
  scaled_h = int(h * scale)
246
+
247
  resized = image.resize((scaled_w, scaled_h), Image.LANCZOS)
248
+
249
  # Center on white canvas if needed
250
  if scaled_w != final_w or scaled_h != final_h:
251
  canvas = Image.new("RGB", (final_w, final_h), (255, 255, 255))
 
253
  offset_y = (final_h - scaled_h) // 2
254
  canvas.paste(resized, (offset_x, offset_y))
255
  resized = canvas
256
+
257
  return resized, tile_rows, tile_cols
258
+
259
  def _sanitize_text(self, text: str) -> str:
260
  """Remove invalid Unicode characters (surrogates) from text."""
261
  if not text:
262
  return ""
263
+
264
  # Remove surrogate characters (U+D800-U+DFFF)
265
+ return text.encode("utf-8", errors="surrogatepass").decode("utf-8", errors="ignore")
266
+
 
 
267
  def extract_metadata_from_filename(
268
  self,
269
  filename: str,
 
271
  ) -> Dict[str, Any]:
272
  """
273
  Extract metadata from PDF filename.
274
+
275
  Uses mapping if provided, otherwise falls back to pattern matching.
276
+
277
  Args:
278
  filename: PDF filename (with or without .pdf extension)
279
  mapping: Optional mapping dict {filename: metadata}
280
+
281
  Returns:
282
  Metadata dict with year, source, district, etc.
283
  """
284
  # Remove extension
285
  stem = Path(filename).stem
286
  stem_lower = stem.lower().strip()
287
+
288
  # Try mapping first
289
  if mapping:
290
  if stem_lower in mapping:
291
  return mapping[stem_lower].copy()
292
+
293
  # Try without .pdf
294
  stem_no_ext = stem_lower.replace(".pdf", "")
295
  if stem_no_ext in mapping:
296
  return mapping[stem_no_ext].copy()
297
+
298
  # Fallback: pattern matching
299
  metadata = {"filename": filename}
300
+
301
  # Extract year
302
  year_match = re.search(r"(20\d{2})", stem)
303
  if year_match:
304
  metadata["year"] = int(year_match.group(1))
305
+
306
  # Detect source type
307
  if "consolidated" in stem_lower or ("annual" in stem_lower and "oag" in stem_lower):
308
  metadata["source"] = "Consolidated"
309
  elif "dlg" in stem_lower or "district local government" in stem_lower:
310
  metadata["source"] = "Local Government"
311
  # Try to extract district name
312
+ district_match = re.search(r"([a-z]+)\s+(?:dlg|district local government)", stem_lower)
 
 
313
  if district_match:
314
  metadata["district"] = district_match.group(1).title()
315
  elif "hospital" in stem_lower or "referral" in stem_lower:
 
320
  metadata["source"] = "Project"
321
  else:
322
  metadata["source"] = "Unknown"
 
 
 
323
 
324
+ return metadata
visual_rag/indexing/pipeline.py CHANGED
@@ -16,11 +16,10 @@ The metadata stored includes everything needed for saliency visualization:
16
  """
17
 
18
  import gc
19
- import time
20
  import hashlib
21
  import logging
22
  from pathlib import Path
23
- from typing import Dict, Any, List, Optional, Set, Tuple
24
 
25
  import numpy as np
26
  import torch
@@ -31,7 +30,7 @@ logger = logging.getLogger(__name__)
31
  class ProcessingPipeline:
32
  """
33
  End-to-end pipeline for PDF processing and indexing.
34
-
35
  This pipeline:
36
  1. Converts PDFs to images
37
  2. Resizes for ColPali processing
@@ -39,7 +38,7 @@ class ProcessingPipeline:
39
  4. Computes pooling (strategy-dependent)
40
  5. Uploads images to Cloudinary (optional)
41
  6. Stores in Qdrant with full saliency metadata
42
-
43
  Args:
44
  embedder: VisualEmbedder instance
45
  indexer: QdrantIndexer instance (optional)
@@ -52,34 +51,34 @@ class ProcessingPipeline:
52
  This is our NOVEL contribution - preserves spatial structure while reducing size.
53
  - "standard": Push ALL tokens as-is (including special tokens, padding)
54
  This is the baseline approach for comparison.
55
-
56
  Example:
57
  >>> from visual_rag import VisualEmbedder, QdrantIndexer, CloudinaryUploader
58
  >>> from visual_rag.indexing.pipeline import ProcessingPipeline
59
- >>>
60
  >>> # Our novel pooling strategy (default)
61
  >>> pipeline = ProcessingPipeline(
62
  ... embedder=VisualEmbedder(),
63
  ... indexer=QdrantIndexer(url, api_key, "my_collection"),
64
  ... embedding_strategy="pooling", # Visual tokens only + tile pooling
65
  ... )
66
- >>>
67
  >>> # Standard baseline (all tokens, no filtering)
68
  >>> pipeline_baseline = ProcessingPipeline(
69
  ... embedder=VisualEmbedder(),
70
  ... indexer=QdrantIndexer(url, api_key, "my_collection_baseline"),
71
  ... embedding_strategy="standard", # All tokens as-is
72
  ... )
73
- >>>
74
  >>> pipeline.process_pdf(Path("report.pdf"))
75
  """
76
-
77
  # Valid embedding strategies
78
  # - "pooling": Visual tokens only + tile-level pooling (NOVEL)
79
  # - "standard": All tokens + global mean (BASELINE)
80
  # - "all": Embed once, push BOTH representations (efficient comparison)
81
  STRATEGIES = ["pooling", "standard", "all"]
82
-
83
  def __init__(
84
  self,
85
  embedder=None,
@@ -92,13 +91,15 @@ class ProcessingPipeline:
92
  crop_empty: bool = False,
93
  crop_empty_percentage_to_remove: float = 0.9,
94
  crop_empty_remove_page_number: bool = False,
 
 
95
  ):
96
  self.embedder = embedder
97
  self.indexer = indexer
98
  self.cloudinary_uploader = cloudinary_uploader
99
  self.metadata_mapping = metadata_mapping or {}
100
  self.config = config or {}
101
-
102
  # Validate and set embedding strategy
103
  if embedding_strategy not in self.STRATEGIES:
104
  raise ValueError(
@@ -110,41 +111,50 @@ class ProcessingPipeline:
110
  self.crop_empty = bool(crop_empty)
111
  self.crop_empty_percentage_to_remove = float(crop_empty_percentage_to_remove)
112
  self.crop_empty_remove_page_number = bool(crop_empty_remove_page_number)
113
-
 
 
 
 
114
  logger.info(f"📊 Embedding strategy: {embedding_strategy}")
115
  if embedding_strategy == "pooling":
116
  logger.info(" → Visual tokens only + tile-level mean pooling (NOVEL)")
117
  else:
118
  logger.info(" → All tokens as-is (BASELINE)")
119
-
120
  # Create PDF processor if not provided
121
  if pdf_processor is None:
122
  from visual_rag.indexing.pdf_processor import PDFProcessor
 
123
  dpi = self.config.get("processing", {}).get("dpi", 140)
124
  pdf_processor = PDFProcessor(dpi=dpi)
125
  self.pdf_processor = pdf_processor
126
-
127
  # Config defaults
128
  self.embedding_batch_size = self.config.get("batching", {}).get("embedding_batch_size", 8)
129
  self.upload_batch_size = self.config.get("batching", {}).get("upload_batch_size", 8)
130
  self.delay_between_uploads = self.config.get("delays", {}).get("between_uploads", 0.5)
131
-
132
  def process_pdf(
133
  self,
134
  pdf_path: Path,
135
  skip_existing: bool = True,
136
  upload_to_cloudinary: bool = True,
137
  upload_to_qdrant: bool = True,
 
 
138
  ) -> Dict[str, Any]:
139
  """
140
  Process a single PDF end-to-end.
141
-
142
  Args:
143
  pdf_path: Path to PDF file
144
  skip_existing: Skip pages that already exist in Qdrant
145
  upload_to_cloudinary: Upload images to Cloudinary
146
  upload_to_qdrant: Upload embeddings to Qdrant
147
-
 
 
148
  Returns:
149
  Dict with processing results:
150
  {
@@ -157,62 +167,73 @@ class ProcessingPipeline:
157
  }
158
  """
159
  pdf_path = Path(pdf_path)
160
- logger.info(f"📚 Processing PDF: {pdf_path.name}")
161
-
 
162
  # Check existing pages
163
  existing_ids: Set[str] = set()
164
  if skip_existing and self.indexer:
165
- existing_ids = self.indexer.get_existing_ids(pdf_path.name)
166
  if existing_ids:
167
  logger.info(f" Found {len(existing_ids)} existing pages")
168
-
169
- # Convert PDF to images
170
- logger.info(f"🖼️ Converting PDF to images...")
 
171
  images, texts = self.pdf_processor.process_pdf(pdf_path)
172
  total_pages = len(images)
173
  logger.info(f" ✅ Converted {total_pages} pages")
174
-
175
- # Get extra metadata
176
- extra_metadata = self._get_extra_metadata(pdf_path.name)
 
177
  if extra_metadata:
178
  logger.info(f" 📋 Found extra metadata: {list(extra_metadata.keys())}")
179
-
180
  # Process in batches
181
  uploaded = 0
182
  skipped = 0
183
  failed = 0
184
  all_pages = []
185
  upload_queue = []
186
-
187
  for batch_start in range(0, total_pages, self.embedding_batch_size):
188
  batch_end = min(batch_start + self.embedding_batch_size, total_pages)
189
  batch_images = images[batch_start:batch_end]
190
  batch_texts = texts[batch_start:batch_end]
191
-
192
  logger.info(f"📦 Processing pages {batch_start + 1}-{batch_end}/{total_pages}")
193
-
194
- # Filter pages that need processing
 
 
 
 
 
 
195
  pages_to_process = []
196
  for i, (img, text) in enumerate(zip(batch_images, batch_texts)):
197
  page_num = batch_start + i + 1
198
- chunk_id = self.generate_chunk_id(pdf_path.name, page_num)
199
-
200
  if skip_existing and chunk_id in existing_ids:
201
  skipped += 1
202
  continue
203
-
204
- pages_to_process.append({
205
- "index": i,
206
- "page_num": page_num,
207
- "chunk_id": chunk_id,
208
- "raw_image": img,
209
- "text": text,
210
- })
211
-
 
 
212
  if not pages_to_process:
213
  logger.info(" All pages in batch exist, skipping...")
214
  continue
215
-
216
  # Generate embeddings with token info
217
  logger.info(f"🤖 Generating embeddings for {len(pages_to_process)} pages...")
218
  from visual_rag.preprocessing.crop_empty import CropEmptyConfig, crop_empty
@@ -226,6 +247,10 @@ class ProcessingPipeline:
226
  config=CropEmptyConfig(
227
  percentage_to_remove=float(self.crop_empty_percentage_to_remove),
228
  remove_page_number=bool(self.crop_empty_remove_page_number),
 
 
 
 
229
  ),
230
  )
231
  p["embed_image"] = cropped_img
@@ -235,15 +260,14 @@ class ProcessingPipeline:
235
  p["embed_image"] = raw_img
236
  p["crop_meta"] = None
237
  images_to_embed.append(raw_img)
238
-
239
  embeddings, token_infos = self.embedder.embed_images(
240
  images_to_embed,
241
  batch_size=self.embedding_batch_size,
242
  return_token_info=True,
243
- show_progress=True,
244
  )
245
-
246
- # Process each page
247
  for idx, page_info in enumerate(pages_to_process):
248
  raw_img = page_info["raw_image"]
249
  embed_img = page_info["embed_image"]
@@ -253,10 +277,19 @@ class ProcessingPipeline:
253
  text = page_info["text"]
254
  embedding = embeddings[idx]
255
  token_info = token_infos[idx]
256
-
 
 
 
 
 
 
 
 
257
  try:
258
  page_data = self._process_single_page(
259
- pdf_path=pdf_path,
 
260
  page_num=page_num,
261
  chunk_id=chunk_id,
262
  total_pages=total_pages,
@@ -269,46 +302,49 @@ class ProcessingPipeline:
269
  upload_to_cloudinary=upload_to_cloudinary,
270
  crop_meta=crop_meta,
271
  )
272
-
273
  all_pages.append(page_data)
274
-
275
  if upload_to_qdrant and self.indexer:
276
  upload_queue.append(page_data)
277
-
278
  # Upload in batches
279
  if len(upload_queue) >= self.upload_batch_size:
280
  count = self._upload_batch(upload_queue)
281
  uploaded += count
282
  upload_queue = []
283
-
284
  except Exception as e:
285
  logger.error(f" ❌ Failed page {page_num}: {e}")
286
  failed += 1
287
-
288
  # Memory cleanup
289
  gc.collect()
290
  if torch.cuda.is_available():
291
  torch.cuda.empty_cache()
292
-
293
  # Upload remaining pages
294
  if upload_queue and upload_to_qdrant and self.indexer:
295
  count = self._upload_batch(upload_queue)
296
  uploaded += count
297
-
298
- logger.info(f"✅ Completed {pdf_path.name}: {uploaded} uploaded, {skipped} skipped, {failed} failed")
299
-
 
 
300
  return {
301
- "filename": pdf_path.name,
302
  "total_pages": total_pages,
303
  "uploaded": uploaded,
304
  "skipped": skipped,
305
  "failed": failed,
306
  "pages": all_pages,
307
  }
308
-
309
  def _process_single_page(
310
  self,
311
- pdf_path: Path,
 
312
  page_num: int,
313
  chunk_id: str,
314
  total_pages: int,
@@ -323,17 +359,17 @@ class ProcessingPipeline:
323
  ) -> Dict[str, Any]:
324
  """Process a single page with full metadata for saliency."""
325
  from visual_rag.embedding.pooling import global_mean_pooling
326
-
327
  # Resize image for ColPali
328
  resized_img, tile_rows, tile_cols = self.pdf_processor.resize_for_colpali(embed_img)
329
-
330
  # Use processor's tile info if available (more accurate)
331
  proc_n_rows = token_info.get("n_rows")
332
  proc_n_cols = token_info.get("n_cols")
333
  if proc_n_rows and proc_n_cols:
334
  tile_rows = proc_n_rows
335
  tile_cols = proc_n_cols
336
-
337
  # Convert embedding to numpy
338
  if isinstance(embedding, torch.Tensor):
339
  if embedding.dtype == torch.bfloat16:
@@ -343,24 +379,30 @@ class ProcessingPipeline:
343
  else:
344
  full_embedding = np.array(embedding)
345
  full_embedding = full_embedding.astype(np.float32)
346
-
347
  # Token info for metadata
348
  visual_indices = token_info["visual_token_indices"]
349
  num_visual_tokens = token_info["num_visual_tokens"]
350
-
351
  # =========================================================================
352
  # STRATEGY: "pooling" (NOVEL) vs "standard" (BASELINE) vs "all" (BOTH)
353
  # =========================================================================
354
-
355
  # Always compute visual-only embedding (needed for pooling and saliency)
356
  visual_embedding = full_embedding[visual_indices]
357
-
358
- tile_pooled = self.embedder.mean_pool_visual_embedding(visual_embedding, token_info, target_vectors=32)
 
 
359
  experimental_pooled = self.embedder.experimental_pool_visual_embedding(
360
  visual_embedding, token_info, target_vectors=32, mean_pool=tile_pooled
361
  )
362
  global_pooled = global_mean_pooling(full_embedding)
363
- global_pooling = self.embedder.global_pool_from_mean_pool(tile_pooled) if tile_pooled.size else global_pooled
 
 
 
 
364
 
365
  num_tiles = int(tile_pooled.shape[0])
366
  patches_per_tile = int(visual_embedding.shape[0] // max(num_tiles, 1)) if num_tiles else 0
@@ -369,64 +411,70 @@ class ProcessingPipeline:
369
  else:
370
  tile_rows = token_info.get("n_rows") or None
371
  tile_cols = token_info.get("n_cols") or None
372
-
373
  if self.embedding_strategy == "pooling":
374
  # NOVEL APPROACH: Visual tokens only + tile-level pooling
375
  embedding_for_initial = visual_embedding
376
  embedding_for_pooling = tile_pooled
377
- global_pooling = self.embedder.global_pool_from_mean_pool(tile_pooled) if tile_pooled.size else global_pooled
378
-
 
 
 
 
379
  elif self.embedding_strategy == "standard":
380
  # BASELINE: All tokens + global mean
381
  embedding_for_initial = full_embedding
382
  embedding_for_pooling = global_pooled.reshape(1, -1)
383
  global_pooling = global_pooled
384
-
385
  else: # "all" - Push BOTH representations (efficient for comparison)
386
  # Embed once, store multiple vector representations
387
  # This allows comparing both strategies without re-embedding
388
  embedding_for_initial = visual_embedding # Use visual for search
389
- embedding_for_pooling = tile_pooled # Use tile-level for fast prefetch
390
- global_pooling = self.embedder.global_pool_from_mean_pool(tile_pooled) if tile_pooled.size else global_pooled
391
-
 
 
 
 
392
  # ALSO store standard representations as additional vectors
393
  # These will be added to metadata for optional use
394
  pass # Extra vectors handled in return dict below
395
-
396
  # Upload to Cloudinary
397
  original_url = None
398
  cropped_url = None
399
  resized_url = None
400
-
401
  if upload_to_cloudinary and self.cloudinary_uploader:
402
- base_filename = f"{pdf_path.stem}_page_{page_num}"
403
  if self.crop_empty:
404
- original_url, cropped_url, resized_url = self.cloudinary_uploader.upload_original_cropped_and_resized(
405
- raw_img, embed_img, resized_img, base_filename
 
 
406
  )
407
  else:
408
  original_url, resized_url = self.cloudinary_uploader.upload_original_and_resized(
409
  raw_img, resized_img, base_filename
410
  )
411
-
412
  # Sanitize text
413
  safe_text = self._sanitize_text(text[:10000]) if text else ""
414
-
415
- # Build metadata (everything needed for saliency)
416
  metadata = {
417
- # Document info
418
- "filename": pdf_path.name,
419
  "page_number": page_num,
420
  "total_pages": total_pages,
421
  "has_text": bool(text and text.strip()),
422
  "text": safe_text,
423
-
424
  # Image URLs
425
  "page": resized_url or "", # For display
426
  "original_url": original_url or "",
427
  "cropped_url": cropped_url or "",
428
  "resized_url": resized_url or "",
429
-
430
  # Dimensions (needed for saliency overlay)
431
  "original_width": raw_img.width,
432
  "original_height": raw_img.height,
@@ -434,35 +482,33 @@ class ProcessingPipeline:
434
  "cropped_height": int(embed_img.height) if self.crop_empty else int(raw_img.height),
435
  "resized_width": resized_img.width,
436
  "resized_height": resized_img.height,
437
-
438
  # Tile structure (needed for saliency)
439
  "num_tiles": num_tiles,
440
  "tile_rows": tile_rows,
441
  "tile_cols": tile_cols,
442
  "patches_per_tile": patches_per_tile,
443
-
444
  # Token info (needed for saliency)
445
  "num_visual_tokens": num_visual_tokens,
446
  "visual_token_indices": visual_indices,
447
  "total_tokens": len(full_embedding), # Total tokens in raw embedding
448
-
449
  # Strategy used (important for paper comparison)
450
  "embedding_strategy": self.embedding_strategy,
451
-
452
  "model_name": getattr(self.embedder, "model_name", None),
453
-
454
  "crop_empty_enabled": bool(self.crop_empty),
455
  "crop_empty_crop_box": (crop_meta or {}).get("crop_box"),
456
  "crop_empty_remove_page_number": bool(self.crop_empty_remove_page_number),
457
  "crop_empty_percentage_to_remove": float(self.crop_empty_percentage_to_remove),
458
-
 
 
 
459
  # Extra metadata (year, district, etc.)
460
  **extra_metadata,
461
  }
462
-
463
  result = {
464
  "id": chunk_id,
465
- "visual_embedding": embedding_for_initial, # "initial" vector in Qdrant
466
  "tile_pooled_embedding": embedding_for_pooling, # "mean_pooling" vector in Qdrant
467
  "experimental_pooled_embedding": experimental_pooled, # "experimental_pooling" vector in Qdrant
468
  "global_pooled_embedding": global_pooling, # "global_pooling" vector in Qdrant
@@ -470,70 +516,70 @@ class ProcessingPipeline:
470
  "image": raw_img,
471
  "resized_image": resized_img,
472
  }
473
-
474
  # For "all" strategy, include BOTH representations for comparison
475
  if self.embedding_strategy == "all":
476
  result["extra_vectors"] = {
477
  # Standard baseline vectors (for comparison)
478
- "full_embedding": full_embedding, # All tokens [total, 128]
479
- "global_pooled": global_pooled, # Global mean [128]
480
  # Pooling vectors (already in main result)
481
- "visual_embedding": visual_embedding, # Visual only [visual, 128]
482
- "tile_pooled": tile_pooled, # Tile-level [tiles, 128]
483
  }
484
-
485
  return result
486
-
487
  def _upload_batch(self, upload_queue: List[Dict[str, Any]]) -> int:
488
  """Upload batch to Qdrant."""
489
  if not upload_queue or not self.indexer:
490
  return 0
491
-
492
  logger.info(f"📤 Uploading batch of {len(upload_queue)} pages...")
493
-
494
  count = self.indexer.upload_batch(
495
  upload_queue,
496
  delay_between_batches=self.delay_between_uploads,
497
  )
498
-
499
  return count
500
-
501
  def _get_extra_metadata(self, filename: str) -> Dict[str, Any]:
502
  """Get extra metadata for a filename."""
503
  if not self.metadata_mapping:
504
  return {}
505
-
506
  # Normalize filename
507
  filename_clean = filename.replace(".pdf", "").replace(".PDF", "").strip().lower()
508
-
509
  # Try exact match
510
  if filename_clean in self.metadata_mapping:
511
  return self.metadata_mapping[filename_clean].copy()
512
-
513
  # Try fuzzy match
514
  from difflib import SequenceMatcher
515
-
516
  best_match = None
517
  best_score = 0.0
518
-
519
  for known_filename, metadata in self.metadata_mapping.items():
520
  score = SequenceMatcher(None, filename_clean, known_filename.lower()).ratio()
521
  if score > best_score and score > 0.75:
522
  best_score = score
523
  best_match = metadata
524
-
525
  if best_match:
526
  logger.debug(f"Fuzzy matched '{filename}' with score {best_score:.2f}")
527
  return best_match.copy()
528
-
529
  return {}
530
-
531
  def _sanitize_text(self, text: str) -> str:
532
  """Remove invalid Unicode characters."""
533
  if not text:
534
  return ""
535
  return text.encode("utf-8", errors="surrogatepass").decode("utf-8", errors="ignore")
536
-
537
  @staticmethod
538
  def generate_chunk_id(filename: str, page_number: int) -> str:
539
  """Generate deterministic chunk ID."""
@@ -541,12 +587,12 @@ class ProcessingPipeline:
541
  hash_obj = hashlib.sha256(content.encode())
542
  hex_str = hash_obj.hexdigest()[:32]
543
  return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
544
-
545
  @staticmethod
546
  def load_metadata_mapping(json_path: Path) -> Dict[str, Dict[str, Any]]:
547
  """
548
  Load metadata mapping from JSON file.
549
-
550
  Expected format:
551
  {
552
  "filenames": {
@@ -554,7 +600,7 @@ class ProcessingPipeline:
554
  ...
555
  }
556
  }
557
-
558
  Or simple format:
559
  {
560
  "Report Name 2023": {"year": 2023, "source": "Local Government", ...},
@@ -562,22 +608,21 @@ class ProcessingPipeline:
562
  }
563
  """
564
  import json
565
-
566
  with open(json_path, "r") as f:
567
  data = json.load(f)
568
-
569
  # Check if nested under "filenames"
570
  if "filenames" in data and isinstance(data["filenames"], dict):
571
  mapping = data["filenames"]
572
  else:
573
  mapping = data
574
-
575
  # Normalize keys to lowercase
576
  normalized = {}
577
  for filename, metadata in mapping.items():
578
  key = filename.lower().strip().replace(".pdf", "")
579
  normalized[key] = metadata
580
-
581
  logger.info(f"📖 Loaded metadata for {len(normalized)} files")
582
  return normalized
583
-
 
16
  """
17
 
18
  import gc
 
19
  import hashlib
20
  import logging
21
  from pathlib import Path
22
+ from typing import Any, Dict, List, Optional, Set
23
 
24
  import numpy as np
25
  import torch
 
30
  class ProcessingPipeline:
31
  """
32
  End-to-end pipeline for PDF processing and indexing.
33
+
34
  This pipeline:
35
  1. Converts PDFs to images
36
  2. Resizes for ColPali processing
 
38
  4. Computes pooling (strategy-dependent)
39
  5. Uploads images to Cloudinary (optional)
40
  6. Stores in Qdrant with full saliency metadata
41
+
42
  Args:
43
  embedder: VisualEmbedder instance
44
  indexer: QdrantIndexer instance (optional)
 
51
  This is our NOVEL contribution - preserves spatial structure while reducing size.
52
  - "standard": Push ALL tokens as-is (including special tokens, padding)
53
  This is the baseline approach for comparison.
54
+
55
  Example:
56
  >>> from visual_rag import VisualEmbedder, QdrantIndexer, CloudinaryUploader
57
  >>> from visual_rag.indexing.pipeline import ProcessingPipeline
58
+ >>>
59
  >>> # Our novel pooling strategy (default)
60
  >>> pipeline = ProcessingPipeline(
61
  ... embedder=VisualEmbedder(),
62
  ... indexer=QdrantIndexer(url, api_key, "my_collection"),
63
  ... embedding_strategy="pooling", # Visual tokens only + tile pooling
64
  ... )
65
+ >>>
66
  >>> # Standard baseline (all tokens, no filtering)
67
  >>> pipeline_baseline = ProcessingPipeline(
68
  ... embedder=VisualEmbedder(),
69
  ... indexer=QdrantIndexer(url, api_key, "my_collection_baseline"),
70
  ... embedding_strategy="standard", # All tokens as-is
71
  ... )
72
+ >>>
73
  >>> pipeline.process_pdf(Path("report.pdf"))
74
  """
75
+
76
  # Valid embedding strategies
77
  # - "pooling": Visual tokens only + tile-level pooling (NOVEL)
78
  # - "standard": All tokens + global mean (BASELINE)
79
  # - "all": Embed once, push BOTH representations (efficient comparison)
80
  STRATEGIES = ["pooling", "standard", "all"]
81
+
82
  def __init__(
83
  self,
84
  embedder=None,
 
91
  crop_empty: bool = False,
92
  crop_empty_percentage_to_remove: float = 0.9,
93
  crop_empty_remove_page_number: bool = False,
94
+ crop_empty_preserve_border_px: int = 1,
95
+ crop_empty_uniform_rowcol_std_threshold: float = 0.0,
96
  ):
97
  self.embedder = embedder
98
  self.indexer = indexer
99
  self.cloudinary_uploader = cloudinary_uploader
100
  self.metadata_mapping = metadata_mapping or {}
101
  self.config = config or {}
102
+
103
  # Validate and set embedding strategy
104
  if embedding_strategy not in self.STRATEGIES:
105
  raise ValueError(
 
111
  self.crop_empty = bool(crop_empty)
112
  self.crop_empty_percentage_to_remove = float(crop_empty_percentage_to_remove)
113
  self.crop_empty_remove_page_number = bool(crop_empty_remove_page_number)
114
+ self.crop_empty_preserve_border_px = int(crop_empty_preserve_border_px)
115
+ self.crop_empty_uniform_rowcol_std_threshold = float(
116
+ crop_empty_uniform_rowcol_std_threshold
117
+ )
118
+
119
  logger.info(f"📊 Embedding strategy: {embedding_strategy}")
120
  if embedding_strategy == "pooling":
121
  logger.info(" → Visual tokens only + tile-level mean pooling (NOVEL)")
122
  else:
123
  logger.info(" → All tokens as-is (BASELINE)")
124
+
125
  # Create PDF processor if not provided
126
  if pdf_processor is None:
127
  from visual_rag.indexing.pdf_processor import PDFProcessor
128
+
129
  dpi = self.config.get("processing", {}).get("dpi", 140)
130
  pdf_processor = PDFProcessor(dpi=dpi)
131
  self.pdf_processor = pdf_processor
132
+
133
  # Config defaults
134
  self.embedding_batch_size = self.config.get("batching", {}).get("embedding_batch_size", 8)
135
  self.upload_batch_size = self.config.get("batching", {}).get("upload_batch_size", 8)
136
  self.delay_between_uploads = self.config.get("delays", {}).get("between_uploads", 0.5)
137
+
138
  def process_pdf(
139
  self,
140
  pdf_path: Path,
141
  skip_existing: bool = True,
142
  upload_to_cloudinary: bool = True,
143
  upload_to_qdrant: bool = True,
144
+ original_filename: Optional[str] = None,
145
+ progress_callback: Optional[callable] = None,
146
  ) -> Dict[str, Any]:
147
  """
148
  Process a single PDF end-to-end.
149
+
150
  Args:
151
  pdf_path: Path to PDF file
152
  skip_existing: Skip pages that already exist in Qdrant
153
  upload_to_cloudinary: Upload images to Cloudinary
154
  upload_to_qdrant: Upload embeddings to Qdrant
155
+ original_filename: Original filename (use this instead of pdf_path.name for temp files)
156
+ progress_callback: Optional callback(stage, current, total, message) for progress updates
157
+
158
  Returns:
159
  Dict with processing results:
160
  {
 
167
  }
168
  """
169
  pdf_path = Path(pdf_path)
170
+ filename = original_filename or pdf_path.name
171
+ logger.info(f"📚 Processing PDF: {filename}")
172
+
173
  # Check existing pages
174
  existing_ids: Set[str] = set()
175
  if skip_existing and self.indexer:
176
+ existing_ids = self.indexer.get_existing_ids(filename)
177
  if existing_ids:
178
  logger.info(f" Found {len(existing_ids)} existing pages")
179
+
180
+ logger.info("🖼️ Converting PDF to images...")
181
+ if progress_callback:
182
+ progress_callback("convert", 0, 0, "Converting PDF to images...")
183
  images, texts = self.pdf_processor.process_pdf(pdf_path)
184
  total_pages = len(images)
185
  logger.info(f" ✅ Converted {total_pages} pages")
186
+ if progress_callback:
187
+ progress_callback("convert", total_pages, total_pages, f"Converted {total_pages} pages")
188
+
189
+ extra_metadata = self._get_extra_metadata(filename)
190
  if extra_metadata:
191
  logger.info(f" 📋 Found extra metadata: {list(extra_metadata.keys())}")
192
+
193
  # Process in batches
194
  uploaded = 0
195
  skipped = 0
196
  failed = 0
197
  all_pages = []
198
  upload_queue = []
199
+
200
  for batch_start in range(0, total_pages, self.embedding_batch_size):
201
  batch_end = min(batch_start + self.embedding_batch_size, total_pages)
202
  batch_images = images[batch_start:batch_end]
203
  batch_texts = texts[batch_start:batch_end]
204
+
205
  logger.info(f"📦 Processing pages {batch_start + 1}-{batch_end}/{total_pages}")
206
+ if progress_callback:
207
+ progress_callback(
208
+ "embed",
209
+ batch_start,
210
+ total_pages,
211
+ f"Embedding pages {batch_start + 1}-{batch_end}",
212
+ )
213
+
214
  pages_to_process = []
215
  for i, (img, text) in enumerate(zip(batch_images, batch_texts)):
216
  page_num = batch_start + i + 1
217
+ chunk_id = self.generate_chunk_id(filename, page_num)
218
+
219
  if skip_existing and chunk_id in existing_ids:
220
  skipped += 1
221
  continue
222
+
223
+ pages_to_process.append(
224
+ {
225
+ "index": i,
226
+ "page_num": page_num,
227
+ "chunk_id": chunk_id,
228
+ "raw_image": img,
229
+ "text": text,
230
+ }
231
+ )
232
+
233
  if not pages_to_process:
234
  logger.info(" All pages in batch exist, skipping...")
235
  continue
236
+
237
  # Generate embeddings with token info
238
  logger.info(f"🤖 Generating embeddings for {len(pages_to_process)} pages...")
239
  from visual_rag.preprocessing.crop_empty import CropEmptyConfig, crop_empty
 
247
  config=CropEmptyConfig(
248
  percentage_to_remove=float(self.crop_empty_percentage_to_remove),
249
  remove_page_number=bool(self.crop_empty_remove_page_number),
250
+ preserve_border_px=int(self.crop_empty_preserve_border_px),
251
+ uniform_rowcol_std_threshold=float(
252
+ self.crop_empty_uniform_rowcol_std_threshold
253
+ ),
254
  ),
255
  )
256
  p["embed_image"] = cropped_img
 
260
  p["embed_image"] = raw_img
261
  p["crop_meta"] = None
262
  images_to_embed.append(raw_img)
263
+
264
  embeddings, token_infos = self.embedder.embed_images(
265
  images_to_embed,
266
  batch_size=self.embedding_batch_size,
267
  return_token_info=True,
268
+ show_progress=False,
269
  )
270
+
 
271
  for idx, page_info in enumerate(pages_to_process):
272
  raw_img = page_info["raw_image"]
273
  embed_img = page_info["embed_image"]
 
277
  text = page_info["text"]
278
  embedding = embeddings[idx]
279
  token_info = token_infos[idx]
280
+
281
+ if progress_callback:
282
+ progress_callback(
283
+ "process",
284
+ page_num,
285
+ total_pages,
286
+ f"Processing page {page_num}/{total_pages}",
287
+ )
288
+
289
  try:
290
  page_data = self._process_single_page(
291
+ filename=filename,
292
+ pdf_stem=pdf_path.stem,
293
  page_num=page_num,
294
  chunk_id=chunk_id,
295
  total_pages=total_pages,
 
302
  upload_to_cloudinary=upload_to_cloudinary,
303
  crop_meta=crop_meta,
304
  )
305
+
306
  all_pages.append(page_data)
307
+
308
  if upload_to_qdrant and self.indexer:
309
  upload_queue.append(page_data)
310
+
311
  # Upload in batches
312
  if len(upload_queue) >= self.upload_batch_size:
313
  count = self._upload_batch(upload_queue)
314
  uploaded += count
315
  upload_queue = []
316
+
317
  except Exception as e:
318
  logger.error(f" ❌ Failed page {page_num}: {e}")
319
  failed += 1
320
+
321
  # Memory cleanup
322
  gc.collect()
323
  if torch.cuda.is_available():
324
  torch.cuda.empty_cache()
325
+
326
  # Upload remaining pages
327
  if upload_queue and upload_to_qdrant and self.indexer:
328
  count = self._upload_batch(upload_queue)
329
  uploaded += count
330
+
331
+ logger.info(
332
+ f"✅ Completed {filename}: {uploaded} uploaded, {skipped} skipped, {failed} failed"
333
+ )
334
+
335
  return {
336
+ "filename": filename,
337
  "total_pages": total_pages,
338
  "uploaded": uploaded,
339
  "skipped": skipped,
340
  "failed": failed,
341
  "pages": all_pages,
342
  }
343
+
344
  def _process_single_page(
345
  self,
346
+ filename: str,
347
+ pdf_stem: str,
348
  page_num: int,
349
  chunk_id: str,
350
  total_pages: int,
 
359
  ) -> Dict[str, Any]:
360
  """Process a single page with full metadata for saliency."""
361
  from visual_rag.embedding.pooling import global_mean_pooling
362
+
363
  # Resize image for ColPali
364
  resized_img, tile_rows, tile_cols = self.pdf_processor.resize_for_colpali(embed_img)
365
+
366
  # Use processor's tile info if available (more accurate)
367
  proc_n_rows = token_info.get("n_rows")
368
  proc_n_cols = token_info.get("n_cols")
369
  if proc_n_rows and proc_n_cols:
370
  tile_rows = proc_n_rows
371
  tile_cols = proc_n_cols
372
+
373
  # Convert embedding to numpy
374
  if isinstance(embedding, torch.Tensor):
375
  if embedding.dtype == torch.bfloat16:
 
379
  else:
380
  full_embedding = np.array(embedding)
381
  full_embedding = full_embedding.astype(np.float32)
382
+
383
  # Token info for metadata
384
  visual_indices = token_info["visual_token_indices"]
385
  num_visual_tokens = token_info["num_visual_tokens"]
386
+
387
  # =========================================================================
388
  # STRATEGY: "pooling" (NOVEL) vs "standard" (BASELINE) vs "all" (BOTH)
389
  # =========================================================================
390
+
391
  # Always compute visual-only embedding (needed for pooling and saliency)
392
  visual_embedding = full_embedding[visual_indices]
393
+
394
+ tile_pooled = self.embedder.mean_pool_visual_embedding(
395
+ visual_embedding, token_info, target_vectors=32
396
+ )
397
  experimental_pooled = self.embedder.experimental_pool_visual_embedding(
398
  visual_embedding, token_info, target_vectors=32, mean_pool=tile_pooled
399
  )
400
  global_pooled = global_mean_pooling(full_embedding)
401
+ global_pooling = (
402
+ self.embedder.global_pool_from_mean_pool(tile_pooled)
403
+ if tile_pooled.size
404
+ else global_pooled
405
+ )
406
 
407
  num_tiles = int(tile_pooled.shape[0])
408
  patches_per_tile = int(visual_embedding.shape[0] // max(num_tiles, 1)) if num_tiles else 0
 
411
  else:
412
  tile_rows = token_info.get("n_rows") or None
413
  tile_cols = token_info.get("n_cols") or None
414
+
415
  if self.embedding_strategy == "pooling":
416
  # NOVEL APPROACH: Visual tokens only + tile-level pooling
417
  embedding_for_initial = visual_embedding
418
  embedding_for_pooling = tile_pooled
419
+ global_pooling = (
420
+ self.embedder.global_pool_from_mean_pool(tile_pooled)
421
+ if tile_pooled.size
422
+ else global_pooled
423
+ )
424
+
425
  elif self.embedding_strategy == "standard":
426
  # BASELINE: All tokens + global mean
427
  embedding_for_initial = full_embedding
428
  embedding_for_pooling = global_pooled.reshape(1, -1)
429
  global_pooling = global_pooled
430
+
431
  else: # "all" - Push BOTH representations (efficient for comparison)
432
  # Embed once, store multiple vector representations
433
  # This allows comparing both strategies without re-embedding
434
  embedding_for_initial = visual_embedding # Use visual for search
435
+ embedding_for_pooling = tile_pooled # Use tile-level for fast prefetch
436
+ global_pooling = (
437
+ self.embedder.global_pool_from_mean_pool(tile_pooled)
438
+ if tile_pooled.size
439
+ else global_pooled
440
+ )
441
+
442
  # ALSO store standard representations as additional vectors
443
  # These will be added to metadata for optional use
444
  pass # Extra vectors handled in return dict below
445
+
446
  # Upload to Cloudinary
447
  original_url = None
448
  cropped_url = None
449
  resized_url = None
450
+
451
  if upload_to_cloudinary and self.cloudinary_uploader:
452
+ base_filename = f"{pdf_stem}_page_{page_num}"
453
  if self.crop_empty:
454
+ original_url, cropped_url, resized_url = (
455
+ self.cloudinary_uploader.upload_original_cropped_and_resized(
456
+ raw_img, embed_img, resized_img, base_filename
457
+ )
458
  )
459
  else:
460
  original_url, resized_url = self.cloudinary_uploader.upload_original_and_resized(
461
  raw_img, resized_img, base_filename
462
  )
463
+
464
  # Sanitize text
465
  safe_text = self._sanitize_text(text[:10000]) if text else ""
466
+
 
467
  metadata = {
468
+ "filename": filename,
 
469
  "page_number": page_num,
470
  "total_pages": total_pages,
471
  "has_text": bool(text and text.strip()),
472
  "text": safe_text,
 
473
  # Image URLs
474
  "page": resized_url or "", # For display
475
  "original_url": original_url or "",
476
  "cropped_url": cropped_url or "",
477
  "resized_url": resized_url or "",
 
478
  # Dimensions (needed for saliency overlay)
479
  "original_width": raw_img.width,
480
  "original_height": raw_img.height,
 
482
  "cropped_height": int(embed_img.height) if self.crop_empty else int(raw_img.height),
483
  "resized_width": resized_img.width,
484
  "resized_height": resized_img.height,
 
485
  # Tile structure (needed for saliency)
486
  "num_tiles": num_tiles,
487
  "tile_rows": tile_rows,
488
  "tile_cols": tile_cols,
489
  "patches_per_tile": patches_per_tile,
 
490
  # Token info (needed for saliency)
491
  "num_visual_tokens": num_visual_tokens,
492
  "visual_token_indices": visual_indices,
493
  "total_tokens": len(full_embedding), # Total tokens in raw embedding
 
494
  # Strategy used (important for paper comparison)
495
  "embedding_strategy": self.embedding_strategy,
 
496
  "model_name": getattr(self.embedder, "model_name", None),
 
497
  "crop_empty_enabled": bool(self.crop_empty),
498
  "crop_empty_crop_box": (crop_meta or {}).get("crop_box"),
499
  "crop_empty_remove_page_number": bool(self.crop_empty_remove_page_number),
500
  "crop_empty_percentage_to_remove": float(self.crop_empty_percentage_to_remove),
501
+ "crop_empty_preserve_border_px": int(self.crop_empty_preserve_border_px),
502
+ "crop_empty_uniform_rowcol_std_threshold": float(
503
+ self.crop_empty_uniform_rowcol_std_threshold
504
+ ),
505
  # Extra metadata (year, district, etc.)
506
  **extra_metadata,
507
  }
508
+
509
  result = {
510
  "id": chunk_id,
511
+ "visual_embedding": embedding_for_initial, # "initial" vector in Qdrant
512
  "tile_pooled_embedding": embedding_for_pooling, # "mean_pooling" vector in Qdrant
513
  "experimental_pooled_embedding": experimental_pooled, # "experimental_pooling" vector in Qdrant
514
  "global_pooled_embedding": global_pooling, # "global_pooling" vector in Qdrant
 
516
  "image": raw_img,
517
  "resized_image": resized_img,
518
  }
519
+
520
  # For "all" strategy, include BOTH representations for comparison
521
  if self.embedding_strategy == "all":
522
  result["extra_vectors"] = {
523
  # Standard baseline vectors (for comparison)
524
+ "full_embedding": full_embedding, # All tokens [total, 128]
525
+ "global_pooled": global_pooled, # Global mean [128]
526
  # Pooling vectors (already in main result)
527
+ "visual_embedding": visual_embedding, # Visual only [visual, 128]
528
+ "tile_pooled": tile_pooled, # Tile-level [tiles, 128]
529
  }
530
+
531
  return result
532
+
533
  def _upload_batch(self, upload_queue: List[Dict[str, Any]]) -> int:
534
  """Upload batch to Qdrant."""
535
  if not upload_queue or not self.indexer:
536
  return 0
537
+
538
  logger.info(f"📤 Uploading batch of {len(upload_queue)} pages...")
539
+
540
  count = self.indexer.upload_batch(
541
  upload_queue,
542
  delay_between_batches=self.delay_between_uploads,
543
  )
544
+
545
  return count
546
+
547
  def _get_extra_metadata(self, filename: str) -> Dict[str, Any]:
548
  """Get extra metadata for a filename."""
549
  if not self.metadata_mapping:
550
  return {}
551
+
552
  # Normalize filename
553
  filename_clean = filename.replace(".pdf", "").replace(".PDF", "").strip().lower()
554
+
555
  # Try exact match
556
  if filename_clean in self.metadata_mapping:
557
  return self.metadata_mapping[filename_clean].copy()
558
+
559
  # Try fuzzy match
560
  from difflib import SequenceMatcher
561
+
562
  best_match = None
563
  best_score = 0.0
564
+
565
  for known_filename, metadata in self.metadata_mapping.items():
566
  score = SequenceMatcher(None, filename_clean, known_filename.lower()).ratio()
567
  if score > best_score and score > 0.75:
568
  best_score = score
569
  best_match = metadata
570
+
571
  if best_match:
572
  logger.debug(f"Fuzzy matched '{filename}' with score {best_score:.2f}")
573
  return best_match.copy()
574
+
575
  return {}
576
+
577
  def _sanitize_text(self, text: str) -> str:
578
  """Remove invalid Unicode characters."""
579
  if not text:
580
  return ""
581
  return text.encode("utf-8", errors="surrogatepass").decode("utf-8", errors="ignore")
582
+
583
  @staticmethod
584
  def generate_chunk_id(filename: str, page_number: int) -> str:
585
  """Generate deterministic chunk ID."""
 
587
  hash_obj = hashlib.sha256(content.encode())
588
  hex_str = hash_obj.hexdigest()[:32]
589
  return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
590
+
591
  @staticmethod
592
  def load_metadata_mapping(json_path: Path) -> Dict[str, Dict[str, Any]]:
593
  """
594
  Load metadata mapping from JSON file.
595
+
596
  Expected format:
597
  {
598
  "filenames": {
 
600
  ...
601
  }
602
  }
603
+
604
  Or simple format:
605
  {
606
  "Report Name 2023": {"year": 2023, "source": "Local Government", ...},
 
608
  }
609
  """
610
  import json
611
+
612
  with open(json_path, "r") as f:
613
  data = json.load(f)
614
+
615
  # Check if nested under "filenames"
616
  if "filenames" in data and isinstance(data["filenames"], dict):
617
  mapping = data["filenames"]
618
  else:
619
  mapping = data
620
+
621
  # Normalize keys to lowercase
622
  normalized = {}
623
  for filename, metadata in mapping.items():
624
  key = filename.lower().strip().replace(".pdf", "")
625
  normalized[key] = metadata
626
+
627
  logger.info(f"📖 Loaded metadata for {len(normalized)} files")
628
  return normalized
 
visual_rag/indexing/qdrant_indexer.py CHANGED
@@ -11,43 +11,61 @@ Features:
11
  - Configurable payload indexes
12
  """
13
 
14
- import time
15
  import hashlib
16
  import logging
17
- from typing import List, Dict, Any, Optional, Set
 
18
  from urllib.parse import urlparse
 
19
  import numpy as np
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  logger = logging.getLogger(__name__)
22
 
23
 
24
  class QdrantIndexer:
25
  """
26
  Upload visual embeddings to Qdrant.
27
-
28
  Works independently - just needs embeddings and metadata.
29
-
30
  Args:
31
  url: Qdrant server URL
32
  api_key: Qdrant API key
33
  collection_name: Name of the collection
34
  timeout: Request timeout in seconds
35
  prefer_grpc: Use gRPC protocol (faster but may have issues)
36
-
37
  Example:
38
  >>> indexer = QdrantIndexer(
39
  ... url="https://your-cluster.qdrant.io:6333",
40
  ... api_key="your-api-key",
41
  ... collection_name="my_collection",
42
  ... )
43
- >>>
44
  >>> # Create collection
45
  >>> indexer.create_collection()
46
- >>>
47
  >>> # Upload points
48
  >>> indexer.upload_batch(points)
49
  """
50
-
51
  def __init__(
52
  self,
53
  url: str,
@@ -57,14 +75,12 @@ class QdrantIndexer:
57
  prefer_grpc: bool = False,
58
  vector_datatype: str = "float32",
59
  ):
60
- try:
61
- from qdrant_client import QdrantClient
62
- except ImportError:
63
  raise ImportError(
64
  "Qdrant client not installed. "
65
  "Install with: pip install visual-rag-toolkit[qdrant]"
66
  )
67
-
68
  self.collection_name = collection_name
69
  self.timeout = timeout
70
  if vector_datatype not in ("float32", "float16"):
@@ -81,7 +97,7 @@ class QdrantIndexer:
81
  grpc_port = 6334
82
  except Exception:
83
  grpc_port = None
84
-
85
  def _make_client(use_grpc: bool):
86
  return QdrantClient(
87
  url=url,
@@ -102,16 +118,16 @@ class QdrantIndexer:
102
  self.client = _make_client(False)
103
  else:
104
  raise
105
-
106
  logger.info(f"🔌 Connected to Qdrant: {url}")
107
  logger.info(f" Collection: {collection_name}")
108
  logger.info(f" Vector datatype: {self.vector_datatype}")
109
-
110
  def collection_exists(self) -> bool:
111
  """Check if collection exists."""
112
  collections = self.client.get_collections().collections
113
  return any(c.name == self.collection_name for c in collections)
114
-
115
  def create_collection(
116
  self,
117
  embedding_dim: int = 128,
@@ -122,32 +138,22 @@ class QdrantIndexer:
122
  ) -> bool:
123
  """
124
  Create collection with multi-vector support.
125
-
126
  Creates named vectors:
127
  - initial: Full multi-vector embeddings (num_patches × dim)
128
  - mean_pooling: Tile-level pooled vectors (num_tiles × dim)
129
  - experimental_pooling: Experimental multi-vector pooling (varies by model)
130
  - global_pooling: Single vector pooled representation (dim)
131
-
132
  Args:
133
  embedding_dim: Embedding dimension (128 for ColSmol)
134
  force_recreate: Delete and recreate if exists
135
  enable_quantization: Enable int8 quantization
136
  indexing_threshold: Qdrant optimizer indexing threshold (set 0 to always build ANN indexes)
137
-
138
  Returns:
139
  True if created, False if already existed
140
  """
141
- from qdrant_client.http import models
142
- from qdrant_client.http.models import (
143
- Distance,
144
- VectorParams,
145
- OptimizersConfigDiff,
146
- HnswConfigDiff,
147
- ScalarQuantizationConfig,
148
- ScalarType,
149
- )
150
-
151
  if self.collection_exists():
152
  if force_recreate:
153
  logger.info(f"🗑️ Deleting existing collection: {self.collection_name}")
@@ -155,120 +161,99 @@ class QdrantIndexer:
155
  else:
156
  logger.info(f"✅ Collection already exists: {self.collection_name}")
157
  return False
158
-
159
  logger.info(f"📦 Creating collection: {self.collection_name}")
160
-
161
  # Multi-vector config for ColBERT-style MaxSim
162
- multivector_config = models.MultiVectorConfig(
163
- comparator=models.MultiVectorComparator.MAX_SIM
164
  )
165
-
166
- # HNSW config for pooled vectors
167
- hnsw_config = HnswConfigDiff(
168
- m=32,
169
- ef_construct=100,
170
- full_scan_threshold=int(full_scan_threshold),
171
- on_disk=True,
172
  )
173
-
174
- # Optional quantization
175
- quantization_config = None
176
- if enable_quantization:
177
- logger.info(" Quantization: ENABLED (int8)")
178
- quantization_config = ScalarQuantizationConfig(
179
- type=ScalarType.INT8,
180
- quantile=0.99,
181
- always_ram=True,
182
- )
183
-
184
- # Vector configs
185
- datatype = models.Datatype.FLOAT16 if self.vector_datatype == "float16" else models.Datatype.FLOAT32
186
  vectors_config = {
187
  "initial": VectorParams(
188
  size=embedding_dim,
189
  distance=Distance.COSINE,
190
  on_disk=True,
191
  multivector_config=multivector_config,
192
- hnsw_config=hnsw_config,
193
  datatype=datatype,
194
- quantization_config=quantization_config,
195
  ),
196
  "mean_pooling": VectorParams(
197
  size=embedding_dim,
198
  distance=Distance.COSINE,
199
- on_disk=False, # Keep in RAM for fast prefetch
200
  multivector_config=multivector_config,
201
- hnsw_config=hnsw_config,
202
  datatype=datatype,
203
- quantization_config=quantization_config,
204
  ),
205
  "experimental_pooling": VectorParams(
206
  size=embedding_dim,
207
  distance=Distance.COSINE,
208
- on_disk=False, # Keep in RAM for fast prefetch
209
  multivector_config=multivector_config,
210
- hnsw_config=hnsw_config,
211
  datatype=datatype,
212
- quantization_config=quantization_config,
213
  ),
214
  "global_pooling": VectorParams(
215
  size=embedding_dim,
216
  distance=Distance.COSINE,
217
- on_disk=False, # Keep in RAM for fast prefetch
218
- hnsw_config=hnsw_config,
219
  datatype=datatype,
220
- quantization_config=quantization_config,
221
  ),
222
  }
223
-
224
- # Optimizer config for low-RAM clusters
225
- optimizer_config = OptimizersConfigDiff(
226
- indexing_threshold=int(indexing_threshold),
227
- memmap_threshold=0, # Use mmap immediately
228
- flush_interval_sec=5, # Flush WAL frequently
229
- )
230
-
231
  self.client.create_collection(
232
  collection_name=self.collection_name,
233
  vectors_config=vectors_config,
234
- optimizers_config=optimizer_config,
235
- hnsw_config=hnsw_config,
236
  )
237
-
 
 
 
 
 
 
 
 
 
 
 
 
238
  logger.info(f"✅ Collection created: {self.collection_name}")
239
  return True
240
-
241
  def create_payload_indexes(
242
  self,
243
  fields: Optional[List[Dict[str, str]]] = None,
244
  ):
245
  """
246
  Create payload indexes for filtering.
247
-
248
  Args:
249
  fields: List of {field, type} dicts
250
  type can be: integer, keyword, bool, float, text
251
  """
252
- from qdrant_client.http import models
253
-
254
  type_mapping = {
255
- "integer": models.PayloadSchemaType.INTEGER,
256
- "keyword": models.PayloadSchemaType.KEYWORD,
257
- "bool": models.PayloadSchemaType.BOOL,
258
- "float": models.PayloadSchemaType.FLOAT,
259
- "text": models.PayloadSchemaType.TEXT,
260
  }
261
-
262
  if not fields:
263
  return
264
-
265
  logger.info("📇 Creating payload indexes...")
266
-
267
  for field_config in fields:
268
  field_name = field_config["field"]
269
  field_type_str = field_config.get("type", "keyword")
270
- field_type = type_mapping.get(field_type_str, models.PayloadSchemaType.KEYWORD)
271
-
272
  try:
273
  self.client.create_payload_index(
274
  collection_name=self.collection_name,
@@ -278,7 +263,7 @@ class QdrantIndexer:
278
  logger.info(f" ✅ {field_name} ({field_type_str})")
279
  except Exception as e:
280
  logger.debug(f" Index {field_name} might already exist: {e}")
281
-
282
  def upload_batch(
283
  self,
284
  points: List[Dict[str, Any]],
@@ -289,7 +274,7 @@ class QdrantIndexer:
289
  ) -> int:
290
  """
291
  Upload a batch of points to Qdrant.
292
-
293
  Each point should have:
294
  - id: Unique point ID (string or UUID)
295
  - visual_embedding: Full embedding [num_patches, dim]
@@ -297,28 +282,28 @@ class QdrantIndexer:
297
  - experimental_pooled_embedding: Experimental pooled embedding [*, dim]
298
  - global_pooled_embedding: Pooled embedding [dim]
299
  - metadata: Payload dict
300
-
301
  Args:
302
  points: List of point dicts
303
  max_retries: Retry attempts on failure
304
  delay_between_batches: Delay after upload
305
  wait: Wait for operation to complete on Qdrant server
306
  stop_event: Optional threading.Event used to cancel uploads early
307
-
308
  Returns:
309
  Number of successfully uploaded points
310
  """
311
- from qdrant_client.http import models
312
-
313
  if not points:
314
  return 0
315
 
316
  def _is_cancelled() -> bool:
317
  return stop_event is not None and getattr(stop_event, "is_set", lambda: False)()
318
-
319
  def _is_payload_too_large_error(e: Exception) -> bool:
320
  msg = str(e)
321
- if ("JSON payload" in msg and "larger than allowed" in msg) or ("Payload error:" in msg and "limit:" in msg):
 
 
322
  return True
323
  content = getattr(e, "content", None)
324
  if content is not None:
@@ -329,7 +314,9 @@ class QdrantIndexer:
329
  text = str(content)
330
  except Exception:
331
  text = ""
332
- if ("JSON payload" in text and "larger than allowed" in text) or ("Payload error" in text and "limit" in text):
 
 
333
  return True
334
  resp = getattr(e, "response", None)
335
  if resp is not None:
@@ -337,7 +324,9 @@ class QdrantIndexer:
337
  text = str(getattr(resp, "text", "") or "")
338
  except Exception:
339
  text = ""
340
- if ("JSON payload" in text and "larger than allowed" in text) or ("Payload error" in text and "limit" in text):
 
 
341
  return True
342
  return False
343
 
@@ -346,8 +335,8 @@ class QdrantIndexer:
346
  return val.tolist()
347
  return val
348
 
349
- def _build_qdrant_points(batch_points: List[Dict[str, Any]]) -> List[models.PointStruct]:
350
- qdrant_points: List[models.PointStruct] = []
351
  for p in batch_points:
352
  global_pooled = p.get("global_pooled_embedding")
353
  if global_pooled is None:
@@ -355,15 +344,19 @@ class QdrantIndexer:
355
  global_pooled = tile_pooled.mean(axis=0)
356
  global_pooled = np.array(global_pooled, dtype=np.float32).reshape(-1)
357
 
358
- initial = np.array(p["visual_embedding"], dtype=np.float32).astype(self._np_vector_dtype, copy=False)
359
- mean_pooling = np.array(p["tile_pooled_embedding"], dtype=np.float32).astype(self._np_vector_dtype, copy=False)
360
- experimental_pooling = np.array(p["experimental_pooled_embedding"], dtype=np.float32).astype(
 
361
  self._np_vector_dtype, copy=False
362
  )
 
 
 
363
  global_pooling = global_pooled.astype(self._np_vector_dtype, copy=False)
364
 
365
  qdrant_points.append(
366
- models.PointStruct(
367
  id=p["id"],
368
  vector={
369
  "initial": _to_list(initial),
@@ -375,7 +368,7 @@ class QdrantIndexer:
375
  )
376
  )
377
  return qdrant_points
378
-
379
  # Upload with retry
380
  for attempt in range(max_retries):
381
  try:
@@ -421,11 +414,11 @@ class QdrantIndexer:
421
  if attempt < max_retries - 1:
422
  if _is_cancelled():
423
  return 0
424
- time.sleep(2 ** attempt) # Exponential backoff
425
-
426
  logger.error(f"❌ Upload failed after {max_retries} attempts")
427
  return 0
428
-
429
  def check_exists(self, chunk_id: str) -> bool:
430
  """Check if a point already exists."""
431
  try:
@@ -438,50 +431,78 @@ class QdrantIndexer:
438
  return len(result) > 0
439
  except Exception:
440
  return False
441
-
442
  def get_existing_ids(self, filename: str) -> Set[str]:
443
- """Get all point IDs for a specific file."""
444
- from qdrant_client.models import Filter, FieldCondition, MatchValue
445
-
 
 
446
  existing_ids = set()
447
  offset = None
448
-
449
- while True:
450
- results = self.client.scroll(
451
- collection_name=self.collection_name,
452
- scroll_filter=Filter(
453
- must=[FieldCondition(key="filename", match=MatchValue(value=filename))]
454
- ),
455
- limit=100,
456
- offset=offset,
457
- with_payload=["page_number"],
458
- with_vectors=False,
459
- )
460
-
461
- points, next_offset = results
462
-
463
- for point in points:
464
- existing_ids.add(str(point.id))
465
-
466
- if next_offset is None or len(points) == 0:
467
- break
468
- offset = next_offset
469
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  return existing_ids
471
-
472
  def get_collection_info(self) -> Optional[Dict[str, Any]]:
473
  """Get collection statistics."""
474
  try:
475
  info = self.client.get_collection(self.collection_name)
476
-
477
  status = info.status
478
  if hasattr(status, "value"):
479
  status = status.value
480
-
481
  indexed_count = getattr(info, "indexed_vectors_count", 0) or 0
482
  if isinstance(indexed_count, dict):
483
  indexed_count = sum(indexed_count.values())
484
-
485
  return {
486
  "status": str(status),
487
  "points_count": getattr(info, "points_count", 0),
@@ -490,12 +511,12 @@ class QdrantIndexer:
490
  except Exception as e:
491
  logger.warning(f"Could not get collection info: {e}")
492
  return None
493
-
494
  @staticmethod
495
  def generate_point_id(filename: str, page_number: int) -> str:
496
  """
497
  Generate deterministic point ID from filename and page.
498
-
499
  Returns a valid UUID string.
500
  """
501
  content = f"{filename}:page:{page_number}"
@@ -503,5 +524,3 @@ class QdrantIndexer:
503
  hex_str = hash_obj.hexdigest()[:32]
504
  # Format as UUID
505
  return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
506
-
507
-
 
11
  - Configurable payload indexes
12
  """
13
 
 
14
  import hashlib
15
  import logging
16
+ import time
17
+ from typing import Any, Dict, List, Optional, Set
18
  from urllib.parse import urlparse
19
+
20
  import numpy as np
21
 
22
+ try:
23
+ from qdrant_client import QdrantClient
24
+ from qdrant_client.http import models as qdrant_models
25
+ from qdrant_client.http.models import Distance, VectorParams
26
+ from qdrant_client.models import FieldCondition, Filter, MatchValue
27
+
28
+ QDRANT_AVAILABLE = True
29
+ except ImportError:
30
+ QDRANT_AVAILABLE = False
31
+ QdrantClient = None
32
+ qdrant_models = None
33
+ Distance = None
34
+ VectorParams = None
35
+ FieldCondition = None
36
+ Filter = None
37
+ MatchValue = None
38
+
39
  logger = logging.getLogger(__name__)
40
 
41
 
42
  class QdrantIndexer:
43
  """
44
  Upload visual embeddings to Qdrant.
45
+
46
  Works independently - just needs embeddings and metadata.
47
+
48
  Args:
49
  url: Qdrant server URL
50
  api_key: Qdrant API key
51
  collection_name: Name of the collection
52
  timeout: Request timeout in seconds
53
  prefer_grpc: Use gRPC protocol (faster but may have issues)
54
+
55
  Example:
56
  >>> indexer = QdrantIndexer(
57
  ... url="https://your-cluster.qdrant.io:6333",
58
  ... api_key="your-api-key",
59
  ... collection_name="my_collection",
60
  ... )
61
+ >>>
62
  >>> # Create collection
63
  >>> indexer.create_collection()
64
+ >>>
65
  >>> # Upload points
66
  >>> indexer.upload_batch(points)
67
  """
68
+
69
  def __init__(
70
  self,
71
  url: str,
 
75
  prefer_grpc: bool = False,
76
  vector_datatype: str = "float32",
77
  ):
78
+ if not QDRANT_AVAILABLE:
 
 
79
  raise ImportError(
80
  "Qdrant client not installed. "
81
  "Install with: pip install visual-rag-toolkit[qdrant]"
82
  )
83
+
84
  self.collection_name = collection_name
85
  self.timeout = timeout
86
  if vector_datatype not in ("float32", "float16"):
 
97
  grpc_port = 6334
98
  except Exception:
99
  grpc_port = None
100
+
101
  def _make_client(use_grpc: bool):
102
  return QdrantClient(
103
  url=url,
 
118
  self.client = _make_client(False)
119
  else:
120
  raise
121
+
122
  logger.info(f"🔌 Connected to Qdrant: {url}")
123
  logger.info(f" Collection: {collection_name}")
124
  logger.info(f" Vector datatype: {self.vector_datatype}")
125
+
126
  def collection_exists(self) -> bool:
127
  """Check if collection exists."""
128
  collections = self.client.get_collections().collections
129
  return any(c.name == self.collection_name for c in collections)
130
+
131
  def create_collection(
132
  self,
133
  embedding_dim: int = 128,
 
138
  ) -> bool:
139
  """
140
  Create collection with multi-vector support.
141
+
142
  Creates named vectors:
143
  - initial: Full multi-vector embeddings (num_patches × dim)
144
  - mean_pooling: Tile-level pooled vectors (num_tiles × dim)
145
  - experimental_pooling: Experimental multi-vector pooling (varies by model)
146
  - global_pooling: Single vector pooled representation (dim)
147
+
148
  Args:
149
  embedding_dim: Embedding dimension (128 for ColSmol)
150
  force_recreate: Delete and recreate if exists
151
  enable_quantization: Enable int8 quantization
152
  indexing_threshold: Qdrant optimizer indexing threshold (set 0 to always build ANN indexes)
153
+
154
  Returns:
155
  True if created, False if already existed
156
  """
 
 
 
 
 
 
 
 
 
 
157
  if self.collection_exists():
158
  if force_recreate:
159
  logger.info(f"🗑️ Deleting existing collection: {self.collection_name}")
 
161
  else:
162
  logger.info(f"✅ Collection already exists: {self.collection_name}")
163
  return False
164
+
165
  logger.info(f"📦 Creating collection: {self.collection_name}")
166
+
167
  # Multi-vector config for ColBERT-style MaxSim
168
+ multivector_config = qdrant_models.MultiVectorConfig(
169
+ comparator=qdrant_models.MultiVectorComparator.MAX_SIM
170
  )
171
+
172
+ # Vector configs - simplified for compatibility
173
+ datatype = (
174
+ qdrant_models.Datatype.FLOAT16
175
+ if self.vector_datatype == "float16"
176
+ else qdrant_models.Datatype.FLOAT32
 
177
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  vectors_config = {
179
  "initial": VectorParams(
180
  size=embedding_dim,
181
  distance=Distance.COSINE,
182
  on_disk=True,
183
  multivector_config=multivector_config,
 
184
  datatype=datatype,
 
185
  ),
186
  "mean_pooling": VectorParams(
187
  size=embedding_dim,
188
  distance=Distance.COSINE,
189
+ on_disk=False,
190
  multivector_config=multivector_config,
 
191
  datatype=datatype,
 
192
  ),
193
  "experimental_pooling": VectorParams(
194
  size=embedding_dim,
195
  distance=Distance.COSINE,
196
+ on_disk=False,
197
  multivector_config=multivector_config,
 
198
  datatype=datatype,
 
199
  ),
200
  "global_pooling": VectorParams(
201
  size=embedding_dim,
202
  distance=Distance.COSINE,
203
+ on_disk=False,
 
204
  datatype=datatype,
 
205
  ),
206
  }
207
+
 
 
 
 
 
 
 
208
  self.client.create_collection(
209
  collection_name=self.collection_name,
210
  vectors_config=vectors_config,
 
 
211
  )
212
+
213
+ # Create required payload index for skip_existing functionality
214
+ # This index is needed for filtering by filename when checking existing docs
215
+ try:
216
+ self.client.create_payload_index(
217
+ collection_name=self.collection_name,
218
+ field_name="filename",
219
+ field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
220
+ )
221
+ logger.info(" 📇 Created payload index: filename")
222
+ except Exception as e:
223
+ logger.warning(f" ⚠️ Could not create filename index: {e}")
224
+
225
  logger.info(f"✅ Collection created: {self.collection_name}")
226
  return True
227
+
228
  def create_payload_indexes(
229
  self,
230
  fields: Optional[List[Dict[str, str]]] = None,
231
  ):
232
  """
233
  Create payload indexes for filtering.
234
+
235
  Args:
236
  fields: List of {field, type} dicts
237
  type can be: integer, keyword, bool, float, text
238
  """
 
 
239
  type_mapping = {
240
+ "integer": qdrant_models.PayloadSchemaType.INTEGER,
241
+ "keyword": qdrant_models.PayloadSchemaType.KEYWORD,
242
+ "bool": qdrant_models.PayloadSchemaType.BOOL,
243
+ "float": qdrant_models.PayloadSchemaType.FLOAT,
244
+ "text": qdrant_models.PayloadSchemaType.TEXT,
245
  }
246
+
247
  if not fields:
248
  return
249
+
250
  logger.info("📇 Creating payload indexes...")
251
+
252
  for field_config in fields:
253
  field_name = field_config["field"]
254
  field_type_str = field_config.get("type", "keyword")
255
+ field_type = type_mapping.get(field_type_str, qdrant_models.PayloadSchemaType.KEYWORD)
256
+
257
  try:
258
  self.client.create_payload_index(
259
  collection_name=self.collection_name,
 
263
  logger.info(f" ✅ {field_name} ({field_type_str})")
264
  except Exception as e:
265
  logger.debug(f" Index {field_name} might already exist: {e}")
266
+
267
  def upload_batch(
268
  self,
269
  points: List[Dict[str, Any]],
 
274
  ) -> int:
275
  """
276
  Upload a batch of points to Qdrant.
277
+
278
  Each point should have:
279
  - id: Unique point ID (string or UUID)
280
  - visual_embedding: Full embedding [num_patches, dim]
 
282
  - experimental_pooled_embedding: Experimental pooled embedding [*, dim]
283
  - global_pooled_embedding: Pooled embedding [dim]
284
  - metadata: Payload dict
285
+
286
  Args:
287
  points: List of point dicts
288
  max_retries: Retry attempts on failure
289
  delay_between_batches: Delay after upload
290
  wait: Wait for operation to complete on Qdrant server
291
  stop_event: Optional threading.Event used to cancel uploads early
292
+
293
  Returns:
294
  Number of successfully uploaded points
295
  """
 
 
296
  if not points:
297
  return 0
298
 
299
  def _is_cancelled() -> bool:
300
  return stop_event is not None and getattr(stop_event, "is_set", lambda: False)()
301
+
302
  def _is_payload_too_large_error(e: Exception) -> bool:
303
  msg = str(e)
304
+ if ("JSON payload" in msg and "larger than allowed" in msg) or (
305
+ "Payload error:" in msg and "limit:" in msg
306
+ ):
307
  return True
308
  content = getattr(e, "content", None)
309
  if content is not None:
 
314
  text = str(content)
315
  except Exception:
316
  text = ""
317
+ if ("JSON payload" in text and "larger than allowed" in text) or (
318
+ "Payload error" in text and "limit" in text
319
+ ):
320
  return True
321
  resp = getattr(e, "response", None)
322
  if resp is not None:
 
324
  text = str(getattr(resp, "text", "") or "")
325
  except Exception:
326
  text = ""
327
+ if ("JSON payload" in text and "larger than allowed" in text) or (
328
+ "Payload error" in text and "limit" in text
329
+ ):
330
  return True
331
  return False
332
 
 
335
  return val.tolist()
336
  return val
337
 
338
+ def _build_qdrant_points(batch_points: List[Dict[str, Any]]) -> List[qdrant_models.PointStruct]:
339
+ qdrant_points: List[qdrant_models.PointStruct] = []
340
  for p in batch_points:
341
  global_pooled = p.get("global_pooled_embedding")
342
  if global_pooled is None:
 
344
  global_pooled = tile_pooled.mean(axis=0)
345
  global_pooled = np.array(global_pooled, dtype=np.float32).reshape(-1)
346
 
347
+ initial = np.array(p["visual_embedding"], dtype=np.float32).astype(
348
+ self._np_vector_dtype, copy=False
349
+ )
350
+ mean_pooling = np.array(p["tile_pooled_embedding"], dtype=np.float32).astype(
351
  self._np_vector_dtype, copy=False
352
  )
353
+ experimental_pooling = np.array(
354
+ p["experimental_pooled_embedding"], dtype=np.float32
355
+ ).astype(self._np_vector_dtype, copy=False)
356
  global_pooling = global_pooled.astype(self._np_vector_dtype, copy=False)
357
 
358
  qdrant_points.append(
359
+ qdrant_models.PointStruct(
360
  id=p["id"],
361
  vector={
362
  "initial": _to_list(initial),
 
368
  )
369
  )
370
  return qdrant_points
371
+
372
  # Upload with retry
373
  for attempt in range(max_retries):
374
  try:
 
414
  if attempt < max_retries - 1:
415
  if _is_cancelled():
416
  return 0
417
+ time.sleep(2**attempt) # Exponential backoff
418
+
419
  logger.error(f"❌ Upload failed after {max_retries} attempts")
420
  return 0
421
+
422
  def check_exists(self, chunk_id: str) -> bool:
423
  """Check if a point already exists."""
424
  try:
 
431
  return len(result) > 0
432
  except Exception:
433
  return False
434
+
435
  def get_existing_ids(self, filename: str) -> Set[str]:
436
+ """Get all point IDs for a specific file.
437
+
438
+ Requires a payload index on 'filename' field. If the index doesn't exist,
439
+ this method will attempt to create it automatically.
440
+ """
441
  existing_ids = set()
442
  offset = None
443
+
444
+ try:
445
+ while True:
446
+ results = self.client.scroll(
447
+ collection_name=self.collection_name,
448
+ scroll_filter=Filter(
449
+ must=[FieldCondition(key="filename", match=MatchValue(value=filename))]
450
+ ),
451
+ limit=100,
452
+ offset=offset,
453
+ with_payload=["page_number"],
454
+ with_vectors=False,
455
+ )
456
+
457
+ points, next_offset = results
458
+
459
+ for point in points:
460
+ existing_ids.add(str(point.id))
461
+
462
+ if next_offset is None or len(points) == 0:
463
+ break
464
+ offset = next_offset
465
+
466
+ except Exception as e:
467
+ error_msg = str(e).lower()
468
+ if "index required" in error_msg or "index" in error_msg and "filename" in error_msg:
469
+ # Missing payload index - try to create it
470
+ logger.warning(
471
+ "⚠️ Missing 'filename' payload index. Creating it now... "
472
+ "(skip_existing requires this index for filtering)"
473
+ )
474
+ try:
475
+ self.client.create_payload_index(
476
+ collection_name=self.collection_name,
477
+ field_name="filename",
478
+ field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
479
+ )
480
+ logger.info(" ✅ Created 'filename' index. Retrying query...")
481
+ # Retry the query
482
+ return self.get_existing_ids(filename)
483
+ except Exception as idx_err:
484
+ logger.warning(f" ❌ Could not create index: {idx_err}")
485
+ logger.warning(" Returning empty set - all pages will be processed")
486
+ return set()
487
+ else:
488
+ logger.warning(f"⚠️ Error checking existing IDs: {e}")
489
+ return set()
490
+
491
  return existing_ids
492
+
493
  def get_collection_info(self) -> Optional[Dict[str, Any]]:
494
  """Get collection statistics."""
495
  try:
496
  info = self.client.get_collection(self.collection_name)
497
+
498
  status = info.status
499
  if hasattr(status, "value"):
500
  status = status.value
501
+
502
  indexed_count = getattr(info, "indexed_vectors_count", 0) or 0
503
  if isinstance(indexed_count, dict):
504
  indexed_count = sum(indexed_count.values())
505
+
506
  return {
507
  "status": str(status),
508
  "points_count": getattr(info, "points_count", 0),
 
511
  except Exception as e:
512
  logger.warning(f"Could not get collection info: {e}")
513
  return None
514
+
515
  @staticmethod
516
  def generate_point_id(filename: str, page_number: int) -> str:
517
  """
518
  Generate deterministic point ID from filename and page.
519
+
520
  Returns a valid UUID string.
521
  """
522
  content = f"{filename}:page:{page_number}"
 
524
  hex_str = hash_obj.hexdigest()[:32]
525
  # Format as UUID
526
  return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
 
 
visual_rag/preprocessing/__init__.py CHANGED
@@ -1,5 +1,3 @@
1
  from visual_rag.preprocessing.crop_empty import CropEmptyConfig, crop_empty
2
 
3
  __all__ = ["CropEmptyConfig", "crop_empty"]
4
-
5
-
 
1
  from visual_rag.preprocessing.crop_empty import CropEmptyConfig, crop_empty
2
 
3
  __all__ = ["CropEmptyConfig", "crop_empty"]
 
 
visual_rag/preprocessing/crop_empty.py CHANGED
@@ -20,7 +20,9 @@ class CropEmptyConfig:
20
  uniform_rowcol_std_threshold: float = 0.0
21
 
22
 
23
- def crop_empty(image: Image.Image, *, config: CropEmptyConfig) -> Tuple[Image.Image, Dict[str, Any]]:
 
 
24
  img = image.convert("RGB")
25
  arr = np.array(img)
26
  intensity = arr.mean(axis=2)
@@ -31,7 +33,9 @@ def crop_empty(image: Image.Image, *, config: CropEmptyConfig) -> Tuple[Image.Im
31
  pixels = intensity[i, :] if axis == 0 else intensity[:, i]
32
  white = float(np.mean(pixels > config.color_threshold))
33
  non_white = 1.0 - white
34
- if float(config.uniform_rowcol_std_threshold) > 0.0 and float(np.std(pixels)) <= float(config.uniform_rowcol_std_threshold):
 
 
35
  continue
36
  if (white < config.min_white_fraction) and (non_white > min_content_density_threshold):
37
  return int(i)
@@ -43,7 +47,9 @@ def crop_empty(image: Image.Image, *, config: CropEmptyConfig) -> Tuple[Image.Im
43
  pixels = intensity[i, :] if axis == 0 else intensity[:, i]
44
  white = float(np.mean(pixels > config.color_threshold))
45
  non_white = 1.0 - white
46
- if float(config.uniform_rowcol_std_threshold) > 0.0 and float(np.std(pixels)) <= float(config.uniform_rowcol_std_threshold):
 
 
47
  continue
48
  if (white < config.min_white_fraction) and (non_white > min_content_density_threshold):
49
  return int(i + 1)
@@ -53,8 +59,12 @@ def crop_empty(image: Image.Image, *, config: CropEmptyConfig) -> Tuple[Image.Im
53
  left = _find_border_start(1, min_content_density_threshold=float(config.content_density_sides))
54
  right = _find_border_end(1, min_content_density_threshold=float(config.content_density_sides))
55
 
56
- main_text_end = _find_border_end(0, min_content_density_threshold=float(config.content_density_main_text))
57
- last_content_end = _find_border_end(0, min_content_density_threshold=float(config.content_density_any))
 
 
 
 
58
  bottom = main_text_end if config.remove_page_number else last_content_end
59
 
60
  width, height = img.size
@@ -108,5 +118,3 @@ def crop_empty(image: Image.Image, *, config: CropEmptyConfig) -> Tuple[Image.Im
108
  "uniform_rowcol_std_threshold": float(config.uniform_rowcol_std_threshold),
109
  },
110
  }
111
-
112
-
 
20
  uniform_rowcol_std_threshold: float = 0.0
21
 
22
 
23
+ def crop_empty(
24
+ image: Image.Image, *, config: CropEmptyConfig
25
+ ) -> Tuple[Image.Image, Dict[str, Any]]:
26
  img = image.convert("RGB")
27
  arr = np.array(img)
28
  intensity = arr.mean(axis=2)
 
33
  pixels = intensity[i, :] if axis == 0 else intensity[:, i]
34
  white = float(np.mean(pixels > config.color_threshold))
35
  non_white = 1.0 - white
36
+ if float(config.uniform_rowcol_std_threshold) > 0.0 and float(np.std(pixels)) <= float(
37
+ config.uniform_rowcol_std_threshold
38
+ ):
39
  continue
40
  if (white < config.min_white_fraction) and (non_white > min_content_density_threshold):
41
  return int(i)
 
47
  pixels = intensity[i, :] if axis == 0 else intensity[:, i]
48
  white = float(np.mean(pixels > config.color_threshold))
49
  non_white = 1.0 - white
50
+ if float(config.uniform_rowcol_std_threshold) > 0.0 and float(np.std(pixels)) <= float(
51
+ config.uniform_rowcol_std_threshold
52
+ ):
53
  continue
54
  if (white < config.min_white_fraction) and (non_white > min_content_density_threshold):
55
  return int(i + 1)
 
59
  left = _find_border_start(1, min_content_density_threshold=float(config.content_density_sides))
60
  right = _find_border_end(1, min_content_density_threshold=float(config.content_density_sides))
61
 
62
+ main_text_end = _find_border_end(
63
+ 0, min_content_density_threshold=float(config.content_density_main_text)
64
+ )
65
+ last_content_end = _find_border_end(
66
+ 0, min_content_density_threshold=float(config.content_density_any)
67
+ )
68
  bottom = main_text_end if config.remove_page_number else last_content_end
69
 
70
  width, height = img.size
 
118
  "uniform_rowcol_std_threshold": float(config.uniform_rowcol_std_threshold),
119
  },
120
  }
 
 
visual_rag/qdrant_admin.py CHANGED
@@ -33,9 +33,16 @@ def _resolve_qdrant_connection(
33
  import os
34
 
35
  _maybe_load_dotenv()
36
- resolved_url = url or os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
 
 
 
 
 
37
  if not resolved_url:
38
- raise ValueError("Qdrant URL not set (pass url= or set SIGIR_QDRANT_URL/DEST_QDRANT_URL/QDRANT_URL).")
 
 
39
  resolved_key = (
40
  api_key
41
  or os.getenv("SIGIR_QDRANT_KEY")
@@ -105,7 +112,11 @@ class QdrantAdmin:
105
  from qdrant_client.http import models as m
106
 
107
  hnsw_diff = m.HnswConfigDiff(**hnsw_config) if isinstance(hnsw_config, dict) else None
108
- params_diff = m.CollectionParamsDiff(**collection_params) if isinstance(collection_params, dict) else None
 
 
 
 
109
  if hnsw_diff is None and params_diff is None:
110
  raise ValueError("No changes provided (pass hnsw_config and/or collection_params).")
111
  return bool(
@@ -143,7 +154,9 @@ class QdrantAdmin:
143
 
144
  missing = [str(k) for k in (vectors or {}).keys() if existing and str(k) not in existing]
145
  if missing:
146
- raise ValueError(f"Vectors do not exist in collection '{collection_name}': {missing}. Existing: {sorted(existing)}")
 
 
147
 
148
  ok = True
149
  for name, cfg in (vectors or {}).items():
@@ -158,13 +171,16 @@ class QdrantAdmin:
158
  )
159
  }
160
 
161
- ok = bool(
162
- self.client.update_collection(
163
- collection_name=collection_name,
164
- vectors_config=vectors_diff,
165
- timeout=int(timeout) if timeout is not None else None,
 
 
166
  )
167
- ) and ok
 
168
 
169
  return ok
170
 
@@ -192,7 +208,9 @@ class QdrantAdmin:
192
  vectors[str(vname)] = {"on_disk": True, "hnsw_config": {"on_disk": True}}
193
 
194
  if vectors:
195
- self.modify_collection_vector_config(collection_name=collection_name, vectors=vectors, timeout=timeout)
 
 
196
 
197
  self.modify_collection_config(
198
  collection_name=collection_name,
@@ -202,4 +220,3 @@ class QdrantAdmin:
202
  )
203
 
204
  return self.get_collection_info(collection_name=collection_name)
205
-
 
33
  import os
34
 
35
  _maybe_load_dotenv()
36
+ resolved_url = (
37
+ url
38
+ or os.getenv("SIGIR_QDRANT_URL")
39
+ or os.getenv("DEST_QDRANT_URL")
40
+ or os.getenv("QDRANT_URL")
41
+ )
42
  if not resolved_url:
43
+ raise ValueError(
44
+ "Qdrant URL not set (pass url= or set SIGIR_QDRANT_URL/DEST_QDRANT_URL/QDRANT_URL)."
45
+ )
46
  resolved_key = (
47
  api_key
48
  or os.getenv("SIGIR_QDRANT_KEY")
 
112
  from qdrant_client.http import models as m
113
 
114
  hnsw_diff = m.HnswConfigDiff(**hnsw_config) if isinstance(hnsw_config, dict) else None
115
+ params_diff = (
116
+ m.CollectionParamsDiff(**collection_params)
117
+ if isinstance(collection_params, dict)
118
+ else None
119
+ )
120
  if hnsw_diff is None and params_diff is None:
121
  raise ValueError("No changes provided (pass hnsw_config and/or collection_params).")
122
  return bool(
 
154
 
155
  missing = [str(k) for k in (vectors or {}).keys() if existing and str(k) not in existing]
156
  if missing:
157
+ raise ValueError(
158
+ f"Vectors do not exist in collection '{collection_name}': {missing}. Existing: {sorted(existing)}"
159
+ )
160
 
161
  ok = True
162
  for name, cfg in (vectors or {}).items():
 
171
  )
172
  }
173
 
174
+ ok = (
175
+ bool(
176
+ self.client.update_collection(
177
+ collection_name=collection_name,
178
+ vectors_config=vectors_diff,
179
+ timeout=int(timeout) if timeout is not None else None,
180
+ )
181
  )
182
+ and ok
183
+ )
184
 
185
  return ok
186
 
 
208
  vectors[str(vname)] = {"on_disk": True, "hnsw_config": {"on_disk": True}}
209
 
210
  if vectors:
211
+ self.modify_collection_vector_config(
212
+ collection_name=collection_name, vectors=vectors, timeout=timeout
213
+ )
214
 
215
  self.modify_collection_config(
216
  collection_name=collection_name,
 
220
  )
221
 
222
  return self.get_collection_info(collection_name=collection_name)
 
visual_rag/retrieval/__init__.py CHANGED
@@ -6,10 +6,10 @@ Components:
6
  - SingleStageRetriever: Direct multi-vector or pooled search
7
  """
8
 
9
- from visual_rag.retrieval.two_stage import TwoStageRetriever
10
- from visual_rag.retrieval.single_stage import SingleStageRetriever
11
  from visual_rag.retrieval.multi_vector import MultiVectorRetriever
 
12
  from visual_rag.retrieval.three_stage import ThreeStageRetriever
 
13
 
14
  __all__ = [
15
  "TwoStageRetriever",
 
6
  - SingleStageRetriever: Direct multi-vector or pooled search
7
  """
8
 
 
 
9
  from visual_rag.retrieval.multi_vector import MultiVectorRetriever
10
+ from visual_rag.retrieval.single_stage import SingleStageRetriever
11
  from visual_rag.retrieval.three_stage import ThreeStageRetriever
12
+ from visual_rag.retrieval.two_stage import TwoStageRetriever
13
 
14
  __all__ = [
15
  "TwoStageRetriever",
visual_rag/retrieval/multi_vector.py CHANGED
@@ -2,18 +2,35 @@ import os
2
  from typing import Any, Dict, List, Optional
3
  from urllib.parse import urlparse
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from visual_rag.embedding.visual_embedder import VisualEmbedder
6
  from visual_rag.retrieval.single_stage import SingleStageRetriever
7
- from visual_rag.retrieval.two_stage import TwoStageRetriever
8
  from visual_rag.retrieval.three_stage import ThreeStageRetriever
 
9
 
10
 
11
  class MultiVectorRetriever:
12
  @staticmethod
13
  def _maybe_load_dotenv() -> None:
14
- try:
15
- from dotenv import load_dotenv
16
- except ImportError:
17
  return
18
  if os.path.exists(".env"):
19
  load_dotenv(".env")
@@ -33,83 +50,84 @@ class MultiVectorRetriever:
33
  ):
34
  if qdrant_client is None:
35
  self._maybe_load_dotenv()
36
- try:
37
- from qdrant_client import QdrantClient
38
- except ImportError as e:
39
  raise ImportError(
40
  "Qdrant client not installed. Install with: pip install visual-rag-toolkit[qdrant]"
41
- ) from e
42
 
43
  qdrant_url = (
44
  qdrant_url
45
- or os.getenv("SIGIR_QDRANT_URL")
46
- or os.getenv("DEST_QDRANT_URL")
47
  or os.getenv("QDRANT_URL")
 
48
  )
49
  if not qdrant_url:
50
  raise ValueError(
51
- "QDRANT_URL is required (pass qdrant_url or set env var). "
52
- "You can also set DEST_QDRANT_URL to override."
53
  )
54
 
55
  qdrant_api_key = (
56
  qdrant_api_key
57
- or os.getenv("SIGIR_QDRANT_KEY")
58
- or os.getenv("SIGIR_QDRANT_API_KEY")
59
- or os.getenv("DEST_QDRANT_API_KEY")
60
  or os.getenv("QDRANT_API_KEY")
 
61
  )
62
 
63
  grpc_port = None
64
  if prefer_grpc:
65
  try:
66
- if urlparse(qdrant_url).port == 6333:
 
 
67
  grpc_port = 6334
68
  except Exception:
69
- grpc_port = None
 
70
  def _make_client(use_grpc: bool):
71
  return QdrantClient(
72
  url=qdrant_url,
73
  api_key=qdrant_api_key,
 
74
  prefer_grpc=bool(use_grpc),
75
  grpc_port=grpc_port,
76
- timeout=int(request_timeout),
77
  check_compatibility=False,
78
  )
79
 
80
- qdrant_client = _make_client(prefer_grpc)
81
  if prefer_grpc:
82
  try:
83
- _ = qdrant_client.get_collections()
84
  except Exception as e:
85
  msg = str(e)
86
  if "StatusCode.PERMISSION_DENIED" in msg or "http2 header with status: 403" in msg:
87
- qdrant_client = _make_client(False)
88
  else:
89
  raise
 
90
 
91
  self.client = qdrant_client
92
  self.collection_name = collection_name
 
93
  self.embedder = embedder or VisualEmbedder(model_name=model_name)
94
 
95
  self._two_stage = TwoStageRetriever(
96
- self.client,
97
- collection_name=self.collection_name,
98
- request_timeout=int(request_timeout),
99
- max_retries=int(max_retries),
100
- retry_sleep=float(retry_sleep),
101
  )
102
  self._three_stage = ThreeStageRetriever(
103
- self.client,
104
- collection_name=self.collection_name,
105
- request_timeout=int(request_timeout),
106
- max_retries=int(max_retries),
107
- retry_sleep=float(retry_sleep),
108
  )
109
  self._single_stage = SingleStageRetriever(
110
- self.client,
111
- collection_name=self.collection_name,
112
- request_timeout=int(request_timeout),
 
 
113
  )
114
 
115
  def build_filter(
@@ -139,14 +157,10 @@ class MultiVectorRetriever:
139
  return_embeddings: bool = False,
140
  ) -> List[Dict[str, Any]]:
141
  q = self.embedder.embed_query(query)
142
- try:
143
- import torch
144
- except ImportError:
145
- torch = None
146
- if torch is not None and isinstance(q, torch.Tensor):
147
  query_embedding = q.detach().cpu().numpy()
148
  else:
149
- query_embedding = q.numpy()
150
 
151
  return self.search_embedded(
152
  query_embedding=query_embedding,
@@ -175,27 +189,17 @@ class MultiVectorRetriever:
175
  return self._single_stage.search(
176
  query_embedding=query_embedding,
177
  top_k=top_k,
178
- strategy="multi_vector",
179
- filter_obj=filter_obj,
180
- )
181
-
182
- if mode == "single_tiles":
183
- return self._single_stage.search(
184
- query_embedding=query_embedding,
185
- top_k=top_k,
186
- strategy="tiles_maxsim",
187
  filter_obj=filter_obj,
 
188
  )
189
-
190
- if mode == "single_global":
191
  return self._single_stage.search(
192
  query_embedding=query_embedding,
193
  top_k=top_k,
194
- strategy="pooled_global",
195
  filter_obj=filter_obj,
 
196
  )
197
-
198
- if mode == "two_stage":
199
  return self._two_stage.search_server_side(
200
  query_embedding=query_embedding,
201
  top_k=top_k,
@@ -203,18 +207,14 @@ class MultiVectorRetriever:
203
  filter_obj=filter_obj,
204
  stage1_mode=stage1_mode,
205
  )
206
-
207
- if mode == "three_stage":
208
- s1 = int(stage1_k) if stage1_k is not None else 1000
209
- s2 = int(stage2_k) if stage2_k is not None else 300
210
  return self._three_stage.search_server_side(
211
  query_embedding=query_embedding,
212
  top_k=top_k,
213
- stage1_k=s1,
214
- stage2_k=s2,
215
  filter_obj=filter_obj,
 
216
  )
217
-
218
- raise ValueError(f"Unknown mode: {mode}")
219
-
220
-
 
2
  from typing import Any, Dict, List, Optional
3
  from urllib.parse import urlparse
4
 
5
+ import numpy as np
6
+ import torch
7
+
8
+ try:
9
+ from dotenv import load_dotenv
10
+
11
+ DOTENV_AVAILABLE = True
12
+ except ImportError:
13
+ DOTENV_AVAILABLE = False
14
+ load_dotenv = None
15
+
16
+ try:
17
+ from qdrant_client import QdrantClient
18
+
19
+ QDRANT_AVAILABLE = True
20
+ except ImportError:
21
+ QDRANT_AVAILABLE = False
22
+ QdrantClient = None
23
+
24
  from visual_rag.embedding.visual_embedder import VisualEmbedder
25
  from visual_rag.retrieval.single_stage import SingleStageRetriever
 
26
  from visual_rag.retrieval.three_stage import ThreeStageRetriever
27
+ from visual_rag.retrieval.two_stage import TwoStageRetriever
28
 
29
 
30
  class MultiVectorRetriever:
31
  @staticmethod
32
  def _maybe_load_dotenv() -> None:
33
+ if not DOTENV_AVAILABLE:
 
 
34
  return
35
  if os.path.exists(".env"):
36
  load_dotenv(".env")
 
50
  ):
51
  if qdrant_client is None:
52
  self._maybe_load_dotenv()
53
+ if not QDRANT_AVAILABLE:
 
 
54
  raise ImportError(
55
  "Qdrant client not installed. Install with: pip install visual-rag-toolkit[qdrant]"
56
+ )
57
 
58
  qdrant_url = (
59
  qdrant_url
 
 
60
  or os.getenv("QDRANT_URL")
61
+ or os.getenv("SIGIR_QDRANT_URL") # legacy
62
  )
63
  if not qdrant_url:
64
  raise ValueError(
65
+ "QDRANT_URL is required (pass qdrant_url or set env var)."
 
66
  )
67
 
68
  qdrant_api_key = (
69
  qdrant_api_key
 
 
 
70
  or os.getenv("QDRANT_API_KEY")
71
+ or os.getenv("SIGIR_QDRANT_KEY") # legacy
72
  )
73
 
74
  grpc_port = None
75
  if prefer_grpc:
76
  try:
77
+ parsed = urlparse(qdrant_url)
78
+ port = parsed.port
79
+ if port == 6333:
80
  grpc_port = 6334
81
  except Exception:
82
+ pass
83
+
84
  def _make_client(use_grpc: bool):
85
  return QdrantClient(
86
  url=qdrant_url,
87
  api_key=qdrant_api_key,
88
+ timeout=request_timeout,
89
  prefer_grpc=bool(use_grpc),
90
  grpc_port=grpc_port,
 
91
  check_compatibility=False,
92
  )
93
 
94
+ client = _make_client(prefer_grpc)
95
  if prefer_grpc:
96
  try:
97
+ _ = client.get_collections()
98
  except Exception as e:
99
  msg = str(e)
100
  if "StatusCode.PERMISSION_DENIED" in msg or "http2 header with status: 403" in msg:
101
+ client = _make_client(False)
102
  else:
103
  raise
104
+ qdrant_client = client
105
 
106
  self.client = qdrant_client
107
  self.collection_name = collection_name
108
+
109
  self.embedder = embedder or VisualEmbedder(model_name=model_name)
110
 
111
  self._two_stage = TwoStageRetriever(
112
+ qdrant_client=qdrant_client,
113
+ collection_name=collection_name,
114
+ request_timeout=request_timeout,
115
+ max_retries=max_retries,
116
+ retry_sleep=retry_sleep,
117
  )
118
  self._three_stage = ThreeStageRetriever(
119
+ qdrant_client=qdrant_client,
120
+ collection_name=collection_name,
121
+ request_timeout=request_timeout,
122
+ max_retries=max_retries,
123
+ retry_sleep=retry_sleep,
124
  )
125
  self._single_stage = SingleStageRetriever(
126
+ qdrant_client=qdrant_client,
127
+ collection_name=collection_name,
128
+ request_timeout=request_timeout,
129
+ max_retries=max_retries,
130
+ retry_sleep=retry_sleep,
131
  )
132
 
133
  def build_filter(
 
157
  return_embeddings: bool = False,
158
  ) -> List[Dict[str, Any]]:
159
  q = self.embedder.embed_query(query)
160
+ if isinstance(q, torch.Tensor):
 
 
 
 
161
  query_embedding = q.detach().cpu().numpy()
162
  else:
163
+ query_embedding = np.asarray(q)
164
 
165
  return self.search_embedded(
166
  query_embedding=query_embedding,
 
189
  return self._single_stage.search(
190
  query_embedding=query_embedding,
191
  top_k=top_k,
 
 
 
 
 
 
 
 
 
192
  filter_obj=filter_obj,
193
+ using="initial",
194
  )
195
+ elif mode == "single_pooled":
 
196
  return self._single_stage.search(
197
  query_embedding=query_embedding,
198
  top_k=top_k,
 
199
  filter_obj=filter_obj,
200
+ using="mean_pooling",
201
  )
202
+ elif mode == "two_stage":
 
203
  return self._two_stage.search_server_side(
204
  query_embedding=query_embedding,
205
  top_k=top_k,
 
207
  filter_obj=filter_obj,
208
  stage1_mode=stage1_mode,
209
  )
210
+ elif mode == "three_stage":
 
 
 
211
  return self._three_stage.search_server_side(
212
  query_embedding=query_embedding,
213
  top_k=top_k,
214
+ stage1_k=stage1_k,
215
+ stage2_k=stage2_k,
216
  filter_obj=filter_obj,
217
+ stage1_mode=stage1_mode,
218
  )
219
+ else:
220
+ raise ValueError(f"Unknown mode: {mode}")
 
 
visual_rag/retrieval/single_stage.py CHANGED
@@ -9,7 +9,8 @@ Use when:
9
  """
10
 
11
  import logging
12
- from typing import List, Dict, Any, Optional, Union
 
13
  import numpy as np
14
  import torch
15
 
@@ -19,22 +20,22 @@ logger = logging.getLogger(__name__)
19
  class SingleStageRetriever:
20
  """
21
  Single-stage visual document retrieval using native Qdrant search.
22
-
23
  Supports strategies:
24
  - multi_vector: Native MaxSim on full embeddings (using="initial")
25
  - tiles_maxsim: Native MaxSim between query tokens and tile vectors (using="mean_pooling")
26
  - pooled_tile: Pooled query vs tile vectors (using="mean_pooling")
27
  - pooled_global: Pooled query vs global pooled doc vector (using="global_pooling")
28
-
29
  Args:
30
  qdrant_client: Connected Qdrant client
31
  collection_name: Name of the Qdrant collection
32
-
33
  Example:
34
  >>> retriever = SingleStageRetriever(client, "my_collection")
35
  >>> results = retriever.search(query, top_k=10)
36
  """
37
-
38
  def __init__(
39
  self,
40
  qdrant_client,
@@ -44,7 +45,7 @@ class SingleStageRetriever:
44
  self.client = qdrant_client
45
  self.collection_name = collection_name
46
  self.request_timeout = int(request_timeout)
47
-
48
  def search(
49
  self,
50
  query_embedding: Union[torch.Tensor, np.ndarray],
@@ -54,47 +55,47 @@ class SingleStageRetriever:
54
  ) -> List[Dict[str, Any]]:
55
  """
56
  Single-stage search with configurable strategy.
57
-
58
  Args:
59
  query_embedding: Query embeddings [num_tokens, dim]
60
  top_k: Number of results
61
  strategy: "multi_vector", "tiles_maxsim", "pooled_tile", or "pooled_global"
62
  filter_obj: Qdrant filter
63
-
64
  Returns:
65
  List of results with scores and metadata
66
  """
67
  query_np = self._to_numpy(query_embedding)
68
-
69
  if strategy == "multi_vector":
70
  # Native multi-vector MaxSim
71
  vector_name = "initial"
72
  query_vector = query_np.tolist()
73
  logger.debug(f"🎯 Multi-vector search on '{vector_name}'")
74
-
75
  elif strategy == "tiles_maxsim":
76
  # Native multi-vector MaxSim against tile vectors
77
  vector_name = "mean_pooling"
78
  query_vector = query_np.tolist()
79
  logger.debug(f"🎯 Tile MaxSim search on '{vector_name}'")
80
-
81
  elif strategy == "pooled_tile":
82
  # Tile-level pooled
83
  vector_name = "mean_pooling"
84
  query_pooled = query_np.mean(axis=0)
85
  query_vector = query_pooled.tolist()
86
  logger.debug(f"🔍 Tile-pooled search on '{vector_name}'")
87
-
88
  elif strategy == "pooled_global":
89
  # Global pooled vector (single vector)
90
  vector_name = "global_pooling"
91
  query_pooled = query_np.mean(axis=0)
92
  query_vector = query_pooled.tolist()
93
  logger.debug(f"🔍 Global-pooled search on '{vector_name}'")
94
-
95
  else:
96
  raise ValueError(f"Unknown strategy: {strategy}")
97
-
98
  results = self.client.query_points(
99
  collection_name=self.collection_name,
100
  query=query_vector,
@@ -105,7 +106,7 @@ class SingleStageRetriever:
105
  with_vectors=False,
106
  timeout=self.request_timeout,
107
  ).points
108
-
109
  return [
110
  {
111
  "id": r.id,
@@ -115,7 +116,7 @@ class SingleStageRetriever:
115
  }
116
  for r in results
117
  ]
118
-
119
  def _to_numpy(self, embedding: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
120
  """Convert embedding to numpy array."""
121
  if isinstance(embedding, torch.Tensor):
@@ -123,5 +124,3 @@ class SingleStageRetriever:
123
  return embedding.cpu().float().numpy()
124
  return embedding.cpu().numpy()
125
  return np.array(embedding, dtype=np.float32)
126
-
127
-
 
9
  """
10
 
11
  import logging
12
+ from typing import Any, Dict, List, Union
13
+
14
  import numpy as np
15
  import torch
16
 
 
20
  class SingleStageRetriever:
21
  """
22
  Single-stage visual document retrieval using native Qdrant search.
23
+
24
  Supports strategies:
25
  - multi_vector: Native MaxSim on full embeddings (using="initial")
26
  - tiles_maxsim: Native MaxSim between query tokens and tile vectors (using="mean_pooling")
27
  - pooled_tile: Pooled query vs tile vectors (using="mean_pooling")
28
  - pooled_global: Pooled query vs global pooled doc vector (using="global_pooling")
29
+
30
  Args:
31
  qdrant_client: Connected Qdrant client
32
  collection_name: Name of the Qdrant collection
33
+
34
  Example:
35
  >>> retriever = SingleStageRetriever(client, "my_collection")
36
  >>> results = retriever.search(query, top_k=10)
37
  """
38
+
39
  def __init__(
40
  self,
41
  qdrant_client,
 
45
  self.client = qdrant_client
46
  self.collection_name = collection_name
47
  self.request_timeout = int(request_timeout)
48
+
49
  def search(
50
  self,
51
  query_embedding: Union[torch.Tensor, np.ndarray],
 
55
  ) -> List[Dict[str, Any]]:
56
  """
57
  Single-stage search with configurable strategy.
58
+
59
  Args:
60
  query_embedding: Query embeddings [num_tokens, dim]
61
  top_k: Number of results
62
  strategy: "multi_vector", "tiles_maxsim", "pooled_tile", or "pooled_global"
63
  filter_obj: Qdrant filter
64
+
65
  Returns:
66
  List of results with scores and metadata
67
  """
68
  query_np = self._to_numpy(query_embedding)
69
+
70
  if strategy == "multi_vector":
71
  # Native multi-vector MaxSim
72
  vector_name = "initial"
73
  query_vector = query_np.tolist()
74
  logger.debug(f"🎯 Multi-vector search on '{vector_name}'")
75
+
76
  elif strategy == "tiles_maxsim":
77
  # Native multi-vector MaxSim against tile vectors
78
  vector_name = "mean_pooling"
79
  query_vector = query_np.tolist()
80
  logger.debug(f"🎯 Tile MaxSim search on '{vector_name}'")
81
+
82
  elif strategy == "pooled_tile":
83
  # Tile-level pooled
84
  vector_name = "mean_pooling"
85
  query_pooled = query_np.mean(axis=0)
86
  query_vector = query_pooled.tolist()
87
  logger.debug(f"🔍 Tile-pooled search on '{vector_name}'")
88
+
89
  elif strategy == "pooled_global":
90
  # Global pooled vector (single vector)
91
  vector_name = "global_pooling"
92
  query_pooled = query_np.mean(axis=0)
93
  query_vector = query_pooled.tolist()
94
  logger.debug(f"🔍 Global-pooled search on '{vector_name}'")
95
+
96
  else:
97
  raise ValueError(f"Unknown strategy: {strategy}")
98
+
99
  results = self.client.query_points(
100
  collection_name=self.collection_name,
101
  query=query_vector,
 
106
  with_vectors=False,
107
  timeout=self.request_timeout,
108
  ).points
109
+
110
  return [
111
  {
112
  "id": r.id,
 
116
  }
117
  for r in results
118
  ]
119
+
120
  def _to_numpy(self, embedding: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
121
  """Convert embedding to numpy array."""
122
  if isinstance(embedding, torch.Tensor):
 
124
  return embedding.cpu().float().numpy()
125
  return embedding.cpu().numpy()
126
  return np.array(embedding, dtype=np.float32)
 
 
visual_rag/retrieval/three_stage.py CHANGED
@@ -43,7 +43,7 @@ class ThreeStageRetriever:
43
  last_err = e
44
  if attempt >= self.max_retries - 1:
45
  break
46
- time.sleep(self.retry_sleep * (2 ** attempt))
47
  if last_err is not None:
48
  raise last_err
49
 
@@ -171,4 +171,3 @@ class ThreeStageRetriever:
171
  }
172
  )
173
  return out
174
-
 
43
  last_err = e
44
  if attempt >= self.max_retries - 1:
45
  break
46
+ time.sleep(self.retry_sleep * (2**attempt))
47
  if last_err is not None:
48
  raise last_err
49
 
 
171
  }
172
  )
173
  return out
 
visual_rag/retrieval/two_stage.py CHANGED
@@ -17,47 +17,54 @@ Research Context:
17
  """
18
 
19
  import logging
20
- from typing import List, Dict, Any, Optional, Union
 
 
21
  import numpy as np
22
  import torch
23
 
 
 
 
 
 
24
  logger = logging.getLogger(__name__)
25
 
26
 
27
  class TwoStageRetriever:
28
  """
29
  Two-stage visual document retrieval with pooling and reranking.
30
-
31
  Stage 1 (Prefetch):
32
  Uses tile-level mean-pooled vectors for fast HNSW search.
33
  Retrieves prefetch_k candidates (e.g., 100-500).
34
-
35
  Stage 2 (Rerank):
36
  Fetches full multi-vector embeddings for candidates.
37
  Computes exact MaxSim scores for precise ranking.
38
  Returns top_k results (e.g., 10).
39
-
40
  Args:
41
  qdrant_client: Connected Qdrant client
42
  collection_name: Name of the Qdrant collection
43
  full_vector_name: Name of full multi-vector field (default: "initial")
44
  pooled_vector_name: Name of pooled vector field (default: "mean_pooling")
45
-
46
  Example:
47
  >>> retriever = TwoStageRetriever(client, "my_collection")
48
- >>>
49
  >>> # Two-stage search: prefetch 200, return top 10
50
  >>> results = retriever.search(
51
  ... query_embedding=query,
52
  ... top_k=10,
53
  ... prefetch_k=200,
54
  ... )
55
- >>>
56
  >>> # Compare latency:
57
  >>> # Full MaxSim (1000 docs): ~500ms
58
  >>> # Two-stage (200→10): ~50ms
59
  """
60
-
61
  def __init__(
62
  self,
63
  qdrant_client,
@@ -81,8 +88,6 @@ class TwoStageRetriever:
81
  self.retry_sleep = float(retry_sleep)
82
 
83
  def _retry_call(self, fn):
84
- import time
85
-
86
  last_err = None
87
  for attempt in range(self.max_retries):
88
  try:
@@ -91,7 +96,7 @@ class TwoStageRetriever:
91
  last_err = e
92
  if attempt >= self.max_retries - 1:
93
  break
94
- time.sleep(self.retry_sleep * (2 ** attempt))
95
  if last_err is not None:
96
  raise last_err
97
 
@@ -105,27 +110,25 @@ class TwoStageRetriever:
105
  ) -> List[Dict[str, Any]]:
106
  """
107
  Two-stage retrieval using Qdrant's native prefetch (all server-side).
108
-
109
  This is MUCH faster than search() because it avoids network transfer
110
  of large multi-vector embeddings. All computation happens in Qdrant.
111
-
112
  Args:
113
  query_embedding: Query embeddings [num_tokens, dim]
114
  top_k: Final number of results
115
  prefetch_k: Candidates for stage 1 (default: 10x top_k)
116
  filter_obj: Qdrant filter
117
  stage1_mode: How to do stage 1 prefetch
118
-
119
  Returns:
120
  List of results with scores
121
  """
122
- from qdrant_client.http import models
123
-
124
  query_np = self._to_numpy(query_embedding)
125
-
126
  if prefetch_k is None:
127
  prefetch_k = max(100, top_k * 10)
128
-
129
  if stage1_mode == "pooled_query_vs_tiles":
130
  prefetch_query = query_np.mean(axis=0).tolist()
131
  prefetch_using = self.pooled_vector_name
@@ -143,9 +146,9 @@ class TwoStageRetriever:
143
  prefetch_using = self.global_vector_name
144
  else:
145
  raise ValueError(f"Unknown stage1_mode: {stage1_mode}")
146
-
147
  rerank_query = query_np.tolist()
148
-
149
  def _do_query():
150
  return self.client.query_points(
151
  collection_name=self.collection_name,
@@ -154,9 +157,9 @@ class TwoStageRetriever:
154
  limit=top_k,
155
  query_filter=filter_obj,
156
  with_payload=True,
157
- search_params=models.SearchParams(exact=True),
158
  prefetch=[
159
- models.Prefetch(
160
  query=prefetch_query,
161
  using=prefetch_using,
162
  limit=prefetch_k,
@@ -164,9 +167,9 @@ class TwoStageRetriever:
164
  ],
165
  timeout=self.request_timeout,
166
  ).points
167
-
168
  results = self._retry_call(_do_query)
169
-
170
  return [
171
  {
172
  "id": r.id,
@@ -177,7 +180,7 @@ class TwoStageRetriever:
177
  }
178
  for r in results
179
  ]
180
-
181
  def search(
182
  self,
183
  query_embedding: Union[torch.Tensor, np.ndarray],
@@ -190,7 +193,7 @@ class TwoStageRetriever:
190
  ) -> List[Dict[str, Any]]:
191
  """
192
  Two-stage retrieval: prefetch with pooling, rerank with MaxSim.
193
-
194
  Args:
195
  query_embedding: Query embeddings [num_tokens, dim]
196
  top_k: Final number of results to return
@@ -202,7 +205,7 @@ class TwoStageRetriever:
202
  - "pooled_query_vs_tiles": pool query to 1×dim and search tile vectors (using="mean_pooling")
203
  - "tokens_vs_tiles": search tile vectors with full query tokens (using="mean_pooling")
204
  - "pooled_query_vs_global": pool query to 1×dim and search global pooled doc vectors (using="global_pooling")
205
-
206
  Returns:
207
  List of results with scores and metadata:
208
  [
@@ -218,11 +221,11 @@ class TwoStageRetriever:
218
  """
219
  # Convert to numpy
220
  query_np = self._to_numpy(query_embedding)
221
-
222
  # Auto-set prefetch_k
223
  if prefetch_k is None:
224
  prefetch_k = max(100, top_k * 10)
225
-
226
  # Stage 1: Prefetch with pooled vectors
227
  logger.info(f"🔍 Stage 1: Prefetching {prefetch_k} candidates ({stage1_mode})")
228
  candidates = self._stage1_prefetch(
@@ -231,16 +234,16 @@ class TwoStageRetriever:
231
  filter_obj=filter_obj,
232
  stage1_mode=stage1_mode,
233
  )
234
-
235
  if not candidates:
236
  logger.warning("No candidates found in stage 1")
237
  return []
238
-
239
  logger.info(f"✅ Stage 1: Retrieved {len(candidates)} candidates")
240
-
241
  # Stage 2: Rerank with full embeddings
242
  if use_reranking and len(candidates) > top_k:
243
- logger.info(f"🎯 Stage 2: Reranking with MaxSim...")
244
  results = self._stage2_rerank(
245
  query_np=query_np,
246
  candidates=candidates,
@@ -254,9 +257,9 @@ class TwoStageRetriever:
254
  for r in results:
255
  r["score_final"] = r["score_stage1"]
256
  logger.info(f"⏭️ Skipping reranking, returning top {len(results)}")
257
-
258
  return results
259
-
260
  def search_single_stage(
261
  self,
262
  query_embedding: Union[torch.Tensor, np.ndarray],
@@ -266,18 +269,18 @@ class TwoStageRetriever:
266
  ) -> List[Dict[str, Any]]:
267
  """
268
  Single-stage search (either pooled or full multi-vector).
269
-
270
  Args:
271
  query_embedding: Query embeddings
272
  top_k: Number of results
273
  filter_obj: Qdrant filter
274
  use_pooling: Use pooled vectors (faster) or full (more accurate)
275
-
276
  Returns:
277
  List of results
278
  """
279
  query_np = self._to_numpy(query_embedding)
280
-
281
  if use_pooling:
282
  # Pool query and search pooled vectors
283
  query_pooled = query_np.mean(axis=0)
@@ -289,7 +292,7 @@ class TwoStageRetriever:
289
  vector_name = self.full_vector_name
290
  query_vector = query_np.tolist()
291
  logger.info(f"🎯 Multi-vector search: {vector_name}")
292
-
293
  results = self.client.query_points(
294
  collection_name=self.collection_name,
295
  query=query_vector,
@@ -300,7 +303,7 @@ class TwoStageRetriever:
300
  with_vectors=False,
301
  timeout=120,
302
  ).points
303
-
304
  return [
305
  {
306
  "id": r.id,
@@ -310,7 +313,7 @@ class TwoStageRetriever:
310
  }
311
  for r in results
312
  ]
313
-
314
  def _stage1_prefetch(
315
  self,
316
  query_np: np.ndarray,
@@ -330,7 +333,7 @@ class TwoStageRetriever:
330
  vector_name = self.global_vector_name
331
  else:
332
  raise ValueError(f"Unknown stage1_mode: {stage1_mode}")
333
-
334
  def _do_query():
335
  return self.client.query_points(
336
  collection_name=self.collection_name,
@@ -344,7 +347,7 @@ class TwoStageRetriever:
344
  ).points
345
 
346
  results = self._retry_call(_do_query)
347
-
348
  return [
349
  {
350
  "id": r.id,
@@ -353,7 +356,7 @@ class TwoStageRetriever:
353
  }
354
  for r in results
355
  ]
356
-
357
  def _stage2_rerank(
358
  self,
359
  query_np: np.ndarray,
@@ -362,11 +365,9 @@ class TwoStageRetriever:
362
  return_embeddings: bool = False,
363
  ) -> List[Dict[str, Any]]:
364
  """Stage 2: Rerank with full multi-vector MaxSim scoring."""
365
- from visual_rag.embedding.pooling import compute_maxsim_score
366
-
367
  # Fetch full embeddings for candidates
368
  candidate_ids = [c["id"] for c in candidates]
369
-
370
  # Retrieve points with vectors
371
  def _do_retrieve():
372
  return self.client.retrieve(
@@ -378,7 +379,7 @@ class TwoStageRetriever:
378
  )
379
 
380
  points = self._retry_call(_do_retrieve)
381
-
382
  # Build ID to embedding map
383
  id_to_embedding = {}
384
  for point in points:
@@ -386,13 +387,13 @@ class TwoStageRetriever:
386
  id_to_embedding[point.id] = np.array(
387
  point.vector[self.full_vector_name], dtype=np.float32
388
  )
389
-
390
  # Compute MaxSim scores
391
  reranked = []
392
  for candidate in candidates:
393
  point_id = candidate["id"]
394
  doc_embedding = id_to_embedding.get(point_id)
395
-
396
  if doc_embedding is None:
397
  # Fallback to stage 1 score
398
  candidate["score_stage2"] = candidate["score_stage1"]
@@ -402,17 +403,17 @@ class TwoStageRetriever:
402
  maxsim_score = compute_maxsim_score(query_np, doc_embedding)
403
  candidate["score_stage2"] = maxsim_score
404
  candidate["score_final"] = maxsim_score
405
-
406
  if return_embeddings:
407
  candidate["embedding"] = doc_embedding
408
-
409
  reranked.append(candidate)
410
-
411
  # Sort by final score (descending)
412
  reranked.sort(key=lambda x: x["score_final"], reverse=True)
413
-
414
  return reranked[:top_k]
415
-
416
  def _to_numpy(self, embedding: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
417
  """Convert embedding to numpy array."""
418
  if isinstance(embedding, torch.Tensor):
@@ -420,7 +421,7 @@ class TwoStageRetriever:
420
  return embedding.cpu().float().numpy()
421
  return embedding.cpu().numpy()
422
  return np.array(embedding, dtype=np.float32)
423
-
424
  def build_filter(
425
  self,
426
  year: Optional[Any] = None,
@@ -431,60 +432,38 @@ class TwoStageRetriever:
431
  ):
432
  """
433
  Build Qdrant filter from parameters.
434
-
435
  Supports single values or lists (using MatchAny).
436
  """
437
- from qdrant_client.models import Filter, FieldCondition, MatchValue, MatchAny
438
-
439
  conditions = []
440
-
441
  if year is not None:
442
  if isinstance(year, list):
443
  year_values = [int(y) if isinstance(y, str) else y for y in year]
444
- conditions.append(
445
- FieldCondition(key="year", match=MatchAny(any=year_values))
446
- )
447
  else:
448
  year_value = int(year) if isinstance(year, str) else year
449
- conditions.append(
450
- FieldCondition(key="year", match=MatchValue(value=year_value))
451
- )
452
-
453
  if source is not None:
454
  if isinstance(source, list):
455
- conditions.append(
456
- FieldCondition(key="source", match=MatchAny(any=source))
457
- )
458
  else:
459
- conditions.append(
460
- FieldCondition(key="source", match=MatchValue(value=source))
461
- )
462
-
463
  if district is not None:
464
  if isinstance(district, list):
465
- conditions.append(
466
- FieldCondition(key="district", match=MatchAny(any=district))
467
- )
468
  else:
469
- conditions.append(
470
- FieldCondition(key="district", match=MatchValue(value=district))
471
- )
472
-
473
  if filename is not None:
474
  if isinstance(filename, list):
475
- conditions.append(
476
- FieldCondition(key="filename", match=MatchAny(any=filename))
477
- )
478
  else:
479
- conditions.append(
480
- FieldCondition(key="filename", match=MatchValue(value=filename))
481
- )
482
-
483
- if has_text is not None:
484
- conditions.append(
485
- FieldCondition(key="has_text", match=MatchValue(value=has_text))
486
- )
487
-
488
- return Filter(must=conditions) if conditions else None
489
 
 
 
490
 
 
 
17
  """
18
 
19
  import logging
20
+ import time
21
+ from typing import Any, Dict, List, Optional, Union
22
+
23
  import numpy as np
24
  import torch
25
 
26
+ from qdrant_client.http import models as qdrant_models
27
+ from qdrant_client.models import FieldCondition, Filter, MatchAny, MatchValue
28
+
29
+ from visual_rag.embedding.pooling import compute_maxsim_score
30
+
31
  logger = logging.getLogger(__name__)
32
 
33
 
34
  class TwoStageRetriever:
35
  """
36
  Two-stage visual document retrieval with pooling and reranking.
37
+
38
  Stage 1 (Prefetch):
39
  Uses tile-level mean-pooled vectors for fast HNSW search.
40
  Retrieves prefetch_k candidates (e.g., 100-500).
41
+
42
  Stage 2 (Rerank):
43
  Fetches full multi-vector embeddings for candidates.
44
  Computes exact MaxSim scores for precise ranking.
45
  Returns top_k results (e.g., 10).
46
+
47
  Args:
48
  qdrant_client: Connected Qdrant client
49
  collection_name: Name of the Qdrant collection
50
  full_vector_name: Name of full multi-vector field (default: "initial")
51
  pooled_vector_name: Name of pooled vector field (default: "mean_pooling")
52
+
53
  Example:
54
  >>> retriever = TwoStageRetriever(client, "my_collection")
55
+ >>>
56
  >>> # Two-stage search: prefetch 200, return top 10
57
  >>> results = retriever.search(
58
  ... query_embedding=query,
59
  ... top_k=10,
60
  ... prefetch_k=200,
61
  ... )
62
+ >>>
63
  >>> # Compare latency:
64
  >>> # Full MaxSim (1000 docs): ~500ms
65
  >>> # Two-stage (200→10): ~50ms
66
  """
67
+
68
  def __init__(
69
  self,
70
  qdrant_client,
 
88
  self.retry_sleep = float(retry_sleep)
89
 
90
  def _retry_call(self, fn):
 
 
91
  last_err = None
92
  for attempt in range(self.max_retries):
93
  try:
 
96
  last_err = e
97
  if attempt >= self.max_retries - 1:
98
  break
99
+ time.sleep(self.retry_sleep * (2**attempt))
100
  if last_err is not None:
101
  raise last_err
102
 
 
110
  ) -> List[Dict[str, Any]]:
111
  """
112
  Two-stage retrieval using Qdrant's native prefetch (all server-side).
113
+
114
  This is MUCH faster than search() because it avoids network transfer
115
  of large multi-vector embeddings. All computation happens in Qdrant.
116
+
117
  Args:
118
  query_embedding: Query embeddings [num_tokens, dim]
119
  top_k: Final number of results
120
  prefetch_k: Candidates for stage 1 (default: 10x top_k)
121
  filter_obj: Qdrant filter
122
  stage1_mode: How to do stage 1 prefetch
123
+
124
  Returns:
125
  List of results with scores
126
  """
 
 
127
  query_np = self._to_numpy(query_embedding)
128
+
129
  if prefetch_k is None:
130
  prefetch_k = max(100, top_k * 10)
131
+
132
  if stage1_mode == "pooled_query_vs_tiles":
133
  prefetch_query = query_np.mean(axis=0).tolist()
134
  prefetch_using = self.pooled_vector_name
 
146
  prefetch_using = self.global_vector_name
147
  else:
148
  raise ValueError(f"Unknown stage1_mode: {stage1_mode}")
149
+
150
  rerank_query = query_np.tolist()
151
+
152
  def _do_query():
153
  return self.client.query_points(
154
  collection_name=self.collection_name,
 
157
  limit=top_k,
158
  query_filter=filter_obj,
159
  with_payload=True,
160
+ search_params=qdrant_models.SearchParams(exact=True),
161
  prefetch=[
162
+ qdrant_models.Prefetch(
163
  query=prefetch_query,
164
  using=prefetch_using,
165
  limit=prefetch_k,
 
167
  ],
168
  timeout=self.request_timeout,
169
  ).points
170
+
171
  results = self._retry_call(_do_query)
172
+
173
  return [
174
  {
175
  "id": r.id,
 
180
  }
181
  for r in results
182
  ]
183
+
184
  def search(
185
  self,
186
  query_embedding: Union[torch.Tensor, np.ndarray],
 
193
  ) -> List[Dict[str, Any]]:
194
  """
195
  Two-stage retrieval: prefetch with pooling, rerank with MaxSim.
196
+
197
  Args:
198
  query_embedding: Query embeddings [num_tokens, dim]
199
  top_k: Final number of results to return
 
205
  - "pooled_query_vs_tiles": pool query to 1×dim and search tile vectors (using="mean_pooling")
206
  - "tokens_vs_tiles": search tile vectors with full query tokens (using="mean_pooling")
207
  - "pooled_query_vs_global": pool query to 1×dim and search global pooled doc vectors (using="global_pooling")
208
+
209
  Returns:
210
  List of results with scores and metadata:
211
  [
 
221
  """
222
  # Convert to numpy
223
  query_np = self._to_numpy(query_embedding)
224
+
225
  # Auto-set prefetch_k
226
  if prefetch_k is None:
227
  prefetch_k = max(100, top_k * 10)
228
+
229
  # Stage 1: Prefetch with pooled vectors
230
  logger.info(f"🔍 Stage 1: Prefetching {prefetch_k} candidates ({stage1_mode})")
231
  candidates = self._stage1_prefetch(
 
234
  filter_obj=filter_obj,
235
  stage1_mode=stage1_mode,
236
  )
237
+
238
  if not candidates:
239
  logger.warning("No candidates found in stage 1")
240
  return []
241
+
242
  logger.info(f"✅ Stage 1: Retrieved {len(candidates)} candidates")
243
+
244
  # Stage 2: Rerank with full embeddings
245
  if use_reranking and len(candidates) > top_k:
246
+ logger.info("🎯 Stage 2: Reranking with MaxSim...")
247
  results = self._stage2_rerank(
248
  query_np=query_np,
249
  candidates=candidates,
 
257
  for r in results:
258
  r["score_final"] = r["score_stage1"]
259
  logger.info(f"⏭️ Skipping reranking, returning top {len(results)}")
260
+
261
  return results
262
+
263
  def search_single_stage(
264
  self,
265
  query_embedding: Union[torch.Tensor, np.ndarray],
 
269
  ) -> List[Dict[str, Any]]:
270
  """
271
  Single-stage search (either pooled or full multi-vector).
272
+
273
  Args:
274
  query_embedding: Query embeddings
275
  top_k: Number of results
276
  filter_obj: Qdrant filter
277
  use_pooling: Use pooled vectors (faster) or full (more accurate)
278
+
279
  Returns:
280
  List of results
281
  """
282
  query_np = self._to_numpy(query_embedding)
283
+
284
  if use_pooling:
285
  # Pool query and search pooled vectors
286
  query_pooled = query_np.mean(axis=0)
 
292
  vector_name = self.full_vector_name
293
  query_vector = query_np.tolist()
294
  logger.info(f"🎯 Multi-vector search: {vector_name}")
295
+
296
  results = self.client.query_points(
297
  collection_name=self.collection_name,
298
  query=query_vector,
 
303
  with_vectors=False,
304
  timeout=120,
305
  ).points
306
+
307
  return [
308
  {
309
  "id": r.id,
 
313
  }
314
  for r in results
315
  ]
316
+
317
  def _stage1_prefetch(
318
  self,
319
  query_np: np.ndarray,
 
333
  vector_name = self.global_vector_name
334
  else:
335
  raise ValueError(f"Unknown stage1_mode: {stage1_mode}")
336
+
337
  def _do_query():
338
  return self.client.query_points(
339
  collection_name=self.collection_name,
 
347
  ).points
348
 
349
  results = self._retry_call(_do_query)
350
+
351
  return [
352
  {
353
  "id": r.id,
 
356
  }
357
  for r in results
358
  ]
359
+
360
  def _stage2_rerank(
361
  self,
362
  query_np: np.ndarray,
 
365
  return_embeddings: bool = False,
366
  ) -> List[Dict[str, Any]]:
367
  """Stage 2: Rerank with full multi-vector MaxSim scoring."""
 
 
368
  # Fetch full embeddings for candidates
369
  candidate_ids = [c["id"] for c in candidates]
370
+
371
  # Retrieve points with vectors
372
  def _do_retrieve():
373
  return self.client.retrieve(
 
379
  )
380
 
381
  points = self._retry_call(_do_retrieve)
382
+
383
  # Build ID to embedding map
384
  id_to_embedding = {}
385
  for point in points:
 
387
  id_to_embedding[point.id] = np.array(
388
  point.vector[self.full_vector_name], dtype=np.float32
389
  )
390
+
391
  # Compute MaxSim scores
392
  reranked = []
393
  for candidate in candidates:
394
  point_id = candidate["id"]
395
  doc_embedding = id_to_embedding.get(point_id)
396
+
397
  if doc_embedding is None:
398
  # Fallback to stage 1 score
399
  candidate["score_stage2"] = candidate["score_stage1"]
 
403
  maxsim_score = compute_maxsim_score(query_np, doc_embedding)
404
  candidate["score_stage2"] = maxsim_score
405
  candidate["score_final"] = maxsim_score
406
+
407
  if return_embeddings:
408
  candidate["embedding"] = doc_embedding
409
+
410
  reranked.append(candidate)
411
+
412
  # Sort by final score (descending)
413
  reranked.sort(key=lambda x: x["score_final"], reverse=True)
414
+
415
  return reranked[:top_k]
416
+
417
  def _to_numpy(self, embedding: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
418
  """Convert embedding to numpy array."""
419
  if isinstance(embedding, torch.Tensor):
 
421
  return embedding.cpu().float().numpy()
422
  return embedding.cpu().numpy()
423
  return np.array(embedding, dtype=np.float32)
424
+
425
  def build_filter(
426
  self,
427
  year: Optional[Any] = None,
 
432
  ):
433
  """
434
  Build Qdrant filter from parameters.
435
+
436
  Supports single values or lists (using MatchAny).
437
  """
 
 
438
  conditions = []
439
+
440
  if year is not None:
441
  if isinstance(year, list):
442
  year_values = [int(y) if isinstance(y, str) else y for y in year]
443
+ conditions.append(FieldCondition(key="year", match=MatchAny(any=year_values)))
 
 
444
  else:
445
  year_value = int(year) if isinstance(year, str) else year
446
+ conditions.append(FieldCondition(key="year", match=MatchValue(value=year_value)))
447
+
 
 
448
  if source is not None:
449
  if isinstance(source, list):
450
+ conditions.append(FieldCondition(key="source", match=MatchAny(any=source)))
 
 
451
  else:
452
+ conditions.append(FieldCondition(key="source", match=MatchValue(value=source)))
453
+
 
 
454
  if district is not None:
455
  if isinstance(district, list):
456
+ conditions.append(FieldCondition(key="district", match=MatchAny(any=district)))
 
 
457
  else:
458
+ conditions.append(FieldCondition(key="district", match=MatchValue(value=district)))
459
+
 
 
460
  if filename is not None:
461
  if isinstance(filename, list):
462
+ conditions.append(FieldCondition(key="filename", match=MatchAny(any=filename)))
 
 
463
  else:
464
+ conditions.append(FieldCondition(key="filename", match=MatchValue(value=filename)))
 
 
 
 
 
 
 
 
 
465
 
466
+ if has_text is not None:
467
+ conditions.append(FieldCondition(key="has_text", match=MatchValue(value=has_text)))
468
 
469
+ return Filter(must=conditions) if conditions else None
visual_rag/visualization/__init__.py CHANGED
@@ -7,8 +7,8 @@ This module provides:
7
  """
8
 
9
  from visual_rag.visualization.saliency import (
10
- generate_saliency_map,
11
  create_saliency_overlay,
 
12
  visualize_search_results,
13
  )
14
 
 
7
  """
8
 
9
  from visual_rag.visualization.saliency import (
 
10
  create_saliency_overlay,
11
+ generate_saliency_map,
12
  visualize_search_results,
13
  )
14
 
visual_rag/visualization/saliency.py CHANGED
@@ -5,10 +5,11 @@ Generates attention/saliency maps to visualize which parts of documents
5
  are most relevant to a query.
6
  """
7
 
8
- import numpy as np
9
- from PIL import Image, ImageDraw, ImageFont
10
- from typing import List, Dict, Any, Optional, Tuple, Union
11
  import logging
 
 
 
 
12
 
13
  logger = logging.getLogger(__name__)
14
 
@@ -24,9 +25,9 @@ def generate_saliency_map(
24
  ) -> Tuple[Image.Image, np.ndarray]:
25
  """
26
  Generate saliency map showing which parts of the image match the query.
27
-
28
  Computes patch-level relevance scores and overlays them on the image.
29
-
30
  Args:
31
  query_embedding: Query embeddings [num_query_tokens, dim]
32
  doc_embedding: Document visual embeddings [num_visual_tokens, dim]
@@ -35,10 +36,10 @@ def generate_saliency_map(
35
  colormap: Matplotlib colormap name (Reds, viridis, jet, etc.)
36
  alpha: Overlay transparency (0-1)
37
  threshold_percentile: Only highlight patches above this percentile
38
-
39
  Returns:
40
  Tuple of (annotated_image, patch_scores)
41
-
42
  Example:
43
  >>> query = embedder.embed_query("budget allocation")
44
  >>> doc = visual_embedding # From embed_images
@@ -51,57 +52,57 @@ def generate_saliency_map(
51
  >>> annotated.save("saliency.png")
52
  """
53
  # Ensure numpy arrays
54
- if hasattr(query_embedding, 'numpy'):
55
  query_np = query_embedding.numpy()
56
- elif hasattr(query_embedding, 'cpu'):
57
  query_np = query_embedding.cpu().numpy()
58
  else:
59
  query_np = np.array(query_embedding, dtype=np.float32)
60
-
61
- if hasattr(doc_embedding, 'numpy'):
62
  doc_np = doc_embedding.numpy()
63
- elif hasattr(doc_embedding, 'cpu'):
64
  doc_np = doc_embedding.cpu().numpy()
65
  else:
66
  doc_np = np.array(doc_embedding, dtype=np.float32)
67
-
68
  # Normalize embeddings
69
  query_norm = query_np / (np.linalg.norm(query_np, axis=1, keepdims=True) + 1e-8)
70
  doc_norm = doc_np / (np.linalg.norm(doc_np, axis=1, keepdims=True) + 1e-8)
71
-
72
  # Compute similarity matrix: [num_query, num_doc]
73
  similarity_matrix = np.dot(query_norm, doc_norm.T)
74
-
75
  # Get max similarity per document patch (best match from any query token)
76
  patch_scores = similarity_matrix.max(axis=0)
77
-
78
  # Normalize to [0, 1]
79
  score_min, score_max = patch_scores.min(), patch_scores.max()
80
  if score_max - score_min > 1e-8:
81
  patch_scores_norm = (patch_scores - score_min) / (score_max - score_min)
82
  else:
83
  patch_scores_norm = np.zeros_like(patch_scores)
84
-
85
  # Determine grid dimensions
86
  if token_info and token_info.get("n_rows") and token_info.get("n_cols"):
87
  n_rows = token_info["n_rows"]
88
  n_cols = token_info["n_cols"]
89
  num_tiles = n_rows * n_cols + 1 # +1 for global tile
90
  patches_per_tile = 64 # ColSmol standard
91
-
92
  # Reshape to tile grid (excluding global tile)
93
  try:
94
  # Skip global tile patches at the end
95
  tile_patches = num_tiles * patches_per_tile
96
  if len(patch_scores_norm) >= tile_patches:
97
- grid_patches = patch_scores_norm[:n_rows * n_cols * patches_per_tile]
98
  else:
99
  grid_patches = patch_scores_norm
100
-
101
  # Reshape: [tiles * patches_per_tile] -> [tiles, patches_per_tile]
102
  # Then mean per tile
103
  num_grid_tiles = n_rows * n_cols
104
- grid_patches = grid_patches[:num_grid_tiles * patches_per_tile]
105
  tile_scores = grid_patches.reshape(num_grid_tiles, patches_per_tile).mean(axis=1)
106
  tile_scores = tile_scores.reshape(n_rows, n_cols)
107
  except Exception as e:
@@ -110,7 +111,7 @@ def generate_saliency_map(
110
  else:
111
  tile_scores = None
112
  n_rows = n_cols = None
113
-
114
  # Create overlay
115
  annotated = create_saliency_overlay(
116
  image=image,
@@ -121,7 +122,7 @@ def generate_saliency_map(
121
  grid_rows=n_rows,
122
  grid_cols=n_cols,
123
  )
124
-
125
  return annotated, patch_scores
126
 
127
 
@@ -136,7 +137,7 @@ def create_saliency_overlay(
136
  ) -> Image.Image:
137
  """
138
  Create colored overlay on image based on scores.
139
-
140
  Args:
141
  image: Base PIL Image
142
  scores: Score array - 1D [num_patches] or 2D [rows, cols]
@@ -144,7 +145,7 @@ def create_saliency_overlay(
144
  alpha: Overlay transparency
145
  threshold_percentile: Only color patches above this percentile
146
  grid_rows, grid_cols: Grid dimensions (auto-detected if not provided)
147
-
148
  Returns:
149
  Annotated PIL Image
150
  """
@@ -153,10 +154,10 @@ def create_saliency_overlay(
153
  except ImportError:
154
  logger.warning("matplotlib not installed, returning original image")
155
  return image
156
-
157
  img_array = np.array(image)
158
  h, w = img_array.shape[:2]
159
-
160
  # Handle 2D scores (tile grid)
161
  if scores.ndim == 2:
162
  rows, cols = scores.shape
@@ -171,58 +172,58 @@ def create_saliency_overlay(
171
  aspect = w / h
172
  cols = int(np.sqrt(num_patches * aspect))
173
  rows = max(1, num_patches // cols)
174
- scores = scores[:rows * cols].reshape(rows, cols)
175
  else:
176
  # Auto-estimate grid
177
  num_patches = len(scores) if scores.ndim == 1 else scores.size
178
  aspect = w / h
179
  cols = max(1, int(np.sqrt(num_patches * aspect)))
180
  rows = max(1, num_patches // cols)
181
-
182
  if rows * cols > len(scores) if scores.ndim == 1 else scores.size:
183
  cols = max(1, cols - 1)
184
-
185
  if scores.ndim == 1:
186
- scores = scores[:rows * cols].reshape(rows, cols)
187
-
188
  # Get colormap
189
  cmap = plt.cm.get_cmap(colormap)
190
-
191
  # Calculate threshold
192
  threshold = np.percentile(scores, threshold_percentile)
193
-
194
  # Calculate cell dimensions
195
  cell_h = h // rows
196
  cell_w = w // cols
197
-
198
  # Create RGBA overlay
199
  overlay = np.zeros((h, w, 4), dtype=np.uint8)
200
-
201
  for i in range(rows):
202
  for j in range(cols):
203
  score = scores[i, j]
204
-
205
  if score >= threshold:
206
  y1 = i * cell_h
207
  y2 = min((i + 1) * cell_h, h)
208
  x1 = j * cell_w
209
  x2 = min((j + 1) * cell_w, w)
210
-
211
  # Normalize score for coloring (above threshold)
212
  norm_score = (score - threshold) / (1.0 - threshold + 1e-8)
213
  norm_score = min(1.0, max(0.0, norm_score))
214
-
215
  # Get color
216
  color = cmap(norm_score)[:3]
217
  color_uint8 = (np.array(color) * 255).astype(np.uint8)
218
-
219
  overlay[y1:y2, x1:x2, :3] = color_uint8
220
  overlay[y1:y2, x1:x2, 3] = int(alpha * 255 * norm_score)
221
-
222
  # Blend with original
223
  overlay_img = Image.fromarray(overlay, "RGBA")
224
  result = Image.alpha_composite(image.convert("RGBA"), overlay_img)
225
-
226
  return result.convert("RGB")
227
 
228
 
@@ -237,7 +238,7 @@ def visualize_search_results(
237
  ) -> Optional[Image.Image]:
238
  """
239
  Visualize search results as a grid of images with scores.
240
-
241
  Args:
242
  query: Original query text
243
  results: List of search results with 'payload' containing 'page' (image URL/base64)
@@ -246,7 +247,7 @@ def visualize_search_results(
246
  output_path: Path to save visualization (optional)
247
  max_results: Maximum results to show
248
  show_saliency: Generate saliency overlays (requires query_embedding & embeddings)
249
-
250
  Returns:
251
  Combined visualization image if successful
252
  """
@@ -255,32 +256,32 @@ def visualize_search_results(
255
  except ImportError:
256
  logger.error("matplotlib required for visualization")
257
  return None
258
-
259
  results = results[:max_results]
260
  n = len(results)
261
-
262
  if n == 0:
263
  logger.warning("No results to visualize")
264
  return None
265
-
266
  fig, axes = plt.subplots(1, n, figsize=(4 * n, 4))
267
  if n == 1:
268
  axes = [axes]
269
-
270
  for idx, (result, ax) in enumerate(zip(results, axes)):
271
  payload = result.get("payload", {})
272
  score = result.get("score_final", result.get("score_stage1", 0))
273
-
274
  # Try to load image from payload
275
  page_data = payload.get("page", "")
276
  image = None
277
-
278
  if page_data.startswith("data:image"):
279
  # Base64 encoded
280
  try:
281
  import base64
282
  from io import BytesIO
283
-
284
  b64_data = page_data.split(",")[1]
285
  image = Image.open(BytesIO(base64.b64decode(b64_data)))
286
  except Exception as e:
@@ -290,50 +291,45 @@ def visualize_search_results(
290
  try:
291
  import urllib.request
292
  from io import BytesIO
293
-
294
  with urllib.request.urlopen(page_data, timeout=5) as response:
295
  image = Image.open(BytesIO(response.read()))
296
  except Exception as e:
297
  logger.debug(f"Could not fetch image URL: {e}")
298
-
299
  if image:
300
  ax.imshow(image)
301
  else:
302
  # Show placeholder
303
- ax.text(
304
- 0.5, 0.5, "No image",
305
- ha="center", va="center",
306
- fontsize=12, color="gray"
307
- )
308
-
309
  # Add title
310
  title = f"Rank {idx + 1}\nScore: {score:.3f}"
311
  if payload.get("filename"):
312
  title += f"\n{payload['filename'][:30]}"
313
  if payload.get("page_number") is not None:
314
  title += f" p.{payload['page_number'] + 1}"
315
-
316
  ax.set_title(title, fontsize=9)
317
  ax.axis("off")
318
-
319
  # Add query as suptitle
320
  query_display = query[:80] + "..." if len(query) > 80 else query
321
  plt.suptitle(f"Query: {query_display}", fontsize=11, fontweight="bold")
322
  plt.tight_layout()
323
-
324
  if output_path:
325
  plt.savefig(output_path, dpi=150, bbox_inches="tight")
326
  logger.info(f"💾 Saved visualization to: {output_path}")
327
-
328
  # Convert to PIL Image for return
329
  from io import BytesIO
 
330
  buf = BytesIO()
331
  plt.savefig(buf, format="png", dpi=100, bbox_inches="tight")
332
  buf.seek(0)
333
  result_image = Image.open(buf)
334
-
335
- plt.close()
336
-
337
- return result_image
338
 
 
339
 
 
 
5
  are most relevant to a query.
6
  """
7
 
 
 
 
8
  import logging
9
+ from typing import Any, Dict, List, Optional, Tuple
10
+
11
+ import numpy as np
12
+ from PIL import Image
13
 
14
  logger = logging.getLogger(__name__)
15
 
 
25
  ) -> Tuple[Image.Image, np.ndarray]:
26
  """
27
  Generate saliency map showing which parts of the image match the query.
28
+
29
  Computes patch-level relevance scores and overlays them on the image.
30
+
31
  Args:
32
  query_embedding: Query embeddings [num_query_tokens, dim]
33
  doc_embedding: Document visual embeddings [num_visual_tokens, dim]
 
36
  colormap: Matplotlib colormap name (Reds, viridis, jet, etc.)
37
  alpha: Overlay transparency (0-1)
38
  threshold_percentile: Only highlight patches above this percentile
39
+
40
  Returns:
41
  Tuple of (annotated_image, patch_scores)
42
+
43
  Example:
44
  >>> query = embedder.embed_query("budget allocation")
45
  >>> doc = visual_embedding # From embed_images
 
52
  >>> annotated.save("saliency.png")
53
  """
54
  # Ensure numpy arrays
55
+ if hasattr(query_embedding, "numpy"):
56
  query_np = query_embedding.numpy()
57
+ elif hasattr(query_embedding, "cpu"):
58
  query_np = query_embedding.cpu().numpy()
59
  else:
60
  query_np = np.array(query_embedding, dtype=np.float32)
61
+
62
+ if hasattr(doc_embedding, "numpy"):
63
  doc_np = doc_embedding.numpy()
64
+ elif hasattr(doc_embedding, "cpu"):
65
  doc_np = doc_embedding.cpu().numpy()
66
  else:
67
  doc_np = np.array(doc_embedding, dtype=np.float32)
68
+
69
  # Normalize embeddings
70
  query_norm = query_np / (np.linalg.norm(query_np, axis=1, keepdims=True) + 1e-8)
71
  doc_norm = doc_np / (np.linalg.norm(doc_np, axis=1, keepdims=True) + 1e-8)
72
+
73
  # Compute similarity matrix: [num_query, num_doc]
74
  similarity_matrix = np.dot(query_norm, doc_norm.T)
75
+
76
  # Get max similarity per document patch (best match from any query token)
77
  patch_scores = similarity_matrix.max(axis=0)
78
+
79
  # Normalize to [0, 1]
80
  score_min, score_max = patch_scores.min(), patch_scores.max()
81
  if score_max - score_min > 1e-8:
82
  patch_scores_norm = (patch_scores - score_min) / (score_max - score_min)
83
  else:
84
  patch_scores_norm = np.zeros_like(patch_scores)
85
+
86
  # Determine grid dimensions
87
  if token_info and token_info.get("n_rows") and token_info.get("n_cols"):
88
  n_rows = token_info["n_rows"]
89
  n_cols = token_info["n_cols"]
90
  num_tiles = n_rows * n_cols + 1 # +1 for global tile
91
  patches_per_tile = 64 # ColSmol standard
92
+
93
  # Reshape to tile grid (excluding global tile)
94
  try:
95
  # Skip global tile patches at the end
96
  tile_patches = num_tiles * patches_per_tile
97
  if len(patch_scores_norm) >= tile_patches:
98
+ grid_patches = patch_scores_norm[: n_rows * n_cols * patches_per_tile]
99
  else:
100
  grid_patches = patch_scores_norm
101
+
102
  # Reshape: [tiles * patches_per_tile] -> [tiles, patches_per_tile]
103
  # Then mean per tile
104
  num_grid_tiles = n_rows * n_cols
105
+ grid_patches = grid_patches[: num_grid_tiles * patches_per_tile]
106
  tile_scores = grid_patches.reshape(num_grid_tiles, patches_per_tile).mean(axis=1)
107
  tile_scores = tile_scores.reshape(n_rows, n_cols)
108
  except Exception as e:
 
111
  else:
112
  tile_scores = None
113
  n_rows = n_cols = None
114
+
115
  # Create overlay
116
  annotated = create_saliency_overlay(
117
  image=image,
 
122
  grid_rows=n_rows,
123
  grid_cols=n_cols,
124
  )
125
+
126
  return annotated, patch_scores
127
 
128
 
 
137
  ) -> Image.Image:
138
  """
139
  Create colored overlay on image based on scores.
140
+
141
  Args:
142
  image: Base PIL Image
143
  scores: Score array - 1D [num_patches] or 2D [rows, cols]
 
145
  alpha: Overlay transparency
146
  threshold_percentile: Only color patches above this percentile
147
  grid_rows, grid_cols: Grid dimensions (auto-detected if not provided)
148
+
149
  Returns:
150
  Annotated PIL Image
151
  """
 
154
  except ImportError:
155
  logger.warning("matplotlib not installed, returning original image")
156
  return image
157
+
158
  img_array = np.array(image)
159
  h, w = img_array.shape[:2]
160
+
161
  # Handle 2D scores (tile grid)
162
  if scores.ndim == 2:
163
  rows, cols = scores.shape
 
172
  aspect = w / h
173
  cols = int(np.sqrt(num_patches * aspect))
174
  rows = max(1, num_patches // cols)
175
+ scores = scores[: rows * cols].reshape(rows, cols)
176
  else:
177
  # Auto-estimate grid
178
  num_patches = len(scores) if scores.ndim == 1 else scores.size
179
  aspect = w / h
180
  cols = max(1, int(np.sqrt(num_patches * aspect)))
181
  rows = max(1, num_patches // cols)
182
+
183
  if rows * cols > len(scores) if scores.ndim == 1 else scores.size:
184
  cols = max(1, cols - 1)
185
+
186
  if scores.ndim == 1:
187
+ scores = scores[: rows * cols].reshape(rows, cols)
188
+
189
  # Get colormap
190
  cmap = plt.cm.get_cmap(colormap)
191
+
192
  # Calculate threshold
193
  threshold = np.percentile(scores, threshold_percentile)
194
+
195
  # Calculate cell dimensions
196
  cell_h = h // rows
197
  cell_w = w // cols
198
+
199
  # Create RGBA overlay
200
  overlay = np.zeros((h, w, 4), dtype=np.uint8)
201
+
202
  for i in range(rows):
203
  for j in range(cols):
204
  score = scores[i, j]
205
+
206
  if score >= threshold:
207
  y1 = i * cell_h
208
  y2 = min((i + 1) * cell_h, h)
209
  x1 = j * cell_w
210
  x2 = min((j + 1) * cell_w, w)
211
+
212
  # Normalize score for coloring (above threshold)
213
  norm_score = (score - threshold) / (1.0 - threshold + 1e-8)
214
  norm_score = min(1.0, max(0.0, norm_score))
215
+
216
  # Get color
217
  color = cmap(norm_score)[:3]
218
  color_uint8 = (np.array(color) * 255).astype(np.uint8)
219
+
220
  overlay[y1:y2, x1:x2, :3] = color_uint8
221
  overlay[y1:y2, x1:x2, 3] = int(alpha * 255 * norm_score)
222
+
223
  # Blend with original
224
  overlay_img = Image.fromarray(overlay, "RGBA")
225
  result = Image.alpha_composite(image.convert("RGBA"), overlay_img)
226
+
227
  return result.convert("RGB")
228
 
229
 
 
238
  ) -> Optional[Image.Image]:
239
  """
240
  Visualize search results as a grid of images with scores.
241
+
242
  Args:
243
  query: Original query text
244
  results: List of search results with 'payload' containing 'page' (image URL/base64)
 
247
  output_path: Path to save visualization (optional)
248
  max_results: Maximum results to show
249
  show_saliency: Generate saliency overlays (requires query_embedding & embeddings)
250
+
251
  Returns:
252
  Combined visualization image if successful
253
  """
 
256
  except ImportError:
257
  logger.error("matplotlib required for visualization")
258
  return None
259
+
260
  results = results[:max_results]
261
  n = len(results)
262
+
263
  if n == 0:
264
  logger.warning("No results to visualize")
265
  return None
266
+
267
  fig, axes = plt.subplots(1, n, figsize=(4 * n, 4))
268
  if n == 1:
269
  axes = [axes]
270
+
271
  for idx, (result, ax) in enumerate(zip(results, axes)):
272
  payload = result.get("payload", {})
273
  score = result.get("score_final", result.get("score_stage1", 0))
274
+
275
  # Try to load image from payload
276
  page_data = payload.get("page", "")
277
  image = None
278
+
279
  if page_data.startswith("data:image"):
280
  # Base64 encoded
281
  try:
282
  import base64
283
  from io import BytesIO
284
+
285
  b64_data = page_data.split(",")[1]
286
  image = Image.open(BytesIO(base64.b64decode(b64_data)))
287
  except Exception as e:
 
291
  try:
292
  import urllib.request
293
  from io import BytesIO
294
+
295
  with urllib.request.urlopen(page_data, timeout=5) as response:
296
  image = Image.open(BytesIO(response.read()))
297
  except Exception as e:
298
  logger.debug(f"Could not fetch image URL: {e}")
299
+
300
  if image:
301
  ax.imshow(image)
302
  else:
303
  # Show placeholder
304
+ ax.text(0.5, 0.5, "No image", ha="center", va="center", fontsize=12, color="gray")
305
+
 
 
 
 
306
  # Add title
307
  title = f"Rank {idx + 1}\nScore: {score:.3f}"
308
  if payload.get("filename"):
309
  title += f"\n{payload['filename'][:30]}"
310
  if payload.get("page_number") is not None:
311
  title += f" p.{payload['page_number'] + 1}"
312
+
313
  ax.set_title(title, fontsize=9)
314
  ax.axis("off")
315
+
316
  # Add query as suptitle
317
  query_display = query[:80] + "..." if len(query) > 80 else query
318
  plt.suptitle(f"Query: {query_display}", fontsize=11, fontweight="bold")
319
  plt.tight_layout()
320
+
321
  if output_path:
322
  plt.savefig(output_path, dpi=150, bbox_inches="tight")
323
  logger.info(f"💾 Saved visualization to: {output_path}")
324
+
325
  # Convert to PIL Image for return
326
  from io import BytesIO
327
+
328
  buf = BytesIO()
329
  plt.savefig(buf, format="png", dpi=100, bbox_inches="tight")
330
  buf.seek(0)
331
  result_image = Image.open(buf)
 
 
 
 
332
 
333
+ plt.close()
334
 
335
+ return result_image