M3T92025 / app.py
colbywentlandt's picture
u
7172cb6 verified
import gradio as gr
import pickle
import pandas as pd
import shap
from shap.plots._force_matplotlib import draw_additive_plot
import gradio as gr
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
# load the model from disk
loaded_model = pickle.load(open("model-reduced.pkl", 'rb'))
# Setup SHAP
explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
# Create a dictionary for feature name mapping to make SHAP interpretable
feature_names_mapping = {
'Tenure': 'Tenure',
'LearningDevelopment': 'Learning & Development',
'WorkEnvironment': 'Work Environment',
'WorkLifeBalance': 'Work-Life Balance',
'WellBeing': 'Well-Being',
'Engagement': 'Employee Engagement',
'GM3': 'General Manager Leadership'
}
# Create the main function for server
def main_func(Tenure, LearningDevelopment, WorkEnvironment, WorkLifeBalance, WellBeing, Engagement, GM3):
tenure_key = {
'< 6 Months' : 2,
'6 Months - 2 Years' : 4,
'2 to 5 Years' : 6,
'5+ Years' : 7,
}
default_tenure = 5
val_tenure = tenure_key.get(Tenure, default_tenure)
new_row = pd.DataFrame.from_dict({
'Tenure':val_tenure,
'LearningDevelopment':LearningDevelopment,
'WorkEnvironment':WorkEnvironment,
'WorkLifeBalance': WorkLifeBalance,
'WellBeing':WellBeing,
'Engagement':Engagement,
'GM3':GM3}, orient = 'index').transpose()
prob = loaded_model.predict_proba(new_row)
shap_values = explainer(new_row)
fig, ax = plt.subplots(figsize=(6, 4))
values = shap_values.values[0]
features = new_row.columns
colors = ['#FF0000' if v < 0 else '#1E4380' for v in values]
sorted_indices = np.argsort(np.abs(values))[-7:]
sorted_values = values[sorted_indices]
sorted_features = [feature_names_mapping[features[i]] for i in sorted_indices]
sorted_colors = [colors[i] for i in sorted_indices]
ax.barh(sorted_features, sorted_values, color=sorted_colors)
ax.set_xlabel("SHAP Value Impact")
ax.set_title("Feature Importance (SHAP)")
plt.tight_layout()
local_plot = plt.gcf()
plt.close()
return {"Leave ❌": float(prob[0][0]), "Stay βœ…": 1-float(prob[0][0])}, local_plot
# Create the UI
title = "**Employee Turnover Predictor & Interpreter** πŸ¨πŸ›ŽοΈ"
description1 = """
This app predicts whether an employee is likely to stay with Hilton or leave based on their satisfaction across seven key workplace factors. The app produces two outputs: 1) the predicted probability of staying or leaving, and 2) a Shapley force plot that visualizes the impact of each factor on the stay/leave prediction."""
description2 = """
To use the app, click on one of the examples below, or adjust the values of the employee satisfaction factors, and click on Analyze. ✨
"""
# Function to change logo color to #1E4380
def change_logo_color(logo_path, hex_color="#FFFFFF"):
r = int(hex_color[1:3], 16)
g = int(hex_color[3:5], 16)
b = int(hex_color[5:7], 16)
img = Image.open(logo_path).convert("RGBA")
data = np.array(img)
alpha = data[:, :, 3]
mask = alpha > 0
colored_data = data.copy()
colored_data[mask, 0] = r
colored_data[mask, 1] = g
colored_data[mask, 2] = b
colored_img = Image.fromarray(colored_data)
colored_logo_path = "colored_logo.png"
colored_img.save(colored_logo_path)
return colored_logo_path
# Create the colored logo
colored_logo_path = change_logo_color("logo.png")
def change_specific_color(logo_path, target_hex="#232d4b", new_hex="#FFFFFF", tolerance=30):
# Convert hex colors to RGB
target_r = int(target_hex[1:3], 16)
target_g = int(target_hex[3:5], 16)
target_b = int(target_hex[5:7], 16)
new_r = int(new_hex[1:3], 16)
new_g = int(new_hex[3:5], 16)
new_b = int(new_hex[5:7], 16)
# Open the image
img = Image.open(logo_path).convert("RGBA")
data = np.array(img)
# Create a color distance mask using Euclidean distance in RGB space
r_diff = (data[:,:,0].astype(int) - target_r)**2
g_diff = (data[:,:,1].astype(int) - target_g)**2
b_diff = (data[:,:,2].astype(int) - target_b)**2
# Calculate color distance
color_distance = np.sqrt(r_diff + g_diff + b_diff)
# Create mask for pixels within tolerance of target color and having alpha > 0
color_mask = (color_distance <= tolerance) & (data[:,:,3] > 0)
# Apply new color to matched pixels
modified_data = data.copy()
modified_data[color_mask, 0] = new_r
modified_data[color_mask, 1] = new_g
modified_data[color_mask, 2] = new_b
# Create a new image with the modified data
modified_img = Image.fromarray(modified_data)
# Save the modified image
output_path = "modified_logo.png"
modified_img.save(output_path)
return output_path
uva_logo_path = change_specific_color("UVA-Logo.png")
css = """
.container img {
max-width: 50%;
max-height: auto;
display: inline-block;
margin: 0;
}
# body {
# background-color: white; # looks bad, keep here for now though
# }
"""
# Create Gradio interface with the colored logo
with gr.Blocks(title=title, css=css) as demo:
with gr.Row(elem_classes=["container"]):
gr.Image(colored_logo_path, show_label=False, container=False)
gr.Image(uva_logo_path, show_label=False, container=False)
gr.Markdown(f"## {title}")
gr.Markdown(description1)
gr.Markdown("""---""")
gr.Markdown(description2)
gr.Markdown("""---""")
with gr.Row():
with gr.Column():
Tenure = gr.Dropdown(["< 6 Months", "6 Months - 2 Years", "2 to 5 Years", "5+ Years"], label="Employee Tenure", value="< 6 Months")
LearningDevelopment = gr.Slider(label="Learning & Development", minimum=1, maximum=5, value=4, step=.1)
WorkEnvironment = gr.Slider(label="Work Environment", minimum=1, maximum=5, value=4, step=.1)
WorkLifeBalance = gr.Slider(label="Work-Life Balance", minimum=1, maximum=5, value=4, step=.1)
WellBeing = gr.Slider(label="Well Being", minimum=1, maximum=5, value=4, step=.1)
Engagement = gr.Slider(label="Employee Engagement", minimum=1, maximum=5, value=4, step=.1)
GM3 = gr.Slider(label="General Manager Leadership (GM3)", minimum=1, maximum=5, value=4, step=.1)
submit_btn = gr.Button("Analyze")
with gr.Column(visible=True,scale=1, min_width=600) as output_col:
label = gr.Label(label = "Predicted Label")
local_plot = gr.Plot(label = 'Shap:')
submit_btn.click(
main_func,
[Tenure, LearningDevelopment, WorkEnvironment, WorkLifeBalance, WellBeing, Engagement, GM3],
[label,local_plot], api_name="Employee_Turnover"
)
gr.Markdown("### Click on any of the examples below to see how it works:")
gr.Examples([
["6 Months - 2 Years", 2.9,3.4,3.1,3.3,3.4,3.2],
["6 Months - 2 Years", 4,4.1,3.9,4,4.2,4],
["6 Months - 2 Years", 4.8,4.9,4.7,4.8,4.9,4.8]],
[Tenure, LearningDevelopment, WorkEnvironment, WorkLifeBalance, WellBeing, Engagement, GM3],
[label,local_plot], main_func, cache_examples=True)
demo.launch(share=True)