Update app.py
Browse files
app.py
CHANGED
|
@@ -11,7 +11,7 @@ plt.rcParams["figure.dpi"] = 100
|
|
| 11 |
# ============================================================
|
| 12 |
# Global font styling
|
| 13 |
# - All non-title text: Arial Black size 10
|
| 14 |
-
# - Titles: Arial Black size 14
|
| 15 |
# ============================================================
|
| 16 |
plt.rcParams["font.family"] = "Arial"
|
| 17 |
plt.rcParams["font.weight"] = "black"
|
|
@@ -130,8 +130,24 @@ def hex_to_rgb01(h: str):
|
|
| 130 |
return (int(h[0:2], 16) / 255.0, int(h[2:4], 16) / 255.0, int(h[4:6], 16) / 255.0)
|
| 131 |
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
# =========================
|
| 134 |
# Donut chart (Turnover Risk)
|
|
|
|
|
|
|
| 135 |
# =========================
|
| 136 |
def make_turnover_donut(p_leave: float, p_stay: float):
|
| 137 |
p_leave = max(0.0, min(1.0, float(p_leave)))
|
|
@@ -143,7 +159,8 @@ def make_turnover_donut(p_leave: float, p_stay: float):
|
|
| 143 |
else:
|
| 144 |
p_leave, p_stay = p_leave / s, p_stay / s
|
| 145 |
|
| 146 |
-
|
|
|
|
| 147 |
|
| 148 |
ax.pie(
|
| 149 |
[p_leave, p_stay],
|
|
@@ -166,31 +183,39 @@ def make_turnover_donut(p_leave: float, p_stay: float):
|
|
| 166 |
family="Arial",
|
| 167 |
)
|
| 168 |
|
| 169 |
-
ax.legend(
|
| 170 |
["Probability of Turnover", "Probability of Retention"],
|
| 171 |
loc="center left",
|
| 172 |
bbox_to_anchor=(1.02, 0.5),
|
| 173 |
frameon=False,
|
| 174 |
prop={"family": "Arial", "weight": "black", "size": TEXT_FONTSIZE},
|
|
|
|
|
|
|
| 175 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
-
ax.set_title("Turnover Risk", pad=
|
| 178 |
ax.set_aspect("equal")
|
| 179 |
|
|
|
|
| 180 |
plt.tight_layout()
|
|
|
|
|
|
|
| 181 |
plt.close(fig)
|
| 182 |
return fig
|
| 183 |
|
| 184 |
|
| 185 |
# =========================
|
| 186 |
# Plot: Average of key drivers + Goal Average
|
| 187 |
-
# (Longer + intended for bottom full-width placement)
|
| 188 |
# =========================
|
| 189 |
def make_driver_plot(driver_vals: dict):
|
| 190 |
values = [driver_vals[v] for v in ALL_DRIVER_VARS]
|
| 191 |
goals = [GOAL_AVG[v] for v in ALL_DRIVER_VARS]
|
| 192 |
|
| 193 |
-
fig, ax = plt.subplots(figsize=(13.0, 3.
|
| 194 |
|
| 195 |
bars = ax.bar(range(len(ALL_DRIVER_LABELS)), values, tick_label=ALL_DRIVER_LABELS)
|
| 196 |
|
|
@@ -214,9 +239,11 @@ def make_driver_plot(driver_vals: dict):
|
|
| 214 |
|
| 215 |
ax.set_ylim(1, 5.4)
|
| 216 |
ax.set_yticks([1, 2, 3, 4, 5])
|
| 217 |
-
ax.set_ylabel("Score (1–5)"
|
| 218 |
ax.set_title("Average of key drivers", fontsize=TITLE_FONTSIZE, fontweight="black", family="Arial")
|
| 219 |
|
|
|
|
|
|
|
| 220 |
ax.margins(x=0.06)
|
| 221 |
plt.tight_layout()
|
| 222 |
plt.subplots_adjust(bottom=0.28)
|
|
@@ -227,16 +254,17 @@ def make_driver_plot(driver_vals: dict):
|
|
| 227 |
|
| 228 |
# =========================
|
| 229 |
# SHAP waterfall
|
| 230 |
-
#
|
|
|
|
| 231 |
# =========================
|
| 232 |
def make_catboost_shap_plot(X: pd.DataFrame):
|
| 233 |
-
fig, ax = plt.subplots(figsize=(8.6, 3.
|
| 234 |
|
| 235 |
try:
|
| 236 |
import shap
|
| 237 |
from catboost import Pool
|
| 238 |
|
| 239 |
-
# force SHAP waterfall
|
| 240 |
try:
|
| 241 |
shap.plots.colors.red_rgb = hex_to_rgb01(RED)
|
| 242 |
shap.plots.colors.blue_rgb = hex_to_rgb01(BLUE)
|
|
@@ -282,13 +310,12 @@ def make_catboost_shap_plot(X: pd.DataFrame):
|
|
| 282 |
shap.plots.waterfall(exp, max_display=8, show=False)
|
| 283 |
|
| 284 |
fig2 = plt.gcf()
|
| 285 |
-
fig2.set_size_inches(8.6, 3.
|
| 286 |
|
| 287 |
-
#
|
| 288 |
try:
|
| 289 |
ax2 = fig2.axes[0]
|
| 290 |
-
|
| 291 |
-
ax2.set_title(ax2.get_title(), fontsize=TITLE_FONTSIZE, fontweight="black", family="Arial")
|
| 292 |
for t in ax2.get_xticklabels() + ax2.get_yticklabels():
|
| 293 |
t.set_fontsize(TEXT_FONTSIZE)
|
| 294 |
t.set_fontweight("black")
|
|
@@ -325,7 +352,7 @@ def make_catboost_shap_plot(X: pd.DataFrame):
|
|
| 325 |
|
| 326 |
# =========================
|
| 327 |
# Core predict
|
| 328 |
-
# Returns: donut, shap, drivers
|
| 329 |
# =========================
|
| 330 |
def predict(
|
| 331 |
Engagement,
|
|
@@ -419,8 +446,11 @@ def apply_recommendation():
|
|
| 419 |
|
| 420 |
# =========================
|
| 421 |
# UI Layout
|
| 422 |
-
# -
|
| 423 |
-
# -
|
|
|
|
|
|
|
|
|
|
| 424 |
# =========================
|
| 425 |
CSS = f"""
|
| 426 |
#app-wrap {{ max-width: 1200px; margin: 0 auto; }}
|
|
@@ -430,6 +460,9 @@ CSS = f"""
|
|
| 430 |
|
| 431 |
#btn_atrisk button {{ background: {RED} !important; color: white !important; border: none !important; }}
|
| 432 |
#btn_reco button {{ background: {BLUE} !important; color: white !important; border: none !important; }}
|
|
|
|
|
|
|
|
|
|
| 433 |
"""
|
| 434 |
|
| 435 |
with gr.Blocks(css=CSS) as demo:
|
|
@@ -440,15 +473,14 @@ with gr.Blocks(css=CSS) as demo:
|
|
| 440 |
"</div>"
|
| 441 |
)
|
| 442 |
|
| 443 |
-
# TOP BUTTON ROW
|
| 444 |
with gr.Row():
|
| 445 |
btn_atrisk = gr.Button("At risk group", elem_id="btn_atrisk")
|
| 446 |
btn_reco = gr.Button("Apply recommendation", elem_id="btn_reco")
|
| 447 |
-
btn_predict = gr.Button("Predict")
|
| 448 |
|
| 449 |
-
#
|
| 450 |
with gr.Row():
|
| 451 |
-
# LEFT: sliders
|
| 452 |
with gr.Column(scale=4):
|
| 453 |
Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
|
| 454 |
SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
|
|
@@ -458,7 +490,9 @@ with gr.Blocks(css=CSS) as demo:
|
|
| 458 |
DecisionAutonomy = gr.Slider(1, 5, value=CLUSTER_3["DecisionAutonomy"], step=0.01, label="Decision Autonomy")
|
| 459 |
Workload = gr.Slider(1, 5, value=CLUSTER_3["Workload"], step=0.01, label="Workload")
|
| 460 |
|
| 461 |
-
|
|
|
|
|
|
|
| 462 |
with gr.Column(scale=6):
|
| 463 |
donut_plot = gr.Plot(label="Turnover Risk")
|
| 464 |
shap_plot = gr.Plot(label="Feature Importance (Shap)")
|
|
|
|
| 11 |
# ============================================================
|
| 12 |
# Global font styling
|
| 13 |
# - All non-title text: Arial Black size 10
|
| 14 |
+
# - Titles: Arial Black size 14
|
| 15 |
# ============================================================
|
| 16 |
plt.rcParams["font.family"] = "Arial"
|
| 17 |
plt.rcParams["font.weight"] = "black"
|
|
|
|
| 130 |
return (int(h[0:2], 16) / 255.0, int(h[2:4], 16) / 255.0, int(h[4:6], 16) / 255.0)
|
| 131 |
|
| 132 |
|
| 133 |
+
def stylize_axis_fonts(ax):
|
| 134 |
+
# Apply consistent fonts to ticks/labels
|
| 135 |
+
for t in ax.get_xticklabels() + ax.get_yticklabels():
|
| 136 |
+
t.set_fontsize(TEXT_FONTSIZE)
|
| 137 |
+
t.set_fontweight("black")
|
| 138 |
+
t.set_family("Arial")
|
| 139 |
+
ax.xaxis.label.set_fontsize(TEXT_FONTSIZE)
|
| 140 |
+
ax.xaxis.label.set_fontweight("black")
|
| 141 |
+
ax.xaxis.label.set_family("Arial")
|
| 142 |
+
ax.yaxis.label.set_fontsize(TEXT_FONTSIZE)
|
| 143 |
+
ax.yaxis.label.set_fontweight("black")
|
| 144 |
+
ax.yaxis.label.set_family("Arial")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
# =========================
|
| 148 |
# Donut chart (Turnover Risk)
|
| 149 |
+
# - Legend made visible by widening figure + right margin
|
| 150 |
+
# - No center subtitle text
|
| 151 |
# =========================
|
| 152 |
def make_turnover_donut(p_leave: float, p_stay: float):
|
| 153 |
p_leave = max(0.0, min(1.0, float(p_leave)))
|
|
|
|
| 159 |
else:
|
| 160 |
p_leave, p_stay = p_leave / s, p_stay / s
|
| 161 |
|
| 162 |
+
# Wider so legend isn't clipped
|
| 163 |
+
fig, ax = plt.subplots(figsize=(8.6, 3.2))
|
| 164 |
|
| 165 |
ax.pie(
|
| 166 |
[p_leave, p_stay],
|
|
|
|
| 183 |
family="Arial",
|
| 184 |
)
|
| 185 |
|
| 186 |
+
leg = ax.legend(
|
| 187 |
["Probability of Turnover", "Probability of Retention"],
|
| 188 |
loc="center left",
|
| 189 |
bbox_to_anchor=(1.02, 0.5),
|
| 190 |
frameon=False,
|
| 191 |
prop={"family": "Arial", "weight": "black", "size": TEXT_FONTSIZE},
|
| 192 |
+
handlelength=1.2,
|
| 193 |
+
borderaxespad=0.0,
|
| 194 |
)
|
| 195 |
+
for txt in leg.get_texts():
|
| 196 |
+
txt.set_fontsize(TEXT_FONTSIZE)
|
| 197 |
+
txt.set_fontweight("black")
|
| 198 |
+
txt.set_family("Arial")
|
| 199 |
|
| 200 |
+
ax.set_title("Turnover Risk", pad=8, fontweight="black", fontsize=TITLE_FONTSIZE, family="Arial")
|
| 201 |
ax.set_aspect("equal")
|
| 202 |
|
| 203 |
+
# Leave room on the right for legend
|
| 204 |
plt.tight_layout()
|
| 205 |
+
plt.subplots_adjust(right=0.80)
|
| 206 |
+
|
| 207 |
plt.close(fig)
|
| 208 |
return fig
|
| 209 |
|
| 210 |
|
| 211 |
# =========================
|
| 212 |
# Plot: Average of key drivers + Goal Average
|
|
|
|
| 213 |
# =========================
|
| 214 |
def make_driver_plot(driver_vals: dict):
|
| 215 |
values = [driver_vals[v] for v in ALL_DRIVER_VARS]
|
| 216 |
goals = [GOAL_AVG[v] for v in ALL_DRIVER_VARS]
|
| 217 |
|
| 218 |
+
fig, ax = plt.subplots(figsize=(13.0, 3.1)) # long/wide
|
| 219 |
|
| 220 |
bars = ax.bar(range(len(ALL_DRIVER_LABELS)), values, tick_label=ALL_DRIVER_LABELS)
|
| 221 |
|
|
|
|
| 239 |
|
| 240 |
ax.set_ylim(1, 5.4)
|
| 241 |
ax.set_yticks([1, 2, 3, 4, 5])
|
| 242 |
+
ax.set_ylabel("Score (1–5)")
|
| 243 |
ax.set_title("Average of key drivers", fontsize=TITLE_FONTSIZE, fontweight="black", family="Arial")
|
| 244 |
|
| 245 |
+
stylize_axis_fonts(ax)
|
| 246 |
+
|
| 247 |
ax.margins(x=0.06)
|
| 248 |
plt.tight_layout()
|
| 249 |
plt.subplots_adjust(bottom=0.28)
|
|
|
|
| 254 |
|
| 255 |
# =========================
|
| 256 |
# SHAP waterfall
|
| 257 |
+
# - Force positive = RED, negative = BLUE
|
| 258 |
+
# - Add title: "Feature Importance (Shap)"
|
| 259 |
# =========================
|
| 260 |
def make_catboost_shap_plot(X: pd.DataFrame):
|
| 261 |
+
fig, ax = plt.subplots(figsize=(8.6, 3.2))
|
| 262 |
|
| 263 |
try:
|
| 264 |
import shap
|
| 265 |
from catboost import Pool
|
| 266 |
|
| 267 |
+
# Try to force SHAP waterfall palette to red/blue
|
| 268 |
try:
|
| 269 |
shap.plots.colors.red_rgb = hex_to_rgb01(RED)
|
| 270 |
shap.plots.colors.blue_rgb = hex_to_rgb01(BLUE)
|
|
|
|
| 310 |
shap.plots.waterfall(exp, max_display=8, show=False)
|
| 311 |
|
| 312 |
fig2 = plt.gcf()
|
| 313 |
+
fig2.set_size_inches(8.6, 3.2)
|
| 314 |
|
| 315 |
+
# Put a clean title on the SHAP plot
|
| 316 |
try:
|
| 317 |
ax2 = fig2.axes[0]
|
| 318 |
+
ax2.set_title("Feature Importance (Shap)", fontsize=TITLE_FONTSIZE, fontweight="black", family="Arial")
|
|
|
|
| 319 |
for t in ax2.get_xticklabels() + ax2.get_yticklabels():
|
| 320 |
t.set_fontsize(TEXT_FONTSIZE)
|
| 321 |
t.set_fontweight("black")
|
|
|
|
| 352 |
|
| 353 |
# =========================
|
| 354 |
# Core predict
|
| 355 |
+
# Returns: donut, shap, drivers
|
| 356 |
# =========================
|
| 357 |
def predict(
|
| 358 |
Engagement,
|
|
|
|
| 446 |
|
| 447 |
# =========================
|
| 448 |
# UI Layout
|
| 449 |
+
# - Top: At risk + Apply recommendation buttons
|
| 450 |
+
# - Left: sliders, then Predict button (under sliders)
|
| 451 |
+
# - Right: donut then SHAP directly below (same width)
|
| 452 |
+
# - Bottom: wide drivers bar chart centered
|
| 453 |
+
# - Sliders column a bit shorter to fit Predict button cleanly
|
| 454 |
# =========================
|
| 455 |
CSS = f"""
|
| 456 |
#app-wrap {{ max-width: 1200px; margin: 0 auto; }}
|
|
|
|
| 460 |
|
| 461 |
#btn_atrisk button {{ background: {RED} !important; color: white !important; border: none !important; }}
|
| 462 |
#btn_reco button {{ background: {BLUE} !important; color: white !important; border: none !important; }}
|
| 463 |
+
|
| 464 |
+
/* Tighten slider spacing a bit so Predict fits nicely */
|
| 465 |
+
.gr-slider {{ margin-bottom: 6px !important; }}
|
| 466 |
"""
|
| 467 |
|
| 468 |
with gr.Blocks(css=CSS) as demo:
|
|
|
|
| 473 |
"</div>"
|
| 474 |
)
|
| 475 |
|
| 476 |
+
# TOP BUTTON ROW (keep these at top)
|
| 477 |
with gr.Row():
|
| 478 |
btn_atrisk = gr.Button("At risk group", elem_id="btn_atrisk")
|
| 479 |
btn_reco = gr.Button("Apply recommendation", elem_id="btn_reco")
|
|
|
|
| 480 |
|
| 481 |
+
# MAIN ROW
|
| 482 |
with gr.Row():
|
| 483 |
+
# LEFT: sliders + Predict (under sliders)
|
| 484 |
with gr.Column(scale=4):
|
| 485 |
Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
|
| 486 |
SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
|
|
|
|
| 490 |
DecisionAutonomy = gr.Slider(1, 5, value=CLUSTER_3["DecisionAutonomy"], step=0.01, label="Decision Autonomy")
|
| 491 |
Workload = gr.Slider(1, 5, value=CLUSTER_3["Workload"], step=0.01, label="Workload")
|
| 492 |
|
| 493 |
+
btn_predict = gr.Button("Predict")
|
| 494 |
+
|
| 495 |
+
# RIGHT: donut + SHAP stacked (same width)
|
| 496 |
with gr.Column(scale=6):
|
| 497 |
donut_plot = gr.Plot(label="Turnover Risk")
|
| 498 |
shap_plot = gr.Plot(label="Feature Importance (Shap)")
|