Spaces:
Paused
Paused
| """ | |
| API response caching middleware | |
| """ | |
| import hashlib | |
| import json | |
| import logging | |
| import os | |
| import time | |
| from functools import wraps | |
| from typing import Any, Optional | |
| import redis.asyncio as redis | |
| from fastapi import Request, Response | |
| from fastapi.responses import JSONResponse | |
| logger = logging.getLogger(__name__) | |
| class CacheManager: | |
| """Advanced API response caching""" | |
| def __init__(self): | |
| self.redis = redis.from_url(os.getenv("REDIS_URL", "redis://cache-service:6379"), decode_responses=True) | |
| self.default_ttl = int(os.getenv("CACHE_DEFAULT_TTL", "300")) | |
| self.max_cache_size = int(os.getenv("CACHE_MAX_SIZE", "1000")) | |
| def _generate_cache_key(self, prefix: str, request: Request) -> str: | |
| """Generate cache key from request""" | |
| # Include method, path, and relevant headers in cache key | |
| key_data = { | |
| "method": request.method, | |
| "path": request.url.path, | |
| "query_params": str(sorted(request.query.params.items())), | |
| "user_agent": request.headers.get("user-agent", ""), | |
| "accept": request.headers.get("accept", ""), | |
| } | |
| key_string = json.dumps(key_data, sort_keys=True) | |
| return f"{prefix}:{hashlib.md5(key_string.encode()).hexdigest()}" | |
| async def get_cached_response(self, cache_key: str) -> Optional[dict[str, Any]]: | |
| """Get cached response""" | |
| try: | |
| cached = await self.redis.get(cache_key) | |
| if cached: | |
| return json.loads(cached) | |
| except Exception as e: | |
| logger.error(f"Cache get error: {e}") | |
| return None | |
| async def set_cached_response(self, cache_key: str, response_data: dict[str, Any], ttl: Optional[int] = None): | |
| """Cache response data""" | |
| try: | |
| ttl = ttl or self.default_ttl | |
| cache_data = {"data": response_data, "timestamp": time.time(), "ttl": ttl} | |
| await self.redis.setex(cache_key, ttl, json.dumps(cache_data)) | |
| except Exception as e: | |
| logger.error(f"Cache set error: {e}") | |
| async def invalidate_cache_pattern(self, pattern: str) -> int: | |
| """Invalidate cache by pattern""" | |
| try: | |
| keys = await self.redis.keys(pattern) | |
| if keys: | |
| deleted = await self.redis.delete(*keys) | |
| logger.info(f"Invalidated {deleted} cache entries matching pattern: {pattern}") | |
| return deleted | |
| except Exception as e: | |
| logger.error(f"Cache invalidation error: {e}") | |
| return 0 | |
| async def get_cache_stats(self) -> dict[str, Any]: | |
| """Get cache statistics""" | |
| try: | |
| info = await self.redis.info() | |
| return { | |
| "used_memory": info.get("used_memory_human", "unknown"), | |
| "used_memory_bytes": info.get("used_memory", 0), | |
| "max_memory": info.get("maxmemory_human", "unknown"), | |
| "connected_clients": info.get("connected_clients", 0), | |
| "total_commands_processed": info.get("total_commands_processed", 0), | |
| "keyspace_hits": info.get("keyspace_hits", 0), | |
| "keyspace_misses": info.get("keyspace_misses", 0), | |
| "hit_rate": self._calculate_hit_rate(info), | |
| } | |
| except Exception as e: | |
| logger.error(f"Cache stats error: {e}") | |
| return {"error": str(e)} | |
| def _calculate_hit_rate(self, info: dict[str, Any]) -> float: | |
| """Calculate cache hit rate""" | |
| hits = info.get("keyspace_hits", 0) | |
| misses = info.get("keyspace_misses", 0) | |
| total = hits + misses | |
| if total > 0: | |
| return (hits / total) * 100 | |
| return 0.0 | |
| # Cache middleware factory | |
| def create_cache_middleware(ttl: Optional[int] = None, cacheable_methods: list = None): | |
| """Create caching middleware with custom options""" | |
| async def middleware(request: Request, call_next): | |
| cache_manager = CacheManager() | |
| async def cache_response(response: Response): | |
| # Only cache GET requests and successful responses | |
| if ( | |
| request.method in (cacheable_methods or ["GET"]) | |
| and response.status_code < 400 | |
| and response.headers.get("content-type", "").startswith("application/json") | |
| ): | |
| cache_key = cache_manager._generate_cache_key("api", request) | |
| response_data = { | |
| "status_code": response.status_code, | |
| "headers": dict(response.headers), | |
| "body": response.body, | |
| "timestamp": time.time(), | |
| } | |
| await cache_manager.set_cached_response(cache_key, response_data, ttl) | |
| return response | |
| # Check cache before processing | |
| cache_key = cache_manager._generate_cache_key("api", request) | |
| cached_response = await cache_manager.get_cached_response(cache_key) | |
| if cached_response: | |
| # Check if cache is still valid | |
| current_time = time.time() | |
| cache_age = current_time - cached_response["timestamp"] | |
| if cache_age < cached_response["ttl"]: | |
| logger.info(f"Cache hit for: {request.method} {request.url.path}") | |
| return JSONResponse( | |
| content=cached_response["body"], | |
| status_code=cached_response["status_code"], | |
| headers={**cached_response["headers"], "X-Cache": "HIT", "X-Cache-Age": str(int(cache_age))}, | |
| ) | |
| # Process request and cache response | |
| response = await call_next(request) | |
| if hasattr(response, "body") and hasattr(response, "status_code"): | |
| return await cache_response(response) | |
| return response | |
| return middleware | |
| # Smart caching decorators | |
| def cache_response(ttl: int = 300, cache_key_func=None): | |
| """Decorator for caching function responses""" | |
| def decorator(func): | |
| async def wrapper(*args, **kwargs): | |
| cache_manager = CacheManager() | |
| # Generate custom cache key if function provided | |
| if cache_key_func: | |
| cache_key = f"function:{func.__name__}:{cache_key_func(*args, **kwargs)}" | |
| else: | |
| cache_key = f"function:{func.__name__}:{hashlib.md5(str(args).encode()).hexdigest()}" | |
| # Check cache first | |
| cached = await cache_manager.get_cached_response(cache_key) | |
| if cached: | |
| logger.info(f"Function cache hit: {func.__name__}") | |
| return cached["data"] | |
| # Execute function | |
| result = await func(*args, **kwargs) | |
| # Cache successful responses | |
| if isinstance(result, dict) and "status_code" in result and result["status_code"] < 400: | |
| response_data = { | |
| "status_code": result["status_code"], | |
| "headers": result.get("headers", {}), | |
| "body": result.get("body", {}), | |
| "timestamp": time.time(), | |
| } | |
| await cache_manager.set_cached_response(cache_key, response_data, ttl) | |
| logger.info(f"Function cached: {func.__name__}") | |
| return result | |
| return wrapper | |
| return decorator | |
| # Specialized caching for different data types | |
| class DataCache: | |
| """Specialized caching for different data types""" | |
| def __init__(self, redis_client): | |
| self.redis = redis_client | |
| async def cache_user_data(self, user_id: str, data: dict[str, Any], ttl: int = 3600): | |
| """Cache user-specific data""" | |
| key = f"user:{user_id}" | |
| await self.redis.setex(key, ttl, json.dumps(data)) | |
| async def get_user_data(self, user_id: str) -> Optional[dict[str, Any]]: | |
| """Get cached user data""" | |
| key = f"user:{user_id}" | |
| data = await self.redis.get(key) | |
| return json.loads(data) if data else None | |
| async def cache_query_results(self, query_hash: str, results: list, ttl: int = 600): | |
| """Cache database query results""" | |
| key = f"query:{query_hash}" | |
| await self.redis.setex(key, ttl, json.dumps(results)) | |
| async def get_query_results(self, query_hash: str) -> Optional[list]: | |
| """Get cached query results""" | |
| key = f"query:{query_hash}" | |
| data = await self.redis.get(key) | |
| return json.loads(data) if data else None | |