Spaces:
Sleeping
Sleeping
| import pickle | |
| import pandas as pd | |
| import shap | |
| from shap.plots._force_matplotlib import draw_additive_plot | |
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| # Load the model | |
| loaded_model = pickle.load(open("salar_xgb_team1.pkl", 'rb')) | |
| # Setup SHAP (do not change) | |
| explainer = shap.Explainer(loaded_model) | |
| def main_func(age, education, workclass, marital_status, race, sex, hours_per_week): | |
| # Binary encoding for sex | |
| sex = 1 if sex == "Female" else 0 | |
| # Use your actual training column list (excluding 'salary-class') | |
| model_columns = [ | |
| 'age', 'sex', 'hours-per-week', | |
| 'education_11th', 'education_12th', 'education_1st-4th', 'education_5th-6th', | |
| 'education_7th-8th', 'education_9th', 'education_Assoc-acdm', 'education_Assoc-voc', | |
| 'education_Bachelors', 'education_Doctorate', 'education_HS-grad', 'education_Masters', | |
| 'education_Preschool', 'education_Prof-school', 'education_Some-college', | |
| 'workclass_Local-gov', 'workclass_Never-worked', 'workclass_Private', | |
| 'workclass_Self-emp-inc', 'workclass_Self-emp-not-inc', 'workclass_State-gov', | |
| 'workclass_Without-pay', | |
| 'martital-status_Married-AF-spouse', 'martital-status_Married-civ-spouse', | |
| 'martital-status_Married-spouse-absent', 'martital-status_Never-married', | |
| 'martital-status_Separated', 'martital-status_Widowed', | |
| 'race_Asian-Pac-Islander', 'race_Black', 'race_Other', 'race_White' | |
| ] | |
| # Initialize all values to 0 | |
| row_data = dict.fromkeys(model_columns, 0) | |
| # Set continuous features | |
| row_data['age'] = age | |
| row_data['sex'] = sex | |
| row_data['hours-per-week'] = hours_per_week | |
| # Set one-hot encoded values | |
| edu_col = f'education_{education}' | |
| wc_col = f'workclass_{workclass}' | |
| ms_col = f'martital-status_{marital_status}' # keep typo! | |
| race_col = f'race_{race}' | |
| for col in [edu_col, wc_col, ms_col, race_col]: | |
| if col in row_data: | |
| row_data[col] = 1 | |
| else: | |
| print(f"⚠️ Warning: Column {col} not found in model — check input spelling") | |
| # Create DataFrame for prediction | |
| new_row = pd.DataFrame([row_data]) | |
| # Make prediction | |
| prob = loaded_model.predict_proba(new_row) | |
| # SHAP explanation | |
| shap_values = explainer(new_row) | |
| plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False) | |
| plt.tight_layout() | |
| local_plot = plt.gcf() | |
| plt.close() | |
| return {"≤ $50K": float(prob[0][0]), "> $50K": float(prob[0][1])}, local_plot | |
| # Gradio UI | |
| title = "Income Predictor Group 2" | |
| description1 = """This app takes demographic information to predict whether a household | |
| earns ≤ $50K or > $50K annually""" | |
| with gr.Blocks(title=title) as demo: | |
| gr.Markdown(f"## {title}") | |
| gr.Markdown(description1) | |
| gr.Markdown("""---""") | |
| gr.Image("house.png.jpg") | |
| with gr.Row(): | |
| with gr.Column(): | |
| age = gr.Number(label="Age", value=35) | |
| sex = gr.Radio(choices=["Male", "Female"], label="Sex", value="Female") | |
| hours_per_week = gr.Number(label="Hours per Week", value=40) | |
| with gr.Column(): | |
| education = gr.Dropdown( | |
| choices=["Bachelors", "Masters", "HS-grad", "Some-college", "Doctorate"], | |
| label="Education", | |
| value="Bachelors" | |
| ) | |
| workclass = gr.Dropdown( | |
| choices=["Private", "Self-emp-not-inc", "State-gov", "Local-gov", "Never-worked", "Without-pay"], | |
| label="Workclass", | |
| value="Private" | |
| ) | |
| marital_status = gr.Dropdown( | |
| choices=["Never-married", "Married-civ-spouse", "Divorced", "Separated", "Widowed", "Married-AF-spouse", "Married-spouse-absent"], | |
| label="Marital Status", | |
| value="Never-married" | |
| ) | |
| race = gr.Dropdown( | |
| choices=["White", "Black", "Asian-Pac-Islander", "Other"], | |
| label="Race", | |
| value="White" | |
| ) | |
| submit_btn = gr.Button("Predict Income") | |
| with gr.Column(visible=True) as output_col: | |
| label = gr.Label(label="Predicted Income") | |
| local_plot = gr.Plot(label='SHAP Interpretation:') | |
| submit_btn.click( | |
| main_func, | |
| [age, education, workclass, marital_status, race, sex, hours_per_week], | |
| [label, local_plot], api_name="Income_Predictor" | |
| ) | |
| gr.Markdown("### Try these examples:") | |
| gr.Examples( | |
| [[39,'Bachelors', 'Local-gov','Separated', 'Other','Male', 45], | |
| [52, 'Masters', "State-gov", 'Married-AF-spouse', 'White', 'Female',34]], | |
| [age, education, workclass, marital_status,race, sex, hours_per_week], | |
| [label, local_plot], | |
| main_func, | |
| cache_examples=True | |
| ) | |
| demo.launch() | |