ShreyasGosavi's picture
Upload 37 files
53bec59 verified
"""
Redis Cache Implementation for Production
"""
import json
import hashlib
from typing import Any, Optional, Union
from datetime import timedelta
import redis.asyncio as aioredis
from src.core.config import settings
from src.core.logging import logger
from src.core.exceptions import CacheError
class RedisCache:
"""Redis cache manager with async support"""
def __init__(self):
self.redis: Optional[aioredis.Redis] = None
self.enabled = settings.CACHE_PREDICTIONS
async def connect(self):
"""Connect to Redis"""
if not self.enabled:
logger.info("Redis cache is disabled")
return
try:
self.redis = await aioredis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=True,
max_connections=50
)
# Test connection
await self.redis.ping()
logger.info(f"Connected to Redis at {settings.REDIS_HOST}:{settings.REDIS_PORT}")
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
self.enabled = False
raise CacheError(f"Redis connection failed: {e}")
async def disconnect(self):
"""Disconnect from Redis"""
if self.redis:
await self.redis.close()
logger.info("Disconnected from Redis")
def _generate_cache_key(self, prefix: str, data: Union[str, dict]) -> str:
"""Generate cache key from data"""
if isinstance(data, dict):
data_str = json.dumps(data, sort_keys=True)
else:
data_str = str(data)
hash_value = hashlib.sha256(data_str.encode()).hexdigest()[:16]
return f"{prefix}:{hash_value}"
async def get(self, key: str) -> Optional[Any]:
"""Get value from cache"""
if not self.enabled or not self.redis:
return None
try:
value = await self.redis.get(key)
if value:
logger.debug(f"Cache hit: {key}")
return json.loads(value)
logger.debug(f"Cache miss: {key}")
return None
except Exception as e:
logger.warning(f"Cache get error for {key}: {e}")
return None
async def set(
self,
key: str,
value: Any,
ttl: Optional[int] = None
) -> bool:
"""Set value in cache with TTL"""
if not self.enabled or not self.redis:
return False
try:
ttl = ttl or settings.CACHE_TTL
value_json = json.dumps(value)
await self.redis.setex(key, ttl, value_json)
logger.debug(f"Cache set: {key} (TTL: {ttl}s)")
return True
except Exception as e:
logger.warning(f"Cache set error for {key}: {e}")
return False
async def delete(self, key: str) -> bool:
"""Delete key from cache"""
if not self.enabled or not self.redis:
return False
try:
await self.redis.delete(key)
logger.debug(f"Cache delete: {key}")
return True
except Exception as e:
logger.warning(f"Cache delete error for {key}: {e}")
return False
async def get_prediction(
self,
model_type: str,
input_data: Union[str, dict]
) -> Optional[dict]:
"""Get cached prediction"""
key = self._generate_cache_key(f"pred:{model_type}", input_data)
return await self.get(key)
async def set_prediction(
self,
model_type: str,
input_data: Union[str, dict],
result: dict,
ttl: Optional[int] = None
) -> bool:
"""Cache prediction result"""
key = self._generate_cache_key(f"pred:{model_type}", input_data)
return await self.set(key, result, ttl)
async def increment_rate_limit(
self,
identifier: str,
window_seconds: int
) -> int:
"""Increment rate limit counter"""
if not self.enabled or not self.redis:
return 0
try:
key = f"ratelimit:{identifier}"
pipe = self.redis.pipeline()
pipe.incr(key)
pipe.expire(key, window_seconds)
result = await pipe.execute()
count = result[0]
logger.debug(f"Rate limit count for {identifier}: {count}")
return count
except Exception as e:
logger.warning(f"Rate limit increment error: {e}")
return 0
async def get_rate_limit_count(self, identifier: str) -> int:
"""Get current rate limit count"""
if not self.enabled or not self.redis:
return 0
try:
key = f"ratelimit:{identifier}"
count = await self.redis.get(key)
return int(count) if count else 0
except Exception as e:
logger.warning(f"Rate limit get error: {e}")
return 0
async def clear_all(self) -> bool:
"""Clear all cache (use with caution!)"""
if not self.enabled or not self.redis:
return False
try:
await self.redis.flushdb()
logger.warning("All cache cleared!")
return True
except Exception as e:
logger.error(f"Cache clear error: {e}")
return False
# Global cache instance
cache = RedisCache()
# Decorator for caching function results
def cached(prefix: str, ttl: Optional[int] = None):
"""Decorator to cache function results"""
def decorator(func):
async def wrapper(*args, **kwargs):
# Generate cache key from function arguments
cache_data = {"args": str(args), "kwargs": str(kwargs)}
cache_key = cache._generate_cache_key(prefix, cache_data)
# Try to get from cache
cached_result = await cache.get(cache_key)
if cached_result is not None:
return cached_result
# Execute function
result = await func(*args, **kwargs)
# Cache result
await cache.set(cache_key, result, ttl)
return result
return wrapper
return decorator
if __name__ == "__main__":
import asyncio
async def test_cache():
# Connect
await cache.connect()
# Test basic operations
await cache.set("test_key", {"value": 123}, ttl=60)
result = await cache.get("test_key")
print(f"Retrieved: {result}")
# Test prediction caching
await cache.set_prediction(
"deepfake",
{"image": "test.jpg"},
{"prediction": "FAKE", "confidence": 0.95},
ttl=300
)
cached_pred = await cache.get_prediction("deepfake", {"image": "test.jpg"})
print(f"Cached prediction: {cached_pred}")
# Test rate limiting
for i in range(5):
count = await cache.increment_rate_limit("user:123", 60)
print(f"Request {i+1}: Rate limit count = {count}")
# Disconnect
await cache.disconnect()
asyncio.run(test_cache())