nicojara commited on
Commit
dddb21b
·
verified ·
1 Parent(s): 1dbb401

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -44
app.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  import pickle
2
  import pandas as pd
3
  import shap
@@ -6,79 +12,76 @@ import gradio as gr
6
  import numpy as np
7
  import matplotlib.pyplot as plt
8
 
9
- # load the model from disk
10
  loaded_model = pickle.load(open("salar_xgb_team.pkl", 'rb'))
11
 
12
- # Setup SHAP
13
- explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
14
 
15
- # Create the main function for server
16
- def main_func(age, education_num, sex, capital_gain, capital_loss, hours_per_week, salary_class):
17
  sex = 1 if sex == "Female" else 0
18
- new_row = pd.DataFrame.from_dict({'age':age,
19
- 'education-num':education_num,'sex':sex,'capital-gain':capital_gain,
20
- 'capital-loss':capital_loss, 'hours-per-week':hours_per_week,'salary-class':salary_class},
21
- orient = 'index').transpose()
22
-
 
 
 
 
23
  prob = loaded_model.predict_proba(new_row)
24
-
25
  shap_values = explainer(new_row)
26
- # plot = shap.force_plot(shap_values[0], matplotlib=True, figsize=(30,30), show=False)
27
- # plot = shap.plots.waterfall(shap_values[0], max_display=6, show=False)
28
  plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
29
 
30
  plt.tight_layout()
31
  local_plot = plt.gcf()
32
  plt.close()
33
-
34
- return {
35
- "Chance of Earning > $50K": float(prob[0][1]),
36
- "Chance of Earning ≤ $50K": float(prob[0][0])
37
- }, local_plot
38
 
39
- # Create the UI
40
- title = "**Household Income Predictor** 💰"
41
- description1 = """This app uses your input to predict whether a household earns more or less than $50K per year"""
42
 
43
- description2 = """Adjust the values below and click 'Analyze' to see the prediction and explanation."""
 
 
 
44
 
45
  with gr.Blocks(title=title) as demo:
46
  gr.Markdown(f"## {title}")
47
  gr.Markdown(description1)
48
- gr.Markdown("---")
49
  gr.Markdown(description2)
50
- gr.Markdown("---")
51
 
52
- with gr.Row():
53
- age = gr.Number(label="Age", value=35)
54
- education_num = gr.Number(label="Education Level (numeric)", value=10)
55
- with gr.Row():
56
- sex = gr.Radio(["Male", "Female"], label="Sex", type="index") # Male = 0, Female = 1
57
- capital_gain = gr.Number(label="Capital Gain", value=0)
58
- capital_loss = gr.Number(label="Capital Loss", value=0)
59
- hours_per_week = gr.Number(label="Hours per Week", value=40)
 
60
 
61
  submit_btn = gr.Button("Analyze")
62
 
63
  with gr.Column(visible=True) as output_col:
64
  label = gr.Label(label="Predicted Income")
65
- local_plot = gr.Plot(label="Top SHAP Features")
66
 
67
  submit_btn.click(
68
  main_func,
69
  [age, education_num, sex, capital_gain, capital_loss, hours_per_week],
70
- [label, local_plot],
71
- api_name="Salary_Predictor"
72
  )
73
 
74
- gr.Markdown("### Examples:")
75
- gr.Examples([
76
- [28, 12, 0, 0, 0, 40], # Male, younger, more educated
77
- [50, 9, 1, 0, 0, 30] # Female, mid-education, fewer hours
78
- ],
79
- inputs=[age, education_num, sex, capital_gain, capital_loss, hours_per_week],
80
- outputs=[label, local_plot],
81
- fn=main_func,
82
- cache_examples=True)
83
 
84
  demo.launch()
 
1
+ """app
2
+ Automatically generated by Colab.
3
+ Original file is located at
4
+ https://colab.research.google.com/drive/1B_g2XLYu46kFDIFzNnnJzBQ0GBPssCQw
5
+ """
6
+
7
  import pickle
8
  import pandas as pd
9
  import shap
 
12
  import numpy as np
13
  import matplotlib.pyplot as plt
14
 
15
+ # Load the model
16
  loaded_model = pickle.load(open("salar_xgb_team.pkl", 'rb'))
17
 
18
+ # Setup SHAP (do not change)
19
+ explainer = shap.Explainer(loaded_model)
20
 
21
+ # Define main prediction function
22
+ def main_func(age, education_num, sex, capital_gain, capital_loss, hours_per_week):
23
  sex = 1 if sex == "Female" else 0
24
+ new_row = pd.DataFrame.from_dict({
25
+ 'age': age,
26
+ 'education-num': education_num,
27
+ 'sex': sex,
28
+ 'capital-gain': capital_gain,
29
+ 'capital-loss': capital_loss,
30
+ 'hours-per-week': hours_per_week
31
+ }, orient='index').transpose()
32
+
33
  prob = loaded_model.predict_proba(new_row)
34
+
35
  shap_values = explainer(new_row)
 
 
36
  plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
37
 
38
  plt.tight_layout()
39
  local_plot = plt.gcf()
40
  plt.close()
 
 
 
 
 
41
 
42
+ return {"≤ $50K": float(prob[0][0]), "> $50K": float(prob[0][1])}, local_plot
 
 
43
 
44
+ # Gradio UI
45
+ title = "**Household Income Predictor & Interpreter** 💰"
46
+ description1 = """This app takes demographic and economic features to predict whether a household earns ≤ $50K or > $50K annually.🚀"""
47
+ description2 = """Adjust the values and click Analyze to get predictions and feature importance."""
48
 
49
  with gr.Blocks(title=title) as demo:
50
  gr.Markdown(f"## {title}")
51
  gr.Markdown(description1)
52
+ gr.Markdown("""---""")
53
  gr.Markdown(description2)
54
+ gr.Markdown("""---""")
55
 
56
+ gr.Image("Household.png")
57
+
58
+ age = gr.Number(label="Age", value=35)
59
+ education_num = gr.Number(label="Education Level (numeric)", value=10)
60
+ sex = gr.Radio(choices=["Male", "Female"], label="Sex", value="Female")
61
+ capital_gain = gr.Number(label="Capital Gain", value=0)
62
+ capital_loss = gr.Number(label="Capital Loss", value=0)
63
+ hours_per_week = gr.Number(label="Hours per Week", value=40)
64
+ # salary_class = gr.Number(label="(Optional) Salary Class for SHAP Context", value=0) # Can remove if not needed
65
 
66
  submit_btn = gr.Button("Analyze")
67
 
68
  with gr.Column(visible=True) as output_col:
69
  label = gr.Label(label="Predicted Income")
70
+ local_plot = gr.Plot(label='SHAP Interpretation:')
71
 
72
  submit_btn.click(
73
  main_func,
74
  [age, education_num, sex, capital_gain, capital_loss, hours_per_week],
75
+ [label, local_plot], api_name="Income_Predictor"
 
76
  )
77
 
78
+ gr.Markdown("### Try these examples:")
79
+ gr.Examples(
80
+ [[39,13, "Male", 0, 0, 40], [52, 9, "Female", 0, 1876, 45]],
81
+ [age, education_num, sex, capital_gain, capital_loss, hours_per_week],
82
+ [label, local_plot],
83
+ main_func,
84
+ cache_examples=True
85
+ )
 
86
 
87
  demo.launch()