build-small-hackathon / backend /cache /semantic_cache.py
shankarsain's picture
Clean up repository and update codebase for vidyabot-build-small
a2398a7
Raw
History Blame Contribute Delete
13.3 kB
"""
Semantic Cache Module
Query deduplication using embedding similarity.
Caches answers to semantically similar questions.
Benefits:
- Identical queries: cached (similarity = 1.0)
- Near-identical queries: cached (similarity > 0.90)
- Similar intent queries: cached (similarity > threshold)
- Saves ~0% tokens on first query, ~0% on cache miss, 100% on cache hit
"""
import hashlib
import logging
import faiss
import numpy as np
from typing import Optional, Dict, Tuple
from dataclasses import dataclass
from backend.config import settings
from backend.database import get_db_connection
from backend.ingestion.embedder import Embedder
logger = logging.getLogger(__name__)
@dataclass
class CachedAnswer:
"""Represents a cached query/answer pair."""
query_id: int
query_text: str
query_embedding: np.ndarray
answer: str
context_tokens_used: int
model_used: str
pruning_ratio: float
source_pages: str
created_at: str
accessed_count: int = 0
class SemanticCache:
"""Manages query cache with embedding-based semantic similarity."""
def __init__(self):
"""Initialize semantic cache."""
self.embedder = Embedder()
self.faiss_index: Optional[faiss.IndexFlatIP] = None
self.cached_queries: Dict[int, CachedAnswer] = {} # id -> CachedAnswer
self.query_order: list = [] # Tracks insertion order for LRU
def load_cache(self) -> None:
"""
Load cached queries from database into memory FAISS index.
Run on application startup.
"""
logger.info("Loading semantic cache from database...")
try:
conn = get_db_connection()
cursor = conn.cursor()
# Get all cached queries
cursor.execute("""
SELECT id, query_text, answer, context_tokens_used, model_used,
pruning_ratio, source_pages, created_at, accessed_count
FROM query_cache
ORDER BY created_at DESC
""")
rows = cursor.fetchall()
if not rows:
logger.info("Cache is empty")
self.faiss_index = faiss.IndexFlatIP(settings.EMBEDDINGS_DIMENSION)
return
# Embed all cached queries
query_texts = [row[1] for row in rows]
embeddings = self.embedder.embed_chunks(query_texts, show_progress=True)
embeddings = self.embedder.normalize_embeddings(embeddings)
# Create FAISS index
self.faiss_index = faiss.IndexFlatIP(settings.EMBEDDINGS_DIMENSION)
self.faiss_index.add(embeddings)
# Load into memory cache
for i, row in enumerate(rows):
cache_item = CachedAnswer(
query_id=row[0],
query_text=row[1],
query_embedding=embeddings[i], # Save embedding for later
answer=row[2],
context_tokens_used=row[3],
model_used=row[4],
pruning_ratio=row[5],
source_pages=row[6],
created_at=row[7],
accessed_count=row[8]
)
self.cached_queries[row[0]] = cache_item
self.query_order.append(row[0])
logger.info(f"✅ Loaded {len(self.cached_queries)} cached queries")
except Exception as e:
logger.error(f"Error loading cache: {e}")
self.faiss_index = faiss.IndexFlatIP(settings.EMBEDDINGS_DIMENSION)
def check_cache(self, query: str, textbook_id: int) -> Optional[Dict]:
"""
Check if query exists in cache (with semantic similarity).
Args:
query: Student question
textbook_id: Textbook being searched
Returns:
Cached answer dict if found, None otherwise
"""
if self.faiss_index is None or len(self.cached_queries) == 0:
return None
try:
# Embed query
query_embedding = self.embedder.embed_query(query)
query_embedding = self.embedder.normalize_embeddings(
np.array([query_embedding], dtype=np.float32)
)[0]
# Search FAISS index
query_reshaped = query_embedding.reshape(1, -1)
distances, indices = self.faiss_index.search(query_reshaped, 1)
if len(indices) == 0 or len(indices[0]) == 0:
return None
# For IndexFlatIP with normalized vectors, distance = similarity score
similarity = float(distances[0][0])
best_idx = int(indices[0][0])
logger.debug(f"Best cache match similarity: {similarity:.3f} (threshold: {settings.CACHE_SIMILARITY_THRESHOLD})")
# Check if above threshold
if similarity < settings.CACHE_SIMILARITY_THRESHOLD:
return None
# Get cached answer
query_ids = list(self.cached_queries.keys())
if best_idx >= len(query_ids):
return None
query_id = query_ids[best_idx]
cached = self.cached_queries[query_id]
# Update access statistics
cached.accessed_count += 1
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("""
UPDATE query_cache
SET accessed_count = ?, last_accessed = CURRENT_TIMESTAMP
WHERE id = ?
""", (cached.accessed_count, query_id))
conn.commit()
except Exception as e:
logger.warning(f"Error updating cache access: {e}")
logger.info(f"✅ Cache hit! Similarity: {similarity:.3f}, Saved {cached.context_tokens_used} tokens")
return {
"query_id": query_id,
"answer": cached.answer,
"context_tokens_used": cached.context_tokens_used,
"model_used": cached.model_used,
"pruning_ratio": cached.pruning_ratio,
"source_pages": cached.source_pages,
"cache_hit": True,
"similarity": similarity,
"accessed_count": cached.accessed_count
}
except Exception as e:
logger.error(f"Error checking cache: {e}")
return None
def store_in_cache(self, query: str, answer: str, context_tokens_used: int,
textbook_id: int, model_used: str, pruning_ratio: float,
source_pages: str) -> int:
"""
Store query/answer pair in cache.
Args:
query: Original query
answer: LLM answer
context_tokens_used: Tokens used for this query
textbook_id: Textbook ID
model_used: Model name
pruning_ratio: Token reduction ratio
source_pages: Comma-separated page numbers
Returns:
Query cache ID
"""
try:
# Embed query
query_embedding = self.embedder.embed_query(query)
query_embedding = self.embedder.normalize_embeddings(
np.array([query_embedding], dtype=np.float32)
)[0]
# Store in database
conn = get_db_connection()
cursor = conn.cursor()
# Generate query hash for deduplication
query_hash = hashlib.sha256(query.lower().encode()).hexdigest()
cursor.execute("""
INSERT INTO query_cache
(query_hash, query_text, textbook_id, answer, context_tokens_used,
model_used, pruning_ratio, source_pages)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(query_hash) DO UPDATE SET
accessed_count = accessed_count + 1,
last_accessed = CURRENT_TIMESTAMP
""", (query_hash, query, textbook_id, answer, context_tokens_used,
model_used, pruning_ratio, source_pages))
conn.commit()
# Get inserted ID
cursor.execute("SELECT id FROM query_cache WHERE query_hash = ?", (query_hash,))
cache_id = cursor.fetchone()[0]
# Add to FAISS index if not already there
if self.faiss_index is None or self.faiss_index.ntotal == 0:
# Create new index
self.faiss_index = faiss.IndexFlatIP(settings.EMBEDDINGS_DIMENSION)
self.faiss_index.add(query_embedding.reshape(1, -1))
# Add to memory cache
cached = CachedAnswer(
query_id=cache_id,
query_text=query,
query_embedding=query_embedding,
answer=answer,
context_tokens_used=context_tokens_used,
model_used=model_used,
pruning_ratio=pruning_ratio,
source_pages=source_pages,
created_at=""
)
self.cached_queries[cache_id] = cached
self.query_order.append(cache_id)
logger.info(f"✅ Cached query {cache_id}. Cache size: {len(self.cached_queries)}")
return cache_id
except Exception as e:
logger.error(f"Error storing in cache: {e}")
return -1
def get_cache_stats(self) -> Dict:
"""
Get cache statistics.
Returns:
Dict with cache metrics
"""
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("""
SELECT
COUNT(*) as total_queries,
SUM(CASE WHEN accessed_count > 0 THEN 1 ELSE 0 END) as cache_hits,
SUM(context_tokens_used) as total_tokens_saved,
AVG(pruning_ratio) as avg_pruning_ratio,
MAX(created_at) as latest_query
FROM query_cache
""")
row = cursor.fetchone()
total_queries = row[0] or 0
if total_queries == 0:
return {}
cache_hits = row[1] or 0
tokens_saved = row[2] or 0
avg_ratio = row[3] or 0.0
cache_hit_rate = cache_hits / total_queries if total_queries > 0 else 0.0
# Estimate cost savings (based on Haiku pricing)
baseline_cost = (tokens_saved / 1_000_000) * settings.HAIKU_INPUT_COST_PER_1M
return {
"total_queries": total_queries,
"cache_hits": cache_hits,
"cache_hit_rate": cache_hit_rate,
"total_tokens_saved": tokens_saved,
"avg_pruning_ratio": avg_ratio,
"cost_saved_usd": baseline_cost,
"cache_size_queries": len(self.cached_queries)
}
except Exception as e:
logger.error(f"Error getting cache stats: {e}")
return {}
def clear_cache(self, older_than_hours: Optional[int] = None) -> int:
"""
Clear cache entries, optionally by age.
Args:
older_than_hours: Delete entries older than N hours (None = clear all)
Returns:
Number of entries deleted
"""
try:
conn = get_db_connection()
cursor = conn.cursor()
if older_than_hours:
cursor.execute(f"""
DELETE FROM query_cache
WHERE created_at < datetime('now', '-{older_than_hours} hours')
""")
else:
cursor.execute("DELETE FROM query_cache")
cursor.execute("DELETE FROM sqlite_sequence WHERE name='query_cache'")
conn.commit()
deleted_count = cursor.rowcount
# Reset in-memory cache
self.cached_queries.clear()
self.query_order.clear()
self.faiss_index = faiss.IndexFlatIP(settings.EMBEDDINGS_DIMENSION)
logger.info(f"✅ Cleared {deleted_count} cache entries")
return deleted_count
except Exception as e:
logger.error(f"Error clearing cache: {e}")
return 0
# Global cache instance
_cache: Optional[SemanticCache] = None
def get_cache() -> SemanticCache:
"""Get or create global cache instance."""
global _cache
if _cache is None:
_cache = SemanticCache()
return _cache
from dataclasses import dataclass # Import at end to avoid circular dependency