Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
# coding: utf-8
|
| 3 |
-
|
| 4 |
import joblib
|
| 5 |
import pandas as pd
|
| 6 |
import gradio as gr
|
|
@@ -11,7 +10,6 @@ import matplotlib.pyplot as plt
|
|
| 11 |
# =========================
|
| 12 |
model = joblib.load("cat (1).joblib")
|
| 13 |
|
| 14 |
-
# MUST match model.feature_names_ EXACTLY
|
| 15 |
FEATURES = [
|
| 16 |
"Engagement",
|
| 17 |
"SupportiveGM",
|
|
@@ -63,7 +61,6 @@ CLUSTER_PROFILES = {
|
|
| 63 |
},
|
| 64 |
}
|
| 65 |
|
| 66 |
-
# Defaults for hidden vars when user adjusts sliders manually (TOTAL means)
|
| 67 |
HIDDEN_DEFAULTS = {"Voice": 4.23, "DecisionAutonomy": 4.50, "Workload": 4.13}
|
| 68 |
VISIBLE_DRIVERS = ["Engagement", "SupportiveGM", "WellBeing", "WorkEnvironment"]
|
| 69 |
|
|
@@ -77,18 +74,13 @@ def clamp_1_5(x: float) -> float:
|
|
| 77 |
|
| 78 |
|
| 79 |
def build_X(vals: dict) -> pd.DataFrame:
|
| 80 |
-
"""Build 1-row DataFrame in exact model feature order."""
|
| 81 |
return pd.DataFrame([[vals[f] for f in FEATURES]], columns=FEATURES)
|
| 82 |
|
| 83 |
|
| 84 |
def prob_at_risk(X: pd.DataFrame) -> float:
|
| 85 |
-
"""
|
| 86 |
-
Return P(At Risk) using classes_ lookup.
|
| 87 |
-
Your encoding: At Risk / low intent is class 1.
|
| 88 |
-
"""
|
| 89 |
probs = model.predict_proba(X)[0]
|
| 90 |
classes = list(model.classes_)
|
| 91 |
-
idx = classes.index(1)
|
| 92 |
return float(probs[idx])
|
| 93 |
|
| 94 |
|
|
@@ -97,26 +89,22 @@ def risk_status(p: float, cutoff: float = 0.5) -> str:
|
|
| 97 |
|
| 98 |
|
| 99 |
def stable_threshold_from_cluster1() -> float:
|
| 100 |
-
"""Threshold = min of the 4 visible drivers in Cluster 1."""
|
| 101 |
c1 = CLUSTER_PROFILES["Cluster 1"]
|
| 102 |
return min(c1[d] for d in VISIBLE_DRIVERS)
|
| 103 |
|
| 104 |
|
| 105 |
def make_driver_plot_midlevel(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
|
| 106 |
-
"""
|
| 107 |
-
Plot only Mid-level managers' 4 drivers, with a threshold line based on Cluster 1 minimum.
|
| 108 |
-
Bars turn red if below threshold, green if at/above threshold.
|
| 109 |
-
"""
|
| 110 |
threshold = stable_threshold_from_cluster1()
|
| 111 |
drivers = ["Engagement", "Supportive GM", "Well-Being", "Work Environment"]
|
| 112 |
values = [Engagement, SupportiveGM, WellBeing, WorkEnvironment]
|
| 113 |
colors = ["seagreen" if v >= threshold else "firebrick" for v in values]
|
| 114 |
|
| 115 |
-
# Deterministic layout
|
| 116 |
fig, ax = plt.subplots(figsize=(11, 5), constrained_layout=True)
|
| 117 |
-
ax.bar(drivers, values, color=colors)
|
| 118 |
|
|
|
|
| 119 |
ax.axhline(threshold, linestyle="--", linewidth=2)
|
|
|
|
| 120 |
ax.text(
|
| 121 |
0.99,
|
| 122 |
threshold,
|
|
@@ -126,19 +114,19 @@ def make_driver_plot_midlevel(Engagement, SupportiveGM, WellBeing, WorkEnvironme
|
|
| 126 |
va="center",
|
| 127 |
)
|
| 128 |
|
| 129 |
-
#
|
| 130 |
ax.set_ylim(1, 5.6)
|
| 131 |
ax.set_ylabel("Survey Score (1–5)")
|
| 132 |
ax.set_title("Key Drivers (Mid-level managers) — Probability of Being At Risk of Leaving")
|
| 133 |
|
| 134 |
-
# Lock x-limits
|
| 135 |
ax.set_xlim(-0.6, len(drivers) - 0.4)
|
| 136 |
|
| 137 |
return fig
|
| 138 |
|
| 139 |
|
| 140 |
# =========================
|
| 141 |
-
#
|
| 142 |
# =========================
|
| 143 |
def predict_dashboard(
|
| 144 |
Engagement,
|
|
@@ -149,13 +137,11 @@ def predict_dashboard(
|
|
| 149 |
DecisionAutonomy=None,
|
| 150 |
Workload=None,
|
| 151 |
):
|
| 152 |
-
# visible
|
| 153 |
Engagement = clamp_1_5(Engagement)
|
| 154 |
SupportiveGM = clamp_1_5(SupportiveGM)
|
| 155 |
WellBeing = clamp_1_5(WellBeing)
|
| 156 |
WorkEnvironment = clamp_1_5(WorkEnvironment)
|
| 157 |
|
| 158 |
-
# hidden
|
| 159 |
Voice = clamp_1_5(HIDDEN_DEFAULTS["Voice"] if Voice is None else Voice)
|
| 160 |
DecisionAutonomy = clamp_1_5(HIDDEN_DEFAULTS["DecisionAutonomy"] if DecisionAutonomy is None else DecisionAutonomy)
|
| 161 |
Workload = clamp_1_5(HIDDEN_DEFAULTS["Workload"] if Workload is None else Workload)
|
|
@@ -182,7 +168,6 @@ def predict_dashboard(
|
|
| 182 |
)
|
| 183 |
|
| 184 |
df_table = pd.DataFrame(rows)
|
| 185 |
-
|
| 186 |
mid_status = df_table.loc[df_table["Management Level"] == MGMT_LABELS[2], "Risk Status"].iloc[0]
|
| 187 |
headline = f"Mid-level managers: **{mid_status}**"
|
| 188 |
|
|
@@ -211,17 +196,16 @@ def apply_cluster(cluster_name: str):
|
|
| 211 |
DecisionAutonomy=DecisionAutonomy,
|
| 212 |
Workload=Workload,
|
| 213 |
)
|
| 214 |
-
|
| 215 |
return Engagement, SupportiveGM, WellBeing, WorkEnvironment, headline, table, fig
|
| 216 |
|
| 217 |
|
| 218 |
# =========================
|
| 219 |
-
#
|
| 220 |
# =========================
|
| 221 |
CSS = """
|
| 222 |
#headline_box { min-height: 44px; }
|
| 223 |
|
| 224 |
-
/*
|
| 225 |
#table_box {
|
| 226 |
height: 240px;
|
| 227 |
overflow: auto;
|
|
@@ -229,6 +213,23 @@ CSS = """
|
|
| 229 |
border-radius: 10px;
|
| 230 |
padding: 6px;
|
| 231 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
"""
|
| 233 |
|
| 234 |
with gr.Blocks() as demo:
|
|
@@ -238,7 +239,7 @@ with gr.Blocks() as demo:
|
|
| 238 |
"The table shows **At Risk / Not At Risk** by management level."
|
| 239 |
)
|
| 240 |
|
| 241 |
-
with gr.Row(
|
| 242 |
with gr.Column(scale=1, min_width=420):
|
| 243 |
with gr.Row():
|
| 244 |
btn_c1 = gr.Button("Load Cluster 1")
|
|
@@ -255,12 +256,13 @@ with gr.Blocks() as demo:
|
|
| 255 |
with gr.Column(scale=1, min_width=520):
|
| 256 |
headline = gr.Markdown(elem_id="headline_box")
|
| 257 |
|
| 258 |
-
# Wrap dataframe in a fixed-height container using HTML
|
| 259 |
gr.HTML('<div id="table_box">')
|
| 260 |
table = gr.Dataframe(label="Risk Status by Management Level")
|
| 261 |
gr.HTML("</div>")
|
| 262 |
|
| 263 |
-
|
|
|
|
|
|
|
| 264 |
|
| 265 |
btn_predict.click(
|
| 266 |
fn=predict_dashboard,
|
|
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
# coding: utf-8
|
|
|
|
| 3 |
import joblib
|
| 4 |
import pandas as pd
|
| 5 |
import gradio as gr
|
|
|
|
| 10 |
# =========================
|
| 11 |
model = joblib.load("cat (1).joblib")
|
| 12 |
|
|
|
|
| 13 |
FEATURES = [
|
| 14 |
"Engagement",
|
| 15 |
"SupportiveGM",
|
|
|
|
| 61 |
},
|
| 62 |
}
|
| 63 |
|
|
|
|
| 64 |
HIDDEN_DEFAULTS = {"Voice": 4.23, "DecisionAutonomy": 4.50, "Workload": 4.13}
|
| 65 |
VISIBLE_DRIVERS = ["Engagement", "SupportiveGM", "WellBeing", "WorkEnvironment"]
|
| 66 |
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
def build_X(vals: dict) -> pd.DataFrame:
|
|
|
|
| 77 |
return pd.DataFrame([[vals[f] for f in FEATURES]], columns=FEATURES)
|
| 78 |
|
| 79 |
|
| 80 |
def prob_at_risk(X: pd.DataFrame) -> float:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
probs = model.predict_proba(X)[0]
|
| 82 |
classes = list(model.classes_)
|
| 83 |
+
idx = classes.index(1) # At Risk = 1
|
| 84 |
return float(probs[idx])
|
| 85 |
|
| 86 |
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
def stable_threshold_from_cluster1() -> float:
|
|
|
|
| 92 |
c1 = CLUSTER_PROFILES["Cluster 1"]
|
| 93 |
return min(c1[d] for d in VISIBLE_DRIVERS)
|
| 94 |
|
| 95 |
|
| 96 |
def make_driver_plot_midlevel(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
threshold = stable_threshold_from_cluster1()
|
| 98 |
drivers = ["Engagement", "Supportive GM", "Well-Being", "Work Environment"]
|
| 99 |
values = [Engagement, SupportiveGM, WellBeing, WorkEnvironment]
|
| 100 |
colors = ["seagreen" if v >= threshold else "firebrick" for v in values]
|
| 101 |
|
| 102 |
+
# Deterministic layout helps reduce jitter
|
| 103 |
fig, ax = plt.subplots(figsize=(11, 5), constrained_layout=True)
|
|
|
|
| 104 |
|
| 105 |
+
ax.bar(drivers, values, color=colors)
|
| 106 |
ax.axhline(threshold, linestyle="--", linewidth=2)
|
| 107 |
+
|
| 108 |
ax.text(
|
| 109 |
0.99,
|
| 110 |
threshold,
|
|
|
|
| 114 |
va="center",
|
| 115 |
)
|
| 116 |
|
| 117 |
+
# More space above 5
|
| 118 |
ax.set_ylim(1, 5.6)
|
| 119 |
ax.set_ylabel("Survey Score (1–5)")
|
| 120 |
ax.set_title("Key Drivers (Mid-level managers) — Probability of Being At Risk of Leaving")
|
| 121 |
|
| 122 |
+
# Lock x-limits so bar geometry never changes
|
| 123 |
ax.set_xlim(-0.6, len(drivers) - 0.4)
|
| 124 |
|
| 125 |
return fig
|
| 126 |
|
| 127 |
|
| 128 |
# =========================
|
| 129 |
+
# Prediction
|
| 130 |
# =========================
|
| 131 |
def predict_dashboard(
|
| 132 |
Engagement,
|
|
|
|
| 137 |
DecisionAutonomy=None,
|
| 138 |
Workload=None,
|
| 139 |
):
|
|
|
|
| 140 |
Engagement = clamp_1_5(Engagement)
|
| 141 |
SupportiveGM = clamp_1_5(SupportiveGM)
|
| 142 |
WellBeing = clamp_1_5(WellBeing)
|
| 143 |
WorkEnvironment = clamp_1_5(WorkEnvironment)
|
| 144 |
|
|
|
|
| 145 |
Voice = clamp_1_5(HIDDEN_DEFAULTS["Voice"] if Voice is None else Voice)
|
| 146 |
DecisionAutonomy = clamp_1_5(HIDDEN_DEFAULTS["DecisionAutonomy"] if DecisionAutonomy is None else DecisionAutonomy)
|
| 147 |
Workload = clamp_1_5(HIDDEN_DEFAULTS["Workload"] if Workload is None else Workload)
|
|
|
|
| 168 |
)
|
| 169 |
|
| 170 |
df_table = pd.DataFrame(rows)
|
|
|
|
| 171 |
mid_status = df_table.loc[df_table["Management Level"] == MGMT_LABELS[2], "Risk Status"].iloc[0]
|
| 172 |
headline = f"Mid-level managers: **{mid_status}**"
|
| 173 |
|
|
|
|
| 196 |
DecisionAutonomy=DecisionAutonomy,
|
| 197 |
Workload=Workload,
|
| 198 |
)
|
|
|
|
| 199 |
return Engagement, SupportiveGM, WellBeing, WorkEnvironment, headline, table, fig
|
| 200 |
|
| 201 |
|
| 202 |
# =========================
|
| 203 |
+
# UI sizing via CSS wrappers (works on older Gradio too)
|
| 204 |
# =========================
|
| 205 |
CSS = """
|
| 206 |
#headline_box { min-height: 44px; }
|
| 207 |
|
| 208 |
+
/* Fixed-height containers to prevent page jumping */
|
| 209 |
#table_box {
|
| 210 |
height: 240px;
|
| 211 |
overflow: auto;
|
|
|
|
| 213 |
border-radius: 10px;
|
| 214 |
padding: 6px;
|
| 215 |
}
|
| 216 |
+
|
| 217 |
+
#plot_box {
|
| 218 |
+
height: 460px;
|
| 219 |
+
overflow: hidden;
|
| 220 |
+
border: 1px solid rgba(0,0,0,0.08);
|
| 221 |
+
border-radius: 10px;
|
| 222 |
+
padding: 6px;
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
/* Make the plot fill the fixed container */
|
| 226 |
+
#plot_box .plot-container,
|
| 227 |
+
#plot_box .gradio-plot,
|
| 228 |
+
#plot_box canvas,
|
| 229 |
+
#plot_box svg {
|
| 230 |
+
height: 100% !important;
|
| 231 |
+
width: 100% !important;
|
| 232 |
+
}
|
| 233 |
"""
|
| 234 |
|
| 235 |
with gr.Blocks() as demo:
|
|
|
|
| 239 |
"The table shows **At Risk / Not At Risk** by management level."
|
| 240 |
)
|
| 241 |
|
| 242 |
+
with gr.Row():
|
| 243 |
with gr.Column(scale=1, min_width=420):
|
| 244 |
with gr.Row():
|
| 245 |
btn_c1 = gr.Button("Load Cluster 1")
|
|
|
|
| 256 |
with gr.Column(scale=1, min_width=520):
|
| 257 |
headline = gr.Markdown(elem_id="headline_box")
|
| 258 |
|
|
|
|
| 259 |
gr.HTML('<div id="table_box">')
|
| 260 |
table = gr.Dataframe(label="Risk Status by Management Level")
|
| 261 |
gr.HTML("</div>")
|
| 262 |
|
| 263 |
+
gr.HTML('<div id="plot_box">')
|
| 264 |
+
plot = gr.Plot()
|
| 265 |
+
gr.HTML("</div>")
|
| 266 |
|
| 267 |
btn_predict.click(
|
| 268 |
fn=predict_dashboard,
|