Spaces:
Sleeping
Sleeping
| """ | |
| Visual Document Search Adapter for Main App | |
| This module provides an adapter to integrate ColPali visual search | |
| into the main app's retrieval pipeline. | |
| All dependencies are now within src/colpali/ - no external colpali_colab_package needed. | |
| """ | |
| import logging | |
| from typing import List, Dict, Any, Optional | |
| import torch | |
| import numpy as np | |
| from qdrant_client import QdrantClient | |
| # Import from local src/colpali modules (no external dependencies) | |
| from src.colpali.processor import ColPaliProcessor | |
| from src.colpali.search import VisualDocumentSearch | |
| # Import device detection utility | |
| from src.utils import get_device_for_colpali | |
| logger = logging.getLogger(__name__) | |
| DEFAULT_MODEL = "colSmol-500M" | |
| class VisualSearchResult: | |
| """ | |
| Wrapper for visual search results to match the interface expected by app.py | |
| """ | |
| def __init__(self, point_id: str, score: float, payload: Dict[str, Any]): | |
| self.id = point_id | |
| self.score = score | |
| self.payload = payload | |
| self.metadata = payload # Alias for compatibility | |
| # Extract content for compatibility with Document interface | |
| self.page_content = payload.get('text', '') | |
| self.content = self.page_content | |
| def __repr__(self): | |
| return f"VisualSearchResult(id={self.id}, score={self.score:.4f})" | |
| class VisualSearchAdapter: | |
| """ | |
| Adapter to integrate ColPali visual search into the main app. | |
| This provides a unified interface for visual document retrieval that works | |
| with the existing chatbot architecture. | |
| """ | |
| def __init__( | |
| self, | |
| qdrant_url: str, | |
| qdrant_api_key: str, | |
| collection_name: str = DEFAULT_MODEL, | |
| model_name: str = f"vidore/{DEFAULT_MODEL}", | |
| device: str = None, | |
| batch_size: int = 4 | |
| ): | |
| """ | |
| Initialize visual search adapter. | |
| Args: | |
| qdrant_url: Qdrant cluster URL | |
| qdrant_api_key: Qdrant API key | |
| collection_name: Name of the collection with visual embeddings | |
| model_name: ColPali model name | |
| device: Device to use (cuda/cpu/mps, auto-detected if None) | |
| batch_size: Batch size for embedding generation | |
| """ | |
| logger.info("π¨ Initializing Visual Search Adapter...") | |
| # Auto-detect device using utility function | |
| if device is None: | |
| device = get_device_for_colpali() | |
| self.device = device | |
| logger.info(f" Device: {device}") | |
| # Initialize Qdrant client | |
| logger.info(f" Connecting to Qdrant: {qdrant_url}") | |
| self.client = QdrantClient( | |
| url=qdrant_url, | |
| api_key=qdrant_api_key, | |
| prefer_grpc=False, # Use HTTP for compatibility | |
| timeout=60 | |
| ) | |
| # Initialize search engine (from local src/colpali/search.py) | |
| self.search_engine = VisualDocumentSearch( | |
| qdrant_client=self.client, | |
| collection_name=collection_name | |
| ) | |
| # Initialize processor (from local src/colpali/processor.py) | |
| logger.info(f" Loading model: {model_name}") | |
| torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32 | |
| self.processor = ColPaliProcessor( | |
| model_name=model_name, | |
| device=device, | |
| torch_dtype=torch_dtype, | |
| batch_size=batch_size | |
| ) | |
| # Store last query embedding for saliency generation | |
| self.last_query_embedding = None | |
| self.collection_name = collection_name | |
| logger.info("β Visual Search Adapter initialized!") | |
| def search( | |
| self, | |
| query: str, | |
| top_k: int = 10, | |
| filters: Optional[Dict[str, Any]] = None, | |
| search_strategy: str = "multi_vector", | |
| **kwargs | |
| ) -> List[VisualSearchResult]: | |
| """ | |
| Search for visually similar documents. | |
| Args: | |
| query: Text query | |
| top_k: Number of results to return | |
| filters: Optional filters (year, source, district, filename, has_text) | |
| search_strategy: Search strategy (multi_vector, pooled, hybrid) | |
| **kwargs: Additional search parameters | |
| Returns: | |
| List of VisualSearchResult objects | |
| """ | |
| logger.info(f"π Visual search: '{query}' (top_k={top_k}, strategy={search_strategy})") | |
| # Generate query embedding | |
| query_embedding = self.processor.embed_query(query) | |
| # Store for saliency generation | |
| self.last_query_embedding = query_embedding | |
| # Convert filters to Qdrant format | |
| filter_params = {} | |
| if filters: | |
| if 'sources' in filters and filters['sources']: | |
| filter_params['source'] = filters['sources'] | |
| if 'years' in filters and filters['years']: | |
| years = filters['years'] | |
| if isinstance(years, list): | |
| filter_params['year'] = [int(y) if isinstance(y, str) else y for y in years] | |
| else: | |
| filter_params['year'] = int(years) if isinstance(years, str) else years | |
| if 'districts' in filters and filters['districts']: | |
| filter_params['district'] = filters['districts'] | |
| if 'filenames' in filters and filters['filenames']: | |
| filter_params['filename'] = filters['filenames'] | |
| if 'has_text' in filters: | |
| filter_params['has_text'] = filters['has_text'] | |
| logger.info(f"π Visual search: Converted filter params: {filter_params}") | |
| # Perform search | |
| results = self.search_engine.search( | |
| query_embedding=query_embedding, | |
| top_k=top_k, | |
| search_strategy=search_strategy, | |
| **filter_params, | |
| **kwargs | |
| ) | |
| # Fallback: If 0 results with filters, retry without filters | |
| if not results and filter_params: | |
| logger.warning(f"β οΈ Visual search: 0 results with filters, retrying WITHOUT filters...") | |
| results = self.search_engine.search( | |
| query_embedding=query_embedding, | |
| top_k=top_k, | |
| search_strategy=search_strategy, | |
| **kwargs # No filter_params | |
| ) | |
| if results: | |
| logger.info(f"β Visual search: Found {len(results)} results after removing filters") | |
| else: | |
| logger.warning(f"β Visual search: Still 0 results even without filters") | |
| # Convert to VisualSearchResult objects | |
| visual_results = [] | |
| for result in results: | |
| visual_result = VisualSearchResult( | |
| point_id=result['id'], | |
| score=result.get('score_final', result.get('score', 0.0)), | |
| payload=result['payload'] | |
| ) | |
| visual_results.append(visual_result) | |
| logger.info(f"β Found {len(visual_results)} visual results") | |
| return visual_results | |
| def get_filter_options(self) -> Dict[str, List[Any]]: | |
| """ | |
| Get available filter options from the collection. | |
| Returns: | |
| Dictionary with years, sources, districts, filenames | |
| """ | |
| return self.search_engine.get_filter_options() | |
| def create_visual_search_adapter( | |
| qdrant_url: Optional[str] = None, | |
| qdrant_api_key: Optional[str] = None, | |
| collection_name: str = DEFAULT_MODEL | |
| ) -> VisualSearchAdapter: | |
| """ | |
| Factory function to create a visual search adapter. | |
| Args: | |
| qdrant_url: Qdrant URL (reads from env if not provided) | |
| qdrant_api_key: Qdrant API key (reads from env if not provided) | |
| collection_name: Collection name | |
| Returns: | |
| Initialized VisualSearchAdapter | |
| """ | |
| import os | |
| if qdrant_url is None: | |
| qdrant_url = os.environ.get("QDRANT_URL") | |
| if qdrant_api_key is None: | |
| qdrant_api_key = os.environ.get("QDRANT_API_KEY") | |
| if not qdrant_url or not qdrant_api_key: | |
| raise ValueError("QDRANT_URL and QDRANT_API_KEY must be provided or set in environment") | |
| return VisualSearchAdapter( | |
| qdrant_url=qdrant_url, | |
| qdrant_api_key=qdrant_api_key, | |
| collection_name=collection_name | |
| ) | |