nicojara commited on
Commit
8fdb1a1
·
verified ·
1 Parent(s): 4b6046e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -48
app.py CHANGED
@@ -1,46 +1,68 @@
1
  import pickle
2
  import pandas as pd
3
  import shap
4
- from shap.plots._force_matplotlib import draw_additive_plot
5
  import gradio as gr
6
  import numpy as np
7
  import matplotlib.pyplot as plt
8
 
9
- # load the model from disk
10
  loaded_model = pickle.load(open("salar_xgb_team.pkl", 'rb'))
11
 
12
- # Setup SHAP
13
- explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
14
-
15
- # Create the main function for server
16
- def main_func(age, education_num, sex, capital_gain, capital_loss, hours_per_week):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  sex = 1 if sex == "Female" else 0
18
- new_row = pd.DataFrame.from_dict({'age':age,
19
- 'education-num':education_num,'sex':sex,'capital-gain':capital_gain,
20
- 'capital-loss':capital_loss, 'hours-per-week':hours_per_week},
21
- orient = 'index').transpose()
22
-
 
 
 
 
 
23
  prob = loaded_model.predict_proba(new_row)
24
-
25
  shap_values = explainer(new_row)
26
- # plot = shap.force_plot(shap_values[0], matplotlib=True, figsize=(30,30), show=False)
27
- # plot = shap.plots.waterfall(shap_values[0], max_display=6, show=False)
28
  plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
29
 
30
  plt.tight_layout()
31
  local_plot = plt.gcf()
32
  plt.close()
33
-
34
  return {
35
  "Chance of Earning > $50K": float(prob[0][1]),
36
  "Chance of Earning ≤ $50K": float(prob[0][0])
37
  }, local_plot
38
 
39
- # Create the UI
40
  title = "**Household Income Predictor** 💰"
41
- description1 = """This app uses your input to predict whether a household earns more or less than $50K per year"""
42
-
43
- description2 = """Adjust the values below and click 'Analyze' to see the prediction and explanation."""
44
 
45
  with gr.Blocks(title=title) as demo:
46
  gr.Markdown(f"## {title}")
@@ -49,72 +71,77 @@ with gr.Blocks(title=title) as demo:
49
  gr.Markdown(description2)
50
  gr.Markdown("---")
51
 
52
- # 🎛 Preset scenario dropdown
53
  scenario = gr.Dropdown(
54
  ["Select a Sample",
55
- "👨‍💻 Young Tech Worker: 28 yrs, college degree, 45 hrs/week",
56
- "👵 Retired Part-Timer: 65 yrs, no college, 20 hrs/week",
57
- "👩‍🏫 Mid-Career Teacher: 42 yrs, 14 education years, 38 hrs/week",
58
- "👨‍🔧 Manual Laborer: 50 yrs, 9 education years, 60 hrs/week"],
59
- label="📋 Choose a Sample Profile (optional — autofills values to explore common cases)"
60
  )
61
 
62
- # 🎯 Inputs
63
  with gr.Row():
64
  age = gr.Number(label="🧓 Age", value=35)
65
- education_num = gr.Number(label="🎓 Education Level (numeric)", value=10)
 
 
 
 
66
  with gr.Row():
67
  sex = gr.Radio(["Male", "Female"], label="🧍 Sex")
68
  capital_gain = gr.Number(label="📈 Capital Gain", value=0)
69
  capital_loss = gr.Number(label="📉 Capital Loss", value=0)
70
  hours_per_week = gr.Number(label="⏱ Hours per Week", value=40)
71
 
72
- submit_btn = gr.Button("🔎 Analyze")
73
-
74
- # 🔁 Handle preset scenario changes
75
  def fill_scenario(scenario_choice):
76
- if scenario_choice == "👨‍💻 Young Tech Worker: 28 yrs, college degree, 45 hrs/week":
77
- return [28, 16, "Male", 0, 0, 45]
78
  elif scenario_choice == "👵 Retired Part-Timer: 65 yrs, no college, 20 hrs/week":
79
- return [65, 8, "Female", 0, 0, 20]
80
- elif scenario_choice == "👩‍🏫 Mid-Career Teacher: 42 yrs, 14 education years, 38 hrs/week":
81
- return [42, 14, "Female", 0, 0, 38]
82
- elif scenario_choice == "👨‍🔧 Manual Laborer: 50 yrs, 9 education years, 60 hrs/week":
83
- return [50, 9, "Male", 0, 0, 60]
84
  else:
85
- return [35, 10, "Male", 0, 0, 40] # Default values
86
 
87
  scenario.change(
88
  fn=fill_scenario,
89
  inputs=[scenario],
90
- outputs=[age, education_num, sex, capital_gain, capital_loss, hours_per_week]
91
  )
92
 
93
- # 🧠 Prediction output
94
  with gr.Column(visible=True) as output_col:
95
  label = gr.Label(label="🧠 Predicted Income")
96
  confidence = gr.Slider(0, 100, value=50, label="📊 Confidence in > $50K", interactive=False)
97
  local_plot = gr.Plot(label="🔍 Top SHAP Features")
98
 
99
- # 🧠 Wrap predict + confidence slider logic
100
- def wrapped_main(age, education_num, sex, capital_gain, capital_loss, hours_per_week):
101
- result, shap_plot = main_func(age, education_num, sex, capital_gain, capital_loss, hours_per_week)
102
  return result, float(result["Chance of Earning > $50K"]) * 100, shap_plot
103
 
 
 
104
  submit_btn.click(
105
  wrapped_main,
106
- [age, education_num, sex, capital_gain, capital_loss, hours_per_week],
107
  [label, confidence, local_plot],
108
  api_name="Salary_Predictor"
109
  )
110
 
 
111
  gr.Markdown("### 🧪 Try Some Examples:")
112
  gr.Examples(
113
  [
114
- [28, 16, "Male", 0, 0, 45],
115
- [60, 8, "Female", 0, 0, 25]
116
  ],
117
- [age, education_num, sex, capital_gain, capital_loss, hours_per_week],
118
  [label, confidence, local_plot],
119
  wrapped_main,
120
  cache_examples=True
@@ -122,4 +149,5 @@ with gr.Blocks(title=title) as demo:
122
 
123
  demo.launch()
124
 
 
125
 
 
1
  import pickle
2
  import pandas as pd
3
  import shap
 
4
  import gradio as gr
5
  import numpy as np
6
  import matplotlib.pyplot as plt
7
 
8
+ # Load the model
9
  loaded_model = pickle.load(open("salar_xgb_team.pkl", 'rb'))
10
 
11
+ # SHAP setup
12
+ explainer = shap.Explainer(loaded_model) # DO NOT CHANGE
13
+
14
+ # Education mapping
15
+ education_map = {
16
+ "Less than 1st grade": 1,
17
+ "1st–4th grade": 2,
18
+ "5th–6th grade": 3,
19
+ "7th–8th grade": 4,
20
+ "9th grade": 5,
21
+ "10th grade": 6,
22
+ "11th grade": 7,
23
+ "12th grade (no diploma)": 8,
24
+ "High School Grad": 9,
25
+ "Some College": 10,
26
+ "Associate's Degree (Voc)": 11,
27
+ "Associate's Degree (Acad)": 12,
28
+ "Bachelor's Degree": 13,
29
+ "Master's Degree": 14,
30
+ "Professional School": 15,
31
+ "Doctorate": 16
32
+ }
33
+
34
+ # Main model logic
35
+ def main_func(age, education_level, sex, capital_gain, capital_loss, hours_per_week):
36
+ education_num = education_map[education_level]
37
  sex = 1 if sex == "Female" else 0
38
+
39
+ new_row = pd.DataFrame.from_dict({
40
+ 'age': age,
41
+ 'education-num': education_num,
42
+ 'sex': sex,
43
+ 'capital-gain': capital_gain,
44
+ 'capital-loss': capital_loss,
45
+ 'hours-per-week': hours_per_week
46
+ }, orient='index').transpose()
47
+
48
  prob = loaded_model.predict_proba(new_row)
49
+
50
  shap_values = explainer(new_row)
 
 
51
  plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
52
 
53
  plt.tight_layout()
54
  local_plot = plt.gcf()
55
  plt.close()
56
+
57
  return {
58
  "Chance of Earning > $50K": float(prob[0][1]),
59
  "Chance of Earning ≤ $50K": float(prob[0][0])
60
  }, local_plot
61
 
62
+ # Gradio UI
63
  title = "**Household Income Predictor** 💰"
64
+ description1 = """This app uses your input to predict whether a household earns more or less than $50K per year."""
65
+ description2 = """Adjust the values below or select a sample profile, then click 'Analyze' to see the prediction and feature impact."""
 
66
 
67
  with gr.Blocks(title=title) as demo:
68
  gr.Markdown(f"## {title}")
 
71
  gr.Markdown(description2)
72
  gr.Markdown("---")
73
 
74
+ # Sample profile dropdown
75
  scenario = gr.Dropdown(
76
  ["Select a Sample",
77
+ "👨‍💻 Young Tech Worker: 28 yrs, Bachelor's, 45 hrs/week",
78
+ "👵 Retired Part-Timer: 65 yrs, no college, 20 hrs/week",
79
+ "👩‍🏫 Mid-Career Teacher: 42 yrs, Master's, 38 hrs/week",
80
+ "👨‍🔧 Manual Laborer: 50 yrs, High School Grad, 60 hrs/week"],
81
+ label="📋 Choose a Sample Profile (optional — autofills values to explore common cases)"
82
  )
83
 
84
+ # Inputs
85
  with gr.Row():
86
  age = gr.Number(label="🧓 Age", value=35)
87
+ education_level = gr.Dropdown(
88
+ list(education_map.keys()),
89
+ label="🎓 Education Level",
90
+ value="Some College"
91
+ )
92
  with gr.Row():
93
  sex = gr.Radio(["Male", "Female"], label="🧍 Sex")
94
  capital_gain = gr.Number(label="📈 Capital Gain", value=0)
95
  capital_loss = gr.Number(label="📉 Capital Loss", value=0)
96
  hours_per_week = gr.Number(label="⏱ Hours per Week", value=40)
97
 
98
+ # Handle preset scenario changes
 
 
99
  def fill_scenario(scenario_choice):
100
+ if scenario_choice == "👨‍💻 Young Tech Worker: 28 yrs, Bachelor's, 45 hrs/week":
101
+ return [28, "Bachelor's Degree", "Male", 0, 0, 45]
102
  elif scenario_choice == "👵 Retired Part-Timer: 65 yrs, no college, 20 hrs/week":
103
+ return [65, "9th grade", "Female", 0, 0, 20]
104
+ elif scenario_choice == "👩‍🏫 Mid-Career Teacher: 42 yrs, Master's, 38 hrs/week":
105
+ return [42, "Master's Degree", "Female", 0, 0, 38]
106
+ elif scenario_choice == "👨‍🔧 Manual Laborer: 50 yrs, High School Grad, 60 hrs/week":
107
+ return [50, "High School Grad", "Male", 0, 0, 60]
108
  else:
109
+ return [35, "Some College", "Male", 0, 0, 40]
110
 
111
  scenario.change(
112
  fn=fill_scenario,
113
  inputs=[scenario],
114
+ outputs=[age, education_level, sex, capital_gain, capital_loss, hours_per_week]
115
  )
116
 
117
+ # Outputs
118
  with gr.Column(visible=True) as output_col:
119
  label = gr.Label(label="🧠 Predicted Income")
120
  confidence = gr.Slider(0, 100, value=50, label="📊 Confidence in > $50K", interactive=False)
121
  local_plot = gr.Plot(label="🔍 Top SHAP Features")
122
 
123
+ # Wrapped function for UI
124
+ def wrapped_main(age, education_level, sex, capital_gain, capital_loss, hours_per_week):
125
+ result, shap_plot = main_func(age, education_level, sex, capital_gain, capital_loss, hours_per_week)
126
  return result, float(result["Chance of Earning > $50K"]) * 100, shap_plot
127
 
128
+ # Button
129
+ submit_btn = gr.Button("🔎 Analyze")
130
  submit_btn.click(
131
  wrapped_main,
132
+ [age, education_level, sex, capital_gain, capital_loss, hours_per_week],
133
  [label, confidence, local_plot],
134
  api_name="Salary_Predictor"
135
  )
136
 
137
+ # Examples
138
  gr.Markdown("### 🧪 Try Some Examples:")
139
  gr.Examples(
140
  [
141
+ [28, "Bachelor's Degree", "Male", 0, 0, 45],
142
+ [60, "9th grade", "Female", 0, 0, 25]
143
  ],
144
+ [age, education_level, sex, capital_gain, capital_loss, hours_per_week],
145
  [label, confidence, local_plot],
146
  wrapped_main,
147
  cache_examples=True
 
149
 
150
  demo.launch()
151
 
152
+
153