Spaces:
Running
Running
| import streamlit as st | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| import io | |
| from datetime import datetime | |
| import matplotlib.pyplot as plt | |
| import scanpy as sc | |
| from itertools import combinations | |
| from typing import Optional | |
| from scipy.sparse import issparse | |
| from scipy.stats import mannwhitneyu | |
| from src.backend.flux_distribution import adata_to_long_df, p_to_star | |
| # Standard color map for metabolic interaction types | |
| INTERACTION_COLORS = { | |
| "Competition": "#d32f2f", # Red | |
| "Release": "#1976d2", # Blue | |
| "Cooperation": "#388e3c", # Green | |
| "Amensalism": "#fbc02d", # Amber | |
| "Neutralism": "#7b1fa2", # Purple | |
| "Interaction": "#607d8b" # Grey (fallback) | |
| } | |
| try: | |
| from statsmodels.stats.multitest import multipletests | |
| _HAS_STATSMODELS = True | |
| except ImportError: | |
| _HAS_STATSMODELS = False | |
| def display_help_button(help_text, plot_name): | |
| """ | |
| Shows a help popover with insights for the plot. | |
| """ | |
| if help_text: | |
| with st.popover("", icon=":material/help:", help="Click for insights", use_container_width=True): | |
| st.markdown(f"#### <i class='fas fa-lightbulb'></i> Plot Insights", unsafe_allow_html=True) | |
| st.markdown(help_text) | |
| def display_plot_with_download(fig, plot_name: str = "plot", help_text: str = None): | |
| """ | |
| Display a matplotlib figure with aligned help and download buttons on top right. | |
| Reuses same figure object to prevent flickering. | |
| """ | |
| # Use consistent column ratios: Spacer, Help, Download. | |
| cols = st.columns([0.7, 0.2, 0.1], gap="small") | |
| with cols[1]: | |
| display_help_button(help_text, plot_name) | |
| with cols[2]: | |
| # Generate PDF file | |
| pdf_buffer = io.BytesIO() | |
| fig.savefig(pdf_buffer, format='pdf', dpi=300, bbox_inches='tight') | |
| file_data = pdf_buffer.getvalue() | |
| st.download_button( | |
| label="", | |
| data=file_data, | |
| file_name=f"{plot_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pdf", | |
| mime="application/pdf", | |
| key=f"download_{plot_name}_{id(fig)}", | |
| help="Download as PDF", | |
| icon=":material/download:", | |
| use_container_width=True | |
| ) | |
| # Display the plot | |
| st.pyplot(fig, use_container_width=False) | |
| def display_plotly_with_download(fig, plot_name: str = "plot", help_text: str = None): | |
| """ | |
| Display a Plotly figure with aligned help button on top right. | |
| Optimized to prevent flickering on reruns. | |
| """ | |
| cols = st.columns([0.7, 0.2, 0.1], gap="small") | |
| with cols[1]: | |
| display_help_button(help_text, plot_name) | |
| with cols[2]: | |
| st.empty() | |
| # Use a unique key to prevent redraws and add config to disable animations for faster rendering | |
| st.plotly_chart( | |
| fig, | |
| use_container_width=False, | |
| key=f"plotly_{plot_name}_{id(fig)}", | |
| config={"displayModeBar": False, "responsive": True, "staticPlot": False} | |
| ) | |
| def display_interactive_spatial_plot(adata, color_key="domain", spot_size = 6, plot_name="spatial_plot", title: Optional[str] = None, help_text: Optional[str] = None): | |
| # spot_size = spot_size | |
| try: | |
| # Create columns for help/download above the plot if help_text is provided | |
| if help_text: | |
| col_space, col_help, col_download = st.columns([5.0, 0.5, 0.5], gap="small") | |
| with col_help: | |
| display_help_button(help_text, plot_name) | |
| library_id = list(adata.uns["spatial"].keys())[0] | |
| img_key = "hires" if "hires" in adata.uns["spatial"][library_id]["images"] else "downscaled_fullres" | |
| img = adata.uns["spatial"][library_id]["images"][img_key] | |
| sf_key = f"tissue_{img_key}_scalef" | |
| sf = adata.uns["spatial"][library_id]["scalefactors"][sf_key] | |
| coords = adata.obsm["spatial"] * sf | |
| if color_key in adata.var_names: | |
| var_idx = adata.var_names.get_loc(color_key) | |
| raw = adata.X[:, var_idx] | |
| color_values = raw.toarray().flatten() if hasattr(raw, "toarray") else np.asarray(raw).flatten() | |
| is_categorical = False | |
| elif color_key in adata.obs.columns: | |
| color_values = adata.obs[color_key].values | |
| is_categorical = not pd.api.types.is_numeric_dtype(adata.obs[color_key]) | |
| else: | |
| color_values = np.full(len(coords), "N/A") | |
| is_categorical = True | |
| df = pd.DataFrame({ | |
| "x": coords[:, 0], | |
| "y": coords[:, 1], | |
| "color": color_values.astype(str) if is_categorical else color_values, | |
| "domain": adata.obs["domain"].values if "domain" in adata.obs.columns else "N/A", | |
| "spot_id": adata.obs_names.tolist() | |
| }) | |
| last_key = st.session_state.get(f"{plot_name}_last_key") | |
| if last_key != color_key: | |
| st.session_state.pop(f"{plot_name}_relayout", None) | |
| st.session_state[f"{plot_name}_last_key"] = color_key | |
| plot_state = st.session_state.get(plot_name, {}) | |
| relayout = None | |
| if isinstance(plot_state, dict): | |
| relayout = plot_state.get("relayout_data") or plot_state.get("relayout") | |
| elif hasattr(plot_state, "selection"): | |
| relayout = getattr(plot_state, "relayout_data", None) | |
| zoom_ratio = 1.0 | |
| has_zoom = relayout and isinstance(relayout, dict) and "xaxis.range[0]" in relayout | |
| if has_zoom: | |
| try: | |
| xr = [relayout["xaxis.range[0]"], relayout["xaxis.range[1]"]] | |
| zoom_ratio = abs(xr[1] - xr[0]) / img.shape[1] | |
| except (IndexError, KeyError, ZeroDivisionError): | |
| zoom_ratio = 1.0 | |
| fig = go.Figure() | |
| fig.add_layout_image( | |
| dict( | |
| source=Image.fromarray((img * 255).astype(np.uint8)), | |
| xref="x", yref="y", | |
| x=0, y=0, | |
| sizex=img.shape[1], sizey=img.shape[0], | |
| sizing="stretch", layer="below" | |
| ) | |
| ) | |
| if is_categorical: | |
| palette = px.colors.qualitative.T10 | |
| unique_vals = sorted(df["color"].astype(str).unique()) | |
| for i, val in enumerate(unique_vals): | |
| sub = df[df["color"].astype(str) == val] | |
| fig.add_trace(go.Scattergl( | |
| x=sub["x"], | |
| y=sub["y"], | |
| customdata=np.stack((sub["spot_id"], sub["domain"]), axis=-1), | |
| mode="markers", | |
| name=str(val), | |
| marker=dict( | |
| size=spot_size, | |
| color=palette[i % len(palette)], | |
| line=dict(width=0.5, color='white') | |
| ), | |
| hovertemplate=( | |
| "<b>Domain: %{customdata[1]}</b><br>" | |
| "<span style='font-size:0.8rem;'>ID: %{customdata[0]}</span>" | |
| "<extra></extra>" | |
| ) | |
| )) | |
| else: | |
| fig.add_trace(go.Scattergl( | |
| x=df["x"], y=df["y"], | |
| customdata=np.stack((df["spot_id"], df["domain"]), axis=-1), | |
| mode="markers", | |
| marker=dict( | |
| size=spot_size, | |
| color=df["color"], | |
| colorscale="Jet", | |
| showscale=True, | |
| colorbar=dict( | |
| thickness=8, | |
| len=0.75, | |
| xref="paper", | |
| yref="paper", | |
| tickfont=dict(size=10), | |
| outlinewidth=0, | |
| ), | |
| line=dict(width=0.3, color='white') | |
| ), | |
| hovertemplate=( | |
| "<b>Domain: %{customdata[1]}</b><br>" | |
| f"<b>Flux:</b> %{{marker.color:.3e}}<br>" | |
| "<span style='font-size:0.8rem;'>ID: %{customdata[0]}</span>" | |
| "<extra></extra>" | |
| ) | |
| )) | |
| # Enforce square axes aligned to tissue image | |
| fig.update_xaxes( | |
| visible=False, | |
| range=[0, img.shape[1]], | |
| scaleanchor="y", | |
| scaleratio=1, | |
| ) | |
| fig.update_yaxes( | |
| visible=False, | |
| range=[img.shape[0], 0], | |
| scaleanchor="x", | |
| scaleratio=1, | |
| constrain="domain", | |
| ) | |
| fig.update_layout( | |
| title=dict( | |
| text=title if title else "", | |
| x=0.5, | |
| y=0.98, | |
| xanchor="center", | |
| yanchor="top", | |
| font=dict(size=16) | |
| ) if title else None, | |
| margin=dict(l=0, r=0, t=40 if title else 0, b=0), | |
| legend=dict( | |
| orientation="v", | |
| yanchor="top", | |
| y=0.99, | |
| xanchor="left", | |
| x=0.01, | |
| bgcolor="rgba(255,255,255,0.6)" | |
| ), | |
| paper_bgcolor='rgba(0,0,0,0)', | |
| plot_bgcolor='rgba(0,0,0,0)', | |
| dragmode="pan", | |
| uirevision="constant" | |
| ) | |
| plot_event = st.plotly_chart( | |
| fig, | |
| use_container_width=False, | |
| config={'scrollZoom': True}, | |
| key=plot_name, | |
| on_select="rerun" | |
| ) | |
| if plot_event and hasattr(plot_event, "get"): | |
| relayout = plot_event.get("relayout_data") or plot_event.get("selection", {}).get("relayout_data") | |
| if relayout: | |
| st.session_state[f"{plot_name}_relayout"] = relayout | |
| return True | |
| except Exception as e: | |
| st.error(f"Error rendering interactive plot: {e}") | |
| return False | |
| def display_formatted_table(df: pd.DataFrame, title: Optional[str] = None): | |
| """Display a dataframe with scientific notation for small float values.""" | |
| if title: | |
| st.markdown(f"##### <i class='fas fa-table'></i> {title}", unsafe_allow_html=True) | |
| config = {} | |
| if not df.empty: | |
| for col in df.select_dtypes(include=['float']).columns: | |
| if 'p_val' in col.lower() or 'pvalue' in col.lower() or df[col].abs().max() < 1e-2: | |
| config[col] = st.column_config.NumberColumn(format="%.2e") | |
| else: | |
| config[col] = st.column_config.NumberColumn(format="%.4f") | |
| st.dataframe(df, width='stretch', column_config=config) | |
| def add_significance_brackets(ax, df, domain_order, y_col="flux"): | |
| """ | |
| Add pairwise significance brackets above a boxen/box plot. | |
| Uses Mann-Whitney U test with FDR-BH correction across all pairs. | |
| Only significant pairs (p_adj < 0.05) are annotated. | |
| """ | |
| pairs = list(combinations(domain_order, 2)) | |
| pvalues = [] | |
| valid_pairs = [] | |
| for d1, d2 in pairs: | |
| g1 = df.loc[df["domain"] == d1, y_col].dropna() | |
| g2 = df.loc[df["domain"] == d2, y_col].dropna() | |
| if len(g1) < 3 or len(g2) < 3: | |
| continue | |
| _, p = mannwhitneyu(g1, g2, alternative="two-sided") | |
| pvalues.append(p) | |
| valid_pairs.append((d1, d2)) | |
| if not valid_pairs: | |
| return | |
| if _HAS_STATSMODELS: | |
| _, p_adj, _, _ = multipletests(pvalues, method="fdr_bh") | |
| else: | |
| p_adj = np.array(pvalues) | |
| y_max = df[y_col].max() | |
| y_range = df[y_col].max() - df[y_col].min() | |
| step = y_range * 0.08 | |
| bracket_y = y_max + step | |
| for (d1, d2), p in zip(valid_pairs, p_adj): | |
| star = p_to_star(p) | |
| if star == "ns": | |
| continue | |
| x1 = domain_order.index(d1) | |
| x2 = domain_order.index(d2) | |
| mid = (x1 + x2) / 2 | |
| ax.plot([x1, x1, x2, x2], [bracket_y, bracket_y + step * 0.3, bracket_y + step * 0.3, bracket_y], | |
| lw=1.2, c="black") | |
| ax.text(mid, bracket_y + step * 0.35, star, ha="center", va="bottom", fontsize=9) | |
| bracket_y += step * 0.9 # stack brackets upward | |
| def create_plotly_tme_plot(adata, interaction_type_df, interaction_score_df, selected_rxn_id, selected_display_name, percentile_threshold=95): | |
| coords_df = pd.DataFrame(adata.obsm["spatial"], index=adata.obs.index, columns=['x', 'y']) | |
| y_max = coords_df['y'].max() | |
| coords_df['y_plot'] = y_max - coords_df['y'] | |
| coords_df['domain'] = adata.obs['domain'] if 'domain' in adata.obs.columns else "N/A" | |
| if percentile_threshold > 0: | |
| thresh = interaction_score_df['Interaction score'].quantile(percentile_threshold / 100) | |
| scores = interaction_score_df[interaction_score_df['Interaction score'] >= thresh] | |
| else: | |
| scores = interaction_score_df | |
| rxn_mask = interaction_type_df['Reaction'].str.replace(r'_(b|f)$', '', regex=True) == selected_rxn_id | |
| rxn_data = interaction_type_df[rxn_mask] | |
| merged = pd.merge(rxn_data, scores, on=['Source', 'Target']) | |
| if merged.empty: | |
| return None | |
| fig = go.Figure() | |
| fig.add_trace(go.Scattergl( | |
| x=coords_df['x'], y=coords_df['y_plot'], | |
| mode='markers', | |
| marker=dict(size=4, color='#bdbdbd', opacity=0.5), # All spots in background | |
| name='Tissue Background', | |
| customdata=np.stack((coords_df.index, coords_df['domain']), axis=-1), | |
| hovertemplate="<b>Spot ID: %{customdata[0]}</b><br>Domain: %{customdata[1]}<extra></extra>", | |
| showlegend=False | |
| )) | |
| types = merged['Interaction type'].unique() | |
| colors = px.colors.qualitative.T10 | |
| for i, t in enumerate(types): | |
| sub = merged[merged['Interaction type'] == t] | |
| s_coords = coords_df.loc[sub['Source'], ['x', 'y_plot']].values | |
| t_coords = coords_df.loc[sub['Target'], ['x', 'y_plot']].values | |
| n = len(sub) | |
| edge_x = np.full(n * 3, np.nan) | |
| edge_y = np.full(n * 3, np.nan) | |
| edge_x[0::3] = s_coords[:, 0]; edge_x[1::3] = t_coords[:, 0] | |
| edge_y[0::3] = s_coords[:, 1]; edge_y[1::3] = t_coords[:, 1] | |
| fig.add_trace(go.Scattergl( | |
| x=edge_x, y=edge_y, | |
| mode='lines', | |
| line=dict(width=3, color=INTERACTION_COLORS.get(t, "#607d8b")), | |
| name=str(t), | |
| hoverinfo='none', # Hover is handled by midpoints | |
| connectgaps=False | |
| )) | |
| # Midpoints for robust hover in the middle of lines | |
| mid_x = (s_coords[:, 0] + t_coords[:, 0]) / 2 | |
| mid_y = (s_coords[:, 1] + t_coords[:, 1]) / 2 | |
| fig.add_trace(go.Scattergl( | |
| x=mid_x, y=mid_y, | |
| mode='markers', | |
| marker=dict(size=12, opacity=0), # Large invisible target | |
| name=str(t), | |
| hovertemplate=f"<b>Interaction: {t}</b><br>Score: %{{customdata:.4f}}<extra></extra>", | |
| customdata=sub['Interaction score'].values, | |
| showlegend=False | |
| )) | |
| active_spots = sorted(list(set(merged['Source']).union(set(merged['Target'])))) | |
| active_df = coords_df.loc[active_spots] | |
| fig.add_trace(go.Scattergl( | |
| x=active_df['x'], y=active_df['y_plot'], | |
| mode='markers', | |
| marker=dict(size=5, color='#424242', opacity=0.9, line=dict(width=1, color='white')), | |
| name='Interacting Spots', | |
| customdata=np.stack((active_df.index, active_df['domain']), axis=-1), | |
| hovertemplate="<b>Spot ID: %{customdata[0]}</b><br>Domain: %{customdata[1]}<extra></extra>", | |
| showlegend=True | |
| )) | |
| fig.update_layout( | |
| title=dict( | |
| text=f"Metabolic Interactions: {selected_display_name}", | |
| ), | |
| xaxis=dict(visible=False), yaxis=dict(visible=False, scaleanchor="x"), | |
| plot_bgcolor='#fcfcfc', paper_bgcolor='white', | |
| width=850, height=850, margin=dict(l=10, r=10, t=60, b=10), | |
| legend=dict(orientation="h", y=1.02, x=0, xanchor="left", title="Interaction Type:"), | |
| hovermode='closest', | |
| hoverdistance=30 # Makes it easier to hover on lines | |
| ) | |
| return fig | |
| def create_plotly_comm_plot(interaction_scores, adata, percentile_threshold=80): | |
| """ | |
| Optimized Communication Strength plot using WebGL and vectorized coordinates. | |
| """ | |
| coords_df = pd.DataFrame(adata.obsm["spatial"], index=adata.obs.index, columns=['x', 'y']) | |
| y_max = coords_df['y'].max() | |
| coords_df['y_plot'] = y_max - coords_df['y'] | |
| coords_df['domain'] = adata.obs['domain'] if 'domain' in adata.obs.columns else "N/A" | |
| if percentile_threshold > 0: | |
| thresh = interaction_scores['Interaction score'].quantile(percentile_threshold / 100) | |
| interaction_scores = interaction_scores[interaction_scores['Interaction score'] >= thresh] | |
| valid = interaction_scores[ | |
| (interaction_scores['Source'].isin(coords_df.index)) & | |
| (interaction_scores['Target'].isin(coords_df.index)) | |
| ] | |
| if valid.empty: return None | |
| fig = go.Figure() | |
| # Background | |
| fig.add_trace(go.Scattergl( | |
| x=coords_df['x'], y=coords_df['y_plot'], | |
| mode='markers', | |
| marker=dict(size=4, color='#bdbdbd', opacity=0.3), # All spots in background | |
| name='Tissue Background', | |
| customdata=np.stack((coords_df.index, coords_df['domain']), axis=-1), | |
| hovertemplate="<b>Spot ID: %{customdata[0]}</b><br>Domain: %{customdata[1]}<extra></extra>", | |
| showlegend=False | |
| )) | |
| # Binned Edges (Vectorized) | |
| n_bins = 5 | |
| valid = valid.copy() | |
| valid['bin'] = pd.qcut(valid['Interaction score'], n_bins, labels=False, duplicates='drop') | |
| for b in range(n_bins): | |
| sub = valid[valid['bin'] == b] | |
| if sub.empty: continue | |
| s_coords = coords_df.loc[sub['Source'], ['x', 'y_plot']].values | |
| t_coords = coords_df.loc[sub['Target'], ['x', 'y_plot']].values | |
| n = len(sub) | |
| edge_x = np.full(n * 3, np.nan) | |
| edge_y = np.full(n * 3, np.nan) | |
| edge_x[0::3] = s_coords[:, 0]; edge_x[1::3] = t_coords[:, 0] | |
| edge_y[0::3] = s_coords[:, 1]; edge_y[1::3] = t_coords[:, 1] | |
| fig.add_trace(go.Scattergl( | |
| x=edge_x, y=edge_y, | |
| mode='lines', | |
| line=dict(width=0.5 + b*1.5, color=px.colors.sample_colorscale("Viridis", b/(n_bins-1))[0]), | |
| name=f"Level {b+1}", hoverinfo='none' | |
| )) | |
| fig.update_layout( | |
| title="Cell-Cell Metabolic Communication Strengths", | |
| xaxis=dict(visible=False), yaxis=dict(visible=False, scaleanchor="x"), | |
| plot_bgcolor='#fcfcfc', width=850, height=850, | |
| legend=dict(title="Score Bin:", orientation="v", x=1.02, y=1) | |
| ) | |
| return fig | |