Update app.py
Browse files
app.py
CHANGED
|
@@ -11,7 +11,7 @@ plt.rcParams["figure.dpi"] = 100
|
|
| 11 |
# =========================
|
| 12 |
# Load model (CatBoostClassifier saved via joblib)
|
| 13 |
# =========================
|
| 14 |
-
model = joblib.load("cat (
|
| 15 |
|
| 16 |
FEATURES = [
|
| 17 |
"Engagement",
|
|
@@ -37,6 +37,16 @@ CLUSTER_1 = {
|
|
| 37 |
"Engagement": 4.9324,
|
| 38 |
}
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
CLUSTER_3 = {
|
| 41 |
"Voice": 2.39,
|
| 42 |
"DecisionAutonomy": 3.55,
|
|
@@ -47,13 +57,31 @@ CLUSTER_3 = {
|
|
| 47 |
"Engagement": 3.3909,
|
| 48 |
}
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
# =========================
|
| 54 |
# Helpers
|
| 55 |
# =========================
|
| 56 |
-
def
|
| 57 |
return max(1.0, min(5.0, float(x)))
|
| 58 |
|
| 59 |
def build_X(vals: dict) -> pd.DataFrame:
|
|
@@ -66,38 +94,30 @@ def prob_at_risk(X: pd.DataFrame) -> float:
|
|
| 66 |
idx = classes.index(1) # class 1 = At Risk
|
| 67 |
return float(probs[idx])
|
| 68 |
|
| 69 |
-
def risk_label(p):
|
| 70 |
return "At Risk" if p >= 0.5 else "Not At Risk"
|
| 71 |
|
| 72 |
-
def stable_threshold():
|
| 73 |
-
return min(CLUSTER_1[v] for v in VISIBLE_DRIVERS)
|
| 74 |
-
|
| 75 |
# =========================
|
| 76 |
-
# Plot: drivers
|
| 77 |
# =========================
|
| 78 |
-
def make_driver_plot(
|
| 79 |
-
|
| 80 |
-
values = [Engagement, SupportiveGM, WellBeing, WorkEnvironment]
|
| 81 |
-
colors = ["seagreen" if v >= th else "firebrick" for v in values]
|
| 82 |
-
|
| 83 |
-
fig, ax = plt.subplots(figsize=(8.8, 3.4))
|
| 84 |
-
ax.bar(VISIBLE_LABELS, values, color=colors)
|
| 85 |
|
| 86 |
-
ax.
|
| 87 |
-
ax.
|
| 88 |
|
| 89 |
ax.set_ylim(1, 5.4)
|
| 90 |
ax.set_yticks([1, 2, 3, 4, 5])
|
| 91 |
ax.set_ylabel("Score (1–5)")
|
| 92 |
-
ax.set_title("
|
| 93 |
|
| 94 |
-
ax.margins(x=0.
|
| 95 |
plt.tight_layout()
|
| 96 |
-
plt.subplots_adjust(bottom=0.
|
| 97 |
return fig
|
| 98 |
|
| 99 |
# =========================
|
| 100 |
-
# TRUE SHAP using CatBoost native SHAP values
|
| 101 |
# =========================
|
| 102 |
def make_catboost_shap_plot(X: pd.DataFrame):
|
| 103 |
"""
|
|
@@ -106,35 +126,33 @@ def make_catboost_shap_plot(X: pd.DataFrame):
|
|
| 106 |
returns array shape: (n_rows, n_features + 1)
|
| 107 |
last column is expected value; first n_features are SHAP contributions.
|
| 108 |
"""
|
| 109 |
-
fig, ax = plt.subplots(figsize=(8.8, 3.
|
| 110 |
|
| 111 |
try:
|
| 112 |
from catboost import Pool
|
| 113 |
|
| 114 |
pool = Pool(X) # 1-row
|
| 115 |
shap_vals = model.get_feature_importance(pool, type="ShapValues")
|
| 116 |
-
# shap_vals shape: (1, n_features+1)
|
| 117 |
contrib = shap_vals[0, :-1] # drop expected value
|
| 118 |
|
| 119 |
s = pd.Series(contrib, index=X.columns)
|
| 120 |
|
| 121 |
-
#
|
| 122 |
s = s.drop(labels=["ManagementLevel"], errors="ignore")
|
| 123 |
|
| 124 |
-
#
|
| 125 |
s = s.reindex(s.abs().sort_values(ascending=False).index).head(8)
|
| 126 |
|
| 127 |
ax.barh(s.index[::-1], s.values[::-1])
|
| 128 |
-
ax.set_title("
|
| 129 |
ax.set_xlabel("Impact on model log-odds (signed)")
|
| 130 |
plt.tight_layout()
|
| 131 |
return fig
|
| 132 |
|
| 133 |
except Exception as e:
|
| 134 |
-
# If catboost isn't installed or something fails, show the error nicely
|
| 135 |
ax.text(
|
| 136 |
0.5, 0.55,
|
| 137 |
-
"
|
| 138 |
ha="center", va="center", fontsize=10
|
| 139 |
)
|
| 140 |
ax.text(0.5, 0.40, f"Error: {str(e)[:150]}", ha="center", va="center", fontsize=9)
|
|
@@ -145,87 +163,125 @@ def make_catboost_shap_plot(X: pd.DataFrame):
|
|
| 145 |
# =========================
|
| 146 |
# Prediction
|
| 147 |
# =========================
|
| 148 |
-
def predict(
|
| 149 |
-
Engagement
|
| 150 |
-
SupportiveGM
|
| 151 |
-
WellBeing
|
| 152 |
-
WorkEnvironment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
-
#
|
| 155 |
vals = {
|
| 156 |
-
|
| 157 |
-
"
|
| 158 |
-
"ManagementLevel": 2, # fixed constant, not shown
|
| 159 |
-
"WellBeing": WellBeing,
|
| 160 |
-
"Voice": CLUSTER_1["Voice"],
|
| 161 |
-
"DecisionAutonomy": CLUSTER_1["DecisionAutonomy"],
|
| 162 |
-
"Workload": CLUSTER_1["Workload"],
|
| 163 |
-
"WorkEnvironment": WorkEnvironment,
|
| 164 |
}
|
| 165 |
|
| 166 |
X = build_X(vals)
|
| 167 |
p = prob_at_risk(X)
|
| 168 |
-
|
| 169 |
headline = f"Predicted Status: {risk_label(p)}"
|
| 170 |
-
|
|
|
|
| 171 |
shap_fig = make_catboost_shap_plot(X)
|
| 172 |
|
| 173 |
-
return headline,
|
| 174 |
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
-
|
| 182 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
# =========================
|
| 185 |
# UI Layout (no scrolling)
|
| 186 |
# =========================
|
| 187 |
CSS = """
|
| 188 |
#app-wrap { max-width: 1200px; margin: 0 auto; }
|
| 189 |
-
.compact .gr-markdown { margin-bottom: 0.
|
| 190 |
"""
|
| 191 |
|
| 192 |
with gr.Blocks(css=CSS) as demo:
|
| 193 |
gr.Markdown(
|
| 194 |
"<div id='app-wrap' class='compact'>"
|
| 195 |
-
"<h2>Retention
|
| 196 |
-
"<p style='margin-top:0;'>Adjust
|
| 197 |
-
"Click <b>
|
| 198 |
"</div>"
|
| 199 |
)
|
| 200 |
|
| 201 |
with gr.Row():
|
| 202 |
# LEFT: sliders + buttons
|
| 203 |
-
with gr.Column(scale=5, min_width=
|
|
|
|
| 204 |
Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
|
| 205 |
SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
|
| 206 |
WellBeing = gr.Slider(1, 5, value=CLUSTER_3["WellBeing"], step=0.01, label="Well-Being")
|
| 207 |
WorkEnvironment = gr.Slider(1, 5, value=CLUSTER_3["WorkEnvironment"], step=0.01, label="Work Environment")
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
with gr.Row():
|
| 210 |
btn_predict = gr.Button("Predict")
|
| 211 |
-
|
| 212 |
|
| 213 |
# RIGHT: headline + two plots stacked
|
| 214 |
with gr.Column(scale=7, min_width=520):
|
| 215 |
headline = gr.Textbox(label="Result", value="", interactive=False)
|
| 216 |
-
|
| 217 |
-
shap_plot = gr.Plot(label="
|
| 218 |
|
| 219 |
btn_predict.click(
|
| 220 |
fn=predict,
|
| 221 |
-
inputs=[Engagement, SupportiveGM, WellBeing, WorkEnvironment],
|
| 222 |
-
outputs=[headline,
|
| 223 |
)
|
| 224 |
|
| 225 |
-
|
| 226 |
-
fn=
|
| 227 |
inputs=[],
|
| 228 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
| 229 |
)
|
| 230 |
|
| 231 |
demo.launch()
|
|
|
|
| 11 |
# =========================
|
| 12 |
# Load model (CatBoostClassifier saved via joblib)
|
| 13 |
# =========================
|
| 14 |
+
model = joblib.load("cat (1).joblib")
|
| 15 |
|
| 16 |
FEATURES = [
|
| 17 |
"Engagement",
|
|
|
|
| 37 |
"Engagement": 4.9324,
|
| 38 |
}
|
| 39 |
|
| 40 |
+
CLUSTER_2 = {
|
| 41 |
+
"Voice": 3.94,
|
| 42 |
+
"DecisionAutonomy": 4.24,
|
| 43 |
+
"Workload": 3.76,
|
| 44 |
+
"WellBeing": 4.0251,
|
| 45 |
+
"WorkEnvironment": 4.1484,
|
| 46 |
+
"SupportiveGM": 4.1275,
|
| 47 |
+
"Engagement": 4.2828,
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
CLUSTER_3 = {
|
| 51 |
"Voice": 2.39,
|
| 52 |
"DecisionAutonomy": 3.55,
|
|
|
|
| 57 |
"Engagement": 3.3909,
|
| 58 |
}
|
| 59 |
|
| 60 |
+
# You asked: "MAKE all THE VARS the key drivers" (we treat all survey vars as drivers)
|
| 61 |
+
ALL_DRIVER_VARS = [
|
| 62 |
+
"Engagement",
|
| 63 |
+
"SupportiveGM",
|
| 64 |
+
"WellBeing",
|
| 65 |
+
"WorkEnvironment",
|
| 66 |
+
"Voice",
|
| 67 |
+
"DecisionAutonomy",
|
| 68 |
+
"Workload",
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
ALL_DRIVER_LABELS = [
|
| 72 |
+
"Engagement",
|
| 73 |
+
"Supportive GM",
|
| 74 |
+
"Well-Being",
|
| 75 |
+
"Work Environment",
|
| 76 |
+
"Voice",
|
| 77 |
+
"Decision Autonomy",
|
| 78 |
+
"Workload",
|
| 79 |
+
]
|
| 80 |
|
| 81 |
# =========================
|
| 82 |
# Helpers
|
| 83 |
# =========================
|
| 84 |
+
def clamp_1_5(x):
|
| 85 |
return max(1.0, min(5.0, float(x)))
|
| 86 |
|
| 87 |
def build_X(vals: dict) -> pd.DataFrame:
|
|
|
|
| 94 |
idx = classes.index(1) # class 1 = At Risk
|
| 95 |
return float(probs[idx])
|
| 96 |
|
| 97 |
+
def risk_label(p: float) -> str:
|
| 98 |
return "At Risk" if p >= 0.5 else "Not At Risk"
|
| 99 |
|
|
|
|
|
|
|
|
|
|
| 100 |
# =========================
|
| 101 |
+
# Plot: "Average of key drivers" (shows ALL driver vars)
|
| 102 |
# =========================
|
| 103 |
+
def make_driver_plot(driver_vals: dict):
|
| 104 |
+
values = [driver_vals[v] for v in ALL_DRIVER_VARS]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
+
fig, ax = plt.subplots(figsize=(8.8, 3.2))
|
| 107 |
+
ax.bar(ALL_DRIVER_LABELS, values)
|
| 108 |
|
| 109 |
ax.set_ylim(1, 5.4)
|
| 110 |
ax.set_yticks([1, 2, 3, 4, 5])
|
| 111 |
ax.set_ylabel("Score (1–5)")
|
| 112 |
+
ax.set_title("Average of key drivers")
|
| 113 |
|
| 114 |
+
ax.margins(x=0.08)
|
| 115 |
plt.tight_layout()
|
| 116 |
+
plt.subplots_adjust(bottom=0.28)
|
| 117 |
return fig
|
| 118 |
|
| 119 |
# =========================
|
| 120 |
+
# Plot: TRUE SHAP using CatBoost native SHAP values
|
| 121 |
# =========================
|
| 122 |
def make_catboost_shap_plot(X: pd.DataFrame):
|
| 123 |
"""
|
|
|
|
| 126 |
returns array shape: (n_rows, n_features + 1)
|
| 127 |
last column is expected value; first n_features are SHAP contributions.
|
| 128 |
"""
|
| 129 |
+
fig, ax = plt.subplots(figsize=(8.8, 3.2))
|
| 130 |
|
| 131 |
try:
|
| 132 |
from catboost import Pool
|
| 133 |
|
| 134 |
pool = Pool(X) # 1-row
|
| 135 |
shap_vals = model.get_feature_importance(pool, type="ShapValues")
|
|
|
|
| 136 |
contrib = shap_vals[0, :-1] # drop expected value
|
| 137 |
|
| 138 |
s = pd.Series(contrib, index=X.columns)
|
| 139 |
|
| 140 |
+
# Keep SHAP focused on survey drivers (exclude ManagementLevel)
|
| 141 |
s = s.drop(labels=["ManagementLevel"], errors="ignore")
|
| 142 |
|
| 143 |
+
# Top 8 by absolute contribution
|
| 144 |
s = s.reindex(s.abs().sort_values(ascending=False).index).head(8)
|
| 145 |
|
| 146 |
ax.barh(s.index[::-1], s.values[::-1])
|
| 147 |
+
ax.set_title("Feature Importance (Shap)")
|
| 148 |
ax.set_xlabel("Impact on model log-odds (signed)")
|
| 149 |
plt.tight_layout()
|
| 150 |
return fig
|
| 151 |
|
| 152 |
except Exception as e:
|
|
|
|
| 153 |
ax.text(
|
| 154 |
0.5, 0.55,
|
| 155 |
+
"SHAP chart unavailable.\nInstall 'catboost' in requirements.txt.",
|
| 156 |
ha="center", va="center", fontsize=10
|
| 157 |
)
|
| 158 |
ax.text(0.5, 0.40, f"Error: {str(e)[:150]}", ha="center", va="center", fontsize=9)
|
|
|
|
| 163 |
# =========================
|
| 164 |
# Prediction
|
| 165 |
# =========================
|
| 166 |
+
def predict(
|
| 167 |
+
Engagement,
|
| 168 |
+
SupportiveGM,
|
| 169 |
+
WellBeing,
|
| 170 |
+
WorkEnvironment,
|
| 171 |
+
Voice,
|
| 172 |
+
DecisionAutonomy,
|
| 173 |
+
Workload,
|
| 174 |
+
):
|
| 175 |
+
# Clamp sliders
|
| 176 |
+
driver_vals = {
|
| 177 |
+
"Engagement": clamp_1_5(Engagement),
|
| 178 |
+
"SupportiveGM": clamp_1_5(SupportiveGM),
|
| 179 |
+
"WellBeing": clamp_1_5(WellBeing),
|
| 180 |
+
"WorkEnvironment": clamp_1_5(WorkEnvironment),
|
| 181 |
+
"Voice": clamp_1_5(Voice),
|
| 182 |
+
"DecisionAutonomy": clamp_1_5(DecisionAutonomy),
|
| 183 |
+
"Workload": clamp_1_5(Workload),
|
| 184 |
+
}
|
| 185 |
|
| 186 |
+
# Build model row (ManagementLevel fixed internally)
|
| 187 |
vals = {
|
| 188 |
+
**driver_vals,
|
| 189 |
+
"ManagementLevel": 2,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
}
|
| 191 |
|
| 192 |
X = build_X(vals)
|
| 193 |
p = prob_at_risk(X)
|
|
|
|
| 194 |
headline = f"Predicted Status: {risk_label(p)}"
|
| 195 |
+
|
| 196 |
+
drivers_fig = make_driver_plot(driver_vals)
|
| 197 |
shap_fig = make_catboost_shap_plot(X)
|
| 198 |
|
| 199 |
+
return headline, drivers_fig, shap_fig
|
| 200 |
|
| 201 |
+
# =========================
|
| 202 |
+
# Button: At risk group = average of Cluster 1 and Cluster 2 (as you requested)
|
| 203 |
+
# =========================
|
| 204 |
+
def at_risk_group():
|
| 205 |
+
avg = {}
|
| 206 |
+
for v in ALL_DRIVER_VARS:
|
| 207 |
+
avg[v] = (CLUSTER_1[v] + CLUSTER_2[v]) / 2.0
|
| 208 |
+
|
| 209 |
+
headline, drivers_fig, shap_fig = predict(
|
| 210 |
+
avg["Engagement"],
|
| 211 |
+
avg["SupportiveGM"],
|
| 212 |
+
avg["WellBeing"],
|
| 213 |
+
avg["WorkEnvironment"],
|
| 214 |
+
avg["Voice"],
|
| 215 |
+
avg["DecisionAutonomy"],
|
| 216 |
+
avg["Workload"],
|
| 217 |
+
)
|
| 218 |
|
| 219 |
+
# Return slider updates + outputs
|
| 220 |
+
return (
|
| 221 |
+
avg["Engagement"],
|
| 222 |
+
avg["SupportiveGM"],
|
| 223 |
+
avg["WellBeing"],
|
| 224 |
+
avg["WorkEnvironment"],
|
| 225 |
+
avg["Voice"],
|
| 226 |
+
avg["DecisionAutonomy"],
|
| 227 |
+
avg["Workload"],
|
| 228 |
+
headline,
|
| 229 |
+
drivers_fig,
|
| 230 |
+
shap_fig,
|
| 231 |
+
)
|
| 232 |
|
| 233 |
# =========================
|
| 234 |
# UI Layout (no scrolling)
|
| 235 |
# =========================
|
| 236 |
CSS = """
|
| 237 |
#app-wrap { max-width: 1200px; margin: 0 auto; }
|
| 238 |
+
.compact .gr-markdown { margin-bottom: 0.35rem !important; }
|
| 239 |
"""
|
| 240 |
|
| 241 |
with gr.Blocks(css=CSS) as demo:
|
| 242 |
gr.Markdown(
|
| 243 |
"<div id='app-wrap' class='compact'>"
|
| 244 |
+
"<h2>Retention Simulator</h2>"
|
| 245 |
+
"<p style='margin-top:0;'>Adjust all drivers and click <b>Predict</b>. "
|
| 246 |
+
"Click <b>At risk group</b> to load the average of Cluster 1 and Cluster 2.</p>"
|
| 247 |
"</div>"
|
| 248 |
)
|
| 249 |
|
| 250 |
with gr.Row():
|
| 251 |
# LEFT: sliders + buttons
|
| 252 |
+
with gr.Column(scale=5, min_width=430):
|
| 253 |
+
# Default starting point: Cluster 3 (most at-risk)
|
| 254 |
Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
|
| 255 |
SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
|
| 256 |
WellBeing = gr.Slider(1, 5, value=CLUSTER_3["WellBeing"], step=0.01, label="Well-Being")
|
| 257 |
WorkEnvironment = gr.Slider(1, 5, value=CLUSTER_3["WorkEnvironment"], step=0.01, label="Work Environment")
|
| 258 |
+
Voice = gr.Slider(1, 5, value=CLUSTER_3["Voice"], step=0.01, label="Voice")
|
| 259 |
+
DecisionAutonomy = gr.Slider(1, 5, value=CLUSTER_3["DecisionAutonomy"], step=0.01, label="Decision Autonomy")
|
| 260 |
+
Workload = gr.Slider(1, 5, value=CLUSTER_3["Workload"], step=0.01, label="Workload")
|
| 261 |
|
| 262 |
with gr.Row():
|
| 263 |
btn_predict = gr.Button("Predict")
|
| 264 |
+
btn_atrisk = gr.Button("At risk group")
|
| 265 |
|
| 266 |
# RIGHT: headline + two plots stacked
|
| 267 |
with gr.Column(scale=7, min_width=520):
|
| 268 |
headline = gr.Textbox(label="Result", value="", interactive=False)
|
| 269 |
+
drivers_plot = gr.Plot(label="Average of key drivers")
|
| 270 |
+
shap_plot = gr.Plot(label="Feature Importance (Shap)")
|
| 271 |
|
| 272 |
btn_predict.click(
|
| 273 |
fn=predict,
|
| 274 |
+
inputs=[Engagement, SupportiveGM, WellBeing, WorkEnvironment, Voice, DecisionAutonomy, Workload],
|
| 275 |
+
outputs=[headline, drivers_plot, shap_plot],
|
| 276 |
)
|
| 277 |
|
| 278 |
+
btn_atrisk.click(
|
| 279 |
+
fn=at_risk_group,
|
| 280 |
inputs=[],
|
| 281 |
+
outputs=[
|
| 282 |
+
Engagement, SupportiveGM, WellBeing, WorkEnvironment, Voice, DecisionAutonomy, Workload,
|
| 283 |
+
headline, drivers_plot, shap_plot
|
| 284 |
+
],
|
| 285 |
)
|
| 286 |
|
| 287 |
demo.launch()
|