File size: 7,637 Bytes
4c086ce
8116ac7
 
 
 
 
 
 
 
 
 
 
4c086ce
020f1ee
8116ac7
 
 
 
 
 
 
 
 
 
4c086ce
8116ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54590c3
8116ac7
 
 
 
 
6e20a97
fabcece
 
8116ac7
54590c3
8116ac7
 
 
 
 
 
 
54590c3
8116ac7
2222e30
8116ac7
 
54590c3
8116ac7
 
 
 
 
 
 
 
 
54590c3
8116ac7
 
 
 
 
 
 
 
 
 
 
 
54590c3
8116ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c086ce
a64fe69
8116ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Import packages
import os
import joblib
import numpy as np
import pandas as pd
import requests
import streamlit as st
import shap
import tempfile
import streamlit.components.v1 as components
import scipy.sparse as sp 

# Setup API and features 
API_URL = os.getenv('API_URL')
NUMERIC_FEATURES = ['age','alcohol_consumption_per_week','physical_activity_minutes_per_week',
                    'diet_score','bmi','cholesterol_total','insulin_level','map','glucose_fasting']
CATEGORICAL_FEATURES = ['gender','ethnicity','education_level','income_level','employment_status',
                        'smoking_status','family_history_diabetes','hypertension_history','cardiovascular_history']
ALL_FEATURES = NUMERIC_FEATURES + CATEGORICAL_FEATURES

st.set_page_config(page_title="Diabetes Predictions", layout="wide")

# Function for SHAP plot
def st_shap(plot, height=None):
    """Renders a SHAP plot in Streamlit using HTML style and white background."""
    import tempfile, os
    with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as tmpfile:
        shap.save_html(tmpfile.name, plot)
        html = open(tmpfile.name, "r").read()
        os.unlink(tmpfile.name)

    # HTML style
    styled_html = f"""
    <div style="
        background-color: white;
        padding: 20px;
        border-radius: 12px;
        box-shadow: 0 0 10px rgba(0,0,0,0.05);
    ">
        {html}
    </div>
    """

    components.html(styled_html, height=height or 500, width=1000, scrolling=True)

# Function to create synthetic data
def synth_data(n_policyholders=100, seed=42):
    rng = np.random.default_rng(seed); n = n_policyholders
    age = rng.normal(50, 15, n).clip(18, 90)
    alcohol = rng.gamma(2, 3, n).clip(0, 40)
    activity = rng.normal(150, 60, n).clip(0, 600)
    diet = rng.uniform(1, 10, n)
    bmi = rng.normal(27, 5, n).clip(15, 50)
    chol = rng.normal(200, 40, n).clip(100, 400)
    insulin = rng.normal(10, 5, n).clip(2, 40)
    map_ = rng.normal(95, 10, n).clip(70, 130)
    glucose = rng.normal(100, 25, n).clip(60, 250)
    gender = np.random.choice(['Male','Female','Other'], n, p=[0.48,0.5,0.02])
    ethnicity = np.random.choice(['White','Black','Asian','Hispanic','Other'], n)
    edu = np.random.choice(['High School','Bachelor','Master','PhD'], n, p=[0.4,0.35,0.2,0.05])
    income = np.random.choice(['Low','Middle','High'], n, p=[0.3,0.5,0.2])
    emp = np.random.choice(['Employed','Unemployed','Retired','Student'], n, p=[0.6,0.1,0.25,0.05])
    smoke = np.random.choice(['Never','Former','Current'], n, p=[0.6,0.25,0.15])
    fam = np.random.choice(['Yes','No'], n, p=[0.35,0.65])
    hyper = np.random.choice(['Yes','No'], n, p=[0.3,0.7])
    cardio = np.random.choice(['Yes','No'], n, p=[0.2,0.8])
    df = pd.DataFrame({
        'age':age,'alcohol_consumption_per_week':alcohol,'physical_activity_minutes_per_week':activity,
        'diet_score':diet,'bmi':bmi,'cholesterol_total':chol,'insulin_level':insulin,'map':map_,'glucose_fasting':glucose,
        'gender':gender,'ethnicity':ethnicity,'education_level':edu,'income_level':income,
        'employment_status':emp,'smoking_status':smoke,'family_history_diabetes':fam,
        'hypertension_history':hyper,'cardiovascular_history':cardio
    })
    df.insert(0, 'policyholder_id', range(1, len(df)+1))
    return df

# Sets up function to call our API
def call_api(df: pd.DataFrame):
    payload = {'data': df[ALL_FEATURES].to_dict(orient="records")}
    r = requests.post(API_URL, json=payload, timeout=60); r.raise_for_status()
    return np.array(r.json()["probabilities"])

# Get model for SHAP 
MODEL_PATH = os.path.join("src", "diabetes_prediction_model_20251007.pkl")
pipe = joblib.load(MODEL_PATH)

# Initiate state -> helps store information across reruns of the same session
if "df" not in st.session_state:
    st.session_state.df = None
if "selected_id" not in st.session_state:
    st.session_state.selected_id = None

# Sidebar with filters 
with st.sidebar:
    # Choices for the simulation, number of policyholder(10-200), random seed and threshold (0-1) 
    st.header("Simulation")
    n_policyholders = st.slider("Synthetic policyholders", 10, 200, 20, 10)
    seed = st.number_input("Random seed", 0, 99999, 42, 1)
    threshold = st.slider("Classification threshold", 0.0, 1.0, 0.24, 0.01)
    # Button to generate the DF and Plot, calls on the API to get predicted risk and risk category
    if st.button("Generate & Predict", use_container_width=True):
        df = synth_data(n_policyholders=n_policyholders, seed=int(seed))
        probs = call_api(df)
        df["predicted_risk"] = probs
        df['risk_category'] = np.where(df['predicted_risk'] >= threshold, "High-risk", "Low-risk")
        st.session_state.df = df
        # reset selection to first id for consistency
        st.session_state.selected_id = int(df["policyholder_id"].iloc[0])
        st.success("Predictions received from API")
    # If there’s a dataframe loaded, it shows a dropdown of all policyholder IDs
    if st.session_state.df is not None:
        st.session_state.selected_id = st.selectbox(
            "Select Policyholder ID:",
            st.session_state.df["policyholder_id"].tolist(),
            index=(
                st.session_state.df["policyholder_id"].tolist().index(st.session_state.selected_id)
                if st.session_state.selected_id in st.session_state.df["policyholder_id"].tolist()
                else 0
            ),
            key="selected_id_widget"
        )

# Main outputs of model shown in DF
st.title("Predictions of high risk diabetes")
st.caption("API: " + API_URL)

if st.session_state.df is None:
    st.info("Generate predictions first to enable results and SHAP explanation.")
else:
    df = st.session_state.df
    st.write("### Results", df[["policyholder_id", "predicted_risk", "risk_category"]])
    st.metric("Average predicted diabetes risk", f"{df['predicted_risk'].mean():.2%}")

    st.header("Explain Prediction for a Policyholder")
    # Set up SHAP explanation for a chosen policyholder
    selected_id = st.session_state.selected_id
    if selected_id is not None:
        row = df.loc[df["policyholder_id"] == selected_id, ALL_FEATURES]

        preprocessor = pipe.named_steps["preprocessor"]
        model = pipe.named_steps["classifier"]

        X_all = preprocessor.transform(df[ALL_FEATURES])
        X_row = preprocessor.transform(row)
        feature_names = preprocessor.get_feature_names_out(ALL_FEATURES)

        explainer = shap.LinearExplainer(model, X_all)
        shap_values = explainer.shap_values(X_row)

        # Force plot
        shap_plot = shap.force_plot(
            explainer.expected_value,
            shap_values[0],
            X_row[0],
            matplotlib=False,
            feature_names=feature_names
        )
        st_shap(shap_plot, height=250)

        # Shap df set up for text explanation
        shap_df = pd.DataFrame({"Feature": feature_names, "SHAP value": shap_values[0]})
        shap_df = shap_df.reindex(shap_df["SHAP value"].abs().sort_values(ascending=False).index)

        # Text explanation of most important features
        top_features = shap_df.head(5)  # the 5 most important features
        increase = top_features[top_features["SHAP value"] > 0]
        decrease = top_features[top_features["SHAP value"] < 0]

        st.markdown("### Why this prediction?")
        st.info(
                f"The model predicts a higher diabetes risk mainly due to **{', '.join(increase['Feature']) or 'none'}**, "
                f"while lower risk is influenced by **{', '.join(decrease['Feature']) or 'none'}**.")