Update app.py
Browse files
app.py
CHANGED
|
@@ -77,6 +77,11 @@ ALL_DRIVER_LABELS = [
|
|
| 77 |
"Workload",
|
| 78 |
]
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
# =========================
|
| 82 |
# Helpers
|
|
@@ -101,30 +106,30 @@ def risk_label(p: float) -> str:
|
|
| 101 |
return "At Risk" if p >= 0.5 else "Not At Risk"
|
| 102 |
|
| 103 |
|
| 104 |
-
def not_at_risk_threshold_from_cluster3():
|
| 105 |
-
"""
|
| 106 |
-
threshold line = MIN of Cluster 3 across the driver vars in the averages chart,
|
| 107 |
-
labeled "Not at-risk threshold".
|
| 108 |
-
"""
|
| 109 |
-
return min(CLUSTER_3[v] for v in ALL_DRIVER_VARS)
|
| 110 |
-
|
| 111 |
-
|
| 112 |
# =========================
|
| 113 |
-
# Plot: Average of key drivers +
|
| 114 |
# =========================
|
| 115 |
def make_driver_plot(driver_vals: dict):
|
| 116 |
values = [driver_vals[v] for v in ALL_DRIVER_VARS]
|
| 117 |
-
|
| 118 |
|
| 119 |
fig, ax = plt.subplots(figsize=(8.6, 3.1))
|
| 120 |
-
ax.bar(ALL_DRIVER_LABELS, values)
|
| 121 |
|
| 122 |
-
#
|
| 123 |
-
ax.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
ax.text(
|
| 125 |
-
len(ALL_DRIVER_LABELS) - 0.
|
| 126 |
-
|
| 127 |
-
"
|
| 128 |
va="center",
|
| 129 |
ha="right",
|
| 130 |
)
|
|
@@ -256,6 +261,7 @@ def load_at_risk_group():
|
|
| 256 |
shap_fig,
|
| 257 |
)
|
| 258 |
|
|
|
|
| 259 |
def apply_recommendation():
|
| 260 |
# Recommendation = move toward high performers
|
| 261 |
# average of Cluster 1 and Cluster 2
|
|
@@ -290,12 +296,10 @@ def apply_recommendation():
|
|
| 290 |
# =========================
|
| 291 |
CSS = """
|
| 292 |
#app-wrap { max-width: 1200px; margin: 0 auto; }
|
| 293 |
-
|
| 294 |
/* Remove extra padding/margins from blocks */
|
| 295 |
.gr-block { padding: 10px 12px !important; }
|
| 296 |
.gr-form { gap: 8px !important; }
|
| 297 |
.gr-row { gap: 10px !important; }
|
| 298 |
-
|
| 299 |
/* Make markdown tighter */
|
| 300 |
.compact h2 { margin: 0 0 6px 0; }
|
| 301 |
.compact p { margin: 0 0 8px 0; }
|
|
@@ -335,7 +339,7 @@ with gr.Blocks(css=CSS) as demo:
|
|
| 335 |
shap_plot = gr.Plot(label="Feature Importance (Shap)")
|
| 336 |
|
| 337 |
# =========================
|
| 338 |
-
# WIRE UP EVENTS
|
| 339 |
# =========================
|
| 340 |
slider_inputs = [Engagement, SupportiveGM, WellBeing, WorkEnvironment, Voice, DecisionAutonomy, Workload]
|
| 341 |
|
|
|
|
| 77 |
"Workload",
|
| 78 |
]
|
| 79 |
|
| 80 |
+
# =========================
|
| 81 |
+
# UPDATED: Goal Average per driver = Cluster 3 mean per driver
|
| 82 |
+
# =========================
|
| 83 |
+
GOAL_AVG = {v: CLUSTER_3[v] for v in ALL_DRIVER_VARS}
|
| 84 |
+
|
| 85 |
|
| 86 |
# =========================
|
| 87 |
# Helpers
|
|
|
|
| 106 |
return "At Risk" if p >= 0.5 else "Not At Risk"
|
| 107 |
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
# =========================
|
| 110 |
+
# Plot: Average of key drivers + Goal Average (Cluster 3 means)
|
| 111 |
# =========================
|
| 112 |
def make_driver_plot(driver_vals: dict):
|
| 113 |
values = [driver_vals[v] for v in ALL_DRIVER_VARS]
|
| 114 |
+
goals = [GOAL_AVG[v] for v in ALL_DRIVER_VARS]
|
| 115 |
|
| 116 |
fig, ax = plt.subplots(figsize=(8.6, 3.1))
|
|
|
|
| 117 |
|
| 118 |
+
# Bars (current values)
|
| 119 |
+
bars = ax.bar(range(len(ALL_DRIVER_LABELS)), values, tick_label=ALL_DRIVER_LABELS)
|
| 120 |
+
|
| 121 |
+
# UPDATED: draw a dashed goal line for each bar at that driver’s Cluster 3 mean
|
| 122 |
+
for i, b in enumerate(bars):
|
| 123 |
+
y = goals[i]
|
| 124 |
+
left = b.get_x()
|
| 125 |
+
right = b.get_x() + b.get_width()
|
| 126 |
+
ax.plot([left, right], [y, y], linestyle="--", linewidth=2)
|
| 127 |
+
|
| 128 |
+
# UPDATED: label it once as "Goal Average"
|
| 129 |
ax.text(
|
| 130 |
+
len(ALL_DRIVER_LABELS) - 0.02,
|
| 131 |
+
goals[-1],
|
| 132 |
+
"Goal Average",
|
| 133 |
va="center",
|
| 134 |
ha="right",
|
| 135 |
)
|
|
|
|
| 261 |
shap_fig,
|
| 262 |
)
|
| 263 |
|
| 264 |
+
|
| 265 |
def apply_recommendation():
|
| 266 |
# Recommendation = move toward high performers
|
| 267 |
# average of Cluster 1 and Cluster 2
|
|
|
|
| 296 |
# =========================
|
| 297 |
CSS = """
|
| 298 |
#app-wrap { max-width: 1200px; margin: 0 auto; }
|
|
|
|
| 299 |
/* Remove extra padding/margins from blocks */
|
| 300 |
.gr-block { padding: 10px 12px !important; }
|
| 301 |
.gr-form { gap: 8px !important; }
|
| 302 |
.gr-row { gap: 10px !important; }
|
|
|
|
| 303 |
/* Make markdown tighter */
|
| 304 |
.compact h2 { margin: 0 0 6px 0; }
|
| 305 |
.compact p { margin: 0 0 8px 0; }
|
|
|
|
| 339 |
shap_plot = gr.Plot(label="Feature Importance (Shap)")
|
| 340 |
|
| 341 |
# =========================
|
| 342 |
+
# WIRE UP EVENTS
|
| 343 |
# =========================
|
| 344 |
slider_inputs = [Engagement, SupportiveGM, WellBeing, WorkEnvironment, Voice, DecisionAutonomy, Workload]
|
| 345 |
|