Spaces:
Running
Running
| import os | |
| import logging | |
| import secrets | |
| from typing import Annotated, Optional | |
| from fastapi import ( | |
| FastAPI, | |
| Request, | |
| HTTPException, | |
| Security, | |
| Depends | |
| ) | |
| from fastapi.background import BackgroundTasks | |
| from fastapi.responses import RedirectResponse | |
| from fastapi.security.api_key import APIKeyHeader | |
| from pydantic import BaseModel, Field | |
| from starlette.status import ( | |
| HTTP_200_OK, | |
| HTTP_403_FORBIDDEN, | |
| HTTP_412_PRECONDITION_FAILED, | |
| HTTP_500_INTERNAL_SERVER_ERROR) | |
| from starlette.responses import JSONResponse | |
| from dotenv import load_dotenv | |
| from sqlalchemy.orm import Session | |
| from sqlalchemy import text | |
| from src.repository.common import get_session | |
| from src.service.model import ( | |
| train_model_from_scratch, | |
| predict, | |
| all_algorithms, | |
| ) | |
| from src.entity.model import Model | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ------------------------------------------------------------------------------ | |
| load_dotenv() | |
| FASTAPI_API_KEY = os.getenv("FASTAPI_API_KEY") | |
| safe_clients = ['127.0.0.1'] | |
| api_key_header = APIKeyHeader(name='Authorization', auto_error=False) | |
| async def validate_api_key(request: Request, key: str = Security(api_key_header)): | |
| ''' | |
| Check if the API key is valid | |
| Args: | |
| key (str): The API key to check | |
| Raises: | |
| HTTPException: If the API key is invalid | |
| ''' | |
| if request.client.host not in safe_clients and not secrets.compare_digest(str(key), str(FASTAPI_API_KEY)): | |
| raise HTTPException( | |
| status_code=HTTP_403_FORBIDDEN, detail="Unauthorized - API Key is wrong" | |
| ) | |
| return None | |
| app = FastAPI(dependencies=[Depends(validate_api_key)] if FASTAPI_API_KEY else None, | |
| title="Fraud detection ML API") | |
| # ------------------------------------------------------------------------------ | |
| def redirect_to_docs(): | |
| ''' | |
| Redirect to the API documentation. | |
| ''' | |
| return RedirectResponse(url='/docs') | |
| # ------------------------------------------------------------------------------ | |
| async def train_model(background_tasks: BackgroundTasks, | |
| limit: Optional[int] = None, | |
| algorithm: Optional[all_algorithms] = 'MLP'): | |
| """ | |
| Train the model | |
| """ | |
| background_tasks.add_task( | |
| func=train_model_from_scratch, | |
| limit=limit, | |
| evaluate=False, | |
| algo=algorithm) | |
| return {"message": "Model training in progress"} | |
| class ModelInput(BaseModel): | |
| transaction_category: str = Field(description='The category of product of the transaction.', example='personal_care') | |
| transaction_amount: float = Field(gt=0, description="The amount of the transaction", example=2.86) | |
| customer_job: str = Field(description='The job of the customer.', example='Mechanical engineer') | |
| customer_address_state: str = Field(description='The state of the customer.', example='SC') | |
| customer_address_city: str = Field(description='The city of the customer.', example='Columbia') | |
| customer_address_city_population: int = Field(gt=0, description="The population of the city", example=100000) | |
| class ModelOutput(BaseModel): | |
| result: int = Field(description="The prediction result. 1 if the transaction is fraudulent, 0 otherwise.", example=1) | |
| fraud_probability: float = Field(description="The probability of the transaction being fraudulent.", example=0.95) | |
| model_metadata: dict = Field(description="The metadata of the model.", example={"model_name": "MLP", "version": "1.0"}) | |
| async def make_prediction(params: ModelInput): | |
| """ | |
| Predict the fraudulent nature of a transaction | |
| """ | |
| # check the presence of 'model.pkl' file in data/ | |
| if not os.path.exists("./data/model.pkl"): | |
| raise HTTPException( | |
| status_code=HTTP_412_PRECONDITION_FAILED, detail="Model not trained. Please train the model first.") | |
| # Load the model | |
| model = Model.get_instance() | |
| # Make the prediction | |
| prediction = predict( | |
| pipeline=model.pipeline, | |
| job=params.customer_job, | |
| city=params.customer_address_city, | |
| state=params.customer_address_state, | |
| category=params.transaction_category, | |
| amt=params.transaction_amount, | |
| city_pop=params.customer_address_city_population | |
| ) | |
| logging.info(prediction) | |
| # Return the prediction | |
| return { | |
| "result": prediction['result'], | |
| "fraud_probability": prediction['fraud_probability'], | |
| "model_metadata": model.metadata | |
| } | |
| # ------------------------------------------------------------------------------ | |
| async def check_health(session: Annotated[Session, Depends(get_session)]): | |
| """ | |
| Check all the services in the infrastructure are working | |
| """ | |
| try: | |
| session.execute(text("SELECT 1")) | |
| return JSONResponse(content={"status": "healthy"}, status_code=HTTP_200_OK) | |
| except Exception as e: | |
| logging.error(f"DB check failed: {e}") | |
| return JSONResponse(content={"status": "unhealthy"}, status_code=HTTP_500_INTERNAL_SERVER_ERROR) | |