import os from sentence_transformers import SentenceTransformer import numpy as np import logging from typing import List, Dict, Optional from app.config import Config from qdrant_client import QdrantClient from qdrant_client.models import Distance, VectorParams, PointStruct from qdrant_client.http.exceptions import UnexpectedResponse class EmbeddingHandler: """ Handles all embedding-related operations including: - Text embedding generation using SentenceTransformers - Vector storage and retrieval with Qdrant - Collection management for vector storage This serves as the central component for vector operations in the RAG system. """ def __init__(self): """Initialize the embedding handler with model and vector store client.""" self.logger = logging.getLogger(__name__) try: # Initialize embedding model with configuration from Config self.model = SentenceTransformer(Config.EMBEDDING_MODEL) # Get embedding dimension from the model self.embedding_dim = self.model.get_sentence_embedding_dimension() # Initialize Qdrant client with configuration from Config self.qdrant_client = QdrantClient( url=Config.QDRANT_URL, api_key=Config.QDRANT_API_KEY, prefer_grpc=False, # HTTP preferred over gRPC for compatibility timeout=30 # Connection timeout in seconds ) # Connection test can be uncommented for local development # self._verify_connection() except Exception as e: self.logger.error(f"Error initializing embedding handler: {str(e)}", exc_info=True) raise RuntimeError("Failed to initialize embedding handler") from e def generate_embeddings(self, texts: List[str]) -> np.ndarray: """ Generate embeddings for a list of text strings. Args: texts: List of text strings to embed Returns: np.ndarray: Array of embeddings (2D numpy array) Raises: Exception: If embedding generation fails """ try: return self.model.encode( texts, show_progress_bar=True, # Visual progress indicator batch_size=32, # Optimal batch size for most GPUs convert_to_numpy=True # Return as numpy array for efficiency ) except Exception as e: self.logger.error(f"Error generating embeddings: {str(e)}", exc_info=True) raise def create_collection(self, collection_name: str) -> bool: """ Create a new Qdrant collection for storing vectors. Args: collection_name: Name of the collection to create Returns: bool: True if collection was created or already exists Raises: Exception: If collection creation fails (except for already exists case) """ try: self.qdrant_client.create_collection( collection_name=collection_name, vectors_config=VectorParams( size=self.embedding_dim, # Must match model's embedding dimension distance=Distance.COSINE # Using cosine similarity ) ) self.logger.info(f"Created collection {collection_name}") return True except UnexpectedResponse as e: # Handle case where collection already exists if "already exists" in str(e): self.logger.info(f"Collection {collection_name} already exists") return True else: self.logger.error(f"Error creating collection: {e}") raise except Exception as e: self.logger.error(f"Error creating collection: {str(e)}", exc_info=True) raise def add_to_collection(self, collection_name: str, embeddings: np.ndarray, payloads: List[dict]) -> bool: """ Add embeddings and associated metadata to a Qdrant collection. Args: collection_name: Target collection name embeddings: Numpy array of embeddings to add payloads: List of metadata dictionaries corresponding to each embedding Returns: bool: True if operation succeeded Raises: Exception: If operation fails """ try: # Convert numpy arrays to lists for Qdrant compatibility if isinstance(embeddings, np.ndarray): embeddings = embeddings.tolist() # Prepare points in batches for efficient processing batch_size = 100 # Optimal batch size for Qdrant Cloud points = [ PointStruct( id=idx, # Sequential ID vector=embedding, payload=payload # Associated metadata ) for idx, (embedding, payload) in enumerate(zip(embeddings, payloads)) ] # Process in batches to avoid overwhelming the server for i in range(0, len(points), batch_size): batch = points[i:i + batch_size] self.qdrant_client.upsert( collection_name=collection_name, points=batch, wait=True # Ensure immediate persistence ) self.logger.info(f"Added {len(points)} vectors to collection {collection_name}") return True except Exception as e: self.logger.error(f"Error adding to collection: {str(e)}", exc_info=True) raise async def search_collection(self, collection_name: str, query: str, k: int = 5) -> Dict: """ Search a Qdrant collection for similar vectors to the query. Args: collection_name: Name of collection to search query: Text query to search for k: Number of similar results to return (default: 5) Returns: Dict: { "status": "success"|"error", "results": List[Dict] (if success), "message": str (if error) } """ try: # Generate embedding for the query text query_embedding = self.model.encode(query).tolist() # Perform similarity search in Qdrant results = self.qdrant_client.search( collection_name=collection_name, query_vector=query_embedding, limit=k, # Number of results to return with_payload=True, # Include metadata with_vectors=False # Exclude raw vectors to save bandwidth ) # Format results for consistent API response formatted_results = [] for hit in results: formatted_results.append({ "id": hit.id, "score": float(hit.score), # Similarity score "payload": hit.payload or {}, # Associated metadata "text": hit.payload.get("text", "") if hit.payload else "" # Extracted text }) return { "status": "success", "results": formatted_results } except Exception as e: self.logger.error(f"Search error: {str(e)}", exc_info=True) return { "status": "error", "message": str(e), "results": [] } # Deprecated FAISS methods (maintained for backward compatibility) def create_faiss_index(self, *args, **kwargs): """Deprecated method - FAISS support has been replaced by Qdrant.""" self.logger.warning("FAISS operations are deprecated") raise NotImplementedError("Use Qdrant collections instead of FAISS") def save_index(self, *args, **kwargs): """Deprecated method - Qdrant persists data automatically.""" self.logger.warning("FAISS operations are deprecated") raise NotImplementedError("Qdrant persists data automatically") def load_index(self, *args, **kwargs): """Deprecated method - Access Qdrant collections directly.""" self.logger.warning("FAISS operations are deprecated") raise NotImplementedError("Access Qdrant collections directly")