Spaces:
Sleeping
Sleeping
| # Author: Juan Parras & Patricia A. Apellániz | |
| # Email: patricia.alonsod@upm.es | |
| # Date: 05/08/2025 | |
| # Package imports | |
| import os | |
| import sys | |
| import sympy | |
| import pickle | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from tqdm import tqdm | |
| from kan import ex_round | |
| from copy import deepcopy | |
| from tueplots import bundles | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| parent_dir = os.path.abspath(os.path.join(current_dir, "..")) | |
| sys.path.append(parent_dir) | |
| from src.data import load_data | |
| from src.utils import get_config, create_results_folder | |
| from src.models.models import Kan_model | |
| from src.models.model_utils import get_metrics | |
| from src.representation import plot_binary_explanation_plot, plot_sorted_variances, radar_factory | |
| def run_kan_gam(): | |
| _ = model.run_model(x_train, x_test, y_train, y_test) | |
| predicted_prob_test = model.predict_proba(x_test.to_numpy()) | |
| pred_proba_train = model.predict_proba(x_train.to_numpy()) | |
| y_pred_test = model.predict(x_test.to_numpy()) | |
| y_pred_train = model.predict(x_train.to_numpy()) | |
| if predicted_prob_test.shape[1] == 2: | |
| binary_data = True | |
| n_classes_data = 2 | |
| n_logits_data = 1 | |
| else: | |
| binary_data = False | |
| n_classes_data = predicted_prob_test.shape[1] | |
| n_logits_data = n_classes_data | |
| test_metrics = get_metrics(y_test.to_numpy(), y_pred_test, predicted_prob_test) | |
| train_metrics = get_metrics(y_train.to_numpy(), y_pred_train, pred_proba_train) | |
| return predicted_prob_test, binary_data, n_classes_data, n_logits_data, test_metrics, train_metrics | |
| def get_patient_values(formula, x_in): | |
| x = x_in.to_numpy() | |
| if isinstance(formula, sympy.Float): # WE have a constant! | |
| delta = np.zeros((x.shape[0], x.shape[1] + 1)) | |
| delta[:, -1] = float(formula) # There is only a constant term! | |
| else: | |
| delta = np.zeros( | |
| (x.shape[0], x.shape[1] + 1)) # One input per covariate, one extra output for the constant term | |
| for i in range(x.shape[0]): # For each patient | |
| for fs in formula.args: | |
| formula_sum_term = deepcopy(fs) | |
| if isinstance(formula_sum_term, sympy.Float): # We have a constant! | |
| delta[i, -1] = float(formula_sum_term) | |
| else: # Since it is a KAAM, it depends on a single variable | |
| assert len(formula_sum_term.free_symbols) == 1 | |
| variable_in_the_expresion = list(formula_sum_term.free_symbols)[0] | |
| variable_index = x_in.columns.get_loc(str(variable_in_the_expresion)) | |
| delta[i, variable_index] += float( | |
| formula_sum_term.subs(variable_in_the_expresion, x[i, variable_index])) | |
| delta = pd.DataFrame(delta, columns=x_in.columns.tolist() + ['const']) | |
| return delta | |
| def adjust_polynomial(tr_x, tr_y, test_x, test_y, metrics_train, metrics_test, binary_dataset, n_logits_dataset, results_folder, dataset_name): | |
| # lib = ['x', 'x^2', 'x^3', 'x^4', 'x^5'] # Polynomial model, used because it tends to provide a nice symbolic formula | |
| model.model(model.dataset['train_input']) # To have activations updated! | |
| model.model.auto_symbolic() # Add the lib in here if desired | |
| formula = model.model.symbolic_formula()[0] # We have several formulae, one per logit! | |
| formula = [ex_round(f, 3) for f in formula] # Number of digits to approximate | |
| if binary_dataset: | |
| delta_formula = [formula[1] - formula[0]] # Keep a single formula (this is the delta) | |
| else: | |
| delta_formula = [f for f in formula] # Keep all the formulae, one per logit. IMPORTANT NOTE: We could simplify the formulae, but this may get rid of the additive separability property!! | |
| # In the delta_formula, replace x_i by its name | |
| for i, col in enumerate(tr_x.columns): | |
| for j in range(n_logits_dataset): | |
| delta_formula[j] = delta_formula[j].subs(sympy.symbols(f'x_{i + 1}'), sympy.symbols(col)) | |
| # Save formula as a text file | |
| with open(os.path.join(results_folder, dataset_name, 'formula.txt'), 'w') as f: | |
| for i in range(n_logits_dataset): | |
| f.write(f"Logit {i}: {delta_formula[i]}\n") | |
| f.write(f"Logit {i} Latex: {sympy.latex(delta_formula[i])}\n") | |
| # Since the formula may have pruned variables, we keep only the variables that are present in the formula | |
| actual_vars = [] | |
| for f in delta_formula: | |
| actual_vars += [str(s) for s in f.free_symbols] | |
| actual_vars = list(set(actual_vars)) # Remove duplicates | |
| tr_x = tr_x[actual_vars] | |
| test_x = test_x[actual_vars] | |
| delta_train, delta_test = [], [] | |
| for i in range(n_logits_dataset): | |
| d = get_patient_values(delta_formula[i], tr_x) | |
| delta_train.append(d) | |
| d.to_csv(os.path.join(results_folder, dataset_name, f'delta_train_{i}.csv'), index=False) | |
| d = get_patient_values(delta_formula[i], test_x) | |
| delta_test.append(d) | |
| d.to_csv(os.path.join(results_folder, dataset_name, f'delta_test_{i}.csv'), index=False) | |
| if binary_dataset: | |
| proba_train_numpy = 1 / (1 + np.exp(-delta_train[0].sum(axis=1).values)) | |
| proba_test_numpy = 1 / (1 + np.exp(-delta_test[0].sum(axis=1).values)) | |
| values_train_numpy = (proba_train_numpy > 0.5).astype(int) | |
| values_test_numpy = (proba_test_numpy > 0.5).astype(int) | |
| else: | |
| proba_train_numpy = np.array( | |
| [np.exp(np.array(d).sum(axis=1)) / np.sum(np.exp(np.array(delta_train).sum(axis=2)), axis=0) for d | |
| in delta_train]).T | |
| proba_test_numpy = np.array( | |
| [np.exp(np.array(d).sum(axis=1)) / np.sum(np.exp(np.array(delta_test).sum(axis=2)), axis=0) for d in | |
| delta_test]).T | |
| values_train_numpy = np.argmax(proba_train_numpy, axis=1) | |
| values_test_numpy = np.argmax(proba_test_numpy, axis=1) | |
| metrics_train_numpy = get_metrics(tr_y, values_train_numpy, proba_train_numpy) | |
| metrics_test_numpy = get_metrics(test_y, values_test_numpy, proba_test_numpy) | |
| print("\n--------Metrics comparison--------") | |
| print(f"Train metrics without formula: {metrics_train}") | |
| print(f"Test metrics without formula: {metrics_test}") | |
| print(f"Train metrics with formula: {metrics_train_numpy}") | |
| print(f"Test metrics with formula: {metrics_test_numpy}") | |
| # Save metrics results in df | |
| metrics_df = pd.DataFrame([metrics_train, metrics_test, metrics_train_numpy, metrics_test_numpy], | |
| index=['train', 'test', 'train_formula', 'test_formula']) | |
| metrics_df.to_csv(os.path.join(args['results_folder'], dataset_name, 'metrics.csv')) | |
| if binary_dataset: | |
| with plt.rc_context({**bundles.icml2024(column='half', nrows=1, ncols=1)}): | |
| plot_binary_explanation_plot(test_y, | |
| proba_test_numpy, | |
| ['0', '1'], | |
| 0.5, | |
| os.path.join(args['results_folder'], dataset_name, 'prob_plot_formula.pdf'), | |
| title='Probability of positive class') | |
| plt.close() | |
| print(f"\nNumber of variables in formula: {len(actual_vars)}\nVariables: {actual_vars}") | |
| for i in range(n_logits_dataset): | |
| print(f"\nFormula for logit {i}: {delta_formula[i]}") | |
| return delta_formula, delta_train, delta_test, proba_train_numpy, proba_test_numpy, tr_x, test_x | |
| if __name__ == '__main__': | |
| # Get the configuration | |
| args = get_config('interpretability') | |
| create_results_folder(args['results_folder'], args) | |
| for dataset in args['datasets']: | |
| # Check if the best model is saved | |
| if os.path.exists(os.path.join(args['base_folder'], 'results_performance', dataset, 'kan_gam.pkl')): | |
| with open(os.path.join(args['base_folder'], 'results_performance', dataset, 'kan_gam.pkl'), 'rb') as f: | |
| metrics = pickle.load(f) | |
| print(f"\n\n------------Model for dataset {dataset} found------------\nUsing the following parameters:") | |
| for param in metrics: | |
| if param not in ['accuracy', 'precision', 'recall', 'f1', 'roc_auc', 'time', 'dataset']: | |
| print(f"{param}: {metrics[param]}") | |
| print('\n') | |
| model = Kan_model(hidden_dim=metrics['hidden_dim'], | |
| batch_size=metrics['batch_size'], | |
| grid=metrics['grid'], | |
| k=metrics['k'], | |
| seed=metrics['seed'], | |
| lr=metrics['lr'], | |
| early_stop=metrics['early_stop'], | |
| steps=metrics['steps'], | |
| lamb=metrics['lamb'], | |
| lamb_entropy=metrics['lamb_entropy'], | |
| weight=metrics['weight'], | |
| sparse_init=metrics['sparse_init'], | |
| mult_kan=metrics['mult_kan']) | |
| else: | |
| model = Kan_model() | |
| print(f"Model for dataset {dataset} not found. Using default parameters.") | |
| # Load the data and run model | |
| x_train, x_test, y_train, y_test = load_data(dataset, args) | |
| pred_prob_test, binary, n_classes, n_logits, test_m, train_m = run_kan_gam() | |
| # Rename y_train to have a "class_target" column | |
| y_train = pd.DataFrame(y_train.values, columns=['class_target']) | |
| y_test = pd.DataFrame(y_test.values, columns=['class_target']) | |
| # Represent predictions in histograms | |
| if binary: | |
| plot_binary_explanation_plot(y_test, | |
| pred_prob_test[:, 1], | |
| ['0', '1'], | |
| 0.5, | |
| os.path.join(args['results_folder'], dataset, 'prob_plot.pdf'), | |
| title='Probability of positive class') | |
| with plt.rc_context({**bundles.icml2024(column='half', nrows=1, ncols=1, usetex=True)}): | |
| model.model.plot(folder=os.path.join(args['results_folder'], dataset, 'kan'), | |
| in_vars=x_train.columns.tolist(), | |
| out_vars=[f'logit_{i}' for i in range(n_classes)], | |
| varscale=0.2, | |
| scale=1) | |
| plt.savefig(os.path.join(args['results_folder'], dataset, 'kan.pdf'), bbox_inches='tight', dpi=300) | |
| plt.close() | |
| # Adjust a polynomial model | |
| (delta_formula, delta_train, delta_test, proba_train_numpy, proba_test_numpy, x_train,x_test) = adjust_polynomial(x_train, y_train, x_test, y_test, train_m, test_m, binary, n_logits, args['results_folder'], dataset) | |
| # Plot of the sorted variances in training and testing for each logit (feat imp is the variance of the delta vals) | |
| plot_sorted_variances(x_train, x_test, binary, delta_train, delta_test, n_logits, args, dataset) | |
| ##### PATIENTS ##### | |
| # Create a folder for the patients | |
| patients_results_folder = os.path.join(args['results_folder'], dataset, 'patients') | |
| os.makedirs(patients_results_folder, exist_ok=True) | |
| if binary: # Add a dimension to the proba arrays | |
| proba_train_numpy = proba_train_numpy[:, None] | |
| proba_test_numpy = proba_test_numpy[:, None] | |
| n_dists = args['n_dists'] | |
| max_atribs_radar = args['max_atribs_radar'] | |
| max_pats_to_save = args['max_pats_to_save'] | |
| max_plot_curves = args['max_plot_curves'] | |
| for l in range(n_logits): | |
| logit = 1 if binary else l | |
| print(f"Processing logit {logit}") | |
| # Get the most important features for the radar plot, be careful to use only training data! | |
| variances = delta_train[l].var(axis=0) | |
| all_cols = delta_train[l].columns.tolist() | |
| idx_vars = np.argsort(variances.values)[::-1] | |
| num_of_zero_var = (variances < 1e-6).sum() | |
| idx_vars = idx_vars[:-num_of_zero_var] | |
| if delta_train[l].shape[1] > max_atribs_radar: | |
| idx_vars = idx_vars[:max_atribs_radar] | |
| for i in tqdm(range(min(x_test.shape[0], max_pats_to_save))): | |
| actual_label = y_test.iloc[i]['class_target'] | |
| current_patient_info = np.concatenate((x_test.iloc[i].values, [proba_test_numpy[i][l], actual_label])) | |
| # Find the n_dists closest patients in the training set | |
| dists = np.linalg.norm(delta_train[l] - delta_test[l].iloc[i], axis=1) | |
| idx_closest = np.argsort(dists)[:n_dists].tolist() | |
| pred_prob = (1 / (1 + np.exp(-delta_train[l].iloc[idx_closest].sum(axis=1)))).values | |
| real_label = y_train.iloc[idx_closest].values | |
| closest_data = x_train.iloc[idx_closest].values | |
| closest_data = np.concatenate((closest_data, pred_prob[:, None], real_label), axis=1) | |
| closest_data = np.vstack((current_patient_info[None, :], closest_data)) # Add the current patient as the first row | |
| new_df = pd.DataFrame(closest_data, columns=x_train.columns.tolist() + ['pred_prob', 'real_label']) | |
| # Limit all new_df values to having 3 decimal numbers at most | |
| new_df = new_df.map(lambda x: round(x, 3) if isinstance(x, float) else x) | |
| new_df.to_csv(os.path.join(patients_results_folder, f'patient_{i}_closest_{n_dists}_logit_{logit}.csv'), index=False) | |
| # Prepare the radar plot, show only the attributes with highest variance | |
| # Change the order of idx_vars and cols_vars to have the importance in clockwise order in the plot | |
| idx_vars = idx_vars[::-1] | |
| cols_vars = [all_cols[i] for i in idx_vars.tolist()] | |
| n_feats = min(max_atribs_radar, len(cols_vars)) # Number of features to show in the radar plot | |
| if n_feats >= 3: # We need at least 3 features to plot a proper radar plot | |
| theta = radar_factory(n_feats, frame='polygon') | |
| if binary: | |
| avg_proba = 1 / (1 + np.exp(-delta_train[l].mean(axis=0).sum())) * np.ones(len(cols_vars)) # Average probability (i.e., "average patient") | |
| title = f"Test Patient \n Predicted: {proba_test_numpy[i][l]:.3f} | True: {actual_label} | Average: {avg_proba[0]:.3f}" | |
| else: | |
| avg_proba = np.exp(delta_train[l].mean(axis=0).sum()) / sum( | |
| [np.exp(d.mean(axis=0).sum()) for d in delta_train]) * np.ones(len(cols_vars)) | |
| title = f"Test Patient \n Predicted: {proba_test_numpy[i][l]:.3f} | True: {actual_label} | Average: {avg_proba[0]:.3f}" | |
| with plt.rc_context({**bundles.icml2024(column='half', ncols=1, nrows=1, usetex=True)}): | |
| fig, ax = plt.subplots(subplot_kw=dict(projection='radar')) | |
| ax.set_rgrids([0.2, 0.4, 0.6, 0.8]) | |
| ax.set_title(title) | |
| # Plot the average of all train patients | |
| _ = ax.plot(theta, avg_proba, label='Average', color='tab:blue', linewidth=0.5) | |
| ax.fill(theta, avg_proba, alpha=0.1, color='tab:blue') | |
| # Prepare for individual patient plotting | |
| avg_delta = delta_train[l].mean(axis=0).values[None, :] | |
| avg_matrix = np.repeat(avg_delta, delta_train[l].shape[1], axis=0) | |
| # Plot the closest patients | |
| for j in range(n_dists): | |
| label = 'Closest patients' if j == 0 else None | |
| np.fill_diagonal(avg_matrix, delta_train[l].iloc[idx_closest[j]].values) | |
| if binary: | |
| pat_proba = 1 / (1 + np.exp(-avg_matrix.sum(axis=1))) | |
| else: | |
| den_term = np.zeros(delta_train[0].shape[1]) | |
| for ll in range(n_logits): | |
| den_matrix = np.repeat(delta_train[ll].mean(axis=0).values[None, :], | |
| delta_train[ll].shape[1], | |
| axis=0) | |
| np.fill_diagonal(den_matrix, delta_train[ll].iloc[idx_closest[j]].values) | |
| den_term += np.exp(den_matrix.sum(axis=1)) | |
| pat_proba = np.exp(avg_matrix.sum(axis=1)) / den_term | |
| _ = ax.plot(theta, pat_proba[idx_vars], label=label, color='tab:green', alpha=0.5) | |
| ax.fill(theta, pat_proba[idx_vars], alpha=0.1, color='tab:green') | |
| # Plot the current patient | |
| np.fill_diagonal(avg_matrix, delta_test[l].iloc[i].values) | |
| if binary: | |
| pat_proba = 1 / (1 + np.exp(-avg_matrix.sum(axis=1))) | |
| else: | |
| den_term = np.zeros(delta_train[0].shape[1]) | |
| for ll in range(n_logits): | |
| den_matrix = np.repeat(delta_train[ll].mean(axis=0).values[None, :], | |
| delta_train[ll].shape[1], | |
| axis=0) | |
| np.fill_diagonal(den_matrix, delta_test[ll].iloc[i].values) | |
| den_term += np.exp(den_matrix.sum(axis=1)) | |
| pat_proba = np.exp(avg_matrix.sum(axis=1)) / den_term | |
| _ = ax.plot(theta, pat_proba[idx_vars], label='Test patient', color='tab:red') | |
| ax.fill(theta, pat_proba[idx_vars], alpha=0.1, color='tab:red') | |
| ax.set_varlabels(cols_vars, fontsize=6) | |
| plt.legend(loc='lower center', | |
| bbox_to_anchor=(0.5, -0.5), | |
| ncol=3) # Note: this can be uncommented, but may clutter the plot | |
| plt.savefig(os.path.join(patients_results_folder, f'patient_{i}_radar_logit_{logit}.pdf'), | |
| bbox_inches='tight', | |
| dpi=600) | |
| plt.close() | |
| # Curves plot: show only the ones that do matter!! | |
| # Revert again the order to have the right plot order | |
| idx_vars = idx_vars[::-1] | |
| cols_vars = [all_cols[i] for i in idx_vars.tolist()] | |
| n_feats = min(len(cols_vars), max_plot_curves) # Number of features to show in the curves plot | |
| with plt.rc_context({**bundles.icml2024(column='full', ncols=1, nrows=1, usetex=True)}): | |
| if n_feats > 0: # There is something to show | |
| if binary: | |
| fig, axs = plt.subplots(n_feats + 1, 1, figsize=(3, 5.5)) | |
| x_vals = np.arange(delta_train[l].sum(axis=1).min(), delta_train[l].sum(axis=1).max(), 0.01) | |
| theor_proba = 1 / (1 + np.exp(-x_vals)) | |
| axs[0].plot(x_vals, theor_proba, 'b', alpha=0.2) | |
| axs[0].scatter(delta_test[l].sum(axis=1)[i], proba_test_numpy[i][l], color='r') | |
| axs[0].set_xlabel('Logit') | |
| axs[0].set_ylabel('Probability') | |
| axs[0].set_title(f'Patient {i}') | |
| else: | |
| fig, axs = plt.subplots(n_feats, 1, figsize=(3, 5.5)) | |
| for idj, feat_name in enumerate(cols_vars): | |
| if idj < n_feats: # Only plot the first n_feats features | |
| if binary: | |
| j = idj + 1 # The first plot is already used for the theoretical curve | |
| else: | |
| j = idj | |
| # Keep only unique values of x_test[feat_name] | |
| idxs = np.unique(x_train[feat_name].values, return_index=True)[1] | |
| if n_feats > 1: # Multiple plots | |
| axs[j].plot(x_train[feat_name].values[idxs], | |
| delta_train[l][feat_name].values[idxs], | |
| color='tab:blue') | |
| axs[j].scatter(x_train[feat_name].values[idxs], | |
| delta_train[l][feat_name].values[idxs], | |
| color='tab:blue', | |
| s=6, | |
| alpha=0.4) | |
| axs[j].scatter(x_test[feat_name].values[i], | |
| delta_test[l][feat_name].values[i], | |
| color='tab:red', | |
| s=40, | |
| alpha=1) | |
| for jj in range(n_dists): | |
| axs[j].scatter(x_train.iloc[idx_closest[jj]][feat_name], | |
| delta_train[l].iloc[idx_closest[jj]][feat_name], | |
| color='tab:green', | |
| s=8, | |
| alpha=1) | |
| axs[j].set_ylabel('Contribution') | |
| axs[j].set_xlabel(feat_name) | |
| else: # Single plot | |
| axs.plot(x_train[feat_name].values[idxs], | |
| delta_train[l][feat_name].values[idxs], | |
| color='tab:blue') | |
| axs.scatter(x_train[feat_name].values[idxs], | |
| delta_train[l][feat_name].values[idxs], | |
| color='tab:blue', | |
| s=6, | |
| alpha=0.4) | |
| axs.scatter(x_test[feat_name].values[i], | |
| delta_test[l][feat_name].values[i], | |
| color='tab:red', | |
| s=40, | |
| alpha=1) | |
| for jj in range(n_dists): | |
| axs.scatter(x_train.iloc[idx_closest[jj]][feat_name], | |
| delta_train[l].iloc[idx_closest[jj]][feat_name], | |
| color='tab:green', | |
| s=8, | |
| alpha=1) | |
| axs.set_ylabel('Contribution') | |
| axs.set_xlabel(feat_name) | |
| red_patch = mpatches.Patch(color='tab:red', label='Test patient') | |
| green_patch = mpatches.Patch(color='tab:green', label='Closest patients') | |
| blue_patch = mpatches.Patch(color='tab:blue', label='Train patients') | |
| plt.tight_layout(rect=[0, 0.05, 1, 1]) | |
| # TODO: Adjust legend box location based on dataset!!! | |
| if n_feats > 1: | |
| axs[0].set_title(f'Patient PDPs') | |
| axs[-1].legend(handles=[red_patch, green_patch, blue_patch], loc='lower center', | |
| bbox_to_anchor=(0.5, -1.0), ncol=3) | |
| else: | |
| axs.set_title(f'Patient PDPs') | |
| axs.legend(handles=[red_patch, green_patch, blue_patch], loc='lower center', | |
| bbox_to_anchor=(0.5, -0.7), ncol=3) | |
| plt.rcParams.update(bundles.icml2024(usetex=False)) | |
| plt.savefig(os.path.join(patients_results_folder, f'patient_{i}_curves_logit_{logit}.pdf'), | |
| bbox_inches='tight', | |
| dpi=600) | |
| plt.close() | |