import streamlit as st import scanpy as sc import matplotlib.pyplot as plt import numpy as np import textwrap from .utils import display_plot_with_download, display_interactive_spatial_plot, display_plotly_with_download def render_umap_embedding(metabolic_adata): """Render UMAP embedding with Red theme.""" st.markdown("

UMAP Analysis

", unsafe_allow_html=True) umap_viz_type = st.session_state.get("u_v_t", "Domain") if umap_viz_type == "Domain": c1, c2 = st.columns([1.5, 1.5]) else: c1, c2, c3 = st.columns([1.2, 1.8, 1.2]) with c1: umap_viz_type = st.selectbox("Color By:", options=["Domain", "Reaction", "Pathway"], key="u_v_t") with (c2 if umap_viz_type == "Domain" else c3): plot_mode = st.radio("Plot Mode:", ["Static", "Interactive"], horizontal=True, key="u_mode") selected_items = [] if umap_viz_type != "Domain": with c2: if umap_viz_type == "Reaction": if 'rxn_full_names' in metabolic_adata.var.columns: # Map full name to ID for user selection unique_names = {} for idx, row in metabolic_adata.var.iterrows(): f_name = str(row['rxn_full_names']) if f_name not in unique_names: unique_names[f_name] = idx rx_options = sorted(list(unique_names.keys())) if plot_mode == "Interactive": sel_name = st.selectbox("Select Reaction:", options=rx_options, key="u_rx_single") selected_items = [unique_names[sel_name]] if sel_name else [] else: sel_names = st.multiselect("Select Reactions:", options=rx_options, default=rx_options[:1], key="u_rx_multi") selected_items = [unique_names[n] for n in sel_names if n in unique_names] else: rx_options = metabolic_adata.var_names.tolist() if plot_mode == "Interactive": sel = st.selectbox("Select Reaction:", options=rx_options, key="u_rx_single") selected_items = [sel] if sel else [] else: selected_items = st.multiselect("Select Reactions:", options=rx_options, default=rx_options[:1], key="u_rx_multi") elif umap_viz_type == "Pathway": if 'subsystems' in metabolic_adata.var.columns: import pandas as pd path_options = sorted([p for p in metabolic_adata.var['subsystems'].unique() if pd.notna(p)]) if plot_mode == "Interactive": sel = st.selectbox("Select Pathway:", options=path_options, key="u_path_single") selected_items = [sel] if sel else [] else: selected_items = st.multiselect("Select Pathways:", options=path_options, default=path_options[:1], key="u_path_multi") else: st.warning("No pathway data.") if 'X_umap' not in metabolic_adata.obsm: with st.spinner("Calculating UMAP..."): sc.pp.pca(metabolic_adata, n_comps=50) sc.pp.neighbors(metabolic_adata, n_neighbors=15, n_pcs=50) sc.tl.umap(metabolic_adata) try: if plot_mode == "Interactive" and (umap_viz_type == "Domain" or selected_items): import plotly.express as px import pandas as pd umap_coords = metabolic_adata.obsm['X_umap'] target = selected_items[0] if selected_items else "Domain" display_title = target if umap_viz_type == "Reaction" and 'rxn_full_names' in metabolic_adata.var.columns: if target in metabolic_adata.var_names: display_title = metabolic_adata.var.loc[target, 'rxn_full_names'] if umap_viz_type == "Domain": vals = metabolic_adata.obs["domain"].astype(str).values color_scale = None # Use default qualitative for domain color_label = "Domain" elif target in metabolic_adata.var_names: idx = metabolic_adata.var_names.get_loc(target) raw = metabolic_adata.X[:, idx] vals = raw.toarray().flatten() if hasattr(raw, "toarray") else np.asarray(raw).flatten() color_scale = "Jet" color_label = "Flux" else: # Pathway rx_list = metabolic_adata.var[metabolic_adata.var['subsystems'] == target].index.tolist() X_sub = metabolic_adata[:, rx_list].X vals = np.array(X_sub.mean(axis=1)).flatten() if not hasattr(X_sub, "toarray") else np.array(X_sub.toarray().mean(axis=1)).flatten() color_scale = "Jet" color_label = "Flux" df_umap = pd.DataFrame({ "UMAP1": umap_coords[:, 0], "UMAP2": umap_coords[:, 1], "color": vals, "Domain": metabolic_adata.obs["domain"].values if "domain" in metabolic_adata.obs.columns else "N/A", "Spot": metabolic_adata.obs_names }) fig = px.scatter(df_umap, x="UMAP1", y="UMAP2", color="color", hover_data=["Domain", "Spot"], color_continuous_scale=color_scale if color_scale else None, title=f"UMAP Analysis: {display_title}") fig.update_layout( template="simple_white", coloraxis_colorbar=dict(title=color_label) if color_scale else None, legend_title_text="Domain" if umap_viz_type == "Domain" else None, yaxis=dict(scaleanchor="x", scaleratio=1), width=700, height=700, xaxis=dict(showgrid=False, zeroline=False), yaxis_showgrid=False, yaxis_zeroline=False ) # Dynamic help text help_msg = f"Uniform Manifold Approximation and Projection (UMAP) is used for dimensionality reduction. " if umap_viz_type == "Reaction": help_msg += f"This plot shows the flux distribution of **{display_title}** in the reduced feature space." elif umap_viz_type == "Pathway": help_msg += f"Across the UMAP manifold, we visualize the average flux for the **{target}** pathway." else: help_msg += "Spots are colored by metabolic domain to visualize global functional clustering." display_plotly_with_download( fig, f"umap_{umap_viz_type}", help_text=help_msg ) elif umap_viz_type == "Domain": # Static Domain fig, ax = plt.subplots(figsize=(8, 8)) sc.pl.umap(metabolic_adata, color=['domain'], show=False, ax=ax, size=100) display_plot_with_download( fig, "umap_domain", help_text="This static UMAP shows the distribution of metabolic domains in lower-dimensional space. Spots colored by domain help visualize how well-separated the clustered metabolic regions are." ) plt.close(fig) elif selected_items: per_page = 8 total = len(selected_items) pages = (total + per_page - 1) // per_page if "umap_page" not in st.session_state: st.session_state.umap_page = 1 if st.session_state.umap_page > pages: st.session_state.umap_page = 1 curr = selected_items[(st.session_state.umap_page-1)*per_page : st.session_state.umap_page*per_page] n_cols = min(2, len(curr)) n_rows = (len(curr) + n_cols - 1) // n_cols fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4.5*n_rows)) if len(curr) == 1: axes = np.array([[axes]]) elif n_rows == 1: axes = axes.reshape(1, -1) elif n_cols == 1: axes = axes.reshape(-1, 1) for i, target in enumerate(curr): r, c = i // n_cols, i % n_cols if target in metabolic_adata.var_names: sc.pl.umap(metabolic_adata, color=[target], cmap='jet', show=False, ax=axes[r,c], size=80) if 'rxn_full_names' in metabolic_adata.var.columns: full_name = str(metabolic_adata.var.loc[target, 'rxn_full_names']) wrapper = textwrap.TextWrapper(width=40) axes[r,c].set_title(wrapper.fill(text=full_name), fontsize=10) else: # Pathway aggregate rx_list = metabolic_adata.var[metabolic_adata.var['subsystems'] == target].index.tolist() metabolic_adata.obs['tmp_u'] = np.array(metabolic_adata[:, rx_list].X.mean(axis=1)).flatten() sc.pl.umap(metabolic_adata, color=['tmp_u'], cmap='jet', show=False, ax=axes[r,c], size=80) wrapper = textwrap.TextWrapper(width=40) axes[r,c].set_title(wrapper.fill(text=str(target)), fontsize=10) if 'tmp_u' in metabolic_adata.obs: del metabolic_adata.obs['tmp_u'] axes[r,c].axis('off') for j in range(len(curr), n_rows*n_cols): axes[j//n_cols, j%n_cols].axis('off') plt.tight_layout() # Dynamic help text for static panels static_names = [] for t in curr: if t in metabolic_adata.var_names and 'rxn_full_names' in metabolic_adata.var.columns: static_names.append(metabolic_adata.var.loc[t, 'rxn_full_names']) else: static_names.append(str(t)) static_names_str = ", ".join(static_names) display_plot_with_download( fig, f"umap_p{st.session_state.umap_page}", help_text=f"These static UMAP panels show the flux distribution for: **{static_names_str}**. It helps identify metabolic hotspots for these specific processes within the reduced manifold." ) plt.close(fig) if pages > 1: cx1, cx2, cx3 = st.columns([1,2,1]) if cx1.button("Prev UMAP Page", key="u_prev"): st.session_state.umap_page -= 1; st.rerun() cx2.markdown(f"
Page {st.session_state.umap_page} / {pages}
", unsafe_allow_html=True) if cx3.button("Next UMAP Page", key="u_next"): st.session_state.umap_page += 1; st.rerun() except Exception as e: st.error(f"Error during UMAP visualization: {e}")