File size: 6,016 Bytes
c4b5910
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3736c33
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import uuid
import hashlib
from qdrant_client import QdrantClient
from qdrant_client.http import models
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Optional
import threading
import logging
import warnings

warnings.filterwarnings('ignore', category=FutureWarning)
logging.getLogger('sentence_transformers').setLevel(logging.WARNING)

class VectorDatabase:
    """Manage vector database for document embeddings using Qdrant Cloud."""

    _embedding_model = None
    _embedding_model_name = None
    _embedding_model_lock = threading.Lock()
    
    def __init__(self, collection_name: str = "documents", persist_directory: str = None):
        """Initialize Qdrant Client (persist_directory is ignored for Cloud)"""
        
        qdrant_url = os.getenv("QDRANT_URL")
        qdrant_api_key = os.getenv("QDRANT_API_KEY")
        
        if not qdrant_url or not qdrant_api_key:
            raise ValueError("QDRANT_URL and QDRANT_API_KEY must be set in environment variables.")

        self.client = QdrantClient(
            url=qdrant_url,
            api_key=qdrant_api_key,
            timeout=60.0
        )
        self.collection_name = collection_name
        self.vector_size = 384  # Size for standard sentence-transformers (e.g. all-MiniLM-L6-v2)
        
        # Ensure collection exists
        self._ensure_collection()
        
        # Load embedding model
        self.embedding_model = self._get_or_create_embedding_model()

    def _ensure_collection(self):
        """Creates the collection in Qdrant if it doesn't exist."""
        try:
            collections = self.client.get_collections().collections
            exists = any(c.name == self.collection_name for c in collections)
            
            if not exists:
                self.client.create_collection(
                    collection_name=self.collection_name,
                    vectors_config=models.VectorParams(
                        size=self.vector_size, 
                        distance=models.Distance.COSINE
                    )
                )
        except Exception as e:
            print(f"Error checking/creating collection: {e}")

    @classmethod
    def _get_or_create_embedding_model(cls):
        with cls._embedding_model_lock:
            # Assuming you set EMBEDDING_MODEL in your config, defaulting to MiniLM
            model_name = os.getenv("EMBEDDING_MODEL", "all-MiniLM-L6-v2")
            if cls._embedding_model is None or cls._embedding_model_name != model_name:
                import torch
                device = 'cuda' if torch.cuda.is_available() else 'cpu'
                print(f"Loading embedding model on {device}...")
                cls._embedding_model = SentenceTransformer(model_name, device=device)
                cls._embedding_model_name = model_name
            return cls._embedding_model

    def _string_to_uuid(self, string_id: str) -> str:
        """Qdrant requires proper UUIDs. This hashes your custom string IDs into UUIDs."""
        return str(uuid.UUID(hashlib.md5(string_id.encode()).hexdigest()))

    def add_documents(self, texts: List[str], metadatas: List[Dict], ids: List[str]):
        if not texts:
            return
        
        embeddings = self.embedding_model.encode(texts, show_progress_bar=False, batch_size=64).tolist()
        
        points = []
        for i in range(len(texts)):
            payload = metadatas[i] if metadatas[i] else {}
            payload['text'] = texts[i]  # Store actual text in payload for retrieval
            
            points.append(models.PointStruct(
                id=self._string_to_uuid(ids[i]),
                vector=embeddings[i],
                payload=payload
            ))
            
        # REMOVED self.client.upsert()
        # ADDED self.client.upload_points() with native auto-batching
        self.client.upload_points(
            collection_name=self.collection_name,
            points=points,
            batch_size=100, # Qdrant will automatically cut the payload into chunks of 100!
            wait=True # Ensures the upload finishes before returning to Flutter
        )
    
    def query(self, query_text: str, n_results: int = 5, filter_dict: Optional[Dict] = None) -> Dict:
        # Check if collection is empty
        count = self.get_collection_count()
        if count == 0:
            return {"documents": [[]], "metadatas": [[]], "distances": [[]], "ids": [[]]}

        query_embedding = self.embedding_model.encode([query_text])[0].tolist()

        # Build Qdrant filter if provided
        qdrant_filter = None
        if filter_dict:
            conditions = [
                models.FieldCondition(key=k, match=models.MatchValue(value=v)) 
                for k, v in filter_dict.items()
            ]
            qdrant_filter = models.Filter(must=conditions)

        search_result = self.client.search(
            collection_name=self.collection_name,
            query_vector=query_embedding,
            query_filter=qdrant_filter,
            limit=n_results
        )

        # Format output to match exactly what your HybridRetriever expects (ChromaDB style)
        docs, metas, scores, ids = [], [], [], []
        for hit in search_result:
            docs.append(hit.payload.get('text', ''))
            
            # Remove text from metadata so it mimics Chroma
            meta = {k: v for k, v in hit.payload.items() if k != 'text'}
            metas.append(meta)
            
            scores.append(hit.score)
            ids.append(str(hit.id))

        return {
            "documents": [docs],
            "metadatas": [metas],
            "distances": [scores], # Note: Qdrant uses cosine similarity (higher is better), Chroma uses distance.
            "ids": [ids]
        }
    
    def get_collection_count(self) -> int:
        try:
            return self.client.count(collection_name=self.collection_name).count
        except Exception:
            return 0