| from __future__ import annotations |
| import numpy as np |
| import plotly.graph_objects as go |
|
|
| DARK = dict( |
| paper_bgcolor="#161b22", |
| plot_bgcolor="#161b22", |
| font=dict(color="#e6edf3", family="JetBrains Mono, monospace, sans-serif"), |
| margin=dict(t=48, r=16, b=48, l=56), |
| ) |
| GRID_COLOR = "#21262d" |
|
|
| |
| MODEL_COLORS = ["#e6edf3", "#7c3aed", "#06b6d4", "#f59e0b", "#34d399", "#f472b6"] |
|
|
|
|
| def build_base_traces(viz: dict, coords3d: np.ndarray) -> list: |
| """One Scatter3d trace per model. Returns a plain list — never mutate it.""" |
| labels = np.array(viz["labels"]) |
| traces = [] |
| for i, name in enumerate(viz["model_names"]): |
| c = coords3d[labels == name] |
| traces.append(go.Scatter3d( |
| x=c[:, 0].tolist(), y=c[:, 1].tolist(), z=c[:, 2].tolist(), |
| mode="markers", name=name, |
| marker=dict( |
| color=MODEL_COLORS[i % len(MODEL_COLORS)], |
| size=5 if name != "student" else 6, |
| opacity=0.85, |
| ), |
| )) |
| return traces |
|
|
|
|
| def rebuild_fig(base_traces: list, probe_points: list[dict]) -> go.Figure: |
| """Fresh Figure from immutable base traces + accumulated probe points. |
| |
| probe_points: list of {"x": float, "y": float, "z": float, "label": str} |
| Never mutates base_traces. |
| """ |
| fig = go.Figure(data=list(base_traces)) |
| if probe_points: |
| fig.add_trace(go.Scatter3d( |
| x=[p["x"] for p in probe_points], |
| y=[p["y"] for p in probe_points], |
| z=[p["z"] for p in probe_points], |
| mode="markers", name="live probe", |
| marker=dict(color="#ffffff", size=9, symbol="diamond", opacity=1.0, |
| line=dict(color="#7c3aed", width=1)), |
| )) |
| fig.update_layout( |
| **DARK, |
| title=dict(text="Soul space — UMAP 3D", font=dict(size=13)), |
| scene=dict( |
| bgcolor="#0d1117", |
| xaxis=dict(showgrid=True, gridcolor=GRID_COLOR, showticklabels=False, title=""), |
| yaxis=dict(showgrid=True, gridcolor=GRID_COLOR, showticklabels=False, title=""), |
| zaxis=dict(showgrid=True, gridcolor=GRID_COLOR, showticklabels=False, title=""), |
| ), |
| legend=dict(bgcolor="rgba(22,27,34,0.85)", bordercolor="#30363d", borderwidth=1, |
| font=dict(size=11)), |
| uirevision="soul", |
| ) |
| return fig |
|
|
|
|
| def build_cka_fig(cka: dict) -> go.Figure: |
| if not cka or "matrix" not in cka: |
| return go.Figure(layout={**DARK, "title": "No CKA data"}) |
| fig = go.Figure(go.Heatmap( |
| z=cka["matrix"], x=cka["models"], y=cka["models"], |
| colorscale="Viridis", zmin=0, zmax=1, |
| colorbar=dict(title="CKA", thickness=14, tickfont=dict(color="#e6edf3", size=11)), |
| text=[[f"{v:.2f}" for v in row] for row in cka["matrix"]], |
| texttemplate="%{text}", textfont=dict(size=11), |
| )) |
| fig.update_layout( |
| **DARK, |
| title=dict(text="CKA geometry alignment — all pairs", font=dict(size=13)), |
| xaxis=dict(side="bottom", tickfont=dict(size=11)), |
| yaxis=dict(autorange="reversed", tickfont=dict(size=11)), |
| ) |
| return fig |
|
|
|
|
| def _ema(vals: list[float], alpha: float = 0.9) -> list[float]: |
| out, s = [], vals[0] |
| for v in vals: |
| s = alpha * s + (1 - alpha) * v |
| out.append(s) |
| return out |
|
|
|
|
| def build_curves_fig(curves: dict) -> go.Figure: |
| if not curves or not curves.get("steps"): |
| return go.Figure(layout={**DARK, "title": "No training data"}) |
| steps = curves["steps"] |
| fig = go.Figure() |
| fig.add_trace(go.Scatter(x=steps, y=_ema(curves["task"]), name="task", |
| line=dict(color="#06b6d4", width=2))) |
| if curves.get("kd"): |
| fig.add_trace(go.Scatter(x=steps, y=_ema(curves["kd"]), name="kd (qwen)", |
| line=dict(color="#34d399", width=2))) |
| fig.add_trace(go.Scatter(x=steps, y=_ema(curves["geo"]), name="geo", |
| line=dict(color="#7c3aed", width=2))) |
| fig.add_trace(go.Scatter(x=steps, y=_ema(curves["total"]), name="total", |
| line=dict(color="#f59e0b", width=1.5, dash="dot"))) |
| fig.update_layout( |
| **DARK, |
| title=dict(text="Loss curves (EMA α=0.9)", font=dict(size=13)), |
| xaxis=dict(title="step", gridcolor=GRID_COLOR), |
| yaxis=dict(title="loss", gridcolor=GRID_COLOR), |
| legend=dict(bgcolor="rgba(22,27,34,0.8)", bordercolor="#30363d", borderwidth=1), |
| ) |
| return fig |
|
|
|
|
| def build_gate_area_fig(curves: dict) -> go.Figure: |
| if not curves or not curves.get("gate"): |
| return go.Figure(layout={**DARK, "title": "No gate data"}) |
| steps = curves["steps"] |
| names = curves["teacher_names"] |
| teacher_colors = MODEL_COLORS[1:] |
| fig = go.Figure() |
| for i, name in enumerate(names): |
| fig.add_trace(go.Scatter( |
| x=steps, |
| y=[g[i] for g in curves["gate"]], |
| name=name, mode="lines", stackgroup="g", |
| line=dict(color=teacher_colors[i % len(teacher_colors)], width=0.5), |
| )) |
| fig.update_layout( |
| **DARK, |
| title=dict(text="Gate — teacher routing over steps", font=dict(size=13)), |
| yaxis=dict(title="weight", range=[0, 1], gridcolor=GRID_COLOR), |
| xaxis=dict(title="step", gridcolor=GRID_COLOR), |
| legend=dict(bgcolor="rgba(22,27,34,0.8)", bordercolor="#30363d", borderwidth=1), |
| ) |
| return fig |
|
|