akryldigital commited on
Commit
150fb2f
·
verified ·
1 Parent(s): c0655b8

add colpali scripts

Browse files
src/colpali/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ColPali Visual Document Retrieval Module
3
+
4
+ This module implements visual document retrieval using ColPali (ColBERT-style multi-vector embeddings)
5
+ for processing PDF documents as images.
6
+
7
+ All components are self-contained within src/colpali/ - no external dependencies on colpali_colab_package.
8
+ """
9
+
10
+ # Core inference components
11
+ from .processor import ColPaliProcessor
12
+ from .search import VisualDocumentSearch
13
+ from .visual_search import VisualSearchAdapter, VisualSearchResult, create_visual_search_adapter
14
+
15
+ # Upload/management components (for data ingestion)
16
+ from .qdrant_manager import ColPaliQdrantManager
17
+ from .visualizer import generate_saliency_maps
18
+
19
+ __all__ = [
20
+ # Inference
21
+ "ColPaliProcessor",
22
+ "VisualDocumentSearch",
23
+ "VisualSearchAdapter",
24
+ "VisualSearchResult",
25
+ "create_visual_search_adapter",
26
+ # Data management
27
+ "ColPaliQdrantManager",
28
+ "generate_saliency_maps",
29
+ ]
src/colpali/processor.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ColPali Query Embedding Processor
3
+
4
+ Handles query embedding generation using ColSmol-500M model.
5
+ This is a standalone implementation for inference only (no PDF processing).
6
+ """
7
+
8
+ import logging
9
+ from typing import Optional
10
+
11
+ import torch
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Check if colpali_engine is available
16
+ try:
17
+ from colpali_engine.models import ColIdefics3, ColIdefics3Processor
18
+ COLPALI_AVAILABLE = True
19
+ except ImportError:
20
+ COLPALI_AVAILABLE = False
21
+ logger.warning("colpali_engine not installed. Install with: pip install colpali-engine")
22
+
23
+
24
+ class ColPaliProcessor:
25
+ """
26
+ Processes queries using ColPali for visual document retrieval.
27
+
28
+ This is a lightweight processor focused on query embedding generation.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ model_name: str = "vidore/colSmol-500M",
34
+ device: str = "cpu",
35
+ torch_dtype: torch.dtype = torch.float32,
36
+ batch_size: int = 4
37
+ ):
38
+ """
39
+ Initialize ColPali processor.
40
+
41
+ Args:
42
+ model_name: HuggingFace model name for ColPali
43
+ device: Device to use ("cuda", "cpu", "mps")
44
+ torch_dtype: Data type for model weights
45
+ batch_size: Batch size for processing
46
+ """
47
+ if not COLPALI_AVAILABLE:
48
+ raise ImportError(
49
+ "colpali_engine not installed. Install with: "
50
+ "pip install colpali-engine"
51
+ )
52
+
53
+ # Validate model name (must include organization prefix)
54
+ if '/' not in model_name:
55
+ logger.warning(f"⚠️ Model name '{model_name}' missing organization prefix, adding 'vidore/'")
56
+ model_name = f"vidore/{model_name}"
57
+
58
+ self.model_name = model_name
59
+ self.device = device
60
+ self.torch_dtype = torch_dtype
61
+ self.batch_size = batch_size
62
+
63
+ logger.info(f"🤖 Loading ColPali model: {model_name}")
64
+ logger.info(f" Device: {device}, dtype: {torch_dtype}")
65
+
66
+ # Load model and processor
67
+ try:
68
+ # Determine attention implementation
69
+ attn_implementation = "eager" # Default for compatibility
70
+
71
+ if device != "cpu":
72
+ try:
73
+ import flash_attn
74
+ attn_implementation = "flash_attention_2"
75
+ logger.info(" Using FlashAttention2 for faster inference")
76
+ except ImportError:
77
+ logger.info(" FlashAttention2 not available, using eager attention")
78
+
79
+ self.model = ColIdefics3.from_pretrained(
80
+ model_name,
81
+ dtype=torch_dtype,
82
+ device_map=device,
83
+ attn_implementation=attn_implementation
84
+ ).eval()
85
+
86
+ self.processor = ColIdefics3Processor.from_pretrained(model_name)
87
+
88
+ logger.info(f"✅ ColPali model loaded successfully")
89
+ logger.info(f" Attention implementation: {attn_implementation}")
90
+
91
+ except Exception as e:
92
+ logger.error(f"❌ Failed to load ColPali model: {e}")
93
+ raise
94
+
95
+ def embed_query(self, query_text: str) -> torch.Tensor:
96
+ """
97
+ Generate embedding for a text query.
98
+
99
+ Args:
100
+ query_text: Natural language query string
101
+
102
+ Returns:
103
+ Query embedding tensor of shape [num_patches, embedding_dim]
104
+ """
105
+ with torch.no_grad():
106
+ # Process query using ColPali's query processing
107
+ processed_query = self.processor.process_queries([query_text]).to(self.model.device)
108
+ query_embedding = self.model(**processed_query)
109
+
110
+ return query_embedding
111
+
112
+ @property
113
+ def embedding_dim(self) -> int:
114
+ """Get the embedding dimension of the model."""
115
+ return self.model.config.hidden_size
116
+
117
+ @property
118
+ def image_token_id(self) -> int:
119
+ """Get the image token ID from the processor."""
120
+ return self.processor.image_token_id
src/colpali/search.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visual Document Search Engine
3
+
4
+ Two-stage visual document retrieval:
5
+ 1. Fast prefetch using pooled vectors (mean/max with HNSW)
6
+ 2. Exact reranking using full multi-vector embeddings (ColBERT-style)
7
+ """
8
+
9
+ import logging
10
+ from typing import List, Dict, Any, Optional
11
+ import numpy as np
12
+ import torch
13
+ from qdrant_client import QdrantClient
14
+ from qdrant_client.models import Filter, FieldCondition, MatchValue, MatchAny, Range
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class VisualDocumentSearch:
20
+ """
21
+ Two-stage visual document retrieval:
22
+ - Stage 1: Fast HNSW search with pooled vectors (10-100ms)
23
+ - Stage 2: Exact ColBERT reranking with full embeddings (100-500ms)
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ qdrant_client: QdrantClient,
29
+ collection_name: str = "colSmol-500M"
30
+ ):
31
+ """
32
+ Initialize search engine.
33
+
34
+ Args:
35
+ qdrant_client: Connected Qdrant client
36
+ collection_name: Name of the collection
37
+ """
38
+ self.client = qdrant_client
39
+ self.collection_name = collection_name
40
+
41
+ def get_filter_options(
42
+ self,
43
+ max_points: int = None,
44
+ use_cache: bool = True,
45
+ progress_callback=None
46
+ ) -> Dict[str, List[Any]]:
47
+ """
48
+ Scan collection to get all possible filter values using iterative scrolling.
49
+
50
+ Args:
51
+ max_points: Maximum number of points to scan (None = scan all)
52
+ use_cache: Whether to cache results (default True)
53
+ progress_callback: Optional callback function(points_scanned, elapsed_time, iteration)
54
+
55
+ Returns:
56
+ Dictionary with all unique values for each filterable field
57
+ """
58
+ scan_limit = max_points if max_points else "all"
59
+ logger.info(f"🔍 Starting metadata scan (target: {scan_limit} points)")
60
+ logger.info(f" Collection: {self.collection_name}")
61
+
62
+ # Scroll through points to collect unique values
63
+ years = set()
64
+ sources = set()
65
+ districts = set()
66
+ filenames = set()
67
+
68
+ batch_size = 900
69
+ points_scanned = 0
70
+ offset = None
71
+ iteration = 0
72
+ max_iterations = 100
73
+
74
+ import time
75
+ start_time = time.time()
76
+
77
+ try:
78
+ while True:
79
+ iteration += 1
80
+
81
+ if iteration > max_iterations:
82
+ logger.warning(f"⚠️ Reached max iterations ({max_iterations}), stopping")
83
+ break
84
+
85
+ if max_points and points_scanned >= max_points:
86
+ logger.info(f"✅ Reached target of {max_points} points")
87
+ break
88
+
89
+ if max_points:
90
+ remaining = max_points - points_scanned
91
+ current_batch_size = min(batch_size, remaining)
92
+ else:
93
+ current_batch_size = batch_size
94
+
95
+ elapsed = time.time() - start_time
96
+ logger.info(f" Batch {iteration}: fetching {current_batch_size} points (scanned: {points_scanned}, {elapsed:.1f}s)")
97
+
98
+ batch_start = time.time()
99
+ try:
100
+ results = self.client.scroll(
101
+ collection_name=self.collection_name,
102
+ limit=current_batch_size,
103
+ offset=offset,
104
+ with_payload=True,
105
+ with_vectors=False,
106
+ )
107
+
108
+ points, next_offset = results
109
+ batch_time = time.time() - batch_start
110
+ logger.info(f" ✓ Fetched {len(points)} points in {batch_time:.2f}s")
111
+
112
+ except Exception as scroll_error:
113
+ logger.error(f"❌ Scroll failed at iteration {iteration}: {scroll_error}")
114
+ break
115
+
116
+ if not points:
117
+ logger.info(f"✅ Reached end of collection (scanned {points_scanned} points)")
118
+ break
119
+
120
+ for point in points:
121
+ payload = point.payload
122
+
123
+ if payload.get('year'):
124
+ year_value = payload['year']
125
+ if isinstance(year_value, str):
126
+ try:
127
+ year_value = int(year_value)
128
+ except ValueError:
129
+ continue
130
+ if isinstance(year_value, int):
131
+ years.add(year_value)
132
+
133
+ if payload.get('source'):
134
+ sources.add(payload['source'])
135
+ if payload.get('district'):
136
+ districts.add(payload['district'])
137
+ if payload.get('filename'):
138
+ filenames.add(payload['filename'])
139
+
140
+ points_scanned += len(points)
141
+ offset = next_offset
142
+
143
+ if progress_callback:
144
+ elapsed = time.time() - start_time
145
+ progress_callback(points_scanned, elapsed, iteration)
146
+
147
+ if offset is None:
148
+ elapsed = time.time() - start_time
149
+ logger.info(f"✅ Completed full scan: {points_scanned} points in {elapsed:.1f}s")
150
+ break
151
+
152
+ elapsed = time.time() - start_time
153
+ logger.info(f"✅ Scan complete: {points_scanned} points in {elapsed:.1f}s")
154
+ logger.info(f" Found: {len(years)} years, {len(sources)} sources, "
155
+ f"{len(districts)} districts, {len(filenames)} files")
156
+
157
+ except Exception as e:
158
+ logger.error(f"❌ Error scanning collection: {e}")
159
+
160
+ return {
161
+ 'years': sorted(list(years)),
162
+ 'sources': sorted(list(sources)),
163
+ 'districts': sorted(list(districts)),
164
+ 'filenames': sorted(list(filenames))
165
+ }
166
+
167
+ def build_filter(
168
+ self,
169
+ year: Optional[Any] = None,
170
+ source: Optional[Any] = None,
171
+ district: Optional[Any] = None,
172
+ filename: Optional[Any] = None,
173
+ has_text: Optional[bool] = None,
174
+ page_range: Optional[tuple] = None
175
+ ) -> Optional[Filter]:
176
+ """
177
+ Build Qdrant filter from parameters.
178
+
179
+ Supports both single values and lists (using MatchAny for lists).
180
+ """
181
+ conditions = []
182
+
183
+ if year is not None:
184
+ if isinstance(year, list):
185
+ year_values = [int(y) if isinstance(y, str) else y for y in year]
186
+ conditions.append(
187
+ FieldCondition(key="year", match=MatchAny(any=year_values))
188
+ )
189
+ logger.info(f"🔍 Filter: year IN {year_values}")
190
+ else:
191
+ year_value = int(year) if isinstance(year, str) else year
192
+ conditions.append(
193
+ FieldCondition(key="year", match=MatchValue(value=year_value))
194
+ )
195
+ logger.info(f"🔍 Filter: year = {year_value}")
196
+
197
+ if source is not None:
198
+ if isinstance(source, list):
199
+ conditions.append(
200
+ FieldCondition(key="source", match=MatchAny(any=source))
201
+ )
202
+ logger.info(f"🔍 Filter: source IN {source}")
203
+ else:
204
+ conditions.append(
205
+ FieldCondition(key="source", match=MatchValue(value=source))
206
+ )
207
+ logger.info(f"🔍 Filter: source = {source}")
208
+
209
+ if district is not None:
210
+ if isinstance(district, list):
211
+ conditions.append(
212
+ FieldCondition(key="district", match=MatchAny(any=district))
213
+ )
214
+ logger.info(f"🔍 Filter: district IN {district}")
215
+ else:
216
+ conditions.append(
217
+ FieldCondition(key="district", match=MatchValue(value=district))
218
+ )
219
+ logger.info(f"🔍 Filter: district = {district}")
220
+
221
+ if filename is not None:
222
+ if isinstance(filename, list):
223
+ conditions.append(
224
+ FieldCondition(key="filename", match=MatchAny(any=filename))
225
+ )
226
+ logger.info(f"🔍 Filter: filename IN {filename}")
227
+ else:
228
+ conditions.append(
229
+ FieldCondition(key="filename", match=MatchValue(value=filename))
230
+ )
231
+ logger.info(f"🔍 Filter: filename = {filename}")
232
+
233
+ if has_text is not None:
234
+ conditions.append(
235
+ FieldCondition(key="has_text", match=MatchValue(value=has_text))
236
+ )
237
+
238
+ if page_range is not None:
239
+ min_page, max_page = page_range
240
+ conditions.append(
241
+ FieldCondition(
242
+ key="page_number",
243
+ range=Range(gte=min_page, lte=max_page)
244
+ )
245
+ )
246
+
247
+ if not conditions:
248
+ return None
249
+
250
+ return Filter(must=conditions)
251
+
252
+ def search_stage1_prefetch(
253
+ self,
254
+ query_embedding: torch.Tensor,
255
+ top_k: int = 100,
256
+ filter_obj: Optional[Filter] = None,
257
+ use_pooling: bool = False,
258
+ pooling_method: str = "mean"
259
+ ) -> List[Dict[str, Any]]:
260
+ """
261
+ Stage 1: Prefetch candidates using either multi-vector or pooled search.
262
+ """
263
+ # Convert to numpy
264
+ if isinstance(query_embedding, torch.Tensor):
265
+ query_np = query_embedding.cpu().float().numpy()
266
+ else:
267
+ query_np = np.array(query_embedding, dtype=np.float32)
268
+
269
+ # Handle batch dimension
270
+ if query_np.ndim == 3:
271
+ query_np = query_np.squeeze(0)
272
+
273
+ # Strategy 1: Pooled search (fast, approximate)
274
+ if use_pooling:
275
+ if pooling_method == "mean":
276
+ query_pooled = query_np.mean(axis=0)
277
+ vector_name = "mean_pooling"
278
+ elif pooling_method == "max":
279
+ query_pooled = query_np.max(axis=0)
280
+ vector_name = "max_pooling"
281
+ else:
282
+ raise ValueError(f"Unknown pooling method: {pooling_method}")
283
+
284
+ if query_pooled.ndim != 1:
285
+ raise ValueError(f"Pooling failed! Expected 1D vector, got shape {query_pooled.shape}")
286
+
287
+ query_vector = query_pooled.tolist()
288
+ logger.info(f"🔍 Pooled search: vector={vector_name}, dims={len(query_vector)}")
289
+
290
+ # Strategy 2: Native multi-vector search (SOTA)
291
+ else:
292
+ vector_name = "initial"
293
+ query_vector = query_np.tolist()
294
+ logger.info(f"🎯 Multi-vector search: vector={vector_name}, patches={len(query_vector)}, dims={len(query_vector[0])}")
295
+
296
+ try:
297
+ results = self.client.query_points(
298
+ collection_name=self.collection_name,
299
+ query=query_vector,
300
+ using=vector_name,
301
+ query_filter=filter_obj,
302
+ limit=top_k,
303
+ with_payload=True,
304
+ with_vectors=False,
305
+ timeout=120
306
+ ).points
307
+
308
+ logger.info(f"✅ Stage 1: Retrieved {len(results)} candidates")
309
+
310
+ except Exception as e:
311
+ logger.error(f"❌ Search with vector '{vector_name}' failed: {e}")
312
+ raise
313
+
314
+ candidates = []
315
+ for result in results:
316
+ candidates.append({
317
+ 'id': result.id,
318
+ 'score_stage1': result.score,
319
+ 'payload': result.payload
320
+ })
321
+
322
+ return candidates
323
+
324
+ def colbert_score(
325
+ self,
326
+ query_embedding: np.ndarray,
327
+ doc_embedding: np.ndarray
328
+ ) -> float:
329
+ """
330
+ Compute ColBERT-style late interaction score.
331
+ """
332
+ # Normalize embeddings
333
+ query_norm = query_embedding / (np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8)
334
+ doc_norm = doc_embedding / (np.linalg.norm(doc_embedding, axis=1, keepdims=True) + 1e-8)
335
+
336
+ # Compute similarity matrix
337
+ sim_matrix = np.dot(query_norm, doc_norm.T)
338
+
339
+ # For each query patch, take max similarity with any doc patch
340
+ max_sims = sim_matrix.max(axis=1)
341
+
342
+ # Average across query patches
343
+ score = max_sims.mean()
344
+
345
+ return float(score)
346
+
347
+ def search_stage2_rerank(
348
+ self,
349
+ query_embedding: torch.Tensor,
350
+ candidates: List[Dict[str, Any]],
351
+ top_k: int = 10
352
+ ) -> List[Dict[str, Any]]:
353
+ """
354
+ Stage 2: Exact reranking using full multi-vector embeddings.
355
+ """
356
+ if isinstance(query_embedding, torch.Tensor):
357
+ query_np = query_embedding.cpu().float().numpy()
358
+ else:
359
+ query_np = np.array(query_embedding, dtype=np.float32)
360
+
361
+ reranked = []
362
+ for candidate in candidates:
363
+ payload = candidate['payload']
364
+
365
+ full_embedding = payload.get('full_embedding')
366
+ if full_embedding is None:
367
+ candidate['score_final'] = candidate['score_stage1']
368
+ reranked.append(candidate)
369
+ continue
370
+
371
+ doc_np = np.array(full_embedding, dtype=np.float32)
372
+ colbert_score = self.colbert_score(query_np, doc_np)
373
+
374
+ candidate['score_stage2'] = colbert_score
375
+ candidate['score_final'] = colbert_score
376
+ reranked.append(candidate)
377
+
378
+ reranked.sort(key=lambda x: x['score_final'], reverse=True)
379
+
380
+ return reranked[:top_k]
381
+
382
+ def search(
383
+ self,
384
+ query_embedding: torch.Tensor,
385
+ top_k: int = 10,
386
+ prefetch_k: Optional[int] = None,
387
+ year: Optional[int] = None,
388
+ source: Optional[str] = None,
389
+ district: Optional[str] = None,
390
+ filename: Optional[str] = None,
391
+ has_text: Optional[bool] = None,
392
+ page_range: Optional[tuple] = None,
393
+ search_strategy: str = "multi_vector",
394
+ pooling_method: str = "mean",
395
+ use_reranking: bool = False
396
+ ) -> List[Dict[str, Any]]:
397
+ """
398
+ Multi-strategy visual document search.
399
+
400
+ Search Strategies:
401
+ 1. "multi_vector" (DEFAULT, SOTA): Native multi-vector search
402
+ 2. "pooled": Pooled search (fastest, less accurate)
403
+ 3. "hybrid": Two-stage retrieval with reranking
404
+ """
405
+ # Build filter
406
+ filter_obj = self.build_filter(
407
+ year=year,
408
+ source=source,
409
+ district=district,
410
+ filename=filename,
411
+ has_text=has_text,
412
+ page_range=page_range
413
+ )
414
+
415
+ # Strategy 1: Native multi-vector search (SOTA, default)
416
+ if search_strategy == "multi_vector":
417
+ logger.info(f"🎯 SOTA Multi-Vector Search: Querying 'initial' vector with native MaxSim")
418
+ candidates = self.search_stage1_prefetch(
419
+ query_embedding=query_embedding,
420
+ top_k=top_k,
421
+ filter_obj=filter_obj,
422
+ use_pooling=False
423
+ )
424
+
425
+ if not candidates:
426
+ logger.warning("❌ No results found")
427
+ return []
428
+
429
+ for c in candidates:
430
+ c['score_final'] = c['score_stage1']
431
+
432
+ logger.info(f"✅ Retrieved {len(candidates)} results (native MaxSim)")
433
+ return candidates
434
+
435
+ # Strategy 2: Pooled search (fast, approximate)
436
+ elif search_strategy == "pooled":
437
+ logger.info(f"🔍 Pooled Search: Querying '{pooling_method}_pooling' vector")
438
+ candidates = self.search_stage1_prefetch(
439
+ query_embedding=query_embedding,
440
+ top_k=top_k,
441
+ filter_obj=filter_obj,
442
+ use_pooling=True,
443
+ pooling_method=pooling_method
444
+ )
445
+
446
+ if not candidates:
447
+ logger.warning("❌ No results found")
448
+ return []
449
+
450
+ for c in candidates:
451
+ c['score_final'] = c['score_stage1']
452
+
453
+ logger.info(f"✅ Retrieved {len(candidates)} results (pooled)")
454
+ return candidates
455
+
456
+ # Strategy 3: Hybrid two-stage
457
+ elif search_strategy == "hybrid":
458
+ if prefetch_k is None:
459
+ prefetch_k = max(100, top_k * 10)
460
+
461
+ logger.info(f"🔄 Hybrid Search: Stage 1 - Prefetching {prefetch_k} with {pooling_method} pooling")
462
+ candidates = self.search_stage1_prefetch(
463
+ query_embedding=query_embedding,
464
+ top_k=prefetch_k,
465
+ filter_obj=filter_obj,
466
+ use_pooling=True,
467
+ pooling_method=pooling_method
468
+ )
469
+
470
+ if not candidates:
471
+ logger.warning("❌ No results found in stage 1")
472
+ return []
473
+
474
+ logger.info(f"✅ Stage 1: Found {len(candidates)} candidates")
475
+
476
+ if use_reranking and len(candidates) > top_k:
477
+ logger.info(f"🎯 Stage 2: Reranking with ColBERT scoring...")
478
+ results = self.search_stage2_rerank(
479
+ query_embedding=query_embedding,
480
+ candidates=candidates,
481
+ top_k=top_k
482
+ )
483
+ logger.info(f"✅ Reranked to top {len(results)} results")
484
+ return results
485
+ else:
486
+ results = candidates[:top_k]
487
+ for r in results:
488
+ r['score_final'] = r['score_stage1']
489
+ logger.info(f"⏭️ Skipping reranking, returning top {len(results)}")
490
+ return results
491
+
492
+ else:
493
+ raise ValueError(f"Unknown search_strategy: {search_strategy}")
494
+
src/colpali/visual_search.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visual Document Search Adapter for Main App
3
+
4
+ This module provides an adapter to integrate ColPali visual search
5
+ into the main app's retrieval pipeline.
6
+
7
+ All dependencies are now within src/colpali/ - no external colpali_colab_package needed.
8
+ """
9
+
10
+ import logging
11
+ from typing import List, Dict, Any, Optional
12
+ import numpy as np
13
+ import torch
14
+ from qdrant_client import QdrantClient
15
+
16
+ # Import from local src/colpali modules (no external dependencies)
17
+ from src.colpali.processor import ColPaliProcessor
18
+ from src.colpali.search import VisualDocumentSearch
19
+
20
+ # Import device detection utility
21
+ from src.utils import get_device_for_colpali
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class VisualSearchResult:
27
+ """
28
+ Wrapper for visual search results to match the interface expected by app.py
29
+ """
30
+ def __init__(self, point_id: str, score: float, payload: Dict[str, Any]):
31
+ self.id = point_id
32
+ self.score = score
33
+ self.payload = payload
34
+ self.metadata = payload # Alias for compatibility
35
+
36
+ # Extract content for compatibility with Document interface
37
+ self.page_content = payload.get('text', '')
38
+ self.content = self.page_content
39
+
40
+ def __repr__(self):
41
+ return f"VisualSearchResult(id={self.id}, score={self.score:.4f})"
42
+
43
+
44
+ class VisualSearchAdapter:
45
+ """
46
+ Adapter to integrate ColPali visual search into the main app.
47
+
48
+ This provides a unified interface for visual document retrieval that works
49
+ with the existing chatbot architecture.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ qdrant_url: str,
55
+ qdrant_api_key: str,
56
+ collection_name: str = "colSmol-500M",
57
+ model_name: str = "vidore/colSmol-500M",
58
+ device: str = None,
59
+ batch_size: int = 4
60
+ ):
61
+ """
62
+ Initialize visual search adapter.
63
+
64
+ Args:
65
+ qdrant_url: Qdrant cluster URL
66
+ qdrant_api_key: Qdrant API key
67
+ collection_name: Name of the collection with visual embeddings
68
+ model_name: ColPali model name
69
+ device: Device to use (cuda/cpu/mps, auto-detected if None)
70
+ batch_size: Batch size for embedding generation
71
+ """
72
+ logger.info("🎨 Initializing Visual Search Adapter...")
73
+
74
+ # Auto-detect device using utility function
75
+ if device is None:
76
+ device = get_device_for_colpali()
77
+
78
+ self.device = device
79
+ logger.info(f" Device: {device}")
80
+
81
+ # Initialize Qdrant client
82
+ logger.info(f" Connecting to Qdrant: {qdrant_url}")
83
+ self.client = QdrantClient(
84
+ url=qdrant_url,
85
+ api_key=qdrant_api_key,
86
+ prefer_grpc=False, # Use HTTP for compatibility
87
+ timeout=60
88
+ )
89
+
90
+ # Initialize search engine (from local src/colpali/search.py)
91
+ self.search_engine = VisualDocumentSearch(
92
+ qdrant_client=self.client,
93
+ collection_name=collection_name
94
+ )
95
+
96
+ # Initialize processor (from local src/colpali/processor.py)
97
+ logger.info(f" Loading model: {model_name}")
98
+ torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
99
+ self.processor = ColPaliProcessor(
100
+ model_name=model_name,
101
+ device=device,
102
+ torch_dtype=torch_dtype,
103
+ batch_size=batch_size
104
+ )
105
+
106
+ # Store last query embedding for saliency generation
107
+ self.last_query_embedding = None
108
+ self.collection_name = collection_name
109
+
110
+ logger.info("✅ Visual Search Adapter initialized!")
111
+
112
+ def search(
113
+ self,
114
+ query: str,
115
+ top_k: int = 10,
116
+ filters: Optional[Dict[str, Any]] = None,
117
+ search_strategy: str = "multi_vector",
118
+ **kwargs
119
+ ) -> List[VisualSearchResult]:
120
+ """
121
+ Search for visually similar documents.
122
+
123
+ Args:
124
+ query: Text query
125
+ top_k: Number of results to return
126
+ filters: Optional filters (year, source, district, filename, has_text)
127
+ search_strategy: Search strategy (multi_vector, pooled, hybrid)
128
+ **kwargs: Additional search parameters
129
+
130
+ Returns:
131
+ List of VisualSearchResult objects
132
+ """
133
+ logger.info(f"🔍 Visual search: '{query}' (top_k={top_k}, strategy={search_strategy})")
134
+
135
+ # Generate query embedding
136
+ query_embedding = self.processor.embed_query(query)
137
+
138
+ # Store for saliency generation
139
+ self.last_query_embedding = query_embedding
140
+
141
+ # Convert filters to Qdrant format
142
+ filter_params = {}
143
+ if filters:
144
+ if 'sources' in filters and filters['sources']:
145
+ filter_params['source'] = filters['sources']
146
+ if 'years' in filters and filters['years']:
147
+ years = filters['years']
148
+ if isinstance(years, list):
149
+ filter_params['year'] = [int(y) if isinstance(y, str) else y for y in years]
150
+ else:
151
+ filter_params['year'] = int(years) if isinstance(years, str) else years
152
+ if 'districts' in filters and filters['districts']:
153
+ filter_params['district'] = filters['districts']
154
+ if 'filenames' in filters and filters['filenames']:
155
+ filter_params['filename'] = filters['filenames']
156
+ if 'has_text' in filters:
157
+ filter_params['has_text'] = filters['has_text']
158
+
159
+ logger.info(f"🔍 Visual search: Converted filter params: {filter_params}")
160
+
161
+ # Perform search
162
+ results = self.search_engine.search(
163
+ query_embedding=query_embedding,
164
+ top_k=top_k,
165
+ search_strategy=search_strategy,
166
+ **filter_params,
167
+ **kwargs
168
+ )
169
+
170
+ # Fallback: If 0 results with filters, retry without filters
171
+ if not results and filter_params:
172
+ logger.warning(f"⚠️ Visual search: 0 results with filters, retrying WITHOUT filters...")
173
+ results = self.search_engine.search(
174
+ query_embedding=query_embedding,
175
+ top_k=top_k,
176
+ search_strategy=search_strategy,
177
+ **kwargs # No filter_params
178
+ )
179
+ if results:
180
+ logger.info(f"✅ Visual search: Found {len(results)} results after removing filters")
181
+ else:
182
+ logger.warning(f"❌ Visual search: Still 0 results even without filters")
183
+
184
+ # Convert to VisualSearchResult objects
185
+ visual_results = []
186
+ for result in results:
187
+ visual_result = VisualSearchResult(
188
+ point_id=result['id'],
189
+ score=result.get('score_final', result.get('score', 0.0)),
190
+ payload=result['payload']
191
+ )
192
+ visual_results.append(visual_result)
193
+
194
+ logger.info(f"✅ Found {len(visual_results)} visual results")
195
+ return visual_results
196
+
197
+ def get_filter_options(self) -> Dict[str, List[Any]]:
198
+ """
199
+ Get available filter options from the collection.
200
+
201
+ Returns:
202
+ Dictionary with years, sources, districts, filenames
203
+ """
204
+ return self.search_engine.get_filter_options()
205
+
206
+
207
+ def create_visual_search_adapter(
208
+ qdrant_url: Optional[str] = None,
209
+ qdrant_api_key: Optional[str] = None,
210
+ collection_name: str = "colSmol-500M"
211
+ ) -> VisualSearchAdapter:
212
+ """
213
+ Factory function to create a visual search adapter.
214
+
215
+ Args:
216
+ qdrant_url: Qdrant URL (reads from env if not provided)
217
+ qdrant_api_key: Qdrant API key (reads from env if not provided)
218
+ collection_name: Collection name
219
+
220
+ Returns:
221
+ Initialized VisualSearchAdapter
222
+ """
223
+ import os
224
+
225
+ if qdrant_url is None:
226
+ qdrant_url = os.environ.get("QDRANT_URL")
227
+ if qdrant_api_key is None:
228
+ qdrant_api_key = os.environ.get("QDRANT_API_KEY")
229
+
230
+ if not qdrant_url or not qdrant_api_key:
231
+ raise ValueError("QDRANT_URL and QDRANT_API_KEY must be provided or set in environment")
232
+
233
+ return VisualSearchAdapter(
234
+ qdrant_url=qdrant_url,
235
+ qdrant_api_key=qdrant_api_key,
236
+ collection_name=collection_name
237
+ )
src/colpali/visualizer.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ColPali Visualization Module
3
+
4
+ Generates attention/saliency maps to visualize which parts of the document
5
+ are most relevant to a query.
6
+ """
7
+
8
+ import torch
9
+ import numpy as np
10
+ from PIL import Image, ImageDraw, ImageFont
11
+ from typing import List, Dict, Any, Optional
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib.patches as patches
14
+ from matplotlib.colors import LinearSegmentedColormap
15
+ import logging
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def generate_saliency_maps(
22
+ query_embedding: torch.Tensor,
23
+ image_embeddings: List[torch.Tensor],
24
+ images: List[Image.Image],
25
+ processor,
26
+ model,
27
+ top_k: int = 5,
28
+ threshold: float = 0.5
29
+ ) -> List[Image.Image]:
30
+ """
31
+ Generate saliency/attention maps showing which parts of images are most relevant.
32
+
33
+ Args:
34
+ query_embedding: Query embedding tensor [num_query_patches, embedding_dim]
35
+ image_embeddings: List of image embedding tensors, each [num_patches, embedding_dim]
36
+ images: List of PIL Images corresponding to embeddings
37
+ processor: ColPali processor for scoring
38
+ model: ColPali model
39
+ top_k: Number of top images to visualize
40
+ threshold: Threshold for highlighting (0-1)
41
+
42
+ Returns:
43
+ List of annotated images with saliency overlays
44
+ """
45
+ logger.info(f"🎨 Generating saliency maps for {len(images)} images")
46
+
47
+ # Calculate scores for all images
48
+ scores = []
49
+ for img_emb in image_embeddings:
50
+ # Use processor's scoring method
51
+ score = processor.score_multi_vector(query_embedding.unsqueeze(0), img_emb.unsqueeze(0))
52
+ scores.append(score.item() if isinstance(score, torch.Tensor) else score)
53
+
54
+ # Get top-k images
55
+ top_indices = np.argsort(scores)[-top_k:][::-1]
56
+
57
+ annotated_images = []
58
+
59
+ for idx in top_indices:
60
+ image = images[idx]
61
+ embedding = image_embeddings[idx]
62
+ score = scores[idx]
63
+
64
+ # Create saliency map
65
+ # For ColPali, we can visualize patch-level relevance
66
+ # Each patch in the embedding corresponds to a region in the image
67
+
68
+ # Calculate patch-level scores
69
+ # Query embedding: [num_query_patches, dim]
70
+ # Image embedding: [num_image_patches, dim]
71
+ # Compute similarity for each patch pair
72
+ query_np = query_embedding.cpu().numpy()
73
+ img_np = embedding.cpu().numpy()
74
+
75
+ # Compute cosine similarity for each patch
76
+ # Normalize
77
+ query_norm = query_np / (np.linalg.norm(query_np, axis=1, keepdims=True) + 1e-8)
78
+ img_norm = img_np / (np.linalg.norm(img_np, axis=1, keepdims=True) + 1e-8)
79
+
80
+ # Compute similarity matrix: [num_query_patches, num_image_patches]
81
+ similarity_matrix = np.dot(query_norm, img_norm.T)
82
+
83
+ # Get max similarity per image patch (best match from any query patch)
84
+ patch_scores = similarity_matrix.max(axis=0) # [num_image_patches]
85
+
86
+ # Normalize scores to [0, 1]
87
+ patch_scores = (patch_scores - patch_scores.min()) / (patch_scores.max() - patch_scores.min() + 1e-8)
88
+
89
+ # Create overlay image
90
+ annotated = _create_saliency_overlay(
91
+ image,
92
+ patch_scores,
93
+ score,
94
+ threshold=threshold
95
+ )
96
+
97
+ annotated_images.append(annotated)
98
+
99
+ logger.info(f"✅ Generated {len(annotated_images)} saliency maps")
100
+
101
+ return annotated_images
102
+
103
+
104
+ def _create_saliency_overlay(
105
+ image: Image.Image,
106
+ patch_scores: np.ndarray,
107
+ overall_score: float,
108
+ threshold: float = 0.5,
109
+ patch_size: int = 16 # Approximate patch size in pixels
110
+ ) -> Image.Image:
111
+ """
112
+ Create saliency overlay on image.
113
+
114
+ Args:
115
+ image: Original PIL Image
116
+ patch_scores: Array of scores for each patch [num_patches]
117
+ overall_score: Overall relevance score
118
+ threshold: Threshold for highlighting
119
+ patch_size: Size of each patch in pixels
120
+
121
+ Returns:
122
+ Annotated PIL Image
123
+ """
124
+ # Convert to numpy array
125
+ img_array = np.array(image)
126
+ h, w = img_array.shape[:2]
127
+
128
+ # Estimate grid dimensions
129
+ # ColPali typically uses a grid of patches
130
+ # For simplicity, assume square grid
131
+ num_patches = len(patch_scores)
132
+ grid_size = int(np.sqrt(num_patches))
133
+
134
+ if grid_size * grid_size != num_patches:
135
+ # Non-square grid, try to estimate
136
+ # Common aspect ratios
137
+ aspect_ratio = w / h
138
+ cols = int(np.sqrt(num_patches * aspect_ratio))
139
+ rows = int(num_patches / cols)
140
+ if cols * rows != num_patches:
141
+ # Fallback to square
142
+ grid_size = int(np.sqrt(num_patches))
143
+ rows = cols = grid_size
144
+ else:
145
+ rows = cols = grid_size
146
+
147
+ # Calculate patch dimensions
148
+ patch_h = h // rows
149
+ patch_w = w // cols
150
+
151
+ # Create overlay
152
+ overlay = np.zeros((h, w, 4), dtype=np.uint8) # RGBA
153
+
154
+ # Create colormap (red for high relevance)
155
+ cmap = plt.cm.Reds
156
+
157
+ patch_idx = 0
158
+ for i in range(rows):
159
+ for j in range(cols):
160
+ if patch_idx >= len(patch_scores):
161
+ break
162
+
163
+ score = patch_scores[patch_idx]
164
+
165
+ if score >= threshold:
166
+ # Calculate patch bounds
167
+ y1 = i * patch_h
168
+ y2 = min((i + 1) * patch_h, h)
169
+ x1 = j * patch_w
170
+ x2 = min((j + 1) * patch_w, w)
171
+
172
+ # Get color from colormap
173
+ color = cmap(score)[:3] # RGB
174
+ color_uint8 = (np.array(color) * 255).astype(np.uint8)
175
+
176
+ # Set overlay
177
+ overlay[y1:y2, x1:x2, :3] = color_uint8
178
+ overlay[y1:y2, x1:x2, 3] = int(score * 128) # Alpha based on score
179
+
180
+ patch_idx += 1
181
+
182
+ # Blend overlay with original image
183
+ overlay_img = Image.fromarray(overlay, 'RGBA')
184
+ annotated = Image.alpha_composite(image.convert('RGBA'), overlay_img)
185
+
186
+ # Add text annotation with score
187
+ draw = ImageDraw.Draw(annotated)
188
+ try:
189
+ font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 24)
190
+ except:
191
+ font = ImageFont.load_default()
192
+
193
+ score_text = f"Relevance: {overall_score:.3f}"
194
+ draw.text((10, 10), score_text, fill=(255, 255, 255, 255), font=font, stroke_width=2, stroke_fill=(0, 0, 0, 255))
195
+
196
+ return annotated.convert('RGB')
197
+
198
+
199
+ def visualize_retrieval_results(
200
+ query: str,
201
+ retrieved_docs: List[Dict[str, Any]],
202
+ output_path: Optional[str] = None
203
+ ) -> None:
204
+ """
205
+ Visualize retrieval results with images and scores.
206
+
207
+ Args:
208
+ query: Original query text
209
+ retrieved_docs: List of retrieved documents with images and scores
210
+ output_path: Optional path to save visualization
211
+ """
212
+ num_docs = len(retrieved_docs)
213
+ fig, axes = plt.subplots(1, num_docs, figsize=(5 * num_docs, 5))
214
+
215
+ if num_docs == 1:
216
+ axes = [axes]
217
+
218
+ for idx, (doc, ax) in enumerate(zip(retrieved_docs, axes)):
219
+ if 'image' in doc:
220
+ ax.imshow(doc['image'])
221
+ ax.set_title(f"Rank {idx+1}\nScore: {doc.get('score', 0):.3f}")
222
+ ax.axis('off')
223
+
224
+ plt.suptitle(f"Query: {query}", fontsize=14, fontweight='bold')
225
+ plt.tight_layout()
226
+
227
+ if output_path:
228
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
229
+ logger.info(f"💾 Saved visualization to: {output_path}")
230
+ else:
231
+ plt.show()
232
+
233
+ plt.close()
234
+
235
+
236
+