File size: 3,518 Bytes
20cbff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, field_validator
from typing import List
from fraud_model import FraudDetector
import uvicorn
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global detector instance
detector: FraudDetector = None


@asynccontextmanager
async def lifespan(app: FastAPI):
    global detector
    try:
        logger.info("Loading FraudDetector model...")
        detector = FraudDetector()
        logger.info("FraudDetector loaded successfully.")
    except Exception as e:
        logger.error(f"FATAL: Failed to initialize FraudDetector: {e}")
        raise RuntimeError(f"Model failed to load: {e}")
    yield
    detector = None
    logger.info("FraudDetector shut down.")


app = FastAPI(
    title="Bank Fraud Detection API",
    description="API for detecting fraudulent bank transactions using AI.",
    version="1.0.0",
    lifespan=lifespan
)


# --- Request / Response Models ---

class PredictionRequest(BaseModel):
    text: str

    @field_validator("text")
    @classmethod
    def text_must_not_be_empty(cls, v):
        if not v or not v.strip():
            raise ValueError("text must not be empty")
        return v.strip()


class BatchPredictionRequest(BaseModel):
    texts: List[str]

    @field_validator("texts")
    @classmethod
    def texts_must_not_be_empty(cls, v):
        if not v:
            raise ValueError("texts list must not be empty")
        cleaned = [t.strip() for t in v if t and t.strip()]
        if not cleaned:
            raise ValueError("texts list contains only empty strings")
        return cleaned


class PredictionResponse(BaseModel):
    text: str
    fraud_score: float
    risk_level: str


class AnalyzeResponse(BaseModel):
    text: str
    fraud_score: float
    risk_level: str
    is_fraud: bool
    detection: str


# --- Routes ---

@app.get("/health")
def health_check():
    if detector:
        return {"status": "healthy", "model": detector.model_name}
    return {"status": "unhealthy", "error": "Model not loaded"}


@app.post("/predict", response_model=PredictionResponse)
def predict_single(request: PredictionRequest):
    if not detector:
        raise HTTPException(status_code=503, detail="Model service unavailable")
    try:
        result = detector.predict(request.text)
        return result
    except Exception as e:
        logger.error(f"Prediction error: {e}")
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/predict/batch", response_model=List[PredictionResponse])
def predict_batch(request: BatchPredictionRequest):
    if not detector:
        raise HTTPException(status_code=503, detail="Model service unavailable")
    try:
        results = detector.predict_batch(request.texts)
        return results
    except Exception as e:
        logger.error(f"Batch prediction error: {e}")
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/analyze", response_model=AnalyzeResponse)
def analyze(request: PredictionRequest):
    if not detector:
        raise HTTPException(status_code=503, detail="Model service unavailable")
    try:
        result = detector.analyze(request.text)
        return result
    except Exception as e:
        logger.error(f"Analyze error: {e}")
        raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)