audit_assistant / src /colpali /visual_search.py
akryldigital's picture
DEFAULT_MODEL = ColSmol-500M
e0bf59c verified
"""
Visual Document Search Adapter for Main App
This module provides an adapter to integrate ColPali visual search
into the main app's retrieval pipeline.
All dependencies are now within src/colpali/ - no external colpali_colab_package needed.
"""
import logging
from typing import List, Dict, Any, Optional
import torch
import numpy as np
from qdrant_client import QdrantClient
# Import from local src/colpali modules (no external dependencies)
from src.colpali.processor import ColPaliProcessor
from src.colpali.search import VisualDocumentSearch
# Import device detection utility
from src.utils import get_device_for_colpali
logger = logging.getLogger(__name__)
DEFAULT_MODEL = "colSmol-500M"
class VisualSearchResult:
"""
Wrapper for visual search results to match the interface expected by app.py
"""
def __init__(self, point_id: str, score: float, payload: Dict[str, Any]):
self.id = point_id
self.score = score
self.payload = payload
self.metadata = payload # Alias for compatibility
# Extract content for compatibility with Document interface
self.page_content = payload.get('text', '')
self.content = self.page_content
def __repr__(self):
return f"VisualSearchResult(id={self.id}, score={self.score:.4f})"
class VisualSearchAdapter:
"""
Adapter to integrate ColPali visual search into the main app.
This provides a unified interface for visual document retrieval that works
with the existing chatbot architecture.
"""
def __init__(
self,
qdrant_url: str,
qdrant_api_key: str,
collection_name: str = DEFAULT_MODEL,
model_name: str = f"vidore/{DEFAULT_MODEL}",
device: str = None,
batch_size: int = 4
):
"""
Initialize visual search adapter.
Args:
qdrant_url: Qdrant cluster URL
qdrant_api_key: Qdrant API key
collection_name: Name of the collection with visual embeddings
model_name: ColPali model name
device: Device to use (cuda/cpu/mps, auto-detected if None)
batch_size: Batch size for embedding generation
"""
logger.info("🎨 Initializing Visual Search Adapter...")
# Auto-detect device using utility function
if device is None:
device = get_device_for_colpali()
self.device = device
logger.info(f" Device: {device}")
# Initialize Qdrant client
logger.info(f" Connecting to Qdrant: {qdrant_url}")
self.client = QdrantClient(
url=qdrant_url,
api_key=qdrant_api_key,
prefer_grpc=False, # Use HTTP for compatibility
timeout=60
)
# Initialize search engine (from local src/colpali/search.py)
self.search_engine = VisualDocumentSearch(
qdrant_client=self.client,
collection_name=collection_name
)
# Initialize processor (from local src/colpali/processor.py)
logger.info(f" Loading model: {model_name}")
torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
self.processor = ColPaliProcessor(
model_name=model_name,
device=device,
torch_dtype=torch_dtype,
batch_size=batch_size
)
# Store last query embedding for saliency generation
self.last_query_embedding = None
self.collection_name = collection_name
logger.info("βœ… Visual Search Adapter initialized!")
def search(
self,
query: str,
top_k: int = 10,
filters: Optional[Dict[str, Any]] = None,
search_strategy: str = "multi_vector",
**kwargs
) -> List[VisualSearchResult]:
"""
Search for visually similar documents.
Args:
query: Text query
top_k: Number of results to return
filters: Optional filters (year, source, district, filename, has_text)
search_strategy: Search strategy (multi_vector, pooled, hybrid)
**kwargs: Additional search parameters
Returns:
List of VisualSearchResult objects
"""
logger.info(f"πŸ” Visual search: '{query}' (top_k={top_k}, strategy={search_strategy})")
# Generate query embedding
query_embedding = self.processor.embed_query(query)
# Store for saliency generation
self.last_query_embedding = query_embedding
# Convert filters to Qdrant format
filter_params = {}
if filters:
if 'sources' in filters and filters['sources']:
filter_params['source'] = filters['sources']
if 'years' in filters and filters['years']:
years = filters['years']
if isinstance(years, list):
filter_params['year'] = [int(y) if isinstance(y, str) else y for y in years]
else:
filter_params['year'] = int(years) if isinstance(years, str) else years
if 'districts' in filters and filters['districts']:
filter_params['district'] = filters['districts']
if 'filenames' in filters and filters['filenames']:
filter_params['filename'] = filters['filenames']
if 'has_text' in filters:
filter_params['has_text'] = filters['has_text']
logger.info(f"πŸ” Visual search: Converted filter params: {filter_params}")
# Perform search
results = self.search_engine.search(
query_embedding=query_embedding,
top_k=top_k,
search_strategy=search_strategy,
**filter_params,
**kwargs
)
# Fallback: If 0 results with filters, retry without filters
if not results and filter_params:
logger.warning(f"⚠️ Visual search: 0 results with filters, retrying WITHOUT filters...")
results = self.search_engine.search(
query_embedding=query_embedding,
top_k=top_k,
search_strategy=search_strategy,
**kwargs # No filter_params
)
if results:
logger.info(f"βœ… Visual search: Found {len(results)} results after removing filters")
else:
logger.warning(f"❌ Visual search: Still 0 results even without filters")
# Convert to VisualSearchResult objects
visual_results = []
for result in results:
visual_result = VisualSearchResult(
point_id=result['id'],
score=result.get('score_final', result.get('score', 0.0)),
payload=result['payload']
)
visual_results.append(visual_result)
logger.info(f"βœ… Found {len(visual_results)} visual results")
return visual_results
def get_filter_options(self) -> Dict[str, List[Any]]:
"""
Get available filter options from the collection.
Returns:
Dictionary with years, sources, districts, filenames
"""
return self.search_engine.get_filter_options()
def create_visual_search_adapter(
qdrant_url: Optional[str] = None,
qdrant_api_key: Optional[str] = None,
collection_name: str = DEFAULT_MODEL
) -> VisualSearchAdapter:
"""
Factory function to create a visual search adapter.
Args:
qdrant_url: Qdrant URL (reads from env if not provided)
qdrant_api_key: Qdrant API key (reads from env if not provided)
collection_name: Collection name
Returns:
Initialized VisualSearchAdapter
"""
import os
if qdrant_url is None:
qdrant_url = os.environ.get("QDRANT_URL")
if qdrant_api_key is None:
qdrant_api_key = os.environ.get("QDRANT_API_KEY")
if not qdrant_url or not qdrant_api_key:
raise ValueError("QDRANT_URL and QDRANT_API_KEY must be provided or set in environment")
return VisualSearchAdapter(
qdrant_url=qdrant_url,
qdrant_api_key=qdrant_api_key,
collection_name=collection_name
)