File size: 2,753 Bytes
28ad851
358d13c
 
 
 
 
f9c685d
358d13c
 
 
 
46691a0
358d13c
 
46691a0
 
358d13c
e55cce3
358d13c
 
e55cce3
358d13c
 
 
 
 
e55cce3
358d13c
e55cce3
358d13c
 
e55cce3
358d13c
a61f45b
e55cce3
 
 
 
 
 
 
 
 
48ccfec
5e412ca
 
e55cce3
 
 
 
 
 
 
6306fc6
5e412ca
e55cce3
 
 
 
 
 
 
 
 
 
 
 
 
54ab004
 
e55cce3
 
 
a2c6fde
c43a556
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import pandas as pd
import numpy as np
import pickle
import shap
import matplotlib.pyplot as plt

# Load model
with open("salar_xgb_team.pkl", "rb") as f:
    model = pickle.load(f)

# Set up SHAP
explainer = shap.Explainer(model)

# Define prediction function
def predict_salary(age, education_num, sex, capital_gain, capital_loss, hours_per_week):
    sex_num = 0 if sex == "Male" else 1
    input_data = pd.DataFrame([[age, education_num, sex_num, capital_gain, capital_loss, hours_per_week]],
                              columns=['age', 'education-num', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week'])

    # Prediction & confidence
    pred = model.predict(input_data)[0]
    prob = model.predict_proba(input_data)[0][1]
    label = ">50K" if pred == 1 else "<=50K"
    confidence = f"{prob * 100:.2f}%" if pred == 1 else f"{(1 - prob) * 100:.2f}%"

    # SHAP values
    shap_values = explainer(input_data)
    fig, ax = plt.subplots(figsize=(6, 2.5))
    shap.plots.bar(shap_values[0], max_display=6, show=False)
    plt.tight_layout()
    
    return label, confidence, fig

# Build UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## 💼 Income Prediction App")
    gr.Markdown(
        """
        This tool uses a trained XGBoost model to predict whether someone earns more than $50K/year based on demographic and financial information.
        It also shows which features influenced the prediction the most, using SHAP explainability.
        """
    )

    with gr.Row():
        with gr.Column():
            age = gr.Number(label="Age", value=35)
            education = gr.Number(label="Education Level (numeric)", value=10)
            sex = gr.Radio(["Male", "Female"], label="Sex", value="Male")
            cap_gain = gr.Number(label="Capital Gain", value=0)
            cap_loss = gr.Number(label="Capital Loss", value=0)
            hours = gr.Number(label="Hours per Week", value=40)
            submit_btn = gr.Button("🔮 Predict")

        with gr.Column():
            result = gr.Label(label="Predicted Income")
            confidence = gr.Label(label="Prediction Confidence")
            shap_plot = gr.Plot(label="Feature Importance (SHAP)")

    gr.Markdown("### 🧪 Try Example Inputs")
    gr.Examples(
        examples=[
            [24, 9, "Female", 0, 0, 25],
            [45, 13, "Male", 5000, 0, 50],
            [39, 10, "Female", 0, 0, 35],
            [60, 16, "Male", 0, 0, 40],
        ],
        inputs=[age, education, sex, cap_gain, cap_loss, hours],
    )

    submit_btn.click(fn=predict_salary,
                     inputs=[age, education, sex, cap_gain, cap_loss, hours],
                     outputs=[result, confidence, shap_plot])

demo.launch()