import gradio as gr from pytorch_tabular import TabularModel import shap import pandas as pd import os from pathlib import Path import numpy as np from sklearn.metrics import mean_absolute_error import scipy import scipy.stats import gradio as gr from tqdm import tqdm import plotly.graph_objects as go from plotly.subplots import make_subplots from sklearn.impute import KNNImputer, SimpleImputer import shutil dir_root = Path(os.getcwd()) dir_out = f"{dir_root}/out" if Path(dir_out).exists(): shutil.rmtree(Path(dir_out)) Path(dir_out).mkdir(parents=True, exist_ok=True) df_imms = pd.read_excel(f"{dir_root}/models/InflammatoryMarkers/InflammatoryMarkers.xlsx", index_col='feature') imms = df_imms.index.values imms_log = [f"{f}_log" for f in imms] cpgs = pd.read_excel(f"{dir_root}/models/InflammatoryMarkers/CpGs.xlsx", index_col=0).index.values models_imms = {} for imm in (pbar := tqdm(imms)): pbar.set_description(f"Loading model for {imm}") models_imms[imm] = TabularModel.load_model(f"{dir_root}/models/InflammatoryMarkers/{imm}") model_age = TabularModel.load_model(f"{dir_root}/models/EpInflammAge") bkgrd_xai = pd.read_pickle(f"{dir_root}/models/background-xai.pkl") bkgrd_imp = pd.read_pickle(f"{dir_root}/models/background-imputation.pkl") def predict_func(X): X_df = pd.DataFrame(data=X, columns=imms_log) y = model_age.predict(X_df)['Age_prediction'].values return y js_func = """ function refresh() { const url = new URL(window.location); if (url.searchParams.get('__theme') !== 'light') { url.searchParams.set('__theme', 'light'); window.location.href = url.href; } } """ with gr.Blocks(theme=gr.themes.Soft(), title='EpImAge', js=js_func, delete_cache=(3600, 3600)) as app: gr.Markdown( """ # EpInflammAge Calculator ## Submit epigenetics data """ ) with gr.Row(): with gr.Column(): gr.Markdown( """ ### Instruction - Upload your methylation data for 2228 CpGs from [File](https://github.com/GillianGrayson/EpInflammAge/tree/main/data/CpGs.xlsx). - The first column must be a sample ID. - Your data must contain `Age` column for metrics (MAE and Pearson Rho) and Age Acceleration calculation. - Missing values should be `NA` in the corresponding cells. - [Imputation](https://scikit-learn.org/stable/modules/impute.html) of missing values can be performed using KNN, Mean, and Median methods with all methylation data from the [Paper](https://www.mdpi.com/1422-0067/26/13/6284). - Data example for GSE87571: [File](https://github.com/GillianGrayson/EpInflammAge/tree/main/data/examples/GSE87571.xlsx). - Calculations take a few minutes, the plot can be displayed slightly later than the results. If imputation is performed, the calculations will take longer. """, ) input_file = gr.File(label='Methylation Data File', file_count='single', file_types=['.xlsx', 'csv']) calc_radio = gr.Radio(choices=["KNN", "Mean", "Median"], value="KNN", label="Imputation method") calc_button = gr.Button("Submit data", variant="primary", interactive=False) with gr.Column(min_width=800): with gr.Row(): output_file = gr.File(label='Result File', file_types=['.xlsx'], interactive=False, visible=False) calc_mae = gr.Text(label='Mean Absolute Error', visible=False) calc_rho = gr.Text(label='Pearson Correlation', visible=False) with gr.Row(): calc_plot = gr.Plot(visible=False, show_label=False) shap_markdown_main = gr.Markdown( """ ## Age Acceleration Explanation using XAI """, visible=False) with gr.Row(): with gr.Column(): shap_dropdown = gr.Dropdown(label='Choose a sample to get an explanation of the EpInflammAge prediction', filterable=True, visible=False) shap_button = gr.Button("Get explanation", variant="primary", visible=False) with gr.Row(): shap_text_id = gr.Text(label='Sample', visible=False) shap_text_age = gr.Text(label='Age', visible=False) shap_text_epimage = gr.Text(label='EpInflammAge', visible=False) shap_markdown_cytokines = gr.Markdown( """ ### Most important cytokines: """, visible=False ) with gr.Column(min_width=800): shap_plot = gr.Plot(label='Explanation', visible=False, show_label=False) def check_size(input): curr_file_size = os.path.getsize(input) if curr_file_size > 1024 * 1024 * 1024: raise gr.Error(f"File exceeds 1 GB limit!") else: return gr.update(interactive=True) def clear(request: gr.Request): dir_to_del = f"{dir_out}/{str(request.session_hash)}" if Path(dir_to_del).exists() and Path(dir_to_del).is_dir(): print(f"Delete cache: {dir_to_del}") shutil.rmtree(f"{dir_out}/{str(request.session_hash)}") dict_gradio = { calc_button: gr.update(interactive=False), output_file: gr.update(value=None, visible=False), calc_mae: gr.update(value=None, visible=False), calc_rho: gr.update(value=None, visible=False), calc_plot: gr.update(value=None, visible=False), shap_markdown_main: gr.update(visible=False), shap_dropdown: gr.update(value=None, visible=False), shap_button: gr.update(visible=False), shap_text_id: gr.update(value=None, visible=False), shap_text_age: gr.update(value=None, visible=False), shap_text_epimage: gr.update(value=None, visible=False), shap_markdown_cytokines: gr.update(visible=False), shap_plot: gr.update(value=None, visible=False), } return dict_gradio def delete_directory(request: gr.Request): dir_to_del = f"{dir_out}/{str(request.session_hash)}" if Path(dir_to_del).exists() and Path(dir_to_del).is_dir(): print(f"Delete cache: {dir_to_del}") shutil.rmtree(f"{dir_out}/{str(request.session_hash)}") def progress_for_calc(): dict_gradio = { output_file: gr.update(value=None, visible=False), calc_mae: gr.update(value=None, visible=False), calc_rho: gr.update(value=None, visible=False), calc_plot: gr.update(visible=True), shap_markdown_main: gr.update(visible=False), shap_dropdown: gr.update(value=None, visible=False), shap_button: gr.update(visible=False), shap_text_id: gr.update(value=None, visible=False), shap_text_age: gr.update(value=None, visible=False), shap_text_epimage: gr.update(value=None, visible=False), shap_markdown_cytokines: gr.update(visible=False), shap_plot: gr.update(value=None, visible=False), } return dict_gradio def progress_for_shap(): dict_gradio = { shap_text_id: gr.update(value=None, visible=False), shap_text_age: gr.update(value=None, visible=False), shap_text_epimage: gr.update(value=None, visible=False), shap_markdown_cytokines: gr.update(value=None, visible=False), shap_plot: gr.update(value=None, visible=True), } return dict_gradio def explain(input, request: gr.Request, progress=gr.Progress()): print(f"Read from cache: {dir_out}/{str(request.session_hash)}") progress(0.0, desc='SHAP values calculation') data = pd.read_pickle(f"{dir_out}/{str(request.session_hash)}/data.pkl") trgt_id = input trgt_age = data.at[trgt_id, 'Age'] trgt_pred = data.at[trgt_id, 'EpInflammAge'] trgt_aa = trgt_pred - trgt_age n_closest = 200 data_closest = bkgrd_xai.iloc[(bkgrd_xai['EpImAge'] - trgt_age).abs().argsort()[:n_closest]] explainer = shap.SamplingExplainer(predict_func, data_closest.loc[:, imms_log]) shap_values = explainer.shap_values(data.loc[[trgt_id], imms_log].values)[0] shap_values = shap_values * (trgt_pred - trgt_age) / (trgt_pred - explainer.expected_value) df_less_more = pd.DataFrame(index=imms, columns=['Less', 'More']) for f in df_less_more.index: df_less_more.at[f, 'Less'] = round(scipy.stats.percentileofscore(data_closest.loc[:, f"{f}_log"].values, data.at[trgt_id, f"{f}_log"])) df_less_more.at[f, 'More'] = 100.0 - df_less_more.at[f, 'Less'] df_shap = pd.DataFrame(index=imms, data=shap_values, columns=[trgt_id]) df_shap.sort_values(by=trgt_id, key=abs, inplace=True) df_shap['cumsum'] = df_shap[trgt_id].cumsum() fig = make_subplots(rows=1, cols=2, shared_yaxes=True, shared_xaxes=False, column_widths=[2.5, 1], horizontal_spacing=0.05, row_titles=['']) fig.add_trace( go.Waterfall( hovertext=["Chrono Age", "EpInflammAge"], orientation="h", measure=['absolute', 'relative'], y=[-1.5, df_shap.shape[0] + 0.5], x=[trgt_age, trgt_aa], base=0, text=[f"{trgt_age:0.2f}", f"+{trgt_aa:0.2f}" if trgt_aa > 0 else f"{trgt_aa:0.2f}"], textposition = "auto", decreasing = {"marker":{"color": "deepskyblue", "line": {"color": "black", "width": 1}}}, increasing = {"marker":{"color": "crimson", "line": {"color": "black", "width": 1}}}, totals= {"marker":{"color": "dimgray", "line": {"color": "black", "width": 1}}}, connector={ "mode": "between", "line": {"width": 1, "color": "black", "dash": "dot"}, }, ), row=1, col=1 ) fig.add_trace( go.Waterfall( hovertext=df_shap.index.values, orientation="h", measure=["relative"] * len(imms), y=list(range(df_shap.shape[0])), x=df_shap[trgt_id].values, base=trgt_age, text=[f"+{x:0.2f}" if x > 0 else f"{x:0.2f}" for x in df_shap[trgt_id].values], textposition = "auto", decreasing = {"marker":{"color": "lightblue", "line": {"color": "black", "width": 1}}}, increasing = {"marker":{"color": "lightcoral", "line": {"color": "black", "width": 1}}}, connector={ "mode": "between", "line": {"width": 1, "color": "black", "dash": "solid"}, }, ), row=1, col=1, ) fig.update_yaxes( row=1, col=1, automargin=True, tickmode="array", tickvals=[-1.5] + list(range(df_shap.shape[0])) + [df_shap.shape[0] + 0.5], ticktext=["Chrono Age"] + df_shap.index.to_list() + ["EpInflammAge"], tickfont=dict(size=16), ) fig.update_xaxes( row=1, col=1, automargin=True, title='Age', titlefont=dict(size=20), range=[ trgt_age - df_shap['cumsum'].abs().max() * 1.25, trgt_age + df_shap['cumsum'].abs().max() * 1.25 ], ) fig.update_traces(row=1, col=1, showlegend=False) fig.add_trace( go.Bar( hovertext=df_shap.index.values, orientation="h", name='Less', x=df_less_more.loc[df_shap.index.values, 'Less'], y=list(range(df_shap.shape[0])), marker=dict(color='steelblue', line=dict(color="black", width=1)), text=df_less_more.loc[df_shap.index.values, 'Less'], textposition='auto' ), row=1, col=2, ) fig.add_trace( go.Bar( hovertext=df_shap.index.values, orientation="h", name='More', x=df_less_more.loc[df_shap.index.values, 'More'], y=list(range(df_shap.shape[0])), marker=dict(color='violet', line=dict(color="black", width=1)), text=df_less_more.loc[df_shap.index.values, 'More'], textposition='auto', ), row=1, col=2 ) fig.update_xaxes( row=1, col=2, automargin=True, showgrid=False, showline=False, zeroline=False, showticklabels=False, ) fig.update_yaxes( row=1, col=2, automargin=True, showgrid=False, showline=False, zeroline=False, showticklabels=False, ) fig.update_layout(barmode="relative") fig.update_layout( legend=dict( title=dict(text="Inflammatory Markers disribution
in samples with same age", side="top"), orientation="h", yanchor="bottom", y=0.95, xanchor="center", x=0.84 ), ) fig.update_layout( template="none", width=800, height=800, ) # Resulted gradio dict dict_gradio = { shap_text_id: gr.update(value=input, visible=True), shap_text_age: gr.update(value=f"{trgt_age:.3f}", visible=True), shap_text_epimage: gr.update(value=f"{trgt_pred:.3f}", visible=True), shap_markdown_cytokines: gr.update( value="### Most important cytokines:\n" + '\n'.join(df_imms.loc[df_shap.index.values[:-4:-1], 'Text'].to_list()), visible=True ), shap_plot: gr.update(value=fig, visible=True), } return dict_gradio def calc_epimage(input, request: gr.Request, progress=gr.Progress()): print(f"Create cache: {dir_out}/{str(request.session_hash)}") Path(f"{dir_out}/{str(request.session_hash)}").mkdir(parents=True, exist_ok=True) # Read input data file progress(0.0, desc='Reading input data file') if input[input_file].endswith('xlsx'): data = pd.read_excel(input[input_file], index_col=0) elif input[input_file].endswith('csv'): data = pd.read_csv(input[input_file], index_col=0) else: raise gr.Error(f"Unknown file type!") # Check features in input file progress(0.2, desc='Checking features in input file') missed_cpgs = list(set(cpgs) - set(data.columns.values)) if len(missed_cpgs) > 0: raise gr.Error(f"Missed {len(missed_cpgs)} CpGs in the input file!") # Imputation of missing values imp_method = input[calc_radio] data.replace({'NA': np.nan}, inplace=True) n_nans = data.isna().sum().sum() if n_nans > 0: print(f"Imputation of {n_nans} missing values using {imp_method} method") progress(0.8, desc=f"Imputation of {n_nans} missing values using {imp_method} method") bkgrd_imp.set_index(bkgrd_imp.index.astype(str) + f'_imputation_{imp_method}', inplace=True) data_all = pd.concat([data, bkgrd_imp], axis=0, verify_integrity=True) if imp_method == "KNN": imputer = KNNImputer(n_neighbors=5) elif imp_method == 'Mean': imputer = SimpleImputer(strategy='mean') elif imp_method == 'Median': imputer = SimpleImputer(strategy='median') data_all.loc[:, cpgs] = imputer.fit_transform(data_all.loc[:, cpgs].values) data.loc[data.index, cpgs] = data_all.loc[data.index, cpgs] # Models' inference progress(0.9, desc="Inflammatory models' inference") for imm in imms: data[f"{imm}_log"] = models_imms[imm].predict(data) progress(0.95, desc='EpInflammAge model inference') data['EpInflammAge'] = model_age.predict(data.loc[:, [f"{imm}_log" for imm in imms]]) data['Age Acceleration'] = data['EpInflammAge'] - data['Age'] data.to_pickle(f'{dir_out}/{str(request.session_hash)}/data.pkl') data_res = data[['Age', 'EpInflammAge', 'Age Acceleration'] + list(imms_log)] data_res.rename(columns={f"{imm}_log": imm for imm in imms}).to_excel(f'{dir_out}/{str(request.session_hash)}/Result.xlsx', index_label='ID') if len(data_res) > 1: mae = mean_absolute_error(data['Age'].values, data['EpInflammAge'].values) rho = scipy.stats.pearsonr(data['Age'].values, data['EpInflammAge'].values).statistic # Plot scatter progress(0.98, desc='Plotting scatter') fig = make_subplots(rows=1, cols=2, shared_yaxes=False, shared_xaxes=False, column_widths=[5, 3], horizontal_spacing=0.15) min_plot_age = data[["Age", "EpInflammAge"]].min().min() max_plot_age = data[["Age", "EpInflammAge"]].max().max() shift_plot_age = max_plot_age - min_plot_age min_plot_age -= 0.1 * shift_plot_age max_plot_age += 0.1 * shift_plot_age fig.add_trace( go.Scatter( x=[min_plot_age, max_plot_age], y=[min_plot_age, max_plot_age], showlegend=False, mode='lines', line = dict(color='black', width=2, dash='dot') ), row=1, col=1 ) fig.add_trace( go.Scatter( name='Scatter', x=data.loc[:, 'Age'].values, y=data.loc[:, 'EpInflammAge'].values, text=data.index.values, hovertext=data.index.values, showlegend=False, mode='markers', marker=dict( size=10, opacity=0.75, line=dict( width=1, color='black' ), color='crimson' ) ), row=1, col=1 ) fig.update_xaxes( row=1, col=1, automargin=True, title_text="Age", autorange=False, range=[min_plot_age, max_plot_age], showgrid=False, zeroline=False, linecolor='black', showline=True, gridcolor='gainsboro', gridwidth=0.05, mirror=True, ticks='outside', titlefont=dict( color='black', size=20 ), showticklabels=True, tickangle=0, tickfont=dict( color='black', size=16 ), exponentformat='e', showexponent='all' ) fig.update_yaxes( row=1, col=1, automargin=True, title_text=f"EpInflammAge", # scaleanchor="x", # scaleratio=1, autorange=False, range=[min_plot_age, max_plot_age], showgrid=False, zeroline=False, linecolor='black', showline=True, gridcolor='gainsboro', gridwidth=0.05, mirror=True, ticks='outside', titlefont=dict( color='black', size=20 ), showticklabels=True, tickangle=0, tickfont=dict( color='black', size=16 ), exponentformat='e', showexponent='all' ) fig.add_trace( go.Violin( y=data.loc[:, 'Age Acceleration'].values, hovertext=data.index.values, name="Violin", box_visible=True, meanline_visible=True, showlegend=False, line_color='black', fillcolor='crimson', marker=dict(color='crimson', line=dict(color='black', width=0.5), opacity=0.75), points='all', bandwidth=np.ptp(data.loc[:, 'Age Acceleration'].values) / 32, opacity=0.75 ), row=1, col=2 ) fig.update_yaxes( row=1, col=2, automargin=True, title_text="Age Acceleraton", autorange=True, showgrid=False, zeroline=True, linecolor='black', showline=True, gridcolor='gainsboro', gridwidth=0.05, mirror=True, ticks='outside', titlefont=dict( color='black', size=20 ), showticklabels=True, tickangle=0, tickfont=dict( color='black', size=16 ), exponentformat='e', showexponent='all' ) fig.update_xaxes( row=1, col=2, automargin=True, autorange=False, range=[-0.5, 0.3], showgrid=False, showline=True, zeroline=False, showticklabels=False, mirror=True, ticks='outside', tickvals=[], ) fig.update_layout( template="simple_white", width=800, height=450, ) # Resulted gradio dict dict_gradio = { output_file: gr.update(value=f'{dir_out}/{str(request.session_hash)}/Result.xlsx', visible=True), calc_plot: gr.update(value=fig, visible=True), calc_mae: gr.update(value=f"{mae:.3f}", visible=True), calc_rho: gr.update(value=f"{rho:.3f}", visible=True) if data.shape[0] > 1 else gr.update(value='Only one sample.\nNo metrics can be calculated.', visible=True), shap_markdown_main: gr.update(visible=True), shap_dropdown: gr.update(choices=list(data.index.values), value=list(data.index.values)[0], interactive=True, visible=True), shap_button: gr.update(visible=True) } return dict_gradio calc_button.click( fn=progress_for_calc, inputs=[], outputs=[output_file, calc_plot, calc_mae, calc_rho, shap_markdown_main, shap_dropdown, shap_button, shap_text_id, shap_text_age, shap_text_epimage, shap_markdown_cytokines, shap_plot] ) calc_button.click( fn=calc_epimage, inputs={input_file, calc_radio}, outputs=[output_file, calc_plot, calc_mae, calc_rho, shap_markdown_main, shap_dropdown, shap_button] ) input_file.clear( fn=clear, inputs=[], outputs=[calc_button, calc_mae, calc_rho, output_file, calc_plot, shap_markdown_main, shap_dropdown, shap_button, shap_text_id, shap_text_age, shap_text_epimage, shap_markdown_cytokines, shap_plot] ) input_file.upload( fn=check_size, inputs=[input_file], outputs=[calc_button] ) shap_button.click( fn=progress_for_shap, inputs=[], outputs=[shap_text_id, shap_text_age, shap_text_epimage, shap_markdown_cytokines, shap_plot] ) shap_button.click( fn=explain, inputs=[shap_dropdown], outputs=[shap_text_id, shap_text_age, shap_text_epimage, shap_markdown_cytokines, shap_plot] ) app.unload(delete_directory) app.launch(show_error=True)