import shap import pandas as pd import os from pathlib import Path import seaborn as sns import matplotlib.pyplot as plt import torch import numpy as np import pickle from sklearn.metrics import mean_absolute_error from scipy.stats import pearsonr from models.tabular.widedeep.ft_transformer import WDFTTransformerModel import gradio as gr from scipy import stats import warnings warnings.filterwarnings("ignore", ".*will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.*") warnings.filterwarnings("ignore", ".*is non-interactive, and thus cannot be shown*") root_dir = Path(os.getcwd()) fn_model = f"{root_dir}/data/model.ckpt" model = WDFTTransformerModel.load_from_checkpoint(checkpoint_path=fn_model) model.eval() model.freeze() feats = [ 'CXCL9', 'CCL22', 'IL6', 'PDGFB', 'CD40LG', 'IL27', 'VEGFA', 'CSF1', 'PDGFA', 'CXCL10' ] fn_shap = f"{root_dir}/data/shap.pickle" out_dir = f"{root_dir}/out" if not os.path.exists(out_dir): os.makedirs(out_dir) def predict_func(x): batch = { 'all': torch.from_numpy(np.float32(x)), 'continuous': torch.from_numpy(np.float32(x)), 'categorical': torch.from_numpy(np.int32(x[:, []])), } return model(batch).cpu().detach().numpy() with open(fn_shap, 'rb') as handle: shap_dict = pickle.load(handle) values_train = shap_dict['values_train'] shap_values_train = shap_dict['shap_values_train'] explainer = shap_dict['explainer'] def predict(input): if input.endswith('xlsx'): df = pd.read_excel(input, index_col=0) elif input.endswith('csv'): df = pd.read_csv(input, index_col=0) else: raise gr.Error(f"Unknown file type!") if "Age" not in df.columns: raise gr.Error("No 'Age' column in the input file!") missed_features = [feature for feature in feats if feature not in df.columns] if len(missed_features) > 0: raise gr.Error(f"No {', '.join(missed_features)} column(s) in the input file!") try: df = df.loc[:, feats + ['Age']] except ValueError: raise gr.Error(f"Non-numeric value in 'Age' column!") df = df.astype({'Age': 'float'}) for feat in feats: try: df = df.astype({feat: 'float'}) except ValueError: raise gr.Error(f"Non-numeric value in '{feat}' column!") df['SImAge'] = model(torch.from_numpy(df.loc[:, feats].values)).cpu().detach().numpy().ravel() df['SImAge acceleration'] = df['SImAge'] - df['Age'] df.to_excel(f'{root_dir}/out/df.xlsx') df_res = df[['SImAge']] df_res.to_excel(f'{root_dir}/out/result.xlsx') if len(df) > 1: mae = mean_absolute_error(df['Age'].values, df['SImAge'].values) rho = pearsonr(df['Age'].values, df['SImAge'].values).statistic plt.close('all') sns.set_theme(style='whitegrid') fig, ax = plt.subplots(figsize=(4, 4)) scatter = sns.scatterplot( data=df, x="Age", y="SImAge", linewidth=0.1, alpha=0.75, edgecolor="k", s=40, color='blue', ax=ax ) bisect = sns.lineplot( x=[0, 120], y=[0, 120], linestyle='--', color='black', linewidth=1.0, ax=ax ) ax.set_xlim(0, 120) ax.set_ylim(0, 120) plt.savefig(f'{root_dir}/out/scatter.svg', bbox_inches='tight') plt.close('all') if len(df) > 1: sns.set_theme(style='whitegrid') fig, ax = plt.subplots(figsize=(2, 4)) sns.violinplot( data=df, y='SImAge acceleration', density_norm='width', color='blue', saturation=0.75, ) plt.savefig(f'{root_dir}/out/violin.svg', bbox_inches='tight') plt.close('all') shap.summary_plot( shap_values=shap_values_train.values, features=values_train.values, feature_names=feats, max_display=len(feats), plot_type="violin", ) plt.savefig(f'{root_dir}/out/shap_beeswarm.svg', bbox_inches='tight') plt.close('all') if len(df) > 1: return_metrics = gr.update(value=f'MAE: {round(mae, 3)}\nPearson Rho: {round(rho, 3)}', visible=True) return_gallery = gr.update(value=[(f'{root_dir}/out/scatter.svg', 'Scatter'), (f'{root_dir}/out/violin.svg', 'Violin'), (f'{root_dir}/out/shap_beeswarm.svg', 'SHAP Beeswarm')], visible=True) else: return_metrics = gr.update(value=f'Only one sample.\nNo metrics can be calculated.', visible=True) return_gallery = gr.update(value=[(f'{root_dir}/out/scatter.svg', 'Scatter'), (f'{root_dir}/out/shap_beeswarm.svg', 'SHAP Beeswarm')], visible=True) return [return_metrics, gr.update(value=f'{root_dir}/out/result.xlsx', visible=True), return_gallery, gr.update(visible=True), gr.update(visible=True), gr.update(choices=list(df.index.values), value=list(df.index.values)[0], interactive=True, visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)] def explain(input): df = pd.read_excel(f'{root_dir}/out/df.xlsx', index_col=0) trgt_id = input shap_values_trgt = explainer.shap_values(df.loc[trgt_id, feats].values) base_value = explainer.expected_value[0] age = df.loc[trgt_id, ['Age']].values[0] simage = df.loc[trgt_id, ['SImAge']].values[0] order = np.argsort(-np.abs(shap_values_trgt)) locally_ordered_feats = [feats[i] for i in order] plt.close('all') shap.plots.waterfall( shap.Explanation( values=shap_values_trgt, base_values=base_value, data=df.loc[trgt_id, feats].values, feature_names=feats ), max_display=len(feats), show=True, ) plt.savefig(f'{root_dir}/out/waterfall_{trgt_id}.svg', bbox_inches='tight') plt.close('all') if len(df) > 1: age_window = 5 trgt_age = df.at[trgt_id, 'Age'] trgt_simage = df.at[trgt_id, 'SImAge'] trgt_simage_acc = df.at[trgt_id, 'SImAge acceleration'] ids_near = df.index[(df['Age'] >= trgt_age - age_window) & (df['Age'] < trgt_age + age_window)] trgt_simage_acc_prctl = stats.percentileofscore(df.loc[ids_near, 'SImAge acceleration'], trgt_simage_acc) sns.set(style='whitegrid', font_scale=1.5) fig, ax = plt.subplots(figsize=(10, 6)) kdeplot = sns.kdeplot( data=df.loc[ids_near, :], x='SImAge acceleration', color='gray', linewidth=4, cut=0, ax=ax ) kdeline = ax.lines[0] xs = kdeline.get_xdata() ys = kdeline.get_ydata() ax.fill_between(xs, 0, ys, where=(xs <= trgt_simage_acc), interpolate=True, facecolor='dodgerblue', alpha=0.7) ax.fill_between(xs, 0, ys, where=(xs >= trgt_simage_acc), interpolate=True, facecolor='crimson', alpha=0.7) ax.vlines(trgt_simage_acc, 0, np.interp(trgt_simage_acc, xs, ys), color='black', linewidth=6) ax.text(np.mean([min(xs), trgt_simage_acc]), 0.1 * max(ys), f"{trgt_simage_acc_prctl:0.1f}%", fontstyle="oblique", color="black", ha="center", va="center") ax.text(np.mean([max(xs), trgt_simage_acc]), 0.1 * max(ys), f"{100 - trgt_simage_acc_prctl:0.1f}%", fontstyle="oblique", color="black", ha="center", va="center") fig.savefig(f"{root_dir}/out/kde_aa_{trgt_id}.svg", bbox_inches='tight') plt.close(fig) sns.set(style='whitegrid', font_scale=0.7) n_rows = 2 n_cols = 5 fig_height = 4 fig_width = 10 fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), gridspec_kw={}, sharey=False, sharex=False) for feat_id, feat in enumerate(feats): row_id, col_id = divmod(feat_id, n_cols) kdeplot = sns.kdeplot( data=df.loc[ids_near, :], x=feat, color='gray', linewidth=1, cut=0, ax=axs[row_id, col_id] ) kdeline = axs[row_id, col_id].lines[0] xs = kdeline.get_xdata() ys = kdeline.get_ydata() trgt_val = df.at[trgt_id, feat] trgt_prctl = stats.percentileofscore(df.loc[ids_near, feat], trgt_val) axs[row_id, col_id].fill_between(xs, 0, ys, where=(xs <= trgt_val), interpolate=True, facecolor='dodgerblue', alpha=0.7) axs[row_id, col_id].fill_between(xs, 0, ys, where=(xs >= trgt_val), interpolate=True, facecolor='crimson', alpha=0.7) axs[row_id, col_id].vlines(trgt_val, 0, np.interp(trgt_val, xs, ys), color='black', linewidth=1.5) axs[row_id, col_id].text(np.mean([min(xs), trgt_val]), 0.1 * max(ys), f"{trgt_prctl:0.1f}%", fontstyle="oblique", color="black", ha="center", va="center") axs[row_id, col_id].text(np.mean([max(xs), trgt_val]), 0.1 * max(ys), f"{100 - trgt_prctl:0.1f}%", fontstyle="oblique", color="black", ha="center", va="center") axs[row_id, col_id].ticklabel_format(style='scientific', scilimits=(-1, 1), axis='y', useOffset=True) fig.tight_layout() fig.savefig(f"{root_dir}/out/kde_feats_{trgt_id}.svg", bbox_inches='tight') plt.close(fig) if len(df) > 1: return_gallery = [(f'{root_dir}/out/waterfall_{trgt_id}.svg', 'Waterfall'), (f'{root_dir}/out/kde_aa_{trgt_id}.svg', 'Age Acceleration KDE'), (f'{root_dir}/out/kde_feats_{trgt_id}.svg', 'Features KDE')] else: return_gallery = [(f'{root_dir}/out/waterfall_{trgt_id}.svg', 'Waterfall')] return [f'Real age: {round(age, 3)}\nSImAge: {round(simage, 3)}', f'{locally_ordered_feats[0]}\n{locally_ordered_feats[1]}\n{locally_ordered_feats[2]}', return_gallery] def clear(): return (gr.update(interactive=False), gr.update(value=None, visible=False), gr.update(value=None, visible=False), gr.update(value=None, visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None, visible=False), gr.update(value=None, visible=False), gr.update(value=None, visible=False)) def check_size(input): curr_file_size = os.path.getsize(input) if curr_file_size > 1024 * 1024: raise gr.Error(f"File exceeds 1 MB limit!") else: return gr.update(interactive=True) css = """ h2 { text-align: center; display:block; } """ with gr.Blocks(css=css, theme=gr.themes.Soft(), title='SImAge') as app: gr.Markdown( """