import os import uuid import hashlib from qdrant_client import QdrantClient from qdrant_client.http import models from sentence_transformers import SentenceTransformer from typing import List, Dict, Optional import threading import logging import warnings warnings.filterwarnings('ignore', category=FutureWarning) logging.getLogger('sentence_transformers').setLevel(logging.WARNING) class VectorDatabase: """Manage vector database for document embeddings using Qdrant Cloud.""" _embedding_model = None _embedding_model_name = None _embedding_model_lock = threading.Lock() def __init__(self, collection_name: str = "documents", persist_directory: str = None): """Initialize Qdrant Client (persist_directory is ignored for Cloud)""" qdrant_url = os.getenv("QDRANT_URL") qdrant_api_key = os.getenv("QDRANT_API_KEY") if not qdrant_url or not qdrant_api_key: raise ValueError("QDRANT_URL and QDRANT_API_KEY must be set in environment variables.") self.client = QdrantClient( url=qdrant_url, api_key=qdrant_api_key, timeout=60.0 ) self.collection_name = collection_name self.vector_size = 384 # Size for standard sentence-transformers (e.g. all-MiniLM-L6-v2) # Ensure collection exists self._ensure_collection() # Load embedding model self.embedding_model = self._get_or_create_embedding_model() def _ensure_collection(self): """Creates the collection in Qdrant if it doesn't exist.""" try: collections = self.client.get_collections().collections exists = any(c.name == self.collection_name for c in collections) if not exists: self.client.create_collection( collection_name=self.collection_name, vectors_config=models.VectorParams( size=self.vector_size, distance=models.Distance.COSINE ) ) except Exception as e: print(f"Error checking/creating collection: {e}") @classmethod def _get_or_create_embedding_model(cls): with cls._embedding_model_lock: # Assuming you set EMBEDDING_MODEL in your config, defaulting to MiniLM model_name = os.getenv("EMBEDDING_MODEL", "all-MiniLM-L6-v2") if cls._embedding_model is None or cls._embedding_model_name != model_name: import torch device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Loading embedding model on {device}...") cls._embedding_model = SentenceTransformer(model_name, device=device) cls._embedding_model_name = model_name return cls._embedding_model def _string_to_uuid(self, string_id: str) -> str: """Qdrant requires proper UUIDs. This hashes your custom string IDs into UUIDs.""" return str(uuid.UUID(hashlib.md5(string_id.encode()).hexdigest())) def add_documents(self, texts: List[str], metadatas: List[Dict], ids: List[str]): if not texts: return embeddings = self.embedding_model.encode(texts, show_progress_bar=False, batch_size=64).tolist() points = [] for i in range(len(texts)): payload = metadatas[i] if metadatas[i] else {} payload['text'] = texts[i] # Store actual text in payload for retrieval points.append(models.PointStruct( id=self._string_to_uuid(ids[i]), vector=embeddings[i], payload=payload )) # REMOVED self.client.upsert() # ADDED self.client.upload_points() with native auto-batching self.client.upload_points( collection_name=self.collection_name, points=points, batch_size=100, # Qdrant will automatically cut the payload into chunks of 100! wait=True # Ensures the upload finishes before returning to Flutter ) def query(self, query_text: str, n_results: int = 5, filter_dict: Optional[Dict] = None) -> Dict: # Check if collection is empty count = self.get_collection_count() if count == 0: return {"documents": [[]], "metadatas": [[]], "distances": [[]], "ids": [[]]} query_embedding = self.embedding_model.encode([query_text])[0].tolist() # Build Qdrant filter if provided qdrant_filter = None if filter_dict: conditions = [ models.FieldCondition(key=k, match=models.MatchValue(value=v)) for k, v in filter_dict.items() ] qdrant_filter = models.Filter(must=conditions) search_result = self.client.search( collection_name=self.collection_name, query_vector=query_embedding, query_filter=qdrant_filter, limit=n_results ) # Format output to match exactly what your HybridRetriever expects (ChromaDB style) docs, metas, scores, ids = [], [], [], [] for hit in search_result: docs.append(hit.payload.get('text', '')) # Remove text from metadata so it mimics Chroma meta = {k: v for k, v in hit.payload.items() if k != 'text'} metas.append(meta) scores.append(hit.score) ids.append(str(hit.id)) return { "documents": [docs], "metadatas": [metas], "distances": [scores], # Note: Qdrant uses cosine similarity (higher is better), Chroma uses distance. "ids": [ids] } def get_collection_count(self) -> int: try: return self.client.count(collection_name=self.collection_name).count except Exception: return 0