Team_2 / app.py
MatthiasLamb's picture
Update app.py
57dad6e verified
import pickle
import pandas as pd
import shap
from shap.plots._force_matplotlib import draw_additive_plot
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
# Load the model
loaded_model = pickle.load(open("salar_xgb_team1.pkl", 'rb'))
# Setup SHAP (do not change)
explainer = shap.Explainer(loaded_model)
def main_func(age, education, workclass, marital_status, race, sex, hours_per_week):
# Binary encoding for sex
sex = 1 if sex == "Female" else 0
# Use your actual training column list (excluding 'salary-class')
model_columns = [
'age', 'sex', 'hours-per-week',
'education_11th', 'education_12th', 'education_1st-4th', 'education_5th-6th',
'education_7th-8th', 'education_9th', 'education_Assoc-acdm', 'education_Assoc-voc',
'education_Bachelors', 'education_Doctorate', 'education_HS-grad', 'education_Masters',
'education_Preschool', 'education_Prof-school', 'education_Some-college',
'workclass_Local-gov', 'workclass_Never-worked', 'workclass_Private',
'workclass_Self-emp-inc', 'workclass_Self-emp-not-inc', 'workclass_State-gov',
'workclass_Without-pay',
'martital-status_Married-AF-spouse', 'martital-status_Married-civ-spouse',
'martital-status_Married-spouse-absent', 'martital-status_Never-married',
'martital-status_Separated', 'martital-status_Widowed',
'race_Asian-Pac-Islander', 'race_Black', 'race_Other', 'race_White'
]
# Initialize all values to 0
row_data = dict.fromkeys(model_columns, 0)
# Set continuous features
row_data['age'] = age
row_data['sex'] = sex
row_data['hours-per-week'] = hours_per_week
# Set one-hot encoded values
edu_col = f'education_{education}'
wc_col = f'workclass_{workclass}'
ms_col = f'martital-status_{marital_status}' # keep typo!
race_col = f'race_{race}'
for col in [edu_col, wc_col, ms_col, race_col]:
if col in row_data:
row_data[col] = 1
else:
print(f"⚠️ Warning: Column {col} not found in model — check input spelling")
# Create DataFrame for prediction
new_row = pd.DataFrame([row_data])
# Make prediction
prob = loaded_model.predict_proba(new_row)
# SHAP explanation
shap_values = explainer(new_row)
plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
plt.tight_layout()
local_plot = plt.gcf()
plt.close()
return {"≤ $50K": float(prob[0][0]), "> $50K": float(prob[0][1])}, local_plot
# Gradio UI
title = "Income Predictor Group 2"
description1 = """This app takes demographic information to predict whether a household
earns ≤ $50K or > $50K annually"""
with gr.Blocks(title=title) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description1)
gr.Markdown("""---""")
gr.Image("house.png.jpg")
with gr.Row():
with gr.Column():
age = gr.Number(label="Age", value=35)
sex = gr.Radio(choices=["Male", "Female"], label="Sex", value="Female")
hours_per_week = gr.Number(label="Hours per Week", value=40)
with gr.Column():
education = gr.Dropdown(
choices=["Bachelors", "Masters", "HS-grad", "Some-college", "Doctorate"],
label="Education",
value="Bachelors"
)
workclass = gr.Dropdown(
choices=["Private", "Self-emp-not-inc", "State-gov", "Local-gov", "Never-worked", "Without-pay"],
label="Workclass",
value="Private"
)
marital_status = gr.Dropdown(
choices=["Never-married", "Married-civ-spouse", "Divorced", "Separated", "Widowed", "Married-AF-spouse", "Married-spouse-absent"],
label="Marital Status",
value="Never-married"
)
race = gr.Dropdown(
choices=["White", "Black", "Asian-Pac-Islander", "Other"],
label="Race",
value="White"
)
submit_btn = gr.Button("Predict Income")
with gr.Column(visible=True) as output_col:
label = gr.Label(label="Predicted Income")
local_plot = gr.Plot(label='SHAP Interpretation:')
submit_btn.click(
main_func,
[age, education, workclass, marital_status, race, sex, hours_per_week],
[label, local_plot], api_name="Income_Predictor"
)
gr.Markdown("### Try these examples:")
gr.Examples(
[[39,'Bachelors', 'Local-gov','Separated', 'Other','Male', 45],
[52, 'Masters', "State-gov", 'Married-AF-spouse', 'White', 'Female',34]],
[age, education, workclass, marital_status,race, sex, hours_per_week],
[label, local_plot],
main_func,
cache_examples=True
)
demo.launch()