File size: 11,395 Bytes
8a08300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
"""
Redis Feature Store

Real-time feature storage and computation for fraud detection.
Implements stateful features that require historical context:
- Sliding window transaction counts (O(log N))
- Exponential moving averages for spending (O(1))

Architecture:
- Uses Redis Sorted Sets (ZSET) for time-based sliding windows
- Uses Redis Strings for EMA computation with atomic operations
- Connection pooling for low-latency concurrent requests

Author: PayShield-ML Team
"""

import time
from typing import Dict, List, Optional, Tuple

import redis
from redis.client import Pipeline
from redis.connection import ConnectionPool


class RedisFeatureStore:
    """
    Redis-backed feature store for real-time fraud detection.

    This class manages stateful features that cannot be computed from a single
    transaction alone. It solves the "Stateful Feature Problem" by maintaining
    rolling windows of user behavior.

    Features Computed:
    1. **trans_count_24h**: Number of transactions in last 24 hours
       - Data Structure: Redis Sorted Set (ZSET)
       - Complexity: O(log N) insert, O(log N + M) range query
       - Key Format: user:{user_id}:tx_history

    2. **avg_spend_24h**: Exponential moving average of spending
       - Data Structure: Redis String (float)
       - Complexity: O(1) update
       - Key Format: user:{user_id}:avg_spend
       - Formula: EMA_new = α * amt_current + (1-α) * EMA_old
       - α = 2/(n+1) where n=24 (for 24-hour window)

    Connection Management:
    - Uses connection pooling to avoid TCP overhead
    - Thread-safe for concurrent API requests
    - Automatic reconnection on failure

    Example:
        >>> store = RedisFeatureStore(host="localhost", port=6379)
        >>>
        >>> # Record a new transaction
        >>> store.add_transaction(
        ...     user_id="u12345",
        ...     amount=150.00,
        ...     timestamp=1234567890
        ... )
        >>>
        >>> # Get features for inference
        >>> features = store.get_features(user_id="u12345")
        >>> print(features)
        {'trans_count_24h': 5, 'avg_spend_24h': 120.50}
    """

    def __init__(
        self,
        host: str = "localhost",
        port: int = 6379,
        db: int = 0,
        password: Optional[str] = None,
        max_connections: int = 50,
        decode_responses: bool = True,
        ema_alpha: Optional[float] = None,
    ) -> None:
        """
        Initialize Redis Feature Store with connection pooling.

        Args:
            host: Redis server hostname
            port: Redis server port
            db: Redis database number (0-15)
            password: Redis password (if authentication enabled)
            max_connections: Maximum connections in pool
            decode_responses: If True, decode bytes to strings
            ema_alpha: Exponential moving average smoothing factor.
                      Default is 2/(24+1) ≈ 0.08 for 24-hour window.
        """
        # Create connection pool for thread-safe access
        self.pool: ConnectionPool = redis.ConnectionPool(
            host=host,
            port=port,
            db=db,
            password=password,
            max_connections=max_connections,
            decode_responses=decode_responses,
            socket_connect_timeout=2,  # 2s connection timeout
            socket_timeout=1,  # 1s operation timeout
        )

        self.client: redis.Redis = redis.Redis(connection_pool=self.pool)

        # EMA configuration: α = 2/(n+1) for n=24 hours
        self.ema_alpha: float = ema_alpha if ema_alpha is not None else 2.0 / (24 + 1)

        # TTL for keys (7 days = 604800 seconds)
        # This prevents unbounded memory growth
        self.key_ttl: int = 604800

        # Test connection
        try:
            self.client.ping()
        except redis.exceptions.ConnectionError as e:
            raise ConnectionError(
                f"Failed to connect to Redis at {host}:{port}. Ensure Redis is running. Error: {e}"
            ) from e

    def _get_tx_history_key(self, user_id: str) -> str:
        """Generate Redis key for transaction history ZSET."""
        return f"user:{user_id}:tx_history"

    def _get_avg_spend_key(self, user_id: str) -> str:
        """Generate Redis key for average spend EMA."""
        return f"user:{user_id}:avg_spend"

    def add_transaction(self, user_id: str, amount: float, timestamp: Optional[int] = None) -> None:
        """
        Record a new transaction and update features atomically.

        This method performs three atomic operations:
        1. Add transaction to sliding window (ZSET)
        2. Remove expired transactions (older than 24h)
        3. Update exponential moving average

        All operations are pipelined for performance (single round trip).

        Args:
            user_id: User identifier
            amount: Transaction amount in USD
            timestamp: Unix timestamp. If None, uses current time.

        Raises:
            redis.exceptions.RedisError: If Redis operation fails

        Example:
            >>> store.add_transaction("u12345", 150.00, 1234567890)
        """
        if timestamp is None:
            timestamp = int(time.time())

        # Calculate window boundaries
        window_start = timestamp - 86400  # 24 hours = 86400 seconds

        tx_key = self._get_tx_history_key(user_id)
        avg_key = self._get_avg_spend_key(user_id)

        # Use pipeline for atomic multi-operation
        pipe: Pipeline = self.client.pipeline()

        # 1. Add transaction to sorted set (score = timestamp)
        # Using transaction hash as member to allow duplicate amounts
        tx_member = f"{timestamp}:{amount}"
        pipe.zadd(tx_key, {tx_member: timestamp})

        # 2. Remove transactions older than 24 hours
        pipe.zremrangebyscore(tx_key, "-inf", window_start)

        # 3. Set TTL to prevent unbounded growth (reset on each transaction)
        pipe.expire(tx_key, self.key_ttl)

        # 4. Update exponential moving average
        # Get current EMA (default to amount if first transaction)
        current_ema = self.client.get(avg_key)
        if current_ema is None:
            new_ema = amount
        else:
            current_ema = float(current_ema)
            # EMA formula: α * x_new + (1-α) * EMA_old
            new_ema = self.ema_alpha * amount + (1 - self.ema_alpha) * current_ema

        pipe.set(avg_key, new_ema)
        pipe.expire(avg_key, self.key_ttl)

        # Execute all operations atomically
        pipe.execute()

    def get_features(
        self, user_id: str, current_timestamp: Optional[int] = None
    ) -> Dict[str, float]:
        """
        Retrieve real-time features for a user.

        This is the primary method called during inference. It returns
        the stateful features needed by the fraud detection model.

        Args:
            user_id: User identifier
            current_timestamp: Current Unix timestamp. If None, uses system time.

        Returns:
            Dictionary containing:
            - trans_count_24h: Number of transactions in last 24 hours
            - avg_spend_24h: Exponential moving average of spending

        Example:
            >>> features = store.get_features("u12345")
            >>> print(features)
            {'trans_count_24h': 5.0, 'avg_spend_24h': 120.50}
        """
        if current_timestamp is None:
            current_timestamp = int(time.time())

        window_start = current_timestamp - 86400

        tx_key = self._get_tx_history_key(user_id)
        avg_key = self._get_avg_spend_key(user_id)

        # Use pipeline for efficiency
        pipe: Pipeline = self.client.pipeline()

        # Count transactions in window (ZCOUNT is O(log N))
        pipe.zcount(tx_key, window_start, current_timestamp)

        # Get average spend
        pipe.get(avg_key)

        results = pipe.execute()

        trans_count = float(results[0])
        avg_spend = float(results[1]) if results[1] is not None else 0.0

        return {
            "trans_count_24h": trans_count,
            "avg_spend_24h": avg_spend,
        }

    def get_transaction_history(
        self, user_id: str, lookback_hours: int = 24
    ) -> List[Tuple[int, float]]:
        """
        Retrieve raw transaction history for a user.

        Useful for debugging and analytics. Not typically used in inference.

        Args:
            user_id: User identifier
            lookback_hours: How many hours of history to retrieve

        Returns:
            List of tuples: [(timestamp, amount), ...]
            Sorted by timestamp (newest first)

        Example:
            >>> history = store.get_transaction_history("u12345", lookback_hours=48)
            >>> for ts, amt in history:
            ...     print(f"{ts}: ${amt:.2f}")
        """
        current_time = int(time.time())
        window_start = current_time - (lookback_hours * 3600)

        tx_key = self._get_tx_history_key(user_id)

        # Get all transactions in window with scores (timestamps)
        # ZRANGEBYSCORE with WITHSCORES
        raw_results = self.client.zrangebyscore(tx_key, window_start, current_time, withscores=True)

        # Parse results: member format is "timestamp:amount"
        transactions = []
        for member, score in raw_results:
            timestamp_str, amount_str = member.split(":")
            transactions.append((int(timestamp_str), float(amount_str)))

        # Sort by timestamp descending (newest first)
        transactions.sort(reverse=True, key=lambda x: x[0])

        return transactions

    def delete_user_data(self, user_id: str) -> int:
        """
        Delete all feature data for a user.

        Used for GDPR compliance / right to be forgotten.

        Args:
            user_id: User identifier

        Returns:
            Number of keys deleted (should be 2)

        Example:
            >>> deleted = store.delete_user_data("u12345")
            >>> print(f"Deleted {deleted} keys")
        """
        tx_key = self._get_tx_history_key(user_id)
        avg_key = self._get_avg_spend_key(user_id)

        return self.client.delete(tx_key, avg_key)

    def health_check(self) -> Dict[str, any]:
        """
        Check Redis connection health and get statistics.

        Returns:
            Dictionary with health metrics

        Example:
            >>> health = store.health_check()
            >>> print(health)
            {'status': 'healthy', 'ping_ms': 0.5, 'connected_clients': 10}
        """
        try:
            start = time.time()
            self.client.ping()
            ping_ms = (time.time() - start) * 1000

            # Get Redis info
            info = self.client.info("stats")

            return {
                "status": "healthy",
                "ping_ms": round(ping_ms, 2),
                "connected_clients": info.get("connected_clients", -1),
                "total_commands_processed": info.get("total_commands_processed", -1),
            }
        except Exception as e:
            return {
                "status": "unhealthy",
                "error": str(e),
            }

    def close(self) -> None:
        """
        Close the connection pool.

        Call this when shutting down the application.
        """
        self.pool.disconnect()


__all__ = ["RedisFeatureStore"]