Spaces:
Sleeping
Sleeping
| """ | |
| Visual Document Search Engine | |
| Two-stage visual document retrieval: | |
| 1. Fast prefetch using pooled vectors (mean/max with HNSW) | |
| 2. Exact reranking using full multi-vector embeddings (ColBERT-style) | |
| """ | |
| import logging | |
| from typing import List, Dict, Any, Optional | |
| import numpy as np | |
| import torch | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Filter, FieldCondition, MatchValue, MatchAny, Range | |
| logger = logging.getLogger(__name__) | |
| class VisualDocumentSearch: | |
| """ | |
| Two-stage visual document retrieval: | |
| - Stage 1: Fast HNSW search with pooled vectors (10-100ms) | |
| - Stage 2: Exact ColBERT reranking with full embeddings (100-500ms) | |
| """ | |
| def __init__( | |
| self, | |
| qdrant_client: QdrantClient, | |
| collection_name: str = "colSmol-500M" | |
| ): | |
| """ | |
| Initialize search engine. | |
| Args: | |
| qdrant_client: Connected Qdrant client | |
| collection_name: Name of the collection | |
| """ | |
| self.client = qdrant_client | |
| self.collection_name = collection_name | |
| def get_filter_options( | |
| self, | |
| max_points: int = None, | |
| use_cache: bool = True, | |
| progress_callback=None | |
| ) -> Dict[str, List[Any]]: | |
| """ | |
| Scan collection to get all possible filter values using iterative scrolling. | |
| Args: | |
| max_points: Maximum number of points to scan (None = scan all) | |
| use_cache: Whether to cache results (default True) | |
| progress_callback: Optional callback function(points_scanned, elapsed_time, iteration) | |
| Returns: | |
| Dictionary with all unique values for each filterable field | |
| """ | |
| scan_limit = max_points if max_points else "all" | |
| logger.info(f"π Starting metadata scan (target: {scan_limit} points)") | |
| logger.info(f" Collection: {self.collection_name}") | |
| # Scroll through points to collect unique values | |
| years = set() | |
| sources = set() | |
| districts = set() | |
| filenames = set() | |
| batch_size = 900 | |
| points_scanned = 0 | |
| offset = None | |
| iteration = 0 | |
| max_iterations = 100 | |
| import time | |
| start_time = time.time() | |
| try: | |
| while True: | |
| iteration += 1 | |
| if iteration > max_iterations: | |
| logger.warning(f"β οΈ Reached max iterations ({max_iterations}), stopping") | |
| break | |
| if max_points and points_scanned >= max_points: | |
| logger.info(f"β Reached target of {max_points} points") | |
| break | |
| if max_points: | |
| remaining = max_points - points_scanned | |
| current_batch_size = min(batch_size, remaining) | |
| else: | |
| current_batch_size = batch_size | |
| elapsed = time.time() - start_time | |
| logger.info(f" Batch {iteration}: fetching {current_batch_size} points (scanned: {points_scanned}, {elapsed:.1f}s)") | |
| batch_start = time.time() | |
| try: | |
| results = self.client.scroll( | |
| collection_name=self.collection_name, | |
| limit=current_batch_size, | |
| offset=offset, | |
| with_payload=True, | |
| with_vectors=False, | |
| ) | |
| points, next_offset = results | |
| batch_time = time.time() - batch_start | |
| logger.info(f" β Fetched {len(points)} points in {batch_time:.2f}s") | |
| except Exception as scroll_error: | |
| logger.error(f"β Scroll failed at iteration {iteration}: {scroll_error}") | |
| break | |
| if not points: | |
| logger.info(f"β Reached end of collection (scanned {points_scanned} points)") | |
| break | |
| for point in points: | |
| payload = point.payload | |
| if payload.get('year'): | |
| year_value = payload['year'] | |
| if isinstance(year_value, str): | |
| try: | |
| year_value = int(year_value) | |
| except ValueError: | |
| continue | |
| if isinstance(year_value, int): | |
| years.add(year_value) | |
| if payload.get('source'): | |
| sources.add(payload['source']) | |
| if payload.get('district'): | |
| districts.add(payload['district']) | |
| if payload.get('filename'): | |
| filenames.add(payload['filename']) | |
| points_scanned += len(points) | |
| offset = next_offset | |
| if progress_callback: | |
| elapsed = time.time() - start_time | |
| progress_callback(points_scanned, elapsed, iteration) | |
| if offset is None: | |
| elapsed = time.time() - start_time | |
| logger.info(f"β Completed full scan: {points_scanned} points in {elapsed:.1f}s") | |
| break | |
| elapsed = time.time() - start_time | |
| logger.info(f"β Scan complete: {points_scanned} points in {elapsed:.1f}s") | |
| logger.info(f" Found: {len(years)} years, {len(sources)} sources, " | |
| f"{len(districts)} districts, {len(filenames)} files") | |
| except Exception as e: | |
| logger.error(f"β Error scanning collection: {e}") | |
| return { | |
| 'years': sorted(list(years)), | |
| 'sources': sorted(list(sources)), | |
| 'districts': sorted(list(districts)), | |
| 'filenames': sorted(list(filenames)) | |
| } | |
| def build_filter( | |
| self, | |
| year: Optional[Any] = None, | |
| source: Optional[Any] = None, | |
| district: Optional[Any] = None, | |
| filename: Optional[Any] = None, | |
| has_text: Optional[bool] = None, | |
| page_range: Optional[tuple] = None | |
| ) -> Optional[Filter]: | |
| """ | |
| Build Qdrant filter from parameters. | |
| Supports both single values and lists (using MatchAny for lists). | |
| """ | |
| conditions = [] | |
| if year is not None: | |
| if isinstance(year, list): | |
| year_values = [int(y) if isinstance(y, str) else y for y in year] | |
| conditions.append( | |
| FieldCondition(key="year", match=MatchAny(any=year_values)) | |
| ) | |
| logger.info(f"π Filter: year IN {year_values}") | |
| else: | |
| year_value = int(year) if isinstance(year, str) else year | |
| conditions.append( | |
| FieldCondition(key="year", match=MatchValue(value=year_value)) | |
| ) | |
| logger.info(f"π Filter: year = {year_value}") | |
| if source is not None: | |
| if isinstance(source, list): | |
| conditions.append( | |
| FieldCondition(key="source", match=MatchAny(any=source)) | |
| ) | |
| logger.info(f"π Filter: source IN {source}") | |
| else: | |
| conditions.append( | |
| FieldCondition(key="source", match=MatchValue(value=source)) | |
| ) | |
| logger.info(f"π Filter: source = {source}") | |
| if district is not None: | |
| if isinstance(district, list): | |
| conditions.append( | |
| FieldCondition(key="district", match=MatchAny(any=district)) | |
| ) | |
| logger.info(f"π Filter: district IN {district}") | |
| else: | |
| conditions.append( | |
| FieldCondition(key="district", match=MatchValue(value=district)) | |
| ) | |
| logger.info(f"π Filter: district = {district}") | |
| if filename is not None: | |
| if isinstance(filename, list): | |
| conditions.append( | |
| FieldCondition(key="filename", match=MatchAny(any=filename)) | |
| ) | |
| logger.info(f"π Filter: filename IN {filename}") | |
| else: | |
| conditions.append( | |
| FieldCondition(key="filename", match=MatchValue(value=filename)) | |
| ) | |
| logger.info(f"π Filter: filename = {filename}") | |
| if has_text is not None: | |
| conditions.append( | |
| FieldCondition(key="has_text", match=MatchValue(value=has_text)) | |
| ) | |
| if page_range is not None: | |
| min_page, max_page = page_range | |
| conditions.append( | |
| FieldCondition( | |
| key="page_number", | |
| range=Range(gte=min_page, lte=max_page) | |
| ) | |
| ) | |
| if not conditions: | |
| return None | |
| return Filter(must=conditions) | |
| def search_stage1_prefetch( | |
| self, | |
| query_embedding: torch.Tensor, | |
| top_k: int = 100, | |
| filter_obj: Optional[Filter] = None, | |
| use_pooling: bool = False, | |
| pooling_method: str = "mean" | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Stage 1: Prefetch candidates using either multi-vector or pooled search. | |
| """ | |
| # Convert to numpy | |
| if isinstance(query_embedding, torch.Tensor): | |
| query_np = query_embedding.cpu().float().numpy() | |
| else: | |
| query_np = np.array(query_embedding, dtype=np.float32) | |
| # Handle batch dimension | |
| if query_np.ndim == 3: | |
| query_np = query_np.squeeze(0) | |
| # Strategy 1: Pooled search (fast, approximate) | |
| if use_pooling: | |
| if pooling_method == "mean": | |
| query_pooled = query_np.mean(axis=0) | |
| vector_name = "mean_pooling" | |
| elif pooling_method == "max": | |
| query_pooled = query_np.max(axis=0) | |
| vector_name = "max_pooling" | |
| else: | |
| raise ValueError(f"Unknown pooling method: {pooling_method}") | |
| if query_pooled.ndim != 1: | |
| raise ValueError(f"Pooling failed! Expected 1D vector, got shape {query_pooled.shape}") | |
| query_vector = query_pooled.tolist() | |
| logger.info(f"π Pooled search: vector={vector_name}, dims={len(query_vector)}") | |
| # Strategy 2: Native multi-vector search (SOTA) | |
| else: | |
| vector_name = "initial" | |
| query_vector = query_np.tolist() | |
| logger.info(f"π― Multi-vector search: vector={vector_name}, patches={len(query_vector)}, dims={len(query_vector[0])}") | |
| try: | |
| results = self.client.query_points( | |
| collection_name=self.collection_name, | |
| query=query_vector, | |
| using=vector_name, | |
| query_filter=filter_obj, | |
| limit=top_k, | |
| with_payload=True, | |
| with_vectors=False, | |
| timeout=120 | |
| ).points | |
| logger.info(f"β Stage 1: Retrieved {len(results)} candidates") | |
| except Exception as e: | |
| logger.error(f"β Search with vector '{vector_name}' failed: {e}") | |
| raise | |
| candidates = [] | |
| for result in results: | |
| candidates.append({ | |
| 'id': result.id, | |
| 'score_stage1': result.score, | |
| 'payload': result.payload | |
| }) | |
| return candidates | |
| def colbert_score( | |
| self, | |
| query_embedding: np.ndarray, | |
| doc_embedding: np.ndarray | |
| ) -> float: | |
| """ | |
| Compute ColBERT-style late interaction score. | |
| """ | |
| # Normalize embeddings | |
| query_norm = query_embedding / (np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8) | |
| doc_norm = doc_embedding / (np.linalg.norm(doc_embedding, axis=1, keepdims=True) + 1e-8) | |
| # Compute similarity matrix | |
| sim_matrix = np.dot(query_norm, doc_norm.T) | |
| # For each query patch, take max similarity with any doc patch | |
| max_sims = sim_matrix.max(axis=1) | |
| # Average across query patches | |
| score = max_sims.mean() | |
| return float(score) | |
| def search_stage2_rerank( | |
| self, | |
| query_embedding: torch.Tensor, | |
| candidates: List[Dict[str, Any]], | |
| top_k: int = 10 | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Stage 2: Exact reranking using full multi-vector embeddings. | |
| """ | |
| if isinstance(query_embedding, torch.Tensor): | |
| query_np = query_embedding.cpu().float().numpy() | |
| else: | |
| query_np = np.array(query_embedding, dtype=np.float32) | |
| reranked = [] | |
| for candidate in candidates: | |
| payload = candidate['payload'] | |
| full_embedding = payload.get('full_embedding') | |
| if full_embedding is None: | |
| candidate['score_final'] = candidate['score_stage1'] | |
| reranked.append(candidate) | |
| continue | |
| doc_np = np.array(full_embedding, dtype=np.float32) | |
| colbert_score = self.colbert_score(query_np, doc_np) | |
| candidate['score_stage2'] = colbert_score | |
| candidate['score_final'] = colbert_score | |
| reranked.append(candidate) | |
| reranked.sort(key=lambda x: x['score_final'], reverse=True) | |
| return reranked[:top_k] | |
| def search( | |
| self, | |
| query_embedding: torch.Tensor, | |
| top_k: int = 10, | |
| prefetch_k: Optional[int] = None, | |
| year: Optional[int] = None, | |
| source: Optional[str] = None, | |
| district: Optional[str] = None, | |
| filename: Optional[str] = None, | |
| has_text: Optional[bool] = None, | |
| page_range: Optional[tuple] = None, | |
| search_strategy: str = "multi_vector", | |
| pooling_method: str = "mean", | |
| use_reranking: bool = False | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Multi-strategy visual document search. | |
| Search Strategies: | |
| 1. "multi_vector" (DEFAULT, SOTA): Native multi-vector search | |
| 2. "pooled": Pooled search (fastest, less accurate) | |
| 3. "hybrid": Two-stage retrieval with reranking | |
| """ | |
| # Build filter | |
| filter_obj = self.build_filter( | |
| year=year, | |
| source=source, | |
| district=district, | |
| filename=filename, | |
| has_text=has_text, | |
| page_range=page_range | |
| ) | |
| # Strategy 1: Native multi-vector search (SOTA, default) | |
| if search_strategy == "multi_vector": | |
| logger.info(f"π― SOTA Multi-Vector Search: Querying 'initial' vector with native MaxSim") | |
| candidates = self.search_stage1_prefetch( | |
| query_embedding=query_embedding, | |
| top_k=top_k, | |
| filter_obj=filter_obj, | |
| use_pooling=False | |
| ) | |
| if not candidates: | |
| logger.warning("β No results found") | |
| return [] | |
| for c in candidates: | |
| c['score_final'] = c['score_stage1'] | |
| logger.info(f"β Retrieved {len(candidates)} results (native MaxSim)") | |
| return candidates | |
| # Strategy 2: Pooled search (fast, approximate) | |
| elif search_strategy == "pooled": | |
| logger.info(f"π Pooled Search: Querying '{pooling_method}_pooling' vector") | |
| candidates = self.search_stage1_prefetch( | |
| query_embedding=query_embedding, | |
| top_k=top_k, | |
| filter_obj=filter_obj, | |
| use_pooling=True, | |
| pooling_method=pooling_method | |
| ) | |
| if not candidates: | |
| logger.warning("β No results found") | |
| return [] | |
| for c in candidates: | |
| c['score_final'] = c['score_stage1'] | |
| logger.info(f"β Retrieved {len(candidates)} results (pooled)") | |
| return candidates | |
| # Strategy 3: Hybrid two-stage | |
| elif search_strategy == "hybrid": | |
| if prefetch_k is None: | |
| prefetch_k = max(100, top_k * 10) | |
| logger.info(f"π Hybrid Search: Stage 1 - Prefetching {prefetch_k} with {pooling_method} pooling") | |
| candidates = self.search_stage1_prefetch( | |
| query_embedding=query_embedding, | |
| top_k=prefetch_k, | |
| filter_obj=filter_obj, | |
| use_pooling=True, | |
| pooling_method=pooling_method | |
| ) | |
| if not candidates: | |
| logger.warning("β No results found in stage 1") | |
| return [] | |
| logger.info(f"β Stage 1: Found {len(candidates)} candidates") | |
| if use_reranking and len(candidates) > top_k: | |
| logger.info(f"π― Stage 2: Reranking with ColBERT scoring...") | |
| results = self.search_stage2_rerank( | |
| query_embedding=query_embedding, | |
| candidates=candidates, | |
| top_k=top_k | |
| ) | |
| logger.info(f"β Reranked to top {len(results)} results") | |
| return results | |
| else: | |
| results = candidates[:top_k] | |
| for r in results: | |
| r['score_final'] = r['score_stage1'] | |
| logger.info(f"βοΈ Skipping reranking, returning top {len(results)}") | |
| return results | |
| else: | |
| raise ValueError(f"Unknown search_strategy: {search_strategy}") | |