# DEPENDENCIES import numpy as np from typing import List from typing import Optional from numpy.typing import NDArray from config.settings import get_settings from config.logging_config import get_logger from utils.error_handler import handle_errors from utils.cache_manager import EmbeddingCache as BaseEmbeddingCache # Setup Settings and Logging settings = get_settings() logger = get_logger(__name__) class EmbeddingCache: """ Embedding cache with numpy array support and statistics: Wraps the base cache with embedding-specific features """ def __init__(self, max_size: int = None, ttl: int = None): """ Initialize embedding cache Arguments: ---------- max_size { int } : Maximum cache size ttl { int } : Time to live in seconds """ self.logger = logger self.max_size = max_size or settings.CACHE_MAX_SIZE self.ttl = ttl or settings.CACHE_TTL # Initialize base cache self.base_cache = BaseEmbeddingCache(max_size = self.max_size, ttl = self.ttl, ) # Enhanced statistics self.hits = 0 self.misses = 0 self.embeddings_generated = 0 self.logger.info(f"Initialized EmbeddingCache: max_size={self.max_size}, ttl={self.ttl}") def get_embedding(self, text: str) -> Optional[NDArray]: """ Get embedding from cache Arguments: ---------- text { str } : Input text Returns: -------- { NDArray } : Cached embedding or None """ cached = self.base_cache.get_embedding(text) if cached is not None: self.hits += 1 # Convert list back to numpy array return np.array(cached) else: self.misses += 1 return None def set_embedding(self, text: str, embedding: NDArray): """ Store embedding in cache Arguments: ---------- text { str } : Input text embedding { NDArray } : Embedding vector """ # Convert numpy array to list for serialization embedding_list = embedding.tolist() self.base_cache.set_embedding(text, embedding_list) self.embeddings_generated += 1 def batch_get_embeddings(self, texts: List[str]) -> tuple[List[Optional[NDArray]], List[str]]: """ Get multiple embeddings from cache Arguments: ---------- texts { list } : List of texts Returns: -------- { tuple } : Tuple of (cached_embeddings, missing_texts) """ cached_embeddings = list() missing_texts = list() for text in texts: embedding = self.get_embedding(text) if embedding is not None: cached_embeddings.append(embedding) else: missing_texts.append(text) cached_embeddings.append(None) return cached_embeddings, missing_texts def batch_set_embeddings(self, texts: List[str], embeddings: List[NDArray]): """ Store multiple embeddings in cache Arguments: ---------- texts { list } : List of texts embeddings { list } : List of embedding vectors """ if (len(texts) != len(embeddings)): raise ValueError("Texts and embeddings must have same length") for text, embedding in zip(texts, embeddings): self.set_embedding(text, embedding) def get_cached_embeddings(self, texts: List[str], embed_function: callable, batch_size: Optional[int] = None) -> List[NDArray]: """ Smart embedding getter: uses cache for existing, generates for missing Arguments: ---------- texts { list } : List of texts embed_function { callable } : Function to generate embeddings for missing texts batch_size { int } : Batch size for generation Returns: -------- { list } : List of embeddings """ # Get cached embeddings cached_embeddings, missing_texts = self.batch_get_embeddings(texts = texts) if not missing_texts: self.logger.debug(f"All {len(texts)} embeddings found in cache") return cached_embeddings # Generate missing embeddings self.logger.info(f"Generating {len(missing_texts)} embeddings ({(len(missing_texts)/len(texts))*100:.1f}% cache miss)") missing_embeddings = embed_function(missing_texts, batch_size = batch_size) # Store new embeddings in cache self.batch_set_embeddings(missing_texts, missing_embeddings) # Combine results result_embeddings = list() missing_idx = 0 for emb in cached_embeddings: if emb is not None: result_embeddings.append(emb) else: result_embeddings.append(missing_embeddings[missing_idx]) missing_idx += 1 return result_embeddings def clear(self): """ Clear entire cache """ self.base_cache.clear() self.hits = 0 self.misses = 0 self.embeddings_generated = 0 self.logger.info("Cleared embedding cache") def get_stats(self) -> dict: """ Get cache statistics Returns: -------- { dict } : Statistics dictionary """ base_stats = self.base_cache.get_stats() total_requests = self.hits + self.misses hit_rate = (self.hits / total_requests * 100) if (total_requests > 0) else 0 stats = {**base_stats, "hits" : self.hits, "misses" : self.misses, "hit_rate_percentage" : hit_rate, "embeddings_generated" : self.embeddings_generated, "cache_size" : self.base_cache.cache.size(), "max_size" : self.max_size, } return stats def save_to_file(self, file_path: str) -> bool: """ Save cache to file Arguments: ---------- file_path { str } : Path to save file Returns: -------- { bool } : True if successful """ return self.base_cache.save_to_file(file_path) def load_from_file(self, file_path: str) -> bool: """ Load cache from file Arguments: ---------- file_path { str } : Path to load file Returns: -------- { bool } : True if successful """ return self.base_cache.load_from_file(file_path) def warm_cache(self, texts: List[str], embed_function: callable, batch_size: Optional[int] = None): """ Pre-populate cache with embeddings Arguments: ---------- texts { list } : List of texts to warm cache with embed_function { callable } : Embedding generation function batch_size { int } : Batch size """ # Check which texts are not in cache _, missing_texts = self.batch_get_embeddings(texts = texts) if not missing_texts: self.logger.info("Cache already warm for all texts") return self.logger.info(f"Warming cache with {len(missing_texts)} embeddings") # Generate and cache embeddings embeddings = embed_function(missing_texts, batch_size = batch_size) self.batch_set_embeddings(missing_texts, embeddings) self.logger.info(f"Cache warming complete: added {len(missing_texts)} embeddings") # Global embedding cache instance _embedding_cache = None def get_embedding_cache() -> EmbeddingCache: """ Get global embedding cache instance Returns: -------- { EmbeddingCache } : EmbeddingCache instance """ global _embedding_cache if _embedding_cache is None: _embedding_cache = EmbeddingCache() return _embedding_cache def cache_embeddings(texts: List[str], embeddings: List[NDArray]): """ Convenience function to cache embeddings Arguments: ---------- texts { list } : List of texts embeddings { list } : List of embeddings """ cache = get_embedding_cache() cache.batch_set_embeddings(texts, embeddings) def get_cached_embeddings(texts: List[str], embed_function: callable, **kwargs) -> List[NDArray]: """ Convenience function to get cached embeddings Arguments: ---------- texts { list } : List of texts embed_function { callable } : Embedding function **kwargs : Additional arguments Returns: -------- { list } : List of embeddings """ cache = get_embedding_cache() return cache.get_cached_embeddings(texts, embed_function, **kwargs)