Spaces:
Sleeping
Sleeping
File size: 4,904 Bytes
28783f9 57dad6e 28783f9 537dc6d 28783f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import pickle
import pandas as pd
import shap
from shap.plots._force_matplotlib import draw_additive_plot
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
# Load the model
loaded_model = pickle.load(open("salar_xgb_team1.pkl", 'rb'))
# Setup SHAP (do not change)
explainer = shap.Explainer(loaded_model)
def main_func(age, education, workclass, marital_status, race, sex, hours_per_week):
# Binary encoding for sex
sex = 1 if sex == "Female" else 0
# Use your actual training column list (excluding 'salary-class')
model_columns = [
'age', 'sex', 'hours-per-week',
'education_11th', 'education_12th', 'education_1st-4th', 'education_5th-6th',
'education_7th-8th', 'education_9th', 'education_Assoc-acdm', 'education_Assoc-voc',
'education_Bachelors', 'education_Doctorate', 'education_HS-grad', 'education_Masters',
'education_Preschool', 'education_Prof-school', 'education_Some-college',
'workclass_Local-gov', 'workclass_Never-worked', 'workclass_Private',
'workclass_Self-emp-inc', 'workclass_Self-emp-not-inc', 'workclass_State-gov',
'workclass_Without-pay',
'martital-status_Married-AF-spouse', 'martital-status_Married-civ-spouse',
'martital-status_Married-spouse-absent', 'martital-status_Never-married',
'martital-status_Separated', 'martital-status_Widowed',
'race_Asian-Pac-Islander', 'race_Black', 'race_Other', 'race_White'
]
# Initialize all values to 0
row_data = dict.fromkeys(model_columns, 0)
# Set continuous features
row_data['age'] = age
row_data['sex'] = sex
row_data['hours-per-week'] = hours_per_week
# Set one-hot encoded values
edu_col = f'education_{education}'
wc_col = f'workclass_{workclass}'
ms_col = f'martital-status_{marital_status}' # keep typo!
race_col = f'race_{race}'
for col in [edu_col, wc_col, ms_col, race_col]:
if col in row_data:
row_data[col] = 1
else:
print(f"⚠️ Warning: Column {col} not found in model — check input spelling")
# Create DataFrame for prediction
new_row = pd.DataFrame([row_data])
# Make prediction
prob = loaded_model.predict_proba(new_row)
# SHAP explanation
shap_values = explainer(new_row)
plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
plt.tight_layout()
local_plot = plt.gcf()
plt.close()
return {"≤ $50K": float(prob[0][0]), "> $50K": float(prob[0][1])}, local_plot
# Gradio UI
title = "Income Predictor Group 2"
description1 = """This app takes demographic information to predict whether a household
earns ≤ $50K or > $50K annually"""
with gr.Blocks(title=title) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description1)
gr.Markdown("""---""")
gr.Image("house.png.jpg")
with gr.Row():
with gr.Column():
age = gr.Number(label="Age", value=35)
sex = gr.Radio(choices=["Male", "Female"], label="Sex", value="Female")
hours_per_week = gr.Number(label="Hours per Week", value=40)
with gr.Column():
education = gr.Dropdown(
choices=["Bachelors", "Masters", "HS-grad", "Some-college", "Doctorate"],
label="Education",
value="Bachelors"
)
workclass = gr.Dropdown(
choices=["Private", "Self-emp-not-inc", "State-gov", "Local-gov", "Never-worked", "Without-pay"],
label="Workclass",
value="Private"
)
marital_status = gr.Dropdown(
choices=["Never-married", "Married-civ-spouse", "Divorced", "Separated", "Widowed", "Married-AF-spouse", "Married-spouse-absent"],
label="Marital Status",
value="Never-married"
)
race = gr.Dropdown(
choices=["White", "Black", "Asian-Pac-Islander", "Other"],
label="Race",
value="White"
)
submit_btn = gr.Button("Predict Income")
with gr.Column(visible=True) as output_col:
label = gr.Label(label="Predicted Income")
local_plot = gr.Plot(label='SHAP Interpretation:')
submit_btn.click(
main_func,
[age, education, workclass, marital_status, race, sex, hours_per_week],
[label, local_plot], api_name="Income_Predictor"
)
gr.Markdown("### Try these examples:")
gr.Examples(
[[39,'Bachelors', 'Local-gov','Separated', 'Other','Male', 45],
[52, 'Masters', "State-gov", 'Married-AF-spouse', 'White', 'Female',34]],
[age, education, workclass, marital_status,race, sex, hours_per_week],
[label, local_plot],
main_func,
cache_examples=True
)
demo.launch()
|