| import logging |
| import asyncio |
| from typing import List, Optional, Dict, Any |
| import numpy as np |
| from sentence_transformers import SentenceTransformer |
| import torch |
| import config |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class EmbeddingService: |
| def __init__(self): |
| self.config = config.config |
| self.model_name = self.config.EMBEDDING_MODEL |
| self.model = None |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| |
| |
| self._load_model() |
| |
| def _load_model(self): |
| """Load the embedding model""" |
| try: |
| logger.info(f"Loading embedding model: {self.model_name}") |
| self.model = SentenceTransformer(self.model_name, device=self.device) |
| logger.info(f"Embedding model loaded successfully on {self.device}") |
| except Exception as e: |
| logger.error(f"Failed to load embedding model: {str(e)}") |
| |
| try: |
| self.model_name = "all-MiniLM-L6-v2" |
| self.model = SentenceTransformer(self.model_name, device=self.device) |
| logger.info(f"Loaded fallback embedding model: {self.model_name}") |
| except Exception as fallback_error: |
| logger.error(f"Failed to load fallback model: {str(fallback_error)}") |
| raise |
| |
| async def generate_embeddings(self, texts: List[str], batch_size: int = 32) -> List[List[float]]: |
| """Generate embeddings for a list of texts""" |
| if not texts: |
| return [] |
| |
| if self.model is None: |
| raise RuntimeError("Embedding model not loaded") |
| |
| try: |
| |
| non_empty_texts = [text for text in texts if text and text.strip()] |
| if not non_empty_texts: |
| logger.warning("No non-empty texts provided for embedding") |
| return [] |
| |
| logger.info(f"Generating embeddings for {len(non_empty_texts)} texts") |
| |
| |
| all_embeddings = [] |
| for i in range(0, len(non_empty_texts), batch_size): |
| batch = non_empty_texts[i:i + batch_size] |
| |
| |
| loop = asyncio.get_event_loop() |
| batch_embeddings = await loop.run_in_executor( |
| None, |
| self._generate_batch_embeddings, |
| batch |
| ) |
| all_embeddings.extend(batch_embeddings) |
| |
| logger.info(f"Generated {len(all_embeddings)} embeddings") |
| return all_embeddings |
| |
| except Exception as e: |
| logger.error(f"Error generating embeddings: {str(e)}") |
| raise |
| |
| def _generate_batch_embeddings(self, texts: List[str]) -> List[List[float]]: |
| """Generate embeddings for a batch of texts (synchronous)""" |
| try: |
| |
| embeddings = self.model.encode( |
| texts, |
| convert_to_numpy=True, |
| normalize_embeddings=True, |
| batch_size=len(texts) |
| ) |
| |
| |
| return embeddings.tolist() |
| except Exception as e: |
| logger.error(f"Error in batch embedding generation: {str(e)}") |
| raise |
| |
| async def generate_single_embedding(self, text: str) -> Optional[List[float]]: |
| """Generate embedding for a single text""" |
| if not text or not text.strip(): |
| return None |
| |
| try: |
| embeddings = await self.generate_embeddings([text]) |
| return embeddings[0] if embeddings else None |
| except Exception as e: |
| logger.error(f"Error generating single embedding: {str(e)}") |
| return None |
| |
| def get_embedding_dimension(self) -> int: |
| """Get the dimension of embeddings produced by the model""" |
| if self.model is None: |
| raise RuntimeError("Embedding model not loaded") |
| |
| return self.model.get_sentence_embedding_dimension() |
| |
| def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float: |
| """Compute cosine similarity between two embeddings""" |
| try: |
| |
| emb1 = np.array(embedding1) |
| emb2 = np.array(embedding2) |
| |
| |
| similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2)) |
| |
| return float(similarity) |
| except Exception as e: |
| logger.error(f"Error computing similarity: {str(e)}") |
| return 0.0 |
| |
| def compute_similarities(self, query_embedding: List[float], embeddings: List[List[float]]) -> List[float]: |
| """Compute similarities between a query embedding and multiple embeddings""" |
| try: |
| query_emb = np.array(query_embedding) |
| emb_matrix = np.array(embeddings) |
| |
| |
| similarities = np.dot(emb_matrix, query_emb) / ( |
| np.linalg.norm(emb_matrix, axis=1) * np.linalg.norm(query_emb) |
| ) |
| |
| return similarities.tolist() |
| except Exception as e: |
| logger.error(f"Error computing similarities: {str(e)}") |
| return [0.0] * len(embeddings) |
| |
| async def embed_chunks(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| """Embed a list of chunks and add embeddings to them""" |
| if not chunks: |
| return [] |
| |
| try: |
| |
| texts = [chunk.get('content', '') for chunk in chunks] |
| |
| |
| embeddings = await self.generate_embeddings(texts) |
| |
| |
| embedded_chunks = [] |
| for i, chunk in enumerate(chunks): |
| if i < len(embeddings): |
| chunk_copy = chunk.copy() |
| chunk_copy['embedding'] = embeddings[i] |
| embedded_chunks.append(chunk_copy) |
| else: |
| logger.warning(f"No embedding generated for chunk {i}") |
| embedded_chunks.append(chunk) |
| |
| return embedded_chunks |
| except Exception as e: |
| logger.error(f"Error embedding chunks: {str(e)}") |
| raise |
| |
| def validate_embedding(self, embedding: List[float]) -> bool: |
| """Validate that an embedding is properly formatted""" |
| try: |
| if not embedding: |
| return False |
| |
| if not isinstance(embedding, list): |
| return False |
| |
| if len(embedding) != self.get_embedding_dimension(): |
| return False |
| |
| |
| emb_array = np.array(embedding) |
| if np.isnan(emb_array).any() or np.isinf(emb_array).any(): |
| return False |
| |
| return True |
| except Exception: |
| return False |
| |
| async def get_model_info(self) -> Dict[str, Any]: |
| """Get information about the loaded model""" |
| try: |
| return { |
| "model_name": self.model_name, |
| "device": self.device, |
| "embedding_dimension": self.get_embedding_dimension(), |
| "max_sequence_length": getattr(self.model, 'max_seq_length', 'unknown'), |
| "model_loaded": self.model is not None |
| } |
| except Exception as e: |
| logger.error(f"Error getting model info: {str(e)}") |
| return {"error": str(e)} |