mns6rh commited on
Commit
f418855
·
verified ·
1 Parent(s): 589c115

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -22
app.py CHANGED
@@ -11,7 +11,7 @@ plt.rcParams["figure.dpi"] = 100
11
  # ============================================================
12
  # Global font styling
13
  # - All non-title text: Arial Black size 10
14
- # - Titles: Arial Black size 14 (unchanged from your spec)
15
  # ============================================================
16
  plt.rcParams["font.family"] = "Arial"
17
  plt.rcParams["font.weight"] = "black"
@@ -130,8 +130,24 @@ def hex_to_rgb01(h: str):
130
  return (int(h[0:2], 16) / 255.0, int(h[2:4], 16) / 255.0, int(h[4:6], 16) / 255.0)
131
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  # =========================
134
  # Donut chart (Turnover Risk)
 
 
135
  # =========================
136
  def make_turnover_donut(p_leave: float, p_stay: float):
137
  p_leave = max(0.0, min(1.0, float(p_leave)))
@@ -143,7 +159,8 @@ def make_turnover_donut(p_leave: float, p_stay: float):
143
  else:
144
  p_leave, p_stay = p_leave / s, p_stay / s
145
 
146
- fig, ax = plt.subplots(figsize=(4.8, 3.4))
 
147
 
148
  ax.pie(
149
  [p_leave, p_stay],
@@ -166,31 +183,39 @@ def make_turnover_donut(p_leave: float, p_stay: float):
166
  family="Arial",
167
  )
168
 
169
- ax.legend(
170
  ["Probability of Turnover", "Probability of Retention"],
171
  loc="center left",
172
  bbox_to_anchor=(1.02, 0.5),
173
  frameon=False,
174
  prop={"family": "Arial", "weight": "black", "size": TEXT_FONTSIZE},
 
 
175
  )
 
 
 
 
176
 
177
- ax.set_title("Turnover Risk", pad=10, fontweight="black", fontsize=TITLE_FONTSIZE, family="Arial")
178
  ax.set_aspect("equal")
179
 
 
180
  plt.tight_layout()
 
 
181
  plt.close(fig)
182
  return fig
183
 
184
 
185
  # =========================
186
  # Plot: Average of key drivers + Goal Average
187
- # (Longer + intended for bottom full-width placement)
188
  # =========================
189
  def make_driver_plot(driver_vals: dict):
190
  values = [driver_vals[v] for v in ALL_DRIVER_VARS]
191
  goals = [GOAL_AVG[v] for v in ALL_DRIVER_VARS]
192
 
193
- fig, ax = plt.subplots(figsize=(13.0, 3.2)) # longer/wider
194
 
195
  bars = ax.bar(range(len(ALL_DRIVER_LABELS)), values, tick_label=ALL_DRIVER_LABELS)
196
 
@@ -214,9 +239,11 @@ def make_driver_plot(driver_vals: dict):
214
 
215
  ax.set_ylim(1, 5.4)
216
  ax.set_yticks([1, 2, 3, 4, 5])
217
- ax.set_ylabel("Score (1–5)", fontsize=TEXT_FONTSIZE, fontweight="black", family="Arial")
218
  ax.set_title("Average of key drivers", fontsize=TITLE_FONTSIZE, fontweight="black", family="Arial")
219
 
 
 
220
  ax.margins(x=0.06)
221
  plt.tight_layout()
222
  plt.subplots_adjust(bottom=0.28)
@@ -227,16 +254,17 @@ def make_driver_plot(driver_vals: dict):
227
 
228
  # =========================
229
  # SHAP waterfall
230
- # Make positive = RED, negative = BLUE
 
231
  # =========================
232
  def make_catboost_shap_plot(X: pd.DataFrame):
233
- fig, ax = plt.subplots(figsize=(8.6, 3.6))
234
 
235
  try:
236
  import shap
237
  from catboost import Pool
238
 
239
- # force SHAP waterfall color theme (pos=red, neg=blue)
240
  try:
241
  shap.plots.colors.red_rgb = hex_to_rgb01(RED)
242
  shap.plots.colors.blue_rgb = hex_to_rgb01(BLUE)
@@ -282,13 +310,12 @@ def make_catboost_shap_plot(X: pd.DataFrame):
282
  shap.plots.waterfall(exp, max_display=8, show=False)
283
 
284
  fig2 = plt.gcf()
285
- fig2.set_size_inches(8.6, 3.6)
286
 
287
- # enforce tick font sizes
288
  try:
289
  ax2 = fig2.axes[0]
290
- if ax2.get_title():
291
- ax2.set_title(ax2.get_title(), fontsize=TITLE_FONTSIZE, fontweight="black", family="Arial")
292
  for t in ax2.get_xticklabels() + ax2.get_yticklabels():
293
  t.set_fontsize(TEXT_FONTSIZE)
294
  t.set_fontweight("black")
@@ -325,7 +352,7 @@ def make_catboost_shap_plot(X: pd.DataFrame):
325
 
326
  # =========================
327
  # Core predict
328
- # Returns: donut, shap, drivers (so we can lay them out how you want)
329
  # =========================
330
  def predict(
331
  Engagement,
@@ -419,8 +446,11 @@ def apply_recommendation():
419
 
420
  # =========================
421
  # UI Layout
422
- # - Right side: donut then SHAP directly below
423
- # - Bottom: wide drivers bar chart centered under everything
 
 
 
424
  # =========================
425
  CSS = f"""
426
  #app-wrap {{ max-width: 1200px; margin: 0 auto; }}
@@ -430,6 +460,9 @@ CSS = f"""
430
 
431
  #btn_atrisk button {{ background: {RED} !important; color: white !important; border: none !important; }}
432
  #btn_reco button {{ background: {BLUE} !important; color: white !important; border: none !important; }}
 
 
 
433
  """
434
 
435
  with gr.Blocks(css=CSS) as demo:
@@ -440,15 +473,14 @@ with gr.Blocks(css=CSS) as demo:
440
  "</div>"
441
  )
442
 
443
- # TOP BUTTON ROW
444
  with gr.Row():
445
  btn_atrisk = gr.Button("At risk group", elem_id="btn_atrisk")
446
  btn_reco = gr.Button("Apply recommendation", elem_id="btn_reco")
447
- btn_predict = gr.Button("Predict")
448
 
449
- # TOP MAIN ROW
450
  with gr.Row():
451
- # LEFT: sliders
452
  with gr.Column(scale=4):
453
  Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
454
  SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
@@ -458,7 +490,9 @@ with gr.Blocks(css=CSS) as demo:
458
  DecisionAutonomy = gr.Slider(1, 5, value=CLUSTER_3["DecisionAutonomy"], step=0.01, label="Decision Autonomy")
459
  Workload = gr.Slider(1, 5, value=CLUSTER_3["Workload"], step=0.01, label="Workload")
460
 
461
- # RIGHT: donut + SHAP stacked
 
 
462
  with gr.Column(scale=6):
463
  donut_plot = gr.Plot(label="Turnover Risk")
464
  shap_plot = gr.Plot(label="Feature Importance (Shap)")
 
11
  # ============================================================
12
  # Global font styling
13
  # - All non-title text: Arial Black size 10
14
+ # - Titles: Arial Black size 14
15
  # ============================================================
16
  plt.rcParams["font.family"] = "Arial"
17
  plt.rcParams["font.weight"] = "black"
 
130
  return (int(h[0:2], 16) / 255.0, int(h[2:4], 16) / 255.0, int(h[4:6], 16) / 255.0)
131
 
132
 
133
+ def stylize_axis_fonts(ax):
134
+ # Apply consistent fonts to ticks/labels
135
+ for t in ax.get_xticklabels() + ax.get_yticklabels():
136
+ t.set_fontsize(TEXT_FONTSIZE)
137
+ t.set_fontweight("black")
138
+ t.set_family("Arial")
139
+ ax.xaxis.label.set_fontsize(TEXT_FONTSIZE)
140
+ ax.xaxis.label.set_fontweight("black")
141
+ ax.xaxis.label.set_family("Arial")
142
+ ax.yaxis.label.set_fontsize(TEXT_FONTSIZE)
143
+ ax.yaxis.label.set_fontweight("black")
144
+ ax.yaxis.label.set_family("Arial")
145
+
146
+
147
  # =========================
148
  # Donut chart (Turnover Risk)
149
+ # - Legend made visible by widening figure + right margin
150
+ # - No center subtitle text
151
  # =========================
152
  def make_turnover_donut(p_leave: float, p_stay: float):
153
  p_leave = max(0.0, min(1.0, float(p_leave)))
 
159
  else:
160
  p_leave, p_stay = p_leave / s, p_stay / s
161
 
162
+ # Wider so legend isn't clipped
163
+ fig, ax = plt.subplots(figsize=(8.6, 3.2))
164
 
165
  ax.pie(
166
  [p_leave, p_stay],
 
183
  family="Arial",
184
  )
185
 
186
+ leg = ax.legend(
187
  ["Probability of Turnover", "Probability of Retention"],
188
  loc="center left",
189
  bbox_to_anchor=(1.02, 0.5),
190
  frameon=False,
191
  prop={"family": "Arial", "weight": "black", "size": TEXT_FONTSIZE},
192
+ handlelength=1.2,
193
+ borderaxespad=0.0,
194
  )
195
+ for txt in leg.get_texts():
196
+ txt.set_fontsize(TEXT_FONTSIZE)
197
+ txt.set_fontweight("black")
198
+ txt.set_family("Arial")
199
 
200
+ ax.set_title("Turnover Risk", pad=8, fontweight="black", fontsize=TITLE_FONTSIZE, family="Arial")
201
  ax.set_aspect("equal")
202
 
203
+ # Leave room on the right for legend
204
  plt.tight_layout()
205
+ plt.subplots_adjust(right=0.80)
206
+
207
  plt.close(fig)
208
  return fig
209
 
210
 
211
  # =========================
212
  # Plot: Average of key drivers + Goal Average
 
213
  # =========================
214
  def make_driver_plot(driver_vals: dict):
215
  values = [driver_vals[v] for v in ALL_DRIVER_VARS]
216
  goals = [GOAL_AVG[v] for v in ALL_DRIVER_VARS]
217
 
218
+ fig, ax = plt.subplots(figsize=(13.0, 3.1)) # long/wide
219
 
220
  bars = ax.bar(range(len(ALL_DRIVER_LABELS)), values, tick_label=ALL_DRIVER_LABELS)
221
 
 
239
 
240
  ax.set_ylim(1, 5.4)
241
  ax.set_yticks([1, 2, 3, 4, 5])
242
+ ax.set_ylabel("Score (1–5)")
243
  ax.set_title("Average of key drivers", fontsize=TITLE_FONTSIZE, fontweight="black", family="Arial")
244
 
245
+ stylize_axis_fonts(ax)
246
+
247
  ax.margins(x=0.06)
248
  plt.tight_layout()
249
  plt.subplots_adjust(bottom=0.28)
 
254
 
255
  # =========================
256
  # SHAP waterfall
257
+ # - Force positive = RED, negative = BLUE
258
+ # - Add title: "Feature Importance (Shap)"
259
  # =========================
260
  def make_catboost_shap_plot(X: pd.DataFrame):
261
+ fig, ax = plt.subplots(figsize=(8.6, 3.2))
262
 
263
  try:
264
  import shap
265
  from catboost import Pool
266
 
267
+ # Try to force SHAP waterfall palette to red/blue
268
  try:
269
  shap.plots.colors.red_rgb = hex_to_rgb01(RED)
270
  shap.plots.colors.blue_rgb = hex_to_rgb01(BLUE)
 
310
  shap.plots.waterfall(exp, max_display=8, show=False)
311
 
312
  fig2 = plt.gcf()
313
+ fig2.set_size_inches(8.6, 3.2)
314
 
315
+ # Put a clean title on the SHAP plot
316
  try:
317
  ax2 = fig2.axes[0]
318
+ ax2.set_title("Feature Importance (Shap)", fontsize=TITLE_FONTSIZE, fontweight="black", family="Arial")
 
319
  for t in ax2.get_xticklabels() + ax2.get_yticklabels():
320
  t.set_fontsize(TEXT_FONTSIZE)
321
  t.set_fontweight("black")
 
352
 
353
  # =========================
354
  # Core predict
355
+ # Returns: donut, shap, drivers
356
  # =========================
357
  def predict(
358
  Engagement,
 
446
 
447
  # =========================
448
  # UI Layout
449
+ # - Top: At risk + Apply recommendation buttons
450
+ # - Left: sliders, then Predict button (under sliders)
451
+ # - Right: donut then SHAP directly below (same width)
452
+ # - Bottom: wide drivers bar chart centered
453
+ # - Sliders column a bit shorter to fit Predict button cleanly
454
  # =========================
455
  CSS = f"""
456
  #app-wrap {{ max-width: 1200px; margin: 0 auto; }}
 
460
 
461
  #btn_atrisk button {{ background: {RED} !important; color: white !important; border: none !important; }}
462
  #btn_reco button {{ background: {BLUE} !important; color: white !important; border: none !important; }}
463
+
464
+ /* Tighten slider spacing a bit so Predict fits nicely */
465
+ .gr-slider {{ margin-bottom: 6px !important; }}
466
  """
467
 
468
  with gr.Blocks(css=CSS) as demo:
 
473
  "</div>"
474
  )
475
 
476
+ # TOP BUTTON ROW (keep these at top)
477
  with gr.Row():
478
  btn_atrisk = gr.Button("At risk group", elem_id="btn_atrisk")
479
  btn_reco = gr.Button("Apply recommendation", elem_id="btn_reco")
 
480
 
481
+ # MAIN ROW
482
  with gr.Row():
483
+ # LEFT: sliders + Predict (under sliders)
484
  with gr.Column(scale=4):
485
  Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
486
  SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
 
490
  DecisionAutonomy = gr.Slider(1, 5, value=CLUSTER_3["DecisionAutonomy"], step=0.01, label="Decision Autonomy")
491
  Workload = gr.Slider(1, 5, value=CLUSTER_3["Workload"], step=0.01, label="Workload")
492
 
493
+ btn_predict = gr.Button("Predict")
494
+
495
+ # RIGHT: donut + SHAP stacked (same width)
496
  with gr.Column(scale=6):
497
  donut_plot = gr.Plot(label="Turnover Risk")
498
  shap_plot = gr.Plot(label="Feature Importance (Shap)")