ConvoBot / src /embeddings.py
ashish-ninehertz
changes
e272f4f
import os
from sentence_transformers import SentenceTransformer
import numpy as np
import logging
from typing import List, Dict, Optional
from app.config import Config
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from qdrant_client.http.exceptions import UnexpectedResponse
class EmbeddingHandler:
"""
Handles all embedding-related operations including:
- Text embedding generation using SentenceTransformers
- Vector storage and retrieval with Qdrant
- Collection management for vector storage
This serves as the central component for vector operations in the RAG system.
"""
def __init__(self):
"""Initialize the embedding handler with model and vector store client."""
self.logger = logging.getLogger(__name__)
try:
# Initialize embedding model with configuration from Config
self.model = SentenceTransformer(Config.EMBEDDING_MODEL)
# Get embedding dimension from the model
self.embedding_dim = self.model.get_sentence_embedding_dimension()
# Initialize Qdrant client with configuration from Config
self.qdrant_client = QdrantClient(
url=Config.QDRANT_URL,
api_key=Config.QDRANT_API_KEY,
prefer_grpc=False, # HTTP preferred over gRPC for compatibility
timeout=30 # Connection timeout in seconds
)
# Connection test can be uncommented for local development
# self._verify_connection()
except Exception as e:
self.logger.error(f"Error initializing embedding handler: {str(e)}", exc_info=True)
raise RuntimeError("Failed to initialize embedding handler") from e
def generate_embeddings(self, texts: List[str]) -> np.ndarray:
"""
Generate embeddings for a list of text strings.
Args:
texts: List of text strings to embed
Returns:
np.ndarray: Array of embeddings (2D numpy array)
Raises:
Exception: If embedding generation fails
"""
try:
return self.model.encode(
texts,
show_progress_bar=True, # Visual progress indicator
batch_size=32, # Optimal batch size for most GPUs
convert_to_numpy=True # Return as numpy array for efficiency
)
except Exception as e:
self.logger.error(f"Error generating embeddings: {str(e)}", exc_info=True)
raise
def create_collection(self, collection_name: str) -> bool:
"""
Create a new Qdrant collection for storing vectors.
Args:
collection_name: Name of the collection to create
Returns:
bool: True if collection was created or already exists
Raises:
Exception: If collection creation fails (except for already exists case)
"""
try:
self.qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=self.embedding_dim, # Must match model's embedding dimension
distance=Distance.COSINE # Using cosine similarity
)
)
self.logger.info(f"Created collection {collection_name}")
return True
except UnexpectedResponse as e:
# Handle case where collection already exists
if "already exists" in str(e):
self.logger.info(f"Collection {collection_name} already exists")
return True
else:
self.logger.error(f"Error creating collection: {e}")
raise
except Exception as e:
self.logger.error(f"Error creating collection: {str(e)}", exc_info=True)
raise
def add_to_collection(self, collection_name: str, embeddings: np.ndarray, payloads: List[dict]) -> bool:
"""
Add embeddings and associated metadata to a Qdrant collection.
Args:
collection_name: Target collection name
embeddings: Numpy array of embeddings to add
payloads: List of metadata dictionaries corresponding to each embedding
Returns:
bool: True if operation succeeded
Raises:
Exception: If operation fails
"""
try:
# Convert numpy arrays to lists for Qdrant compatibility
if isinstance(embeddings, np.ndarray):
embeddings = embeddings.tolist()
# Prepare points in batches for efficient processing
batch_size = 100 # Optimal batch size for Qdrant Cloud
points = [
PointStruct(
id=idx, # Sequential ID
vector=embedding,
payload=payload # Associated metadata
)
for idx, (embedding, payload) in enumerate(zip(embeddings, payloads))
]
# Process in batches to avoid overwhelming the server
for i in range(0, len(points), batch_size):
batch = points[i:i + batch_size]
self.qdrant_client.upsert(
collection_name=collection_name,
points=batch,
wait=True # Ensure immediate persistence
)
self.logger.info(f"Added {len(points)} vectors to collection {collection_name}")
return True
except Exception as e:
self.logger.error(f"Error adding to collection: {str(e)}", exc_info=True)
raise
async def search_collection(self, collection_name: str, query: str, k: int = 5) -> Dict:
"""
Search a Qdrant collection for similar vectors to the query.
Args:
collection_name: Name of collection to search
query: Text query to search for
k: Number of similar results to return (default: 5)
Returns:
Dict: {
"status": "success"|"error",
"results": List[Dict] (if success),
"message": str (if error)
}
"""
try:
# Generate embedding for the query text
query_embedding = self.model.encode(query).tolist()
# Perform similarity search in Qdrant
results = self.qdrant_client.search(
collection_name=collection_name,
query_vector=query_embedding,
limit=k, # Number of results to return
with_payload=True, # Include metadata
with_vectors=False # Exclude raw vectors to save bandwidth
)
# Format results for consistent API response
formatted_results = []
for hit in results:
formatted_results.append({
"id": hit.id,
"score": float(hit.score), # Similarity score
"payload": hit.payload or {}, # Associated metadata
"text": hit.payload.get("text", "") if hit.payload else "" # Extracted text
})
return {
"status": "success",
"results": formatted_results
}
except Exception as e:
self.logger.error(f"Search error: {str(e)}", exc_info=True)
return {
"status": "error",
"message": str(e),
"results": []
}
# Deprecated FAISS methods (maintained for backward compatibility)
def create_faiss_index(self, *args, **kwargs):
"""Deprecated method - FAISS support has been replaced by Qdrant."""
self.logger.warning("FAISS operations are deprecated")
raise NotImplementedError("Use Qdrant collections instead of FAISS")
def save_index(self, *args, **kwargs):
"""Deprecated method - Qdrant persists data automatically."""
self.logger.warning("FAISS operations are deprecated")
raise NotImplementedError("Qdrant persists data automatically")
def load_index(self, *args, **kwargs):
"""Deprecated method - Access Qdrant collections directly."""
self.logger.warning("FAISS operations are deprecated")
raise NotImplementedError("Access Qdrant collections directly")