| """
|
| Advanced caching system with semantic similarity and persistence
|
| """
|
| import pickle
|
| import hashlib
|
| import time
|
| from typing import Dict, Any, Optional, Tuple
|
| from dataclasses import dataclass
|
| import numpy as np
|
| from sklearn.feature_extraction.text import TfidfVectorizer
|
| from sklearn.metrics.pairwise import cosine_similarity
|
| import os
|
| import logging
|
| from threading import Lock
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
| @dataclass
|
| class CacheEntry:
|
| result: Dict[str, Any]
|
| timestamp: float
|
| access_count: int
|
| semantic_vector: Optional[np.ndarray] = None
|
|
|
| class AdvancedCache:
|
| def __init__(self, max_size: int = 1000, ttl: int = 3600, similarity_threshold: float = 0.995):
|
| self.max_size = max_size
|
| self.ttl = ttl
|
| self.similarity_threshold = similarity_threshold
|
| self.cache: Dict[str, CacheEntry] = {}
|
| self.access_times: Dict[str, float] = {}
|
| self.lock = Lock()
|
|
|
|
|
| self.vectorizer = TfidfVectorizer(max_features=1000, stop_words='english')
|
| self.is_vectorizer_fitted = False
|
| self.cache_file = "cache_persistent.pkl"
|
|
|
|
|
| self._load_cache()
|
|
|
|
|
| self.hits = 0
|
| self.misses = 0
|
| self.semantic_hits = 0
|
|
|
| def _load_cache(self):
|
| """Load cache from disk"""
|
| try:
|
| if os.path.exists(self.cache_file):
|
| with open(self.cache_file, 'rb') as f:
|
| data = pickle.load(f)
|
| self.cache = data.get('cache', {})
|
| self.access_times = data.get('access_times', {})
|
| if 'vectorizer' in data and data['vectorizer'] is not None:
|
| self.vectorizer = data['vectorizer']
|
| self.is_vectorizer_fitted = True
|
| logger.info(f"Loaded {len(self.cache)} entries from persistent cache")
|
| except Exception as e:
|
| logger.warning(f"Failed to load persistent cache: {e}")
|
|
|
| def _save_cache(self):
|
| """Save cache to disk"""
|
| try:
|
| data = {
|
| 'cache': self.cache,
|
| 'access_times': self.access_times,
|
| 'vectorizer': self.vectorizer if self.is_vectorizer_fitted else None
|
| }
|
| with open(self.cache_file, 'wb') as f:
|
| pickle.dump(data, f)
|
| except Exception as e:
|
| logger.warning(f"Failed to save persistent cache: {e}")
|
|
|
| def _generate_key(self, prompt: str, response: str, question: str = "") -> str:
|
| """Generate cache key"""
|
| combined = f"{prompt}|{response}|{question}".lower().strip()
|
| return hashlib.sha256(combined.encode()).hexdigest()
|
|
|
| def _create_semantic_vector(self, prompt: str, response: str, question: str = "") -> np.ndarray:
|
| """Create semantic vector for similarity comparison"""
|
| combined_text = f"{prompt} {response} {question}"
|
|
|
| if not self.is_vectorizer_fitted:
|
|
|
| self.vectorizer.fit([combined_text])
|
| self.is_vectorizer_fitted = True
|
|
|
| try:
|
| vector = self.vectorizer.transform([combined_text])
|
| return vector.toarray()[0]
|
| except Exception:
|
|
|
| all_texts = [combined_text]
|
| for entry in self.cache.values():
|
| if hasattr(entry, 'semantic_vector') and entry.semantic_vector is not None:
|
| all_texts.append("dummy")
|
|
|
| self.vectorizer.fit(all_texts)
|
| vector = self.vectorizer.transform([combined_text])
|
| return vector.toarray()[0]
|
|
|
| def _find_similar_entry(self, prompt: str, response: str, question: str = "") -> Optional[Tuple[str, CacheEntry]]:
|
| """Find semantically similar cache entry"""
|
| if len(self.cache) == 0:
|
| return None
|
|
|
| try:
|
| query_vector = self._create_semantic_vector(prompt, response, question)
|
|
|
| best_similarity = 0
|
| best_entry = None
|
| best_key = None
|
|
|
| for key, entry in self.cache.items():
|
| if entry.semantic_vector is None:
|
| continue
|
|
|
| similarity = cosine_similarity([query_vector], [entry.semantic_vector])[0][0]
|
| if similarity > best_similarity and similarity >= self.similarity_threshold:
|
| best_similarity = similarity
|
| best_entry = entry
|
| best_key = key
|
|
|
| if best_entry:
|
| logger.debug(f"Found similar entry with {best_similarity:.3f} similarity")
|
| return best_key, best_entry
|
|
|
| except Exception as e:
|
| logger.warning(f"Semantic similarity search failed: {e}")
|
|
|
| return None
|
|
|
| def get(self, prompt: str, response: str, question: str = "") -> Optional[Dict[str, Any]]:
|
| """Get cached result with semantic similarity fallback"""
|
| with self.lock:
|
| key = self._generate_key(prompt, response, question)
|
| current_time = time.time()
|
|
|
|
|
| if key in self.cache:
|
| entry = self.cache[key]
|
| if current_time - entry.timestamp <= self.ttl:
|
| entry.access_count += 1
|
| self.access_times[key] = current_time
|
| self.hits += 1
|
| logger.debug(f"Cache hit for key: {key[:8]}...")
|
| return entry.result
|
| else:
|
|
|
| del self.cache[key]
|
| if key in self.access_times:
|
| del self.access_times[key]
|
|
|
|
|
| similar_result = self._find_similar_entry(prompt, response, question)
|
| if similar_result:
|
| similar_key, similar_entry = similar_result
|
|
|
| similar_entry.access_count += 1
|
| self.access_times[similar_key] = current_time
|
| self.semantic_hits += 1
|
| logger.debug(f"Semantic cache hit for key: {similar_key[:8]}...")
|
| return similar_entry.result
|
|
|
| self.misses += 1
|
| return None
|
|
|
| def set(self, prompt: str, response: str, question: str, result: Dict[str, Any]):
|
| """Cache result with semantic vector"""
|
| with self.lock:
|
| key = self._generate_key(prompt, response, question)
|
| current_time = time.time()
|
|
|
|
|
| semantic_vector = self._create_semantic_vector(prompt, response, question)
|
|
|
|
|
| entry = CacheEntry(
|
| result=result,
|
| timestamp=current_time,
|
| access_count=1,
|
| semantic_vector=semantic_vector
|
| )
|
|
|
|
|
| if len(self.cache) >= self.max_size:
|
| self._evict_entries()
|
|
|
| self.cache[key] = entry
|
| self.access_times[key] = current_time
|
|
|
|
|
| if len(self.cache) % 10 == 0:
|
| self._save_cache()
|
|
|
| def _evict_entries(self):
|
| """Evict least recently used entries"""
|
| if not self.cache:
|
| return
|
|
|
|
|
| sorted_keys = sorted(self.access_times.keys(), key=lambda k: self.access_times[k])
|
| evict_count = max(1, len(sorted_keys) // 5)
|
|
|
| for key in sorted_keys[:evict_count]:
|
| if key in self.cache:
|
| del self.cache[key]
|
| if key in self.access_times:
|
| del self.access_times[key]
|
|
|
| logger.info(f"Evicted {evict_count} cache entries")
|
|
|
| def get_stats(self) -> Dict[str, Any]:
|
| """Get cache statistics"""
|
| total_requests = self.hits + self.misses
|
| hit_rate = (self.hits / total_requests * 100) if total_requests > 0 else 0
|
| semantic_hit_rate = (self.semantic_hits / total_requests * 100) if total_requests > 0 else 0
|
|
|
| return {
|
| "total_entries": len(self.cache),
|
| "hits": self.hits,
|
| "misses": self.misses,
|
| "semantic_hits": self.semantic_hits,
|
| "hit_rate": hit_rate,
|
| "semantic_hit_rate": semantic_hit_rate,
|
| "total_requests": total_requests
|
| }
|
|
|
| def clear(self):
|
| """Clear all cache entries"""
|
| with self.lock:
|
| self.cache.clear()
|
| self.access_times.clear()
|
| if os.path.exists(self.cache_file):
|
| os.remove(self.cache_file)
|
|
|
| def __del__(self):
|
| """Save cache when object is destroyed"""
|
| self._save_cache()
|
|
|
|
|
| advanced_cache = AdvancedCache()
|
|
|