Update app.py
Browse files
app.py
CHANGED
|
@@ -6,7 +6,6 @@ import pandas as pd
|
|
| 6 |
import gradio as gr
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
|
| 9 |
-
# Optional: helps reduce tiny resize flicker in some HF setups
|
| 10 |
plt.rcParams["figure.dpi"] = 100
|
| 11 |
|
| 12 |
# =========================
|
|
@@ -26,9 +25,7 @@ FEATURES = [
|
|
| 26 |
]
|
| 27 |
|
| 28 |
# =========================
|
| 29 |
-
# Cluster anchors
|
| 30 |
-
# Start state = Cluster 3 (at-risk profile)
|
| 31 |
-
# Target state = Cluster 1 (stable profile)
|
| 32 |
# =========================
|
| 33 |
CLUSTER_1 = {
|
| 34 |
"Voice": 4.84,
|
|
@@ -53,15 +50,11 @@ CLUSTER_3 = {
|
|
| 53 |
VISIBLE_DRIVERS = ["Engagement", "SupportiveGM", "WellBeing", "WorkEnvironment"]
|
| 54 |
VISIBLE_LABELS = ["Engagement", "Supportive GM", "Well-Being", "Work Environment"]
|
| 55 |
|
| 56 |
-
|
| 57 |
# =========================
|
| 58 |
-
# SHAP setup (
|
| 59 |
-
# Shows which features drive the current prediction.
|
| 60 |
-
# If SHAP isn't available, we fall back to model feature importance (if available).
|
| 61 |
# =========================
|
| 62 |
SHAP_AVAILABLE = False
|
| 63 |
explainer = None
|
| 64 |
-
|
| 65 |
try:
|
| 66 |
import shap # noqa: F401
|
| 67 |
from shap import TreeExplainer # type: ignore
|
|
@@ -97,12 +90,11 @@ def risk_label(p):
|
|
| 97 |
|
| 98 |
|
| 99 |
def stable_threshold():
|
| 100 |
-
# threshold line = minimum of the 4 visible drivers in the stable (Cluster 1) profile
|
| 101 |
return min(CLUSTER_1[v] for v in VISIBLE_DRIVERS)
|
| 102 |
|
| 103 |
|
| 104 |
# =========================
|
| 105 |
-
# Plot:
|
| 106 |
# =========================
|
| 107 |
def make_driver_plot(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
|
| 108 |
th = stable_threshold()
|
|
@@ -115,7 +107,7 @@ def make_driver_plot(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
|
|
| 115 |
ax.axhline(th, linestyle="--", linewidth=2)
|
| 116 |
ax.text(3.05, th, "Stable threshold", va="center")
|
| 117 |
|
| 118 |
-
ax.set_ylim(1, 5.4)
|
| 119 |
ax.set_yticks([1, 2, 3, 4, 5])
|
| 120 |
ax.set_ylabel("Survey Score (1–5)")
|
| 121 |
ax.set_title("Key Drivers vs Stable Threshold")
|
|
@@ -127,54 +119,34 @@ def make_driver_plot(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
|
|
| 127 |
|
| 128 |
|
| 129 |
# =========================
|
| 130 |
-
# Plot: SHAP (or fallback
|
| 131 |
# =========================
|
| 132 |
def make_shap_plot(X: pd.DataFrame):
|
| 133 |
-
# We’ll show feature impact for the CURRENT prediction (one-row SHAP bar chart).
|
| 134 |
-
# Exclude ManagementLevel from the display because you don't want mgmt info in the story.
|
| 135 |
-
display_features = [f for f in FEATURES if f != "ManagementLevel"]
|
| 136 |
-
|
| 137 |
fig, ax = plt.subplots(figsize=(10.5, 4.8))
|
| 138 |
|
| 139 |
if SHAP_AVAILABLE and explainer is not None:
|
| 140 |
shap_vals = explainer.shap_values(X)
|
| 141 |
|
| 142 |
-
# shap_values formats vary by model:
|
| 143 |
-
# - array (n, p)
|
| 144 |
-
# - list of arrays for classes
|
| 145 |
-
# We'll pick the "At Risk" class (label 1) if it's a list.
|
| 146 |
if isinstance(shap_vals, list):
|
| 147 |
-
# classes aligned with model.classes_
|
| 148 |
classes = list(model.classes_)
|
| 149 |
idx = classes.index(1)
|
| 150 |
sv = shap_vals[idx][0]
|
| 151 |
else:
|
| 152 |
sv = shap_vals[0]
|
| 153 |
|
| 154 |
-
|
| 155 |
-
s =
|
| 156 |
-
|
| 157 |
-
# Drop ManagementLevel for display
|
| 158 |
-
s = s.drop(labels=["ManagementLevel"], errors="ignore")
|
| 159 |
-
|
| 160 |
-
# Rank by absolute contribution
|
| 161 |
-
s = s.reindex(s.abs().sort_values(ascending=False).index)
|
| 162 |
-
|
| 163 |
-
# Plot top 8 (or fewer)
|
| 164 |
-
top = s.head(8)
|
| 165 |
-
ax.barh(top.index[::-1], top.values[::-1])
|
| 166 |
|
|
|
|
| 167 |
ax.set_title("What drives this prediction (SHAP impact)")
|
| 168 |
ax.set_xlabel("Impact on model output (signed)")
|
| 169 |
plt.tight_layout()
|
| 170 |
return fig
|
| 171 |
|
| 172 |
-
#
|
| 173 |
imp = None
|
| 174 |
-
# sklearn-style
|
| 175 |
if hasattr(model, "feature_importances_"):
|
| 176 |
imp = pd.Series(model.feature_importances_, index=FEATURES)
|
| 177 |
-
# CatBoost-style (sometimes)
|
| 178 |
elif hasattr(model, "get_feature_importance"):
|
| 179 |
try:
|
| 180 |
imp = pd.Series(model.get_feature_importance(), index=FEATURES)
|
|
@@ -184,17 +156,14 @@ def make_shap_plot(X: pd.DataFrame):
|
|
| 184 |
if imp is None:
|
| 185 |
ax.text(
|
| 186 |
0.5, 0.5,
|
| 187 |
-
"SHAP not available
|
| 188 |
ha="center", va="center"
|
| 189 |
)
|
| 190 |
ax.set_axis_off()
|
| 191 |
plt.tight_layout()
|
| 192 |
return fig
|
| 193 |
|
| 194 |
-
|
| 195 |
-
imp = imp.drop(labels=["ManagementLevel"], errors="ignore")
|
| 196 |
-
imp = imp.sort_values(ascending=True).tail(8)
|
| 197 |
-
|
| 198 |
ax.barh(imp.index, imp.values)
|
| 199 |
ax.set_title("Feature importance (fallback — not SHAP)")
|
| 200 |
ax.set_xlabel("Importance")
|
|
@@ -203,21 +172,18 @@ def make_shap_plot(X: pd.DataFrame):
|
|
| 203 |
|
| 204 |
|
| 205 |
# =========================
|
| 206 |
-
#
|
| 207 |
# =========================
|
| 208 |
def predict(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
|
| 209 |
-
# visible
|
| 210 |
Engagement = clamp(Engagement)
|
| 211 |
SupportiveGM = clamp(SupportiveGM)
|
| 212 |
WellBeing = clamp(WellBeing)
|
| 213 |
WorkEnvironment = clamp(WorkEnvironment)
|
| 214 |
|
| 215 |
-
# IMPORTANT: model still needs hidden vars. We'll hold them at the stable (Cluster 1) levels.
|
| 216 |
-
# This keeps the story focused on the 4 drivers you’re showing.
|
| 217 |
vals = {
|
| 218 |
"Engagement": Engagement,
|
| 219 |
"SupportiveGM": SupportiveGM,
|
| 220 |
-
"ManagementLevel": 2, # fixed constant; not shown
|
| 221 |
"WellBeing": WellBeing,
|
| 222 |
"Voice": CLUSTER_1["Voice"],
|
| 223 |
"DecisionAutonomy": CLUSTER_1["DecisionAutonomy"],
|
|
@@ -227,19 +193,11 @@ def predict(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
|
|
| 227 |
|
| 228 |
X = build_X(vals)
|
| 229 |
p = prob_at_risk(X)
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
headline = f"Predicted Status: **{label}**"
|
| 233 |
|
| 234 |
-
|
| 235 |
-
shap_fig = make_shap_plot(X)
|
| 236 |
|
| 237 |
-
return headline, driver_fig, shap_fig
|
| 238 |
|
| 239 |
-
|
| 240 |
-
# =========================
|
| 241 |
-
# Apply recommendation = move to Cluster 1 targets
|
| 242 |
-
# =========================
|
| 243 |
def apply_recommendation():
|
| 244 |
e = CLUSTER_1["Engagement"]
|
| 245 |
s = CLUSTER_1["SupportiveGM"]
|
|
@@ -251,17 +209,21 @@ def apply_recommendation():
|
|
| 251 |
|
| 252 |
|
| 253 |
# =========================
|
| 254 |
-
# UI
|
| 255 |
# =========================
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
gr.Markdown("# Retention Recommendation Simulator")
|
| 258 |
-
gr.Markdown(
|
| 259 |
-
"Use the sliders to simulate workplace conditions. "
|
| 260 |
-
"Click **Apply Recommendation Plan** to move the profile to the stable target."
|
| 261 |
-
)
|
| 262 |
|
| 263 |
with gr.Row():
|
| 264 |
-
with gr.Column():
|
| 265 |
Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
|
| 266 |
SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
|
| 267 |
WellBeing = gr.Slider(1, 5, value=CLUSTER_3["WellBeing"], step=0.01, label="Well-Being")
|
|
@@ -270,10 +232,16 @@ with gr.Blocks() as demo:
|
|
| 270 |
btn_predict = gr.Button("Predict")
|
| 271 |
btn_recommend = gr.Button("Apply Recommendation Plan")
|
| 272 |
|
| 273 |
-
with gr.Column():
|
| 274 |
headline = gr.Markdown()
|
|
|
|
|
|
|
| 275 |
driver_plot = gr.Plot(label="Drivers vs Threshold")
|
|
|
|
|
|
|
|
|
|
| 276 |
shap_plot = gr.Plot(label="SHAP / Feature Impact")
|
|
|
|
| 277 |
|
| 278 |
btn_predict.click(
|
| 279 |
fn=predict,
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
|
|
|
|
| 9 |
plt.rcParams["figure.dpi"] = 100
|
| 10 |
|
| 11 |
# =========================
|
|
|
|
| 25 |
]
|
| 26 |
|
| 27 |
# =========================
|
| 28 |
+
# Cluster anchors
|
|
|
|
|
|
|
| 29 |
# =========================
|
| 30 |
CLUSTER_1 = {
|
| 31 |
"Voice": 4.84,
|
|
|
|
| 50 |
VISIBLE_DRIVERS = ["Engagement", "SupportiveGM", "WellBeing", "WorkEnvironment"]
|
| 51 |
VISIBLE_LABELS = ["Engagement", "Supportive GM", "Well-Being", "Work Environment"]
|
| 52 |
|
|
|
|
| 53 |
# =========================
|
| 54 |
+
# SHAP setup (optional)
|
|
|
|
|
|
|
| 55 |
# =========================
|
| 56 |
SHAP_AVAILABLE = False
|
| 57 |
explainer = None
|
|
|
|
| 58 |
try:
|
| 59 |
import shap # noqa: F401
|
| 60 |
from shap import TreeExplainer # type: ignore
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
def stable_threshold():
|
|
|
|
| 93 |
return min(CLUSTER_1[v] for v in VISIBLE_DRIVERS)
|
| 94 |
|
| 95 |
|
| 96 |
# =========================
|
| 97 |
+
# Plot: drivers vs threshold
|
| 98 |
# =========================
|
| 99 |
def make_driver_plot(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
|
| 100 |
th = stable_threshold()
|
|
|
|
| 107 |
ax.axhline(th, linestyle="--", linewidth=2)
|
| 108 |
ax.text(3.05, th, "Stable threshold", va="center")
|
| 109 |
|
| 110 |
+
ax.set_ylim(1, 5.4)
|
| 111 |
ax.set_yticks([1, 2, 3, 4, 5])
|
| 112 |
ax.set_ylabel("Survey Score (1–5)")
|
| 113 |
ax.set_title("Key Drivers vs Stable Threshold")
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
# =========================
|
| 122 |
+
# Plot: SHAP (or fallback)
|
| 123 |
# =========================
|
| 124 |
def make_shap_plot(X: pd.DataFrame):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
fig, ax = plt.subplots(figsize=(10.5, 4.8))
|
| 126 |
|
| 127 |
if SHAP_AVAILABLE and explainer is not None:
|
| 128 |
shap_vals = explainer.shap_values(X)
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
if isinstance(shap_vals, list):
|
|
|
|
| 131 |
classes = list(model.classes_)
|
| 132 |
idx = classes.index(1)
|
| 133 |
sv = shap_vals[idx][0]
|
| 134 |
else:
|
| 135 |
sv = shap_vals[0]
|
| 136 |
|
| 137 |
+
s = pd.Series(sv, index=X.columns).drop(labels=["ManagementLevel"], errors="ignore")
|
| 138 |
+
s = s.reindex(s.abs().sort_values(ascending=False).index).head(8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
+
ax.barh(s.index[::-1], s.values[::-1])
|
| 141 |
ax.set_title("What drives this prediction (SHAP impact)")
|
| 142 |
ax.set_xlabel("Impact on model output (signed)")
|
| 143 |
plt.tight_layout()
|
| 144 |
return fig
|
| 145 |
|
| 146 |
+
# fallback feature importance
|
| 147 |
imp = None
|
|
|
|
| 148 |
if hasattr(model, "feature_importances_"):
|
| 149 |
imp = pd.Series(model.feature_importances_, index=FEATURES)
|
|
|
|
| 150 |
elif hasattr(model, "get_feature_importance"):
|
| 151 |
try:
|
| 152 |
imp = pd.Series(model.get_feature_importance(), index=FEATURES)
|
|
|
|
| 156 |
if imp is None:
|
| 157 |
ax.text(
|
| 158 |
0.5, 0.5,
|
| 159 |
+
"SHAP not available.\nAdd 'shap' to requirements.txt for SHAP charts.",
|
| 160 |
ha="center", va="center"
|
| 161 |
)
|
| 162 |
ax.set_axis_off()
|
| 163 |
plt.tight_layout()
|
| 164 |
return fig
|
| 165 |
|
| 166 |
+
imp = imp.drop(labels=["ManagementLevel"], errors="ignore").sort_values(ascending=True).tail(8)
|
|
|
|
|
|
|
|
|
|
| 167 |
ax.barh(imp.index, imp.values)
|
| 168 |
ax.set_title("Feature importance (fallback — not SHAP)")
|
| 169 |
ax.set_xlabel("Importance")
|
|
|
|
| 172 |
|
| 173 |
|
| 174 |
# =========================
|
| 175 |
+
# Prediction
|
| 176 |
# =========================
|
| 177 |
def predict(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
|
|
|
|
| 178 |
Engagement = clamp(Engagement)
|
| 179 |
SupportiveGM = clamp(SupportiveGM)
|
| 180 |
WellBeing = clamp(WellBeing)
|
| 181 |
WorkEnvironment = clamp(WorkEnvironment)
|
| 182 |
|
|
|
|
|
|
|
| 183 |
vals = {
|
| 184 |
"Engagement": Engagement,
|
| 185 |
"SupportiveGM": SupportiveGM,
|
| 186 |
+
"ManagementLevel": 2, # fixed constant; not shown
|
| 187 |
"WellBeing": WellBeing,
|
| 188 |
"Voice": CLUSTER_1["Voice"],
|
| 189 |
"DecisionAutonomy": CLUSTER_1["DecisionAutonomy"],
|
|
|
|
| 193 |
|
| 194 |
X = build_X(vals)
|
| 195 |
p = prob_at_risk(X)
|
| 196 |
+
headline = f"Predicted Status: **{risk_label(p)}**"
|
|
|
|
|
|
|
| 197 |
|
| 198 |
+
return headline, make_driver_plot(Engagement, SupportiveGM, WellBeing, WorkEnvironment), make_shap_plot(X)
|
|
|
|
| 199 |
|
|
|
|
| 200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
def apply_recommendation():
|
| 202 |
e = CLUSTER_1["Engagement"]
|
| 203 |
s = CLUSTER_1["SupportiveGM"]
|
|
|
|
| 209 |
|
| 210 |
|
| 211 |
# =========================
|
| 212 |
+
# UI (fixed-height plot areas to prevent shaking)
|
| 213 |
# =========================
|
| 214 |
+
CSS = """
|
| 215 |
+
.fixed-plot {
|
| 216 |
+
height: 520px;
|
| 217 |
+
overflow: hidden;
|
| 218 |
+
}
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
with gr.Blocks(css=CSS) as demo:
|
| 222 |
gr.Markdown("# Retention Recommendation Simulator")
|
| 223 |
+
gr.Markdown("Use the sliders, then click **Predict**. Click **Apply Recommendation Plan** to move to the stable target.")
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
with gr.Row():
|
| 226 |
+
with gr.Column(scale=1):
|
| 227 |
Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
|
| 228 |
SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
|
| 229 |
WellBeing = gr.Slider(1, 5, value=CLUSTER_3["WellBeing"], step=0.01, label="Well-Being")
|
|
|
|
| 232 |
btn_predict = gr.Button("Predict")
|
| 233 |
btn_recommend = gr.Button("Apply Recommendation Plan")
|
| 234 |
|
| 235 |
+
with gr.Column(scale=1):
|
| 236 |
headline = gr.Markdown()
|
| 237 |
+
|
| 238 |
+
gr.HTML('<div class="fixed-plot">')
|
| 239 |
driver_plot = gr.Plot(label="Drivers vs Threshold")
|
| 240 |
+
gr.HTML('</div>')
|
| 241 |
+
|
| 242 |
+
gr.HTML('<div class="fixed-plot">')
|
| 243 |
shap_plot = gr.Plot(label="SHAP / Feature Impact")
|
| 244 |
+
gr.HTML('</div>')
|
| 245 |
|
| 246 |
btn_predict.click(
|
| 247 |
fn=predict,
|