zenith-backend / core /distributed_cache.py
teoat's picture
Upload core/distributed_cache.py with huggingface_hub
235db39 verified
"""
Distributed Cache System with Redis Integration
Replaces the custom query cache with a proper distributed cache
that supports Redis for production and in-memory fallback for development.
"""
import hashlib
import json
import pickle
import time
from abc import ABC, abstractmethod
from functools import wraps
from typing import Any, Callable, Dict, Optional
from core.logging import logger
class CacheBackend(ABC):
"""Abstract base class for cache backends"""
@abstractmethod
async def get(self, key: str) -> Optional[Any]:
"""Get value from cache"""
pass
@abstractmethod
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""Set value in cache with optional TTL"""
pass
@abstractmethod
async def delete(self, key: str) -> bool:
"""Delete key from cache"""
pass
@abstractmethod
async def exists(self, key: str) -> bool:
"""Check if key exists in cache"""
pass
@abstractmethod
async def clear(self, pattern: Optional[str] = None) -> int:
"""Clear cache keys, optionally by pattern"""
pass
@abstractmethod
async def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
pass
class InMemoryCache(CacheBackend):
"""In-memory cache backend for development"""
def __init__(self, max_size: int = 1000):
self.store: Dict[str, Dict[str, Any]] = {}
self.max_size = max_size
self.stats = {
"hits": 0,
"misses": 0,
"sets": 0,
"deletes": 0,
"evictions": 0,
}
async def get(self, key: str) -> Optional[Any]:
"""Get value from cache"""
if key not in self.store:
self.stats["misses"] += 1
return None
entry = self.store[key]
# Check if expired
if entry.get("expires_at") and time.time() > entry["expires_at"]:
del self.store[key]
self.stats["misses"] += 1
return None
self.stats["hits"] += 1
return entry["value"]
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""Set value in cache with optional TTL"""
try:
# Check if we need to evict entries
if len(self.store) >= self.max_size and key not in self.store:
self._evict_oldest()
expires_at = None
if ttl:
expires_at = time.time() + ttl
self.store[key] = {
"value": value,
"created_at": time.time(),
"expires_at": expires_at,
"access_count": 0,
}
self.stats["sets"] += 1
return True
except Exception as e:
logger.error(f"In-memory cache set failed: {e}")
return False
async def delete(self, key: str) -> bool:
"""Delete key from cache"""
if key in self.store:
del self.store[key]
self.stats["deletes"] += 1
return True
return False
async def exists(self, key: str) -> bool:
"""Check if key exists in cache"""
if key not in self.store:
return False
entry = self.store[key]
# Check if expired
if entry.get("expires_at") and time.time() > entry["expires_at"]:
del self.store[key]
return False
return True
async def clear(self, pattern: Optional[str] = None) -> int:
"""Clear cache keys, optionally by pattern"""
if pattern:
import fnmatch
keys_to_delete = [
k for k in self.store.keys() if fnmatch.fnmatch(k, pattern)
]
else:
keys_to_delete = list(self.store.keys())
for key in keys_to_delete:
del self.store[key]
return len(keys_to_delete)
async def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
total_requests = self.stats["hits"] + self.stats["misses"]
hit_ratio = self.stats["hits"] / total_requests if total_requests > 0 else 0
return {
**self.stats,
"total_keys": len(self.store),
"hit_ratio": hit_ratio,
"memory_usage_mb": self._estimate_memory_usage(),
}
def _evict_oldest(self):
"""Evict the oldest entry based on created_at"""
if not self.store:
return
oldest_key = min(self.store.keys(), key=lambda k: self.store[k]["created_at"])
del self.store[oldest_key]
self.stats["evictions"] += 1
def _estimate_memory_usage(self) -> float:
"""Estimate memory usage in MB (rough estimate)"""
try:
total_size = sum(len(pickle.dumps(entry)) for entry in self.store.values())
return total_size / (1024 * 1024) # Convert to MB
except Exception:
return 0.0
class RedisCache(CacheBackend):
"""Redis cache backend for production"""
def __init__(self, redis_client, key_prefix: str = "cache:"):
self.redis = redis_client
self.key_prefix = key_prefix
self.stats = {
"hits": 0,
"misses": 0,
"sets": 0,
"deletes": 0,
"errors": 0,
}
def _make_key(self, key: str) -> str:
"""Create full Redis key with prefix"""
return f"{self.key_prefix}{key}"
async def get(self, key: str) -> Optional[Any]:
"""Get value from cache"""
try:
redis_key = self._make_key(key)
value = await self.redis.get(redis_key)
if value is None:
self.stats["misses"] += 1
return None
# Deserialize value
try:
deserialized_value = pickle.loads(value)
self.stats["hits"] += 1
return deserialized_value
except (pickle.PickleError, TypeError):
# Fallback to JSON if pickle fails
deserialized_value = json.loads(value)
self.stats["hits"] += 1
return deserialized_value
except Exception as e:
logger.error(f"Redis cache get failed: {e}")
self.stats["errors"] += 1
return None
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""Set value in cache with optional TTL"""
try:
redis_key = self._make_key(key)
# Serialize value
try:
serialized_value = pickle.dumps(value)
except (pickle.PickleError, TypeError):
# Fallback to JSON if pickle fails
serialized_value = json.dumps(value, default=str)
# Set with TTL if provided
if ttl:
await self.redis.setex(redis_key, ttl, serialized_value)
else:
await self.redis.set(redis_key, serialized_value)
self.stats["sets"] += 1
return True
except Exception as e:
logger.error(f"Redis cache set failed: {e}")
self.stats["errors"] += 1
return False
async def delete(self, key: str) -> bool:
"""Delete key from cache"""
try:
redis_key = self._make_key(key)
result = await self.redis.delete(redis_key)
if result > 0:
self.stats["deletes"] += 1
return True
return False
except Exception as e:
logger.error(f"Redis cache delete failed: {e}")
self.stats["errors"] += 1
return False
async def exists(self, key: str) -> bool:
"""Check if key exists in cache"""
try:
redis_key = self._make_key(key)
result = await self.redis.exists(redis_key)
return result > 0
except Exception as e:
logger.error(f"Redis cache exists failed: {e}")
self.stats["errors"] += 1
return False
async def clear(self, pattern: Optional[str] = None) -> int:
"""Clear cache keys, optionally by pattern"""
try:
if pattern:
search_pattern = self._make_key(pattern)
keys = await self.redis.keys(search_pattern)
else:
search_pattern = self._make_key("*")
keys = await self.redis.keys(search_pattern)
if keys:
await self.redis.delete(*keys)
return len(keys)
except Exception as e:
logger.error(f"Redis cache clear failed: {e}")
self.stats["errors"] += 1
return 0
async def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
try:
info = await self.redis.info()
total_requests = self.stats["hits"] + self.stats["misses"]
hit_ratio = self.stats["hits"] / total_requests if total_requests > 0 else 0
return {
**self.stats,
"hit_ratio": hit_ratio,
"redis_memory_used_mb": info.get("used_memory", 0) / (1024 * 1024),
"redis_connected_clients": info.get("connected_clients", 0),
"redis_total_commands": info.get("total_commands_processed", 0),
}
except Exception as e:
logger.error(f"Failed to get Redis stats: {e}")
return {**self.stats, "error": str(e)}
class DistributedCache:
"""Unified distributed cache interface"""
def __init__(self, backend: CacheBackend):
self.backend = backend
async def get(self, key: str) -> Optional[Any]:
"""Get value from cache"""
return await self.backend.get(key)
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""Set value in cache with optional TTL"""
return await self.backend.set(key, value, ttl)
async def delete(self, key: str) -> bool:
"""Delete key from cache"""
return await self.backend.delete(key)
async def exists(self, key: str) -> bool:
"""Check if key exists in cache"""
return await self.backend.exists(key)
async def clear(self, pattern: Optional[str] = None) -> int:
"""Clear cache keys, optionally by pattern"""
return await self.backend.clear(pattern)
async def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
return await self.backend.get_stats()
def cached(self, ttl: Optional[int] = None, key_prefix: str = ""):
"""Decorator for caching function results"""
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
# Create cache key
cache_key = self._make_cache_key(func, args, kwargs, key_prefix)
# Try to get from cache
cached_result = await self.get(cache_key)
if cached_result is not None:
return cached_result
# Execute function and cache result
result = await func(*args, **kwargs)
await self.set(cache_key, result, ttl)
return result
return wrapper
return decorator
def _make_cache_key(
self, func: Callable, args: tuple, kwargs: dict, prefix: str
) -> str:
"""Create a cache key for a function call"""
# Create a string representation of the function and arguments
key_data = {
"func": f"{func.__module__}.{func.__name__}",
"args": str(args),
"kwargs": str(sorted(kwargs.items())),
}
key_string = json.dumps(key_data, sort_keys=True)
key_hash = hashlib.md5(key_string.encode()).hexdigest()
return f"{prefix}{func.__module__}.{func.__name__}:{key_hash}"
# Cache factory function
def create_cache() -> DistributedCache:
"""Create a cache instance with appropriate backend"""
try:
import os
if os.getenv("ENVIRONMENT", "development").lower() == "production":
# Try to initialize Redis
import redis.asyncio as redis
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379")
redis_client = redis.from_url(redis_url)
backend = RedisCache(redis_client)
logger.info("Using Redis for distributed cache")
else:
raise ImportError("Use in-memory for development")
except (ImportError, Exception) as e:
backend = InMemoryCache(max_size=1000)
logger.info(f"Using in-memory cache for distributed cache: {e}")
return DistributedCache(backend)
# Global cache instance
_cache_instance: Optional[DistributedCache] = None
def get_cache() -> DistributedCache:
"""Get the global cache instance"""
global _cache_instance
if _cache_instance is None:
_cache_instance = create_cache()
return _cache_instance
# Legacy compatibility functions
async def cached_query(ttl: int = 300):
"""Legacy decorator for backward compatibility"""
cache = get_cache()
return cache.cached(ttl=ttl)
async def clear_query_cache():
"""Legacy function to clear query cache"""
cache = get_cache()
return await cache.clear()
# Export main classes and functions
__all__ = [
"DistributedCache",
"CacheBackend",
"InMemoryCache",
"RedisCache",
"create_cache",
"get_cache",
# Legacy compatibility
"cached_query",
"clear_query_cache",
]