harshpatel080503's picture
Update main.py
04f2d23 verified
# ==========================================
# IMPORTS
# ==========================================
from fastapi import FastAPI, File, UploadFile
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import pandas as pd
import numpy as np
import joblib
import tensorflow as tf
from PIL import Image
import io
# ==========================================
# INITIALIZE APP
# ==========================================
app = FastAPI(
title="Stroke Detection API (CT + Clinical Data)",
description="Deep Learning (DenseNet121) + ML Logistic Regression",
version="2.0"
)
# CORS setup
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ==========================================
# LOAD MODELS
# ==========================================
logistic_model = joblib.load("stroke_logistic_regression_model.pkl")
preprocessor = joblib.load("preprocessor.pkl")
cnn_model = tf.keras.models.load_model("dense_final_finetuned.keras")
IMG_SIZE = (224, 224)
# ==========================================
# Pydantic Models
# ==========================================
class StrokeInput(BaseModel):
age: float
avg_glucose_level: float
bmi: float
hypertension: int
heart_disease: int
gender: str
ever_married: str
Residence_type: str
work_type: str
smoking_status: str
class StrokeOutput(BaseModel):
stroke_prediction: int
stroke_probability: float
# ==========================================
# HELPER FUNCTIONS
# ==========================================
def preprocess_image(image_bytes):
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
img = img.resize(IMG_SIZE)
img_arr = tf.keras.preprocessing.image.img_to_array(img) / 255.0
img_arr = np.expand_dims(img_arr, axis=0)
return img_arr
def predict_image_cnn(img_tensor, threshold=0.5):
prob = cnn_model.predict(img_tensor)[0][0]
label = "Stroke Detected" if prob >= threshold else "Normal Brain"
return label, float(prob)
# ==========================================
# ENDPOINT 1: STRUCTURED DATA ML MODEL
# ==========================================
@app.post("/stroke-predict-struct", response_model=StrokeOutput)
def predict_stroke_struct(data: StrokeInput):
df = pd.DataFrame([data.dict()])
# Feature Engineering
df['age_glu_interaction'] = df['age'] * df['avg_glucose_level']
df['ht_hd_score'] = df['hypertension'] + df['heart_disease']
df['work_type_simplified'] = df['work_type'].replace({
'children': 'No_Work',
'Never_worked': 'No_Work',
'Private': 'Private',
'Self-employed': 'Self_Employed',
'Govt_job': 'Govt'
})
df['smoke_simplified'] = df['smoking_status'].replace({
'formerly smoked': 'Former',
'never smoked': 'Never',
'smokes': 'Smoker',
'Unknown': 'Unknown'
})
df['glucose_bin'] = pd.cut(
df['avg_glucose_level'],
bins=[0, 100, 140, np.inf],
labels=['Normal', 'Prediabetic', 'High']
)
selected_features = [
'age','avg_glucose_level','bmi','age_glu_interaction',
'hypertension','heart_disease','ht_hd_score',
'gender','ever_married','Residence_type',
'work_type_simplified','smoke_simplified','glucose_bin'
]
df = df[selected_features]
processed = preprocessor.transform(df)
prob = logistic_model.predict_proba(processed)[0][1]
pred = logistic_model.predict(processed)[0]
return {
"stroke_prediction": int(pred),
"stroke_probability": float(round(prob, 4))
}
# ==========================================
# ENDPOINT 2: MRI IMAGE CNN MODEL
# ==========================================
@app.post("/stroke-predict-image")
async def predict_stroke_image(file: UploadFile = File(...)):
image_bytes = await file.read()
img_tensor = preprocess_image(image_bytes)
label, prob = predict_image_cnn(img_tensor)
return JSONResponse({
"filename": file.filename,
"prediction": label,
"confidence_score": float(round(prob, 4))
})