Spaces:
Sleeping
Sleeping
File size: 6,016 Bytes
c4b5910 3736c33 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | import os
import uuid
import hashlib
from qdrant_client import QdrantClient
from qdrant_client.http import models
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Optional
import threading
import logging
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)
logging.getLogger('sentence_transformers').setLevel(logging.WARNING)
class VectorDatabase:
"""Manage vector database for document embeddings using Qdrant Cloud."""
_embedding_model = None
_embedding_model_name = None
_embedding_model_lock = threading.Lock()
def __init__(self, collection_name: str = "documents", persist_directory: str = None):
"""Initialize Qdrant Client (persist_directory is ignored for Cloud)"""
qdrant_url = os.getenv("QDRANT_URL")
qdrant_api_key = os.getenv("QDRANT_API_KEY")
if not qdrant_url or not qdrant_api_key:
raise ValueError("QDRANT_URL and QDRANT_API_KEY must be set in environment variables.")
self.client = QdrantClient(
url=qdrant_url,
api_key=qdrant_api_key,
timeout=60.0
)
self.collection_name = collection_name
self.vector_size = 384 # Size for standard sentence-transformers (e.g. all-MiniLM-L6-v2)
# Ensure collection exists
self._ensure_collection()
# Load embedding model
self.embedding_model = self._get_or_create_embedding_model()
def _ensure_collection(self):
"""Creates the collection in Qdrant if it doesn't exist."""
try:
collections = self.client.get_collections().collections
exists = any(c.name == self.collection_name for c in collections)
if not exists:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=models.VectorParams(
size=self.vector_size,
distance=models.Distance.COSINE
)
)
except Exception as e:
print(f"Error checking/creating collection: {e}")
@classmethod
def _get_or_create_embedding_model(cls):
with cls._embedding_model_lock:
# Assuming you set EMBEDDING_MODEL in your config, defaulting to MiniLM
model_name = os.getenv("EMBEDDING_MODEL", "all-MiniLM-L6-v2")
if cls._embedding_model is None or cls._embedding_model_name != model_name:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Loading embedding model on {device}...")
cls._embedding_model = SentenceTransformer(model_name, device=device)
cls._embedding_model_name = model_name
return cls._embedding_model
def _string_to_uuid(self, string_id: str) -> str:
"""Qdrant requires proper UUIDs. This hashes your custom string IDs into UUIDs."""
return str(uuid.UUID(hashlib.md5(string_id.encode()).hexdigest()))
def add_documents(self, texts: List[str], metadatas: List[Dict], ids: List[str]):
if not texts:
return
embeddings = self.embedding_model.encode(texts, show_progress_bar=False, batch_size=64).tolist()
points = []
for i in range(len(texts)):
payload = metadatas[i] if metadatas[i] else {}
payload['text'] = texts[i] # Store actual text in payload for retrieval
points.append(models.PointStruct(
id=self._string_to_uuid(ids[i]),
vector=embeddings[i],
payload=payload
))
# REMOVED self.client.upsert()
# ADDED self.client.upload_points() with native auto-batching
self.client.upload_points(
collection_name=self.collection_name,
points=points,
batch_size=100, # Qdrant will automatically cut the payload into chunks of 100!
wait=True # Ensures the upload finishes before returning to Flutter
)
def query(self, query_text: str, n_results: int = 5, filter_dict: Optional[Dict] = None) -> Dict:
# Check if collection is empty
count = self.get_collection_count()
if count == 0:
return {"documents": [[]], "metadatas": [[]], "distances": [[]], "ids": [[]]}
query_embedding = self.embedding_model.encode([query_text])[0].tolist()
# Build Qdrant filter if provided
qdrant_filter = None
if filter_dict:
conditions = [
models.FieldCondition(key=k, match=models.MatchValue(value=v))
for k, v in filter_dict.items()
]
qdrant_filter = models.Filter(must=conditions)
search_result = self.client.search(
collection_name=self.collection_name,
query_vector=query_embedding,
query_filter=qdrant_filter,
limit=n_results
)
# Format output to match exactly what your HybridRetriever expects (ChromaDB style)
docs, metas, scores, ids = [], [], [], []
for hit in search_result:
docs.append(hit.payload.get('text', ''))
# Remove text from metadata so it mimics Chroma
meta = {k: v for k, v in hit.payload.items() if k != 'text'}
metas.append(meta)
scores.append(hit.score)
ids.append(str(hit.id))
return {
"documents": [docs],
"metadatas": [metas],
"distances": [scores], # Note: Qdrant uses cosine similarity (higher is better), Chroma uses distance.
"ids": [ids]
}
def get_collection_count(self) -> int:
try:
return self.client.count(collection_name=self.collection_name).count
except Exception:
return 0 |