"""Qdrant connection and utility functions.""" import os import traceback from typing import Any, Dict, List, Optional, Tuple import streamlit as st def get_qdrant_credentials() -> Tuple[Optional[str], Optional[str]]: """Get Qdrant credentials from session state or environment variables. Priority: session_state > QDRANT_URL/QDRANT_API_KEY > legacy env vars """ url = ( st.session_state.get("qdrant_url_input") or os.getenv("QDRANT_URL") or os.getenv("SIGIR_QDRANT_URL") # legacy ) api_key = ( st.session_state.get("qdrant_key_input") or os.getenv("QDRANT_API_KEY") or os.getenv("SIGIR_QDRANT_KEY") # legacy ) return url, api_key def init_qdrant_client_with_creds(url: str, api_key: str): try: from qdrant_client import QdrantClient if not url: return None, "QDRANT_URL not configured" client = QdrantClient(url=url, api_key=api_key, timeout=60) client.get_collections() return client, None except Exception as e: return None, str(e) @st.cache_resource(show_spinner="Connecting to Qdrant...") def init_qdrant_client(): url, api_key = get_qdrant_credentials() return init_qdrant_client_with_creds(url, api_key) @st.cache_resource(show_spinner="Loading embedding model...") def init_embedder(model_name: str): try: from visual_rag import VisualEmbedder return VisualEmbedder(model_name=model_name), None except Exception as e: return None, f"{e}\n\n{traceback.format_exc()}" @st.cache_data(ttl=300, show_spinner="Fetching collections...") def get_collections(_url: str, _api_key: str) -> List[str]: client, err = init_qdrant_client_with_creds(_url, _api_key) if client is None: return [] try: collections = client.get_collections().collections return sorted([c.name for c in collections]) except Exception: return [] @st.cache_data(ttl=120, show_spinner="Loading collection stats...") def get_collection_stats(collection_name: str) -> Dict[str, Any]: url, api_key = get_qdrant_credentials() client, err = init_qdrant_client_with_creds(url, api_key) if client is None: return {"error": err} try: info = client.get_collection(collection_name) vectors_config = getattr(getattr(getattr(info, "config", None), "params", None), "vectors", None) vector_info = {} if vectors_config is not None: if hasattr(vectors_config, "items"): for name, cfg in vectors_config.items(): size = getattr(cfg, "size", None) multivec = getattr(cfg, "multivector_config", None) on_disk = getattr(cfg, "on_disk", None) datatype = str(getattr(cfg, "datatype", "Float32")).replace("Datatype.", "") quantization = getattr(cfg, "quantization_config", None) num_vectors = 1 if multivec is not None: comparator = getattr(multivec, "comparator", None) num_vectors = "N" if comparator else 1 vector_info[name] = { "size": size, "num_vectors": num_vectors, "is_multivector": multivec is not None, "on_disk": on_disk, "datatype": datatype, "quantization": quantization is not None, } elif hasattr(vectors_config, "size"): on_disk = getattr(vectors_config, "on_disk", None) datatype = str(getattr(vectors_config, "datatype", "Float32")).replace("Datatype.", "") multivec = getattr(vectors_config, "multivector_config", None) vector_info["default"] = { "size": getattr(vectors_config, "size", None), "num_vectors": "N" if multivec else 1, "is_multivector": multivec is not None, "on_disk": on_disk, "datatype": datatype, } return { "points_count": getattr(info, "points_count", 0), "vectors_count": getattr(info, "vectors_count", getattr(info, "points_count", 0)), "status": str(getattr(info, "status", "unknown")), "vector_info": vector_info, "indexed_vectors_count": getattr(info, "indexed_vectors_count", None), } except Exception as e: return {"error": f"{e}\n\n{traceback.format_exc()}"} @st.cache_data(ttl=60) def sample_points_cached(collection_name: str, n: int, seed: int, _url: str, _api_key: str) -> List[Dict[str, Any]]: client, err = init_qdrant_client_with_creds(_url, _api_key) if client is None: return [] try: import random rng = random.Random(seed) points, _ = client.scroll( collection_name=collection_name, limit=min(n * 10, 100), with_payload=True, with_vectors=False, ) if not points: return [] sampled = rng.sample(points, min(n, len(points))) results = [] for p in sampled: payload = dict(p.payload) if p.payload else {} results.append({ "id": str(p.id), "payload": payload, }) return results except Exception: return [] @st.cache_data(ttl=300) def get_vector_sizes(collection_name: str, _url: str, _api_key: str) -> Dict[str, int]: client, err = init_qdrant_client_with_creds(_url, _api_key) if client is None: return {} try: points, _ = client.scroll( collection_name=collection_name, limit=1, with_payload=False, with_vectors=True, ) if not points: return {} vectors = points[0].vector sizes = {} if isinstance(vectors, dict): for name, vec in vectors.items(): if isinstance(vec, list): if vec and isinstance(vec[0], list): sizes[name] = len(vec) else: sizes[name] = 1 else: sizes[name] = 1 return sizes except Exception: return {} def search_collection( collection_name: str, query: str, top_k: int = 10, mode: str = "single_full", prefetch_k: int = 256, stage1_mode: str = "tokens_vs_tiles", stage1_k: int = 1000, stage2_k: int = 300, model_name: str = "vidore/colSmol-500M", ) -> Tuple[List[Dict[str, Any]], Optional[str]]: try: import traceback from visual_rag.retrieval import MultiVectorRetriever retriever = MultiVectorRetriever( collection_name=collection_name, model_name=model_name, ) if mode == "three_stage": q_emb = retriever.embedder.embed_query(query) if hasattr(q_emb, "cpu"): q_emb = q_emb.cpu().numpy() results = retriever.search_embedded( query_embedding=q_emb, top_k=top_k, mode=mode, stage1_k=stage1_k, stage2_k=stage2_k, ) else: results = retriever.search( query=query, top_k=top_k, mode=mode, prefetch_k=prefetch_k, stage1_mode=stage1_mode, ) return results, None except Exception as e: import traceback return [], f"{e}\n\n{traceback.format_exc()}"