import streamlit as st import scanpy as sc import matplotlib.pyplot as plt import numpy as np import pandas as pd import logging import textwrap from .utils import display_plot_with_download, display_interactive_spatial_plot, display_plotly_with_download logger = logging.getLogger(__name__) def init_plot_state(): """Initialize plot caching state variables.""" if "plot_cache" not in st.session_state: st.session_state.plot_cache = {} if "last_viz_params" not in st.session_state: st.session_state.last_viz_params = {} def _detect_viz_change_and_clear(): """Detect if visualization parameters changed and clear matplotlib cache.""" current_params = { 'viz_choice': st.session_state.get('sp_viz_choice', 'Domains'), 'plot_mode': st.session_state.get('sp_mode', 'Static'), } last_params = st.session_state.get('sp_last_params', {}) if current_params != last_params: st.session_state.sp_last_params = current_params plt.close('all') # Close all matplotlib figures return True return False def render_spatial_flux_map(metabolic_adata): """Render spatial flux maps with Red theme.""" init_plot_state() _detect_viz_change_and_clear() st.markdown("

Spatial Metabolic flux

", unsafe_allow_html=True) viz_choice = st.session_state.get("sp_viz_choice", "Domains") if viz_choice == "Domains": c1, c2, c3 = st.columns([1.5, 1.2, 1.3]) else: c1, c2, c3, c4 = st.columns([1.2, 1.8, 1.0, 1.2]) with c1: viz_choice = st.selectbox("Analysis Type:", options=["Domains", "Reactions", "Pathways"], key="sp_viz_choice") with (c3 if viz_choice == "Domains" else c4): plot_mode = st.radio("Plot Mode:", ["Static", "Interactive"], horizontal=True, key="sp_mode") with (c2 if viz_choice == "Domains" else c3): spot_size = st.slider("Spot Size:", 0.5, 5.0, 1.2, 0.5, key="sp_spot_size_static") if plot_mode == "Static" else st.slider("Spot Size:", 1, 20, 6, key="sp_spot_size_interactive") selected_items = [] if viz_choice != "Domains": with c2: if viz_choice == "Reactions": if 'rxn_full_names' in metabolic_adata.var.columns: unique_names = {} for idx, row in metabolic_adata.var.iterrows(): f_name = str(row['rxn_full_names']) if f_name not in unique_names: unique_names[f_name] = idx rx_options = sorted(list(unique_names.keys())) if plot_mode == "Interactive": sel_name = st.selectbox("Select Reaction:", options=rx_options, key="sp_rx_single") selected_items = [unique_names[sel_name]] if sel_name else [] else: sel_names = st.multiselect("Select Reactions:", options=rx_options, default=rx_options[:1], key="sp_rx_multi") selected_items = [unique_names[n] for n in sel_names if n in unique_names] else: rx_options = metabolic_adata.var_names.tolist() if plot_mode == "Interactive": sel = st.selectbox("Select Reaction:", options=rx_options, key="sp_rx_single") selected_items = [sel] if sel else [] else: selected_items = st.multiselect("Select Reactions:", options=rx_options, default=rx_options[:1], key="sp_rx_multi") elif viz_choice == "Pathways": if 'subsystems' in metabolic_adata.var.columns: path_options = sorted([p for p in metabolic_adata.var['subsystems'].unique() if pd.notna(p)]) if plot_mode == "Interactive": sel = st.selectbox("Select Pathway:", options=path_options, key="sp_path_single") selected_items = [sel] if sel else [] else: selected_items = st.multiselect("Select Pathways:", options=path_options, default=path_options[:1], key="sp_path_multi") else: st.warning("No pathway data.") try: library_id = next(iter(metabolic_adata.uns["spatial"])) img_key = "hires" if "hires" in metabolic_adata.uns["spatial"][library_id]["images"] else "downscaled_fullres" if viz_choice == "Domains": if plot_mode == "Interactive": display_interactive_spatial_plot( metabolic_adata, color_key="domain", spot_size=spot_size, plot_name="spatial_domain_plotly", title="Domain Assignment", help_text="This map highlights the spatial domains assigned byclustering spots with similar metabolic flux patterns. It shows the geographical organization of the tissue's metabolic environment." ) else: fig, ax = plt.subplots(figsize=(10, 8)) sc.pl.spatial(metabolic_adata, img_key=img_key, color=['domain'], size=spot_size, show=False, ax=ax) display_plot_with_download( fig, "spatial_domain", help_text="This map shows the spatial distribution of metabolic domains across the tissue. Each domain represents a cluster of spots with similar metabolic flux profiles." ) plt.close(fig) elif viz_choice == "Pathways": if not selected_items: st.info("Please select a pathway.") return if plot_mode == "Interactive": target = selected_items[0] rx_list = metabolic_adata.var[metabolic_adata.var['subsystems'] == target].index.tolist() X_sub = metabolic_adata[:, rx_list].X pathway_avg = np.array(X_sub.mean(axis=1)).flatten() if not hasattr(X_sub, "toarray") else np.array(X_sub.toarray().mean(axis=1)).flatten() metabolic_adata.obs[f'temp_{target}'] = pathway_avg wrapper = textwrap.TextWrapper(width=40) display_title = wrapper.fill(text=f"Pathway: {target}") display_interactive_spatial_plot( metabolic_adata, color_key=f'temp_{target}', spot_size=spot_size, plot_name=f"spatial_{target}_avg_plotly", title=display_title, help_text=f"This interactive map shows the averaged flux distribution for the **{target}** pathway. High intensity regions highlight where this metabolic process is most active within the tissue." ) del metabolic_adata.obs[f'temp_{target}'] else: per_page = 4 total = len(selected_items) pages = (total + per_page - 1) // per_page if "sp_path_page" not in st.session_state: st.session_state.sp_path_page = 1 if st.session_state.sp_path_page > pages: st.session_state.sp_path_page = 1 curr_items = selected_items[(st.session_state.sp_path_page-1)*per_page : st.session_state.sp_path_page*per_page] n_cols = 2 if len(curr_items) > 1 else 1 n_rows = (len(curr_items) + n_cols - 1) // n_cols fig, axes = plt.subplots(n_rows, n_cols, figsize=(8*n_cols, 7*n_rows)) if len(curr_items) == 1: axes = np.array([[axes]]) elif n_rows == 1: axes = axes.reshape(1, -1) elif n_cols == 1: axes = axes.reshape(-1, 1) for i, target in enumerate(curr_items): r, c = i // n_cols, i % n_cols rx_list = metabolic_adata.var[metabolic_adata.var['subsystems'] == target].index.tolist() X_sub = metabolic_adata[:, rx_list].X avg = np.array(X_sub.mean(axis=1)).flatten() if not hasattr(X_sub, "toarray") else np.array(X_sub.toarray().mean(axis=1)).flatten() metabolic_adata.obs['tmp_avg'] = avg sc.pl.spatial(metabolic_adata, img_key=img_key, color=['tmp_avg'], size=spot_size, cmap='jet', show=False, ax=axes[r,c]) wrapper = textwrap.TextWrapper(width=40) axes[r,c].set_title(wrapper.fill(text=str(target)), fontsize=12) for j in range(len(curr_items), n_rows*n_cols): axes[j//n_cols, j%n_cols].axis('off') plt.tight_layout() target_names = ", ".join([str(t) for t in curr_items]) display_plot_with_download( fig, f"spatial_pathway_p{st.session_state.sp_path_page}", help_text=f"This spatial flux map visualizes the spatial distribution of averaged flux for the pathways: **{target_names}**. It helps localize pathway activities within the tissue." ) plt.close(fig) if 'tmp_avg' in metabolic_adata.obs: del metabolic_adata.obs['tmp_avg'] if pages > 1: c_p1, c_p2, c_p3 = st.columns([1,2,1]) if c_p1.button("Prev Pathway", key="pw_prev"): st.session_state.sp_path_page -= 1; st.rerun() c_p2.markdown(f"
Pathway Page {st.session_state.sp_path_page} / {pages}
", unsafe_allow_html=True) if c_p3.button("Next Pathway", key="pw_next"): st.session_state.sp_path_page += 1; st.rerun() elif selected_items: if plot_mode == "Interactive": target = selected_items[0] display_title = target if 'rxn_full_names' in metabolic_adata.var.columns and target in metabolic_adata.var_names: display_title = metabolic_adata.var.loc[target, 'rxn_full_names'] wrapper = textwrap.TextWrapper(width=40) display_interactive_spatial_plot( metabolic_adata, color_key=target, spot_size=spot_size, plot_name=f"spatial_{target}_plotly", title=wrapper.fill(text=f"Reaction: {display_title}"), help_text=f"This interactive spatial map visualizes the flux distribution for the reaction **{display_title}**. You can explore its metabolic activity across different spatial domains." ) else: per_page = 8 total = len(selected_items) pages = (total + per_page - 1) // per_page if "spatial_flux_page" not in st.session_state: st.session_state.spatial_flux_page = 1 if st.session_state.spatial_flux_page > pages: st.session_state.spatial_flux_page = 1 curr_rx = selected_items[(st.session_state.spatial_flux_page-1)*per_page : st.session_state.spatial_flux_page*per_page] n_cols = min(2, len(curr_rx)) n_rows = (len(curr_rx) + n_cols - 1) // n_cols fig, axes = plt.subplots(n_rows, n_cols, figsize=(8*n_cols, 7*n_rows)) if len(curr_rx) == 1: axes = np.array([[axes]]) elif n_rows == 1: axes = axes.reshape(1, -1) elif n_cols == 1: axes = axes.reshape(-1, 1) for i, rx in enumerate(curr_rx): r, c = i // n_cols, i % n_cols sc.pl.spatial(metabolic_adata, img_key=img_key, color=[rx], size=spot_size, cmap='jet', show=False, ax=axes[r,c]) display_title = rx if 'rxn_full_names' in metabolic_adata.var.columns and rx in metabolic_adata.var_names: display_title = metabolic_adata.var.loc[rx, 'rxn_full_names'] wrapper = textwrap.TextWrapper(width=40) axes[r,c].set_title(wrapper.fill(text=display_title), fontsize=10) axes[r,c].axis('off') for j in range(len(curr_rx), n_rows*n_cols): axes[j//n_cols, j%n_cols].axis('off') plt.tight_layout() rx_names_list = [] for rx in curr_rx: if 'rxn_full_names' in metabolic_adata.var.columns and rx in metabolic_adata.var_names: rx_names_list.append(metabolic_adata.var.loc[rx, 'rxn_full_names']) else: rx_names_list.append(rx) rx_names_str = ", ".join(rx_names_list) display_plot_with_download( fig, f"spatial_flux_p{st.session_state.spatial_flux_page}", help_text=f"These maps show the spatial distribution of flux for: **{rx_names_str}**, allowing visualization of where specific metabolic processes are active." ) plt.close(fig) if pages > 1: cx1, cx2, cx3 = st.columns([1,2,1]) if cx1.button("Previous Page", key="sf_prev"): st.session_state.spatial_flux_page -= 1; st.rerun() cx2.markdown(f"
Reaction Page {st.session_state.spatial_flux_page} of {pages}
", unsafe_allow_html=True) if cx3.button("Next Page", key="sf_next"): st.session_state.spatial_flux_page += 1; st.rerun() except Exception as e: st.error(f"Error: {e}")