mns6rh commited on
Commit
f752e4a
·
verified ·
1 Parent(s): ee7d29b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -28
app.py CHANGED
@@ -1,6 +1,5 @@
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
-
4
  import joblib
5
  import pandas as pd
6
  import gradio as gr
@@ -11,7 +10,6 @@ import matplotlib.pyplot as plt
11
  # =========================
12
  model = joblib.load("cat (1).joblib")
13
 
14
- # MUST match model.feature_names_ EXACTLY
15
  FEATURES = [
16
  "Engagement",
17
  "SupportiveGM",
@@ -63,7 +61,6 @@ CLUSTER_PROFILES = {
63
  },
64
  }
65
 
66
- # Defaults for hidden vars when user adjusts sliders manually (TOTAL means)
67
  HIDDEN_DEFAULTS = {"Voice": 4.23, "DecisionAutonomy": 4.50, "Workload": 4.13}
68
  VISIBLE_DRIVERS = ["Engagement", "SupportiveGM", "WellBeing", "WorkEnvironment"]
69
 
@@ -77,18 +74,13 @@ def clamp_1_5(x: float) -> float:
77
 
78
 
79
  def build_X(vals: dict) -> pd.DataFrame:
80
- """Build 1-row DataFrame in exact model feature order."""
81
  return pd.DataFrame([[vals[f] for f in FEATURES]], columns=FEATURES)
82
 
83
 
84
  def prob_at_risk(X: pd.DataFrame) -> float:
85
- """
86
- Return P(At Risk) using classes_ lookup.
87
- Your encoding: At Risk / low intent is class 1.
88
- """
89
  probs = model.predict_proba(X)[0]
90
  classes = list(model.classes_)
91
- idx = classes.index(1)
92
  return float(probs[idx])
93
 
94
 
@@ -97,26 +89,22 @@ def risk_status(p: float, cutoff: float = 0.5) -> str:
97
 
98
 
99
  def stable_threshold_from_cluster1() -> float:
100
- """Threshold = min of the 4 visible drivers in Cluster 1."""
101
  c1 = CLUSTER_PROFILES["Cluster 1"]
102
  return min(c1[d] for d in VISIBLE_DRIVERS)
103
 
104
 
105
  def make_driver_plot_midlevel(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
106
- """
107
- Plot only Mid-level managers' 4 drivers, with a threshold line based on Cluster 1 minimum.
108
- Bars turn red if below threshold, green if at/above threshold.
109
- """
110
  threshold = stable_threshold_from_cluster1()
111
  drivers = ["Engagement", "Supportive GM", "Well-Being", "Work Environment"]
112
  values = [Engagement, SupportiveGM, WellBeing, WorkEnvironment]
113
  colors = ["seagreen" if v >= threshold else "firebrick" for v in values]
114
 
115
- # Deterministic layout (reduces jitter)
116
  fig, ax = plt.subplots(figsize=(11, 5), constrained_layout=True)
117
- ax.bar(drivers, values, color=colors)
118
 
 
119
  ax.axhline(threshold, linestyle="--", linewidth=2)
 
120
  ax.text(
121
  0.99,
122
  threshold,
@@ -126,19 +114,19 @@ def make_driver_plot_midlevel(Engagement, SupportiveGM, WellBeing, WorkEnvironme
126
  va="center",
127
  )
128
 
129
- # Extra space above 5
130
  ax.set_ylim(1, 5.6)
131
  ax.set_ylabel("Survey Score (1–5)")
132
  ax.set_title("Key Drivers (Mid-level managers) — Probability of Being At Risk of Leaving")
133
 
134
- # Lock x-limits to avoid autoscale twitching
135
  ax.set_xlim(-0.6, len(drivers) - 0.4)
136
 
137
  return fig
138
 
139
 
140
  # =========================
141
- # Dashboard prediction
142
  # =========================
143
  def predict_dashboard(
144
  Engagement,
@@ -149,13 +137,11 @@ def predict_dashboard(
149
  DecisionAutonomy=None,
150
  Workload=None,
151
  ):
152
- # visible
153
  Engagement = clamp_1_5(Engagement)
154
  SupportiveGM = clamp_1_5(SupportiveGM)
155
  WellBeing = clamp_1_5(WellBeing)
156
  WorkEnvironment = clamp_1_5(WorkEnvironment)
157
 
158
- # hidden
159
  Voice = clamp_1_5(HIDDEN_DEFAULTS["Voice"] if Voice is None else Voice)
160
  DecisionAutonomy = clamp_1_5(HIDDEN_DEFAULTS["DecisionAutonomy"] if DecisionAutonomy is None else DecisionAutonomy)
161
  Workload = clamp_1_5(HIDDEN_DEFAULTS["Workload"] if Workload is None else Workload)
@@ -182,7 +168,6 @@ def predict_dashboard(
182
  )
183
 
184
  df_table = pd.DataFrame(rows)
185
-
186
  mid_status = df_table.loc[df_table["Management Level"] == MGMT_LABELS[2], "Risk Status"].iloc[0]
187
  headline = f"Mid-level managers: **{mid_status}**"
188
 
@@ -211,17 +196,16 @@ def apply_cluster(cluster_name: str):
211
  DecisionAutonomy=DecisionAutonomy,
212
  Workload=Workload,
213
  )
214
-
215
  return Engagement, SupportiveGM, WellBeing, WorkEnvironment, headline, table, fig
216
 
217
 
218
  # =========================
219
- # Gradio UI (CSS locks layout without Dataframe height=)
220
  # =========================
221
  CSS = """
222
  #headline_box { min-height: 44px; }
223
 
224
- /* fixed-height container for the dataframe to prevent page reflow */
225
  #table_box {
226
  height: 240px;
227
  overflow: auto;
@@ -229,6 +213,23 @@ CSS = """
229
  border-radius: 10px;
230
  padding: 6px;
231
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  """
233
 
234
  with gr.Blocks() as demo:
@@ -238,7 +239,7 @@ with gr.Blocks() as demo:
238
  "The table shows **At Risk / Not At Risk** by management level."
239
  )
240
 
241
- with gr.Row(equal_height=True):
242
  with gr.Column(scale=1, min_width=420):
243
  with gr.Row():
244
  btn_c1 = gr.Button("Load Cluster 1")
@@ -255,12 +256,13 @@ with gr.Blocks() as demo:
255
  with gr.Column(scale=1, min_width=520):
256
  headline = gr.Markdown(elem_id="headline_box")
257
 
258
- # Wrap dataframe in a fixed-height container using HTML
259
  gr.HTML('<div id="table_box">')
260
  table = gr.Dataframe(label="Risk Status by Management Level")
261
  gr.HTML("</div>")
262
 
263
- plot = gr.Plot(height=420)
 
 
264
 
265
  btn_predict.click(
266
  fn=predict_dashboard,
 
1
  #!/usr/bin/env python
2
  # coding: utf-8
 
3
  import joblib
4
  import pandas as pd
5
  import gradio as gr
 
10
  # =========================
11
  model = joblib.load("cat (1).joblib")
12
 
 
13
  FEATURES = [
14
  "Engagement",
15
  "SupportiveGM",
 
61
  },
62
  }
63
 
 
64
  HIDDEN_DEFAULTS = {"Voice": 4.23, "DecisionAutonomy": 4.50, "Workload": 4.13}
65
  VISIBLE_DRIVERS = ["Engagement", "SupportiveGM", "WellBeing", "WorkEnvironment"]
66
 
 
74
 
75
 
76
  def build_X(vals: dict) -> pd.DataFrame:
 
77
  return pd.DataFrame([[vals[f] for f in FEATURES]], columns=FEATURES)
78
 
79
 
80
  def prob_at_risk(X: pd.DataFrame) -> float:
 
 
 
 
81
  probs = model.predict_proba(X)[0]
82
  classes = list(model.classes_)
83
+ idx = classes.index(1) # At Risk = 1
84
  return float(probs[idx])
85
 
86
 
 
89
 
90
 
91
  def stable_threshold_from_cluster1() -> float:
 
92
  c1 = CLUSTER_PROFILES["Cluster 1"]
93
  return min(c1[d] for d in VISIBLE_DRIVERS)
94
 
95
 
96
  def make_driver_plot_midlevel(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
 
 
 
 
97
  threshold = stable_threshold_from_cluster1()
98
  drivers = ["Engagement", "Supportive GM", "Well-Being", "Work Environment"]
99
  values = [Engagement, SupportiveGM, WellBeing, WorkEnvironment]
100
  colors = ["seagreen" if v >= threshold else "firebrick" for v in values]
101
 
102
+ # Deterministic layout helps reduce jitter
103
  fig, ax = plt.subplots(figsize=(11, 5), constrained_layout=True)
 
104
 
105
+ ax.bar(drivers, values, color=colors)
106
  ax.axhline(threshold, linestyle="--", linewidth=2)
107
+
108
  ax.text(
109
  0.99,
110
  threshold,
 
114
  va="center",
115
  )
116
 
117
+ # More space above 5
118
  ax.set_ylim(1, 5.6)
119
  ax.set_ylabel("Survey Score (1–5)")
120
  ax.set_title("Key Drivers (Mid-level managers) — Probability of Being At Risk of Leaving")
121
 
122
+ # Lock x-limits so bar geometry never changes
123
  ax.set_xlim(-0.6, len(drivers) - 0.4)
124
 
125
  return fig
126
 
127
 
128
  # =========================
129
+ # Prediction
130
  # =========================
131
  def predict_dashboard(
132
  Engagement,
 
137
  DecisionAutonomy=None,
138
  Workload=None,
139
  ):
 
140
  Engagement = clamp_1_5(Engagement)
141
  SupportiveGM = clamp_1_5(SupportiveGM)
142
  WellBeing = clamp_1_5(WellBeing)
143
  WorkEnvironment = clamp_1_5(WorkEnvironment)
144
 
 
145
  Voice = clamp_1_5(HIDDEN_DEFAULTS["Voice"] if Voice is None else Voice)
146
  DecisionAutonomy = clamp_1_5(HIDDEN_DEFAULTS["DecisionAutonomy"] if DecisionAutonomy is None else DecisionAutonomy)
147
  Workload = clamp_1_5(HIDDEN_DEFAULTS["Workload"] if Workload is None else Workload)
 
168
  )
169
 
170
  df_table = pd.DataFrame(rows)
 
171
  mid_status = df_table.loc[df_table["Management Level"] == MGMT_LABELS[2], "Risk Status"].iloc[0]
172
  headline = f"Mid-level managers: **{mid_status}**"
173
 
 
196
  DecisionAutonomy=DecisionAutonomy,
197
  Workload=Workload,
198
  )
 
199
  return Engagement, SupportiveGM, WellBeing, WorkEnvironment, headline, table, fig
200
 
201
 
202
  # =========================
203
+ # UI sizing via CSS wrappers (works on older Gradio too)
204
  # =========================
205
  CSS = """
206
  #headline_box { min-height: 44px; }
207
 
208
+ /* Fixed-height containers to prevent page jumping */
209
  #table_box {
210
  height: 240px;
211
  overflow: auto;
 
213
  border-radius: 10px;
214
  padding: 6px;
215
  }
216
+
217
+ #plot_box {
218
+ height: 460px;
219
+ overflow: hidden;
220
+ border: 1px solid rgba(0,0,0,0.08);
221
+ border-radius: 10px;
222
+ padding: 6px;
223
+ }
224
+
225
+ /* Make the plot fill the fixed container */
226
+ #plot_box .plot-container,
227
+ #plot_box .gradio-plot,
228
+ #plot_box canvas,
229
+ #plot_box svg {
230
+ height: 100% !important;
231
+ width: 100% !important;
232
+ }
233
  """
234
 
235
  with gr.Blocks() as demo:
 
239
  "The table shows **At Risk / Not At Risk** by management level."
240
  )
241
 
242
+ with gr.Row():
243
  with gr.Column(scale=1, min_width=420):
244
  with gr.Row():
245
  btn_c1 = gr.Button("Load Cluster 1")
 
256
  with gr.Column(scale=1, min_width=520):
257
  headline = gr.Markdown(elem_id="headline_box")
258
 
 
259
  gr.HTML('<div id="table_box">')
260
  table = gr.Dataframe(label="Risk Status by Management Level")
261
  gr.HTML("</div>")
262
 
263
+ gr.HTML('<div id="plot_box">')
264
+ plot = gr.Plot()
265
+ gr.HTML("</div>")
266
 
267
  btn_predict.click(
268
  fn=predict_dashboard,