akryldigital commited on
Commit
7b91d8c
Β·
verified Β·
1 Parent(s): 7628659

add mapping between metadata (qdrant collections)

Browse files
Files changed (1) hide show
  1. src/colpali/visual_search.py +30 -11
src/colpali/visual_search.py CHANGED
@@ -8,10 +8,10 @@ All dependencies are now within src/colpali/ - no external colpali_colab_package
8
  """
9
 
10
  import logging
 
11
  from typing import List, Dict, Any, Optional
12
-
13
- import torch
14
  import numpy as np
 
15
  from qdrant_client import QdrantClient
16
 
17
  # Import from local src/colpali modules (no external dependencies)
@@ -21,8 +21,10 @@ from src.colpali.search import VisualDocumentSearch
21
  # Import device detection utility
22
  from src.utils import get_device_for_colpali
23
 
 
 
 
24
  logger = logging.getLogger(__name__)
25
- DEFAULT_MODEL = "colSmol-500M"
26
 
27
 
28
  class VisualSearchResult:
@@ -55,8 +57,8 @@ class VisualSearchAdapter:
55
  self,
56
  qdrant_url: str,
57
  qdrant_api_key: str,
58
- collection_name: str = DEFAULT_MODEL,
59
- model_name: str = f"vidore/{DEFAULT_MODEL}",
60
  device: str = None,
61
  batch_size: int = 4
62
  ):
@@ -66,13 +68,17 @@ class VisualSearchAdapter:
66
  Args:
67
  qdrant_url: Qdrant cluster URL
68
  qdrant_api_key: Qdrant API key
69
- collection_name: Name of the collection with visual embeddings
70
  model_name: ColPali model name
71
  device: Device to use (cuda/cpu/mps, auto-detected if None)
72
  batch_size: Batch size for embedding generation
73
  """
74
  logger.info("🎨 Initializing Visual Search Adapter...")
75
 
 
 
 
 
76
  # Auto-detect device using utility function
77
  if device is None:
78
  device = get_device_for_colpali()
@@ -134,11 +140,15 @@ class VisualSearchAdapter:
134
  """
135
  logger.info(f"πŸ” Visual search: '{query}' (top_k={top_k}, strategy={search_strategy})")
136
 
137
- # Generate query embedding
138
- query_embedding = self.processor.embed_query(query)
139
 
140
  # Store for saliency generation
141
  self.last_query_embedding = query_embedding
 
 
 
 
142
 
143
  # Convert filters to Qdrant format
144
  filter_params = {}
@@ -154,7 +164,14 @@ class VisualSearchAdapter:
154
  if 'districts' in filters and filters['districts']:
155
  filter_params['district'] = filters['districts']
156
  if 'filenames' in filters and filters['filenames']:
157
- filter_params['filename'] = filters['filenames']
 
 
 
 
 
 
 
158
  if 'has_text' in filters:
159
  filter_params['has_text'] = filters['has_text']
160
 
@@ -209,7 +226,7 @@ class VisualSearchAdapter:
209
  def create_visual_search_adapter(
210
  qdrant_url: Optional[str] = None,
211
  qdrant_api_key: Optional[str] = None,
212
- collection_name: str = DEFAULT_MODEL
213
  ) -> VisualSearchAdapter:
214
  """
215
  Factory function to create a visual search adapter.
@@ -217,7 +234,7 @@ def create_visual_search_adapter(
217
  Args:
218
  qdrant_url: Qdrant URL (reads from env if not provided)
219
  qdrant_api_key: Qdrant API key (reads from env if not provided)
220
- collection_name: Collection name
221
 
222
  Returns:
223
  Initialized VisualSearchAdapter
@@ -228,6 +245,8 @@ def create_visual_search_adapter(
228
  qdrant_url = os.environ.get("QDRANT_URL")
229
  if qdrant_api_key is None:
230
  qdrant_api_key = os.environ.get("QDRANT_API_KEY")
 
 
231
 
232
  if not qdrant_url or not qdrant_api_key:
233
  raise ValueError("QDRANT_URL and QDRANT_API_KEY must be provided or set in environment")
 
8
  """
9
 
10
  import logging
11
+ import os
12
  from typing import List, Dict, Any, Optional
 
 
13
  import numpy as np
14
+ import torch
15
  from qdrant_client import QdrantClient
16
 
17
  # Import from local src/colpali modules (no external dependencies)
 
21
  # Import device detection utility
22
  from src.utils import get_device_for_colpali
23
 
24
+ # Import filename mapping for v1 -> visual collection translation
25
+ from src.config.visual_filename_mapping import v1_filenames_to_visual
26
+
27
  logger = logging.getLogger(__name__)
 
28
 
29
 
30
  class VisualSearchResult:
 
57
  self,
58
  qdrant_url: str,
59
  qdrant_api_key: str,
60
+ collection_name: str = None, # Will use QDRANT_COLLECTION_VISUAL env var or default
61
+ model_name: str = "vidore/colSmol-500M",
62
  device: str = None,
63
  batch_size: int = 4
64
  ):
 
68
  Args:
69
  qdrant_url: Qdrant cluster URL
70
  qdrant_api_key: Qdrant API key
71
+ collection_name: Name of the collection with visual embeddings (default from QDRANT_COLLECTION_VISUAL env var)
72
  model_name: ColPali model name
73
  device: Device to use (cuda/cpu/mps, auto-detected if None)
74
  batch_size: Batch size for embedding generation
75
  """
76
  logger.info("🎨 Initializing Visual Search Adapter...")
77
 
78
+ # Get collection name from env var if not provided
79
+ if collection_name is None:
80
+ collection_name = os.environ.get("QDRANT_COLLECTION_VISUAL", "colSmol-500M-v2")
81
+
82
  # Auto-detect device using utility function
83
  if device is None:
84
  device = get_device_for_colpali()
 
140
  """
141
  logger.info(f"πŸ” Visual search: '{query}' (top_k={top_k}, strategy={search_strategy})")
142
 
143
+ # Generate query embedding (filter special tokens by default)
144
+ query_embedding = self.processor.embed_query(query, filter_special_tokens=True)
145
 
146
  # Store for saliency generation
147
  self.last_query_embedding = query_embedding
148
+ self.last_query_text = query # Store query text for word selection
149
+ # Store processed query info for accurate word-to-token mapping
150
+ self.last_input_ids = getattr(self.processor, 'last_input_ids', None)
151
+ self.last_attention_mask = getattr(self.processor, 'last_attention_mask', None)
152
 
153
  # Convert filters to Qdrant format
154
  filter_params = {}
 
164
  if 'districts' in filters and filters['districts']:
165
  filter_params['district'] = filters['districts']
166
  if 'filenames' in filters and filters['filenames']:
167
+ v1_filenames = filters['filenames']
168
+ visual_filenames = v1_filenames_to_visual(v1_filenames)
169
+ if visual_filenames:
170
+ logger.info(f"πŸ”„ Filename translation: {v1_filenames} -> {visual_filenames}")
171
+ filter_params['filename'] = visual_filenames
172
+ else:
173
+ logger.warning(f"⚠️ No visual filename mappings found for: {v1_filenames}")
174
+ filter_params['filename'] = v1_filenames
175
  if 'has_text' in filters:
176
  filter_params['has_text'] = filters['has_text']
177
 
 
226
  def create_visual_search_adapter(
227
  qdrant_url: Optional[str] = None,
228
  qdrant_api_key: Optional[str] = None,
229
+ collection_name: Optional[str] = None
230
  ) -> VisualSearchAdapter:
231
  """
232
  Factory function to create a visual search adapter.
 
234
  Args:
235
  qdrant_url: Qdrant URL (reads from env if not provided)
236
  qdrant_api_key: Qdrant API key (reads from env if not provided)
237
+ collection_name: Collection name (reads from QDRANT_COLLECTION_VISUAL env var if not provided)
238
 
239
  Returns:
240
  Initialized VisualSearchAdapter
 
245
  qdrant_url = os.environ.get("QDRANT_URL")
246
  if qdrant_api_key is None:
247
  qdrant_api_key = os.environ.get("QDRANT_API_KEY")
248
+ if collection_name is None:
249
+ collection_name = os.environ.get("QDRANT_COLLECTION_VISUAL", "colSmol-500M-v2")
250
 
251
  if not qdrant_url or not qdrant_api_key:
252
  raise ValueError("QDRANT_URL and QDRANT_API_KEY must be provided or set in environment")