""" Visual Document Search Engine Two-stage visual document retrieval: 1. Fast prefetch using pooled vectors (mean/max with HNSW) 2. Exact reranking using full multi-vector embeddings (ColBERT-style) """ import logging from typing import List, Dict, Any, Optional import numpy as np import torch from qdrant_client import QdrantClient from qdrant_client.models import Filter, FieldCondition, MatchValue, MatchAny, Range logger = logging.getLogger(__name__) class VisualDocumentSearch: """ Two-stage visual document retrieval: - Stage 1: Fast HNSW search with pooled vectors (10-100ms) - Stage 2: Exact ColBERT reranking with full embeddings (100-500ms) """ def __init__( self, qdrant_client: QdrantClient, collection_name: str = "colSmol-500M" ): """ Initialize search engine. Args: qdrant_client: Connected Qdrant client collection_name: Name of the collection """ self.client = qdrant_client self.collection_name = collection_name def get_filter_options( self, max_points: int = None, use_cache: bool = True, progress_callback=None ) -> Dict[str, List[Any]]: """ Scan collection to get all possible filter values using iterative scrolling. Args: max_points: Maximum number of points to scan (None = scan all) use_cache: Whether to cache results (default True) progress_callback: Optional callback function(points_scanned, elapsed_time, iteration) Returns: Dictionary with all unique values for each filterable field """ scan_limit = max_points if max_points else "all" logger.info(f"🔍 Starting metadata scan (target: {scan_limit} points)") logger.info(f" Collection: {self.collection_name}") # Scroll through points to collect unique values years = set() sources = set() districts = set() filenames = set() batch_size = 900 points_scanned = 0 offset = None iteration = 0 max_iterations = 100 import time start_time = time.time() try: while True: iteration += 1 if iteration > max_iterations: logger.warning(f"⚠️ Reached max iterations ({max_iterations}), stopping") break if max_points and points_scanned >= max_points: logger.info(f"✅ Reached target of {max_points} points") break if max_points: remaining = max_points - points_scanned current_batch_size = min(batch_size, remaining) else: current_batch_size = batch_size elapsed = time.time() - start_time logger.info(f" Batch {iteration}: fetching {current_batch_size} points (scanned: {points_scanned}, {elapsed:.1f}s)") batch_start = time.time() try: results = self.client.scroll( collection_name=self.collection_name, limit=current_batch_size, offset=offset, with_payload=True, with_vectors=False, ) points, next_offset = results batch_time = time.time() - batch_start logger.info(f" ✓ Fetched {len(points)} points in {batch_time:.2f}s") except Exception as scroll_error: logger.error(f"❌ Scroll failed at iteration {iteration}: {scroll_error}") break if not points: logger.info(f"✅ Reached end of collection (scanned {points_scanned} points)") break for point in points: payload = point.payload if payload.get('year'): year_value = payload['year'] if isinstance(year_value, str): try: year_value = int(year_value) except ValueError: continue if isinstance(year_value, int): years.add(year_value) if payload.get('source'): sources.add(payload['source']) if payload.get('district'): districts.add(payload['district']) if payload.get('filename'): filenames.add(payload['filename']) points_scanned += len(points) offset = next_offset if progress_callback: elapsed = time.time() - start_time progress_callback(points_scanned, elapsed, iteration) if offset is None: elapsed = time.time() - start_time logger.info(f"✅ Completed full scan: {points_scanned} points in {elapsed:.1f}s") break elapsed = time.time() - start_time logger.info(f"✅ Scan complete: {points_scanned} points in {elapsed:.1f}s") logger.info(f" Found: {len(years)} years, {len(sources)} sources, " f"{len(districts)} districts, {len(filenames)} files") except Exception as e: logger.error(f"❌ Error scanning collection: {e}") return { 'years': sorted(list(years)), 'sources': sorted(list(sources)), 'districts': sorted(list(districts)), 'filenames': sorted(list(filenames)) } def build_filter( self, year: Optional[Any] = None, source: Optional[Any] = None, district: Optional[Any] = None, filename: Optional[Any] = None, has_text: Optional[bool] = None, page_range: Optional[tuple] = None ) -> Optional[Filter]: """ Build Qdrant filter from parameters. Supports both single values and lists (using MatchAny for lists). """ conditions = [] if year is not None: if isinstance(year, list): year_values = [int(y) if isinstance(y, str) else y for y in year] conditions.append( FieldCondition(key="year", match=MatchAny(any=year_values)) ) logger.info(f"🔍 Filter: year IN {year_values}") else: year_value = int(year) if isinstance(year, str) else year conditions.append( FieldCondition(key="year", match=MatchValue(value=year_value)) ) logger.info(f"🔍 Filter: year = {year_value}") if source is not None: if isinstance(source, list): conditions.append( FieldCondition(key="source", match=MatchAny(any=source)) ) logger.info(f"🔍 Filter: source IN {source}") else: conditions.append( FieldCondition(key="source", match=MatchValue(value=source)) ) logger.info(f"🔍 Filter: source = {source}") if district is not None: if isinstance(district, list): conditions.append( FieldCondition(key="district", match=MatchAny(any=district)) ) logger.info(f"🔍 Filter: district IN {district}") else: conditions.append( FieldCondition(key="district", match=MatchValue(value=district)) ) logger.info(f"🔍 Filter: district = {district}") if filename is not None: if isinstance(filename, list): conditions.append( FieldCondition(key="filename", match=MatchAny(any=filename)) ) logger.info(f"🔍 Filter: filename IN {filename}") else: conditions.append( FieldCondition(key="filename", match=MatchValue(value=filename)) ) logger.info(f"🔍 Filter: filename = {filename}") if has_text is not None: conditions.append( FieldCondition(key="has_text", match=MatchValue(value=has_text)) ) if page_range is not None: min_page, max_page = page_range conditions.append( FieldCondition( key="page_number", range=Range(gte=min_page, lte=max_page) ) ) if not conditions: return None return Filter(must=conditions) def search_stage1_prefetch( self, query_embedding: torch.Tensor, top_k: int = 100, filter_obj: Optional[Filter] = None, use_pooling: bool = False, pooling_method: str = "mean" ) -> List[Dict[str, Any]]: """ Stage 1: Prefetch candidates using either multi-vector or pooled search. """ # Convert to numpy if isinstance(query_embedding, torch.Tensor): query_np = query_embedding.cpu().float().numpy() else: query_np = np.array(query_embedding, dtype=np.float32) # Handle batch dimension if query_np.ndim == 3: query_np = query_np.squeeze(0) # Strategy 1: Pooled search (fast, approximate) if use_pooling: if pooling_method == "mean": query_pooled = query_np.mean(axis=0) vector_name = "mean_pooling" elif pooling_method == "max": query_pooled = query_np.max(axis=0) vector_name = "max_pooling" else: raise ValueError(f"Unknown pooling method: {pooling_method}") if query_pooled.ndim != 1: raise ValueError(f"Pooling failed! Expected 1D vector, got shape {query_pooled.shape}") query_vector = query_pooled.tolist() logger.info(f"🔍 Pooled search: vector={vector_name}, dims={len(query_vector)}") # Strategy 2: Native multi-vector search (SOTA) else: vector_name = "initial" query_vector = query_np.tolist() logger.info(f"🎯 Multi-vector search: vector={vector_name}, patches={len(query_vector)}, dims={len(query_vector[0])}") try: results = self.client.query_points( collection_name=self.collection_name, query=query_vector, using=vector_name, query_filter=filter_obj, limit=top_k, with_payload=True, with_vectors=False, timeout=120 ).points logger.info(f"✅ Stage 1: Retrieved {len(results)} candidates") except Exception as e: logger.error(f"❌ Search with vector '{vector_name}' failed: {e}") raise candidates = [] for result in results: candidates.append({ 'id': result.id, 'score_stage1': result.score, 'payload': result.payload }) return candidates def colbert_score( self, query_embedding: np.ndarray, doc_embedding: np.ndarray ) -> float: """ Compute ColBERT-style late interaction score. """ # Normalize embeddings query_norm = query_embedding / (np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8) doc_norm = doc_embedding / (np.linalg.norm(doc_embedding, axis=1, keepdims=True) + 1e-8) # Compute similarity matrix sim_matrix = np.dot(query_norm, doc_norm.T) # For each query patch, take max similarity with any doc patch max_sims = sim_matrix.max(axis=1) # Average across query patches score = max_sims.mean() return float(score) def search_stage2_rerank( self, query_embedding: torch.Tensor, candidates: List[Dict[str, Any]], top_k: int = 10 ) -> List[Dict[str, Any]]: """ Stage 2: Exact reranking using full multi-vector embeddings. """ if isinstance(query_embedding, torch.Tensor): query_np = query_embedding.cpu().float().numpy() else: query_np = np.array(query_embedding, dtype=np.float32) reranked = [] for candidate in candidates: payload = candidate['payload'] full_embedding = payload.get('full_embedding') if full_embedding is None: candidate['score_final'] = candidate['score_stage1'] reranked.append(candidate) continue doc_np = np.array(full_embedding, dtype=np.float32) colbert_score = self.colbert_score(query_np, doc_np) candidate['score_stage2'] = colbert_score candidate['score_final'] = colbert_score reranked.append(candidate) reranked.sort(key=lambda x: x['score_final'], reverse=True) return reranked[:top_k] def search( self, query_embedding: torch.Tensor, top_k: int = 10, prefetch_k: Optional[int] = None, year: Optional[int] = None, source: Optional[str] = None, district: Optional[str] = None, filename: Optional[str] = None, has_text: Optional[bool] = None, page_range: Optional[tuple] = None, search_strategy: str = "multi_vector", pooling_method: str = "mean", use_reranking: bool = False ) -> List[Dict[str, Any]]: """ Multi-strategy visual document search. Search Strategies: 1. "multi_vector" (DEFAULT, SOTA): Native multi-vector search 2. "pooled": Pooled search (fastest, less accurate) 3. "hybrid": Two-stage retrieval with reranking """ # Build filter filter_obj = self.build_filter( year=year, source=source, district=district, filename=filename, has_text=has_text, page_range=page_range ) # Strategy 1: Native multi-vector search (SOTA, default) if search_strategy == "multi_vector": logger.info(f"🎯 SOTA Multi-Vector Search: Querying 'initial' vector with native MaxSim") candidates = self.search_stage1_prefetch( query_embedding=query_embedding, top_k=top_k, filter_obj=filter_obj, use_pooling=False ) if not candidates: logger.warning("❌ No results found") return [] for c in candidates: c['score_final'] = c['score_stage1'] logger.info(f"✅ Retrieved {len(candidates)} results (native MaxSim)") return candidates # Strategy 2: Pooled search (fast, approximate) elif search_strategy == "pooled": logger.info(f"🔍 Pooled Search: Querying '{pooling_method}_pooling' vector") candidates = self.search_stage1_prefetch( query_embedding=query_embedding, top_k=top_k, filter_obj=filter_obj, use_pooling=True, pooling_method=pooling_method ) if not candidates: logger.warning("❌ No results found") return [] for c in candidates: c['score_final'] = c['score_stage1'] logger.info(f"✅ Retrieved {len(candidates)} results (pooled)") return candidates # Strategy 3: Hybrid two-stage elif search_strategy == "hybrid": if prefetch_k is None: prefetch_k = max(100, top_k * 10) logger.info(f"🔄 Hybrid Search: Stage 1 - Prefetching {prefetch_k} with {pooling_method} pooling") candidates = self.search_stage1_prefetch( query_embedding=query_embedding, top_k=prefetch_k, filter_obj=filter_obj, use_pooling=True, pooling_method=pooling_method ) if not candidates: logger.warning("❌ No results found in stage 1") return [] logger.info(f"✅ Stage 1: Found {len(candidates)} candidates") if use_reranking and len(candidates) > top_k: logger.info(f"🎯 Stage 2: Reranking with ColBERT scoring...") results = self.search_stage2_rerank( query_embedding=query_embedding, candidates=candidates, top_k=top_k ) logger.info(f"✅ Reranked to top {len(results)} results") return results else: results = candidates[:top_k] for r in results: r['score_final'] = r['score_stage1'] logger.info(f"⏭️ Skipping reranking, returning top {len(results)}") return results else: raise ValueError(f"Unknown search_strategy: {search_strategy}")