|
|
""" |
|
|
Semantic cache that caches and retrieves similar queries using embeddings. |
|
|
More advanced than exact match caching - understands semantic similarity. |
|
|
""" |
|
|
import numpy as np |
|
|
from typing import List, Dict, Any, Optional, Tuple |
|
|
import sqlite3 |
|
|
import hashlib |
|
|
import json |
|
|
import time |
|
|
from datetime import datetime, timedelta |
|
|
from pathlib import Path |
|
|
import faiss |
|
|
import logging |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
|
|
|
from app.hyper_config import config |
|
|
from app.ultra_fast_embeddings import get_embedder |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class CacheStrategy(str, Enum): |
|
|
EXACT = "exact" |
|
|
SEMANTIC = "semantic" |
|
|
HYBRID = "hybrid" |
|
|
|
|
|
@dataclass |
|
|
class CacheEntry: |
|
|
query: str |
|
|
query_hash: str |
|
|
query_embedding: np.ndarray |
|
|
answer: str |
|
|
chunks_used: List[str] |
|
|
metadata: Dict[str, Any] |
|
|
created_at: datetime |
|
|
accessed_at: datetime |
|
|
access_count: int |
|
|
ttl_seconds: int |
|
|
|
|
|
class SemanticCache: |
|
|
""" |
|
|
Advanced semantic cache that understands similar queries. |
|
|
|
|
|
Features: |
|
|
- Exact match caching |
|
|
- Semantic similarity caching |
|
|
- FAISS-based similarity search |
|
|
- TTL and LRU eviction |
|
|
- Adaptive similarity thresholds |
|
|
- Performance metrics |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
cache_dir: Optional[Path] = None, |
|
|
strategy: CacheStrategy = CacheStrategy.HYBRID, |
|
|
similarity_threshold: float = 0.85, |
|
|
max_cache_size: int = 10000, |
|
|
ttl_hours: int = 24 |
|
|
): |
|
|
self.cache_dir = cache_dir or config.cache_dir |
|
|
self.cache_dir.mkdir(exist_ok=True) |
|
|
|
|
|
self.strategy = strategy |
|
|
self.similarity_threshold = similarity_threshold |
|
|
self.max_cache_size = max_cache_size |
|
|
self.ttl_hours = ttl_hours |
|
|
|
|
|
|
|
|
self.db_path = self.cache_dir / "semantic_cache.db" |
|
|
self.conn = None |
|
|
|
|
|
|
|
|
self.faiss_index = None |
|
|
self.embedding_dim = 384 |
|
|
self.entry_ids = [] |
|
|
|
|
|
|
|
|
self.embedder = None |
|
|
|
|
|
|
|
|
self.hits = 0 |
|
|
self.misses = 0 |
|
|
self.semantic_hits = 0 |
|
|
self.exact_hits = 0 |
|
|
|
|
|
self._initialized = False |
|
|
|
|
|
def initialize(self): |
|
|
"""Initialize the cache database and FAISS index.""" |
|
|
if self._initialized: |
|
|
return |
|
|
|
|
|
logger.info(f"🚀 Initializing SemanticCache (strategy: {self.strategy.value})") |
|
|
|
|
|
|
|
|
self._init_database() |
|
|
|
|
|
|
|
|
if self.strategy in [CacheStrategy.SEMANTIC, CacheStrategy.HYBRID]: |
|
|
self.embedder = get_embedder() |
|
|
self.embedding_dim = 384 |
|
|
|
|
|
|
|
|
if self.strategy in [CacheStrategy.SEMANTIC, CacheStrategy.HYBRID]: |
|
|
self._init_faiss_index() |
|
|
|
|
|
|
|
|
self._load_cache_entries() |
|
|
|
|
|
logger.info(f"✅ SemanticCache initialized with {len(self.entry_ids)} entries") |
|
|
self._initialized = True |
|
|
|
|
|
def _init_database(self): |
|
|
"""Initialize the cache database.""" |
|
|
self.conn = sqlite3.connect(self.db_path) |
|
|
cursor = self.conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS cache_entries ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
query TEXT NOT NULL, |
|
|
query_hash TEXT UNIQUE NOT NULL, |
|
|
query_embedding BLOB, |
|
|
answer TEXT NOT NULL, |
|
|
chunks_used_json TEXT NOT NULL, |
|
|
metadata_json TEXT NOT NULL, |
|
|
created_at TIMESTAMP NOT NULL, |
|
|
accessed_at TIMESTAMP NOT NULL, |
|
|
access_count INTEGER DEFAULT 1, |
|
|
ttl_seconds INTEGER NOT NULL, |
|
|
embedding_hash TEXT |
|
|
) |
|
|
""") |
|
|
|
|
|
|
|
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_query_hash ON cache_entries(query_hash)") |
|
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_accessed_at ON cache_entries(accessed_at)") |
|
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_embedding_hash ON cache_entries(embedding_hash)") |
|
|
|
|
|
self.conn.commit() |
|
|
|
|
|
def _init_faiss_index(self): |
|
|
"""Initialize FAISS index for semantic search.""" |
|
|
self.faiss_index = faiss.IndexFlatL2(self.embedding_dim) |
|
|
self.entry_ids = [] |
|
|
|
|
|
def _load_cache_entries(self): |
|
|
"""Load existing cache entries into FAISS index.""" |
|
|
if self.strategy not in [CacheStrategy.SEMANTIC, CacheStrategy.HYBRID]: |
|
|
return |
|
|
|
|
|
cursor = self.conn.cursor() |
|
|
cursor.execute(""" |
|
|
SELECT id, query_embedding FROM cache_entries |
|
|
WHERE query_embedding IS NOT NULL |
|
|
ORDER BY accessed_at DESC |
|
|
LIMIT 1000 |
|
|
""") |
|
|
|
|
|
for entry_id, embedding_blob in cursor.fetchall(): |
|
|
if embedding_blob: |
|
|
embedding = np.frombuffer(embedding_blob, dtype=np.float32) |
|
|
self.faiss_index.add(embedding.reshape(1, -1)) |
|
|
self.entry_ids.append(entry_id) |
|
|
|
|
|
logger.info(f"Loaded {len(self.entry_ids)} entries into FAISS index") |
|
|
|
|
|
def get(self, query: str) -> Optional[Tuple[str, List[str]]]: |
|
|
""" |
|
|
Get cached answer for query. |
|
|
|
|
|
Returns: |
|
|
Tuple of (answer, chunks_used) or None if not found |
|
|
""" |
|
|
if not self._initialized: |
|
|
self.initialize() |
|
|
|
|
|
query_hash = self._hash_query(query) |
|
|
|
|
|
|
|
|
if self.strategy in [CacheStrategy.EXACT, CacheStrategy.HYBRID]: |
|
|
result = self._get_exact(query_hash) |
|
|
if result: |
|
|
self.exact_hits += 1 |
|
|
self.hits += 1 |
|
|
return result |
|
|
|
|
|
|
|
|
if self.strategy in [CacheStrategy.SEMANTIC, CacheStrategy.HYBRID]: |
|
|
result = self._get_semantic(query) |
|
|
if result: |
|
|
self.semantic_hits += 1 |
|
|
self.hits += 1 |
|
|
return result |
|
|
|
|
|
self.misses += 1 |
|
|
return None |
|
|
|
|
|
def _get_exact(self, query_hash: str) -> Optional[Tuple[str, List[str]]]: |
|
|
"""Get exact match from cache.""" |
|
|
cursor = self.conn.cursor() |
|
|
cursor.execute(""" |
|
|
SELECT answer, chunks_used_json, accessed_at, ttl_seconds |
|
|
FROM cache_entries |
|
|
WHERE query_hash = ? |
|
|
LIMIT 1 |
|
|
""", (query_hash,)) |
|
|
|
|
|
row = cursor.fetchone() |
|
|
if not row: |
|
|
return None |
|
|
|
|
|
answer, chunks_used_json, accessed_at_str, ttl_seconds = row |
|
|
|
|
|
|
|
|
accessed_at = datetime.fromisoformat(accessed_at_str) |
|
|
if self._is_expired(accessed_at, ttl_seconds): |
|
|
self._delete_entry(query_hash) |
|
|
return None |
|
|
|
|
|
|
|
|
self._update_access_time(query_hash) |
|
|
|
|
|
chunks_used = json.loads(chunks_used_json) |
|
|
return answer, chunks_used |
|
|
|
|
|
def _get_semantic(self, query: str) -> Optional[Tuple[str, List[str]]]: |
|
|
"""Get semantic match from cache.""" |
|
|
if not self.embedder or not self.faiss_index or len(self.entry_ids) == 0: |
|
|
return None |
|
|
|
|
|
|
|
|
query_embedding = self.embedder.embed_single(query) |
|
|
query_embedding = query_embedding.astype(np.float32).reshape(1, -1) |
|
|
|
|
|
|
|
|
distances, indices = self.faiss_index.search(query_embedding, 3) |
|
|
|
|
|
|
|
|
for i, (distance, idx) in enumerate(zip(distances[0], indices[0])): |
|
|
if idx >= 0 and idx < len(self.entry_ids): |
|
|
similarity = 1.0 / (1.0 + distance) |
|
|
|
|
|
if similarity >= self.similarity_threshold: |
|
|
entry_id = self.entry_ids[idx] |
|
|
|
|
|
|
|
|
cursor = self.conn.cursor() |
|
|
cursor.execute(""" |
|
|
SELECT answer, chunks_used_json, accessed_at, ttl_seconds, query |
|
|
FROM cache_entries |
|
|
WHERE id = ? |
|
|
LIMIT 1 |
|
|
""", (entry_id,)) |
|
|
|
|
|
row = cursor.fetchone() |
|
|
if row: |
|
|
answer, chunks_used_json, accessed_at_str, ttl_seconds, original_query = row |
|
|
|
|
|
|
|
|
accessed_at = datetime.fromisoformat(accessed_at_str) |
|
|
if self._is_expired(accessed_at, ttl_seconds): |
|
|
self._delete_by_id(entry_id) |
|
|
continue |
|
|
|
|
|
|
|
|
self._update_access_by_id(entry_id) |
|
|
|
|
|
chunks_used = json.loads(chunks_used_json) |
|
|
|
|
|
logger.debug(f"Semantic cache hit: similarity={similarity:.3f}, " |
|
|
f"original='{original_query[:30]}...', " |
|
|
f"current='{query[:30]}...'") |
|
|
|
|
|
return answer, chunks_used |
|
|
|
|
|
return None |
|
|
|
|
|
def put( |
|
|
self, |
|
|
query: str, |
|
|
answer: str, |
|
|
chunks_used: List[str], |
|
|
metadata: Optional[Dict[str, Any]] = None, |
|
|
ttl_seconds: Optional[int] = None |
|
|
): |
|
|
""" |
|
|
Store query and answer in cache. |
|
|
|
|
|
Args: |
|
|
query: The user query |
|
|
answer: Generated answer |
|
|
chunks_used: List of chunks used for answer |
|
|
metadata: Additional metadata |
|
|
ttl_seconds: Time to live in seconds |
|
|
""" |
|
|
if not self._initialized: |
|
|
self.initialize() |
|
|
|
|
|
query_hash = self._hash_query(query) |
|
|
ttl = ttl_seconds or (self.ttl_hours * 3600) |
|
|
|
|
|
|
|
|
query_embedding = None |
|
|
embedding_hash = None |
|
|
|
|
|
if self.strategy in [CacheStrategy.SEMANTIC, CacheStrategy.HYBRID] and self.embedder: |
|
|
embedding_result = self.embedder.embed_single(query) |
|
|
query_embedding = embedding_result.astype(np.float32).tobytes() |
|
|
embedding_hash = hashlib.md5(query_embedding).hexdigest() |
|
|
|
|
|
|
|
|
chunks_used_json = json.dumps(chunks_used) |
|
|
metadata_json = json.dumps(metadata or {}) |
|
|
now = datetime.now().isoformat() |
|
|
|
|
|
cursor = self.conn.cursor() |
|
|
|
|
|
try: |
|
|
|
|
|
cursor.execute(""" |
|
|
INSERT INTO cache_entries ( |
|
|
query, query_hash, query_embedding, answer, chunks_used_json, |
|
|
metadata_json, created_at, accessed_at, ttl_seconds, embedding_hash |
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) |
|
|
""", ( |
|
|
query, query_hash, query_embedding, answer, chunks_used_json, |
|
|
metadata_json, now, now, ttl, embedding_hash |
|
|
)) |
|
|
|
|
|
entry_id = cursor.lastrowid |
|
|
|
|
|
|
|
|
if (self.strategy in [CacheStrategy.SEMANTIC, CacheStrategy.HYBRID] and |
|
|
query_embedding and self.faiss_index is not None): |
|
|
embedding = np.frombuffer(query_embedding, dtype=np.float32) |
|
|
self.faiss_index.add(embedding.reshape(1, -1)) |
|
|
self.entry_ids.append(entry_id) |
|
|
|
|
|
self.conn.commit() |
|
|
|
|
|
logger.debug(f"Cached query: '{query[:50]}...'") |
|
|
|
|
|
|
|
|
self._evict_if_needed() |
|
|
|
|
|
except sqlite3.IntegrityError: |
|
|
|
|
|
self.conn.rollback() |
|
|
self._update_entry(query_hash, answer, chunks_used_json, metadata_json, now, ttl) |
|
|
|
|
|
def _update_entry( |
|
|
self, |
|
|
query_hash: str, |
|
|
answer: str, |
|
|
chunks_used_json: str, |
|
|
metadata_json: str, |
|
|
timestamp: str, |
|
|
ttl_seconds: int |
|
|
): |
|
|
"""Update existing cache entry.""" |
|
|
cursor = self.conn.cursor() |
|
|
cursor.execute(""" |
|
|
UPDATE cache_entries |
|
|
SET answer = ?, chunks_used_json = ?, metadata_json = ?, |
|
|
accessed_at = ?, ttl_seconds = ?, access_count = access_count + 1 |
|
|
WHERE query_hash = ? |
|
|
""", (answer, chunks_used_json, metadata_json, timestamp, ttl_seconds, query_hash)) |
|
|
self.conn.commit() |
|
|
|
|
|
def _update_access_time(self, query_hash: str): |
|
|
"""Update access time for cache entry.""" |
|
|
cursor = self.conn.cursor() |
|
|
cursor.execute(""" |
|
|
UPDATE cache_entries |
|
|
SET accessed_at = ?, access_count = access_count + 1 |
|
|
WHERE query_hash = ? |
|
|
""", (datetime.now().isoformat(), query_hash)) |
|
|
self.conn.commit() |
|
|
|
|
|
def _update_access_by_id(self, entry_id: int): |
|
|
"""Update access time by entry ID.""" |
|
|
cursor = self.conn.cursor() |
|
|
cursor.execute(""" |
|
|
UPDATE cache_entries |
|
|
SET accessed_at = ?, access_count = access_count + 1 |
|
|
WHERE id = ? |
|
|
""", (datetime.now().isoformat(), entry_id)) |
|
|
self.conn.commit() |
|
|
|
|
|
def _delete_entry(self, query_hash: str): |
|
|
"""Delete cache entry by query hash.""" |
|
|
cursor = self.conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute("SELECT id FROM cache_entries WHERE query_hash = ?", (query_hash,)) |
|
|
row = cursor.fetchone() |
|
|
|
|
|
if row: |
|
|
entry_id = row[0] |
|
|
self._remove_from_faiss(entry_id) |
|
|
|
|
|
|
|
|
cursor.execute("DELETE FROM cache_entries WHERE query_hash = ?", (query_hash,)) |
|
|
self.conn.commit() |
|
|
|
|
|
def _delete_by_id(self, entry_id: int): |
|
|
"""Delete cache entry by ID.""" |
|
|
self._remove_from_faiss(entry_id) |
|
|
|
|
|
cursor = self.conn.cursor() |
|
|
cursor.execute("DELETE FROM cache_entries WHERE id = ?", (entry_id,)) |
|
|
self.conn.commit() |
|
|
|
|
|
def _remove_from_faiss(self, entry_id: int): |
|
|
"""Remove entry from FAISS index (simplified - FAISS doesn't support removal).""" |
|
|
|
|
|
|
|
|
if entry_id in self.entry_ids: |
|
|
idx = self.entry_ids.index(entry_id) |
|
|
|
|
|
|
|
|
del self.entry_ids[idx] |
|
|
|
|
|
def _evict_if_needed(self): |
|
|
"""Evict old entries if cache exceeds max size.""" |
|
|
cursor = self.conn.cursor() |
|
|
cursor.execute("SELECT COUNT(*) FROM cache_entries") |
|
|
count = cursor.fetchone()[0] |
|
|
|
|
|
if count > self.max_cache_size: |
|
|
|
|
|
cursor.execute(""" |
|
|
DELETE FROM cache_entries |
|
|
WHERE id IN ( |
|
|
SELECT id FROM cache_entries |
|
|
ORDER BY accessed_at ASC |
|
|
LIMIT ? |
|
|
) |
|
|
""", (count - self.max_cache_size,)) |
|
|
self.conn.commit() |
|
|
|
|
|
|
|
|
if self.strategy in [CacheStrategy.SEMANTIC, CacheStrategy.HYBRID]: |
|
|
self._rebuild_faiss_index() |
|
|
|
|
|
def _rebuild_faiss_index(self): |
|
|
"""Rebuild FAISS index from database.""" |
|
|
if self.faiss_index: |
|
|
self.faiss_index.reset() |
|
|
self.entry_ids = [] |
|
|
self._load_cache_entries() |
|
|
|
|
|
def _hash_query(self, query: str) -> str: |
|
|
"""Create hash for query.""" |
|
|
return hashlib.md5(query.encode()).hexdigest() |
|
|
|
|
|
def _is_expired(self, accessed_at: datetime, ttl_seconds: int) -> bool: |
|
|
"""Check if cache entry is expired.""" |
|
|
expiry_time = accessed_at + timedelta(seconds=ttl_seconds) |
|
|
return datetime.now() > expiry_time |
|
|
|
|
|
def clear(self): |
|
|
"""Clear all cache entries.""" |
|
|
cursor = self.conn.cursor() |
|
|
cursor.execute("DELETE FROM cache_entries") |
|
|
self.conn.commit() |
|
|
|
|
|
if self.faiss_index: |
|
|
self.faiss_index.reset() |
|
|
self.entry_ids = [] |
|
|
|
|
|
logger.info("Cache cleared") |
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
|
"""Get cache statistics.""" |
|
|
cursor = self.conn.cursor() |
|
|
|
|
|
cursor.execute("SELECT COUNT(*) FROM cache_entries") |
|
|
total_entries = cursor.fetchone()[0] |
|
|
|
|
|
cursor.execute("SELECT SUM(access_count) FROM cache_entries") |
|
|
total_accesses = cursor.fetchone()[0] or 0 |
|
|
|
|
|
cursor.execute(""" |
|
|
SELECT COUNT(*) FROM cache_entries |
|
|
WHERE datetime(accessed_at) < datetime('now', '-7 days') |
|
|
""") |
|
|
stale_entries = cursor.fetchone()[0] |
|
|
|
|
|
hit_rate = self.hits / (self.hits + self.misses) if (self.hits + self.misses) > 0 else 0 |
|
|
|
|
|
return { |
|
|
"total_entries": total_entries, |
|
|
"total_accesses": total_accesses, |
|
|
"stale_entries": stale_entries, |
|
|
"hits": self.hits, |
|
|
"misses": self.misses, |
|
|
"exact_hits": self.exact_hits, |
|
|
"semantic_hits": self.semantic_hits, |
|
|
"hit_rate": hit_rate, |
|
|
"strategy": self.strategy.value, |
|
|
"similarity_threshold": self.similarity_threshold, |
|
|
"faiss_entries": len(self.entry_ids) |
|
|
} |
|
|
|
|
|
def __del__(self): |
|
|
"""Cleanup.""" |
|
|
if self.conn: |
|
|
self.conn.close() |
|
|
|
|
|
|
|
|
_cache_instance = None |
|
|
|
|
|
def get_semantic_cache() -> SemanticCache: |
|
|
"""Get or create the global semantic cache instance.""" |
|
|
global _cache_instance |
|
|
if _cache_instance is None: |
|
|
_cache_instance = SemanticCache( |
|
|
strategy=CacheStrategy.HYBRID, |
|
|
similarity_threshold=0.85, |
|
|
max_cache_size=5000, |
|
|
ttl_hours=24 |
|
|
) |
|
|
_cache_instance.initialize() |
|
|
return _cache_instance |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import logging |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
print("\n🧪 Testing SemanticCache...") |
|
|
|
|
|
cache = SemanticCache( |
|
|
strategy=CacheStrategy.HYBRID, |
|
|
similarity_threshold=0.8, |
|
|
max_cache_size=100 |
|
|
) |
|
|
cache.initialize() |
|
|
|
|
|
|
|
|
print("\n📝 Testing exact caching...") |
|
|
query1 = "What is machine learning?" |
|
|
answer1 = "Machine learning is a subset of AI that enables systems to learn from data." |
|
|
chunks1 = ["chunk1", "chunk2"] |
|
|
|
|
|
cache.put(query1, answer1, chunks1) |
|
|
|
|
|
cached = cache.get(query1) |
|
|
if cached: |
|
|
print(f" Exact cache HIT: {cached[0][:50]}...") |
|
|
else: |
|
|
print(" Exact cache MISS") |
|
|
|
|
|
|
|
|
print("\n📝 Testing semantic caching...") |
|
|
similar_query = "Can you explain machine learning?" |
|
|
|
|
|
cached = cache.get(similar_query) |
|
|
if cached: |
|
|
print(f" Semantic cache HIT: {cached[0][:50]}...") |
|
|
else: |
|
|
print(" Semantic cache MISS (might need lower threshold)") |
|
|
|
|
|
|
|
|
print("\n📝 Testing non-similar query...") |
|
|
different_query = "What is the capital of France?" |
|
|
|
|
|
cached = cache.get(different_query) |
|
|
if cached: |
|
|
print(f" Unexpected HIT: {cached[0][:50]}...") |
|
|
else: |
|
|
print(" Expected MISS") |
|
|
|
|
|
|
|
|
stats = cache.get_stats() |
|
|
print("\n📊 Cache Statistics:") |
|
|
for key, value in stats.items(): |
|
|
print(f" {key}: {value}") |
|
|
|
|
|
|
|
|
cache.clear() |
|
|
print("\n🧹 Cache cleared") |
|
|
|