Spaces:
Sleeping
Sleeping
File size: 11,387 Bytes
8a08300 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 |
"""
FastAPI Real-Time Fraud Detection Service.
Production-grade inference API with sub-50ms latency target.
Integrates with Redis Feature Store for real-time feature injection.
"""
import json
import logging
import time
from pathlib import Path
import pandas as pd
from typing import Dict, Any, Optional
import joblib
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from src.api.config import settings
from src.api.logger import log_shadow_prediction
from src.api.schemas import PredictionRequest, PredictionResponse, HealthResponse
from src.features.store import RedisFeatureStore
from src.explainability import FraudExplainer
# Initialize FastAPI app
app = FastAPI(
title=settings.api_title,
version=settings.api_version,
description="Real-time fraud detection API with Redis feature store integration",
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global resources (loaded on startup)
pipeline = None
threshold = None
feature_store: Optional[RedisFeatureStore] = None
explainer: Optional[FraudExplainer] = None
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@app.on_event("startup")
async def load_resources():
"""
Load model and initialize Redis on startup.
This runs once when the API starts, avoiding per-request overhead.
"""
global pipeline, threshold, feature_store, explainer
logger.info("Loading model and resources...")
# Load trained pipeline
model_path = Path(settings.model_path)
if not model_path.exists():
raise FileNotFoundError(f"Model not found: {model_path}")
pipeline = joblib.load(model_path)
logger.info(f"✓ Loaded model from {model_path}")
# Load optimal threshold
threshold_path = Path(settings.threshold_path)
if not threshold_path.exists():
raise FileNotFoundError(f"Threshold file not found: {threshold_path}")
with open(threshold_path, "r") as f:
threshold_data = json.load(f)
threshold = threshold_data["optimal_threshold"]
logger.info(f"✓ Loaded threshold: {threshold:.4f}")
# Initialize Redis Feature Store
try:
feature_store = RedisFeatureStore(
host=settings.redis_host,
port=settings.redis_port,
db=settings.redis_db,
password=settings.redis_password,
)
logger.info("✓ Connected to Redis Feature Store")
except Exception as e:
logger.warning(f"Redis connection failed: {e}. Feature store disabled.")
feature_store = None
# Initialize SHAP Explainer
try:
explainer = FraudExplainer(str(model_path))
logger.info("✓ Initialized SHAP Explainer")
except Exception as e:
logger.warning(f"SHAP initialization failed: {e}. Explainability disabled.")
explainer = None
logger.info("=" * 60)
logger.info("API Ready!")
logger.info(f"Shadow Mode: {settings.shadow_mode}")
logger.info(f"Max Latency Target: {settings.max_latency_ms}ms")
logger.info("=" * 60)
@app.on_event("shutdown")
async def shutdown_resources():
"""Clean up resources on shutdown."""
global feature_store
if feature_store:
feature_store.close()
logger.info("✓ Closed Redis connection")
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""
Health check endpoint for monitoring.
Returns service status and resource availability.
"""
redis_connected = False
if feature_store:
try:
health = feature_store.health_check()
redis_connected = health["status"] == "healthy"
except Exception:
redis_connected = False
status = "healthy" if (pipeline is not None and threshold is not None) else "unhealthy"
return HealthResponse(
status=status,
model_loaded=pipeline is not None,
redis_connected=redis_connected,
version=settings.api_version,
)
@app.post("/v1/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
"""
Real-time fraud detection endpoint.
Workflow:
1. Parse & validate request
2. Query Redis for real-time features (trans_count_24h, avg_spend_24h)
3. Combine features + run inference
4. Apply decision threshold
5. Shadow mode override (if enabled)
6. Return decision with latency tracking
Args:
request: Transaction data
Returns:
Fraud decision with probability and latency
Raises:
HTTPException: If model not loaded or validation fails
"""
start_time = time.time()
# Verify resources are loaded
if pipeline is None or threshold is None:
raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
try:
# Step 1: Convert request to dict
request_data = request.dict()
# Step 2: Query Redis/Use Overrides
# Priority: Override > Redis > Default
trans_count_24h = request.trans_count_24h
avg_spend_24h = request.avg_spend_24h
amt_to_avg_ratio_24h = request.amt_to_avg_ratio_24h
user_avg_amt_all_time = request.user_avg_amt_all_time
# If any real-time feature is missing from overrides, try Redis
if (trans_count_24h is None or avg_spend_24h is None or user_avg_amt_all_time is None) and feature_store:
try:
# Uses transaction timestamp for time-based lookup
trans_time = pd.to_datetime(request.trans_date_trans_time)
timestamp = int(trans_time.timestamp())
features = feature_store.get_features(request.user_id, timestamp)
if trans_count_24h is None:
trans_count_24h = features.get("trans_count_24h", 0)
if avg_spend_24h is None:
avg_spend_24h = features.get("avg_spend_24h", request.amt)
# Note: Redis Feature Store doesn't currently track all-time average
# This would need to be added to the Feature Store implementation
# For now, we'll use avg_spend_24h as a proxy if not overridden
if user_avg_amt_all_time is None:
user_avg_amt_all_time = features.get("user_avg_amt_all_time", avg_spend_24h)
except Exception as e:
logger.warning(f"Redis feature lookup failed: {e}. Using defaults for missing values.")
# Fill remaining defaults
if trans_count_24h is None: trans_count_24h = 0
if avg_spend_24h is None: avg_spend_24h = request.amt
if user_avg_amt_all_time is None: user_avg_amt_all_time = avg_spend_24h # Use 24h avg as proxy
# Calculate derived ratio if not overridden
if amt_to_avg_ratio_24h is None:
amt_to_avg_ratio_24h = request.amt / avg_spend_24h if avg_spend_24h > 0 else 1.0
# Inject into request data
request_data["trans_count_24h"] = trans_count_24h
request_data["avg_spend_24h"] = avg_spend_24h
request_data["amt_to_avg_ratio_24h"] = amt_to_avg_ratio_24h
request_data["amt_relative_to_all_time"] = 1.0 # Default if not computed
# Step 3: Convert to DataFrame for pipeline
df = pd.DataFrame([request_data])
# Step 4: Inference
prob = pipeline.predict_proba(df)[:, 1][0]
# Step 5: Apply threshold
real_decision = "BLOCK" if prob >= threshold else "APPROVE"
# Calculate latency
latency_ms = (time.time() - start_time) * 1000
# Step 6: Shadow mode override
final_decision = real_decision
if settings.shadow_mode:
log_shadow_prediction(
request_data=request_data,
probability=prob,
real_decision=real_decision,
latency_ms=latency_ms,
)
# But always approve in shadow mode
final_decision = "APPROVE"
# Log performance warning if latency exceeds target
if latency_ms > settings.max_latency_ms:
logger.warning(
f"Latency exceeded target: {latency_ms:.2f}ms > {settings.max_latency_ms}ms"
)
# Capture features used for response
features_used = {
"trans_count_24h": trans_count_24h,
"avg_spend_24h": avg_spend_24h,
"amt_to_avg_ratio_24h": amt_to_avg_ratio_24h,
"user_avg_amt_all_time": user_avg_amt_all_time # Now uses real/override value
}
# Calculate SHAP values if explainer is available
shap_contributions = {}
if explainer is not None and settings.enable_explainability:
try:
explanation = explainer.explain_prediction(df, threshold=threshold)
# Get top 5 features by absolute impact
shap_contributions = {
item["feature"]: item["impact"]
for item in explanation["top_features"]
}
except Exception as e:
logger.warning(f"SHAP computation failed: {e}")
# Persist transaction to Redis (if no overrides were used and not in shadow mode)
# This ensures velocity features accumulate for future predictions
if feature_store and not settings.shadow_mode:
# Only persist if user didn't override features (to avoid polluting real data)
no_overrides = (
request.trans_count_24h is None and
request.avg_spend_24h is None and
request.user_avg_amt_all_time is None
)
if no_overrides:
try:
trans_time = pd.to_datetime(request.trans_date_trans_time)
timestamp = int(trans_time.timestamp())
feature_store.add_transaction(
user_id=request.user_id,
amount=request.amt,
timestamp=timestamp
)
except Exception as e:
logger.warning(f"Failed to persist transaction to Redis: {e}")
return PredictionResponse(
decision=final_decision,
probability=float(prob),
risk_score=float(prob * 100),
latency_ms=latency_ms,
shadow_mode=settings.shadow_mode,
features=features_used,
shap_values=shap_contributions
)
except Exception as e:
logger.error(f"Prediction error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"service": settings.api_title,
"version": settings.api_version,
"status": "running",
"endpoints": {
"predict": "/v1/predict (POST)",
"health": "/health (GET)",
"docs": "/docs (GET)",
},
}
__all__ = ["app"]
|