Alogotron commited on
Commit
17c5f1d
·
verified ·
1 Parent(s): 09d20cb

Upload viz_token_layer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. viz_token_layer.py +141 -0
viz_token_layer.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NeuroScope — Token-Layer Activation Grid
3
+
4
+ Heatmap with tokens as columns and layers as rows.
5
+ Color encodes activation magnitude (L2 norm) per token per layer,
6
+ revealing how each token's representation evolves through the network.
7
+
8
+ All charts use Plotly with the project dark theme (#1a1a2e bg, #e6b800 accent).
9
+ """
10
+
11
+ import numpy as np
12
+ import plotly.graph_objects as go
13
+ from extraction import ExtractionResult
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # Theme constants
17
+ # ---------------------------------------------------------------------------
18
+ BG_COLOR = "#1a1a2e"
19
+ PAPER_COLOR = "#1a1a2e"
20
+ TEXT_COLOR = "#e0e0e0"
21
+ ACCENT_COLOR = "#e6b800"
22
+ GRID_COLOR = "#2a2a4e"
23
+
24
+ # Custom purple-to-gold heatmap colorscale for activation intensity
25
+ TOKEN_LAYER_COLORSCALE = [
26
+ [0.0, "#0d0d1a"],
27
+ [0.1, "#1a1040"],
28
+ [0.25, "#2d1b69"],
29
+ [0.4, "#5e2d8e"],
30
+ [0.55, "#8e4585"],
31
+ [0.7, "#c46a3a"],
32
+ [0.85, "#e6b800"],
33
+ [1.0, "#ffd633"],
34
+ ]
35
+
36
+
37
+ def create_token_layer_grid(
38
+ result: ExtractionResult,
39
+ normalize: str = "global",
40
+ ) -> go.Figure:
41
+ """Create a token × layer activation magnitude heatmap.
42
+
43
+ Args:
44
+ result: Extraction output containing hidden states.
45
+ normalize: Normalization strategy:
46
+ - "global": Scale to global min/max across all layers and tokens.
47
+ - "per_layer": Normalize each row independently (highlights
48
+ within-layer variation).
49
+ - "per_token": Normalize each column independently (highlights
50
+ depth evolution per token).
51
+ - "none": Raw L2 norms.
52
+
53
+ Returns:
54
+ Plotly Figure with interactive heatmap.
55
+ """
56
+ # hidden_states: (num_layers+1, seq_len, hidden_dim)
57
+ hs = result.hidden_states
58
+ tokens = result.tokens
59
+ num_layers_total = hs.shape[0] # includes embedding layer
60
+ seq_len = len(tokens)
61
+
62
+ # Compute L2 norm per token per layer → (num_layers+1, seq_len)
63
+ magnitudes = np.linalg.norm(hs, axis=-1)
64
+
65
+ # Apply normalization
66
+ display = magnitudes.copy()
67
+ if normalize == "global":
68
+ vmin, vmax = display.min(), display.max()
69
+ if vmax > vmin:
70
+ display = (display - vmin) / (vmax - vmin)
71
+ elif normalize == "per_layer":
72
+ for i in range(num_layers_total):
73
+ row = display[i]
74
+ rmin, rmax = row.min(), row.max()
75
+ if rmax > rmin:
76
+ display[i] = (row - rmin) / (rmax - rmin)
77
+ elif normalize == "per_token":
78
+ for j in range(seq_len):
79
+ col = display[:, j]
80
+ cmin, cmax = col.min(), col.max()
81
+ if cmax > cmin:
82
+ display[:, j] = (col - cmin) / (cmax - cmin)
83
+ # else: "none" — use raw values
84
+
85
+ # Build axis labels
86
+ x_labels = [t[:12] for t in tokens]
87
+ y_labels = ["Embed"] + [f"L{i}" for i in range(result.num_layers)]
88
+
89
+ # Build hover text with raw values
90
+ hover = np.empty((num_layers_total, seq_len), dtype=object)
91
+ for i in range(num_layers_total):
92
+ layer_name = "Embedding" if i == 0 else f"Layer {i - 1}"
93
+ for j in range(seq_len):
94
+ hover[i, j] = (
95
+ f"Token: {tokens[j]}<br>"
96
+ f"{layer_name}<br>"
97
+ f"L2 Norm: {magnitudes[i, j]:.2f}<br>"
98
+ f"Normalized: {display[i, j]:.3f}"
99
+ )
100
+
101
+ fig = go.Figure(
102
+ data=go.Heatmap(
103
+ z=display,
104
+ x=x_labels,
105
+ y=y_labels,
106
+ text=hover,
107
+ hoverinfo="text",
108
+ colorscale=TOKEN_LAYER_COLORSCALE,
109
+ colorbar=dict(
110
+ title=dict(
111
+ text="Activation" if normalize == "none" else "Norm. Activation",
112
+ font=dict(color=TEXT_COLOR),
113
+ ),
114
+ tickfont=dict(color=TEXT_COLOR),
115
+ ),
116
+ )
117
+ )
118
+
119
+ fig.update_layout(
120
+ title=dict(
121
+ text=f"Token × Layer Activation Grid (norm: {normalize})",
122
+ font=dict(color=ACCENT_COLOR, size=14),
123
+ ),
124
+ xaxis=dict(
125
+ title=dict(text="Token", font=dict(color=TEXT_COLOR, size=11)),
126
+ tickfont=dict(color=TEXT_COLOR, size=9),
127
+ side="top",
128
+ tickangle=45,
129
+ ),
130
+ yaxis=dict(
131
+ title=dict(text="Layer", font=dict(color=TEXT_COLOR, size=11)),
132
+ tickfont=dict(color=TEXT_COLOR, size=8),
133
+ autorange="reversed",
134
+ ),
135
+ paper_bgcolor=PAPER_COLOR,
136
+ plot_bgcolor=BG_COLOR,
137
+ margin=dict(l=60, r=30, t=80, b=30),
138
+ height=520,
139
+ )
140
+
141
+ return fig