import streamlit as st
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np
import textwrap
from .utils import display_plot_with_download, display_interactive_spatial_plot, display_plotly_with_download
def render_umap_embedding(metabolic_adata):
"""Render UMAP embedding with Red theme."""
st.markdown("
UMAP Analysis
", unsafe_allow_html=True)
umap_viz_type = st.session_state.get("u_v_t", "Domain")
if umap_viz_type == "Domain":
c1, c2 = st.columns([1.5, 1.5])
else:
c1, c2, c3 = st.columns([1.2, 1.8, 1.2])
with c1:
umap_viz_type = st.selectbox("Color By:", options=["Domain", "Reaction", "Pathway"], key="u_v_t")
with (c2 if umap_viz_type == "Domain" else c3):
plot_mode = st.radio("Plot Mode:", ["Static", "Interactive"], horizontal=True, key="u_mode")
selected_items = []
if umap_viz_type != "Domain":
with c2:
if umap_viz_type == "Reaction":
if 'rxn_full_names' in metabolic_adata.var.columns:
# Map full name to ID for user selection
unique_names = {}
for idx, row in metabolic_adata.var.iterrows():
f_name = str(row['rxn_full_names'])
if f_name not in unique_names:
unique_names[f_name] = idx
rx_options = sorted(list(unique_names.keys()))
if plot_mode == "Interactive":
sel_name = st.selectbox("Select Reaction:", options=rx_options, key="u_rx_single")
selected_items = [unique_names[sel_name]] if sel_name else []
else:
sel_names = st.multiselect("Select Reactions:", options=rx_options, default=rx_options[:1], key="u_rx_multi")
selected_items = [unique_names[n] for n in sel_names if n in unique_names]
else:
rx_options = metabolic_adata.var_names.tolist()
if plot_mode == "Interactive":
sel = st.selectbox("Select Reaction:", options=rx_options, key="u_rx_single")
selected_items = [sel] if sel else []
else:
selected_items = st.multiselect("Select Reactions:", options=rx_options, default=rx_options[:1], key="u_rx_multi")
elif umap_viz_type == "Pathway":
if 'subsystems' in metabolic_adata.var.columns:
import pandas as pd
path_options = sorted([p for p in metabolic_adata.var['subsystems'].unique() if pd.notna(p)])
if plot_mode == "Interactive":
sel = st.selectbox("Select Pathway:", options=path_options, key="u_path_single")
selected_items = [sel] if sel else []
else:
selected_items = st.multiselect("Select Pathways:", options=path_options, default=path_options[:1], key="u_path_multi")
else:
st.warning("No pathway data.")
if 'X_umap' not in metabolic_adata.obsm:
with st.spinner("Calculating UMAP..."):
sc.pp.pca(metabolic_adata, n_comps=50)
sc.pp.neighbors(metabolic_adata, n_neighbors=15, n_pcs=50)
sc.tl.umap(metabolic_adata)
try:
if plot_mode == "Interactive" and (umap_viz_type == "Domain" or selected_items):
import plotly.express as px
import pandas as pd
umap_coords = metabolic_adata.obsm['X_umap']
target = selected_items[0] if selected_items else "Domain"
display_title = target
if umap_viz_type == "Reaction" and 'rxn_full_names' in metabolic_adata.var.columns:
if target in metabolic_adata.var_names:
display_title = metabolic_adata.var.loc[target, 'rxn_full_names']
if umap_viz_type == "Domain":
vals = metabolic_adata.obs["domain"].astype(str).values
color_scale = None # Use default qualitative for domain
color_label = "Domain"
elif target in metabolic_adata.var_names:
idx = metabolic_adata.var_names.get_loc(target)
raw = metabolic_adata.X[:, idx]
vals = raw.toarray().flatten() if hasattr(raw, "toarray") else np.asarray(raw).flatten()
color_scale = "Jet"
color_label = "Flux"
else:
# Pathway
rx_list = metabolic_adata.var[metabolic_adata.var['subsystems'] == target].index.tolist()
X_sub = metabolic_adata[:, rx_list].X
vals = np.array(X_sub.mean(axis=1)).flatten() if not hasattr(X_sub, "toarray") else np.array(X_sub.toarray().mean(axis=1)).flatten()
color_scale = "Jet"
color_label = "Flux"
df_umap = pd.DataFrame({
"UMAP1": umap_coords[:, 0],
"UMAP2": umap_coords[:, 1],
"color": vals,
"Domain": metabolic_adata.obs["domain"].values if "domain" in metabolic_adata.obs.columns else "N/A",
"Spot": metabolic_adata.obs_names
})
fig = px.scatter(df_umap, x="UMAP1", y="UMAP2", color="color",
hover_data=["Domain", "Spot"],
color_continuous_scale=color_scale if color_scale else None,
title=f"UMAP Analysis: {display_title}")
fig.update_layout(
template="simple_white",
coloraxis_colorbar=dict(title=color_label) if color_scale else None,
legend_title_text="Domain" if umap_viz_type == "Domain" else None,
yaxis=dict(scaleanchor="x", scaleratio=1),
width=700, height=700,
xaxis=dict(showgrid=False, zeroline=False),
yaxis_showgrid=False, yaxis_zeroline=False
)
# Dynamic help text
help_msg = f"Uniform Manifold Approximation and Projection (UMAP) is used for dimensionality reduction. "
if umap_viz_type == "Reaction":
help_msg += f"This plot shows the flux distribution of **{display_title}** in the reduced feature space."
elif umap_viz_type == "Pathway":
help_msg += f"Across the UMAP manifold, we visualize the average flux for the **{target}** pathway."
else:
help_msg += "Spots are colored by metabolic domain to visualize global functional clustering."
display_plotly_with_download(
fig,
f"umap_{umap_viz_type}",
help_text=help_msg
)
elif umap_viz_type == "Domain":
# Static Domain
fig, ax = plt.subplots(figsize=(8, 8))
sc.pl.umap(metabolic_adata, color=['domain'], show=False, ax=ax, size=100)
display_plot_with_download(
fig,
"umap_domain",
help_text="This static UMAP shows the distribution of metabolic domains in lower-dimensional space. Spots colored by domain help visualize how well-separated the clustered metabolic regions are."
)
plt.close(fig)
elif selected_items:
per_page = 8
total = len(selected_items)
pages = (total + per_page - 1) // per_page
if "umap_page" not in st.session_state: st.session_state.umap_page = 1
if st.session_state.umap_page > pages: st.session_state.umap_page = 1
curr = selected_items[(st.session_state.umap_page-1)*per_page : st.session_state.umap_page*per_page]
n_cols = min(2, len(curr))
n_rows = (len(curr) + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4.5*n_rows))
if len(curr) == 1: axes = np.array([[axes]])
elif n_rows == 1: axes = axes.reshape(1, -1)
elif n_cols == 1: axes = axes.reshape(-1, 1)
for i, target in enumerate(curr):
r, c = i // n_cols, i % n_cols
if target in metabolic_adata.var_names:
sc.pl.umap(metabolic_adata, color=[target], cmap='jet', show=False, ax=axes[r,c], size=80)
if 'rxn_full_names' in metabolic_adata.var.columns:
full_name = str(metabolic_adata.var.loc[target, 'rxn_full_names'])
wrapper = textwrap.TextWrapper(width=40)
axes[r,c].set_title(wrapper.fill(text=full_name), fontsize=10)
else:
# Pathway aggregate
rx_list = metabolic_adata.var[metabolic_adata.var['subsystems'] == target].index.tolist()
metabolic_adata.obs['tmp_u'] = np.array(metabolic_adata[:, rx_list].X.mean(axis=1)).flatten()
sc.pl.umap(metabolic_adata, color=['tmp_u'], cmap='jet', show=False, ax=axes[r,c], size=80)
wrapper = textwrap.TextWrapper(width=40)
axes[r,c].set_title(wrapper.fill(text=str(target)), fontsize=10)
if 'tmp_u' in metabolic_adata.obs: del metabolic_adata.obs['tmp_u']
axes[r,c].axis('off')
for j in range(len(curr), n_rows*n_cols): axes[j//n_cols, j%n_cols].axis('off')
plt.tight_layout()
# Dynamic help text for static panels
static_names = []
for t in curr:
if t in metabolic_adata.var_names and 'rxn_full_names' in metabolic_adata.var.columns:
static_names.append(metabolic_adata.var.loc[t, 'rxn_full_names'])
else:
static_names.append(str(t))
static_names_str = ", ".join(static_names)
display_plot_with_download(
fig,
f"umap_p{st.session_state.umap_page}",
help_text=f"These static UMAP panels show the flux distribution for: **{static_names_str}**. It helps identify metabolic hotspots for these specific processes within the reduced manifold."
)
plt.close(fig)
if pages > 1:
cx1, cx2, cx3 = st.columns([1,2,1])
if cx1.button("Prev UMAP Page", key="u_prev"): st.session_state.umap_page -= 1; st.rerun()
cx2.markdown(f"Page {st.session_state.umap_page} / {pages}", unsafe_allow_html=True)
if cx3.button("Next UMAP Page", key="u_next"): st.session_state.umap_page += 1; st.rerun()
except Exception as e:
st.error(f"Error during UMAP visualization: {e}")