starboywilliam commited on
Commit
358d13c
·
verified ·
1 Parent(s): fe403f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -14
app.py CHANGED
@@ -1,14 +1,42 @@
1
  import gradio as gr
 
 
 
 
 
2
 
 
 
 
 
 
 
 
 
3
  def predict_income_fn(age, education, sex, capital_gain, capital_loss, hours_per_week):
4
- # Dummy logic for demonstration
5
- return ">$50K" if capital_gain > 5000 else "<=50K"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
 
7
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
8
  gr.Markdown(
9
  """
10
- <div style='text-align: center; max-width: 700px; margin: 0 auto;'>
11
- <h1 style='font-size: 2.5em;'>💼 Income Prediction App</h1>
12
  <p style='font-size: 1.2em;'>Predict whether someone earns more than $50K/year based on financial and demographic data.</p>
13
  </div>
14
  """
@@ -20,12 +48,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
20
  age = gr.Number(value=30, label="", interactive=True)
21
 
22
  gr.Markdown("<h3 style='text-align:center;'>Education Level</h3>")
23
- education = gr.Dropdown(
24
- choices=[str(i) for i in range(1, 17)],
25
- value="10",
26
- label="",
27
- interactive=True
28
- )
29
 
30
  gr.Markdown("<h3 style='text-align:center;'>Sex</h3>")
31
  sex = gr.Radio(choices=["Male", "Female"], value="Male", label="", interactive=True)
@@ -36,7 +59,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
36
  gr.Markdown("<h3 style='text-align:center;'>Capital Loss</h3>")
37
  capital_loss = gr.Number(value=0, label="", interactive=True)
38
 
39
-
40
  gr.Markdown("<h3 style='text-align:center;'>Hours per Week</h3>")
41
  hours_per_week = gr.Number(value=40, label="", interactive=True)
42
 
@@ -45,13 +67,19 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
45
 
46
  with gr.Row():
47
  with gr.Column():
48
- gr.Markdown("<h3 style='text-align:center;'>Prediction Result</h3>")
49
- output = gr.Textbox(label="", interactive=False)
 
 
 
 
 
50
 
51
  predict_btn.click(
52
  fn=predict_income_fn,
53
  inputs=[age, education, sex, capital_gain, capital_loss, hours_per_week],
54
- outputs=[output]
55
  )
56
 
57
  demo.launch()
 
 
1
  import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import pickle
5
+ import shap
6
+ import matplotlib.pyplot as plt
7
 
8
+ # Load model
9
+ with open("salar_xgb_team.pkl", "rb") as f:
10
+ model = pickle.load(f)
11
+
12
+ # SHAP setup
13
+ explainer = shap.Explainer(model)
14
+
15
+ # Prediction function
16
  def predict_income_fn(age, education, sex, capital_gain, capital_loss, hours_per_week):
17
+ sex_num = 0 if sex == "Male" else 1
18
+ input_data = pd.DataFrame([[age, int(education), sex_num, capital_gain, capital_loss, hours_per_week]],
19
+ columns=['age', 'education-num', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week'])
20
+
21
+ pred = model.predict(input_data)[0]
22
+ prob = model.predict_proba(input_data)[0][1]
23
+ label = ">50K" if pred == 1 else "<=50K"
24
+ confidence = f"{prob * 100:.2f}%" if pred == 1 else f"{(1 - prob) * 100:.2f}%"
25
+
26
+ # SHAP
27
+ shap_values = explainer(input_data)
28
+ fig, ax = plt.subplots(figsize=(6, 3))
29
+ shap.plots.bar(shap_values[0], max_display=6, show=False)
30
+ plt.tight_layout()
31
+
32
+ return label, confidence, fig
33
 
34
+ # Gradio UI
35
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
36
  gr.Markdown(
37
  """
38
+ <div style='text-align: center; max-width: 750px; margin: 0 auto;'>
39
+ <h1 style='font-size: 2.5em; color: #1DB954;'>💼 Income Prediction App</h1>
40
  <p style='font-size: 1.2em;'>Predict whether someone earns more than $50K/year based on financial and demographic data.</p>
41
  </div>
42
  """
 
48
  age = gr.Number(value=30, label="", interactive=True)
49
 
50
  gr.Markdown("<h3 style='text-align:center;'>Education Level</h3>")
51
+ education = gr.Dropdown(choices=[str(i) for i in range(1, 17)], value="10", label="", interactive=True)
 
 
 
 
 
52
 
53
  gr.Markdown("<h3 style='text-align:center;'>Sex</h3>")
54
  sex = gr.Radio(choices=["Male", "Female"], value="Male", label="", interactive=True)
 
59
  gr.Markdown("<h3 style='text-align:center;'>Capital Loss</h3>")
60
  capital_loss = gr.Number(value=0, label="", interactive=True)
61
 
 
62
  gr.Markdown("<h3 style='text-align:center;'>Hours per Week</h3>")
63
  hours_per_week = gr.Number(value=40, label="", interactive=True)
64
 
 
67
 
68
  with gr.Row():
69
  with gr.Column():
70
+ gr.Markdown("<h3 style='text-align:center;'>Prediction</h3>")
71
+ output_label = gr.Textbox(label="Income", interactive=False)
72
+ output_confidence = gr.Textbox(label="Confidence", interactive=False)
73
+
74
+ with gr.Column():
75
+ gr.Markdown("<h3 style='text-align:center;'>Feature Importance (SHAP)</h3>")
76
+ shap_plot = gr.Plot(label="")
77
 
78
  predict_btn.click(
79
  fn=predict_income_fn,
80
  inputs=[age, education, sex, capital_gain, capital_loss, hours_per_week],
81
+ outputs=[output_label, output_confidence, shap_plot]
82
  )
83
 
84
  demo.launch()
85
+