import streamlit as st
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import logging
import textwrap
from .utils import display_plot_with_download, display_interactive_spatial_plot, display_plotly_with_download
logger = logging.getLogger(__name__)
def init_plot_state():
"""Initialize plot caching state variables."""
if "plot_cache" not in st.session_state:
st.session_state.plot_cache = {}
if "last_viz_params" not in st.session_state:
st.session_state.last_viz_params = {}
def _detect_viz_change_and_clear():
"""Detect if visualization parameters changed and clear matplotlib cache."""
current_params = {
'viz_choice': st.session_state.get('sp_viz_choice', 'Domains'),
'plot_mode': st.session_state.get('sp_mode', 'Static'),
}
last_params = st.session_state.get('sp_last_params', {})
if current_params != last_params:
st.session_state.sp_last_params = current_params
plt.close('all') # Close all matplotlib figures
return True
return False
def render_spatial_flux_map(metabolic_adata):
"""Render spatial flux maps with Red theme."""
init_plot_state()
_detect_viz_change_and_clear()
st.markdown("
Spatial Metabolic flux
", unsafe_allow_html=True)
viz_choice = st.session_state.get("sp_viz_choice", "Domains")
if viz_choice == "Domains":
c1, c2, c3 = st.columns([1.5, 1.2, 1.3])
else:
c1, c2, c3, c4 = st.columns([1.2, 1.8, 1.0, 1.2])
with c1:
viz_choice = st.selectbox("Analysis Type:", options=["Domains", "Reactions", "Pathways"], key="sp_viz_choice")
with (c3 if viz_choice == "Domains" else c4):
plot_mode = st.radio("Plot Mode:", ["Static", "Interactive"], horizontal=True, key="sp_mode")
with (c2 if viz_choice == "Domains" else c3):
spot_size = st.slider("Spot Size:", 0.5, 5.0, 1.2, 0.5, key="sp_spot_size_static") if plot_mode == "Static" else st.slider("Spot Size:", 1, 20, 6, key="sp_spot_size_interactive")
selected_items = []
if viz_choice != "Domains":
with c2:
if viz_choice == "Reactions":
if 'rxn_full_names' in metabolic_adata.var.columns:
unique_names = {}
for idx, row in metabolic_adata.var.iterrows():
f_name = str(row['rxn_full_names'])
if f_name not in unique_names:
unique_names[f_name] = idx
rx_options = sorted(list(unique_names.keys()))
if plot_mode == "Interactive":
sel_name = st.selectbox("Select Reaction:", options=rx_options, key="sp_rx_single")
selected_items = [unique_names[sel_name]] if sel_name else []
else:
sel_names = st.multiselect("Select Reactions:", options=rx_options, default=rx_options[:1], key="sp_rx_multi")
selected_items = [unique_names[n] for n in sel_names if n in unique_names]
else:
rx_options = metabolic_adata.var_names.tolist()
if plot_mode == "Interactive":
sel = st.selectbox("Select Reaction:", options=rx_options, key="sp_rx_single")
selected_items = [sel] if sel else []
else:
selected_items = st.multiselect("Select Reactions:", options=rx_options, default=rx_options[:1], key="sp_rx_multi")
elif viz_choice == "Pathways":
if 'subsystems' in metabolic_adata.var.columns:
path_options = sorted([p for p in metabolic_adata.var['subsystems'].unique() if pd.notna(p)])
if plot_mode == "Interactive":
sel = st.selectbox("Select Pathway:", options=path_options, key="sp_path_single")
selected_items = [sel] if sel else []
else:
selected_items = st.multiselect("Select Pathways:", options=path_options, default=path_options[:1], key="sp_path_multi")
else:
st.warning("No pathway data.")
try:
library_id = next(iter(metabolic_adata.uns["spatial"]))
img_key = "hires" if "hires" in metabolic_adata.uns["spatial"][library_id]["images"] else "downscaled_fullres"
if viz_choice == "Domains":
if plot_mode == "Interactive":
display_interactive_spatial_plot(
metabolic_adata,
color_key="domain",
spot_size=spot_size,
plot_name="spatial_domain_plotly",
title="Domain Assignment",
help_text="This map highlights the spatial domains assigned byclustering spots with similar metabolic flux patterns. It shows the geographical organization of the tissue's metabolic environment."
)
else:
fig, ax = plt.subplots(figsize=(10, 8))
sc.pl.spatial(metabolic_adata, img_key=img_key, color=['domain'], size=spot_size, show=False, ax=ax)
display_plot_with_download(
fig,
"spatial_domain",
help_text="This map shows the spatial distribution of metabolic domains across the tissue. Each domain represents a cluster of spots with similar metabolic flux profiles."
)
plt.close(fig)
elif viz_choice == "Pathways":
if not selected_items:
st.info("Please select a pathway.")
return
if plot_mode == "Interactive":
target = selected_items[0]
rx_list = metabolic_adata.var[metabolic_adata.var['subsystems'] == target].index.tolist()
X_sub = metabolic_adata[:, rx_list].X
pathway_avg = np.array(X_sub.mean(axis=1)).flatten() if not hasattr(X_sub, "toarray") else np.array(X_sub.toarray().mean(axis=1)).flatten()
metabolic_adata.obs[f'temp_{target}'] = pathway_avg
wrapper = textwrap.TextWrapper(width=40)
display_title = wrapper.fill(text=f"Pathway: {target}")
display_interactive_spatial_plot(
metabolic_adata,
color_key=f'temp_{target}',
spot_size=spot_size,
plot_name=f"spatial_{target}_avg_plotly",
title=display_title,
help_text=f"This interactive map shows the averaged flux distribution for the **{target}** pathway. High intensity regions highlight where this metabolic process is most active within the tissue."
)
del metabolic_adata.obs[f'temp_{target}']
else:
per_page = 4
total = len(selected_items)
pages = (total + per_page - 1) // per_page
if "sp_path_page" not in st.session_state: st.session_state.sp_path_page = 1
if st.session_state.sp_path_page > pages: st.session_state.sp_path_page = 1
curr_items = selected_items[(st.session_state.sp_path_page-1)*per_page : st.session_state.sp_path_page*per_page]
n_cols = 2 if len(curr_items) > 1 else 1
n_rows = (len(curr_items) + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(8*n_cols, 7*n_rows))
if len(curr_items) == 1: axes = np.array([[axes]])
elif n_rows == 1: axes = axes.reshape(1, -1)
elif n_cols == 1: axes = axes.reshape(-1, 1)
for i, target in enumerate(curr_items):
r, c = i // n_cols, i % n_cols
rx_list = metabolic_adata.var[metabolic_adata.var['subsystems'] == target].index.tolist()
X_sub = metabolic_adata[:, rx_list].X
avg = np.array(X_sub.mean(axis=1)).flatten() if not hasattr(X_sub, "toarray") else np.array(X_sub.toarray().mean(axis=1)).flatten()
metabolic_adata.obs['tmp_avg'] = avg
sc.pl.spatial(metabolic_adata, img_key=img_key, color=['tmp_avg'], size=spot_size, cmap='jet', show=False, ax=axes[r,c])
wrapper = textwrap.TextWrapper(width=40)
axes[r,c].set_title(wrapper.fill(text=str(target)), fontsize=12)
for j in range(len(curr_items), n_rows*n_cols): axes[j//n_cols, j%n_cols].axis('off')
plt.tight_layout()
target_names = ", ".join([str(t) for t in curr_items])
display_plot_with_download(
fig,
f"spatial_pathway_p{st.session_state.sp_path_page}",
help_text=f"This spatial flux map visualizes the spatial distribution of averaged flux for the pathways: **{target_names}**. It helps localize pathway activities within the tissue."
)
plt.close(fig)
if 'tmp_avg' in metabolic_adata.obs: del metabolic_adata.obs['tmp_avg']
if pages > 1:
c_p1, c_p2, c_p3 = st.columns([1,2,1])
if c_p1.button("Prev Pathway", key="pw_prev"): st.session_state.sp_path_page -= 1; st.rerun()
c_p2.markdown(f"Pathway Page {st.session_state.sp_path_page} / {pages}", unsafe_allow_html=True)
if c_p3.button("Next Pathway", key="pw_next"): st.session_state.sp_path_page += 1; st.rerun()
elif selected_items:
if plot_mode == "Interactive":
target = selected_items[0]
display_title = target
if 'rxn_full_names' in metabolic_adata.var.columns and target in metabolic_adata.var_names:
display_title = metabolic_adata.var.loc[target, 'rxn_full_names']
wrapper = textwrap.TextWrapper(width=40)
display_interactive_spatial_plot(
metabolic_adata,
color_key=target,
spot_size=spot_size,
plot_name=f"spatial_{target}_plotly",
title=wrapper.fill(text=f"Reaction: {display_title}"),
help_text=f"This interactive spatial map visualizes the flux distribution for the reaction **{display_title}**. You can explore its metabolic activity across different spatial domains."
)
else:
per_page = 8
total = len(selected_items)
pages = (total + per_page - 1) // per_page
if "spatial_flux_page" not in st.session_state: st.session_state.spatial_flux_page = 1
if st.session_state.spatial_flux_page > pages: st.session_state.spatial_flux_page = 1
curr_rx = selected_items[(st.session_state.spatial_flux_page-1)*per_page : st.session_state.spatial_flux_page*per_page]
n_cols = min(2, len(curr_rx))
n_rows = (len(curr_rx) + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(8*n_cols, 7*n_rows))
if len(curr_rx) == 1: axes = np.array([[axes]])
elif n_rows == 1: axes = axes.reshape(1, -1)
elif n_cols == 1: axes = axes.reshape(-1, 1)
for i, rx in enumerate(curr_rx):
r, c = i // n_cols, i % n_cols
sc.pl.spatial(metabolic_adata, img_key=img_key, color=[rx], size=spot_size, cmap='jet', show=False, ax=axes[r,c])
display_title = rx
if 'rxn_full_names' in metabolic_adata.var.columns and rx in metabolic_adata.var_names:
display_title = metabolic_adata.var.loc[rx, 'rxn_full_names']
wrapper = textwrap.TextWrapper(width=40)
axes[r,c].set_title(wrapper.fill(text=display_title), fontsize=10)
axes[r,c].axis('off')
for j in range(len(curr_rx), n_rows*n_cols): axes[j//n_cols, j%n_cols].axis('off')
plt.tight_layout()
rx_names_list = []
for rx in curr_rx:
if 'rxn_full_names' in metabolic_adata.var.columns and rx in metabolic_adata.var_names:
rx_names_list.append(metabolic_adata.var.loc[rx, 'rxn_full_names'])
else:
rx_names_list.append(rx)
rx_names_str = ", ".join(rx_names_list)
display_plot_with_download(
fig,
f"spatial_flux_p{st.session_state.spatial_flux_page}",
help_text=f"These maps show the spatial distribution of flux for: **{rx_names_str}**, allowing visualization of where specific metabolic processes are active."
)
plt.close(fig)
if pages > 1:
cx1, cx2, cx3 = st.columns([1,2,1])
if cx1.button("Previous Page", key="sf_prev"): st.session_state.spatial_flux_page -= 1; st.rerun()
cx2.markdown(f"Reaction Page {st.session_state.spatial_flux_page} of {pages}", unsafe_allow_html=True)
if cx3.button("Next Page", key="sf_next"): st.session_state.spatial_flux_page += 1; st.rerun()
except Exception as e:
st.error(f"Error: {e}")