eaglelandsonce commited on
Commit
941a235
·
1 Parent(s): 1846e5f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -0
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ from sklearn.model_selection import train_test_split
6
+ from sklearn.preprocessing import OneHotEncoder
7
+ from sklearn.compose import ColumnTransformer
8
+ from sklearn.pipeline import Pipeline
9
+ from sklearn.metrics import accuracy_score
10
+
11
+ from sklearn.ensemble import RandomForestClassifier
12
+
13
+ from fairlearn.metrics import MetricFrame, selection_rate, demographic_parity_difference
14
+ import shap
15
+ import matplotlib.pyplot as plt
16
+
17
+ # -----------------------------
18
+ # Core training + metrics logic
19
+ # -----------------------------
20
+ def train_and_evaluate(csv_file, target_col, sensitive_col):
21
+ if csv_file is None:
22
+ return "Please upload a CSV.", None, None, None
23
+
24
+ # Load data
25
+ df = pd.read_csv(csv_file.name)
26
+
27
+ # Basic validation
28
+ if target_col not in df.columns:
29
+ return f"Target column '{target_col}' not found in CSV.", None, None, None
30
+ if sensitive_col not in df.columns:
31
+ return f"Sensitive column '{sensitive_col}' not found in CSV.", None, None, None
32
+
33
+ # Drop rows with missing target
34
+ df = df.dropna(subset=[target_col])
35
+
36
+ # Separate features/target
37
+ y = df[target_col]
38
+ X = df.drop(columns=[target_col])
39
+
40
+ # Keep a copy of sensitive feature before encoding
41
+ sensitive_series = df[sensitive_col]
42
+
43
+ # Identify numeric vs categorical
44
+ numeric_cols = X.select_dtypes(include=["int64", "float64"]).columns.tolist()
45
+ categorical_cols = [c for c in X.columns if c not in numeric_cols]
46
+
47
+ # Preprocess
48
+ numeric_transformer = "passthrough"
49
+ categorical_transformer = OneHotEncoder(handle_unknown="ignore")
50
+
51
+ preprocessor = ColumnTransformer(
52
+ transformers=[
53
+ ("num", numeric_transformer, numeric_cols),
54
+ ("cat", categorical_transformer, categorical_cols),
55
+ ]
56
+ )
57
+
58
+ # Model
59
+ model = RandomForestClassifier(
60
+ n_estimators=100,
61
+ random_state=42
62
+ )
63
+
64
+ clf = Pipeline(
65
+ steps=[
66
+ ("preprocessor", preprocessor),
67
+ ("model", model),
68
+ ]
69
+ )
70
+
71
+ # Train/test split
72
+ X_train, X_test, y_train, y_test, sens_train, sens_test = train_test_split(
73
+ X, y, sensitive_series, test_size=0.3, random_state=42, stratify=y
74
+ )
75
+
76
+ # Fit
77
+ clf.fit(X_train, y_train)
78
+
79
+ # Predictions
80
+ y_pred = clf.predict(X_test)
81
+
82
+ # -----------------
83
+ # Standard accuracy
84
+ # -----------------
85
+ acc = accuracy_score(y_test, y_pred)
86
+
87
+ # -------------------------
88
+ # Fairlearn: Demographic Parity
89
+ # -------------------------
90
+ # selection_rate expects y_pred and sensitive features
91
+ mf = MetricFrame(
92
+ metrics=selection_rate,
93
+ y_true=y_test,
94
+ y_pred=y_pred,
95
+ sensitive_features=sens_test
96
+ )
97
+
98
+ # Overall selection rate by group
99
+ group_selection_rates = mf.by_group
100
+
101
+ # Demographic parity difference
102
+ dp_diff = demographic_parity_difference(
103
+ y_true=y_test,
104
+ y_pred=y_pred,
105
+ sensitive_features=sens_test
106
+ )
107
+
108
+ # Governance threshold example
109
+ governance_threshold = 0.10
110
+ policy_status = (
111
+ "Blocked: Demographic parity difference exceeds threshold."
112
+ if abs(dp_diff) > governance_threshold
113
+ else "Allowed: Within governance threshold."
114
+ )
115
+
116
+ # -----------------
117
+ # SHAP explanation
118
+ # -----------------
119
+ # Extract trained model and transformed data for SHAP
120
+ # We use a small sample for speed
121
+ X_test_sample = X_test.sample(min(200, len(X_test)), random_state=42)
122
+
123
+ # Fit a separate preprocessing-only transform to get numeric matrix
124
+ X_test_transformed = clf.named_steps["preprocessor"].transform(X_test_sample)
125
+ rf_model = clf.named_steps["model"]
126
+
127
+ # SHAP for tree models
128
+ explainer = shap.TreeExplainer(rf_model)
129
+ shap_values = explainer.shap_values(X_test_transformed)
130
+
131
+ # Get feature names after preprocessing
132
+ # numeric + one-hot categories
133
+ feature_names = []
134
+ feature_names.extend(numeric_cols)
135
+
136
+ if categorical_cols:
137
+ ohe = clf.named_steps["preprocessor"].named_transformers_["cat"]
138
+ ohe_feature_names = ohe.get_feature_names_out(categorical_cols).tolist()
139
+ feature_names.extend(ohe_feature_names)
140
+
141
+ # Summary plot (global importance)
142
+ plt.figure(figsize=(8, 6))
143
+ shap.summary_plot(
144
+ shap_values[1] if isinstance(shap_values, list) else shap_values,
145
+ X_test_transformed,
146
+ feature_names=feature_names,
147
+ show=False
148
+ )
149
+ plt.tight_layout()
150
+ shap_plot_path = "shap_summary.png"
151
+ plt.savefig(shap_plot_path, dpi=120)
152
+ plt.close()
153
+
154
+ # -----------------
155
+ # Build text outputs
156
+ # -----------------
157
+ metrics_text = []
158
+ metrics_text.append(f"Accuracy: {acc:.3f}")
159
+ metrics_text.append("")
160
+ metrics_text.append("Selection rate by sensitive group:")
161
+ metrics_text.append(str(group_selection_rates))
162
+ metrics_text.append("")
163
+ metrics_text.append(f"Demographic Parity Difference: {dp_diff:.3f}")
164
+ metrics_text.append(f"Governance Threshold: {governance_threshold:.3f}")
165
+ metrics_text.append(f"Policy Status: {policy_status}")
166
+
167
+ metrics_text = "\n".join(metrics_text)
168
+
169
+ # Also return a small table of group metrics as HTML
170
+ group_df = group_selection_rates.reset_index()
171
+ group_df.columns = [sensitive_col, "selection_rate"]
172
+ group_html = group_df.to_html(index=False)
173
+
174
+ return metrics_text, group_html, shap_plot_path, df.head().to_html(index=False)
175
+
176
+
177
+ # -----------------------------
178
+ # Gradio interface
179
+ # -----------------------------
180
+ def get_columns(csv_file):
181
+ if csv_file is None:
182
+ return gr.update(choices=[]), gr.update(choices=[])
183
+ df = pd.read_csv(csv_file.name)
184
+ cols = df.columns.tolist()
185
+ return gr.update(choices=cols, value=cols[-1]), gr.update(choices=cols, value=cols[0])
186
+
187
+
188
+ with gr.Blocks(title="AI Governance Lab - CSV + Fairness + SHAP") as demo:
189
+ gr.Markdown("# 🧭 AI Governance Lab\nUpload a CSV, pick target and sensitive columns, train, and inspect fairness + SHAP.")
190
+
191
+ with gr.Row():
192
+ csv_input = gr.File(label="Upload CSV", file_types=[".csv"])
193
+
194
+ with gr.Row():
195
+ target_dropdown = gr.Dropdown(
196
+ label="Target column (label)",
197
+ choices=[],
198
+ interactive=True
199
+ )
200
+ sensitive_dropdown = gr.Dropdown(
201
+ label="Sensitive attribute column (e.g., sex, race)",
202
+ choices=[],
203
+ interactive=True
204
+ )
205
+
206
+ csv_input.change(
207
+ fn=get_columns,
208
+ inputs=csv_input,
209
+ outputs=[target_dropdown, sensitive_dropdown]
210
+ )
211
+
212
+ run_button = gr.Button("Train & Evaluate")
213
+
214
+ metrics_output = gr.Textbox(
215
+ label="Model & Fairness Metrics",
216
+ lines=12
217
+ )
218
+ group_table_output = gr.HTML(label="Group Selection Rates")
219
+ shap_image_output = gr.Image(label="SHAP Summary Plot")
220
+ preview_output = gr.HTML(label="Data Preview (first 5 rows)")
221
+
222
+ run_button.click(
223
+ fn=train_and_evaluate,
224
+ inputs=[csv_input, target_dropdown, sensitive_dropdown],
225
+ outputs=[metrics_output, group_table_output, shap_image_output, preview_output]
226
+ )
227
+
228
+ demo.launch()
229
+