hlnicholls's picture
Update app.py
789c5ff verified
"""
Blood Pressure Gene Prioritisation β€” Streamlit App
===================================================
Two top-level tabs:
1. Blood Pressure (Pre-trained) β€” live prioritisation + SHAP from local BP assets
2. User-Defined Model β€” upload your own .pkl + features CSV
"""
import io
import re
import numpy as np
import pandas as pd
import pickle
import streamlit as st
import shap
import plotly.graph_objects as go
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
st.set_page_config(
page_title="Blood Pressure Gene Prioritisation",
page_icon="🩺",
layout="wide",
)
seed = 0
# ── Helpers ──────────────────────────────────────────────────────────────────
def parse_gene_list(text: str):
return [s for s in re.split(r"[,\s]+", text.strip()) if s]
@st.cache_data
def to_csv_bytes(df: pd.DataFrame) -> bytes:
return df.to_csv(index=False).encode("utf-8")
def normalise_shap(shap_values):
"""Return a list of 2-D per-class SHAP arrays regardless of SHAP version."""
if isinstance(shap_values, np.ndarray) and shap_values.ndim == 3:
return [shap_values[:, :, i] for i in range(shap_values.shape[2])]
return shap_values
def is_tree_model(model) -> bool:
try:
import xgboost, lightgbm, catboost
from sklearn.ensemble import RandomForestClassifier
tree_types = (
xgboost.XGBClassifier,
lightgbm.LGBMClassifier,
catboost.CatBoostClassifier,
RandomForestClassifier,
)
return isinstance(model, tree_types)
except ImportError:
return False
def build_prioritisation_table(annotations: pd.DataFrame, model) -> pd.DataFrame:
"""Compute all-gene prioritisation directly from model + feature table."""
probability_columns = [
"Probability_Most_Likely", "Probability_Probable", "Probability_Least_Likely"
]
probabilities = model.predict_proba(annotations)
prob_df = pd.DataFrame(
probabilities, index=annotations.index, columns=probability_columns
)
return pd.concat([prob_df, annotations], axis=1)
def show_current_figure():
"""Render and close the current matplotlib figure safely in Streamlit."""
fig = plt.gcf()
st.pyplot(fig)
plt.close(fig)
# ── Load BP assets (cached so they only load once) ──────────────────────────
@st.cache_resource
def load_bp_assets():
annotations = pd.read_csv("all_genes_imputed_features.csv")
annotations.fillna(0, inplace=True)
annotations = annotations.set_index("Gene")
with open("best_model_fitted.pkl", "rb") as f:
model = pickle.load(f)
explainer = shap.TreeExplainer(model)
return annotations, model, explainer
# ── Shared gene-prioritisation + SHAP panel ─────────────────────────────────
def render_gene_prioritisation(annotations, model, explainer, prefix=""):
probability_columns = [
"Probability_Most_Likely", "Probability_Probable", "Probability_Least_Likely"
]
class_names = ["Most likely", "Probable", "Least likely"]
df_total = build_prioritisation_table(annotations, model)
# ── Multi-gene query ─────────────────────────────────────────────────────
input_text = st.text_input(
"Input multiple HGNC genes (comma or space separated):",
key=f"{prefix}_multi_gene",
)
gene_list = parse_gene_list(input_text)
if len(gene_list) > 1:
df = df_total[df_total.index.isin(gene_list)].copy()
df.insert(0, "Gene", df.index)
st.dataframe(df)
st.download_button(
"Download Gene Prioritisation",
to_csv_bytes(df[["Gene"] + probability_columns]),
"gene_prioritisation.csv",
"text/csv",
key=f"{prefix}_dl_multi",
)
df_shap = df.drop(columns=probability_columns + ["Gene"])
sv = normalise_shap(explainer.shap_values(df_shap))
col1, col2 = st.columns(2)
with col1:
st.subheader("Global SHAP Summary")
shap.summary_plot(sv, df_shap, plot_type="bar", class_names=class_names, show=False)
show_current_figure()
with col2:
st.subheader(f"{class_names[0]} Prediction")
shap.summary_plot(sv[0], df_shap, show=False)
show_current_figure()
col3, col4 = st.columns(2)
with col3:
st.subheader(f"{class_names[1]} Prediction")
shap.summary_plot(sv[1], df_shap, show=False)
show_current_figure()
with col4:
st.subheader(f"{class_names[2]} Prediction")
shap.summary_plot(sv[2], df_shap, show=False)
show_current_figure()
# Interactive beeswarm (top-20 features, class 0)
st.subheader("Interactive SHAP Plot (Top 20 Features β€” Most Likely Class)")
fi = np.abs(sv[0]).mean(axis=0)
top_idx = np.argsort(fi)[-20:]
features_top = df_shap.columns[top_idx][::-1]
sv_top = sv[0][:, top_idx][..., ::-1]
genes_in_plot = list(df["Gene"])
x_vals, y_vals, hover = [], [], []
for i, fname in enumerate(features_top):
for gene, val in zip(genes_in_plot, sv_top[:, i]):
x_vals.append(val); y_vals.append(fname)
hover.append(f"{gene}: {val:.3f}")
fig_bee = go.Figure(data=go.Scatter(
x=x_vals, y=y_vals, mode="markers",
marker=dict(color=x_vals, colorbar=dict(title="SHAP"),
colorscale=[(0, "blue"), (1, "red")]),
text=hover, hoverinfo="text+x",
))
fig_bee.update_layout(
title="SHAP Summary β€” Top 20 Features",
xaxis_title="SHAP Value",
yaxis=dict(autorange="reversed", title="Feature"),
showlegend=False,
)
st.plotly_chart(fig_bee, use_container_width=True)
# ── Single-gene force plot ───────────────────────────────────────────────
input_single = st.text_input("Input a single HGNC gene:", key=f"{prefix}_single_gene")
if input_single:
if " " in input_single or "," in input_single:
st.warning("Please enter only one gene name (no spaces or commas).")
elif input_single not in df_total.index:
st.warning(f"Gene '{input_single}' not found in the dataset.")
else:
df2 = df_total.loc[[input_single]].copy()
df2.insert(0, "Gene", df2.index)
st.dataframe(df2)
df2_shap = df_total.loc[
[input_single],
[c for c in df_total.columns if c not in probability_columns],
]
sv2 = normalise_shap(explainer.shap_values(df2_shap))
for i, cname in enumerate(class_names):
st.subheader(f"Force Plot β€” {cname}")
fp = shap.force_plot(
explainer.expected_value[i], sv2[i], df2_shap,
matplotlib=True, show=False,
)
# Some SHAP versions return the active pyplot figure rather than a Figure object.
if fp is not None and hasattr(fp, "savefig"):
st.pyplot(fp)
plt.close(fp)
else:
show_current_figure()
url = (
"https://astrazeneca-cgr-publications.github.io/DrugnomeAI/"
f"geneview.html?gene={input_single}"
)
genecards_url = f"https://www.genecards.org/cgi-bin/carddisp.pl?gene={input_single}"
st.markdown(
f"[{input_single} Druggability in DrugnomeAI]({url}) | "
f"[{input_single} in GeneCards]({genecards_url})"
)
# ── Full results table ───────────────────────────────────────────────────
st.markdown("### Full Prioritisation Results (All Genes)")
out = df_total.copy()
out.insert(0, "Gene", out.index)
st.dataframe(out.head(500))
st.caption(f"Showing first 500 rows of {len(out)} total genes.")
st.download_button(
"Download All Genes",
to_csv_bytes(out),
"all_genes_prioritisation.csv",
"text/csv",
key=f"{prefix}_dl_all",
)
return df_total
# ── App layout ───────────────────────────────────────────────────────────────
st.title("Disease-specific Gene Prioritisation Post-GWAS")
st.markdown(
"A interactive visualisation of gene prioritisation results from the machine learning pipeline: https://github.com/hlnicholls/GenePrioritiser. "
"Use the **Blood Pressure** tab for the pre-trained BP model, "
"or the **User-Defined Model** tab to upload your own model and run the same analysis."
)
top_tab_bp, top_tab_user = st.tabs(
[
"🩺 Blood Pressure (Pre-trained)",
"πŸ“‚ User-Defined Model",
]
)
# ════════════════════════════════════════════════════════════════════════════
# TAB 1 β€” Blood Pressure (Pre-trained)
# ════════════════════════════════════════════════════════════════════════════
with top_tab_bp:
st.header("Blood Pressure Gene Prioritisation")
st.markdown(
"Pre-trained model using SBP, DBP and PP GWAS data. "
"This tab performs prioritisation of genes using a model trained on the "
"blood pressure GWAS from Keaton et al. "
"([https://doi.org/10.1038/s41588-024-01714-w](https://doi.org/10.1038/s41588-024-01714-w)). "
"Three-class prediction: **Most Likely**, **Probable**, **Least Likely**."
)
annotations_bp, model_bp, explainer_bp = load_bp_assets()
sub_gp, sub_cluster = st.tabs(
["Gene Prioritisation & SHAP", "Supervised SHAP Clustering"]
)
with sub_gp:
df_total_bp = render_gene_prioritisation(
annotations_bp, model_bp, explainer_bp, prefix="bp"
)
with sub_cluster:
st.subheader("Supervised SHAP Clustering (PCA)")
st.markdown(
"Each point is a gene positioned using PCA on its SHAP profile for the **Most Likely** class, "
"so clustering reflects similarity in model explanation patterns rather than raw feature values. "
"Genes that appear close together are being prioritised for similar reasons by the model, while genes far apart "
"are driven by different feature contributions. Use the overlays to see whether your queried genes align with "
"the most-likely training-gene region or form distinct groups."
)
try:
training_genes = pd.read_csv("training_cleaned.csv")
label_col = next(
(c for c in training_genes.columns if "label_encoded" in c.lower()), None
)
if label_col:
training_genes = training_genes[training_genes[label_col] == 0]
training_genes = training_genes.set_index("Gene")
except FileNotFoundError:
training_genes = pd.DataFrame()
st.info("training_cleaned.csv not found β€” most-likely training gene overlay disabled.")
shap_full = normalise_shap(explainer_bp.shap_values(annotations_bp))
sv_arr = np.array(shap_full[0])
sv_pca = PCA(n_components=2).fit_transform(sv_arr)
clust_labels = KMeans(n_clusters=3, random_state=seed).fit_predict(sv_pca)
# Reuse genes typed in the Gene Prioritisation tab (same BP tab scope)
# so users can see their queried genes on the clustering view without retyping.
gp_input_text = st.session_state.get("bp_multi_gene", "")
gp_genes = parse_gene_list(gp_input_text)
cluster_input = st.text_input(
"Highlight additional genes (comma/space separated):", key="bp_cluster_genes"
)
cluster_genes = parse_gene_list(cluster_input)
# Combine and deduplicate, preserving input order.
highlight_genes = list(dict.fromkeys(gp_genes + cluster_genes))
if gp_genes:
st.caption(
f"Including {len(gp_genes)} gene(s) from the Gene Prioritisation input in this plot."
)
df_plot = pd.DataFrame({
"PCA_1": sv_pca[:, 0], "PCA_2": sv_pca[:, 1],
"Cluster": clust_labels.astype(str),
"Gene": annotations_bp.index, "SpecialGroup": "None",
})
if not training_genes.empty:
df_plot.loc[
df_plot["Gene"].isin(training_genes.index), "SpecialGroup"
] = "Most Likely Training Gene"
if highlight_genes:
df_plot.loc[
df_plot["Gene"].isin(highlight_genes), "SpecialGroup"
] = "User Input Gene"
if highlight_genes:
found_genes = sorted(set(df_plot.loc[df_plot["Gene"].isin(highlight_genes), "Gene"]))
missing_genes = sorted(set(highlight_genes) - set(found_genes))
if missing_genes:
st.caption(
"Not found in this feature set: " + ", ".join(missing_genes)
)
fig_c = go.Figure()
cluster_colors = ["#ADD8E6", "#87CEEB", "#1E90FF"]
for i, cluster in enumerate(sorted(df_plot["Cluster"].unique())):
sub = df_plot[(df_plot["Cluster"] == cluster) & (df_plot["SpecialGroup"] == "None")]
fig_c.add_trace(go.Scatter(
x=sub["PCA_1"], y=sub["PCA_2"], mode="markers",
name=f"Cluster {cluster}", text=sub["Gene"],
marker=dict(color=cluster_colors[i % len(cluster_colors)]),
hoverinfo="text+x+y",
))
for group, colour in [
("Most Likely Training Gene", "black"),
("User Input Gene", "purple"),
]:
sub = df_plot[df_plot["SpecialGroup"] == group]
if not sub.empty:
fig_c.add_trace(go.Scatter(
x=sub["PCA_1"], y=sub["PCA_2"], mode="markers",
name=group, text=sub["Gene"],
marker=dict(color=colour), hoverinfo="text+x+y",
))
fig_c.update_layout(
title="Supervised SHAP Clustering with PCA",
xaxis_title="First Principal Component",
yaxis_title="Second Principal Component",
legend_title_text="Gene Category",
)
st.plotly_chart(fig_c, use_container_width=True)
# ════════════════════════════════════════════════════════════════════════════
# TAB 2 β€” User-Defined Model
# ════════════════════════════════════════════════════════════════════════════
with top_tab_user:
st.header("User-Defined Model")
st.markdown(
"""
Upload your own trained model and imputed gene features to run the same
gene prioritisation and SHAP analysis.
**Requirements:**
- **Model file** (`.pkl`): a scikit-learn–compatible classifier with `predict_proba`
returning 3 classes ordered as **Most Likely (0)**, **Probable (1)**, **Least Likely (2)**.
Tree-based models (XGBoost, LightGBM, CatBoost, RandomForest) are fully supported.
- **Features CSV**: imputed gene features with a `Gene` column and one numeric column
per feature (same format as `all_genes_imputed_features.csv`).
"""
)
col_a, col_b = st.columns(2)
with col_a:
model_file = st.file_uploader("Upload model (.pkl)", type=["pkl"], key="user_pkl")
with col_b:
features_file = st.file_uploader(
"Upload gene features CSV", type=["csv"], key="user_csv"
)
if model_file is None or features_file is None:
st.info("⬆️ Upload both files above to begin.")
else:
# Load model
try:
user_model = pickle.load(io.BytesIO(model_file.read()))
except Exception as e:
st.error(f"Could not load model: {e}")
st.stop()
# Load features
try:
user_annotations = pd.read_csv(features_file)
user_annotations.fillna(0, inplace=True)
if "Gene" not in user_annotations.columns:
st.error("Features CSV must contain a 'Gene' column.")
st.stop()
user_annotations = user_annotations.set_index("Gene")
except Exception as e:
st.error(f"Could not read features CSV: {e}")
st.stop()
if not is_tree_model(user_model):
st.warning(
f"`{type(user_model).__name__}` is not a recognised tree-based model. "
"SHAP TreeExplainer will be attempted but may fail for non-tree models."
)
# Build explainer
try:
user_explainer = shap.TreeExplainer(user_model)
except Exception as e:
st.error(
f"Could not create SHAP TreeExplainer: {e}\n"
"Only tree-based models (XGBoost, LightGBM, CatBoost, RandomForest) "
"are supported for SHAP analysis."
)
st.stop()
st.success(
f"βœ… **Model**: `{type(user_model).__name__}` | "
f"**Genes**: {user_annotations.shape[0]} | "
f"**Features**: {user_annotations.shape[1]}"
)
render_gene_prioritisation(
user_annotations, user_model, user_explainer, prefix="user"
)