"""Plotly helpers for the explorer UI.""" from __future__ import annotations import html from typing import Any import numpy as np import pandas as pd import plotly.express as px import plotly.graph_objects as go from plotly.subplots import make_subplots from streamlit_hf.lib.reactions import normalize_reaction_key # Matches Streamlit theme primary + slate text; used across Plotly layouts. PLOT_FONT = dict(family="Inter, system-ui, sans-serif", size=12) # Same as app / plotly_white paper so figures are not tinted vs the page. PAGE_BG = "#ffffff" PALETTE = ( "#2563eb", "#dc2626", "#059669", "#d97706", "#7c3aed", "#db2777", "#0d9488", "#4f46e5", ) MODALITY_COLOR = {"RNA": "#E64B35", "ATAC": "#4DBBD5", "Flux": "#00A087"} # Global modality pie only: edit here to try other hues (bars/scatter use MODALITY_COLOR). MODALITY_PIE_COLOR = dict(MODALITY_COLOR) # Log₂FC heatmaps/sunburst: colours like ggplot2 scale_colour_gradient2 (mid grey at 0). LOG_FC_COLOR_MIN = -0.5 LOG_FC_COLOR_MAX = 0.5 LOG_FC_DIVERGING_SCALE: list[list] = [ [0.0, "#1C86EE"], [0.5, "#FAFAFA"], [1.0, "#FF0000"], ] # Unicode minus (U+2212) and subscript ₁₀ / ₂ for axes/colorbars. LABEL_NEG_LOG10_ADJ_P = "\u2212log\u2081\u2080 adj. p" LABEL_LOG2FC = "Log\u2082FC" # Cached attention dict uses lowercase modality keys. FI_ATT_MOD_KEY = {"RNA": "rna", "ATAC": "atac", "Flux": "flux"} # Model appends one batch-embedding token per modality; hide from attention rankings in the UI. BATCH_EMBEDDING_FEATURE_NAMES = frozenset({"batch_rna", "batch_atac", "batch_flux"}) def _attention_pairs_skip_batch(pairs: list) -> list: return [(n, s) for n, s in pairs if str(n) not in BATCH_EMBEDDING_FEATURE_NAMES] def rollout_top_features_table(feature_names, vec, top_n: int) -> pd.DataFrame: """Top `top_n` rollout weights per modality slice, excluding batch-embedding tokens.""" names = [str(x) for x in feature_names] v = np.asarray(vec, dtype=float) rows = [ (names[i], float(v[i])) for i in range(len(names)) if names[i] not in BATCH_EMBEDDING_FEATURE_NAMES ] rows.sort(key=lambda x: -x[1]) rows = rows[:top_n] if not rows: return pd.DataFrame(columns=["feature", "mean_attention"]) feat, val = zip(*rows) return pd.DataFrame({"feature": list(feat), "mean_attention": list(val)}) # Themed continuous scale for dominant-fate % on UMAP (low → high emphasis). UMAP_PCT_COLORSCALE: list[list] = [ [0.0, "#eff6ff"], [0.25, "#bfdbfe"], [0.55, "#3b82f6"], [0.82, "#2563eb"], [1.0, "#1e3a8a"], ] # Okabe–Ito–style distinct colours (colourblind-friendly) for categorical UMAP hues. LATENT_DISCRETE_PALETTE = ( "#0072B2", "#E69F00", "#009E73", "#CC79A7", "#56B4E9", "#D55E00", "#F0E442", "#000000", ) def latent_scatter( df, color_col: str, title: str, width: int = 720, height: int = 520, marker_size: float = 5.0, marker_opacity: float = 0.78, subtitle: str | None = None, ): d = df.copy() hover_spec = { "umap_x": ":.3f", "umap_y": ":.3f", "dataset_idx": True, "fold": True, "batch_no": True, "predicted_class": True, "label": True, "correct": True, "pct": ":.2f", "modality_label": True, "modality": True, "predicted_value": ":.3f", "clone_id": True, "clone_size": True, "cell_type": True, } if "modality_label" in d.columns: hover_spec.pop("modality", None) hover_data = {k: v for k, v in hover_spec.items() if k in d.columns} _disp = { "label": "CellTag-Multi label", "predicted_class": "Predicted fate", "pct": "Dominant fate (%)", "modality_label": "Available modalities", "dataset_idx": "Dataset index", "batch_no": "Batch", "fold": "Cross Validation fold", } labels_map = {c: _disp[c] for c in _disp if c in d.columns} continuous = color_col == "pct" if color_col == "fold": d["_color"] = d["fold"].astype(str) color_arg = "_color" labels_map["_color"] = "Fold" continuous = False elif color_col == "batch_no": d["_color"] = d["batch_no"].astype(str) color_arg = "_color" labels_map["_color"] = "Batch" continuous = False elif color_col == "correct": d["_color"] = d["correct"].map({True: "Correct", False: "Wrong"}) color_arg = "_color" labels_map["_color"] = "Prediction" continuous = False else: color_arg = color_col # Plotly Express turns title="" into a visible "undefined" title in some versions; omit when empty. common = dict( x="umap_x", y="umap_y", hover_data=hover_data, labels=labels_map, width=width, height=height, ) # Title + subtitle are applied via update_layout when `subtitle` is set (Plotly 5+). if title and not subtitle: common["title"] = title if continuous: fig = px.scatter( d, color=color_arg, color_continuous_scale=UMAP_PCT_COLORSCALE, **common, ) else: fig = px.scatter( d, color=color_arg, color_discrete_sequence=list(LATENT_DISCRETE_PALETTE), **common, ) fig.update_traces( marker=dict(size=marker_size, opacity=marker_opacity, line=dict(width=0.25, color="rgba(255,255,255,0.4)")) ) if title and subtitle: top_margin = 88 else: top_margin = 56 if title else 28 fig.update_layout( template="plotly_white", font=PLOT_FONT, title_font_size=16, margin=dict(l=28, r=20, t=top_margin, b=28), legend_title_text="", xaxis_title="", yaxis_title="", paper_bgcolor=PAGE_BG, plot_bgcolor=PAGE_BG, ) if title and subtitle: fig.update_layout( title=dict( text=title, x=0.5, xanchor="center", font=dict(size=16, family=PLOT_FONT["family"]), subtitle=dict( text=subtitle, font=dict(size=11, color="#64748b", family=PLOT_FONT["family"]), ), ), ) elif not title: fig.update_layout(title=None) fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False) fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False) return fig def rank_scatter_shift_vs_attention(df_mod, modality: str, width: int = 420, height: int = 440): """Attention rank on x, shift rank on y, least-squares trend, colours by top ~10% within this modality.""" need = ("shift_order_mod", "attention_order_mod") if not all(c in df_mod.columns for c in need): return go.Figure() sub = df_mod.dropna(subset=list(need)).copy() if sub.empty: return go.Figure() n = len(sub) top_k = max(1, int(np.ceil(0.1 * n))) s_ok = sub["shift_order_mod"].astype(int) <= top_k a_ok = sub["attention_order_mod"].astype(int) <= top_k sub["_tier_label"] = np.where( s_ok & a_ok, "Both", np.where(s_ok, "Shift", np.where(a_ok, "Attention", "Neither")), ) x = sub["attention_order_mod"].astype(float).to_numpy() y = sub["shift_order_mod"].astype(float).to_numpy() fig = px.scatter( sub, x="attention_order_mod", y="shift_order_mod", color="_tier_label", hover_name="feature", hover_data={ "mean_rank": True, "importance_shift": ":.4f", "importance_att": ":.4f", }, labels={ "attention_order_mod": "Attention rank", "shift_order_mod": "Shift rank", "_tier_label": "Top-10% tier", }, category_orders={"_tier_label": ["Both", "Shift", "Attention", "Neither"]}, width=width, height=height, color_discrete_map={ "Both": PALETTE[0], "Shift": PALETTE[1], "Attention": PALETTE[2], "Neither": "#94a3b8", }, ) fig.update_traces(marker=dict(size=7, opacity=0.62, line=dict(width=0.5, color="rgba(15,23,42,0.28)"))) if len(x) >= 2 and float(np.ptp(x)) > 0: coef = np.polyfit(x, y, 1) poly = np.poly1d(coef) xs = np.linspace(float(np.min(x)), float(np.max(x)), 100) fig.add_trace( go.Scatter( x=xs, y=poly(xs), mode="lines", name=f"y = {coef[0]:.2f}x + {coef[1]:.2f}", line=dict(color="#2563eb", width=2, dash="dash"), showlegend=True, ) ) fig.update_layout( template="plotly_white", font=PLOT_FONT, title=dict( text=f"{modality}: shift vs attention (ranks)", x=0.5, xanchor="center", y=0.98, yanchor="top", font=dict(size=14, family=PLOT_FONT["family"]), ), margin=dict(l=48, r=20, t=52, b=72), legend=dict( title=dict(text="Among top 10% features?"), orientation="h", yanchor="top", y=-0.2, xanchor="center", x=0.5, ), ) return fig def _truncate_label(s: str, max_len: int = 36) -> str: s = str(s) return s if len(s) <= max_len else s[: max_len - 1] + "…" def joint_shift_attention_top_features(df_mod, modality: str, top_n: int): """ Top features by mean_rank (lowest = strongest joint shift+attention ranking). Shift and attention importances are min-max scaled within this top-N slice for side-by-side comparison. """ need = ("mean_rank", "importance_shift", "importance_att", "feature") if not all(c in df_mod.columns for c in need): return go.Figure() sub = df_mod.nsmallest(top_n, "mean_rank").copy() if sub.empty: return go.Figure() def _mm(s: pd.Series) -> pd.Series: lo, hi = float(s.min()), float(s.max()) if hi <= lo: return pd.Series(0.5, index=s.index) return (s.astype(float) - lo) / (hi - lo) sub["_zs"] = _mm(sub["importance_shift"]) sub["_za"] = _mm(sub["importance_att"]) # Best (lowest mean_rank) at top of chart; matches shift/attention rows below. sub = sub.sort_values("mean_rank", ascending=True) feats_full = sub["feature"].astype(str) y_disp = feats_full.map(lambda s: _truncate_label(s, 40)) base = MODALITY_COLOR.get(modality, PALETTE[0]) att_c = "#475569" if base != "#475569" else "#64748b" margin_l = int(min(380, 64 + 5.8 * max((len(t) for t in y_disp), default=10))) h = min(720, 52 + 22 * len(sub)) fig = go.Figure() fig.add_trace( go.Bar( name="Shift (scaled)", y=y_disp, x=sub["_zs"], orientation="h", marker_color=base, customdata=feats_full, hovertemplate="%{customdata}
Shift (scaled): %{x:.3f}", ) ) fig.add_trace( go.Bar( name="Attention (scaled)", y=y_disp, x=sub["_za"], orientation="h", marker_color=att_c, customdata=feats_full, hovertemplate="%{customdata}
Attention (scaled): %{x:.3f}", ) ) fig.update_layout( template="plotly_white", font=PLOT_FONT, title=dict( text=f"{modality} · top {top_n}", x=0.5, xanchor="center", y=0.98, yanchor="top", font=dict(size=14, family=PLOT_FONT["family"]), ), barmode="group", bargap=0.15, bargroupgap=0.05, width=680, height=h, margin=dict(l=margin_l, r=12, t=44, b=72), xaxis_title="Scaled 0-1 within selection", yaxis_title="", legend=dict(orientation="h", yanchor="top", y=-0.14, xanchor="center", x=0.5), ) fig.update_yaxes(autorange="reversed", tickfont=dict(size=10)) return fig def modality_shift_attention_rank_stats(df_mod) -> dict[str, Any]: """Pearson / Spearman between per-modality shift and attention ordinal ranks.""" from scipy.stats import pearsonr, spearmanr need = ("shift_order_mod", "attention_order_mod") if not all(c in df_mod.columns for c in need): return {"n": 0} sub = df_mod.dropna(subset=list(need)) n = len(sub) if n < 3: return {"n": n} xs = sub["attention_order_mod"].astype(float) ys = sub["shift_order_mod"].astype(float) pr, pp = pearsonr(xs, ys) sr, sp = spearmanr(xs, ys) return { "n": n, "pearson_r": float(pr), "pearson_p": float(pp), "spearman_r": float(sr), "spearman_p": float(sp), } def rank_bar( df_top, xcol: str, ycol: str, title: str, color: str = PALETTE[0], xaxis_title: str | None = None, ): d = df_top.sort_values(xcol, ascending=True) y_raw = d[ycol].astype(str) y_show = y_raw.map(lambda s: _truncate_label(s, 42)) margin_l = int(min(420, 80 + 5.8 * max((len(s) for s in y_show), default=12))) fig = go.Figure( go.Bar( y=y_show, x=d[xcol], orientation="h", marker_color=color, customdata=y_raw, hovertemplate="%{customdata}
%{x:.4g}", ) ) xt = xaxis_title if xaxis_title is not None else xcol.replace("_", " ") fig.update_layout( template="plotly_white", font=PLOT_FONT, title=title, width=680, height=min(620, 38 + 20 * len(d)), margin=dict(l=margin_l, r=24, t=48, b=40), xaxis_title=xt, yaxis_title="", ) fig.update_yaxes(tickfont=dict(size=10)) return fig def attention_top_comparison(fi_lists: dict, modality: str, top_n: int = 18): """fi_lists: cohort -> {rna|atac|flux: [(name, score), ...]}.""" mk = FI_ATT_MOD_KEY.get(modality, str(modality).lower()) traces = [] for key, name, color in ( ("all", "All validation samples", PALETTE[0]), ("dead_end", "Predicted dead-end", PALETTE[1]), ("reprogramming", "Predicted reprogramming", PALETTE[2]), ): cohort = fi_lists.get(key) or {} items = _attention_pairs_skip_batch(list(cohort.get(mk, [])))[:top_n] if not items: continue feats, scores = zip(*items) traces.append( go.Bar( name=name, x=list(scores), y=[f[:52] + ("…" if len(f) > 52 else "") for f in feats], orientation="h", marker_color=color, ) ) fig = go.Figure(traces) bar_h = max(320, 36 + min(top_n, 20) * 22 * max(1, len(traces))) fig.update_layout( barmode="group", template="plotly_white", font=PLOT_FONT, title=f"Top attention (rollout): {modality}", width=520, height=bar_h, margin=dict(l=220, r=24, t=56, b=40), legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), ) if not traces: fig.update_layout( annotations=[ dict( text="No attention list for this modality (re-run precompute).", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, ) ] ) else: fig.update_yaxes(autorange="reversed") return fig def attention_cohort_view( fi_lists: dict, modality: str, top_n: int, mode: str, ): """ mode: 'compare': grouped bars for all three cohorts; 'all' | 'dead_end' | 'reprogramming': single cohort only. """ if mode == "compare": return attention_top_comparison(fi_lists, modality, top_n) mk = FI_ATT_MOD_KEY.get(modality, str(modality).lower()) cohort = fi_lists.get(mode) or {} items = _attention_pairs_skip_batch(list(cohort.get(mk, [])))[:top_n] label = { "all": "All validation samples", "dead_end": "Predicted dead-end", "reprogramming": "Predicted reprogramming", }.get(mode, mode) if not items: fig = go.Figure() fig.update_layout( template="plotly_white", font=PLOT_FONT, title=f"{modality} · {label}", annotations=[ dict( text="No items for this cohort.", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, ) ], ) return fig feats, scores = zip(*items) fig = go.Figure( go.Bar( x=list(scores), y=[f[:52] + ("…" if len(f) > 52 else "") for f in feats], orientation="h", marker_color=PALETTE[0], ) ) h = max(280, 40 + min(top_n, 25) * 20) fig.update_layout( template="plotly_white", font=PLOT_FONT, title=f"{modality} · {label}", width=520, height=h, margin=dict(l=220, r=24, t=56, b=40), xaxis_title="Attention weight", ) fig.update_yaxes(autorange="reversed") return fig def _pie_hover_feature_lines(names: list[str], *, names_per_line: int = 5) -> str: """Join feature names with commas; start a new hover row every ``names_per_line`` items (HTML ``
``).""" if not names: return "—" safe = [html.escape(str(n), quote=False) for n in names] step = max(1, int(names_per_line)) lines: list[str] = [] for i in range(0, len(safe), step): lines.append(", ".join(safe[i : i + step])) return "
".join(lines) def global_rank_triple_panel( df_features, top_n: int = 20, top_n_pie: int = 100, *, chart_outline: bool = True, modality_mix_hole: float = 0.0, modality_mix_hover_feature_list: bool = False, ): """ Global top-N by latent-shift and by attention (min-max scaled), plus pie or donut of modality mix among the top `top_n_pie` features by mean rank. Set ``chart_outline=False`` for a flatter look (e.g. home page); Feature Insights keeps outlines by default. Set ``modality_mix_hole`` in (0, 1), e.g. ``0.66``, for a donut instead of a full pie (e.g. home page). Set ``modality_mix_hover_feature_list=True`` to show comma-separated feature names per donut slice on hover (same pool as the pie: strongest by mean rank within each modality), wrapped every few names for readability. """ d = df_features.copy() for col in ("importance_shift", "importance_att"): min_v, max_v = d[col].min(), d[col].max() if max_v > min_v: d[col + "_norm"] = (d[col] - min_v) / (max_v - min_v) else: d[col + "_norm"] = 0.0 shift_top = d.nlargest(top_n, "importance_shift") att_top = d.nlargest(top_n, "importance_att") pie_pool = d.nsmallest(top_n_pie, "mean_rank") fig = make_subplots( rows=1, cols=3, column_widths=[0.36, 0.36, 0.28], specs=[[{}, {}, {"type": "domain"}]], subplot_titles=( f"Top {top_n} by latent shift (ranked)", f"Top {top_n} by attention (ranked)", f"Top {top_n_pie} by mean rank (modality mix)", ), horizontal_spacing=0.06, ) bar_outline = dict(color="#1e293b", width=1.2) if chart_outline else dict(width=0) pie_line = dict(color="#1e293b", width=1.2) if chart_outline else dict(width=0) leg_line = dict(width=1.2, color="#1e293b") if chart_outline else dict(width=0) fig.add_trace( go.Bar( x=shift_top["importance_shift_norm"], y=shift_top["feature"], orientation="h", marker_color=[MODALITY_COLOR.get(m, "#64748b") for m in shift_top["modality"]], marker_line=bar_outline, showlegend=False, hovertemplate="%{y}
scaled shift: %{x:.3f}", ), row=1, col=1, ) fig.add_trace( go.Bar( x=att_top["importance_att_norm"], y=att_top["feature"], orientation="h", marker_color=[MODALITY_COLOR.get(m, "#64748b") for m in att_top["modality"]], marker_line=bar_outline, showlegend=False, hovertemplate="%{y}
scaled attention: %{x:.3f}", ), row=1, col=2, ) pie_labels = ["RNA", "ATAC", "Flux"] counts = pie_pool["modality"].value_counts() pie_vals = [int(counts.get(lab, 0)) for lab in pie_labels] if sum(pie_vals) == 0: pie_vals = [1, 1, 1] _hole = float(modality_mix_hole) if modality_mix_hole and modality_mix_hole > 0 else 0.0 # Narrow third subplot: "auto" avoids clipped outside labels on donuts. _pie_textpos = "auto" _pie_kwargs: dict = dict( labels=pie_labels, values=pie_vals, marker=dict( colors=[MODALITY_PIE_COLOR.get(l, "#64748b") for l in pie_labels], line=pie_line, ), textinfo="label+percent", textfont_size=12, textposition=_pie_textpos, hole=_hole, showlegend=False, ) if modality_mix_hover_feature_list: _hover_texts: list[str] = [] for lab in pie_labels: sub = pie_pool[pie_pool["modality"] == lab] if sub.empty: _hover_texts.append("—") else: sub = sub.sort_values("mean_rank", ascending=True, kind="mergesort") _per_line = 1 if lab == "Flux" else 5 _hover_texts.append( _pie_hover_feature_lines(sub["feature"].astype(str).tolist(), names_per_line=_per_line) ) _pie_kwargs["hovertext"] = _hover_texts _pie_kwargs["hovertemplate"] = ( "%{label} · %{value} features (%{percent:.1%})

%{hovertext}" ) fig.add_trace( go.Pie(**_pie_kwargs), row=1, col=3, ) # Modality legend (bar colours + pie segment colours): invisible markers in first subplot only. for _name, _col in ( ("RNA (transcriptome)", MODALITY_PIE_COLOR["RNA"]), ("ATAC (chromatin)", MODALITY_PIE_COLOR["ATAC"]), ("Flux (metabolism)", MODALITY_PIE_COLOR["Flux"]), ): fig.add_trace( go.Scatter( x=[None], y=[None], mode="markers", marker=dict(size=12, color=_col, symbol="square", line=leg_line), name=_name, showlegend=True, hoverinfo="skip", ), row=1, col=1, ) fig.update_xaxes(title_text="Min-max scaled shift", row=1, col=1) fig.update_xaxes(title_text="Min-max scaled attention", row=1, col=2) fig.update_yaxes(autorange="reversed", row=1, col=1) fig.update_yaxes(autorange="reversed", row=1, col=2) h = max(480, 40 + top_n * 18) fig.update_layout( template="plotly_white", font=PLOT_FONT, height=h, width=min(1280, 400 + top_n * 14), margin=dict(l=40, r=40, t=80, b=108), paper_bgcolor=PAGE_BG, plot_bgcolor=PAGE_BG, title_text="Global feature ranking (all modalities)", title_x=0.5, legend=dict( title=dict(text="Modality colour key", font=dict(size=11, family=PLOT_FONT["family"])), orientation="h", yanchor="top", y=-0.14, xanchor="center", x=0.5, font=dict(size=11, family=PLOT_FONT["family"]), traceorder="normal", itemsizing="constant", ), ) return fig def _flux_prepare_top_ranked(flux_df: pd.DataFrame, top_n: int, metric: str = "mean_rank") -> pd.DataFrame: sub = flux_df[~flux_df["feature"].astype(str).str.contains("batch", case=False, na=False)].copy() if metric not in sub.columns: metric = "mean_rank" sub = sub.sort_values(metric, ascending=True).head(int(top_n)).copy() if "pathway" in sub.columns: pc = sub["pathway"].value_counts() sub["_pw_n"] = sub["pathway"].map(pc) sub.sort_values(["_pw_n", "pathway"], ascending=[False, True], inplace=True) return sub def flux_pathway_sunburst(flux_df: pd.DataFrame, max_features: int = 55) -> go.Figure: sub = flux_df.dropna(subset=["pathway"]).copy() if sub.empty: return go.Figure() sub = sub.nsmallest(int(max_features), "mean_rank") sub["pathway"] = sub["pathway"].astype(str) sub["_uid"] = np.arange(len(sub)) sub["rxn"] = sub.apply( lambda r: f"{_truncate_label(str(r['feature']), 36)} ·{int(r['_uid'])}", axis=1, ) mr = sub["mean_rank"].astype(float) sub["w"] = (mr.max() - mr + 1.0).clip(lower=0.5) color_col = "log_fc" if "log_fc" in sub.columns and sub["log_fc"].notna().any() else "mean_rank" sb_kw: dict[str, Any] = { "path": ["pathway", "rxn"], "values": "w", "color": color_col, "hover_data": {"mean_rank": ":.2f", "pval_adj": ":.2e", "feature": True, "w": False, "_uid": False}, } if color_col == "log_fc": sb_kw["color_continuous_scale"] = LOG_FC_DIVERGING_SCALE sb_kw["range_color"] = [LOG_FC_COLOR_MIN, LOG_FC_COLOR_MAX] else: sb_kw["color_continuous_scale"] = "Viridis_r" fig = px.sunburst(sub, **sb_kw) fig.update_layout( template="plotly_white", font=PLOT_FONT, margin=dict(l=8, r=8, t=100, b=16), height=min(820, 520 + int(max_features) * 5), title=dict( text="Top flux reactions by model rank, nested under pathway", x=0, xanchor="left", y=0.99, yanchor="top", font=dict(size=13, family=PLOT_FONT["family"]), pad=dict(b=16, l=4), ), ) if color_col == "log_fc": fig.update_layout( coloraxis=dict( cmin=LOG_FC_COLOR_MIN, cmax=LOG_FC_COLOR_MAX, colorbar=dict( title=dict(text=LABEL_LOG2FC, side="right"), tickformat=".2f", len=0.38, thickness=12, y=0.52, yanchor="middle", ), ) ) return fig def flux_volcano(flux_df: pd.DataFrame) -> go.Figure: if "log_fc" not in flux_df.columns: return go.Figure() d = flux_df.dropna(subset=["log_fc"]).copy() if d.empty: return go.Figure() # Drop degenerate rows: ~zero fold-change with exactly-zero adjusted p (numeric artifact / noise). lf = d["log_fc"].astype(float) if "pval_adj" in d.columns: pa = d["pval_adj"].astype(float) bad = np.isfinite(lf) & np.isfinite(pa) & (np.abs(lf) < 1e-10) & (pa <= 0.0) d = d[~bad] if d.empty: return go.Figure() if "pval_adj_log" in d.columns: y = d["pval_adj_log"].astype(float) else: p = d["pval_adj"].astype(float).clip(lower=1e-300) y = -np.log10(p.to_numpy()) d = d.assign(_neglogp=y) fig = px.scatter( d, x="log_fc", y="_neglogp", color="mean_rank", color_continuous_scale="Viridis_r", hover_name="feature", hover_data=["pathway", "pval_adj", "group"], labels={ "log_fc": LABEL_LOG2FC, "_neglogp": LABEL_NEG_LOG10_ADJ_P, "mean_rank": "Mean rank", }, ) fig.update_layout( template="plotly_white", font=PLOT_FONT, title="Differential flux vs statistical significance", height=520, margin=dict(l=52, r=24, t=52, b=48), coloraxis_colorbar=dict( title=dict(text="Mean rank", side="right"), thickness=12, len=0.55, ), ) return fig def motif_tf_mean_rank_bars(atac_df: pd.DataFrame, top_n: int = 22) -> go.Figure: """Aggregate motif features by TF name (prefix before ``_``); show lowest mean joint rank.""" if atac_df.empty or "feature" not in atac_df.columns: return go.Figure() def _tf_prefix(feat: str) -> str: s = str(feat) if "_" in s: head, tail = s.rsplit("_", 1) if tail.isdigit(): return head return s d = atac_df.copy() d["_tf"] = d["feature"].map(_tf_prefix) agg = d.groupby("_tf", as_index=False)["mean_rank"].mean() agg = agg.nsmallest(int(top_n), "mean_rank").sort_values("mean_rank", ascending=True) if agg.empty: return go.Figure() y_show = agg["_tf"].astype(str).map(lambda s: _truncate_label(s, 36)) fig = go.Figure( go.Bar( y=y_show, x=agg["mean_rank"], orientation="h", marker_color=MODALITY_COLOR.get("ATAC", PALETTE[0]), customdata=agg["_tf"], hovertemplate="%{customdata}
Mean mean_rank (across motifs): %{x:.2f}", ) ) fig.update_layout( template="plotly_white", font=PLOT_FONT, title=f"TFs by average motif rank (top {top_n} by lowest mean rank)", height=min(640, 48 + 22 * len(agg)), margin=dict(l=160, r=24, t=52, b=40), xaxis_title="Mean of mean_rank over motif instances (lower = stronger)", yaxis_title="", ) fig.update_yaxes(autorange="reversed", tickfont=dict(size=10)) return fig def motif_chromvar_volcano(atac_df: pd.DataFrame) -> go.Figure: """Motif differential view: mean activity difference (reprogramming − dead-end) vs significance.""" need = ("mean_diff", "pval_adj") if not all(c in atac_df.columns for c in need): return go.Figure() d = atac_df.dropna(subset=["mean_diff", "pval_adj"]).copy() if d.empty: return go.Figure() md = d["mean_diff"].astype(float) pa = d["pval_adj"].astype(float) bad = np.isfinite(md) & np.isfinite(pa) & (np.abs(md) < 1e-12) & (pa <= 0.0) d = d[~bad] if d.empty: return go.Figure() if "pval_adj_log" in d.columns: y = d["pval_adj_log"].astype(float) else: p = d["pval_adj"].astype(float).clip(lower=1e-300) y = -np.log10(p.to_numpy()) d = d.assign(_y=y) hover_cols = [c for c in ("group", "pval_adj", "mean_rank", "mean_de", "mean_re") if c in d.columns] fig = px.scatter( d, x="mean_diff", y="_y", color="mean_rank", color_continuous_scale="Viridis_r", hover_name="feature", hover_data=hover_cols if hover_cols else None, labels={ "mean_diff": "Mean difference (reprogramming − dead-end)", "_y": LABEL_NEG_LOG10_ADJ_P, "mean_rank": "Mean rank", }, ) fig.update_layout( template="plotly_white", font=PLOT_FONT, title="TF motif differential activity (mean difference vs significance)", height=520, margin=dict(l=52, r=24, t=52, b=48), coloraxis_colorbar=dict(title=dict(text="Mean rank", side="right"), thickness=12, len=0.55), ) return fig def notebook_style_activity_scatter( df: pd.DataFrame, title: str, x_title: str, y_title: str, ) -> go.Figure: """mean_de vs mean_re, colour = pval_adj_log (Reds), marker size ∝ inverse mean_rank.""" need = ("mean_de", "mean_re", "mean_rank", "pval_adj_log", "feature", "group") if not all(c in df.columns for c in need): return go.Figure() d = df.dropna(subset=["mean_de", "mean_re", "mean_rank", "pval_adj_log"]).copy() if d.empty: return go.Figure() mx = float(d["mean_rank"].max()) d = d.assign(_inv=(mx - d["mean_rank"].astype(float)).clip(lower=0)) inv = d["_inv"].astype(float) lo, hi = float(inv.min()), float(inv.max()) if hi <= lo: d["_sz"] = 6.0 else: d["_sz"] = 3.5 + (inv - lo) / (hi - lo) * 9.0 fig = px.scatter( d, x="mean_de", y="mean_re", color="pval_adj_log", color_continuous_scale="Reds", size="_sz", size_max=14, hover_name="feature", hover_data={ "mean_rank": ":.2f", "group": True, "pval_adj_log": ":.2f", "_inv": False, "_sz": False, }, labels={ "mean_de": x_title, "mean_re": y_title, "pval_adj_log": "Adj. p-value (log)", }, ) fig.update_traces( marker=dict(line=dict(width=0.45, color="rgba(255,255,255,0.75)"), opacity=0.9), selector=dict(mode="markers"), ) fig.update_layout( template="plotly_white", font=PLOT_FONT, title=title, height=520, margin=dict(l=52, r=24, t=52, b=48), coloraxis_colorbar=dict(title=dict(text="Adj. p (log)", side="right"), thickness=12, len=0.55), ) return fig def pathway_bubble_suggested_height(n_paths: int) -> int: """Total figure height for pathway bubble panels (use the max of both cohorts so legends line up).""" n = max(int(n_paths), 1) return max(520, min(1100, 22 * n + 200)) def pathway_enrichment_bubble_panel( df: pd.DataFrame, title: str, *, show_colorbar: bool = True, layout_height: int | None = None, ) -> go.Figure: """Single cohort: Reactome (circle) vs KEGG (square), colour = −log₁₀ Benjamini (scale per panel).""" fig = go.Figure() if df.empty: fig.update_layout( template="plotly_white", font=PLOT_FONT, title=dict(text=title, x=0.5, xanchor="center"), annotations=[ dict( text="No significant pathways (Benjamini-Hochberg q < 0.05)", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=13, color="#64748b"), ) ], height=320, margin=dict(l=40, r=40, t=56, b=40), ) return fig # More genes in the overlap first, then stronger gene ratio (matches enrichment table emphasis). d = df.sort_values(by=["Count", "Gene Ratio"], ascending=[False, False]).reset_index(drop=True) d = d.assign( _neglog=-np.log10(d["Benjamini"].astype(float).clip(lower=1e-300)), _y=np.arange(len(d), dtype=float), ) nl = d["_neglog"].astype(float) cmin = float(nl.min()) cmax = float(nl.max()) if cmax <= cmin: cmax = cmin + 1e-6 # Single trace: per-panel cmin/cmax so Viridis uses the cohort’s range (shared global max clusters at one hue). sym_map = {"Reactome": "circle", "KEGG": "square"} symbols = [sym_map.get(str(x), "circle") for x in d["Library"].tolist()] sz = np.sqrt(d["Count"].astype(float).clip(lower=1)) * 4.8 customdata = np.stack( [d["Count"].to_numpy(), d["_neglog"].to_numpy(), d["Library"].astype(str).to_numpy()], axis=1, ) fig.add_trace( go.Scatter( x=d["Gene Ratio"], y=d["_y"], mode="markers", name="Pathways", showlegend=False, marker=dict( size=sz, sizemode="diameter", sizemin=4, symbol=symbols, color=d["_neglog"], cmin=cmin, cmax=cmax, colorscale="Viridis", showscale=bool(show_colorbar), colorbar=dict( title=dict( text="\u2212log\u2081\u2080 q", side="right", ), len=0.72, thickness=12, y=0.45, yanchor="middle", outlinewidth=0, ) if show_colorbar else None, line=dict(width=0.75, color="rgba(0,0,0,0.5)"), opacity=0.92, ), text=d["Term"], customdata=customdata, hovertemplate=( "%{text}
%{customdata[2]}
Gene ratio: %{x:.3f}
Count: %{customdata[0]}" "
\u2212log\u2081\u2080 Benjamini: %{customdata[1]:.2f}" ), ) ) for lib, sym in (("Reactome", "circle"), ("KEGG", "square")): if lib not in set(d["Library"].astype(str)): continue fig.add_trace( go.Scatter( x=[None], y=[None], mode="markers", name=lib, marker=dict( symbol=sym, size=11, color="#475569", line=dict(width=1, color="rgba(0,0,0,0.45)"), ), showlegend=True, ) ) ticktext = [_truncate_label(str(t), 52) for t in d["Term"]] h = int(layout_height) if layout_height is not None else pathway_bubble_suggested_height(len(d)) fig.update_yaxes( tickmode="array", tickvals=d["_y"].tolist(), ticktext=ticktext, autorange="reversed", title="", ) fig.update_xaxes(title_text="Gene ratio (count ÷ list total)") fig.update_layout( template="plotly_white", font=PLOT_FONT, title=dict( text=title, x=0.5, xanchor="center", yanchor="top", y=0.985, pad=dict(b=0), ), height=h, margin=dict(l=215, r=132, t=48, b=108), legend=dict( orientation="h", yanchor="top", y=-0.11, xanchor="center", x=0.5, bgcolor="rgba(255,255,255,0.92)", bordercolor="rgba(0,0,0,0.08)", borderwidth=1, ), showlegend=True, ) return fig def pathway_gene_membership_heatmap( z: np.ndarray, row_labels: list[str], col_labels: list[str] ) -> go.Figure: """Pathway × gene grid; empty cells use a light tint vs page white; Reactome/KEGG as a narrow left row spine.""" if z.size == 0: return go.Figure() z_int = z.astype(int) n_rows, n_cols = z.shape def _cell_hint(v: float) -> str: k = int(round(float(v))) return { 0: "", 1: "Gene enriched in dead-end contrast", 2: "Gene enriched in reprogramming contrast", 3: "Reactome pathway set", 4: "KEGG pathway set", }.get(k, "") # Discrete codes 0–4 must not use z/4 (3→0.75 landed in the KEGG band). Map to fixed slots. _z_plot = {0: 0.04, 1: 0.24, 2: 0.44, 3: 0.64, 4: 0.84} # Slight contrast vs PAGE_BG (#fff) so empty (code 0) cells read as a grid, not “missing” paint. _empty_cell = "#f1f5f9" colorscale_main = [ [0.0, _empty_cell], [0.14, _empty_cell], [0.15, "#e69138"], [0.33, "#e69138"], [0.34, "#7eb6d9"], [0.53, "#7eb6d9"], [0.54, "#9ccc65"], [0.73, "#9ccc65"], [0.74, "#283593"], [1.0, "#283593"], ] _spine_plot = {3: 0.22, 4: 0.78} colorscale_spine = [ [0.0, "#9ccc65"], [0.42, "#9ccc65"], [0.58, "#283593"], [1.0, "#283593"], ] use_spine = n_cols >= 2 and str(col_labels[-1]) == "Library" if use_spine: lib_codes = z_int[:, -1] z_main_int = z_int[:, :-1] x_main = list(col_labels[:-1]) zn_main = np.vectorize(lambda v: _z_plot.get(int(v), 0.04))(z_main_int).astype(float) text_main = [[_cell_hint(z_main_int[i, j]) for j in range(z_main_int.shape[1])] for i in range(n_rows)] spine_zn = np.array( [[_spine_plot.get(int(lib_codes[i]), 0.22)] for i in range(n_rows)], dtype=float, ) spine_text = [ ["Reactome" if int(lib_codes[i]) == 3 else "KEGG" if int(lib_codes[i]) == 4 else "Library"] for i in range(n_rows) ] n_gene_cols = z_main_int.shape[1] else: zn_main = np.vectorize(lambda v: _z_plot.get(int(v), 0.04))(z_int).astype(float) text_main = [[_cell_hint(z_int[i, j]) for j in range(n_cols)] for i in range(n_rows)] x_main = list(col_labels) n_gene_cols = n_cols if use_spine: fig = make_subplots( rows=1, cols=2, column_widths=[0.034, 0.966], horizontal_spacing=0.006, shared_yaxes=True, ) fig.add_trace( go.Heatmap( z=spine_zn, x=[""], y=row_labels, text=spine_text, colorscale=colorscale_spine, zmin=0, zmax=1, showscale=False, xgap=0, ygap=1, hovertemplate="%{y}
%{text}", ), row=1, col=1, ) fig.add_trace( go.Heatmap( z=zn_main, x=x_main, y=row_labels, text=text_main, colorscale=colorscale_main, zmin=0, zmax=1, showscale=False, xgap=1, ygap=1, hovertemplate="%{y}
%{x}
%{text}", ), row=1, col=2, ) else: fig = go.Figure( data=[ go.Heatmap( z=zn_main, x=x_main, y=row_labels, text=text_main, colorscale=colorscale_main, zmin=0, zmax=1, showscale=False, xgap=1, ygap=1, hovertemplate="%{y}
%{x}
%{text}", ) ] ) cell_w = 10 cell_h = 20 w = int(min(1000, max(460, n_gene_cols * cell_w + 300))) h = int(min(960, max(460, n_rows * cell_h + 128))) fig.update_layout( template="plotly_white", font=PLOT_FONT, title=dict(text="Pathway-gene membership", x=0.5, xanchor="center"), height=h, width=w, margin=dict(l=4, r=168, t=52, b=108), paper_bgcolor=PAGE_BG, plot_bgcolor=PAGE_BG, ) if use_spine: fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=1) fig.update_xaxes(side="bottom", tickangle=-50, showgrid=False, zeroline=False, row=1, col=2) fig.update_yaxes( tickfont=dict(size=9), showgrid=False, zeroline=False, autorange="reversed", showticklabels=True, row=1, col=1, ) fig.update_yaxes( showgrid=False, zeroline=False, autorange="reversed", showticklabels=False, row=1, col=2, ) else: fig.update_layout( xaxis=dict(side="bottom", tickangle=-50, showgrid=False, zeroline=False), yaxis=dict( tickfont=dict(size=9), showgrid=False, zeroline=False, autorange="reversed", ), ) _legend_groups: list[tuple[str, str, str, str | None]] = [ ("Reactome", "#9ccc65", "pathway_library", "Library"), ("KEGG", "#283593", "pathway_library", None), ("Dead-end", "#e69138", "gene_contrast", "Contrast"), ("Reprogramming", "#7eb6d9", "gene_contrast", None), ] for name, color, group, group_title in _legend_groups: _mk = dict(size=11, color=color, symbol="square", line=dict(width=1, color="rgba(0,0,0,0.25)")) _kw: dict[str, Any] = dict( x=[None], y=[None], mode="markers", name=name, legendgroup=group, marker=_mk, showlegend=True, ) if group_title: _kw["legendgrouptitle"] = dict(text=group_title) fig.add_trace(go.Scatter(**_kw)) fig.update_layout( legend=dict( orientation="v", yanchor="top", y=0.98, xanchor="left", x=1.02, bgcolor="rgba(255,255,255,0.92)", bordercolor="rgba(0,0,0,0.08)", borderwidth=1, font=dict(size=11), ) ) return fig def flux_dead_end_vs_reprogram_scatter(flux_df: pd.DataFrame, max_pathway_colors: int = 12) -> go.Figure: need = ("mean_de", "mean_re") if not all(c in flux_df.columns for c in need): return go.Figure() d = flux_df.dropna(subset=list(need)).copy() if d.empty: return go.Figure() imp = ( d["importance_shift"].astype(float).clip(lower=0) * d["importance_att"].astype(float).clip(lower=0) ) ** 0.5 q = float(imp.quantile(0.95)) if len(imp) else 1.0 d = d.assign(_s=(imp / (q or 1.0)).clip(upper=1) * 20 + 5) pw = d["pathway"].fillna("Unknown").astype(str) if "pathway" in d.columns else pd.Series( ["Unknown"] * len(d), index=d.index ) top_pw = pw.value_counts().head(int(max_pathway_colors)).index d = d.assign(_pw_col=pw.where(pw.isin(top_pw), "Other")) uniq = sorted(d["_pw_col"].astype(str).unique(), key=lambda x: (x == "Other", x)) pal = list(LATENT_DISCRETE_PALETTE) pw_cmap: dict[str, str] = {} j = 0 for name in uniq: if name == "Other": pw_cmap[name] = "#94a3b8" else: pw_cmap[name] = pal[j % len(pal)] j += 1 fig = px.scatter( d, x="mean_de", y="mean_re", color="_pw_col", color_discrete_map=pw_cmap, size="_s", hover_name="feature", hover_data=["mean_rank", "log_fc", "pathway"], labels={ "mean_de": "Mean flux · dead-end", "mean_re": "Mean flux · reprogramming", "_pw_col": "Pathway", }, ) fig.update_layout( template="plotly_white", font=PLOT_FONT, height=540, margin=dict(l=52, r=20, t=52, b=40), title="Average measured flux by fate label (each point is one reaction)", legend=dict(orientation="h", yanchor="top", y=-0.28, xanchor="center", x=0.5), ) fig.update_traces(marker=dict(opacity=0.75, line=dict(width=0.35, color="rgba(0,0,0,0.3)"))) return fig def flux_pathway_mean_rank_violin(flux_df: pd.DataFrame, top_pathways: int = 12) -> go.Figure: sub = flux_df.dropna(subset=["pathway"]).copy() if sub.empty: return go.Figure() top_p = sub["pathway"].astype(str).value_counts().head(int(top_pathways)).index sub = sub[sub["pathway"].astype(str).isin(top_p)] top_list = list(top_p) v_cmap = {p: LATENT_DISCRETE_PALETTE[i % len(LATENT_DISCRETE_PALETTE)] for i, p in enumerate(top_list)} fig = px.violin( sub, x="pathway", y="mean_rank", box=True, points=False, color="pathway", color_discrete_map=v_cmap, labels={"mean_rank": "Mean rank (lower = stronger model focus)", "pathway": "Pathway"}, ) fig.update_layout( template="plotly_white", font=PLOT_FONT, showlegend=False, height=420, xaxis_tickangle=-32, margin=dict(l=48, r=24, t=48, b=140), title="How joint model rank spreads within high-coverage pathways", ) return fig def flux_reaction_annotation_panel(flux_df: pd.DataFrame, top_n: int = 26, metric: str = "mean_rank") -> go.Figure: """Three heatmap columns: pathway (categorical), DE Log₂FC, −log₁₀ adjusted p.""" top = _flux_prepare_top_ranked(flux_df, top_n, metric) if top.empty: return go.Figure() n = len(top) pathways = top["pathway"].fillna("Unknown").astype(str).tolist() if "pathway" in top.columns else ["Unknown"] * n uniq = list(dict.fromkeys(pathways)) code_map = {u: i for i, u in enumerate(uniq)} codes = np.array([code_map[p] for p in pathways], dtype=float) k = max(len(uniq), 1) qual = list(px.colors.qualitative.Safe) + list(px.colors.qualitative.Dark24) + list(px.colors.qualitative.Light24) if k <= 1: disc_scale = [[0, qual[0]], [1, qual[0]]] else: disc_scale = [[j / (k - 1), qual[j % len(qual)]] for j in range(k)] log_fc = top["log_fc"].fillna(0).astype(float).to_numpy() if "log_fc" in top.columns else np.zeros(n) if "pval_adj_log" in top.columns: pv = top["pval_adj_log"].fillna(0).astype(float).to_numpy() else: pv = -np.log10(top["pval_adj"].astype(float).clip(lower=1e-300).to_numpy()) full_features = top["feature"].astype(str).tolist() y_labels = [_truncate_label(str(f), 44) for f in full_features] z_path = codes.reshape(-1, 1) # hovertext (not customdata): subplot heatmaps often render %{customdata[0]} as "-" in the browser. hover_path = [[f"{fn}
pathway: {pw}"] for fn, pw in zip(full_features, pathways)] hover_lfc = [ [f"{fn}
{LABEL_LOG2FC}: {float(log_fc[i]):.4f}"] for i, fn in enumerate(full_features) ] hover_pv = [ [f"{fn}
{LABEL_NEG_LOG10_ADJ_P}: {float(pv[i]):.2f}"] for i, fn in enumerate(full_features) ] fig = make_subplots( rows=1, cols=3, shared_yaxes=True, horizontal_spacing=0.06, column_widths=[0.24, 0.24, 0.24], ) fig.add_trace( go.Heatmap( z=z_path, x=[""], y=y_labels, colorscale=disc_scale, zmin=0, zmax=max(k - 1, 0), showscale=False, hovertext=hover_path, hovertemplate="%{hovertext}", ), row=1, col=1, ) fig.add_trace( go.Heatmap( z=log_fc.reshape(-1, 1), x=[""], y=y_labels, colorscale=LOG_FC_DIVERGING_SCALE, zmin=LOG_FC_COLOR_MIN, zmax=LOG_FC_COLOR_MAX, showscale=True, colorbar=dict( title=dict(text=LABEL_LOG2FC, side="right"), tickformat=".2f", len=0.22, y=0.71, yanchor="middle", x=1.0, xanchor="left", xref="paper", yref="paper", thickness=12, ), hovertext=hover_lfc, hovertemplate="%{hovertext}", ), row=1, col=2, ) fig.add_trace( go.Heatmap( z=pv.reshape(-1, 1), x=[""], y=y_labels, colorscale="Viridis", showscale=True, colorbar=dict( title=dict(text=LABEL_NEG_LOG10_ADJ_P, side="right"), len=0.22, y=0.29, yanchor="middle", x=1.0, xanchor="left", xref="paper", yref="paper", thickness=12, ), hovertext=hover_pv, hovertemplate="%{hovertext}", ), row=1, col=3, ) fig.update_layout( template="plotly_white", font=PLOT_FONT, height=min(820, 120 + n * 22), width=900, margin=dict(l=8, r=108, t=56, b=72), title=dict( text=f"Pathway, {LABEL_LOG2FC}, and significance", x=0, xanchor="left", y=0.995, yanchor="top", font=dict(size=13, family=PLOT_FONT["family"]), pad=dict(b=8, l=4), ), ) fig.update_xaxes(side="bottom", title_standoff=8) fig.update_xaxes(title_text="Pathway", row=1, col=1) fig.update_xaxes(title_text=LABEL_LOG2FC, row=1, col=2) fig.update_xaxes(title_text=LABEL_NEG_LOG10_ADJ_P, row=1, col=3) fig.update_yaxes(autorange="reversed") return fig def flux_model_metric_profile(flux_df: pd.DataFrame, top_n: int = 22, metric: str = "mean_rank") -> go.Figure: """Matrix view: scaled shift, attention, model priority, and fate flux contrast.""" top = _flux_prepare_top_ranked(flux_df, top_n, metric) if top.empty: return go.Figure() def mm(s: pd.Series) -> np.ndarray: v = s.astype(float).to_numpy() lo, hi = float(np.nanmin(v)), float(np.nanmax(v)) if hi <= lo or not np.isfinite(lo): return np.zeros_like(v, dtype=float) return (v - lo) / (hi - lo) cols: list[np.ndarray] = [] labels: list[str] = [] for c, lab in (("importance_shift", "Latent shift impact"), ("importance_att", "Attention (rollout)")): if c in top.columns: cols.append(mm(top[c])) labels.append(lab) cols.append(1.0 - mm(top["mean_rank"])) labels.append("Joint priority (1 - scaled mean rank)") if "mean_de" in top.columns and "mean_re" in top.columns: de = top["mean_de"].astype(float).replace(0, np.nan) ratio = (top["mean_re"].astype(float) / (de + 1e-12)).fillna(0) cols.append(mm(ratio)) labels.append("RE / DE mean flux (scaled)") z = np.column_stack(cols) full_rxn = top["feature"].astype(str).tolist() x_labels = [_truncate_label(str(f), 34) for f in full_rxn] fig = px.imshow( z.T, x=x_labels, y=labels, aspect="auto", color_continuous_scale="Tealrose", labels=dict(x="Reaction", y="Metric", color="Scaled 0-1 per metric"), ) n_met, n_rxn = z.T.shape hover_cd = np.broadcast_to(np.array(full_rxn, dtype=object), (n_met, n_rxn)) fig.update_traces( customdata=hover_cd, hovertemplate="%{customdata}
%{y}
scaled: %{z:.3f}", ) fig.update_xaxes(tickangle=-50, side="bottom", title_standoff=12) fig.update_layout( template="plotly_white", font=PLOT_FONT, height=min(380, 140 + len(labels) * 36), margin=dict(l=200, r=28, t=64, b=200), title=dict( text="Reaction profile", x=0, xanchor="left", y=0.98, yanchor="top", font=dict(size=13, family=PLOT_FONT["family"]), pad=dict(b=10, l=4), ), ) return fig