Spaces:
Sleeping
Sleeping
| """ | |
| Semantic Cache Module | |
| Query deduplication using embedding similarity. | |
| Caches answers to semantically similar questions. | |
| Benefits: | |
| - Identical queries: cached (similarity = 1.0) | |
| - Near-identical queries: cached (similarity > 0.90) | |
| - Similar intent queries: cached (similarity > threshold) | |
| - Saves ~0% tokens on first query, ~0% on cache miss, 100% on cache hit | |
| """ | |
| import hashlib | |
| import logging | |
| import faiss | |
| import numpy as np | |
| from typing import Optional, Dict, Tuple | |
| from dataclasses import dataclass | |
| from backend.config import settings | |
| from backend.database import get_db_connection | |
| from backend.ingestion.embedder import Embedder | |
| logger = logging.getLogger(__name__) | |
| class CachedAnswer: | |
| """Represents a cached query/answer pair.""" | |
| query_id: int | |
| query_text: str | |
| query_embedding: np.ndarray | |
| answer: str | |
| context_tokens_used: int | |
| model_used: str | |
| pruning_ratio: float | |
| source_pages: str | |
| created_at: str | |
| accessed_count: int = 0 | |
| class SemanticCache: | |
| """Manages query cache with embedding-based semantic similarity.""" | |
| def __init__(self): | |
| """Initialize semantic cache.""" | |
| self.embedder = Embedder() | |
| self.faiss_index: Optional[faiss.IndexFlatIP] = None | |
| self.cached_queries: Dict[int, CachedAnswer] = {} # id -> CachedAnswer | |
| self.query_order: list = [] # Tracks insertion order for LRU | |
| def load_cache(self) -> None: | |
| """ | |
| Load cached queries from database into memory FAISS index. | |
| Run on application startup. | |
| """ | |
| logger.info("Loading semantic cache from database...") | |
| try: | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| # Get all cached queries | |
| cursor.execute(""" | |
| SELECT id, query_text, answer, context_tokens_used, model_used, | |
| pruning_ratio, source_pages, created_at, accessed_count | |
| FROM query_cache | |
| ORDER BY created_at DESC | |
| """) | |
| rows = cursor.fetchall() | |
| if not rows: | |
| logger.info("Cache is empty") | |
| self.faiss_index = faiss.IndexFlatIP(settings.EMBEDDINGS_DIMENSION) | |
| return | |
| # Embed all cached queries | |
| query_texts = [row[1] for row in rows] | |
| embeddings = self.embedder.embed_chunks(query_texts, show_progress=True) | |
| embeddings = self.embedder.normalize_embeddings(embeddings) | |
| # Create FAISS index | |
| self.faiss_index = faiss.IndexFlatIP(settings.EMBEDDINGS_DIMENSION) | |
| self.faiss_index.add(embeddings) | |
| # Load into memory cache | |
| for i, row in enumerate(rows): | |
| cache_item = CachedAnswer( | |
| query_id=row[0], | |
| query_text=row[1], | |
| query_embedding=embeddings[i], # Save embedding for later | |
| answer=row[2], | |
| context_tokens_used=row[3], | |
| model_used=row[4], | |
| pruning_ratio=row[5], | |
| source_pages=row[6], | |
| created_at=row[7], | |
| accessed_count=row[8] | |
| ) | |
| self.cached_queries[row[0]] = cache_item | |
| self.query_order.append(row[0]) | |
| logger.info(f"✅ Loaded {len(self.cached_queries)} cached queries") | |
| except Exception as e: | |
| logger.error(f"Error loading cache: {e}") | |
| self.faiss_index = faiss.IndexFlatIP(settings.EMBEDDINGS_DIMENSION) | |
| def check_cache(self, query: str, textbook_id: int) -> Optional[Dict]: | |
| """ | |
| Check if query exists in cache (with semantic similarity). | |
| Args: | |
| query: Student question | |
| textbook_id: Textbook being searched | |
| Returns: | |
| Cached answer dict if found, None otherwise | |
| """ | |
| if self.faiss_index is None or len(self.cached_queries) == 0: | |
| return None | |
| try: | |
| # Embed query | |
| query_embedding = self.embedder.embed_query(query) | |
| query_embedding = self.embedder.normalize_embeddings( | |
| np.array([query_embedding], dtype=np.float32) | |
| )[0] | |
| # Search FAISS index | |
| query_reshaped = query_embedding.reshape(1, -1) | |
| distances, indices = self.faiss_index.search(query_reshaped, 1) | |
| if len(indices) == 0 or len(indices[0]) == 0: | |
| return None | |
| # For IndexFlatIP with normalized vectors, distance = similarity score | |
| similarity = float(distances[0][0]) | |
| best_idx = int(indices[0][0]) | |
| logger.debug(f"Best cache match similarity: {similarity:.3f} (threshold: {settings.CACHE_SIMILARITY_THRESHOLD})") | |
| # Check if above threshold | |
| if similarity < settings.CACHE_SIMILARITY_THRESHOLD: | |
| return None | |
| # Get cached answer | |
| query_ids = list(self.cached_queries.keys()) | |
| if best_idx >= len(query_ids): | |
| return None | |
| query_id = query_ids[best_idx] | |
| cached = self.cached_queries[query_id] | |
| # Update access statistics | |
| cached.accessed_count += 1 | |
| try: | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| UPDATE query_cache | |
| SET accessed_count = ?, last_accessed = CURRENT_TIMESTAMP | |
| WHERE id = ? | |
| """, (cached.accessed_count, query_id)) | |
| conn.commit() | |
| except Exception as e: | |
| logger.warning(f"Error updating cache access: {e}") | |
| logger.info(f"✅ Cache hit! Similarity: {similarity:.3f}, Saved {cached.context_tokens_used} tokens") | |
| return { | |
| "query_id": query_id, | |
| "answer": cached.answer, | |
| "context_tokens_used": cached.context_tokens_used, | |
| "model_used": cached.model_used, | |
| "pruning_ratio": cached.pruning_ratio, | |
| "source_pages": cached.source_pages, | |
| "cache_hit": True, | |
| "similarity": similarity, | |
| "accessed_count": cached.accessed_count | |
| } | |
| except Exception as e: | |
| logger.error(f"Error checking cache: {e}") | |
| return None | |
| def store_in_cache(self, query: str, answer: str, context_tokens_used: int, | |
| textbook_id: int, model_used: str, pruning_ratio: float, | |
| source_pages: str) -> int: | |
| """ | |
| Store query/answer pair in cache. | |
| Args: | |
| query: Original query | |
| answer: LLM answer | |
| context_tokens_used: Tokens used for this query | |
| textbook_id: Textbook ID | |
| model_used: Model name | |
| pruning_ratio: Token reduction ratio | |
| source_pages: Comma-separated page numbers | |
| Returns: | |
| Query cache ID | |
| """ | |
| try: | |
| # Embed query | |
| query_embedding = self.embedder.embed_query(query) | |
| query_embedding = self.embedder.normalize_embeddings( | |
| np.array([query_embedding], dtype=np.float32) | |
| )[0] | |
| # Store in database | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| # Generate query hash for deduplication | |
| query_hash = hashlib.sha256(query.lower().encode()).hexdigest() | |
| cursor.execute(""" | |
| INSERT INTO query_cache | |
| (query_hash, query_text, textbook_id, answer, context_tokens_used, | |
| model_used, pruning_ratio, source_pages) | |
| VALUES (?, ?, ?, ?, ?, ?, ?, ?) | |
| ON CONFLICT(query_hash) DO UPDATE SET | |
| accessed_count = accessed_count + 1, | |
| last_accessed = CURRENT_TIMESTAMP | |
| """, (query_hash, query, textbook_id, answer, context_tokens_used, | |
| model_used, pruning_ratio, source_pages)) | |
| conn.commit() | |
| # Get inserted ID | |
| cursor.execute("SELECT id FROM query_cache WHERE query_hash = ?", (query_hash,)) | |
| cache_id = cursor.fetchone()[0] | |
| # Add to FAISS index if not already there | |
| if self.faiss_index is None or self.faiss_index.ntotal == 0: | |
| # Create new index | |
| self.faiss_index = faiss.IndexFlatIP(settings.EMBEDDINGS_DIMENSION) | |
| self.faiss_index.add(query_embedding.reshape(1, -1)) | |
| # Add to memory cache | |
| cached = CachedAnswer( | |
| query_id=cache_id, | |
| query_text=query, | |
| query_embedding=query_embedding, | |
| answer=answer, | |
| context_tokens_used=context_tokens_used, | |
| model_used=model_used, | |
| pruning_ratio=pruning_ratio, | |
| source_pages=source_pages, | |
| created_at="" | |
| ) | |
| self.cached_queries[cache_id] = cached | |
| self.query_order.append(cache_id) | |
| logger.info(f"✅ Cached query {cache_id}. Cache size: {len(self.cached_queries)}") | |
| return cache_id | |
| except Exception as e: | |
| logger.error(f"Error storing in cache: {e}") | |
| return -1 | |
| def get_cache_stats(self) -> Dict: | |
| """ | |
| Get cache statistics. | |
| Returns: | |
| Dict with cache metrics | |
| """ | |
| try: | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| SELECT | |
| COUNT(*) as total_queries, | |
| SUM(CASE WHEN accessed_count > 0 THEN 1 ELSE 0 END) as cache_hits, | |
| SUM(context_tokens_used) as total_tokens_saved, | |
| AVG(pruning_ratio) as avg_pruning_ratio, | |
| MAX(created_at) as latest_query | |
| FROM query_cache | |
| """) | |
| row = cursor.fetchone() | |
| total_queries = row[0] or 0 | |
| if total_queries == 0: | |
| return {} | |
| cache_hits = row[1] or 0 | |
| tokens_saved = row[2] or 0 | |
| avg_ratio = row[3] or 0.0 | |
| cache_hit_rate = cache_hits / total_queries if total_queries > 0 else 0.0 | |
| # Estimate cost savings (based on Haiku pricing) | |
| baseline_cost = (tokens_saved / 1_000_000) * settings.HAIKU_INPUT_COST_PER_1M | |
| return { | |
| "total_queries": total_queries, | |
| "cache_hits": cache_hits, | |
| "cache_hit_rate": cache_hit_rate, | |
| "total_tokens_saved": tokens_saved, | |
| "avg_pruning_ratio": avg_ratio, | |
| "cost_saved_usd": baseline_cost, | |
| "cache_size_queries": len(self.cached_queries) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting cache stats: {e}") | |
| return {} | |
| def clear_cache(self, older_than_hours: Optional[int] = None) -> int: | |
| """ | |
| Clear cache entries, optionally by age. | |
| Args: | |
| older_than_hours: Delete entries older than N hours (None = clear all) | |
| Returns: | |
| Number of entries deleted | |
| """ | |
| try: | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| if older_than_hours: | |
| cursor.execute(f""" | |
| DELETE FROM query_cache | |
| WHERE created_at < datetime('now', '-{older_than_hours} hours') | |
| """) | |
| else: | |
| cursor.execute("DELETE FROM query_cache") | |
| cursor.execute("DELETE FROM sqlite_sequence WHERE name='query_cache'") | |
| conn.commit() | |
| deleted_count = cursor.rowcount | |
| # Reset in-memory cache | |
| self.cached_queries.clear() | |
| self.query_order.clear() | |
| self.faiss_index = faiss.IndexFlatIP(settings.EMBEDDINGS_DIMENSION) | |
| logger.info(f"✅ Cleared {deleted_count} cache entries") | |
| return deleted_count | |
| except Exception as e: | |
| logger.error(f"Error clearing cache: {e}") | |
| return 0 | |
| # Global cache instance | |
| _cache: Optional[SemanticCache] = None | |
| def get_cache() -> SemanticCache: | |
| """Get or create global cache instance.""" | |
| global _cache | |
| if _cache is None: | |
| _cache = SemanticCache() | |
| return _cache | |
| from dataclasses import dataclass # Import at end to avoid circular dependency | |