File size: 8,612 Bytes
e272f4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import os
from sentence_transformers import SentenceTransformer
import numpy as np
import logging
from typing import List, Dict, Optional
from app.config import Config
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from qdrant_client.http.exceptions import UnexpectedResponse

class EmbeddingHandler:
    """
    Handles all embedding-related operations including:
    - Text embedding generation using SentenceTransformers
    - Vector storage and retrieval with Qdrant
    - Collection management for vector storage
    
    This serves as the central component for vector operations in the RAG system.
    """
    
    def __init__(self):
        """Initialize the embedding handler with model and vector store client."""
        self.logger = logging.getLogger(__name__)
        try:
            # Initialize embedding model with configuration from Config
            self.model = SentenceTransformer(Config.EMBEDDING_MODEL)
            # Get embedding dimension from the model
            self.embedding_dim = self.model.get_sentence_embedding_dimension()

            # Initialize Qdrant client with configuration from Config
            self.qdrant_client = QdrantClient(
                url=Config.QDRANT_URL,
                api_key=Config.QDRANT_API_KEY,
                prefer_grpc=False,  # HTTP preferred over gRPC for compatibility
                timeout=30  # Connection timeout in seconds
            )
            
            # Connection test can be uncommented for local development
            # self._verify_connection()

        except Exception as e:
            self.logger.error(f"Error initializing embedding handler: {str(e)}", exc_info=True)
            raise RuntimeError("Failed to initialize embedding handler") from e

    def generate_embeddings(self, texts: List[str]) -> np.ndarray:
        """
        Generate embeddings for a list of text strings.
        
        Args:
            texts: List of text strings to embed
            
        Returns:
            np.ndarray: Array of embeddings (2D numpy array)
            
        Raises:
            Exception: If embedding generation fails
        """
        try:
            return self.model.encode(
                texts,
                show_progress_bar=True,  # Visual progress indicator
                batch_size=32,  # Optimal batch size for most GPUs
                convert_to_numpy=True  # Return as numpy array for efficiency
            )
        except Exception as e:
            self.logger.error(f"Error generating embeddings: {str(e)}", exc_info=True)
            raise

    def create_collection(self, collection_name: str) -> bool:
        """
        Create a new Qdrant collection for storing vectors.
        
        Args:
            collection_name: Name of the collection to create
            
        Returns:
            bool: True if collection was created or already exists
            
        Raises:
            Exception: If collection creation fails (except for already exists case)
        """
        try:
            self.qdrant_client.create_collection(
                collection_name=collection_name,
                vectors_config=VectorParams(
                    size=self.embedding_dim,  # Must match model's embedding dimension
                    distance=Distance.COSINE  # Using cosine similarity
                )
            )
            self.logger.info(f"Created collection {collection_name}")
            return True

        except UnexpectedResponse as e:
            # Handle case where collection already exists
            if "already exists" in str(e):
                self.logger.info(f"Collection {collection_name} already exists")
                return True
            else:
                self.logger.error(f"Error creating collection: {e}")
                raise
        except Exception as e:
            self.logger.error(f"Error creating collection: {str(e)}", exc_info=True)
            raise

    def add_to_collection(self, collection_name: str, embeddings: np.ndarray, payloads: List[dict]) -> bool:
        """
        Add embeddings and associated metadata to a Qdrant collection.
        
        Args:
            collection_name: Target collection name
            embeddings: Numpy array of embeddings to add
            payloads: List of metadata dictionaries corresponding to each embedding
            
        Returns:
            bool: True if operation succeeded
            
        Raises:
            Exception: If operation fails
        """
        try:
            # Convert numpy arrays to lists for Qdrant compatibility
            if isinstance(embeddings, np.ndarray):
                embeddings = embeddings.tolist()

            # Prepare points in batches for efficient processing
            batch_size = 100  # Optimal batch size for Qdrant Cloud
            points = [
                PointStruct(
                    id=idx,  # Sequential ID
                    vector=embedding,
                    payload=payload  # Associated metadata
                )
                for idx, (embedding, payload) in enumerate(zip(embeddings, payloads))
            ]

            # Process in batches to avoid overwhelming the server
            for i in range(0, len(points), batch_size):
                batch = points[i:i + batch_size]
                self.qdrant_client.upsert(
                    collection_name=collection_name,
                    points=batch,
                    wait=True  # Ensure immediate persistence
                )

            self.logger.info(f"Added {len(points)} vectors to collection {collection_name}")
            return True

        except Exception as e:
            self.logger.error(f"Error adding to collection: {str(e)}", exc_info=True)
            raise

    async def search_collection(self, collection_name: str, query: str, k: int = 5) -> Dict:
        """
        Search a Qdrant collection for similar vectors to the query.
        
        Args:
            collection_name: Name of collection to search
            query: Text query to search for
            k: Number of similar results to return (default: 5)
            
        Returns:
            Dict: {
                "status": "success"|"error",
                "results": List[Dict] (if success),
                "message": str (if error)
            }
        """
        try:
            # Generate embedding for the query text
            query_embedding = self.model.encode(query).tolist()

            # Perform similarity search in Qdrant
            results = self.qdrant_client.search(
                collection_name=collection_name,
                query_vector=query_embedding,
                limit=k,  # Number of results to return
                with_payload=True,  # Include metadata
                with_vectors=False  # Exclude raw vectors to save bandwidth
            )

            # Format results for consistent API response
            formatted_results = []
            for hit in results:
                formatted_results.append({
                    "id": hit.id,
                    "score": float(hit.score),  # Similarity score
                    "payload": hit.payload or {},  # Associated metadata
                    "text": hit.payload.get("text", "") if hit.payload else ""  # Extracted text
                })

            return {
                "status": "success",
                "results": formatted_results
            }

        except Exception as e:
            self.logger.error(f"Search error: {str(e)}", exc_info=True)
            return {
                "status": "error",
                "message": str(e),
                "results": []
            }

    # Deprecated FAISS methods (maintained for backward compatibility)
    def create_faiss_index(self, *args, **kwargs):
        """Deprecated method - FAISS support has been replaced by Qdrant."""
        self.logger.warning("FAISS operations are deprecated")
        raise NotImplementedError("Use Qdrant collections instead of FAISS")

    def save_index(self, *args, **kwargs):
        """Deprecated method - Qdrant persists data automatically."""
        self.logger.warning("FAISS operations are deprecated")
        raise NotImplementedError("Qdrant persists data automatically")

    def load_index(self, *args, **kwargs):
        """Deprecated method - Access Qdrant collections directly."""
        self.logger.warning("FAISS operations are deprecated")
        raise NotImplementedError("Access Qdrant collections directly")