Spaces:
Paused
Paused
| # backend/services/cache_manager.py | |
| import asyncio | |
| import hashlib | |
| import json | |
| import logging | |
| import os | |
| import threading | |
| import time | |
| from collections.abc import Callable | |
| from dataclasses import dataclass | |
| from datetime import datetime, timedelta | |
| from typing import Any | |
| try: | |
| import redis | |
| HAS_REDIS = True | |
| except ImportError: | |
| HAS_REDIS = False | |
| redis = None | |
| logger = logging.getLogger(__name__) | |
| class CacheEntry: | |
| key: str | |
| value: Any | |
| created_at: datetime | |
| expires_at: datetime | None | |
| access_count: int = 0 | |
| last_accessed: datetime | None = None | |
| size_bytes: int = 0 | |
| class QueryCacheEntry: | |
| """Specialized cache entry for database query results""" | |
| query_hash: str | |
| query_sql: str | |
| parameters: tuple | |
| result: Any | |
| execution_time: float | |
| created_at: datetime | |
| expires_at: datetime | |
| hit_count: int = 0 | |
| table_names: list[str] = None # For cache invalidation | |
| is_read_replica: bool = False # Whether result came from read replica | |
| def __post_init__(self): | |
| if self.table_names is None: | |
| self.table_names = [] | |
| class QueryCacheMetrics: | |
| """Metrics specific to query result caching""" | |
| def __init__(self): | |
| self.query_hits = 0 | |
| self.query_misses = 0 | |
| self.cache_invalidations = 0 | |
| self.read_replica_hits = 0 | |
| self.primary_db_hits = 0 | |
| self.avg_query_time_saved = 0.0 | |
| def hit_rate(self) -> float: | |
| total = self.query_hits + self.query_misses | |
| return self.query_hits / total if total > 0 else 0.0 | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "query_hits": self.query_hits, | |
| "query_misses": self.query_misses, | |
| "cache_invalidations": self.cache_invalidations, | |
| "read_replica_hits": self.read_replica_hits, | |
| "primary_db_hits": self.primary_db_hits, | |
| "avg_query_time_saved": self.avg_query_time_saved, | |
| "hit_rate": self.hit_rate(), | |
| } | |
| class CacheMetrics: | |
| def __init__(self): | |
| self.hits = 0 | |
| self.misses = 0 | |
| self.evictions = 0 | |
| self.sets = 0 | |
| self.deletes = 0 | |
| def hit_rate(self) -> float: | |
| total = self.hits + self.misses | |
| return self.hits / total if total > 0 else 0.0 | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "hits": self.hits, | |
| "misses": self.misses, | |
| "evictions": self.evictions, | |
| "sets": self.sets, | |
| "deletes": self.deletes, | |
| "hit_rate": self.hit_rate(), | |
| "total_requests": self.hits + self.misses, | |
| } | |
| class MultiLayerCache: | |
| """ | |
| Multi-layer caching system with memory and optional Redis support | |
| L1: Fast in-memory (hot data) | |
| L2: Larger in-memory (warm data) | |
| L3: Redis (persistent, shared across instances) | |
| """ | |
| def __init__( | |
| self, | |
| max_memory_entries: int = 1000, | |
| default_ttl_seconds: int = 300, | |
| redis_host: str = "localhost", | |
| redis_port: int = 6379, | |
| redis_db: int = 0, | |
| ): | |
| self.l1_cache: dict[str, CacheEntry] = {} # L1: Fast memory cache | |
| self.l2_cache: dict[str, CacheEntry] = {} # L2: Larger memory cache | |
| self.max_l1_entries = max_memory_entries // 4 # 25% for L1 | |
| self.max_l2_entries = max_memory_entries - self.max_l1_entries # 75% for L2 | |
| self.default_ttl = timedelta(seconds=default_ttl_seconds) | |
| self.metrics = CacheMetrics() | |
| self.lock = threading.RLock() | |
| # Redis configuration | |
| self.redis_host = redis_host | |
| self.redis_port = redis_port | |
| self.redis_db = redis_db | |
| self.redis_client = None | |
| self.redis_retry_count = 0 | |
| self.redis_max_retries = 5 | |
| self.redis_retry_interval = 30 # seconds between retries | |
| if HAS_REDIS: | |
| self._init_redis() | |
| # Start cleanup task | |
| self._start_cleanup_task() | |
| # Start Redis reconnection task | |
| self._start_redis_retry_task() | |
| def _init_redis(self): | |
| """Initialize Redis client""" | |
| try: | |
| # Check for REDIS_URL environment variable first (Common for Upstash/Cloud) | |
| redis_url = os.getenv("REDIS_URL") | |
| if redis_url: | |
| # Use from_url for connection strings (handles ssl/rediss automatically) | |
| self.redis_client = redis.from_url( | |
| redis_url, | |
| decode_responses=False, | |
| socket_connect_timeout=5, | |
| socket_timeout=5, | |
| ) | |
| else: | |
| # Fallback to host/port | |
| self.redis_client = redis.Redis( | |
| host=self.redis_host, | |
| port=self.redis_port, | |
| db=self.redis_db, | |
| decode_responses=False, | |
| socket_connect_timeout=5, | |
| socket_timeout=5, | |
| ) | |
| # Test connection | |
| self.redis_client.ping() | |
| self.redis_retry_count = 0 | |
| logger.info( | |
| f"Redis cache layer initialized successfully (host={self.redis_host}:{self.redis_port})" | |
| ) | |
| except Exception as e: | |
| self.redis_client = None | |
| if self.redis_retry_count < self.redis_max_retries: | |
| self.redis_retry_count += 1 | |
| logger.warning( | |
| f"Redis initialization attempt {self.redis_retry_count}/{self.redis_max_retries} failed: {e}. " | |
| f"Retrying in {self.redis_retry_interval}s. Using memory-only cache for now." | |
| ) | |
| else: | |
| logger.warning( | |
| f"Redis initialization failed after {self.redis_max_retries} attempts: {e}. " | |
| f"Continuing with enhanced memory-only cache (L1+L2 layers)." | |
| ) | |
| def _start_redis_retry_task(self): | |
| """Start background task to retry Redis connection""" | |
| def retry_worker(): | |
| while True: | |
| try: | |
| time.sleep(self.redis_retry_interval) | |
| if ( | |
| self.redis_client is None | |
| and self.redis_retry_count < self.redis_max_retries | |
| ): | |
| logger.info( | |
| f"Attempting to reconnect to Redis (attempt {self.redis_retry_count + 1}/{self.redis_max_retries})..." | |
| ) | |
| self._init_redis() | |
| except Exception as e: | |
| logger.error(f"Redis retry task error: {e}") | |
| time.sleep(self.redis_retry_interval) | |
| thread = threading.Thread(target=retry_worker, daemon=True) | |
| thread.start() | |
| def _redis_available(self) -> bool: | |
| """Check if Redis is available""" | |
| return self.redis_client is not None | |
| def _serialize_for_redis(self, value: Any) -> str: | |
| """Serialize value for Redis storage""" | |
| try: | |
| return json.dumps(value, default=str) | |
| except Exception: | |
| return json.dumps({"error": "serialization_failed"}) | |
| def _deserialize_from_redis(self, value: str) -> Any: | |
| """Deserialize value from Redis storage""" | |
| try: | |
| return json.loads(value) | |
| except Exception: | |
| return None | |
| def _generate_key(self, namespace: str, key: Any) -> str: | |
| """Generate a consistent cache key""" | |
| if isinstance(key, (dict, list)): | |
| key_str = json.dumps(key, sort_keys=True, default=str) | |
| else: | |
| key_str = str(key) | |
| full_key = f"{namespace}:{key_str}" | |
| return hashlib.md5(full_key.encode()).hexdigest() | |
| def _calculate_size(self, value: Any) -> int: | |
| """Calculate approximate memory size of value""" | |
| try: | |
| return len(json.dumps(value, default=str).encode("utf-8")) | |
| except Exception: | |
| return len(str(value).encode("utf-8")) | |
| def _is_expired(self, entry: CacheEntry) -> bool: | |
| """Check if cache entry is expired""" | |
| return entry.expires_at and datetime.now() > entry.expires_at | |
| def _evict_lru(self, cache: dict[str, CacheEntry], max_entries: int): | |
| """Evict least recently used entries""" | |
| if len(cache) <= max_entries: | |
| return | |
| # Sort by last accessed time (oldest first) | |
| entries = sorted( | |
| cache.items(), key=lambda x: x[1].last_accessed or x[1].created_at | |
| ) | |
| to_evict = len(cache) - max_entries | |
| for key, _ in entries[:to_evict]: | |
| del cache[key] | |
| self.metrics.evictions += 1 | |
| def _cleanup_expired(self): | |
| """Clean up expired entries from both cache layers""" | |
| with self.lock: | |
| datetime.now() | |
| # Clean L1 cache | |
| expired_l1 = [k for k, v in self.l1_cache.items() if self._is_expired(v)] | |
| for key in expired_l1: | |
| del self.l1_cache[key] | |
| # Clean L2 cache | |
| expired_l2 = [k for k, v in self.l2_cache.items() if self._is_expired(v)] | |
| for key in expired_l2: | |
| del self.l2_cache[key] | |
| if expired_l1 or expired_l2: | |
| logger.debug( | |
| f"Cleaned up {len(expired_l1)} L1 and {len(expired_l2)} L2 expired entries" | |
| ) | |
| def _start_cleanup_task(self): | |
| """Start background cleanup task""" | |
| def cleanup_worker(): | |
| while True: | |
| try: | |
| self._cleanup_expired() | |
| time.sleep(60) # Clean up every minute | |
| except Exception as e: | |
| logger.error(f"Cache cleanup error: {e}") | |
| time.sleep(60) | |
| thread = threading.Thread(target=cleanup_worker, daemon=True) | |
| thread.start() | |
| def get(self, namespace: str, key: Any) -> Any | None: | |
| """Get value from cache""" | |
| cache_key = self._generate_key(namespace, key) | |
| with self.lock: | |
| # Try L1 cache first | |
| if cache_key in self.l1_cache: | |
| entry = self.l1_cache[cache_key] | |
| if not self._is_expired(entry): | |
| entry.access_count += 1 | |
| entry.last_accessed = datetime.now() | |
| self.metrics.hits += 1 | |
| return entry.value | |
| else: | |
| del self.l1_cache[cache_key] | |
| # Try L2 cache | |
| if cache_key in self.l2_cache: | |
| entry = self.l2_cache[cache_key] | |
| if not self._is_expired(entry): | |
| entry.access_count += 1 | |
| entry.last_accessed = datetime.now() | |
| self.metrics.hits += 1 | |
| # Promote to L1 cache | |
| self.l1_cache[cache_key] = entry | |
| self._evict_lru(self.l1_cache, self.max_l1_entries) | |
| return entry.value | |
| else: | |
| del self.l2_cache[cache_key] | |
| # Try L3 cache (Redis) | |
| if self._redis_available(): | |
| try: | |
| redis_key = f"{namespace}:{cache_key}" | |
| redis_value = self.redis_client.get(redis_key) | |
| if redis_value: | |
| value = self._deserialize_from_redis( | |
| redis_value.decode("utf-8") | |
| ) | |
| if value is not None: | |
| self.metrics.hits += 1 | |
| # Promote to memory caches | |
| entry = CacheEntry( | |
| key=cache_key, | |
| value=value, | |
| created_at=datetime.now(), | |
| expires_at=None, # Redis handles TTL | |
| size_bytes=self._calculate_size(value), | |
| ) | |
| with self.lock: | |
| self.l2_cache[cache_key] = entry | |
| self._evict_lru(self.l2_cache, self.max_l2_entries) | |
| self.l1_cache[cache_key] = entry | |
| self._evict_lru(self.l1_cache, self.max_l1_entries) | |
| return value | |
| except Exception as e: | |
| logger.debug(f"Redis L3 cache miss: {e}") | |
| self.metrics.misses += 1 | |
| return None | |
| def set( | |
| self, namespace: str, key: Any, value: Any, ttl_seconds: int | None = None | |
| ) -> bool: | |
| """Set value in cache""" | |
| cache_key = self._generate_key(namespace, key) | |
| expires_at = ( | |
| datetime.now() + timedelta(seconds=ttl_seconds) if ttl_seconds else None | |
| ) | |
| size_bytes = self._calculate_size(value) | |
| entry = CacheEntry( | |
| key=cache_key, | |
| value=value, | |
| created_at=datetime.now(), | |
| expires_at=expires_at, | |
| size_bytes=size_bytes, | |
| ) | |
| with self.lock: | |
| # Always set in L2 cache | |
| self.l2_cache[cache_key] = entry | |
| self._evict_lru(self.l2_cache, self.max_l2_entries) | |
| # Set in L1 cache for frequently accessed items | |
| self.l1_cache[cache_key] = entry | |
| self._evict_lru(self.l1_cache, self.max_l1_entries) | |
| # Store in L3 cache (Redis) if available | |
| if self._redis_available(): | |
| try: | |
| redis_key = f"{namespace}:{cache_key}" | |
| serialized_value = self._serialize_for_redis(value) | |
| if ttl_seconds: | |
| self.redis_client.setex( | |
| redis_key, ttl_seconds, serialized_value | |
| ) | |
| else: | |
| self.redis_client.set(redis_key, serialized_value) | |
| except Exception as e: | |
| logger.debug(f"Redis L3 cache set failed: {e}") | |
| self.metrics.sets += 1 | |
| return True | |
| def delete(self, namespace: str, key: Any) -> bool: | |
| """Delete value from cache""" | |
| cache_key = self._generate_key(namespace, key) | |
| with self.lock: | |
| deleted = False | |
| if cache_key in self.l1_cache: | |
| del self.l1_cache[cache_key] | |
| deleted = True | |
| if cache_key in self.l2_cache: | |
| del self.l2_cache[cache_key] | |
| deleted = True | |
| # Delete from L3 cache (Redis) | |
| if self._redis_available(): | |
| try: | |
| redis_key = f"{namespace}:{cache_key}" | |
| self.redis_client.delete(redis_key) | |
| except Exception as e: | |
| logger.debug(f"Redis L3 cache delete failed: {e}") | |
| if deleted: | |
| self.metrics.deletes += 1 | |
| return deleted | |
| def clear_namespace(self, namespace: str) -> int: | |
| """Clear all entries in a namespace""" | |
| with self.lock: | |
| cleared = 0 | |
| # Clear from L1 | |
| to_remove_l1 = [k for k in self.l1_cache if k.startswith(f"{namespace}:")] | |
| for key in to_remove_l1: | |
| del self.l1_cache[key] | |
| cleared += 1 | |
| # Clear from L2 | |
| to_remove_l2 = [k for k in self.l2_cache if k.startswith(f"{namespace}:")] | |
| for key in to_remove_l2: | |
| del self.l2_cache[key] | |
| cleared += 1 | |
| # Clear from L3 cache (Redis) | |
| if self._redis_available(): | |
| try: | |
| pattern = f"{namespace}:*" | |
| keys = self.redis_client.keys(pattern) | |
| if keys: | |
| self.redis_client.delete(*keys) | |
| cleared += len(keys) | |
| except Exception as e: | |
| logger.debug(f"Redis L3 cache clear failed: {e}") | |
| return cleared | |
| def clear_all(self) -> int: | |
| """Clear all cache entries""" | |
| with self.lock: | |
| total_cleared = len(self.l1_cache) + len(self.l2_cache) | |
| self.l1_cache.clear() | |
| self.l2_cache.clear() | |
| return total_cleared | |
| def get_stats(self) -> dict[str, Any]: | |
| """Get comprehensive cache statistics""" | |
| stats = { | |
| "l1_cache": { | |
| "entries": len(self.l1_cache), | |
| "max_entries": self.max_l1_entries, | |
| "utilization": ( | |
| len(self.l1_cache) / self.max_l1_entries | |
| if self.max_l1_entries > 0 | |
| else 0 | |
| ), | |
| }, | |
| "l2_cache": { | |
| "entries": len(self.l2_cache), | |
| "max_entries": self.max_l2_entries, | |
| "utilization": ( | |
| len(self.l2_cache) / self.max_l2_entries | |
| if self.max_l2_entries > 0 | |
| else 0 | |
| ), | |
| }, | |
| "metrics": self.metrics.to_dict(), | |
| "total_size_bytes": sum(e.size_bytes for e in self.l1_cache.values()) | |
| + sum(e.size_bytes for e in self.l2_cache.values()), | |
| } | |
| # Add Redis stats if available | |
| if self._redis_available(): | |
| try: | |
| redis_info = self.redis_client.info() | |
| stats["l3_cache"] = { | |
| "available": True, | |
| "db_size": self.redis_client.dbsize(), | |
| "memory_used": redis_info.get("used_memory_human", "N/A"), | |
| "connected_clients": redis_info.get("connected_clients", 0), | |
| } | |
| except Exception as e: | |
| stats["l3_cache"] = {"available": False, "error": str(e)} | |
| else: | |
| stats["l3_cache"] = { | |
| "available": False, | |
| "reason": "Redis not installed or not available", | |
| } | |
| return stats | |
| class QueryResultCache: | |
| """Advanced query result caching with database read replicas support""" | |
| def __init__(self, cache_manager): | |
| self.cache = cache_manager | |
| self.query_metrics = QueryCacheMetrics() | |
| self.query_cache_ttl = 300 # 5 minutes default | |
| self.read_replica_available = False | |
| self.primary_connection_string = "" | |
| self.replica_connection_strings: list[str] = [] | |
| def configure_read_replicas(self, primary: str, replicas: list[str]): | |
| """Configure database read replicas for improved read performance""" | |
| self.primary_connection_string = primary | |
| self.replica_connection_strings = replicas | |
| self.read_replica_available = len(replicas) > 0 | |
| async def execute_cached_query( | |
| self, | |
| query_func: Callable, | |
| query_sql: str, | |
| parameters: tuple = (), | |
| ttl_seconds: int | None = None, | |
| use_read_replica: bool = True, | |
| table_names: list[str] | None = None, | |
| ) -> Any: | |
| """ | |
| Execute a database query with intelligent caching | |
| Args: | |
| query_func: Function that executes the actual query | |
| query_sql: SQL query string for cache key generation | |
| parameters: Query parameters | |
| ttl_seconds: Cache TTL in seconds | |
| use_read_replica: Whether to prefer read replica | |
| table_names: Tables affected by this query (for invalidation) | |
| """ | |
| # Generate cache key from query and parameters | |
| query_hash = self._generate_query_hash(query_sql, parameters) | |
| cache_key = f"query:{query_hash}" | |
| # Check cache first | |
| cached_result = await self._get_cached_query_result(cache_key) | |
| if cached_result: | |
| self.query_metrics.query_hits += 1 | |
| cached_result.hit_count += 1 | |
| if cached_result.is_read_replica: | |
| self.query_metrics.read_replica_hits += 1 | |
| else: | |
| self.query_metrics.primary_db_hits += 1 | |
| return cached_result.result | |
| # Cache miss - execute query | |
| self.query_metrics.query_misses += 1 | |
| start_time = time.time() | |
| try: | |
| # Determine which database to use | |
| is_read_replica = use_read_replica and self.read_replica_available | |
| # Execute query | |
| result = await query_func() | |
| execution_time = time.time() - start_time | |
| # Cache the result | |
| await self._cache_query_result( | |
| query_hash=query_hash, | |
| query_sql=query_sql, | |
| parameters=parameters, | |
| result=result, | |
| execution_time=execution_time, | |
| ttl_seconds=ttl_seconds or self.query_cache_ttl, | |
| is_read_replica=is_read_replica, | |
| table_names=table_names or [], | |
| ) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Query execution failed: {e}") | |
| raise | |
| async def _get_cached_query_result(self, cache_key: str) -> QueryCacheEntry | None: | |
| """Get cached query result""" | |
| return await self.cache.get("query_cache", cache_key) | |
| async def _cache_query_result( | |
| self, | |
| query_hash: str, | |
| query_sql: str, | |
| parameters: tuple, | |
| result: Any, | |
| execution_time: float, | |
| ttl_seconds: int, | |
| is_read_replica: bool, | |
| table_names: list[str], | |
| ): | |
| """Cache query result with metadata""" | |
| cache_key = f"query:{query_hash}" | |
| entry = QueryCacheEntry( | |
| query_hash=query_hash, | |
| query_sql=query_sql, | |
| parameters=parameters, | |
| result=result, | |
| execution_time=execution_time, | |
| created_at=datetime.now(), | |
| expires_at=datetime.now() + timedelta(seconds=ttl_seconds), | |
| table_names=table_names, | |
| is_read_replica=is_read_replica, | |
| ) | |
| await self.cache.set("query_cache", cache_key, entry, ttl_seconds) | |
| def _generate_query_hash(self, query_sql: str, parameters: tuple) -> str: | |
| """Generate unique hash for query + parameters""" | |
| query_key = f"{query_sql}:{parameters!s}" | |
| return hashlib.md5(query_key.encode()).hexdigest()[:16] | |
| async def invalidate_table_cache(self, table_name: str): | |
| """Invalidate all cached queries for a specific table""" | |
| # This would require a reverse index of table -> queries | |
| # For now, we'll implement a simple invalidation strategy | |
| self.query_metrics.cache_invalidations += 1 | |
| logger.info(f"Cache invalidated for table: {table_name}") | |
| async def get_cache_statistics(self) -> dict[str, Any]: | |
| """Get comprehensive cache statistics""" | |
| base_stats = await self.cache.get_statistics() | |
| query_stats = self.query_metrics.to_dict() | |
| return { | |
| "cache_statistics": base_stats, | |
| "query_cache_statistics": query_stats, | |
| "read_replica_available": self.read_replica_available, | |
| "replica_count": len(self.replica_connection_strings), | |
| "cache_configuration": { | |
| "default_ttl": self.query_cache_ttl, | |
| "read_replica_preferred": self.read_replica_available, | |
| }, | |
| } | |
| class CachedFunction: | |
| """Decorator for caching function results""" | |
| def __init__(self, cache: MultiLayerCache, namespace: str, ttl_seconds: int = 300): | |
| self.cache = cache | |
| self.namespace = namespace | |
| self.ttl_seconds = ttl_seconds | |
| def __call__(self, func: Callable) -> Callable: | |
| async def async_wrapper(*args, **kwargs): | |
| # Create cache key from function name and arguments | |
| key_data = {"function": func.__name__, "args": args, "kwargs": kwargs} | |
| # Try cache first | |
| cached_result = self.cache.get(self.namespace, key_data) | |
| if cached_result is not None: | |
| return cached_result | |
| # Execute function | |
| result = await func(*args, **kwargs) | |
| # Cache result | |
| self.cache.set(self.namespace, key_data, result, self.ttl_seconds) | |
| return result | |
| def sync_wrapper(*args, **kwargs): | |
| # Create cache key from function name and arguments | |
| key_data = {"function": func.__name__, "args": args, "kwargs": kwargs} | |
| # Try cache first | |
| cached_result = self.cache.get(self.namespace, key_data) | |
| if cached_result is not None: | |
| return cached_result | |
| # Execute function | |
| result = func(*args, **kwargs) | |
| # Cache result | |
| self.cache.set(self.namespace, key_data, result, self.ttl_seconds) | |
| return result | |
| if asyncio.iscoroutinefunction(func): | |
| return async_wrapper | |
| else: | |
| return sync_wrapper | |
| # Create singleton cache instance | |
| cache_manager = MultiLayerCache(max_memory_entries=2000, default_ttl_seconds=300) | |
| # Export convenience functions | |
| def cached(namespace: str, ttl_seconds: int = 300): | |
| """Decorator for caching function results""" | |
| return CachedFunction(cache_manager, namespace, ttl_seconds) | |
| def get_cache_stats(): | |
| """Get cache statistics""" | |
| return cache_manager.get_stats() | |
| def clear_cache_namespace(namespace: str): | |
| """Clear all cache entries in a namespace""" | |
| return cache_manager.clear_namespace(namespace) | |
| def clear_all_cache(): | |
| """Clear all cache entries""" | |
| return cache_manager.clear_all() | |
| # Global instances | |
| query_cache = QueryResultCache(cache_manager) | |