""" Redis Cache Implementation for Production """ import json import hashlib from typing import Any, Optional, Union from datetime import timedelta import redis.asyncio as aioredis from src.core.config import settings from src.core.logging import logger from src.core.exceptions import CacheError class RedisCache: """Redis cache manager with async support""" def __init__(self): self.redis: Optional[aioredis.Redis] = None self.enabled = settings.CACHE_PREDICTIONS async def connect(self): """Connect to Redis""" if not self.enabled: logger.info("Redis cache is disabled") return try: self.redis = await aioredis.from_url( settings.REDIS_URL, encoding="utf-8", decode_responses=True, max_connections=50 ) # Test connection await self.redis.ping() logger.info(f"Connected to Redis at {settings.REDIS_HOST}:{settings.REDIS_PORT}") except Exception as e: logger.error(f"Failed to connect to Redis: {e}") self.enabled = False raise CacheError(f"Redis connection failed: {e}") async def disconnect(self): """Disconnect from Redis""" if self.redis: await self.redis.close() logger.info("Disconnected from Redis") def _generate_cache_key(self, prefix: str, data: Union[str, dict]) -> str: """Generate cache key from data""" if isinstance(data, dict): data_str = json.dumps(data, sort_keys=True) else: data_str = str(data) hash_value = hashlib.sha256(data_str.encode()).hexdigest()[:16] return f"{prefix}:{hash_value}" async def get(self, key: str) -> Optional[Any]: """Get value from cache""" if not self.enabled or not self.redis: return None try: value = await self.redis.get(key) if value: logger.debug(f"Cache hit: {key}") return json.loads(value) logger.debug(f"Cache miss: {key}") return None except Exception as e: logger.warning(f"Cache get error for {key}: {e}") return None async def set( self, key: str, value: Any, ttl: Optional[int] = None ) -> bool: """Set value in cache with TTL""" if not self.enabled or not self.redis: return False try: ttl = ttl or settings.CACHE_TTL value_json = json.dumps(value) await self.redis.setex(key, ttl, value_json) logger.debug(f"Cache set: {key} (TTL: {ttl}s)") return True except Exception as e: logger.warning(f"Cache set error for {key}: {e}") return False async def delete(self, key: str) -> bool: """Delete key from cache""" if not self.enabled or not self.redis: return False try: await self.redis.delete(key) logger.debug(f"Cache delete: {key}") return True except Exception as e: logger.warning(f"Cache delete error for {key}: {e}") return False async def get_prediction( self, model_type: str, input_data: Union[str, dict] ) -> Optional[dict]: """Get cached prediction""" key = self._generate_cache_key(f"pred:{model_type}", input_data) return await self.get(key) async def set_prediction( self, model_type: str, input_data: Union[str, dict], result: dict, ttl: Optional[int] = None ) -> bool: """Cache prediction result""" key = self._generate_cache_key(f"pred:{model_type}", input_data) return await self.set(key, result, ttl) async def increment_rate_limit( self, identifier: str, window_seconds: int ) -> int: """Increment rate limit counter""" if not self.enabled or not self.redis: return 0 try: key = f"ratelimit:{identifier}" pipe = self.redis.pipeline() pipe.incr(key) pipe.expire(key, window_seconds) result = await pipe.execute() count = result[0] logger.debug(f"Rate limit count for {identifier}: {count}") return count except Exception as e: logger.warning(f"Rate limit increment error: {e}") return 0 async def get_rate_limit_count(self, identifier: str) -> int: """Get current rate limit count""" if not self.enabled or not self.redis: return 0 try: key = f"ratelimit:{identifier}" count = await self.redis.get(key) return int(count) if count else 0 except Exception as e: logger.warning(f"Rate limit get error: {e}") return 0 async def clear_all(self) -> bool: """Clear all cache (use with caution!)""" if not self.enabled or not self.redis: return False try: await self.redis.flushdb() logger.warning("All cache cleared!") return True except Exception as e: logger.error(f"Cache clear error: {e}") return False # Global cache instance cache = RedisCache() # Decorator for caching function results def cached(prefix: str, ttl: Optional[int] = None): """Decorator to cache function results""" def decorator(func): async def wrapper(*args, **kwargs): # Generate cache key from function arguments cache_data = {"args": str(args), "kwargs": str(kwargs)} cache_key = cache._generate_cache_key(prefix, cache_data) # Try to get from cache cached_result = await cache.get(cache_key) if cached_result is not None: return cached_result # Execute function result = await func(*args, **kwargs) # Cache result await cache.set(cache_key, result, ttl) return result return wrapper return decorator if __name__ == "__main__": import asyncio async def test_cache(): # Connect await cache.connect() # Test basic operations await cache.set("test_key", {"value": 123}, ttl=60) result = await cache.get("test_key") print(f"Retrieved: {result}") # Test prediction caching await cache.set_prediction( "deepfake", {"image": "test.jpg"}, {"prediction": "FAKE", "confidence": 0.95}, ttl=300 ) cached_pred = await cache.get_prediction("deepfake", {"image": "test.jpg"}) print(f"Cached prediction: {cached_pred}") # Test rate limiting for i in range(5): count = await cache.increment_rate_limit("user:123", 60) print(f"Request {i+1}: Rate limit count = {count}") # Disconnect await cache.disconnect() asyncio.run(test_cache())