Spaces:
Sleeping
Sleeping
| """ | |
| Blood Pressure Gene Prioritisation β Streamlit App | |
| =================================================== | |
| Two top-level tabs: | |
| 1. Blood Pressure (Pre-trained) β live prioritisation + SHAP from local BP assets | |
| 2. User-Defined Model β upload your own .pkl + features CSV | |
| """ | |
| import io | |
| import re | |
| import numpy as np | |
| import pandas as pd | |
| import pickle | |
| import streamlit as st | |
| import shap | |
| import plotly.graph_objects as go | |
| from sklearn.cluster import KMeans | |
| from sklearn.decomposition import PCA | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| st.set_page_config( | |
| page_title="Blood Pressure Gene Prioritisation", | |
| page_icon="π©Ί", | |
| layout="wide", | |
| ) | |
| seed = 0 | |
| # ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def parse_gene_list(text: str): | |
| return [s for s in re.split(r"[,\s]+", text.strip()) if s] | |
| def to_csv_bytes(df: pd.DataFrame) -> bytes: | |
| return df.to_csv(index=False).encode("utf-8") | |
| def normalise_shap(shap_values): | |
| """Return a list of 2-D per-class SHAP arrays regardless of SHAP version.""" | |
| if isinstance(shap_values, np.ndarray) and shap_values.ndim == 3: | |
| return [shap_values[:, :, i] for i in range(shap_values.shape[2])] | |
| return shap_values | |
| def is_tree_model(model) -> bool: | |
| try: | |
| import xgboost, lightgbm, catboost | |
| from sklearn.ensemble import RandomForestClassifier | |
| tree_types = ( | |
| xgboost.XGBClassifier, | |
| lightgbm.LGBMClassifier, | |
| catboost.CatBoostClassifier, | |
| RandomForestClassifier, | |
| ) | |
| return isinstance(model, tree_types) | |
| except ImportError: | |
| return False | |
| def build_prioritisation_table(annotations: pd.DataFrame, model) -> pd.DataFrame: | |
| """Compute all-gene prioritisation directly from model + feature table.""" | |
| probability_columns = [ | |
| "Probability_Most_Likely", "Probability_Probable", "Probability_Least_Likely" | |
| ] | |
| probabilities = model.predict_proba(annotations) | |
| prob_df = pd.DataFrame( | |
| probabilities, index=annotations.index, columns=probability_columns | |
| ) | |
| return pd.concat([prob_df, annotations], axis=1) | |
| def show_current_figure(): | |
| """Render and close the current matplotlib figure safely in Streamlit.""" | |
| fig = plt.gcf() | |
| st.pyplot(fig) | |
| plt.close(fig) | |
| # ββ Load BP assets (cached so they only load once) ββββββββββββββββββββββββββ | |
| def load_bp_assets(): | |
| annotations = pd.read_csv("all_genes_imputed_features.csv") | |
| annotations.fillna(0, inplace=True) | |
| annotations = annotations.set_index("Gene") | |
| with open("best_model_fitted.pkl", "rb") as f: | |
| model = pickle.load(f) | |
| explainer = shap.TreeExplainer(model) | |
| return annotations, model, explainer | |
| # ββ Shared gene-prioritisation + SHAP panel βββββββββββββββββββββββββββββββββ | |
| def render_gene_prioritisation(annotations, model, explainer, prefix=""): | |
| probability_columns = [ | |
| "Probability_Most_Likely", "Probability_Probable", "Probability_Least_Likely" | |
| ] | |
| class_names = ["Most likely", "Probable", "Least likely"] | |
| df_total = build_prioritisation_table(annotations, model) | |
| # ββ Multi-gene query βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| input_text = st.text_input( | |
| "Input multiple HGNC genes (comma or space separated):", | |
| key=f"{prefix}_multi_gene", | |
| ) | |
| gene_list = parse_gene_list(input_text) | |
| if len(gene_list) > 1: | |
| df = df_total[df_total.index.isin(gene_list)].copy() | |
| df.insert(0, "Gene", df.index) | |
| st.dataframe(df) | |
| st.download_button( | |
| "Download Gene Prioritisation", | |
| to_csv_bytes(df[["Gene"] + probability_columns]), | |
| "gene_prioritisation.csv", | |
| "text/csv", | |
| key=f"{prefix}_dl_multi", | |
| ) | |
| df_shap = df.drop(columns=probability_columns + ["Gene"]) | |
| sv = normalise_shap(explainer.shap_values(df_shap)) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Global SHAP Summary") | |
| shap.summary_plot(sv, df_shap, plot_type="bar", class_names=class_names, show=False) | |
| show_current_figure() | |
| with col2: | |
| st.subheader(f"{class_names[0]} Prediction") | |
| shap.summary_plot(sv[0], df_shap, show=False) | |
| show_current_figure() | |
| col3, col4 = st.columns(2) | |
| with col3: | |
| st.subheader(f"{class_names[1]} Prediction") | |
| shap.summary_plot(sv[1], df_shap, show=False) | |
| show_current_figure() | |
| with col4: | |
| st.subheader(f"{class_names[2]} Prediction") | |
| shap.summary_plot(sv[2], df_shap, show=False) | |
| show_current_figure() | |
| # Interactive beeswarm (top-20 features, class 0) | |
| st.subheader("Interactive SHAP Plot (Top 20 Features β Most Likely Class)") | |
| fi = np.abs(sv[0]).mean(axis=0) | |
| top_idx = np.argsort(fi)[-20:] | |
| features_top = df_shap.columns[top_idx][::-1] | |
| sv_top = sv[0][:, top_idx][..., ::-1] | |
| genes_in_plot = list(df["Gene"]) | |
| x_vals, y_vals, hover = [], [], [] | |
| for i, fname in enumerate(features_top): | |
| for gene, val in zip(genes_in_plot, sv_top[:, i]): | |
| x_vals.append(val); y_vals.append(fname) | |
| hover.append(f"{gene}: {val:.3f}") | |
| fig_bee = go.Figure(data=go.Scatter( | |
| x=x_vals, y=y_vals, mode="markers", | |
| marker=dict(color=x_vals, colorbar=dict(title="SHAP"), | |
| colorscale=[(0, "blue"), (1, "red")]), | |
| text=hover, hoverinfo="text+x", | |
| )) | |
| fig_bee.update_layout( | |
| title="SHAP Summary β Top 20 Features", | |
| xaxis_title="SHAP Value", | |
| yaxis=dict(autorange="reversed", title="Feature"), | |
| showlegend=False, | |
| ) | |
| st.plotly_chart(fig_bee, use_container_width=True) | |
| # ββ Single-gene force plot βββββββββββββββββββββββββββββββββββββββββββββββ | |
| input_single = st.text_input("Input a single HGNC gene:", key=f"{prefix}_single_gene") | |
| if input_single: | |
| if " " in input_single or "," in input_single: | |
| st.warning("Please enter only one gene name (no spaces or commas).") | |
| elif input_single not in df_total.index: | |
| st.warning(f"Gene '{input_single}' not found in the dataset.") | |
| else: | |
| df2 = df_total.loc[[input_single]].copy() | |
| df2.insert(0, "Gene", df2.index) | |
| st.dataframe(df2) | |
| df2_shap = df_total.loc[ | |
| [input_single], | |
| [c for c in df_total.columns if c not in probability_columns], | |
| ] | |
| sv2 = normalise_shap(explainer.shap_values(df2_shap)) | |
| for i, cname in enumerate(class_names): | |
| st.subheader(f"Force Plot β {cname}") | |
| fp = shap.force_plot( | |
| explainer.expected_value[i], sv2[i], df2_shap, | |
| matplotlib=True, show=False, | |
| ) | |
| # Some SHAP versions return the active pyplot figure rather than a Figure object. | |
| if fp is not None and hasattr(fp, "savefig"): | |
| st.pyplot(fp) | |
| plt.close(fp) | |
| else: | |
| show_current_figure() | |
| url = ( | |
| "https://astrazeneca-cgr-publications.github.io/DrugnomeAI/" | |
| f"geneview.html?gene={input_single}" | |
| ) | |
| genecards_url = f"https://www.genecards.org/cgi-bin/carddisp.pl?gene={input_single}" | |
| st.markdown( | |
| f"[{input_single} Druggability in DrugnomeAI]({url}) | " | |
| f"[{input_single} in GeneCards]({genecards_url})" | |
| ) | |
| # ββ Full results table βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown("### Full Prioritisation Results (All Genes)") | |
| out = df_total.copy() | |
| out.insert(0, "Gene", out.index) | |
| st.dataframe(out.head(500)) | |
| st.caption(f"Showing first 500 rows of {len(out)} total genes.") | |
| st.download_button( | |
| "Download All Genes", | |
| to_csv_bytes(out), | |
| "all_genes_prioritisation.csv", | |
| "text/csv", | |
| key=f"{prefix}_dl_all", | |
| ) | |
| return df_total | |
| # ββ App layout βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.title("Disease-specific Gene Prioritisation Post-GWAS") | |
| st.markdown( | |
| "A interactive visualisation of gene prioritisation results from the machine learning pipeline: https://github.com/hlnicholls/GenePrioritiser. " | |
| "Use the **Blood Pressure** tab for the pre-trained BP model, " | |
| "or the **User-Defined Model** tab to upload your own model and run the same analysis." | |
| ) | |
| top_tab_bp, top_tab_user = st.tabs( | |
| [ | |
| "π©Ί Blood Pressure (Pre-trained)", | |
| "π User-Defined Model", | |
| ] | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TAB 1 β Blood Pressure (Pre-trained) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with top_tab_bp: | |
| st.header("Blood Pressure Gene Prioritisation") | |
| st.markdown( | |
| "Pre-trained model using SBP, DBP and PP GWAS data. " | |
| "This tab performs prioritisation of genes using a model trained on the " | |
| "blood pressure GWAS from Keaton et al. " | |
| "([https://doi.org/10.1038/s41588-024-01714-w](https://doi.org/10.1038/s41588-024-01714-w)). " | |
| "Three-class prediction: **Most Likely**, **Probable**, **Least Likely**." | |
| ) | |
| annotations_bp, model_bp, explainer_bp = load_bp_assets() | |
| sub_gp, sub_cluster = st.tabs( | |
| ["Gene Prioritisation & SHAP", "Supervised SHAP Clustering"] | |
| ) | |
| with sub_gp: | |
| df_total_bp = render_gene_prioritisation( | |
| annotations_bp, model_bp, explainer_bp, prefix="bp" | |
| ) | |
| with sub_cluster: | |
| st.subheader("Supervised SHAP Clustering (PCA)") | |
| st.markdown( | |
| "Each point is a gene positioned using PCA on its SHAP profile for the **Most Likely** class, " | |
| "so clustering reflects similarity in model explanation patterns rather than raw feature values. " | |
| "Genes that appear close together are being prioritised for similar reasons by the model, while genes far apart " | |
| "are driven by different feature contributions. Use the overlays to see whether your queried genes align with " | |
| "the most-likely training-gene region or form distinct groups." | |
| ) | |
| try: | |
| training_genes = pd.read_csv("training_cleaned.csv") | |
| label_col = next( | |
| (c for c in training_genes.columns if "label_encoded" in c.lower()), None | |
| ) | |
| if label_col: | |
| training_genes = training_genes[training_genes[label_col] == 0] | |
| training_genes = training_genes.set_index("Gene") | |
| except FileNotFoundError: | |
| training_genes = pd.DataFrame() | |
| st.info("training_cleaned.csv not found β most-likely training gene overlay disabled.") | |
| shap_full = normalise_shap(explainer_bp.shap_values(annotations_bp)) | |
| sv_arr = np.array(shap_full[0]) | |
| sv_pca = PCA(n_components=2).fit_transform(sv_arr) | |
| clust_labels = KMeans(n_clusters=3, random_state=seed).fit_predict(sv_pca) | |
| # Reuse genes typed in the Gene Prioritisation tab (same BP tab scope) | |
| # so users can see their queried genes on the clustering view without retyping. | |
| gp_input_text = st.session_state.get("bp_multi_gene", "") | |
| gp_genes = parse_gene_list(gp_input_text) | |
| cluster_input = st.text_input( | |
| "Highlight additional genes (comma/space separated):", key="bp_cluster_genes" | |
| ) | |
| cluster_genes = parse_gene_list(cluster_input) | |
| # Combine and deduplicate, preserving input order. | |
| highlight_genes = list(dict.fromkeys(gp_genes + cluster_genes)) | |
| if gp_genes: | |
| st.caption( | |
| f"Including {len(gp_genes)} gene(s) from the Gene Prioritisation input in this plot." | |
| ) | |
| df_plot = pd.DataFrame({ | |
| "PCA_1": sv_pca[:, 0], "PCA_2": sv_pca[:, 1], | |
| "Cluster": clust_labels.astype(str), | |
| "Gene": annotations_bp.index, "SpecialGroup": "None", | |
| }) | |
| if not training_genes.empty: | |
| df_plot.loc[ | |
| df_plot["Gene"].isin(training_genes.index), "SpecialGroup" | |
| ] = "Most Likely Training Gene" | |
| if highlight_genes: | |
| df_plot.loc[ | |
| df_plot["Gene"].isin(highlight_genes), "SpecialGroup" | |
| ] = "User Input Gene" | |
| if highlight_genes: | |
| found_genes = sorted(set(df_plot.loc[df_plot["Gene"].isin(highlight_genes), "Gene"])) | |
| missing_genes = sorted(set(highlight_genes) - set(found_genes)) | |
| if missing_genes: | |
| st.caption( | |
| "Not found in this feature set: " + ", ".join(missing_genes) | |
| ) | |
| fig_c = go.Figure() | |
| cluster_colors = ["#ADD8E6", "#87CEEB", "#1E90FF"] | |
| for i, cluster in enumerate(sorted(df_plot["Cluster"].unique())): | |
| sub = df_plot[(df_plot["Cluster"] == cluster) & (df_plot["SpecialGroup"] == "None")] | |
| fig_c.add_trace(go.Scatter( | |
| x=sub["PCA_1"], y=sub["PCA_2"], mode="markers", | |
| name=f"Cluster {cluster}", text=sub["Gene"], | |
| marker=dict(color=cluster_colors[i % len(cluster_colors)]), | |
| hoverinfo="text+x+y", | |
| )) | |
| for group, colour in [ | |
| ("Most Likely Training Gene", "black"), | |
| ("User Input Gene", "purple"), | |
| ]: | |
| sub = df_plot[df_plot["SpecialGroup"] == group] | |
| if not sub.empty: | |
| fig_c.add_trace(go.Scatter( | |
| x=sub["PCA_1"], y=sub["PCA_2"], mode="markers", | |
| name=group, text=sub["Gene"], | |
| marker=dict(color=colour), hoverinfo="text+x+y", | |
| )) | |
| fig_c.update_layout( | |
| title="Supervised SHAP Clustering with PCA", | |
| xaxis_title="First Principal Component", | |
| yaxis_title="Second Principal Component", | |
| legend_title_text="Gene Category", | |
| ) | |
| st.plotly_chart(fig_c, use_container_width=True) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TAB 2 β User-Defined Model | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with top_tab_user: | |
| st.header("User-Defined Model") | |
| st.markdown( | |
| """ | |
| Upload your own trained model and imputed gene features to run the same | |
| gene prioritisation and SHAP analysis. | |
| **Requirements:** | |
| - **Model file** (`.pkl`): a scikit-learnβcompatible classifier with `predict_proba` | |
| returning 3 classes ordered as **Most Likely (0)**, **Probable (1)**, **Least Likely (2)**. | |
| Tree-based models (XGBoost, LightGBM, CatBoost, RandomForest) are fully supported. | |
| - **Features CSV**: imputed gene features with a `Gene` column and one numeric column | |
| per feature (same format as `all_genes_imputed_features.csv`). | |
| """ | |
| ) | |
| col_a, col_b = st.columns(2) | |
| with col_a: | |
| model_file = st.file_uploader("Upload model (.pkl)", type=["pkl"], key="user_pkl") | |
| with col_b: | |
| features_file = st.file_uploader( | |
| "Upload gene features CSV", type=["csv"], key="user_csv" | |
| ) | |
| if model_file is None or features_file is None: | |
| st.info("β¬οΈ Upload both files above to begin.") | |
| else: | |
| # Load model | |
| try: | |
| user_model = pickle.load(io.BytesIO(model_file.read())) | |
| except Exception as e: | |
| st.error(f"Could not load model: {e}") | |
| st.stop() | |
| # Load features | |
| try: | |
| user_annotations = pd.read_csv(features_file) | |
| user_annotations.fillna(0, inplace=True) | |
| if "Gene" not in user_annotations.columns: | |
| st.error("Features CSV must contain a 'Gene' column.") | |
| st.stop() | |
| user_annotations = user_annotations.set_index("Gene") | |
| except Exception as e: | |
| st.error(f"Could not read features CSV: {e}") | |
| st.stop() | |
| if not is_tree_model(user_model): | |
| st.warning( | |
| f"`{type(user_model).__name__}` is not a recognised tree-based model. " | |
| "SHAP TreeExplainer will be attempted but may fail for non-tree models." | |
| ) | |
| # Build explainer | |
| try: | |
| user_explainer = shap.TreeExplainer(user_model) | |
| except Exception as e: | |
| st.error( | |
| f"Could not create SHAP TreeExplainer: {e}\n" | |
| "Only tree-based models (XGBoost, LightGBM, CatBoost, RandomForest) " | |
| "are supported for SHAP analysis." | |
| ) | |
| st.stop() | |
| st.success( | |
| f"β **Model**: `{type(user_model).__name__}` | " | |
| f"**Genes**: {user_annotations.shape[0]} | " | |
| f"**Features**: {user_annotations.shape[1]}" | |
| ) | |
| render_gene_prioritisation( | |
| user_annotations, user_model, user_explainer, prefix="user" | |
| ) | |