mns6rh commited on
Commit
1074df2
·
verified ·
1 Parent(s): 8fdcc57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -30
app.py CHANGED
@@ -78,50 +78,69 @@ ALL_DRIVER_LABELS = [
78
  "Workload",
79
  ]
80
 
 
81
  # =========================
82
  # Helpers
83
  # =========================
84
  def clamp_1_5(x):
85
  return max(1.0, min(5.0, float(x)))
86
 
 
87
  def build_X(vals: dict) -> pd.DataFrame:
88
  row = {f: vals[f] for f in FEATURES}
89
  return pd.DataFrame([[row[f] for f in FEATURES]], columns=FEATURES)
90
 
 
91
  def prob_at_risk(X: pd.DataFrame) -> float:
92
  probs = model.predict_proba(X)[0]
93
  classes = list(model.classes_)
94
  idx = classes.index(1) # class 1 = At Risk
95
  return float(probs[idx])
96
 
 
97
  def risk_label(p: float) -> str:
98
  return "At Risk" if p >= 0.5 else "Not At Risk"
99
 
 
 
 
 
 
 
 
 
 
 
100
  # =========================
101
- # Plot: Average of key drivers
102
- # NOTE: fixed figsize + fixed container height prevents "vibration"
103
  # =========================
104
  def make_driver_plot(driver_vals: dict):
105
  values = [driver_vals[v] for v in ALL_DRIVER_VARS]
 
106
 
107
- fig, ax = plt.subplots(figsize=(8.8, 3.2))
108
  ax.bar(ALL_DRIVER_LABELS, values)
109
 
 
 
 
 
110
  ax.set_ylim(1, 5.4)
111
  ax.set_yticks([1, 2, 3, 4, 5])
112
  ax.set_ylabel("Score (1–5)")
113
  ax.set_title("Average of key drivers")
114
 
115
- ax.margins(x=0.08)
116
  plt.tight_layout()
117
  plt.subplots_adjust(bottom=0.30)
118
  return fig
119
 
 
120
  # =========================
121
  # TRUE SHAP using CatBoost native SHAP values
122
  # =========================
123
  def make_catboost_shap_plot(X: pd.DataFrame):
124
- fig, ax = plt.subplots(figsize=(8.8, 3.2))
125
 
126
  try:
127
  from catboost import Pool
@@ -131,7 +150,7 @@ def make_catboost_shap_plot(X: pd.DataFrame):
131
  contrib = shap_vals[0, :-1] # drop expected value
132
 
133
  s = pd.Series(contrib, index=X.columns)
134
- s = s.drop(labels=["ManagementLevel"], errors="ignore") # hide mgmt level from story
135
  s = s.reindex(s.abs().sort_values(ascending=False).index).head(8)
136
 
137
  ax.barh(s.index[::-1], s.values[::-1])
@@ -142,15 +161,19 @@ def make_catboost_shap_plot(X: pd.DataFrame):
142
 
143
  except Exception as e:
144
  ax.text(
145
- 0.5, 0.55,
 
146
  "SHAP chart unavailable.\nInstall 'catboost' in requirements.txt.",
147
- ha="center", va="center", fontsize=10
 
 
148
  )
149
  ax.text(0.5, 0.40, f"Error: {str(e)[:150]}", ha="center", va="center", fontsize=9)
150
  ax.set_axis_off()
151
  plt.tight_layout()
152
  return fig
153
 
 
154
  # =========================
155
  # Core predict
156
  # =========================
@@ -184,11 +207,12 @@ def predict(
184
 
185
  return headline, drivers_fig, shap_fig
186
 
 
187
  # =========================
188
  # Buttons
189
  # =========================
190
  def load_at_risk_group():
191
- # At risk group = average of Cluster 1 and Cluster 2 (as you requested)
192
  avg = {v: (CLUSTER_1[v] + CLUSTER_2[v]) / 2.0 for v in ALL_DRIVER_VARS}
193
 
194
  headline, drivers_fig, shap_fig = predict(
@@ -214,6 +238,7 @@ def load_at_risk_group():
214
  shap_fig,
215
  )
216
 
 
217
  def apply_recommendation():
218
  # Apply recommendation = move to Cluster 1 target levels
219
  target = {v: CLUSTER_1[v] for v in ALL_DRIVER_VARS}
@@ -241,35 +266,37 @@ def apply_recommendation():
241
  shap_fig,
242
  )
243
 
 
244
  # =========================
245
- # UI Layout (fix vibration)
246
- # Key fixes:
247
- # - Use Textbox (fixed height) instead of Markdown
248
- # - Wrap plots in fixed-height containers using CSS
249
  # =========================
250
  CSS = """
251
  #app-wrap { max-width: 1200px; margin: 0 auto; }
252
 
253
- /* Make output panels stable height so the page doesn't reflow */
254
- .fixed-plot { height: 360px; overflow: hidden; }
 
 
 
 
 
255
 
256
- /* Reduce extra vertical whitespace */
257
- .compact .gr-box, .compact .gr-panel { padding-top: 8px !important; padding-bottom: 8px !important; }
 
258
  """
259
 
260
  with gr.Blocks(css=CSS) as demo:
261
  gr.Markdown(
262
  "<div id='app-wrap' class='compact'>"
263
  "<h2>Retention Simulator</h2>"
264
- "<p style='margin-top:0;'>Use the sliders and click <b>Predict</b>. "
265
- "Or click <b>At risk group</b> / <b>Apply recommendation</b>.</p>"
266
  "</div>"
267
  )
268
 
269
  with gr.Row():
270
  # LEFT: sliders + buttons
271
  with gr.Column(scale=5, min_width=430):
272
- # Start at Cluster 3 (most at-risk)
273
  Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
274
  SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
275
  WellBeing = gr.Slider(1, 5, value=CLUSTER_3["WellBeing"], step=0.01, label="Well-Being")
@@ -283,42 +310,58 @@ with gr.Blocks(css=CSS) as demo:
283
  btn_atrisk = gr.Button("At risk group")
284
  btn_reco = gr.Button("Apply recommendation")
285
 
286
- # RIGHT: headline + two plots stacked
287
  with gr.Column(scale=7, min_width=520):
288
  headline = gr.Textbox(label="Result", value="", interactive=False)
289
 
290
  gr.HTML('<div class="fixed-plot">')
291
  drivers_plot = gr.Plot(label="Average of key drivers")
292
- gr.HTML('</div>')
293
 
294
  gr.HTML('<div class="fixed-plot">')
295
  shap_plot = gr.Plot(label="Feature Importance (Shap)")
296
- gr.HTML('</div>')
297
 
298
- # Manual predict (does NOT change sliders)
299
  btn_predict.click(
300
  fn=predict,
301
  inputs=[Engagement, SupportiveGM, WellBeing, WorkEnvironment, Voice, DecisionAutonomy, Workload],
302
  outputs=[headline, drivers_plot, shap_plot],
303
  )
304
 
305
- # Button: At risk group (updates sliders + outputs)
306
  btn_atrisk.click(
307
  fn=load_at_risk_group,
308
  inputs=[],
309
  outputs=[
310
- Engagement, SupportiveGM, WellBeing, WorkEnvironment, Voice, DecisionAutonomy, Workload,
311
- headline, drivers_plot, shap_plot
 
 
 
 
 
 
 
 
312
  ],
313
  )
314
 
315
- # Button: Apply recommendation (updates sliders + outputs)
316
  btn_reco.click(
317
  fn=apply_recommendation,
318
  inputs=[],
319
  outputs=[
320
- Engagement, SupportiveGM, WellBeing, WorkEnvironment, Voice, DecisionAutonomy, Workload,
321
- headline, drivers_plot, shap_plot
 
 
 
 
 
 
 
 
322
  ],
323
  )
324
 
 
78
  "Workload",
79
  ]
80
 
81
+
82
  # =========================
83
  # Helpers
84
  # =========================
85
  def clamp_1_5(x):
86
  return max(1.0, min(5.0, float(x)))
87
 
88
+
89
  def build_X(vals: dict) -> pd.DataFrame:
90
  row = {f: vals[f] for f in FEATURES}
91
  return pd.DataFrame([[row[f] for f in FEATURES]], columns=FEATURES)
92
 
93
+
94
  def prob_at_risk(X: pd.DataFrame) -> float:
95
  probs = model.predict_proba(X)[0]
96
  classes = list(model.classes_)
97
  idx = classes.index(1) # class 1 = At Risk
98
  return float(probs[idx])
99
 
100
+
101
  def risk_label(p: float) -> str:
102
  return "At Risk" if p >= 0.5 else "Not At Risk"
103
 
104
+
105
+ def not_at_risk_threshold_from_cluster3():
106
+ """
107
+ As requested:
108
+ threshold line = MIN of Cluster 3 across the driver vars in the averages chart,
109
+ labeled "Not at-risk threshold".
110
+ """
111
+ return min(CLUSTER_3[v] for v in ALL_DRIVER_VARS)
112
+
113
+
114
  # =========================
115
+ # Plot: Average of key drivers + threshold
 
116
  # =========================
117
  def make_driver_plot(driver_vals: dict):
118
  values = [driver_vals[v] for v in ALL_DRIVER_VARS]
119
+ th = not_at_risk_threshold_from_cluster3()
120
 
121
+ fig, ax = plt.subplots(figsize=(8.6, 3.1))
122
  ax.bar(ALL_DRIVER_LABELS, values)
123
 
124
+ # threshold line
125
+ ax.axhline(th, linestyle="--", linewidth=2)
126
+ ax.text(len(ALL_DRIVER_LABELS) - 0.1, th, "Not at-risk threshold", va="center", ha="right")
127
+
128
  ax.set_ylim(1, 5.4)
129
  ax.set_yticks([1, 2, 3, 4, 5])
130
  ax.set_ylabel("Score (1–5)")
131
  ax.set_title("Average of key drivers")
132
 
133
+ ax.margins(x=0.06)
134
  plt.tight_layout()
135
  plt.subplots_adjust(bottom=0.30)
136
  return fig
137
 
138
+
139
  # =========================
140
  # TRUE SHAP using CatBoost native SHAP values
141
  # =========================
142
  def make_catboost_shap_plot(X: pd.DataFrame):
143
+ fig, ax = plt.subplots(figsize=(8.6, 3.1))
144
 
145
  try:
146
  from catboost import Pool
 
150
  contrib = shap_vals[0, :-1] # drop expected value
151
 
152
  s = pd.Series(contrib, index=X.columns)
153
+ s = s.drop(labels=["ManagementLevel"], errors="ignore")
154
  s = s.reindex(s.abs().sort_values(ascending=False).index).head(8)
155
 
156
  ax.barh(s.index[::-1], s.values[::-1])
 
161
 
162
  except Exception as e:
163
  ax.text(
164
+ 0.5,
165
+ 0.55,
166
  "SHAP chart unavailable.\nInstall 'catboost' in requirements.txt.",
167
+ ha="center",
168
+ va="center",
169
+ fontsize=10,
170
  )
171
  ax.text(0.5, 0.40, f"Error: {str(e)[:150]}", ha="center", va="center", fontsize=9)
172
  ax.set_axis_off()
173
  plt.tight_layout()
174
  return fig
175
 
176
+
177
  # =========================
178
  # Core predict
179
  # =========================
 
207
 
208
  return headline, drivers_fig, shap_fig
209
 
210
+
211
  # =========================
212
  # Buttons
213
  # =========================
214
  def load_at_risk_group():
215
+ # At risk group = average of Cluster 1 and Cluster 2
216
  avg = {v: (CLUSTER_1[v] + CLUSTER_2[v]) / 2.0 for v in ALL_DRIVER_VARS}
217
 
218
  headline, drivers_fig, shap_fig = predict(
 
238
  shap_fig,
239
  )
240
 
241
+
242
  def apply_recommendation():
243
  # Apply recommendation = move to Cluster 1 target levels
244
  target = {v: CLUSTER_1[v] for v in ALL_DRIVER_VARS}
 
266
  shap_fig,
267
  )
268
 
269
+
270
  # =========================
271
+ # UI Layout (tight, no big gaps)
 
 
 
272
  # =========================
273
  CSS = """
274
  #app-wrap { max-width: 1200px; margin: 0 auto; }
275
 
276
+ /* Remove extra padding/margins from blocks */
277
+ .gr-block { padding: 10px 12px !important; }
278
+ .gr-form { gap: 8px !important; }
279
+ .gr-row { gap: 10px !important; }
280
+
281
+ /* Make plot containers stable but NOT huge (reduces empty space) */
282
+ .fixed-plot { height: 330px; overflow: hidden; }
283
 
284
+ /* Make markdown tighter */
285
+ .compact h2 { margin: 0 0 6px 0; }
286
+ .compact p { margin: 0 0 8px 0; }
287
  """
288
 
289
  with gr.Blocks(css=CSS) as demo:
290
  gr.Markdown(
291
  "<div id='app-wrap' class='compact'>"
292
  "<h2>Retention Simulator</h2>"
293
+ "<p>Use sliders + <b>Predict</b>, or click <b>At risk group</b> / <b>Apply recommendation</b>.</p>"
 
294
  "</div>"
295
  )
296
 
297
  with gr.Row():
298
  # LEFT: sliders + buttons
299
  with gr.Column(scale=5, min_width=430):
 
300
  Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
301
  SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
302
  WellBeing = gr.Slider(1, 5, value=CLUSTER_3["WellBeing"], step=0.01, label="Well-Being")
 
310
  btn_atrisk = gr.Button("At risk group")
311
  btn_reco = gr.Button("Apply recommendation")
312
 
313
+ # RIGHT: result + two plots stacked
314
  with gr.Column(scale=7, min_width=520):
315
  headline = gr.Textbox(label="Result", value="", interactive=False)
316
 
317
  gr.HTML('<div class="fixed-plot">')
318
  drivers_plot = gr.Plot(label="Average of key drivers")
319
+ gr.HTML("</div>")
320
 
321
  gr.HTML('<div class="fixed-plot">')
322
  shap_plot = gr.Plot(label="Feature Importance (Shap)")
323
+ gr.HTML("</div>")
324
 
325
+ # Predict (does NOT change sliders)
326
  btn_predict.click(
327
  fn=predict,
328
  inputs=[Engagement, SupportiveGM, WellBeing, WorkEnvironment, Voice, DecisionAutonomy, Workload],
329
  outputs=[headline, drivers_plot, shap_plot],
330
  )
331
 
332
+ # At risk group (updates sliders + outputs)
333
  btn_atrisk.click(
334
  fn=load_at_risk_group,
335
  inputs=[],
336
  outputs=[
337
+ Engagement,
338
+ SupportiveGM,
339
+ WellBeing,
340
+ WorkEnvironment,
341
+ Voice,
342
+ DecisionAutonomy,
343
+ Workload,
344
+ headline,
345
+ drivers_plot,
346
+ shap_plot,
347
  ],
348
  )
349
 
350
+ # Apply recommendation (updates sliders + outputs)
351
  btn_reco.click(
352
  fn=apply_recommendation,
353
  inputs=[],
354
  outputs=[
355
+ Engagement,
356
+ SupportiveGM,
357
+ WellBeing,
358
+ WorkEnvironment,
359
+ Voice,
360
+ DecisionAutonomy,
361
+ Workload,
362
+ headline,
363
+ drivers_plot,
364
+ shap_plot,
365
  ],
366
  )
367