| | |
| | |
| |
|
| | import joblib |
| | import pandas as pd |
| | import gradio as gr |
| | import matplotlib.pyplot as plt |
| |
|
| | plt.rcParams["figure.dpi"] = 100 |
| |
|
| | |
| | |
| | |
| | plt.rcParams["font.family"] = "Arial" |
| | plt.rcParams["font.weight"] = "black" |
| | plt.rcParams["font.size"] = 10 |
| |
|
| | TITLE_FONTSIZE = 14 |
| | TEXT_FONTSIZE = 10 |
| |
|
| | |
| | |
| | |
| | RED = "#d62728" |
| | BLUE = "#1f77b4" |
| |
|
| | |
| | |
| | |
| | LEAVE_CLASS = 1 |
| | STAY_CLASS = 0 |
| |
|
| | |
| | |
| | |
| | model = joblib.load("final.joblib") |
| |
|
| | FEATURES = [ |
| | "Engagement", |
| | "SupportiveGM", |
| | "ManagementLevel", |
| | "WellBeing", |
| | "Voice", |
| | "DecisionAutonomy", |
| | "Workload", |
| | "WorkEnvironment", |
| | ] |
| |
|
| | |
| | |
| | |
| | CLUSTER_1 = { |
| | "Voice": 4.84, |
| | "DecisionAutonomy": 4.90, |
| | "Workload": 4.72, |
| | "WellBeing": 4.8397, |
| | "WorkEnvironment": 4.8858, |
| | "SupportiveGM": 4.8583, |
| | "Engagement": 4.9324, |
| | } |
| |
|
| | CLUSTER_2 = { |
| | "Voice": 3.94, |
| | "DecisionAutonomy": 4.24, |
| | "Workload": 3.76, |
| | "WellBeing": 4.0251, |
| | "WorkEnvironment": 4.1484, |
| | "SupportiveGM": 4.1275, |
| | "Engagement": 4.2828, |
| | } |
| |
|
| | CLUSTER_3 = { |
| | "Voice": 2.39, |
| | "DecisionAutonomy": 3.55, |
| | "Workload": 2.68, |
| | "WellBeing": 3.0299, |
| | "WorkEnvironment": 3.4537, |
| | "SupportiveGM": 3.2208, |
| | "Engagement": 3.3909, |
| | } |
| |
|
| | ALL_DRIVER_VARS = [ |
| | "Engagement", |
| | "SupportiveGM", |
| | "WellBeing", |
| | "WorkEnvironment", |
| | "Voice", |
| | "DecisionAutonomy", |
| | "Workload", |
| | ] |
| |
|
| | ALL_DRIVER_LABELS = [ |
| | "Engagement", |
| | "Supportive GM", |
| | "Well-Being", |
| | "Work Environment", |
| | "Voice", |
| | "Decision Autonomy", |
| | "Workload", |
| | ] |
| |
|
| | GOAL_AVG = {v: CLUSTER_1[v] for v in ALL_DRIVER_VARS} |
| |
|
| |
|
| | |
| | |
| | |
| | def clamp_1_5(x): |
| | return max(1.0, min(5.0, float(x))) |
| |
|
| |
|
| | def build_X(vals): |
| | row = {f: vals[f] for f in FEATURES} |
| | return pd.DataFrame([[row[f] for f in FEATURES]], columns=FEATURES) |
| |
|
| |
|
| | def prob_leave_and_stay(X): |
| | probs = model.predict_proba(X)[0] |
| | classes = list(model.classes_) |
| | p_leave = float(probs[classes.index(LEAVE_CLASS)]) |
| | p_stay = float(probs[classes.index(STAY_CLASS)]) |
| | return p_leave, p_stay |
| |
|
| |
|
| | def hex_to_rgb01(h): |
| | h = h.lstrip("#") |
| | return ( |
| | int(h[0:2], 16) / 255.0, |
| | int(h[2:4], 16) / 255.0, |
| | int(h[4:6], 16) / 255.0, |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| | def make_turnover_donut(p_leave, p_stay): |
| |
|
| | fig, ax = plt.subplots(figsize=(8, 3)) |
| |
|
| | ax.pie( |
| | [p_leave, p_stay], |
| | startangle=90, |
| | colors=[RED, BLUE], |
| | wedgeprops=dict(width=0.35, edgecolor="white"), |
| | ) |
| |
|
| | ax.text( |
| | 0, |
| | 0, |
| | f"{p_leave*100:.0f}%", |
| | ha="center", |
| | va="center", |
| | fontsize=18, |
| | color=RED, |
| | fontweight="black", |
| | ) |
| |
|
| | ax.set_title("Turnover Risk", fontsize=TITLE_FONTSIZE) |
| |
|
| | plt.close(fig) |
| | return fig |
| |
|
| |
|
| | |
| | |
| | |
| | def make_driver_plot(vals): |
| |
|
| | values = [vals[v] for v in ALL_DRIVER_VARS] |
| | goals = [GOAL_AVG[v] for v in ALL_DRIVER_VARS] |
| |
|
| | fig, ax = plt.subplots(figsize=(12,3)) |
| |
|
| | bars = ax.bar(range(len(values)), values) |
| |
|
| | for i,b in enumerate(bars): |
| | y = goals[i] |
| | ax.plot([b.get_x(), b.get_x()+b.get_width()], [y,y], linestyle="--") |
| |
|
| | ax.set_xticks(range(len(values))) |
| | ax.set_xticklabels(ALL_DRIVER_LABELS) |
| |
|
| | ax.set_ylim(1,5.4) |
| | ax.set_title("Average of key drivers") |
| |
|
| | plt.close(fig) |
| | return fig |
| |
|
| |
|
| | |
| | |
| | |
| | def make_catboost_shap_plot(X): |
| |
|
| | import shap |
| | from catboost import Pool |
| |
|
| | try: |
| |
|
| | shap.plots.colors.red_rgb = hex_to_rgb01(RED) |
| | shap.plots.colors.blue_rgb = hex_to_rgb01(BLUE) |
| |
|
| | pool = Pool(X) |
| | shap_vals = model.get_feature_importance(pool, type="ShapValues") |
| |
|
| | classes = list(model.classes_) |
| | class_idx = classes.index(LEAVE_CLASS) |
| |
|
| | if shap_vals.ndim == 3: |
| | base = float(shap_vals[0, -1, class_idx]) |
| | values = shap_vals[0, :-1, class_idx] |
| | else: |
| | base = float(shap_vals[0, -1]) |
| | values = shap_vals[0, :-1] |
| |
|
| | feature_names = list(X.columns) |
| | data_row = X.iloc[0].values |
| |
|
| | exp = shap.Explanation( |
| | values=values, |
| | base_values=base, |
| | data=data_row, |
| | feature_names=feature_names, |
| | ) |
| |
|
| | shap.plots.waterfall(exp, max_display=8, show=False) |
| |
|
| | fig = plt.gcf() |
| | fig.set_size_inches(8,3) |
| |
|
| | plt.tight_layout() |
| | plt.close(fig) |
| |
|
| | return fig |
| |
|
| | except Exception as e: |
| |
|
| | fig, ax = plt.subplots() |
| | ax.text(0.5,0.5,"SHAP unavailable",ha="center") |
| | plt.close(fig) |
| | return fig |
| |
|
| |
|
| | |
| | |
| | |
| | def predict(Engagement,SupportiveGM,WellBeing,WorkEnvironment,Voice,DecisionAutonomy,Workload): |
| |
|
| | vals = { |
| | "Engagement": clamp_1_5(Engagement), |
| | "SupportiveGM": clamp_1_5(SupportiveGM), |
| | "WellBeing": clamp_1_5(WellBeing), |
| | "WorkEnvironment": clamp_1_5(WorkEnvironment), |
| | "Voice": clamp_1_5(Voice), |
| | "DecisionAutonomy": clamp_1_5(DecisionAutonomy), |
| | "Workload": clamp_1_5(Workload), |
| | "ManagementLevel": 2, |
| | } |
| |
|
| | X = build_X(vals) |
| |
|
| | p_leave, p_stay = prob_leave_and_stay(X) |
| |
|
| | donut = make_turnover_donut(p_leave,p_stay) |
| | shap = make_catboost_shap_plot(X) |
| | drivers = make_driver_plot(vals) |
| |
|
| | return donut, shap, drivers |
| |
|
| |
|
| | |
| | |
| | |
| | def load_risk(): |
| |
|
| | target = {v:(CLUSTER_2[v]+CLUSTER_3[v])/2 for v in ALL_DRIVER_VARS} |
| |
|
| | return ( |
| | target["Engagement"], |
| | target["SupportiveGM"], |
| | target["WellBeing"], |
| | target["WorkEnvironment"], |
| | target["Voice"], |
| | target["DecisionAutonomy"], |
| | target["Workload"], |
| | *predict(**target) |
| | ) |
| |
|
| |
|
| | def apply_recommendation(): |
| |
|
| | target = {v:CLUSTER_1[v] for v in ALL_DRIVER_VARS} |
| |
|
| | return ( |
| | target["Engagement"], |
| | target["SupportiveGM"], |
| | target["WellBeing"], |
| | target["WorkEnvironment"], |
| | target["Voice"], |
| | target["DecisionAutonomy"], |
| | target["Workload"], |
| | *predict(**target) |
| | ) |
| |
|
| |
|
| | def hilton_heroes(): |
| | return apply_recommendation() |
| |
|
| |
|
| | |
| | |
| | |
| | CSS = f""" |
| | #btn_risk button {{ background:{RED}; color:white; }} |
| | #btn_heroes button {{ background:#0a5eb8; color:white; }} |
| | #btn_reco button {{ background:{BLUE}; color:white; }} |
| | """ |
| |
|
| |
|
| | |
| | |
| | |
| | with gr.Blocks(css=CSS) as demo: |
| |
|
| | gr.Markdown("## Predicting Intent to Stay") |
| |
|
| | with gr.Row(): |
| | btn_risk = gr.Button("Immediate and Silent Risk", elem_id="btn_risk") |
| | btn_heroes = gr.Button("Hilton Heroes", elem_id="btn_heroes") |
| | btn_reco = gr.Button("Apply Recommendation", elem_id="btn_reco") |
| |
|
| | with gr.Row(): |
| |
|
| | with gr.Column(): |
| |
|
| | Engagement = gr.Slider(1,5,value=3,label="Engagement") |
| | SupportiveGM = gr.Slider(1,5,value=3,label="Supportive GM") |
| | WellBeing = gr.Slider(1,5,value=3,label="Well Being") |
| | WorkEnvironment = gr.Slider(1,5,value=3,label="Work Environment") |
| | Voice = gr.Slider(1,5,value=3,label="Voice") |
| | DecisionAutonomy = gr.Slider(1,5,value=3,label="Decision Autonomy") |
| | Workload = gr.Slider(1,5,value=3,label="Workload") |
| |
|
| | btn_predict = gr.Button("Predict") |
| |
|
| | with gr.Column(): |
| |
|
| | donut_plot = gr.Plot() |
| | shap_plot = gr.Plot() |
| |
|
| | drivers_plot = gr.Plot() |
| |
|
| | sliders = [ |
| | Engagement, |
| | SupportiveGM, |
| | WellBeing, |
| | WorkEnvironment, |
| | Voice, |
| | DecisionAutonomy, |
| | Workload, |
| | ] |
| |
|
| | btn_predict.click( |
| | predict, |
| | sliders, |
| | [donut_plot,shap_plot,drivers_plot] |
| | ) |
| |
|
| | btn_risk.click( |
| | load_risk, |
| | [], |
| | sliders+[donut_plot,shap_plot,drivers_plot] |
| | ) |
| |
|
| | btn_heroes.click( |
| | hilton_heroes, |
| | [], |
| | sliders+[donut_plot,shap_plot,drivers_plot] |
| | ) |
| |
|
| | btn_reco.click( |
| | apply_recommendation, |
| | [], |
| | sliders+[donut_plot,shap_plot,drivers_plot] |
| | ) |
| | |
| | demo.launch() |