Spaces:
Runtime error
Runtime error
| # Copyright 2019 Seth V. Neel, Michael J. Kearns, Aaron L. Roth, Zhiwei Steven Wu | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); you may not | |
| # use this file except in compliance with the License. You may obtain a copy of | |
| # the License at http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software distributed | |
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| # specific language governing permissions and limitations under the License. | |
| """Class GerryFairClassifier implementing the 'FairFictPlay' Algorithm of [KRNW18]. | |
| This module contains functionality to instantiate, fit, and predict | |
| using the FairFictPlay algorithm of: | |
| https://arxiv.org/abs/1711.05144 | |
| It also contains the ability to audit arbitrary classifiers for | |
| rich subgroup unfairness, where rich subgroups are defined by hyperplanes | |
| over the sensitive attributes. This iteration of the codebase supports hyperplanes, trees, | |
| kernel methods, and support vector machines. For usage examples refer to examples/gerry_plots.ipynb | |
| """ | |
| import copy | |
| from aif360.algorithms.inprocessing.gerryfair import heatmap | |
| from aif360.algorithms.inprocessing.gerryfair.clean import array_to_tuple | |
| from aif360.algorithms.inprocessing.gerryfair.learner import Learner | |
| from aif360.algorithms.inprocessing.gerryfair.auditor import * | |
| from aif360.algorithms.inprocessing.gerryfair.classifier_history import ClassifierHistory | |
| from aif360.algorithms import Transformer | |
| class GerryFairClassifier(Transformer): | |
| """Model is an algorithm for learning classifiers that are fair with respect | |
| to rich subgroups. | |
| Rich subgroups are defined by (linear) functions over the sensitive | |
| attributes, and fairness notions are statistical: false positive, false | |
| negative, and statistical parity rates. This implementation uses a max of | |
| two regressions as a cost-sensitive classification oracle, and supports | |
| linear regression, support vector machines, decision trees, and kernel | |
| regression. For details see: | |
| References: | |
| .. [1] "Preventing Fairness Gerrymandering: Auditing and Learning for | |
| Subgroup Fairness." Michale Kearns, Seth Neel, Aaron Roth, Steven Wu. | |
| ICML '18. | |
| .. [2] "An Empirical Study of Rich Subgroup Fairness for Machine | |
| Learning". Michael Kearns, Seth Neel, Aaron Roth, Steven Wu. FAT '19. | |
| """ | |
| def __init__(self, C=10, printflag=False, heatmapflag=False, | |
| heatmap_iter=10, heatmap_path='.', max_iters=10, gamma=0.01, | |
| fairness_def='FP', predictor=linear_model.LinearRegression()): | |
| """Initialize Model Object and set hyperparameters. | |
| Args: | |
| C: Maximum L1 Norm for the Dual Variables (hyperparameter) | |
| printflag: Print Output Flag | |
| heatmapflag: Save Heatmaps every heatmap_iter Flag | |
| heatmap_iter: Save Heatmaps every heatmap_iter | |
| heatmap_path: Save Heatmaps path | |
| max_iters: Time Horizon for the fictitious play dynamic. | |
| gamma: Fairness Approximation Paramater | |
| fairness_def: Fairness notion, FP, FN, SP. | |
| errors: see fit() | |
| fairness_violations: see fit() | |
| predictor: Hypothesis class for the Learner. Supports LR, SVM, KR, | |
| Trees. | |
| """ | |
| super(GerryFairClassifier, self).__init__() | |
| self.C = C | |
| self.printflag = printflag | |
| self.heatmapflag = heatmapflag | |
| self.heatmap_iter = heatmap_iter | |
| self.heatmap_path = heatmap_path | |
| self.max_iters = max_iters | |
| self.gamma = gamma | |
| self.fairness_def = fairness_def | |
| self.predictor = predictor | |
| self.classifiers = None | |
| self.errors = None | |
| self.fairness_violations = None | |
| if self.fairness_def not in ['FP', 'FN']: | |
| raise Exception( | |
| 'This metric is not yet supported for learning. Metric specified: {}.' | |
| .format(self.fairness_def)) | |
| def fit(self, dataset, early_termination=True): | |
| """Run Fictitious play to compute the approximately fair classifier. | |
| Args: | |
| dataset: dataset object with its own class definition in datasets | |
| folder inherits from class StandardDataset. | |
| early_termination: Terminate Early if Auditor can't find fairness | |
| violation of more than gamma. | |
| Returns: | |
| Self | |
| """ | |
| # defining variables and data structures for algorithm | |
| X, X_prime, y = clean.extract_df_from_ds(dataset) | |
| learner = Learner(X, y, self.predictor) | |
| auditor = Auditor(dataset, self.fairness_def) | |
| history = ClassifierHistory() | |
| # initialize variables | |
| n = X.shape[0] | |
| costs_0, costs_1, X_0 = auditor.initialize_costs(n) | |
| metric_baseline = 0 | |
| predictions = [0.0] * n | |
| # scaling variables for heatmap | |
| vmin = None | |
| vmax = None | |
| # print output variables | |
| errors = [] | |
| fairness_violations = [] | |
| iteration = 1 | |
| while iteration < self.max_iters: | |
| # learner's best response: solve the CSC problem, get mixture decisions on X to update prediction probabilities | |
| history.append_classifier(learner.best_response(costs_0, costs_1)) | |
| error, predictions = learner.generate_predictions( | |
| history.get_most_recent(), predictions, iteration) | |
| # auditor's best response: find group, update costs | |
| metric_baseline = auditor.get_baseline(y, predictions) | |
| group = auditor.get_group(predictions, metric_baseline) | |
| costs_0, costs_1 = auditor.update_costs(costs_0, costs_1, group, | |
| self.C, iteration, | |
| self.gamma) | |
| # outputs | |
| errors.append(error) | |
| fairness_violations.append(group.weighted_disparity) | |
| self.print_outputs(iteration, error, group) | |
| vmin, vmax = self.save_heatmap( | |
| iteration, dataset, | |
| history.get_most_recent().predict(X), vmin, vmax) | |
| iteration += 1 | |
| # early termination: | |
| if early_termination and (len(errors) >= 5) and ( | |
| (errors[-1] == errors[-2]) or fairness_violations[-1] == fairness_violations[-2]) and \ | |
| fairness_violations[-1] < self.gamma: | |
| iteration = self.max_iters | |
| self.classifiers = history.classifiers | |
| self.errors = errors | |
| self.fairness_violations = fairness_violations | |
| return self | |
| def predict(self, dataset, threshold=.5): | |
| """Return dataset object where labels are the predictions returned by | |
| the fitted model. | |
| Args: | |
| dataset: dataset object with its own class definition in datasets | |
| folder inherits from class StandardDataset. | |
| threshold: The positive prediction cutoff for the soft-classifier. | |
| Returns: | |
| dataset_new: modified dataset object where the labels attribute are | |
| the predictions returned by the self model | |
| """ | |
| # Generates predictions. | |
| dataset_new = copy.deepcopy(dataset) | |
| data, _, _ = clean.extract_df_from_ds(dataset_new) | |
| num_classifiers = len(self.classifiers) | |
| y_hat = None | |
| for hyp in self.classifiers: | |
| new_predictions = hyp.predict(data)/num_classifiers | |
| if y_hat is None: | |
| y_hat = new_predictions | |
| else: | |
| y_hat = np.add(y_hat, new_predictions) | |
| if threshold: | |
| dataset_new.labels = np.asarray( | |
| [1 if y >= threshold else 0 for y in y_hat]) | |
| else: | |
| dataset_new.labels = np.asarray([y for y in y_hat]) | |
| # ensure ndarray is formatted correctly | |
| dataset_new.labels.resize(dataset.labels.shape, refcheck=True) | |
| return dataset_new | |
| def print_outputs(self, iteration, error, group): | |
| """Helper function to print outputs at each iteration of fit. | |
| Args: | |
| iteration: current iter | |
| error: most recent error | |
| group: most recent group found by the auditor | |
| """ | |
| if self.printflag: | |
| print( | |
| 'iteration: {}, error: {}, fairness violation: {}, violated group size: {}' | |
| .format(int(iteration), error, group.weighted_disparity, | |
| group.group_size)) | |
| def save_heatmap(self, iteration, dataset, predictions, vmin, vmax): | |
| """Helper Function to save the heatmap. | |
| Args: | |
| iteration: current iteration | |
| dataset: dataset object with its own class definition in datasets | |
| folder inherits from class StandardDataset. | |
| predictions: predictions of the model self on dataset. | |
| vmin: see documentation of heatmap.py heat_map function | |
| vmax: see documentation of heatmap.py heat_map function | |
| Returns: | |
| (vmin, vmax) | |
| """ | |
| X, X_prime, y = clean.extract_df_from_ds(dataset) | |
| # save heatmap every heatmap_iter iterations or the last iteration | |
| if (self.heatmapflag and (iteration % self.heatmap_iter) == 0): | |
| # initial heat map | |
| X_prime_heat = X_prime.iloc[:, 0:2] | |
| eta = 0.1 | |
| minmax = heatmap.heat_map( | |
| X, X_prime_heat, y, predictions, eta, | |
| self.heatmap_path + '/heatmap_iteration_{}'.format(iteration), | |
| vmin, vmax) | |
| if iteration == 1: | |
| vmin = minmax[0] | |
| vmax = minmax[1] | |
| return vmin, vmax | |
| def generate_heatmap(self, | |
| dataset, | |
| predictions, | |
| vmin=None, | |
| vmax=None, | |
| cols_index=[0, 1], | |
| eta=.1): | |
| """Helper Function to generate the heatmap at the current time. | |
| Args: | |
| iteration:current iteration | |
| dataset: dataset object with its own class definition in datasets | |
| folder inherits from class StandardDataset. | |
| predictions: predictions of the model self on dataset. | |
| vmin: see documentation of heatmap.py heat_map function | |
| vmax: see documentation of heatmap.py heat_map function | |
| """ | |
| X, X_prime, y = clean.extract_df_from_ds(dataset) | |
| # save heatmap every heatmap_iter iterations or the last iteration | |
| X_prime_heat = X_prime.iloc[:, cols_index] | |
| minmax = heatmap.heat_map(X, X_prime_heat, y, predictions, eta, | |
| self.heatmap_path, vmin, vmax) | |
| def pareto(self, dataset, gamma_list): | |
| """Assumes Model has FP specified for metric. Trains for each value of | |
| gamma, returns error, FP (via training), and FN (via auditing) values. | |
| Args: | |
| dataset: dataset object with its own class definition in datasets | |
| folder inherits from class StandardDataset. | |
| gamma_list: the list of gamma values to generate the pareto curve | |
| Returns: | |
| list of errors, list of fp violations of those models, list of fn | |
| violations of those models | |
| """ | |
| C = self.C | |
| max_iters = self.max_iters | |
| # Store errors and fp over time for each gamma | |
| # change var names, but no real dependence on FP logic | |
| all_errors = [] | |
| all_fp_violations = [] | |
| all_fn_violations = [] | |
| self.C = C | |
| self.max_iters = max_iters | |
| auditor = Auditor(dataset, 'FN') | |
| for g in gamma_list: | |
| self.gamma = g | |
| fitted_model = self.fit(dataset, early_termination=True) | |
| errors, fairness_violations = fitted_model.errors, fitted_model.fairness_violations | |
| predictions = array_to_tuple((self.predict(dataset)).labels) | |
| _, fn_violation = auditor.audit(predictions) | |
| all_errors.append(errors[-1]) | |
| all_fp_violations.append(fairness_violations[-1]) | |
| all_fn_violations.append(fn_violation) | |
| return all_errors, all_fp_violations, all_fn_violations | |