mns6rh commited on
Commit
03f90ba
·
verified ·
1 Parent(s): a05016c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -72
app.py CHANGED
@@ -1,5 +1,6 @@
1
  #!/usr/bin/env python
2
  # coding: utf-8
 
3
  import joblib
4
  import pandas as pd
5
  import gradio as gr
@@ -103,7 +104,8 @@ def make_driver_plot(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
103
  values = [Engagement, SupportiveGM, WellBeing, WorkEnvironment]
104
  colors = ["seagreen" if v >= th else "firebrick" for v in values]
105
 
106
- fig, ax = plt.subplots(figsize=(10.5, 4.8))
 
107
  ax.bar(VISIBLE_LABELS, values, color=colors)
108
 
109
  ax.axhline(th, linestyle="--", linewidth=2)
@@ -111,95 +113,65 @@ def make_driver_plot(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
111
 
112
  ax.set_ylim(1, 5.4)
113
  ax.set_yticks([1, 2, 3, 4, 5])
114
- ax.set_ylabel("Survey Score (1–5)")
115
  ax.set_title("Key Drivers vs Stable Threshold")
116
 
117
- ax.margins(x=0.15)
118
  plt.tight_layout()
119
- plt.subplots_adjust(bottom=0.20)
120
  return fig
121
 
122
 
123
  # =========================
124
- # Always-visible "importance" chart
125
  # =========================
126
  def get_global_importance_series():
127
- """
128
- Returns a pandas Series indexed by FEATURES (or None if not available).
129
- Works for many CatBoost / sklearn-style models.
130
- """
131
- # sklearn-like
132
  if hasattr(model, "feature_importances_"):
133
  try:
134
- vals = model.feature_importances_
135
- return pd.Series(vals, index=FEATURES)
136
  except Exception:
137
  pass
138
-
139
- # CatBoost-like
140
  if hasattr(model, "get_feature_importance"):
141
  try:
142
- vals = model.get_feature_importance()
143
- return pd.Series(vals, index=FEATURES)
144
  except Exception:
145
  pass
146
-
147
  return None
148
 
149
 
150
  def make_importance_plot(X: pd.DataFrame):
151
- """
152
- Try SHAP for this one-row scenario.
153
- If SHAP fails, show global feature importance.
154
- ALWAYS returns a non-empty chart.
155
- """
156
- fig, ax = plt.subplots(figsize=(10.5, 4.8))
157
-
158
- # Try SHAP (scenario-level)
159
  if SHAP_OK and explainer is not None:
160
  try:
161
  shap_vals = explainer.shap_values(X)
162
 
163
- # Handle list-of-classes vs single array
164
  if isinstance(shap_vals, list):
165
  classes = list(model.classes_)
166
- idx = classes.index(1) # at-risk class
167
  sv = shap_vals[idx][0]
168
  else:
169
  sv = shap_vals[0]
170
 
171
- s = pd.Series(sv, index=X.columns)
172
-
173
- # you don't want to talk about ManagementLevel
174
- s = s.drop(labels=["ManagementLevel"], errors="ignore")
175
-
176
- # top by absolute contribution
177
  s = s.reindex(s.abs().sort_values(ascending=False).index).head(8)
178
 
179
  ax.barh(s.index[::-1], s.values[::-1])
180
- ax.set_title("Top drivers of THIS prediction (SHAP impact)")
181
- ax.set_xlabel("Impact on model output (signed)")
182
  plt.tight_layout()
183
  return fig
184
- except Exception as e:
185
- # fall back to global importance
186
- local_err = str(e)
187
- else:
188
- local_err = shap_err or "shap not installed"
189
 
190
  # Fallback: global importance
191
  imp = get_global_importance_series()
192
  if imp is None:
193
- # Last resort: show message on chart (still visible)
194
- ax.text(
195
- 0.5, 0.55,
196
- "SHAP chart not available in this Space,\nand feature importance not found on the model.",
197
- ha="center", va="center", fontsize=11
198
- )
199
  ax.text(
200
- 0.5, 0.40,
201
- f"Reason: {local_err[:140]}",
202
- ha="center", va="center", fontsize=9
203
  )
204
  ax.set_axis_off()
205
  plt.tight_layout()
@@ -208,7 +180,7 @@ def make_importance_plot(X: pd.DataFrame):
208
  imp = imp.drop(labels=["ManagementLevel"], errors="ignore")
209
  imp = imp.sort_values(ascending=True).tail(8)
210
  ax.barh(imp.index, imp.values)
211
- ax.set_title("Global feature importance (fallback)")
212
  ax.set_xlabel("Importance")
213
  plt.tight_layout()
214
  return fig
@@ -223,7 +195,7 @@ def predict(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
223
  WellBeing = clamp(WellBeing)
224
  WorkEnvironment = clamp(WorkEnvironment)
225
 
226
- # Model still needs hidden vars. Hold at stable cluster values for a clean story.
227
  vals = {
228
  "Engagement": Engagement,
229
  "SupportiveGM": SupportiveGM,
@@ -238,13 +210,14 @@ def predict(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
238
  X = build_X(vals)
239
  p = prob_at_risk(X)
240
 
241
- headline = f"Predicted Status: **{risk_label(p)}**"
242
- tech_note = f"Explainer: **{'SHAP' if SHAP_OK else 'Global importance'}**"
 
243
 
244
  driver_fig = make_driver_plot(Engagement, SupportiveGM, WellBeing, WorkEnvironment)
245
  imp_fig = make_importance_plot(X)
246
 
247
- return headline + "\n\n" + tech_note, driver_fig, imp_fig
248
 
249
 
250
  def apply_recommendation():
@@ -258,47 +231,53 @@ def apply_recommendation():
258
 
259
 
260
  # =========================
261
- # UI
262
  # =========================
263
  CSS = """
264
- .fixed-plot { height: 520px; overflow: hidden; }
 
 
265
  """
266
 
267
  with gr.Blocks(css=CSS) as demo:
268
- gr.Markdown("# Retention Recommendation Simulator")
269
- gr.Markdown("Use the sliders, then click **Predict**. Click **Apply Recommendation Plan** to move to the stable target.")
 
 
 
 
 
270
 
271
  with gr.Row():
272
- with gr.Column(scale=1):
 
273
  Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
274
  SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
275
  WellBeing = gr.Slider(1, 5, value=CLUSTER_3["WellBeing"], step=0.01, label="Well-Being")
276
  WorkEnvironment = gr.Slider(1, 5, value=CLUSTER_3["WorkEnvironment"], step=0.01, label="Work Environment")
277
 
278
- btn_predict = gr.Button("Predict")
279
- btn_recommend = gr.Button("Apply Recommendation Plan")
280
-
281
- with gr.Column(scale=1):
282
- headline = gr.Markdown()
283
 
284
- gr.HTML('<div class="fixed-plot">')
285
- driver_plot = gr.Plot(label="Drivers vs Threshold")
286
- gr.HTML('</div>')
 
287
 
288
- gr.HTML('<div class="fixed-plot">')
289
- shap_plot = gr.Plot(label="SHAP / Importance")
290
- gr.HTML('</div>')
291
 
292
  btn_predict.click(
293
  fn=predict,
294
  inputs=[Engagement, SupportiveGM, WellBeing, WorkEnvironment],
295
- outputs=[headline, driver_plot, shap_plot],
296
  )
297
 
298
  btn_recommend.click(
299
  fn=apply_recommendation,
300
  inputs=[],
301
- outputs=[Engagement, SupportiveGM, WellBeing, WorkEnvironment, headline, driver_plot, shap_plot],
302
  )
303
 
304
  demo.launch()
 
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
+
4
  import joblib
5
  import pandas as pd
6
  import gradio as gr
 
104
  values = [Engagement, SupportiveGM, WellBeing, WorkEnvironment]
105
  colors = ["seagreen" if v >= th else "firebrick" for v in values]
106
 
107
+ # slightly smaller so both charts fit on one screen
108
+ fig, ax = plt.subplots(figsize=(8.8, 3.4))
109
  ax.bar(VISIBLE_LABELS, values, color=colors)
110
 
111
  ax.axhline(th, linestyle="--", linewidth=2)
 
113
 
114
  ax.set_ylim(1, 5.4)
115
  ax.set_yticks([1, 2, 3, 4, 5])
116
+ ax.set_ylabel("Score (1–5)")
117
  ax.set_title("Key Drivers vs Stable Threshold")
118
 
119
+ ax.margins(x=0.12)
120
  plt.tight_layout()
121
+ plt.subplots_adjust(bottom=0.22)
122
  return fig
123
 
124
 
125
  # =========================
126
+ # Importance chart (SHAP if available; otherwise global importance)
127
  # =========================
128
  def get_global_importance_series():
 
 
 
 
 
129
  if hasattr(model, "feature_importances_"):
130
  try:
131
+ return pd.Series(model.feature_importances_, index=FEATURES)
 
132
  except Exception:
133
  pass
 
 
134
  if hasattr(model, "get_feature_importance"):
135
  try:
136
+ return pd.Series(model.get_feature_importance(), index=FEATURES)
 
137
  except Exception:
138
  pass
 
139
  return None
140
 
141
 
142
  def make_importance_plot(X: pd.DataFrame):
143
+ fig, ax = plt.subplots(figsize=(8.8, 3.4))
144
+
145
+ # Try SHAP for this one-row scenario
 
 
 
 
 
146
  if SHAP_OK and explainer is not None:
147
  try:
148
  shap_vals = explainer.shap_values(X)
149
 
 
150
  if isinstance(shap_vals, list):
151
  classes = list(model.classes_)
152
+ idx = classes.index(1)
153
  sv = shap_vals[idx][0]
154
  else:
155
  sv = shap_vals[0]
156
 
157
+ s = pd.Series(sv, index=X.columns).drop(labels=["ManagementLevel"], errors="ignore")
 
 
 
 
 
158
  s = s.reindex(s.abs().sort_values(ascending=False).index).head(8)
159
 
160
  ax.barh(s.index[::-1], s.values[::-1])
161
+ ax.set_title("Top Drivers of This Prediction (SHAP)")
162
+ ax.set_xlabel("Impact (signed)")
163
  plt.tight_layout()
164
  return fig
165
+ except Exception:
166
+ pass
 
 
 
167
 
168
  # Fallback: global importance
169
  imp = get_global_importance_series()
170
  if imp is None:
 
 
 
 
 
 
171
  ax.text(
172
+ 0.5, 0.5,
173
+ "SHAP/importance not available\n(add 'shap' to requirements.txt for SHAP)",
174
+ ha="center", va="center", fontsize=10
175
  )
176
  ax.set_axis_off()
177
  plt.tight_layout()
 
180
  imp = imp.drop(labels=["ManagementLevel"], errors="ignore")
181
  imp = imp.sort_values(ascending=True).tail(8)
182
  ax.barh(imp.index, imp.values)
183
+ ax.set_title("Global Feature Importance (Fallback)")
184
  ax.set_xlabel("Importance")
185
  plt.tight_layout()
186
  return fig
 
195
  WellBeing = clamp(WellBeing)
196
  WorkEnvironment = clamp(WorkEnvironment)
197
 
198
+ # Model still needs hidden vars. Hold at stable cluster values for clean story.
199
  vals = {
200
  "Engagement": Engagement,
201
  "SupportiveGM": SupportiveGM,
 
210
  X = build_X(vals)
211
  p = prob_at_risk(X)
212
 
213
+ # Keep headline single-line to prevent layout jump
214
+ explainer_name = "SHAP" if SHAP_OK else "Importance"
215
+ headline = f"Predicted Status: {risk_label(p)} | Explanation: {explainer_name}"
216
 
217
  driver_fig = make_driver_plot(Engagement, SupportiveGM, WellBeing, WorkEnvironment)
218
  imp_fig = make_importance_plot(X)
219
 
220
+ return headline, driver_fig, imp_fig
221
 
222
 
223
  def apply_recommendation():
 
231
 
232
 
233
  # =========================
234
+ # UI Layout (no scrolling)
235
  # =========================
236
  CSS = """
237
+ /* Make the app fit on one screen as much as possible */
238
+ #app-wrap { max-width: 1200px; margin: 0 auto; }
239
+ .compact .gr-markdown { margin-bottom: 0.4rem !important; }
240
  """
241
 
242
  with gr.Blocks(css=CSS) as demo:
243
+ gr.Markdown(
244
+ "<div id='app-wrap' class='compact'>"
245
+ "<h2>Retention Recommendation Simulator</h2>"
246
+ "<p style='margin-top:0;'>Adjust the 4 drivers and click <b>Predict</b>. "
247
+ "Click <b>Apply Recommendation Plan</b> to jump to the stable target.</p>"
248
+ "</div>"
249
+ )
250
 
251
  with gr.Row():
252
+ # LEFT: sliders + buttons
253
+ with gr.Column(scale=5, min_width=420):
254
  Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
255
  SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
256
  WellBeing = gr.Slider(1, 5, value=CLUSTER_3["WellBeing"], step=0.01, label="Well-Being")
257
  WorkEnvironment = gr.Slider(1, 5, value=CLUSTER_3["WorkEnvironment"], step=0.01, label="Work Environment")
258
 
259
+ with gr.Row():
260
+ btn_predict = gr.Button("Predict")
261
+ btn_recommend = gr.Button("Apply Recommendation Plan")
 
 
262
 
263
+ # RIGHT: headline + two plots stacked
264
+ with gr.Column(scale=7, min_width=520):
265
+ # Use Textbox (single line) to avoid markdown height jumps
266
+ headline = gr.Textbox(label="Result", value="", interactive=False)
267
 
268
+ driver_plot = gr.Plot(label="Key Drivers vs Stable Threshold")
269
+ importance_plot = gr.Plot(label="SHAP / Feature Importance")
 
270
 
271
  btn_predict.click(
272
  fn=predict,
273
  inputs=[Engagement, SupportiveGM, WellBeing, WorkEnvironment],
274
+ outputs=[headline, driver_plot, importance_plot],
275
  )
276
 
277
  btn_recommend.click(
278
  fn=apply_recommendation,
279
  inputs=[],
280
+ outputs=[Engagement, SupportiveGM, WellBeing, WorkEnvironment, headline, driver_plot, importance_plot],
281
  )
282
 
283
  demo.launch()