File size: 7,637 Bytes
4c086ce 8116ac7 4c086ce 020f1ee 8116ac7 4c086ce 8116ac7 54590c3 8116ac7 6e20a97 fabcece 8116ac7 54590c3 8116ac7 54590c3 8116ac7 2222e30 8116ac7 54590c3 8116ac7 54590c3 8116ac7 54590c3 8116ac7 4c086ce a64fe69 8116ac7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | # Import packages
import os
import joblib
import numpy as np
import pandas as pd
import requests
import streamlit as st
import shap
import tempfile
import streamlit.components.v1 as components
import scipy.sparse as sp
# Setup API and features
API_URL = os.getenv('API_URL')
NUMERIC_FEATURES = ['age','alcohol_consumption_per_week','physical_activity_minutes_per_week',
'diet_score','bmi','cholesterol_total','insulin_level','map','glucose_fasting']
CATEGORICAL_FEATURES = ['gender','ethnicity','education_level','income_level','employment_status',
'smoking_status','family_history_diabetes','hypertension_history','cardiovascular_history']
ALL_FEATURES = NUMERIC_FEATURES + CATEGORICAL_FEATURES
st.set_page_config(page_title="Diabetes Predictions", layout="wide")
# Function for SHAP plot
def st_shap(plot, height=None):
"""Renders a SHAP plot in Streamlit using HTML style and white background."""
import tempfile, os
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as tmpfile:
shap.save_html(tmpfile.name, plot)
html = open(tmpfile.name, "r").read()
os.unlink(tmpfile.name)
# HTML style
styled_html = f"""
<div style="
background-color: white;
padding: 20px;
border-radius: 12px;
box-shadow: 0 0 10px rgba(0,0,0,0.05);
">
{html}
</div>
"""
components.html(styled_html, height=height or 500, width=1000, scrolling=True)
# Function to create synthetic data
def synth_data(n_policyholders=100, seed=42):
rng = np.random.default_rng(seed); n = n_policyholders
age = rng.normal(50, 15, n).clip(18, 90)
alcohol = rng.gamma(2, 3, n).clip(0, 40)
activity = rng.normal(150, 60, n).clip(0, 600)
diet = rng.uniform(1, 10, n)
bmi = rng.normal(27, 5, n).clip(15, 50)
chol = rng.normal(200, 40, n).clip(100, 400)
insulin = rng.normal(10, 5, n).clip(2, 40)
map_ = rng.normal(95, 10, n).clip(70, 130)
glucose = rng.normal(100, 25, n).clip(60, 250)
gender = np.random.choice(['Male','Female','Other'], n, p=[0.48,0.5,0.02])
ethnicity = np.random.choice(['White','Black','Asian','Hispanic','Other'], n)
edu = np.random.choice(['High School','Bachelor','Master','PhD'], n, p=[0.4,0.35,0.2,0.05])
income = np.random.choice(['Low','Middle','High'], n, p=[0.3,0.5,0.2])
emp = np.random.choice(['Employed','Unemployed','Retired','Student'], n, p=[0.6,0.1,0.25,0.05])
smoke = np.random.choice(['Never','Former','Current'], n, p=[0.6,0.25,0.15])
fam = np.random.choice(['Yes','No'], n, p=[0.35,0.65])
hyper = np.random.choice(['Yes','No'], n, p=[0.3,0.7])
cardio = np.random.choice(['Yes','No'], n, p=[0.2,0.8])
df = pd.DataFrame({
'age':age,'alcohol_consumption_per_week':alcohol,'physical_activity_minutes_per_week':activity,
'diet_score':diet,'bmi':bmi,'cholesterol_total':chol,'insulin_level':insulin,'map':map_,'glucose_fasting':glucose,
'gender':gender,'ethnicity':ethnicity,'education_level':edu,'income_level':income,
'employment_status':emp,'smoking_status':smoke,'family_history_diabetes':fam,
'hypertension_history':hyper,'cardiovascular_history':cardio
})
df.insert(0, 'policyholder_id', range(1, len(df)+1))
return df
# Sets up function to call our API
def call_api(df: pd.DataFrame):
payload = {'data': df[ALL_FEATURES].to_dict(orient="records")}
r = requests.post(API_URL, json=payload, timeout=60); r.raise_for_status()
return np.array(r.json()["probabilities"])
# Get model for SHAP
MODEL_PATH = os.path.join("src", "diabetes_prediction_model_20251007.pkl")
pipe = joblib.load(MODEL_PATH)
# Initiate state -> helps store information across reruns of the same session
if "df" not in st.session_state:
st.session_state.df = None
if "selected_id" not in st.session_state:
st.session_state.selected_id = None
# Sidebar with filters
with st.sidebar:
# Choices for the simulation, number of policyholder(10-200), random seed and threshold (0-1)
st.header("Simulation")
n_policyholders = st.slider("Synthetic policyholders", 10, 200, 20, 10)
seed = st.number_input("Random seed", 0, 99999, 42, 1)
threshold = st.slider("Classification threshold", 0.0, 1.0, 0.24, 0.01)
# Button to generate the DF and Plot, calls on the API to get predicted risk and risk category
if st.button("Generate & Predict", use_container_width=True):
df = synth_data(n_policyholders=n_policyholders, seed=int(seed))
probs = call_api(df)
df["predicted_risk"] = probs
df['risk_category'] = np.where(df['predicted_risk'] >= threshold, "High-risk", "Low-risk")
st.session_state.df = df
# reset selection to first id for consistency
st.session_state.selected_id = int(df["policyholder_id"].iloc[0])
st.success("Predictions received from API")
# If there’s a dataframe loaded, it shows a dropdown of all policyholder IDs
if st.session_state.df is not None:
st.session_state.selected_id = st.selectbox(
"Select Policyholder ID:",
st.session_state.df["policyholder_id"].tolist(),
index=(
st.session_state.df["policyholder_id"].tolist().index(st.session_state.selected_id)
if st.session_state.selected_id in st.session_state.df["policyholder_id"].tolist()
else 0
),
key="selected_id_widget"
)
# Main outputs of model shown in DF
st.title("Predictions of high risk diabetes")
st.caption("API: " + API_URL)
if st.session_state.df is None:
st.info("Generate predictions first to enable results and SHAP explanation.")
else:
df = st.session_state.df
st.write("### Results", df[["policyholder_id", "predicted_risk", "risk_category"]])
st.metric("Average predicted diabetes risk", f"{df['predicted_risk'].mean():.2%}")
st.header("Explain Prediction for a Policyholder")
# Set up SHAP explanation for a chosen policyholder
selected_id = st.session_state.selected_id
if selected_id is not None:
row = df.loc[df["policyholder_id"] == selected_id, ALL_FEATURES]
preprocessor = pipe.named_steps["preprocessor"]
model = pipe.named_steps["classifier"]
X_all = preprocessor.transform(df[ALL_FEATURES])
X_row = preprocessor.transform(row)
feature_names = preprocessor.get_feature_names_out(ALL_FEATURES)
explainer = shap.LinearExplainer(model, X_all)
shap_values = explainer.shap_values(X_row)
# Force plot
shap_plot = shap.force_plot(
explainer.expected_value,
shap_values[0],
X_row[0],
matplotlib=False,
feature_names=feature_names
)
st_shap(shap_plot, height=250)
# Shap df set up for text explanation
shap_df = pd.DataFrame({"Feature": feature_names, "SHAP value": shap_values[0]})
shap_df = shap_df.reindex(shap_df["SHAP value"].abs().sort_values(ascending=False).index)
# Text explanation of most important features
top_features = shap_df.head(5) # the 5 most important features
increase = top_features[top_features["SHAP value"] > 0]
decrease = top_features[top_features["SHAP value"] < 0]
st.markdown("### Why this prediction?")
st.info(
f"The model predicts a higher diabetes risk mainly due to **{', '.join(increase['Feature']) or 'none'}**, "
f"while lower risk is influenced by **{', '.join(decrease['Feature']) or 'none'}**.") |