tenzinlodoe's picture
Update app.py
c43a556 verified
import gradio as gr
import pandas as pd
import numpy as np
import pickle
import shap
import matplotlib.pyplot as plt
# Load model
with open("salar_xgb_team.pkl", "rb") as f:
model = pickle.load(f)
# Set up SHAP
explainer = shap.Explainer(model)
# Define prediction function
def predict_salary(age, education_num, sex, capital_gain, capital_loss, hours_per_week):
sex_num = 0 if sex == "Male" else 1
input_data = pd.DataFrame([[age, education_num, sex_num, capital_gain, capital_loss, hours_per_week]],
columns=['age', 'education-num', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week'])
# Prediction & confidence
pred = model.predict(input_data)[0]
prob = model.predict_proba(input_data)[0][1]
label = ">50K" if pred == 1 else "<=50K"
confidence = f"{prob * 100:.2f}%" if pred == 1 else f"{(1 - prob) * 100:.2f}%"
# SHAP values
shap_values = explainer(input_data)
fig, ax = plt.subplots(figsize=(6, 2.5))
shap.plots.bar(shap_values[0], max_display=6, show=False)
plt.tight_layout()
return label, confidence, fig
# Build UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## 💼 Income Prediction App")
gr.Markdown(
"""
This tool uses a trained XGBoost model to predict whether someone earns more than $50K/year based on demographic and financial information.
It also shows which features influenced the prediction the most, using SHAP explainability.
"""
)
with gr.Row():
with gr.Column():
age = gr.Number(label="Age", value=35)
education = gr.Number(label="Education Level (numeric)", value=10)
sex = gr.Radio(["Male", "Female"], label="Sex", value="Male")
cap_gain = gr.Number(label="Capital Gain", value=0)
cap_loss = gr.Number(label="Capital Loss", value=0)
hours = gr.Number(label="Hours per Week", value=40)
submit_btn = gr.Button("🔮 Predict")
with gr.Column():
result = gr.Label(label="Predicted Income")
confidence = gr.Label(label="Prediction Confidence")
shap_plot = gr.Plot(label="Feature Importance (SHAP)")
gr.Markdown("### 🧪 Try Example Inputs")
gr.Examples(
examples=[
[24, 9, "Female", 0, 0, 25],
[45, 13, "Male", 5000, 0, 50],
[39, 10, "Female", 0, 0, 35],
[60, 16, "Male", 0, 0, 40],
],
inputs=[age, education, sex, cap_gain, cap_loss, hours],
)
submit_btn.click(fn=predict_salary,
inputs=[age, education, sex, cap_gain, cap_loss, hours],
outputs=[result, confidence, shap_plot])
demo.launch()