NeuroScope / viz_token_layer.py
Alogotron's picture
Upload viz_token_layer.py with huggingface_hub
17c5f1d verified
"""
NeuroScope — Token-Layer Activation Grid
Heatmap with tokens as columns and layers as rows.
Color encodes activation magnitude (L2 norm) per token per layer,
revealing how each token's representation evolves through the network.
All charts use Plotly with the project dark theme (#1a1a2e bg, #e6b800 accent).
"""
import numpy as np
import plotly.graph_objects as go
from extraction import ExtractionResult
# ---------------------------------------------------------------------------
# Theme constants
# ---------------------------------------------------------------------------
BG_COLOR = "#1a1a2e"
PAPER_COLOR = "#1a1a2e"
TEXT_COLOR = "#e0e0e0"
ACCENT_COLOR = "#e6b800"
GRID_COLOR = "#2a2a4e"
# Custom purple-to-gold heatmap colorscale for activation intensity
TOKEN_LAYER_COLORSCALE = [
[0.0, "#0d0d1a"],
[0.1, "#1a1040"],
[0.25, "#2d1b69"],
[0.4, "#5e2d8e"],
[0.55, "#8e4585"],
[0.7, "#c46a3a"],
[0.85, "#e6b800"],
[1.0, "#ffd633"],
]
def create_token_layer_grid(
result: ExtractionResult,
normalize: str = "global",
) -> go.Figure:
"""Create a token × layer activation magnitude heatmap.
Args:
result: Extraction output containing hidden states.
normalize: Normalization strategy:
- "global": Scale to global min/max across all layers and tokens.
- "per_layer": Normalize each row independently (highlights
within-layer variation).
- "per_token": Normalize each column independently (highlights
depth evolution per token).
- "none": Raw L2 norms.
Returns:
Plotly Figure with interactive heatmap.
"""
# hidden_states: (num_layers+1, seq_len, hidden_dim)
hs = result.hidden_states
tokens = result.tokens
num_layers_total = hs.shape[0] # includes embedding layer
seq_len = len(tokens)
# Compute L2 norm per token per layer → (num_layers+1, seq_len)
magnitudes = np.linalg.norm(hs, axis=-1)
# Apply normalization
display = magnitudes.copy()
if normalize == "global":
vmin, vmax = display.min(), display.max()
if vmax > vmin:
display = (display - vmin) / (vmax - vmin)
elif normalize == "per_layer":
for i in range(num_layers_total):
row = display[i]
rmin, rmax = row.min(), row.max()
if rmax > rmin:
display[i] = (row - rmin) / (rmax - rmin)
elif normalize == "per_token":
for j in range(seq_len):
col = display[:, j]
cmin, cmax = col.min(), col.max()
if cmax > cmin:
display[:, j] = (col - cmin) / (cmax - cmin)
# else: "none" — use raw values
# Build axis labels
x_labels = [t[:12] for t in tokens]
y_labels = ["Embed"] + [f"L{i}" for i in range(result.num_layers)]
# Build hover text with raw values
hover = np.empty((num_layers_total, seq_len), dtype=object)
for i in range(num_layers_total):
layer_name = "Embedding" if i == 0 else f"Layer {i - 1}"
for j in range(seq_len):
hover[i, j] = (
f"Token: {tokens[j]}<br>"
f"{layer_name}<br>"
f"L2 Norm: {magnitudes[i, j]:.2f}<br>"
f"Normalized: {display[i, j]:.3f}"
)
fig = go.Figure(
data=go.Heatmap(
z=display,
x=x_labels,
y=y_labels,
text=hover,
hoverinfo="text",
colorscale=TOKEN_LAYER_COLORSCALE,
colorbar=dict(
title=dict(
text="Activation" if normalize == "none" else "Norm. Activation",
font=dict(color=TEXT_COLOR),
),
tickfont=dict(color=TEXT_COLOR),
),
)
)
fig.update_layout(
title=dict(
text=f"Token × Layer Activation Grid (norm: {normalize})",
font=dict(color=ACCENT_COLOR, size=14),
),
xaxis=dict(
title=dict(text="Token", font=dict(color=TEXT_COLOR, size=11)),
tickfont=dict(color=TEXT_COLOR, size=9),
side="top",
tickangle=45,
),
yaxis=dict(
title=dict(text="Layer", font=dict(color=TEXT_COLOR, size=11)),
tickfont=dict(color=TEXT_COLOR, size=8),
autorange="reversed",
),
paper_bgcolor=PAPER_COLOR,
plot_bgcolor=BG_COLOR,
margin=dict(l=60, r=30, t=80, b=30),
height=520,
)
return fig