Abdullah1211 commited on
Commit
14a47f3
·
verified ·
1 Parent(s): 147a58d

Upload 11 files

Browse files
Files changed (4) hide show
  1. Dockerfile +7 -24
  2. app.py +205 -411
  3. model.joblib +2 -2
  4. requirements.txt +8 -8
Dockerfile CHANGED
@@ -1,29 +1,12 @@
1
- FROM python:3.10-slim
2
 
3
- WORKDIR /app
4
-
5
- # Install build dependencies
6
- RUN apt-get update && apt-get install -y \
7
- build-essential \
8
- && rm -rf /var/lib/apt/lists/*
9
-
10
- # Create a non-root user
11
- RUN useradd -m -u 1000 user
12
- USER user
13
- ENV HOME=/home/user \
14
- PATH=/home/user/.local/bin:$PATH
15
 
16
- # Copy requirements file first for better layer caching
17
- COPY --chown=user:user requirements.txt .
18
-
19
- # Install required packages from requirements.txt
20
- RUN pip install --no-cache-dir --user -r requirements.txt
21
 
22
- # Make sure python-multipart is installed - use quotes to prevent shell issues
23
- RUN pip install --no-cache-dir --user "python-multipart>=0.0.6"
24
 
25
- # Copy application files
26
- COPY --chown=user:user . .
27
 
28
- # Run the application
29
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
1
 
2
+ FROM python:3.10
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ WORKDIR /app
 
 
 
 
5
 
6
+ COPY requirements.txt .
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
 
9
+ COPY model.joblib .
10
+ COPY app.py .
11
 
12
+ CMD ["python", "app.py"]
 
app.py CHANGED
@@ -1,447 +1,241 @@
1
- try:
2
- import multipart
3
- print("python-multipart is installed: ", multipart.__version__)
4
- except ImportError:
5
- print("python-multipart is NOT installed. Installing now...")
6
- import subprocess
7
- subprocess.check_call(["pip", "install", "python-multipart"])
8
- print("python-multipart has been installed")
9
 
10
- from fastapi import FastAPI, Request, HTTPException, Form
11
- from fastapi.middleware.cors import CORSMiddleware
12
- import numpy as np
13
  import joblib
14
- import os
15
- from typing import Optional
16
- import json
17
  import pandas as pd
18
- from sklearn.ensemble import RandomForestClassifier
19
- from sklearn.preprocessing import StandardScaler, OneHotEncoder
 
 
 
20
  import time
21
- import traceback
22
-
23
- app = FastAPI()
24
-
25
- # Add CORS middleware
26
- app.add_middleware(
27
- CORSMiddleware,
28
- allow_origins=["*"], # Allows all origins
29
- allow_credentials=True,
30
- allow_methods=["*"], # Allows all methods
31
- allow_headers=["*"], # Allows all headers
32
- )
33
-
34
- # Risk categories
35
- RISK_CATEGORIES = {
36
- 'Very Low Risk': 0.1,
37
- 'Low Risk': 0.2,
38
- 'Moderate Risk': 0.4,
39
- 'High Risk': 0.6,
40
- 'Very High Risk': 0.8
41
- }
42
-
43
- # Feature importance dictionary based on medical knowledge
44
- FEATURE_IMPORTANCE = {
45
- 'age': 0.20,
46
- 'hypertension': 0.15,
47
- 'heart_disease': 0.15,
48
- 'avg_glucose_level': 0.12,
49
- 'bmi': 0.10,
50
- 'smoking_status': 0.10,
51
- 'gender': 0.08,
52
- 'work_type': 0.05,
53
- 'Residence_type': 0.03,
54
- 'ever_married': 0.02
55
- }
56
 
57
- # Load the model with better error handling
58
  print("Loading model...")
 
 
 
 
 
 
59
  try:
60
- model_path = os.path.join(os.path.dirname(__file__), "model.joblib")
61
- print(f"Model path: {model_path}")
62
-
63
- if not os.path.exists(model_path):
64
- raise FileNotFoundError(f"Model file not found at: {model_path}")
65
-
66
- print(f"Model file exists. Size: {os.path.getsize(model_path) / 1024:.2f} KB")
67
-
68
- model_data = joblib.load(model_path)
69
  print("Model loaded successfully!")
70
 
71
- # Validate model data structure
72
- if not isinstance(model_data, dict):
73
- print(f"Warning: Model data is not a dictionary as expected. Type: {type(model_data)}")
74
- model_loaded = False
75
- rf_model = None
76
- preprocessor = None
77
- encoded_cols = []
78
- numeric_cols = []
79
- else:
80
- rf_model = model_data.get('model')
81
- encoded_cols = model_data.get('encoded_cols', [])
82
- numeric_cols = model_data.get('numeric_cols', [])
83
- preprocessor = model_data.get('preprocessor')
84
-
85
- print(f"Model details: Type: {type(rf_model)}")
86
- print(f"Features: {len(numeric_cols)} numeric features, {len(encoded_cols)} encoded features")
87
-
88
- if rf_model is None:
89
- print("Warning: Model is None")
90
- model_loaded = False
91
- else:
92
- # Test if the model has the expected methods
93
- if hasattr(rf_model, 'predict_proba'):
94
- print("Model has predict_proba method: ✅")
95
- model_loaded = True
96
- else:
97
- print("Warning: Model does not have predict_proba method")
98
- model_loaded = False
99
- except Exception as e:
100
- print(f"Error loading model: {str(e)}")
101
- traceback_str = traceback.format_exc()
102
- print(f"Traceback: {traceback_str}")
103
- rf_model = None
104
- preprocessor = None
105
- encoded_cols = []
106
- numeric_cols = []
107
- model_loaded = False
108
-
109
- def get_risk_level(probability):
110
- """Get risk level based on probability score"""
111
- for category, threshold in RISK_CATEGORIES.items():
112
- if probability < threshold:
113
- return category
114
- return "Very High Risk"
115
-
116
- def preprocess_without_pandas(data):
117
- """Preprocess input data without using pandas"""
118
- # Handle numeric features
119
- numeric_features = []
120
- for col in numeric_cols:
121
- if col == 'age':
122
- numeric_features.append(float(data.get('age', 0)))
123
- elif col == 'avg_glucose_level':
124
- numeric_features.append(float(data.get('avg_glucose_level', 0)))
125
- elif col == 'bmi':
126
- numeric_features.append(float(data.get('bmi', 0)))
127
-
128
- # Create input array for categorical processing
129
- categorical_input = np.array([[
130
- data.get('gender', 'Male'),
131
- data.get('hypertension', 0),
132
- data.get('heart_disease', 0),
133
- data.get('ever_married', 'No'),
134
- data.get('work_type', 'Private'),
135
- data.get('Residence_type', 'Urban'),
136
- data.get('smoking_status', 'never smoked')
137
- ]], dtype=object)
138
 
139
- # Apply preprocessing
140
- if preprocessor is not None:
141
- try:
142
- encoded_features = preprocessor.transform(categorical_input)
143
- # Combine numeric and encoded features
144
- features = np.concatenate([numeric_features, encoded_features.flatten()])
145
- return features.reshape(1, -1)
146
- except Exception as e:
147
- print(f"Error in preprocessing: {str(e)}")
148
-
149
- # Return none if preprocessing fails
150
- return None
151
-
152
- def predict_with_model(features):
153
- """Make prediction using the loaded model"""
154
- try:
155
- if rf_model is not None and features is not None:
156
- start_time = time.time()
157
- probabilities = rf_model.predict_proba(features)
158
- end_time = time.time()
159
-
160
- stroke_probability = probabilities[0, 1] # Class 1 probability (stroke)
161
- risk_level = get_risk_level(stroke_probability)
162
- execution_time_ms = (end_time - start_time) * 1000
163
-
164
- # Get feature importances for explaining prediction
165
- # Note: This uses global feature importance rather than
166
- # instance-specific importance for simplicity
167
- important_features = []
168
- if hasattr(rf_model, 'feature_importances_'):
169
- feature_names = numeric_cols + encoded_cols
170
- importances = rf_model.feature_importances_
171
-
172
- # Get top 5 most important features
173
- imp_indices = np.argsort(importances)[-5:][::-1]
174
- for i in imp_indices:
175
- if i < len(feature_names):
176
- important_features.append({
177
- 'feature': feature_names[i],
178
- 'importance': float(importances[i])
179
- })
180
-
181
- return {
182
- 'probability': stroke_probability,
183
- 'risk_level': risk_level,
184
- 'prediction_success': True,
185
- 'execution_time_ms': execution_time_ms,
186
- 'important_features': important_features
187
- }
188
- except Exception as e:
189
- print(f"Error in model prediction: {str(e)}")
190
 
191
- return {
192
- 'probability': None,
193
- 'risk_level': None,
194
- 'prediction_success': False,
195
- 'execution_time_ms': 0,
196
- 'important_features': []
197
- }
198
-
199
- def get_top_risk_factors(data):
200
- """Get top risk factors based on data and medical knowledge"""
201
- risk_factors = []
202
-
203
- # Calculate risk contribution for each field
204
- # Age risk
205
- if 'age' in data:
206
- age = float(data.get('age', 0))
207
- if age > 75:
208
- risk_factors.append({'factor': 'Advanced Age (>75)', 'contribution': 0.20})
209
- elif age > 65:
210
- risk_factors.append({'factor': 'Elderly Age (>65)', 'contribution': 0.15})
211
- elif age > 55:
212
- risk_factors.append({'factor': 'Higher Age (>55)', 'contribution': 0.10})
213
-
214
- # Major health risk factors
215
- if data.get('hypertension', 0) == 1:
216
- risk_factors.append({'factor': 'Hypertension', 'contribution': 0.15})
217
-
218
- if data.get('heart_disease', 0) == 1:
219
- risk_factors.append({'factor': 'Heart Disease', 'contribution': 0.15})
220
-
221
- # Blood glucose levels
222
- if 'avg_glucose_level' in data:
223
- glucose = float(data.get('avg_glucose_level', 0))
224
- if glucose > 200:
225
- risk_factors.append({'factor': 'Very High Blood Glucose (>200)', 'contribution': 0.12})
226
- elif glucose > 140:
227
- risk_factors.append({'factor': 'High Blood Glucose (>140)', 'contribution': 0.10})
228
-
229
- # BMI-related risk
230
- if 'bmi' in data:
231
- bmi = float(data.get('bmi', 0))
232
- if bmi > 30:
233
- risk_factors.append({'factor': 'Obesity (BMI > 30)', 'contribution': 0.10})
234
- elif bmi > 25:
235
- risk_factors.append({'factor': 'Overweight (BMI > 25)', 'contribution': 0.07})
236
-
237
- # Smoking status
238
- if data.get('smoking_status', '') == 'smokes':
239
- risk_factors.append({'factor': 'Current Smoker', 'contribution': 0.10})
240
- elif data.get('smoking_status', '') == 'formerly smoked':
241
- risk_factors.append({'factor': 'Former Smoker', 'contribution': 0.05})
242
-
243
- # Sort by contribution (highest first)
244
- risk_factors.sort(key=lambda x: x['contribution'], reverse=True)
245
-
246
- return risk_factors
247
 
248
- def fallback_prediction(data):
249
- """Fallback prediction when model fails - improved version"""
250
- # Get top risk factors
251
- risk_factors = get_top_risk_factors(data)
252
-
253
- # Calculate total risk score
254
- total_risk_score = sum(rf['contribution'] for rf in risk_factors)
255
-
256
- # Apply sigmoid function to create probability curve
257
- # This creates more reasonable probability distribution
258
- if total_risk_score > 0:
259
- probability = 1 / (1 + np.exp(-5 * (total_risk_score - 0.5)))
260
- else:
261
- probability = 0.05 # Baseline risk
262
-
263
- return probability, get_risk_level(probability), risk_factors
264
 
265
- @app.get("/")
266
- async def root():
267
- """Root endpoint for documentation and health check"""
268
- return {
269
- "message": "Stroke Prediction API is running",
270
- "model_loaded": model_loaded,
271
- "usage": "Send a POST request to / with patient data",
272
- "example": {
273
- "gender": "Male",
274
- "age": 67,
275
- "hypertension": 1,
276
- "heart_disease": 0,
277
- "ever_married": "Yes",
278
- "work_type": "Private",
279
- "Residence_type": "Urban",
280
- "avg_glucose_level": 228.69,
281
- "bmi": 36.6,
282
- "smoking_status": "formerly smoked"
283
- },
284
- "api_endpoints": {
285
- "standard": "POST /",
286
- "form_data": "POST /api/predict"
287
- },
288
- "model_version": "1.0",
289
- "last_updated": "2023-11-15"
290
- }
291
 
292
- @app.post("/")
293
- async def predict(request: Request):
294
- """Make stroke prediction based on input data"""
295
- try:
296
- start_time = time.time()
297
- data = await request.json()
298
-
299
- # Try using the model first
300
- if model_loaded:
301
- # Preprocess the data
302
- features = preprocess_without_pandas(data)
303
-
304
- # Make prediction
305
- model_result = predict_with_model(features)
306
-
307
- if model_result['prediction_success']:
308
- # Calculate top risk factors for explanation
309
- risk_factors = get_top_risk_factors(data)
310
-
311
- end_time = time.time()
312
- execution_time_ms = (end_time - start_time) * 1000
313
-
314
- return {
315
- "probability": float(model_result['probability']),
316
- "prediction": model_result['risk_level'],
317
- "stroke_prediction": int(model_result['probability'] > 0.5),
318
- "risk_factors": [rf['factor'] for rf in risk_factors],
319
- "important_features": model_result['important_features'],
320
- "execution_time_ms": execution_time_ms,
321
- "using_model": True,
322
- "model_version": "1.0"
323
- }
324
-
325
- # Use fallback if model fails or isn't loaded
326
- probability, risk_level, risk_factors = fallback_prediction(data)
327
- end_time = time.time()
328
- execution_time_ms = (end_time - start_time) * 1000
329
-
330
- return {
331
- "probability": float(probability),
332
- "prediction": risk_level,
333
- "stroke_prediction": int(probability > 0.5),
334
- "risk_factors": [rf['factor'] for rf in risk_factors],
335
- "using_model": False,
336
- "execution_time_ms": execution_time_ms,
337
- "model_version": "fallback-1.0"
338
- }
339
-
340
- except Exception as e:
341
- raise HTTPException(status_code=400, detail=f"Invalid input: {str(e)}")
342
 
 
343
  @app.post("/api/predict")
344
- async def predict_from_form(
345
  gender: Optional[str] = Form(None),
346
  age: Optional[float] = Form(None),
347
  hypertension: Optional[int] = Form(None),
348
  heart_disease: Optional[int] = Form(None),
349
- heartDisease: Optional[int] = Form(None), # Alternative field name from frontend
350
  ever_married: Optional[str] = Form(None),
351
- everMarried: Optional[str] = Form(None), # Alternative field name from frontend
352
  work_type: Optional[str] = Form(None),
353
- workType: Optional[str] = Form(None), # Alternative field name from frontend
354
  Residence_type: Optional[str] = Form(None),
355
- residenceType: Optional[str] = Form(None), # Alternative field name from frontend
356
  avg_glucose_level: Optional[float] = Form(None),
357
- avgGlucoseLevel: Optional[float] = Form(None), # Alternative field name from frontend
358
  bmi: Optional[float] = Form(None),
359
- smoking_status: Optional[str] = Form(None),
360
- smokingStatus: Optional[str] = Form(None), # Alternative field name from frontend
361
- formDataJson: Optional[str] = Form(None)
362
  ):
363
- """API endpoint that accepts form data or JSON string for prediction"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  try:
365
- # Try to use JSON data if provided
366
- if formDataJson:
367
- try:
368
- raw_data = json.loads(formDataJson)
369
- print(f"Received JSON data: {raw_data}")
370
- except json.JSONDecodeError:
371
- raise HTTPException(status_code=400, detail="Invalid JSON in formDataJson")
372
- else:
373
- # Build data dict from form fields with alternative field names
374
- raw_data = {
375
- "gender": gender,
376
- "age": age,
377
- "hypertension": hypertension,
378
- "heart_disease": heart_disease if heart_disease is not None else heartDisease,
379
- "ever_married": ever_married if ever_married is not None else everMarried,
380
- "work_type": work_type if work_type is not None else workType,
381
- "Residence_type": Residence_type if Residence_type is not None else residenceType,
382
- "avg_glucose_level": avg_glucose_level if avg_glucose_level is not None else avgGlucoseLevel,
383
- "bmi": bmi,
384
- "smoking_status": smoking_status if smoking_status is not None else smokingStatus
385
- }
386
- print(f"Received form data: {raw_data}")
387
 
388
- # Map frontend field names to model field names if needed
389
- data = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
- # Handle potential camelCase to snake_case conversions and field name differences
392
- field_mappings = {
393
- "gender": ["gender"],
394
- "age": ["age"],
395
- "hypertension": ["hypertension"],
396
- "heart_disease": ["heart_disease", "heartDisease"],
397
- "ever_married": ["ever_married", "everMarried"],
398
- "work_type": ["work_type", "workType"],
399
- "Residence_type": ["Residence_type", "residenceType"],
400
- "avg_glucose_level": ["avg_glucose_level", "avgGlucoseLevel"],
401
- "bmi": ["bmi"],
402
- "smoking_status": ["smoking_status", "smokingStatus"]
403
  }
404
 
405
- # Fill data with the first available value from mappings
406
- for model_field, possible_keys in field_mappings.items():
407
- for key in possible_keys:
408
- if key in raw_data and raw_data[key] is not None:
409
- data[model_field] = raw_data[key]
410
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
 
412
- # Use defaults for missing fields
413
- if model_field not in data or data[model_field] is None:
414
- if model_field in ["age", "avg_glucose_level", "bmi"]:
415
- data[model_field] = 0
416
- elif model_field in ["hypertension", "heart_disease"]:
417
- data[model_field] = 0
418
- elif model_field == "gender":
419
- data[model_field] = "Male"
420
- elif model_field == "ever_married":
421
- data[model_field] = "No"
422
- elif model_field == "work_type":
423
- data[model_field] = "Private"
424
- elif model_field == "Residence_type":
425
- data[model_field] = "Urban"
426
- elif model_field == "smoking_status":
427
- data[model_field] = "never smoked"
428
 
429
- print(f"Processed data for prediction: {data}")
430
-
431
- # Create a request object with our data
432
- request = Request(scope={"type": "http"})
433
- request._json = data
 
 
 
 
 
 
 
 
 
 
 
434
 
435
- # Pass to main prediction endpoint
436
- result = await predict(request)
437
- print(f"Prediction result: {result}")
438
- return result
439
- except Exception as e:
440
- error_traceback = traceback.format_exc()
441
- print(f"Error processing request: {str(e)}")
442
- print(f"Traceback: {error_traceback}")
443
- raise HTTPException(status_code=400, detail=f"Error processing request: {str(e)}")
 
 
 
 
 
 
 
444
 
 
445
  if __name__ == "__main__":
446
- import uvicorn
447
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
1
 
 
 
 
2
  import joblib
 
 
 
3
  import pandas as pd
4
+ import numpy as np
5
+ from fastapi import FastAPI, Form, File, UploadFile, Request
6
+ from fastapi.responses import JSONResponse
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel
9
  import time
10
+ import json
11
+ from typing import Optional, List, Union
12
+ import uvicorn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Load the trained model
15
  print("Loading model...")
16
+ model_path = "/app/model.joblib"
17
+ import os
18
+ print(f"Model path: {model_path}")
19
+ print(f"Model file exists: {os.path.exists(model_path)}")
20
+ print(f"Model file size: {os.path.getsize(model_path) / 1024:.2f} KB")
21
+
22
  try:
23
+ model_info = joblib.load(model_path)
 
 
 
 
 
 
 
 
24
  print("Model loaded successfully!")
25
 
26
+ # Access model components
27
+ pipeline = model_info['model']
28
+ model = pipeline.named_steps['classifier']
29
+ print(f"Model details: Type: {type(model)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # Get preprocessing info
32
+ numeric_cols = model_info['numeric_cols']
33
+ categorical_cols = model_info['encoded_cols']
34
+ print(f"Features: {len(numeric_cols)} numeric features, {len(categorical_cols)} encoded features")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # Verify model has predict_proba
37
+ has_predict_proba = hasattr(model, 'predict_proba')
38
+ print(f"Model has predict_proba method: {'Yes' if has_predict_proba else 'No'}")
39
+ except Exception as e:
40
+ print(f"Error loading model: {e}")
41
+ model_info = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # Initialize FastAPI
44
+ app = FastAPI(title="Stroke Prediction Model API")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # Add CORS middleware
47
+ app.add_middleware(
48
+ CORSMiddleware,
49
+ allow_origins=["*"],
50
+ allow_credentials=True,
51
+ allow_methods=["*"],
52
+ allow_headers=["*"],
53
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ # Check if python-multipart is installed
56
+ try:
57
+ import multipart
58
+ print("python-multipart is installed: ", multipart.__version__)
59
+ except ImportError:
60
+ print("python-multipart is NOT installed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # Define prediction endpoints
63
  @app.post("/api/predict")
64
+ async def predict_stroke(
65
  gender: Optional[str] = Form(None),
66
  age: Optional[float] = Form(None),
67
  hypertension: Optional[int] = Form(None),
68
  heart_disease: Optional[int] = Form(None),
 
69
  ever_married: Optional[str] = Form(None),
 
70
  work_type: Optional[str] = Form(None),
 
71
  Residence_type: Optional[str] = Form(None),
 
72
  avg_glucose_level: Optional[float] = Form(None),
 
73
  bmi: Optional[float] = Form(None),
74
+ smoking_status: Optional[str] = Form(None)
 
 
75
  ):
76
+ start_time = time.time()
77
+
78
+ # Log the received data
79
+ form_data = {
80
+ 'gender': gender,
81
+ 'age': age,
82
+ 'hypertension': hypertension,
83
+ 'heart_disease': heart_disease,
84
+ 'ever_married': ever_married,
85
+ 'work_type': work_type,
86
+ 'Residence_type': Residence_type,
87
+ 'avg_glucose_level': avg_glucose_level,
88
+ 'bmi': bmi,
89
+ 'smoking_status': smoking_status
90
+ }
91
+ print("Received form data:", form_data)
92
+
93
+ # Process data and fill default values if needed
94
+ processed_data = {
95
+ 'gender': gender if gender else 'Male',
96
+ 'age': float(age) if age is not None else 0,
97
+ 'hypertension': int(hypertension) if hypertension is not None else 0,
98
+ 'heart_disease': int(heart_disease) if heart_disease is not None else 0,
99
+ 'ever_married': ever_married if ever_married else 'No',
100
+ 'work_type': work_type if work_type else 'Private',
101
+ 'Residence_type': Residence_type if Residence_type else 'Urban',
102
+ 'avg_glucose_level': float(avg_glucose_level) if avg_glucose_level is not None else 0,
103
+ 'bmi': float(bmi) if bmi is not None else 0,
104
+ 'smoking_status': smoking_status if smoking_status else 'never smoked'
105
+ }
106
+ print("Processed data for prediction:", processed_data)
107
+
108
+ # Create a DataFrame from the processed data
109
+ input_df = pd.DataFrame([processed_data])
110
+
111
+ # Prediction with fallback
112
  try:
113
+ if model_info is None:
114
+ raise ValueError("Model not loaded")
115
+
116
+ # Get prediction from model
117
+ prediction_proba = pipeline.predict_proba(input_df)[0][1]
118
+ prediction_binary = pipeline.predict(input_df)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
+ # Calculate risk level
121
+ if prediction_proba < 0.1:
122
+ risk_level = "Very Low Risk"
123
+ elif prediction_proba < 0.3:
124
+ risk_level = "Low Risk"
125
+ elif prediction_proba < 0.6:
126
+ risk_level = "Moderate Risk"
127
+ else:
128
+ risk_level = "High Risk"
129
+
130
+ # Identify risk factors
131
+ risk_factors = []
132
+ if processed_data['hypertension'] == 1:
133
+ risk_factors.append("Hypertension")
134
+ if processed_data['heart_disease'] == 1:
135
+ risk_factors.append("Heart Disease")
136
+ if processed_data['age'] > 65:
137
+ risk_factors.append("Advanced Age (65+)")
138
+ if processed_data['avg_glucose_level'] > 140:
139
+ risk_factors.append("High Blood Glucose (>140)")
140
+ if processed_data['bmi'] > 30:
141
+ risk_factors.append("Obesity (BMI > 30)")
142
+ if processed_data['smoking_status'] == 'formerly smoked':
143
+ risk_factors.append("Former Smoker")
144
+ if processed_data['smoking_status'] == 'smokes':
145
+ risk_factors.append("Current Smoker")
146
 
147
+ # Return results
148
+ result = {
149
+ "probability": float(prediction_proba),
150
+ "prediction": risk_level,
151
+ "stroke_prediction": int(prediction_binary),
152
+ "risk_factors": risk_factors,
153
+ "using_model": True,
154
+ "execution_time_ms": (time.time() - start_time) * 1000,
155
+ "model_version": "stroke-prediction-1.0"
 
 
 
156
  }
157
 
158
+ except Exception as e:
159
+ print("Error in preprocessing:", e)
160
+
161
+ # Fallback risk calculation
162
+ fallback_probability = 0.05 # Default low risk
163
+
164
+ # Increase risk based on known factors
165
+ if processed_data['hypertension'] == 1:
166
+ fallback_probability += 0.1
167
+
168
+ if processed_data['heart_disease'] == 1:
169
+ fallback_probability += 0.1
170
+
171
+ if processed_data['age'] > 65:
172
+ fallback_probability += 0.15
173
+ elif processed_data['age'] > 55:
174
+ fallback_probability += 0.1
175
+
176
+ if processed_data['avg_glucose_level'] > 180:
177
+ fallback_probability += 0.1
178
+ elif processed_data['avg_glucose_level'] > 140:
179
+ fallback_probability += 0.05
180
+
181
+ if processed_data['bmi'] > 30:
182
+ fallback_probability += 0.05
183
+
184
+ if processed_data['smoking_status'] == 'smokes':
185
+ fallback_probability += 0.07
186
+ elif processed_data['smoking_status'] == 'formerly smoked':
187
+ fallback_probability += 0.03
188
+
189
+ # Cap at 80%
190
+ fallback_probability = min(fallback_probability, 0.8)
191
+
192
+ # Determine risk level
193
+ if fallback_probability < 0.1:
194
+ risk_level = "Very Low Risk"
195
+ elif fallback_probability < 0.3:
196
+ risk_level = "Low Risk"
197
+ elif fallback_probability < 0.6:
198
+ risk_level = "Moderate Risk"
199
+ else:
200
+ risk_level = "High Risk"
201
 
202
+ # Threshold for binary prediction
203
+ stroke_prediction = 1 if fallback_probability > 0.5 else 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
+ # Identify risk factors
206
+ risk_factors = []
207
+ if processed_data['hypertension'] == 1:
208
+ risk_factors.append("Hypertension")
209
+ if processed_data['heart_disease'] == 1:
210
+ risk_factors.append("Heart Disease")
211
+ if processed_data['age'] > 65:
212
+ risk_factors.append("Advanced Age (65+)")
213
+ if processed_data['avg_glucose_level'] > 140:
214
+ risk_factors.append("High Blood Glucose (>140)")
215
+ if processed_data['bmi'] > 30:
216
+ risk_factors.append("Obesity (BMI > 30)")
217
+ if processed_data['smoking_status'] == 'formerly smoked':
218
+ risk_factors.append("Former Smoker")
219
+ if processed_data['smoking_status'] == 'smokes':
220
+ risk_factors.append("Current Smoker")
221
 
222
+ result = {
223
+ "probability": fallback_probability,
224
+ "prediction": risk_level,
225
+ "stroke_prediction": stroke_prediction,
226
+ "risk_factors": risk_factors,
227
+ "using_model": False,
228
+ "execution_time_ms": (time.time() - start_time) * 1000,
229
+ "model_version": "fallback-1.0"
230
+ }
231
+
232
+ print("Prediction result:", result)
233
+ return result
234
+
235
+ @app.get("/")
236
+ async def root():
237
+ return {"message": "Stroke Prediction API is running! Use /api/predict for predictions."}
238
 
239
+ # Run the server
240
  if __name__ == "__main__":
241
+ uvicorn.run(app, host="0.0.0.0", port=7860)
 
model.joblib CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d4f68bdffeff53794d8c886eab70b97b1180e22985a6482869bac19fb60f5a29
3
- size 6057
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:311ed211d2620c9b995f86949f49d53467cb8933678e858f98e671a93c4f4c09
3
+ size 10381670
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
- fastapi==0.103.1
2
- uvicorn==0.23.2
3
- python-multipart==0.0.6
4
- numpy==1.25.2
5
- pandas==2.0.3
6
- scikit-learn==1.3.0
7
- joblib==1.3.2
8
- requests==2.31.0
 
1
+
2
+ fastapi>=0.95.1
3
+ uvicorn>=0.22.0
4
+ pandas>=1.5.3
5
+ numpy>=1.23.5
6
+ scikit-learn>=1.2.2
7
+ joblib>=1.2.0
8
+ python-multipart>=0.0.6