eaglelandsonce's picture
Create app.py
941a235
import gradio as gr
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier
from fairlearn.metrics import MetricFrame, selection_rate, demographic_parity_difference
import shap
import matplotlib.pyplot as plt
# -----------------------------
# Core training + metrics logic
# -----------------------------
def train_and_evaluate(csv_file, target_col, sensitive_col):
if csv_file is None:
return "Please upload a CSV.", None, None, None
# Load data
df = pd.read_csv(csv_file.name)
# Basic validation
if target_col not in df.columns:
return f"Target column '{target_col}' not found in CSV.", None, None, None
if sensitive_col not in df.columns:
return f"Sensitive column '{sensitive_col}' not found in CSV.", None, None, None
# Drop rows with missing target
df = df.dropna(subset=[target_col])
# Separate features/target
y = df[target_col]
X = df.drop(columns=[target_col])
# Keep a copy of sensitive feature before encoding
sensitive_series = df[sensitive_col]
# Identify numeric vs categorical
numeric_cols = X.select_dtypes(include=["int64", "float64"]).columns.tolist()
categorical_cols = [c for c in X.columns if c not in numeric_cols]
# Preprocess
numeric_transformer = "passthrough"
categorical_transformer = OneHotEncoder(handle_unknown="ignore")
preprocessor = ColumnTransformer(
transformers=[
("num", numeric_transformer, numeric_cols),
("cat", categorical_transformer, categorical_cols),
]
)
# Model
model = RandomForestClassifier(
n_estimators=100,
random_state=42
)
clf = Pipeline(
steps=[
("preprocessor", preprocessor),
("model", model),
]
)
# Train/test split
X_train, X_test, y_train, y_test, sens_train, sens_test = train_test_split(
X, y, sensitive_series, test_size=0.3, random_state=42, stratify=y
)
# Fit
clf.fit(X_train, y_train)
# Predictions
y_pred = clf.predict(X_test)
# -----------------
# Standard accuracy
# -----------------
acc = accuracy_score(y_test, y_pred)
# -------------------------
# Fairlearn: Demographic Parity
# -------------------------
# selection_rate expects y_pred and sensitive features
mf = MetricFrame(
metrics=selection_rate,
y_true=y_test,
y_pred=y_pred,
sensitive_features=sens_test
)
# Overall selection rate by group
group_selection_rates = mf.by_group
# Demographic parity difference
dp_diff = demographic_parity_difference(
y_true=y_test,
y_pred=y_pred,
sensitive_features=sens_test
)
# Governance threshold example
governance_threshold = 0.10
policy_status = (
"Blocked: Demographic parity difference exceeds threshold."
if abs(dp_diff) > governance_threshold
else "Allowed: Within governance threshold."
)
# -----------------
# SHAP explanation
# -----------------
# Extract trained model and transformed data for SHAP
# We use a small sample for speed
X_test_sample = X_test.sample(min(200, len(X_test)), random_state=42)
# Fit a separate preprocessing-only transform to get numeric matrix
X_test_transformed = clf.named_steps["preprocessor"].transform(X_test_sample)
rf_model = clf.named_steps["model"]
# SHAP for tree models
explainer = shap.TreeExplainer(rf_model)
shap_values = explainer.shap_values(X_test_transformed)
# Get feature names after preprocessing
# numeric + one-hot categories
feature_names = []
feature_names.extend(numeric_cols)
if categorical_cols:
ohe = clf.named_steps["preprocessor"].named_transformers_["cat"]
ohe_feature_names = ohe.get_feature_names_out(categorical_cols).tolist()
feature_names.extend(ohe_feature_names)
# Summary plot (global importance)
plt.figure(figsize=(8, 6))
shap.summary_plot(
shap_values[1] if isinstance(shap_values, list) else shap_values,
X_test_transformed,
feature_names=feature_names,
show=False
)
plt.tight_layout()
shap_plot_path = "shap_summary.png"
plt.savefig(shap_plot_path, dpi=120)
plt.close()
# -----------------
# Build text outputs
# -----------------
metrics_text = []
metrics_text.append(f"Accuracy: {acc:.3f}")
metrics_text.append("")
metrics_text.append("Selection rate by sensitive group:")
metrics_text.append(str(group_selection_rates))
metrics_text.append("")
metrics_text.append(f"Demographic Parity Difference: {dp_diff:.3f}")
metrics_text.append(f"Governance Threshold: {governance_threshold:.3f}")
metrics_text.append(f"Policy Status: {policy_status}")
metrics_text = "\n".join(metrics_text)
# Also return a small table of group metrics as HTML
group_df = group_selection_rates.reset_index()
group_df.columns = [sensitive_col, "selection_rate"]
group_html = group_df.to_html(index=False)
return metrics_text, group_html, shap_plot_path, df.head().to_html(index=False)
# -----------------------------
# Gradio interface
# -----------------------------
def get_columns(csv_file):
if csv_file is None:
return gr.update(choices=[]), gr.update(choices=[])
df = pd.read_csv(csv_file.name)
cols = df.columns.tolist()
return gr.update(choices=cols, value=cols[-1]), gr.update(choices=cols, value=cols[0])
with gr.Blocks(title="AI Governance Lab - CSV + Fairness + SHAP") as demo:
gr.Markdown("# 🧭 AI Governance Lab\nUpload a CSV, pick target and sensitive columns, train, and inspect fairness + SHAP.")
with gr.Row():
csv_input = gr.File(label="Upload CSV", file_types=[".csv"])
with gr.Row():
target_dropdown = gr.Dropdown(
label="Target column (label)",
choices=[],
interactive=True
)
sensitive_dropdown = gr.Dropdown(
label="Sensitive attribute column (e.g., sex, race)",
choices=[],
interactive=True
)
csv_input.change(
fn=get_columns,
inputs=csv_input,
outputs=[target_dropdown, sensitive_dropdown]
)
run_button = gr.Button("Train & Evaluate")
metrics_output = gr.Textbox(
label="Model & Fairness Metrics",
lines=12
)
group_table_output = gr.HTML(label="Group Selection Rates")
shap_image_output = gr.Image(label="SHAP Summary Plot")
preview_output = gr.HTML(label="Data Preview (first 5 rows)")
run_button.click(
fn=train_and_evaluate,
inputs=[csv_input, target_dropdown, sensitive_dropdown],
outputs=[metrics_output, group_table_output, shap_image_output, preview_output]
)
demo.launch()