"""Qdrant vector retriever — handles embedding queries and searching.""" from urllib.parse import urlparse from qdrant_client import QdrantClient from qdrant_client.models import Distance, VectorParams, PointStruct from sentence_transformers import SentenceTransformer import uuid as uuid_lib from app.config import ( QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME, EMBEDDING_MODEL, EMBEDDING_DIMENSION, TOP_K, ) class Retriever: """Wraps Qdrant for vector search operations.""" def __init__(self): # Parse URL into host/port for qdrant_client (avoids default port 6333 issue) parsed = urlparse(QDRANT_URL) host = parsed.hostname or "localhost" port = parsed.port or (443 if parsed.scheme == "https" else 80) use_https = parsed.scheme == "https" self.client = QdrantClient( host=host, port=port, api_key=QDRANT_API_KEY if QDRANT_API_KEY else None, prefer_grpc=False, https=use_https, timeout=30, ) self.model = SentenceTransformer(EMBEDDING_MODEL) def ensure_collection(self): """Create the collection if it doesn't exist.""" collections = [c.name for c in self.client.get_collections().collections] if COLLECTION_NAME not in collections: self.client.create_collection( collection_name=COLLECTION_NAME, vectors_config=VectorParams( size=EMBEDDING_DIMENSION, distance=Distance.COSINE, ), ) print(f"Created collection: {COLLECTION_NAME}") else: print(f"Collection '{COLLECTION_NAME}' already exists.") def embed_text(self, text: str) -> list[float]: """Embed a single text string.""" return self.model.encode(text).tolist() def embed_texts(self, texts: list[str]) -> list[list[float]]: """Embed a batch of text strings.""" return self.model.encode(texts).tolist() def upsert_chunks(self, chunks: list[dict]): """ Upsert document chunks into Qdrant. Each chunk: {"text": str, "metadata": dict} """ if not chunks: return texts = [c["text"] for c in chunks] embeddings = self.embed_texts(texts) points = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): point_id = str(uuid_lib.uuid4()) payload = {**chunk["metadata"], "text": chunk["text"]} points.append( PointStruct(id=point_id, vector=embedding, payload=payload) ) # Upsert in batches of 100 batch_size = 100 for i in range(0, len(points), batch_size): batch = points[i : i + batch_size] self.client.upsert( collection_name=COLLECTION_NAME, points=batch, ) def search(self, query: str, top_k: int = TOP_K) -> list[dict]: """ Search for relevant chunks. Returns list of {"text": str, "score": float, "metadata": dict} """ query_vector = self.embed_text(query) results = self.client.search( collection_name=COLLECTION_NAME, query_vector=query_vector, limit=top_k, ) return [ { "text": hit.payload.get("text", ""), "score": hit.score, "metadata": { k: v for k, v in hit.payload.items() if k != "text" }, } for hit in results ] def get_collection_info(self) -> dict: """Get information about the collection.""" try: info = self.client.get_collection(COLLECTION_NAME) return { "name": COLLECTION_NAME, "vectors_count": info.vectors_count, "points_count": info.points_count, "status": info.status.value, } except Exception as e: return {"error": str(e)} # Global singleton — lazy loaded _retriever = None def get_retriever() -> Retriever: global _retriever if _retriever is None: _retriever = Retriever() return _retriever