Spaces:
Running
Running
File size: 5,445 Bytes
537db6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os
import logging
import secrets
from typing import Annotated, Optional
from fastapi import (
FastAPI,
Request,
HTTPException,
Security,
Depends
)
from fastapi.background import BackgroundTasks
from fastapi.responses import RedirectResponse
from fastapi.security.api_key import APIKeyHeader
from pydantic import BaseModel, Field
from starlette.status import (
HTTP_200_OK,
HTTP_403_FORBIDDEN,
HTTP_412_PRECONDITION_FAILED,
HTTP_500_INTERNAL_SERVER_ERROR)
from starlette.responses import JSONResponse
from dotenv import load_dotenv
from sqlalchemy.orm import Session
from sqlalchemy import text
from src.repository.common import get_session
from src.service.model import (
train_model_from_scratch,
predict,
all_algorithms,
)
from src.entity.model import Model
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ------------------------------------------------------------------------------
load_dotenv()
FASTAPI_API_KEY = os.getenv("FASTAPI_API_KEY")
safe_clients = ['127.0.0.1']
api_key_header = APIKeyHeader(name='Authorization', auto_error=False)
async def validate_api_key(request: Request, key: str = Security(api_key_header)):
'''
Check if the API key is valid
Args:
key (str): The API key to check
Raises:
HTTPException: If the API key is invalid
'''
if request.client.host not in safe_clients and not secrets.compare_digest(str(key), str(FASTAPI_API_KEY)):
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Unauthorized - API Key is wrong"
)
return None
app = FastAPI(dependencies=[Depends(validate_api_key)] if FASTAPI_API_KEY else None,
title="Fraud detection ML API")
# ------------------------------------------------------------------------------
@app.get("/", include_in_schema=False)
def redirect_to_docs():
'''
Redirect to the API documentation.
'''
return RedirectResponse(url='/docs')
# ------------------------------------------------------------------------------
@app.get("/train_model", tags=["model"])
async def train_model(background_tasks: BackgroundTasks,
limit: Optional[int] = None,
algorithm: Optional[all_algorithms] = 'MLP'):
"""
Train the model
"""
background_tasks.add_task(
func=train_model_from_scratch,
limit=limit,
evaluate=False,
algo=algorithm)
return {"message": "Model training in progress"}
class ModelInput(BaseModel):
transaction_category: str = Field(description='The category of product of the transaction.', example='personal_care')
transaction_amount: float = Field(gt=0, description="The amount of the transaction", example=2.86)
customer_job: str = Field(description='The job of the customer.', example='Mechanical engineer')
customer_address_state: str = Field(description='The state of the customer.', example='SC')
customer_address_city: str = Field(description='The city of the customer.', example='Columbia')
customer_address_city_population: int = Field(gt=0, description="The population of the city", example=100000)
class ModelOutput(BaseModel):
result: int = Field(description="The prediction result. 1 if the transaction is fraudulent, 0 otherwise.", example=1)
fraud_probability: float = Field(description="The probability of the transaction being fraudulent.", example=0.95)
model_metadata: dict = Field(description="The metadata of the model.", example={"model_name": "MLP", "version": "1.0"})
@app.post("/predict",
tags=["model"],
description="Predict the fraudulent transactions",
response_description="Prediction result",
response_model=ModelOutput)
async def make_prediction(params: ModelInput):
"""
Predict the fraudulent nature of a transaction
"""
# check the presence of 'model.pkl' file in data/
if not os.path.exists("./data/model.pkl"):
raise HTTPException(
status_code=HTTP_412_PRECONDITION_FAILED, detail="Model not trained. Please train the model first.")
# Load the model
model = Model.get_instance()
# Make the prediction
prediction = predict(
pipeline=model.pipeline,
job=params.customer_job,
city=params.customer_address_city,
state=params.customer_address_state,
category=params.transaction_category,
amt=params.transaction_amount,
city_pop=params.customer_address_city_population
)
logging.info(prediction)
# Return the prediction
return {
"result": prediction['result'],
"fraud_probability": prediction['fraud_probability'],
"model_metadata": model.metadata
}
# ------------------------------------------------------------------------------
@app.get("/check_health", tags=["general"], description="Check the health of the API")
async def check_health(session: Annotated[Session, Depends(get_session)]):
"""
Check all the services in the infrastructure are working
"""
try:
session.execute(text("SELECT 1"))
return JSONResponse(content={"status": "healthy"}, status_code=HTTP_200_OK)
except Exception as e:
logging.error(f"DB check failed: {e}")
return JSONResponse(content={"status": "unhealthy"}, status_code=HTTP_500_INTERNAL_SERVER_ERROR)
|