T6 / app.py
mns6rh's picture
Update app.py
927a652 verified
#!/usr/bin/env python
# coding: utf-8
import joblib
import pandas as pd
import gradio as gr
import matplotlib.pyplot as plt
plt.rcParams["figure.dpi"] = 100
# ============================================================
# Fonts
# ============================================================
plt.rcParams["font.family"] = "Arial"
plt.rcParams["font.weight"] = "black"
plt.rcParams["font.size"] = 10
TITLE_FONTSIZE = 14
TEXT_FONTSIZE = 10
# ============================================================
# Colors
# ============================================================
RED = "#d62728"
BLUE = "#1f77b4"
# ============================================================
# Class meaning
# ============================================================
LEAVE_CLASS = 1
STAY_CLASS = 0
# =========================
# Load model
# =========================
model = joblib.load("final.joblib")
FEATURES = [
"Engagement",
"SupportiveGM",
"ManagementLevel",
"WellBeing",
"Voice",
"DecisionAutonomy",
"Workload",
"WorkEnvironment",
]
# =========================
# Cluster anchors
# =========================
CLUSTER_1 = {
"Voice": 4.84,
"DecisionAutonomy": 4.90,
"Workload": 4.72,
"WellBeing": 4.8397,
"WorkEnvironment": 4.8858,
"SupportiveGM": 4.8583,
"Engagement": 4.9324,
}
CLUSTER_2 = {
"Voice": 3.94,
"DecisionAutonomy": 4.24,
"Workload": 3.76,
"WellBeing": 4.0251,
"WorkEnvironment": 4.1484,
"SupportiveGM": 4.1275,
"Engagement": 4.2828,
}
CLUSTER_3 = {
"Voice": 2.39,
"DecisionAutonomy": 3.55,
"Workload": 2.68,
"WellBeing": 3.0299,
"WorkEnvironment": 3.4537,
"SupportiveGM": 3.2208,
"Engagement": 3.3909,
}
ALL_DRIVER_VARS = [
"Engagement",
"SupportiveGM",
"WellBeing",
"WorkEnvironment",
"Voice",
"DecisionAutonomy",
"Workload",
]
ALL_DRIVER_LABELS = [
"Engagement",
"Supportive GM",
"Well-Being",
"Work Environment",
"Voice",
"Decision Autonomy",
"Workload",
]
GOAL_AVG = {v: CLUSTER_1[v] for v in ALL_DRIVER_VARS}
# ============================================================
# Helpers
# ============================================================
def clamp_1_5(x):
return max(1.0, min(5.0, float(x)))
def build_X(vals):
row = {f: vals[f] for f in FEATURES}
return pd.DataFrame([[row[f] for f in FEATURES]], columns=FEATURES)
def prob_leave_and_stay(X):
probs = model.predict_proba(X)[0]
classes = list(model.classes_)
p_leave = float(probs[classes.index(LEAVE_CLASS)])
p_stay = float(probs[classes.index(STAY_CLASS)])
return p_leave, p_stay
def hex_to_rgb01(h):
h = h.lstrip("#")
return (
int(h[0:2], 16) / 255.0,
int(h[2:4], 16) / 255.0,
int(h[4:6], 16) / 255.0,
)
# ============================================================
# Donut chart
# ============================================================
def make_turnover_donut(p_leave, p_stay):
fig, ax = plt.subplots(figsize=(8, 3))
ax.pie(
[p_leave, p_stay],
startangle=90,
colors=[RED, BLUE],
wedgeprops=dict(width=0.35, edgecolor="white"),
)
ax.text(
0,
0,
f"{p_leave*100:.0f}%",
ha="center",
va="center",
fontsize=18,
color=RED,
fontweight="black",
)
ax.set_title("Turnover Risk", fontsize=TITLE_FONTSIZE)
plt.close(fig)
return fig
# ============================================================
# Driver chart
# ============================================================
def make_driver_plot(vals):
values = [vals[v] for v in ALL_DRIVER_VARS]
goals = [GOAL_AVG[v] for v in ALL_DRIVER_VARS]
fig, ax = plt.subplots(figsize=(12,3))
bars = ax.bar(range(len(values)), values)
for i,b in enumerate(bars):
y = goals[i]
ax.plot([b.get_x(), b.get_x()+b.get_width()], [y,y], linestyle="--")
ax.set_xticks(range(len(values)))
ax.set_xticklabels(ALL_DRIVER_LABELS)
ax.set_ylim(1,5.4)
ax.set_title("Average of key drivers")
plt.close(fig)
return fig
# ============================================================
# SHAP (original waterfall restored, red forced)
# ============================================================
def make_catboost_shap_plot(X):
import shap
from catboost import Pool
try:
shap.plots.colors.red_rgb = hex_to_rgb01(RED)
shap.plots.colors.blue_rgb = hex_to_rgb01(BLUE)
pool = Pool(X)
shap_vals = model.get_feature_importance(pool, type="ShapValues")
classes = list(model.classes_)
class_idx = classes.index(LEAVE_CLASS)
if shap_vals.ndim == 3:
base = float(shap_vals[0, -1, class_idx])
values = shap_vals[0, :-1, class_idx]
else:
base = float(shap_vals[0, -1])
values = shap_vals[0, :-1]
feature_names = list(X.columns)
data_row = X.iloc[0].values
exp = shap.Explanation(
values=values,
base_values=base,
data=data_row,
feature_names=feature_names,
)
shap.plots.waterfall(exp, max_display=8, show=False)
fig = plt.gcf()
fig.set_size_inches(8,3)
plt.tight_layout()
plt.close(fig)
return fig
except Exception as e:
fig, ax = plt.subplots()
ax.text(0.5,0.5,"SHAP unavailable",ha="center")
plt.close(fig)
return fig
# ============================================================
# Predict
# ============================================================
def predict(Engagement,SupportiveGM,WellBeing,WorkEnvironment,Voice,DecisionAutonomy,Workload):
vals = {
"Engagement": clamp_1_5(Engagement),
"SupportiveGM": clamp_1_5(SupportiveGM),
"WellBeing": clamp_1_5(WellBeing),
"WorkEnvironment": clamp_1_5(WorkEnvironment),
"Voice": clamp_1_5(Voice),
"DecisionAutonomy": clamp_1_5(DecisionAutonomy),
"Workload": clamp_1_5(Workload),
"ManagementLevel": 2,
}
X = build_X(vals)
p_leave, p_stay = prob_leave_and_stay(X)
donut = make_turnover_donut(p_leave,p_stay)
shap = make_catboost_shap_plot(X)
drivers = make_driver_plot(vals)
return donut, shap, drivers
# ============================================================
# Buttons
# ============================================================
def load_risk():
target = {v:(CLUSTER_2[v]+CLUSTER_3[v])/2 for v in ALL_DRIVER_VARS}
return (
target["Engagement"],
target["SupportiveGM"],
target["WellBeing"],
target["WorkEnvironment"],
target["Voice"],
target["DecisionAutonomy"],
target["Workload"],
*predict(**target)
)
def apply_recommendation():
target = {v:CLUSTER_1[v] for v in ALL_DRIVER_VARS}
return (
target["Engagement"],
target["SupportiveGM"],
target["WellBeing"],
target["WorkEnvironment"],
target["Voice"],
target["DecisionAutonomy"],
target["Workload"],
*predict(**target)
)
def hilton_heroes():
return apply_recommendation()
# ============================================================
# CSS
# ============================================================
CSS = f"""
#btn_risk button {{ background:{RED}; color:white; }}
#btn_heroes button {{ background:#0a5eb8; color:white; }}
#btn_reco button {{ background:{BLUE}; color:white; }}
"""
# ============================================================
# UI
# ============================================================
with gr.Blocks(css=CSS) as demo:
gr.Markdown("## Predicting Intent to Stay")
with gr.Row():
btn_risk = gr.Button("Immediate and Silent Risk", elem_id="btn_risk")
btn_heroes = gr.Button("Hilton Heroes", elem_id="btn_heroes")
btn_reco = gr.Button("Apply Recommendation", elem_id="btn_reco")
with gr.Row():
with gr.Column():
Engagement = gr.Slider(1,5,value=3,label="Engagement")
SupportiveGM = gr.Slider(1,5,value=3,label="Supportive GM")
WellBeing = gr.Slider(1,5,value=3,label="Well Being")
WorkEnvironment = gr.Slider(1,5,value=3,label="Work Environment")
Voice = gr.Slider(1,5,value=3,label="Voice")
DecisionAutonomy = gr.Slider(1,5,value=3,label="Decision Autonomy")
Workload = gr.Slider(1,5,value=3,label="Workload")
btn_predict = gr.Button("Predict")
with gr.Column():
donut_plot = gr.Plot()
shap_plot = gr.Plot()
drivers_plot = gr.Plot()
sliders = [
Engagement,
SupportiveGM,
WellBeing,
WorkEnvironment,
Voice,
DecisionAutonomy,
Workload,
]
btn_predict.click(
predict,
sliders,
[donut_plot,shap_plot,drivers_plot]
)
btn_risk.click(
load_risk,
[],
sliders+[donut_plot,shap_plot,drivers_plot]
)
btn_heroes.click(
hilton_heroes,
[],
sliders+[donut_plot,shap_plot,drivers_plot]
)
btn_reco.click(
apply_recommendation,
[],
sliders+[donut_plot,shap_plot,drivers_plot]
)
demo.launch()