Signe22's picture
Update src/streamlit_app.py
54590c3 verified
# Import packages
import os
import joblib
import numpy as np
import pandas as pd
import requests
import streamlit as st
import shap
import tempfile
import streamlit.components.v1 as components
import scipy.sparse as sp
# Setup API and features
API_URL = os.getenv('API_URL')
NUMERIC_FEATURES = ['age','alcohol_consumption_per_week','physical_activity_minutes_per_week',
'diet_score','bmi','cholesterol_total','insulin_level','map','glucose_fasting']
CATEGORICAL_FEATURES = ['gender','ethnicity','education_level','income_level','employment_status',
'smoking_status','family_history_diabetes','hypertension_history','cardiovascular_history']
ALL_FEATURES = NUMERIC_FEATURES + CATEGORICAL_FEATURES
st.set_page_config(page_title="Diabetes Predictions", layout="wide")
# Function for SHAP plot
def st_shap(plot, height=None):
"""Renders a SHAP plot in Streamlit using HTML style and white background."""
import tempfile, os
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as tmpfile:
shap.save_html(tmpfile.name, plot)
html = open(tmpfile.name, "r").read()
os.unlink(tmpfile.name)
# HTML style
styled_html = f"""
<div style="
background-color: white;
padding: 20px;
border-radius: 12px;
box-shadow: 0 0 10px rgba(0,0,0,0.05);
">
{html}
</div>
"""
components.html(styled_html, height=height or 500, width=1000, scrolling=True)
# Function to create synthetic data
def synth_data(n_policyholders=100, seed=42):
rng = np.random.default_rng(seed); n = n_policyholders
age = rng.normal(50, 15, n).clip(18, 90)
alcohol = rng.gamma(2, 3, n).clip(0, 40)
activity = rng.normal(150, 60, n).clip(0, 600)
diet = rng.uniform(1, 10, n)
bmi = rng.normal(27, 5, n).clip(15, 50)
chol = rng.normal(200, 40, n).clip(100, 400)
insulin = rng.normal(10, 5, n).clip(2, 40)
map_ = rng.normal(95, 10, n).clip(70, 130)
glucose = rng.normal(100, 25, n).clip(60, 250)
gender = np.random.choice(['Male','Female','Other'], n, p=[0.48,0.5,0.02])
ethnicity = np.random.choice(['White','Black','Asian','Hispanic','Other'], n)
edu = np.random.choice(['High School','Bachelor','Master','PhD'], n, p=[0.4,0.35,0.2,0.05])
income = np.random.choice(['Low','Middle','High'], n, p=[0.3,0.5,0.2])
emp = np.random.choice(['Employed','Unemployed','Retired','Student'], n, p=[0.6,0.1,0.25,0.05])
smoke = np.random.choice(['Never','Former','Current'], n, p=[0.6,0.25,0.15])
fam = np.random.choice(['Yes','No'], n, p=[0.35,0.65])
hyper = np.random.choice(['Yes','No'], n, p=[0.3,0.7])
cardio = np.random.choice(['Yes','No'], n, p=[0.2,0.8])
df = pd.DataFrame({
'age':age,'alcohol_consumption_per_week':alcohol,'physical_activity_minutes_per_week':activity,
'diet_score':diet,'bmi':bmi,'cholesterol_total':chol,'insulin_level':insulin,'map':map_,'glucose_fasting':glucose,
'gender':gender,'ethnicity':ethnicity,'education_level':edu,'income_level':income,
'employment_status':emp,'smoking_status':smoke,'family_history_diabetes':fam,
'hypertension_history':hyper,'cardiovascular_history':cardio
})
df.insert(0, 'policyholder_id', range(1, len(df)+1))
return df
# Sets up function to call our API
def call_api(df: pd.DataFrame):
payload = {'data': df[ALL_FEATURES].to_dict(orient="records")}
r = requests.post(API_URL, json=payload, timeout=60); r.raise_for_status()
return np.array(r.json()["probabilities"])
# Get model for SHAP
MODEL_PATH = os.path.join("src", "diabetes_prediction_model_20251007.pkl")
pipe = joblib.load(MODEL_PATH)
# Initiate state -> helps store information across reruns of the same session
if "df" not in st.session_state:
st.session_state.df = None
if "selected_id" not in st.session_state:
st.session_state.selected_id = None
# Sidebar with filters
with st.sidebar:
# Choices for the simulation, number of policyholder(10-200), random seed and threshold (0-1)
st.header("Simulation")
n_policyholders = st.slider("Synthetic policyholders", 10, 200, 20, 10)
seed = st.number_input("Random seed", 0, 99999, 42, 1)
threshold = st.slider("Classification threshold", 0.0, 1.0, 0.24, 0.01)
# Button to generate the DF and Plot, calls on the API to get predicted risk and risk category
if st.button("Generate & Predict", use_container_width=True):
df = synth_data(n_policyholders=n_policyholders, seed=int(seed))
probs = call_api(df)
df["predicted_risk"] = probs
df['risk_category'] = np.where(df['predicted_risk'] >= threshold, "High-risk", "Low-risk")
st.session_state.df = df
# reset selection to first id for consistency
st.session_state.selected_id = int(df["policyholder_id"].iloc[0])
st.success("Predictions received from API")
# If there’s a dataframe loaded, it shows a dropdown of all policyholder IDs
if st.session_state.df is not None:
st.session_state.selected_id = st.selectbox(
"Select Policyholder ID:",
st.session_state.df["policyholder_id"].tolist(),
index=(
st.session_state.df["policyholder_id"].tolist().index(st.session_state.selected_id)
if st.session_state.selected_id in st.session_state.df["policyholder_id"].tolist()
else 0
),
key="selected_id_widget"
)
# Main outputs of model shown in DF
st.title("Predictions of high risk diabetes")
st.caption("API: " + API_URL)
if st.session_state.df is None:
st.info("Generate predictions first to enable results and SHAP explanation.")
else:
df = st.session_state.df
st.write("### Results", df[["policyholder_id", "predicted_risk", "risk_category"]])
st.metric("Average predicted diabetes risk", f"{df['predicted_risk'].mean():.2%}")
st.header("Explain Prediction for a Policyholder")
# Set up SHAP explanation for a chosen policyholder
selected_id = st.session_state.selected_id
if selected_id is not None:
row = df.loc[df["policyholder_id"] == selected_id, ALL_FEATURES]
preprocessor = pipe.named_steps["preprocessor"]
model = pipe.named_steps["classifier"]
X_all = preprocessor.transform(df[ALL_FEATURES])
X_row = preprocessor.transform(row)
feature_names = preprocessor.get_feature_names_out(ALL_FEATURES)
explainer = shap.LinearExplainer(model, X_all)
shap_values = explainer.shap_values(X_row)
# Force plot
shap_plot = shap.force_plot(
explainer.expected_value,
shap_values[0],
X_row[0],
matplotlib=False,
feature_names=feature_names
)
st_shap(shap_plot, height=250)
# Shap df set up for text explanation
shap_df = pd.DataFrame({"Feature": feature_names, "SHAP value": shap_values[0]})
shap_df = shap_df.reindex(shap_df["SHAP value"].abs().sort_values(ascending=False).index)
# Text explanation of most important features
top_features = shap_df.head(5) # the 5 most important features
increase = top_features[top_features["SHAP value"] > 0]
decrease = top_features[top_features["SHAP value"] < 0]
st.markdown("### Why this prediction?")
st.info(
f"The model predicts a higher diabetes risk mainly due to **{', '.join(increase['Feature']) or 'none'}**, "
f"while lower risk is influenced by **{', '.join(decrease['Feature']) or 'none'}**.")