Yeroyan's picture
Fix BFloat16 numpy conversion
51cf2c0 verified
import os
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import numpy as np
import torch
try:
from dotenv import load_dotenv
DOTENV_AVAILABLE = True
except ImportError:
DOTENV_AVAILABLE = False
load_dotenv = None
try:
from qdrant_client import QdrantClient
QDRANT_AVAILABLE = True
except ImportError:
QDRANT_AVAILABLE = False
QdrantClient = None
from visual_rag.embedding.visual_embedder import VisualEmbedder
from visual_rag.retrieval.single_stage import SingleStageRetriever
from visual_rag.retrieval.three_stage import ThreeStageRetriever
from visual_rag.retrieval.two_stage import TwoStageRetriever
class MultiVectorRetriever:
@staticmethod
def _maybe_load_dotenv() -> None:
if not DOTENV_AVAILABLE:
return
if os.path.exists(".env"):
load_dotenv(".env")
def __init__(
self,
collection_name: str,
model_name: str = "vidore/colSmol-500M",
qdrant_url: Optional[str] = None,
qdrant_api_key: Optional[str] = None,
prefer_grpc: bool = False,
request_timeout: int = 120,
max_retries: int = 3,
retry_sleep: float = 0.5,
qdrant_client=None,
embedder: Optional[VisualEmbedder] = None,
):
if qdrant_client is None:
self._maybe_load_dotenv()
if not QDRANT_AVAILABLE:
raise ImportError(
"Qdrant client not installed. Install with: pip install visual-rag-toolkit[qdrant]"
)
qdrant_url = (
qdrant_url or os.getenv("QDRANT_URL") or os.getenv("SIGIR_QDRANT_URL") # legacy
)
if not qdrant_url:
raise ValueError("QDRANT_URL is required (pass qdrant_url or set env var).")
qdrant_api_key = (
qdrant_api_key
or os.getenv("QDRANT_API_KEY")
or os.getenv("SIGIR_QDRANT_KEY") # legacy
)
grpc_port = None
if prefer_grpc:
try:
parsed = urlparse(qdrant_url)
port = parsed.port
if port == 6333:
grpc_port = 6334
except Exception:
pass
def _make_client(use_grpc: bool):
return QdrantClient(
url=qdrant_url,
api_key=qdrant_api_key,
timeout=request_timeout,
prefer_grpc=bool(use_grpc),
grpc_port=grpc_port,
check_compatibility=False,
)
client = _make_client(prefer_grpc)
if prefer_grpc:
try:
_ = client.get_collections()
except Exception as e:
msg = str(e)
if (
"StatusCode.PERMISSION_DENIED" in msg
or "http2 header with status: 403" in msg
):
client = _make_client(False)
else:
raise
qdrant_client = client
self.client = qdrant_client
self.collection_name = collection_name
self.embedder = embedder or VisualEmbedder(model_name=model_name)
self._two_stage = TwoStageRetriever(
qdrant_client=qdrant_client,
collection_name=collection_name,
request_timeout=request_timeout,
max_retries=max_retries,
retry_sleep=retry_sleep,
)
self._three_stage = ThreeStageRetriever(
qdrant_client=qdrant_client,
collection_name=collection_name,
request_timeout=request_timeout,
max_retries=max_retries,
retry_sleep=retry_sleep,
)
self._single_stage = SingleStageRetriever(
qdrant_client=qdrant_client,
collection_name=collection_name,
request_timeout=request_timeout,
max_retries=max_retries,
retry_sleep=retry_sleep,
)
def build_filter(
self,
year: Optional[Any] = None,
source: Optional[str] = None,
district: Optional[str] = None,
filename: Optional[str] = None,
has_text: Optional[bool] = None,
):
return self._two_stage.build_filter(
year=year,
source=source,
district=district,
filename=filename,
has_text=has_text,
)
def search(
self,
query: str,
top_k: int = 10,
mode: str = "single_full",
prefetch_k: Optional[int] = None,
stage1_mode: str = "pooled_query_vs_tiles",
filter_obj=None,
return_embeddings: bool = False,
) -> List[Dict[str, Any]]:
q = self.embedder.embed_query(query)
if isinstance(q, torch.Tensor):
# .float() converts BFloat16 to Float32 (numpy doesn't support BFloat16)
query_embedding = q.detach().cpu().float().numpy()
else:
query_embedding = np.asarray(q, dtype=np.float32)
return self.search_embedded(
query_embedding=query_embedding,
top_k=top_k,
mode=mode,
prefetch_k=prefetch_k,
stage1_mode=stage1_mode,
filter_obj=filter_obj,
return_embeddings=return_embeddings,
)
def search_embedded(
self,
*,
query_embedding,
top_k: int = 10,
mode: str = "single_full",
prefetch_k: Optional[int] = None,
stage1_mode: str = "pooled_query_vs_tiles",
stage1_k: Optional[int] = None,
stage2_k: Optional[int] = None,
filter_obj=None,
return_embeddings: bool = False,
) -> List[Dict[str, Any]]:
if mode == "single_full":
return self._single_stage.search(
query_embedding=query_embedding,
top_k=top_k,
filter_obj=filter_obj,
using="initial",
)
elif mode == "single_pooled":
return self._single_stage.search(
query_embedding=query_embedding,
top_k=top_k,
filter_obj=filter_obj,
using="mean_pooling",
)
elif mode == "two_stage":
return self._two_stage.search_server_side(
query_embedding=query_embedding,
top_k=top_k,
prefetch_k=prefetch_k,
filter_obj=filter_obj,
stage1_mode=stage1_mode,
)
elif mode == "three_stage":
return self._three_stage.search_server_side(
query_embedding=query_embedding,
top_k=top_k,
stage1_k=stage1_k,
stage2_k=stage2_k,
filter_obj=filter_obj,
stage1_mode=stage1_mode,
)
else:
raise ValueError(f"Unknown mode: {mode}")