import pickle import pandas as pd import shap from shap.plots._force_matplotlib import draw_additive_plot import gradio as gr import numpy as np import matplotlib.pyplot as plt # Load the model loaded_model = pickle.load(open("salar_xgb_team1.pkl", 'rb')) # Setup SHAP (do not change) explainer = shap.Explainer(loaded_model) def main_func(age, education, workclass, marital_status, race, sex, hours_per_week): # Binary encoding for sex sex = 1 if sex == "Female" else 0 # Use your actual training column list (excluding 'salary-class') model_columns = [ 'age', 'sex', 'hours-per-week', 'education_11th', 'education_12th', 'education_1st-4th', 'education_5th-6th', 'education_7th-8th', 'education_9th', 'education_Assoc-acdm', 'education_Assoc-voc', 'education_Bachelors', 'education_Doctorate', 'education_HS-grad', 'education_Masters', 'education_Preschool', 'education_Prof-school', 'education_Some-college', 'workclass_Local-gov', 'workclass_Never-worked', 'workclass_Private', 'workclass_Self-emp-inc', 'workclass_Self-emp-not-inc', 'workclass_State-gov', 'workclass_Without-pay', 'martital-status_Married-AF-spouse', 'martital-status_Married-civ-spouse', 'martital-status_Married-spouse-absent', 'martital-status_Never-married', 'martital-status_Separated', 'martital-status_Widowed', 'race_Asian-Pac-Islander', 'race_Black', 'race_Other', 'race_White' ] # Initialize all values to 0 row_data = dict.fromkeys(model_columns, 0) # Set continuous features row_data['age'] = age row_data['sex'] = sex row_data['hours-per-week'] = hours_per_week # Set one-hot encoded values edu_col = f'education_{education}' wc_col = f'workclass_{workclass}' ms_col = f'martital-status_{marital_status}' # keep typo! race_col = f'race_{race}' for col in [edu_col, wc_col, ms_col, race_col]: if col in row_data: row_data[col] = 1 else: print(f"⚠️ Warning: Column {col} not found in model — check input spelling") # Create DataFrame for prediction new_row = pd.DataFrame([row_data]) # Make prediction prob = loaded_model.predict_proba(new_row) # SHAP explanation shap_values = explainer(new_row) plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False) plt.tight_layout() local_plot = plt.gcf() plt.close() return {"≤ $50K": float(prob[0][0]), "> $50K": float(prob[0][1])}, local_plot # Gradio UI title = "Income Predictor Group 2" description1 = """This app takes demographic information to predict whether a household earns ≤ $50K or > $50K annually""" with gr.Blocks(title=title) as demo: gr.Markdown(f"## {title}") gr.Markdown(description1) gr.Markdown("""---""") gr.Image("house.png.jpg") with gr.Row(): with gr.Column(): age = gr.Number(label="Age", value=35) sex = gr.Radio(choices=["Male", "Female"], label="Sex", value="Female") hours_per_week = gr.Number(label="Hours per Week", value=40) with gr.Column(): education = gr.Dropdown( choices=["Bachelors", "Masters", "HS-grad", "Some-college", "Doctorate"], label="Education", value="Bachelors" ) workclass = gr.Dropdown( choices=["Private", "Self-emp-not-inc", "State-gov", "Local-gov", "Never-worked", "Without-pay"], label="Workclass", value="Private" ) marital_status = gr.Dropdown( choices=["Never-married", "Married-civ-spouse", "Divorced", "Separated", "Widowed", "Married-AF-spouse", "Married-spouse-absent"], label="Marital Status", value="Never-married" ) race = gr.Dropdown( choices=["White", "Black", "Asian-Pac-Islander", "Other"], label="Race", value="White" ) submit_btn = gr.Button("Predict Income") with gr.Column(visible=True) as output_col: label = gr.Label(label="Predicted Income") local_plot = gr.Plot(label='SHAP Interpretation:') submit_btn.click( main_func, [age, education, workclass, marital_status, race, sex, hours_per_week], [label, local_plot], api_name="Income_Predictor" ) gr.Markdown("### Try these examples:") gr.Examples( [[39,'Bachelors', 'Local-gov','Separated', 'Other','Male', 45], [52, 'Masters', "State-gov", 'Married-AF-spouse', 'White', 'Female',34]], [age, education, workclass, marital_status,race, sex, hours_per_week], [label, local_plot], main_func, cache_examples=True ) demo.launch()