""" Redis Cache Service for Gapura AI Provides caching layer for Google Sheets data and predictions """ import os import time import json import logging from typing import Optional, Any, Callable from datetime import timedelta from functools import wraps import redis from redis.exceptions import RedisError logger = logging.getLogger(__name__) class CacheService: """Redis-based caching service with L1 In-Memory and L2 Redis layers""" def __init__(self): self.redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") self.client = None # In-memory cache stores: key -> (value, expiry_timestamp) self.in_memory_cache = {} self.enabled = os.getenv("CACHE_ENABLED", "true").lower() == "true" self.backend = os.getenv("CACHE_BACKEND", "").lower() # 'redis' | 'memory' | '' # Auto-detect HF Spaces and prefer memory backend if REDIS_URL is not set if (os.getenv("SPACE_ID") or os.getenv("HF_TOKEN")) and not os.getenv("REDIS_URL"): if self.backend == "": self.backend = "memory" if self.enabled and self.backend != "memory": self._connect() def _connect(self): """Connect to Redis with in-memory fallback""" if not self.enabled: logger.info("Cache disabled via environment variable") return if self.backend == "memory": logger.info("Cache backend set to memory; skipping Redis connection") self.client = None return try: self.client = redis.from_url( self.redis_url, decode_responses=True, socket_connect_timeout=2, socket_timeout=2, ) self.client.ping() logger.info(f"Connected to Redis at {self.redis_url}") except RedisError as e: logger.warning(f"Failed to connect to Redis: {e}. Falling back to in-memory cache.") self.client = None def get(self, key: str) -> Optional[Any]: """Get value from cache (L1 Memory -> L2 Redis)""" if not self.enabled: return None # 1. Try L1 In-Memory Cache first if key in self.in_memory_cache: entry = self.in_memory_cache[key] # Check for tuple (value, expiry) if isinstance(entry, tuple) and len(entry) == 2: value, expiry = entry if expiry > time.time(): logger.debug(f"L1 Memory HIT: {key}") return value else: # Expired, remove it del self.in_memory_cache[key] else: # Legacy format (just value), return it logger.debug(f"L1 Memory HIT (Legacy): {key}") return entry # 2. Try L2 Redis Cache if self.client: try: value_str = self.client.get(key) if value_str: logger.debug(f"L2 Redis HIT: {key}") data = json.loads(value_str) # Backfill L1 Memory # Try to get TTL from Redis to sync expiration try: ttl = self.client.ttl(key) if ttl > 0: self.in_memory_cache[key] = (data, time.time() + ttl) else: self.in_memory_cache[key] = (data, time.time() + 300) # Default 5m except: self.in_memory_cache[key] = (data, time.time() + 300) return data except (RedisError, json.JSONDecodeError) as e: logger.warning(f"Redis get error for {key}: {e}") logger.debug(f"Cache MISS: {key}") return None def set(self, key: str, value: Any, ttl_seconds: int = 300) -> bool: """Set value in cache (L1 Memory + L2 Redis)""" if not self.enabled: return False success = False # 1. Set L2 Redis if self.client: try: serialized = json.dumps(value, default=str) self.client.setex(key, ttl_seconds, serialized) logger.debug(f"L2 Redis SET: {key}") success = True except (RedisError, TypeError) as e: logger.warning(f"Redis set error for {key}: {e}") # 2. Set L1 Memory expiry = time.time() + ttl_seconds self.in_memory_cache[key] = (value, expiry) # Memory Management: Simple FIFO if too large if len(self.in_memory_cache) > 1000: # Cleanup expired items first now = time.time() expired_keys = [k for k, v in self.in_memory_cache.items() if isinstance(v, tuple) and v[1] < now] for k in expired_keys: del self.in_memory_cache[k] # If still too big, remove oldest inserted if len(self.in_memory_cache) > 1000: first_key = next(iter(self.in_memory_cache)) self.in_memory_cache.pop(first_key) success = True return success def delete(self, key: str) -> bool: """Delete key from cache""" if not self.enabled: return False # Delete from Redis if self.client: try: self.client.delete(key) except RedisError: pass # Delete from In-Memory if key in self.in_memory_cache: del self.in_memory_cache[key] return True def delete_pattern(self, pattern: str) -> int: """Delete all keys matching pattern""" if not self.enabled: return 0 deleted_count = 0 # 1. Delete from Redis if self.client: try: keys = self.client.keys(pattern) if keys: deleted_count = self.client.delete(*keys) logger.debug(f"Redis DELETE pattern {pattern}: {deleted_count} keys") except RedisError as e: logger.warning(f"Cache delete pattern error for {pattern}: {e}") # 2. Delete from In-Memory (using simple string matching) try: mem_deleted = 0 if not self.in_memory_cache: return deleted_count # Simple wildcard matching token = pattern.replace("*", "") keys_to_check = list(self.in_memory_cache.keys()) for k in keys_to_check: if token in k: # Simple substring match for now del self.in_memory_cache[k] mem_deleted += 1 if mem_deleted > 0: logger.debug(f"In-memory DELETE pattern {pattern}: {mem_deleted} keys") # Return max of both (approximate) return max(deleted_count, mem_deleted) except Exception as e: logger.warning(f"In-memory delete pattern error for {pattern}: {e}") return deleted_count def health_check(self) -> dict: """Check cache health""" if not self.enabled: return {"status": "disabled", "message": "Caching is disabled"} status = { "backend": self.backend if self.backend else ("redis" if self.client else "memory"), "l1_items": len(self.in_memory_cache), } if self.client: try: self.client.ping() info = self.client.info("memory") status.update({ "redis_status": "connected", "redis_used_memory": info.get("used_memory_human", "unknown"), "redis_clients": self.client.client_list().__len__(), }) except RedisError as e: status["redis_status"] = f"error: {str(e)}" else: status["redis_status"] = "not_configured" return status def cached(key_prefix: str, ttl_seconds: int = 300): """ Decorator for caching function results Usage: @cached("my_prefix", ttl_seconds=300) """ def decorator(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs): bypass_cache = kwargs.pop("bypass_cache", False) # Use singleton instance! cache = get_cache() if bypass_cache: logger.debug(f"Cache bypassed for {key_prefix}") return func(*args, **kwargs) # Create a consistent cache key # Filter out authentication related args if present to avoid caching user-specifics if not needed # For now, just hash everything arg_str = str(args) + str(sorted(kwargs.items())) cache_key = f"{key_prefix}:{hash(arg_str)}" cached_result = cache.get(cache_key) if cached_result is not None: return cached_result result = func(*args, **kwargs) if result is not None: cache.set(cache_key, result, ttl_seconds) return result return wrapper return decorator _cache_instance: Optional[CacheService] = None def get_cache() -> CacheService: """Get singleton cache instance""" global _cache_instance if _cache_instance is None: _cache_instance = CacheService() return _cache_instance