Spaces:
Sleeping
Sleeping
File size: 7,119 Bytes
c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf de13ab3 c4ef1cf de13ab3 c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf de13ab3 9513cca c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf 9513cca 51cf2c0 c4ef1cf 51cf2c0 c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf 9513cca c4ef1cf 9513cca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import os
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import numpy as np
import torch
try:
from dotenv import load_dotenv
DOTENV_AVAILABLE = True
except ImportError:
DOTENV_AVAILABLE = False
load_dotenv = None
try:
from qdrant_client import QdrantClient
QDRANT_AVAILABLE = True
except ImportError:
QDRANT_AVAILABLE = False
QdrantClient = None
from visual_rag.embedding.visual_embedder import VisualEmbedder
from visual_rag.retrieval.single_stage import SingleStageRetriever
from visual_rag.retrieval.three_stage import ThreeStageRetriever
from visual_rag.retrieval.two_stage import TwoStageRetriever
class MultiVectorRetriever:
@staticmethod
def _maybe_load_dotenv() -> None:
if not DOTENV_AVAILABLE:
return
if os.path.exists(".env"):
load_dotenv(".env")
def __init__(
self,
collection_name: str,
model_name: str = "vidore/colSmol-500M",
qdrant_url: Optional[str] = None,
qdrant_api_key: Optional[str] = None,
prefer_grpc: bool = False,
request_timeout: int = 120,
max_retries: int = 3,
retry_sleep: float = 0.5,
qdrant_client=None,
embedder: Optional[VisualEmbedder] = None,
):
if qdrant_client is None:
self._maybe_load_dotenv()
if not QDRANT_AVAILABLE:
raise ImportError(
"Qdrant client not installed. Install with: pip install visual-rag-toolkit[qdrant]"
)
qdrant_url = (
qdrant_url or os.getenv("QDRANT_URL") or os.getenv("SIGIR_QDRANT_URL") # legacy
)
if not qdrant_url:
raise ValueError("QDRANT_URL is required (pass qdrant_url or set env var).")
qdrant_api_key = (
qdrant_api_key
or os.getenv("QDRANT_API_KEY")
or os.getenv("SIGIR_QDRANT_KEY") # legacy
)
grpc_port = None
if prefer_grpc:
try:
parsed = urlparse(qdrant_url)
port = parsed.port
if port == 6333:
grpc_port = 6334
except Exception:
pass
def _make_client(use_grpc: bool):
return QdrantClient(
url=qdrant_url,
api_key=qdrant_api_key,
timeout=request_timeout,
prefer_grpc=bool(use_grpc),
grpc_port=grpc_port,
check_compatibility=False,
)
client = _make_client(prefer_grpc)
if prefer_grpc:
try:
_ = client.get_collections()
except Exception as e:
msg = str(e)
if (
"StatusCode.PERMISSION_DENIED" in msg
or "http2 header with status: 403" in msg
):
client = _make_client(False)
else:
raise
qdrant_client = client
self.client = qdrant_client
self.collection_name = collection_name
self.embedder = embedder or VisualEmbedder(model_name=model_name)
self._two_stage = TwoStageRetriever(
qdrant_client=qdrant_client,
collection_name=collection_name,
request_timeout=request_timeout,
max_retries=max_retries,
retry_sleep=retry_sleep,
)
self._three_stage = ThreeStageRetriever(
qdrant_client=qdrant_client,
collection_name=collection_name,
request_timeout=request_timeout,
max_retries=max_retries,
retry_sleep=retry_sleep,
)
self._single_stage = SingleStageRetriever(
qdrant_client=qdrant_client,
collection_name=collection_name,
request_timeout=request_timeout,
max_retries=max_retries,
retry_sleep=retry_sleep,
)
def build_filter(
self,
year: Optional[Any] = None,
source: Optional[str] = None,
district: Optional[str] = None,
filename: Optional[str] = None,
has_text: Optional[bool] = None,
):
return self._two_stage.build_filter(
year=year,
source=source,
district=district,
filename=filename,
has_text=has_text,
)
def search(
self,
query: str,
top_k: int = 10,
mode: str = "single_full",
prefetch_k: Optional[int] = None,
stage1_mode: str = "pooled_query_vs_tiles",
filter_obj=None,
return_embeddings: bool = False,
) -> List[Dict[str, Any]]:
q = self.embedder.embed_query(query)
if isinstance(q, torch.Tensor):
# .float() converts BFloat16 to Float32 (numpy doesn't support BFloat16)
query_embedding = q.detach().cpu().float().numpy()
else:
query_embedding = np.asarray(q, dtype=np.float32)
return self.search_embedded(
query_embedding=query_embedding,
top_k=top_k,
mode=mode,
prefetch_k=prefetch_k,
stage1_mode=stage1_mode,
filter_obj=filter_obj,
return_embeddings=return_embeddings,
)
def search_embedded(
self,
*,
query_embedding,
top_k: int = 10,
mode: str = "single_full",
prefetch_k: Optional[int] = None,
stage1_mode: str = "pooled_query_vs_tiles",
stage1_k: Optional[int] = None,
stage2_k: Optional[int] = None,
filter_obj=None,
return_embeddings: bool = False,
) -> List[Dict[str, Any]]:
if mode == "single_full":
return self._single_stage.search(
query_embedding=query_embedding,
top_k=top_k,
filter_obj=filter_obj,
using="initial",
)
elif mode == "single_pooled":
return self._single_stage.search(
query_embedding=query_embedding,
top_k=top_k,
filter_obj=filter_obj,
using="mean_pooling",
)
elif mode == "two_stage":
return self._two_stage.search_server_side(
query_embedding=query_embedding,
top_k=top_k,
prefetch_k=prefetch_k,
filter_obj=filter_obj,
stage1_mode=stage1_mode,
)
elif mode == "three_stage":
return self._three_stage.search_server_side(
query_embedding=query_embedding,
top_k=top_k,
stage1_k=stage1_k,
stage2_k=stage2_k,
filter_obj=filter_obj,
stage1_mode=stage1_mode,
)
else:
raise ValueError(f"Unknown mode: {mode}")
|