Update app.py
Browse files
app.py
CHANGED
|
@@ -10,15 +10,15 @@ plt.rcParams["figure.dpi"] = 100
|
|
| 10 |
|
| 11 |
# ============================================================
|
| 12 |
# Global font styling
|
| 13 |
-
# - All non-title text: Arial Black size
|
| 14 |
-
# - Titles: Arial Black size 14
|
| 15 |
# ============================================================
|
| 16 |
plt.rcParams["font.family"] = "Arial"
|
| 17 |
plt.rcParams["font.weight"] = "black"
|
| 18 |
-
plt.rcParams["font.size"] =
|
| 19 |
|
| 20 |
TITLE_FONTSIZE = 14
|
| 21 |
-
TEXT_FONTSIZE =
|
| 22 |
|
| 23 |
# ============================================================
|
| 24 |
# Class meaning (per your note)
|
|
@@ -125,13 +125,13 @@ def prob_leave_and_stay(X: pd.DataFrame) -> tuple[float, float]:
|
|
| 125 |
return p_leave, p_stay
|
| 126 |
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
# =========================
|
| 129 |
-
# Donut chart
|
| 130 |
-
# Requirements implemented:
|
| 131 |
-
# 1) remove "Probability of Leave" center text (keep only %)
|
| 132 |
-
# 2) title = "Turnover Risk"
|
| 133 |
-
# 3) legend: red turnover, blue retention
|
| 134 |
-
# 4) fonts set globally; title uses size 14
|
| 135 |
# =========================
|
| 136 |
def make_turnover_donut(p_leave: float, p_stay: float):
|
| 137 |
p_leave = max(0.0, min(1.0, float(p_leave)))
|
|
@@ -143,7 +143,7 @@ def make_turnover_donut(p_leave: float, p_stay: float):
|
|
| 143 |
else:
|
| 144 |
p_leave, p_stay = p_leave / s, p_stay / s
|
| 145 |
|
| 146 |
-
fig, ax = plt.subplots(figsize=(4.
|
| 147 |
|
| 148 |
ax.pie(
|
| 149 |
[p_leave, p_stay],
|
|
@@ -153,14 +153,14 @@ def make_turnover_donut(p_leave: float, p_stay: float):
|
|
| 153 |
wedgeprops=dict(width=0.35, edgecolor="white"),
|
| 154 |
)
|
| 155 |
|
| 156 |
-
# Center: ONLY percent
|
| 157 |
ax.text(
|
| 158 |
0,
|
| 159 |
0.00,
|
| 160 |
f"{p_leave*100:.0f}%",
|
| 161 |
ha="center",
|
| 162 |
va="center",
|
| 163 |
-
fontsize=
|
| 164 |
fontweight="black",
|
| 165 |
color=RED,
|
| 166 |
family="Arial",
|
|
@@ -184,12 +184,14 @@ def make_turnover_donut(p_leave: float, p_stay: float):
|
|
| 184 |
|
| 185 |
# =========================
|
| 186 |
# Plot: Average of key drivers + Goal Average
|
|
|
|
| 187 |
# =========================
|
| 188 |
def make_driver_plot(driver_vals: dict):
|
| 189 |
values = [driver_vals[v] for v in ALL_DRIVER_VARS]
|
| 190 |
goals = [GOAL_AVG[v] for v in ALL_DRIVER_VARS]
|
| 191 |
|
| 192 |
-
fig, ax = plt.subplots(figsize=(
|
|
|
|
| 193 |
bars = ax.bar(range(len(ALL_DRIVER_LABELS)), values, tick_label=ALL_DRIVER_LABELS)
|
| 194 |
|
| 195 |
# dashed goal segment per bar
|
|
@@ -217,7 +219,7 @@ def make_driver_plot(driver_vals: dict):
|
|
| 217 |
|
| 218 |
ax.margins(x=0.06)
|
| 219 |
plt.tight_layout()
|
| 220 |
-
plt.subplots_adjust(bottom=0.
|
| 221 |
|
| 222 |
plt.close(fig)
|
| 223 |
return fig
|
|
@@ -225,7 +227,7 @@ def make_driver_plot(driver_vals: dict):
|
|
| 225 |
|
| 226 |
# =========================
|
| 227 |
# SHAP waterfall
|
| 228 |
-
# Make
|
| 229 |
# =========================
|
| 230 |
def make_catboost_shap_plot(X: pd.DataFrame):
|
| 231 |
fig, ax = plt.subplots(figsize=(8.6, 3.6))
|
|
@@ -234,10 +236,10 @@ def make_catboost_shap_plot(X: pd.DataFrame):
|
|
| 234 |
import shap
|
| 235 |
from catboost import Pool
|
| 236 |
|
| 237 |
-
#
|
| 238 |
-
# These affect shap.plots.waterfall styling.
|
| 239 |
try:
|
| 240 |
-
shap.plots.colors.
|
|
|
|
| 241 |
except Exception:
|
| 242 |
pass
|
| 243 |
|
|
@@ -277,13 +279,12 @@ def make_catboost_shap_plot(X: pd.DataFrame):
|
|
| 277 |
|
| 278 |
plt.close(fig)
|
| 279 |
|
| 280 |
-
# SHAP will create its own matplotlib figure
|
| 281 |
shap.plots.waterfall(exp, max_display=8, show=False)
|
| 282 |
|
| 283 |
fig2 = plt.gcf()
|
| 284 |
fig2.set_size_inches(8.6, 3.6)
|
| 285 |
|
| 286 |
-
#
|
| 287 |
try:
|
| 288 |
ax2 = fig2.axes[0]
|
| 289 |
if ax2.get_title():
|
|
@@ -324,6 +325,7 @@ def make_catboost_shap_plot(X: pd.DataFrame):
|
|
| 324 |
|
| 325 |
# =========================
|
| 326 |
# Core predict
|
|
|
|
| 327 |
# =========================
|
| 328 |
def predict(
|
| 329 |
Engagement,
|
|
@@ -350,10 +352,10 @@ def predict(
|
|
| 350 |
p_leave, p_stay = prob_leave_and_stay(X)
|
| 351 |
|
| 352 |
donut_fig = make_turnover_donut(p_leave, p_stay)
|
| 353 |
-
drivers_fig = make_driver_plot(driver_vals)
|
| 354 |
shap_fig = make_catboost_shap_plot(X)
|
|
|
|
| 355 |
|
| 356 |
-
return donut_fig,
|
| 357 |
|
| 358 |
|
| 359 |
# =========================
|
|
@@ -363,7 +365,7 @@ def load_at_risk_group():
|
|
| 363 |
# At-risk group = average of the two lowest clusters (Cluster 2 and 3)
|
| 364 |
target = {v: (CLUSTER_2[v] + CLUSTER_3[v]) / 2.0 for v in ALL_DRIVER_VARS}
|
| 365 |
|
| 366 |
-
donut_fig,
|
| 367 |
target["Engagement"],
|
| 368 |
target["SupportiveGM"],
|
| 369 |
target["WellBeing"],
|
|
@@ -382,8 +384,8 @@ def load_at_risk_group():
|
|
| 382 |
target["DecisionAutonomy"],
|
| 383 |
target["Workload"],
|
| 384 |
donut_fig,
|
| 385 |
-
drivers_fig,
|
| 386 |
shap_fig,
|
|
|
|
| 387 |
)
|
| 388 |
|
| 389 |
|
|
@@ -391,7 +393,7 @@ def apply_recommendation():
|
|
| 391 |
# Recommendation = highest cluster (Cluster 1)
|
| 392 |
target = {v: CLUSTER_1[v] for v in ALL_DRIVER_VARS}
|
| 393 |
|
| 394 |
-
donut_fig,
|
| 395 |
target["Engagement"],
|
| 396 |
target["SupportiveGM"],
|
| 397 |
target["WellBeing"],
|
|
@@ -410,31 +412,24 @@ def apply_recommendation():
|
|
| 410 |
target["DecisionAutonomy"],
|
| 411 |
target["Workload"],
|
| 412 |
donut_fig,
|
| 413 |
-
drivers_fig,
|
| 414 |
shap_fig,
|
|
|
|
| 415 |
)
|
| 416 |
|
| 417 |
|
| 418 |
# =========================
|
| 419 |
-
# UI Layout
|
| 420 |
-
#
|
| 421 |
-
# -
|
| 422 |
-
# - Apply recommendation = blue
|
| 423 |
# =========================
|
| 424 |
CSS = f"""
|
| 425 |
#app-wrap {{ max-width: 1200px; margin: 0 auto; }}
|
| 426 |
-
/* Remove extra padding/margins from blocks */
|
| 427 |
.gr-block {{ padding: 10px 12px !important; }}
|
| 428 |
.gr-form {{ gap: 8px !important; }}
|
| 429 |
.gr-row {{ gap: 10px !important; }}
|
| 430 |
|
| 431 |
-
/* Button colors */
|
| 432 |
#btn_atrisk button {{ background: {RED} !important; color: white !important; border: none !important; }}
|
| 433 |
#btn_reco button {{ background: {BLUE} !important; color: white !important; border: none !important; }}
|
| 434 |
-
|
| 435 |
-
/* Make markdown tighter */
|
| 436 |
-
.compact h2 {{ margin: 0 0 6px 0; }}
|
| 437 |
-
.compact p {{ margin: 0 0 8px 0; }}
|
| 438 |
"""
|
| 439 |
|
| 440 |
with gr.Blocks(css=CSS) as demo:
|
|
@@ -451,8 +446,9 @@ with gr.Blocks(css=CSS) as demo:
|
|
| 451 |
btn_reco = gr.Button("Apply recommendation", elem_id="btn_reco")
|
| 452 |
btn_predict = gr.Button("Predict")
|
| 453 |
|
|
|
|
| 454 |
with gr.Row():
|
| 455 |
-
# LEFT
|
| 456 |
with gr.Column(scale=4):
|
| 457 |
Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
|
| 458 |
SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
|
|
@@ -462,12 +458,15 @@ with gr.Blocks(css=CSS) as demo:
|
|
| 462 |
DecisionAutonomy = gr.Slider(1, 5, value=CLUSTER_3["DecisionAutonomy"], step=0.01, label="Decision Autonomy")
|
| 463 |
Workload = gr.Slider(1, 5, value=CLUSTER_3["Workload"], step=0.01, label="Workload")
|
| 464 |
|
| 465 |
-
# RIGHT
|
| 466 |
with gr.Column(scale=6):
|
| 467 |
donut_plot = gr.Plot(label="Turnover Risk")
|
| 468 |
-
drivers_plot = gr.Plot(label="Average of key drivers")
|
| 469 |
shap_plot = gr.Plot(label="Feature Importance (Shap)")
|
| 470 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
# =========================
|
| 472 |
# WIRE UP EVENTS
|
| 473 |
# =========================
|
|
@@ -476,7 +475,7 @@ with gr.Blocks(css=CSS) as demo:
|
|
| 476 |
btn_predict.click(
|
| 477 |
fn=predict,
|
| 478 |
inputs=slider_inputs,
|
| 479 |
-
outputs=[donut_plot,
|
| 480 |
)
|
| 481 |
|
| 482 |
btn_atrisk.click(
|
|
@@ -491,8 +490,8 @@ with gr.Blocks(css=CSS) as demo:
|
|
| 491 |
DecisionAutonomy,
|
| 492 |
Workload,
|
| 493 |
donut_plot,
|
| 494 |
-
drivers_plot,
|
| 495 |
shap_plot,
|
|
|
|
| 496 |
],
|
| 497 |
)
|
| 498 |
|
|
@@ -508,8 +507,8 @@ with gr.Blocks(css=CSS) as demo:
|
|
| 508 |
DecisionAutonomy,
|
| 509 |
Workload,
|
| 510 |
donut_plot,
|
| 511 |
-
drivers_plot,
|
| 512 |
shap_plot,
|
|
|
|
| 513 |
],
|
| 514 |
)
|
| 515 |
|
|
|
|
| 10 |
|
| 11 |
# ============================================================
|
| 12 |
# Global font styling
|
| 13 |
+
# - All non-title text: Arial Black size 10
|
| 14 |
+
# - Titles: Arial Black size 14 (unchanged from your spec)
|
| 15 |
# ============================================================
|
| 16 |
plt.rcParams["font.family"] = "Arial"
|
| 17 |
plt.rcParams["font.weight"] = "black"
|
| 18 |
+
plt.rcParams["font.size"] = 10
|
| 19 |
|
| 20 |
TITLE_FONTSIZE = 14
|
| 21 |
+
TEXT_FONTSIZE = 10
|
| 22 |
|
| 23 |
# ============================================================
|
| 24 |
# Class meaning (per your note)
|
|
|
|
| 125 |
return p_leave, p_stay
|
| 126 |
|
| 127 |
|
| 128 |
+
def hex_to_rgb01(h: str):
|
| 129 |
+
h = h.lstrip("#")
|
| 130 |
+
return (int(h[0:2], 16) / 255.0, int(h[2:4], 16) / 255.0, int(h[4:6], 16) / 255.0)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
# =========================
|
| 134 |
+
# Donut chart (Turnover Risk)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
# =========================
|
| 136 |
def make_turnover_donut(p_leave: float, p_stay: float):
|
| 137 |
p_leave = max(0.0, min(1.0, float(p_leave)))
|
|
|
|
| 143 |
else:
|
| 144 |
p_leave, p_stay = p_leave / s, p_stay / s
|
| 145 |
|
| 146 |
+
fig, ax = plt.subplots(figsize=(4.8, 3.4))
|
| 147 |
|
| 148 |
ax.pie(
|
| 149 |
[p_leave, p_stay],
|
|
|
|
| 153 |
wedgeprops=dict(width=0.35, edgecolor="white"),
|
| 154 |
)
|
| 155 |
|
| 156 |
+
# Center: ONLY percent
|
| 157 |
ax.text(
|
| 158 |
0,
|
| 159 |
0.00,
|
| 160 |
f"{p_leave*100:.0f}%",
|
| 161 |
ha="center",
|
| 162 |
va="center",
|
| 163 |
+
fontsize=18,
|
| 164 |
fontweight="black",
|
| 165 |
color=RED,
|
| 166 |
family="Arial",
|
|
|
|
| 184 |
|
| 185 |
# =========================
|
| 186 |
# Plot: Average of key drivers + Goal Average
|
| 187 |
+
# (Longer + intended for bottom full-width placement)
|
| 188 |
# =========================
|
| 189 |
def make_driver_plot(driver_vals: dict):
|
| 190 |
values = [driver_vals[v] for v in ALL_DRIVER_VARS]
|
| 191 |
goals = [GOAL_AVG[v] for v in ALL_DRIVER_VARS]
|
| 192 |
|
| 193 |
+
fig, ax = plt.subplots(figsize=(13.0, 3.2)) # longer/wider
|
| 194 |
+
|
| 195 |
bars = ax.bar(range(len(ALL_DRIVER_LABELS)), values, tick_label=ALL_DRIVER_LABELS)
|
| 196 |
|
| 197 |
# dashed goal segment per bar
|
|
|
|
| 219 |
|
| 220 |
ax.margins(x=0.06)
|
| 221 |
plt.tight_layout()
|
| 222 |
+
plt.subplots_adjust(bottom=0.28)
|
| 223 |
|
| 224 |
plt.close(fig)
|
| 225 |
return fig
|
|
|
|
| 227 |
|
| 228 |
# =========================
|
| 229 |
# SHAP waterfall
|
| 230 |
+
# Make positive = RED, negative = BLUE
|
| 231 |
# =========================
|
| 232 |
def make_catboost_shap_plot(X: pd.DataFrame):
|
| 233 |
fig, ax = plt.subplots(figsize=(8.6, 3.6))
|
|
|
|
| 236 |
import shap
|
| 237 |
from catboost import Pool
|
| 238 |
|
| 239 |
+
# force SHAP waterfall color theme (pos=red, neg=blue)
|
|
|
|
| 240 |
try:
|
| 241 |
+
shap.plots.colors.red_rgb = hex_to_rgb01(RED)
|
| 242 |
+
shap.plots.colors.blue_rgb = hex_to_rgb01(BLUE)
|
| 243 |
except Exception:
|
| 244 |
pass
|
| 245 |
|
|
|
|
| 279 |
|
| 280 |
plt.close(fig)
|
| 281 |
|
|
|
|
| 282 |
shap.plots.waterfall(exp, max_display=8, show=False)
|
| 283 |
|
| 284 |
fig2 = plt.gcf()
|
| 285 |
fig2.set_size_inches(8.6, 3.6)
|
| 286 |
|
| 287 |
+
# enforce tick font sizes
|
| 288 |
try:
|
| 289 |
ax2 = fig2.axes[0]
|
| 290 |
if ax2.get_title():
|
|
|
|
| 325 |
|
| 326 |
# =========================
|
| 327 |
# Core predict
|
| 328 |
+
# Returns: donut, shap, drivers (so we can lay them out how you want)
|
| 329 |
# =========================
|
| 330 |
def predict(
|
| 331 |
Engagement,
|
|
|
|
| 352 |
p_leave, p_stay = prob_leave_and_stay(X)
|
| 353 |
|
| 354 |
donut_fig = make_turnover_donut(p_leave, p_stay)
|
|
|
|
| 355 |
shap_fig = make_catboost_shap_plot(X)
|
| 356 |
+
drivers_fig = make_driver_plot(driver_vals)
|
| 357 |
|
| 358 |
+
return donut_fig, shap_fig, drivers_fig
|
| 359 |
|
| 360 |
|
| 361 |
# =========================
|
|
|
|
| 365 |
# At-risk group = average of the two lowest clusters (Cluster 2 and 3)
|
| 366 |
target = {v: (CLUSTER_2[v] + CLUSTER_3[v]) / 2.0 for v in ALL_DRIVER_VARS}
|
| 367 |
|
| 368 |
+
donut_fig, shap_fig, drivers_fig = predict(
|
| 369 |
target["Engagement"],
|
| 370 |
target["SupportiveGM"],
|
| 371 |
target["WellBeing"],
|
|
|
|
| 384 |
target["DecisionAutonomy"],
|
| 385 |
target["Workload"],
|
| 386 |
donut_fig,
|
|
|
|
| 387 |
shap_fig,
|
| 388 |
+
drivers_fig,
|
| 389 |
)
|
| 390 |
|
| 391 |
|
|
|
|
| 393 |
# Recommendation = highest cluster (Cluster 1)
|
| 394 |
target = {v: CLUSTER_1[v] for v in ALL_DRIVER_VARS}
|
| 395 |
|
| 396 |
+
donut_fig, shap_fig, drivers_fig = predict(
|
| 397 |
target["Engagement"],
|
| 398 |
target["SupportiveGM"],
|
| 399 |
target["WellBeing"],
|
|
|
|
| 412 |
target["DecisionAutonomy"],
|
| 413 |
target["Workload"],
|
| 414 |
donut_fig,
|
|
|
|
| 415 |
shap_fig,
|
| 416 |
+
drivers_fig,
|
| 417 |
)
|
| 418 |
|
| 419 |
|
| 420 |
# =========================
|
| 421 |
+
# UI Layout
|
| 422 |
+
# - Right side: donut then SHAP directly below
|
| 423 |
+
# - Bottom: wide drivers bar chart centered under everything
|
|
|
|
| 424 |
# =========================
|
| 425 |
CSS = f"""
|
| 426 |
#app-wrap {{ max-width: 1200px; margin: 0 auto; }}
|
|
|
|
| 427 |
.gr-block {{ padding: 10px 12px !important; }}
|
| 428 |
.gr-form {{ gap: 8px !important; }}
|
| 429 |
.gr-row {{ gap: 10px !important; }}
|
| 430 |
|
|
|
|
| 431 |
#btn_atrisk button {{ background: {RED} !important; color: white !important; border: none !important; }}
|
| 432 |
#btn_reco button {{ background: {BLUE} !important; color: white !important; border: none !important; }}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
"""
|
| 434 |
|
| 435 |
with gr.Blocks(css=CSS) as demo:
|
|
|
|
| 446 |
btn_reco = gr.Button("Apply recommendation", elem_id="btn_reco")
|
| 447 |
btn_predict = gr.Button("Predict")
|
| 448 |
|
| 449 |
+
# TOP MAIN ROW
|
| 450 |
with gr.Row():
|
| 451 |
+
# LEFT: sliders
|
| 452 |
with gr.Column(scale=4):
|
| 453 |
Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
|
| 454 |
SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
|
|
|
|
| 458 |
DecisionAutonomy = gr.Slider(1, 5, value=CLUSTER_3["DecisionAutonomy"], step=0.01, label="Decision Autonomy")
|
| 459 |
Workload = gr.Slider(1, 5, value=CLUSTER_3["Workload"], step=0.01, label="Workload")
|
| 460 |
|
| 461 |
+
# RIGHT: donut + SHAP stacked
|
| 462 |
with gr.Column(scale=6):
|
| 463 |
donut_plot = gr.Plot(label="Turnover Risk")
|
|
|
|
| 464 |
shap_plot = gr.Plot(label="Feature Importance (Shap)")
|
| 465 |
|
| 466 |
+
# BOTTOM ROW: wide driver bars centered under everything
|
| 467 |
+
with gr.Row():
|
| 468 |
+
drivers_plot = gr.Plot(label="Average of key drivers")
|
| 469 |
+
|
| 470 |
# =========================
|
| 471 |
# WIRE UP EVENTS
|
| 472 |
# =========================
|
|
|
|
| 475 |
btn_predict.click(
|
| 476 |
fn=predict,
|
| 477 |
inputs=slider_inputs,
|
| 478 |
+
outputs=[donut_plot, shap_plot, drivers_plot],
|
| 479 |
)
|
| 480 |
|
| 481 |
btn_atrisk.click(
|
|
|
|
| 490 |
DecisionAutonomy,
|
| 491 |
Workload,
|
| 492 |
donut_plot,
|
|
|
|
| 493 |
shap_plot,
|
| 494 |
+
drivers_plot,
|
| 495 |
],
|
| 496 |
)
|
| 497 |
|
|
|
|
| 507 |
DecisionAutonomy,
|
| 508 |
Workload,
|
| 509 |
donut_plot,
|
|
|
|
| 510 |
shap_plot,
|
| 511 |
+
drivers_plot,
|
| 512 |
],
|
| 513 |
)
|
| 514 |
|