Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Any, Dict, List, Optional | |
| from urllib.parse import urlparse | |
| import numpy as np | |
| import torch | |
| try: | |
| from dotenv import load_dotenv | |
| DOTENV_AVAILABLE = True | |
| except ImportError: | |
| DOTENV_AVAILABLE = False | |
| load_dotenv = None | |
| try: | |
| from qdrant_client import QdrantClient | |
| QDRANT_AVAILABLE = True | |
| except ImportError: | |
| QDRANT_AVAILABLE = False | |
| QdrantClient = None | |
| from visual_rag.embedding.visual_embedder import VisualEmbedder | |
| from visual_rag.retrieval.single_stage import SingleStageRetriever | |
| from visual_rag.retrieval.three_stage import ThreeStageRetriever | |
| from visual_rag.retrieval.two_stage import TwoStageRetriever | |
| class MultiVectorRetriever: | |
| def _maybe_load_dotenv() -> None: | |
| if not DOTENV_AVAILABLE: | |
| return | |
| if os.path.exists(".env"): | |
| load_dotenv(".env") | |
| def __init__( | |
| self, | |
| collection_name: str, | |
| model_name: str = "vidore/colSmol-500M", | |
| qdrant_url: Optional[str] = None, | |
| qdrant_api_key: Optional[str] = None, | |
| prefer_grpc: bool = False, | |
| request_timeout: int = 120, | |
| max_retries: int = 3, | |
| retry_sleep: float = 0.5, | |
| qdrant_client=None, | |
| embedder: Optional[VisualEmbedder] = None, | |
| ): | |
| if qdrant_client is None: | |
| self._maybe_load_dotenv() | |
| if not QDRANT_AVAILABLE: | |
| raise ImportError( | |
| "Qdrant client not installed. Install with: pip install visual-rag-toolkit[qdrant]" | |
| ) | |
| qdrant_url = ( | |
| qdrant_url or os.getenv("QDRANT_URL") or os.getenv("SIGIR_QDRANT_URL") # legacy | |
| ) | |
| if not qdrant_url: | |
| raise ValueError("QDRANT_URL is required (pass qdrant_url or set env var).") | |
| qdrant_api_key = ( | |
| qdrant_api_key | |
| or os.getenv("QDRANT_API_KEY") | |
| or os.getenv("SIGIR_QDRANT_KEY") # legacy | |
| ) | |
| grpc_port = None | |
| if prefer_grpc: | |
| try: | |
| parsed = urlparse(qdrant_url) | |
| port = parsed.port | |
| if port == 6333: | |
| grpc_port = 6334 | |
| except Exception: | |
| pass | |
| def _make_client(use_grpc: bool): | |
| return QdrantClient( | |
| url=qdrant_url, | |
| api_key=qdrant_api_key, | |
| timeout=request_timeout, | |
| prefer_grpc=bool(use_grpc), | |
| grpc_port=grpc_port, | |
| check_compatibility=False, | |
| ) | |
| client = _make_client(prefer_grpc) | |
| if prefer_grpc: | |
| try: | |
| _ = client.get_collections() | |
| except Exception as e: | |
| msg = str(e) | |
| if ( | |
| "StatusCode.PERMISSION_DENIED" in msg | |
| or "http2 header with status: 403" in msg | |
| ): | |
| client = _make_client(False) | |
| else: | |
| raise | |
| qdrant_client = client | |
| self.client = qdrant_client | |
| self.collection_name = collection_name | |
| self.embedder = embedder or VisualEmbedder(model_name=model_name) | |
| self._two_stage = TwoStageRetriever( | |
| qdrant_client=qdrant_client, | |
| collection_name=collection_name, | |
| request_timeout=request_timeout, | |
| max_retries=max_retries, | |
| retry_sleep=retry_sleep, | |
| ) | |
| self._three_stage = ThreeStageRetriever( | |
| qdrant_client=qdrant_client, | |
| collection_name=collection_name, | |
| request_timeout=request_timeout, | |
| max_retries=max_retries, | |
| retry_sleep=retry_sleep, | |
| ) | |
| self._single_stage = SingleStageRetriever( | |
| qdrant_client=qdrant_client, | |
| collection_name=collection_name, | |
| request_timeout=request_timeout, | |
| max_retries=max_retries, | |
| retry_sleep=retry_sleep, | |
| ) | |
| def build_filter( | |
| self, | |
| year: Optional[Any] = None, | |
| source: Optional[str] = None, | |
| district: Optional[str] = None, | |
| filename: Optional[str] = None, | |
| has_text: Optional[bool] = None, | |
| ): | |
| return self._two_stage.build_filter( | |
| year=year, | |
| source=source, | |
| district=district, | |
| filename=filename, | |
| has_text=has_text, | |
| ) | |
| def search( | |
| self, | |
| query: str, | |
| top_k: int = 10, | |
| mode: str = "single_full", | |
| prefetch_k: Optional[int] = None, | |
| stage1_mode: str = "pooled_query_vs_tiles", | |
| filter_obj=None, | |
| return_embeddings: bool = False, | |
| ) -> List[Dict[str, Any]]: | |
| q = self.embedder.embed_query(query) | |
| if isinstance(q, torch.Tensor): | |
| # .float() converts BFloat16 to Float32 (numpy doesn't support BFloat16) | |
| query_embedding = q.detach().cpu().float().numpy() | |
| else: | |
| query_embedding = np.asarray(q, dtype=np.float32) | |
| return self.search_embedded( | |
| query_embedding=query_embedding, | |
| top_k=top_k, | |
| mode=mode, | |
| prefetch_k=prefetch_k, | |
| stage1_mode=stage1_mode, | |
| filter_obj=filter_obj, | |
| return_embeddings=return_embeddings, | |
| ) | |
| def search_embedded( | |
| self, | |
| *, | |
| query_embedding, | |
| top_k: int = 10, | |
| mode: str = "single_full", | |
| prefetch_k: Optional[int] = None, | |
| stage1_mode: str = "pooled_query_vs_tiles", | |
| stage1_k: Optional[int] = None, | |
| stage2_k: Optional[int] = None, | |
| filter_obj=None, | |
| return_embeddings: bool = False, | |
| ) -> List[Dict[str, Any]]: | |
| if mode == "single_full": | |
| return self._single_stage.search( | |
| query_embedding=query_embedding, | |
| top_k=top_k, | |
| filter_obj=filter_obj, | |
| using="initial", | |
| ) | |
| elif mode == "single_pooled": | |
| return self._single_stage.search( | |
| query_embedding=query_embedding, | |
| top_k=top_k, | |
| filter_obj=filter_obj, | |
| using="mean_pooling", | |
| ) | |
| elif mode == "two_stage": | |
| return self._two_stage.search_server_side( | |
| query_embedding=query_embedding, | |
| top_k=top_k, | |
| prefetch_k=prefetch_k, | |
| filter_obj=filter_obj, | |
| stage1_mode=stage1_mode, | |
| ) | |
| elif mode == "three_stage": | |
| return self._three_stage.search_server_side( | |
| query_embedding=query_embedding, | |
| top_k=top_k, | |
| stage1_k=stage1_k, | |
| stage2_k=stage2_k, | |
| filter_obj=filter_obj, | |
| stage1_mode=stage1_mode, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown mode: {mode}") | |