File size: 7,658 Bytes
53bec59 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 | """
Redis Cache Implementation for Production
"""
import json
import hashlib
from typing import Any, Optional, Union
from datetime import timedelta
import redis.asyncio as aioredis
from src.core.config import settings
from src.core.logging import logger
from src.core.exceptions import CacheError
class RedisCache:
"""Redis cache manager with async support"""
def __init__(self):
self.redis: Optional[aioredis.Redis] = None
self.enabled = settings.CACHE_PREDICTIONS
async def connect(self):
"""Connect to Redis"""
if not self.enabled:
logger.info("Redis cache is disabled")
return
try:
self.redis = await aioredis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=True,
max_connections=50
)
# Test connection
await self.redis.ping()
logger.info(f"Connected to Redis at {settings.REDIS_HOST}:{settings.REDIS_PORT}")
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
self.enabled = False
raise CacheError(f"Redis connection failed: {e}")
async def disconnect(self):
"""Disconnect from Redis"""
if self.redis:
await self.redis.close()
logger.info("Disconnected from Redis")
def _generate_cache_key(self, prefix: str, data: Union[str, dict]) -> str:
"""Generate cache key from data"""
if isinstance(data, dict):
data_str = json.dumps(data, sort_keys=True)
else:
data_str = str(data)
hash_value = hashlib.sha256(data_str.encode()).hexdigest()[:16]
return f"{prefix}:{hash_value}"
async def get(self, key: str) -> Optional[Any]:
"""Get value from cache"""
if not self.enabled or not self.redis:
return None
try:
value = await self.redis.get(key)
if value:
logger.debug(f"Cache hit: {key}")
return json.loads(value)
logger.debug(f"Cache miss: {key}")
return None
except Exception as e:
logger.warning(f"Cache get error for {key}: {e}")
return None
async def set(
self,
key: str,
value: Any,
ttl: Optional[int] = None
) -> bool:
"""Set value in cache with TTL"""
if not self.enabled or not self.redis:
return False
try:
ttl = ttl or settings.CACHE_TTL
value_json = json.dumps(value)
await self.redis.setex(key, ttl, value_json)
logger.debug(f"Cache set: {key} (TTL: {ttl}s)")
return True
except Exception as e:
logger.warning(f"Cache set error for {key}: {e}")
return False
async def delete(self, key: str) -> bool:
"""Delete key from cache"""
if not self.enabled or not self.redis:
return False
try:
await self.redis.delete(key)
logger.debug(f"Cache delete: {key}")
return True
except Exception as e:
logger.warning(f"Cache delete error for {key}: {e}")
return False
async def get_prediction(
self,
model_type: str,
input_data: Union[str, dict]
) -> Optional[dict]:
"""Get cached prediction"""
key = self._generate_cache_key(f"pred:{model_type}", input_data)
return await self.get(key)
async def set_prediction(
self,
model_type: str,
input_data: Union[str, dict],
result: dict,
ttl: Optional[int] = None
) -> bool:
"""Cache prediction result"""
key = self._generate_cache_key(f"pred:{model_type}", input_data)
return await self.set(key, result, ttl)
async def increment_rate_limit(
self,
identifier: str,
window_seconds: int
) -> int:
"""Increment rate limit counter"""
if not self.enabled or not self.redis:
return 0
try:
key = f"ratelimit:{identifier}"
pipe = self.redis.pipeline()
pipe.incr(key)
pipe.expire(key, window_seconds)
result = await pipe.execute()
count = result[0]
logger.debug(f"Rate limit count for {identifier}: {count}")
return count
except Exception as e:
logger.warning(f"Rate limit increment error: {e}")
return 0
async def get_rate_limit_count(self, identifier: str) -> int:
"""Get current rate limit count"""
if not self.enabled or not self.redis:
return 0
try:
key = f"ratelimit:{identifier}"
count = await self.redis.get(key)
return int(count) if count else 0
except Exception as e:
logger.warning(f"Rate limit get error: {e}")
return 0
async def clear_all(self) -> bool:
"""Clear all cache (use with caution!)"""
if not self.enabled or not self.redis:
return False
try:
await self.redis.flushdb()
logger.warning("All cache cleared!")
return True
except Exception as e:
logger.error(f"Cache clear error: {e}")
return False
# Global cache instance
cache = RedisCache()
# Decorator for caching function results
def cached(prefix: str, ttl: Optional[int] = None):
"""Decorator to cache function results"""
def decorator(func):
async def wrapper(*args, **kwargs):
# Generate cache key from function arguments
cache_data = {"args": str(args), "kwargs": str(kwargs)}
cache_key = cache._generate_cache_key(prefix, cache_data)
# Try to get from cache
cached_result = await cache.get(cache_key)
if cached_result is not None:
return cached_result
# Execute function
result = await func(*args, **kwargs)
# Cache result
await cache.set(cache_key, result, ttl)
return result
return wrapper
return decorator
if __name__ == "__main__":
import asyncio
async def test_cache():
# Connect
await cache.connect()
# Test basic operations
await cache.set("test_key", {"value": 123}, ttl=60)
result = await cache.get("test_key")
print(f"Retrieved: {result}")
# Test prediction caching
await cache.set_prediction(
"deepfake",
{"image": "test.jpg"},
{"prediction": "FAKE", "confidence": 0.95},
ttl=300
)
cached_pred = await cache.get_prediction("deepfake", {"image": "test.jpg"})
print(f"Cached prediction: {cached_pred}")
# Test rate limiting
for i in range(5):
count = await cache.increment_rate_limit("user:123", 60)
print(f"Request {i+1}: Rate limit count = {count}")
# Disconnect
await cache.disconnect()
asyncio.run(test_cache())
|