| import numpy as np |
| import scipy |
| from sklearn import ensemble |
| from sklearn import metrics |
| from sklearn import linear_model |
| import torch |
| from munkres import Munkres |
| from scipy.optimize import linear_sum_assignment |
|
|
|
|
| import numpy as np |
| import scipy as sp |
| from sklearn.ensemble import GradientBoostingRegressor |
| from sklearn.metrics import mean_squared_error |
|
|
| import numpy as np |
|
|
|
|
|
|
|
|
| def generate_batch_factor_code(ground_truth_data, representation_function, num_points, random_state, batch_size): |
| """Sample a single training sample based on a mini-batch of ground-truth data. |
| |
| Args: |
| ground_truth_data: GroundTruthData to be sampled from. |
| representation_function: Function that takes observation as input and |
| outputs a representation. |
| num_points: Number of points to sample. |
| random_state: Numpy random state used for randomness. |
| batch_size: Batchsize to sample points. |
| |
| Returns: |
| representations: Codes (num_codes, num_points)-np array. |
| factors: Factors generating the codes (num_factors, num_points)-np array. |
| """ |
| representations = None |
| factors = None |
| i = 0 |
| while i < num_points: |
| num_points_iter = min(num_points - i, batch_size) |
| current_factors, current_observations = \ |
| ground_truth_data.sample(num_points_iter, random_state) |
| if i == 0: |
| factors = current_factors |
| representations = representation_function(current_observations) |
| else: |
| factors = np.vstack((factors, current_factors)) |
| representations = np.vstack((representations, |
| representation_function( |
| current_observations))) |
| i += num_points_iter |
| return np.transpose(representations), np.transpose(factors) |
|
|
|
|
| |
|
|
| |
|
|
|
|
|
|
| def compute_irs(rep, y, diff_quantile=0.99): |
| """Computes the Interventional Robustness Score. |
| |
| Args: |
| ground_truth_data: GroundTruthData to be sampled from. |
| representation_function: Function that takes observations as input and |
| outputs a dim_representation sized representation for each observation. |
| random_state: Numpy random state used for randomness. |
| artifact_dir: Optional path to directory where artifacts can be saved. |
| diff_quantile: Float value between 0 and 1 to decide what quantile of diffs |
| to select (use 1.0 for the version in the paper). |
| num_train: Number of points used for training. |
| batch_size: Batch size for sampling. |
| |
| Returns: |
| Dict with IRS and number of active dimensions. |
| """ |
|
|
| |
| |
| |
| |
|
|
|
|
| if not rep.any(): |
| irs_score = 0.0 |
| else: |
| irs_score = scalable_disentanglement_score(y.T, rep.T, diff_quantile)["avg_score"] |
|
|
| score_dict = {} |
| score_dict["IRS"] = irs_score |
| score_dict["num_active_dims"] = np.sum(rep) |
| |
| return score_dict |
|
|
|
|
| def _drop_constant_dims(ys): |
| """Returns a view of the matrix `ys` with dropped constant rows.""" |
| ys = np.asarray(ys) |
| if ys.ndim != 2: |
| raise ValueError("Expecting a matrix.") |
|
|
| variances = ys.var(axis=1) |
| active_mask = variances > 0. |
| |
| return ys[active_mask, :] |
|
|
|
|
| def scalable_disentanglement_score(gen_factors, latents, diff_quantile=0.99): |
| """Computes IRS scores of a dataset. |
| |
| Assumes no noise in X and crossed generative factors (i.e. one sample per |
| combination of gen_factors). Assumes each g_i is an equally probable |
| realization of g_i and all g_i are independent. |
| |
| Args: |
| gen_factors: Numpy array of shape (num samples, num generative factors), |
| matrix of ground truth generative factors. |
| latents: Numpy array of shape (num samples, num latent dimensions), matrix |
| of latent variables. |
| diff_quantile: Float value between 0 and 1 to decide what quantile of diffs |
| to select (use 1.0 for the version in the paper). |
| |
| Returns: |
| Dictionary with IRS scores. |
| """ |
| num_gen = gen_factors.shape[1] |
| num_lat = latents.shape[1] |
|
|
| |
| max_deviations = np.max(np.abs(latents - latents.mean(axis=0)), axis=0) |
| cum_deviations = np.zeros([num_lat, num_gen]) |
| for i in range(num_gen): |
| unique_factors = np.unique(gen_factors[:, i], axis=0) |
| assert unique_factors.ndim == 1 |
| num_distinct_factors = unique_factors.shape[0] |
| for k in range(num_distinct_factors): |
| |
| match = gen_factors[:, i] == unique_factors[k] |
| e_loc = np.mean(latents[match, :], axis=0) |
|
|
| |
| diffs = np.abs(latents[match, :] - e_loc) |
| max_diffs = np.percentile(diffs, q=diff_quantile*100, axis=0) |
| cum_deviations[:, i] += max_diffs |
| cum_deviations[:, i] /= num_distinct_factors |
| |
| normalized_deviations = cum_deviations / max_deviations[:, np.newaxis] |
| irs_matrix = 1.0 - normalized_deviations |
| disentanglement_scores = irs_matrix.max(axis=1) |
| if np.sum(max_deviations) > 0.0: |
| avg_score = np.average(disentanglement_scores, weights=max_deviations) |
| else: |
| avg_score = np.mean(disentanglement_scores) |
|
|
| parents = irs_matrix.argmax(axis=1) |
| score_dict = {} |
| score_dict["disentanglement_scores"] = disentanglement_scores |
| score_dict["avg_score"] = avg_score |
| score_dict["parents"] = parents |
| score_dict["IRS_matrix"] = irs_matrix |
| score_dict["max_deviations"] = max_deviations |
| |
| return score_dict |
|
|
|
|
| def _compute_dci(mus_train, ys_train, mus_test, ys_test): |
| """Computes score based on both training and testing codes and factors.""" |
| scores = {} |
| importance_matrix, train_err, test_err = compute_importance_gbt( |
| mus_train, ys_train, mus_test, ys_test) |
| assert importance_matrix.shape[0] == mus_train.shape[0] |
| assert importance_matrix.shape[1] == ys_train.shape[0] |
| scores["informativeness_train"] = train_err |
| scores["informativeness_test"] = test_err |
| disent, code_importance = disentanglement(importance_matrix) |
| scores["disentanglement"] = disent |
| scores["completeness"] = completeness(importance_matrix) |
| return scores, importance_matrix, code_importance |
|
|
|
|
| def compute_importance_gbt(x_train, y_train, x_test, y_test): |
| """Compute importance based on gradient boosted trees.""" |
| num_factors = y_train.shape[0] |
| num_codes = x_train.shape[0] |
| importance_matrix = np.zeros(shape=[num_codes, num_factors], |
| dtype=np.float64) |
| train_loss = [] |
| test_loss = [] |
| for i in range(num_factors): |
| |
| |
| |
| model = ensemble.GradientBoostingRegressor() |
| model.fit(x_train.T, y_train[i, :]) |
| importance_matrix[:, i] = np.abs(model.feature_importances_) |
| train_loss.append(np.mean(model.predict(x_train.T) == y_train[i, :])) |
| test_loss.append(np.mean(model.predict(x_test.T) == y_test[i, :])) |
| return importance_matrix, np.mean(train_loss), np.mean(test_loss) |
|
|
|
|
| def disentanglement_per_code(importance_matrix): |
| """Compute disentanglement score of each code.""" |
| |
| return 1. - scipy.stats.entropy(importance_matrix.T + 1e-11, |
| base=importance_matrix.shape[1]) |
|
|
|
|
| def disentanglement(importance_matrix): |
| """Compute the disentanglement score of the representation.""" |
| per_code = disentanglement_per_code(importance_matrix) |
| if importance_matrix.sum() == 0.: |
| importance_matrix = np.ones_like(importance_matrix) |
| code_importance = importance_matrix.sum(axis=1) / importance_matrix.sum() |
| |
| return np.sum(per_code*code_importance), code_importance |
|
|
|
|
| def completeness_per_factor(importance_matrix): |
| """Compute completeness of each factor.""" |
| |
| return 1. - scipy.stats.entropy(importance_matrix + 1e-11, |
| base=importance_matrix.shape[0]) |
|
|
|
|
| def completeness(importance_matrix): |
| """"Compute completeness of the representation.""" |
| per_factor = completeness_per_factor(importance_matrix) |
| if importance_matrix.sum() == 0.: |
| importance_matrix = np.ones_like(importance_matrix) |
| factor_importance = importance_matrix.sum(axis=0) / importance_matrix.sum() |
| return np.sum(per_factor*factor_importance) |
|
|
|
|
| def MCC(Z, Zp): |
| n = np.shape(Z)[1] |
| |
| rho_matrix = np.zeros((n, n)) |
| for i in range(n): |
| for j in range(n): |
| rho_matrix[i, j] = np.abs(np.corrcoef(Z[:, i], Zp[:, j])[0, 1]) |
|
|
| r, c = linear_sum_assignment(-rho_matrix) |
|
|
| return np.mean(rho_matrix[r, c]) |
|
|
|
|
| def r2_disentanglement(z, hz, mode = "r2", reorder=None): |
| """Measure how well hz reconstructs z measured either by the Coefficient of Determination or the |
| Pearson/Spearman correlation coefficient.""" |
|
|
| assert mode in ("r2", "adjusted_r2", "pearson", "spearman") |
|
|
| if mode == "r2": |
| |
| |
| |
| r2_i = [] |
| for i in range(z.shape[0]): |
| r2_i.append(metrics.r2_score(z[i], hz[i])) |
| print(metrics.r2_score(z[i], hz[i])) |
|
|
| return sum(r2_i) / len(r2_i) |
| elif mode == "adjusted_r2": |
| r2 = metrics.r2_score(z, hz) |
| |
| n = z.shape[0] |
| |
| p = z.shape[1] |
| adjusted_r2 = 1.0 - (1.0 - r2) * (n - 1) / (n - p - 1) |
| return adjusted_r2, None |
| elif mode in ("spearman", "pearson"): |
| dim = z.shape[-1] |
|
|
| if mode == "spearman": |
| raw_corr, pvalue = scipy.stats.spearmanr(z, hz) |
| else: |
| raw_corr = np.corrcoef(z.T, hz.T) |
| corr = raw_corr[:dim, dim:] |
|
|
| if reorder: |
| |
| munk = Munkres() |
| indexes = munk.compute(-np.absolute(corr)) |
|
|
| sort_idx = np.zeros(dim) |
| hz_sort = np.zeros(z.shape) |
| for i in range(dim): |
| sort_idx[i] = indexes[i][1] |
| hz_sort[:, i] = hz[:, indexes[i][1]] |
|
|
| if mode == "spearman": |
| raw_corr, pvalue = scipy.stats.spearmanr(z, hz_sort) |
| else: |
| raw_corr = np.corrcoef(z.T, hz_sort.T) |
|
|
| corr = raw_corr[:dim, dim:] |
|
|
| return np.diag(np.abs(corr)).mean(), corr |
|
|
| |
| |
| def linear_disentanglement(z, hz, mode="r2", train_test_split=None): |
| """Calculate disentanglement up to linear transformations. |
| Args: |
| z: Ground-truth latents. |
| hz: Reconstructed latents. |
| mode: Can be r2, pearson, spearman |
| train_test_split: Use first half to train linear model, second half to test. |
| Is only relevant if there are less samples then latent dimensions. |
| """ |
|
|
| if torch.is_tensor(hz): |
| hz = hz.detach().cpu().numpy() |
| if torch.is_tensor(z): |
| z = z.detach().cpu().numpy() |
|
|
| |
| |
|
|
| |
| if train_test_split: |
| n_train = len(z) // 2 |
| z_1 = z[:n_train] |
| hz_1 = hz[:n_train] |
| z_2 = z[n_train:] |
| hz_2 = hz[n_train:] |
| else: |
| z_1 = z |
| hz_1 = hz |
| z_2 = z |
| hz_2 = hz |
|
|
| model = linear_model.LinearRegression() |
| model.fit(hz_1, z_1) |
|
|
| hz_2 = model.predict(hz_2) |
|
|
| inner_result = _disentanglement(z_2, hz_2, mode=mode, reorder=False) |
|
|
| return inner_result, (z_2, hz_2) |
|
|
|
|
| def _disentanglement(z, hz, mode="r2", reorder=None): |
| """Measure how well hz reconstructs z measured either by the Coefficient of Determination or the |
| Pearson/Spearman correlation coefficient.""" |
|
|
| |
|
|
| if mode == "r2": |
| return metrics.r2_score(z, hz), None |
| elif mode == "adjusted_r2": |
| r2 = metrics.r2_score(z, hz) |
| |
| n = z.shape[0] |
| |
| p = z.shape[1] |
| adjusted_r2 = 1.0 - (1.0 - r2) * (n - 1) / (n - p - 1) |
| return adjusted_r2, None |
| elif mode in ("spearman", "pearson"): |
| dim = z.shape[-1] |
|
|
| if mode == "spearman": |
| raw_corr, pvalue = sp.stats.spearmanr(z, hz) |
| else: |
| raw_corr = np.corrcoef(z.T, hz.T) |
| corr = raw_corr[:dim, dim:] |
|
|
| if reorder: |
| |
| munk = Munkres() |
| indexes = munk.compute(-np.absolute(corr)) |
|
|
| sort_idx = np.zeros(dim) |
| hz_sort = np.zeros(z.shape) |
| for i in range(dim): |
| sort_idx[i] = indexes[i][1] |
| hz_sort[:, i] = hz[:, indexes[i][1]] |
|
|
| if mode == "spearman": |
| raw_corr, pvalue = sp.stats.spearmanr(z, hz_sort) |
| else: |
| raw_corr = np.corrcoef(z.T, hz_sort.T) |
|
|
| corr = raw_corr[:dim, dim:] |
|
|
| return np.diag(np.abs(corr)).mean(), corr |
| |
| |
| def permutation_disentanglement( |
| z, |
| hz, |
| mode="r2", |
| rescaling=True, |
| solver="naive", |
| sign_flips=True, |
| cache_permutations=None, |
| ): |
| """Measure disentanglement up to permutations by either using the Munkres solver |
| or naively trying out every possible permutation. |
| Args: |
| z: Ground-truth latents. |
| hz: Reconstructed latents. |
| mode: Can be r2, pearson, spearman |
| rescaling: Rescale every individual latent to maximize the agreement |
| with the ground-truth. |
| solver: How to find best possible permutation. Either use Munkres algorithm |
| or naively test every possible permutation. |
| sign_flips: Only relevant for `naive` solver. Also include sign-flips in |
| set of possible permutations to test. |
| cache_permutations: Only relevant for `naive` solver. Cache permutation matrices |
| to allow faster access if called multiple times. |
| """ |
|
|
| assert solver in ("naive", "munkres") |
| if mode == "r2" or mode == "adjusted_r2": |
| assert solver == "naive", "R2 coefficient is only supported with naive solver" |
|
|
| if cache_permutations and not hasattr( |
| permutation_disentanglement, "permutation_matrices" |
| ): |
| permutation_disentanglement.permutation_matrices = dict() |
|
|
| if torch.is_tensor(hz): |
| hz = hz.detach().cpu().numpy() |
| if torch.is_tensor(z): |
| z = z.detach().cpu().numpy() |
|
|
| assert isinstance(z, np.ndarray), "Either pass a torch tensor or numpy array as z" |
| assert isinstance(hz, np.ndarray), "Either pass a torch tensor or numpy array as hz" |
|
|
| def test_transformation(T, reorder): |
| |
|
|
| Thz = hz @ T |
| if rescaling: |
| assert z.shape == hz.shape |
| |
| Y = z |
| X = hz |
|
|
| beta = np.diag((Y * X).sum(0) / (X ** 2).sum(0)) |
|
|
| Thz = X @ beta |
|
|
| return _disentanglement(z, Thz, mode=mode, reorder=reorder), Thz |
|
|
| def gen_permutations(n): |
| |
|
|
| def gen_permutation_single_row(basis, row, sign_flips=False): |
| |
| |
| basis = basis.clone() |
| basis[row] = 0 |
| for i in range(basis.shape[-1]): |
| |
| |
| if torch.sum(torch.abs(basis[:row, i])) > 0: |
| continue |
| signs = [1] |
| if sign_flips: |
| signs += [-1] |
|
|
| for sign in signs: |
| T = basis.clone() |
| T[row, i] = sign |
|
|
| yield T |
|
|
| def gen_permutations_all_rows(basis, current_row=0, sign_flips=False): |
| |
|
|
| for T in gen_permutation_single_row(basis, current_row, sign_flips): |
| if current_row == len(basis) - 1: |
| yield T.numpy() |
| else: |
| |
| yield from gen_permutations_all_rows(T, current_row + 1, sign_flips) |
|
|
| basis = torch.zeros((n, n)) |
|
|
| yield from gen_permutations_all_rows(basis, sign_flips=sign_flips) |
|
|
| n = z.shape[-1] |
| |
| if cache_permutations and not solver == "munkres": |
| key = (rescaling, n) |
| if not key in permutation_disentanglement.permutation_matrices: |
| permutation_disentanglement.permutation_matrices[key] = list( |
| gen_permutations(n) |
| ) |
| permutations = permutation_disentanglement.permutation_matrices[key] |
| else: |
| if solver == "naive": |
| permutations = list(gen_permutations(n)) |
| elif solver == "munkres": |
| permutations = [np.eye(n, dtype=z.dtype)] |
|
|
| scores = [] |
|
|
| |
| for T in permutations: |
| scores.append(test_transformation(T, solver == "munkres")) |
|
|
| return max(scores, key=lambda x: x[0][0]) |