akryldigital's picture
add colpali scripts
150fb2f verified
"""
Visual Document Search Engine
Two-stage visual document retrieval:
1. Fast prefetch using pooled vectors (mean/max with HNSW)
2. Exact reranking using full multi-vector embeddings (ColBERT-style)
"""
import logging
from typing import List, Dict, Any, Optional
import numpy as np
import torch
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue, MatchAny, Range
logger = logging.getLogger(__name__)
class VisualDocumentSearch:
"""
Two-stage visual document retrieval:
- Stage 1: Fast HNSW search with pooled vectors (10-100ms)
- Stage 2: Exact ColBERT reranking with full embeddings (100-500ms)
"""
def __init__(
self,
qdrant_client: QdrantClient,
collection_name: str = "colSmol-500M"
):
"""
Initialize search engine.
Args:
qdrant_client: Connected Qdrant client
collection_name: Name of the collection
"""
self.client = qdrant_client
self.collection_name = collection_name
def get_filter_options(
self,
max_points: int = None,
use_cache: bool = True,
progress_callback=None
) -> Dict[str, List[Any]]:
"""
Scan collection to get all possible filter values using iterative scrolling.
Args:
max_points: Maximum number of points to scan (None = scan all)
use_cache: Whether to cache results (default True)
progress_callback: Optional callback function(points_scanned, elapsed_time, iteration)
Returns:
Dictionary with all unique values for each filterable field
"""
scan_limit = max_points if max_points else "all"
logger.info(f"πŸ” Starting metadata scan (target: {scan_limit} points)")
logger.info(f" Collection: {self.collection_name}")
# Scroll through points to collect unique values
years = set()
sources = set()
districts = set()
filenames = set()
batch_size = 900
points_scanned = 0
offset = None
iteration = 0
max_iterations = 100
import time
start_time = time.time()
try:
while True:
iteration += 1
if iteration > max_iterations:
logger.warning(f"⚠️ Reached max iterations ({max_iterations}), stopping")
break
if max_points and points_scanned >= max_points:
logger.info(f"βœ… Reached target of {max_points} points")
break
if max_points:
remaining = max_points - points_scanned
current_batch_size = min(batch_size, remaining)
else:
current_batch_size = batch_size
elapsed = time.time() - start_time
logger.info(f" Batch {iteration}: fetching {current_batch_size} points (scanned: {points_scanned}, {elapsed:.1f}s)")
batch_start = time.time()
try:
results = self.client.scroll(
collection_name=self.collection_name,
limit=current_batch_size,
offset=offset,
with_payload=True,
with_vectors=False,
)
points, next_offset = results
batch_time = time.time() - batch_start
logger.info(f" βœ“ Fetched {len(points)} points in {batch_time:.2f}s")
except Exception as scroll_error:
logger.error(f"❌ Scroll failed at iteration {iteration}: {scroll_error}")
break
if not points:
logger.info(f"βœ… Reached end of collection (scanned {points_scanned} points)")
break
for point in points:
payload = point.payload
if payload.get('year'):
year_value = payload['year']
if isinstance(year_value, str):
try:
year_value = int(year_value)
except ValueError:
continue
if isinstance(year_value, int):
years.add(year_value)
if payload.get('source'):
sources.add(payload['source'])
if payload.get('district'):
districts.add(payload['district'])
if payload.get('filename'):
filenames.add(payload['filename'])
points_scanned += len(points)
offset = next_offset
if progress_callback:
elapsed = time.time() - start_time
progress_callback(points_scanned, elapsed, iteration)
if offset is None:
elapsed = time.time() - start_time
logger.info(f"βœ… Completed full scan: {points_scanned} points in {elapsed:.1f}s")
break
elapsed = time.time() - start_time
logger.info(f"βœ… Scan complete: {points_scanned} points in {elapsed:.1f}s")
logger.info(f" Found: {len(years)} years, {len(sources)} sources, "
f"{len(districts)} districts, {len(filenames)} files")
except Exception as e:
logger.error(f"❌ Error scanning collection: {e}")
return {
'years': sorted(list(years)),
'sources': sorted(list(sources)),
'districts': sorted(list(districts)),
'filenames': sorted(list(filenames))
}
def build_filter(
self,
year: Optional[Any] = None,
source: Optional[Any] = None,
district: Optional[Any] = None,
filename: Optional[Any] = None,
has_text: Optional[bool] = None,
page_range: Optional[tuple] = None
) -> Optional[Filter]:
"""
Build Qdrant filter from parameters.
Supports both single values and lists (using MatchAny for lists).
"""
conditions = []
if year is not None:
if isinstance(year, list):
year_values = [int(y) if isinstance(y, str) else y for y in year]
conditions.append(
FieldCondition(key="year", match=MatchAny(any=year_values))
)
logger.info(f"πŸ” Filter: year IN {year_values}")
else:
year_value = int(year) if isinstance(year, str) else year
conditions.append(
FieldCondition(key="year", match=MatchValue(value=year_value))
)
logger.info(f"πŸ” Filter: year = {year_value}")
if source is not None:
if isinstance(source, list):
conditions.append(
FieldCondition(key="source", match=MatchAny(any=source))
)
logger.info(f"πŸ” Filter: source IN {source}")
else:
conditions.append(
FieldCondition(key="source", match=MatchValue(value=source))
)
logger.info(f"πŸ” Filter: source = {source}")
if district is not None:
if isinstance(district, list):
conditions.append(
FieldCondition(key="district", match=MatchAny(any=district))
)
logger.info(f"πŸ” Filter: district IN {district}")
else:
conditions.append(
FieldCondition(key="district", match=MatchValue(value=district))
)
logger.info(f"πŸ” Filter: district = {district}")
if filename is not None:
if isinstance(filename, list):
conditions.append(
FieldCondition(key="filename", match=MatchAny(any=filename))
)
logger.info(f"πŸ” Filter: filename IN {filename}")
else:
conditions.append(
FieldCondition(key="filename", match=MatchValue(value=filename))
)
logger.info(f"πŸ” Filter: filename = {filename}")
if has_text is not None:
conditions.append(
FieldCondition(key="has_text", match=MatchValue(value=has_text))
)
if page_range is not None:
min_page, max_page = page_range
conditions.append(
FieldCondition(
key="page_number",
range=Range(gte=min_page, lte=max_page)
)
)
if not conditions:
return None
return Filter(must=conditions)
def search_stage1_prefetch(
self,
query_embedding: torch.Tensor,
top_k: int = 100,
filter_obj: Optional[Filter] = None,
use_pooling: bool = False,
pooling_method: str = "mean"
) -> List[Dict[str, Any]]:
"""
Stage 1: Prefetch candidates using either multi-vector or pooled search.
"""
# Convert to numpy
if isinstance(query_embedding, torch.Tensor):
query_np = query_embedding.cpu().float().numpy()
else:
query_np = np.array(query_embedding, dtype=np.float32)
# Handle batch dimension
if query_np.ndim == 3:
query_np = query_np.squeeze(0)
# Strategy 1: Pooled search (fast, approximate)
if use_pooling:
if pooling_method == "mean":
query_pooled = query_np.mean(axis=0)
vector_name = "mean_pooling"
elif pooling_method == "max":
query_pooled = query_np.max(axis=0)
vector_name = "max_pooling"
else:
raise ValueError(f"Unknown pooling method: {pooling_method}")
if query_pooled.ndim != 1:
raise ValueError(f"Pooling failed! Expected 1D vector, got shape {query_pooled.shape}")
query_vector = query_pooled.tolist()
logger.info(f"πŸ” Pooled search: vector={vector_name}, dims={len(query_vector)}")
# Strategy 2: Native multi-vector search (SOTA)
else:
vector_name = "initial"
query_vector = query_np.tolist()
logger.info(f"🎯 Multi-vector search: vector={vector_name}, patches={len(query_vector)}, dims={len(query_vector[0])}")
try:
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vector,
using=vector_name,
query_filter=filter_obj,
limit=top_k,
with_payload=True,
with_vectors=False,
timeout=120
).points
logger.info(f"βœ… Stage 1: Retrieved {len(results)} candidates")
except Exception as e:
logger.error(f"❌ Search with vector '{vector_name}' failed: {e}")
raise
candidates = []
for result in results:
candidates.append({
'id': result.id,
'score_stage1': result.score,
'payload': result.payload
})
return candidates
def colbert_score(
self,
query_embedding: np.ndarray,
doc_embedding: np.ndarray
) -> float:
"""
Compute ColBERT-style late interaction score.
"""
# Normalize embeddings
query_norm = query_embedding / (np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8)
doc_norm = doc_embedding / (np.linalg.norm(doc_embedding, axis=1, keepdims=True) + 1e-8)
# Compute similarity matrix
sim_matrix = np.dot(query_norm, doc_norm.T)
# For each query patch, take max similarity with any doc patch
max_sims = sim_matrix.max(axis=1)
# Average across query patches
score = max_sims.mean()
return float(score)
def search_stage2_rerank(
self,
query_embedding: torch.Tensor,
candidates: List[Dict[str, Any]],
top_k: int = 10
) -> List[Dict[str, Any]]:
"""
Stage 2: Exact reranking using full multi-vector embeddings.
"""
if isinstance(query_embedding, torch.Tensor):
query_np = query_embedding.cpu().float().numpy()
else:
query_np = np.array(query_embedding, dtype=np.float32)
reranked = []
for candidate in candidates:
payload = candidate['payload']
full_embedding = payload.get('full_embedding')
if full_embedding is None:
candidate['score_final'] = candidate['score_stage1']
reranked.append(candidate)
continue
doc_np = np.array(full_embedding, dtype=np.float32)
colbert_score = self.colbert_score(query_np, doc_np)
candidate['score_stage2'] = colbert_score
candidate['score_final'] = colbert_score
reranked.append(candidate)
reranked.sort(key=lambda x: x['score_final'], reverse=True)
return reranked[:top_k]
def search(
self,
query_embedding: torch.Tensor,
top_k: int = 10,
prefetch_k: Optional[int] = None,
year: Optional[int] = None,
source: Optional[str] = None,
district: Optional[str] = None,
filename: Optional[str] = None,
has_text: Optional[bool] = None,
page_range: Optional[tuple] = None,
search_strategy: str = "multi_vector",
pooling_method: str = "mean",
use_reranking: bool = False
) -> List[Dict[str, Any]]:
"""
Multi-strategy visual document search.
Search Strategies:
1. "multi_vector" (DEFAULT, SOTA): Native multi-vector search
2. "pooled": Pooled search (fastest, less accurate)
3. "hybrid": Two-stage retrieval with reranking
"""
# Build filter
filter_obj = self.build_filter(
year=year,
source=source,
district=district,
filename=filename,
has_text=has_text,
page_range=page_range
)
# Strategy 1: Native multi-vector search (SOTA, default)
if search_strategy == "multi_vector":
logger.info(f"🎯 SOTA Multi-Vector Search: Querying 'initial' vector with native MaxSim")
candidates = self.search_stage1_prefetch(
query_embedding=query_embedding,
top_k=top_k,
filter_obj=filter_obj,
use_pooling=False
)
if not candidates:
logger.warning("❌ No results found")
return []
for c in candidates:
c['score_final'] = c['score_stage1']
logger.info(f"βœ… Retrieved {len(candidates)} results (native MaxSim)")
return candidates
# Strategy 2: Pooled search (fast, approximate)
elif search_strategy == "pooled":
logger.info(f"πŸ” Pooled Search: Querying '{pooling_method}_pooling' vector")
candidates = self.search_stage1_prefetch(
query_embedding=query_embedding,
top_k=top_k,
filter_obj=filter_obj,
use_pooling=True,
pooling_method=pooling_method
)
if not candidates:
logger.warning("❌ No results found")
return []
for c in candidates:
c['score_final'] = c['score_stage1']
logger.info(f"βœ… Retrieved {len(candidates)} results (pooled)")
return candidates
# Strategy 3: Hybrid two-stage
elif search_strategy == "hybrid":
if prefetch_k is None:
prefetch_k = max(100, top_k * 10)
logger.info(f"πŸ”„ Hybrid Search: Stage 1 - Prefetching {prefetch_k} with {pooling_method} pooling")
candidates = self.search_stage1_prefetch(
query_embedding=query_embedding,
top_k=prefetch_k,
filter_obj=filter_obj,
use_pooling=True,
pooling_method=pooling_method
)
if not candidates:
logger.warning("❌ No results found in stage 1")
return []
logger.info(f"βœ… Stage 1: Found {len(candidates)} candidates")
if use_reranking and len(candidates) > top_k:
logger.info(f"🎯 Stage 2: Reranking with ColBERT scoring...")
results = self.search_stage2_rerank(
query_embedding=query_embedding,
candidates=candidates,
top_k=top_k
)
logger.info(f"βœ… Reranked to top {len(results)} results")
return results
else:
results = candidates[:top_k]
for r in results:
r['score_final'] = r['score_stage1']
logger.info(f"⏭️ Skipping reranking, returning top {len(results)}")
return results
else:
raise ValueError(f"Unknown search_strategy: {search_strategy}")