# core/plotter_plotly.py
"""
Native Plotly interactive figures for Wang's Five Laws.
12 subplots stacked vertically (12×1), full browser width.
Fast: data aggregated once, drawn directly — no matplotlib conversion.
Layout (top → bottom):
0 pearson_QK Law 1 Spectral Linear Alignment
1 ssr_QK Law 2 Spectral Shape Fidelity
2 alpha_QK Law 1+2 Scale Factor α
3 sigma_max_Q Law 3 Max Singular Value (Q)
4 sigma_max_K Law 3 Max Singular Value (K)
5 cond_Q + cond_K Law 3 Condition Number κ (dual line)
6 cosU_QK Law 4 Output Subspace Q–K
7 cosU_QV Law 4 Output Subspace Q–V
8 cosU_KV Law 4 Output Subspace K–V
9 cosV_QK Law 5 Input Subspace Q–K
10 cosV_QV Law 5 Input Subspace Q–V
11 cosV_KV Law 5 Input Subspace K–V
"""
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# ── Color palette (identical to plotter.py) ───────────────────────────────────
C = {
"Q": "#2166AC",
"K": "#D6604D",
"V": "#4DAC26",
"QK": "#762A83",
"QV": "#01665E",
"KV": "#E08214",
"ref": "#888888",
}
BAND_ALPHA = 0.15 # opacity for IQR band fill
# ── Panel definitions ─────────────────────────────────────────────────────────
# (col, color_key, y_label, title, ideal_value or None)
PANELS = [
("pearson_QK", "QK", "Pearson r", "Law 1 — Spectral Linear Alignment (Pearson r Q–K)", 1.0),
("ssr_QK", "QK", "SSR", "Law 2 — Spectral Shape Fidelity (SSR Q–K)", 0.0),
("alpha_QK", "QK", "α", "Law 1+2 — Scale Factor α (Q–K)", 1.0),
("sigma_max_Q", "Q", "σ_max", "Law 3 — Max Singular Value σ_max (Q)", None),
("sigma_max_K", "K", "σ_max", "Law 3 — Max Singular Value σ_max (K)", None),
("cond_dual", None, "κ", "Law 3 — Condition Number κ (Q & K)", None),
("cosU_QK", "QK", "cosU", "Law 4 — Output Subspace cosU (Q–K)", None),
("cosU_QV", "QV", "cosU", "Law 4 — Output Subspace cosU (Q–V) [super-orth]", None),
("cosU_KV", "KV", "cosU", "Law 4 — Output Subspace cosU (K–V) [super-orth]", None),
("cosV_QK", "QK", "cosV", "Law 5 — Input Subspace cosV (Q–K)", None),
("cosV_QV", "QV", "cosV", "Law 5 — Input Subspace cosV (Q–V)", None),
("cosV_KV", "KV", "cosV", "Law 5 — Input Subspace cosV (K–V)", None),
]
SUBPLOT_HEIGHT = 280 # px per subplot
TOTAL_HEIGHT = SUBPLOT_HEIGHT * len(PANELS) + 120 # +header
# ─────────────────────────────────────────────────────────────────────────────
# Data helpers
# ─────────────────────────────────────────────────────────────────────────────
def _agg(df: pd.DataFrame, col: str):
"""
Pseudo-bulk two-step aggregation per layer (Nature Comms 2021).
Step 1: median across Q heads within each (layer, kv_head) group.
Step 2: median / q25 / q75 across kv_head groups per layer.
Avoids pseudoreplication bias in GQA models (e.g. 4Q:1K).
Excludes kv_shared rows for KV metrics (theoretical-value bias).
"""
kv_cols = {"ssr_KV", "pearson_KV", "cosU_KV", "cosV_KV", "alpha_KV"}
if col in kv_cols and "kv_shared" in df.columns:
df = df[df["kv_shared"] == 0]
layers = np.array(sorted(df["layer"].unique()), dtype=int)
med_vals, q25_vals, q75_vals = [], [], []
for layer in layers:
ldf = df[df["layer"] == layer]
# Step 1: median within each kv_head group
if "kv_head" in ldf.columns:
step1 = ldf.groupby("kv_head")[col].median().values
else:
step1 = ldf[col].dropna().values
step1 = step1[~np.isnan(step1.astype(float))] if len(step1) > 0 else step1
# Step 2: statistics across kv_head medians
med_vals.append(float(np.median(step1)) if len(step1) > 0 else np.nan)
q25_vals.append(float(np.percentile(step1, 25)) if len(step1) > 0 else np.nan)
q75_vals.append(float(np.percentile(step1, 75)) if len(step1) > 0 else np.nan)
return (layers,
np.array(med_vals, dtype=float),
np.array(q25_vals, dtype=float),
np.array(q75_vals, dtype=float))
def _global_layers(df: pd.DataFrame) -> list[int]:
if "kv_shared" not in df.columns:
return []
return sorted(df[df["kv_shared"] == 1]["layer"].unique().tolist())
def _infer_dims(df: pd.DataFrame) -> tuple[int, int]:
head_dim = int(df["head_dim"].dropna().median()) if "head_dim" in df.columns and df["head_dim"].notna().any() else 128
d_model = int(df["d_model"].dropna().median()) if "d_model" in df.columns and df["d_model"].notna().any() else 5120
return head_dim, d_model
# ─────────────────────────────────────────────────────────────────────────────
# Trace builders
# ─────────────────────────────────────────────────────────────────────────────
def _band_traces(layers, med, q25, q75, color, name,
row, dash="solid", show_legend=True):
"""Returns (band_trace, line_trace) for one series."""
rgba_fill = _hex_to_rgba(color, BAND_ALPHA)
band = go.Scatter(
x=np.concatenate([layers, layers[::-1]]).tolist(),
y=np.concatenate([q75, q25[::-1]]).tolist(),
fill="toself",
fillcolor=rgba_fill,
line=dict(color="rgba(0,0,0,0)"),
hoverinfo="skip",
showlegend=False,
legendgroup=name,
)
line = go.Scatter(
x=layers.tolist(),
y=med.tolist(),
mode="lines",
name=name,
line=dict(color=color, width=2, dash=dash),
hovertemplate=f"Layer %{{x}}
{name}: %{{y:.5f}}",
showlegend=show_legend,
legendgroup=name,
)
return band, line
def _hline_trace(layers, y_val, label, color=None, row=None):
color = color or C["ref"]
return go.Scatter(
x=[layers[0], layers[-1]],
y=[y_val, y_val],
mode="lines",
name=label,
line=dict(color=color, width=1.2, dash="dash"),
hoverinfo="skip",
showlegend=True,
legendgroup=label,
)
def _hex_to_rgba(hex_color: str, alpha: float) -> str:
h = hex_color.lstrip("#")
r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
return f"rgba({r},{g},{b},{alpha})"
def _vlines(fig, global_layers, row, x_range):
for gl in global_layers:
fig.add_vline(
x=gl, row=row, col=1,
line=dict(color="#AAAAAA", width=1, dash="dot"),
annotation=dict(
text=f"G{gl}", font=dict(size=8, color="#999999"),
showarrow=False, yref="paper",
) if row == 1 else None,
)
# ─────────────────────────────────────────────────────────────────────────────
# Single-model native Plotly figure
# ─────────────────────────────────────────────────────────────────────────────
def plotly_single(
df: pd.DataFrame,
model_name: str,
show_band: bool = True,
) -> go.Figure:
"""
12×1 stacked subplots, full browser width.
Each subplot: median line + IQR band + reference lines + global-layer markers.
"""
n_panels = len(PANELS)
head_dim, d_model = _infer_dims(df)
baseline_U = 1.0 / np.sqrt(head_dim)
baseline_V = 1.0 / np.sqrt(d_model)
gl = _global_layers(df)
subtitles = [p[3] for p in PANELS]
fig = make_subplots(
rows=n_panels, cols=1,
subplot_titles=subtitles,
shared_xaxes=False,
vertical_spacing=0.03,
)
for row_idx, (col, color_key, ylabel, title, ideal) in enumerate(PANELS, start=1):
color = C[color_key] if color_key else C["Q"]
# ── special case: cond_dual ──────────────────────────────────────────
if col == "cond_dual":
for c_col, c_key, c_name in [
("cond_Q", "Q", "κ(Q)"),
("cond_K", "K", "κ(K)"),
]:
layers, med, q25, q75 = _agg(df, c_col)
if len(layers) == 0:
continue
band, line = _band_traces(
layers, med, q25, q75, C[c_key], c_name,
row=row_idx, show_legend=True
)
if show_band:
fig.add_trace(band, row=row_idx, col=1)
fig.add_trace(line, row=row_idx, col=1)
layers_ref = _agg(df, "cond_Q")[0]
else:
layers, med, q25, q75 = _agg(df, col)
if len(layers) == 0:
continue
band, line = _band_traces(
layers, med, q25, q75, color,
model_name, row=row_idx, show_legend=(row_idx == 1)
)
if show_band:
fig.add_trace(band, row=row_idx, col=1)
fig.add_trace(line, row=row_idx, col=1)
layers_ref = layers
# ── ideal / baseline reference lines ─────────────────────────────
if ideal is not None and len(layers_ref):
fig.add_trace(
_hline_trace(layers_ref, ideal, f"Ideal={ideal}",
color=C["ref"]),
row=row_idx, col=1
)
# ── random baselines for cosU / cosV ────────────────────────────────
if col.startswith("cosU_") and len(layers_ref):
fig.add_trace(
_hline_trace(layers_ref, baseline_U,
f"Random 1/√d_h ≈ {baseline_U:.4f}",
color="#E07B39"),
row=row_idx, col=1
)
if col.startswith("cosV_") and len(layers_ref):
fig.add_trace(
_hline_trace(layers_ref, baseline_V,
f"Random 1/√D ≈ {baseline_V:.4f}",
color="#E07B39"),
row=row_idx, col=1
)
# ── global layer vertical markers ────────────────────────────────────
for gl_idx in gl:
fig.add_vline(
x=gl_idx, row=row_idx, col=1,
line=dict(color="#BBBBBB", width=1, dash="dot"),
)
# ── y-axis label ─────────────────────────────────────────────────────
fig.update_yaxes(title_text=ylabel, row=row_idx, col=1,
title_font=dict(size=11))
fig.update_xaxes(title_text="Layer index", row=row_idx, col=1,
title_font=dict(size=11))
# ── log scale for condition number panel (row 6) ─────────────────────
if col == "cond_dual":
fig.update_yaxes(type="log", row=row_idx, col=1)
# ── shared Y for cosU row (panels 6,7,8) ─────────────────────────────────
_sync_yrange(fig, df, ["cosU_QK", "cosU_QV", "cosU_KV"],
rows=[7, 8, 9], pad=0.08)
# ── shared Y for cosV row (panels 9,10,11) ───────────────────────────────
_sync_yrange(fig, df, ["cosV_QK", "cosV_QV", "cosV_KV"],
rows=[10, 11, 12], pad=0.08)
# ── layout ───────────────────────────────────────────────────────────────
fig.update_layout(
title=dict(
text=f"Wang's Five Laws — {model_name}",
font=dict(size=16),
x=0.5, xanchor="center",
),
height=TOTAL_HEIGHT,
width=None, # full browser width
autosize=True,
showlegend=True,
legend=dict(
orientation="h",
yanchor="bottom", y=1.01,
xanchor="right", x=1,
font=dict(size=10),
),
margin=dict(l=70, r=30, t=80, b=40),
paper_bgcolor="white",
plot_bgcolor="#FAFAFA",
font=dict(family="Arial, sans-serif", size=11),
hovermode="x unified",
)
fig.update_annotations(font_size=11)
return fig
# ─────────────────────────────────────────────────────────────────────────────
# Two-model comparison native Plotly figure
# ─────────────────────────────────────────────────────────────────────────────
def plotly_compare(
df_a: pd.DataFrame,
df_b: pd.DataFrame,
name_a: str,
name_b: str,
show_band: bool = True,
show_delta: bool = True,
) -> go.Figure:
"""
12×1 stacked subplots.
Model A: solid lines. Model B: dashed lines.
Δ = B − A shown as light gray fill when show_delta=True.
"""
n_panels = len(PANELS)
head_dim_a, d_model_a = _infer_dims(df_a)
head_dim_b, d_model_b = _infer_dims(df_b)
head_dim = (head_dim_a + head_dim_b) // 2
d_model = (d_model_a + d_model_b) // 2
baseline_U = 1.0 / np.sqrt(head_dim)
baseline_V = 1.0 / np.sqrt(d_model)
gl = sorted(set(_global_layers(df_a)) | set(_global_layers(df_b)))
subtitles = [p[3] for p in PANELS]
fig = make_subplots(
rows=n_panels, cols=1,
subplot_titles=subtitles,
shared_xaxes=False,
vertical_spacing=0.03,
)
for row_idx, (col, color_key, ylabel, title, ideal) in enumerate(PANELS, start=1):
color = C[color_key] if color_key else C["Q"]
if col == "cond_dual":
for c_col, c_key, c_name in [
("cond_Q", "Q", "κ(Q)"),
("cond_K", "K", "κ(K)"),
]:
for df_, nm, dash in [(df_a, name_a, "solid"),
(df_b, name_b, "dash")]:
layers, med, q25, q75 = _agg(df_, c_col)
if len(layers) == 0:
continue
label = f"{c_name} {nm}"
band, line = _band_traces(
layers, med, q25, q75, C[c_key], label,
row=row_idx, dash=dash, show_legend=True
)
if show_band:
fig.add_trace(band, row=row_idx, col=1)
fig.add_trace(line, row=row_idx, col=1)
layers_ref = _agg(df_a, "cond_Q")[0]
else:
layers_a, med_a, q25_a, q75_a = _agg(df_a, col)
layers_b, med_b, q25_b, q75_b = _agg(df_b, col)
for layers, med, q25, q75, nm, dash in [
(layers_a, med_a, q25_a, q75_a, name_a, "solid"),
(layers_b, med_b, q25_b, q75_b, name_b, "dash"),
]:
if len(layers) == 0:
continue
show_leg = (row_idx == 1)
band, line = _band_traces(
layers, med, q25, q75, color, nm,
row=row_idx, dash=dash, show_legend=show_leg
)
if show_band:
fig.add_trace(band, row=row_idx, col=1)
fig.add_trace(line, row=row_idx, col=1)
# Delta fill
if show_delta and len(layers_a) and len(layers_b):
common = np.intersect1d(layers_a, layers_b)
if len(common) > 1:
idx_a = np.isin(layers_a, common)
idx_b = np.isin(layers_b, common)
delta = med_b[idx_b] - med_a[idx_a]
zero = np.zeros_like(delta)
fig.add_trace(go.Scatter(
x=np.concatenate([common, common[::-1]]).tolist(),
y=np.concatenate([delta, zero[::-1]]).tolist(),
fill="toself",
fillcolor="rgba(160,160,160,0.20)",
line=dict(color="rgba(0,0,0,0)"),
hoverinfo="skip",
showlegend=(row_idx == 1),
name=f"Δ ({name_b}−{name_a})",
legendgroup="delta",
), row=row_idx, col=1)
layers_ref = layers_a if len(layers_a) else layers_b
# Reference lines
if ideal is not None and len(layers_ref):
fig.add_trace(
_hline_trace(layers_ref, ideal, f"Ideal={ideal}", C["ref"]),
row=row_idx, col=1
)
if col.startswith("cosU_") and len(layers_ref):
fig.add_trace(
_hline_trace(layers_ref, baseline_U,
f"Random 1/√d_h ≈ {baseline_U:.4f}", "#E07B39"),
row=row_idx, col=1
)
if col.startswith("cosV_") and len(layers_ref):
fig.add_trace(
_hline_trace(layers_ref, baseline_V,
f"Random 1/√D ≈ {baseline_V:.4f}", "#E07B39"),
row=row_idx, col=1
)
for gl_idx in gl:
fig.add_vline(
x=gl_idx, row=row_idx, col=1,
line=dict(color="#BBBBBB", width=1, dash="dot"),
)
fig.update_yaxes(title_text=ylabel, row=row_idx, col=1,
title_font=dict(size=11))
fig.update_xaxes(title_text="Layer index", row=row_idx, col=1,
title_font=dict(size=11))
# ── log scale for condition number panel (row 6) ─────────────────────
if col == "cond_dual":
fig.update_yaxes(type="log", row=row_idx, col=1)
_sync_yrange_compare(fig, df_a, df_b,
["cosU_QK", "cosU_QV", "cosU_KV"], [7, 8, 9])
_sync_yrange_compare(fig, df_a, df_b,
["cosV_QK", "cosV_QV", "cosV_KV"], [10, 11, 12])
fig.update_layout(
title=dict(
text=f"Wang's Five Laws — {name_a} vs {name_b}",
font=dict(size=16),
x=0.5, xanchor="center",
),
height=TOTAL_HEIGHT,
width=None,
autosize=True,
showlegend=True,
legend=dict(
orientation="h",
yanchor="bottom", y=1.01,
xanchor="right", x=1,
font=dict(size=10),
),
margin=dict(l=70, r=30, t=80, b=40),
paper_bgcolor="white",
plot_bgcolor="#FAFAFA",
font=dict(family="Arial, sans-serif", size=11),
hovermode="x unified",
)
fig.update_annotations(font_size=11)
return fig
# ─────────────────────────────────────────────────────────────────────────────
# Shared Y-axis helpers
# ─────────────────────────────────────────────────────────────────────────────
def _sync_yrange(fig, df, cols, rows, pad=0.08):
"""Force identical y-range for a set of rows (single model)."""
vals = []
for col in cols:
try:
_, med, q25, q75 = _agg(df, col)
vals.extend(q25[~np.isnan(q25)].tolist())
vals.extend(q75[~np.isnan(q75)].tolist())
except Exception:
pass
if not vals:
return
lo = max(0.0, min(vals) * (1 - pad))
hi = max(vals) * (1 + pad)
for r in rows:
fig.update_yaxes(range=[lo, hi], row=r, col=1)
def _sync_yrange_compare(fig, df_a, df_b, cols, rows, pad=0.08):
"""Force identical y-range for a set of rows (two-model comparison)."""
vals = []
for col in cols:
for df_ in [df_a, df_b]:
try:
_, med, q25, q75 = _agg(df_, col)
vals.extend(q25[~np.isnan(q25)].tolist())
vals.extend(q75[~np.isnan(q75)].tolist())
except Exception:
pass
if not vals:
return
lo = max(0.0, min(vals) * (1 - pad))
hi = max(vals) * (1 + pad)
for r in rows:
fig.update_yaxes(range=[lo, hi], row=r, col=1)