import gradio as gr import pandas as pd import numpy as np from sklearn.model_selection import train_test_split from sklearn.preprocessing import OneHotEncoder from sklearn.compose import ColumnTransformer from sklearn.pipeline import Pipeline from sklearn.metrics import accuracy_score from sklearn.ensemble import RandomForestClassifier from fairlearn.metrics import MetricFrame, selection_rate, demographic_parity_difference import shap import matplotlib.pyplot as plt # ----------------------------- # Core training + metrics logic # ----------------------------- def train_and_evaluate(csv_file, target_col, sensitive_col): if csv_file is None: return "Please upload a CSV.", None, None, None # Load data df = pd.read_csv(csv_file.name) # Basic validation if target_col not in df.columns: return f"Target column '{target_col}' not found in CSV.", None, None, None if sensitive_col not in df.columns: return f"Sensitive column '{sensitive_col}' not found in CSV.", None, None, None # Drop rows with missing target df = df.dropna(subset=[target_col]) # Separate features/target y = df[target_col] X = df.drop(columns=[target_col]) # Keep a copy of sensitive feature before encoding sensitive_series = df[sensitive_col] # Identify numeric vs categorical numeric_cols = X.select_dtypes(include=["int64", "float64"]).columns.tolist() categorical_cols = [c for c in X.columns if c not in numeric_cols] # Preprocess numeric_transformer = "passthrough" categorical_transformer = OneHotEncoder(handle_unknown="ignore") preprocessor = ColumnTransformer( transformers=[ ("num", numeric_transformer, numeric_cols), ("cat", categorical_transformer, categorical_cols), ] ) # Model model = RandomForestClassifier( n_estimators=100, random_state=42 ) clf = Pipeline( steps=[ ("preprocessor", preprocessor), ("model", model), ] ) # Train/test split X_train, X_test, y_train, y_test, sens_train, sens_test = train_test_split( X, y, sensitive_series, test_size=0.3, random_state=42, stratify=y ) # Fit clf.fit(X_train, y_train) # Predictions y_pred = clf.predict(X_test) # ----------------- # Standard accuracy # ----------------- acc = accuracy_score(y_test, y_pred) # ------------------------- # Fairlearn: Demographic Parity # ------------------------- # selection_rate expects y_pred and sensitive features mf = MetricFrame( metrics=selection_rate, y_true=y_test, y_pred=y_pred, sensitive_features=sens_test ) # Overall selection rate by group group_selection_rates = mf.by_group # Demographic parity difference dp_diff = demographic_parity_difference( y_true=y_test, y_pred=y_pred, sensitive_features=sens_test ) # Governance threshold example governance_threshold = 0.10 policy_status = ( "Blocked: Demographic parity difference exceeds threshold." if abs(dp_diff) > governance_threshold else "Allowed: Within governance threshold." ) # ----------------- # SHAP explanation # ----------------- # Extract trained model and transformed data for SHAP # We use a small sample for speed X_test_sample = X_test.sample(min(200, len(X_test)), random_state=42) # Fit a separate preprocessing-only transform to get numeric matrix X_test_transformed = clf.named_steps["preprocessor"].transform(X_test_sample) rf_model = clf.named_steps["model"] # SHAP for tree models explainer = shap.TreeExplainer(rf_model) shap_values = explainer.shap_values(X_test_transformed) # Get feature names after preprocessing # numeric + one-hot categories feature_names = [] feature_names.extend(numeric_cols) if categorical_cols: ohe = clf.named_steps["preprocessor"].named_transformers_["cat"] ohe_feature_names = ohe.get_feature_names_out(categorical_cols).tolist() feature_names.extend(ohe_feature_names) # Summary plot (global importance) plt.figure(figsize=(8, 6)) shap.summary_plot( shap_values[1] if isinstance(shap_values, list) else shap_values, X_test_transformed, feature_names=feature_names, show=False ) plt.tight_layout() shap_plot_path = "shap_summary.png" plt.savefig(shap_plot_path, dpi=120) plt.close() # ----------------- # Build text outputs # ----------------- metrics_text = [] metrics_text.append(f"Accuracy: {acc:.3f}") metrics_text.append("") metrics_text.append("Selection rate by sensitive group:") metrics_text.append(str(group_selection_rates)) metrics_text.append("") metrics_text.append(f"Demographic Parity Difference: {dp_diff:.3f}") metrics_text.append(f"Governance Threshold: {governance_threshold:.3f}") metrics_text.append(f"Policy Status: {policy_status}") metrics_text = "\n".join(metrics_text) # Also return a small table of group metrics as HTML group_df = group_selection_rates.reset_index() group_df.columns = [sensitive_col, "selection_rate"] group_html = group_df.to_html(index=False) return metrics_text, group_html, shap_plot_path, df.head().to_html(index=False) # ----------------------------- # Gradio interface # ----------------------------- def get_columns(csv_file): if csv_file is None: return gr.update(choices=[]), gr.update(choices=[]) df = pd.read_csv(csv_file.name) cols = df.columns.tolist() return gr.update(choices=cols, value=cols[-1]), gr.update(choices=cols, value=cols[0]) with gr.Blocks(title="AI Governance Lab - CSV + Fairness + SHAP") as demo: gr.Markdown("# 🧭 AI Governance Lab\nUpload a CSV, pick target and sensitive columns, train, and inspect fairness + SHAP.") with gr.Row(): csv_input = gr.File(label="Upload CSV", file_types=[".csv"]) with gr.Row(): target_dropdown = gr.Dropdown( label="Target column (label)", choices=[], interactive=True ) sensitive_dropdown = gr.Dropdown( label="Sensitive attribute column (e.g., sex, race)", choices=[], interactive=True ) csv_input.change( fn=get_columns, inputs=csv_input, outputs=[target_dropdown, sensitive_dropdown] ) run_button = gr.Button("Train & Evaluate") metrics_output = gr.Textbox( label="Model & Fairness Metrics", lines=12 ) group_table_output = gr.HTML(label="Group Selection Rates") shap_image_output = gr.Image(label="SHAP Summary Plot") preview_output = gr.HTML(label="Data Preview (first 5 rows)") run_button.click( fn=train_and_evaluate, inputs=[csv_input, target_dropdown, sensitive_dropdown], outputs=[metrics_output, group_table_output, shap_image_output, preview_output] ) demo.launch()