Spaces:
Runtime error
Runtime error
File size: 4,572 Bytes
17c5f1d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | """
NeuroScope — Token-Layer Activation Grid
Heatmap with tokens as columns and layers as rows.
Color encodes activation magnitude (L2 norm) per token per layer,
revealing how each token's representation evolves through the network.
All charts use Plotly with the project dark theme (#1a1a2e bg, #e6b800 accent).
"""
import numpy as np
import plotly.graph_objects as go
from extraction import ExtractionResult
# ---------------------------------------------------------------------------
# Theme constants
# ---------------------------------------------------------------------------
BG_COLOR = "#1a1a2e"
PAPER_COLOR = "#1a1a2e"
TEXT_COLOR = "#e0e0e0"
ACCENT_COLOR = "#e6b800"
GRID_COLOR = "#2a2a4e"
# Custom purple-to-gold heatmap colorscale for activation intensity
TOKEN_LAYER_COLORSCALE = [
[0.0, "#0d0d1a"],
[0.1, "#1a1040"],
[0.25, "#2d1b69"],
[0.4, "#5e2d8e"],
[0.55, "#8e4585"],
[0.7, "#c46a3a"],
[0.85, "#e6b800"],
[1.0, "#ffd633"],
]
def create_token_layer_grid(
result: ExtractionResult,
normalize: str = "global",
) -> go.Figure:
"""Create a token × layer activation magnitude heatmap.
Args:
result: Extraction output containing hidden states.
normalize: Normalization strategy:
- "global": Scale to global min/max across all layers and tokens.
- "per_layer": Normalize each row independently (highlights
within-layer variation).
- "per_token": Normalize each column independently (highlights
depth evolution per token).
- "none": Raw L2 norms.
Returns:
Plotly Figure with interactive heatmap.
"""
# hidden_states: (num_layers+1, seq_len, hidden_dim)
hs = result.hidden_states
tokens = result.tokens
num_layers_total = hs.shape[0] # includes embedding layer
seq_len = len(tokens)
# Compute L2 norm per token per layer → (num_layers+1, seq_len)
magnitudes = np.linalg.norm(hs, axis=-1)
# Apply normalization
display = magnitudes.copy()
if normalize == "global":
vmin, vmax = display.min(), display.max()
if vmax > vmin:
display = (display - vmin) / (vmax - vmin)
elif normalize == "per_layer":
for i in range(num_layers_total):
row = display[i]
rmin, rmax = row.min(), row.max()
if rmax > rmin:
display[i] = (row - rmin) / (rmax - rmin)
elif normalize == "per_token":
for j in range(seq_len):
col = display[:, j]
cmin, cmax = col.min(), col.max()
if cmax > cmin:
display[:, j] = (col - cmin) / (cmax - cmin)
# else: "none" — use raw values
# Build axis labels
x_labels = [t[:12] for t in tokens]
y_labels = ["Embed"] + [f"L{i}" for i in range(result.num_layers)]
# Build hover text with raw values
hover = np.empty((num_layers_total, seq_len), dtype=object)
for i in range(num_layers_total):
layer_name = "Embedding" if i == 0 else f"Layer {i - 1}"
for j in range(seq_len):
hover[i, j] = (
f"Token: {tokens[j]}<br>"
f"{layer_name}<br>"
f"L2 Norm: {magnitudes[i, j]:.2f}<br>"
f"Normalized: {display[i, j]:.3f}"
)
fig = go.Figure(
data=go.Heatmap(
z=display,
x=x_labels,
y=y_labels,
text=hover,
hoverinfo="text",
colorscale=TOKEN_LAYER_COLORSCALE,
colorbar=dict(
title=dict(
text="Activation" if normalize == "none" else "Norm. Activation",
font=dict(color=TEXT_COLOR),
),
tickfont=dict(color=TEXT_COLOR),
),
)
)
fig.update_layout(
title=dict(
text=f"Token × Layer Activation Grid (norm: {normalize})",
font=dict(color=ACCENT_COLOR, size=14),
),
xaxis=dict(
title=dict(text="Token", font=dict(color=TEXT_COLOR, size=11)),
tickfont=dict(color=TEXT_COLOR, size=9),
side="top",
tickangle=45,
),
yaxis=dict(
title=dict(text="Layer", font=dict(color=TEXT_COLOR, size=11)),
tickfont=dict(color=TEXT_COLOR, size=8),
autorange="reversed",
),
paper_bgcolor=PAPER_COLOR,
plot_bgcolor=BG_COLOR,
margin=dict(l=60, r=30, t=80, b=30),
height=520,
)
return fig
|