tenzinlodoe commited on
Commit
e55cce3
·
verified ·
1 Parent(s): 75dd2d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -52
app.py CHANGED
@@ -15,75 +15,61 @@ explainer = shap.Explainer(model)
15
  # Define prediction function
16
  def predict_salary(age, education_num, sex, capital_gain, capital_loss, hours_per_week):
17
  sex_num = 0 if sex == "Male" else 1
18
- input_data = pd.DataFrame([[age, int(education_num), sex_num, capital_gain, capital_loss, hours_per_week]],
19
  columns=['age', 'education-num', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week'])
20
 
 
21
  pred = model.predict(input_data)[0]
22
  prob = model.predict_proba(input_data)[0][1]
23
  label = ">50K" if pred == 1 else "<=50K"
24
  confidence = f"{prob * 100:.2f}%" if pred == 1 else f"{(1 - prob) * 100:.2f}%"
25
 
26
- # SHAP plot with dark background
27
  shap_values = explainer(input_data)
28
- plt.style.use('dark_background')
29
- fig, ax = plt.subplots(figsize=(6, 3))
30
  shap.plots.bar(shap_values[0], max_display=6, show=False)
31
  plt.tight_layout()
32
-
33
  return label, confidence, fig
34
 
35
- with gr.Blocks(css="""
36
- body { background-color: #0a0a23; color: white; }
37
- .gr-box { background-color: #1a1a3d !important; border-radius: 12px; padding: 20px; }
38
- h1, h3, p { color: white; text-align: center; font-family: 'Segoe UI', sans-serif; }
39
- .gr-button { background-color: #1db954 !important; color: white !important; border-radius: 10px; font-size: 1.1em; }
40
- .gr-radio .gr-form { display: flex; justify-content: center !important; }
41
- """
42
- ) as demo:
43
-
44
- gr.Markdown("""
45
- <div style='max-width: 700px; margin: 0 auto;'>
46
- <h1 style='font-size: 2.5em;'>💼 Income Prediction App</h1>
47
- <p style='font-size: 1.2em;'>Predict whether someone earns more than $50K/year using financial and demographic data, with AI explainability via SHAP.</p>
48
- </div>
49
- """)
50
 
51
  with gr.Row():
52
  with gr.Column():
53
- gr.Markdown("<h3>Age</h3>")
54
- age = gr.Slider(minimum=0, maximum=100, step=1, value=35, label="", interactive=True)
55
-
56
- gr.Markdown("<h3>Education Level</h3>")
57
- education = gr.Dropdown(choices=[str(i) for i in range(1, 17)], value="10", label="", interactive=True)
58
-
59
- gr.Markdown("<h3>Sex</h3>")
60
- sex = gr.Radio(choices=["Male", "Female"], value="Male", label="", interactive=True)
61
-
62
- gr.Markdown("<h3>Capital Gain</h3>")
63
- capital_gain = gr.Number(value=0, label="", interactive=True)
64
-
65
- gr.Markdown("<h3>Capital Loss</h3>")
66
- capital_loss = gr.Number(value=0, label="", interactive=True)
67
-
68
- gr.Markdown("<h3>Hours per Week</h3>")
69
- hours_per_week = gr.Number(value=40, label="", interactive=True)
70
 
71
- predict_btn = gr.Button("🔮 Predict", elem_id="predict-button")
72
-
73
- with gr.Row():
74
  with gr.Column():
75
- gr.Markdown("<h3>Prediction Result</h3>")
76
- result = gr.Textbox(label="", interactive=False)
77
- confidence = gr.Textbox(label="Confidence", interactive=False)
78
- shap_plot = gr.Plot(label="SHAP Feature Importance")
79
-
80
- predict_btn.click(
81
- fn=predict_salary,
82
- inputs=[age, education, sex, capital_gain, capital_loss, hours_per_week],
83
- outputs=[result, confidence, shap_plot]
 
 
 
 
84
  )
85
 
86
- demo.launch()
87
-
88
-
89
 
 
 
15
  # Define prediction function
16
  def predict_salary(age, education_num, sex, capital_gain, capital_loss, hours_per_week):
17
  sex_num = 0 if sex == "Male" else 1
18
+ input_data = pd.DataFrame([[age, education_num, sex_num, capital_gain, capital_loss, hours_per_week]],
19
  columns=['age', 'education-num', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week'])
20
 
21
+ # Prediction & confidence
22
  pred = model.predict(input_data)[0]
23
  prob = model.predict_proba(input_data)[0][1]
24
  label = ">50K" if pred == 1 else "<=50K"
25
  confidence = f"{prob * 100:.2f}%" if pred == 1 else f"{(1 - prob) * 100:.2f}%"
26
 
27
+ # SHAP values
28
  shap_values = explainer(input_data)
29
+ fig, ax = plt.subplots(figsize=(6, 2.5))
 
30
  shap.plots.bar(shap_values[0], max_display=6, show=False)
31
  plt.tight_layout()
32
+
33
  return label, confidence, fig
34
 
35
+ # Build UI
36
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
37
+ gr.Markdown("## 💼 Income Prediction App")
38
+ gr.Markdown(
39
+ """
40
+ This tool uses a trained XGBoost model to predict whether someone earns more than $50K/year based on demographic and financial information.
41
+ It also shows which features influenced the prediction the most, using SHAP explainability.
42
+ """
43
+ )
 
 
 
 
 
 
44
 
45
  with gr.Row():
46
  with gr.Column():
47
+ age = gr.Number(label="Age", value=35)
48
+ education = gr.Number(label="Education Level (numeric)", value=10)
49
+ sex = gr.Radio(["Male", "Female"], label="Sex", value="Male")
50
+ cap_gain = gr.Number(label="Capital Gain", value=0)
51
+ cap_loss = gr.Number(label="Capital Loss", value=0)
52
+ hours = gr.Number(label="Hours per Week", value=40)
53
+ submit_btn = gr.Button("🔮 Predict")
 
 
 
 
 
 
 
 
 
 
54
 
 
 
 
55
  with gr.Column():
56
+ result = gr.Label(label="Predicted Income")
57
+ confidence = gr.Label(label="Prediction Confidence")
58
+ shap_plot = gr.Plot(label="Feature Importance (SHAP)")
59
+
60
+ gr.Markdown("### 🧪 Try Example Inputs")
61
+ gr.Examples(
62
+ examples=[
63
+ [24, 9, "Female", 0, 0, 25],
64
+ [45, 13, "Male", 5000, 0, 50],
65
+ [39, 10, "Female", 0, 0, 35],
66
+ [60, 16, "Male", 0, 0, 40],
67
+ ],
68
+ inputs=[age, education, sex, cap_gain, cap_loss, hours],
69
  )
70
 
71
+ submit_btn.click(fn=predict_salary,
72
+ inputs=[age, education, sex, cap_gain, cap_loss, hours],
73
+ outputs=[result, confidence, shap_plot])
74
 
75
+ demo.launch()