Spaces:
Build error
Build error
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import pickle | |
| import shap | |
| import matplotlib.pyplot as plt | |
| # Load model | |
| with open("salar_xgb_team.pkl", "rb") as f: | |
| model = pickle.load(f) | |
| # Set up SHAP | |
| explainer = shap.Explainer(model) | |
| # Define prediction function | |
| def predict_salary(age, education_num, sex, capital_gain, capital_loss, hours_per_week): | |
| sex_num = 0 if sex == "Male" else 1 | |
| input_data = pd.DataFrame([[age, education_num, sex_num, capital_gain, capital_loss, hours_per_week]], | |
| columns=['age', 'education-num', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week']) | |
| # Prediction & confidence | |
| pred = model.predict(input_data)[0] | |
| prob = model.predict_proba(input_data)[0][1] | |
| label = ">50K" if pred == 1 else "<=50K" | |
| confidence = f"{prob * 100:.2f}%" if pred == 1 else f"{(1 - prob) * 100:.2f}%" | |
| # SHAP values | |
| shap_values = explainer(input_data) | |
| fig, ax = plt.subplots(figsize=(6, 2.5)) | |
| shap.plots.bar(shap_values[0], max_display=6, show=False) | |
| plt.tight_layout() | |
| return label, confidence, fig | |
| # Build UI | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("## 💼 Income Prediction App") | |
| gr.Markdown( | |
| """ | |
| This tool uses a trained XGBoost model to predict whether someone earns more than $50K/year based on demographic and financial information. | |
| It also shows which features influenced the prediction the most, using SHAP explainability. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| age = gr.Number(label="Age", value=35) | |
| education = gr.Number(label="Education Level (numeric)", value=10) | |
| sex = gr.Radio(["Male", "Female"], label="Sex", value="Male") | |
| cap_gain = gr.Number(label="Capital Gain", value=0) | |
| cap_loss = gr.Number(label="Capital Loss", value=0) | |
| hours = gr.Number(label="Hours per Week", value=40) | |
| submit_btn = gr.Button("🔮 Predict") | |
| with gr.Column(): | |
| result = gr.Label(label="Predicted Income") | |
| confidence = gr.Label(label="Prediction Confidence") | |
| shap_plot = gr.Plot(label="Feature Importance (SHAP)") | |
| gr.Markdown("### 🧪 Try Example Inputs") | |
| gr.Examples( | |
| examples=[ | |
| [24, 9, "Female", 0, 0, 25], | |
| [45, 13, "Male", 5000, 0, 50], | |
| [39, 10, "Female", 0, 0, 35], | |
| [60, 16, "Male", 0, 0, 40], | |
| ], | |
| inputs=[age, education, sex, cap_gain, cap_loss, hours], | |
| ) | |
| submit_btn.click(fn=predict_salary, | |
| inputs=[age, education, sex, cap_gain, cap_loss, hours], | |
| outputs=[result, confidence, shap_plot]) | |
| demo.launch() |