#!/usr/bin/env python # coding: utf-8 import joblib import pandas as pd import gradio as gr import matplotlib.pyplot as plt plt.rcParams["figure.dpi"] = 100 # ============================================================ # Fonts # ============================================================ plt.rcParams["font.family"] = "Arial" plt.rcParams["font.weight"] = "black" plt.rcParams["font.size"] = 10 TITLE_FONTSIZE = 14 TEXT_FONTSIZE = 10 # ============================================================ # Colors # ============================================================ RED = "#d62728" BLUE = "#1f77b4" # ============================================================ # Class meaning # ============================================================ LEAVE_CLASS = 1 STAY_CLASS = 0 # ========================= # Load model # ========================= model = joblib.load("final.joblib") FEATURES = [ "Engagement", "SupportiveGM", "ManagementLevel", "WellBeing", "Voice", "DecisionAutonomy", "Workload", "WorkEnvironment", ] # ========================= # Cluster anchors # ========================= CLUSTER_1 = { "Voice": 4.84, "DecisionAutonomy": 4.90, "Workload": 4.72, "WellBeing": 4.8397, "WorkEnvironment": 4.8858, "SupportiveGM": 4.8583, "Engagement": 4.9324, } CLUSTER_2 = { "Voice": 3.94, "DecisionAutonomy": 4.24, "Workload": 3.76, "WellBeing": 4.0251, "WorkEnvironment": 4.1484, "SupportiveGM": 4.1275, "Engagement": 4.2828, } CLUSTER_3 = { "Voice": 2.39, "DecisionAutonomy": 3.55, "Workload": 2.68, "WellBeing": 3.0299, "WorkEnvironment": 3.4537, "SupportiveGM": 3.2208, "Engagement": 3.3909, } ALL_DRIVER_VARS = [ "Engagement", "SupportiveGM", "WellBeing", "WorkEnvironment", "Voice", "DecisionAutonomy", "Workload", ] ALL_DRIVER_LABELS = [ "Engagement", "Supportive GM", "Well-Being", "Work Environment", "Voice", "Decision Autonomy", "Workload", ] GOAL_AVG = {v: CLUSTER_1[v] for v in ALL_DRIVER_VARS} # ============================================================ # Helpers # ============================================================ def clamp_1_5(x): return max(1.0, min(5.0, float(x))) def build_X(vals): row = {f: vals[f] for f in FEATURES} return pd.DataFrame([[row[f] for f in FEATURES]], columns=FEATURES) def prob_leave_and_stay(X): probs = model.predict_proba(X)[0] classes = list(model.classes_) p_leave = float(probs[classes.index(LEAVE_CLASS)]) p_stay = float(probs[classes.index(STAY_CLASS)]) return p_leave, p_stay def hex_to_rgb01(h): h = h.lstrip("#") return ( int(h[0:2], 16) / 255.0, int(h[2:4], 16) / 255.0, int(h[4:6], 16) / 255.0, ) # ============================================================ # Donut chart # ============================================================ def make_turnover_donut(p_leave, p_stay): fig, ax = plt.subplots(figsize=(8, 3)) ax.pie( [p_leave, p_stay], startangle=90, colors=[RED, BLUE], wedgeprops=dict(width=0.35, edgecolor="white"), ) ax.text( 0, 0, f"{p_leave*100:.0f}%", ha="center", va="center", fontsize=18, color=RED, fontweight="black", ) ax.set_title("Turnover Risk", fontsize=TITLE_FONTSIZE) plt.close(fig) return fig # ============================================================ # Driver chart # ============================================================ def make_driver_plot(vals): values = [vals[v] for v in ALL_DRIVER_VARS] goals = [GOAL_AVG[v] for v in ALL_DRIVER_VARS] fig, ax = plt.subplots(figsize=(12,3)) bars = ax.bar(range(len(values)), values) for i,b in enumerate(bars): y = goals[i] ax.plot([b.get_x(), b.get_x()+b.get_width()], [y,y], linestyle="--") ax.set_xticks(range(len(values))) ax.set_xticklabels(ALL_DRIVER_LABELS) ax.set_ylim(1,5.4) ax.set_title("Average of key drivers") plt.close(fig) return fig # ============================================================ # SHAP (original waterfall restored, red forced) # ============================================================ def make_catboost_shap_plot(X): import shap from catboost import Pool try: shap.plots.colors.red_rgb = hex_to_rgb01(RED) shap.plots.colors.blue_rgb = hex_to_rgb01(BLUE) pool = Pool(X) shap_vals = model.get_feature_importance(pool, type="ShapValues") classes = list(model.classes_) class_idx = classes.index(LEAVE_CLASS) if shap_vals.ndim == 3: base = float(shap_vals[0, -1, class_idx]) values = shap_vals[0, :-1, class_idx] else: base = float(shap_vals[0, -1]) values = shap_vals[0, :-1] feature_names = list(X.columns) data_row = X.iloc[0].values exp = shap.Explanation( values=values, base_values=base, data=data_row, feature_names=feature_names, ) shap.plots.waterfall(exp, max_display=8, show=False) fig = plt.gcf() fig.set_size_inches(8,3) plt.tight_layout() plt.close(fig) return fig except Exception as e: fig, ax = plt.subplots() ax.text(0.5,0.5,"SHAP unavailable",ha="center") plt.close(fig) return fig # ============================================================ # Predict # ============================================================ def predict(Engagement,SupportiveGM,WellBeing,WorkEnvironment,Voice,DecisionAutonomy,Workload): vals = { "Engagement": clamp_1_5(Engagement), "SupportiveGM": clamp_1_5(SupportiveGM), "WellBeing": clamp_1_5(WellBeing), "WorkEnvironment": clamp_1_5(WorkEnvironment), "Voice": clamp_1_5(Voice), "DecisionAutonomy": clamp_1_5(DecisionAutonomy), "Workload": clamp_1_5(Workload), "ManagementLevel": 2, } X = build_X(vals) p_leave, p_stay = prob_leave_and_stay(X) donut = make_turnover_donut(p_leave,p_stay) shap = make_catboost_shap_plot(X) drivers = make_driver_plot(vals) return donut, shap, drivers # ============================================================ # Buttons # ============================================================ def load_risk(): target = {v:(CLUSTER_2[v]+CLUSTER_3[v])/2 for v in ALL_DRIVER_VARS} return ( target["Engagement"], target["SupportiveGM"], target["WellBeing"], target["WorkEnvironment"], target["Voice"], target["DecisionAutonomy"], target["Workload"], *predict(**target) ) def apply_recommendation(): target = {v:CLUSTER_1[v] for v in ALL_DRIVER_VARS} return ( target["Engagement"], target["SupportiveGM"], target["WellBeing"], target["WorkEnvironment"], target["Voice"], target["DecisionAutonomy"], target["Workload"], *predict(**target) ) def hilton_heroes(): return apply_recommendation() # ============================================================ # CSS # ============================================================ CSS = f""" #btn_risk button {{ background:{RED}; color:white; }} #btn_heroes button {{ background:#0a5eb8; color:white; }} #btn_reco button {{ background:{BLUE}; color:white; }} """ # ============================================================ # UI # ============================================================ with gr.Blocks(css=CSS) as demo: gr.Markdown("## Predicting Intent to Stay") with gr.Row(): btn_risk = gr.Button("Immediate and Silent Risk", elem_id="btn_risk") btn_heroes = gr.Button("Hilton Heroes", elem_id="btn_heroes") btn_reco = gr.Button("Apply Recommendation", elem_id="btn_reco") with gr.Row(): with gr.Column(): Engagement = gr.Slider(1,5,value=3,label="Engagement") SupportiveGM = gr.Slider(1,5,value=3,label="Supportive GM") WellBeing = gr.Slider(1,5,value=3,label="Well Being") WorkEnvironment = gr.Slider(1,5,value=3,label="Work Environment") Voice = gr.Slider(1,5,value=3,label="Voice") DecisionAutonomy = gr.Slider(1,5,value=3,label="Decision Autonomy") Workload = gr.Slider(1,5,value=3,label="Workload") btn_predict = gr.Button("Predict") with gr.Column(): donut_plot = gr.Plot() shap_plot = gr.Plot() drivers_plot = gr.Plot() sliders = [ Engagement, SupportiveGM, WellBeing, WorkEnvironment, Voice, DecisionAutonomy, Workload, ] btn_predict.click( predict, sliders, [donut_plot,shap_plot,drivers_plot] ) btn_risk.click( load_risk, [], sliders+[donut_plot,shap_plot,drivers_plot] ) btn_heroes.click( hilton_heroes, [], sliders+[donut_plot,shap_plot,drivers_plot] ) btn_reco.click( apply_recommendation, [], sliders+[donut_plot,shap_plot,drivers_plot] ) demo.launch()