harshpatel080503 commited on
Commit
7cb2a8b
·
verified ·
1 Parent(s): 7896908

Upload 4 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ dense_final_finetuned.keras filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==========================================
2
+ # IMPORTS
3
+ # ==========================================
4
+ from fastapi import FastAPI, File, UploadFile
5
+ from pydantic import BaseModel
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi.responses import JSONResponse
8
+ import pandas as pd
9
+ import numpy as np
10
+ import joblib
11
+ import tensorflow as tf
12
+ from PIL import Image
13
+ import io
14
+
15
+ # ==========================================
16
+ # INITIALIZE APP
17
+ # ==========================================
18
+ app = FastAPI(
19
+ title="Stroke Detection API (CT + Clinical Data)",
20
+ description="Deep Learning (DenseNet121) + ML Logistic Regression",
21
+ version="2.0"
22
+ )
23
+
24
+ # CORS setup
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"],
28
+ allow_credentials=True,
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ )
32
+
33
+ # ==========================================
34
+ # LOAD MODELS
35
+ # ==========================================
36
+ logistic_model = joblib.load("E:/Data Science Study/ML Project/Arshi/Models/stroke_logistic_regression_model.pkl")
37
+ preprocessor = joblib.load("E:/Data Science Study/ML Project/Arshi/Models/preprocessor.pkl")
38
+ cnn_model = tf.keras.models.load_model("E:/Data Science Study/ML Project/Arshi/Models/dense_final_finetuned.keras")
39
+
40
+ IMG_SIZE = (224, 224)
41
+
42
+ # ==========================================
43
+ # Pydantic Models
44
+ # ==========================================
45
+ class StrokeInput(BaseModel):
46
+ age: float
47
+ avg_glucose_level: float
48
+ bmi: float
49
+ hypertension: int
50
+ heart_disease: int
51
+ gender: str
52
+ ever_married: str
53
+ Residence_type: str
54
+ work_type: str
55
+ smoking_status: str
56
+
57
+ class StrokeOutput(BaseModel):
58
+ stroke_prediction: int
59
+ stroke_probability: float
60
+
61
+ # ==========================================
62
+ # HELPER FUNCTIONS
63
+ # ==========================================
64
+ def preprocess_image(image_bytes):
65
+ img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
66
+ img = img.resize(IMG_SIZE)
67
+ img_arr = tf.keras.preprocessing.image.img_to_array(img) / 255.0
68
+ img_arr = np.expand_dims(img_arr, axis=0)
69
+ return img_arr
70
+
71
+ def predict_image_cnn(img_tensor, threshold=0.5):
72
+ prob = cnn_model.predict(img_tensor)[0][0]
73
+ label = "Stroke Detected" if prob >= threshold else "Normal Brain"
74
+ return label, float(prob)
75
+
76
+ # ==========================================
77
+ # ENDPOINT 1: STRUCTURED DATA ML MODEL
78
+ # ==========================================
79
+ @app.post("/stroke-predict-struct", response_model=StrokeOutput)
80
+ def predict_stroke_struct(data: StrokeInput):
81
+
82
+ df = pd.DataFrame([data.dict()])
83
+
84
+ # Feature Engineering
85
+ df['age_glu_interaction'] = df['age'] * df['avg_glucose_level']
86
+ df['ht_hd_score'] = df['hypertension'] + df['heart_disease']
87
+
88
+ df['work_type_simplified'] = df['work_type'].replace({
89
+ 'children': 'No_Work',
90
+ 'Never_worked': 'No_Work',
91
+ 'Private': 'Private',
92
+ 'Self-employed': 'Self_Employed',
93
+ 'Govt_job': 'Govt'
94
+ })
95
+
96
+ df['smoke_simplified'] = df['smoking_status'].replace({
97
+ 'formerly smoked': 'Former',
98
+ 'never smoked': 'Never',
99
+ 'smokes': 'Smoker',
100
+ 'Unknown': 'Unknown'
101
+ })
102
+
103
+ df['glucose_bin'] = pd.cut(
104
+ df['avg_glucose_level'],
105
+ bins=[0, 100, 140, np.inf],
106
+ labels=['Normal', 'Prediabetic', 'High']
107
+ )
108
+
109
+ selected_features = [
110
+ 'age','avg_glucose_level','bmi','age_glu_interaction',
111
+ 'hypertension','heart_disease','ht_hd_score',
112
+ 'gender','ever_married','Residence_type',
113
+ 'work_type_simplified','smoke_simplified','glucose_bin'
114
+ ]
115
+
116
+ df = df[selected_features]
117
+
118
+ processed = preprocessor.transform(df)
119
+ prob = logistic_model.predict_proba(processed)[0][1]
120
+ pred = logistic_model.predict(processed)[0]
121
+
122
+ return {
123
+ "stroke_prediction": int(pred),
124
+ "stroke_probability": float(round(prob, 4))
125
+ }
126
+
127
+ # ==========================================
128
+ # ENDPOINT 2: MRI IMAGE CNN MODEL
129
+ # ==========================================
130
+ @app.post("/stroke-predict-image")
131
+ async def predict_stroke_image(file: UploadFile = File(...)):
132
+ image_bytes = await file.read()
133
+ img_tensor = preprocess_image(image_bytes)
134
+
135
+ label, prob = predict_image_cnn(img_tensor)
136
+
137
+ return JSONResponse({
138
+ "filename": file.filename,
139
+ "prediction": label,
140
+ "confidence_score": float(round(prob, 4))
141
+ })
dense_final_finetuned.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d64058908f87df5502186ea7d210c0218b4df1c9ec72e3de47274d32a2b2f05b
3
+ size 39468817
preprocessor.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0a2aee25e2edc731a241d2e3c8d5ba89dd6ad269a6619967b82bab46b0a35d8
3
+ size 6020
stroke_logistic_regression_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c211ffd2f14ad218ade62283081290520f245322dd05c2f6748b442c4d848ba5
3
+ size 1023