mns6rh commited on
Commit
8a548c5
·
verified ·
1 Parent(s): 41087b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -65
app.py CHANGED
@@ -6,7 +6,6 @@ import pandas as pd
6
  import gradio as gr
7
  import matplotlib.pyplot as plt
8
 
9
- # Optional: helps reduce tiny resize flicker in some HF setups
10
  plt.rcParams["figure.dpi"] = 100
11
 
12
  # =========================
@@ -26,9 +25,7 @@ FEATURES = [
26
  ]
27
 
28
  # =========================
29
- # Cluster anchors (from SPSS)
30
- # Start state = Cluster 3 (at-risk profile)
31
- # Target state = Cluster 1 (stable profile)
32
  # =========================
33
  CLUSTER_1 = {
34
  "Voice": 4.84,
@@ -53,15 +50,11 @@ CLUSTER_3 = {
53
  VISIBLE_DRIVERS = ["Engagement", "SupportiveGM", "WellBeing", "WorkEnvironment"]
54
  VISIBLE_LABELS = ["Engagement", "Supportive GM", "Well-Being", "Work Environment"]
55
 
56
-
57
  # =========================
58
- # SHAP setup (scenario-level)
59
- # Shows which features drive the current prediction.
60
- # If SHAP isn't available, we fall back to model feature importance (if available).
61
  # =========================
62
  SHAP_AVAILABLE = False
63
  explainer = None
64
-
65
  try:
66
  import shap # noqa: F401
67
  from shap import TreeExplainer # type: ignore
@@ -97,12 +90,11 @@ def risk_label(p):
97
 
98
 
99
  def stable_threshold():
100
- # threshold line = minimum of the 4 visible drivers in the stable (Cluster 1) profile
101
  return min(CLUSTER_1[v] for v in VISIBLE_DRIVERS)
102
 
103
 
104
  # =========================
105
- # Plot: driver bars + threshold
106
  # =========================
107
  def make_driver_plot(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
108
  th = stable_threshold()
@@ -115,7 +107,7 @@ def make_driver_plot(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
115
  ax.axhline(th, linestyle="--", linewidth=2)
116
  ax.text(3.05, th, "Stable threshold", va="center")
117
 
118
- ax.set_ylim(1, 5.4) # extra space above 5
119
  ax.set_yticks([1, 2, 3, 4, 5])
120
  ax.set_ylabel("Survey Score (1–5)")
121
  ax.set_title("Key Drivers vs Stable Threshold")
@@ -127,54 +119,34 @@ def make_driver_plot(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
127
 
128
 
129
  # =========================
130
- # Plot: SHAP (or fallback importance)
131
  # =========================
132
  def make_shap_plot(X: pd.DataFrame):
133
- # We’ll show feature impact for the CURRENT prediction (one-row SHAP bar chart).
134
- # Exclude ManagementLevel from the display because you don't want mgmt info in the story.
135
- display_features = [f for f in FEATURES if f != "ManagementLevel"]
136
-
137
  fig, ax = plt.subplots(figsize=(10.5, 4.8))
138
 
139
  if SHAP_AVAILABLE and explainer is not None:
140
  shap_vals = explainer.shap_values(X)
141
 
142
- # shap_values formats vary by model:
143
- # - array (n, p)
144
- # - list of arrays for classes
145
- # We'll pick the "At Risk" class (label 1) if it's a list.
146
  if isinstance(shap_vals, list):
147
- # classes aligned with model.classes_
148
  classes = list(model.classes_)
149
  idx = classes.index(1)
150
  sv = shap_vals[idx][0]
151
  else:
152
  sv = shap_vals[0]
153
 
154
- # Build series aligned to columns
155
- s = pd.Series(sv, index=X.columns)
156
-
157
- # Drop ManagementLevel for display
158
- s = s.drop(labels=["ManagementLevel"], errors="ignore")
159
-
160
- # Rank by absolute contribution
161
- s = s.reindex(s.abs().sort_values(ascending=False).index)
162
-
163
- # Plot top 8 (or fewer)
164
- top = s.head(8)
165
- ax.barh(top.index[::-1], top.values[::-1])
166
 
 
167
  ax.set_title("What drives this prediction (SHAP impact)")
168
  ax.set_xlabel("Impact on model output (signed)")
169
  plt.tight_layout()
170
  return fig
171
 
172
- # ---- Fallback: model feature importance (global) ----
173
  imp = None
174
- # sklearn-style
175
  if hasattr(model, "feature_importances_"):
176
  imp = pd.Series(model.feature_importances_, index=FEATURES)
177
- # CatBoost-style (sometimes)
178
  elif hasattr(model, "get_feature_importance"):
179
  try:
180
  imp = pd.Series(model.get_feature_importance(), index=FEATURES)
@@ -184,17 +156,14 @@ def make_shap_plot(X: pd.DataFrame):
184
  if imp is None:
185
  ax.text(
186
  0.5, 0.5,
187
- "SHAP not available in this Space.\nInstall 'shap' to show a SHAP chart.",
188
  ha="center", va="center"
189
  )
190
  ax.set_axis_off()
191
  plt.tight_layout()
192
  return fig
193
 
194
- # Drop ManagementLevel for display
195
- imp = imp.drop(labels=["ManagementLevel"], errors="ignore")
196
- imp = imp.sort_values(ascending=True).tail(8)
197
-
198
  ax.barh(imp.index, imp.values)
199
  ax.set_title("Feature importance (fallback — not SHAP)")
200
  ax.set_xlabel("Importance")
@@ -203,21 +172,18 @@ def make_shap_plot(X: pd.DataFrame):
203
 
204
 
205
  # =========================
206
- # Core prediction
207
  # =========================
208
  def predict(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
209
- # visible
210
  Engagement = clamp(Engagement)
211
  SupportiveGM = clamp(SupportiveGM)
212
  WellBeing = clamp(WellBeing)
213
  WorkEnvironment = clamp(WorkEnvironment)
214
 
215
- # IMPORTANT: model still needs hidden vars. We'll hold them at the stable (Cluster 1) levels.
216
- # This keeps the story focused on the 4 drivers you’re showing.
217
  vals = {
218
  "Engagement": Engagement,
219
  "SupportiveGM": SupportiveGM,
220
- "ManagementLevel": 2, # fixed constant; not shown anywhere
221
  "WellBeing": WellBeing,
222
  "Voice": CLUSTER_1["Voice"],
223
  "DecisionAutonomy": CLUSTER_1["DecisionAutonomy"],
@@ -227,19 +193,11 @@ def predict(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
227
 
228
  X = build_X(vals)
229
  p = prob_at_risk(X)
230
- label = risk_label(p)
231
-
232
- headline = f"Predicted Status: **{label}**"
233
 
234
- driver_fig = make_driver_plot(Engagement, SupportiveGM, WellBeing, WorkEnvironment)
235
- shap_fig = make_shap_plot(X)
236
 
237
- return headline, driver_fig, shap_fig
238
 
239
-
240
- # =========================
241
- # Apply recommendation = move to Cluster 1 targets
242
- # =========================
243
  def apply_recommendation():
244
  e = CLUSTER_1["Engagement"]
245
  s = CLUSTER_1["SupportiveGM"]
@@ -251,17 +209,21 @@ def apply_recommendation():
251
 
252
 
253
  # =========================
254
- # UI
255
  # =========================
256
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
257
  gr.Markdown("# Retention Recommendation Simulator")
258
- gr.Markdown(
259
- "Use the sliders to simulate workplace conditions. "
260
- "Click **Apply Recommendation Plan** to move the profile to the stable target."
261
- )
262
 
263
  with gr.Row():
264
- with gr.Column():
265
  Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
266
  SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
267
  WellBeing = gr.Slider(1, 5, value=CLUSTER_3["WellBeing"], step=0.01, label="Well-Being")
@@ -270,10 +232,16 @@ with gr.Blocks() as demo:
270
  btn_predict = gr.Button("Predict")
271
  btn_recommend = gr.Button("Apply Recommendation Plan")
272
 
273
- with gr.Column():
274
  headline = gr.Markdown()
 
 
275
  driver_plot = gr.Plot(label="Drivers vs Threshold")
 
 
 
276
  shap_plot = gr.Plot(label="SHAP / Feature Impact")
 
277
 
278
  btn_predict.click(
279
  fn=predict,
 
6
  import gradio as gr
7
  import matplotlib.pyplot as plt
8
 
 
9
  plt.rcParams["figure.dpi"] = 100
10
 
11
  # =========================
 
25
  ]
26
 
27
  # =========================
28
+ # Cluster anchors
 
 
29
  # =========================
30
  CLUSTER_1 = {
31
  "Voice": 4.84,
 
50
  VISIBLE_DRIVERS = ["Engagement", "SupportiveGM", "WellBeing", "WorkEnvironment"]
51
  VISIBLE_LABELS = ["Engagement", "Supportive GM", "Well-Being", "Work Environment"]
52
 
 
53
  # =========================
54
+ # SHAP setup (optional)
 
 
55
  # =========================
56
  SHAP_AVAILABLE = False
57
  explainer = None
 
58
  try:
59
  import shap # noqa: F401
60
  from shap import TreeExplainer # type: ignore
 
90
 
91
 
92
  def stable_threshold():
 
93
  return min(CLUSTER_1[v] for v in VISIBLE_DRIVERS)
94
 
95
 
96
  # =========================
97
+ # Plot: drivers vs threshold
98
  # =========================
99
  def make_driver_plot(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
100
  th = stable_threshold()
 
107
  ax.axhline(th, linestyle="--", linewidth=2)
108
  ax.text(3.05, th, "Stable threshold", va="center")
109
 
110
+ ax.set_ylim(1, 5.4)
111
  ax.set_yticks([1, 2, 3, 4, 5])
112
  ax.set_ylabel("Survey Score (1–5)")
113
  ax.set_title("Key Drivers vs Stable Threshold")
 
119
 
120
 
121
  # =========================
122
+ # Plot: SHAP (or fallback)
123
  # =========================
124
  def make_shap_plot(X: pd.DataFrame):
 
 
 
 
125
  fig, ax = plt.subplots(figsize=(10.5, 4.8))
126
 
127
  if SHAP_AVAILABLE and explainer is not None:
128
  shap_vals = explainer.shap_values(X)
129
 
 
 
 
 
130
  if isinstance(shap_vals, list):
 
131
  classes = list(model.classes_)
132
  idx = classes.index(1)
133
  sv = shap_vals[idx][0]
134
  else:
135
  sv = shap_vals[0]
136
 
137
+ s = pd.Series(sv, index=X.columns).drop(labels=["ManagementLevel"], errors="ignore")
138
+ s = s.reindex(s.abs().sort_values(ascending=False).index).head(8)
 
 
 
 
 
 
 
 
 
 
139
 
140
+ ax.barh(s.index[::-1], s.values[::-1])
141
  ax.set_title("What drives this prediction (SHAP impact)")
142
  ax.set_xlabel("Impact on model output (signed)")
143
  plt.tight_layout()
144
  return fig
145
 
146
+ # fallback feature importance
147
  imp = None
 
148
  if hasattr(model, "feature_importances_"):
149
  imp = pd.Series(model.feature_importances_, index=FEATURES)
 
150
  elif hasattr(model, "get_feature_importance"):
151
  try:
152
  imp = pd.Series(model.get_feature_importance(), index=FEATURES)
 
156
  if imp is None:
157
  ax.text(
158
  0.5, 0.5,
159
+ "SHAP not available.\nAdd 'shap' to requirements.txt for SHAP charts.",
160
  ha="center", va="center"
161
  )
162
  ax.set_axis_off()
163
  plt.tight_layout()
164
  return fig
165
 
166
+ imp = imp.drop(labels=["ManagementLevel"], errors="ignore").sort_values(ascending=True).tail(8)
 
 
 
167
  ax.barh(imp.index, imp.values)
168
  ax.set_title("Feature importance (fallback — not SHAP)")
169
  ax.set_xlabel("Importance")
 
172
 
173
 
174
  # =========================
175
+ # Prediction
176
  # =========================
177
  def predict(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
 
178
  Engagement = clamp(Engagement)
179
  SupportiveGM = clamp(SupportiveGM)
180
  WellBeing = clamp(WellBeing)
181
  WorkEnvironment = clamp(WorkEnvironment)
182
 
 
 
183
  vals = {
184
  "Engagement": Engagement,
185
  "SupportiveGM": SupportiveGM,
186
+ "ManagementLevel": 2, # fixed constant; not shown
187
  "WellBeing": WellBeing,
188
  "Voice": CLUSTER_1["Voice"],
189
  "DecisionAutonomy": CLUSTER_1["DecisionAutonomy"],
 
193
 
194
  X = build_X(vals)
195
  p = prob_at_risk(X)
196
+ headline = f"Predicted Status: **{risk_label(p)}**"
 
 
197
 
198
+ return headline, make_driver_plot(Engagement, SupportiveGM, WellBeing, WorkEnvironment), make_shap_plot(X)
 
199
 
 
200
 
 
 
 
 
201
  def apply_recommendation():
202
  e = CLUSTER_1["Engagement"]
203
  s = CLUSTER_1["SupportiveGM"]
 
209
 
210
 
211
  # =========================
212
+ # UI (fixed-height plot areas to prevent shaking)
213
  # =========================
214
+ CSS = """
215
+ .fixed-plot {
216
+ height: 520px;
217
+ overflow: hidden;
218
+ }
219
+ """
220
+
221
+ with gr.Blocks(css=CSS) as demo:
222
  gr.Markdown("# Retention Recommendation Simulator")
223
+ gr.Markdown("Use the sliders, then click **Predict**. Click **Apply Recommendation Plan** to move to the stable target.")
 
 
 
224
 
225
  with gr.Row():
226
+ with gr.Column(scale=1):
227
  Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
228
  SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
229
  WellBeing = gr.Slider(1, 5, value=CLUSTER_3["WellBeing"], step=0.01, label="Well-Being")
 
232
  btn_predict = gr.Button("Predict")
233
  btn_recommend = gr.Button("Apply Recommendation Plan")
234
 
235
+ with gr.Column(scale=1):
236
  headline = gr.Markdown()
237
+
238
+ gr.HTML('<div class="fixed-plot">')
239
  driver_plot = gr.Plot(label="Drivers vs Threshold")
240
+ gr.HTML('</div>')
241
+
242
+ gr.HTML('<div class="fixed-plot">')
243
  shap_plot = gr.Plot(label="SHAP / Feature Impact")
244
+ gr.HTML('</div>')
245
 
246
  btn_predict.click(
247
  fn=predict,