water3 / agent /core /semantic_cache.py
onewayto's picture
Upload 187 files
070daf8 verified
"""
Semantic Caching Layer - Intelligent caching using embedding similarity
"""
import hashlib
import json
import logging
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
@dataclass
class CacheEntry:
"""A single cache entry"""
key: str
query: str
embedding: List[float]
result: Any
metadata: Dict[str, Any]
created_at: float
ttl: int
@property
def is_expired(self) -> bool:
return time.time() - self.created_at > self.ttl
@property
def age_seconds(self) -> float:
return time.time() - self.created_at
@dataclass
class CacheHit:
"""Result of a cache hit"""
result: Any
similarity: float
age_seconds: float
savings_tokens: int
savings_cost: float
class SemanticCache:
"""
In-memory semantic cache using embedding similarity.
Caches: Tool execution results, LLM completions, web search results, code execution outputs
"""
def __init__(
self,
similarity_threshold: float = 0.92,
default_ttl: int = 604800, # 7 days
max_size: int = 10000
):
self.similarity_threshold = similarity_threshold
self.default_ttl = default_ttl
self.max_size = max_size
self._cache: Dict[str, CacheEntry] = {}
self._hits = 0
self._misses = 0
self._tokens_saved = 0
self._cost_saved = 0.0
def _compute_embedding(self, text: str) -> List[float]:
"""
Compute a simple embedding for the text.
In production, use OpenAI or sentence-transformers.
For now, use a simple hash-based approach.
"""
# Simple character n-gram based embedding
text = text.lower().strip()
ngrams = []
for i in range(len(text) - 2):
ngrams.append(text[i:i+3])
# Create a 128-dimensional embedding
embedding = [0.0] * 128
for ngram in ngrams:
idx = hash(ngram) % 128
embedding[idx] += 1.0
# Normalize
norm = sum(x**2 for x in embedding) ** 0.5
if norm > 0:
embedding = [x / norm for x in embedding]
return embedding
def _cosine_similarity(self, a: List[float], b: List[float]) -> float:
"""Compute cosine similarity between two vectors"""
dot = sum(x * y for x, y in zip(a, b))
norm_a = sum(x**2 for x in a) ** 0.5
norm_b = sum(x**2 for x in b) ** 0.5
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
def _generate_key(self, query: str) -> str:
"""Generate a cache key from query"""
return hashlib.md5(query.encode()).hexdigest()
async def check(self, query: str) -> Optional[CacheHit]:
"""Check if similar query exists in cache"""
query_embedding = self._compute_embedding(query)
best_match = None
best_similarity = 0.0
# Find best matching entry
for entry in self._cache.values():
if entry.is_expired:
continue
similarity = self._cosine_similarity(query_embedding, entry.embedding)
if similarity > self.similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = entry
if best_match:
self._hits += 1
tokens_saved = best_match.metadata.get("prompt_tokens", 0)
cost_saved = best_match.metadata.get("estimated_cost", 0.0)
self._tokens_saved += tokens_saved
self._cost_saved += cost_saved
logger.info(f"Cache hit: similarity={best_similarity:.3f}, key={best_match.key[:8]}")
return CacheHit(
result=best_match.result,
similarity=best_similarity,
age_seconds=best_match.age_seconds,
savings_tokens=tokens_saved,
savings_cost=cost_saved
)
self._misses += 1
return None
async def store(
self,
query: str,
result: Any,
metadata: Dict[str, Any] = None,
ttl: int = None
) -> str:
"""Store result with embedding for future semantic matching"""
# Clean up expired entries if cache is full
if len(self._cache) >= self.max_size:
self._cleanup_expired()
# If still full, remove oldest entries
if len(self._cache) >= self.max_size:
oldest_key = min(
self._cache.keys(),
key=lambda k: self._cache[k].created_at
)
del self._cache[oldest_key]
embedding = self._compute_embedding(query)
key = self._generate_key(query)
entry = CacheEntry(
key=key,
query=query,
embedding=embedding,
result=result,
metadata=metadata or {},
created_at=time.time(),
ttl=ttl or self.default_ttl
)
self._cache[key] = entry
logger.info(f"Cache stored: key={key[:8]}, entries={len(self._cache)}")
return key
def _cleanup_expired(self):
"""Remove expired entries"""
expired_keys = [
k for k, v in self._cache.items()
if v.is_expired
]
for key in expired_keys:
del self._cache[key]
if expired_keys:
logger.info(f"Cleaned up {len(expired_keys)} expired cache entries")
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
total = self._hits + self._misses
hit_rate = self._hits / total if total > 0 else 0.0
return {
"hits": self._hits,
"misses": self._misses,
"hit_rate": hit_rate,
"entries": len(self._cache),
"tokens_saved": self._tokens_saved,
"cost_saved": self._cost_saved
}
def clear(self):
"""Clear all cache entries"""
self._cache.clear()
self._hits = 0
self._misses = 0
self._tokens_saved = 0
self._cost_saved = 0.0
logger.info("Cache cleared")
# Global cache instance
semantic_cache = SemanticCache()