File size: 4,568 Bytes
186ce20 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | # ==========================================
# IMPORTS
# ==========================================
from fastapi import FastAPI, File, UploadFile
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import pandas as pd
import numpy as np
import joblib
import tensorflow as tf
from PIL import Image
import io
import os
# Get the base directory for relative paths
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# ==========================================
# INITIALIZE APP
# ==========================================
app = FastAPI(
title="Stroke Detection API (CT + Clinical Data)",
description="Deep Learning (DenseNet121) + ML Logistic Regression",
version="2.0"
)
# CORS setup
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ==========================================
# LOAD MODELS
# ==========================================
logistic_model = joblib.load(os.path.join(BASE_DIR, "models", "stroke_logistic_regression_model.pkl"))
preprocessor = joblib.load(os.path.join(BASE_DIR, "models", "preprocessor.pkl"))
cnn_model = tf.keras.models.load_model(os.path.join(BASE_DIR, "models", "dense_final_finetuned.keras"))
IMG_SIZE = (224, 224)
# ==========================================
# Pydantic Models
# ==========================================
class StrokeInput(BaseModel):
age: float
avg_glucose_level: float
bmi: float
hypertension: int
heart_disease: int
gender: str
ever_married: str
Residence_type: str
work_type: str
smoking_status: str
class StrokeOutput(BaseModel):
stroke_prediction: int
stroke_probability: float
# ==========================================
# HELPER FUNCTIONS
# ==========================================
def preprocess_image(image_bytes):
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
img = img.resize(IMG_SIZE)
img_arr = tf.keras.preprocessing.image.img_to_array(img) / 255.0
img_arr = np.expand_dims(img_arr, axis=0)
return img_arr
def predict_image_cnn(img_tensor, threshold=0.5):
prob = cnn_model.predict(img_tensor)[0][0]
label = "Stroke Detected" if prob >= threshold else "Normal Brain"
return label, float(prob)
# ==========================================
# ENDPOINT 1: STRUCTURED DATA ML MODEL
# ==========================================
@app.post("/stroke-predict-struct", response_model=StrokeOutput)
def predict_stroke_struct(data: StrokeInput):
df = pd.DataFrame([data.dict()])
# Feature Engineering
df['age_glu_interaction'] = df['age'] * df['avg_glucose_level']
df['ht_hd_score'] = df['hypertension'] + df['heart_disease']
df['work_type_simplified'] = df['work_type'].replace({
'children': 'No_Work',
'Never_worked': 'No_Work',
'Private': 'Private',
'Self-employed': 'Self_Employed',
'Govt_job': 'Govt'
})
df['smoke_simplified'] = df['smoking_status'].replace({
'formerly smoked': 'Former',
'never smoked': 'Never',
'smokes': 'Smoker',
'Unknown': 'Unknown'
})
df['glucose_bin'] = pd.cut(
df['avg_glucose_level'],
bins=[0, 100, 140, np.inf],
labels=['Normal', 'Prediabetic', 'High']
)
selected_features = [
'age','avg_glucose_level','bmi','age_glu_interaction',
'hypertension','heart_disease','ht_hd_score',
'gender','ever_married','Residence_type',
'work_type_simplified','smoke_simplified','glucose_bin'
]
df = df[selected_features]
processed = preprocessor.transform(df)
prob = logistic_model.predict_proba(processed)[0][1]
pred = logistic_model.predict(processed)[0]
return {
"stroke_prediction": int(pred),
"stroke_probability": float(round(prob, 4))
}
# ==========================================
# ENDPOINT 2: MRI IMAGE CNN MODEL
# ==========================================
@app.post("/stroke-predict-image")
async def predict_stroke_image(file: UploadFile = File(...)):
image_bytes = await file.read()
img_tensor = preprocess_image(image_bytes)
label, prob = predict_image_cnn(img_tensor)
return JSONResponse({
"filename": file.filename,
"prediction": label,
"confidence_score": float(round(prob, 4))
}) |