Spaces:
Sleeping
Sleeping
File size: 11,395 Bytes
8a08300 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 |
"""
Redis Feature Store
Real-time feature storage and computation for fraud detection.
Implements stateful features that require historical context:
- Sliding window transaction counts (O(log N))
- Exponential moving averages for spending (O(1))
Architecture:
- Uses Redis Sorted Sets (ZSET) for time-based sliding windows
- Uses Redis Strings for EMA computation with atomic operations
- Connection pooling for low-latency concurrent requests
Author: PayShield-ML Team
"""
import time
from typing import Dict, List, Optional, Tuple
import redis
from redis.client import Pipeline
from redis.connection import ConnectionPool
class RedisFeatureStore:
"""
Redis-backed feature store for real-time fraud detection.
This class manages stateful features that cannot be computed from a single
transaction alone. It solves the "Stateful Feature Problem" by maintaining
rolling windows of user behavior.
Features Computed:
1. **trans_count_24h**: Number of transactions in last 24 hours
- Data Structure: Redis Sorted Set (ZSET)
- Complexity: O(log N) insert, O(log N + M) range query
- Key Format: user:{user_id}:tx_history
2. **avg_spend_24h**: Exponential moving average of spending
- Data Structure: Redis String (float)
- Complexity: O(1) update
- Key Format: user:{user_id}:avg_spend
- Formula: EMA_new = α * amt_current + (1-α) * EMA_old
- α = 2/(n+1) where n=24 (for 24-hour window)
Connection Management:
- Uses connection pooling to avoid TCP overhead
- Thread-safe for concurrent API requests
- Automatic reconnection on failure
Example:
>>> store = RedisFeatureStore(host="localhost", port=6379)
>>>
>>> # Record a new transaction
>>> store.add_transaction(
... user_id="u12345",
... amount=150.00,
... timestamp=1234567890
... )
>>>
>>> # Get features for inference
>>> features = store.get_features(user_id="u12345")
>>> print(features)
{'trans_count_24h': 5, 'avg_spend_24h': 120.50}
"""
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: Optional[str] = None,
max_connections: int = 50,
decode_responses: bool = True,
ema_alpha: Optional[float] = None,
) -> None:
"""
Initialize Redis Feature Store with connection pooling.
Args:
host: Redis server hostname
port: Redis server port
db: Redis database number (0-15)
password: Redis password (if authentication enabled)
max_connections: Maximum connections in pool
decode_responses: If True, decode bytes to strings
ema_alpha: Exponential moving average smoothing factor.
Default is 2/(24+1) ≈ 0.08 for 24-hour window.
"""
# Create connection pool for thread-safe access
self.pool: ConnectionPool = redis.ConnectionPool(
host=host,
port=port,
db=db,
password=password,
max_connections=max_connections,
decode_responses=decode_responses,
socket_connect_timeout=2, # 2s connection timeout
socket_timeout=1, # 1s operation timeout
)
self.client: redis.Redis = redis.Redis(connection_pool=self.pool)
# EMA configuration: α = 2/(n+1) for n=24 hours
self.ema_alpha: float = ema_alpha if ema_alpha is not None else 2.0 / (24 + 1)
# TTL for keys (7 days = 604800 seconds)
# This prevents unbounded memory growth
self.key_ttl: int = 604800
# Test connection
try:
self.client.ping()
except redis.exceptions.ConnectionError as e:
raise ConnectionError(
f"Failed to connect to Redis at {host}:{port}. Ensure Redis is running. Error: {e}"
) from e
def _get_tx_history_key(self, user_id: str) -> str:
"""Generate Redis key for transaction history ZSET."""
return f"user:{user_id}:tx_history"
def _get_avg_spend_key(self, user_id: str) -> str:
"""Generate Redis key for average spend EMA."""
return f"user:{user_id}:avg_spend"
def add_transaction(self, user_id: str, amount: float, timestamp: Optional[int] = None) -> None:
"""
Record a new transaction and update features atomically.
This method performs three atomic operations:
1. Add transaction to sliding window (ZSET)
2. Remove expired transactions (older than 24h)
3. Update exponential moving average
All operations are pipelined for performance (single round trip).
Args:
user_id: User identifier
amount: Transaction amount in USD
timestamp: Unix timestamp. If None, uses current time.
Raises:
redis.exceptions.RedisError: If Redis operation fails
Example:
>>> store.add_transaction("u12345", 150.00, 1234567890)
"""
if timestamp is None:
timestamp = int(time.time())
# Calculate window boundaries
window_start = timestamp - 86400 # 24 hours = 86400 seconds
tx_key = self._get_tx_history_key(user_id)
avg_key = self._get_avg_spend_key(user_id)
# Use pipeline for atomic multi-operation
pipe: Pipeline = self.client.pipeline()
# 1. Add transaction to sorted set (score = timestamp)
# Using transaction hash as member to allow duplicate amounts
tx_member = f"{timestamp}:{amount}"
pipe.zadd(tx_key, {tx_member: timestamp})
# 2. Remove transactions older than 24 hours
pipe.zremrangebyscore(tx_key, "-inf", window_start)
# 3. Set TTL to prevent unbounded growth (reset on each transaction)
pipe.expire(tx_key, self.key_ttl)
# 4. Update exponential moving average
# Get current EMA (default to amount if first transaction)
current_ema = self.client.get(avg_key)
if current_ema is None:
new_ema = amount
else:
current_ema = float(current_ema)
# EMA formula: α * x_new + (1-α) * EMA_old
new_ema = self.ema_alpha * amount + (1 - self.ema_alpha) * current_ema
pipe.set(avg_key, new_ema)
pipe.expire(avg_key, self.key_ttl)
# Execute all operations atomically
pipe.execute()
def get_features(
self, user_id: str, current_timestamp: Optional[int] = None
) -> Dict[str, float]:
"""
Retrieve real-time features for a user.
This is the primary method called during inference. It returns
the stateful features needed by the fraud detection model.
Args:
user_id: User identifier
current_timestamp: Current Unix timestamp. If None, uses system time.
Returns:
Dictionary containing:
- trans_count_24h: Number of transactions in last 24 hours
- avg_spend_24h: Exponential moving average of spending
Example:
>>> features = store.get_features("u12345")
>>> print(features)
{'trans_count_24h': 5.0, 'avg_spend_24h': 120.50}
"""
if current_timestamp is None:
current_timestamp = int(time.time())
window_start = current_timestamp - 86400
tx_key = self._get_tx_history_key(user_id)
avg_key = self._get_avg_spend_key(user_id)
# Use pipeline for efficiency
pipe: Pipeline = self.client.pipeline()
# Count transactions in window (ZCOUNT is O(log N))
pipe.zcount(tx_key, window_start, current_timestamp)
# Get average spend
pipe.get(avg_key)
results = pipe.execute()
trans_count = float(results[0])
avg_spend = float(results[1]) if results[1] is not None else 0.0
return {
"trans_count_24h": trans_count,
"avg_spend_24h": avg_spend,
}
def get_transaction_history(
self, user_id: str, lookback_hours: int = 24
) -> List[Tuple[int, float]]:
"""
Retrieve raw transaction history for a user.
Useful for debugging and analytics. Not typically used in inference.
Args:
user_id: User identifier
lookback_hours: How many hours of history to retrieve
Returns:
List of tuples: [(timestamp, amount), ...]
Sorted by timestamp (newest first)
Example:
>>> history = store.get_transaction_history("u12345", lookback_hours=48)
>>> for ts, amt in history:
... print(f"{ts}: ${amt:.2f}")
"""
current_time = int(time.time())
window_start = current_time - (lookback_hours * 3600)
tx_key = self._get_tx_history_key(user_id)
# Get all transactions in window with scores (timestamps)
# ZRANGEBYSCORE with WITHSCORES
raw_results = self.client.zrangebyscore(tx_key, window_start, current_time, withscores=True)
# Parse results: member format is "timestamp:amount"
transactions = []
for member, score in raw_results:
timestamp_str, amount_str = member.split(":")
transactions.append((int(timestamp_str), float(amount_str)))
# Sort by timestamp descending (newest first)
transactions.sort(reverse=True, key=lambda x: x[0])
return transactions
def delete_user_data(self, user_id: str) -> int:
"""
Delete all feature data for a user.
Used for GDPR compliance / right to be forgotten.
Args:
user_id: User identifier
Returns:
Number of keys deleted (should be 2)
Example:
>>> deleted = store.delete_user_data("u12345")
>>> print(f"Deleted {deleted} keys")
"""
tx_key = self._get_tx_history_key(user_id)
avg_key = self._get_avg_spend_key(user_id)
return self.client.delete(tx_key, avg_key)
def health_check(self) -> Dict[str, any]:
"""
Check Redis connection health and get statistics.
Returns:
Dictionary with health metrics
Example:
>>> health = store.health_check()
>>> print(health)
{'status': 'healthy', 'ping_ms': 0.5, 'connected_clients': 10}
"""
try:
start = time.time()
self.client.ping()
ping_ms = (time.time() - start) * 1000
# Get Redis info
info = self.client.info("stats")
return {
"status": "healthy",
"ping_ms": round(ping_ms, 2),
"connected_clients": info.get("connected_clients", -1),
"total_commands_processed": info.get("total_commands_processed", -1),
}
except Exception as e:
return {
"status": "unhealthy",
"error": str(e),
}
def close(self) -> None:
"""
Close the connection pool.
Call this when shutting down the application.
"""
self.pool.disconnect()
__all__ = ["RedisFeatureStore"]
|