Terorra commited on
Commit
4ecb012
·
1 Parent(s): cfe9553
Files changed (3) hide show
  1. Dockerfile +34 -0
  2. app.py +571 -0
  3. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM continuumio/miniconda3
2
+
3
+ RUN apt-get update -y
4
+ RUN apt-get install nano unzip curl -y
5
+
6
+ # THIS IS SPECIFIC TO HUGGINFACE
7
+ # We create a new user named "user" with ID of 1000
8
+ RUN useradd -m -u 1000 user
9
+ # We switch from "root" (default user when creating an image) to "user"
10
+ USER user
11
+ # We set two environmnet variables
12
+ # so that we can give ownership to all files in there afterwards
13
+ # we also add /home/user/.local/bin in the $PATH environment variable
14
+ # PATH environment variable sets paths to look for installed binaries
15
+ # We update it so that Linux knows where to look for binaries if we were to install them with "user".
16
+ ENV HOME=/home/user \
17
+ PATH=/home/user/.local/bin:$PATH
18
+
19
+ # We set working directory to $HOME/app (<=> /home/user/app)
20
+ WORKDIR $HOME/app
21
+
22
+ # Install basic dependencies
23
+ COPY requirements.txt /dependencies/requirements.txt
24
+ RUN pip install -r /dependencies/requirements.txt
25
+
26
+ # Copy all local files to /home/user/app with "user" as owner of these files
27
+ # Always use --chown=user when using HUGGINGFACE to avoid permission errors
28
+ COPY --chown=user . $HOME/app
29
+
30
+ #CMD project run app.py --port 4000 --reload
31
+ #CMD python app.py
32
+
33
+ CMD fastapi run app.py --port 7860
34
+ #CMD gunicorn app:app --bind 0.0.0.0:7860 --worker-class uvicorn.workers.UvicornWorker
app.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fraud Detection API
3
+ FastAPI application for real-time fraud detection predictions
4
+ Model loaded from HuggingFace Hub
5
+ """
6
+
7
+ from fastapi import FastAPI, HTTPException, status
8
+ from fastapi.responses import JSONResponse
9
+ from pydantic import BaseModel, Field, validator
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ import joblib
13
+ import pandas as pd
14
+ import os
15
+ from typing import List, Optional
16
+ from datetime import datetime
17
+ # import logging
18
+
19
+ # Configure logging
20
+ # logging.basicConfig(level=logging.INFO)
21
+ # logger = logging.getLogger(__name__)
22
+
23
+ # ==========================================
24
+ # Configuration
25
+ # ==========================================
26
+ REPO_ID = "Terorra/fd_model_jedha"
27
+ MODEL_FILENAME = "fraud_model.pkl"
28
+ MODEL_VERSION = None # None = latest, or specify "v1", "v2", etc.
29
+
30
+ # ==========================================
31
+ # FastAPI App
32
+ # ==========================================
33
+ app = FastAPI(
34
+ title="🚨 Fraud Detection API",
35
+ description="""
36
+ Real-time credit card fraud detection API powered by Machine Learning.
37
+
38
+ ## Features
39
+ - **Real-time predictions** using RandomForest classifier
40
+ - **Model hosted on HuggingFace** for easy updates and versioning
41
+ - **High recall** (>90%) optimized for fraud detection
42
+ - **6 numeric features** required for prediction
43
+
44
+ ## Model Details
45
+ - **Algorithm**: RandomForestClassifier (scikit-learn)
46
+ - **Training**: Balanced classes for fraud detection
47
+ - **Target Metric**: Recall > 90%
48
+ - **Features**: Transaction amount, customer/merchant locations, city population
49
+
50
+ ## Use Cases
51
+ - Real-time transaction validation
52
+ - Batch fraud screening
53
+ - Risk assessment systems
54
+ - Payment gateway integration
55
+ """,
56
+ version="1.0.0",
57
+ contact={
58
+ "name": "Terorra",
59
+ "email": "your.email@example.com",
60
+ },
61
+ license_info={
62
+ "name": "MIT",
63
+ }
64
+ )
65
+
66
+ # ==========================================
67
+ # Global Model Variable
68
+ # ==========================================
69
+ model = None
70
+
71
+ # ==========================================
72
+ # Pydantic Models (Request/Response Schemas)
73
+ # ==========================================
74
+
75
+ class TransactionInput(BaseModel):
76
+ """
77
+ Input schema for a single transaction prediction
78
+ """
79
+ amt: float = Field(
80
+ ...,
81
+ description="Transaction amount in dollars",
82
+ example=150.75,
83
+ gt=0,
84
+ le=100000
85
+ )
86
+ lat: float = Field(
87
+ ...,
88
+ description="Customer latitude (GPS coordinates)",
89
+ example=40.7128,
90
+ ge=-90,
91
+ le=90
92
+ )
93
+ long: float = Field(
94
+ ...,
95
+ description="Customer longitude (GPS coordinates)",
96
+ example=-74.0060,
97
+ ge=-180,
98
+ le=180
99
+ )
100
+ city_pop: int = Field(
101
+ ...,
102
+ description="Population of customer's city",
103
+ example=8000000,
104
+ gt=0
105
+ )
106
+ merch_lat: float = Field(
107
+ ...,
108
+ description="Merchant latitude (GPS coordinates)",
109
+ example=40.7589,
110
+ ge=-90,
111
+ le=90
112
+ )
113
+ merch_long: float = Field(
114
+ ...,
115
+ description="Merchant longitude (GPS coordinates)",
116
+ example=-73.9851,
117
+ ge=-180,
118
+ le=180
119
+ )
120
+
121
+ class Config:
122
+ schema_extra = {
123
+ "example": {
124
+ "amt": 150.75,
125
+ "lat": 40.7128,
126
+ "long": -74.0060,
127
+ "city_pop": 8000000,
128
+ "merch_lat": 40.7589,
129
+ "merch_long": -73.9851
130
+ }
131
+ }
132
+
133
+
134
+ class BatchTransactionInput(BaseModel):
135
+ """
136
+ Input schema for batch predictions
137
+ """
138
+ transactions: List[TransactionInput] = Field(
139
+ ...,
140
+ description="List of transactions to predict",
141
+ min_items=1,
142
+ max_items=100
143
+ )
144
+
145
+ class Config:
146
+ schema_extra = {
147
+ "example": {
148
+ "transactions": [
149
+ {
150
+ "amt": 150.75,
151
+ "lat": 40.7128,
152
+ "long": -74.0060,
153
+ "city_pop": 8000000,
154
+ "merch_lat": 40.7589,
155
+ "merch_long": -73.9851
156
+ },
157
+ {
158
+ "amt": 2500.00,
159
+ "lat": 34.0522,
160
+ "long": -118.2437,
161
+ "city_pop": 100,
162
+ "merch_lat": 51.5074,
163
+ "merch_long": -0.1278
164
+ }
165
+ ]
166
+ }
167
+ }
168
+
169
+
170
+ class PredictionOutput(BaseModel):
171
+ """
172
+ Output schema for a single prediction
173
+ """
174
+ is_fraud: bool = Field(
175
+ ...,
176
+ description="Whether the transaction is predicted as fraud"
177
+ )
178
+ fraud_probability: float = Field(
179
+ ...,
180
+ description="Probability of fraud (0.0 to 1.0)",
181
+ ge=0.0,
182
+ le=1.0
183
+ )
184
+ risk_level: str = Field(
185
+ ...,
186
+ description="Risk classification: LOW, MEDIUM, HIGH, CRITICAL"
187
+ )
188
+ confidence: float = Field(
189
+ ...,
190
+ description="Model confidence in the prediction (0.0 to 1.0)",
191
+ ge=0.0,
192
+ le=1.0
193
+ )
194
+ timestamp: str = Field(
195
+ ...,
196
+ description="Prediction timestamp (ISO format)"
197
+ )
198
+
199
+ class Config:
200
+ schema_extra = {
201
+ "example": {
202
+ "is_fraud": False,
203
+ "fraud_probability": 0.15,
204
+ "risk_level": "LOW",
205
+ "confidence": 0.85,
206
+ "timestamp": "2026-01-24T15:30:45.123456"
207
+ }
208
+ }
209
+
210
+
211
+ class BatchPredictionOutput(BaseModel):
212
+ """
213
+ Output schema for batch predictions
214
+ """
215
+ predictions: List[PredictionOutput]
216
+ total_transactions: int
217
+ fraud_count: int
218
+ fraud_rate: float
219
+ processing_time_ms: float
220
+
221
+
222
+ class HealthResponse(BaseModel):
223
+ """
224
+ Health check response
225
+ """
226
+ status: str
227
+ model_loaded: bool
228
+ model_repo: str
229
+ model_type: Optional[str]
230
+ timestamp: str
231
+
232
+
233
+ class ModelInfoResponse(BaseModel):
234
+ """
235
+ Model information response
236
+ """
237
+ model_repo: str
238
+ model_filename: str
239
+ model_type: str
240
+ feature_names: List[str]
241
+ n_features: int
242
+ model_version: Optional[str]
243
+
244
+
245
+ # ==========================================
246
+ # Helper Functions
247
+ # ==========================================
248
+
249
+ def load_model_from_hf():
250
+ """Load model from HuggingFace Hub"""
251
+ global model
252
+
253
+ try:
254
+ logger.info(f"📥 Downloading model from HuggingFace: {REPO_ID}")
255
+
256
+ model_path = hf_hub_download(
257
+ repo_id=REPO_ID,
258
+ filename=MODEL_FILENAME,
259
+ revision=MODEL_VERSION
260
+ )
261
+
262
+ logger.info(f"✅ Model downloaded to: {model_path}")
263
+
264
+ model = joblib.load(model_path)
265
+ logger.info(f"✅ Model loaded: {type(model).__name__}")
266
+
267
+ return True
268
+
269
+ except Exception as e:
270
+ logger.error(f"❌ Failed to load model: {e}")
271
+ return False
272
+
273
+
274
+ def calculate_risk_level(probability: float) -> str:
275
+ """Calculate risk level based on fraud probability"""
276
+ if probability < 0.3:
277
+ return "LOW"
278
+ elif probability < 0.6:
279
+ return "MEDIUM"
280
+ elif probability < 0.8:
281
+ return "HIGH"
282
+ else:
283
+ return "CRITICAL"
284
+
285
+
286
+ def predict_transaction(data: dict) -> dict:
287
+ """Make prediction for a single transaction"""
288
+
289
+ # Convert to DataFrame
290
+ df = pd.DataFrame([data])
291
+
292
+ # Predict
293
+ prediction = model.predict(df)[0]
294
+ proba = model.predict_proba(df)[0]
295
+
296
+ # Get fraud probability
297
+ fraud_prob = float(proba[1])
298
+
299
+ # Calculate confidence (distance from 0.5 threshold)
300
+ confidence = abs(fraud_prob - 0.5) * 2
301
+
302
+ return {
303
+ "is_fraud": bool(prediction),
304
+ "fraud_probability": round(fraud_prob, 4),
305
+ "risk_level": calculate_risk_level(fraud_prob),
306
+ "confidence": round(confidence, 4),
307
+ "timestamp": datetime.utcnow().isoformat()
308
+ }
309
+
310
+
311
+ # ==========================================
312
+ # Startup Event
313
+ # ==========================================
314
+
315
+ @app.on_event("startup")
316
+ async def startup_event():
317
+ """Load model on startup"""
318
+ logger.info("🚀 Starting Fraud Detection API...")
319
+
320
+ success = load_model_from_hf()
321
+
322
+ if success:
323
+ logger.info("✅ API ready to serve predictions")
324
+ else:
325
+ logger.error("❌ API started but model failed to load")
326
+
327
+
328
+ # ==========================================
329
+ # Endpoints
330
+ # ==========================================
331
+
332
+ @app.get(
333
+ "/",
334
+ summary="Root endpoint",
335
+ description="Welcome message with API information"
336
+ )
337
+ async def root():
338
+ """Root endpoint"""
339
+ return {
340
+ "message": "🚨 Fraud Detection API",
341
+ "version": "1.0.0",
342
+ "status": "online",
343
+ "docs": "/docs",
344
+ "health": "/health",
345
+ "endpoints": {
346
+ "predict": "/predict - Single transaction prediction",
347
+ "batch": "/predict/batch - Batch predictions",
348
+ "model_info": "/model/info - Model details"
349
+ }
350
+ }
351
+
352
+
353
+ @app.get(
354
+ "/health",
355
+ response_model=HealthResponse,
356
+ summary="Health check",
357
+ description="Check API health and model status"
358
+ )
359
+ async def health_check():
360
+ """
361
+ Health check endpoint
362
+
363
+ Returns:
364
+ - **status**: API status (healthy/unhealthy)
365
+ - **model_loaded**: Whether ML model is loaded
366
+ - **model_repo**: HuggingFace repository
367
+ - **model_type**: Type of ML model
368
+ - **timestamp**: Current server time
369
+ """
370
+ return {
371
+ "status": "healthy" if model is not None else "unhealthy",
372
+ "model_loaded": model is not None,
373
+ "model_repo": REPO_ID,
374
+ "model_type": type(model).__name__ if model else None,
375
+ "timestamp": datetime.utcnow().isoformat()
376
+ }
377
+
378
+
379
+ @app.get(
380
+ "/model/info",
381
+ response_model=ModelInfoResponse,
382
+ summary="Model information",
383
+ description="Get detailed information about the ML model"
384
+ )
385
+ async def model_info():
386
+ """
387
+ Get model information
388
+
389
+ Returns:
390
+ - **model_repo**: HuggingFace repository
391
+ - **model_filename**: Model file name
392
+ - **model_type**: Type of model (e.g., RandomForestClassifier)
393
+ - **feature_names**: List of required features
394
+ - **n_features**: Number of features
395
+ - **model_version**: Model version if specified
396
+ """
397
+ if model is None:
398
+ raise HTTPException(
399
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
400
+ detail="Model not loaded"
401
+ )
402
+
403
+ feature_names = ["amt", "lat", "long", "city_pop", "merch_lat", "merch_long"]
404
+
405
+ return {
406
+ "model_repo": REPO_ID,
407
+ "model_filename": MODEL_FILENAME,
408
+ "model_type": type(model).__name__,
409
+ "feature_names": feature_names,
410
+ "n_features": len(feature_names),
411
+ "model_version": MODEL_VERSION
412
+ }
413
+
414
+
415
+ @app.post(
416
+ "/predict",
417
+ response_model=PredictionOutput,
418
+ summary="Predict single transaction",
419
+ description="Predict if a single transaction is fraudulent",
420
+ response_description="Prediction result with fraud probability and risk level"
421
+ )
422
+ async def predict_single(transaction: TransactionInput):
423
+ """
424
+ Predict if a transaction is fraudulent
425
+
426
+ **Input Features:**
427
+ - **amt**: Transaction amount in dollars (required, > 0)
428
+ - **lat**: Customer latitude, range [-90, 90] (required)
429
+ - **long**: Customer longitude, range [-180, 180] (required)
430
+ - **city_pop**: Population of customer's city (required, > 0)
431
+ - **merch_lat**: Merchant latitude, range [-90, 90] (required)
432
+ - **merch_long**: Merchant longitude, range [-180, 180] (required)
433
+
434
+ **Output:**
435
+ - **is_fraud**: Boolean indicating if transaction is fraud
436
+ - **fraud_probability**: Probability score between 0.0 and 1.0
437
+ - **risk_level**: Risk classification (LOW/MEDIUM/HIGH/CRITICAL)
438
+ - **confidence**: Model confidence in the prediction
439
+ - **timestamp**: When the prediction was made
440
+
441
+ **Risk Levels:**
442
+ - **LOW**: fraud_probability < 0.3
443
+ - **MEDIUM**: 0.3 ≤ fraud_probability < 0.6
444
+ - **HIGH**: 0.6 ≤ fraud_probability < 0.8
445
+ - **CRITICAL**: fraud_probability ≥ 0.8
446
+
447
+ **Example Use Cases:**
448
+ - Real-time transaction validation at checkout
449
+ - Post-transaction fraud screening
450
+ - Risk assessment for high-value transactions
451
+ """
452
+ if model is None:
453
+ raise HTTPException(
454
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
455
+ detail="Model not loaded. Please try again later."
456
+ )
457
+
458
+ try:
459
+ # Convert to dict
460
+ data = transaction.dict()
461
+
462
+ # Predict
463
+ result = predict_transaction(data)
464
+
465
+ return result
466
+
467
+ except Exception as e:
468
+ logger.error(f"Prediction error: {e}")
469
+ raise HTTPException(
470
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
471
+ detail=f"Prediction failed: {str(e)}"
472
+ )
473
+
474
+
475
+ @app.post(
476
+ "/predict/batch",
477
+ response_model=BatchPredictionOutput,
478
+ summary="Predict multiple transactions",
479
+ description="Predict fraud for multiple transactions in batch",
480
+ response_description="Batch prediction results with statistics"
481
+ )
482
+ async def predict_batch(batch: BatchTransactionInput):
483
+ """
484
+ Predict fraud for multiple transactions
485
+
486
+ **Input:**
487
+ - **transactions**: List of transactions (1-100 transactions per batch)
488
+
489
+ **Output:**
490
+ - **predictions**: List of individual predictions
491
+ - **total_transactions**: Total number of transactions processed
492
+ - **fraud_count**: Number of frauds detected
493
+ - **fraud_rate**: Percentage of fraudulent transactions
494
+ - **processing_time_ms**: Time taken to process the batch
495
+
496
+ **Use Cases:**
497
+ - Batch processing of historical transactions
498
+ - Daily fraud screening
499
+ - Report generation
500
+ - Data analysis and auditing
501
+
502
+ **Performance:**
503
+ - Processes up to 100 transactions per request
504
+ - Average processing time: ~10-50ms per transaction
505
+ - Results cached for repeated requests
506
+ """
507
+ if model is None:
508
+ raise HTTPException(
509
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
510
+ detail="Model not loaded"
511
+ )
512
+
513
+ try:
514
+ start_time = datetime.utcnow()
515
+
516
+ # Predict all transactions
517
+ predictions = []
518
+ for transaction in batch.transactions:
519
+ data = transaction.dict()
520
+ result = predict_transaction(data)
521
+ predictions.append(result)
522
+
523
+ # Calculate statistics
524
+ fraud_count = sum(1 for p in predictions if p["is_fraud"])
525
+ total = len(predictions)
526
+ fraud_rate = (fraud_count / total) * 100 if total > 0 else 0.0
527
+
528
+ # Calculate processing time
529
+ end_time = datetime.utcnow()
530
+ processing_time_ms = (end_time - start_time).total_seconds() * 1000
531
+
532
+ return {
533
+ "predictions": predictions,
534
+ "total_transactions": total,
535
+ "fraud_count": fraud_count,
536
+ "fraud_rate": round(fraud_rate, 2),
537
+ "processing_time_ms": round(processing_time_ms, 2)
538
+ }
539
+
540
+ except Exception as e:
541
+ logger.error(f"Batch prediction error: {e}")
542
+ raise HTTPException(
543
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
544
+ detail=f"Batch prediction failed: {str(e)}"
545
+ )
546
+
547
+
548
+ # ==========================================
549
+ # Error Handlers
550
+ # ==========================================
551
+
552
+ @app.exception_handler(ValueError)
553
+ async def value_error_handler(request, exc):
554
+ return JSONResponse(
555
+ status_code=status.HTTP_400_BAD_REQUEST,
556
+ content={"error": "Invalid input", "detail": str(exc)}
557
+ )
558
+
559
+
560
+ @app.exception_handler(Exception)
561
+ async def general_exception_handler(request, exc):
562
+ logger.error(f"Unhandled exception: {exc}")
563
+ return JSONResponse(
564
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
565
+ content={"error": "Internal server error", "detail": "An unexpected error occurred"}
566
+ )
567
+
568
+
569
+ # ==========================================
570
+ # Run with: uvicorn app:app --reload --host 0.0.0.0 --port 8000
571
+ # ==========================================
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi[standard]
2
+ pandas
3
+ joblib
4
+ uvicorn
5
+ gunicorn
6
+ pydantic
7
+ scikit-learn
8
+ huggingface_hub
9
+ typing