|
|
"""
|
|
|
Redis caching strategies and decorators for FastAPI endpoints.
|
|
|
|
|
|
This module provides caching decorators, cache invalidation patterns,
|
|
|
cache warming strategies, and monitoring for the video generation API.
|
|
|
"""
|
|
|
|
|
|
import asyncio
|
|
|
import functools
|
|
|
import hashlib
|
|
|
import json
|
|
|
import logging
|
|
|
from datetime import datetime, timedelta
|
|
|
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
|
|
from fastapi import Request, Response
|
|
|
from redis.asyncio import Redis
|
|
|
|
|
|
from .redis import redis_manager, RedisKeyManager, safe_redis_operation
|
|
|
from .config import get_settings
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
settings = get_settings()
|
|
|
|
|
|
|
|
|
class CacheConfig:
|
|
|
"""Configuration for caching behavior."""
|
|
|
|
|
|
|
|
|
DEFAULT_TTL = 300
|
|
|
SHORT_TTL = 60
|
|
|
MEDIUM_TTL = 900
|
|
|
LONG_TTL = 3600
|
|
|
VERY_LONG_TTL = 86400
|
|
|
|
|
|
|
|
|
ENDPOINT_CACHE = "endpoint_cache"
|
|
|
QUERY_CACHE = "query_cache"
|
|
|
USER_CACHE = "user_cache"
|
|
|
SYSTEM_CACHE = "system_cache"
|
|
|
|
|
|
|
|
|
WARM_CACHE_BATCH_SIZE = 10
|
|
|
WARM_CACHE_DELAY = 0.1
|
|
|
|
|
|
|
|
|
class CacheKeyGenerator:
|
|
|
"""Utility class for generating consistent cache keys."""
|
|
|
|
|
|
@staticmethod
|
|
|
def endpoint_key(
|
|
|
method: str,
|
|
|
path: str,
|
|
|
query_params: Dict[str, Any] = None,
|
|
|
user_id: str = None,
|
|
|
additional_params: Dict[str, Any] = None
|
|
|
) -> str:
|
|
|
"""
|
|
|
Generate cache key for API endpoint responses.
|
|
|
|
|
|
Args:
|
|
|
method: HTTP method
|
|
|
path: Request path
|
|
|
query_params: Query parameters
|
|
|
user_id: User ID for user-specific caching
|
|
|
additional_params: Additional parameters for key generation
|
|
|
|
|
|
Returns:
|
|
|
Generated cache key
|
|
|
"""
|
|
|
key_parts = [method.upper(), path]
|
|
|
|
|
|
if user_id:
|
|
|
key_parts.append(f"user:{user_id}")
|
|
|
|
|
|
if query_params:
|
|
|
|
|
|
sorted_params = sorted(query_params.items())
|
|
|
params_str = "&".join(f"{k}={v}" for k, v in sorted_params)
|
|
|
key_parts.append(f"params:{params_str}")
|
|
|
|
|
|
if additional_params:
|
|
|
sorted_additional = sorted(additional_params.items())
|
|
|
additional_str = "&".join(f"{k}={v}" for k, v in sorted_additional)
|
|
|
key_parts.append(f"extra:{additional_str}")
|
|
|
|
|
|
|
|
|
key_string = "|".join(key_parts)
|
|
|
if len(key_string) > 200:
|
|
|
key_hash = hashlib.md5(key_string.encode()).hexdigest()
|
|
|
return RedisKeyManager.cache_key(CacheConfig.ENDPOINT_CACHE, key_hash)
|
|
|
|
|
|
return RedisKeyManager.cache_key(CacheConfig.ENDPOINT_CACHE, key_string)
|
|
|
|
|
|
@staticmethod
|
|
|
def query_key(query_name: str, params: Dict[str, Any] = None) -> str:
|
|
|
"""Generate cache key for database queries."""
|
|
|
key_parts = [query_name]
|
|
|
|
|
|
if params:
|
|
|
sorted_params = sorted(params.items())
|
|
|
params_str = "&".join(f"{k}={v}" for k, v in sorted_params)
|
|
|
key_parts.append(params_str)
|
|
|
|
|
|
key_string = "|".join(key_parts)
|
|
|
return RedisKeyManager.cache_key(CacheConfig.QUERY_CACHE, key_string)
|
|
|
|
|
|
@staticmethod
|
|
|
def user_key(user_id: str, data_type: str) -> str:
|
|
|
"""Generate cache key for user-specific data."""
|
|
|
return RedisKeyManager.cache_key(CacheConfig.USER_CACHE, f"{user_id}:{data_type}")
|
|
|
|
|
|
@staticmethod
|
|
|
def system_key(component: str, metric: str = None) -> str:
|
|
|
"""Generate cache key for system data."""
|
|
|
key = component
|
|
|
if metric:
|
|
|
key = f"{component}:{metric}"
|
|
|
return RedisKeyManager.cache_key(CacheConfig.SYSTEM_CACHE, key)
|
|
|
|
|
|
|
|
|
class CacheManager:
|
|
|
"""Advanced cache management with invalidation and warming strategies."""
|
|
|
|
|
|
def __init__(self):
|
|
|
self._cache_stats = {
|
|
|
"hits": 0,
|
|
|
"misses": 0,
|
|
|
"sets": 0,
|
|
|
"deletes": 0,
|
|
|
"invalidations": 0
|
|
|
}
|
|
|
|
|
|
async def get(
|
|
|
self,
|
|
|
key: str,
|
|
|
default: Any = None,
|
|
|
deserialize: bool = True
|
|
|
) -> Any:
|
|
|
"""
|
|
|
Get value from cache with statistics tracking.
|
|
|
|
|
|
Args:
|
|
|
key: Cache key
|
|
|
default: Default value if key doesn't exist
|
|
|
deserialize: Whether to deserialize JSON data
|
|
|
|
|
|
Returns:
|
|
|
Cached value or default
|
|
|
"""
|
|
|
try:
|
|
|
redis_client = redis_manager.redis
|
|
|
value = await redis_client.get(key)
|
|
|
|
|
|
if value is None:
|
|
|
self._cache_stats["misses"] += 1
|
|
|
return default
|
|
|
|
|
|
self._cache_stats["hits"] += 1
|
|
|
|
|
|
if deserialize:
|
|
|
try:
|
|
|
return json.loads(value)
|
|
|
except json.JSONDecodeError:
|
|
|
logger.warning(f"Failed to deserialize cached value for key: {key}")
|
|
|
return default
|
|
|
|
|
|
return value
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Cache get failed for key {key}: {e}")
|
|
|
self._cache_stats["misses"] += 1
|
|
|
return default
|
|
|
|
|
|
async def set(
|
|
|
self,
|
|
|
key: str,
|
|
|
value: Any,
|
|
|
ttl: int = CacheConfig.DEFAULT_TTL,
|
|
|
serialize: bool = True
|
|
|
) -> bool:
|
|
|
"""
|
|
|
Set value in cache with TTL.
|
|
|
|
|
|
Args:
|
|
|
key: Cache key
|
|
|
value: Value to cache
|
|
|
ttl: Time to live in seconds
|
|
|
serialize: Whether to serialize value as JSON
|
|
|
|
|
|
Returns:
|
|
|
True if successful
|
|
|
"""
|
|
|
try:
|
|
|
redis_client = redis_manager.redis
|
|
|
|
|
|
if serialize:
|
|
|
value = json.dumps(value, default=str)
|
|
|
|
|
|
result = await redis_client.setex(key, ttl, value)
|
|
|
self._cache_stats["sets"] += 1
|
|
|
return result
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Cache set failed for key {key}: {e}")
|
|
|
return False
|
|
|
|
|
|
async def delete(self, key: str) -> bool:
|
|
|
"""Delete key from cache."""
|
|
|
try:
|
|
|
redis_client = redis_manager.redis
|
|
|
result = await redis_client.delete(key)
|
|
|
self._cache_stats["deletes"] += 1
|
|
|
return bool(result)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Cache delete failed for key {key}: {e}")
|
|
|
return False
|
|
|
|
|
|
async def delete_pattern(self, pattern: str) -> int:
|
|
|
"""
|
|
|
Delete all keys matching pattern.
|
|
|
|
|
|
Args:
|
|
|
pattern: Redis key pattern (supports wildcards)
|
|
|
|
|
|
Returns:
|
|
|
Number of keys deleted
|
|
|
"""
|
|
|
try:
|
|
|
redis_client = redis_manager.redis
|
|
|
keys = await redis_client.keys(pattern)
|
|
|
|
|
|
if not keys:
|
|
|
return 0
|
|
|
|
|
|
deleted = await redis_client.delete(*keys)
|
|
|
self._cache_stats["deletes"] += deleted
|
|
|
return deleted
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Cache pattern delete failed for pattern {pattern}: {e}")
|
|
|
return 0
|
|
|
|
|
|
async def invalidate_user_cache(self, user_id: str) -> int:
|
|
|
"""Invalidate all cache entries for a specific user."""
|
|
|
pattern = RedisKeyManager.cache_key(CacheConfig.USER_CACHE, f"{user_id}:*")
|
|
|
deleted = await self.delete_pattern(pattern)
|
|
|
self._cache_stats["invalidations"] += 1
|
|
|
logger.info(f"Invalidated {deleted} cache entries for user {user_id}")
|
|
|
return deleted
|
|
|
|
|
|
async def invalidate_endpoint_cache(self, path_pattern: str) -> int:
|
|
|
"""Invalidate cache entries for specific endpoint patterns."""
|
|
|
pattern = RedisKeyManager.cache_key(CacheConfig.ENDPOINT_CACHE, f"*{path_pattern}*")
|
|
|
deleted = await self.delete_pattern(pattern)
|
|
|
self._cache_stats["invalidations"] += 1
|
|
|
logger.info(f"Invalidated {deleted} cache entries for pattern {path_pattern}")
|
|
|
return deleted
|
|
|
|
|
|
async def warm_cache(
|
|
|
self,
|
|
|
warm_functions: List[Tuple[Callable, Dict[str, Any]]],
|
|
|
batch_size: int = CacheConfig.WARM_CACHE_BATCH_SIZE
|
|
|
) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Warm cache with predefined data.
|
|
|
|
|
|
Args:
|
|
|
warm_functions: List of (function, kwargs) tuples to execute
|
|
|
batch_size: Number of operations per batch
|
|
|
|
|
|
Returns:
|
|
|
Warming results and statistics
|
|
|
"""
|
|
|
results = {
|
|
|
"total_functions": len(warm_functions),
|
|
|
"successful": 0,
|
|
|
"failed": 0,
|
|
|
"errors": []
|
|
|
}
|
|
|
|
|
|
for i in range(0, len(warm_functions), batch_size):
|
|
|
batch = warm_functions[i:i + batch_size]
|
|
|
|
|
|
for func, kwargs in batch:
|
|
|
try:
|
|
|
await func(**kwargs)
|
|
|
results["successful"] += 1
|
|
|
except Exception as e:
|
|
|
results["failed"] += 1
|
|
|
results["errors"].append({
|
|
|
"function": func.__name__,
|
|
|
"error": str(e)
|
|
|
})
|
|
|
logger.error(f"Cache warming failed for {func.__name__}: {e}")
|
|
|
|
|
|
|
|
|
if i + batch_size < len(warm_functions):
|
|
|
await asyncio.sleep(CacheConfig.WARM_CACHE_DELAY)
|
|
|
|
|
|
logger.info(f"Cache warming completed: {results}")
|
|
|
return results
|
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
|
"""Get cache statistics."""
|
|
|
total_operations = sum(self._cache_stats.values())
|
|
|
hit_rate = (
|
|
|
self._cache_stats["hits"] / (self._cache_stats["hits"] + self._cache_stats["misses"])
|
|
|
if (self._cache_stats["hits"] + self._cache_stats["misses"]) > 0
|
|
|
else 0
|
|
|
)
|
|
|
|
|
|
return {
|
|
|
**self._cache_stats,
|
|
|
"total_operations": total_operations,
|
|
|
"hit_rate": round(hit_rate * 100, 2),
|
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
|
}
|
|
|
|
|
|
def reset_stats(self) -> None:
|
|
|
"""Reset cache statistics."""
|
|
|
self._cache_stats = {
|
|
|
"hits": 0,
|
|
|
"misses": 0,
|
|
|
"sets": 0,
|
|
|
"deletes": 0,
|
|
|
"invalidations": 0
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
cache_manager = CacheManager()
|
|
|
|
|
|
|
|
|
def cache_response(
|
|
|
ttl: int = CacheConfig.DEFAULT_TTL,
|
|
|
key_generator: Optional[Callable] = None,
|
|
|
user_specific: bool = False,
|
|
|
skip_cache_header: str = "X-Skip-Cache",
|
|
|
vary_on: List[str] = None
|
|
|
):
|
|
|
"""
|
|
|
Decorator for caching FastAPI endpoint responses.
|
|
|
|
|
|
Args:
|
|
|
ttl: Time to live in seconds
|
|
|
key_generator: Custom key generation function
|
|
|
user_specific: Whether to include user ID in cache key
|
|
|
skip_cache_header: Header name to skip cache
|
|
|
vary_on: List of headers/params to vary cache on
|
|
|
|
|
|
Usage:
|
|
|
@router.get("/api/v1/jobs")
|
|
|
@cache_response(ttl=300, user_specific=True)
|
|
|
async def get_jobs(request: Request, user_id: str = Depends(get_current_user_id)):
|
|
|
return await job_service.get_jobs(user_id)
|
|
|
"""
|
|
|
def decorator(func: Callable) -> Callable:
|
|
|
@functools.wraps(func)
|
|
|
async def wrapper(*args, **kwargs):
|
|
|
|
|
|
request = None
|
|
|
for arg in args:
|
|
|
if isinstance(arg, Request):
|
|
|
request = arg
|
|
|
break
|
|
|
|
|
|
if not request:
|
|
|
|
|
|
return await func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
if request.headers.get(skip_cache_header):
|
|
|
return await func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
if key_generator:
|
|
|
cache_key = await key_generator(request, *args, **kwargs)
|
|
|
else:
|
|
|
user_id = None
|
|
|
if user_specific:
|
|
|
|
|
|
user_id = kwargs.get("user_id") or kwargs.get("current_user_id")
|
|
|
|
|
|
vary_params = {}
|
|
|
if vary_on:
|
|
|
for header in vary_on:
|
|
|
if header in request.headers:
|
|
|
vary_params[header] = request.headers[header]
|
|
|
|
|
|
cache_key = CacheKeyGenerator.endpoint_key(
|
|
|
method=request.method,
|
|
|
path=request.url.path,
|
|
|
query_params=dict(request.query_params),
|
|
|
user_id=user_id,
|
|
|
additional_params=vary_params
|
|
|
)
|
|
|
|
|
|
|
|
|
cached_result = await cache_manager.get(cache_key)
|
|
|
if cached_result is not None:
|
|
|
logger.debug(f"Cache hit for key: {cache_key}")
|
|
|
return cached_result
|
|
|
|
|
|
|
|
|
result = await func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
await cache_manager.set(cache_key, result, ttl)
|
|
|
logger.debug(f"Cached result for key: {cache_key}")
|
|
|
|
|
|
return result
|
|
|
|
|
|
return wrapper
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
def cache_query(
|
|
|
ttl: int = CacheConfig.DEFAULT_TTL,
|
|
|
key_prefix: str = "query"
|
|
|
):
|
|
|
"""
|
|
|
Decorator for caching database query results.
|
|
|
|
|
|
Args:
|
|
|
ttl: Time to live in seconds
|
|
|
key_prefix: Prefix for cache key
|
|
|
|
|
|
Usage:
|
|
|
@cache_query(ttl=600, key_prefix="user_jobs")
|
|
|
async def get_user_jobs(user_id: str, status: str = None):
|
|
|
# Database query logic
|
|
|
return results
|
|
|
"""
|
|
|
def decorator(func: Callable) -> Callable:
|
|
|
@functools.wraps(func)
|
|
|
async def wrapper(*args, **kwargs):
|
|
|
|
|
|
cache_key = CacheKeyGenerator.query_key(
|
|
|
f"{key_prefix}:{func.__name__}",
|
|
|
{**dict(zip(func.__code__.co_varnames, args)), **kwargs}
|
|
|
)
|
|
|
|
|
|
|
|
|
cached_result = await cache_manager.get(cache_key)
|
|
|
if cached_result is not None:
|
|
|
logger.debug(f"Query cache hit for key: {cache_key}")
|
|
|
return cached_result
|
|
|
|
|
|
|
|
|
result = await func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
await cache_manager.set(cache_key, result, ttl)
|
|
|
logger.debug(f"Cached query result for key: {cache_key}")
|
|
|
|
|
|
return result
|
|
|
|
|
|
return wrapper
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
class CacheInvalidationManager:
|
|
|
"""Manages cache invalidation patterns and strategies."""
|
|
|
|
|
|
@staticmethod
|
|
|
async def invalidate_job_related_cache(job_id: str, user_id: str = None):
|
|
|
"""Invalidate all cache entries related to a specific job."""
|
|
|
patterns_to_invalidate = [
|
|
|
f"*jobs*{job_id}*",
|
|
|
f"*job_status*{job_id}*",
|
|
|
f"*videos*{job_id}*"
|
|
|
]
|
|
|
|
|
|
if user_id:
|
|
|
patterns_to_invalidate.extend([
|
|
|
f"*user:{user_id}*jobs*",
|
|
|
f"*user_jobs*{user_id}*"
|
|
|
])
|
|
|
|
|
|
total_deleted = 0
|
|
|
for pattern in patterns_to_invalidate:
|
|
|
deleted = await cache_manager.delete_pattern(pattern)
|
|
|
total_deleted += deleted
|
|
|
|
|
|
logger.info(f"Invalidated {total_deleted} cache entries for job {job_id}")
|
|
|
return total_deleted
|
|
|
|
|
|
@staticmethod
|
|
|
async def invalidate_user_related_cache(user_id: str):
|
|
|
"""Invalidate all cache entries related to a specific user."""
|
|
|
return await cache_manager.invalidate_user_cache(user_id)
|
|
|
|
|
|
@staticmethod
|
|
|
async def invalidate_system_cache():
|
|
|
"""Invalidate system-wide cache entries."""
|
|
|
patterns = [
|
|
|
f"*{CacheConfig.SYSTEM_CACHE}*",
|
|
|
"*health*",
|
|
|
"*metrics*",
|
|
|
"*queue*"
|
|
|
]
|
|
|
|
|
|
total_deleted = 0
|
|
|
for pattern in patterns:
|
|
|
deleted = await cache_manager.delete_pattern(pattern)
|
|
|
total_deleted += deleted
|
|
|
|
|
|
logger.info(f"Invalidated {total_deleted} system cache entries")
|
|
|
return total_deleted
|
|
|
|
|
|
|
|
|
|
|
|
async def warm_common_queries():
|
|
|
"""Warm cache with commonly accessed data."""
|
|
|
warming_functions = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
if warming_functions:
|
|
|
return await cache_manager.warm_cache(warming_functions)
|
|
|
|
|
|
return {"message": "No warming functions configured"}
|
|
|
|
|
|
|
|
|
|
|
|
async def get_cache_info() -> Dict[str, Any]:
|
|
|
"""Get comprehensive cache information and statistics."""
|
|
|
try:
|
|
|
redis_client = redis_manager.redis
|
|
|
redis_info = await redis_client.info("memory")
|
|
|
|
|
|
return {
|
|
|
"cache_stats": cache_manager.get_stats(),
|
|
|
"redis_memory": {
|
|
|
"used_memory": redis_info.get("used_memory_human", "unknown"),
|
|
|
"used_memory_peak": redis_info.get("used_memory_peak_human", "unknown"),
|
|
|
"memory_fragmentation_ratio": redis_info.get("memory_fragmentation_ratio", 0)
|
|
|
},
|
|
|
"connection_info": await redis_manager.get_connection_info(),
|
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
|
}
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to get cache info: {e}")
|
|
|
return {
|
|
|
"error": str(e),
|
|
|
"cache_stats": cache_manager.get_stats(),
|
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
|
} |