Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import scanpy as sc | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import logging | |
| import textwrap | |
| from .utils import display_plot_with_download, display_interactive_spatial_plot, display_plotly_with_download | |
| logger = logging.getLogger(__name__) | |
| def init_plot_state(): | |
| """Initialize plot caching state variables.""" | |
| if "plot_cache" not in st.session_state: | |
| st.session_state.plot_cache = {} | |
| if "last_viz_params" not in st.session_state: | |
| st.session_state.last_viz_params = {} | |
| def _detect_viz_change_and_clear(): | |
| """Detect if visualization parameters changed and clear matplotlib cache.""" | |
| current_params = { | |
| 'viz_choice': st.session_state.get('sp_viz_choice', 'Domains'), | |
| 'plot_mode': st.session_state.get('sp_mode', 'Static'), | |
| } | |
| last_params = st.session_state.get('sp_last_params', {}) | |
| if current_params != last_params: | |
| st.session_state.sp_last_params = current_params | |
| plt.close('all') # Close all matplotlib figures | |
| return True | |
| return False | |
| def render_spatial_flux_map(metabolic_adata): | |
| """Render spatial flux maps with Red theme.""" | |
| init_plot_state() | |
| _detect_viz_change_and_clear() | |
| st.markdown("<h2 style='color: #d32f2f;'><i class='fas fa-map-location-dot'></i> Spatial Metabolic flux</h2>", unsafe_allow_html=True) | |
| viz_choice = st.session_state.get("sp_viz_choice", "Domains") | |
| if viz_choice == "Domains": | |
| c1, c2, c3 = st.columns([1.5, 1.2, 1.3]) | |
| else: | |
| c1, c2, c3, c4 = st.columns([1.2, 1.8, 1.0, 1.2]) | |
| with c1: | |
| viz_choice = st.selectbox("Analysis Type:", options=["Domains", "Reactions", "Pathways"], key="sp_viz_choice") | |
| with (c3 if viz_choice == "Domains" else c4): | |
| plot_mode = st.radio("Plot Mode:", ["Static", "Interactive"], horizontal=True, key="sp_mode") | |
| with (c2 if viz_choice == "Domains" else c3): | |
| spot_size = st.slider("Spot Size:", 0.5, 5.0, 1.2, 0.5, key="sp_spot_size_static") if plot_mode == "Static" else st.slider("Spot Size:", 1, 20, 6, key="sp_spot_size_interactive") | |
| selected_items = [] | |
| if viz_choice != "Domains": | |
| with c2: | |
| if viz_choice == "Reactions": | |
| if 'rxn_full_names' in metabolic_adata.var.columns: | |
| unique_names = {} | |
| for idx, row in metabolic_adata.var.iterrows(): | |
| f_name = str(row['rxn_full_names']) | |
| if f_name not in unique_names: | |
| unique_names[f_name] = idx | |
| rx_options = sorted(list(unique_names.keys())) | |
| if plot_mode == "Interactive": | |
| sel_name = st.selectbox("Select Reaction:", options=rx_options, key="sp_rx_single") | |
| selected_items = [unique_names[sel_name]] if sel_name else [] | |
| else: | |
| sel_names = st.multiselect("Select Reactions:", options=rx_options, default=rx_options[:1], key="sp_rx_multi") | |
| selected_items = [unique_names[n] for n in sel_names if n in unique_names] | |
| else: | |
| rx_options = metabolic_adata.var_names.tolist() | |
| if plot_mode == "Interactive": | |
| sel = st.selectbox("Select Reaction:", options=rx_options, key="sp_rx_single") | |
| selected_items = [sel] if sel else [] | |
| else: | |
| selected_items = st.multiselect("Select Reactions:", options=rx_options, default=rx_options[:1], key="sp_rx_multi") | |
| elif viz_choice == "Pathways": | |
| if 'subsystems' in metabolic_adata.var.columns: | |
| path_options = sorted([p for p in metabolic_adata.var['subsystems'].unique() if pd.notna(p)]) | |
| if plot_mode == "Interactive": | |
| sel = st.selectbox("Select Pathway:", options=path_options, key="sp_path_single") | |
| selected_items = [sel] if sel else [] | |
| else: | |
| selected_items = st.multiselect("Select Pathways:", options=path_options, default=path_options[:1], key="sp_path_multi") | |
| else: | |
| st.warning("No pathway data.") | |
| try: | |
| library_id = next(iter(metabolic_adata.uns["spatial"])) | |
| img_key = "hires" if "hires" in metabolic_adata.uns["spatial"][library_id]["images"] else "downscaled_fullres" | |
| if viz_choice == "Domains": | |
| if plot_mode == "Interactive": | |
| display_interactive_spatial_plot( | |
| metabolic_adata, | |
| color_key="domain", | |
| spot_size=spot_size, | |
| plot_name="spatial_domain_plotly", | |
| title="Domain Assignment", | |
| help_text="This map highlights the spatial domains assigned byclustering spots with similar metabolic flux patterns. It shows the geographical organization of the tissue's metabolic environment." | |
| ) | |
| else: | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| sc.pl.spatial(metabolic_adata, img_key=img_key, color=['domain'], size=spot_size, show=False, ax=ax) | |
| display_plot_with_download( | |
| fig, | |
| "spatial_domain", | |
| help_text="This map shows the spatial distribution of metabolic domains across the tissue. Each domain represents a cluster of spots with similar metabolic flux profiles." | |
| ) | |
| plt.close(fig) | |
| elif viz_choice == "Pathways": | |
| if not selected_items: | |
| st.info("Please select a pathway.") | |
| return | |
| if plot_mode == "Interactive": | |
| target = selected_items[0] | |
| rx_list = metabolic_adata.var[metabolic_adata.var['subsystems'] == target].index.tolist() | |
| X_sub = metabolic_adata[:, rx_list].X | |
| pathway_avg = np.array(X_sub.mean(axis=1)).flatten() if not hasattr(X_sub, "toarray") else np.array(X_sub.toarray().mean(axis=1)).flatten() | |
| metabolic_adata.obs[f'temp_{target}'] = pathway_avg | |
| wrapper = textwrap.TextWrapper(width=40) | |
| display_title = wrapper.fill(text=f"Pathway: {target}") | |
| display_interactive_spatial_plot( | |
| metabolic_adata, | |
| color_key=f'temp_{target}', | |
| spot_size=spot_size, | |
| plot_name=f"spatial_{target}_avg_plotly", | |
| title=display_title, | |
| help_text=f"This interactive map shows the averaged flux distribution for the **{target}** pathway. High intensity regions highlight where this metabolic process is most active within the tissue." | |
| ) | |
| del metabolic_adata.obs[f'temp_{target}'] | |
| else: | |
| per_page = 4 | |
| total = len(selected_items) | |
| pages = (total + per_page - 1) // per_page | |
| if "sp_path_page" not in st.session_state: st.session_state.sp_path_page = 1 | |
| if st.session_state.sp_path_page > pages: st.session_state.sp_path_page = 1 | |
| curr_items = selected_items[(st.session_state.sp_path_page-1)*per_page : st.session_state.sp_path_page*per_page] | |
| n_cols = 2 if len(curr_items) > 1 else 1 | |
| n_rows = (len(curr_items) + n_cols - 1) // n_cols | |
| fig, axes = plt.subplots(n_rows, n_cols, figsize=(8*n_cols, 7*n_rows)) | |
| if len(curr_items) == 1: axes = np.array([[axes]]) | |
| elif n_rows == 1: axes = axes.reshape(1, -1) | |
| elif n_cols == 1: axes = axes.reshape(-1, 1) | |
| for i, target in enumerate(curr_items): | |
| r, c = i // n_cols, i % n_cols | |
| rx_list = metabolic_adata.var[metabolic_adata.var['subsystems'] == target].index.tolist() | |
| X_sub = metabolic_adata[:, rx_list].X | |
| avg = np.array(X_sub.mean(axis=1)).flatten() if not hasattr(X_sub, "toarray") else np.array(X_sub.toarray().mean(axis=1)).flatten() | |
| metabolic_adata.obs['tmp_avg'] = avg | |
| sc.pl.spatial(metabolic_adata, img_key=img_key, color=['tmp_avg'], size=spot_size, cmap='jet', show=False, ax=axes[r,c]) | |
| wrapper = textwrap.TextWrapper(width=40) | |
| axes[r,c].set_title(wrapper.fill(text=str(target)), fontsize=12) | |
| for j in range(len(curr_items), n_rows*n_cols): axes[j//n_cols, j%n_cols].axis('off') | |
| plt.tight_layout() | |
| target_names = ", ".join([str(t) for t in curr_items]) | |
| display_plot_with_download( | |
| fig, | |
| f"spatial_pathway_p{st.session_state.sp_path_page}", | |
| help_text=f"This spatial flux map visualizes the spatial distribution of averaged flux for the pathways: **{target_names}**. It helps localize pathway activities within the tissue." | |
| ) | |
| plt.close(fig) | |
| if 'tmp_avg' in metabolic_adata.obs: del metabolic_adata.obs['tmp_avg'] | |
| if pages > 1: | |
| c_p1, c_p2, c_p3 = st.columns([1,2,1]) | |
| if c_p1.button("Prev Pathway", key="pw_prev"): st.session_state.sp_path_page -= 1; st.rerun() | |
| c_p2.markdown(f"<center>Pathway Page {st.session_state.sp_path_page} / {pages}</center>", unsafe_allow_html=True) | |
| if c_p3.button("Next Pathway", key="pw_next"): st.session_state.sp_path_page += 1; st.rerun() | |
| elif selected_items: | |
| if plot_mode == "Interactive": | |
| target = selected_items[0] | |
| display_title = target | |
| if 'rxn_full_names' in metabolic_adata.var.columns and target in metabolic_adata.var_names: | |
| display_title = metabolic_adata.var.loc[target, 'rxn_full_names'] | |
| wrapper = textwrap.TextWrapper(width=40) | |
| display_interactive_spatial_plot( | |
| metabolic_adata, | |
| color_key=target, | |
| spot_size=spot_size, | |
| plot_name=f"spatial_{target}_plotly", | |
| title=wrapper.fill(text=f"Reaction: {display_title}"), | |
| help_text=f"This interactive spatial map visualizes the flux distribution for the reaction **{display_title}**. You can explore its metabolic activity across different spatial domains." | |
| ) | |
| else: | |
| per_page = 8 | |
| total = len(selected_items) | |
| pages = (total + per_page - 1) // per_page | |
| if "spatial_flux_page" not in st.session_state: st.session_state.spatial_flux_page = 1 | |
| if st.session_state.spatial_flux_page > pages: st.session_state.spatial_flux_page = 1 | |
| curr_rx = selected_items[(st.session_state.spatial_flux_page-1)*per_page : st.session_state.spatial_flux_page*per_page] | |
| n_cols = min(2, len(curr_rx)) | |
| n_rows = (len(curr_rx) + n_cols - 1) // n_cols | |
| fig, axes = plt.subplots(n_rows, n_cols, figsize=(8*n_cols, 7*n_rows)) | |
| if len(curr_rx) == 1: axes = np.array([[axes]]) | |
| elif n_rows == 1: axes = axes.reshape(1, -1) | |
| elif n_cols == 1: axes = axes.reshape(-1, 1) | |
| for i, rx in enumerate(curr_rx): | |
| r, c = i // n_cols, i % n_cols | |
| sc.pl.spatial(metabolic_adata, img_key=img_key, color=[rx], size=spot_size, cmap='jet', show=False, ax=axes[r,c]) | |
| display_title = rx | |
| if 'rxn_full_names' in metabolic_adata.var.columns and rx in metabolic_adata.var_names: | |
| display_title = metabolic_adata.var.loc[rx, 'rxn_full_names'] | |
| wrapper = textwrap.TextWrapper(width=40) | |
| axes[r,c].set_title(wrapper.fill(text=display_title), fontsize=10) | |
| axes[r,c].axis('off') | |
| for j in range(len(curr_rx), n_rows*n_cols): axes[j//n_cols, j%n_cols].axis('off') | |
| plt.tight_layout() | |
| rx_names_list = [] | |
| for rx in curr_rx: | |
| if 'rxn_full_names' in metabolic_adata.var.columns and rx in metabolic_adata.var_names: | |
| rx_names_list.append(metabolic_adata.var.loc[rx, 'rxn_full_names']) | |
| else: | |
| rx_names_list.append(rx) | |
| rx_names_str = ", ".join(rx_names_list) | |
| display_plot_with_download( | |
| fig, | |
| f"spatial_flux_p{st.session_state.spatial_flux_page}", | |
| help_text=f"These maps show the spatial distribution of flux for: **{rx_names_str}**, allowing visualization of where specific metabolic processes are active." | |
| ) | |
| plt.close(fig) | |
| if pages > 1: | |
| cx1, cx2, cx3 = st.columns([1,2,1]) | |
| if cx1.button("Previous Page", key="sf_prev"): st.session_state.spatial_flux_page -= 1; st.rerun() | |
| cx2.markdown(f"<center>Reaction Page {st.session_state.spatial_flux_page} of {pages}</center>", unsafe_allow_html=True) | |
| if cx3.button("Next Page", key="sf_next"): st.session_state.spatial_flux_page += 1; st.rerun() | |
| except Exception as e: | |
| st.error(f"Error: {e}") | |