File size: 5,445 Bytes
537db6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import logging
import secrets
from typing import Annotated, Optional
from fastapi import (
    FastAPI,
    Request,
    HTTPException,
    Security,
    Depends
)
from fastapi.background import BackgroundTasks
from fastapi.responses import RedirectResponse
from fastapi.security.api_key import APIKeyHeader
from pydantic import BaseModel, Field
from starlette.status import (
    HTTP_200_OK,
    HTTP_403_FORBIDDEN,
    HTTP_412_PRECONDITION_FAILED,
    HTTP_500_INTERNAL_SERVER_ERROR)
from starlette.responses import JSONResponse
from dotenv import load_dotenv
from sqlalchemy.orm import Session
from sqlalchemy import text

from src.repository.common import get_session
from src.service.model import (
    train_model_from_scratch,
    predict,
    all_algorithms,
)
from src.entity.model import Model

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ------------------------------------------------------------------------------

load_dotenv()
FASTAPI_API_KEY = os.getenv("FASTAPI_API_KEY")
safe_clients = ['127.0.0.1']

api_key_header = APIKeyHeader(name='Authorization', auto_error=False)

async def validate_api_key(request: Request, key: str = Security(api_key_header)):
    '''
    Check if the API key is valid

    Args:
        key (str): The API key to check
    
    Raises:
        HTTPException: If the API key is invalid
    '''
    if request.client.host not in safe_clients and not secrets.compare_digest(str(key), str(FASTAPI_API_KEY)):
        raise HTTPException(
            status_code=HTTP_403_FORBIDDEN, detail="Unauthorized - API Key is wrong"
        )
    return None

app = FastAPI(dependencies=[Depends(validate_api_key)] if FASTAPI_API_KEY else None,
              title="Fraud detection ML API")


# ------------------------------------------------------------------------------
@app.get("/", include_in_schema=False)
def redirect_to_docs():
    '''
    Redirect to the API documentation.
    '''
    return RedirectResponse(url='/docs')


# ------------------------------------------------------------------------------
@app.get("/train_model", tags=["model"])
async def train_model(background_tasks: BackgroundTasks,
                      limit: Optional[int] = None,
                      algorithm: Optional[all_algorithms] = 'MLP'):
    """
    Train the model
    """
    background_tasks.add_task(
        func=train_model_from_scratch,
        limit=limit,
        evaluate=False,
        algo=algorithm)
    
    return {"message": "Model training in progress"}

class ModelInput(BaseModel):
    transaction_category: str = Field(description='The category of product of the transaction.', example='personal_care')
    transaction_amount: float = Field(gt=0, description="The amount of the transaction", example=2.86)
    customer_job: str = Field(description='The job of the customer.', example='Mechanical engineer')
    customer_address_state: str = Field(description='The state of the customer.', example='SC')
    customer_address_city: str = Field(description='The city of the customer.', example='Columbia')
    customer_address_city_population: int = Field(gt=0, description="The population of the city", example=100000)

class ModelOutput(BaseModel):
    result: int = Field(description="The prediction result. 1 if the transaction is fraudulent, 0 otherwise.", example=1)
    fraud_probability: float = Field(description="The probability of the transaction being fraudulent.", example=0.95)
    model_metadata: dict = Field(description="The metadata of the model.", example={"model_name": "MLP", "version": "1.0"})

@app.post("/predict",
         tags=["model"],
         description="Predict the fraudulent transactions",
         response_description="Prediction result",
         response_model=ModelOutput)
async def make_prediction(params: ModelInput):
    """
    Predict the fraudulent nature of a transaction
    """
    # check the presence of 'model.pkl' file in data/
    if not os.path.exists("./data/model.pkl"):
        raise HTTPException(
            status_code=HTTP_412_PRECONDITION_FAILED, detail="Model not trained. Please train the model first.")

    # Load the model
    model = Model.get_instance()

    # Make the prediction
    prediction = predict(
        pipeline=model.pipeline,
        job=params.customer_job,
        city=params.customer_address_city,
        state=params.customer_address_state,
        category=params.transaction_category,
        amt=params.transaction_amount,
        city_pop=params.customer_address_city_population
    )

    logging.info(prediction)

    # Return the prediction
    return {
        "result": prediction['result'],
        "fraud_probability": prediction['fraud_probability'],
        "model_metadata": model.metadata
    }

# ------------------------------------------------------------------------------
@app.get("/check_health", tags=["general"], description="Check the health of the API")
async def check_health(session: Annotated[Session, Depends(get_session)]):
    """
    Check all the services in the infrastructure are working
    """
    try:
        session.execute(text("SELECT 1"))
        return JSONResponse(content={"status": "healthy"}, status_code=HTTP_200_OK)
    except Exception as e:
        logging.error(f"DB check failed: {e}")
        return JSONResponse(content={"status": "unhealthy"}, status_code=HTTP_500_INTERNAL_SERVER_ERROR)