# DEPENDENCIES import numpy as np from typing import List from typing import Optional from numpy.typing import NDArray from config.settings import get_settings from config.logging_config import get_logger from utils.error_handler import handle_errors from utils.error_handler import EmbeddingError from chunking.token_counter import get_token_counter from sentence_transformers import SentenceTransformer from utils.helpers import BatchProcessor as BaseBatchProcessor # Setup Settings and Logging settings = get_settings() logger = get_logger(__name__) class BatchProcessor: """ Efficient batch processing for embeddings: Handles large batches with memory optimization and progress tracking """ def __init__(self): self.logger = logger self.base_processor = BaseBatchProcessor() # Batch processing statistics self.total_batches = 0 self.total_texts = 0 self.failed_batches = 0 @handle_errors(error_type = EmbeddingError, log_error = True, reraise = True) def process_embeddings_batch(self, model: SentenceTransformer, texts: List[str], batch_size: Optional[int] = None, normalize: bool = True, **kwargs) -> List[NDArray]: """ Process embeddings in optimized batches Arguments: ---------- model { SentenceTransformer } : Embedding model texts { list } : List of texts to embed batch_size { int } : Batch size (default from settings) normalize { bool } : Normalize embeddings **kwargs : Additional model.encode parameters Returns: -------- { list } : List of embedding vectors """ if not texts: return [] batch_size = batch_size or settings.EMBEDDING_BATCH_SIZE self.logger.debug(f"Processing {len(texts)} texts in batches of {batch_size}") try: # Use model's built-in batching with optimization embeddings = model.encode(texts, batch_size = batch_size, normalize_embeddings = normalize, show_progress_bar = False, convert_to_numpy = True, **kwargs ) # Update statistics self.total_batches += ((len(texts) + batch_size - 1) // batch_size) self.total_texts += len(texts) self.logger.debug(f"Successfully generated {len(embeddings)} embeddings") # Convert to list of arrays return list(embeddings) except Exception as e: self.failed_batches += 1 self.logger.error(f"Batch embedding failed: {repr(e)}") raise EmbeddingError(f"Batch processing failed: {repr(e)}") def process_embeddings_with_fallback(self, model: SentenceTransformer, texts: List[str], batch_size: Optional[int] = None, normalize: bool = True) -> List[NDArray]: """ Process embeddings with automatic batch size reduction on failure Arguments: ---------- model { SentenceTransformer } : Embedding model texts { list } : List of texts batch_size { int } : Initial batch size normalize { bool } : Normalize embeddings Returns: -------- { list } : List of embeddings """ batch_size = batch_size or settings.EMBEDDING_BATCH_SIZE try: return self.process_embeddings_batch(model = model, texts = texts, batch_size = batch_size, normalize = normalize, ) except (MemoryError, RuntimeError) as e: self.logger.warning(f"Batch size {batch_size} failed, reducing to {batch_size // 2}") # Reduce batch size and retry return self.process_embeddings_batch(model = model, texts = texts, batch_size = batch_size // 2, normalize = normalize, ) def split_into_optimal_batches(self, texts: List[str], target_batch_size: int, max_batch_size: int = 1000) -> List[List[str]]: """ Split texts into optimal batches considering token counts Arguments: ---------- texts { list } : List of texts target_batch_size { int } : Target batch size in texts max_batch_size { int } : Maximum batch size to allow Returns: -------- { list } : List of text batches """ if not texts: return [] token_counter = get_token_counter() batches = list() current_batch = list() current_tokens = 0 # Estimate tokens per text (average of first 10 or all if less) sample_size = min(10, len(texts)) sample_tokens = [token_counter.count_tokens(text) for text in texts[:sample_size]] avg_tokens = sum(sample_tokens) / len(sample_tokens) if sample_tokens else 100 # Target tokens per batch (approximate) target_tokens = target_batch_size * avg_tokens for text in texts: text_tokens = token_counter.count_tokens(text) # If single text is too large, put it in its own batch if (text_tokens > (target_tokens * 0.8)): if current_batch: batches.append(current_batch) current_batch = list() current_tokens = 0 batches.append([text]) continue # Check if adding this text would exceed limits if (((current_tokens + text_tokens) > target_tokens) and current_batch) or (len(current_batch) >= max_batch_size): batches.append(current_batch) current_batch = list() current_tokens = 0 current_batch.append(text) current_tokens += text_tokens # Add final batch if current_batch: batches.append(current_batch) self.logger.debug(f"Split {len(texts)} texts into {len(batches)} optimal batches") return batches def process_batches_with_progress(self, model: SentenceTransformer, texts: List[str], batch_size: Optional[int] = None, progress_callback: Optional[callable] = None, **kwargs) -> List[NDArray]: """ Process batches with progress reporting Arguments: ---------- model { SentenceTransformer } : Embedding model texts { list } : List of texts batch_size { int } : Batch size progress_callback { callable } : Callback for progress updates **kwargs : Additional parameters Returns: -------- { list } : List of embeddings """ if not texts: return [] batch_size = batch_size or settings.EMBEDDING_BATCH_SIZE # Split into batches batches = self.split_into_optimal_batches(texts = texts, target_batch_size = batch_size, ) all_embeddings = list() for i, batch_texts in enumerate(batches): if progress_callback: progress = (i / len(batches)) * 100 progress_callback(progress, f"Processing batch {i + 1}/{len(batches)}") try: batch_embeddings = self.process_embeddings_batch(model = model, texts = batch_texts, batch_size = len(batch_texts), **kwargs ) all_embeddings.extend(batch_embeddings) self.logger.debug(f"Processed batch {i + 1}/{len(batches)}: {len(batch_texts)} texts") except Exception as e: self.logger.error(f"Failed to process batch {i + 1}: {repr(e)}") # Add None placeholders for failed batch all_embeddings.extend([None] * len(batch_texts)) if progress_callback: progress_callback(100, "Embedding complete") return all_embeddings def validate_embeddings_batch(self, embeddings: List[NDArray], expected_count: int) -> bool: """ Validate a batch of embeddings Arguments: ---------- embeddings { list } : List of embedding vectors expected_count { int } : Expected number of embeddings Returns: -------- { bool } : True if valid """ if (len(embeddings) != expected_count): self.logger.error(f"Embedding count mismatch: expected {expected_count}, got {len(embeddings)}") return False valid_count = 0 for i, emb in enumerate(embeddings): if emb is None: self.logger.warning(f"None embedding at index {i}") continue if not isinstance(emb, np.ndarray): self.logger.warning(f"Invalid embedding type at index {i}: {type(emb)}") continue if (emb.ndim != 1): self.logger.warning(f"Invalid embedding dimension at index {i}: {emb.ndim}") continue if np.any(np.isnan(emb)): self.logger.warning(f"NaN values in embedding at index {i}") continue valid_count += 1 validity_ratio = valid_count / expected_count if (validity_ratio < 0.9): self.logger.warning(f"Low embedding validity: {valid_count}/{expected_count} ({validity_ratio:.1%})") return False return True def get_processing_stats(self) -> dict: """ Get batch processing statistics Returns: -------- { dict } : Statistics dictionary """ success_rate = ((self.total_batches - self.failed_batches) / self.total_batches * 100) if (self.total_batches > 0) else 100 stats = {"total_batches" : self.total_batches, "total_texts" : self.total_texts, "failed_batches" : self.failed_batches, "success_rate" : success_rate, "avg_batch_size" : self.total_texts / self.total_batches if (self.total_batches > 0) else 0, } return stats def reset_stats(self): """ Reset processing statistics """ self.total_batches = 0 self.total_texts = 0 self.failed_batches = 0 self.logger.debug("Reset batch processing statistics") # Global batch processor instance _batch_processor = None def get_batch_processor() -> BatchProcessor: """ Get global batch processor instance Returns: -------- { BatchProcessor } : BatchProcessor instance """ global _batch_processor if _batch_processor is None: _batch_processor = BatchProcessor() return _batch_processor def process_embeddings_batch(model: SentenceTransformer, texts: List[str], **kwargs) -> List[NDArray]: """ Convenience function for batch embedding Arguments: ---------- model { SentenceTransformer } : Embedding model texts { list } : List of texts **kwargs : Additional arguments Returns: -------- { list } : List of embeddings """ processor = get_batch_processor() return processor.process_embeddings_batch(model, texts, **kwargs)