Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,1087 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import joblib
|
| 4 |
+
import shap
|
| 5 |
+
import json
|
| 6 |
+
import plotly
|
| 7 |
+
import plotly.graph_objs as go
|
| 8 |
+
import plotly.express as px
|
| 9 |
+
from flask import Flask, render_template, request, jsonify
|
| 10 |
+
import math
|
| 11 |
+
import networkx as nx
|
| 12 |
+
import traceback # For detailed error logging
|
| 13 |
+
import os # For environment variables
|
| 14 |
+
|
| 15 |
+
app = Flask(__name__)
|
| 16 |
+
|
| 17 |
+
# --- Load Models ---
|
| 18 |
+
MODEL_PATH_NA = 'non-adherence_XAI.pkl'
|
| 19 |
+
MODEL_PATH_R = 'readmission_XAI.pkl'
|
| 20 |
+
try:
|
| 21 |
+
model_na = joblib.load(MODEL_PATH_NA)
|
| 22 |
+
model_r = joblib.load(MODEL_PATH_R)
|
| 23 |
+
print("Models loaded successfully.")
|
| 24 |
+
if not (hasattr(model_na, 'predict') and not hasattr(model_na, 'predict_proba')):
|
| 25 |
+
print(f"Warning: Model {MODEL_PATH_NA} might not be a regressor.")
|
| 26 |
+
if not (hasattr(model_r, 'predict') and not hasattr(model_r, 'predict_proba')):
|
| 27 |
+
print(f"Warning: Model {MODEL_PATH_R} might not be a regressor.")
|
| 28 |
+
|
| 29 |
+
except FileNotFoundError as e:
|
| 30 |
+
print(f"FATAL ERROR: Model file not found: {e}. Make sure the .pkl files are in the correct directory.")
|
| 31 |
+
print("Attempted paths:", os.path.abspath(MODEL_PATH_NA), os.path.abspath(MODEL_PATH_R))
|
| 32 |
+
exit()
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f"FATAL ERROR: An unexpected error occurred loading models: {e}")
|
| 35 |
+
traceback.print_exc()
|
| 36 |
+
exit()
|
| 37 |
+
|
| 38 |
+
# --- Mappings ---
|
| 39 |
+
gender_map = {'Female': 0, 'Male': 1}
|
| 40 |
+
why_map = {
|
| 41 |
+
'Bone': 0, 'Brain': 1, 'Heart': 2, 'Infection': 3,
|
| 42 |
+
'Lung': 4, 'Stomach': 5, 'Surgery': 6
|
| 43 |
+
}
|
| 44 |
+
yes_no_map = {'Yes': 1, 'No': 0}
|
| 45 |
+
reverse_gender_map = {v: k for k, v in gender_map.items()}
|
| 46 |
+
reverse_why_map = {v: k for k, v in why_map.items()}
|
| 47 |
+
reverse_yes_no_map = {v: k for k, v in yes_no_map.items()}
|
| 48 |
+
|
| 49 |
+
# --- Feature Order (Crucial for model input) ---
|
| 50 |
+
feature_order = [
|
| 51 |
+
'Age', 'Gender', 'Why in Hospital', 'Hospital Days', 'Was in ICU (1=Yes)',
|
| 52 |
+
'ICU Days', 'Number of Medicines', 'Cost per Medicine (₹)', 'Days Medicine Lasts',
|
| 53 |
+
'Total Dosage per Day (mg)', 'Total Pills Given', 'Medicine Availability (0-1)',
|
| 54 |
+
'Took Medicine Day 1 (1=Yes)', 'Took Medicine Day 2 (1=Yes)', 'Took Medicine Day 3 (1=Yes)'
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
# --- Risk Level Logic ---
|
| 58 |
+
def get_risk_level(score):
|
| 59 |
+
"""Categorizes risk score into Low, Medium, High."""
|
| 60 |
+
try:
|
| 61 |
+
clamped_score = max(0.0, min(1.0, float(score)))
|
| 62 |
+
except (ValueError, TypeError):
|
| 63 |
+
print(f"Warning: Invalid score '{score}' received. Defaulting to 0.")
|
| 64 |
+
clamped_score = 0.0
|
| 65 |
+
|
| 66 |
+
percentage = round(clamped_score * 100)
|
| 67 |
+
if clamped_score < 0.3:
|
| 68 |
+
return {'level': 'Low', 'color': 'green', 'percentage': f"{percentage}%"}
|
| 69 |
+
elif clamped_score < 0.7:
|
| 70 |
+
return {'level': 'Medium', 'color': 'yellow', 'percentage': f"{percentage}%"}
|
| 71 |
+
else:
|
| 72 |
+
return {'level': 'High', 'color': 'red', 'percentage': f"{percentage}%"}
|
| 73 |
+
|
| 74 |
+
# --- Feature Metadata for UI ---
|
| 75 |
+
# (Keep the feature_info dictionary as it was in the previous version)
|
| 76 |
+
feature_info = {
|
| 77 |
+
'Age': {
|
| 78 |
+
'description': 'Patient age in years',
|
| 79 |
+
'question': 'How old is the patient?',
|
| 80 |
+
'help_text': 'Age is a significant factor in both medication adherence and hospital readmission risk.',
|
| 81 |
+
'ideal_range': '18-100', 'type': 'number'
|
| 82 |
+
},
|
| 83 |
+
'Gender': {
|
| 84 |
+
'description': 'Patient gender',
|
| 85 |
+
'question': 'What is the patient\'s gender?',
|
| 86 |
+
'help_text': 'Gender can influence medication adherence patterns and readmission risk for certain conditions.',
|
| 87 |
+
'options': list(gender_map.keys()), 'type': 'select'
|
| 88 |
+
},
|
| 89 |
+
'Why in Hospital': {
|
| 90 |
+
'description': 'Primary reason for hospitalization',
|
| 91 |
+
'question': 'What is the primary reason for hospitalization?',
|
| 92 |
+
'help_text': 'Different conditions have varying impacts on adherence and readmission patterns.',
|
| 93 |
+
'options': list(why_map.keys()), 'type': 'select'
|
| 94 |
+
},
|
| 95 |
+
'Hospital Days': {
|
| 96 |
+
'description': 'Total days spent in hospital during this admission',
|
| 97 |
+
'question': 'How many days did the patient spend in the hospital?',
|
| 98 |
+
'help_text': 'Longer hospital stays often correlate with more complex cases and higher readmission risks.',
|
| 99 |
+
'ideal_range': '1-30', 'type': 'number'
|
| 100 |
+
},
|
| 101 |
+
'Was in ICU (1=Yes)': {
|
| 102 |
+
'description': 'Whether patient spent time in ICU',
|
| 103 |
+
'question': 'Did the patient spend time in the ICU?',
|
| 104 |
+
'help_text': 'ICU stays indicate higher severity and may impact post-discharge outcomes.',
|
| 105 |
+
'options': list(yes_no_map.keys()), 'type': 'select'
|
| 106 |
+
},
|
| 107 |
+
'ICU Days': {
|
| 108 |
+
'description': 'Total days spent in ICU (if applicable)',
|
| 109 |
+
'question': 'How many days did the patient spend in the ICU? (Enter 0 if not in ICU)',
|
| 110 |
+
'help_text': 'Longer ICU stays typically indicate more severe conditions requiring careful post-discharge planning.',
|
| 111 |
+
'ideal_range': '0-15', 'type': 'number'
|
| 112 |
+
},
|
| 113 |
+
'Number of Medicines': {
|
| 114 |
+
'description': 'Total number of different medications prescribed',
|
| 115 |
+
'question': 'How many different medications is the patient prescribed?',
|
| 116 |
+
'help_text': 'Higher medication counts increase complexity and can lead to reduced adherence.',
|
| 117 |
+
'ideal_range': '1-12', 'type': 'number'
|
| 118 |
+
},
|
| 119 |
+
'Cost per Medicine (₹)': {
|
| 120 |
+
'description': 'Average cost per medication in rupees',
|
| 121 |
+
'question': 'What is the average cost per medication (in ₹)?',
|
| 122 |
+
'help_text': 'Higher medication costs can impact adherence due to financial constraints.',
|
| 123 |
+
'ideal_range': '10-1000', 'type': 'number'
|
| 124 |
+
},
|
| 125 |
+
'Days Medicine Lasts': {
|
| 126 |
+
'description': 'Number of days the prescribed medication will last',
|
| 127 |
+
'question': 'How many days will the prescribed medication last?',
|
| 128 |
+
'help_text': 'Longer durations between refills can affect adherence patterns.',
|
| 129 |
+
'ideal_range': '7-90', 'type': 'number'
|
| 130 |
+
},
|
| 131 |
+
'Total Dosage per Day (mg)': {
|
| 132 |
+
'description': 'Total medication dosage per day in milligrams',
|
| 133 |
+
'question': 'What is the total medication dosage per day (in mg)?',
|
| 134 |
+
'help_text': 'Higher daily dosages may indicate more severe conditions and can affect adherence.',
|
| 135 |
+
'ideal_range': '5-500', 'type': 'number'
|
| 136 |
+
},
|
| 137 |
+
'Total Pills Given': {
|
| 138 |
+
'description': 'Total number of pills provided at discharge',
|
| 139 |
+
'question': 'How many total pills were given to the patient?',
|
| 140 |
+
'help_text': 'Pill burden is a known factor in medication adherence.',
|
| 141 |
+
'ideal_range': '10-300', 'type': 'number'
|
| 142 |
+
},
|
| 143 |
+
'Medicine Availability (0-1)': {
|
| 144 |
+
'description': 'Availability score of prescribed medication (0=low, 1=high)',
|
| 145 |
+
'question': 'How available is the medication (0=low, 1=high)?',
|
| 146 |
+
'help_text': 'Limited availability can significantly impact medication adherence.',
|
| 147 |
+
'ideal_range': '0-1', 'type': 'number', 'step': '0.01' # Specify step for float input
|
| 148 |
+
},
|
| 149 |
+
'Took Medicine Day 1 (1=Yes)': {
|
| 150 |
+
'description': 'Whether patient took medication on day 1 post-discharge',
|
| 151 |
+
'question': 'Did the patient take their medication on day 1 after discharge?',
|
| 152 |
+
'help_text': 'Early adherence patterns are strong predictors of overall medication adherence.',
|
| 153 |
+
'options': list(yes_no_map.keys()), 'type': 'select'
|
| 154 |
+
},
|
| 155 |
+
'Took Medicine Day 2 (1=Yes)': {
|
| 156 |
+
'description': 'Whether patient took medication on day 2 post-discharge',
|
| 157 |
+
'question': 'Did the patient take their medication on day 2 after discharge?',
|
| 158 |
+
'help_text': 'Consistent adherence in the first days after discharge indicates better overall adherence.',
|
| 159 |
+
'options': list(yes_no_map.keys()), 'type': 'select'
|
| 160 |
+
},
|
| 161 |
+
'Took Medicine Day 3 (1=Yes)': {
|
| 162 |
+
'description': 'Whether patient took medication on day 3 post-discharge',
|
| 163 |
+
'question': 'Did the patient take their medication on day 3 after discharge?',
|
| 164 |
+
'help_text': 'Patterns established in the first few days often continue throughout treatment.',
|
| 165 |
+
'options': list(yes_no_map.keys()), 'type': 'select'
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
# --- Risk Explanations for UI ---
|
| 170 |
+
# (Keep the risk_explanations dictionary as it was)
|
| 171 |
+
risk_explanations = {
|
| 172 |
+
'non_adherence': {
|
| 173 |
+
'title': 'Medication Non-Adherence Risk',
|
| 174 |
+
'description': 'Medication non-adherence refers to the degree to which a patient does not follow their prescription medication regimen as directed by their healthcare provider. This includes missing doses, taking incorrect doses, or stopping treatment early.',
|
| 175 |
+
'levels': {
|
| 176 |
+
'Low': 'Patient is likely to follow medication regimen as prescribed with minimal intervention needed.',
|
| 177 |
+
'Medium': 'Patient may need additional support such as reminders or follow-up calls to ensure adherence.',
|
| 178 |
+
'High': 'Patient is at significant risk of not taking medications as prescribed. Consider intensive interventions and close monitoring.'
|
| 179 |
+
},
|
| 180 |
+
'consequences': [
|
| 181 |
+
'Reduced treatment effectiveness', 'Disease progression or complications',
|
| 182 |
+
'Increased hospitalization rates', 'Higher healthcare costs', 'Poorer health outcomes'
|
| 183 |
+
],
|
| 184 |
+
'interventions': [
|
| 185 |
+
'Medication reminder systems', 'Simplified medication regimens',
|
| 186 |
+
'Patient education on importance of adherence', 'Regular follow-up calls',
|
| 187 |
+
'Addressing barriers (financial, logistical, etc.)'
|
| 188 |
+
]
|
| 189 |
+
},
|
| 190 |
+
'readmission': {
|
| 191 |
+
'title': 'Hospital Readmission Risk',
|
| 192 |
+
'description': 'Hospital readmission risk refers to the likelihood that a patient will need to return to the hospital within a short period (typically 30 days) after being discharged.',
|
| 193 |
+
'levels': {
|
| 194 |
+
'Low': 'Patient has minimal risk factors for readmission and can likely be managed with standard follow-up care.',
|
| 195 |
+
'Medium': 'Patient has moderate risk of readmission and may benefit from enhanced discharge planning and follow-up.',
|
| 196 |
+
'High': 'Patient is at high risk for readmission and requires comprehensive discharge planning, early follow-up, and possibly home health services.'
|
| 197 |
+
},
|
| 198 |
+
'consequences': [
|
| 199 |
+
'Increased patient suffering', 'Higher healthcare costs', 'Potential complications',
|
| 200 |
+
'Disruption to patient recovery', 'Reduced hospital quality metrics'
|
| 201 |
+
],
|
| 202 |
+
'interventions': [
|
| 203 |
+
'Comprehensive discharge planning', 'Medication reconciliation',
|
| 204 |
+
'Early (within 7 days) follow-up appointments', 'Home health services when appropriate',
|
| 205 |
+
'Patient and caregiver education'
|
| 206 |
+
]
|
| 207 |
+
}
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# --- SHAP Explainer Initialization ---
|
| 212 |
+
# (Keep the get_shap_explainer function as it was)
|
| 213 |
+
shap_explainers = {}
|
| 214 |
+
def get_shap_explainer(model_key, model):
|
| 215 |
+
"""Gets or creates a SHAP explainer for a given model."""
|
| 216 |
+
global shap_explainers
|
| 217 |
+
if model_key not in shap_explainers:
|
| 218 |
+
print(f"Initializing SHAP explainer for {model_key}...")
|
| 219 |
+
try:
|
| 220 |
+
if hasattr(model, 'feature_importances_') or 'xgboost' in str(type(model)).lower():
|
| 221 |
+
explainer = shap.TreeExplainer(model)
|
| 222 |
+
shap_explainers[model_key] = explainer
|
| 223 |
+
print(f"SHAP TreeExplainer for {model_key} initialized.")
|
| 224 |
+
else:
|
| 225 |
+
print(f"Warning: Model for {model_key} might not be a tree model. Using generic SHAP Explainer.")
|
| 226 |
+
try:
|
| 227 |
+
explainer = shap.Explainer(model)
|
| 228 |
+
shap_explainers[model_key] = explainer
|
| 229 |
+
print(f"Initialized generic SHAP Explainer for {model_key}.")
|
| 230 |
+
except Exception as gen_e:
|
| 231 |
+
print(f"ERROR initializing generic SHAP explainer for {model_key}: {gen_e}")
|
| 232 |
+
return None
|
| 233 |
+
except Exception as e:
|
| 234 |
+
print(f"ERROR initializing SHAP Explainer for {model_key}: {e}")
|
| 235 |
+
traceback.print_exc()
|
| 236 |
+
return None
|
| 237 |
+
return shap_explainers.get(model_key)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# --- Flask Routes ---
|
| 241 |
+
@app.route('/')
|
| 242 |
+
def home():
|
| 243 |
+
"""Renders the main input form page."""
|
| 244 |
+
return render_template('index.html', feature_info=feature_info, feature_order=feature_order, risk_explanations=risk_explanations)
|
| 245 |
+
|
| 246 |
+
@app.route('/predict', methods=['POST'])
|
| 247 |
+
def predict():
|
| 248 |
+
"""Handles prediction requests, generates explanations, and returns JSON."""
|
| 249 |
+
try:
|
| 250 |
+
data = request.get_json()
|
| 251 |
+
if not data:
|
| 252 |
+
return jsonify({'success': False, 'error': 'No data received'}), 400
|
| 253 |
+
print("Received data:", data)
|
| 254 |
+
|
| 255 |
+
# --- Input Processing and Validation ---
|
| 256 |
+
user_input = {}
|
| 257 |
+
# Use the more robust key mapping based on getFieldId in JS
|
| 258 |
+
js_key_map = {
|
| 259 |
+
'age': 'Age', 'gender': 'Gender', 'why-in-hospital': 'Why in Hospital',
|
| 260 |
+
'hospital-days': 'Hospital Days', 'was-in-icu-1yes': 'Was in ICU (1=Yes)',
|
| 261 |
+
'icu-days': 'ICU Days', 'number-of-medicines': 'Number of Medicines',
|
| 262 |
+
'cost-per-medicine-rupees': 'Cost per Medicine (₹)', 'days-medicine-lasts': 'Days Medicine Lasts',
|
| 263 |
+
'total-dosage-per-day-mg': 'Total Dosage per Day (mg)', 'total-pills-given': 'Total Pills Given',
|
| 264 |
+
'medicine-availability-0-1': 'Medicine Availability (0-1)',
|
| 265 |
+
'took-medicine-day-1-1yes': 'Took Medicine Day 1 (1=Yes)',
|
| 266 |
+
'took-medicine-day-2-1yes': 'Took Medicine Day 2 (1=Yes)',
|
| 267 |
+
'took-medicine-day-3-1yes': 'Took Medicine Day 3 (1=Yes)'
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
missing_features = []
|
| 271 |
+
invalid_features = {}
|
| 272 |
+
|
| 273 |
+
# Get 'Was in ICU' value first for conditional validation
|
| 274 |
+
was_icu_js_key = 'was-in-icu-1yes'
|
| 275 |
+
was_icu_value_str = data.get(was_icu_js_key, '').capitalize()
|
| 276 |
+
|
| 277 |
+
for js_key, feature in js_key_map.items():
|
| 278 |
+
if feature not in feature_order:
|
| 279 |
+
print(f"Warning: Feature '{feature}' from js_key_map not in expected feature_order list.")
|
| 280 |
+
continue
|
| 281 |
+
|
| 282 |
+
is_icu_days = (feature == 'ICU Days')
|
| 283 |
+
is_icu_no = (was_icu_value_str == 'No')
|
| 284 |
+
|
| 285 |
+
# Handle missing values
|
| 286 |
+
if js_key not in data or data[js_key] is None or str(data[js_key]).strip() == '':
|
| 287 |
+
# ICU days is allowed to be missing/empty only if Was in ICU is 'No'
|
| 288 |
+
if is_icu_days and is_icu_no:
|
| 289 |
+
user_input[feature] = 0.0 # Default to 0
|
| 290 |
+
print(f"Setting ICU Days to 0 as Was in ICU is No.")
|
| 291 |
+
continue
|
| 292 |
+
else:
|
| 293 |
+
missing_features.append(feature)
|
| 294 |
+
continue
|
| 295 |
+
|
| 296 |
+
value = data[js_key]
|
| 297 |
+
f_info = feature_info.get(feature)
|
| 298 |
+
if not f_info:
|
| 299 |
+
print(f"Warning: No feature_info found for '{feature}'. Skipping.")
|
| 300 |
+
continue
|
| 301 |
+
|
| 302 |
+
try:
|
| 303 |
+
# Type conversion and mapping
|
| 304 |
+
if f_info['type'] == 'number':
|
| 305 |
+
# Specific check for ICU Days if Was in ICU is Yes
|
| 306 |
+
if is_icu_days and not is_icu_no and float(value) < 0:
|
| 307 |
+
invalid_features[feature] = "ICU Days cannot be negative if patient was in ICU."
|
| 308 |
+
elif is_icu_days and is_icu_no and float(value) != 0:
|
| 309 |
+
# If they entered a non-zero value but said No ICU, force to 0
|
| 310 |
+
user_input[feature] = 0.0
|
| 311 |
+
print(f"Warning: Forcing ICU Days to 0 because Was in ICU is No, but user entered {value}.")
|
| 312 |
+
else:
|
| 313 |
+
user_input[feature] = float(value)
|
| 314 |
+
|
| 315 |
+
elif f_info['type'] == 'select':
|
| 316 |
+
lookup_value = value.capitalize() if feature in ['Was in ICU (1=Yes)',
|
| 317 |
+
'Took Medicine Day 1 (1=Yes)',
|
| 318 |
+
'Took Medicine Day 2 (1=Yes)',
|
| 319 |
+
'Took Medicine Day 3 (1=Yes)'] else value
|
| 320 |
+
mapped_value = None
|
| 321 |
+
if feature == 'Gender': mapped_value = gender_map.get(lookup_value)
|
| 322 |
+
elif feature == 'Why in Hospital': mapped_value = why_map.get(lookup_value)
|
| 323 |
+
elif feature in ['Was in ICU (1=Yes)', 'Took Medicine Day 1 (1=Yes)',
|
| 324 |
+
'Took Medicine Day 2 (1=Yes)', 'Took Medicine Day 3 (1=Yes)']:
|
| 325 |
+
mapped_value = yes_no_map.get(lookup_value)
|
| 326 |
+
|
| 327 |
+
if mapped_value is None:
|
| 328 |
+
invalid_features[feature] = f"Invalid option: '{value}'"
|
| 329 |
+
else:
|
| 330 |
+
user_input[feature] = mapped_value
|
| 331 |
+
else:
|
| 332 |
+
invalid_features[feature] = f"Unknown feature type: '{f_info['type']}'"
|
| 333 |
+
|
| 334 |
+
except (ValueError, TypeError) as e:
|
| 335 |
+
invalid_features[feature] = f"Invalid format for value '{value}': {e}"
|
| 336 |
+
|
| 337 |
+
# Handle Validation Errors
|
| 338 |
+
if missing_features:
|
| 339 |
+
return jsonify({'success': False, 'error': f"Missing input for: {', '.join(missing_features)}"}), 400
|
| 340 |
+
if invalid_features:
|
| 341 |
+
error_msg = "; ".join([f"{k}: {v}" for k,v in invalid_features.items()])
|
| 342 |
+
return jsonify({'success': False, 'error': f"Invalid input: {error_msg}"}), 400
|
| 343 |
+
|
| 344 |
+
# Ensure all features are present before creating DataFrame
|
| 345 |
+
if len(user_input) != len(feature_order):
|
| 346 |
+
provided = set(user_input.keys())
|
| 347 |
+
expected = set(feature_order)
|
| 348 |
+
missing = list(expected - provided)
|
| 349 |
+
extra = list(provided - expected)
|
| 350 |
+
err = f"Feature mismatch after processing. Missing: {missing}. Extra: {extra}."
|
| 351 |
+
print(f"ERROR: {err}")
|
| 352 |
+
return jsonify({'success': False, 'error': f"Internal error: Feature mismatch. Please check processing logic. Missing: {missing}"}), 500
|
| 353 |
+
|
| 354 |
+
# Create DataFrame in the correct order
|
| 355 |
+
df_user = pd.DataFrame([user_input], columns=feature_order)
|
| 356 |
+
print("Processed DataFrame for prediction:\n", df_user.to_string())
|
| 357 |
+
|
| 358 |
+
# --- Model Predictions ---
|
| 359 |
+
pred_na_raw = model_na.predict(df_user)[0]
|
| 360 |
+
pred_r_raw = model_r.predict(df_user)[0]
|
| 361 |
+
pred_na_score = max(0.0, min(1.0, float(pred_na_raw)))
|
| 362 |
+
pred_r_score = max(0.0, min(1.0, float(pred_r_raw)))
|
| 363 |
+
print(f"Predictions - NA Score: {pred_na_score:.4f}, R Score: {pred_r_score:.4f}")
|
| 364 |
+
risk_level_na = get_risk_level(pred_na_score)
|
| 365 |
+
risk_level_r = get_risk_level(pred_r_score)
|
| 366 |
+
|
| 367 |
+
# --- SHAP Explanations ---
|
| 368 |
+
# (Keep the SHAP calculation block largely the same as previous version,
|
| 369 |
+
# ensuring base_value_na/r and shap_error_na/r are set correctly)
|
| 370 |
+
shap_explainer_na = get_shap_explainer('non_adherence', model_na)
|
| 371 |
+
shap_explainer_r = get_shap_explainer('readmission', model_r)
|
| 372 |
+
shap_data_na = []
|
| 373 |
+
shap_data_r = []
|
| 374 |
+
base_value_na = None
|
| 375 |
+
base_value_r = None
|
| 376 |
+
shap_error_na = False
|
| 377 |
+
shap_error_r = False
|
| 378 |
+
|
| 379 |
+
# Calculate SHAP NA
|
| 380 |
+
if shap_explainer_na:
|
| 381 |
+
try:
|
| 382 |
+
shap_values_na = shap_explainer_na.shap_values(df_user)
|
| 383 |
+
shap_vec_na = None
|
| 384 |
+
if isinstance(shap_values_na, list): shap_vec_na = shap_values_na[0][0] # Multi-output?
|
| 385 |
+
elif isinstance(shap_values_na, np.ndarray) and shap_values_na.ndim == 2: shap_vec_na = shap_values_na[0]
|
| 386 |
+
elif isinstance(shap_values_na, np.ndarray) and shap_values_na.ndim == 1: shap_vec_na = shap_values_na
|
| 387 |
+
else: raise TypeError(f"Unexpected SHAP NA format: {type(shap_values_na)}")
|
| 388 |
+
|
| 389 |
+
ev_na = shap_explainer_na.expected_value
|
| 390 |
+
if isinstance(ev_na, (list, np.ndarray)): base_value_na = float(ev_na[0])
|
| 391 |
+
elif ev_na is not None: base_value_na = float(ev_na)
|
| 392 |
+
else: base_value_na = 0.5; print("Warning: SHAP NA expected_value is None. Using 0.5.")
|
| 393 |
+
|
| 394 |
+
if shap_vec_na is not None and len(shap_vec_na) == len(feature_order):
|
| 395 |
+
print(f"SHAP NA: Base={base_value_na:.4f}, Sum={np.sum(shap_vec_na):.4f}, Pred={pred_na_score:.4f}, Total={base_value_na + np.sum(shap_vec_na):.4f}")
|
| 396 |
+
for i, feature in enumerate(feature_order):
|
| 397 |
+
orig_js_key = next((k for k, v in js_key_map.items() if v == feature), None)
|
| 398 |
+
orig_val = data.get(orig_js_key, "N/A")
|
| 399 |
+
f_info = feature_info.get(feature, {})
|
| 400 |
+
shap_data_na.append({
|
| 401 |
+
'feature': feature, 'shap_value': float(shap_vec_na[i]),
|
| 402 |
+
'feature_value': str(orig_val), 'numeric_value': float(df_user.iloc[0, i]),
|
| 403 |
+
'description': f_info.get('description', ''), 'help_text': f_info.get('help_text', '')
|
| 404 |
+
})
|
| 405 |
+
shap_data_na.sort(key=lambda x: abs(x.get('shap_value', 0)), reverse=True)
|
| 406 |
+
else: raise ValueError(f"SHAP NA vector length mismatch or None.")
|
| 407 |
+
except Exception as e:
|
| 408 |
+
print(f"Error calculating SHAP NA: {e}"); traceback.print_exc()
|
| 409 |
+
shap_data_na = [{"error": "Could not calculate SHAP values for Non-Adherence."}]
|
| 410 |
+
base_value_na = 0.5; shap_error_na = True
|
| 411 |
+
else:
|
| 412 |
+
shap_data_na = [{"error": "SHAP explainer NA not available."}]; base_value_na = 0.5; shap_error_na = True
|
| 413 |
+
|
| 414 |
+
# Calculate SHAP R
|
| 415 |
+
if shap_explainer_r:
|
| 416 |
+
try:
|
| 417 |
+
shap_values_r = shap_explainer_r.shap_values(df_user)
|
| 418 |
+
shap_vec_r = None
|
| 419 |
+
if isinstance(shap_values_r, list): shap_vec_r = shap_values_r[0][0]
|
| 420 |
+
elif isinstance(shap_values_r, np.ndarray) and shap_values_r.ndim == 2: shap_vec_r = shap_values_r[0]
|
| 421 |
+
elif isinstance(shap_values_r, np.ndarray) and shap_values_r.ndim == 1: shap_vec_r = shap_values_r
|
| 422 |
+
else: raise TypeError(f"Unexpected SHAP R format: {type(shap_values_r)}")
|
| 423 |
+
|
| 424 |
+
ev_r = shap_explainer_r.expected_value
|
| 425 |
+
if isinstance(ev_r, (list, np.ndarray)): base_value_r = float(ev_r[0])
|
| 426 |
+
elif ev_r is not None: base_value_r = float(ev_r)
|
| 427 |
+
else: base_value_r = 0.5; print("Warning: SHAP R expected_value is None. Using 0.5.")
|
| 428 |
+
|
| 429 |
+
if shap_vec_r is not None and len(shap_vec_r) == len(feature_order):
|
| 430 |
+
print(f"SHAP R: Base={base_value_r:.4f}, Sum={np.sum(shap_vec_r):.4f}, Pred={pred_r_score:.4f}, Total={base_value_r + np.sum(shap_vec_r):.4f}")
|
| 431 |
+
for i, feature in enumerate(feature_order):
|
| 432 |
+
orig_js_key = next((k for k, v in js_key_map.items() if v == feature), None)
|
| 433 |
+
orig_val = data.get(orig_js_key, "N/A")
|
| 434 |
+
f_info = feature_info.get(feature, {})
|
| 435 |
+
shap_data_r.append({
|
| 436 |
+
'feature': feature, 'shap_value': float(shap_vec_r[i]),
|
| 437 |
+
'feature_value': str(orig_val), 'numeric_value': float(df_user.iloc[0, i]),
|
| 438 |
+
'description': f_info.get('description', ''), 'help_text': f_info.get('help_text', '')
|
| 439 |
+
})
|
| 440 |
+
shap_data_r.sort(key=lambda x: abs(x.get('shap_value', 0)), reverse=True)
|
| 441 |
+
else: raise ValueError(f"SHAP R vector length mismatch or None.")
|
| 442 |
+
except Exception as e:
|
| 443 |
+
print(f"Error calculating SHAP R: {e}"); traceback.print_exc()
|
| 444 |
+
shap_data_r = [{"error": "Could not calculate SHAP values for Readmission."}]
|
| 445 |
+
base_value_r = 0.5; shap_error_r = True
|
| 446 |
+
else:
|
| 447 |
+
shap_data_r = [{"error": "SHAP explainer R not available."}]; base_value_r = 0.5; shap_error_r = True
|
| 448 |
+
|
| 449 |
+
# --- Counterfactuals & Recommendations ---
|
| 450 |
+
cf_data_na = generate_comprehensive_counterfactuals(df_user, model_na, "Non-Adherence", shap_data_na, shap_error_na)
|
| 451 |
+
cf_data_r = generate_comprehensive_counterfactuals(df_user, model_r, "Readmission", shap_data_r, shap_error_r)
|
| 452 |
+
recommendations = generate_recommendations(shap_data_na, shap_data_r, pred_na_score, pred_r_score, shap_error_na, shap_error_r)
|
| 453 |
+
|
| 454 |
+
# --- Visualizations ---
|
| 455 |
+
gauges = generate_gauge_charts(pred_na_score, pred_r_score)
|
| 456 |
+
additional_visualizations = generate_additional_visualizations(
|
| 457 |
+
df_user, shap_data_na, shap_data_r, base_value_na, base_value_r, shap_error_na, shap_error_r
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
# --- Response ---
|
| 461 |
+
response = {
|
| 462 |
+
'success': True,
|
| 463 |
+
'predictions': {
|
| 464 |
+
'non_adherence': round(pred_na_score, 3),
|
| 465 |
+
'readmission': round(pred_r_score, 3),
|
| 466 |
+
'risk_level_na': risk_level_na,
|
| 467 |
+
'risk_level_r': risk_level_r,
|
| 468 |
+
'base_value_na': round(base_value_na, 3) if base_value_na is not None else None,
|
| 469 |
+
'base_value_r': round(base_value_r, 3) if base_value_r is not None else None
|
| 470 |
+
},
|
| 471 |
+
'explanations': {
|
| 472 |
+
'shap_values_na': shap_data_na,
|
| 473 |
+
'shap_values_r': shap_data_r,
|
| 474 |
+
'counterfactuals_na': cf_data_na,
|
| 475 |
+
'counterfactuals_r': cf_data_r,
|
| 476 |
+
'shap_error_na': shap_error_na,
|
| 477 |
+
'shap_error_r': shap_error_r
|
| 478 |
+
},
|
| 479 |
+
'recommendations': recommendations,
|
| 480 |
+
'visualizations': {
|
| 481 |
+
'gauges': gauges,
|
| 482 |
+
'additional': additional_visualizations
|
| 483 |
+
}
|
| 484 |
+
}
|
| 485 |
+
return jsonify(response)
|
| 486 |
+
|
| 487 |
+
# Error Handling remains the same as previous version
|
| 488 |
+
except ValueError as ve:
|
| 489 |
+
print(f"Value Error during prediction processing: {ve}")
|
| 490 |
+
traceback.print_exc()
|
| 491 |
+
return jsonify({'success': False, 'error': f"Invalid input: {str(ve)}"}), 400
|
| 492 |
+
except KeyError as ke:
|
| 493 |
+
print(f"Key Error during prediction processing: {ke}")
|
| 494 |
+
traceback.print_exc()
|
| 495 |
+
return jsonify({'success': False, 'error': f"Missing expected data field: {str(ke)}"}), 400
|
| 496 |
+
except Exception as e:
|
| 497 |
+
print(f"An unexpected error occurred during prediction: {e}")
|
| 498 |
+
traceback.print_exc()
|
| 499 |
+
return jsonify({'success': False, 'error': "An internal server error occurred. Please try again later."}), 500
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# --- Placeholder JSON for Plots on Error ---
|
| 503 |
+
# (Keep create_placeholder_plot as it was)
|
| 504 |
+
def create_placeholder_plot(title_suffix, message="Required data not available for this chart."):
|
| 505 |
+
"""Creates JSON for a placeholder plot indicating data unavailability."""
|
| 506 |
+
return json.dumps({
|
| 507 |
+
'data': [],
|
| 508 |
+
'layout': {
|
| 509 |
+
'title': f'{title_suffix} (Not Generated)',
|
| 510 |
+
'xaxis': {'visible': False},
|
| 511 |
+
'yaxis': {'visible': False},
|
| 512 |
+
'annotations': [{
|
| 513 |
+
'text': message,
|
| 514 |
+
'xref': 'paper', 'yref': 'paper',
|
| 515 |
+
'x': 0.5, 'y': 0.5, 'showarrow': False,
|
| 516 |
+
'font': {'size': 12, 'color': '#888'}
|
| 517 |
+
}],
|
| 518 |
+
'plot_bgcolor': 'rgba(0,0,0,0)',
|
| 519 |
+
'paper_bgcolor': 'rgba(0,0,0,0)'
|
| 520 |
+
}
|
| 521 |
+
}, cls=plotly.utils.PlotlyJSONEncoder)
|
| 522 |
+
|
| 523 |
+
# --- Counterfactual Generation ---
|
| 524 |
+
# (Keep generate_comprehensive_counterfactuals as it was)
|
| 525 |
+
def generate_comprehensive_counterfactuals(df_original, model, target_type, shap_data, shap_error):
|
| 526 |
+
"""Generate simplified counterfactuals based on SHAP values and feature modifiability."""
|
| 527 |
+
print(f"Generating counterfactuals for {target_type}...")
|
| 528 |
+
counterfactuals = []
|
| 529 |
+
if shap_error or not isinstance(shap_data, list) or not shap_data or (shap_data[0] and shap_data[0].get("error")):
|
| 530 |
+
print(f"Skipping counterfactuals for {target_type} due to SHAP errors or missing data.")
|
| 531 |
+
return [{"error": f"Could not generate counterfactuals for {target_type} because SHAP data is missing or invalid."}]
|
| 532 |
+
|
| 533 |
+
df = df_original.copy()
|
| 534 |
+
try:
|
| 535 |
+
current_pred_score = max(0.0, min(1.0, float(model.predict(df)[0])))
|
| 536 |
+
goal_direction = "decrease"
|
| 537 |
+
non_modifiable_features = {"Age", "Gender", "Why in Hospital", "Hospital Days", "Was in ICU (1=Yes)", "ICU Days"}
|
| 538 |
+
valid_shap_data = [item for item in shap_data if "error" not in item]
|
| 539 |
+
shap_data_sorted = sorted(valid_shap_data, key=lambda x: abs(x.get('shap_value', 0)), reverse=True)
|
| 540 |
+
cf_count = 0
|
| 541 |
+
max_cfs = 5
|
| 542 |
+
|
| 543 |
+
for feature_data in shap_data_sorted:
|
| 544 |
+
if cf_count >= max_cfs: break
|
| 545 |
+
feature = feature_data.get('feature')
|
| 546 |
+
shap_value = feature_data.get('shap_value', 0)
|
| 547 |
+
current_display_value = feature_data.get('feature_value', 'N/A')
|
| 548 |
+
original_numeric_value = feature_data.get('numeric_value')
|
| 549 |
+
|
| 550 |
+
if feature is None or original_numeric_value is None or abs(shap_value) < 0.01 or feature in non_modifiable_features:
|
| 551 |
+
continue
|
| 552 |
+
|
| 553 |
+
feature_change_direction = "decrease" if shap_value > 0 else "increase"
|
| 554 |
+
could_change = False
|
| 555 |
+
new_value = None
|
| 556 |
+
suggested_val_str = ""
|
| 557 |
+
potential_outcome = "N/A"
|
| 558 |
+
notes = ""
|
| 559 |
+
impact_magnitude = "Minor"
|
| 560 |
+
|
| 561 |
+
# --- Specific Logic for Modifiable Features ---
|
| 562 |
+
if feature.startswith("Took Medicine Day"):
|
| 563 |
+
if original_numeric_value == 0 and shap_value > 0.01:
|
| 564 |
+
suggested_val_str = "Ensure Adherence ('Yes')"
|
| 565 |
+
new_value = 1
|
| 566 |
+
could_change = True
|
| 567 |
+
elif feature == "Medicine Availability (0-1)":
|
| 568 |
+
if shap_value < -0.01 and original_numeric_value < 0.95:
|
| 569 |
+
new_value = 1.0
|
| 570 |
+
suggested_val_str = "Improve towards High (1.0)"
|
| 571 |
+
could_change = True
|
| 572 |
+
else: # Other numeric modifiable features
|
| 573 |
+
if (feature_change_direction == "decrease" and shap_value > 0.02) or \
|
| 574 |
+
(feature_change_direction == "increase" and shap_value < -0.02):
|
| 575 |
+
change_perc = 0.20
|
| 576 |
+
target_factor = (1 - change_perc) if feature_change_direction == "decrease" else (1 + change_perc)
|
| 577 |
+
tentative_new_value = original_numeric_value * target_factor
|
| 578 |
+
|
| 579 |
+
# Apply constraints and rounding
|
| 580 |
+
if feature in ["Number of Medicines", "Total Pills Given"]: new_value = max(1, math.floor(tentative_new_value))
|
| 581 |
+
elif feature == "Days Medicine Lasts": new_value = max(7, math.floor(tentative_new_value))
|
| 582 |
+
elif feature == "Cost per Medicine (₹)": new_value = max(10.0, round(tentative_new_value, 0))
|
| 583 |
+
elif feature == "Total Dosage per Day (mg)": new_value = max(5.0, round(tentative_new_value, 0))
|
| 584 |
+
else: new_value = round(tentative_new_value, 2) # Fallback
|
| 585 |
+
|
| 586 |
+
# Generate suggested string AFTER setting new_value
|
| 587 |
+
if feature == "Number of Medicines": suggested_val_str = f"{feature_change_direction.capitalize()} towards ~{int(new_value)}"
|
| 588 |
+
elif feature == "Total Pills Given": suggested_val_str = f"{feature_change_direction.capitalize()} towards ~{int(new_value)}"
|
| 589 |
+
elif feature == "Days Medicine Lasts": suggested_val_str = f"{feature_change_direction.capitalize()} towards ~{int(new_value)} days"
|
| 590 |
+
elif feature == "Cost per Medicine (₹)": suggested_val_str = f"{feature_change_direction.capitalize()} towards ~₹{new_value:.0f}"
|
| 591 |
+
elif feature == "Total Dosage per Day (mg)": suggested_val_str = f"{feature_change_direction.capitalize()} towards ~{new_value:.0f} mg"
|
| 592 |
+
else: suggested_val_str = f"{feature_change_direction.capitalize()} towards ~{new_value:.2f}"
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
# Check if change is significant
|
| 596 |
+
if abs(new_value - original_numeric_value) > 0.01 * abs(original_numeric_value) + 0.1:
|
| 597 |
+
could_change = True
|
| 598 |
+
else:
|
| 599 |
+
new_value = None # Don't simulate small changes
|
| 600 |
+
|
| 601 |
+
# --- Simulate Outcome ---
|
| 602 |
+
if could_change and new_value is not None and suggested_val_str:
|
| 603 |
+
df_cf = df.copy(); df_cf[feature] = new_value
|
| 604 |
+
try:
|
| 605 |
+
cf_pred_score = max(0.0, min(1.0, float(model.predict(df_cf)[0])))
|
| 606 |
+
change_in_score = cf_pred_score - current_pred_score
|
| 607 |
+
outcome_desc = f"Est. new risk: {cf_pred_score:.1%}"
|
| 608 |
+
change_desc = f"({change_in_score:+.1%})"
|
| 609 |
+
|
| 610 |
+
if (goal_direction == "decrease" and change_in_score < -0.01):
|
| 611 |
+
potential_outcome = f"{outcome_desc} {change_desc}"
|
| 612 |
+
if abs(change_in_score) > 0.10: impact_magnitude = "Significant"
|
| 613 |
+
elif abs(change_in_score) > 0.05: impact_magnitude = "Moderate"
|
| 614 |
+
notes = f"This change is predicted to {goal_direction} the {target_type.lower()} risk."
|
| 615 |
+
counterfactuals.append({
|
| 616 |
+
"feature": feature, "current_value": current_display_value,
|
| 617 |
+
"suggested_change": suggested_val_str, "potential_outcome": potential_outcome,
|
| 618 |
+
"impact_magnitude": impact_magnitude, "risk_type": target_type
|
| 619 |
+
})
|
| 620 |
+
cf_count += 1
|
| 621 |
+
except Exception as sim_e:
|
| 622 |
+
print(f"Error simulating counterfactual for {feature}: {sim_e}")
|
| 623 |
+
|
| 624 |
+
if not counterfactuals:
|
| 625 |
+
counterfactuals.append({"notes": f"No simple, impactful counterfactual changes identified among top modifiable factors for {target_type} risk."})
|
| 626 |
+
print(f"Finished generating {cf_count} counterfactual entries for {target_type}.")
|
| 627 |
+
return counterfactuals
|
| 628 |
+
except Exception as e:
|
| 629 |
+
print(f"Error generating counterfactuals for {target_type}: {str(e)}")
|
| 630 |
+
traceback.print_exc()
|
| 631 |
+
return [{"error": f"Could not generate counterfactuals for {target_type}: {str(e)}"}]
|
| 632 |
+
|
| 633 |
+
# --- Recommendation Generation ---
|
| 634 |
+
# (Keep generate_recommendations as it was)
|
| 635 |
+
def generate_recommendations(shap_data_na, shap_data_r, pred_na_prob, pred_r_prob, shap_error_na, shap_error_r):
|
| 636 |
+
"""Generate actionable recommendations based on risk levels and SHAP factors."""
|
| 637 |
+
print("Generating recommendations...")
|
| 638 |
+
recommendations = []
|
| 639 |
+
processed_features = set()
|
| 640 |
+
|
| 641 |
+
na_shap_valid = not shap_error_na and isinstance(shap_data_na, list) and shap_data_na and "error" not in shap_data_na[0]
|
| 642 |
+
r_shap_valid = not shap_error_r and isinstance(shap_data_r, list) and shap_data_r and "error" not in shap_data_r[0]
|
| 643 |
+
|
| 644 |
+
na_risk_level_info = get_risk_level(pred_na_prob)
|
| 645 |
+
r_risk_level_info = get_risk_level(pred_r_prob)
|
| 646 |
+
na_level = na_risk_level_info['level']
|
| 647 |
+
r_level = r_risk_level_info['level']
|
| 648 |
+
|
| 649 |
+
# --- General Recommendations ---
|
| 650 |
+
if na_level == 'High':
|
| 651 |
+
recommendations.append({
|
| 652 |
+
"category": "Overall Non-Adherence", "priority": "Critical",
|
| 653 |
+
"recommendation": f"High ({na_risk_level_info['percentage']}) non-adherence risk detected.",
|
| 654 |
+
"action": "Implement intensive adherence support: daily reminders, regimen simplification consult, frequent follow-up (e.g., within 3 days), assess/address specific barriers (cost, access, understanding)."
|
| 655 |
+
})
|
| 656 |
+
elif na_level == 'Medium':
|
| 657 |
+
recommendations.append({
|
| 658 |
+
"category": "Overall Non-Adherence", "priority": "High",
|
| 659 |
+
"recommendation": f"Medium ({na_risk_level_info['percentage']}) non-adherence risk detected.",
|
| 660 |
+
"action": "Provide adherence aids (e.g., pillbox, app reminders), schedule follow-up call within 1 week, reinforce importance of medication."
|
| 661 |
+
})
|
| 662 |
+
if r_level == 'High':
|
| 663 |
+
recommendations.append({
|
| 664 |
+
"category": "Overall Readmission", "priority": "Critical",
|
| 665 |
+
"recommendation": f"High ({r_risk_level_info['percentage']}) readmission risk detected.",
|
| 666 |
+
"action": "Comprehensive discharge plan crucial: schedule follow-up within 7 days, ensure medication reconciliation, consider home health/transitional care referral, detailed patient/caregiver education using teach-back."
|
| 667 |
+
})
|
| 668 |
+
elif r_level == 'Medium':
|
| 669 |
+
recommendations.append({
|
| 670 |
+
"category": "Overall Readmission", "priority": "High",
|
| 671 |
+
"recommendation": f"Medium ({r_risk_level_info['percentage']}) readmission risk detected.",
|
| 672 |
+
"action": "Enhanced discharge process: ensure clear instructions (written/verbal), schedule follow-up within 10-14 days, confirm patient understanding of red flags and who to contact."
|
| 673 |
+
})
|
| 674 |
+
|
| 675 |
+
# --- Specific Recommendations ---
|
| 676 |
+
if not na_shap_valid and not r_shap_valid:
|
| 677 |
+
recommendations.append({
|
| 678 |
+
"category": "Explanations", "priority": "Medium",
|
| 679 |
+
"recommendation": "Detailed factor analysis unavailable due to SHAP errors.",
|
| 680 |
+
"action": "Focus on general risk levels and standard protocols. Investigate SHAP calculation issues if persistent."})
|
| 681 |
+
return recommendations # Early exit
|
| 682 |
+
|
| 683 |
+
combined_shap = {}
|
| 684 |
+
valid_shap_list = ([item for item in shap_data_na if "error" not in item] if na_shap_valid else []) + \
|
| 685 |
+
([item for item in shap_data_r if "error" not in item] if r_shap_valid else [])
|
| 686 |
+
for item in valid_shap_list:
|
| 687 |
+
feature = item.get('feature')
|
| 688 |
+
abs_shap = abs(item.get('shap_value', 0))
|
| 689 |
+
if feature:
|
| 690 |
+
current_entry = combined_shap.get(feature, {'abs_shap_sum': 0, 'data_na': None, 'data_r': None})
|
| 691 |
+
current_entry['abs_shap_sum'] += abs_shap
|
| 692 |
+
if item in shap_data_na: current_entry['data_na'] = item
|
| 693 |
+
if item in shap_data_r: current_entry['data_r'] = item
|
| 694 |
+
combined_shap[feature] = current_entry
|
| 695 |
+
sorted_features = sorted(combined_shap.keys(), key=lambda f: combined_shap[f]['abs_shap_sum'], reverse=True)
|
| 696 |
+
|
| 697 |
+
rec_count = 0
|
| 698 |
+
max_recs = 7
|
| 699 |
+
shap_threshold = 0.02 # Minimum SHAP value to consider as a driver
|
| 700 |
+
|
| 701 |
+
for feature in sorted_features:
|
| 702 |
+
if rec_count >= max_recs: break
|
| 703 |
+
if feature in processed_features: continue
|
| 704 |
+
|
| 705 |
+
shap_entry = combined_shap.get(feature, {})
|
| 706 |
+
na_item = shap_entry.get('data_na')
|
| 707 |
+
r_item = shap_entry.get('data_r')
|
| 708 |
+
na_shap = na_item.get('shap_value', 0) if na_item else 0
|
| 709 |
+
r_shap = r_item.get('shap_value', 0) if r_item else 0
|
| 710 |
+
item_for_val = na_item or r_item
|
| 711 |
+
current_val = item_for_val.get('feature_value', 'N/A') if item_for_val else 'N/A'
|
| 712 |
+
numeric_val = item_for_val.get('numeric_value') if item_for_val else None
|
| 713 |
+
|
| 714 |
+
is_na_driver = na_shap > shap_threshold
|
| 715 |
+
is_r_driver = r_shap > shap_threshold
|
| 716 |
+
if not (is_na_driver or is_r_driver): continue # Skip if not a driver for either
|
| 717 |
+
|
| 718 |
+
max_abs_shap = max(abs(na_shap), abs(r_shap))
|
| 719 |
+
priority = "Medium"
|
| 720 |
+
if max_abs_shap > 0.1: priority = "High"
|
| 721 |
+
if (is_na_driver and na_level == 'High') or (is_r_driver and r_level == 'High'): priority = "High" if priority == "Medium" else priority
|
| 722 |
+
if (na_level == 'High' and na_shap > 0.15) or (r_level == 'High' and r_shap > 0.15): priority = "Critical"
|
| 723 |
+
|
| 724 |
+
rec_made = False
|
| 725 |
+
action_text = ""
|
| 726 |
+
rec_category = ""
|
| 727 |
+
rec_recommendation = ""
|
| 728 |
+
|
| 729 |
+
# --- Generate Recommendation Content ---
|
| 730 |
+
# (Combine NA and R logic for the same feature where applicable)
|
| 731 |
+
if feature.startswith("Took Medicine Day") and current_val == "No":
|
| 732 |
+
if is_na_driver: # Primarily an adherence issue
|
| 733 |
+
priority = "Critical"; rec_category = "Early Adherence Failure"
|
| 734 |
+
rec_recommendation = f"Missed medication on {feature.split('(')[0].strip()} is a strong indicator of future non-adherence."
|
| 735 |
+
action_text = "Immediate intervention required: follow-up call TODAY, assess reasons, counsel, establish reminders/support."
|
| 736 |
+
rec_made = True
|
| 737 |
+
elif feature == "Number of Medicines":
|
| 738 |
+
if is_na_driver or is_r_driver:
|
| 739 |
+
rec_category = "Medication Complexity"
|
| 740 |
+
rec_recommendation = f"High number of medicines ({current_val}) associated with increased risk ({'NA' if is_na_driver else ''}{'& R' if is_r_driver else ''})."
|
| 741 |
+
action_text = "Review list for simplification/consolidation. Consider pharmacist consult for polypharmacy review."
|
| 742 |
+
rec_made = True
|
| 743 |
+
elif feature == "Cost per Medicine (₹)":
|
| 744 |
+
try: cost_val = float(numeric_val) if numeric_val is not None else 0
|
| 745 |
+
except: cost_val = 0
|
| 746 |
+
if cost_val > 100 and is_na_driver: # Primarily adherence cost barrier
|
| 747 |
+
rec_category = "Medication Cost Barrier"
|
| 748 |
+
rec_recommendation = f"High average medication cost (₹{current_val}) may be a barrier to adherence."
|
| 749 |
+
action_text = "Discuss cost concerns. Explore generics, assistance programs, or lower-cost alternatives."
|
| 750 |
+
rec_made = True
|
| 751 |
+
elif feature == "Medicine Availability (0-1)":
|
| 752 |
+
try: avail_val = float(numeric_val) if numeric_val is not None else 1.0
|
| 753 |
+
except: avail_val = 1.0
|
| 754 |
+
if avail_val < 0.5 and is_na_driver: # Primarily adherence access issue
|
| 755 |
+
rec_category = "Medication Access Issue"
|
| 756 |
+
rec_recommendation = f"Reported low medication availability ({current_val}) likely hinders adherence."
|
| 757 |
+
action_text = "Verify pharmacy stock pre-discharge. Help patient identify reliable source or discuss alternatives."
|
| 758 |
+
rec_made = True
|
| 759 |
+
elif feature == "ICU Days":
|
| 760 |
+
try: icu_days_val = int(numeric_val) if numeric_val is not None else 0
|
| 761 |
+
except: icu_days_val = 0
|
| 762 |
+
if icu_days_val > 0 and is_r_driver: # Primarily readmission risk
|
| 763 |
+
rec_priority = "High" if icu_days_val > 2 else priority
|
| 764 |
+
rec_category = "ICU History Impact"; priority = rec_priority
|
| 765 |
+
rec_recommendation = f"Prior ICU stay ({current_val} days) significantly increases readmission risk."
|
| 766 |
+
action_text = "Intensive post-discharge support: early follow-up (≤7 days), consider transitional care/home health, meticulous med review & education."
|
| 767 |
+
rec_made = True
|
| 768 |
+
elif feature == "Hospital Days":
|
| 769 |
+
try: days_val = int(numeric_val) if numeric_val is not None else 0
|
| 770 |
+
except: days_val = 0
|
| 771 |
+
if days_val > 7 and is_r_driver: # Primarily readmission risk
|
| 772 |
+
rec_category = "Length of Stay Impact"
|
| 773 |
+
rec_recommendation = f"Longer hospital stay ({current_val} days) associated with increased readmission risk."
|
| 774 |
+
action_text = "Reinforces need for thorough discharge planning, clear instructions, med reconciliation, and prompt follow-up (≤10 days)."
|
| 775 |
+
rec_made = True
|
| 776 |
+
elif feature == "Age":
|
| 777 |
+
try: age_val = int(numeric_val) if numeric_val is not None else 0
|
| 778 |
+
except: age_val = 0
|
| 779 |
+
if age_val > 75 and is_r_driver: # Readmission context factor
|
| 780 |
+
priority = "Medium"; rec_category = "Age Factor (Context)"
|
| 781 |
+
rec_recommendation = f"Patient's age ({current_val}) contributes moderately to readmission risk."
|
| 782 |
+
action_text = "Consider age-related needs in discharge plan (support, mobility, cognition, simple instructions)."
|
| 783 |
+
rec_made = True
|
| 784 |
+
elif feature == "Why in Hospital":
|
| 785 |
+
if is_r_driver: # Readmission context factor
|
| 786 |
+
priority = "Medium" if priority=="Standard" else priority # Elevate slightly but not critical usually
|
| 787 |
+
rec_category = "Diagnosis Factor (Context)"
|
| 788 |
+
rec_recommendation = f"Primary diagnosis ({current_val}) contributes to risk."
|
| 789 |
+
action_text = f"Ensure condition-specific discharge education (red flags, follow-up) and care plan for managing '{current_val.lower()}' are emphasized."
|
| 790 |
+
rec_made = True
|
| 791 |
+
|
| 792 |
+
# Append the recommendation if one was generated
|
| 793 |
+
if rec_made:
|
| 794 |
+
recommendations.append({
|
| 795 |
+
"category": rec_category,
|
| 796 |
+
"priority": priority,
|
| 797 |
+
"recommendation": rec_recommendation,
|
| 798 |
+
"action": action_text
|
| 799 |
+
})
|
| 800 |
+
processed_features.add(feature)
|
| 801 |
+
rec_count += 1
|
| 802 |
+
|
| 803 |
+
# Add default recommendation if few specific ones generated but risk elevated
|
| 804 |
+
if (na_level in ['High', 'Medium'] or r_level in ['High', 'Medium']) and len(recommendations) < 3:
|
| 805 |
+
recommendations.append({
|
| 806 |
+
"category": "General Follow-up", "priority": "Medium",
|
| 807 |
+
"recommendation": "Overall risk is elevated. Review standard discharge protocols.",
|
| 808 |
+
"action": "Ensure robust basics: use teach-back, confirm follow-up appointments, provide clear contact info."})
|
| 809 |
+
|
| 810 |
+
# Sort final recommendations by priority
|
| 811 |
+
priority_map = {"Critical": 0, "High": 1, "Medium": 2, "Standard": 3, "Info": 4}
|
| 812 |
+
recommendations.sort(key=lambda x: priority_map.get(x.get("priority", "Medium"), 99))
|
| 813 |
+
|
| 814 |
+
print(f"Finished generating {len(recommendations)} recommendations.")
|
| 815 |
+
return recommendations
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
# --- Visualization Generation ---
|
| 819 |
+
# (Keep generate_gauge_charts as it was)
|
| 820 |
+
def generate_gauge_charts(pred_na_prob, pred_r_prob):
|
| 821 |
+
"""Generate Plotly gauge chart JSON objects."""
|
| 822 |
+
print("Generating gauge charts...")
|
| 823 |
+
gauges = {}
|
| 824 |
+
color_low = '#28a745'; color_medium = '#ffc107'; color_high = '#dc3545'
|
| 825 |
+
color_low_bg = '#d4edda'; color_medium_bg = '#fff3cd'; color_high_bg = '#f8d7da'
|
| 826 |
+
pred_na_prob = max(0.0, min(1.0, pred_na_prob))
|
| 827 |
+
pred_r_prob = max(0.0, min(1.0, pred_r_prob))
|
| 828 |
+
common_layout = {'height': 220, 'margin': {'t': 50, 'b': 10, 'l': 20, 'r': 20}, 'font': {'color': "#333", 'family': "Arial, sans-serif", 'size': 12}, 'paper_bgcolor': 'rgba(0,0,0,0)', 'plot_bgcolor': 'rgba(0,0,0,0)', 'autosize': True}
|
| 829 |
+
common_gauge = {'axis': {'range': [0, 100], 'tickwidth': 1, 'tickcolor': "#aaa", 'tickfont': {'size': 10}}, 'bar': {'color': "rgba(0,0,0,0.1)", 'thickness': 0.3}, 'bgcolor': 'rgba(0,0,0,0)', 'borderwidth': 0, 'steps': [{'range': [0, 30], 'color': color_low_bg}, {'range': [30, 70], 'color': color_medium_bg}, {'range': [70, 100], 'color': color_high_bg}], 'threshold': {'line': {'color': '#666', 'width': 4}, 'thickness': 0.75, 'value': 0}}
|
| 830 |
+
|
| 831 |
+
na_risk_info = get_risk_level(pred_na_prob)
|
| 832 |
+
na_value_perc = round(pred_na_prob * 100)
|
| 833 |
+
na_gauge_data = common_gauge.copy(); na_gauge_data['threshold']['value'] = na_value_perc
|
| 834 |
+
gauge_bar_color_na = {'green': color_low, 'yellow': color_medium, 'red': color_high}.get(na_risk_info['color'], '#888')
|
| 835 |
+
na_gauge_data['bar']['color'] = gauge_bar_color_na
|
| 836 |
+
gauges['non_adherence'] = json.dumps({'data': [{'type': 'indicator', 'mode': 'gauge+number', 'value': na_value_perc, 'title': {'text': f"<b>Non-Adherence Risk</b><br><span style='font-size:0.9em;'>Level: {na_risk_info['level']}</span>", 'font': {'size': 14}}, 'gauge': na_gauge_data, 'number': {'suffix': "%", 'font': {'size': 24, 'color': gauge_bar_color_na}}}], 'layout': common_layout.copy()}, cls=plotly.utils.PlotlyJSONEncoder)
|
| 837 |
+
|
| 838 |
+
r_risk_info = get_risk_level(pred_r_prob)
|
| 839 |
+
r_value_perc = round(pred_r_prob * 100)
|
| 840 |
+
r_gauge_data = common_gauge.copy(); r_gauge_data['threshold']['value'] = r_value_perc
|
| 841 |
+
gauge_bar_color_r = {'green': color_low, 'yellow': color_medium, 'red': color_high}.get(r_risk_info['color'], '#888')
|
| 842 |
+
r_gauge_data['bar']['color'] = gauge_bar_color_r
|
| 843 |
+
gauges['readmission'] = json.dumps({'data': [{'type': 'indicator', 'mode': 'gauge+number', 'value': r_value_perc, 'title': {'text': f"<b>Readmission Risk</b><br><span style='font-size:0.9em;'>Level: {r_risk_info['level']}</span>", 'font': {'size': 14}}, 'gauge': r_gauge_data, 'number': {'suffix': "%", 'font': {'size': 24, 'color': gauge_bar_color_r}}}], 'layout': common_layout.copy()}, cls=plotly.utils.PlotlyJSONEncoder)
|
| 844 |
+
|
| 845 |
+
print("Finished generating gauge charts.")
|
| 846 |
+
return gauges
|
| 847 |
+
|
| 848 |
+
|
| 849 |
+
# In app.py
|
| 850 |
+
def generate_additional_visualizations(df_user, shap_data_na, shap_data_r,
|
| 851 |
+
base_value_na, base_value_r,
|
| 852 |
+
shap_error_na, shap_error_r):
|
| 853 |
+
"""Generate additional Plotly JSON for various XAI visualizations, including
|
| 854 |
+
waterfall, heatmap, feature comparison, intervention impact, and network graph."""
|
| 855 |
+
print("Generating additional visualizations...")
|
| 856 |
+
visualizations = {}
|
| 857 |
+
plot_bgcolor = 'rgba(0,0,0,0)'
|
| 858 |
+
paper_bgcolor = 'rgba(0,0,0,0)'
|
| 859 |
+
font_family = "Arial, sans-serif"
|
| 860 |
+
font_color = "#333"
|
| 861 |
+
|
| 862 |
+
# Determine validity
|
| 863 |
+
na_shap_valid = not shap_error_na and isinstance(shap_data_na, list) and shap_data_na and "error" not in shap_data_na[0]
|
| 864 |
+
r_shap_valid = not shap_error_r and isinstance(shap_data_r, list) and shap_data_r and "error" not in shap_data_r[0]
|
| 865 |
+
|
| 866 |
+
# --- 1. Waterfall Charts ---
|
| 867 |
+
def create_waterfall(shap_data, base_value, title_suffix, is_valid):
|
| 868 |
+
if not is_valid or base_value is None:
|
| 869 |
+
return create_placeholder_plot(f'Waterfall: {title_suffix}')
|
| 870 |
+
try:
|
| 871 |
+
valid = [item for item in shap_data if isinstance(item, dict) and "error" not in item]
|
| 872 |
+
top = sorted(valid, key=lambda x: abs(x['shap_value']), reverse=True)[:10]
|
| 873 |
+
values = [item['shap_value'] for item in top]
|
| 874 |
+
labels = [
|
| 875 |
+
f"{item['feature'][:25]}{'...' if len(item['feature'])>25 else ''} = {item['feature_value']}"
|
| 876 |
+
for item in top
|
| 877 |
+
]
|
| 878 |
+
measures = ["absolute"] + ["relative"]*len(top) + ["total"]
|
| 879 |
+
y = ["Average Model Output"] + labels + ["Final Prediction"]
|
| 880 |
+
x = [base_value] + values + [base_value + sum(values)]
|
| 881 |
+
text = [f"{v:+.3f}" if 0< i < len(x)-1 else f"{v:.3f}" for i, v in enumerate(x)]
|
| 882 |
+
|
| 883 |
+
fig = go.Figure(go.Waterfall(
|
| 884 |
+
orientation="h", measure=measures, y=y, x=x, text=text,
|
| 885 |
+
textposition="outside", base=0,
|
| 886 |
+
connector={"line":{"color":"rgb(150,150,150)", "width":1}},
|
| 887 |
+
increasing={"marker":{"color":"#dc3545","line":{"width":1,"color":"#dc3545"}}},
|
| 888 |
+
decreasing={"marker":{"color":"#28a745","line":{"width":1,"color":"#28a745"}}},
|
| 889 |
+
totals={"marker":{"color":"#007bff","line":{"width":1,"color":"#007bff"}}}
|
| 890 |
+
))
|
| 891 |
+
# autoscale axis
|
| 892 |
+
all_vals = x
|
| 893 |
+
mn, mx = min(all_vals), max(all_vals)
|
| 894 |
+
pad = (mx-mn)*0.15 if (mx-mn)>0.01 else 0.1
|
| 895 |
+
fig.update_layout(
|
| 896 |
+
title=f'How Factors Contribute to {title_suffix}',
|
| 897 |
+
showlegend=False,
|
| 898 |
+
height=max(450,40*len(y)),
|
| 899 |
+
margin=dict(t=50,l=250,r=50,b=50),
|
| 900 |
+
yaxis={'autorange':'reversed','automargin':True},
|
| 901 |
+
xaxis={'range':[mn-pad,mx+pad],'title':'Contribution to Score'},
|
| 902 |
+
plot_bgcolor=plot_bgcolor,
|
| 903 |
+
paper_bgcolor=paper_bgcolor,
|
| 904 |
+
font=dict(family=font_family,color=font_color),
|
| 905 |
+
autosize=True
|
| 906 |
+
)
|
| 907 |
+
return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
|
| 908 |
+
except Exception as e:
|
| 909 |
+
print(f"Error creating waterfall {title_suffix}: {e}")
|
| 910 |
+
return create_placeholder_plot(f'Waterfall: {title_suffix}')
|
| 911 |
+
|
| 912 |
+
visualizations['waterfall_na'] = create_waterfall(shap_data_na, base_value_na, 'Non-Adherence Risk', na_shap_valid)
|
| 913 |
+
visualizations['waterfall_r'] = create_waterfall(shap_data_r, base_value_r, 'Readmission Risk', r_shap_valid)
|
| 914 |
+
|
| 915 |
+
# --- 2. Combined SHAP for heatmap & bar chart ---
|
| 916 |
+
# collect valid shap items
|
| 917 |
+
valid_na = [i for i in shap_data_na if na_shap_valid and "error" not in i]
|
| 918 |
+
valid_r = [i for i in shap_data_r if r_shap_valid and "error" not in i]
|
| 919 |
+
combined_impact = {}
|
| 920 |
+
for i in (valid_na + valid_r):
|
| 921 |
+
f=i['feature']; combined_impact[f]=combined_impact.get(f,0)+abs(i['shap_value'])
|
| 922 |
+
top_feats = sorted(combined_impact, key=lambda f:combined_impact[f], reverse=True)[:10]
|
| 923 |
+
|
| 924 |
+
# 2a. Heatmap
|
| 925 |
+
try:
|
| 926 |
+
z=[]; text=[]
|
| 927 |
+
for f in top_feats:
|
| 928 |
+
na_v = next((i['shap_value'] for i in valid_na if i['feature']==f),0)
|
| 929 |
+
r_v = next((i['shap_value'] for i in valid_r if i['feature']==f),0)
|
| 930 |
+
z.append([na_v,r_v])
|
| 931 |
+
fv_na = next((i['feature_value'] for i in valid_na if i['feature']==f),None)
|
| 932 |
+
fv_r = next((i['feature_value'] for i in valid_r if i['feature']==f),None)
|
| 933 |
+
display = fv_na if fv_na is not None else fv_r
|
| 934 |
+
text.append([f"<b>{f}={display}</b><br>NA: {na_v:.3f}", f"<b>{f}={display}</b><br>R: {r_v:.3f}"])
|
| 935 |
+
hm = {
|
| 936 |
+
'data':[{
|
| 937 |
+
'type':'heatmap','z':z,'text':text,'hoverinfo':'text',
|
| 938 |
+
'x':['Non-Adherence','Readmission'],'y':top_feats,
|
| 939 |
+
'colorscale':'RdBu_r','zmid':0,'xgap':1,'ygap':1
|
| 940 |
+
}],
|
| 941 |
+
'layout':{
|
| 942 |
+
'title':'Top Factor Impact Comparison',
|
| 943 |
+
'height':max(400,40*len(top_feats)),
|
| 944 |
+
'margin':{'t':60,'l':250,'b':80,'r':50},
|
| 945 |
+
'yaxis':{'autorange':'reversed','automargin':True},
|
| 946 |
+
'plot_bgcolor':plot_bgcolor,'paper_bgcolor':paper_bgcolor,
|
| 947 |
+
'font':{'family':font_family,'color':font_color},
|
| 948 |
+
'autosize':True
|
| 949 |
+
}
|
| 950 |
+
}
|
| 951 |
+
visualizations['risk_heatmap']=json.dumps(hm, cls=plotly.utils.PlotlyJSONEncoder)
|
| 952 |
+
except Exception as e:
|
| 953 |
+
print(f"Error heatmap: {e}")
|
| 954 |
+
visualizations['risk_heatmap']=create_placeholder_plot('Risk Factor Heatmap')
|
| 955 |
+
|
| 956 |
+
# 2b. Bar chart
|
| 957 |
+
try:
|
| 958 |
+
bars = {
|
| 959 |
+
'data':[
|
| 960 |
+
{'type':'bar','x':top_feats,'y':[abs(next((i['shap_value'] for i in valid_na if i['feature']==f),0)) for f in top_feats],
|
| 961 |
+
'name':'NA Impact'},
|
| 962 |
+
{'type':'bar','x':top_feats,'y':[abs(next((i['shap_value'] for i in valid_r if i['feature']==f),0)) for f in top_feats],
|
| 963 |
+
'name':'Readmission Impact'}
|
| 964 |
+
],
|
| 965 |
+
'layout':{
|
| 966 |
+
'title':'Feature Importance (Absolute SHAP)',
|
| 967 |
+
'barmode':'group','bargap':0.15,'bargroupgap':0.1,
|
| 968 |
+
'height':450,'margin':{'t':50,'b':150,'l':60,'r':20},
|
| 969 |
+
'xaxis':{'tickangle':-45,'automargin':True},
|
| 970 |
+
'plot_bgcolor':plot_bgcolor,'paper_bgcolor':paper_bgcolor,
|
| 971 |
+
'font':{'family':font_family,'color':font_color},
|
| 972 |
+
'autosize':True
|
| 973 |
+
}
|
| 974 |
+
}
|
| 975 |
+
visualizations['feature_comparison']=json.dumps(bars, cls=plotly.utils.PlotlyJSONEncoder)
|
| 976 |
+
except Exception as e:
|
| 977 |
+
print(f"Error bar chart: {e}")
|
| 978 |
+
visualizations['feature_comparison']=create_placeholder_plot('Feature Importance Comparison')
|
| 979 |
+
|
| 980 |
+
# --- 3. Intervention Impact ---
|
| 981 |
+
try:
|
| 982 |
+
interventions=[]
|
| 983 |
+
mod_set={'Number of Medicines','Cost per Medicine (₹)','Days Medicine Lasts',
|
| 984 |
+
'Total Dosage per Day (mg)','Total Pills Given','Medicine Availability (0-1)',
|
| 985 |
+
'Took Medicine Day 1 (1=Yes)','Took Medicine Day 2 (1=Yes)',
|
| 986 |
+
'Took Medicine Day 3 (1=Yes)'}
|
| 987 |
+
map_label={
|
| 988 |
+
'Number of Medicines':'Reduce # Medicines',
|
| 989 |
+
'Cost per Medicine (₹)':'Reduce Med Cost',
|
| 990 |
+
'Days Medicine Lasts':'Optimize Refill',
|
| 991 |
+
'Total Dosage per Day (mg)':'Optimize Dosage',
|
| 992 |
+
'Total Pills Given':'Reduce Pill Burden',
|
| 993 |
+
'Medicine Availability (0-1)':'Improve Availability',
|
| 994 |
+
**{f:f.replace('Took Medicine','Ensure') for f in mod_set if f.startswith('Took Medicine')}
|
| 995 |
+
}
|
| 996 |
+
thresh=0.015
|
| 997 |
+
for f in mod_set:
|
| 998 |
+
na_v=abs(next((i['shap_value'] for i in valid_na if i['feature']==f),0))
|
| 999 |
+
r_v =abs(next((i['shap_value'] for i in valid_r if i['feature']==f),0))
|
| 1000 |
+
if na_v>thresh or r_v>thresh:
|
| 1001 |
+
interventions.append({'intervention':map_label.get(f,f),'na':na_v,'r':r_v})
|
| 1002 |
+
topi=sorted(interventions, key=lambda x:x['na']+x['r'],reverse=True)[:6]
|
| 1003 |
+
if topi:
|
| 1004 |
+
chart={'data':[
|
| 1005 |
+
{'type':'bar','orientation':'h','y':[i['intervention'] for i in topi],
|
| 1006 |
+
'x':[i['na'] for i in topi],'name':'NA Reduction','text':[f"{i['na']:.3f}" for i in topi],
|
| 1007 |
+
'textposition':'outside'},
|
| 1008 |
+
{'type':'bar','orientation':'h','y':[i['intervention'] for i in topi],
|
| 1009 |
+
'x':[i['r'] for i in topi],'name':'R Reduction','text':[f"{i['r']:.3f}" for i in topi],
|
| 1010 |
+
'textposition':'outside'}
|
| 1011 |
+
],
|
| 1012 |
+
'layout':{
|
| 1013 |
+
'title':'Top Potential Intervention Impacts',
|
| 1014 |
+
'barmode':'group','height':max(350,50*len(topi)),
|
| 1015 |
+
'margin':{'t':50,'l':200,'b':50,'r':50},
|
| 1016 |
+
'yaxis':{'autorange':'reversed','automargin':True},
|
| 1017 |
+
'plot_bgcolor':plot_bgcolor,'paper_bgcolor':paper_bgcolor,
|
| 1018 |
+
'font':{'family':font_family,'color':font_color},
|
| 1019 |
+
'autosize':True
|
| 1020 |
+
}}
|
| 1021 |
+
visualizations['intervention_impact']=json.dumps(chart, cls=plotly.utils.PlotlyJSONEncoder)
|
| 1022 |
+
else:
|
| 1023 |
+
visualizations['intervention_impact']=create_placeholder_plot(
|
| 1024 |
+
'Potential Intervention Impact',
|
| 1025 |
+
message="No significant interventions identified."
|
| 1026 |
+
)
|
| 1027 |
+
except Exception as e:
|
| 1028 |
+
print(f"Error interventions: {e}")
|
| 1029 |
+
visualizations['intervention_impact']=create_placeholder_plot('Potential Intervention Impact')
|
| 1030 |
+
|
| 1031 |
+
# --- 4. Network Graph ---
|
| 1032 |
+
try:
|
| 1033 |
+
import networkx as nx
|
| 1034 |
+
G=nx.Graph()
|
| 1035 |
+
# Use combined_impact from above
|
| 1036 |
+
N=8
|
| 1037 |
+
topN=sorted(combined_impact, key=lambda f:combined_impact[f], reverse=True)[:N]
|
| 1038 |
+
for f in topN:
|
| 1039 |
+
G.add_node(f, size=combined_impact[f])
|
| 1040 |
+
for i,a in enumerate(topN):
|
| 1041 |
+
for b in topN[i+1:]:
|
| 1042 |
+
w=combined_impact[a]+combined_impact[b]
|
| 1043 |
+
G.add_edge(a,b,weight=w)
|
| 1044 |
+
pos=nx.spring_layout(G,k=0.5,iterations=50,seed=42)
|
| 1045 |
+
ex,ey=[],[]
|
| 1046 |
+
for u,v in G.edges():
|
| 1047 |
+
x0,y0=pos[u]; x1,y1=pos[v]
|
| 1048 |
+
ex+=[x0,x1,None]; ey+=[y0,y1,None]
|
| 1049 |
+
nx_,ny_=[],[]
|
| 1050 |
+
ns=[G.nodes[n]['size']*50 for n in G.nodes()]
|
| 1051 |
+
for n in G.nodes():
|
| 1052 |
+
x,y=pos[n]
|
| 1053 |
+
nx_.append(x); ny_.append(y)
|
| 1054 |
+
net={
|
| 1055 |
+
'data':[
|
| 1056 |
+
{'type':'scatter','x':ex,'y':ey,'mode':'lines','line':{'width':1,'color':'#888'},'hoverinfo':'none'},
|
| 1057 |
+
{'type':'scatter','x':nx_,'y':ny_,'mode':'markers+text',
|
| 1058 |
+
'marker':{'size':ns,'color':'#1f77b4','opacity':0.8},
|
| 1059 |
+
'text':list(G.nodes()),'textposition':'top center','hoverinfo':'text'}
|
| 1060 |
+
],
|
| 1061 |
+
'layout':{
|
| 1062 |
+
'title':'Risk Factor Network',
|
| 1063 |
+
'showlegend':False,
|
| 1064 |
+
'xaxis':{'visible':False},'yaxis':{'visible':False},
|
| 1065 |
+
'plot_bgcolor':paper_bgcolor,'paper_bgcolor':paper_bgcolor,
|
| 1066 |
+
'margin':{'l':20,'r':20,'t':40,'b':20},'autosize':True
|
| 1067 |
+
}
|
| 1068 |
+
}
|
| 1069 |
+
visualizations['network_graph']=json.dumps(net, cls=plotly.utils.PlotlyJSONEncoder)
|
| 1070 |
+
except Exception as e:
|
| 1071 |
+
print(f"Error creating network graph: {e}")
|
| 1072 |
+
visualizations['network_graph']=create_placeholder_plot(
|
| 1073 |
+
'Risk Factor Network',
|
| 1074 |
+
message="Network data unavailable."
|
| 1075 |
+
)
|
| 1076 |
+
|
| 1077 |
+
print("Finished generating additional visualizations.")
|
| 1078 |
+
return visualizations
|
| 1079 |
+
|
| 1080 |
+
|
| 1081 |
+
# --- Main Execution ---
|
| 1082 |
+
if __name__ == '__main__':
|
| 1083 |
+
print("Starting Flask application...")
|
| 1084 |
+
port = int(os.environ.get("PORT", 5000))
|
| 1085 |
+
# Set debug=False when deploying
|
| 1086 |
+
# Use debug=True for local development ONLY
|
| 1087 |
+
app.run(debug=False, host='0.0.0.0', port=port)
|