kaveh's picture
changed CV to cross validation and increased font size in first page
34cacad
"""Plotly helpers for the explorer UI."""
from __future__ import annotations
import html
from typing import Any
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from streamlit_hf.lib.reactions import normalize_reaction_key
# Matches Streamlit theme primary + slate text; used across Plotly layouts.
PLOT_FONT = dict(family="Inter, system-ui, sans-serif", size=12)
# Same as app / plotly_white paper so figures are not tinted vs the page.
PAGE_BG = "#ffffff"
PALETTE = (
"#2563eb",
"#dc2626",
"#059669",
"#d97706",
"#7c3aed",
"#db2777",
"#0d9488",
"#4f46e5",
)
MODALITY_COLOR = {"RNA": "#E64B35", "ATAC": "#4DBBD5", "Flux": "#00A087"}
# Global modality pie only: edit here to try other hues (bars/scatter use MODALITY_COLOR).
MODALITY_PIE_COLOR = dict(MODALITY_COLOR)
# Log₂FC heatmaps/sunburst: colours like ggplot2 scale_colour_gradient2 (mid grey at 0).
LOG_FC_COLOR_MIN = -0.5
LOG_FC_COLOR_MAX = 0.5
LOG_FC_DIVERGING_SCALE: list[list] = [
[0.0, "#1C86EE"],
[0.5, "#FAFAFA"],
[1.0, "#FF0000"],
]
# Unicode minus (U+2212) and subscript ₁₀ / ₂ for axes/colorbars.
LABEL_NEG_LOG10_ADJ_P = "\u2212log\u2081\u2080 adj. p"
LABEL_LOG2FC = "Log\u2082FC"
# Cached attention dict uses lowercase modality keys.
FI_ATT_MOD_KEY = {"RNA": "rna", "ATAC": "atac", "Flux": "flux"}
# Model appends one batch-embedding token per modality; hide from attention rankings in the UI.
BATCH_EMBEDDING_FEATURE_NAMES = frozenset({"batch_rna", "batch_atac", "batch_flux"})
def _attention_pairs_skip_batch(pairs: list) -> list:
return [(n, s) for n, s in pairs if str(n) not in BATCH_EMBEDDING_FEATURE_NAMES]
def rollout_top_features_table(feature_names, vec, top_n: int) -> pd.DataFrame:
"""Top `top_n` rollout weights per modality slice, excluding batch-embedding tokens."""
names = [str(x) for x in feature_names]
v = np.asarray(vec, dtype=float)
rows = [
(names[i], float(v[i]))
for i in range(len(names))
if names[i] not in BATCH_EMBEDDING_FEATURE_NAMES
]
rows.sort(key=lambda x: -x[1])
rows = rows[:top_n]
if not rows:
return pd.DataFrame(columns=["feature", "mean_attention"])
feat, val = zip(*rows)
return pd.DataFrame({"feature": list(feat), "mean_attention": list(val)})
# Themed continuous scale for dominant-fate % on UMAP (low → high emphasis).
UMAP_PCT_COLORSCALE: list[list] = [
[0.0, "#eff6ff"],
[0.25, "#bfdbfe"],
[0.55, "#3b82f6"],
[0.82, "#2563eb"],
[1.0, "#1e3a8a"],
]
# Okabe–Ito–style distinct colours (colourblind-friendly) for categorical UMAP hues.
LATENT_DISCRETE_PALETTE = (
"#0072B2",
"#E69F00",
"#009E73",
"#CC79A7",
"#56B4E9",
"#D55E00",
"#F0E442",
"#000000",
)
def latent_scatter(
df,
color_col: str,
title: str,
width: int = 720,
height: int = 520,
marker_size: float = 5.0,
marker_opacity: float = 0.78,
subtitle: str | None = None,
):
d = df.copy()
hover_spec = {
"umap_x": ":.3f",
"umap_y": ":.3f",
"dataset_idx": True,
"fold": True,
"batch_no": True,
"predicted_class": True,
"label": True,
"correct": True,
"pct": ":.2f",
"modality_label": True,
"modality": True,
"predicted_value": ":.3f",
"clone_id": True,
"clone_size": True,
"cell_type": True,
}
if "modality_label" in d.columns:
hover_spec.pop("modality", None)
hover_data = {k: v for k, v in hover_spec.items() if k in d.columns}
_disp = {
"label": "CellTag-Multi label",
"predicted_class": "Predicted fate",
"pct": "Dominant fate (%)",
"modality_label": "Available modalities",
"dataset_idx": "Dataset index",
"batch_no": "Batch",
"fold": "Cross Validation fold",
}
labels_map = {c: _disp[c] for c in _disp if c in d.columns}
continuous = color_col == "pct"
if color_col == "fold":
d["_color"] = d["fold"].astype(str)
color_arg = "_color"
labels_map["_color"] = "Fold"
continuous = False
elif color_col == "batch_no":
d["_color"] = d["batch_no"].astype(str)
color_arg = "_color"
labels_map["_color"] = "Batch"
continuous = False
elif color_col == "correct":
d["_color"] = d["correct"].map({True: "Correct", False: "Wrong"})
color_arg = "_color"
labels_map["_color"] = "Prediction"
continuous = False
else:
color_arg = color_col
# Plotly Express turns title="" into a visible "undefined" title in some versions; omit when empty.
common = dict(
x="umap_x",
y="umap_y",
hover_data=hover_data,
labels=labels_map,
width=width,
height=height,
)
# Title + subtitle are applied via update_layout when `subtitle` is set (Plotly 5+).
if title and not subtitle:
common["title"] = title
if continuous:
fig = px.scatter(
d,
color=color_arg,
color_continuous_scale=UMAP_PCT_COLORSCALE,
**common,
)
else:
fig = px.scatter(
d,
color=color_arg,
color_discrete_sequence=list(LATENT_DISCRETE_PALETTE),
**common,
)
fig.update_traces(
marker=dict(size=marker_size, opacity=marker_opacity, line=dict(width=0.25, color="rgba(255,255,255,0.4)"))
)
if title and subtitle:
top_margin = 88
else:
top_margin = 56 if title else 28
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
title_font_size=16,
margin=dict(l=28, r=20, t=top_margin, b=28),
legend_title_text="",
xaxis_title="",
yaxis_title="",
paper_bgcolor=PAGE_BG,
plot_bgcolor=PAGE_BG,
)
if title and subtitle:
fig.update_layout(
title=dict(
text=title,
x=0.5,
xanchor="center",
font=dict(size=16, family=PLOT_FONT["family"]),
subtitle=dict(
text=subtitle,
font=dict(size=11, color="#64748b", family=PLOT_FONT["family"]),
),
),
)
elif not title:
fig.update_layout(title=None)
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)
return fig
def rank_scatter_shift_vs_attention(df_mod, modality: str, width: int = 420, height: int = 440):
"""Attention rank on x, shift rank on y, least-squares trend, colours by top ~10% within this modality."""
need = ("shift_order_mod", "attention_order_mod")
if not all(c in df_mod.columns for c in need):
return go.Figure()
sub = df_mod.dropna(subset=list(need)).copy()
if sub.empty:
return go.Figure()
n = len(sub)
top_k = max(1, int(np.ceil(0.1 * n)))
s_ok = sub["shift_order_mod"].astype(int) <= top_k
a_ok = sub["attention_order_mod"].astype(int) <= top_k
sub["_tier_label"] = np.where(
s_ok & a_ok,
"Both",
np.where(s_ok, "Shift", np.where(a_ok, "Attention", "Neither")),
)
x = sub["attention_order_mod"].astype(float).to_numpy()
y = sub["shift_order_mod"].astype(float).to_numpy()
fig = px.scatter(
sub,
x="attention_order_mod",
y="shift_order_mod",
color="_tier_label",
hover_name="feature",
hover_data={
"mean_rank": True,
"importance_shift": ":.4f",
"importance_att": ":.4f",
},
labels={
"attention_order_mod": "Attention rank",
"shift_order_mod": "Shift rank",
"_tier_label": "Top-10% tier",
},
category_orders={"_tier_label": ["Both", "Shift", "Attention", "Neither"]},
width=width,
height=height,
color_discrete_map={
"Both": PALETTE[0],
"Shift": PALETTE[1],
"Attention": PALETTE[2],
"Neither": "#94a3b8",
},
)
fig.update_traces(marker=dict(size=7, opacity=0.62, line=dict(width=0.5, color="rgba(15,23,42,0.28)")))
if len(x) >= 2 and float(np.ptp(x)) > 0:
coef = np.polyfit(x, y, 1)
poly = np.poly1d(coef)
xs = np.linspace(float(np.min(x)), float(np.max(x)), 100)
fig.add_trace(
go.Scatter(
x=xs,
y=poly(xs),
mode="lines",
name=f"y = {coef[0]:.2f}x + {coef[1]:.2f}",
line=dict(color="#2563eb", width=2, dash="dash"),
showlegend=True,
)
)
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
title=dict(
text=f"{modality}: shift vs attention (ranks)",
x=0.5,
xanchor="center",
y=0.98,
yanchor="top",
font=dict(size=14, family=PLOT_FONT["family"]),
),
margin=dict(l=48, r=20, t=52, b=72),
legend=dict(
title=dict(text="Among top 10% features?"),
orientation="h",
yanchor="top",
y=-0.2,
xanchor="center",
x=0.5,
),
)
return fig
def _truncate_label(s: str, max_len: int = 36) -> str:
s = str(s)
return s if len(s) <= max_len else s[: max_len - 1] + "…"
def joint_shift_attention_top_features(df_mod, modality: str, top_n: int):
"""
Top features by mean_rank (lowest = strongest joint shift+attention ranking).
Shift and attention importances are min-max scaled within this top-N slice for side-by-side comparison.
"""
need = ("mean_rank", "importance_shift", "importance_att", "feature")
if not all(c in df_mod.columns for c in need):
return go.Figure()
sub = df_mod.nsmallest(top_n, "mean_rank").copy()
if sub.empty:
return go.Figure()
def _mm(s: pd.Series) -> pd.Series:
lo, hi = float(s.min()), float(s.max())
if hi <= lo:
return pd.Series(0.5, index=s.index)
return (s.astype(float) - lo) / (hi - lo)
sub["_zs"] = _mm(sub["importance_shift"])
sub["_za"] = _mm(sub["importance_att"])
# Best (lowest mean_rank) at top of chart; matches shift/attention rows below.
sub = sub.sort_values("mean_rank", ascending=True)
feats_full = sub["feature"].astype(str)
y_disp = feats_full.map(lambda s: _truncate_label(s, 40))
base = MODALITY_COLOR.get(modality, PALETTE[0])
att_c = "#475569" if base != "#475569" else "#64748b"
margin_l = int(min(380, 64 + 5.8 * max((len(t) for t in y_disp), default=10)))
h = min(720, 52 + 22 * len(sub))
fig = go.Figure()
fig.add_trace(
go.Bar(
name="Shift (scaled)",
y=y_disp,
x=sub["_zs"],
orientation="h",
marker_color=base,
customdata=feats_full,
hovertemplate="<b>%{customdata}</b><br>Shift (scaled): %{x:.3f}<extra></extra>",
)
)
fig.add_trace(
go.Bar(
name="Attention (scaled)",
y=y_disp,
x=sub["_za"],
orientation="h",
marker_color=att_c,
customdata=feats_full,
hovertemplate="<b>%{customdata}</b><br>Attention (scaled): %{x:.3f}<extra></extra>",
)
)
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
title=dict(
text=f"{modality} · top {top_n}",
x=0.5,
xanchor="center",
y=0.98,
yanchor="top",
font=dict(size=14, family=PLOT_FONT["family"]),
),
barmode="group",
bargap=0.15,
bargroupgap=0.05,
width=680,
height=h,
margin=dict(l=margin_l, r=12, t=44, b=72),
xaxis_title="Scaled 0-1 within selection",
yaxis_title="",
legend=dict(orientation="h", yanchor="top", y=-0.14, xanchor="center", x=0.5),
)
fig.update_yaxes(autorange="reversed", tickfont=dict(size=10))
return fig
def modality_shift_attention_rank_stats(df_mod) -> dict[str, Any]:
"""Pearson / Spearman between per-modality shift and attention ordinal ranks."""
from scipy.stats import pearsonr, spearmanr
need = ("shift_order_mod", "attention_order_mod")
if not all(c in df_mod.columns for c in need):
return {"n": 0}
sub = df_mod.dropna(subset=list(need))
n = len(sub)
if n < 3:
return {"n": n}
xs = sub["attention_order_mod"].astype(float)
ys = sub["shift_order_mod"].astype(float)
pr, pp = pearsonr(xs, ys)
sr, sp = spearmanr(xs, ys)
return {
"n": n,
"pearson_r": float(pr),
"pearson_p": float(pp),
"spearman_r": float(sr),
"spearman_p": float(sp),
}
def rank_bar(
df_top,
xcol: str,
ycol: str,
title: str,
color: str = PALETTE[0],
xaxis_title: str | None = None,
):
d = df_top.sort_values(xcol, ascending=True)
y_raw = d[ycol].astype(str)
y_show = y_raw.map(lambda s: _truncate_label(s, 42))
margin_l = int(min(420, 80 + 5.8 * max((len(s) for s in y_show), default=12)))
fig = go.Figure(
go.Bar(
y=y_show,
x=d[xcol],
orientation="h",
marker_color=color,
customdata=y_raw,
hovertemplate="<b>%{customdata}</b><br>%{x:.4g}<extra></extra>",
)
)
xt = xaxis_title if xaxis_title is not None else xcol.replace("_", " ")
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
title=title,
width=680,
height=min(620, 38 + 20 * len(d)),
margin=dict(l=margin_l, r=24, t=48, b=40),
xaxis_title=xt,
yaxis_title="",
)
fig.update_yaxes(tickfont=dict(size=10))
return fig
def attention_top_comparison(fi_lists: dict, modality: str, top_n: int = 18):
"""fi_lists: cohort -> {rna|atac|flux: [(name, score), ...]}."""
mk = FI_ATT_MOD_KEY.get(modality, str(modality).lower())
traces = []
for key, name, color in (
("all", "All validation samples", PALETTE[0]),
("dead_end", "Predicted dead-end", PALETTE[1]),
("reprogramming", "Predicted reprogramming", PALETTE[2]),
):
cohort = fi_lists.get(key) or {}
items = _attention_pairs_skip_batch(list(cohort.get(mk, [])))[:top_n]
if not items:
continue
feats, scores = zip(*items)
traces.append(
go.Bar(
name=name,
x=list(scores),
y=[f[:52] + ("…" if len(f) > 52 else "") for f in feats],
orientation="h",
marker_color=color,
)
)
fig = go.Figure(traces)
bar_h = max(320, 36 + min(top_n, 20) * 22 * max(1, len(traces)))
fig.update_layout(
barmode="group",
template="plotly_white",
font=PLOT_FONT,
title=f"Top attention (rollout): {modality}",
width=520,
height=bar_h,
margin=dict(l=220, r=24, t=56, b=40),
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
)
if not traces:
fig.update_layout(
annotations=[
dict(
text="No attention list for this modality (re-run precompute).",
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
)
]
)
else:
fig.update_yaxes(autorange="reversed")
return fig
def attention_cohort_view(
fi_lists: dict,
modality: str,
top_n: int,
mode: str,
):
"""
mode: 'compare': grouped bars for all three cohorts;
'all' | 'dead_end' | 'reprogramming': single cohort only.
"""
if mode == "compare":
return attention_top_comparison(fi_lists, modality, top_n)
mk = FI_ATT_MOD_KEY.get(modality, str(modality).lower())
cohort = fi_lists.get(mode) or {}
items = _attention_pairs_skip_batch(list(cohort.get(mk, [])))[:top_n]
label = {
"all": "All validation samples",
"dead_end": "Predicted dead-end",
"reprogramming": "Predicted reprogramming",
}.get(mode, mode)
if not items:
fig = go.Figure()
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
title=f"{modality} · {label}",
annotations=[
dict(
text="No items for this cohort.",
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
)
],
)
return fig
feats, scores = zip(*items)
fig = go.Figure(
go.Bar(
x=list(scores),
y=[f[:52] + ("…" if len(f) > 52 else "") for f in feats],
orientation="h",
marker_color=PALETTE[0],
)
)
h = max(280, 40 + min(top_n, 25) * 20)
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
title=f"{modality} · {label}",
width=520,
height=h,
margin=dict(l=220, r=24, t=56, b=40),
xaxis_title="Attention weight",
)
fig.update_yaxes(autorange="reversed")
return fig
def _pie_hover_feature_lines(names: list[str], *, names_per_line: int = 5) -> str:
"""Join feature names with commas; start a new hover row every ``names_per_line`` items (HTML ``<br>``)."""
if not names:
return "—"
safe = [html.escape(str(n), quote=False) for n in names]
step = max(1, int(names_per_line))
lines: list[str] = []
for i in range(0, len(safe), step):
lines.append(", ".join(safe[i : i + step]))
return "<br>".join(lines)
def global_rank_triple_panel(
df_features,
top_n: int = 20,
top_n_pie: int = 100,
*,
chart_outline: bool = True,
modality_mix_hole: float = 0.0,
modality_mix_hover_feature_list: bool = False,
):
"""
Global top-N by latent-shift and by attention (min-max scaled), plus pie or donut of modality mix
among the top `top_n_pie` features by mean rank.
Set ``chart_outline=False`` for a flatter look (e.g. home page); Feature Insights keeps outlines by default.
Set ``modality_mix_hole`` in (0, 1), e.g. ``0.66``, for a donut instead of a full pie (e.g. home page).
Set ``modality_mix_hover_feature_list=True`` to show comma-separated feature names per donut slice on hover
(same pool as the pie: strongest by mean rank within each modality), wrapped every few names for readability.
"""
d = df_features.copy()
for col in ("importance_shift", "importance_att"):
min_v, max_v = d[col].min(), d[col].max()
if max_v > min_v:
d[col + "_norm"] = (d[col] - min_v) / (max_v - min_v)
else:
d[col + "_norm"] = 0.0
shift_top = d.nlargest(top_n, "importance_shift")
att_top = d.nlargest(top_n, "importance_att")
pie_pool = d.nsmallest(top_n_pie, "mean_rank")
fig = make_subplots(
rows=1,
cols=3,
column_widths=[0.36, 0.36, 0.28],
specs=[[{}, {}, {"type": "domain"}]],
subplot_titles=(
f"Top {top_n} by latent shift (ranked)",
f"Top {top_n} by attention (ranked)",
f"Top {top_n_pie} by mean rank (modality mix)",
),
horizontal_spacing=0.06,
)
bar_outline = dict(color="#1e293b", width=1.2) if chart_outline else dict(width=0)
pie_line = dict(color="#1e293b", width=1.2) if chart_outline else dict(width=0)
leg_line = dict(width=1.2, color="#1e293b") if chart_outline else dict(width=0)
fig.add_trace(
go.Bar(
x=shift_top["importance_shift_norm"],
y=shift_top["feature"],
orientation="h",
marker_color=[MODALITY_COLOR.get(m, "#64748b") for m in shift_top["modality"]],
marker_line=bar_outline,
showlegend=False,
hovertemplate="%{y}<br>scaled shift: %{x:.3f}<extra></extra>",
),
row=1,
col=1,
)
fig.add_trace(
go.Bar(
x=att_top["importance_att_norm"],
y=att_top["feature"],
orientation="h",
marker_color=[MODALITY_COLOR.get(m, "#64748b") for m in att_top["modality"]],
marker_line=bar_outline,
showlegend=False,
hovertemplate="%{y}<br>scaled attention: %{x:.3f}<extra></extra>",
),
row=1,
col=2,
)
pie_labels = ["RNA", "ATAC", "Flux"]
counts = pie_pool["modality"].value_counts()
pie_vals = [int(counts.get(lab, 0)) for lab in pie_labels]
if sum(pie_vals) == 0:
pie_vals = [1, 1, 1]
_hole = float(modality_mix_hole) if modality_mix_hole and modality_mix_hole > 0 else 0.0
# Narrow third subplot: "auto" avoids clipped outside labels on donuts.
_pie_textpos = "auto"
_pie_kwargs: dict = dict(
labels=pie_labels,
values=pie_vals,
marker=dict(
colors=[MODALITY_PIE_COLOR.get(l, "#64748b") for l in pie_labels],
line=pie_line,
),
textinfo="label+percent",
textfont_size=12,
textposition=_pie_textpos,
hole=_hole,
showlegend=False,
)
if modality_mix_hover_feature_list:
_hover_texts: list[str] = []
for lab in pie_labels:
sub = pie_pool[pie_pool["modality"] == lab]
if sub.empty:
_hover_texts.append("—")
else:
sub = sub.sort_values("mean_rank", ascending=True, kind="mergesort")
_per_line = 1 if lab == "Flux" else 5
_hover_texts.append(
_pie_hover_feature_lines(sub["feature"].astype(str).tolist(), names_per_line=_per_line)
)
_pie_kwargs["hovertext"] = _hover_texts
_pie_kwargs["hovertemplate"] = (
"<b>%{label}</b> · %{value} features (%{percent:.1%})<br><br>%{hovertext}<extra></extra>"
)
fig.add_trace(
go.Pie(**_pie_kwargs),
row=1,
col=3,
)
# Modality legend (bar colours + pie segment colours): invisible markers in first subplot only.
for _name, _col in (
("RNA (transcriptome)", MODALITY_PIE_COLOR["RNA"]),
("ATAC (chromatin)", MODALITY_PIE_COLOR["ATAC"]),
("Flux (metabolism)", MODALITY_PIE_COLOR["Flux"]),
):
fig.add_trace(
go.Scatter(
x=[None],
y=[None],
mode="markers",
marker=dict(size=12, color=_col, symbol="square", line=leg_line),
name=_name,
showlegend=True,
hoverinfo="skip",
),
row=1,
col=1,
)
fig.update_xaxes(title_text="Min-max scaled shift", row=1, col=1)
fig.update_xaxes(title_text="Min-max scaled attention", row=1, col=2)
fig.update_yaxes(autorange="reversed", row=1, col=1)
fig.update_yaxes(autorange="reversed", row=1, col=2)
h = max(480, 40 + top_n * 18)
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
height=h,
width=min(1280, 400 + top_n * 14),
margin=dict(l=40, r=40, t=80, b=108),
paper_bgcolor=PAGE_BG,
plot_bgcolor=PAGE_BG,
title_text="Global feature ranking (all modalities)",
title_x=0.5,
legend=dict(
title=dict(text="Modality colour key", font=dict(size=11, family=PLOT_FONT["family"])),
orientation="h",
yanchor="top",
y=-0.14,
xanchor="center",
x=0.5,
font=dict(size=11, family=PLOT_FONT["family"]),
traceorder="normal",
itemsizing="constant",
),
)
return fig
def _flux_prepare_top_ranked(flux_df: pd.DataFrame, top_n: int, metric: str = "mean_rank") -> pd.DataFrame:
sub = flux_df[~flux_df["feature"].astype(str).str.contains("batch", case=False, na=False)].copy()
if metric not in sub.columns:
metric = "mean_rank"
sub = sub.sort_values(metric, ascending=True).head(int(top_n)).copy()
if "pathway" in sub.columns:
pc = sub["pathway"].value_counts()
sub["_pw_n"] = sub["pathway"].map(pc)
sub.sort_values(["_pw_n", "pathway"], ascending=[False, True], inplace=True)
return sub
def flux_pathway_sunburst(flux_df: pd.DataFrame, max_features: int = 55) -> go.Figure:
sub = flux_df.dropna(subset=["pathway"]).copy()
if sub.empty:
return go.Figure()
sub = sub.nsmallest(int(max_features), "mean_rank")
sub["pathway"] = sub["pathway"].astype(str)
sub["_uid"] = np.arange(len(sub))
sub["rxn"] = sub.apply(
lambda r: f"{_truncate_label(str(r['feature']), 36)} ·{int(r['_uid'])}",
axis=1,
)
mr = sub["mean_rank"].astype(float)
sub["w"] = (mr.max() - mr + 1.0).clip(lower=0.5)
color_col = "log_fc" if "log_fc" in sub.columns and sub["log_fc"].notna().any() else "mean_rank"
sb_kw: dict[str, Any] = {
"path": ["pathway", "rxn"],
"values": "w",
"color": color_col,
"hover_data": {"mean_rank": ":.2f", "pval_adj": ":.2e", "feature": True, "w": False, "_uid": False},
}
if color_col == "log_fc":
sb_kw["color_continuous_scale"] = LOG_FC_DIVERGING_SCALE
sb_kw["range_color"] = [LOG_FC_COLOR_MIN, LOG_FC_COLOR_MAX]
else:
sb_kw["color_continuous_scale"] = "Viridis_r"
fig = px.sunburst(sub, **sb_kw)
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
margin=dict(l=8, r=8, t=100, b=16),
height=min(820, 520 + int(max_features) * 5),
title=dict(
text="Top flux reactions by model rank, nested under pathway",
x=0,
xanchor="left",
y=0.99,
yanchor="top",
font=dict(size=13, family=PLOT_FONT["family"]),
pad=dict(b=16, l=4),
),
)
if color_col == "log_fc":
fig.update_layout(
coloraxis=dict(
cmin=LOG_FC_COLOR_MIN,
cmax=LOG_FC_COLOR_MAX,
colorbar=dict(
title=dict(text=LABEL_LOG2FC, side="right"),
tickformat=".2f",
len=0.38,
thickness=12,
y=0.52,
yanchor="middle",
),
)
)
return fig
def flux_volcano(flux_df: pd.DataFrame) -> go.Figure:
if "log_fc" not in flux_df.columns:
return go.Figure()
d = flux_df.dropna(subset=["log_fc"]).copy()
if d.empty:
return go.Figure()
# Drop degenerate rows: ~zero fold-change with exactly-zero adjusted p (numeric artifact / noise).
lf = d["log_fc"].astype(float)
if "pval_adj" in d.columns:
pa = d["pval_adj"].astype(float)
bad = np.isfinite(lf) & np.isfinite(pa) & (np.abs(lf) < 1e-10) & (pa <= 0.0)
d = d[~bad]
if d.empty:
return go.Figure()
if "pval_adj_log" in d.columns:
y = d["pval_adj_log"].astype(float)
else:
p = d["pval_adj"].astype(float).clip(lower=1e-300)
y = -np.log10(p.to_numpy())
d = d.assign(_neglogp=y)
fig = px.scatter(
d,
x="log_fc",
y="_neglogp",
color="mean_rank",
color_continuous_scale="Viridis_r",
hover_name="feature",
hover_data=["pathway", "pval_adj", "group"],
labels={
"log_fc": LABEL_LOG2FC,
"_neglogp": LABEL_NEG_LOG10_ADJ_P,
"mean_rank": "Mean rank",
},
)
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
title="Differential flux vs statistical significance",
height=520,
margin=dict(l=52, r=24, t=52, b=48),
coloraxis_colorbar=dict(
title=dict(text="Mean rank", side="right"),
thickness=12,
len=0.55,
),
)
return fig
def motif_tf_mean_rank_bars(atac_df: pd.DataFrame, top_n: int = 22) -> go.Figure:
"""Aggregate motif features by TF name (prefix before ``_<motif_id>``); show lowest mean joint rank."""
if atac_df.empty or "feature" not in atac_df.columns:
return go.Figure()
def _tf_prefix(feat: str) -> str:
s = str(feat)
if "_" in s:
head, tail = s.rsplit("_", 1)
if tail.isdigit():
return head
return s
d = atac_df.copy()
d["_tf"] = d["feature"].map(_tf_prefix)
agg = d.groupby("_tf", as_index=False)["mean_rank"].mean()
agg = agg.nsmallest(int(top_n), "mean_rank").sort_values("mean_rank", ascending=True)
if agg.empty:
return go.Figure()
y_show = agg["_tf"].astype(str).map(lambda s: _truncate_label(s, 36))
fig = go.Figure(
go.Bar(
y=y_show,
x=agg["mean_rank"],
orientation="h",
marker_color=MODALITY_COLOR.get("ATAC", PALETTE[0]),
customdata=agg["_tf"],
hovertemplate="<b>%{customdata}</b><br>Mean mean_rank (across motifs): %{x:.2f}<extra></extra>",
)
)
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
title=f"TFs by average motif rank (top {top_n} by lowest mean rank)",
height=min(640, 48 + 22 * len(agg)),
margin=dict(l=160, r=24, t=52, b=40),
xaxis_title="Mean of mean_rank over motif instances (lower = stronger)",
yaxis_title="",
)
fig.update_yaxes(autorange="reversed", tickfont=dict(size=10))
return fig
def motif_chromvar_volcano(atac_df: pd.DataFrame) -> go.Figure:
"""Motif differential view: mean activity difference (reprogramming − dead-end) vs significance."""
need = ("mean_diff", "pval_adj")
if not all(c in atac_df.columns for c in need):
return go.Figure()
d = atac_df.dropna(subset=["mean_diff", "pval_adj"]).copy()
if d.empty:
return go.Figure()
md = d["mean_diff"].astype(float)
pa = d["pval_adj"].astype(float)
bad = np.isfinite(md) & np.isfinite(pa) & (np.abs(md) < 1e-12) & (pa <= 0.0)
d = d[~bad]
if d.empty:
return go.Figure()
if "pval_adj_log" in d.columns:
y = d["pval_adj_log"].astype(float)
else:
p = d["pval_adj"].astype(float).clip(lower=1e-300)
y = -np.log10(p.to_numpy())
d = d.assign(_y=y)
hover_cols = [c for c in ("group", "pval_adj", "mean_rank", "mean_de", "mean_re") if c in d.columns]
fig = px.scatter(
d,
x="mean_diff",
y="_y",
color="mean_rank",
color_continuous_scale="Viridis_r",
hover_name="feature",
hover_data=hover_cols if hover_cols else None,
labels={
"mean_diff": "Mean difference (reprogramming − dead-end)",
"_y": LABEL_NEG_LOG10_ADJ_P,
"mean_rank": "Mean rank",
},
)
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
title="TF motif differential activity (mean difference vs significance)",
height=520,
margin=dict(l=52, r=24, t=52, b=48),
coloraxis_colorbar=dict(title=dict(text="Mean rank", side="right"), thickness=12, len=0.55),
)
return fig
def notebook_style_activity_scatter(
df: pd.DataFrame,
title: str,
x_title: str,
y_title: str,
) -> go.Figure:
"""mean_de vs mean_re, colour = pval_adj_log (Reds), marker size ∝ inverse mean_rank."""
need = ("mean_de", "mean_re", "mean_rank", "pval_adj_log", "feature", "group")
if not all(c in df.columns for c in need):
return go.Figure()
d = df.dropna(subset=["mean_de", "mean_re", "mean_rank", "pval_adj_log"]).copy()
if d.empty:
return go.Figure()
mx = float(d["mean_rank"].max())
d = d.assign(_inv=(mx - d["mean_rank"].astype(float)).clip(lower=0))
inv = d["_inv"].astype(float)
lo, hi = float(inv.min()), float(inv.max())
if hi <= lo:
d["_sz"] = 6.0
else:
d["_sz"] = 3.5 + (inv - lo) / (hi - lo) * 9.0
fig = px.scatter(
d,
x="mean_de",
y="mean_re",
color="pval_adj_log",
color_continuous_scale="Reds",
size="_sz",
size_max=14,
hover_name="feature",
hover_data={
"mean_rank": ":.2f",
"group": True,
"pval_adj_log": ":.2f",
"_inv": False,
"_sz": False,
},
labels={
"mean_de": x_title,
"mean_re": y_title,
"pval_adj_log": "Adj. p-value (log)",
},
)
fig.update_traces(
marker=dict(line=dict(width=0.45, color="rgba(255,255,255,0.75)"), opacity=0.9),
selector=dict(mode="markers"),
)
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
title=title,
height=520,
margin=dict(l=52, r=24, t=52, b=48),
coloraxis_colorbar=dict(title=dict(text="Adj. p (log)", side="right"), thickness=12, len=0.55),
)
return fig
def pathway_bubble_suggested_height(n_paths: int) -> int:
"""Total figure height for pathway bubble panels (use the max of both cohorts so legends line up)."""
n = max(int(n_paths), 1)
return max(520, min(1100, 22 * n + 200))
def pathway_enrichment_bubble_panel(
df: pd.DataFrame,
title: str,
*,
show_colorbar: bool = True,
layout_height: int | None = None,
) -> go.Figure:
"""Single cohort: Reactome (circle) vs KEGG (square), colour = −log₁₀ Benjamini (scale per panel)."""
fig = go.Figure()
if df.empty:
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
title=dict(text=title, x=0.5, xanchor="center"),
annotations=[
dict(
text="No significant pathways (Benjamini-Hochberg q < 0.05)",
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=13, color="#64748b"),
)
],
height=320,
margin=dict(l=40, r=40, t=56, b=40),
)
return fig
# More genes in the overlap first, then stronger gene ratio (matches enrichment table emphasis).
d = df.sort_values(by=["Count", "Gene Ratio"], ascending=[False, False]).reset_index(drop=True)
d = d.assign(
_neglog=-np.log10(d["Benjamini"].astype(float).clip(lower=1e-300)),
_y=np.arange(len(d), dtype=float),
)
nl = d["_neglog"].astype(float)
cmin = float(nl.min())
cmax = float(nl.max())
if cmax <= cmin:
cmax = cmin + 1e-6
# Single trace: per-panel cmin/cmax so Viridis uses the cohort’s range (shared global max clusters at one hue).
sym_map = {"Reactome": "circle", "KEGG": "square"}
symbols = [sym_map.get(str(x), "circle") for x in d["Library"].tolist()]
sz = np.sqrt(d["Count"].astype(float).clip(lower=1)) * 4.8
customdata = np.stack(
[d["Count"].to_numpy(), d["_neglog"].to_numpy(), d["Library"].astype(str).to_numpy()],
axis=1,
)
fig.add_trace(
go.Scatter(
x=d["Gene Ratio"],
y=d["_y"],
mode="markers",
name="Pathways",
showlegend=False,
marker=dict(
size=sz,
sizemode="diameter",
sizemin=4,
symbol=symbols,
color=d["_neglog"],
cmin=cmin,
cmax=cmax,
colorscale="Viridis",
showscale=bool(show_colorbar),
colorbar=dict(
title=dict(
text="\u2212log\u2081\u2080 q",
side="right",
),
len=0.72,
thickness=12,
y=0.45,
yanchor="middle",
outlinewidth=0,
)
if show_colorbar
else None,
line=dict(width=0.75, color="rgba(0,0,0,0.5)"),
opacity=0.92,
),
text=d["Term"],
customdata=customdata,
hovertemplate=(
"<b>%{text}</b><br>%{customdata[2]}<br>Gene ratio: %{x:.3f}<br>Count: %{customdata[0]}"
"<br>\u2212log\u2081\u2080 Benjamini: %{customdata[1]:.2f}<extra></extra>"
),
)
)
for lib, sym in (("Reactome", "circle"), ("KEGG", "square")):
if lib not in set(d["Library"].astype(str)):
continue
fig.add_trace(
go.Scatter(
x=[None],
y=[None],
mode="markers",
name=lib,
marker=dict(
symbol=sym,
size=11,
color="#475569",
line=dict(width=1, color="rgba(0,0,0,0.45)"),
),
showlegend=True,
)
)
ticktext = [_truncate_label(str(t), 52) for t in d["Term"]]
h = int(layout_height) if layout_height is not None else pathway_bubble_suggested_height(len(d))
fig.update_yaxes(
tickmode="array",
tickvals=d["_y"].tolist(),
ticktext=ticktext,
autorange="reversed",
title="",
)
fig.update_xaxes(title_text="Gene ratio (count ÷ list total)")
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
title=dict(
text=title,
x=0.5,
xanchor="center",
yanchor="top",
y=0.985,
pad=dict(b=0),
),
height=h,
margin=dict(l=215, r=132, t=48, b=108),
legend=dict(
orientation="h",
yanchor="top",
y=-0.11,
xanchor="center",
x=0.5,
bgcolor="rgba(255,255,255,0.92)",
bordercolor="rgba(0,0,0,0.08)",
borderwidth=1,
),
showlegend=True,
)
return fig
def pathway_gene_membership_heatmap(
z: np.ndarray, row_labels: list[str], col_labels: list[str]
) -> go.Figure:
"""Pathway × gene grid; empty cells use a light tint vs page white; Reactome/KEGG as a narrow left row spine."""
if z.size == 0:
return go.Figure()
z_int = z.astype(int)
n_rows, n_cols = z.shape
def _cell_hint(v: float) -> str:
k = int(round(float(v)))
return {
0: "",
1: "Gene enriched in dead-end contrast",
2: "Gene enriched in reprogramming contrast",
3: "Reactome pathway set",
4: "KEGG pathway set",
}.get(k, "")
# Discrete codes 0–4 must not use z/4 (3→0.75 landed in the KEGG band). Map to fixed slots.
_z_plot = {0: 0.04, 1: 0.24, 2: 0.44, 3: 0.64, 4: 0.84}
# Slight contrast vs PAGE_BG (#fff) so empty (code 0) cells read as a grid, not “missing” paint.
_empty_cell = "#f1f5f9"
colorscale_main = [
[0.0, _empty_cell],
[0.14, _empty_cell],
[0.15, "#e69138"],
[0.33, "#e69138"],
[0.34, "#7eb6d9"],
[0.53, "#7eb6d9"],
[0.54, "#9ccc65"],
[0.73, "#9ccc65"],
[0.74, "#283593"],
[1.0, "#283593"],
]
_spine_plot = {3: 0.22, 4: 0.78}
colorscale_spine = [
[0.0, "#9ccc65"],
[0.42, "#9ccc65"],
[0.58, "#283593"],
[1.0, "#283593"],
]
use_spine = n_cols >= 2 and str(col_labels[-1]) == "Library"
if use_spine:
lib_codes = z_int[:, -1]
z_main_int = z_int[:, :-1]
x_main = list(col_labels[:-1])
zn_main = np.vectorize(lambda v: _z_plot.get(int(v), 0.04))(z_main_int).astype(float)
text_main = [[_cell_hint(z_main_int[i, j]) for j in range(z_main_int.shape[1])] for i in range(n_rows)]
spine_zn = np.array(
[[_spine_plot.get(int(lib_codes[i]), 0.22)] for i in range(n_rows)],
dtype=float,
)
spine_text = [
["Reactome" if int(lib_codes[i]) == 3 else "KEGG" if int(lib_codes[i]) == 4 else "Library"]
for i in range(n_rows)
]
n_gene_cols = z_main_int.shape[1]
else:
zn_main = np.vectorize(lambda v: _z_plot.get(int(v), 0.04))(z_int).astype(float)
text_main = [[_cell_hint(z_int[i, j]) for j in range(n_cols)] for i in range(n_rows)]
x_main = list(col_labels)
n_gene_cols = n_cols
if use_spine:
fig = make_subplots(
rows=1,
cols=2,
column_widths=[0.034, 0.966],
horizontal_spacing=0.006,
shared_yaxes=True,
)
fig.add_trace(
go.Heatmap(
z=spine_zn,
x=[""],
y=row_labels,
text=spine_text,
colorscale=colorscale_spine,
zmin=0,
zmax=1,
showscale=False,
xgap=0,
ygap=1,
hovertemplate="%{y}<br><b>%{text}</b><extra></extra>",
),
row=1,
col=1,
)
fig.add_trace(
go.Heatmap(
z=zn_main,
x=x_main,
y=row_labels,
text=text_main,
colorscale=colorscale_main,
zmin=0,
zmax=1,
showscale=False,
xgap=1,
ygap=1,
hovertemplate="%{y}<br>%{x}<br>%{text}<extra></extra>",
),
row=1,
col=2,
)
else:
fig = go.Figure(
data=[
go.Heatmap(
z=zn_main,
x=x_main,
y=row_labels,
text=text_main,
colorscale=colorscale_main,
zmin=0,
zmax=1,
showscale=False,
xgap=1,
ygap=1,
hovertemplate="%{y}<br>%{x}<br>%{text}<extra></extra>",
)
]
)
cell_w = 10
cell_h = 20
w = int(min(1000, max(460, n_gene_cols * cell_w + 300)))
h = int(min(960, max(460, n_rows * cell_h + 128)))
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
title=dict(text="Pathway-gene membership", x=0.5, xanchor="center"),
height=h,
width=w,
margin=dict(l=4, r=168, t=52, b=108),
paper_bgcolor=PAGE_BG,
plot_bgcolor=PAGE_BG,
)
if use_spine:
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=1)
fig.update_xaxes(side="bottom", tickangle=-50, showgrid=False, zeroline=False, row=1, col=2)
fig.update_yaxes(
tickfont=dict(size=9),
showgrid=False,
zeroline=False,
autorange="reversed",
showticklabels=True,
row=1,
col=1,
)
fig.update_yaxes(
showgrid=False,
zeroline=False,
autorange="reversed",
showticklabels=False,
row=1,
col=2,
)
else:
fig.update_layout(
xaxis=dict(side="bottom", tickangle=-50, showgrid=False, zeroline=False),
yaxis=dict(
tickfont=dict(size=9),
showgrid=False,
zeroline=False,
autorange="reversed",
),
)
_legend_groups: list[tuple[str, str, str, str | None]] = [
("Reactome", "#9ccc65", "pathway_library", "Library"),
("KEGG", "#283593", "pathway_library", None),
("Dead-end", "#e69138", "gene_contrast", "Contrast"),
("Reprogramming", "#7eb6d9", "gene_contrast", None),
]
for name, color, group, group_title in _legend_groups:
_mk = dict(size=11, color=color, symbol="square", line=dict(width=1, color="rgba(0,0,0,0.25)"))
_kw: dict[str, Any] = dict(
x=[None],
y=[None],
mode="markers",
name=name,
legendgroup=group,
marker=_mk,
showlegend=True,
)
if group_title:
_kw["legendgrouptitle"] = dict(text=group_title)
fig.add_trace(go.Scatter(**_kw))
fig.update_layout(
legend=dict(
orientation="v",
yanchor="top",
y=0.98,
xanchor="left",
x=1.02,
bgcolor="rgba(255,255,255,0.92)",
bordercolor="rgba(0,0,0,0.08)",
borderwidth=1,
font=dict(size=11),
)
)
return fig
def flux_dead_end_vs_reprogram_scatter(flux_df: pd.DataFrame, max_pathway_colors: int = 12) -> go.Figure:
need = ("mean_de", "mean_re")
if not all(c in flux_df.columns for c in need):
return go.Figure()
d = flux_df.dropna(subset=list(need)).copy()
if d.empty:
return go.Figure()
imp = (
d["importance_shift"].astype(float).clip(lower=0) * d["importance_att"].astype(float).clip(lower=0)
) ** 0.5
q = float(imp.quantile(0.95)) if len(imp) else 1.0
d = d.assign(_s=(imp / (q or 1.0)).clip(upper=1) * 20 + 5)
pw = d["pathway"].fillna("Unknown").astype(str) if "pathway" in d.columns else pd.Series(
["Unknown"] * len(d), index=d.index
)
top_pw = pw.value_counts().head(int(max_pathway_colors)).index
d = d.assign(_pw_col=pw.where(pw.isin(top_pw), "Other"))
uniq = sorted(d["_pw_col"].astype(str).unique(), key=lambda x: (x == "Other", x))
pal = list(LATENT_DISCRETE_PALETTE)
pw_cmap: dict[str, str] = {}
j = 0
for name in uniq:
if name == "Other":
pw_cmap[name] = "#94a3b8"
else:
pw_cmap[name] = pal[j % len(pal)]
j += 1
fig = px.scatter(
d,
x="mean_de",
y="mean_re",
color="_pw_col",
color_discrete_map=pw_cmap,
size="_s",
hover_name="feature",
hover_data=["mean_rank", "log_fc", "pathway"],
labels={
"mean_de": "Mean flux · dead-end",
"mean_re": "Mean flux · reprogramming",
"_pw_col": "Pathway",
},
)
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
height=540,
margin=dict(l=52, r=20, t=52, b=40),
title="Average measured flux by fate label (each point is one reaction)",
legend=dict(orientation="h", yanchor="top", y=-0.28, xanchor="center", x=0.5),
)
fig.update_traces(marker=dict(opacity=0.75, line=dict(width=0.35, color="rgba(0,0,0,0.3)")))
return fig
def flux_pathway_mean_rank_violin(flux_df: pd.DataFrame, top_pathways: int = 12) -> go.Figure:
sub = flux_df.dropna(subset=["pathway"]).copy()
if sub.empty:
return go.Figure()
top_p = sub["pathway"].astype(str).value_counts().head(int(top_pathways)).index
sub = sub[sub["pathway"].astype(str).isin(top_p)]
top_list = list(top_p)
v_cmap = {p: LATENT_DISCRETE_PALETTE[i % len(LATENT_DISCRETE_PALETTE)] for i, p in enumerate(top_list)}
fig = px.violin(
sub,
x="pathway",
y="mean_rank",
box=True,
points=False,
color="pathway",
color_discrete_map=v_cmap,
labels={"mean_rank": "Mean rank (lower = stronger model focus)", "pathway": "Pathway"},
)
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
showlegend=False,
height=420,
xaxis_tickangle=-32,
margin=dict(l=48, r=24, t=48, b=140),
title="How joint model rank spreads within high-coverage pathways",
)
return fig
def flux_reaction_annotation_panel(flux_df: pd.DataFrame, top_n: int = 26, metric: str = "mean_rank") -> go.Figure:
"""Three heatmap columns: pathway (categorical), DE Log₂FC, −log₁₀ adjusted p."""
top = _flux_prepare_top_ranked(flux_df, top_n, metric)
if top.empty:
return go.Figure()
n = len(top)
pathways = top["pathway"].fillna("Unknown").astype(str).tolist() if "pathway" in top.columns else ["Unknown"] * n
uniq = list(dict.fromkeys(pathways))
code_map = {u: i for i, u in enumerate(uniq)}
codes = np.array([code_map[p] for p in pathways], dtype=float)
k = max(len(uniq), 1)
qual = list(px.colors.qualitative.Safe) + list(px.colors.qualitative.Dark24) + list(px.colors.qualitative.Light24)
if k <= 1:
disc_scale = [[0, qual[0]], [1, qual[0]]]
else:
disc_scale = [[j / (k - 1), qual[j % len(qual)]] for j in range(k)]
log_fc = top["log_fc"].fillna(0).astype(float).to_numpy() if "log_fc" in top.columns else np.zeros(n)
if "pval_adj_log" in top.columns:
pv = top["pval_adj_log"].fillna(0).astype(float).to_numpy()
else:
pv = -np.log10(top["pval_adj"].astype(float).clip(lower=1e-300).to_numpy())
full_features = top["feature"].astype(str).tolist()
y_labels = [_truncate_label(str(f), 44) for f in full_features]
z_path = codes.reshape(-1, 1)
# hovertext (not customdata): subplot heatmaps often render %{customdata[0]} as "-" in the browser.
hover_path = [[f"<b>{fn}</b><br>pathway: {pw}"] for fn, pw in zip(full_features, pathways)]
hover_lfc = [
[f"<b>{fn}</b><br>{LABEL_LOG2FC}: {float(log_fc[i]):.4f}"]
for i, fn in enumerate(full_features)
]
hover_pv = [
[f"<b>{fn}</b><br>{LABEL_NEG_LOG10_ADJ_P}: {float(pv[i]):.2f}"]
for i, fn in enumerate(full_features)
]
fig = make_subplots(
rows=1,
cols=3,
shared_yaxes=True,
horizontal_spacing=0.06,
column_widths=[0.24, 0.24, 0.24],
)
fig.add_trace(
go.Heatmap(
z=z_path,
x=[""],
y=y_labels,
colorscale=disc_scale,
zmin=0,
zmax=max(k - 1, 0),
showscale=False,
hovertext=hover_path,
hovertemplate="%{hovertext}<extra></extra>",
),
row=1,
col=1,
)
fig.add_trace(
go.Heatmap(
z=log_fc.reshape(-1, 1),
x=[""],
y=y_labels,
colorscale=LOG_FC_DIVERGING_SCALE,
zmin=LOG_FC_COLOR_MIN,
zmax=LOG_FC_COLOR_MAX,
showscale=True,
colorbar=dict(
title=dict(text=LABEL_LOG2FC, side="right"),
tickformat=".2f",
len=0.22,
y=0.71,
yanchor="middle",
x=1.0,
xanchor="left",
xref="paper",
yref="paper",
thickness=12,
),
hovertext=hover_lfc,
hovertemplate="%{hovertext}<extra></extra>",
),
row=1,
col=2,
)
fig.add_trace(
go.Heatmap(
z=pv.reshape(-1, 1),
x=[""],
y=y_labels,
colorscale="Viridis",
showscale=True,
colorbar=dict(
title=dict(text=LABEL_NEG_LOG10_ADJ_P, side="right"),
len=0.22,
y=0.29,
yanchor="middle",
x=1.0,
xanchor="left",
xref="paper",
yref="paper",
thickness=12,
),
hovertext=hover_pv,
hovertemplate="%{hovertext}<extra></extra>",
),
row=1,
col=3,
)
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
height=min(820, 120 + n * 22),
width=900,
margin=dict(l=8, r=108, t=56, b=72),
title=dict(
text=f"Pathway, {LABEL_LOG2FC}, and significance",
x=0,
xanchor="left",
y=0.995,
yanchor="top",
font=dict(size=13, family=PLOT_FONT["family"]),
pad=dict(b=8, l=4),
),
)
fig.update_xaxes(side="bottom", title_standoff=8)
fig.update_xaxes(title_text="Pathway", row=1, col=1)
fig.update_xaxes(title_text=LABEL_LOG2FC, row=1, col=2)
fig.update_xaxes(title_text=LABEL_NEG_LOG10_ADJ_P, row=1, col=3)
fig.update_yaxes(autorange="reversed")
return fig
def flux_model_metric_profile(flux_df: pd.DataFrame, top_n: int = 22, metric: str = "mean_rank") -> go.Figure:
"""Matrix view: scaled shift, attention, model priority, and fate flux contrast."""
top = _flux_prepare_top_ranked(flux_df, top_n, metric)
if top.empty:
return go.Figure()
def mm(s: pd.Series) -> np.ndarray:
v = s.astype(float).to_numpy()
lo, hi = float(np.nanmin(v)), float(np.nanmax(v))
if hi <= lo or not np.isfinite(lo):
return np.zeros_like(v, dtype=float)
return (v - lo) / (hi - lo)
cols: list[np.ndarray] = []
labels: list[str] = []
for c, lab in (("importance_shift", "Latent shift impact"), ("importance_att", "Attention (rollout)")):
if c in top.columns:
cols.append(mm(top[c]))
labels.append(lab)
cols.append(1.0 - mm(top["mean_rank"]))
labels.append("Joint priority (1 - scaled mean rank)")
if "mean_de" in top.columns and "mean_re" in top.columns:
de = top["mean_de"].astype(float).replace(0, np.nan)
ratio = (top["mean_re"].astype(float) / (de + 1e-12)).fillna(0)
cols.append(mm(ratio))
labels.append("RE / DE mean flux (scaled)")
z = np.column_stack(cols)
full_rxn = top["feature"].astype(str).tolist()
x_labels = [_truncate_label(str(f), 34) for f in full_rxn]
fig = px.imshow(
z.T,
x=x_labels,
y=labels,
aspect="auto",
color_continuous_scale="Tealrose",
labels=dict(x="Reaction", y="Metric", color="Scaled 0-1 per metric"),
)
n_met, n_rxn = z.T.shape
hover_cd = np.broadcast_to(np.array(full_rxn, dtype=object), (n_met, n_rxn))
fig.update_traces(
customdata=hover_cd,
hovertemplate="<b>%{customdata}</b><br>%{y}<br>scaled: %{z:.3f}<extra></extra>",
)
fig.update_xaxes(tickangle=-50, side="bottom", title_standoff=12)
fig.update_layout(
template="plotly_white",
font=PLOT_FONT,
height=min(380, 140 + len(labels) * 36),
margin=dict(l=200, r=28, t=64, b=200),
title=dict(
text="Reaction profile",
x=0,
xanchor="left",
y=0.98,
yanchor="top",
font=dict(size=13, family=PLOT_FONT["family"]),
pad=dict(b=10, l=4),
),
)
return fig