EpInflammAge / app.py
kalyakulina's picture
Added link to the original paper
86cacd1 verified
import gradio as gr
from pytorch_tabular import TabularModel
import shap
import pandas as pd
import os
from pathlib import Path
import numpy as np
from sklearn.metrics import mean_absolute_error
import scipy
import scipy.stats
import gradio as gr
from tqdm import tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.impute import KNNImputer, SimpleImputer
import shutil
dir_root = Path(os.getcwd())
dir_out = f"{dir_root}/out"
if Path(dir_out).exists():
shutil.rmtree(Path(dir_out))
Path(dir_out).mkdir(parents=True, exist_ok=True)
df_imms = pd.read_excel(f"{dir_root}/models/InflammatoryMarkers/InflammatoryMarkers.xlsx", index_col='feature')
imms = df_imms.index.values
imms_log = [f"{f}_log" for f in imms]
cpgs = pd.read_excel(f"{dir_root}/models/InflammatoryMarkers/CpGs.xlsx", index_col=0).index.values
models_imms = {}
for imm in (pbar := tqdm(imms)):
pbar.set_description(f"Loading model for {imm}")
models_imms[imm] = TabularModel.load_model(f"{dir_root}/models/InflammatoryMarkers/{imm}")
model_age = TabularModel.load_model(f"{dir_root}/models/EpInflammAge")
bkgrd_xai = pd.read_pickle(f"{dir_root}/models/background-xai.pkl")
bkgrd_imp = pd.read_pickle(f"{dir_root}/models/background-imputation.pkl")
def predict_func(X):
X_df = pd.DataFrame(data=X, columns=imms_log)
y = model_age.predict(X_df)['Age_prediction'].values
return y
js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'light') {
url.searchParams.set('__theme', 'light');
window.location.href = url.href;
}
}
"""
with gr.Blocks(theme=gr.themes.Soft(), title='EpImAge', js=js_func, delete_cache=(3600, 3600)) as app:
gr.Markdown(
"""
# EpInflammAge Calculator
## Submit epigenetics data
"""
)
with gr.Row():
with gr.Column():
gr.Markdown(
"""
### Instruction
- Upload your methylation data for 2228 CpGs from [File](https://github.com/GillianGrayson/EpInflammAge/tree/main/data/CpGs.xlsx).
- The first column must be a sample ID.
- Your data must contain `Age` column for metrics (MAE and Pearson Rho) and Age Acceleration calculation.
- Missing values should be `NA` in the corresponding cells.
- [Imputation](https://scikit-learn.org/stable/modules/impute.html) of missing values can be performed using KNN, Mean, and Median methods with all methylation data from the [Paper](https://www.mdpi.com/1422-0067/26/13/6284).
- Data example for GSE87571: [File](https://github.com/GillianGrayson/EpInflammAge/tree/main/data/examples/GSE87571.xlsx).
- Calculations take a few minutes, the plot can be displayed slightly later than the results. If imputation is performed, the calculations will take longer.
""",
)
input_file = gr.File(label='Methylation Data File', file_count='single', file_types=['.xlsx', 'csv'])
calc_radio = gr.Radio(choices=["KNN", "Mean", "Median"], value="KNN", label="Imputation method")
calc_button = gr.Button("Submit data", variant="primary", interactive=False)
with gr.Column(min_width=800):
with gr.Row():
output_file = gr.File(label='Result File', file_types=['.xlsx'], interactive=False, visible=False)
calc_mae = gr.Text(label='Mean Absolute Error', visible=False)
calc_rho = gr.Text(label='Pearson Correlation', visible=False)
with gr.Row():
calc_plot = gr.Plot(visible=False, show_label=False)
shap_markdown_main = gr.Markdown(
"""
## Age Acceleration Explanation using XAI
""",
visible=False)
with gr.Row():
with gr.Column():
shap_dropdown = gr.Dropdown(label='Choose a sample to get an explanation of the EpInflammAge prediction', filterable=True, visible=False)
shap_button = gr.Button("Get explanation", variant="primary", visible=False)
with gr.Row():
shap_text_id = gr.Text(label='Sample', visible=False)
shap_text_age = gr.Text(label='Age', visible=False)
shap_text_epimage = gr.Text(label='EpInflammAge', visible=False)
shap_markdown_cytokines = gr.Markdown(
"""
### Most important cytokines:
""",
visible=False
)
with gr.Column(min_width=800):
shap_plot = gr.Plot(label='Explanation', visible=False, show_label=False)
def check_size(input):
curr_file_size = os.path.getsize(input)
if curr_file_size > 1024 * 1024 * 1024:
raise gr.Error(f"File exceeds 1 GB limit!")
else:
return gr.update(interactive=True)
def clear(request: gr.Request):
dir_to_del = f"{dir_out}/{str(request.session_hash)}"
if Path(dir_to_del).exists() and Path(dir_to_del).is_dir():
print(f"Delete cache: {dir_to_del}")
shutil.rmtree(f"{dir_out}/{str(request.session_hash)}")
dict_gradio = {
calc_button: gr.update(interactive=False),
output_file: gr.update(value=None, visible=False),
calc_mae: gr.update(value=None, visible=False),
calc_rho: gr.update(value=None, visible=False),
calc_plot: gr.update(value=None, visible=False),
shap_markdown_main: gr.update(visible=False),
shap_dropdown: gr.update(value=None, visible=False),
shap_button: gr.update(visible=False),
shap_text_id: gr.update(value=None, visible=False),
shap_text_age: gr.update(value=None, visible=False),
shap_text_epimage: gr.update(value=None, visible=False),
shap_markdown_cytokines: gr.update(visible=False),
shap_plot: gr.update(value=None, visible=False),
}
return dict_gradio
def delete_directory(request: gr.Request):
dir_to_del = f"{dir_out}/{str(request.session_hash)}"
if Path(dir_to_del).exists() and Path(dir_to_del).is_dir():
print(f"Delete cache: {dir_to_del}")
shutil.rmtree(f"{dir_out}/{str(request.session_hash)}")
def progress_for_calc():
dict_gradio = {
output_file: gr.update(value=None, visible=False),
calc_mae: gr.update(value=None, visible=False),
calc_rho: gr.update(value=None, visible=False),
calc_plot: gr.update(visible=True),
shap_markdown_main: gr.update(visible=False),
shap_dropdown: gr.update(value=None, visible=False),
shap_button: gr.update(visible=False),
shap_text_id: gr.update(value=None, visible=False),
shap_text_age: gr.update(value=None, visible=False),
shap_text_epimage: gr.update(value=None, visible=False),
shap_markdown_cytokines: gr.update(visible=False),
shap_plot: gr.update(value=None, visible=False),
}
return dict_gradio
def progress_for_shap():
dict_gradio = {
shap_text_id: gr.update(value=None, visible=False),
shap_text_age: gr.update(value=None, visible=False),
shap_text_epimage: gr.update(value=None, visible=False),
shap_markdown_cytokines: gr.update(value=None, visible=False),
shap_plot: gr.update(value=None, visible=True),
}
return dict_gradio
def explain(input, request: gr.Request, progress=gr.Progress()):
print(f"Read from cache: {dir_out}/{str(request.session_hash)}")
progress(0.0, desc='SHAP values calculation')
data = pd.read_pickle(f"{dir_out}/{str(request.session_hash)}/data.pkl")
trgt_id = input
trgt_age = data.at[trgt_id, 'Age']
trgt_pred = data.at[trgt_id, 'EpInflammAge']
trgt_aa = trgt_pred - trgt_age
n_closest = 200
data_closest = bkgrd_xai.iloc[(bkgrd_xai['EpImAge'] - trgt_age).abs().argsort()[:n_closest]]
explainer = shap.SamplingExplainer(predict_func, data_closest.loc[:, imms_log])
shap_values = explainer.shap_values(data.loc[[trgt_id], imms_log].values)[0]
shap_values = shap_values * (trgt_pred - trgt_age) / (trgt_pred - explainer.expected_value)
df_less_more = pd.DataFrame(index=imms, columns=['Less', 'More'])
for f in df_less_more.index:
df_less_more.at[f, 'Less'] = round(scipy.stats.percentileofscore(data_closest.loc[:, f"{f}_log"].values, data.at[trgt_id, f"{f}_log"]))
df_less_more.at[f, 'More'] = 100.0 - df_less_more.at[f, 'Less']
df_shap = pd.DataFrame(index=imms, data=shap_values, columns=[trgt_id])
df_shap.sort_values(by=trgt_id, key=abs, inplace=True)
df_shap['cumsum'] = df_shap[trgt_id].cumsum()
fig = make_subplots(rows=1, cols=2, shared_yaxes=True, shared_xaxes=False, column_widths=[2.5, 1], horizontal_spacing=0.05, row_titles=[''])
fig.add_trace(
go.Waterfall(
hovertext=["Chrono Age", "EpInflammAge"],
orientation="h",
measure=['absolute', 'relative'],
y=[-1.5, df_shap.shape[0] + 0.5],
x=[trgt_age, trgt_aa],
base=0,
text=[f"{trgt_age:0.2f}", f"+{trgt_aa:0.2f}" if trgt_aa > 0 else f"{trgt_aa:0.2f}"],
textposition = "auto",
decreasing = {"marker":{"color": "deepskyblue", "line": {"color": "black", "width": 1}}},
increasing = {"marker":{"color": "crimson", "line": {"color": "black", "width": 1}}},
totals= {"marker":{"color": "dimgray", "line": {"color": "black", "width": 1}}},
connector={
"mode": "between",
"line": {"width": 1, "color": "black", "dash": "dot"},
},
),
row=1,
col=1
)
fig.add_trace(
go.Waterfall(
hovertext=df_shap.index.values,
orientation="h",
measure=["relative"] * len(imms),
y=list(range(df_shap.shape[0])),
x=df_shap[trgt_id].values,
base=trgt_age,
text=[f"+{x:0.2f}" if x > 0 else f"{x:0.2f}" for x in df_shap[trgt_id].values],
textposition = "auto",
decreasing = {"marker":{"color": "lightblue", "line": {"color": "black", "width": 1}}},
increasing = {"marker":{"color": "lightcoral", "line": {"color": "black", "width": 1}}},
connector={
"mode": "between",
"line": {"width": 1, "color": "black", "dash": "solid"},
},
),
row=1,
col=1,
)
fig.update_yaxes(
row=1,
col=1,
automargin=True,
tickmode="array",
tickvals=[-1.5] + list(range(df_shap.shape[0])) + [df_shap.shape[0] + 0.5],
ticktext=["Chrono Age"] + df_shap.index.to_list() + ["EpInflammAge"],
tickfont=dict(size=16),
)
fig.update_xaxes(
row=1,
col=1,
automargin=True,
title='Age',
titlefont=dict(size=20),
range=[
trgt_age - df_shap['cumsum'].abs().max() * 1.25,
trgt_age + df_shap['cumsum'].abs().max() * 1.25
],
)
fig.update_traces(row=1, col=1, showlegend=False)
fig.add_trace(
go.Bar(
hovertext=df_shap.index.values,
orientation="h",
name='Less',
x=df_less_more.loc[df_shap.index.values, 'Less'],
y=list(range(df_shap.shape[0])),
marker=dict(color='steelblue', line=dict(color="black", width=1)),
text=df_less_more.loc[df_shap.index.values, 'Less'],
textposition='auto'
),
row=1,
col=2,
)
fig.add_trace(
go.Bar(
hovertext=df_shap.index.values,
orientation="h",
name='More',
x=df_less_more.loc[df_shap.index.values, 'More'],
y=list(range(df_shap.shape[0])),
marker=dict(color='violet', line=dict(color="black", width=1)),
text=df_less_more.loc[df_shap.index.values, 'More'],
textposition='auto',
),
row=1,
col=2
)
fig.update_xaxes(
row=1,
col=2,
automargin=True,
showgrid=False,
showline=False,
zeroline=False,
showticklabels=False,
)
fig.update_yaxes(
row=1,
col=2,
automargin=True,
showgrid=False,
showline=False,
zeroline=False,
showticklabels=False,
)
fig.update_layout(barmode="relative")
fig.update_layout(
legend=dict(
title=dict(text="Inflammatory Markers disribution<br>in samples with same age", side="top"),
orientation="h",
yanchor="bottom",
y=0.95,
xanchor="center",
x=0.84
),
)
fig.update_layout(
template="none",
width=800,
height=800,
)
# Resulted gradio dict
dict_gradio = {
shap_text_id: gr.update(value=input, visible=True),
shap_text_age: gr.update(value=f"{trgt_age:.3f}", visible=True),
shap_text_epimage: gr.update(value=f"{trgt_pred:.3f}", visible=True),
shap_markdown_cytokines: gr.update(
value="### Most important cytokines:\n" + '\n'.join(df_imms.loc[df_shap.index.values[:-4:-1], 'Text'].to_list()),
visible=True
),
shap_plot: gr.update(value=fig, visible=True),
}
return dict_gradio
def calc_epimage(input, request: gr.Request, progress=gr.Progress()):
print(f"Create cache: {dir_out}/{str(request.session_hash)}")
Path(f"{dir_out}/{str(request.session_hash)}").mkdir(parents=True, exist_ok=True)
# Read input data file
progress(0.0, desc='Reading input data file')
if input[input_file].endswith('xlsx'):
data = pd.read_excel(input[input_file], index_col=0)
elif input[input_file].endswith('csv'):
data = pd.read_csv(input[input_file], index_col=0)
else:
raise gr.Error(f"Unknown file type!")
# Check features in input file
progress(0.2, desc='Checking features in input file')
missed_cpgs = list(set(cpgs) - set(data.columns.values))
if len(missed_cpgs) > 0:
raise gr.Error(f"Missed {len(missed_cpgs)} CpGs in the input file!")
# Imputation of missing values
imp_method = input[calc_radio]
data.replace({'NA': np.nan}, inplace=True)
n_nans = data.isna().sum().sum()
if n_nans > 0:
print(f"Imputation of {n_nans} missing values using {imp_method} method")
progress(0.8, desc=f"Imputation of {n_nans} missing values using {imp_method} method")
bkgrd_imp.set_index(bkgrd_imp.index.astype(str) + f'_imputation_{imp_method}', inplace=True)
data_all = pd.concat([data, bkgrd_imp], axis=0, verify_integrity=True)
if imp_method == "KNN":
imputer = KNNImputer(n_neighbors=5)
elif imp_method == 'Mean':
imputer = SimpleImputer(strategy='mean')
elif imp_method == 'Median':
imputer = SimpleImputer(strategy='median')
data_all.loc[:, cpgs] = imputer.fit_transform(data_all.loc[:, cpgs].values)
data.loc[data.index, cpgs] = data_all.loc[data.index, cpgs]
# Models' inference
progress(0.9, desc="Inflammatory models' inference")
for imm in imms:
data[f"{imm}_log"] = models_imms[imm].predict(data)
progress(0.95, desc='EpInflammAge model inference')
data['EpInflammAge'] = model_age.predict(data.loc[:, [f"{imm}_log" for imm in imms]])
data['Age Acceleration'] = data['EpInflammAge'] - data['Age']
data.to_pickle(f'{dir_out}/{str(request.session_hash)}/data.pkl')
data_res = data[['Age', 'EpInflammAge', 'Age Acceleration'] + list(imms_log)]
data_res.rename(columns={f"{imm}_log": imm for imm in imms}).to_excel(f'{dir_out}/{str(request.session_hash)}/Result.xlsx', index_label='ID')
if len(data_res) > 1:
mae = mean_absolute_error(data['Age'].values, data['EpInflammAge'].values)
rho = scipy.stats.pearsonr(data['Age'].values, data['EpInflammAge'].values).statistic
# Plot scatter
progress(0.98, desc='Plotting scatter')
fig = make_subplots(rows=1, cols=2, shared_yaxes=False, shared_xaxes=False, column_widths=[5, 3], horizontal_spacing=0.15)
min_plot_age = data[["Age", "EpInflammAge"]].min().min()
max_plot_age = data[["Age", "EpInflammAge"]].max().max()
shift_plot_age = max_plot_age - min_plot_age
min_plot_age -= 0.1 * shift_plot_age
max_plot_age += 0.1 * shift_plot_age
fig.add_trace(
go.Scatter(
x=[min_plot_age, max_plot_age],
y=[min_plot_age, max_plot_age],
showlegend=False,
mode='lines',
line = dict(color='black', width=2, dash='dot')
),
row=1,
col=1
)
fig.add_trace(
go.Scatter(
name='Scatter',
x=data.loc[:, 'Age'].values,
y=data.loc[:, 'EpInflammAge'].values,
text=data.index.values,
hovertext=data.index.values,
showlegend=False,
mode='markers',
marker=dict(
size=10,
opacity=0.75,
line=dict(
width=1,
color='black'
),
color='crimson'
)
),
row=1,
col=1
)
fig.update_xaxes(
row=1,
col=1,
automargin=True,
title_text="Age",
autorange=False,
range=[min_plot_age, max_plot_age],
showgrid=False,
zeroline=False,
linecolor='black',
showline=True,
gridcolor='gainsboro',
gridwidth=0.05,
mirror=True,
ticks='outside',
titlefont=dict(
color='black',
size=20
),
showticklabels=True,
tickangle=0,
tickfont=dict(
color='black',
size=16
),
exponentformat='e',
showexponent='all'
)
fig.update_yaxes(
row=1,
col=1,
automargin=True,
title_text=f"EpInflammAge",
# scaleanchor="x",
# scaleratio=1,
autorange=False,
range=[min_plot_age, max_plot_age],
showgrid=False,
zeroline=False,
linecolor='black',
showline=True,
gridcolor='gainsboro',
gridwidth=0.05,
mirror=True,
ticks='outside',
titlefont=dict(
color='black',
size=20
),
showticklabels=True,
tickangle=0,
tickfont=dict(
color='black',
size=16
),
exponentformat='e',
showexponent='all'
)
fig.add_trace(
go.Violin(
y=data.loc[:, 'Age Acceleration'].values,
hovertext=data.index.values,
name="Violin",
box_visible=True,
meanline_visible=True,
showlegend=False,
line_color='black',
fillcolor='crimson',
marker=dict(color='crimson', line=dict(color='black', width=0.5), opacity=0.75),
points='all',
bandwidth=np.ptp(data.loc[:, 'Age Acceleration'].values) / 32,
opacity=0.75
),
row=1,
col=2
)
fig.update_yaxes(
row=1,
col=2,
automargin=True,
title_text="Age Acceleraton",
autorange=True,
showgrid=False,
zeroline=True,
linecolor='black',
showline=True,
gridcolor='gainsboro',
gridwidth=0.05,
mirror=True,
ticks='outside',
titlefont=dict(
color='black',
size=20
),
showticklabels=True,
tickangle=0,
tickfont=dict(
color='black',
size=16
),
exponentformat='e',
showexponent='all'
)
fig.update_xaxes(
row=1,
col=2,
automargin=True,
autorange=False,
range=[-0.5, 0.3],
showgrid=False,
showline=True,
zeroline=False,
showticklabels=False,
mirror=True,
ticks='outside',
tickvals=[],
)
fig.update_layout(
template="simple_white",
width=800,
height=450,
)
# Resulted gradio dict
dict_gradio = {
output_file: gr.update(value=f'{dir_out}/{str(request.session_hash)}/Result.xlsx', visible=True),
calc_plot: gr.update(value=fig, visible=True),
calc_mae: gr.update(value=f"{mae:.3f}", visible=True),
calc_rho: gr.update(value=f"{rho:.3f}", visible=True) if data.shape[0] > 1 else gr.update(value='Only one sample.\nNo metrics can be calculated.', visible=True),
shap_markdown_main: gr.update(visible=True),
shap_dropdown: gr.update(choices=list(data.index.values), value=list(data.index.values)[0], interactive=True, visible=True),
shap_button: gr.update(visible=True)
}
return dict_gradio
calc_button.click(
fn=progress_for_calc,
inputs=[],
outputs=[output_file, calc_plot, calc_mae, calc_rho, shap_markdown_main, shap_dropdown, shap_button, shap_text_id, shap_text_age, shap_text_epimage, shap_markdown_cytokines, shap_plot]
)
calc_button.click(
fn=calc_epimage,
inputs={input_file, calc_radio},
outputs=[output_file, calc_plot, calc_mae, calc_rho, shap_markdown_main, shap_dropdown, shap_button]
)
input_file.clear(
fn=clear,
inputs=[],
outputs=[calc_button, calc_mae, calc_rho, output_file, calc_plot, shap_markdown_main, shap_dropdown, shap_button, shap_text_id, shap_text_age, shap_text_epimage, shap_markdown_cytokines, shap_plot]
)
input_file.upload(
fn=check_size,
inputs=[input_file],
outputs=[calc_button]
)
shap_button.click(
fn=progress_for_shap,
inputs=[],
outputs=[shap_text_id, shap_text_age, shap_text_epimage, shap_markdown_cytokines, shap_plot]
)
shap_button.click(
fn=explain,
inputs=[shap_dropdown],
outputs=[shap_text_id, shap_text_age, shap_text_epimage, shap_markdown_cytokines, shap_plot]
)
app.unload(delete_directory)
app.launch(show_error=True)