Spaces:
Sleeping
Sleeping
File size: 8,612 Bytes
e272f4f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | import os
from sentence_transformers import SentenceTransformer
import numpy as np
import logging
from typing import List, Dict, Optional
from app.config import Config
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from qdrant_client.http.exceptions import UnexpectedResponse
class EmbeddingHandler:
"""
Handles all embedding-related operations including:
- Text embedding generation using SentenceTransformers
- Vector storage and retrieval with Qdrant
- Collection management for vector storage
This serves as the central component for vector operations in the RAG system.
"""
def __init__(self):
"""Initialize the embedding handler with model and vector store client."""
self.logger = logging.getLogger(__name__)
try:
# Initialize embedding model with configuration from Config
self.model = SentenceTransformer(Config.EMBEDDING_MODEL)
# Get embedding dimension from the model
self.embedding_dim = self.model.get_sentence_embedding_dimension()
# Initialize Qdrant client with configuration from Config
self.qdrant_client = QdrantClient(
url=Config.QDRANT_URL,
api_key=Config.QDRANT_API_KEY,
prefer_grpc=False, # HTTP preferred over gRPC for compatibility
timeout=30 # Connection timeout in seconds
)
# Connection test can be uncommented for local development
# self._verify_connection()
except Exception as e:
self.logger.error(f"Error initializing embedding handler: {str(e)}", exc_info=True)
raise RuntimeError("Failed to initialize embedding handler") from e
def generate_embeddings(self, texts: List[str]) -> np.ndarray:
"""
Generate embeddings for a list of text strings.
Args:
texts: List of text strings to embed
Returns:
np.ndarray: Array of embeddings (2D numpy array)
Raises:
Exception: If embedding generation fails
"""
try:
return self.model.encode(
texts,
show_progress_bar=True, # Visual progress indicator
batch_size=32, # Optimal batch size for most GPUs
convert_to_numpy=True # Return as numpy array for efficiency
)
except Exception as e:
self.logger.error(f"Error generating embeddings: {str(e)}", exc_info=True)
raise
def create_collection(self, collection_name: str) -> bool:
"""
Create a new Qdrant collection for storing vectors.
Args:
collection_name: Name of the collection to create
Returns:
bool: True if collection was created or already exists
Raises:
Exception: If collection creation fails (except for already exists case)
"""
try:
self.qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=self.embedding_dim, # Must match model's embedding dimension
distance=Distance.COSINE # Using cosine similarity
)
)
self.logger.info(f"Created collection {collection_name}")
return True
except UnexpectedResponse as e:
# Handle case where collection already exists
if "already exists" in str(e):
self.logger.info(f"Collection {collection_name} already exists")
return True
else:
self.logger.error(f"Error creating collection: {e}")
raise
except Exception as e:
self.logger.error(f"Error creating collection: {str(e)}", exc_info=True)
raise
def add_to_collection(self, collection_name: str, embeddings: np.ndarray, payloads: List[dict]) -> bool:
"""
Add embeddings and associated metadata to a Qdrant collection.
Args:
collection_name: Target collection name
embeddings: Numpy array of embeddings to add
payloads: List of metadata dictionaries corresponding to each embedding
Returns:
bool: True if operation succeeded
Raises:
Exception: If operation fails
"""
try:
# Convert numpy arrays to lists for Qdrant compatibility
if isinstance(embeddings, np.ndarray):
embeddings = embeddings.tolist()
# Prepare points in batches for efficient processing
batch_size = 100 # Optimal batch size for Qdrant Cloud
points = [
PointStruct(
id=idx, # Sequential ID
vector=embedding,
payload=payload # Associated metadata
)
for idx, (embedding, payload) in enumerate(zip(embeddings, payloads))
]
# Process in batches to avoid overwhelming the server
for i in range(0, len(points), batch_size):
batch = points[i:i + batch_size]
self.qdrant_client.upsert(
collection_name=collection_name,
points=batch,
wait=True # Ensure immediate persistence
)
self.logger.info(f"Added {len(points)} vectors to collection {collection_name}")
return True
except Exception as e:
self.logger.error(f"Error adding to collection: {str(e)}", exc_info=True)
raise
async def search_collection(self, collection_name: str, query: str, k: int = 5) -> Dict:
"""
Search a Qdrant collection for similar vectors to the query.
Args:
collection_name: Name of collection to search
query: Text query to search for
k: Number of similar results to return (default: 5)
Returns:
Dict: {
"status": "success"|"error",
"results": List[Dict] (if success),
"message": str (if error)
}
"""
try:
# Generate embedding for the query text
query_embedding = self.model.encode(query).tolist()
# Perform similarity search in Qdrant
results = self.qdrant_client.search(
collection_name=collection_name,
query_vector=query_embedding,
limit=k, # Number of results to return
with_payload=True, # Include metadata
with_vectors=False # Exclude raw vectors to save bandwidth
)
# Format results for consistent API response
formatted_results = []
for hit in results:
formatted_results.append({
"id": hit.id,
"score": float(hit.score), # Similarity score
"payload": hit.payload or {}, # Associated metadata
"text": hit.payload.get("text", "") if hit.payload else "" # Extracted text
})
return {
"status": "success",
"results": formatted_results
}
except Exception as e:
self.logger.error(f"Search error: {str(e)}", exc_info=True)
return {
"status": "error",
"message": str(e),
"results": []
}
# Deprecated FAISS methods (maintained for backward compatibility)
def create_faiss_index(self, *args, **kwargs):
"""Deprecated method - FAISS support has been replaced by Qdrant."""
self.logger.warning("FAISS operations are deprecated")
raise NotImplementedError("Use Qdrant collections instead of FAISS")
def save_index(self, *args, **kwargs):
"""Deprecated method - Qdrant persists data automatically."""
self.logger.warning("FAISS operations are deprecated")
raise NotImplementedError("Qdrant persists data automatically")
def load_index(self, *args, **kwargs):
"""Deprecated method - Access Qdrant collections directly."""
self.logger.warning("FAISS operations are deprecated")
raise NotImplementedError("Access Qdrant collections directly") |