Spaces:
Sleeping
Sleeping
| """Qdrant vector retriever — handles embedding queries and searching.""" | |
| from urllib.parse import urlparse | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Distance, VectorParams, PointStruct | |
| from sentence_transformers import SentenceTransformer | |
| import uuid as uuid_lib | |
| from app.config import ( | |
| QDRANT_URL, | |
| QDRANT_API_KEY, | |
| COLLECTION_NAME, | |
| EMBEDDING_MODEL, | |
| EMBEDDING_DIMENSION, | |
| TOP_K, | |
| ) | |
| class Retriever: | |
| """Wraps Qdrant for vector search operations.""" | |
| def __init__(self): | |
| # Parse URL into host/port for qdrant_client (avoids default port 6333 issue) | |
| parsed = urlparse(QDRANT_URL) | |
| host = parsed.hostname or "localhost" | |
| port = parsed.port or (443 if parsed.scheme == "https" else 80) | |
| use_https = parsed.scheme == "https" | |
| self.client = QdrantClient( | |
| host=host, | |
| port=port, | |
| api_key=QDRANT_API_KEY if QDRANT_API_KEY else None, | |
| prefer_grpc=False, | |
| https=use_https, | |
| timeout=30, | |
| ) | |
| self.model = SentenceTransformer(EMBEDDING_MODEL) | |
| def ensure_collection(self): | |
| """Create the collection if it doesn't exist.""" | |
| collections = [c.name for c in self.client.get_collections().collections] | |
| if COLLECTION_NAME not in collections: | |
| self.client.create_collection( | |
| collection_name=COLLECTION_NAME, | |
| vectors_config=VectorParams( | |
| size=EMBEDDING_DIMENSION, | |
| distance=Distance.COSINE, | |
| ), | |
| ) | |
| print(f"Created collection: {COLLECTION_NAME}") | |
| else: | |
| print(f"Collection '{COLLECTION_NAME}' already exists.") | |
| def embed_text(self, text: str) -> list[float]: | |
| """Embed a single text string.""" | |
| return self.model.encode(text).tolist() | |
| def embed_texts(self, texts: list[str]) -> list[list[float]]: | |
| """Embed a batch of text strings.""" | |
| return self.model.encode(texts).tolist() | |
| def upsert_chunks(self, chunks: list[dict]): | |
| """ | |
| Upsert document chunks into Qdrant. | |
| Each chunk: {"text": str, "metadata": dict} | |
| """ | |
| if not chunks: | |
| return | |
| texts = [c["text"] for c in chunks] | |
| embeddings = self.embed_texts(texts) | |
| points = [] | |
| for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): | |
| point_id = str(uuid_lib.uuid4()) | |
| payload = {**chunk["metadata"], "text": chunk["text"]} | |
| points.append( | |
| PointStruct(id=point_id, vector=embedding, payload=payload) | |
| ) | |
| # Upsert in batches of 100 | |
| batch_size = 100 | |
| for i in range(0, len(points), batch_size): | |
| batch = points[i : i + batch_size] | |
| self.client.upsert( | |
| collection_name=COLLECTION_NAME, | |
| points=batch, | |
| ) | |
| def search(self, query: str, top_k: int = TOP_K) -> list[dict]: | |
| """ | |
| Search for relevant chunks. | |
| Returns list of {"text": str, "score": float, "metadata": dict} | |
| """ | |
| query_vector = self.embed_text(query) | |
| results = self.client.search( | |
| collection_name=COLLECTION_NAME, | |
| query_vector=query_vector, | |
| limit=top_k, | |
| ) | |
| return [ | |
| { | |
| "text": hit.payload.get("text", ""), | |
| "score": hit.score, | |
| "metadata": { | |
| k: v for k, v in hit.payload.items() if k != "text" | |
| }, | |
| } | |
| for hit in results | |
| ] | |
| def get_collection_info(self) -> dict: | |
| """Get information about the collection.""" | |
| try: | |
| info = self.client.get_collection(COLLECTION_NAME) | |
| return { | |
| "name": COLLECTION_NAME, | |
| "vectors_count": info.vectors_count, | |
| "points_count": info.points_count, | |
| "status": info.status.value, | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # Global singleton — lazy loaded | |
| _retriever = None | |
| def get_retriever() -> Retriever: | |
| global _retriever | |
| if _retriever is None: | |
| _retriever = Retriever() | |
| return _retriever | |