zenith-backend / app /middleware /caching.py
teoat's picture
fix(backend): fix port and health check robustness
d29a5a0 verified
"""
API response caching middleware
"""
import hashlib
import json
import logging
import os
import time
from functools import wraps
from typing import Any, Optional
import redis.asyncio as redis
from fastapi import Request, Response
from fastapi.responses import JSONResponse
logger = logging.getLogger(__name__)
class CacheManager:
"""Advanced API response caching"""
def __init__(self):
self.redis = redis.from_url(os.getenv("REDIS_URL", "redis://cache-service:6379"), decode_responses=True)
self.default_ttl = int(os.getenv("CACHE_DEFAULT_TTL", "300"))
self.max_cache_size = int(os.getenv("CACHE_MAX_SIZE", "1000"))
def _generate_cache_key(self, prefix: str, request: Request) -> str:
"""Generate cache key from request"""
# Include method, path, and relevant headers in cache key
key_data = {
"method": request.method,
"path": request.url.path,
"query_params": str(sorted(request.query.params.items())),
"user_agent": request.headers.get("user-agent", ""),
"accept": request.headers.get("accept", ""),
}
key_string = json.dumps(key_data, sort_keys=True)
return f"{prefix}:{hashlib.md5(key_string.encode()).hexdigest()}"
async def get_cached_response(self, cache_key: str) -> Optional[dict[str, Any]]:
"""Get cached response"""
try:
cached = await self.redis.get(cache_key)
if cached:
return json.loads(cached)
except Exception as e:
logger.error(f"Cache get error: {e}")
return None
async def set_cached_response(self, cache_key: str, response_data: dict[str, Any], ttl: Optional[int] = None):
"""Cache response data"""
try:
ttl = ttl or self.default_ttl
cache_data = {"data": response_data, "timestamp": time.time(), "ttl": ttl}
await self.redis.setex(cache_key, ttl, json.dumps(cache_data))
except Exception as e:
logger.error(f"Cache set error: {e}")
async def invalidate_cache_pattern(self, pattern: str) -> int:
"""Invalidate cache by pattern"""
try:
keys = await self.redis.keys(pattern)
if keys:
deleted = await self.redis.delete(*keys)
logger.info(f"Invalidated {deleted} cache entries matching pattern: {pattern}")
return deleted
except Exception as e:
logger.error(f"Cache invalidation error: {e}")
return 0
async def get_cache_stats(self) -> dict[str, Any]:
"""Get cache statistics"""
try:
info = await self.redis.info()
return {
"used_memory": info.get("used_memory_human", "unknown"),
"used_memory_bytes": info.get("used_memory", 0),
"max_memory": info.get("maxmemory_human", "unknown"),
"connected_clients": info.get("connected_clients", 0),
"total_commands_processed": info.get("total_commands_processed", 0),
"keyspace_hits": info.get("keyspace_hits", 0),
"keyspace_misses": info.get("keyspace_misses", 0),
"hit_rate": self._calculate_hit_rate(info),
}
except Exception as e:
logger.error(f"Cache stats error: {e}")
return {"error": str(e)}
def _calculate_hit_rate(self, info: dict[str, Any]) -> float:
"""Calculate cache hit rate"""
hits = info.get("keyspace_hits", 0)
misses = info.get("keyspace_misses", 0)
total = hits + misses
if total > 0:
return (hits / total) * 100
return 0.0
# Cache middleware factory
def create_cache_middleware(ttl: Optional[int] = None, cacheable_methods: list = None):
"""Create caching middleware with custom options"""
async def middleware(request: Request, call_next):
cache_manager = CacheManager()
async def cache_response(response: Response):
# Only cache GET requests and successful responses
if (
request.method in (cacheable_methods or ["GET"])
and response.status_code < 400
and response.headers.get("content-type", "").startswith("application/json")
):
cache_key = cache_manager._generate_cache_key("api", request)
response_data = {
"status_code": response.status_code,
"headers": dict(response.headers),
"body": response.body,
"timestamp": time.time(),
}
await cache_manager.set_cached_response(cache_key, response_data, ttl)
return response
# Check cache before processing
cache_key = cache_manager._generate_cache_key("api", request)
cached_response = await cache_manager.get_cached_response(cache_key)
if cached_response:
# Check if cache is still valid
current_time = time.time()
cache_age = current_time - cached_response["timestamp"]
if cache_age < cached_response["ttl"]:
logger.info(f"Cache hit for: {request.method} {request.url.path}")
return JSONResponse(
content=cached_response["body"],
status_code=cached_response["status_code"],
headers={**cached_response["headers"], "X-Cache": "HIT", "X-Cache-Age": str(int(cache_age))},
)
# Process request and cache response
response = await call_next(request)
if hasattr(response, "body") and hasattr(response, "status_code"):
return await cache_response(response)
return response
return middleware
# Smart caching decorators
def cache_response(ttl: int = 300, cache_key_func=None):
"""Decorator for caching function responses"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
cache_manager = CacheManager()
# Generate custom cache key if function provided
if cache_key_func:
cache_key = f"function:{func.__name__}:{cache_key_func(*args, **kwargs)}"
else:
cache_key = f"function:{func.__name__}:{hashlib.md5(str(args).encode()).hexdigest()}"
# Check cache first
cached = await cache_manager.get_cached_response(cache_key)
if cached:
logger.info(f"Function cache hit: {func.__name__}")
return cached["data"]
# Execute function
result = await func(*args, **kwargs)
# Cache successful responses
if isinstance(result, dict) and "status_code" in result and result["status_code"] < 400:
response_data = {
"status_code": result["status_code"],
"headers": result.get("headers", {}),
"body": result.get("body", {}),
"timestamp": time.time(),
}
await cache_manager.set_cached_response(cache_key, response_data, ttl)
logger.info(f"Function cached: {func.__name__}")
return result
return wrapper
return decorator
# Specialized caching for different data types
class DataCache:
"""Specialized caching for different data types"""
def __init__(self, redis_client):
self.redis = redis_client
async def cache_user_data(self, user_id: str, data: dict[str, Any], ttl: int = 3600):
"""Cache user-specific data"""
key = f"user:{user_id}"
await self.redis.setex(key, ttl, json.dumps(data))
async def get_user_data(self, user_id: str) -> Optional[dict[str, Any]]:
"""Get cached user data"""
key = f"user:{user_id}"
data = await self.redis.get(key)
return json.loads(data) if data else None
async def cache_query_results(self, query_hash: str, results: list, ttl: int = 600):
"""Cache database query results"""
key = f"query:{query_hash}"
await self.redis.setex(key, ttl, json.dumps(results))
async def get_query_results(self, query_hash: str) -> Optional[list]:
"""Get cached query results"""
key = f"query:{query_hash}"
data = await self.redis.get(key)
return json.loads(data) if data else None