Spaces:
Sleeping
Sleeping
| import os | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| import logging | |
| from typing import List, Dict, Optional | |
| from app.config import Config | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Distance, VectorParams, PointStruct | |
| from qdrant_client.http.exceptions import UnexpectedResponse | |
| class EmbeddingHandler: | |
| """ | |
| Handles all embedding-related operations including: | |
| - Text embedding generation using SentenceTransformers | |
| - Vector storage and retrieval with Qdrant | |
| - Collection management for vector storage | |
| This serves as the central component for vector operations in the RAG system. | |
| """ | |
| def __init__(self): | |
| """Initialize the embedding handler with model and vector store client.""" | |
| self.logger = logging.getLogger(__name__) | |
| try: | |
| # Initialize embedding model with configuration from Config | |
| self.model = SentenceTransformer(Config.EMBEDDING_MODEL) | |
| # Get embedding dimension from the model | |
| self.embedding_dim = self.model.get_sentence_embedding_dimension() | |
| # Initialize Qdrant client with configuration from Config | |
| self.qdrant_client = QdrantClient( | |
| url=Config.QDRANT_URL, | |
| api_key=Config.QDRANT_API_KEY, | |
| prefer_grpc=False, # HTTP preferred over gRPC for compatibility | |
| timeout=30 # Connection timeout in seconds | |
| ) | |
| # Connection test can be uncommented for local development | |
| # self._verify_connection() | |
| except Exception as e: | |
| self.logger.error(f"Error initializing embedding handler: {str(e)}", exc_info=True) | |
| raise RuntimeError("Failed to initialize embedding handler") from e | |
| def generate_embeddings(self, texts: List[str]) -> np.ndarray: | |
| """ | |
| Generate embeddings for a list of text strings. | |
| Args: | |
| texts: List of text strings to embed | |
| Returns: | |
| np.ndarray: Array of embeddings (2D numpy array) | |
| Raises: | |
| Exception: If embedding generation fails | |
| """ | |
| try: | |
| return self.model.encode( | |
| texts, | |
| show_progress_bar=True, # Visual progress indicator | |
| batch_size=32, # Optimal batch size for most GPUs | |
| convert_to_numpy=True # Return as numpy array for efficiency | |
| ) | |
| except Exception as e: | |
| self.logger.error(f"Error generating embeddings: {str(e)}", exc_info=True) | |
| raise | |
| def create_collection(self, collection_name: str) -> bool: | |
| """ | |
| Create a new Qdrant collection for storing vectors. | |
| Args: | |
| collection_name: Name of the collection to create | |
| Returns: | |
| bool: True if collection was created or already exists | |
| Raises: | |
| Exception: If collection creation fails (except for already exists case) | |
| """ | |
| try: | |
| self.qdrant_client.create_collection( | |
| collection_name=collection_name, | |
| vectors_config=VectorParams( | |
| size=self.embedding_dim, # Must match model's embedding dimension | |
| distance=Distance.COSINE # Using cosine similarity | |
| ) | |
| ) | |
| self.logger.info(f"Created collection {collection_name}") | |
| return True | |
| except UnexpectedResponse as e: | |
| # Handle case where collection already exists | |
| if "already exists" in str(e): | |
| self.logger.info(f"Collection {collection_name} already exists") | |
| return True | |
| else: | |
| self.logger.error(f"Error creating collection: {e}") | |
| raise | |
| except Exception as e: | |
| self.logger.error(f"Error creating collection: {str(e)}", exc_info=True) | |
| raise | |
| def add_to_collection(self, collection_name: str, embeddings: np.ndarray, payloads: List[dict]) -> bool: | |
| """ | |
| Add embeddings and associated metadata to a Qdrant collection. | |
| Args: | |
| collection_name: Target collection name | |
| embeddings: Numpy array of embeddings to add | |
| payloads: List of metadata dictionaries corresponding to each embedding | |
| Returns: | |
| bool: True if operation succeeded | |
| Raises: | |
| Exception: If operation fails | |
| """ | |
| try: | |
| # Convert numpy arrays to lists for Qdrant compatibility | |
| if isinstance(embeddings, np.ndarray): | |
| embeddings = embeddings.tolist() | |
| # Prepare points in batches for efficient processing | |
| batch_size = 100 # Optimal batch size for Qdrant Cloud | |
| points = [ | |
| PointStruct( | |
| id=idx, # Sequential ID | |
| vector=embedding, | |
| payload=payload # Associated metadata | |
| ) | |
| for idx, (embedding, payload) in enumerate(zip(embeddings, payloads)) | |
| ] | |
| # Process in batches to avoid overwhelming the server | |
| for i in range(0, len(points), batch_size): | |
| batch = points[i:i + batch_size] | |
| self.qdrant_client.upsert( | |
| collection_name=collection_name, | |
| points=batch, | |
| wait=True # Ensure immediate persistence | |
| ) | |
| self.logger.info(f"Added {len(points)} vectors to collection {collection_name}") | |
| return True | |
| except Exception as e: | |
| self.logger.error(f"Error adding to collection: {str(e)}", exc_info=True) | |
| raise | |
| async def search_collection(self, collection_name: str, query: str, k: int = 5) -> Dict: | |
| """ | |
| Search a Qdrant collection for similar vectors to the query. | |
| Args: | |
| collection_name: Name of collection to search | |
| query: Text query to search for | |
| k: Number of similar results to return (default: 5) | |
| Returns: | |
| Dict: { | |
| "status": "success"|"error", | |
| "results": List[Dict] (if success), | |
| "message": str (if error) | |
| } | |
| """ | |
| try: | |
| # Generate embedding for the query text | |
| query_embedding = self.model.encode(query).tolist() | |
| # Perform similarity search in Qdrant | |
| results = self.qdrant_client.search( | |
| collection_name=collection_name, | |
| query_vector=query_embedding, | |
| limit=k, # Number of results to return | |
| with_payload=True, # Include metadata | |
| with_vectors=False # Exclude raw vectors to save bandwidth | |
| ) | |
| # Format results for consistent API response | |
| formatted_results = [] | |
| for hit in results: | |
| formatted_results.append({ | |
| "id": hit.id, | |
| "score": float(hit.score), # Similarity score | |
| "payload": hit.payload or {}, # Associated metadata | |
| "text": hit.payload.get("text", "") if hit.payload else "" # Extracted text | |
| }) | |
| return { | |
| "status": "success", | |
| "results": formatted_results | |
| } | |
| except Exception as e: | |
| self.logger.error(f"Search error: {str(e)}", exc_info=True) | |
| return { | |
| "status": "error", | |
| "message": str(e), | |
| "results": [] | |
| } | |
| # Deprecated FAISS methods (maintained for backward compatibility) | |
| def create_faiss_index(self, *args, **kwargs): | |
| """Deprecated method - FAISS support has been replaced by Qdrant.""" | |
| self.logger.warning("FAISS operations are deprecated") | |
| raise NotImplementedError("Use Qdrant collections instead of FAISS") | |
| def save_index(self, *args, **kwargs): | |
| """Deprecated method - Qdrant persists data automatically.""" | |
| self.logger.warning("FAISS operations are deprecated") | |
| raise NotImplementedError("Qdrant persists data automatically") | |
| def load_index(self, *args, **kwargs): | |
| """Deprecated method - Access Qdrant collections directly.""" | |
| self.logger.warning("FAISS operations are deprecated") | |
| raise NotImplementedError("Access Qdrant collections directly") |