llmopt-server / llmopt /cache /semantic_cache.py
Shrot101's picture
feat: upgrade LLMOpt to V2 ML-powered architecture
eff2120
import logging
import json
import hashlib
from typing import Optional, Any
logger = logging.getLogger(__name__)
class SemanticCache:
"""
Semantic Cache powered by Redis and sentence-transformers.
Recommended redis.conf / Redis server settings:
maxmemory 240mb
maxmemory-policy allkeys-lfu
lfu-decay-time 5
lfu-log-factor 10
Automatically disables itself when Redis or ML dependencies are unavailable.
"""
def __init__(self, redis_url: Optional[str] = None, similarity_threshold: float = 0.95):
self.enabled = False
self.similarity_threshold = similarity_threshold
self.redis: Any = None
self.model: Any = None
self.cosine_similarity: Any = None
self.np: Any = None
if not redis_url:
logger.info("SemanticCache: No Redis URL provided. Cache disabled.")
return
# Try connecting to Redis
try:
import redis # type: ignore
self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
self.redis.ping()
except ImportError:
logger.warning("SemanticCache: 'redis' package not installed. Cache disabled.")
return
except Exception as e:
logger.warning(f"SemanticCache: Failed to connect to Redis at {redis_url}: {e}")
self.redis = None
return
# Try loading sentence-transformers + sklearn
try:
from sentence_transformers import SentenceTransformer # type: ignore
import numpy as np # type: ignore
from sklearn.metrics.pairwise import cosine_similarity
self.cosine_similarity = cosine_similarity
self.np = np
logger.info("SemanticCache: Loading embedding model (all-MiniLM-L6-v2)...")
self.model = SentenceTransformer("all-MiniLM-L6-v2")
self.enabled = True
logger.info("SemanticCache: Successfully initialized and connected to Redis!")
except ImportError:
logger.warning(
"SemanticCache: 'sentence-transformers' or 'scikit-learn' not installed. Cache disabled."
)
self.redis = None
except Exception as e:
logger.warning(f"SemanticCache: Failed to load ML models: {e}")
self.redis = None
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _is_within_memory_limit(self, safety_ratio: float = 0.90) -> bool:
"""
Returns False when Redis has consumed >= safety_ratio of its maxmemory.
Prevents new writes from pushing Redis over the 250 MB hard limit.
Fails open (returns True) if the info call itself errors.
"""
try:
info = self.redis.info("memory")
used = info["used_memory"]
max_mem = info.get("maxmemory", 0)
if max_mem == 0:
# No maxmemory configured — rely solely on allkeys-lfu eviction.
return True
within = (used / max_mem) < safety_ratio
if not within:
logger.warning(
f"SemanticCache: Memory at {used / max_mem:.1%} of limit "
f"({used / 1_048_576:.1f} MB / {max_mem / 1_048_576:.1f} MB). "
"Skipping write."
)
return within
except Exception as e:
logger.warning(f"SemanticCache: Memory check failed (failing open): {e}")
return True
@staticmethod
def _cache_key(query: str) -> str:
"""Stable, cross-process MD5 key for a query string."""
query_hash = hashlib.md5(query.encode("utf-8")).hexdigest()
return f"llmopt:cache:{query_hash}"
@staticmethod
def _ttl_for_response(response: str) -> int:
"""
Longer, richer responses get a longer TTL — they are more expensive to
regenerate and therefore more valuable to keep around.
> 500 chars → 7 days (604 800 s)
≤ 500 chars → 3 days (259 200 s)
"""
return 604_800 if len(response) > 500 else 259_200
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def get(self, query: str) -> Optional[str]:
"""
Return the cached LLM response for a semantically similar query, or
None on a cache miss.
Uses a Redis pipeline to fetch all cached entries in a single round
trip instead of one GET per key, keeping network overhead low even as
the cache grows.
"""
if not self.enabled:
return None
try:
query_embedding = self.model.encode([query])[0]
keys = self.redis.keys("llmopt:cache:*")
if not keys:
return None
# Batch-fetch all entries in one round trip
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
results = pipe.execute()
best_key = None
highest_sim = -1.0
for key, data_str in zip(keys, results):
if not data_str:
continue
data = json.loads(data_str)
cached_emb = self.np.array(data["embedding"])
sim = self.cosine_similarity([query_embedding], [cached_emb])[0][0]
if sim > highest_sim:
highest_sim = sim
best_key = key
if highest_sim >= self.similarity_threshold and best_key:
logger.info(f"SemanticCache HIT! Similarity: {highest_sim:.3f}")
match_data = json.loads(self.redis.get(best_key))
return match_data["response"]
except Exception as e:
logger.warning(f"SemanticCache GET error: {e}")
return None
def set(self, query: str, response: str) -> None:
"""
Embed and store a query/response pair.
Skips the write when Redis is near its memory ceiling so that the
allkeys-lfu policy never has to evict a hot entry just to absorb a
brand-new one.
"""
if not self.enabled:
return
# Guard: don't write when we are close to the 250 MB limit
if not self._is_within_memory_limit(safety_ratio=0.90):
return
try:
query_embedding = self.model.encode([query])[0]
key = self._cache_key(query)
ttl = self._ttl_for_response(response)
data = {
"query": query,
"embedding": query_embedding.tolist(),
"response": response,
}
# Atomic set + expiry via pipeline
pipe = self.redis.pipeline()
pipe.set(key, json.dumps(data))
pipe.expire(key, ttl)
pipe.execute()
logger.debug(
f"SemanticCache SET: key={key} ttl={ttl}s "
f"response_len={len(response)}"
)
except Exception as e:
logger.warning(f"SemanticCache SET error: {e}")