Spaces:
Build error
Build error
File size: 2,753 Bytes
28ad851 358d13c f9c685d 358d13c 46691a0 358d13c 46691a0 358d13c e55cce3 358d13c e55cce3 358d13c e55cce3 358d13c e55cce3 358d13c e55cce3 358d13c a61f45b e55cce3 48ccfec 5e412ca e55cce3 6306fc6 5e412ca e55cce3 54ab004 e55cce3 a2c6fde c43a556 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import gradio as gr
import pandas as pd
import numpy as np
import pickle
import shap
import matplotlib.pyplot as plt
# Load model
with open("salar_xgb_team.pkl", "rb") as f:
model = pickle.load(f)
# Set up SHAP
explainer = shap.Explainer(model)
# Define prediction function
def predict_salary(age, education_num, sex, capital_gain, capital_loss, hours_per_week):
sex_num = 0 if sex == "Male" else 1
input_data = pd.DataFrame([[age, education_num, sex_num, capital_gain, capital_loss, hours_per_week]],
columns=['age', 'education-num', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week'])
# Prediction & confidence
pred = model.predict(input_data)[0]
prob = model.predict_proba(input_data)[0][1]
label = ">50K" if pred == 1 else "<=50K"
confidence = f"{prob * 100:.2f}%" if pred == 1 else f"{(1 - prob) * 100:.2f}%"
# SHAP values
shap_values = explainer(input_data)
fig, ax = plt.subplots(figsize=(6, 2.5))
shap.plots.bar(shap_values[0], max_display=6, show=False)
plt.tight_layout()
return label, confidence, fig
# Build UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## 💼 Income Prediction App")
gr.Markdown(
"""
This tool uses a trained XGBoost model to predict whether someone earns more than $50K/year based on demographic and financial information.
It also shows which features influenced the prediction the most, using SHAP explainability.
"""
)
with gr.Row():
with gr.Column():
age = gr.Number(label="Age", value=35)
education = gr.Number(label="Education Level (numeric)", value=10)
sex = gr.Radio(["Male", "Female"], label="Sex", value="Male")
cap_gain = gr.Number(label="Capital Gain", value=0)
cap_loss = gr.Number(label="Capital Loss", value=0)
hours = gr.Number(label="Hours per Week", value=40)
submit_btn = gr.Button("🔮 Predict")
with gr.Column():
result = gr.Label(label="Predicted Income")
confidence = gr.Label(label="Prediction Confidence")
shap_plot = gr.Plot(label="Feature Importance (SHAP)")
gr.Markdown("### 🧪 Try Example Inputs")
gr.Examples(
examples=[
[24, 9, "Female", 0, 0, 25],
[45, 13, "Male", 5000, 0, 50],
[39, 10, "Female", 0, 0, 35],
[60, 16, "Male", 0, 0, 40],
],
inputs=[age, education, sex, cap_gain, cap_loss, hours],
)
submit_btn.click(fn=predict_salary,
inputs=[age, education, sex, cap_gain, cap_loss, hours],
outputs=[result, confidence, shap_plot])
demo.launch() |