visual-rag-toolkit / demo /qdrant_utils.py
Yeroyan's picture
SYNC with v0.1.3
047e959 verified
"""Qdrant connection and utility functions."""
import os
import traceback
from typing import Any, Dict, List, Optional, Tuple
import streamlit as st
def get_qdrant_credentials() -> Tuple[Optional[str], Optional[str]]:
"""Get Qdrant credentials from session state or environment variables.
Priority: session_state > QDRANT_URL/QDRANT_API_KEY > legacy env vars
"""
url = (
st.session_state.get("qdrant_url_input")
or os.getenv("QDRANT_URL")
or os.getenv("SIGIR_QDRANT_URL") # legacy
)
api_key = (
st.session_state.get("qdrant_key_input")
or os.getenv("QDRANT_API_KEY")
or os.getenv("SIGIR_QDRANT_KEY") # legacy
)
return url, api_key
def init_qdrant_client_with_creds(url: str, api_key: str):
try:
from qdrant_client import QdrantClient
if not url:
return None, "QDRANT_URL not configured"
client = QdrantClient(url=url, api_key=api_key, timeout=60)
client.get_collections()
return client, None
except Exception as e:
return None, str(e)
@st.cache_resource(show_spinner="Connecting to Qdrant...")
def init_qdrant_client():
url, api_key = get_qdrant_credentials()
return init_qdrant_client_with_creds(url, api_key)
@st.cache_resource(show_spinner="Loading embedding model...")
def init_embedder(model_name: str):
try:
from visual_rag import VisualEmbedder
return VisualEmbedder(model_name=model_name), None
except Exception as e:
return None, f"{e}\n\n{traceback.format_exc()}"
@st.cache_data(ttl=300, show_spinner="Fetching collections...")
def get_collections(_url: str, _api_key: str) -> List[str]:
client, err = init_qdrant_client_with_creds(_url, _api_key)
if client is None:
return []
try:
collections = client.get_collections().collections
return sorted([c.name for c in collections])
except Exception:
return []
@st.cache_data(ttl=120, show_spinner="Loading collection stats...")
def get_collection_stats(collection_name: str) -> Dict[str, Any]:
url, api_key = get_qdrant_credentials()
client, err = init_qdrant_client_with_creds(url, api_key)
if client is None:
return {"error": err}
try:
info = client.get_collection(collection_name)
vectors_config = getattr(getattr(getattr(info, "config", None), "params", None), "vectors", None)
vector_info = {}
if vectors_config is not None:
if hasattr(vectors_config, "items"):
for name, cfg in vectors_config.items():
size = getattr(cfg, "size", None)
multivec = getattr(cfg, "multivector_config", None)
on_disk = getattr(cfg, "on_disk", None)
datatype = str(getattr(cfg, "datatype", "Float32")).replace("Datatype.", "")
quantization = getattr(cfg, "quantization_config", None)
num_vectors = 1
if multivec is not None:
comparator = getattr(multivec, "comparator", None)
num_vectors = "N" if comparator else 1
vector_info[name] = {
"size": size,
"num_vectors": num_vectors,
"is_multivector": multivec is not None,
"on_disk": on_disk,
"datatype": datatype,
"quantization": quantization is not None,
}
elif hasattr(vectors_config, "size"):
on_disk = getattr(vectors_config, "on_disk", None)
datatype = str(getattr(vectors_config, "datatype", "Float32")).replace("Datatype.", "")
multivec = getattr(vectors_config, "multivector_config", None)
vector_info["default"] = {
"size": getattr(vectors_config, "size", None),
"num_vectors": "N" if multivec else 1,
"is_multivector": multivec is not None,
"on_disk": on_disk,
"datatype": datatype,
}
return {
"points_count": getattr(info, "points_count", 0),
"vectors_count": getattr(info, "vectors_count", getattr(info, "points_count", 0)),
"status": str(getattr(info, "status", "unknown")),
"vector_info": vector_info,
"indexed_vectors_count": getattr(info, "indexed_vectors_count", None),
}
except Exception as e:
return {"error": f"{e}\n\n{traceback.format_exc()}"}
@st.cache_data(ttl=60)
def sample_points_cached(collection_name: str, n: int, seed: int, _url: str, _api_key: str) -> List[Dict[str, Any]]:
client, err = init_qdrant_client_with_creds(_url, _api_key)
if client is None:
return []
try:
import random
rng = random.Random(seed)
points, _ = client.scroll(
collection_name=collection_name,
limit=min(n * 10, 100),
with_payload=True,
with_vectors=False,
)
if not points:
return []
sampled = rng.sample(points, min(n, len(points)))
results = []
for p in sampled:
payload = dict(p.payload) if p.payload else {}
results.append({
"id": str(p.id),
"payload": payload,
})
return results
except Exception:
return []
@st.cache_data(ttl=300)
def get_vector_sizes(collection_name: str, _url: str, _api_key: str) -> Dict[str, int]:
client, err = init_qdrant_client_with_creds(_url, _api_key)
if client is None:
return {}
try:
points, _ = client.scroll(
collection_name=collection_name,
limit=1,
with_payload=False,
with_vectors=True,
)
if not points:
return {}
vectors = points[0].vector
sizes = {}
if isinstance(vectors, dict):
for name, vec in vectors.items():
if isinstance(vec, list):
if vec and isinstance(vec[0], list):
sizes[name] = len(vec)
else:
sizes[name] = 1
else:
sizes[name] = 1
return sizes
except Exception:
return {}
def search_collection(
collection_name: str,
query: str,
top_k: int = 10,
mode: str = "single_full",
prefetch_k: int = 256,
stage1_mode: str = "tokens_vs_tiles",
stage1_k: int = 1000,
stage2_k: int = 300,
model_name: str = "vidore/colSmol-500M",
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
try:
import traceback
from visual_rag.retrieval import MultiVectorRetriever
retriever = MultiVectorRetriever(
collection_name=collection_name,
model_name=model_name,
)
if mode == "three_stage":
q_emb = retriever.embedder.embed_query(query)
if hasattr(q_emb, "cpu"):
q_emb = q_emb.cpu().numpy()
results = retriever.search_embedded(
query_embedding=q_emb,
top_k=top_k,
mode=mode,
stage1_k=stage1_k,
stage2_k=stage2_k,
)
else:
results = retriever.search(
query=query,
top_k=top_k,
mode=mode,
prefetch_k=prefetch_k,
stage1_mode=stage1_mode,
)
return results, None
except Exception as e:
import traceback
return [], f"{e}\n\n{traceback.format_exc()}"