sghorbal
Initial commit
537db6d
import os
import logging
import secrets
from typing import Annotated, Optional
from fastapi import (
FastAPI,
Request,
HTTPException,
Security,
Depends
)
from fastapi.background import BackgroundTasks
from fastapi.responses import RedirectResponse
from fastapi.security.api_key import APIKeyHeader
from pydantic import BaseModel, Field
from starlette.status import (
HTTP_200_OK,
HTTP_403_FORBIDDEN,
HTTP_412_PRECONDITION_FAILED,
HTTP_500_INTERNAL_SERVER_ERROR)
from starlette.responses import JSONResponse
from dotenv import load_dotenv
from sqlalchemy.orm import Session
from sqlalchemy import text
from src.repository.common import get_session
from src.service.model import (
train_model_from_scratch,
predict,
all_algorithms,
)
from src.entity.model import Model
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ------------------------------------------------------------------------------
load_dotenv()
FASTAPI_API_KEY = os.getenv("FASTAPI_API_KEY")
safe_clients = ['127.0.0.1']
api_key_header = APIKeyHeader(name='Authorization', auto_error=False)
async def validate_api_key(request: Request, key: str = Security(api_key_header)):
'''
Check if the API key is valid
Args:
key (str): The API key to check
Raises:
HTTPException: If the API key is invalid
'''
if request.client.host not in safe_clients and not secrets.compare_digest(str(key), str(FASTAPI_API_KEY)):
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Unauthorized - API Key is wrong"
)
return None
app = FastAPI(dependencies=[Depends(validate_api_key)] if FASTAPI_API_KEY else None,
title="Fraud detection ML API")
# ------------------------------------------------------------------------------
@app.get("/", include_in_schema=False)
def redirect_to_docs():
'''
Redirect to the API documentation.
'''
return RedirectResponse(url='/docs')
# ------------------------------------------------------------------------------
@app.get("/train_model", tags=["model"])
async def train_model(background_tasks: BackgroundTasks,
limit: Optional[int] = None,
algorithm: Optional[all_algorithms] = 'MLP'):
"""
Train the model
"""
background_tasks.add_task(
func=train_model_from_scratch,
limit=limit,
evaluate=False,
algo=algorithm)
return {"message": "Model training in progress"}
class ModelInput(BaseModel):
transaction_category: str = Field(description='The category of product of the transaction.', example='personal_care')
transaction_amount: float = Field(gt=0, description="The amount of the transaction", example=2.86)
customer_job: str = Field(description='The job of the customer.', example='Mechanical engineer')
customer_address_state: str = Field(description='The state of the customer.', example='SC')
customer_address_city: str = Field(description='The city of the customer.', example='Columbia')
customer_address_city_population: int = Field(gt=0, description="The population of the city", example=100000)
class ModelOutput(BaseModel):
result: int = Field(description="The prediction result. 1 if the transaction is fraudulent, 0 otherwise.", example=1)
fraud_probability: float = Field(description="The probability of the transaction being fraudulent.", example=0.95)
model_metadata: dict = Field(description="The metadata of the model.", example={"model_name": "MLP", "version": "1.0"})
@app.post("/predict",
tags=["model"],
description="Predict the fraudulent transactions",
response_description="Prediction result",
response_model=ModelOutput)
async def make_prediction(params: ModelInput):
"""
Predict the fraudulent nature of a transaction
"""
# check the presence of 'model.pkl' file in data/
if not os.path.exists("./data/model.pkl"):
raise HTTPException(
status_code=HTTP_412_PRECONDITION_FAILED, detail="Model not trained. Please train the model first.")
# Load the model
model = Model.get_instance()
# Make the prediction
prediction = predict(
pipeline=model.pipeline,
job=params.customer_job,
city=params.customer_address_city,
state=params.customer_address_state,
category=params.transaction_category,
amt=params.transaction_amount,
city_pop=params.customer_address_city_population
)
logging.info(prediction)
# Return the prediction
return {
"result": prediction['result'],
"fraud_probability": prediction['fraud_probability'],
"model_metadata": model.metadata
}
# ------------------------------------------------------------------------------
@app.get("/check_health", tags=["general"], description="Check the health of the API")
async def check_health(session: Annotated[Session, Depends(get_session)]):
"""
Check all the services in the infrastructure are working
"""
try:
session.execute(text("SELECT 1"))
return JSONResponse(content={"status": "healthy"}, status_code=HTTP_200_OK)
except Exception as e:
logging.error(f"DB check failed: {e}")
return JSONResponse(content={"status": "unhealthy"}, status_code=HTTP_500_INTERNAL_SERVER_ERROR)