starboywilliam commited on
Commit
46691a0
·
verified ·
1 Parent(s): 358d13c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -31
app.py CHANGED
@@ -9,13 +9,13 @@ import matplotlib.pyplot as plt
9
  with open("salar_xgb_team.pkl", "rb") as f:
10
  model = pickle.load(f)
11
 
12
- # SHAP setup
13
  explainer = shap.Explainer(model)
14
 
15
- # Prediction function
16
- def predict_income_fn(age, education, sex, capital_gain, capital_loss, hours_per_week):
17
  sex_num = 0 if sex == "Male" else 1
18
- input_data = pd.DataFrame([[age, int(education), sex_num, capital_gain, capital_loss, hours_per_week]],
19
  columns=['age', 'education-num', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week'])
20
 
21
  pred = model.predict(input_data)[0]
@@ -23,63 +23,64 @@ def predict_income_fn(age, education, sex, capital_gain, capital_loss, hours_per
23
  label = ">50K" if pred == 1 else "<=50K"
24
  confidence = f"{prob * 100:.2f}%" if pred == 1 else f"{(1 - prob) * 100:.2f}%"
25
 
26
- # SHAP
27
  shap_values = explainer(input_data)
 
28
  fig, ax = plt.subplots(figsize=(6, 3))
29
  shap.plots.bar(shap_values[0], max_display=6, show=False)
30
  plt.tight_layout()
31
 
32
  return label, confidence, fig
33
 
34
- # Gradio UI
35
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
36
- gr.Markdown(
37
- """
38
- <div style='text-align: center; max-width: 750px; margin: 0 auto;'>
39
- <h1 style='font-size: 2.5em; color: #1DB954;'>💼 Income Prediction App</h1>
40
- <p style='font-size: 1.2em;'>Predict whether someone earns more than $50K/year based on financial and demographic data.</p>
 
 
 
 
41
  </div>
42
- """
43
- )
44
 
45
  with gr.Row():
46
  with gr.Column():
47
- gr.Markdown("<h3 style='text-align:center;'>Age</h3>")
48
- age = gr.Number(value=30, label="", interactive=True)
49
 
50
- gr.Markdown("<h3 style='text-align:center;'>Education Level</h3>")
51
  education = gr.Dropdown(choices=[str(i) for i in range(1, 17)], value="10", label="", interactive=True)
52
 
53
- gr.Markdown("<h3 style='text-align:center;'>Sex</h3>")
54
  sex = gr.Radio(choices=["Male", "Female"], value="Male", label="", interactive=True)
55
 
56
- gr.Markdown("<h3 style='text-align:center;'>Capital Gain</h3>")
57
  capital_gain = gr.Number(value=0, label="", interactive=True)
58
 
59
- gr.Markdown("<h3 style='text-align:center;'>Capital Loss</h3>")
60
  capital_loss = gr.Number(value=0, label="", interactive=True)
61
 
62
- gr.Markdown("<h3 style='text-align:center;'>Hours per Week</h3>")
63
  hours_per_week = gr.Number(value=40, label="", interactive=True)
64
 
65
- gr.Markdown("<div style='text-align:center;'><br></div>")
66
  predict_btn = gr.Button("🔮 Predict", elem_id="predict-button")
67
 
68
  with gr.Row():
69
  with gr.Column():
70
- gr.Markdown("<h3 style='text-align:center;'>Prediction</h3>")
71
- output_label = gr.Textbox(label="Income", interactive=False)
72
- output_confidence = gr.Textbox(label="Confidence", interactive=False)
73
-
74
- with gr.Column():
75
- gr.Markdown("<h3 style='text-align:center;'>Feature Importance (SHAP)</h3>")
76
- shap_plot = gr.Plot(label="")
77
 
78
  predict_btn.click(
79
- fn=predict_income_fn,
80
  inputs=[age, education, sex, capital_gain, capital_loss, hours_per_week],
81
- outputs=[output_label, output_confidence, shap_plot]
82
  )
83
 
84
  demo.launch()
85
 
 
 
9
  with open("salar_xgb_team.pkl", "rb") as f:
10
  model = pickle.load(f)
11
 
12
+ # Set up SHAP
13
  explainer = shap.Explainer(model)
14
 
15
+ # Define prediction function
16
+ def predict_salary(age, education_num, sex, capital_gain, capital_loss, hours_per_week):
17
  sex_num = 0 if sex == "Male" else 1
18
+ input_data = pd.DataFrame([[age, int(education_num), sex_num, capital_gain, capital_loss, hours_per_week]],
19
  columns=['age', 'education-num', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week'])
20
 
21
  pred = model.predict(input_data)[0]
 
23
  label = ">50K" if pred == 1 else "<=50K"
24
  confidence = f"{prob * 100:.2f}%" if pred == 1 else f"{(1 - prob) * 100:.2f}%"
25
 
26
+ # SHAP plot with dark background
27
  shap_values = explainer(input_data)
28
+ plt.style.use('dark_background')
29
  fig, ax = plt.subplots(figsize=(6, 3))
30
  shap.plots.bar(shap_values[0], max_display=6, show=False)
31
  plt.tight_layout()
32
 
33
  return label, confidence, fig
34
 
35
+ with gr.Blocks(css="""
36
+ body { background-color: #0a0a23; color: white; }
37
+ .gr-box { background-color: #1a1a3d !important; border-radius: 12px; padding: 20px; }
38
+ h1, h3, p { color: white; text-align: center; font-family: 'Segoe UI', sans-serif; }
39
+ .gr-button { background-color: #1db954 !important; color: white !important; border-radius: 10px; font-size: 1.1em; }
40
+ """) as demo:
41
+
42
+ gr.Markdown("""
43
+ <div style='max-width: 700px; margin: 0 auto;'>
44
+ <h1 style='font-size: 2.5em;'>💼 Income Prediction App</h1>
45
+ <p style='font-size: 1.2em;'>Predict whether someone earns more than $50K/year using financial and demographic data, with AI explainability via SHAP.</p>
46
  </div>
47
+ """)
 
48
 
49
  with gr.Row():
50
  with gr.Column():
51
+ gr.Markdown("<h3>Age</h3>")
52
+ age = gr.Slider(minimum=0, maximum=100, step=1, value=35, label="", interactive=True)
53
 
54
+ gr.Markdown("<h3>Education Level</h3>")
55
  education = gr.Dropdown(choices=[str(i) for i in range(1, 17)], value="10", label="", interactive=True)
56
 
57
+ gr.Markdown("<h3>Sex</h3>")
58
  sex = gr.Radio(choices=["Male", "Female"], value="Male", label="", interactive=True)
59
 
60
+ gr.Markdown("<h3>Capital Gain</h3>")
61
  capital_gain = gr.Number(value=0, label="", interactive=True)
62
 
63
+ gr.Markdown("<h3>Capital Loss</h3>")
64
  capital_loss = gr.Number(value=0, label="", interactive=True)
65
 
66
+ gr.Markdown("<h3>Hours per Week</h3>")
67
  hours_per_week = gr.Number(value=40, label="", interactive=True)
68
 
 
69
  predict_btn = gr.Button("🔮 Predict", elem_id="predict-button")
70
 
71
  with gr.Row():
72
  with gr.Column():
73
+ gr.Markdown("<h3>Prediction Result</h3>")
74
+ result = gr.Textbox(label="", interactive=False)
75
+ confidence = gr.Textbox(label="Confidence", interactive=False)
76
+ shap_plot = gr.Plot(label="SHAP Feature Importance")
 
 
 
77
 
78
  predict_btn.click(
79
+ fn=predict_salary,
80
  inputs=[age, education, sex, capital_gain, capital_loss, hours_per_week],
81
+ outputs=[result, confidence, shap_plot]
82
  )
83
 
84
  demo.launch()
85
 
86
+