Spaces:
Runtime error
Runtime error
| """Caching system for embeddings, queries, and responses.""" | |
| import hashlib | |
| import time | |
| from typing import Optional, Dict, Any, List | |
| from datetime import datetime, timedelta | |
| import json | |
| class CacheEntry: | |
| """Represents a cached item with expiration.""" | |
| def __init__(self, value: Any, ttl: int = 3600): | |
| """ | |
| Initialize cache entry. | |
| Args: | |
| value: Value to cache | |
| ttl: Time to live in seconds (default: 1 hour) | |
| """ | |
| self.value = value | |
| self.created_at = time.time() | |
| self.ttl = ttl | |
| self.hit_count = 0 | |
| def is_expired(self) -> bool: | |
| """Check if entry has expired.""" | |
| return time.time() - self.created_at > self.ttl | |
| def increment_hits(self): | |
| """Increment hit counter.""" | |
| self.hit_count += 1 | |
| class EmbeddingCache: | |
| """Cache for query embeddings to avoid re-computing.""" | |
| def __init__(self, max_size: int = 1000, ttl: int = 86400): | |
| """ | |
| Initialize embedding cache. | |
| Args: | |
| max_size: Maximum number of entries (default: 1000) | |
| ttl: Time to live in seconds (default: 24 hours) | |
| """ | |
| self.cache: Dict[str, CacheEntry] = {} | |
| self.max_size = max_size | |
| self.ttl = ttl | |
| self.hits = 0 | |
| self.misses = 0 | |
| def _generate_key(self, text: str) -> str: | |
| """Generate cache key from text.""" | |
| return hashlib.md5(text.lower().strip().encode()).hexdigest() | |
| def get(self, text: str) -> Optional[List[float]]: | |
| """ | |
| Get cached embedding. | |
| Args: | |
| text: Query text | |
| Returns: | |
| Cached embedding vector or None | |
| """ | |
| key = self._generate_key(text) | |
| if key in self.cache: | |
| entry = self.cache[key] | |
| if not entry.is_expired(): | |
| entry.increment_hits() | |
| self.hits += 1 | |
| return entry.value | |
| else: | |
| # Remove expired entry | |
| del self.cache[key] | |
| self.misses += 1 | |
| return None | |
| def set(self, text: str, embedding: List[float]): | |
| """ | |
| Cache an embedding. | |
| Args: | |
| text: Query text | |
| embedding: Embedding vector | |
| """ | |
| key = self._generate_key(text) | |
| # If cache is full, remove oldest entries | |
| if len(self.cache) >= self.max_size: | |
| self._evict_oldest() | |
| self.cache[key] = CacheEntry(embedding, ttl=self.ttl) | |
| def _evict_oldest(self): | |
| """Remove oldest 10% of entries.""" | |
| num_to_remove = max(1, self.max_size // 10) | |
| # Sort by creation time and remove oldest | |
| sorted_keys = sorted( | |
| self.cache.keys(), | |
| key=lambda k: self.cache[k].created_at | |
| ) | |
| for key in sorted_keys[:num_to_remove]: | |
| del self.cache[key] | |
| def clear(self): | |
| """Clear all cached embeddings.""" | |
| self.cache.clear() | |
| self.hits = 0 | |
| self.misses = 0 | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get cache statistics.""" | |
| total_requests = self.hits + self.misses | |
| hit_rate = (self.hits / total_requests * 100) if total_requests > 0 else 0 | |
| return { | |
| "size": len(self.cache), | |
| "max_size": self.max_size, | |
| "hits": self.hits, | |
| "misses": self.misses, | |
| "hit_rate": round(hit_rate, 2), | |
| "total_requests": total_requests | |
| } | |
| class QueryResponseCache: | |
| """Cache for complete query responses.""" | |
| def __init__(self, max_size: int = 500, ttl: int = 3600): | |
| """ | |
| Initialize response cache. | |
| Args: | |
| max_size: Maximum number of entries (default: 500) | |
| ttl: Time to live in seconds (default: 1 hour) | |
| """ | |
| self.cache: Dict[str, CacheEntry] = {} | |
| self.max_size = max_size | |
| self.ttl = ttl | |
| self.hits = 0 | |
| self.misses = 0 | |
| def _generate_key( | |
| self, | |
| query: str, | |
| ticker: Optional[str] = None, | |
| doc_types: Optional[List[str]] = None, | |
| top_k: int = 10 | |
| ) -> str: | |
| """Generate cache key from query parameters.""" | |
| # Normalize inputs | |
| query_normalized = query.lower().strip() | |
| ticker_normalized = ticker.lower() if ticker else "" | |
| doc_types_normalized = sorted(doc_types) if doc_types else [] | |
| # Create key string | |
| key_parts = [ | |
| query_normalized, | |
| ticker_normalized, | |
| ",".join(doc_types_normalized), | |
| str(top_k) | |
| ] | |
| key_string = "|".join(key_parts) | |
| return hashlib.md5(key_string.encode()).hexdigest() | |
| def get( | |
| self, | |
| query: str, | |
| ticker: Optional[str] = None, | |
| doc_types: Optional[List[str]] = None, | |
| top_k: int = 10 | |
| ) -> Optional[Dict[str, Any]]: | |
| """ | |
| Get cached response. | |
| Args: | |
| query: Query text | |
| ticker: Ticker filter | |
| doc_types: Document type filters | |
| top_k: Number of results | |
| Returns: | |
| Cached response or None | |
| """ | |
| key = self._generate_key(query, ticker, doc_types, top_k) | |
| if key in self.cache: | |
| entry = self.cache[key] | |
| if not entry.is_expired(): | |
| entry.increment_hits() | |
| self.hits += 1 | |
| return entry.value | |
| else: | |
| del self.cache[key] | |
| self.misses += 1 | |
| return None | |
| def set( | |
| self, | |
| query: str, | |
| response: Dict[str, Any], | |
| ticker: Optional[str] = None, | |
| doc_types: Optional[List[str]] = None, | |
| top_k: int = 10 | |
| ): | |
| """ | |
| Cache a response. | |
| Args: | |
| query: Query text | |
| response: Response to cache | |
| ticker: Ticker filter | |
| doc_types: Document type filters | |
| top_k: Number of results | |
| """ | |
| key = self._generate_key(query, ticker, doc_types, top_k) | |
| if len(self.cache) >= self.max_size: | |
| self._evict_lru() | |
| self.cache[key] = CacheEntry(response, ttl=self.ttl) | |
| def _evict_lru(self): | |
| """Remove least recently used 10% of entries.""" | |
| num_to_remove = max(1, self.max_size // 10) | |
| # Sort by last access time (hit count and creation time) | |
| sorted_keys = sorted( | |
| self.cache.keys(), | |
| key=lambda k: (self.cache[k].hit_count, self.cache[k].created_at) | |
| ) | |
| for key in sorted_keys[:num_to_remove]: | |
| del self.cache[key] | |
| def clear(self): | |
| """Clear all cached responses.""" | |
| self.cache.clear() | |
| self.hits = 0 | |
| self.misses = 0 | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get cache statistics.""" | |
| total_requests = self.hits + self.misses | |
| hit_rate = (self.hits / total_requests * 100) if total_requests > 0 else 0 | |
| # Calculate cost savings (assuming $0.0001 per query) | |
| cost_per_query = 0.0001 # Approximate cost per LLM call | |
| estimated_savings = self.hits * cost_per_query | |
| return { | |
| "size": len(self.cache), | |
| "max_size": self.max_size, | |
| "hits": self.hits, | |
| "misses": self.misses, | |
| "hit_rate": round(hit_rate, 2), | |
| "total_requests": total_requests, | |
| "estimated_savings_usd": round(estimated_savings, 4) | |
| } | |
| class DocumentCache: | |
| """Cache for retrieved documents to avoid vector searches.""" | |
| def __init__(self, max_size: int = 200, ttl: int = 7200): | |
| """ | |
| Initialize document cache. | |
| Args: | |
| max_size: Maximum number of entries (default: 200) | |
| ttl: Time to live in seconds (default: 2 hours) | |
| """ | |
| self.cache: Dict[str, CacheEntry] = {} | |
| self.max_size = max_size | |
| self.ttl = ttl | |
| self.hits = 0 | |
| self.misses = 0 | |
| def _generate_key( | |
| self, | |
| query: str, | |
| ticker: Optional[str] = None, | |
| doc_types: Optional[List[str]] = None | |
| ) -> str: | |
| """Generate cache key from search parameters.""" | |
| query_normalized = query.lower().strip() | |
| ticker_normalized = ticker.lower() if ticker else "" | |
| doc_types_normalized = sorted(doc_types) if doc_types else [] | |
| key_string = f"{query_normalized}|{ticker_normalized}|{','.join(doc_types_normalized)}" | |
| return hashlib.md5(key_string.encode()).hexdigest() | |
| def get( | |
| self, | |
| query: str, | |
| ticker: Optional[str] = None, | |
| doc_types: Optional[List[str]] = None | |
| ) -> Optional[List[Any]]: | |
| """Get cached documents.""" | |
| key = self._generate_key(query, ticker, doc_types) | |
| if key in self.cache: | |
| entry = self.cache[key] | |
| if not entry.is_expired(): | |
| entry.increment_hits() | |
| self.hits += 1 | |
| return entry.value | |
| else: | |
| del self.cache[key] | |
| self.misses += 1 | |
| return None | |
| def set( | |
| self, | |
| query: str, | |
| documents: List[Any], | |
| ticker: Optional[str] = None, | |
| doc_types: Optional[List[str]] = None | |
| ): | |
| """Cache retrieved documents.""" | |
| key = self._generate_key(query, ticker, doc_types) | |
| if len(self.cache) >= self.max_size: | |
| self._evict_oldest() | |
| self.cache[key] = CacheEntry(documents, ttl=self.ttl) | |
| def _evict_oldest(self): | |
| """Remove oldest 10% of entries.""" | |
| num_to_remove = max(1, self.max_size // 10) | |
| sorted_keys = sorted( | |
| self.cache.keys(), | |
| key=lambda k: self.cache[k].created_at | |
| ) | |
| for key in sorted_keys[:num_to_remove]: | |
| del self.cache[key] | |
| def clear(self): | |
| """Clear all cached documents.""" | |
| self.cache.clear() | |
| self.hits = 0 | |
| self.misses = 0 | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get cache statistics.""" | |
| total_requests = self.hits + self.misses | |
| hit_rate = (self.hits / total_requests * 100) if total_requests > 0 else 0 | |
| return { | |
| "size": len(self.cache), | |
| "max_size": self.max_size, | |
| "hits": self.hits, | |
| "misses": self.misses, | |
| "hit_rate": round(hit_rate, 2), | |
| "total_requests": total_requests | |
| } | |
| class CacheManager: | |
| """Centralized cache management.""" | |
| def __init__(self): | |
| """Initialize all caches.""" | |
| self.embedding_cache = EmbeddingCache(max_size=1000, ttl=86400) # 24h | |
| self.response_cache = QueryResponseCache(max_size=500, ttl=3600) # 1h | |
| self.document_cache = DocumentCache(max_size=200, ttl=7200) # 2h | |
| def get_response(self, cache_key: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| Get cached response by pre-computed key. | |
| Args: | |
| cache_key: Pre-computed cache key from rag_chain | |
| Returns: | |
| Cached response data or None | |
| """ | |
| if cache_key in self.response_cache.cache: | |
| entry = self.response_cache.cache[cache_key] | |
| if not entry.is_expired(): | |
| entry.increment_hits() | |
| self.response_cache.hits += 1 | |
| return entry.value | |
| else: | |
| # Remove expired entry | |
| del self.response_cache.cache[cache_key] | |
| self.response_cache.misses += 1 | |
| return None | |
| def set_response(self, cache_key: str, response_data: Dict[str, Any]): | |
| """ | |
| Cache a response with pre-computed key. | |
| Args: | |
| cache_key: Pre-computed cache key from rag_chain | |
| response_data: Response data to cache | |
| """ | |
| # Check if cache is full and evict if needed | |
| if len(self.response_cache.cache) >= self.response_cache.max_size: | |
| self.response_cache._evict_lru() | |
| # Store with the pre-computed key | |
| self.response_cache.cache[cache_key] = CacheEntry( | |
| response_data, | |
| ttl=self.response_cache.ttl | |
| ) | |
| def get_stats(self) -> Dict[str, Any]: | |
| """ | |
| Get cache statistics (alias for get_all_stats). | |
| Returns: | |
| Dictionary with stats for all caches | |
| """ | |
| return self.get_all_stats() | |
| def clear(self): | |
| """Clear all caches (alias for clear_all).""" | |
| self.clear_all() | |
| def clear_all(self): | |
| """Clear all caches.""" | |
| self.embedding_cache.clear() | |
| self.response_cache.clear() | |
| self.document_cache.clear() | |
| def get_all_stats(self) -> Dict[str, Any]: | |
| """Get statistics for all caches.""" | |
| return { | |
| "embedding_cache": self.embedding_cache.get_stats(), | |
| "response_cache": self.response_cache.get_stats(), | |
| "document_cache": self.document_cache.get_stats(), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| # Global cache manager instance | |
| cache_manager = CacheManager() |