""" Blood Pressure Gene Prioritisation β€” Streamlit App =================================================== Two top-level tabs: 1. Blood Pressure (Pre-trained) β€” live prioritisation + SHAP from local BP assets 2. User-Defined Model β€” upload your own .pkl + features CSV """ import io import re import numpy as np import pandas as pd import pickle import streamlit as st import shap import plotly.graph_objects as go from sklearn.cluster import KMeans from sklearn.decomposition import PCA import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt st.set_page_config( page_title="Blood Pressure Gene Prioritisation", page_icon="🩺", layout="wide", ) seed = 0 # ── Helpers ────────────────────────────────────────────────────────────────── def parse_gene_list(text: str): return [s for s in re.split(r"[,\s]+", text.strip()) if s] @st.cache_data def to_csv_bytes(df: pd.DataFrame) -> bytes: return df.to_csv(index=False).encode("utf-8") def normalise_shap(shap_values): """Return a list of 2-D per-class SHAP arrays regardless of SHAP version.""" if isinstance(shap_values, np.ndarray) and shap_values.ndim == 3: return [shap_values[:, :, i] for i in range(shap_values.shape[2])] return shap_values def is_tree_model(model) -> bool: try: import xgboost, lightgbm, catboost from sklearn.ensemble import RandomForestClassifier tree_types = ( xgboost.XGBClassifier, lightgbm.LGBMClassifier, catboost.CatBoostClassifier, RandomForestClassifier, ) return isinstance(model, tree_types) except ImportError: return False def build_prioritisation_table(annotations: pd.DataFrame, model) -> pd.DataFrame: """Compute all-gene prioritisation directly from model + feature table.""" probability_columns = [ "Probability_Most_Likely", "Probability_Probable", "Probability_Least_Likely" ] probabilities = model.predict_proba(annotations) prob_df = pd.DataFrame( probabilities, index=annotations.index, columns=probability_columns ) return pd.concat([prob_df, annotations], axis=1) def show_current_figure(): """Render and close the current matplotlib figure safely in Streamlit.""" fig = plt.gcf() st.pyplot(fig) plt.close(fig) # ── Load BP assets (cached so they only load once) ────────────────────────── @st.cache_resource def load_bp_assets(): annotations = pd.read_csv("all_genes_imputed_features.csv") annotations.fillna(0, inplace=True) annotations = annotations.set_index("Gene") with open("best_model_fitted.pkl", "rb") as f: model = pickle.load(f) explainer = shap.TreeExplainer(model) return annotations, model, explainer # ── Shared gene-prioritisation + SHAP panel ───────────────────────────────── def render_gene_prioritisation(annotations, model, explainer, prefix=""): probability_columns = [ "Probability_Most_Likely", "Probability_Probable", "Probability_Least_Likely" ] class_names = ["Most likely", "Probable", "Least likely"] df_total = build_prioritisation_table(annotations, model) # ── Multi-gene query ───────────────────────────────────────────────────── input_text = st.text_input( "Input multiple HGNC genes (comma or space separated):", key=f"{prefix}_multi_gene", ) gene_list = parse_gene_list(input_text) if len(gene_list) > 1: df = df_total[df_total.index.isin(gene_list)].copy() df.insert(0, "Gene", df.index) st.dataframe(df) st.download_button( "Download Gene Prioritisation", to_csv_bytes(df[["Gene"] + probability_columns]), "gene_prioritisation.csv", "text/csv", key=f"{prefix}_dl_multi", ) df_shap = df.drop(columns=probability_columns + ["Gene"]) sv = normalise_shap(explainer.shap_values(df_shap)) col1, col2 = st.columns(2) with col1: st.subheader("Global SHAP Summary") shap.summary_plot(sv, df_shap, plot_type="bar", class_names=class_names, show=False) show_current_figure() with col2: st.subheader(f"{class_names[0]} Prediction") shap.summary_plot(sv[0], df_shap, show=False) show_current_figure() col3, col4 = st.columns(2) with col3: st.subheader(f"{class_names[1]} Prediction") shap.summary_plot(sv[1], df_shap, show=False) show_current_figure() with col4: st.subheader(f"{class_names[2]} Prediction") shap.summary_plot(sv[2], df_shap, show=False) show_current_figure() # Interactive beeswarm (top-20 features, class 0) st.subheader("Interactive SHAP Plot (Top 20 Features β€” Most Likely Class)") fi = np.abs(sv[0]).mean(axis=0) top_idx = np.argsort(fi)[-20:] features_top = df_shap.columns[top_idx][::-1] sv_top = sv[0][:, top_idx][..., ::-1] genes_in_plot = list(df["Gene"]) x_vals, y_vals, hover = [], [], [] for i, fname in enumerate(features_top): for gene, val in zip(genes_in_plot, sv_top[:, i]): x_vals.append(val); y_vals.append(fname) hover.append(f"{gene}: {val:.3f}") fig_bee = go.Figure(data=go.Scatter( x=x_vals, y=y_vals, mode="markers", marker=dict(color=x_vals, colorbar=dict(title="SHAP"), colorscale=[(0, "blue"), (1, "red")]), text=hover, hoverinfo="text+x", )) fig_bee.update_layout( title="SHAP Summary β€” Top 20 Features", xaxis_title="SHAP Value", yaxis=dict(autorange="reversed", title="Feature"), showlegend=False, ) st.plotly_chart(fig_bee, use_container_width=True) # ── Single-gene force plot ─────────────────────────────────────────────── input_single = st.text_input("Input a single HGNC gene:", key=f"{prefix}_single_gene") if input_single: if " " in input_single or "," in input_single: st.warning("Please enter only one gene name (no spaces or commas).") elif input_single not in df_total.index: st.warning(f"Gene '{input_single}' not found in the dataset.") else: df2 = df_total.loc[[input_single]].copy() df2.insert(0, "Gene", df2.index) st.dataframe(df2) df2_shap = df_total.loc[ [input_single], [c for c in df_total.columns if c not in probability_columns], ] sv2 = normalise_shap(explainer.shap_values(df2_shap)) for i, cname in enumerate(class_names): st.subheader(f"Force Plot β€” {cname}") fp = shap.force_plot( explainer.expected_value[i], sv2[i], df2_shap, matplotlib=True, show=False, ) # Some SHAP versions return the active pyplot figure rather than a Figure object. if fp is not None and hasattr(fp, "savefig"): st.pyplot(fp) plt.close(fp) else: show_current_figure() url = ( "https://astrazeneca-cgr-publications.github.io/DrugnomeAI/" f"geneview.html?gene={input_single}" ) genecards_url = f"https://www.genecards.org/cgi-bin/carddisp.pl?gene={input_single}" st.markdown( f"[{input_single} Druggability in DrugnomeAI]({url}) | " f"[{input_single} in GeneCards]({genecards_url})" ) # ── Full results table ─────────────────────────────────────────────────── st.markdown("### Full Prioritisation Results (All Genes)") out = df_total.copy() out.insert(0, "Gene", out.index) st.dataframe(out.head(500)) st.caption(f"Showing first 500 rows of {len(out)} total genes.") st.download_button( "Download All Genes", to_csv_bytes(out), "all_genes_prioritisation.csv", "text/csv", key=f"{prefix}_dl_all", ) return df_total # ── App layout ─────────────────────────────────────────────────────────────── st.title("Disease-specific Gene Prioritisation Post-GWAS") st.markdown( "A interactive visualisation of gene prioritisation results from the machine learning pipeline: https://github.com/hlnicholls/GenePrioritiser. " "Use the **Blood Pressure** tab for the pre-trained BP model, " "or the **User-Defined Model** tab to upload your own model and run the same analysis." ) top_tab_bp, top_tab_user = st.tabs( [ "🩺 Blood Pressure (Pre-trained)", "πŸ“‚ User-Defined Model", ] ) # ════════════════════════════════════════════════════════════════════════════ # TAB 1 β€” Blood Pressure (Pre-trained) # ════════════════════════════════════════════════════════════════════════════ with top_tab_bp: st.header("Blood Pressure Gene Prioritisation") st.markdown( "Pre-trained model using SBP, DBP and PP GWAS data. " "This tab performs prioritisation of genes using a model trained on the " "blood pressure GWAS from Keaton et al. " "([https://doi.org/10.1038/s41588-024-01714-w](https://doi.org/10.1038/s41588-024-01714-w)). " "Three-class prediction: **Most Likely**, **Probable**, **Least Likely**." ) annotations_bp, model_bp, explainer_bp = load_bp_assets() sub_gp, sub_cluster = st.tabs( ["Gene Prioritisation & SHAP", "Supervised SHAP Clustering"] ) with sub_gp: df_total_bp = render_gene_prioritisation( annotations_bp, model_bp, explainer_bp, prefix="bp" ) with sub_cluster: st.subheader("Supervised SHAP Clustering (PCA)") st.markdown( "Each point is a gene positioned using PCA on its SHAP profile for the **Most Likely** class, " "so clustering reflects similarity in model explanation patterns rather than raw feature values. " "Genes that appear close together are being prioritised for similar reasons by the model, while genes far apart " "are driven by different feature contributions. Use the overlays to see whether your queried genes align with " "the most-likely training-gene region or form distinct groups." ) try: training_genes = pd.read_csv("training_cleaned.csv") label_col = next( (c for c in training_genes.columns if "label_encoded" in c.lower()), None ) if label_col: training_genes = training_genes[training_genes[label_col] == 0] training_genes = training_genes.set_index("Gene") except FileNotFoundError: training_genes = pd.DataFrame() st.info("training_cleaned.csv not found β€” most-likely training gene overlay disabled.") shap_full = normalise_shap(explainer_bp.shap_values(annotations_bp)) sv_arr = np.array(shap_full[0]) sv_pca = PCA(n_components=2).fit_transform(sv_arr) clust_labels = KMeans(n_clusters=3, random_state=seed).fit_predict(sv_pca) # Reuse genes typed in the Gene Prioritisation tab (same BP tab scope) # so users can see their queried genes on the clustering view without retyping. gp_input_text = st.session_state.get("bp_multi_gene", "") gp_genes = parse_gene_list(gp_input_text) cluster_input = st.text_input( "Highlight additional genes (comma/space separated):", key="bp_cluster_genes" ) cluster_genes = parse_gene_list(cluster_input) # Combine and deduplicate, preserving input order. highlight_genes = list(dict.fromkeys(gp_genes + cluster_genes)) if gp_genes: st.caption( f"Including {len(gp_genes)} gene(s) from the Gene Prioritisation input in this plot." ) df_plot = pd.DataFrame({ "PCA_1": sv_pca[:, 0], "PCA_2": sv_pca[:, 1], "Cluster": clust_labels.astype(str), "Gene": annotations_bp.index, "SpecialGroup": "None", }) if not training_genes.empty: df_plot.loc[ df_plot["Gene"].isin(training_genes.index), "SpecialGroup" ] = "Most Likely Training Gene" if highlight_genes: df_plot.loc[ df_plot["Gene"].isin(highlight_genes), "SpecialGroup" ] = "User Input Gene" if highlight_genes: found_genes = sorted(set(df_plot.loc[df_plot["Gene"].isin(highlight_genes), "Gene"])) missing_genes = sorted(set(highlight_genes) - set(found_genes)) if missing_genes: st.caption( "Not found in this feature set: " + ", ".join(missing_genes) ) fig_c = go.Figure() cluster_colors = ["#ADD8E6", "#87CEEB", "#1E90FF"] for i, cluster in enumerate(sorted(df_plot["Cluster"].unique())): sub = df_plot[(df_plot["Cluster"] == cluster) & (df_plot["SpecialGroup"] == "None")] fig_c.add_trace(go.Scatter( x=sub["PCA_1"], y=sub["PCA_2"], mode="markers", name=f"Cluster {cluster}", text=sub["Gene"], marker=dict(color=cluster_colors[i % len(cluster_colors)]), hoverinfo="text+x+y", )) for group, colour in [ ("Most Likely Training Gene", "black"), ("User Input Gene", "purple"), ]: sub = df_plot[df_plot["SpecialGroup"] == group] if not sub.empty: fig_c.add_trace(go.Scatter( x=sub["PCA_1"], y=sub["PCA_2"], mode="markers", name=group, text=sub["Gene"], marker=dict(color=colour), hoverinfo="text+x+y", )) fig_c.update_layout( title="Supervised SHAP Clustering with PCA", xaxis_title="First Principal Component", yaxis_title="Second Principal Component", legend_title_text="Gene Category", ) st.plotly_chart(fig_c, use_container_width=True) # ════════════════════════════════════════════════════════════════════════════ # TAB 2 β€” User-Defined Model # ════════════════════════════════════════════════════════════════════════════ with top_tab_user: st.header("User-Defined Model") st.markdown( """ Upload your own trained model and imputed gene features to run the same gene prioritisation and SHAP analysis. **Requirements:** - **Model file** (`.pkl`): a scikit-learn–compatible classifier with `predict_proba` returning 3 classes ordered as **Most Likely (0)**, **Probable (1)**, **Least Likely (2)**. Tree-based models (XGBoost, LightGBM, CatBoost, RandomForest) are fully supported. - **Features CSV**: imputed gene features with a `Gene` column and one numeric column per feature (same format as `all_genes_imputed_features.csv`). """ ) col_a, col_b = st.columns(2) with col_a: model_file = st.file_uploader("Upload model (.pkl)", type=["pkl"], key="user_pkl") with col_b: features_file = st.file_uploader( "Upload gene features CSV", type=["csv"], key="user_csv" ) if model_file is None or features_file is None: st.info("⬆️ Upload both files above to begin.") else: # Load model try: user_model = pickle.load(io.BytesIO(model_file.read())) except Exception as e: st.error(f"Could not load model: {e}") st.stop() # Load features try: user_annotations = pd.read_csv(features_file) user_annotations.fillna(0, inplace=True) if "Gene" not in user_annotations.columns: st.error("Features CSV must contain a 'Gene' column.") st.stop() user_annotations = user_annotations.set_index("Gene") except Exception as e: st.error(f"Could not read features CSV: {e}") st.stop() if not is_tree_model(user_model): st.warning( f"`{type(user_model).__name__}` is not a recognised tree-based model. " "SHAP TreeExplainer will be attempted but may fail for non-tree models." ) # Build explainer try: user_explainer = shap.TreeExplainer(user_model) except Exception as e: st.error( f"Could not create SHAP TreeExplainer: {e}\n" "Only tree-based models (XGBoost, LightGBM, CatBoost, RandomForest) " "are supported for SHAP analysis." ) st.stop() st.success( f"βœ… **Model**: `{type(user_model).__name__}` | " f"**Genes**: {user_annotations.shape[0]} | " f"**Features**: {user_annotations.shape[1]}" ) render_gene_prioritisation( user_annotations, user_model, user_explainer, prefix="user" )