""" Semantic Caching Layer - Intelligent caching using embedding similarity """ import hashlib import json import logging import time from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple logger = logging.getLogger(__name__) @dataclass class CacheEntry: """A single cache entry""" key: str query: str embedding: List[float] result: Any metadata: Dict[str, Any] created_at: float ttl: int @property def is_expired(self) -> bool: return time.time() - self.created_at > self.ttl @property def age_seconds(self) -> float: return time.time() - self.created_at @dataclass class CacheHit: """Result of a cache hit""" result: Any similarity: float age_seconds: float savings_tokens: int savings_cost: float class SemanticCache: """ In-memory semantic cache using embedding similarity. Caches: Tool execution results, LLM completions, web search results, code execution outputs """ def __init__( self, similarity_threshold: float = 0.92, default_ttl: int = 604800, # 7 days max_size: int = 10000 ): self.similarity_threshold = similarity_threshold self.default_ttl = default_ttl self.max_size = max_size self._cache: Dict[str, CacheEntry] = {} self._hits = 0 self._misses = 0 self._tokens_saved = 0 self._cost_saved = 0.0 def _compute_embedding(self, text: str) -> List[float]: """ Compute a simple embedding for the text. In production, use OpenAI or sentence-transformers. For now, use a simple hash-based approach. """ # Simple character n-gram based embedding text = text.lower().strip() ngrams = [] for i in range(len(text) - 2): ngrams.append(text[i:i+3]) # Create a 128-dimensional embedding embedding = [0.0] * 128 for ngram in ngrams: idx = hash(ngram) % 128 embedding[idx] += 1.0 # Normalize norm = sum(x**2 for x in embedding) ** 0.5 if norm > 0: embedding = [x / norm for x in embedding] return embedding def _cosine_similarity(self, a: List[float], b: List[float]) -> float: """Compute cosine similarity between two vectors""" dot = sum(x * y for x, y in zip(a, b)) norm_a = sum(x**2 for x in a) ** 0.5 norm_b = sum(x**2 for x in b) ** 0.5 if norm_a == 0 or norm_b == 0: return 0.0 return dot / (norm_a * norm_b) def _generate_key(self, query: str) -> str: """Generate a cache key from query""" return hashlib.md5(query.encode()).hexdigest() async def check(self, query: str) -> Optional[CacheHit]: """Check if similar query exists in cache""" query_embedding = self._compute_embedding(query) best_match = None best_similarity = 0.0 # Find best matching entry for entry in self._cache.values(): if entry.is_expired: continue similarity = self._cosine_similarity(query_embedding, entry.embedding) if similarity > self.similarity_threshold and similarity > best_similarity: best_similarity = similarity best_match = entry if best_match: self._hits += 1 tokens_saved = best_match.metadata.get("prompt_tokens", 0) cost_saved = best_match.metadata.get("estimated_cost", 0.0) self._tokens_saved += tokens_saved self._cost_saved += cost_saved logger.info(f"Cache hit: similarity={best_similarity:.3f}, key={best_match.key[:8]}") return CacheHit( result=best_match.result, similarity=best_similarity, age_seconds=best_match.age_seconds, savings_tokens=tokens_saved, savings_cost=cost_saved ) self._misses += 1 return None async def store( self, query: str, result: Any, metadata: Dict[str, Any] = None, ttl: int = None ) -> str: """Store result with embedding for future semantic matching""" # Clean up expired entries if cache is full if len(self._cache) >= self.max_size: self._cleanup_expired() # If still full, remove oldest entries if len(self._cache) >= self.max_size: oldest_key = min( self._cache.keys(), key=lambda k: self._cache[k].created_at ) del self._cache[oldest_key] embedding = self._compute_embedding(query) key = self._generate_key(query) entry = CacheEntry( key=key, query=query, embedding=embedding, result=result, metadata=metadata or {}, created_at=time.time(), ttl=ttl or self.default_ttl ) self._cache[key] = entry logger.info(f"Cache stored: key={key[:8]}, entries={len(self._cache)}") return key def _cleanup_expired(self): """Remove expired entries""" expired_keys = [ k for k, v in self._cache.items() if v.is_expired ] for key in expired_keys: del self._cache[key] if expired_keys: logger.info(f"Cleaned up {len(expired_keys)} expired cache entries") def get_stats(self) -> Dict[str, Any]: """Get cache statistics""" total = self._hits + self._misses hit_rate = self._hits / total if total > 0 else 0.0 return { "hits": self._hits, "misses": self._misses, "hit_rate": hit_rate, "entries": len(self._cache), "tokens_saved": self._tokens_saved, "cost_saved": self._cost_saved } def clear(self): """Clear all cache entries""" self._cache.clear() self._hits = 0 self._misses = 0 self._tokens_saved = 0 self._cost_saved = 0.0 logger.info("Cache cleared") # Global cache instance semantic_cache = SemanticCache()