Sibi Krishnamoorthy
prod
8a08300
"""
Redis Feature Store
Real-time feature storage and computation for fraud detection.
Implements stateful features that require historical context:
- Sliding window transaction counts (O(log N))
- Exponential moving averages for spending (O(1))
Architecture:
- Uses Redis Sorted Sets (ZSET) for time-based sliding windows
- Uses Redis Strings for EMA computation with atomic operations
- Connection pooling for low-latency concurrent requests
Author: PayShield-ML Team
"""
import time
from typing import Dict, List, Optional, Tuple
import redis
from redis.client import Pipeline
from redis.connection import ConnectionPool
class RedisFeatureStore:
"""
Redis-backed feature store for real-time fraud detection.
This class manages stateful features that cannot be computed from a single
transaction alone. It solves the "Stateful Feature Problem" by maintaining
rolling windows of user behavior.
Features Computed:
1. **trans_count_24h**: Number of transactions in last 24 hours
- Data Structure: Redis Sorted Set (ZSET)
- Complexity: O(log N) insert, O(log N + M) range query
- Key Format: user:{user_id}:tx_history
2. **avg_spend_24h**: Exponential moving average of spending
- Data Structure: Redis String (float)
- Complexity: O(1) update
- Key Format: user:{user_id}:avg_spend
- Formula: EMA_new = α * amt_current + (1-α) * EMA_old
- α = 2/(n+1) where n=24 (for 24-hour window)
Connection Management:
- Uses connection pooling to avoid TCP overhead
- Thread-safe for concurrent API requests
- Automatic reconnection on failure
Example:
>>> store = RedisFeatureStore(host="localhost", port=6379)
>>>
>>> # Record a new transaction
>>> store.add_transaction(
... user_id="u12345",
... amount=150.00,
... timestamp=1234567890
... )
>>>
>>> # Get features for inference
>>> features = store.get_features(user_id="u12345")
>>> print(features)
{'trans_count_24h': 5, 'avg_spend_24h': 120.50}
"""
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: Optional[str] = None,
max_connections: int = 50,
decode_responses: bool = True,
ema_alpha: Optional[float] = None,
) -> None:
"""
Initialize Redis Feature Store with connection pooling.
Args:
host: Redis server hostname
port: Redis server port
db: Redis database number (0-15)
password: Redis password (if authentication enabled)
max_connections: Maximum connections in pool
decode_responses: If True, decode bytes to strings
ema_alpha: Exponential moving average smoothing factor.
Default is 2/(24+1) ≈ 0.08 for 24-hour window.
"""
# Create connection pool for thread-safe access
self.pool: ConnectionPool = redis.ConnectionPool(
host=host,
port=port,
db=db,
password=password,
max_connections=max_connections,
decode_responses=decode_responses,
socket_connect_timeout=2, # 2s connection timeout
socket_timeout=1, # 1s operation timeout
)
self.client: redis.Redis = redis.Redis(connection_pool=self.pool)
# EMA configuration: α = 2/(n+1) for n=24 hours
self.ema_alpha: float = ema_alpha if ema_alpha is not None else 2.0 / (24 + 1)
# TTL for keys (7 days = 604800 seconds)
# This prevents unbounded memory growth
self.key_ttl: int = 604800
# Test connection
try:
self.client.ping()
except redis.exceptions.ConnectionError as e:
raise ConnectionError(
f"Failed to connect to Redis at {host}:{port}. Ensure Redis is running. Error: {e}"
) from e
def _get_tx_history_key(self, user_id: str) -> str:
"""Generate Redis key for transaction history ZSET."""
return f"user:{user_id}:tx_history"
def _get_avg_spend_key(self, user_id: str) -> str:
"""Generate Redis key for average spend EMA."""
return f"user:{user_id}:avg_spend"
def add_transaction(self, user_id: str, amount: float, timestamp: Optional[int] = None) -> None:
"""
Record a new transaction and update features atomically.
This method performs three atomic operations:
1. Add transaction to sliding window (ZSET)
2. Remove expired transactions (older than 24h)
3. Update exponential moving average
All operations are pipelined for performance (single round trip).
Args:
user_id: User identifier
amount: Transaction amount in USD
timestamp: Unix timestamp. If None, uses current time.
Raises:
redis.exceptions.RedisError: If Redis operation fails
Example:
>>> store.add_transaction("u12345", 150.00, 1234567890)
"""
if timestamp is None:
timestamp = int(time.time())
# Calculate window boundaries
window_start = timestamp - 86400 # 24 hours = 86400 seconds
tx_key = self._get_tx_history_key(user_id)
avg_key = self._get_avg_spend_key(user_id)
# Use pipeline for atomic multi-operation
pipe: Pipeline = self.client.pipeline()
# 1. Add transaction to sorted set (score = timestamp)
# Using transaction hash as member to allow duplicate amounts
tx_member = f"{timestamp}:{amount}"
pipe.zadd(tx_key, {tx_member: timestamp})
# 2. Remove transactions older than 24 hours
pipe.zremrangebyscore(tx_key, "-inf", window_start)
# 3. Set TTL to prevent unbounded growth (reset on each transaction)
pipe.expire(tx_key, self.key_ttl)
# 4. Update exponential moving average
# Get current EMA (default to amount if first transaction)
current_ema = self.client.get(avg_key)
if current_ema is None:
new_ema = amount
else:
current_ema = float(current_ema)
# EMA formula: α * x_new + (1-α) * EMA_old
new_ema = self.ema_alpha * amount + (1 - self.ema_alpha) * current_ema
pipe.set(avg_key, new_ema)
pipe.expire(avg_key, self.key_ttl)
# Execute all operations atomically
pipe.execute()
def get_features(
self, user_id: str, current_timestamp: Optional[int] = None
) -> Dict[str, float]:
"""
Retrieve real-time features for a user.
This is the primary method called during inference. It returns
the stateful features needed by the fraud detection model.
Args:
user_id: User identifier
current_timestamp: Current Unix timestamp. If None, uses system time.
Returns:
Dictionary containing:
- trans_count_24h: Number of transactions in last 24 hours
- avg_spend_24h: Exponential moving average of spending
Example:
>>> features = store.get_features("u12345")
>>> print(features)
{'trans_count_24h': 5.0, 'avg_spend_24h': 120.50}
"""
if current_timestamp is None:
current_timestamp = int(time.time())
window_start = current_timestamp - 86400
tx_key = self._get_tx_history_key(user_id)
avg_key = self._get_avg_spend_key(user_id)
# Use pipeline for efficiency
pipe: Pipeline = self.client.pipeline()
# Count transactions in window (ZCOUNT is O(log N))
pipe.zcount(tx_key, window_start, current_timestamp)
# Get average spend
pipe.get(avg_key)
results = pipe.execute()
trans_count = float(results[0])
avg_spend = float(results[1]) if results[1] is not None else 0.0
return {
"trans_count_24h": trans_count,
"avg_spend_24h": avg_spend,
}
def get_transaction_history(
self, user_id: str, lookback_hours: int = 24
) -> List[Tuple[int, float]]:
"""
Retrieve raw transaction history for a user.
Useful for debugging and analytics. Not typically used in inference.
Args:
user_id: User identifier
lookback_hours: How many hours of history to retrieve
Returns:
List of tuples: [(timestamp, amount), ...]
Sorted by timestamp (newest first)
Example:
>>> history = store.get_transaction_history("u12345", lookback_hours=48)
>>> for ts, amt in history:
... print(f"{ts}: ${amt:.2f}")
"""
current_time = int(time.time())
window_start = current_time - (lookback_hours * 3600)
tx_key = self._get_tx_history_key(user_id)
# Get all transactions in window with scores (timestamps)
# ZRANGEBYSCORE with WITHSCORES
raw_results = self.client.zrangebyscore(tx_key, window_start, current_time, withscores=True)
# Parse results: member format is "timestamp:amount"
transactions = []
for member, score in raw_results:
timestamp_str, amount_str = member.split(":")
transactions.append((int(timestamp_str), float(amount_str)))
# Sort by timestamp descending (newest first)
transactions.sort(reverse=True, key=lambda x: x[0])
return transactions
def delete_user_data(self, user_id: str) -> int:
"""
Delete all feature data for a user.
Used for GDPR compliance / right to be forgotten.
Args:
user_id: User identifier
Returns:
Number of keys deleted (should be 2)
Example:
>>> deleted = store.delete_user_data("u12345")
>>> print(f"Deleted {deleted} keys")
"""
tx_key = self._get_tx_history_key(user_id)
avg_key = self._get_avg_spend_key(user_id)
return self.client.delete(tx_key, avg_key)
def health_check(self) -> Dict[str, any]:
"""
Check Redis connection health and get statistics.
Returns:
Dictionary with health metrics
Example:
>>> health = store.health_check()
>>> print(health)
{'status': 'healthy', 'ping_ms': 0.5, 'connected_clients': 10}
"""
try:
start = time.time()
self.client.ping()
ping_ms = (time.time() - start) * 1000
# Get Redis info
info = self.client.info("stats")
return {
"status": "healthy",
"ping_ms": round(ping_ms, 2),
"connected_clients": info.get("connected_clients", -1),
"total_commands_processed": info.get("total_commands_processed", -1),
}
except Exception as e:
return {
"status": "unhealthy",
"error": str(e),
}
def close(self) -> None:
"""
Close the connection pool.
Call this when shutting down the application.
"""
self.pool.disconnect()
__all__ = ["RedisFeatureStore"]