Spaces:
Sleeping
Sleeping
File size: 7,134 Bytes
941a235 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 | import gradio as gr
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier
from fairlearn.metrics import MetricFrame, selection_rate, demographic_parity_difference
import shap
import matplotlib.pyplot as plt
# -----------------------------
# Core training + metrics logic
# -----------------------------
def train_and_evaluate(csv_file, target_col, sensitive_col):
if csv_file is None:
return "Please upload a CSV.", None, None, None
# Load data
df = pd.read_csv(csv_file.name)
# Basic validation
if target_col not in df.columns:
return f"Target column '{target_col}' not found in CSV.", None, None, None
if sensitive_col not in df.columns:
return f"Sensitive column '{sensitive_col}' not found in CSV.", None, None, None
# Drop rows with missing target
df = df.dropna(subset=[target_col])
# Separate features/target
y = df[target_col]
X = df.drop(columns=[target_col])
# Keep a copy of sensitive feature before encoding
sensitive_series = df[sensitive_col]
# Identify numeric vs categorical
numeric_cols = X.select_dtypes(include=["int64", "float64"]).columns.tolist()
categorical_cols = [c for c in X.columns if c not in numeric_cols]
# Preprocess
numeric_transformer = "passthrough"
categorical_transformer = OneHotEncoder(handle_unknown="ignore")
preprocessor = ColumnTransformer(
transformers=[
("num", numeric_transformer, numeric_cols),
("cat", categorical_transformer, categorical_cols),
]
)
# Model
model = RandomForestClassifier(
n_estimators=100,
random_state=42
)
clf = Pipeline(
steps=[
("preprocessor", preprocessor),
("model", model),
]
)
# Train/test split
X_train, X_test, y_train, y_test, sens_train, sens_test = train_test_split(
X, y, sensitive_series, test_size=0.3, random_state=42, stratify=y
)
# Fit
clf.fit(X_train, y_train)
# Predictions
y_pred = clf.predict(X_test)
# -----------------
# Standard accuracy
# -----------------
acc = accuracy_score(y_test, y_pred)
# -------------------------
# Fairlearn: Demographic Parity
# -------------------------
# selection_rate expects y_pred and sensitive features
mf = MetricFrame(
metrics=selection_rate,
y_true=y_test,
y_pred=y_pred,
sensitive_features=sens_test
)
# Overall selection rate by group
group_selection_rates = mf.by_group
# Demographic parity difference
dp_diff = demographic_parity_difference(
y_true=y_test,
y_pred=y_pred,
sensitive_features=sens_test
)
# Governance threshold example
governance_threshold = 0.10
policy_status = (
"Blocked: Demographic parity difference exceeds threshold."
if abs(dp_diff) > governance_threshold
else "Allowed: Within governance threshold."
)
# -----------------
# SHAP explanation
# -----------------
# Extract trained model and transformed data for SHAP
# We use a small sample for speed
X_test_sample = X_test.sample(min(200, len(X_test)), random_state=42)
# Fit a separate preprocessing-only transform to get numeric matrix
X_test_transformed = clf.named_steps["preprocessor"].transform(X_test_sample)
rf_model = clf.named_steps["model"]
# SHAP for tree models
explainer = shap.TreeExplainer(rf_model)
shap_values = explainer.shap_values(X_test_transformed)
# Get feature names after preprocessing
# numeric + one-hot categories
feature_names = []
feature_names.extend(numeric_cols)
if categorical_cols:
ohe = clf.named_steps["preprocessor"].named_transformers_["cat"]
ohe_feature_names = ohe.get_feature_names_out(categorical_cols).tolist()
feature_names.extend(ohe_feature_names)
# Summary plot (global importance)
plt.figure(figsize=(8, 6))
shap.summary_plot(
shap_values[1] if isinstance(shap_values, list) else shap_values,
X_test_transformed,
feature_names=feature_names,
show=False
)
plt.tight_layout()
shap_plot_path = "shap_summary.png"
plt.savefig(shap_plot_path, dpi=120)
plt.close()
# -----------------
# Build text outputs
# -----------------
metrics_text = []
metrics_text.append(f"Accuracy: {acc:.3f}")
metrics_text.append("")
metrics_text.append("Selection rate by sensitive group:")
metrics_text.append(str(group_selection_rates))
metrics_text.append("")
metrics_text.append(f"Demographic Parity Difference: {dp_diff:.3f}")
metrics_text.append(f"Governance Threshold: {governance_threshold:.3f}")
metrics_text.append(f"Policy Status: {policy_status}")
metrics_text = "\n".join(metrics_text)
# Also return a small table of group metrics as HTML
group_df = group_selection_rates.reset_index()
group_df.columns = [sensitive_col, "selection_rate"]
group_html = group_df.to_html(index=False)
return metrics_text, group_html, shap_plot_path, df.head().to_html(index=False)
# -----------------------------
# Gradio interface
# -----------------------------
def get_columns(csv_file):
if csv_file is None:
return gr.update(choices=[]), gr.update(choices=[])
df = pd.read_csv(csv_file.name)
cols = df.columns.tolist()
return gr.update(choices=cols, value=cols[-1]), gr.update(choices=cols, value=cols[0])
with gr.Blocks(title="AI Governance Lab - CSV + Fairness + SHAP") as demo:
gr.Markdown("# 🧭 AI Governance Lab\nUpload a CSV, pick target and sensitive columns, train, and inspect fairness + SHAP.")
with gr.Row():
csv_input = gr.File(label="Upload CSV", file_types=[".csv"])
with gr.Row():
target_dropdown = gr.Dropdown(
label="Target column (label)",
choices=[],
interactive=True
)
sensitive_dropdown = gr.Dropdown(
label="Sensitive attribute column (e.g., sex, race)",
choices=[],
interactive=True
)
csv_input.change(
fn=get_columns,
inputs=csv_input,
outputs=[target_dropdown, sensitive_dropdown]
)
run_button = gr.Button("Train & Evaluate")
metrics_output = gr.Textbox(
label="Model & Fairness Metrics",
lines=12
)
group_table_output = gr.HTML(label="Group Selection Rates")
shap_image_output = gr.Image(label="SHAP Summary Plot")
preview_output = gr.HTML(label="Data Preview (first 5 rows)")
run_button.click(
fn=train_and_evaluate,
inputs=[csv_input, target_dropdown, sensitive_dropdown],
outputs=[metrics_output, group_table_output, shap_image_output, preview_output]
)
demo.launch()
|