mns6rh commited on
Commit
df403ba
·
verified ·
1 Parent(s): 6cff660

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -66
app.py CHANGED
@@ -11,7 +11,7 @@ plt.rcParams["figure.dpi"] = 100
11
  # =========================
12
  # Load model (CatBoostClassifier saved via joblib)
13
  # =========================
14
- model = joblib.load("cat (3).joblib")
15
 
16
  FEATURES = [
17
  "Engagement",
@@ -37,6 +37,16 @@ CLUSTER_1 = {
37
  "Engagement": 4.9324,
38
  }
39
 
 
 
 
 
 
 
 
 
 
 
40
  CLUSTER_3 = {
41
  "Voice": 2.39,
42
  "DecisionAutonomy": 3.55,
@@ -47,13 +57,31 @@ CLUSTER_3 = {
47
  "Engagement": 3.3909,
48
  }
49
 
50
- VISIBLE_DRIVERS = ["Engagement", "SupportiveGM", "WellBeing", "WorkEnvironment"]
51
- VISIBLE_LABELS = ["Engagement", "Supportive GM", "Well-Being", "Work Environment"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  # =========================
54
  # Helpers
55
  # =========================
56
- def clamp(x):
57
  return max(1.0, min(5.0, float(x)))
58
 
59
  def build_X(vals: dict) -> pd.DataFrame:
@@ -66,38 +94,30 @@ def prob_at_risk(X: pd.DataFrame) -> float:
66
  idx = classes.index(1) # class 1 = At Risk
67
  return float(probs[idx])
68
 
69
- def risk_label(p):
70
  return "At Risk" if p >= 0.5 else "Not At Risk"
71
 
72
- def stable_threshold():
73
- return min(CLUSTER_1[v] for v in VISIBLE_DRIVERS)
74
-
75
  # =========================
76
- # Plot: drivers vs threshold
77
  # =========================
78
- def make_driver_plot(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
79
- th = stable_threshold()
80
- values = [Engagement, SupportiveGM, WellBeing, WorkEnvironment]
81
- colors = ["seagreen" if v >= th else "firebrick" for v in values]
82
-
83
- fig, ax = plt.subplots(figsize=(8.8, 3.4))
84
- ax.bar(VISIBLE_LABELS, values, color=colors)
85
 
86
- ax.axhline(th, linestyle="--", linewidth=2)
87
- ax.text(3.05, th, "Stable threshold", va="center")
88
 
89
  ax.set_ylim(1, 5.4)
90
  ax.set_yticks([1, 2, 3, 4, 5])
91
  ax.set_ylabel("Score (1–5)")
92
- ax.set_title("Key Drivers vs Stable Threshold")
93
 
94
- ax.margins(x=0.12)
95
  plt.tight_layout()
96
- plt.subplots_adjust(bottom=0.22)
97
  return fig
98
 
99
  # =========================
100
- # TRUE SHAP using CatBoost native SHAP values
101
  # =========================
102
  def make_catboost_shap_plot(X: pd.DataFrame):
103
  """
@@ -106,35 +126,33 @@ def make_catboost_shap_plot(X: pd.DataFrame):
106
  returns array shape: (n_rows, n_features + 1)
107
  last column is expected value; first n_features are SHAP contributions.
108
  """
109
- fig, ax = plt.subplots(figsize=(8.8, 3.4))
110
 
111
  try:
112
  from catboost import Pool
113
 
114
  pool = Pool(X) # 1-row
115
  shap_vals = model.get_feature_importance(pool, type="ShapValues")
116
- # shap_vals shape: (1, n_features+1)
117
  contrib = shap_vals[0, :-1] # drop expected value
118
 
119
  s = pd.Series(contrib, index=X.columns)
120
 
121
- # You don't want to talk about management level in the story
122
  s = s.drop(labels=["ManagementLevel"], errors="ignore")
123
 
124
- # top by absolute contribution
125
  s = s.reindex(s.abs().sort_values(ascending=False).index).head(8)
126
 
127
  ax.barh(s.index[::-1], s.values[::-1])
128
- ax.set_title("Top Drivers of This Prediction (True SHAP)")
129
  ax.set_xlabel("Impact on model log-odds (signed)")
130
  plt.tight_layout()
131
  return fig
132
 
133
  except Exception as e:
134
- # If catboost isn't installed or something fails, show the error nicely
135
  ax.text(
136
  0.5, 0.55,
137
- "True SHAP chart unavailable.\nInstall 'catboost' in requirements.txt.",
138
  ha="center", va="center", fontsize=10
139
  )
140
  ax.text(0.5, 0.40, f"Error: {str(e)[:150]}", ha="center", va="center", fontsize=9)
@@ -145,87 +163,125 @@ def make_catboost_shap_plot(X: pd.DataFrame):
145
  # =========================
146
  # Prediction
147
  # =========================
148
- def predict(Engagement, SupportiveGM, WellBeing, WorkEnvironment):
149
- Engagement = clamp(Engagement)
150
- SupportiveGM = clamp(SupportiveGM)
151
- WellBeing = clamp(WellBeing)
152
- WorkEnvironment = clamp(WorkEnvironment)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- # Model needs hidden vars; keep them at stable values to keep the story focused
155
  vals = {
156
- "Engagement": Engagement,
157
- "SupportiveGM": SupportiveGM,
158
- "ManagementLevel": 2, # fixed constant, not shown
159
- "WellBeing": WellBeing,
160
- "Voice": CLUSTER_1["Voice"],
161
- "DecisionAutonomy": CLUSTER_1["DecisionAutonomy"],
162
- "Workload": CLUSTER_1["Workload"],
163
- "WorkEnvironment": WorkEnvironment,
164
  }
165
 
166
  X = build_X(vals)
167
  p = prob_at_risk(X)
168
-
169
  headline = f"Predicted Status: {risk_label(p)}"
170
- driver_fig = make_driver_plot(Engagement, SupportiveGM, WellBeing, WorkEnvironment)
 
171
  shap_fig = make_catboost_shap_plot(X)
172
 
173
- return headline, driver_fig, shap_fig
174
 
175
- def apply_recommendation():
176
- e = CLUSTER_1["Engagement"]
177
- s = CLUSTER_1["SupportiveGM"]
178
- w = CLUSTER_1["WellBeing"]
179
- env = CLUSTER_1["WorkEnvironment"]
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
- headline, driver_fig, shap_fig = predict(e, s, w, env)
182
- return e, s, w, env, headline, driver_fig, shap_fig
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  # =========================
185
  # UI Layout (no scrolling)
186
  # =========================
187
  CSS = """
188
  #app-wrap { max-width: 1200px; margin: 0 auto; }
189
- .compact .gr-markdown { margin-bottom: 0.4rem !important; }
190
  """
191
 
192
  with gr.Blocks(css=CSS) as demo:
193
  gr.Markdown(
194
  "<div id='app-wrap' class='compact'>"
195
- "<h2>Retention Recommendation Simulator</h2>"
196
- "<p style='margin-top:0;'>Adjust the 4 drivers and click <b>Predict</b>. "
197
- "Click <b>Apply Recommendation Plan</b> to jump to the stable target.</p>"
198
  "</div>"
199
  )
200
 
201
  with gr.Row():
202
  # LEFT: sliders + buttons
203
- with gr.Column(scale=5, min_width=420):
 
204
  Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
205
  SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
206
  WellBeing = gr.Slider(1, 5, value=CLUSTER_3["WellBeing"], step=0.01, label="Well-Being")
207
  WorkEnvironment = gr.Slider(1, 5, value=CLUSTER_3["WorkEnvironment"], step=0.01, label="Work Environment")
 
 
 
208
 
209
  with gr.Row():
210
  btn_predict = gr.Button("Predict")
211
- btn_recommend = gr.Button("Apply Recommendation Plan")
212
 
213
  # RIGHT: headline + two plots stacked
214
  with gr.Column(scale=7, min_width=520):
215
  headline = gr.Textbox(label="Result", value="", interactive=False)
216
- driver_plot = gr.Plot(label="Key Drivers vs Stable Threshold")
217
- shap_plot = gr.Plot(label="True SHAP (CatBoost)")
218
 
219
  btn_predict.click(
220
  fn=predict,
221
- inputs=[Engagement, SupportiveGM, WellBeing, WorkEnvironment],
222
- outputs=[headline, driver_plot, shap_plot],
223
  )
224
 
225
- btn_recommend.click(
226
- fn=apply_recommendation,
227
  inputs=[],
228
- outputs=[Engagement, SupportiveGM, WellBeing, WorkEnvironment, headline, driver_plot, shap_plot],
 
 
 
229
  )
230
 
231
  demo.launch()
 
11
  # =========================
12
  # Load model (CatBoostClassifier saved via joblib)
13
  # =========================
14
+ model = joblib.load("cat (1).joblib")
15
 
16
  FEATURES = [
17
  "Engagement",
 
37
  "Engagement": 4.9324,
38
  }
39
 
40
+ CLUSTER_2 = {
41
+ "Voice": 3.94,
42
+ "DecisionAutonomy": 4.24,
43
+ "Workload": 3.76,
44
+ "WellBeing": 4.0251,
45
+ "WorkEnvironment": 4.1484,
46
+ "SupportiveGM": 4.1275,
47
+ "Engagement": 4.2828,
48
+ }
49
+
50
  CLUSTER_3 = {
51
  "Voice": 2.39,
52
  "DecisionAutonomy": 3.55,
 
57
  "Engagement": 3.3909,
58
  }
59
 
60
+ # You asked: "MAKE all THE VARS the key drivers" (we treat all survey vars as drivers)
61
+ ALL_DRIVER_VARS = [
62
+ "Engagement",
63
+ "SupportiveGM",
64
+ "WellBeing",
65
+ "WorkEnvironment",
66
+ "Voice",
67
+ "DecisionAutonomy",
68
+ "Workload",
69
+ ]
70
+
71
+ ALL_DRIVER_LABELS = [
72
+ "Engagement",
73
+ "Supportive GM",
74
+ "Well-Being",
75
+ "Work Environment",
76
+ "Voice",
77
+ "Decision Autonomy",
78
+ "Workload",
79
+ ]
80
 
81
  # =========================
82
  # Helpers
83
  # =========================
84
+ def clamp_1_5(x):
85
  return max(1.0, min(5.0, float(x)))
86
 
87
  def build_X(vals: dict) -> pd.DataFrame:
 
94
  idx = classes.index(1) # class 1 = At Risk
95
  return float(probs[idx])
96
 
97
+ def risk_label(p: float) -> str:
98
  return "At Risk" if p >= 0.5 else "Not At Risk"
99
 
 
 
 
100
  # =========================
101
+ # Plot: "Average of key drivers" (shows ALL driver vars)
102
  # =========================
103
+ def make_driver_plot(driver_vals: dict):
104
+ values = [driver_vals[v] for v in ALL_DRIVER_VARS]
 
 
 
 
 
105
 
106
+ fig, ax = plt.subplots(figsize=(8.8, 3.2))
107
+ ax.bar(ALL_DRIVER_LABELS, values)
108
 
109
  ax.set_ylim(1, 5.4)
110
  ax.set_yticks([1, 2, 3, 4, 5])
111
  ax.set_ylabel("Score (1–5)")
112
+ ax.set_title("Average of key drivers")
113
 
114
+ ax.margins(x=0.08)
115
  plt.tight_layout()
116
+ plt.subplots_adjust(bottom=0.28)
117
  return fig
118
 
119
  # =========================
120
+ # Plot: TRUE SHAP using CatBoost native SHAP values
121
  # =========================
122
  def make_catboost_shap_plot(X: pd.DataFrame):
123
  """
 
126
  returns array shape: (n_rows, n_features + 1)
127
  last column is expected value; first n_features are SHAP contributions.
128
  """
129
+ fig, ax = plt.subplots(figsize=(8.8, 3.2))
130
 
131
  try:
132
  from catboost import Pool
133
 
134
  pool = Pool(X) # 1-row
135
  shap_vals = model.get_feature_importance(pool, type="ShapValues")
 
136
  contrib = shap_vals[0, :-1] # drop expected value
137
 
138
  s = pd.Series(contrib, index=X.columns)
139
 
140
+ # Keep SHAP focused on survey drivers (exclude ManagementLevel)
141
  s = s.drop(labels=["ManagementLevel"], errors="ignore")
142
 
143
+ # Top 8 by absolute contribution
144
  s = s.reindex(s.abs().sort_values(ascending=False).index).head(8)
145
 
146
  ax.barh(s.index[::-1], s.values[::-1])
147
+ ax.set_title("Feature Importance (Shap)")
148
  ax.set_xlabel("Impact on model log-odds (signed)")
149
  plt.tight_layout()
150
  return fig
151
 
152
  except Exception as e:
 
153
  ax.text(
154
  0.5, 0.55,
155
+ "SHAP chart unavailable.\nInstall 'catboost' in requirements.txt.",
156
  ha="center", va="center", fontsize=10
157
  )
158
  ax.text(0.5, 0.40, f"Error: {str(e)[:150]}", ha="center", va="center", fontsize=9)
 
163
  # =========================
164
  # Prediction
165
  # =========================
166
+ def predict(
167
+ Engagement,
168
+ SupportiveGM,
169
+ WellBeing,
170
+ WorkEnvironment,
171
+ Voice,
172
+ DecisionAutonomy,
173
+ Workload,
174
+ ):
175
+ # Clamp sliders
176
+ driver_vals = {
177
+ "Engagement": clamp_1_5(Engagement),
178
+ "SupportiveGM": clamp_1_5(SupportiveGM),
179
+ "WellBeing": clamp_1_5(WellBeing),
180
+ "WorkEnvironment": clamp_1_5(WorkEnvironment),
181
+ "Voice": clamp_1_5(Voice),
182
+ "DecisionAutonomy": clamp_1_5(DecisionAutonomy),
183
+ "Workload": clamp_1_5(Workload),
184
+ }
185
 
186
+ # Build model row (ManagementLevel fixed internally)
187
  vals = {
188
+ **driver_vals,
189
+ "ManagementLevel": 2,
 
 
 
 
 
 
190
  }
191
 
192
  X = build_X(vals)
193
  p = prob_at_risk(X)
 
194
  headline = f"Predicted Status: {risk_label(p)}"
195
+
196
+ drivers_fig = make_driver_plot(driver_vals)
197
  shap_fig = make_catboost_shap_plot(X)
198
 
199
+ return headline, drivers_fig, shap_fig
200
 
201
+ # =========================
202
+ # Button: At risk group = average of Cluster 1 and Cluster 2 (as you requested)
203
+ # =========================
204
+ def at_risk_group():
205
+ avg = {}
206
+ for v in ALL_DRIVER_VARS:
207
+ avg[v] = (CLUSTER_1[v] + CLUSTER_2[v]) / 2.0
208
+
209
+ headline, drivers_fig, shap_fig = predict(
210
+ avg["Engagement"],
211
+ avg["SupportiveGM"],
212
+ avg["WellBeing"],
213
+ avg["WorkEnvironment"],
214
+ avg["Voice"],
215
+ avg["DecisionAutonomy"],
216
+ avg["Workload"],
217
+ )
218
 
219
+ # Return slider updates + outputs
220
+ return (
221
+ avg["Engagement"],
222
+ avg["SupportiveGM"],
223
+ avg["WellBeing"],
224
+ avg["WorkEnvironment"],
225
+ avg["Voice"],
226
+ avg["DecisionAutonomy"],
227
+ avg["Workload"],
228
+ headline,
229
+ drivers_fig,
230
+ shap_fig,
231
+ )
232
 
233
  # =========================
234
  # UI Layout (no scrolling)
235
  # =========================
236
  CSS = """
237
  #app-wrap { max-width: 1200px; margin: 0 auto; }
238
+ .compact .gr-markdown { margin-bottom: 0.35rem !important; }
239
  """
240
 
241
  with gr.Blocks(css=CSS) as demo:
242
  gr.Markdown(
243
  "<div id='app-wrap' class='compact'>"
244
+ "<h2>Retention Simulator</h2>"
245
+ "<p style='margin-top:0;'>Adjust all drivers and click <b>Predict</b>. "
246
+ "Click <b>At risk group</b> to load the average of Cluster 1 and Cluster 2.</p>"
247
  "</div>"
248
  )
249
 
250
  with gr.Row():
251
  # LEFT: sliders + buttons
252
+ with gr.Column(scale=5, min_width=430):
253
+ # Default starting point: Cluster 3 (most at-risk)
254
  Engagement = gr.Slider(1, 5, value=CLUSTER_3["Engagement"], step=0.01, label="Engagement")
255
  SupportiveGM = gr.Slider(1, 5, value=CLUSTER_3["SupportiveGM"], step=0.01, label="Supportive GM")
256
  WellBeing = gr.Slider(1, 5, value=CLUSTER_3["WellBeing"], step=0.01, label="Well-Being")
257
  WorkEnvironment = gr.Slider(1, 5, value=CLUSTER_3["WorkEnvironment"], step=0.01, label="Work Environment")
258
+ Voice = gr.Slider(1, 5, value=CLUSTER_3["Voice"], step=0.01, label="Voice")
259
+ DecisionAutonomy = gr.Slider(1, 5, value=CLUSTER_3["DecisionAutonomy"], step=0.01, label="Decision Autonomy")
260
+ Workload = gr.Slider(1, 5, value=CLUSTER_3["Workload"], step=0.01, label="Workload")
261
 
262
  with gr.Row():
263
  btn_predict = gr.Button("Predict")
264
+ btn_atrisk = gr.Button("At risk group")
265
 
266
  # RIGHT: headline + two plots stacked
267
  with gr.Column(scale=7, min_width=520):
268
  headline = gr.Textbox(label="Result", value="", interactive=False)
269
+ drivers_plot = gr.Plot(label="Average of key drivers")
270
+ shap_plot = gr.Plot(label="Feature Importance (Shap)")
271
 
272
  btn_predict.click(
273
  fn=predict,
274
+ inputs=[Engagement, SupportiveGM, WellBeing, WorkEnvironment, Voice, DecisionAutonomy, Workload],
275
+ outputs=[headline, drivers_plot, shap_plot],
276
  )
277
 
278
+ btn_atrisk.click(
279
+ fn=at_risk_group,
280
  inputs=[],
281
+ outputs=[
282
+ Engagement, SupportiveGM, WellBeing, WorkEnvironment, Voice, DecisionAutonomy, Workload,
283
+ headline, drivers_plot, shap_plot
284
+ ],
285
  )
286
 
287
  demo.launch()