| | """
|
| | Redis Client Module.
|
| |
|
| | Provides session state management with TTL for:
|
| | - Active honeypot sessions
|
| | - Conversation context caching
|
| | - Rate limiting counters
|
| |
|
| | Implements Task 6.2 requirements:
|
| | - AC-2.3.1: State persists across API calls
|
| | - AC-2.3.2: Session expires after 1 hour
|
| | - AC-2.3.4: Redis failure degrades gracefully
|
| | """
|
| |
|
| | from typing import Dict, Optional, Any, Callable, TypeVar
|
| | import json
|
| | import os
|
| | import time
|
| | from functools import wraps
|
| | import redis
|
| | from redis.exceptions import ConnectionError as RedisConnectionError, RedisError
|
| |
|
| | from app.config import settings
|
| | from app.utils.logger import get_logger
|
| |
|
| | logger = get_logger(__name__)
|
| |
|
| |
|
| | T = TypeVar("T")
|
| |
|
| |
|
| | redis_client: Optional[redis.Redis] = None
|
| |
|
| |
|
| | _redis_unavailable: bool = False
|
| | _redis_last_check: float = 0
|
| | _REDIS_RECHECK_INTERVAL = 60
|
| |
|
| |
|
| | _fallback_cache: Dict[str, Dict[str, Any]] = {}
|
| | _fallback_cache_ttl: Dict[str, float] = {}
|
| |
|
| |
|
| | DEFAULT_SESSION_TTL = 3600
|
| |
|
| |
|
| | def init_redis_client() -> None:
|
| | """
|
| | Initialize Redis client from configuration.
|
| |
|
| | Raises:
|
| | ValueError: If REDIS_URL is not configured
|
| | """
|
| | global redis_client
|
| |
|
| | if redis_client is not None:
|
| | return
|
| |
|
| | redis_url = settings.REDIS_URL
|
| |
|
| | if not redis_url:
|
| | logger.warning("REDIS_URL not configured. Redis operations will fail.")
|
| | return
|
| |
|
| | try:
|
| | redis_client = redis.from_url(
|
| | redis_url,
|
| | decode_responses=True,
|
| | socket_connect_timeout=1,
|
| | socket_timeout=1,
|
| | retry_on_timeout=False,
|
| | health_check_interval=60,
|
| | )
|
| |
|
| | redis_client.ping()
|
| | logger.info("Redis client initialized successfully")
|
| | except (RedisConnectionError, RedisError) as e:
|
| | logger.error(f"Failed to initialize Redis client: {e}")
|
| | redis_client = None
|
| | raise
|
| |
|
| |
|
| | def get_redis_client() -> redis.Redis:
|
| | """
|
| | Get Redis client connection.
|
| |
|
| | Returns:
|
| | Redis client object
|
| |
|
| | Raises:
|
| | ConnectionError: If Redis connection fails
|
| | ValueError: If REDIS_URL is not configured
|
| | """
|
| | global _redis_unavailable, _redis_last_check
|
| |
|
| |
|
| | if _redis_unavailable:
|
| | if time.time() - _redis_last_check < _REDIS_RECHECK_INTERVAL:
|
| | raise ConnectionError("Redis unavailable (cached)")
|
| |
|
| | _redis_unavailable = False
|
| |
|
| | if redis_client is None:
|
| | try:
|
| | init_redis_client()
|
| | except Exception:
|
| | _redis_unavailable = True
|
| | _redis_last_check = time.time()
|
| | raise
|
| |
|
| | if redis_client is None:
|
| | _redis_unavailable = True
|
| | _redis_last_check = time.time()
|
| | raise ConnectionError("Redis client not initialized. Check REDIS_URL configuration.")
|
| |
|
| | return redis_client
|
| |
|
| |
|
| | def save_session_state(session_id: str, state: Dict[str, Any], ttl: int = 3600) -> bool:
|
| | """
|
| | Save session state to Redis with TTL.
|
| |
|
| | Args:
|
| | session_id: Unique session identifier
|
| | state: Session state dictionary
|
| | ttl: Time-to-live in seconds (default 1 hour)
|
| |
|
| | Returns:
|
| | True if successful, False otherwise
|
| | """
|
| | try:
|
| | client = get_redis_client()
|
| | key = f"session:{session_id}"
|
| | client.setex(key, ttl, json.dumps(state))
|
| | return True
|
| | except (ConnectionError, RedisError) as e:
|
| | logger.error(f"Failed to save session state: {e}")
|
| | return False
|
| |
|
| |
|
| | def get_session_state(session_id: str) -> Optional[Dict[str, Any]]:
|
| | """
|
| | Retrieve session state from Redis.
|
| |
|
| | Args:
|
| | session_id: Session identifier
|
| |
|
| | Returns:
|
| | Session state dictionary or None if not found/expired
|
| | """
|
| | try:
|
| | client = get_redis_client()
|
| | key = f"session:{session_id}"
|
| | data = client.get(key)
|
| | if data:
|
| | return json.loads(data)
|
| | return None
|
| | except (ConnectionError, RedisError) as e:
|
| | logger.error(f"Failed to get session state: {e}")
|
| | return None
|
| | except json.JSONDecodeError as e:
|
| | logger.error(f"Failed to decode session state JSON: {e}")
|
| | return None
|
| |
|
| |
|
| | def delete_session_state(session_id: str) -> bool:
|
| | """
|
| | Delete session state from Redis.
|
| |
|
| | Args:
|
| | session_id: Session identifier
|
| |
|
| | Returns:
|
| | True if deleted, False if not found
|
| | """
|
| | try:
|
| | client = get_redis_client()
|
| | key = f"session:{session_id}"
|
| | deleted = client.delete(key)
|
| | return deleted > 0
|
| | except (ConnectionError, RedisError) as e:
|
| | logger.error(f"Failed to delete session state: {e}")
|
| | return False
|
| |
|
| |
|
| | def update_session_state(session_id: str, updates: Dict[str, Any]) -> bool:
|
| | """
|
| | Update existing session state.
|
| |
|
| | Args:
|
| | session_id: Session identifier
|
| | updates: Fields to update
|
| |
|
| | Returns:
|
| | True if successful, False if session not found
|
| | """
|
| |
|
| | state = get_session_state(session_id)
|
| | if state is None:
|
| | return False
|
| |
|
| | state.update(updates)
|
| | return save_session_state(session_id, state)
|
| |
|
| |
|
| | def increment_rate_counter(key: str, window_seconds: int = 60) -> int:
|
| | """
|
| | Increment rate limiting counter.
|
| |
|
| | Args:
|
| | key: Counter key (e.g., IP address)
|
| | window_seconds: Time window for counter
|
| |
|
| | Returns:
|
| | Current count within window
|
| | """
|
| | try:
|
| | client = get_redis_client()
|
| | counter_key = f"rate_limit:{key}"
|
| | count = client.incr(counter_key)
|
| | if count == 1:
|
| |
|
| | client.expire(counter_key, window_seconds)
|
| | return count
|
| | except (ConnectionError, RedisError) as e:
|
| | logger.error(f"Failed to increment rate counter: {e}")
|
| | return 0
|
| |
|
| |
|
| | def check_rate_limit(key: str, limit: int, window_seconds: int = 60) -> bool:
|
| | """
|
| | Check if rate limit is exceeded.
|
| |
|
| | Args:
|
| | key: Counter key
|
| | limit: Maximum allowed requests
|
| | window_seconds: Time window
|
| |
|
| | Returns:
|
| | True if within limit, False if exceeded
|
| | """
|
| | try:
|
| | count = increment_rate_counter(key, window_seconds)
|
| | return count <= limit
|
| | except Exception as e:
|
| | logger.error(f"Failed to check rate limit: {e}")
|
| |
|
| | return True
|
| |
|
| |
|
| | def health_check() -> bool:
|
| | """
|
| | Check Redis connection health.
|
| |
|
| | Returns:
|
| | True if Redis is responsive, False otherwise
|
| | """
|
| | try:
|
| | client = get_redis_client()
|
| | client.ping()
|
| | return True
|
| | except (ConnectionError, RedisError) as e:
|
| | logger.warning(f"Redis health check failed: {e}")
|
| | return False
|
| | except Exception as e:
|
| | logger.error(f"Unexpected error in Redis health check: {e}")
|
| | return False
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def _cleanup_fallback_cache() -> None:
|
| | """Remove expired entries from fallback cache."""
|
| | now = time.time()
|
| | expired_keys = [
|
| | key for key, expiry in _fallback_cache_ttl.items()
|
| | if expiry < now
|
| | ]
|
| | for key in expired_keys:
|
| | _fallback_cache.pop(key, None)
|
| | _fallback_cache_ttl.pop(key, None)
|
| |
|
| |
|
| | def save_session_state_with_fallback(
|
| | session_id: str,
|
| | state: Dict[str, Any],
|
| | ttl: int = DEFAULT_SESSION_TTL,
|
| | ) -> bool:
|
| | """
|
| | Save session state with in-memory fallback.
|
| |
|
| | Implements AC-2.3.4: Redis failure degrades gracefully.
|
| |
|
| | Args:
|
| | session_id: Unique session identifier
|
| | state: Session state dictionary
|
| | ttl: Time-to-live in seconds (default 1 hour per AC-2.3.2)
|
| |
|
| | Returns:
|
| | True if saved (Redis or fallback), False on complete failure
|
| | """
|
| |
|
| | if save_session_state(session_id, state, ttl):
|
| | return True
|
| |
|
| |
|
| | logger.warning(f"Redis unavailable, using fallback cache for session {session_id}")
|
| | try:
|
| | _cleanup_fallback_cache()
|
| | key = f"session:{session_id}"
|
| | _fallback_cache[key] = state.copy()
|
| | _fallback_cache_ttl[key] = time.time() + ttl
|
| | return True
|
| | except Exception as e:
|
| | logger.error(f"Fallback cache failed: {e}")
|
| | return False
|
| |
|
| |
|
| | def get_session_state_with_fallback(session_id: str) -> Optional[Dict[str, Any]]:
|
| | """
|
| | Get session state with in-memory fallback.
|
| |
|
| | Implements AC-2.3.4: Redis failure degrades gracefully.
|
| |
|
| | Args:
|
| | session_id: Session identifier
|
| |
|
| | Returns:
|
| | Session state or None if not found/expired
|
| | """
|
| |
|
| | state = get_session_state(session_id)
|
| | if state is not None:
|
| | logger.debug(f"Session {session_id} found in Redis")
|
| | return state
|
| |
|
| |
|
| | _cleanup_fallback_cache()
|
| | key = f"session:{session_id}"
|
| |
|
| | if key in _fallback_cache:
|
| | expiry = _fallback_cache_ttl.get(key, 0)
|
| | if expiry > time.time():
|
| | logger.debug(f"Session {session_id} retrieved from fallback cache")
|
| | return _fallback_cache[key].copy()
|
| | else:
|
| |
|
| | _fallback_cache.pop(key, None)
|
| | _fallback_cache_ttl.pop(key, None)
|
| |
|
| | return None
|
| |
|
| |
|
| | def delete_session_state_with_fallback(session_id: str) -> bool:
|
| | """
|
| | Delete session state from Redis and fallback cache.
|
| |
|
| | Args:
|
| | session_id: Session identifier
|
| |
|
| | Returns:
|
| | True if deleted from either location
|
| | """
|
| | redis_deleted = delete_session_state(session_id)
|
| |
|
| |
|
| | key = f"session:{session_id}"
|
| | fallback_deleted = key in _fallback_cache
|
| | _fallback_cache.pop(key, None)
|
| | _fallback_cache_ttl.pop(key, None)
|
| |
|
| | return redis_deleted or fallback_deleted
|
| |
|
| |
|
| | def extend_session_ttl(session_id: str, additional_seconds: int = DEFAULT_SESSION_TTL) -> bool:
|
| | """
|
| | Extend session TTL.
|
| |
|
| | Useful for keeping active sessions alive beyond initial TTL.
|
| |
|
| | Args:
|
| | session_id: Session identifier
|
| | additional_seconds: Additional time in seconds
|
| |
|
| | Returns:
|
| | True if extended, False otherwise
|
| | """
|
| | try:
|
| | client = get_redis_client()
|
| | key = f"session:{session_id}"
|
| |
|
| |
|
| | current_ttl = client.ttl(key)
|
| |
|
| | if current_ttl > 0:
|
| |
|
| | new_ttl = current_ttl + additional_seconds
|
| | client.expire(key, new_ttl)
|
| | logger.debug(f"Session {session_id} TTL extended by {additional_seconds}s")
|
| | return True
|
| |
|
| | return False
|
| | except (ConnectionError, RedisError) as e:
|
| | logger.error(f"Failed to extend session TTL: {e}")
|
| | return False
|
| |
|
| |
|
| | def get_session_ttl(session_id: str) -> int:
|
| | """
|
| | Get remaining TTL for a session.
|
| |
|
| | Args:
|
| | session_id: Session identifier
|
| |
|
| | Returns:
|
| | Remaining TTL in seconds, -2 if key doesn't exist, -1 if no expiry
|
| | """
|
| | try:
|
| | client = get_redis_client()
|
| | key = f"session:{session_id}"
|
| | return client.ttl(key)
|
| | except (ConnectionError, RedisError) as e:
|
| | logger.error(f"Failed to get session TTL: {e}")
|
| | return -2
|
| |
|
| |
|
| | def get_active_session_count() -> int:
|
| | """
|
| | Get count of active sessions.
|
| |
|
| | Returns:
|
| | Number of active sessions
|
| | """
|
| | try:
|
| | client = get_redis_client()
|
| | keys = client.keys("session:*")
|
| | return len(keys)
|
| | except (ConnectionError, RedisError) as e:
|
| | logger.error(f"Failed to get active session count: {e}")
|
| |
|
| | _cleanup_fallback_cache()
|
| | return len([k for k in _fallback_cache if k.startswith("session:")])
|
| |
|
| |
|
| | def clear_all_sessions() -> int:
|
| | """
|
| | Clear all session data (for testing/admin purposes).
|
| |
|
| | Returns:
|
| | Number of sessions cleared
|
| | """
|
| | try:
|
| | client = get_redis_client()
|
| | keys = client.keys("session:*")
|
| | if keys:
|
| | deleted = client.delete(*keys)
|
| | logger.info(f"Cleared {deleted} sessions from Redis")
|
| | return deleted
|
| | return 0
|
| | except (ConnectionError, RedisError) as e:
|
| | logger.error(f"Failed to clear sessions: {e}")
|
| | return 0
|
| |
|
| |
|
| | def reset_fallback_cache() -> None:
|
| | """Reset the in-memory fallback cache (for testing)."""
|
| | global _fallback_cache, _fallback_cache_ttl
|
| | _fallback_cache = {}
|
| | _fallback_cache_ttl = {}
|
| |
|
| |
|
| | def get_fallback_cache_stats() -> Dict[str, Any]:
|
| | """
|
| | Get fallback cache statistics.
|
| |
|
| | Returns:
|
| | Dictionary with cache stats
|
| | """
|
| | _cleanup_fallback_cache()
|
| | return {
|
| | "entries": len(_fallback_cache),
|
| | "total_size_bytes": sum(
|
| | len(json.dumps(v)) for v in _fallback_cache.values()
|
| | ),
|
| | }
|
| |
|
| |
|
| | def is_redis_available() -> bool:
|
| | """
|
| | Check if Redis is available without raising exceptions.
|
| |
|
| | Returns:
|
| | True if Redis is available, False otherwise
|
| | """
|
| | return health_check()
|
| |
|