one-for-all / _fig.py
frankyy03's picture
feat: Three.js 3D soul space (remove Plotly UMAP)
956241a verified
from __future__ import annotations
import numpy as np
import plotly.graph_objects as go
DARK = dict(
paper_bgcolor="#161b22",
plot_bgcolor="#161b22",
font=dict(color="#e6edf3", family="JetBrains Mono, monospace, sans-serif"),
margin=dict(t=48, r=16, b=48, l=56),
)
GRID_COLOR = "#21262d"
# index 0 = student, 1-5 = teachers
MODEL_COLORS = ["#e6edf3", "#7c3aed", "#06b6d4", "#f59e0b", "#34d399", "#f472b6"]
def build_base_traces(viz: dict, coords3d: np.ndarray) -> list:
"""One Scatter3d trace per model. Returns a plain list — never mutate it."""
labels = np.array(viz["labels"])
traces = []
for i, name in enumerate(viz["model_names"]):
c = coords3d[labels == name]
traces.append(go.Scatter3d(
x=c[:, 0].tolist(), y=c[:, 1].tolist(), z=c[:, 2].tolist(),
mode="markers", name=name,
marker=dict(
color=MODEL_COLORS[i % len(MODEL_COLORS)],
size=5 if name != "student" else 6,
opacity=0.85,
),
))
return traces
def rebuild_fig(base_traces: list, probe_points: list[dict]) -> go.Figure:
"""Fresh Figure from immutable base traces + accumulated probe points.
probe_points: list of {"x": float, "y": float, "z": float, "label": str}
Never mutates base_traces.
"""
fig = go.Figure(data=list(base_traces)) # copy, not reference
if probe_points:
fig.add_trace(go.Scatter3d(
x=[p["x"] for p in probe_points],
y=[p["y"] for p in probe_points],
z=[p["z"] for p in probe_points],
mode="markers", name="live probe",
marker=dict(color="#ffffff", size=9, symbol="diamond", opacity=1.0,
line=dict(color="#7c3aed", width=1)),
))
fig.update_layout(
**DARK,
title=dict(text="Soul space — UMAP 3D", font=dict(size=13)),
scene=dict(
bgcolor="#0d1117",
xaxis=dict(showgrid=True, gridcolor=GRID_COLOR, showticklabels=False, title=""),
yaxis=dict(showgrid=True, gridcolor=GRID_COLOR, showticklabels=False, title=""),
zaxis=dict(showgrid=True, gridcolor=GRID_COLOR, showticklabels=False, title=""),
),
legend=dict(bgcolor="rgba(22,27,34,0.85)", bordercolor="#30363d", borderwidth=1,
font=dict(size=11)),
uirevision="soul", # keeps camera position between updates
)
return fig
def build_cka_fig(cka: dict) -> go.Figure:
if not cka or "matrix" not in cka:
return go.Figure(layout={**DARK, "title": "No CKA data"})
fig = go.Figure(go.Heatmap(
z=cka["matrix"], x=cka["models"], y=cka["models"],
colorscale="Viridis", zmin=0, zmax=1,
colorbar=dict(title="CKA", thickness=14, tickfont=dict(color="#e6edf3", size=11)),
text=[[f"{v:.2f}" for v in row] for row in cka["matrix"]],
texttemplate="%{text}", textfont=dict(size=11),
))
fig.update_layout(
**DARK,
title=dict(text="CKA geometry alignment — all pairs", font=dict(size=13)),
xaxis=dict(side="bottom", tickfont=dict(size=11)),
yaxis=dict(autorange="reversed", tickfont=dict(size=11)),
)
return fig
def _ema(vals: list[float], alpha: float = 0.9) -> list[float]:
out, s = [], vals[0]
for v in vals:
s = alpha * s + (1 - alpha) * v
out.append(s)
return out
def build_curves_fig(curves: dict) -> go.Figure:
if not curves or not curves.get("steps"):
return go.Figure(layout={**DARK, "title": "No training data"})
steps = curves["steps"]
fig = go.Figure()
fig.add_trace(go.Scatter(x=steps, y=_ema(curves["task"]), name="task",
line=dict(color="#06b6d4", width=2)))
if curves.get("kd"):
fig.add_trace(go.Scatter(x=steps, y=_ema(curves["kd"]), name="kd (qwen)",
line=dict(color="#34d399", width=2)))
fig.add_trace(go.Scatter(x=steps, y=_ema(curves["geo"]), name="geo",
line=dict(color="#7c3aed", width=2)))
fig.add_trace(go.Scatter(x=steps, y=_ema(curves["total"]), name="total",
line=dict(color="#f59e0b", width=1.5, dash="dot")))
fig.update_layout(
**DARK,
title=dict(text="Loss curves (EMA α=0.9)", font=dict(size=13)),
xaxis=dict(title="step", gridcolor=GRID_COLOR),
yaxis=dict(title="loss", gridcolor=GRID_COLOR),
legend=dict(bgcolor="rgba(22,27,34,0.8)", bordercolor="#30363d", borderwidth=1),
)
return fig
def build_gate_area_fig(curves: dict) -> go.Figure:
if not curves or not curves.get("gate"):
return go.Figure(layout={**DARK, "title": "No gate data"})
steps = curves["steps"]
names = curves["teacher_names"]
teacher_colors = MODEL_COLORS[1:]
fig = go.Figure()
for i, name in enumerate(names):
fig.add_trace(go.Scatter(
x=steps,
y=[g[i] for g in curves["gate"]],
name=name, mode="lines", stackgroup="g",
line=dict(color=teacher_colors[i % len(teacher_colors)], width=0.5),
))
fig.update_layout(
**DARK,
title=dict(text="Gate — teacher routing over steps", font=dict(size=13)),
yaxis=dict(title="weight", range=[0, 1], gridcolor=GRID_COLOR),
xaxis=dict(title="step", gridcolor=GRID_COLOR),
legend=dict(bgcolor="rgba(22,27,34,0.8)", bordercolor="#30363d", borderwidth=1),
)
return fig