Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify, render_template, send_file | |
| import pandas as pd | |
| import numpy as np | |
| import joblib | |
| import sqlite3 | |
| import io | |
| from datetime import datetime | |
| import xlsxwriter | |
| from flask_httpauth import HTTPBasicAuth | |
| import os | |
| app = Flask(__name__, template_folder='templates') | |
| app.config['WTF_CSRF_ENABLED'] = False # Disable CSRF for testing | |
| auth = HTTPBasicAuth() | |
| # Global variable to hold the model (lazy-loaded) | |
| model = None | |
| # Admin credentials for report access | |
| users = { | |
| "admin": "your_secure_password" # Replace with a strong password | |
| } | |
| def verify_password(username, password): | |
| if username in users and users[username] == password: | |
| return username | |
| return None | |
| # Initialize SQLite database | |
| def init_db(): | |
| db_path = '/tmp/submissions.db' | |
| print(f"Initializing database at {db_path}") | |
| try: | |
| conn = sqlite3.connect(db_path) | |
| c = conn.cursor() | |
| c.execute(''' | |
| CREATE TABLE IF NOT EXISTS submissions ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| timestamp TEXT, | |
| age REAL, | |
| gender TEXT, | |
| ever_married TEXT, | |
| residence_type TEXT, | |
| work_type TEXT, | |
| hypertension INTEGER, | |
| heart_disease INTEGER, | |
| avg_glucose_level REAL, | |
| bmi REAL, | |
| smoking_status TEXT, | |
| probability INTEGER, | |
| risk_level TEXT | |
| ) | |
| ''') | |
| conn.commit() | |
| print("Database initialized successfully") | |
| except Exception as e: | |
| print(f"Error initializing database: {str(e)}") | |
| finally: | |
| conn.close() | |
| init_db() | |
| def home(): | |
| print("Serving home page") | |
| return render_template('index.html') | |
| def predict(): | |
| global model | |
| print("Received predict request at /predict endpoint") | |
| # Check and resolve model file path | |
| model_path = os.path.join(os.getcwd(), 'stroke_prediction_model.pkl') | |
| print(f"Checking model file at: {model_path}") | |
| if not os.path.exists(model_path): | |
| print(f"Model file not found at {model_path}") | |
| return jsonify({'success': False, 'error': f'Model file not found at {model_path}'}), 500 | |
| # Load model if not already loaded | |
| if model is None: | |
| try: | |
| model = joblib.load(model_path) | |
| print("Model loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| return jsonify({'success': False, 'error': f'Model failed to load: {str(e)}'}), 500 | |
| try: | |
| # Parse incoming data | |
| data = request.json | |
| print(f"Received JSON data: {data}") | |
| if not data or 'input' not in data: | |
| print("Invalid input data format") | |
| return jsonify({'success': False, 'error': 'Invalid input data'}), 400 | |
| input_data = data['input'] | |
| print(f"Processed input data: {input_data}") | |
| # Validate and convert input data with fallback | |
| features = { | |
| 'age': float(input_data.get('age', 0)) if input_data.get('age') else 0, | |
| 'gender': input_data.get('gender', 'Unknown'), | |
| 'ever_married': input_data.get('ever_married', 'No'), | |
| 'residence_type': input_data.get('residence_type', 'Unknown'), | |
| 'work_type': input_data.get('work_type', 'Unknown'), | |
| 'hypertension': int(input_data.get('hypertension', 0)) if input_data.get('hypertension') else 0, | |
| 'heart_disease': int(input_data.get('heart_disease', 0)) if input_data.get('heart_disease') else 0, | |
| 'avg_glucose_level': float(input_data.get('avg_glucose_level', 0)) if input_data.get('avg_glucose_level') else 0, | |
| 'bmi': float(input_data.get('bmi', 0)) if input_data.get('bmi') else 0, | |
| 'smoking_status': input_data.get('smoking_status', 'Unknown') | |
| } | |
| print(f"Converted features: {features}") | |
| # Create DataFrame and handle categorical variables | |
| df = pd.DataFrame([features]) | |
| categorical_cols = ['gender', 'ever_married', 'residence_type', 'work_type', 'smoking_status'] | |
| df = pd.get_dummies(df, columns=categorical_cols, drop_first=True) | |
| print(f"DataFrame after get_dummies: {df.columns.tolist()}") | |
| # Define expected columns based on model training | |
| expected_columns = model.feature_names_in_ if hasattr(model, 'feature_names_in_') else [ | |
| 'age', 'hypertension', 'heart_disease', 'avg_glucose_level', 'bmi', | |
| 'gender_Male', 'gender_Other', 'ever_married_Yes', 'residence_type_Urban', | |
| 'work_type_Govt_job', 'work_type_Never_worked', 'work_type_Private', | |
| 'work_type_Self-employed', | |
| 'smoking_status_formerly smoked', 'smoking_status_never smoked', | |
| 'smoking_status_smokes' | |
| ] | |
| print(f"Expected columns: {expected_columns}") | |
| # Align DataFrame columns | |
| for col in expected_columns: | |
| if col not in df.columns: | |
| df[col] = 0 | |
| df = df[expected_columns] | |
| print(f"Aligned DataFrame columns: {df.columns.tolist()}") | |
| # Make prediction | |
| try: | |
| probability = model.predict_proba(df)[0][1] * 100 | |
| risk_prediction = "Stroke Risk" if probability > 50 else "No Stroke Risk" | |
| print(f"Prediction result: probability={probability}%, prediction={risk_prediction}") | |
| except Exception as pred_error: | |
| print(f"Prediction error: {str(pred_error)}") | |
| return jsonify({'success': False, 'error': f'Prediction failed: {str(pred_error)}'}), 500 | |
| # Determine contributing factors | |
| contributing_factors = { | |
| 'glucose': features['avg_glucose_level'] > 140, | |
| 'hypertension': features['hypertension'] == 1, | |
| 'heartDisease': features['heart_disease'] == 1, | |
| 'smoking': features['smoking_status'] in ['smokes', 'formerly smoked'] | |
| } | |
| print(f"Contributing factors: {contributing_factors}") | |
| # Attempt to store in database with fallback | |
| try: | |
| conn = sqlite3.connect('/tmp/submissions.db') | |
| c = conn.cursor() | |
| c.execute(''' | |
| INSERT INTO submissions ( | |
| timestamp, age, gender, ever_married, residence_type, work_type, | |
| hypertension, heart_disease, avg_glucose_level, bmi, smoking_status, | |
| probability, risk_level | |
| ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) | |
| ''', ( | |
| datetime.now().isoformat(), | |
| features['age'], | |
| features['gender'], | |
| features['ever_married'], | |
| features['residence_type'], | |
| features['work_type'], | |
| features['hypertension'], | |
| features['heart_disease'], | |
| features['avg_glucose_level'], | |
| features['bmi'], | |
| features['smoking_status'], | |
| round(probability), | |
| risk_prediction | |
| )) | |
| conn.commit() | |
| print("Data successfully written to database") | |
| except Exception as db_error: | |
| print(f"Database write error (non-critical): {str(db_error)}") | |
| finally: | |
| conn.close() | |
| # Return prediction result | |
| return jsonify({ | |
| 'success': True, | |
| 'prediction': risk_prediction, | |
| 'probability': round(probability), | |
| 'contributingFactors': contributing_factors | |
| }) | |
| except Exception as e: | |
| print(f"Unexpected error during prediction: {str(e)}") | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| def admin_report(): | |
| try: | |
| today = datetime.now().strftime('%Y-%m-%d') | |
| conn = sqlite3.connect('/tmp/submissions.db') | |
| df = pd.read_sql_query('SELECT * FROM submissions WHERE DATE(timestamp) = ?', conn, params=[today]) | |
| conn.close() | |
| if df.empty: | |
| return "No submissions for today.", 200 | |
| df['hypertension'] = df['hypertension'].apply(lambda x: 'Yes' if x == 1 else 'No') | |
| df['heart_disease'] = df['heart_disease'].apply(lambda x: 'Yes' if x == 1 else 'No') | |
| df = df[[ | |
| 'timestamp', 'age', 'gender', 'ever_married', 'residence_type', 'work_type', | |
| 'hypertension', 'heart_disease', 'avg_glucose_level', 'bmi', 'smoking_status', | |
| 'probability', 'risk_level' | |
| ]] | |
| df.columns = [ | |
| 'Timestamp', 'Age', 'Gender', 'Ever Married', 'Residence Type', 'Work Type', | |
| 'Hypertension', 'Heart Disease', 'Avg Glucose Level', 'BMI', 'Smoking Status', | |
| 'Probability (%)', 'Risk Level' | |
| ] | |
| output = io.BytesIO() | |
| with pd.ExcelWriter(output, engine='xlsxwriter') as writer: | |
| df.to_excel(writer, sheet_name='Daily Report', index=False) | |
| worksheet = writer.sheets['Daily Report'] | |
| for idx, col in enumerate(df.columns): | |
| max_len = max(df[col].astype(str).map(len).max(), len(col)) + 2 | |
| worksheet.set_column(idx, idx, max_len) | |
| output.seek(0) | |
| return send_file( | |
| output, | |
| download_name=f'report_{today}.xlsx', | |
| as_attachment=True, | |
| mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' | |
| ) | |
| except Exception as e: | |
| print(f"Error generating report: {str(e)}") | |
| return f"Error generating report: {str(e)}", 500 |