Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -72,8 +72,9 @@ def predict():
|
|
| 72 |
global model
|
| 73 |
print("Received predict request at /predict endpoint")
|
| 74 |
|
| 75 |
-
# Check model file
|
| 76 |
-
model_path = 'stroke_prediction_model.pkl'
|
|
|
|
| 77 |
if not os.path.exists(model_path):
|
| 78 |
print(f"Model file not found at {model_path}")
|
| 79 |
return jsonify({'success': False, 'error': f'Model file not found at {model_path}'}), 500
|
|
@@ -97,17 +98,17 @@ def predict():
|
|
| 97 |
input_data = data['input']
|
| 98 |
print(f"Processed input data: {input_data}")
|
| 99 |
|
| 100 |
-
# Validate and convert input data
|
| 101 |
features = {
|
| 102 |
-
'age': float(input_data.get('age', 0)),
|
| 103 |
'gender': input_data.get('gender', 'Unknown'),
|
| 104 |
'ever_married': input_data.get('ever_married', 'No'),
|
| 105 |
'residence_type': input_data.get('residence_type', 'Unknown'),
|
| 106 |
'work_type': input_data.get('work_type', 'Unknown'),
|
| 107 |
-
'hypertension': int(input_data.get('hypertension', 0)),
|
| 108 |
-
'heart_disease': int(input_data.get('heart_disease', 0)),
|
| 109 |
-
'avg_glucose_level': float(input_data.get('avg_glucose_level', 0)),
|
| 110 |
-
'bmi': float(input_data.get('bmi', 0)),
|
| 111 |
'smoking_status': input_data.get('smoking_status', 'Unknown')
|
| 112 |
}
|
| 113 |
print(f"Converted features: {features}")
|
|
@@ -137,9 +138,13 @@ def predict():
|
|
| 137 |
print(f"Aligned DataFrame columns: {df.columns.tolist()}")
|
| 138 |
|
| 139 |
# Make prediction
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
# Determine contributing factors
|
| 145 |
contributing_factors = {
|
|
@@ -150,15 +155,14 @@ def predict():
|
|
| 150 |
}
|
| 151 |
print(f"Contributing factors: {contributing_factors}")
|
| 152 |
|
| 153 |
-
#
|
| 154 |
-
"""
|
| 155 |
try:
|
| 156 |
conn = sqlite3.connect('/tmp/submissions.db')
|
| 157 |
c = conn.cursor()
|
| 158 |
c.execute('''
|
| 159 |
INSERT INTO submissions (
|
| 160 |
timestamp, age, gender, ever_married, residence_type, work_type,
|
| 161 |
-
hypertension, heart_disease, avg_glucose_level, bmi,
|
| 162 |
probability, risk_level
|
| 163 |
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 164 |
''', (
|
|
@@ -179,10 +183,9 @@ def predict():
|
|
| 179 |
conn.commit()
|
| 180 |
print("Data successfully written to database")
|
| 181 |
except Exception as db_error:
|
| 182 |
-
print(f"Database write error: {str(db_error)}")
|
| 183 |
finally:
|
| 184 |
conn.close()
|
| 185 |
-
"""
|
| 186 |
|
| 187 |
# Return prediction result
|
| 188 |
return jsonify({
|
|
@@ -192,7 +195,7 @@ def predict():
|
|
| 192 |
'contributingFactors': contributing_factors
|
| 193 |
})
|
| 194 |
except Exception as e:
|
| 195 |
-
print(f"
|
| 196 |
return jsonify({'success': False, 'error': str(e)}), 500
|
| 197 |
|
| 198 |
@app.route('/admin/report', methods=['GET'])
|
|
|
|
| 72 |
global model
|
| 73 |
print("Received predict request at /predict endpoint")
|
| 74 |
|
| 75 |
+
# Check and resolve model file path
|
| 76 |
+
model_path = os.path.join(os.getcwd(), 'stroke_prediction_model.pkl')
|
| 77 |
+
print(f"Checking model file at: {model_path}")
|
| 78 |
if not os.path.exists(model_path):
|
| 79 |
print(f"Model file not found at {model_path}")
|
| 80 |
return jsonify({'success': False, 'error': f'Model file not found at {model_path}'}), 500
|
|
|
|
| 98 |
input_data = data['input']
|
| 99 |
print(f"Processed input data: {input_data}")
|
| 100 |
|
| 101 |
+
# Validate and convert input data with fallback
|
| 102 |
features = {
|
| 103 |
+
'age': float(input_data.get('age', 0)) if input_data.get('age') else 0,
|
| 104 |
'gender': input_data.get('gender', 'Unknown'),
|
| 105 |
'ever_married': input_data.get('ever_married', 'No'),
|
| 106 |
'residence_type': input_data.get('residence_type', 'Unknown'),
|
| 107 |
'work_type': input_data.get('work_type', 'Unknown'),
|
| 108 |
+
'hypertension': int(input_data.get('hypertension', 0)) if input_data.get('hypertension') else 0,
|
| 109 |
+
'heart_disease': int(input_data.get('heart_disease', 0)) if input_data.get('heart_disease') else 0,
|
| 110 |
+
'avg_glucose_level': float(input_data.get('avg_glucose_level', 0)) if input_data.get('avg_glucose_level') else 0,
|
| 111 |
+
'bmi': float(input_data.get('bmi', 0)) if input_data.get('bmi') else 0,
|
| 112 |
'smoking_status': input_data.get('smoking_status', 'Unknown')
|
| 113 |
}
|
| 114 |
print(f"Converted features: {features}")
|
|
|
|
| 138 |
print(f"Aligned DataFrame columns: {df.columns.tolist()}")
|
| 139 |
|
| 140 |
# Make prediction
|
| 141 |
+
try:
|
| 142 |
+
probability = model.predict_proba(df)[0][1] * 100
|
| 143 |
+
risk_prediction = "Stroke Risk" if probability > 50 else "No Stroke Risk"
|
| 144 |
+
print(f"Prediction result: probability={probability}%, prediction={risk_prediction}")
|
| 145 |
+
except Exception as pred_error:
|
| 146 |
+
print(f"Prediction error: {str(pred_error)}")
|
| 147 |
+
return jsonify({'success': False, 'error': f'Prediction failed: {str(pred_error)}'}), 500
|
| 148 |
|
| 149 |
# Determine contributing factors
|
| 150 |
contributing_factors = {
|
|
|
|
| 155 |
}
|
| 156 |
print(f"Contributing factors: {contributing_factors}")
|
| 157 |
|
| 158 |
+
# Attempt to store in database with fallback
|
|
|
|
| 159 |
try:
|
| 160 |
conn = sqlite3.connect('/tmp/submissions.db')
|
| 161 |
c = conn.cursor()
|
| 162 |
c.execute('''
|
| 163 |
INSERT INTO submissions (
|
| 164 |
timestamp, age, gender, ever_married, residence_type, work_type,
|
| 165 |
+
hypertension, heart_disease, avg_glucose_level, bmi, smoking_status,
|
| 166 |
probability, risk_level
|
| 167 |
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 168 |
''', (
|
|
|
|
| 183 |
conn.commit()
|
| 184 |
print("Data successfully written to database")
|
| 185 |
except Exception as db_error:
|
| 186 |
+
print(f"Database write error (non-critical): {str(db_error)}")
|
| 187 |
finally:
|
| 188 |
conn.close()
|
|
|
|
| 189 |
|
| 190 |
# Return prediction result
|
| 191 |
return jsonify({
|
|
|
|
| 195 |
'contributingFactors': contributing_factors
|
| 196 |
})
|
| 197 |
except Exception as e:
|
| 198 |
+
print(f"Unexpected error during prediction: {str(e)}")
|
| 199 |
return jsonify({'success': False, 'error': str(e)}), 500
|
| 200 |
|
| 201 |
@app.route('/admin/report', methods=['GET'])
|