|
|
from fastapi import FastAPI, Request, HTTPException
|
|
|
import numpy as np
|
|
|
import joblib
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
|
def get_risk_level(probability):
|
|
|
if probability < 0.2:
|
|
|
return "Very Low Risk"
|
|
|
elif probability < 0.4:
|
|
|
return "Low Risk"
|
|
|
elif probability < 0.6:
|
|
|
return "Moderate Risk"
|
|
|
elif probability < 0.8:
|
|
|
return "High Risk"
|
|
|
else:
|
|
|
return "Very High Risk"
|
|
|
|
|
|
|
|
|
def predict_risk(data):
|
|
|
|
|
|
risk_factors = 0
|
|
|
|
|
|
if data.get('hypertension', 0) == 1:
|
|
|
risk_factors += 1
|
|
|
if data.get('heart_disease', 0) == 1:
|
|
|
risk_factors += 1
|
|
|
if data.get('age', 0) > 65:
|
|
|
risk_factors += 1
|
|
|
if data.get('smoking_status', '') == 'smokes':
|
|
|
risk_factors += 1
|
|
|
if data.get('avg_glucose_level', 0) > 140:
|
|
|
risk_factors += 1
|
|
|
if data.get('bmi', 0) > 30:
|
|
|
risk_factors += 1
|
|
|
|
|
|
|
|
|
if risk_factors == 0:
|
|
|
probability = 0.05
|
|
|
elif risk_factors == 1:
|
|
|
probability = 0.15
|
|
|
elif risk_factors == 2:
|
|
|
probability = 0.30
|
|
|
elif risk_factors == 3:
|
|
|
probability = 0.60
|
|
|
else:
|
|
|
probability = 0.80
|
|
|
|
|
|
return probability, get_risk_level(probability)
|
|
|
|
|
|
@app.get("/")
|
|
|
async def root():
|
|
|
return {
|
|
|
"message": "Stroke Prediction API is running",
|
|
|
"usage": "Send a POST request to / with patient data",
|
|
|
"example": {
|
|
|
"gender": "Male",
|
|
|
"age": 67,
|
|
|
"hypertension": 1,
|
|
|
"heart_disease": 0,
|
|
|
"avg_glucose_level": 228.69,
|
|
|
"bmi": 36.6,
|
|
|
"smoking_status": "formerly smoked"
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@app.post("/")
|
|
|
async def predict(request: Request):
|
|
|
try:
|
|
|
data = await request.json()
|
|
|
|
|
|
|
|
|
probability, risk_level = predict_risk(data)
|
|
|
|
|
|
return {
|
|
|
"probability": float(probability),
|
|
|
"prediction": risk_level,
|
|
|
"stroke_prediction": int(probability > 0.5)
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid input: {str(e)}") |