Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.preprocessing import OneHotEncoder | |
| from sklearn.compose import ColumnTransformer | |
| from sklearn.pipeline import Pipeline | |
| from sklearn.metrics import accuracy_score | |
| from sklearn.ensemble import RandomForestClassifier | |
| from fairlearn.metrics import MetricFrame, selection_rate, demographic_parity_difference | |
| import shap | |
| import matplotlib.pyplot as plt | |
| # ----------------------------- | |
| # Core training + metrics logic | |
| # ----------------------------- | |
| def train_and_evaluate(csv_file, target_col, sensitive_col): | |
| if csv_file is None: | |
| return "Please upload a CSV.", None, None, None | |
| # Load data | |
| df = pd.read_csv(csv_file.name) | |
| # Basic validation | |
| if target_col not in df.columns: | |
| return f"Target column '{target_col}' not found in CSV.", None, None, None | |
| if sensitive_col not in df.columns: | |
| return f"Sensitive column '{sensitive_col}' not found in CSV.", None, None, None | |
| # Drop rows with missing target | |
| df = df.dropna(subset=[target_col]) | |
| # Separate features/target | |
| y = df[target_col] | |
| X = df.drop(columns=[target_col]) | |
| # Keep a copy of sensitive feature before encoding | |
| sensitive_series = df[sensitive_col] | |
| # Identify numeric vs categorical | |
| numeric_cols = X.select_dtypes(include=["int64", "float64"]).columns.tolist() | |
| categorical_cols = [c for c in X.columns if c not in numeric_cols] | |
| # Preprocess | |
| numeric_transformer = "passthrough" | |
| categorical_transformer = OneHotEncoder(handle_unknown="ignore") | |
| preprocessor = ColumnTransformer( | |
| transformers=[ | |
| ("num", numeric_transformer, numeric_cols), | |
| ("cat", categorical_transformer, categorical_cols), | |
| ] | |
| ) | |
| # Model | |
| model = RandomForestClassifier( | |
| n_estimators=100, | |
| random_state=42 | |
| ) | |
| clf = Pipeline( | |
| steps=[ | |
| ("preprocessor", preprocessor), | |
| ("model", model), | |
| ] | |
| ) | |
| # Train/test split | |
| X_train, X_test, y_train, y_test, sens_train, sens_test = train_test_split( | |
| X, y, sensitive_series, test_size=0.3, random_state=42, stratify=y | |
| ) | |
| # Fit | |
| clf.fit(X_train, y_train) | |
| # Predictions | |
| y_pred = clf.predict(X_test) | |
| # ----------------- | |
| # Standard accuracy | |
| # ----------------- | |
| acc = accuracy_score(y_test, y_pred) | |
| # ------------------------- | |
| # Fairlearn: Demographic Parity | |
| # ------------------------- | |
| # selection_rate expects y_pred and sensitive features | |
| mf = MetricFrame( | |
| metrics=selection_rate, | |
| y_true=y_test, | |
| y_pred=y_pred, | |
| sensitive_features=sens_test | |
| ) | |
| # Overall selection rate by group | |
| group_selection_rates = mf.by_group | |
| # Demographic parity difference | |
| dp_diff = demographic_parity_difference( | |
| y_true=y_test, | |
| y_pred=y_pred, | |
| sensitive_features=sens_test | |
| ) | |
| # Governance threshold example | |
| governance_threshold = 0.10 | |
| policy_status = ( | |
| "Blocked: Demographic parity difference exceeds threshold." | |
| if abs(dp_diff) > governance_threshold | |
| else "Allowed: Within governance threshold." | |
| ) | |
| # ----------------- | |
| # SHAP explanation | |
| # ----------------- | |
| # Extract trained model and transformed data for SHAP | |
| # We use a small sample for speed | |
| X_test_sample = X_test.sample(min(200, len(X_test)), random_state=42) | |
| # Fit a separate preprocessing-only transform to get numeric matrix | |
| X_test_transformed = clf.named_steps["preprocessor"].transform(X_test_sample) | |
| rf_model = clf.named_steps["model"] | |
| # SHAP for tree models | |
| explainer = shap.TreeExplainer(rf_model) | |
| shap_values = explainer.shap_values(X_test_transformed) | |
| # Get feature names after preprocessing | |
| # numeric + one-hot categories | |
| feature_names = [] | |
| feature_names.extend(numeric_cols) | |
| if categorical_cols: | |
| ohe = clf.named_steps["preprocessor"].named_transformers_["cat"] | |
| ohe_feature_names = ohe.get_feature_names_out(categorical_cols).tolist() | |
| feature_names.extend(ohe_feature_names) | |
| # Summary plot (global importance) | |
| plt.figure(figsize=(8, 6)) | |
| shap.summary_plot( | |
| shap_values[1] if isinstance(shap_values, list) else shap_values, | |
| X_test_transformed, | |
| feature_names=feature_names, | |
| show=False | |
| ) | |
| plt.tight_layout() | |
| shap_plot_path = "shap_summary.png" | |
| plt.savefig(shap_plot_path, dpi=120) | |
| plt.close() | |
| # ----------------- | |
| # Build text outputs | |
| # ----------------- | |
| metrics_text = [] | |
| metrics_text.append(f"Accuracy: {acc:.3f}") | |
| metrics_text.append("") | |
| metrics_text.append("Selection rate by sensitive group:") | |
| metrics_text.append(str(group_selection_rates)) | |
| metrics_text.append("") | |
| metrics_text.append(f"Demographic Parity Difference: {dp_diff:.3f}") | |
| metrics_text.append(f"Governance Threshold: {governance_threshold:.3f}") | |
| metrics_text.append(f"Policy Status: {policy_status}") | |
| metrics_text = "\n".join(metrics_text) | |
| # Also return a small table of group metrics as HTML | |
| group_df = group_selection_rates.reset_index() | |
| group_df.columns = [sensitive_col, "selection_rate"] | |
| group_html = group_df.to_html(index=False) | |
| return metrics_text, group_html, shap_plot_path, df.head().to_html(index=False) | |
| # ----------------------------- | |
| # Gradio interface | |
| # ----------------------------- | |
| def get_columns(csv_file): | |
| if csv_file is None: | |
| return gr.update(choices=[]), gr.update(choices=[]) | |
| df = pd.read_csv(csv_file.name) | |
| cols = df.columns.tolist() | |
| return gr.update(choices=cols, value=cols[-1]), gr.update(choices=cols, value=cols[0]) | |
| with gr.Blocks(title="AI Governance Lab - CSV + Fairness + SHAP") as demo: | |
| gr.Markdown("# 🧭 AI Governance Lab\nUpload a CSV, pick target and sensitive columns, train, and inspect fairness + SHAP.") | |
| with gr.Row(): | |
| csv_input = gr.File(label="Upload CSV", file_types=[".csv"]) | |
| with gr.Row(): | |
| target_dropdown = gr.Dropdown( | |
| label="Target column (label)", | |
| choices=[], | |
| interactive=True | |
| ) | |
| sensitive_dropdown = gr.Dropdown( | |
| label="Sensitive attribute column (e.g., sex, race)", | |
| choices=[], | |
| interactive=True | |
| ) | |
| csv_input.change( | |
| fn=get_columns, | |
| inputs=csv_input, | |
| outputs=[target_dropdown, sensitive_dropdown] | |
| ) | |
| run_button = gr.Button("Train & Evaluate") | |
| metrics_output = gr.Textbox( | |
| label="Model & Fairness Metrics", | |
| lines=12 | |
| ) | |
| group_table_output = gr.HTML(label="Group Selection Rates") | |
| shap_image_output = gr.Image(label="SHAP Summary Plot") | |
| preview_output = gr.HTML(label="Data Preview (first 5 rows)") | |
| run_button.click( | |
| fn=train_and_evaluate, | |
| inputs=[csv_input, target_dropdown, sensitive_dropdown], | |
| outputs=[metrics_output, group_table_output, shap_image_output, preview_output] | |
| ) | |
| demo.launch() | |