Spaces:
Sleeping
Sleeping
| """ | |
| Redis Feature Store | |
| Real-time feature storage and computation for fraud detection. | |
| Implements stateful features that require historical context: | |
| - Sliding window transaction counts (O(log N)) | |
| - Exponential moving averages for spending (O(1)) | |
| Architecture: | |
| - Uses Redis Sorted Sets (ZSET) for time-based sliding windows | |
| - Uses Redis Strings for EMA computation with atomic operations | |
| - Connection pooling for low-latency concurrent requests | |
| Author: PayShield-ML Team | |
| """ | |
| import time | |
| from typing import Dict, List, Optional, Tuple | |
| import redis | |
| from redis.client import Pipeline | |
| from redis.connection import ConnectionPool | |
| class RedisFeatureStore: | |
| """ | |
| Redis-backed feature store for real-time fraud detection. | |
| This class manages stateful features that cannot be computed from a single | |
| transaction alone. It solves the "Stateful Feature Problem" by maintaining | |
| rolling windows of user behavior. | |
| Features Computed: | |
| 1. **trans_count_24h**: Number of transactions in last 24 hours | |
| - Data Structure: Redis Sorted Set (ZSET) | |
| - Complexity: O(log N) insert, O(log N + M) range query | |
| - Key Format: user:{user_id}:tx_history | |
| 2. **avg_spend_24h**: Exponential moving average of spending | |
| - Data Structure: Redis String (float) | |
| - Complexity: O(1) update | |
| - Key Format: user:{user_id}:avg_spend | |
| - Formula: EMA_new = α * amt_current + (1-α) * EMA_old | |
| - α = 2/(n+1) where n=24 (for 24-hour window) | |
| Connection Management: | |
| - Uses connection pooling to avoid TCP overhead | |
| - Thread-safe for concurrent API requests | |
| - Automatic reconnection on failure | |
| Example: | |
| >>> store = RedisFeatureStore(host="localhost", port=6379) | |
| >>> | |
| >>> # Record a new transaction | |
| >>> store.add_transaction( | |
| ... user_id="u12345", | |
| ... amount=150.00, | |
| ... timestamp=1234567890 | |
| ... ) | |
| >>> | |
| >>> # Get features for inference | |
| >>> features = store.get_features(user_id="u12345") | |
| >>> print(features) | |
| {'trans_count_24h': 5, 'avg_spend_24h': 120.50} | |
| """ | |
| def __init__( | |
| self, | |
| host: str = "localhost", | |
| port: int = 6379, | |
| db: int = 0, | |
| password: Optional[str] = None, | |
| max_connections: int = 50, | |
| decode_responses: bool = True, | |
| ema_alpha: Optional[float] = None, | |
| ) -> None: | |
| """ | |
| Initialize Redis Feature Store with connection pooling. | |
| Args: | |
| host: Redis server hostname | |
| port: Redis server port | |
| db: Redis database number (0-15) | |
| password: Redis password (if authentication enabled) | |
| max_connections: Maximum connections in pool | |
| decode_responses: If True, decode bytes to strings | |
| ema_alpha: Exponential moving average smoothing factor. | |
| Default is 2/(24+1) ≈ 0.08 for 24-hour window. | |
| """ | |
| # Create connection pool for thread-safe access | |
| self.pool: ConnectionPool = redis.ConnectionPool( | |
| host=host, | |
| port=port, | |
| db=db, | |
| password=password, | |
| max_connections=max_connections, | |
| decode_responses=decode_responses, | |
| socket_connect_timeout=2, # 2s connection timeout | |
| socket_timeout=1, # 1s operation timeout | |
| ) | |
| self.client: redis.Redis = redis.Redis(connection_pool=self.pool) | |
| # EMA configuration: α = 2/(n+1) for n=24 hours | |
| self.ema_alpha: float = ema_alpha if ema_alpha is not None else 2.0 / (24 + 1) | |
| # TTL for keys (7 days = 604800 seconds) | |
| # This prevents unbounded memory growth | |
| self.key_ttl: int = 604800 | |
| # Test connection | |
| try: | |
| self.client.ping() | |
| except redis.exceptions.ConnectionError as e: | |
| raise ConnectionError( | |
| f"Failed to connect to Redis at {host}:{port}. Ensure Redis is running. Error: {e}" | |
| ) from e | |
| def _get_tx_history_key(self, user_id: str) -> str: | |
| """Generate Redis key for transaction history ZSET.""" | |
| return f"user:{user_id}:tx_history" | |
| def _get_avg_spend_key(self, user_id: str) -> str: | |
| """Generate Redis key for average spend EMA.""" | |
| return f"user:{user_id}:avg_spend" | |
| def add_transaction(self, user_id: str, amount: float, timestamp: Optional[int] = None) -> None: | |
| """ | |
| Record a new transaction and update features atomically. | |
| This method performs three atomic operations: | |
| 1. Add transaction to sliding window (ZSET) | |
| 2. Remove expired transactions (older than 24h) | |
| 3. Update exponential moving average | |
| All operations are pipelined for performance (single round trip). | |
| Args: | |
| user_id: User identifier | |
| amount: Transaction amount in USD | |
| timestamp: Unix timestamp. If None, uses current time. | |
| Raises: | |
| redis.exceptions.RedisError: If Redis operation fails | |
| Example: | |
| >>> store.add_transaction("u12345", 150.00, 1234567890) | |
| """ | |
| if timestamp is None: | |
| timestamp = int(time.time()) | |
| # Calculate window boundaries | |
| window_start = timestamp - 86400 # 24 hours = 86400 seconds | |
| tx_key = self._get_tx_history_key(user_id) | |
| avg_key = self._get_avg_spend_key(user_id) | |
| # Use pipeline for atomic multi-operation | |
| pipe: Pipeline = self.client.pipeline() | |
| # 1. Add transaction to sorted set (score = timestamp) | |
| # Using transaction hash as member to allow duplicate amounts | |
| tx_member = f"{timestamp}:{amount}" | |
| pipe.zadd(tx_key, {tx_member: timestamp}) | |
| # 2. Remove transactions older than 24 hours | |
| pipe.zremrangebyscore(tx_key, "-inf", window_start) | |
| # 3. Set TTL to prevent unbounded growth (reset on each transaction) | |
| pipe.expire(tx_key, self.key_ttl) | |
| # 4. Update exponential moving average | |
| # Get current EMA (default to amount if first transaction) | |
| current_ema = self.client.get(avg_key) | |
| if current_ema is None: | |
| new_ema = amount | |
| else: | |
| current_ema = float(current_ema) | |
| # EMA formula: α * x_new + (1-α) * EMA_old | |
| new_ema = self.ema_alpha * amount + (1 - self.ema_alpha) * current_ema | |
| pipe.set(avg_key, new_ema) | |
| pipe.expire(avg_key, self.key_ttl) | |
| # Execute all operations atomically | |
| pipe.execute() | |
| def get_features( | |
| self, user_id: str, current_timestamp: Optional[int] = None | |
| ) -> Dict[str, float]: | |
| """ | |
| Retrieve real-time features for a user. | |
| This is the primary method called during inference. It returns | |
| the stateful features needed by the fraud detection model. | |
| Args: | |
| user_id: User identifier | |
| current_timestamp: Current Unix timestamp. If None, uses system time. | |
| Returns: | |
| Dictionary containing: | |
| - trans_count_24h: Number of transactions in last 24 hours | |
| - avg_spend_24h: Exponential moving average of spending | |
| Example: | |
| >>> features = store.get_features("u12345") | |
| >>> print(features) | |
| {'trans_count_24h': 5.0, 'avg_spend_24h': 120.50} | |
| """ | |
| if current_timestamp is None: | |
| current_timestamp = int(time.time()) | |
| window_start = current_timestamp - 86400 | |
| tx_key = self._get_tx_history_key(user_id) | |
| avg_key = self._get_avg_spend_key(user_id) | |
| # Use pipeline for efficiency | |
| pipe: Pipeline = self.client.pipeline() | |
| # Count transactions in window (ZCOUNT is O(log N)) | |
| pipe.zcount(tx_key, window_start, current_timestamp) | |
| # Get average spend | |
| pipe.get(avg_key) | |
| results = pipe.execute() | |
| trans_count = float(results[0]) | |
| avg_spend = float(results[1]) if results[1] is not None else 0.0 | |
| return { | |
| "trans_count_24h": trans_count, | |
| "avg_spend_24h": avg_spend, | |
| } | |
| def get_transaction_history( | |
| self, user_id: str, lookback_hours: int = 24 | |
| ) -> List[Tuple[int, float]]: | |
| """ | |
| Retrieve raw transaction history for a user. | |
| Useful for debugging and analytics. Not typically used in inference. | |
| Args: | |
| user_id: User identifier | |
| lookback_hours: How many hours of history to retrieve | |
| Returns: | |
| List of tuples: [(timestamp, amount), ...] | |
| Sorted by timestamp (newest first) | |
| Example: | |
| >>> history = store.get_transaction_history("u12345", lookback_hours=48) | |
| >>> for ts, amt in history: | |
| ... print(f"{ts}: ${amt:.2f}") | |
| """ | |
| current_time = int(time.time()) | |
| window_start = current_time - (lookback_hours * 3600) | |
| tx_key = self._get_tx_history_key(user_id) | |
| # Get all transactions in window with scores (timestamps) | |
| # ZRANGEBYSCORE with WITHSCORES | |
| raw_results = self.client.zrangebyscore(tx_key, window_start, current_time, withscores=True) | |
| # Parse results: member format is "timestamp:amount" | |
| transactions = [] | |
| for member, score in raw_results: | |
| timestamp_str, amount_str = member.split(":") | |
| transactions.append((int(timestamp_str), float(amount_str))) | |
| # Sort by timestamp descending (newest first) | |
| transactions.sort(reverse=True, key=lambda x: x[0]) | |
| return transactions | |
| def delete_user_data(self, user_id: str) -> int: | |
| """ | |
| Delete all feature data for a user. | |
| Used for GDPR compliance / right to be forgotten. | |
| Args: | |
| user_id: User identifier | |
| Returns: | |
| Number of keys deleted (should be 2) | |
| Example: | |
| >>> deleted = store.delete_user_data("u12345") | |
| >>> print(f"Deleted {deleted} keys") | |
| """ | |
| tx_key = self._get_tx_history_key(user_id) | |
| avg_key = self._get_avg_spend_key(user_id) | |
| return self.client.delete(tx_key, avg_key) | |
| def health_check(self) -> Dict[str, any]: | |
| """ | |
| Check Redis connection health and get statistics. | |
| Returns: | |
| Dictionary with health metrics | |
| Example: | |
| >>> health = store.health_check() | |
| >>> print(health) | |
| {'status': 'healthy', 'ping_ms': 0.5, 'connected_clients': 10} | |
| """ | |
| try: | |
| start = time.time() | |
| self.client.ping() | |
| ping_ms = (time.time() - start) * 1000 | |
| # Get Redis info | |
| info = self.client.info("stats") | |
| return { | |
| "status": "healthy", | |
| "ping_ms": round(ping_ms, 2), | |
| "connected_clients": info.get("connected_clients", -1), | |
| "total_commands_processed": info.get("total_commands_processed", -1), | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "unhealthy", | |
| "error": str(e), | |
| } | |
| def close(self) -> None: | |
| """ | |
| Close the connection pool. | |
| Call this when shutting down the application. | |
| """ | |
| self.pool.disconnect() | |
| __all__ = ["RedisFeatureStore"] | |