mns6rh commited on
Commit
589c115
·
verified ·
1 Parent(s): cd33557

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -43
app.py CHANGED
@@ -10,15 +10,15 @@ plt.rcParams["figure.dpi"] = 100
10
 
11
  # ============================================================
12
  # Global font styling
13
- # - All non-title text: Arial Black size 12
14
- # - Titles: Arial Black size 14
15
  # ============================================================
16
  plt.rcParams["font.family"] = "Arial"
17
  plt.rcParams["font.weight"] = "black"
18
- plt.rcParams["font.size"] = 12
19
 
20
  TITLE_FONTSIZE = 14
21
- TEXT_FONTSIZE = 12
22
 
23
  # ============================================================
24
  # Class meaning (per your note)
@@ -125,13 +125,13 @@ def prob_leave_and_stay(X: pd.DataFrame) -> tuple[float, float]:
125
  return p_leave, p_stay
126
 
127
 
 
 
 
 
 
128
  # =========================
129
- # Donut chart
130
- # Requirements implemented:
131
- # 1) remove "Probability of Leave" center text (keep only %)
132
- # 2) title = "Turnover Risk"
133
- # 3) legend: red turnover, blue retention
134
- # 4) fonts set globally; title uses size 14
135
  # =========================
136
  def make_turnover_donut(p_leave: float, p_stay: float):
137
  p_leave = max(0.0, min(1.0, float(p_leave)))
@@ -143,7 +143,7 @@ def make_turnover_donut(p_leave: float, p_stay: float):
143
  else:
144
  p_leave, p_stay = p_leave / s, p_stay / s
145
 
146
- fig, ax = plt.subplots(figsize=(4.6, 3.6))
147
 
148
  ax.pie(
149
  [p_leave, p_stay],
@@ -153,14 +153,14 @@ def make_turnover_donut(p_leave: float, p_stay: float):
153
  wedgeprops=dict(width=0.35, edgecolor="white"),
154
  )
155
 
156
- # Center: ONLY percent (no "Probability of Leave")
157
  ax.text(
158
  0,
159
  0.00,
160
  f"{p_leave*100:.0f}%",
161
  ha="center",
162
  va="center",
163
- fontsize=20,
164
  fontweight="black",
165
  color=RED,
166
  family="Arial",
@@ -184,12 +184,14 @@ def make_turnover_donut(p_leave: float, p_stay: float):
184
 
185
  # =========================
186
  # Plot: Average of key drivers + Goal Average
 
187
  # =========================
188
  def make_driver_plot(driver_vals: dict):
189
  values = [driver_vals[v] for v in ALL_DRIVER_VARS]
190
  goals = [GOAL_AVG[v] for v in ALL_DRIVER_VARS]
191
 
192
- fig, ax = plt.subplots(figsize=(8.6, 3.1))
 
193
  bars = ax.bar(range(len(ALL_DRIVER_LABELS)), values, tick_label=ALL_DRIVER_LABELS)
194
 
195
  # dashed goal segment per bar
@@ -217,7 +219,7 @@ def make_driver_plot(driver_vals: dict):
217
 
218
  ax.margins(x=0.06)
219
  plt.tight_layout()
220
- plt.subplots_adjust(bottom=0.30)
221
 
222
  plt.close(fig)
223
  return fig
@@ -225,7 +227,7 @@ def make_driver_plot(driver_vals: dict):
225
 
226
  # =========================
227
  # SHAP waterfall
228
- # Make "pink/red" theme consistent by setting SHAP colors to red/blue.
229
  # =========================
230
  def make_catboost_shap_plot(X: pd.DataFrame):
231
  fig, ax = plt.subplots(figsize=(8.6, 3.6))
@@ -234,10 +236,10 @@ def make_catboost_shap_plot(X: pd.DataFrame):
234
  import shap
235
  from catboost import Pool
236
 
237
- # Set SHAP waterfall colors (pos = turnover = red; neg = retention = blue)
238
- # These affect shap.plots.waterfall styling.
239
  try:
240
- shap.plots.colors.red_blue = [BLUE, RED] # keep fallback attempt harmless
 
241
  except Exception:
242
  pass
243
 
@@ -277,13 +279,12 @@ def make_catboost_shap_plot(X: pd.DataFrame):
277
 
278
  plt.close(fig)
279
 
280
- # SHAP will create its own matplotlib figure
281
  shap.plots.waterfall(exp, max_display=8, show=False)
282
 
283
  fig2 = plt.gcf()
284
  fig2.set_size_inches(8.6, 3.6)
285
 
286
- # Make SHAP title font consistent (if present)
287
  try:
288
  ax2 = fig2.axes[0]
289
  if ax2.get_title():
@@ -324,6 +325,7 @@ def make_catboost_shap_plot(X: pd.DataFrame):
324
 
325
  # =========================
326
  # Core predict
 
327
  # =========================
328
  def predict(
329
  Engagement,
@@ -350,10 +352,10 @@ def predict(
350
  p_leave, p_stay = prob_leave_and_stay(X)
351
 
352
  donut_fig = make_turnover_donut(p_leave, p_stay)
353
- drivers_fig = make_driver_plot(driver_vals)
354
  shap_fig = make_catboost_shap_plot(X)
 
355
 
356
- return donut_fig, drivers_fig, shap_fig
357
 
358
 
359
  # =========================
@@ -363,7 +365,7 @@ def load_at_risk_group():
363
  # At-risk group = average of the two lowest clusters (Cluster 2 and 3)
364
  target = {v: (CLUSTER_2[v] + CLUSTER_3[v]) / 2.0 for v in ALL_DRIVER_VARS}
365
 
366
- donut_fig, drivers_fig, shap_fig = predict(
367
  target["Engagement"],
368
  target["SupportiveGM"],
369
  target["WellBeing"],
@@ -382,8 +384,8 @@ def load_at_risk_group():
382
  target["DecisionAutonomy"],
383
  target["Workload"],
384
  donut_fig,
385
- drivers_fig,
386
  shap_fig,
 
387
  )
388
 
389
 
@@ -391,7 +393,7 @@ def apply_recommendation():
391
  # Recommendation = highest cluster (Cluster 1)
392
  target = {v: CLUSTER_1[v] for v in ALL_DRIVER_VARS}
393
 
394
- donut_fig, drivers_fig, shap_fig = predict(
395
  target["Engagement"],
396
  target["SupportiveGM"],
397
  target["WellBeing"],
@@ -410,31 +412,24 @@ def apply_recommendation():
410
  target["DecisionAutonomy"],
411
  target["Workload"],
412
  donut_fig,
413
- drivers_fig,
414
  shap_fig,
 
415
  )
416
 
417
 
418
  # =========================
419
- # UI Layout (tight, no big gaps)
420
- # Buttons at the TOP:
421
- # - At risk group = red
422
- # - Apply recommendation = blue
423
  # =========================
424
  CSS = f"""
425
  #app-wrap {{ max-width: 1200px; margin: 0 auto; }}
426
- /* Remove extra padding/margins from blocks */
427
  .gr-block {{ padding: 10px 12px !important; }}
428
  .gr-form {{ gap: 8px !important; }}
429
  .gr-row {{ gap: 10px !important; }}
430
 
431
- /* Button colors */
432
  #btn_atrisk button {{ background: {RED} !important; color: white !important; border: none !important; }}
433
  #btn_reco button {{ background: {BLUE} !important; color: white !important; border: none !important; }}
434
-
435
- /* Make markdown tighter */
436
- .compact h2 {{ margin: 0 0 6px 0; }}
437
- .compact p {{ margin: 0 0 8px 0; }}
438
  """
439
 
440
  with gr.Blocks(css=CSS) as demo:
@@ -451,8 +446,9 @@ with gr.Blocks(css=CSS) as demo:
451
  btn_reco = gr.Button("Apply recommendation", elem_id="btn_reco")
452
  btn_predict = gr.Button("Predict")
453
 
 
454
  with gr.Row():
455
- # LEFT SIDE — SLIDERS
456
  with gr.Column(scale=4):
457
  Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
458
  SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
@@ -462,12 +458,15 @@ with gr.Blocks(css=CSS) as demo:
462
  DecisionAutonomy = gr.Slider(1, 5, value=CLUSTER_3["DecisionAutonomy"], step=0.01, label="Decision Autonomy")
463
  Workload = gr.Slider(1, 5, value=CLUSTER_3["Workload"], step=0.01, label="Workload")
464
 
465
- # RIGHT SIDE — DONUT + CHARTS
466
  with gr.Column(scale=6):
467
  donut_plot = gr.Plot(label="Turnover Risk")
468
- drivers_plot = gr.Plot(label="Average of key drivers")
469
  shap_plot = gr.Plot(label="Feature Importance (Shap)")
470
 
 
 
 
 
471
  # =========================
472
  # WIRE UP EVENTS
473
  # =========================
@@ -476,7 +475,7 @@ with gr.Blocks(css=CSS) as demo:
476
  btn_predict.click(
477
  fn=predict,
478
  inputs=slider_inputs,
479
- outputs=[donut_plot, drivers_plot, shap_plot],
480
  )
481
 
482
  btn_atrisk.click(
@@ -491,8 +490,8 @@ with gr.Blocks(css=CSS) as demo:
491
  DecisionAutonomy,
492
  Workload,
493
  donut_plot,
494
- drivers_plot,
495
  shap_plot,
 
496
  ],
497
  )
498
 
@@ -508,8 +507,8 @@ with gr.Blocks(css=CSS) as demo:
508
  DecisionAutonomy,
509
  Workload,
510
  donut_plot,
511
- drivers_plot,
512
  shap_plot,
 
513
  ],
514
  )
515
 
 
10
 
11
  # ============================================================
12
  # Global font styling
13
+ # - All non-title text: Arial Black size 10
14
+ # - Titles: Arial Black size 14 (unchanged from your spec)
15
  # ============================================================
16
  plt.rcParams["font.family"] = "Arial"
17
  plt.rcParams["font.weight"] = "black"
18
+ plt.rcParams["font.size"] = 10
19
 
20
  TITLE_FONTSIZE = 14
21
+ TEXT_FONTSIZE = 10
22
 
23
  # ============================================================
24
  # Class meaning (per your note)
 
125
  return p_leave, p_stay
126
 
127
 
128
+ def hex_to_rgb01(h: str):
129
+ h = h.lstrip("#")
130
+ return (int(h[0:2], 16) / 255.0, int(h[2:4], 16) / 255.0, int(h[4:6], 16) / 255.0)
131
+
132
+
133
  # =========================
134
+ # Donut chart (Turnover Risk)
 
 
 
 
 
135
  # =========================
136
  def make_turnover_donut(p_leave: float, p_stay: float):
137
  p_leave = max(0.0, min(1.0, float(p_leave)))
 
143
  else:
144
  p_leave, p_stay = p_leave / s, p_stay / s
145
 
146
+ fig, ax = plt.subplots(figsize=(4.8, 3.4))
147
 
148
  ax.pie(
149
  [p_leave, p_stay],
 
153
  wedgeprops=dict(width=0.35, edgecolor="white"),
154
  )
155
 
156
+ # Center: ONLY percent
157
  ax.text(
158
  0,
159
  0.00,
160
  f"{p_leave*100:.0f}%",
161
  ha="center",
162
  va="center",
163
+ fontsize=18,
164
  fontweight="black",
165
  color=RED,
166
  family="Arial",
 
184
 
185
  # =========================
186
  # Plot: Average of key drivers + Goal Average
187
+ # (Longer + intended for bottom full-width placement)
188
  # =========================
189
  def make_driver_plot(driver_vals: dict):
190
  values = [driver_vals[v] for v in ALL_DRIVER_VARS]
191
  goals = [GOAL_AVG[v] for v in ALL_DRIVER_VARS]
192
 
193
+ fig, ax = plt.subplots(figsize=(13.0, 3.2)) # longer/wider
194
+
195
  bars = ax.bar(range(len(ALL_DRIVER_LABELS)), values, tick_label=ALL_DRIVER_LABELS)
196
 
197
  # dashed goal segment per bar
 
219
 
220
  ax.margins(x=0.06)
221
  plt.tight_layout()
222
+ plt.subplots_adjust(bottom=0.28)
223
 
224
  plt.close(fig)
225
  return fig
 
227
 
228
  # =========================
229
  # SHAP waterfall
230
+ # Make positive = RED, negative = BLUE
231
  # =========================
232
  def make_catboost_shap_plot(X: pd.DataFrame):
233
  fig, ax = plt.subplots(figsize=(8.6, 3.6))
 
236
  import shap
237
  from catboost import Pool
238
 
239
+ # force SHAP waterfall color theme (pos=red, neg=blue)
 
240
  try:
241
+ shap.plots.colors.red_rgb = hex_to_rgb01(RED)
242
+ shap.plots.colors.blue_rgb = hex_to_rgb01(BLUE)
243
  except Exception:
244
  pass
245
 
 
279
 
280
  plt.close(fig)
281
 
 
282
  shap.plots.waterfall(exp, max_display=8, show=False)
283
 
284
  fig2 = plt.gcf()
285
  fig2.set_size_inches(8.6, 3.6)
286
 
287
+ # enforce tick font sizes
288
  try:
289
  ax2 = fig2.axes[0]
290
  if ax2.get_title():
 
325
 
326
  # =========================
327
  # Core predict
328
+ # Returns: donut, shap, drivers (so we can lay them out how you want)
329
  # =========================
330
  def predict(
331
  Engagement,
 
352
  p_leave, p_stay = prob_leave_and_stay(X)
353
 
354
  donut_fig = make_turnover_donut(p_leave, p_stay)
 
355
  shap_fig = make_catboost_shap_plot(X)
356
+ drivers_fig = make_driver_plot(driver_vals)
357
 
358
+ return donut_fig, shap_fig, drivers_fig
359
 
360
 
361
  # =========================
 
365
  # At-risk group = average of the two lowest clusters (Cluster 2 and 3)
366
  target = {v: (CLUSTER_2[v] + CLUSTER_3[v]) / 2.0 for v in ALL_DRIVER_VARS}
367
 
368
+ donut_fig, shap_fig, drivers_fig = predict(
369
  target["Engagement"],
370
  target["SupportiveGM"],
371
  target["WellBeing"],
 
384
  target["DecisionAutonomy"],
385
  target["Workload"],
386
  donut_fig,
 
387
  shap_fig,
388
+ drivers_fig,
389
  )
390
 
391
 
 
393
  # Recommendation = highest cluster (Cluster 1)
394
  target = {v: CLUSTER_1[v] for v in ALL_DRIVER_VARS}
395
 
396
+ donut_fig, shap_fig, drivers_fig = predict(
397
  target["Engagement"],
398
  target["SupportiveGM"],
399
  target["WellBeing"],
 
412
  target["DecisionAutonomy"],
413
  target["Workload"],
414
  donut_fig,
 
415
  shap_fig,
416
+ drivers_fig,
417
  )
418
 
419
 
420
  # =========================
421
+ # UI Layout
422
+ # - Right side: donut then SHAP directly below
423
+ # - Bottom: wide drivers bar chart centered under everything
 
424
  # =========================
425
  CSS = f"""
426
  #app-wrap {{ max-width: 1200px; margin: 0 auto; }}
 
427
  .gr-block {{ padding: 10px 12px !important; }}
428
  .gr-form {{ gap: 8px !important; }}
429
  .gr-row {{ gap: 10px !important; }}
430
 
 
431
  #btn_atrisk button {{ background: {RED} !important; color: white !important; border: none !important; }}
432
  #btn_reco button {{ background: {BLUE} !important; color: white !important; border: none !important; }}
 
 
 
 
433
  """
434
 
435
  with gr.Blocks(css=CSS) as demo:
 
446
  btn_reco = gr.Button("Apply recommendation", elem_id="btn_reco")
447
  btn_predict = gr.Button("Predict")
448
 
449
+ # TOP MAIN ROW
450
  with gr.Row():
451
+ # LEFT: sliders
452
  with gr.Column(scale=4):
453
  Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
454
  SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
 
458
  DecisionAutonomy = gr.Slider(1, 5, value=CLUSTER_3["DecisionAutonomy"], step=0.01, label="Decision Autonomy")
459
  Workload = gr.Slider(1, 5, value=CLUSTER_3["Workload"], step=0.01, label="Workload")
460
 
461
+ # RIGHT: donut + SHAP stacked
462
  with gr.Column(scale=6):
463
  donut_plot = gr.Plot(label="Turnover Risk")
 
464
  shap_plot = gr.Plot(label="Feature Importance (Shap)")
465
 
466
+ # BOTTOM ROW: wide driver bars centered under everything
467
+ with gr.Row():
468
+ drivers_plot = gr.Plot(label="Average of key drivers")
469
+
470
  # =========================
471
  # WIRE UP EVENTS
472
  # =========================
 
475
  btn_predict.click(
476
  fn=predict,
477
  inputs=slider_inputs,
478
+ outputs=[donut_plot, shap_plot, drivers_plot],
479
  )
480
 
481
  btn_atrisk.click(
 
490
  DecisionAutonomy,
491
  Workload,
492
  donut_plot,
 
493
  shap_plot,
494
+ drivers_plot,
495
  ],
496
  )
497
 
 
507
  DecisionAutonomy,
508
  Workload,
509
  donut_plot,
 
510
  shap_plot,
511
+ drivers_plot,
512
  ],
513
  )
514