ML_RFA1 / main.py
LazyBoss's picture
Update main.py
669c133 verified
import joblib
from pydantic import BaseModel
from fastapi import FastAPI, HTTPException
import uvicorn
import logging
import os
logging.basicConfig(level=logging.INFO)
# 1. Load the trained model
try:
if os.path.exists('frauddetection.pkl'):
model = joblib.load('frauddetection.pkl')
else:
raise FileNotFoundError("Model file 'frauddetection.pkl' not found.")
except Exception as e:
logging.error(f"Error loading model: {e}")
model = None
# 2. Define the input data schema using Pydantic BaseModel
class InputData(BaseModel):
Year: int
Month: int
UseChip: int
Amount: int
MerchantName: int
MerchantCity: int
MerchantState: int
mcc: int
# 3. Create a FastAPI app
app = FastAPI()
@app.get('/')
def welcome():
return {"Welcome": "This is the home page of the API"}
# 4. Define the prediction route
@app.post('/predict/')
async def predict(data: InputData):
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded properly.")
# Convert the input data to a dictionary
input_data = data.dict()
# Extract the input features from the dictionary
feature_list = [
input_data['Year'],
input_data['Month'],
input_data['UseChip'],
input_data['Amount'],
input_data['MerchantName'],
input_data['MerchantCity'],
input_data['MerchantState'],
input_data['mcc']
]
try:
# Perform the prediction using the loaded model
prediction = model.predict([feature_list]) # Ensure model expects the correct number of features
result = "Fraud" if prediction[0] == 1 else "Not a Fraud"
return {"prediction": result}
except Exception as e:
logging.error(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail="An error occurred during prediction.")
# 5. Run the API with uvicorn
# Will run on http://127.0.0.1:8080
if __name__ == '__main__':
uvicorn.run(app, host="127.0.0.1", port=8080)