starboywilliam commited on
Commit
7531df6
·
verified ·
1 Parent(s): a074585

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -52
app.py CHANGED
@@ -7,82 +7,86 @@ import matplotlib.pyplot as plt
7
 
8
  # Load model
9
  with open("salar_xgb_team.pkl", "rb") as f:
10
- model = pickle.load(f)
11
 
12
  # Set up SHAP
13
  explainer = shap.Explainer(model)
14
 
15
  # Define prediction function
16
  def predict_salary(age, education_num, sex, capital_gain, capital_loss, hours_per_week):
17
- sex_num = 0 if sex == "Male" else 1
18
- input_data = pd.DataFrame([[age, int(education_num), sex_num, capital_gain, capital_loss, hours_per_week]],
19
- columns=['age', 'education-num', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week'])
20
 
21
- pred = model.predict(input_data)[0]
22
- prob = model.predict_proba(input_data)[0][1]
23
- label = ">50K" if pred == 1 else "<=50K"
24
- confidence = f"{prob * 100:.2f}%" if pred == 1 else f"{(1 - prob) * 100:.2f}%"
25
 
26
- # SHAP plot with dark background
27
- shap_values = explainer(input_data)
28
- plt.style.use('dark_background')
29
- fig, ax = plt.subplots(figsize=(6, 3))
30
- shap.plots.bar(shap_values[0], max_display=6, show=False)
31
- plt.tight_layout()
32
 
33
- return label, confidence, fig
 
 
 
34
 
35
  # Custom CSS for styling
36
  dark_css = """
37
- body { background-color: #0a0a23; color: white; }
38
- .gr-box { background-color: #1a1a3d !important; border-radius: 12px; padding: 20px; }
39
- h1, h3, p { color: white; text-align: center; font-family: 'Segoe UI', sans-serif; }
40
- .gr-button { background-color: #1db954 !important; color: white !important; border-radius: 10px; font-size: 1.1em; }
41
- .gr-radio .gr-form { display: flex; justify-content: center !important; }
42
  """
43
 
44
  with gr.Blocks(css=dark_css) as demo:
45
 
46
- gr.Markdown("""
47
-
48
- 💼 Income Prediction App
49
- Predict whether someone earns more than $50K/year using financial and demographic data, with AI explainability via SHAP.
50
-
51
- """)
52
 
53
- with gr.Row():
54
- with gr.Column():
55
- gr.Markdown("🎂 Age")
56
- age = gr.Slider(minimum=0, maximum=100, step=1, value=35, label="", interactive=True)
57
 
58
- gr.Markdown("🎓 Education Level")
59
- education = gr.Dropdown(choices=[str(i) for i in range(1, 17)], value="10", label="", interactive=True)
60
 
61
- gr.Markdown("🚻 Sex")
62
- sex = gr.Radio(choices=["Male", "Female"], value="Male", label="", interactive=True)
63
 
64
- gr.Markdown("📈 Capital Gain")
65
- capital_gain = gr.Number(value=0, label="", interactive=True)
66
 
67
- gr.Markdown("📉 Capital Loss")
68
- capital_loss = gr.Number(value=0, label="", interactive=True)
69
 
70
- gr.Markdown("🕒 Hours per Week")
71
- hours_per_week = gr.Number(value=40, label="", interactive=True)
72
 
73
- predict_btn = gr.Button("🔮 Predict", elem_id="predict-button")
74
 
75
- with gr.Row():
76
- with gr.Column():
77
- gr.Markdown("📊 Prediction Result")
78
- result = gr.Textbox(label="", interactive=False)
79
- confidence = gr.Textbox(label="Confidence", interactive=False)
80
- shap_plot = gr.Plot(label="SHAP Feature Importance")
 
81
 
82
- predict_btn.click(
83
- fn=predict_salary,
84
- inputs=[age, education, sex, capital_gain, capital_loss, hours_per_week],
85
- outputs=[result, confidence, shap_plot]
86
- )
87
 
88
  demo.launch()
 
7
 
8
  # Load model
9
  with open("salar_xgb_team.pkl", "rb") as f:
10
+ model = pickle.load(f)
11
 
12
  # Set up SHAP
13
  explainer = shap.Explainer(model)
14
 
15
  # Define prediction function
16
  def predict_salary(age, education_num, sex, capital_gain, capital_loss, hours_per_week):
17
+ sex_num = 0 if sex == "Male" else 1
18
+ input_data = pd.DataFrame([[age, int(education_num), sex_num, capital_gain, capital_loss, hours_per_week]],
19
+ columns=['age', 'education-num', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week'])
20
 
21
+ pred = model.predict(input_data)[0]
22
+ prob = model.predict_proba(input_data)[0][1]
23
+ label = ">50K" if pred == 1 else "<=50K"
24
+ confidence = f"{prob * 100:.2f}%" if pred == 1 else f"{(1 - prob) * 100:.2f}%"
25
 
26
+ # SHAP plot with dark background
27
+ shap_values = explainer(input_data)
28
+ plt.style.use('dark_background')
29
+ fig, ax = plt.subplots(figsize=(6, 3))
30
+ shap.plots.bar(shap_values[0], max_display=6, show=False)
31
+ plt.tight_layout()
32
 
33
+ # Conditional message
34
+ message = "You're halfway to six figures!" if pred == 1 else "Keep working hard!"
35
+
36
+ return label, confidence, fig, message
37
 
38
  # Custom CSS for styling
39
  dark_css = """
40
+ body { background-color: #0a0a23; color: white; }
41
+ .gr-box { background-color: #1a1a3d !important; border-radius: 12px; padding: 20px; }
42
+ h1, h3, p { color: white; text-align: center; font-family: 'Segoe UI', sans-serif; }
43
+ .gr-button { background-color: #1db954 !important; color: white !important; border-radius: 10px; font-size: 1.1em; }
44
+ .gr-radio .gr-form { display: flex; justify-content: center !important; }
45
  """
46
 
47
  with gr.Blocks(css=dark_css) as demo:
48
 
49
+ gr.Markdown("""
50
+ <div style='max-width: 700px; margin: 0 auto;'>
51
+ <h1 style='font-size: 2.5em;'>💼 Income Prediction App</h1>
52
+ <p style='font-size: 1.2em;'>Predict whether someone earns more than $50K/year using financial and demographic data, with AI explainability via SHAP.</p>
53
+ </div>
54
+ """)
55
 
56
+ with gr.Row():
57
+ with gr.Column():
58
+ gr.Markdown("<h3>🎂 Age</h3>")
59
+ age = gr.Slider(minimum=0, maximum=100, step=1, value=35, label="", interactive=True)
60
 
61
+ gr.Markdown("<h3>🎓 Education Level</h3>")
62
+ education = gr.Dropdown(choices=[str(i) for i in range(1, 17)], value="10", label="", interactive=True)
63
 
64
+ gr.Markdown("<h3>🚻 Sex</h3>")
65
+ sex = gr.Radio(choices=["Male", "Female"], value="Male", label="", interactive=True)
66
 
67
+ gr.Markdown("<h3>📈 Capital Gain</h3>")
68
+ capital_gain = gr.Number(value=0, label="", interactive=True)
69
 
70
+ gr.Markdown("<h3>📉 Capital Loss</h3>")
71
+ capital_loss = gr.Number(value=0, label="", interactive=True)
72
 
73
+ gr.Markdown("<h3>🕒 Hours per Week</h3>")
74
+ hours_per_week = gr.Number(value=40, label="", interactive=True)
75
 
76
+ predict_btn = gr.Button("🔮 Predict", elem_id="predict-button")
77
 
78
+ with gr.Row():
79
+ with gr.Column():
80
+ gr.Markdown("<h3>📊 Prediction Result</h3>")
81
+ result = gr.Textbox(label="", interactive=False)
82
+ confidence = gr.Textbox(label="Confidence", interactive=False)
83
+ shap_plot = gr.Plot(label="SHAP Feature Importance")
84
+ message = gr.Textbox(label="Message", interactive=False)
85
 
86
+ predict_btn.click(
87
+ fn=predict_salary,
88
+ inputs=[age, education, sex, capital_gain, capital_loss, hours_per_week],
89
+ outputs=[result, confidence, shap_plot, message]
90
+ )
91
 
92
  demo.launch()