| | """ |
| | Semantic Caching Layer - Intelligent caching using embedding similarity |
| | """ |
| |
|
| | import hashlib |
| | import json |
| | import logging |
| | import time |
| | from dataclasses import dataclass |
| | from typing import Any, Dict, List, Optional, Tuple |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class CacheEntry: |
| | """A single cache entry""" |
| | key: str |
| | query: str |
| | embedding: List[float] |
| | result: Any |
| | metadata: Dict[str, Any] |
| | created_at: float |
| | ttl: int |
| | |
| | @property |
| | def is_expired(self) -> bool: |
| | return time.time() - self.created_at > self.ttl |
| | |
| | @property |
| | def age_seconds(self) -> float: |
| | return time.time() - self.created_at |
| |
|
| |
|
| | @dataclass |
| | class CacheHit: |
| | """Result of a cache hit""" |
| | result: Any |
| | similarity: float |
| | age_seconds: float |
| | savings_tokens: int |
| | savings_cost: float |
| |
|
| |
|
| | class SemanticCache: |
| | """ |
| | In-memory semantic cache using embedding similarity. |
| | Caches: Tool execution results, LLM completions, web search results, code execution outputs |
| | """ |
| | |
| | def __init__( |
| | self, |
| | similarity_threshold: float = 0.92, |
| | default_ttl: int = 604800, |
| | max_size: int = 10000 |
| | ): |
| | self.similarity_threshold = similarity_threshold |
| | self.default_ttl = default_ttl |
| | self.max_size = max_size |
| | self._cache: Dict[str, CacheEntry] = {} |
| | self._hits = 0 |
| | self._misses = 0 |
| | self._tokens_saved = 0 |
| | self._cost_saved = 0.0 |
| | |
| | def _compute_embedding(self, text: str) -> List[float]: |
| | """ |
| | Compute a simple embedding for the text. |
| | In production, use OpenAI or sentence-transformers. |
| | For now, use a simple hash-based approach. |
| | """ |
| | |
| | text = text.lower().strip() |
| | ngrams = [] |
| | for i in range(len(text) - 2): |
| | ngrams.append(text[i:i+3]) |
| | |
| | |
| | embedding = [0.0] * 128 |
| | for ngram in ngrams: |
| | idx = hash(ngram) % 128 |
| | embedding[idx] += 1.0 |
| | |
| | |
| | norm = sum(x**2 for x in embedding) ** 0.5 |
| | if norm > 0: |
| | embedding = [x / norm for x in embedding] |
| | |
| | return embedding |
| | |
| | def _cosine_similarity(self, a: List[float], b: List[float]) -> float: |
| | """Compute cosine similarity between two vectors""" |
| | dot = sum(x * y for x, y in zip(a, b)) |
| | norm_a = sum(x**2 for x in a) ** 0.5 |
| | norm_b = sum(x**2 for x in b) ** 0.5 |
| | |
| | if norm_a == 0 or norm_b == 0: |
| | return 0.0 |
| | |
| | return dot / (norm_a * norm_b) |
| | |
| | def _generate_key(self, query: str) -> str: |
| | """Generate a cache key from query""" |
| | return hashlib.md5(query.encode()).hexdigest() |
| | |
| | async def check(self, query: str) -> Optional[CacheHit]: |
| | """Check if similar query exists in cache""" |
| | query_embedding = self._compute_embedding(query) |
| | |
| | best_match = None |
| | best_similarity = 0.0 |
| | |
| | |
| | for entry in self._cache.values(): |
| | if entry.is_expired: |
| | continue |
| | |
| | similarity = self._cosine_similarity(query_embedding, entry.embedding) |
| | |
| | if similarity > self.similarity_threshold and similarity > best_similarity: |
| | best_similarity = similarity |
| | best_match = entry |
| | |
| | if best_match: |
| | self._hits += 1 |
| | tokens_saved = best_match.metadata.get("prompt_tokens", 0) |
| | cost_saved = best_match.metadata.get("estimated_cost", 0.0) |
| | self._tokens_saved += tokens_saved |
| | self._cost_saved += cost_saved |
| | |
| | logger.info(f"Cache hit: similarity={best_similarity:.3f}, key={best_match.key[:8]}") |
| | |
| | return CacheHit( |
| | result=best_match.result, |
| | similarity=best_similarity, |
| | age_seconds=best_match.age_seconds, |
| | savings_tokens=tokens_saved, |
| | savings_cost=cost_saved |
| | ) |
| | |
| | self._misses += 1 |
| | return None |
| | |
| | async def store( |
| | self, |
| | query: str, |
| | result: Any, |
| | metadata: Dict[str, Any] = None, |
| | ttl: int = None |
| | ) -> str: |
| | """Store result with embedding for future semantic matching""" |
| | |
| | |
| | if len(self._cache) >= self.max_size: |
| | self._cleanup_expired() |
| | |
| | |
| | if len(self._cache) >= self.max_size: |
| | oldest_key = min( |
| | self._cache.keys(), |
| | key=lambda k: self._cache[k].created_at |
| | ) |
| | del self._cache[oldest_key] |
| | |
| | embedding = self._compute_embedding(query) |
| | key = self._generate_key(query) |
| | |
| | entry = CacheEntry( |
| | key=key, |
| | query=query, |
| | embedding=embedding, |
| | result=result, |
| | metadata=metadata or {}, |
| | created_at=time.time(), |
| | ttl=ttl or self.default_ttl |
| | ) |
| | |
| | self._cache[key] = entry |
| | logger.info(f"Cache stored: key={key[:8]}, entries={len(self._cache)}") |
| | |
| | return key |
| | |
| | def _cleanup_expired(self): |
| | """Remove expired entries""" |
| | expired_keys = [ |
| | k for k, v in self._cache.items() |
| | if v.is_expired |
| | ] |
| | for key in expired_keys: |
| | del self._cache[key] |
| | |
| | if expired_keys: |
| | logger.info(f"Cleaned up {len(expired_keys)} expired cache entries") |
| | |
| | def get_stats(self) -> Dict[str, Any]: |
| | """Get cache statistics""" |
| | total = self._hits + self._misses |
| | hit_rate = self._hits / total if total > 0 else 0.0 |
| | |
| | return { |
| | "hits": self._hits, |
| | "misses": self._misses, |
| | "hit_rate": hit_rate, |
| | "entries": len(self._cache), |
| | "tokens_saved": self._tokens_saved, |
| | "cost_saved": self._cost_saved |
| | } |
| | |
| | def clear(self): |
| | """Clear all cache entries""" |
| | self._cache.clear() |
| | self._hits = 0 |
| | self._misses = 0 |
| | self._tokens_saved = 0 |
| | self._cost_saved = 0.0 |
| | logger.info("Cache cleared") |
| |
|
| |
|
| | |
| | semantic_cache = SemanticCache() |
| |
|