harshpatel080503 commited on
Commit
04f2d23
·
verified ·
1 Parent(s): 0c8ab17

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +140 -140
main.py CHANGED
@@ -1,141 +1,141 @@
1
- # ==========================================
2
- # IMPORTS
3
- # ==========================================
4
- from fastapi import FastAPI, File, UploadFile
5
- from pydantic import BaseModel
6
- from fastapi.middleware.cors import CORSMiddleware
7
- from fastapi.responses import JSONResponse
8
- import pandas as pd
9
- import numpy as np
10
- import joblib
11
- import tensorflow as tf
12
- from PIL import Image
13
- import io
14
-
15
- # ==========================================
16
- # INITIALIZE APP
17
- # ==========================================
18
- app = FastAPI(
19
- title="Stroke Detection API (CT + Clinical Data)",
20
- description="Deep Learning (DenseNet121) + ML Logistic Regression",
21
- version="2.0"
22
- )
23
-
24
- # CORS setup
25
- app.add_middleware(
26
- CORSMiddleware,
27
- allow_origins=["*"],
28
- allow_credentials=True,
29
- allow_methods=["*"],
30
- allow_headers=["*"],
31
- )
32
-
33
- # ==========================================
34
- # LOAD MODELS
35
- # ==========================================
36
- logistic_model = joblib.load("E:/Data Science Study/ML Project/Arshi/Models/stroke_logistic_regression_model.pkl")
37
- preprocessor = joblib.load("E:/Data Science Study/ML Project/Arshi/Models/preprocessor.pkl")
38
- cnn_model = tf.keras.models.load_model("E:/Data Science Study/ML Project/Arshi/Models/dense_final_finetuned.keras")
39
-
40
- IMG_SIZE = (224, 224)
41
-
42
- # ==========================================
43
- # Pydantic Models
44
- # ==========================================
45
- class StrokeInput(BaseModel):
46
- age: float
47
- avg_glucose_level: float
48
- bmi: float
49
- hypertension: int
50
- heart_disease: int
51
- gender: str
52
- ever_married: str
53
- Residence_type: str
54
- work_type: str
55
- smoking_status: str
56
-
57
- class StrokeOutput(BaseModel):
58
- stroke_prediction: int
59
- stroke_probability: float
60
-
61
- # ==========================================
62
- # HELPER FUNCTIONS
63
- # ==========================================
64
- def preprocess_image(image_bytes):
65
- img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
66
- img = img.resize(IMG_SIZE)
67
- img_arr = tf.keras.preprocessing.image.img_to_array(img) / 255.0
68
- img_arr = np.expand_dims(img_arr, axis=0)
69
- return img_arr
70
-
71
- def predict_image_cnn(img_tensor, threshold=0.5):
72
- prob = cnn_model.predict(img_tensor)[0][0]
73
- label = "Stroke Detected" if prob >= threshold else "Normal Brain"
74
- return label, float(prob)
75
-
76
- # ==========================================
77
- # ENDPOINT 1: STRUCTURED DATA ML MODEL
78
- # ==========================================
79
- @app.post("/stroke-predict-struct", response_model=StrokeOutput)
80
- def predict_stroke_struct(data: StrokeInput):
81
-
82
- df = pd.DataFrame([data.dict()])
83
-
84
- # Feature Engineering
85
- df['age_glu_interaction'] = df['age'] * df['avg_glucose_level']
86
- df['ht_hd_score'] = df['hypertension'] + df['heart_disease']
87
-
88
- df['work_type_simplified'] = df['work_type'].replace({
89
- 'children': 'No_Work',
90
- 'Never_worked': 'No_Work',
91
- 'Private': 'Private',
92
- 'Self-employed': 'Self_Employed',
93
- 'Govt_job': 'Govt'
94
- })
95
-
96
- df['smoke_simplified'] = df['smoking_status'].replace({
97
- 'formerly smoked': 'Former',
98
- 'never smoked': 'Never',
99
- 'smokes': 'Smoker',
100
- 'Unknown': 'Unknown'
101
- })
102
-
103
- df['glucose_bin'] = pd.cut(
104
- df['avg_glucose_level'],
105
- bins=[0, 100, 140, np.inf],
106
- labels=['Normal', 'Prediabetic', 'High']
107
- )
108
-
109
- selected_features = [
110
- 'age','avg_glucose_level','bmi','age_glu_interaction',
111
- 'hypertension','heart_disease','ht_hd_score',
112
- 'gender','ever_married','Residence_type',
113
- 'work_type_simplified','smoke_simplified','glucose_bin'
114
- ]
115
-
116
- df = df[selected_features]
117
-
118
- processed = preprocessor.transform(df)
119
- prob = logistic_model.predict_proba(processed)[0][1]
120
- pred = logistic_model.predict(processed)[0]
121
-
122
- return {
123
- "stroke_prediction": int(pred),
124
- "stroke_probability": float(round(prob, 4))
125
- }
126
-
127
- # ==========================================
128
- # ENDPOINT 2: MRI IMAGE CNN MODEL
129
- # ==========================================
130
- @app.post("/stroke-predict-image")
131
- async def predict_stroke_image(file: UploadFile = File(...)):
132
- image_bytes = await file.read()
133
- img_tensor = preprocess_image(image_bytes)
134
-
135
- label, prob = predict_image_cnn(img_tensor)
136
-
137
- return JSONResponse({
138
- "filename": file.filename,
139
- "prediction": label,
140
- "confidence_score": float(round(prob, 4))
141
  })
 
1
+ # ==========================================
2
+ # IMPORTS
3
+ # ==========================================
4
+ from fastapi import FastAPI, File, UploadFile
5
+ from pydantic import BaseModel
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi.responses import JSONResponse
8
+ import pandas as pd
9
+ import numpy as np
10
+ import joblib
11
+ import tensorflow as tf
12
+ from PIL import Image
13
+ import io
14
+
15
+ # ==========================================
16
+ # INITIALIZE APP
17
+ # ==========================================
18
+ app = FastAPI(
19
+ title="Stroke Detection API (CT + Clinical Data)",
20
+ description="Deep Learning (DenseNet121) + ML Logistic Regression",
21
+ version="2.0"
22
+ )
23
+
24
+ # CORS setup
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"],
28
+ allow_credentials=True,
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ )
32
+
33
+ # ==========================================
34
+ # LOAD MODELS
35
+ # ==========================================
36
+ logistic_model = joblib.load("stroke_logistic_regression_model.pkl")
37
+ preprocessor = joblib.load("preprocessor.pkl")
38
+ cnn_model = tf.keras.models.load_model("dense_final_finetuned.keras")
39
+
40
+ IMG_SIZE = (224, 224)
41
+
42
+ # ==========================================
43
+ # Pydantic Models
44
+ # ==========================================
45
+ class StrokeInput(BaseModel):
46
+ age: float
47
+ avg_glucose_level: float
48
+ bmi: float
49
+ hypertension: int
50
+ heart_disease: int
51
+ gender: str
52
+ ever_married: str
53
+ Residence_type: str
54
+ work_type: str
55
+ smoking_status: str
56
+
57
+ class StrokeOutput(BaseModel):
58
+ stroke_prediction: int
59
+ stroke_probability: float
60
+
61
+ # ==========================================
62
+ # HELPER FUNCTIONS
63
+ # ==========================================
64
+ def preprocess_image(image_bytes):
65
+ img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
66
+ img = img.resize(IMG_SIZE)
67
+ img_arr = tf.keras.preprocessing.image.img_to_array(img) / 255.0
68
+ img_arr = np.expand_dims(img_arr, axis=0)
69
+ return img_arr
70
+
71
+ def predict_image_cnn(img_tensor, threshold=0.5):
72
+ prob = cnn_model.predict(img_tensor)[0][0]
73
+ label = "Stroke Detected" if prob >= threshold else "Normal Brain"
74
+ return label, float(prob)
75
+
76
+ # ==========================================
77
+ # ENDPOINT 1: STRUCTURED DATA ML MODEL
78
+ # ==========================================
79
+ @app.post("/stroke-predict-struct", response_model=StrokeOutput)
80
+ def predict_stroke_struct(data: StrokeInput):
81
+
82
+ df = pd.DataFrame([data.dict()])
83
+
84
+ # Feature Engineering
85
+ df['age_glu_interaction'] = df['age'] * df['avg_glucose_level']
86
+ df['ht_hd_score'] = df['hypertension'] + df['heart_disease']
87
+
88
+ df['work_type_simplified'] = df['work_type'].replace({
89
+ 'children': 'No_Work',
90
+ 'Never_worked': 'No_Work',
91
+ 'Private': 'Private',
92
+ 'Self-employed': 'Self_Employed',
93
+ 'Govt_job': 'Govt'
94
+ })
95
+
96
+ df['smoke_simplified'] = df['smoking_status'].replace({
97
+ 'formerly smoked': 'Former',
98
+ 'never smoked': 'Never',
99
+ 'smokes': 'Smoker',
100
+ 'Unknown': 'Unknown'
101
+ })
102
+
103
+ df['glucose_bin'] = pd.cut(
104
+ df['avg_glucose_level'],
105
+ bins=[0, 100, 140, np.inf],
106
+ labels=['Normal', 'Prediabetic', 'High']
107
+ )
108
+
109
+ selected_features = [
110
+ 'age','avg_glucose_level','bmi','age_glu_interaction',
111
+ 'hypertension','heart_disease','ht_hd_score',
112
+ 'gender','ever_married','Residence_type',
113
+ 'work_type_simplified','smoke_simplified','glucose_bin'
114
+ ]
115
+
116
+ df = df[selected_features]
117
+
118
+ processed = preprocessor.transform(df)
119
+ prob = logistic_model.predict_proba(processed)[0][1]
120
+ pred = logistic_model.predict(processed)[0]
121
+
122
+ return {
123
+ "stroke_prediction": int(pred),
124
+ "stroke_probability": float(round(prob, 4))
125
+ }
126
+
127
+ # ==========================================
128
+ # ENDPOINT 2: MRI IMAGE CNN MODEL
129
+ # ==========================================
130
+ @app.post("/stroke-predict-image")
131
+ async def predict_stroke_image(file: UploadFile = File(...)):
132
+ image_bytes = await file.read()
133
+ img_tensor = preprocess_image(image_bytes)
134
+
135
+ label, prob = predict_image_cnn(img_tensor)
136
+
137
+ return JSONResponse({
138
+ "filename": file.filename,
139
+ "prediction": label,
140
+ "confidence_score": float(round(prob, 4))
141
  })