File size: 9,663 Bytes
eefb354
 
3b5d2e9
 
 
 
eefb354
 
 
 
5dd2ee5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cab845
 
 
 
5dd2ee5
4cab845
 
5dd2ee5
 
 
4cab845
 
 
 
 
 
 
5dd2ee5
 
 
 
 
eefb354
 
 
 
 
 
 
3b5d2e9
4cab845
 
 
 
 
 
 
 
 
 
 
 
3b5d2e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cab845
 
 
3b5d2e9
4cab845
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b5d2e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eefb354
 
 
 
 
 
 
 
43fe2fe
eefb354
 
 
 
 
 
 
3b5d2e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c72956b
 
3b5d2e9
c72956b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b5d2e9
 
 
c72956b
 
 
3b5d2e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
from qdrant_client import QdrantClient, models
import os
import uuid
import logging

logger = logging.getLogger(__name__)

# --- Qdrant Client Initialization ---

def get_qdrant_client():
    """Initializes and returns the Qdrant client, prioritizing Cloud over local."""
    qdrant_url = os.environ.get("QDRANT_URL")
    qdrant_api_key = os.environ.get("QDRANT_API_KEY")

    # Priority 1: Qdrant Cloud (production)
    if qdrant_url and qdrant_api_key:
        try:
            client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
            logger.info(f"Connected to Qdrant Cloud at {qdrant_url}")
            return client
        except Exception as e:
            logger.error(f"Failed to connect to Qdrant Cloud with provided credentials: {e}")
            raise  # If cloud credentials are provided, failure should be fatal.

    # Priority 2: Local Docker container
    qdrant_host = os.environ.get("QDRANT_HOST")
    if qdrant_host and qdrant_host != "localhost":
        try:
            client = QdrantClient(host=qdrant_host, port=6333)
            logger.info(f"Connected to local Qdrant server at {qdrant_host}")
            return client
        except Exception as e:
            logger.warning(f"Failed to connect to local Qdrant server at {qdrant_host}: {e}")

    # Priority 3: Local file-based storage (fallback for development)
    try:
        data_dir = "/app/data/qdrant"
        os.makedirs(data_dir, exist_ok=True)
        client = QdrantClient(path=data_dir)
        logger.info(f"Using file-based Qdrant client at {data_dir}")
        return client
    except Exception as e:
        logger.warning(f"Failed to create file-based Qdrant client: {e}")
        # Final fallback: in-memory
        client = QdrantClient(":memory:")
        logger.info("Using in-memory Qdrant client as final fallback")
        return client

# --- Collection Management ---

def create_collection_if_not_exists(client: QdrantClient, collection_name: str, vector_size: int):
    """Creates a Qdrant collection if it doesn't already exist."""
    try:
        client.get_collection(collection_name=collection_name)
        logger.info(f"Collection '{collection_name}' already exists")
    except Exception as e:
        # If the collection does not exist, this will raise an exception
        logger.info(f"Collection '{collection_name}' does not exist, creating it...")
        try:
            client.create_collection(
                collection_name=collection_name,
                vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.COSINE),
            )
            logger.info(f"Created new collection '{collection_name}'")
        except Exception as create_error:
            logger.error(f"Failed to create collection '{collection_name}': {str(create_error)}")
            raise

# --- User-Specific Collection Management ---

def get_user_collection_name(user_id: uuid.UUID) -> str:
    """
    Generate a user-specific collection name.
    
    Args:
        user_id: The user's UUID
        
    Returns:
        Collection name in format 'user_{user_id_without_hyphens}'
    """
    # Convert UUID to string and remove hyphens for valid collection name
    user_id_str = str(user_id).replace('-', '_')
    return f"user_{user_id_str}"

def ensure_user_collection_exists(client: QdrantClient, user_id: uuid.UUID, vector_size: int) -> str:
    """
    Ensure that a user-specific collection exists in Qdrant.
    
    Args:
        client: Qdrant client instance
        user_id: The user's UUID
        vector_size: Size of the embedding vectors
        
    Returns:
        The collection name that was created or verified
    """
    try:
        collection_name = get_user_collection_name(user_id)
        logger.info(f"Ensuring collection exists for user {user_id}: {collection_name}")
        
        try:
            # Check if collection exists
            client.get_collection(collection_name=collection_name)
            logger.info(f"User collection '{collection_name}' already exists for user {user_id}")
        except Exception as get_error:
            # Collection doesn't exist, create it
            logger.info(f"Collection '{collection_name}' does not exist, creating it for user {user_id}")
            try:
                client.create_collection(
                    collection_name=collection_name,
                    vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.COSINE),
                )
                logger.info(f"Created new user collection '{collection_name}' for user {user_id}")
            except Exception as create_error:
                logger.error(f"Failed to create collection '{collection_name}' for user {user_id}: {str(create_error)}")
                raise create_error
        
        return collection_name
        
    except Exception as e:
        logger.error(f"Error in ensure_user_collection_exists: {str(e)}")
        logger.error(f"Function called with client={type(client)}, user_id={user_id}, vector_size={vector_size}")
        raise

def collection_exists(client: QdrantClient, collection_name: str) -> bool:
    """
    Check if a collection exists in Qdrant.
    
    Args:
        client: Qdrant client instance
        collection_name: Name of the collection to check
        
    Returns:
        True if collection exists, False otherwise
    """
    try:
        client.get_collection(collection_name=collection_name)
        return True
    except Exception:
        return False

# --- Vector Operations ---

def upsert_vectors(client: QdrantClient, collection_name: str, vectors, payloads):
    """Upserts vectors and their payloads into the specified collection."""
    client.upsert(
        collection_name=collection_name,
        points=models.Batch(
            ids=list(range(len(vectors))),  # Generate sequential integer IDs
            vectors=vectors,
            payloads=payloads
        ),
        wait=True
    )

def search_vectors(client: QdrantClient, collection_name: str, query_vector, limit: int = 5):
    """
    Searches for similar vectors in the collection.
    
    Args:
        client: Qdrant client instance
        collection_name: Name of the collection to search
        query_vector: Query vector for similarity search
        limit: Maximum number of results to return
        
    Returns:
        Search results, or empty list if collection doesn't exist or is empty
    """
    try:
        # Check if collection exists first
        if not collection_exists(client, collection_name):
            logger.warning(f"Collection '{collection_name}' does not exist")
            return []
        
        # Check if collection has any points
        collection_info = client.get_collection(collection_name)
        if collection_info.points_count == 0:
            logger.info(f"Collection '{collection_name}' is empty")
            return []
        
        # Convert numpy array to list if needed
        query_vector_list = query_vector.tolist() if hasattr(query_vector, 'tolist') else query_vector
        
        # Qdrant Cloud uses the newer API (v1.7+)
        # Use query_points which is the current method
        try:
            logger.debug(f"Attempting query_points on collection '{collection_name}'")
            result = client.query_points(
                collection_name=collection_name,
                query=query_vector_list,
                limit=limit,
                with_payload=True
            )
            # Extract points from QueryResponse
            results = result.points if hasattr(result, 'points') else result
            logger.info(f"Found {len(results)} results using query_points in collection '{collection_name}'")
            return results
        except AttributeError as attr_err:
            # Fallback to older search method for backward compatibility
            logger.warning(f"query_points failed ({attr_err}), falling back to search method")
            try:
                results = client.search(
                    collection_name=collection_name,
                    query_vector=query_vector_list,
                    limit=limit,
                    with_payload=True
                )
                logger.info(f"Found {len(results)} results using search in collection '{collection_name}'")
                return results
            except Exception as search_err:
                logger.error(f"Both query_points and search failed. search error: {search_err}")
                raise
        
    except Exception as e:
        logger.error(f"Error searching collection '{collection_name}': {str(e)}")
        logger.error(f"Error type: {type(e).__name__}")
        import traceback
        logger.error(f"Traceback: {traceback.format_exc()}")
        return []

def get_collection_info(client: QdrantClient, collection_name: str) -> dict:
    """
    Get information about a collection.
    
    Args:
        client: Qdrant client instance
        collection_name: Name of the collection
        
    Returns:
        Dictionary with collection information or None if collection doesn't exist
    """
    try:
        collection_info = client.get_collection(collection_name)
        return {
            "name": collection_name,
            "points_count": collection_info.points_count,
            "status": collection_info.status,
            "vectors_count": collection_info.vectors_count if hasattr(collection_info, 'vectors_count') else None
        }
    except Exception as e:
        logger.error(f"Error getting collection info for '{collection_name}': {str(e)}")
        return None