import logging import json import hashlib from typing import Optional, Any logger = logging.getLogger(__name__) class SemanticCache: """ Semantic Cache powered by Redis and sentence-transformers. Recommended redis.conf / Redis server settings: maxmemory 240mb maxmemory-policy allkeys-lfu lfu-decay-time 5 lfu-log-factor 10 Automatically disables itself when Redis or ML dependencies are unavailable. """ def __init__(self, redis_url: Optional[str] = None, similarity_threshold: float = 0.95): self.enabled = False self.similarity_threshold = similarity_threshold self.redis: Any = None self.model: Any = None self.cosine_similarity: Any = None self.np: Any = None if not redis_url: logger.info("SemanticCache: No Redis URL provided. Cache disabled.") return # Try connecting to Redis try: import redis # type: ignore self.redis = redis.Redis.from_url(redis_url, decode_responses=True) self.redis.ping() except ImportError: logger.warning("SemanticCache: 'redis' package not installed. Cache disabled.") return except Exception as e: logger.warning(f"SemanticCache: Failed to connect to Redis at {redis_url}: {e}") self.redis = None return # Try loading sentence-transformers + sklearn try: from sentence_transformers import SentenceTransformer # type: ignore import numpy as np # type: ignore from sklearn.metrics.pairwise import cosine_similarity self.cosine_similarity = cosine_similarity self.np = np logger.info("SemanticCache: Loading embedding model (all-MiniLM-L6-v2)...") self.model = SentenceTransformer("all-MiniLM-L6-v2") self.enabled = True logger.info("SemanticCache: Successfully initialized and connected to Redis!") except ImportError: logger.warning( "SemanticCache: 'sentence-transformers' or 'scikit-learn' not installed. Cache disabled." ) self.redis = None except Exception as e: logger.warning(f"SemanticCache: Failed to load ML models: {e}") self.redis = None # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _is_within_memory_limit(self, safety_ratio: float = 0.90) -> bool: """ Returns False when Redis has consumed >= safety_ratio of its maxmemory. Prevents new writes from pushing Redis over the 250 MB hard limit. Fails open (returns True) if the info call itself errors. """ try: info = self.redis.info("memory") used = info["used_memory"] max_mem = info.get("maxmemory", 0) if max_mem == 0: # No maxmemory configured — rely solely on allkeys-lfu eviction. return True within = (used / max_mem) < safety_ratio if not within: logger.warning( f"SemanticCache: Memory at {used / max_mem:.1%} of limit " f"({used / 1_048_576:.1f} MB / {max_mem / 1_048_576:.1f} MB). " "Skipping write." ) return within except Exception as e: logger.warning(f"SemanticCache: Memory check failed (failing open): {e}") return True @staticmethod def _cache_key(query: str) -> str: """Stable, cross-process MD5 key for a query string.""" query_hash = hashlib.md5(query.encode("utf-8")).hexdigest() return f"llmopt:cache:{query_hash}" @staticmethod def _ttl_for_response(response: str) -> int: """ Longer, richer responses get a longer TTL — they are more expensive to regenerate and therefore more valuable to keep around. > 500 chars → 7 days (604 800 s) ≤ 500 chars → 3 days (259 200 s) """ return 604_800 if len(response) > 500 else 259_200 # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def get(self, query: str) -> Optional[str]: """ Return the cached LLM response for a semantically similar query, or None on a cache miss. Uses a Redis pipeline to fetch all cached entries in a single round trip instead of one GET per key, keeping network overhead low even as the cache grows. """ if not self.enabled: return None try: query_embedding = self.model.encode([query])[0] keys = self.redis.keys("llmopt:cache:*") if not keys: return None # Batch-fetch all entries in one round trip pipe = self.redis.pipeline() for key in keys: pipe.get(key) results = pipe.execute() best_key = None highest_sim = -1.0 for key, data_str in zip(keys, results): if not data_str: continue data = json.loads(data_str) cached_emb = self.np.array(data["embedding"]) sim = self.cosine_similarity([query_embedding], [cached_emb])[0][0] if sim > highest_sim: highest_sim = sim best_key = key if highest_sim >= self.similarity_threshold and best_key: logger.info(f"SemanticCache HIT! Similarity: {highest_sim:.3f}") match_data = json.loads(self.redis.get(best_key)) return match_data["response"] except Exception as e: logger.warning(f"SemanticCache GET error: {e}") return None def set(self, query: str, response: str) -> None: """ Embed and store a query/response pair. Skips the write when Redis is near its memory ceiling so that the allkeys-lfu policy never has to evict a hot entry just to absorb a brand-new one. """ if not self.enabled: return # Guard: don't write when we are close to the 250 MB limit if not self._is_within_memory_limit(safety_ratio=0.90): return try: query_embedding = self.model.encode([query])[0] key = self._cache_key(query) ttl = self._ttl_for_response(response) data = { "query": query, "embedding": query_embedding.tolist(), "response": response, } # Atomic set + expiry via pipeline pipe = self.redis.pipeline() pipe.set(key, json.dumps(data)) pipe.expire(key, ttl) pipe.execute() logger.debug( f"SemanticCache SET: key={key} ttl={ttl}s " f"response_len={len(response)}" ) except Exception as e: logger.warning(f"SemanticCache SET error: {e}")