atsuari commited on
Commit
29315f2
·
verified ·
1 Parent(s): 280fc94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -24
app.py CHANGED
@@ -1,106 +1,243 @@
1
  import pickle
 
2
  import pandas as pd
 
3
  import shap
 
4
  import gradio as gr
 
5
  import numpy as np
 
6
  import matplotlib.pyplot as plt
7
 
8
- # Load the model
 
9
  loaded_model = pickle.load(open("salar_xgb_team.pkl", "rb"))
10
 
11
- # Setup SHAP
12
- explainer = shap.Explainer(loaded_model) # DO NOT CHANGE THIS
13
-
14
- # Main function
15
- def main_func(age, education_num, sex, capital_gain, capital_loss, hours_per_week):
16
- # Input validation
17
- if age < 18 or age > 100 or education_num < 1 or hours_per_week < 1 or hours_per_week > 100:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  return {"≤50K": 0.0, ">50K": 0.0}, None, "❌ Invalid inputs. Please check your entries."
19
 
20
- # Process categorical
 
 
 
21
  sex_binary = 0 if sex == "Male" else 1
22
 
23
- # Create input row
 
24
  new_row = pd.DataFrame({
 
25
  'age': [age],
 
26
  'education-num': [education_num],
 
27
  'sex': [sex_binary],
 
28
  'capital-gain': [capital_gain],
 
29
  'capital-loss': [capital_loss],
 
30
  'hours-per-week': [hours_per_week]
 
31
  })
32
 
33
- # Predict
 
34
  prob = loaded_model.predict_proba(new_row)
 
35
  shap_values = explainer(new_row)
36
 
37
  # SHAP plot
 
38
  plt.figure(figsize=(8, 4))
 
39
  shap.plots.bar(shap_values[0], max_display=6, show=False)
 
40
  plt.tight_layout()
 
41
  local_plot = plt.gcf()
 
42
  plt.close()
43
 
44
- # Predicted class and confidence
 
45
  pred_class = ">50K" if prob[0][1] > 0.5 else "≤50K"
 
46
  confidence = round(prob[0][1] if pred_class == ">50K" else prob[0][0], 2)
47
-
48
  interpretation = f"💼 Prediction: **{pred_class}**\nConfidence: {confidence * 100:.2f}%"
49
 
50
  return {
 
51
  "≤50K": round(prob[0][0], 2),
 
52
  ">50K": round(prob[0][1], 2)
 
53
  }, local_plot, interpretation
54
 
55
- # ----------- Gradio UI -----------
 
56
  title = "**Salary Predictor & SHAP Explainer** 💰"
 
57
  description1 = "This app uses demographic and financial info to predict whether someone earns more than $50K annually."
58
- description2 = "Adjust the sliders and inputs below, then click **Analyze** to see the prediction and SHAP explanation."
 
59
 
60
  with gr.Blocks(title=title) as demo:
 
61
  gr.Markdown(f"## {title}")
 
62
  gr.Markdown(description1)
 
63
  gr.Markdown("---")
 
64
  gr.Markdown(description2)
 
65
  gr.Markdown("---")
66
 
67
  with gr.Row():
 
68
  with gr.Column(scale=1):
69
- age = gr.Slider(label="Age (Years)", minimum=18, maximum=100, value=35, info="Enter age between 18 and 100")
70
- education_num = gr.Slider(label="Education Level (Numerical)", minimum=1, maximum=16, value=10, info="E.g., 1 = Preschool, 16 = Doctorate")
 
 
 
 
 
 
 
 
 
 
 
71
  sex = gr.Radio(["Male", "Female"], label="Sex")
 
72
  capital_gain = gr.Number(label="Capital Gain", value=0)
 
73
  capital_loss = gr.Number(label="Capital Loss", value=0)
 
74
  hours_per_week = gr.Slider(label="Hours Worked per Week", minimum=1, maximum=100, value=40)
75
 
76
  submit_btn = gr.Button("🔍 Analyze")
77
 
78
  with gr.Column(scale=1):
 
79
  label = gr.Label(label="Predicted Probabilities")
 
80
  local_plot = gr.Plot(label="SHAP Feature Importance")
 
81
  result_text = gr.Textbox(label="Prediction Summary", lines=2)
82
 
83
  submit_btn.click(
 
84
  main_func,
85
- [age, education_num, sex, capital_gain, capital_loss, hours_per_week],
 
 
86
  [label, local_plot, result_text],
 
87
  api_name="Salary_Predictor"
 
88
  )
89
 
90
  gr.Markdown("### Try one of the following examples:")
 
91
  gr.Examples(
 
92
  examples=[
93
- [28, 12, "Male", 0, 0, 45],
94
- [52, 14, "Female", 7688, 0, 60],
95
- [35, 9, "Male", 0, 1902, 40]
 
 
 
 
96
  ],
97
- inputs=[age, education_num, sex, capital_gain, capital_loss, hours_per_week],
 
 
98
  outputs=[label, local_plot, result_text],
 
99
  fn=main_func,
 
100
  cache_examples=True
 
101
  )
102
 
103
  gr.Markdown("---")
104
- gr.Markdown("Built with ❤️ by Tania Ramesh for the 2025 AI Applications Project.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- demo.launch()
 
 
 
1
  import pickle
2
+
3
  import pandas as pd
4
+
5
  import shap
6
+
7
  import gradio as gr
8
+
9
  import numpy as np
10
+
11
  import matplotlib.pyplot as plt
12
 
13
+ # Load model
14
+
15
  loaded_model = pickle.load(open("salar_xgb_team.pkl", "rb"))
16
 
17
+ # SHAP explainer (Do not change)
18
+
19
+ explainer = shap.Explainer(loaded_model)
20
+
21
+ # Mapping from dropdown labels to numeric education levels
22
+
23
+ education_map = {
24
+
25
+ "Preschool": 1,
26
+
27
+ "1st-4th": 2,
28
+
29
+ "5th-6th": 3,
30
+
31
+ "7th-8th": 4,
32
+
33
+ "9th": 5,
34
+
35
+ "10th": 6,
36
+
37
+ "11th": 7,
38
+
39
+ "12th": 8,
40
+
41
+ "HS-grad": 9,
42
+
43
+ "Some-college": 10,
44
+
45
+ "Assoc-voc": 11,
46
+
47
+ "Assoc-acdm": 12,
48
+
49
+ "Bachelors": 13,
50
+
51
+ "Masters": 14,
52
+
53
+ "Prof-school": 15,
54
+
55
+ "Doctorate": 16
56
+
57
+ }
58
+
59
+ # Main prediction function
60
+
61
+ def main_func(age, education_label, sex, capital_gain, capital_loss, hours_per_week):
62
+
63
+ # Validate input
64
+
65
+ if age < 18 or age > 100 or hours_per_week < 1 or hours_per_week > 100:
66
+
67
  return {"≤50K": 0.0, ">50K": 0.0}, None, "❌ Invalid inputs. Please check your entries."
68
 
69
+ # Convert to model format
70
+
71
+ education_num = education_map.get(education_label, 9) # default to HS-grad
72
+
73
  sex_binary = 0 if sex == "Male" else 1
74
 
75
+ # Build dataframe
76
+
77
  new_row = pd.DataFrame({
78
+
79
  'age': [age],
80
+
81
  'education-num': [education_num],
82
+
83
  'sex': [sex_binary],
84
+
85
  'capital-gain': [capital_gain],
86
+
87
  'capital-loss': [capital_loss],
88
+
89
  'hours-per-week': [hours_per_week]
90
+
91
  })
92
 
93
+ # Predict and explain
94
+
95
  prob = loaded_model.predict_proba(new_row)
96
+
97
  shap_values = explainer(new_row)
98
 
99
  # SHAP plot
100
+
101
  plt.figure(figsize=(8, 4))
102
+
103
  shap.plots.bar(shap_values[0], max_display=6, show=False)
104
+
105
  plt.tight_layout()
106
+
107
  local_plot = plt.gcf()
108
+
109
  plt.close()
110
 
111
+ # Class and confidence
112
+
113
  pred_class = ">50K" if prob[0][1] > 0.5 else "≤50K"
114
+
115
  confidence = round(prob[0][1] if pred_class == ">50K" else prob[0][0], 2)
116
+
117
  interpretation = f"💼 Prediction: **{pred_class}**\nConfidence: {confidence * 100:.2f}%"
118
 
119
  return {
120
+
121
  "≤50K": round(prob[0][0], 2),
122
+
123
  ">50K": round(prob[0][1], 2)
124
+
125
  }, local_plot, interpretation
126
 
127
+ # UI layout
128
+
129
  title = "**Salary Predictor & SHAP Explainer** 💰"
130
+
131
  description1 = "This app uses demographic and financial info to predict whether someone earns more than $50K annually."
132
+
133
+ description2 = "Adjust the inputs and click **Analyze** to see prediction and SHAP feature contributions."
134
 
135
  with gr.Blocks(title=title) as demo:
136
+
137
  gr.Markdown(f"## {title}")
138
+
139
  gr.Markdown(description1)
140
+
141
  gr.Markdown("---")
142
+
143
  gr.Markdown(description2)
144
+
145
  gr.Markdown("---")
146
 
147
  with gr.Row():
148
+
149
  with gr.Column(scale=1):
150
+
151
+ age = gr.Number(label="Age", value=35, precision=0)
152
+
153
+ education_label = gr.Dropdown(
154
+
155
+ choices=list(education_map.keys()),
156
+
157
+ label="Education Level",
158
+
159
+ value="HS-grad"
160
+
161
+ )
162
+
163
  sex = gr.Radio(["Male", "Female"], label="Sex")
164
+
165
  capital_gain = gr.Number(label="Capital Gain", value=0)
166
+
167
  capital_loss = gr.Number(label="Capital Loss", value=0)
168
+
169
  hours_per_week = gr.Slider(label="Hours Worked per Week", minimum=1, maximum=100, value=40)
170
 
171
  submit_btn = gr.Button("🔍 Analyze")
172
 
173
  with gr.Column(scale=1):
174
+
175
  label = gr.Label(label="Predicted Probabilities")
176
+
177
  local_plot = gr.Plot(label="SHAP Feature Importance")
178
+
179
  result_text = gr.Textbox(label="Prediction Summary", lines=2)
180
 
181
  submit_btn.click(
182
+
183
  main_func,
184
+
185
+ [age, education_label, sex, capital_gain, capital_loss, hours_per_week],
186
+
187
  [label, local_plot, result_text],
188
+
189
  api_name="Salary_Predictor"
190
+
191
  )
192
 
193
  gr.Markdown("### Try one of the following examples:")
194
+
195
  gr.Examples(
196
+
197
  examples=[
198
+
199
+ [28, "Some-college", "Male", 0, 0, 45],
200
+
201
+ [52, "Masters", "Female", 7688, 0, 60],
202
+
203
+ [35, "HS-grad", "Male", 0, 1902, 40]
204
+
205
  ],
206
+
207
+ inputs=[age, education_label, sex, capital_gain, capital_loss, hours_per_week],
208
+
209
  outputs=[label, local_plot, result_text],
210
+
211
  fn=main_func,
212
+
213
  cache_examples=True
214
+
215
  )
216
 
217
  gr.Markdown("---")
218
+
219
+ gr.Markdown("Built with ❤️ by Group 3 for the 2025 AI Applications Project.")
220
+
221
+
222
+ gr.Markdown("---")
223
+
224
+ gr.Markdown("📊 Thanks for using the Salary Predictor!")
225
+
226
+ gr.Image(
227
+
228
+ value="https://media.giphy.com/media/l0MYt5jPR6QX5pnqM/giphy.gif",
229
+
230
+ label="",
231
+
232
+ show_label=False,
233
+
234
+ show_download_button=False,
235
+
236
+ height=200
237
+
238
+ )
239
+
240
 
241
+ demo.launch()
242
+
243
+