File size: 4,572 Bytes
17c5f1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
NeuroScope — Token-Layer Activation Grid

Heatmap with tokens as columns and layers as rows.
Color encodes activation magnitude (L2 norm) per token per layer,
revealing how each token's representation evolves through the network.

All charts use Plotly with the project dark theme (#1a1a2e bg, #e6b800 accent).
"""

import numpy as np
import plotly.graph_objects as go
from extraction import ExtractionResult

# ---------------------------------------------------------------------------
# Theme constants
# ---------------------------------------------------------------------------
BG_COLOR = "#1a1a2e"
PAPER_COLOR = "#1a1a2e"
TEXT_COLOR = "#e0e0e0"
ACCENT_COLOR = "#e6b800"
GRID_COLOR = "#2a2a4e"

# Custom purple-to-gold heatmap colorscale for activation intensity
TOKEN_LAYER_COLORSCALE = [
    [0.0, "#0d0d1a"],
    [0.1, "#1a1040"],
    [0.25, "#2d1b69"],
    [0.4, "#5e2d8e"],
    [0.55, "#8e4585"],
    [0.7, "#c46a3a"],
    [0.85, "#e6b800"],
    [1.0, "#ffd633"],
]


def create_token_layer_grid(
    result: ExtractionResult,
    normalize: str = "global",
) -> go.Figure:
    """Create a token × layer activation magnitude heatmap.

    Args:
        result: Extraction output containing hidden states.
        normalize: Normalization strategy:
            - "global": Scale to global min/max across all layers and tokens.
            - "per_layer": Normalize each row independently (highlights
              within-layer variation).
            - "per_token": Normalize each column independently (highlights
              depth evolution per token).
            - "none": Raw L2 norms.

    Returns:
        Plotly Figure with interactive heatmap.
    """
    # hidden_states: (num_layers+1, seq_len, hidden_dim)
    hs = result.hidden_states
    tokens = result.tokens
    num_layers_total = hs.shape[0]  # includes embedding layer
    seq_len = len(tokens)

    # Compute L2 norm per token per layer → (num_layers+1, seq_len)
    magnitudes = np.linalg.norm(hs, axis=-1)

    # Apply normalization
    display = magnitudes.copy()
    if normalize == "global":
        vmin, vmax = display.min(), display.max()
        if vmax > vmin:
            display = (display - vmin) / (vmax - vmin)
    elif normalize == "per_layer":
        for i in range(num_layers_total):
            row = display[i]
            rmin, rmax = row.min(), row.max()
            if rmax > rmin:
                display[i] = (row - rmin) / (rmax - rmin)
    elif normalize == "per_token":
        for j in range(seq_len):
            col = display[:, j]
            cmin, cmax = col.min(), col.max()
            if cmax > cmin:
                display[:, j] = (col - cmin) / (cmax - cmin)
    # else: "none" — use raw values

    # Build axis labels
    x_labels = [t[:12] for t in tokens]
    y_labels = ["Embed"] + [f"L{i}" for i in range(result.num_layers)]

    # Build hover text with raw values
    hover = np.empty((num_layers_total, seq_len), dtype=object)
    for i in range(num_layers_total):
        layer_name = "Embedding" if i == 0 else f"Layer {i - 1}"
        for j in range(seq_len):
            hover[i, j] = (
                f"Token: {tokens[j]}<br>"
                f"{layer_name}<br>"
                f"L2 Norm: {magnitudes[i, j]:.2f}<br>"
                f"Normalized: {display[i, j]:.3f}"
            )

    fig = go.Figure(
        data=go.Heatmap(
            z=display,
            x=x_labels,
            y=y_labels,
            text=hover,
            hoverinfo="text",
            colorscale=TOKEN_LAYER_COLORSCALE,
            colorbar=dict(
                title=dict(
                    text="Activation" if normalize == "none" else "Norm. Activation",
                    font=dict(color=TEXT_COLOR),
                ),
                tickfont=dict(color=TEXT_COLOR),
            ),
        )
    )

    fig.update_layout(
        title=dict(
            text=f"Token × Layer Activation Grid (norm: {normalize})",
            font=dict(color=ACCENT_COLOR, size=14),
        ),
        xaxis=dict(
            title=dict(text="Token", font=dict(color=TEXT_COLOR, size=11)),
            tickfont=dict(color=TEXT_COLOR, size=9),
            side="top",
            tickangle=45,
        ),
        yaxis=dict(
            title=dict(text="Layer", font=dict(color=TEXT_COLOR, size=11)),
            tickfont=dict(color=TEXT_COLOR, size=8),
            autorange="reversed",
        ),
        paper_bgcolor=PAPER_COLOR,
        plot_bgcolor=BG_COLOR,
        margin=dict(l=60, r=30, t=80, b=30),
        height=520,
    )

    return fig