| | """
|
| | Redis Cache Implementation for Production
|
| | """
|
| |
|
| | import json
|
| | import hashlib
|
| | from typing import Any, Optional, Union
|
| | from datetime import timedelta
|
| | import redis.asyncio as aioredis
|
| |
|
| | from src.core.config import settings
|
| | from src.core.logging import logger
|
| | from src.core.exceptions import CacheError
|
| |
|
| |
|
| | class RedisCache:
|
| | """Redis cache manager with async support"""
|
| |
|
| | def __init__(self):
|
| | self.redis: Optional[aioredis.Redis] = None
|
| | self.enabled = settings.CACHE_PREDICTIONS
|
| |
|
| | async def connect(self):
|
| | """Connect to Redis"""
|
| | if not self.enabled:
|
| | logger.info("Redis cache is disabled")
|
| | return
|
| |
|
| | try:
|
| | self.redis = await aioredis.from_url(
|
| | settings.REDIS_URL,
|
| | encoding="utf-8",
|
| | decode_responses=True,
|
| | max_connections=50
|
| | )
|
| |
|
| | await self.redis.ping()
|
| | logger.info(f"Connected to Redis at {settings.REDIS_HOST}:{settings.REDIS_PORT}")
|
| | except Exception as e:
|
| | logger.error(f"Failed to connect to Redis: {e}")
|
| | self.enabled = False
|
| | raise CacheError(f"Redis connection failed: {e}")
|
| |
|
| | async def disconnect(self):
|
| | """Disconnect from Redis"""
|
| | if self.redis:
|
| | await self.redis.close()
|
| | logger.info("Disconnected from Redis")
|
| |
|
| | def _generate_cache_key(self, prefix: str, data: Union[str, dict]) -> str:
|
| | """Generate cache key from data"""
|
| | if isinstance(data, dict):
|
| | data_str = json.dumps(data, sort_keys=True)
|
| | else:
|
| | data_str = str(data)
|
| |
|
| | hash_value = hashlib.sha256(data_str.encode()).hexdigest()[:16]
|
| | return f"{prefix}:{hash_value}"
|
| |
|
| | async def get(self, key: str) -> Optional[Any]:
|
| | """Get value from cache"""
|
| | if not self.enabled or not self.redis:
|
| | return None
|
| |
|
| | try:
|
| | value = await self.redis.get(key)
|
| | if value:
|
| | logger.debug(f"Cache hit: {key}")
|
| | return json.loads(value)
|
| | logger.debug(f"Cache miss: {key}")
|
| | return None
|
| | except Exception as e:
|
| | logger.warning(f"Cache get error for {key}: {e}")
|
| | return None
|
| |
|
| | async def set(
|
| | self,
|
| | key: str,
|
| | value: Any,
|
| | ttl: Optional[int] = None
|
| | ) -> bool:
|
| | """Set value in cache with TTL"""
|
| | if not self.enabled or not self.redis:
|
| | return False
|
| |
|
| | try:
|
| | ttl = ttl or settings.CACHE_TTL
|
| | value_json = json.dumps(value)
|
| | await self.redis.setex(key, ttl, value_json)
|
| | logger.debug(f"Cache set: {key} (TTL: {ttl}s)")
|
| | return True
|
| | except Exception as e:
|
| | logger.warning(f"Cache set error for {key}: {e}")
|
| | return False
|
| |
|
| | async def delete(self, key: str) -> bool:
|
| | """Delete key from cache"""
|
| | if not self.enabled or not self.redis:
|
| | return False
|
| |
|
| | try:
|
| | await self.redis.delete(key)
|
| | logger.debug(f"Cache delete: {key}")
|
| | return True
|
| | except Exception as e:
|
| | logger.warning(f"Cache delete error for {key}: {e}")
|
| | return False
|
| |
|
| | async def get_prediction(
|
| | self,
|
| | model_type: str,
|
| | input_data: Union[str, dict]
|
| | ) -> Optional[dict]:
|
| | """Get cached prediction"""
|
| | key = self._generate_cache_key(f"pred:{model_type}", input_data)
|
| | return await self.get(key)
|
| |
|
| | async def set_prediction(
|
| | self,
|
| | model_type: str,
|
| | input_data: Union[str, dict],
|
| | result: dict,
|
| | ttl: Optional[int] = None
|
| | ) -> bool:
|
| | """Cache prediction result"""
|
| | key = self._generate_cache_key(f"pred:{model_type}", input_data)
|
| | return await self.set(key, result, ttl)
|
| |
|
| | async def increment_rate_limit(
|
| | self,
|
| | identifier: str,
|
| | window_seconds: int
|
| | ) -> int:
|
| | """Increment rate limit counter"""
|
| | if not self.enabled or not self.redis:
|
| | return 0
|
| |
|
| | try:
|
| | key = f"ratelimit:{identifier}"
|
| | pipe = self.redis.pipeline()
|
| | pipe.incr(key)
|
| | pipe.expire(key, window_seconds)
|
| | result = await pipe.execute()
|
| | count = result[0]
|
| | logger.debug(f"Rate limit count for {identifier}: {count}")
|
| | return count
|
| | except Exception as e:
|
| | logger.warning(f"Rate limit increment error: {e}")
|
| | return 0
|
| |
|
| | async def get_rate_limit_count(self, identifier: str) -> int:
|
| | """Get current rate limit count"""
|
| | if not self.enabled or not self.redis:
|
| | return 0
|
| |
|
| | try:
|
| | key = f"ratelimit:{identifier}"
|
| | count = await self.redis.get(key)
|
| | return int(count) if count else 0
|
| | except Exception as e:
|
| | logger.warning(f"Rate limit get error: {e}")
|
| | return 0
|
| |
|
| | async def clear_all(self) -> bool:
|
| | """Clear all cache (use with caution!)"""
|
| | if not self.enabled or not self.redis:
|
| | return False
|
| |
|
| | try:
|
| | await self.redis.flushdb()
|
| | logger.warning("All cache cleared!")
|
| | return True
|
| | except Exception as e:
|
| | logger.error(f"Cache clear error: {e}")
|
| | return False
|
| |
|
| |
|
| |
|
| | cache = RedisCache()
|
| |
|
| |
|
| |
|
| | def cached(prefix: str, ttl: Optional[int] = None):
|
| | """Decorator to cache function results"""
|
| | def decorator(func):
|
| | async def wrapper(*args, **kwargs):
|
| |
|
| | cache_data = {"args": str(args), "kwargs": str(kwargs)}
|
| | cache_key = cache._generate_cache_key(prefix, cache_data)
|
| |
|
| |
|
| | cached_result = await cache.get(cache_key)
|
| | if cached_result is not None:
|
| | return cached_result
|
| |
|
| |
|
| | result = await func(*args, **kwargs)
|
| |
|
| |
|
| | await cache.set(cache_key, result, ttl)
|
| |
|
| | return result
|
| | return wrapper
|
| | return decorator
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | import asyncio
|
| |
|
| | async def test_cache():
|
| |
|
| | await cache.connect()
|
| |
|
| |
|
| | await cache.set("test_key", {"value": 123}, ttl=60)
|
| | result = await cache.get("test_key")
|
| | print(f"Retrieved: {result}")
|
| |
|
| |
|
| | await cache.set_prediction(
|
| | "deepfake",
|
| | {"image": "test.jpg"},
|
| | {"prediction": "FAKE", "confidence": 0.95},
|
| | ttl=300
|
| | )
|
| |
|
| | cached_pred = await cache.get_prediction("deepfake", {"image": "test.jpg"})
|
| | print(f"Cached prediction: {cached_pred}")
|
| |
|
| |
|
| | for i in range(5):
|
| | count = await cache.increment_rate_limit("user:123", 60)
|
| | print(f"Request {i+1}: Rate limit count = {count}")
|
| |
|
| |
|
| | await cache.disconnect()
|
| |
|
| | asyncio.run(test_cache())
|
| |
|