Spaces:
Running
Running
add mapping between metadata (qdrant collections)
Browse files- src/colpali/visual_search.py +30 -11
src/colpali/visual_search.py
CHANGED
|
@@ -8,10 +8,10 @@ All dependencies are now within src/colpali/ - no external colpali_colab_package
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
import logging
|
|
|
|
| 11 |
from typing import List, Dict, Any, Optional
|
| 12 |
-
|
| 13 |
-
import torch
|
| 14 |
import numpy as np
|
|
|
|
| 15 |
from qdrant_client import QdrantClient
|
| 16 |
|
| 17 |
# Import from local src/colpali modules (no external dependencies)
|
|
@@ -21,8 +21,10 @@ from src.colpali.search import VisualDocumentSearch
|
|
| 21 |
# Import device detection utility
|
| 22 |
from src.utils import get_device_for_colpali
|
| 23 |
|
|
|
|
|
|
|
|
|
|
| 24 |
logger = logging.getLogger(__name__)
|
| 25 |
-
DEFAULT_MODEL = "colSmol-500M"
|
| 26 |
|
| 27 |
|
| 28 |
class VisualSearchResult:
|
|
@@ -55,8 +57,8 @@ class VisualSearchAdapter:
|
|
| 55 |
self,
|
| 56 |
qdrant_url: str,
|
| 57 |
qdrant_api_key: str,
|
| 58 |
-
collection_name: str =
|
| 59 |
-
model_name: str =
|
| 60 |
device: str = None,
|
| 61 |
batch_size: int = 4
|
| 62 |
):
|
|
@@ -66,13 +68,17 @@ class VisualSearchAdapter:
|
|
| 66 |
Args:
|
| 67 |
qdrant_url: Qdrant cluster URL
|
| 68 |
qdrant_api_key: Qdrant API key
|
| 69 |
-
collection_name: Name of the collection with visual embeddings
|
| 70 |
model_name: ColPali model name
|
| 71 |
device: Device to use (cuda/cpu/mps, auto-detected if None)
|
| 72 |
batch_size: Batch size for embedding generation
|
| 73 |
"""
|
| 74 |
logger.info("π¨ Initializing Visual Search Adapter...")
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
# Auto-detect device using utility function
|
| 77 |
if device is None:
|
| 78 |
device = get_device_for_colpali()
|
|
@@ -134,11 +140,15 @@ class VisualSearchAdapter:
|
|
| 134 |
"""
|
| 135 |
logger.info(f"π Visual search: '{query}' (top_k={top_k}, strategy={search_strategy})")
|
| 136 |
|
| 137 |
-
# Generate query embedding
|
| 138 |
-
query_embedding = self.processor.embed_query(query)
|
| 139 |
|
| 140 |
# Store for saliency generation
|
| 141 |
self.last_query_embedding = query_embedding
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
# Convert filters to Qdrant format
|
| 144 |
filter_params = {}
|
|
@@ -154,7 +164,14 @@ class VisualSearchAdapter:
|
|
| 154 |
if 'districts' in filters and filters['districts']:
|
| 155 |
filter_params['district'] = filters['districts']
|
| 156 |
if 'filenames' in filters and filters['filenames']:
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
if 'has_text' in filters:
|
| 159 |
filter_params['has_text'] = filters['has_text']
|
| 160 |
|
|
@@ -209,7 +226,7 @@ class VisualSearchAdapter:
|
|
| 209 |
def create_visual_search_adapter(
|
| 210 |
qdrant_url: Optional[str] = None,
|
| 211 |
qdrant_api_key: Optional[str] = None,
|
| 212 |
-
collection_name: str =
|
| 213 |
) -> VisualSearchAdapter:
|
| 214 |
"""
|
| 215 |
Factory function to create a visual search adapter.
|
|
@@ -217,7 +234,7 @@ def create_visual_search_adapter(
|
|
| 217 |
Args:
|
| 218 |
qdrant_url: Qdrant URL (reads from env if not provided)
|
| 219 |
qdrant_api_key: Qdrant API key (reads from env if not provided)
|
| 220 |
-
collection_name: Collection name
|
| 221 |
|
| 222 |
Returns:
|
| 223 |
Initialized VisualSearchAdapter
|
|
@@ -228,6 +245,8 @@ def create_visual_search_adapter(
|
|
| 228 |
qdrant_url = os.environ.get("QDRANT_URL")
|
| 229 |
if qdrant_api_key is None:
|
| 230 |
qdrant_api_key = os.environ.get("QDRANT_API_KEY")
|
|
|
|
|
|
|
| 231 |
|
| 232 |
if not qdrant_url or not qdrant_api_key:
|
| 233 |
raise ValueError("QDRANT_URL and QDRANT_API_KEY must be provided or set in environment")
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
import logging
|
| 11 |
+
import os
|
| 12 |
from typing import List, Dict, Any, Optional
|
|
|
|
|
|
|
| 13 |
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
from qdrant_client import QdrantClient
|
| 16 |
|
| 17 |
# Import from local src/colpali modules (no external dependencies)
|
|
|
|
| 21 |
# Import device detection utility
|
| 22 |
from src.utils import get_device_for_colpali
|
| 23 |
|
| 24 |
+
# Import filename mapping for v1 -> visual collection translation
|
| 25 |
+
from src.config.visual_filename_mapping import v1_filenames_to_visual
|
| 26 |
+
|
| 27 |
logger = logging.getLogger(__name__)
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
class VisualSearchResult:
|
|
|
|
| 57 |
self,
|
| 58 |
qdrant_url: str,
|
| 59 |
qdrant_api_key: str,
|
| 60 |
+
collection_name: str = None, # Will use QDRANT_COLLECTION_VISUAL env var or default
|
| 61 |
+
model_name: str = "vidore/colSmol-500M",
|
| 62 |
device: str = None,
|
| 63 |
batch_size: int = 4
|
| 64 |
):
|
|
|
|
| 68 |
Args:
|
| 69 |
qdrant_url: Qdrant cluster URL
|
| 70 |
qdrant_api_key: Qdrant API key
|
| 71 |
+
collection_name: Name of the collection with visual embeddings (default from QDRANT_COLLECTION_VISUAL env var)
|
| 72 |
model_name: ColPali model name
|
| 73 |
device: Device to use (cuda/cpu/mps, auto-detected if None)
|
| 74 |
batch_size: Batch size for embedding generation
|
| 75 |
"""
|
| 76 |
logger.info("π¨ Initializing Visual Search Adapter...")
|
| 77 |
|
| 78 |
+
# Get collection name from env var if not provided
|
| 79 |
+
if collection_name is None:
|
| 80 |
+
collection_name = os.environ.get("QDRANT_COLLECTION_VISUAL", "colSmol-500M-v2")
|
| 81 |
+
|
| 82 |
# Auto-detect device using utility function
|
| 83 |
if device is None:
|
| 84 |
device = get_device_for_colpali()
|
|
|
|
| 140 |
"""
|
| 141 |
logger.info(f"π Visual search: '{query}' (top_k={top_k}, strategy={search_strategy})")
|
| 142 |
|
| 143 |
+
# Generate query embedding (filter special tokens by default)
|
| 144 |
+
query_embedding = self.processor.embed_query(query, filter_special_tokens=True)
|
| 145 |
|
| 146 |
# Store for saliency generation
|
| 147 |
self.last_query_embedding = query_embedding
|
| 148 |
+
self.last_query_text = query # Store query text for word selection
|
| 149 |
+
# Store processed query info for accurate word-to-token mapping
|
| 150 |
+
self.last_input_ids = getattr(self.processor, 'last_input_ids', None)
|
| 151 |
+
self.last_attention_mask = getattr(self.processor, 'last_attention_mask', None)
|
| 152 |
|
| 153 |
# Convert filters to Qdrant format
|
| 154 |
filter_params = {}
|
|
|
|
| 164 |
if 'districts' in filters and filters['districts']:
|
| 165 |
filter_params['district'] = filters['districts']
|
| 166 |
if 'filenames' in filters and filters['filenames']:
|
| 167 |
+
v1_filenames = filters['filenames']
|
| 168 |
+
visual_filenames = v1_filenames_to_visual(v1_filenames)
|
| 169 |
+
if visual_filenames:
|
| 170 |
+
logger.info(f"π Filename translation: {v1_filenames} -> {visual_filenames}")
|
| 171 |
+
filter_params['filename'] = visual_filenames
|
| 172 |
+
else:
|
| 173 |
+
logger.warning(f"β οΈ No visual filename mappings found for: {v1_filenames}")
|
| 174 |
+
filter_params['filename'] = v1_filenames
|
| 175 |
if 'has_text' in filters:
|
| 176 |
filter_params['has_text'] = filters['has_text']
|
| 177 |
|
|
|
|
| 226 |
def create_visual_search_adapter(
|
| 227 |
qdrant_url: Optional[str] = None,
|
| 228 |
qdrant_api_key: Optional[str] = None,
|
| 229 |
+
collection_name: Optional[str] = None
|
| 230 |
) -> VisualSearchAdapter:
|
| 231 |
"""
|
| 232 |
Factory function to create a visual search adapter.
|
|
|
|
| 234 |
Args:
|
| 235 |
qdrant_url: Qdrant URL (reads from env if not provided)
|
| 236 |
qdrant_api_key: Qdrant API key (reads from env if not provided)
|
| 237 |
+
collection_name: Collection name (reads from QDRANT_COLLECTION_VISUAL env var if not provided)
|
| 238 |
|
| 239 |
Returns:
|
| 240 |
Initialized VisualSearchAdapter
|
|
|
|
| 245 |
qdrant_url = os.environ.get("QDRANT_URL")
|
| 246 |
if qdrant_api_key is None:
|
| 247 |
qdrant_api_key = os.environ.get("QDRANT_API_KEY")
|
| 248 |
+
if collection_name is None:
|
| 249 |
+
collection_name = os.environ.get("QDRANT_COLLECTION_VISUAL", "colSmol-500M-v2")
|
| 250 |
|
| 251 |
if not qdrant_url or not qdrant_api_key:
|
| 252 |
raise ValueError("QDRANT_URL and QDRANT_API_KEY must be provided or set in environment")
|