# DEPENDENCIES import time import numpy as np from typing import List from typing import Optional from numpy.typing import NDArray from config.models import DocumentChunk from config.settings import get_settings from config.models import EmbeddingRequest from config.models import EmbeddingResponse from config.logging_config import get_logger from utils.error_handler import handle_errors from utils.error_handler import EmbeddingError from embeddings.model_loader import get_model_loader from embeddings.batch_processor import BatchProcessor # Setup Settings and Logging settings = get_settings() logger = get_logger(__name__) class BGEEmbedder: """ BGE (BAAI General Embedding) model wrapper: Optimized for BAAI/bge models with proper normalization and batching """ def __init__(self, model_name: Optional[str] = None, device: Optional[str] = None): """ Initialize BGE embedder Arguments: ---------- model_name { str } : BGE model name (default from settings) device { str } : Device to run on """ self.logger = logger self.model_name = model_name or settings.EMBEDDING_MODEL self.device = device # Initialize components self.model_loader = get_model_loader() self.batch_processor = BatchProcessor() # Load model self.model = self.model_loader.load_model(model_name = self.model_name, device = self.device, ) # Get model info self.embedding_dim = self.model.get_sentence_embedding_dimension() self.supports_batch = True self.logger.info(f"Initialized BGEEmbedder: model={self.model_name}, dim={self.embedding_dim}, device={self.model.device}") @handle_errors(error_type = EmbeddingError, log_error = True, reraise = True) def embed_text(self, text: str, normalize: bool = True) -> NDArray: """ Embed single text string Arguments: ---------- text { str } : Input text normalize { bool } : Normalize embeddings to unit length Returns: -------- { NDArray } : Embedding vector """ if not text or not text.strip(): raise EmbeddingError("Cannot embed empty text") try: # Encode single text embedding = self.model.encode([text], normalize_embeddings = normalize, show_progress_bar = False, ) # Return single vector return embedding[0] except Exception as e: self.logger.error(f"Failed to embed text: {repr(e)}") raise EmbeddingError(f"Text embedding failed: {repr(e)}") @handle_errors(error_type = EmbeddingError, log_error = True, reraise = True) def embed_texts(self, texts: List[str], batch_size: Optional[int] = None, normalize: bool = True) -> List[NDArray]: """ Embed multiple texts with batching Arguments: ---------- texts { list } : List of text strings batch_size { int } : Batch size (default from settings) normalize { bool } : Normalize embeddings Returns: -------- { list } : List of embedding vectors """ if not texts: return [] # Filter empty texts valid_texts = [t for t in texts if t and t.strip()] if (len(valid_texts) != len(texts)): self.logger.warning(f"Filtered {len(texts) - len(valid_texts)} empty texts") if not valid_texts: return [] batch_size = batch_size or settings.EMBEDDING_BATCH_SIZE try: # Use batch processing for efficiency embeddings = self.batch_processor.process_embeddings_batch(model = self.model, texts = valid_texts, batch_size = batch_size, normalize = normalize, ) self.logger.debug(f"Generated {len(embeddings)} embeddings for {len(texts)} texts") return embeddings except Exception as e: self.logger.error(f"Batch embedding failed: {repr(e)}") raise EmbeddingError(f"Batch embedding failed: {repr(e)}") @handle_errors(error_type = EmbeddingError, log_error = True, reraise = True) def embed_chunks(self, chunks: List[DocumentChunk], batch_size: Optional[int] = None, normalize: bool = True) -> List[DocumentChunk]: """ Embed document chunks and update them with embeddings Arguments: ---------- chunks { list } : List of DocumentChunk objects batch_size { int } : Batch size normalize { bool } : Normalize embeddings Returns: -------- { list } : Chunks with embeddings added """ if not chunks: return [] # Extract texts from chunks texts = [chunk.text for chunk in chunks] # Generate embeddings embeddings = self.embed_texts(texts = texts, batch_size = batch_size, normalize = normalize, ) # Update chunks with embeddings for chunk, embedding in zip(chunks, embeddings): # Convert numpy to list for serialization chunk.embedding = embedding.tolist() self.logger.info(f"Embedded {len(chunks)} document chunks") return chunks def process_embedding_request(self, request: EmbeddingRequest) -> EmbeddingResponse: """ Process embedding request from API Arguments: ---------- request { EmbeddingRequest } : Embedding request Returns: -------- { EmbeddingResponse } : Embedding response """ start_time = time.time() # Generate embeddings embeddings = self.embed_texts(texts = request.texts, batch_size = request.batch_size, normalize = request.normalize, ) # Convert to milliseconds processing_time = (time.time() - start_time) * 1000 # Convert to list for serialization embedding_list = [emb.tolist() for emb in embeddings] response = EmbeddingResponse(embeddings = embedding_list, dimension = self.embedding_dim, num_embeddings = len(embeddings), processing_time_ms = processing_time, ) return response def get_embedding_dimension(self) -> int: """ Get embedding dimension Returns: -------- { int } : Embedding vector dimension """ return self.embedding_dim def cosine_similarity(self, emb1: NDArray, emb2: NDArray) -> float: """ Calculate cosine similarity between two embeddings Arguments: ---------- emb1 { NDArray } : First embedding emb2 { NDArray } : Second embedding Returns: -------- { float } : Cosine similarity (-1 to 1) """ # Ensure embeddings are normalized emb1_norm = emb1 / np.linalg.norm(emb1) emb2_norm = emb2 / np.linalg.norm(emb2) return float(np.dot(emb1_norm, emb2_norm)) def validate_embedding(self, embedding: NDArray) -> bool: """ Validate embedding vector Arguments: ---------- embedding { NDArray } : Embedding vector Returns: -------- { bool } : True if valid """ if (embedding is None): return False if (not isinstance(embedding, np.ndarray)): return False if (embedding.shape != (self.embedding_dim,)): return False if (np.all(embedding == 0)): return False if (np.any(np.isnan(embedding))): return False return True def get_model_info(self) -> dict: """ Get embedder information Returns: -------- { dict } : Embedder information """ return {"model_name" : self.model_name, "embedding_dim" : self.embedding_dim, "device" : str(self.model.device), "supports_batch" : self.supports_batch, "normalize_default" : True, } # Global embedder instance _embedder = None def get_embedder(model_name: Optional[str] = None, device: Optional[str] = None) -> BGEEmbedder: """ Get global embedder instance Arguments: ---------- model_name { str } : Model name device { str } : Device Returns: -------- { BGEEmbedder } : BGEEmbedder instance """ global _embedder if _embedder is None: _embedder = BGEEmbedder(model_name, device) return _embedder def embed_texts(texts: List[str], **kwargs) -> List[NDArray]: """ Convenience function to embed texts Arguments: ---------- texts { list } : List of texts **kwargs : Additional arguments for BGEEmbedder Returns: -------- { list } : List of embeddings """ embedder = get_embedder() return embedder.embed_texts(texts, **kwargs) def embed_chunks(chunks: List[DocumentChunk], **kwargs) -> List[DocumentChunk]: """ Convenience function to embed document chunks Arguments: ---------- chunks { list } : List of DocumentChunk objects **kwargs : Additional arguments Returns: -------- { list } : Chunks with embeddings """ embedder = get_embedder() return embedder.embed_chunks(chunks, **kwargs)