Alogotron commited on
Commit
80ae797
·
verified ·
1 Parent(s): 8244e01

Upload viz_comparison.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. viz_comparison.py +366 -0
viz_comparison.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NeuroScope — Comparison Mode Visualizations (Phase 2)
3
+
4
+ Side-by-side and overlay visualizations comparing activation patterns
5
+ between two different prompts. Uses joint normalization so the two
6
+ prompts are visually comparable.
7
+
8
+ All charts use Plotly with the project dark theme (#1a1a2e bg, #e6b800 accent).
9
+ """
10
+
11
+ import numpy as np
12
+ import plotly.graph_objects as go
13
+ from plotly.subplots import make_subplots
14
+ from extraction import ExtractionResult
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Theme constants
18
+ # ---------------------------------------------------------------------------
19
+ BG_COLOR = "#1a1a2e"
20
+ PAPER_COLOR = "#1a1a2e"
21
+ TEXT_COLOR = "#e0e0e0"
22
+ ACCENT_COLOR = "#e6b800"
23
+ GRID_COLOR = "#2a2a4e"
24
+
25
+ COLOR_A = "#e6b800" # Gold for Prompt A
26
+ COLOR_B = "#4a90d9" # Blue for Prompt B
27
+
28
+ ATTN_COLORSCALE = [
29
+ [0.0, "#0d0d1a"],
30
+ [0.15, "#1a1a3e"],
31
+ [0.3, "#2e1a00"],
32
+ [0.5, "#7a5500"],
33
+ [0.7, "#b38600"],
34
+ [0.85, "#e6b800"],
35
+ [1.0, "#ffd633"],
36
+ ]
37
+
38
+ TOKEN_LAYER_COLORSCALE = [
39
+ [0.0, "#0d0d1a"],
40
+ [0.1, "#1a1040"],
41
+ [0.25, "#2d1b69"],
42
+ [0.4, "#5e2d8e"],
43
+ [0.55, "#8e4585"],
44
+ [0.7, "#c46a3a"],
45
+ [0.85, "#e6b800"],
46
+ [1.0, "#ffd633"],
47
+ ]
48
+
49
+ # Diverging colorscale for difference maps (blue-white-red)
50
+ DIFF_COLORSCALE = [
51
+ [0.0, "#2166ac"],
52
+ [0.25, "#67a9cf"],
53
+ [0.5, "#1a1a2e"],
54
+ [0.75, "#ef8a62"],
55
+ [1.0, "#b2182b"],
56
+ ]
57
+
58
+
59
+ # ---------------------------------------------------------------------------
60
+ # Attention comparison
61
+ # ---------------------------------------------------------------------------
62
+ def create_attention_comparison(
63
+ result_a: ExtractionResult,
64
+ result_b: ExtractionResult,
65
+ layer: int = 0,
66
+ head: str = "average",
67
+ ) -> go.Figure:
68
+ """Side-by-side attention heatmaps with difference overlay."""
69
+ layer = int(np.clip(layer, 0, min(result_a.num_layers, result_b.num_layers) - 1))
70
+
71
+ def _get_matrix(result, layer, head):
72
+ attn = result.attentions[layer]
73
+ if head == "average":
74
+ return attn.mean(axis=0)
75
+ elif head == "max":
76
+ return attn.max(axis=0)
77
+ else:
78
+ h = int(np.clip(int(head), 0, result.num_heads - 1))
79
+ return attn[h]
80
+
81
+ mat_a = _get_matrix(result_a, layer, head)
82
+ mat_b = _get_matrix(result_b, layer, head)
83
+ labels_a = [t[:10] for t in result_a.tokens]
84
+ labels_b = [t[:10] for t in result_b.tokens]
85
+
86
+ head_label = f"avg" if head == "average" else f"max" if head == "max" else f"h{head}"
87
+
88
+ fig = make_subplots(
89
+ rows=1, cols=2,
90
+ subplot_titles=[f"Prompt A — L{layer} {head_label}", f"Prompt B — L{layer} {head_label}"],
91
+ horizontal_spacing=0.08,
92
+ )
93
+
94
+ # Joint scale
95
+ vmax = max(float(mat_a.max()), float(mat_b.max()))
96
+ if vmax == 0:
97
+ vmax = 1.0
98
+
99
+ def _hover(tokens, mat):
100
+ sl = len(tokens)
101
+ h = np.empty((sl, sl), dtype=object)
102
+ for i in range(sl):
103
+ for j in range(sl):
104
+ h[i, j] = f"From: {tokens[i]}<br>To: {tokens[j]}<br>Wt: {mat[i, j]:.4f}"
105
+ return h
106
+
107
+ fig.add_trace(
108
+ go.Heatmap(
109
+ z=mat_a, x=labels_a, y=labels_a,
110
+ text=_hover(result_a.tokens, mat_a), hoverinfo="text",
111
+ colorscale=ATTN_COLORSCALE, zmin=0, zmax=vmax, showscale=False,
112
+ ), row=1, col=1,
113
+ )
114
+ fig.add_trace(
115
+ go.Heatmap(
116
+ z=mat_b, x=labels_b, y=labels_b,
117
+ text=_hover(result_b.tokens, mat_b), hoverinfo="text",
118
+ colorscale=ATTN_COLORSCALE, zmin=0, zmax=vmax,
119
+ colorbar=dict(
120
+ title=dict(text="Attn", font=dict(color=TEXT_COLOR)),
121
+ tickfont=dict(color=TEXT_COLOR),
122
+ ),
123
+ ), row=1, col=2,
124
+ )
125
+
126
+ fig.update_layout(
127
+ paper_bgcolor=PAPER_COLOR, plot_bgcolor=BG_COLOR,
128
+ height=480, margin=dict(l=60, r=30, t=50, b=60),
129
+ )
130
+ for col in (1, 2):
131
+ fig.update_xaxes(tickfont=dict(color=TEXT_COLOR, size=8), tickangle=45, row=1, col=col)
132
+ fig.update_yaxes(
133
+ tickfont=dict(color=TEXT_COLOR, size=8), autorange="reversed",
134
+ row=1, col=col,
135
+ )
136
+ for ann in fig.layout.annotations:
137
+ ann.font = dict(color=ACCENT_COLOR, size=12)
138
+ return fig
139
+
140
+
141
+ # ---------------------------------------------------------------------------
142
+ # Magnitude comparison
143
+ # ---------------------------------------------------------------------------
144
+ def create_magnitude_comparison(
145
+ result_a: ExtractionResult,
146
+ result_b: ExtractionResult,
147
+ metric: str = "mean_l2",
148
+ ) -> go.Figure:
149
+ """Overlaid magnitude bar chart for two prompts."""
150
+
151
+ def _compute_mag(hs, metric):
152
+ n = hs.shape[0]
153
+ if metric == "mean_l2":
154
+ return np.array([np.linalg.norm(hs[i], axis=-1).mean() for i in range(n)])
155
+ elif metric == "max_l2":
156
+ return np.array([np.linalg.norm(hs[i], axis=-1).max() for i in range(n)])
157
+ else: # mean_abs
158
+ return np.array([np.abs(hs[i]).mean() for i in range(n)])
159
+
160
+ mag_a = _compute_mag(result_a.hidden_states, metric)
161
+ mag_b = _compute_mag(result_b.hidden_states, metric)
162
+ labels = ["Emb"] + [f"L{i}" for i in range(result_a.num_layers)]
163
+
164
+ fig = go.Figure()
165
+
166
+ fig.add_trace(go.Bar(
167
+ x=labels, y=mag_a,
168
+ name="Prompt A", marker_color=COLOR_A, opacity=0.8,
169
+ hovertext=[f"Prompt A<br>{labels[i]}<br>{metric}: {mag_a[i]:.2f}" for i in range(len(labels))],
170
+ hoverinfo="text",
171
+ ))
172
+ fig.add_trace(go.Bar(
173
+ x=labels, y=mag_b,
174
+ name="Prompt B", marker_color=COLOR_B, opacity=0.8,
175
+ hovertext=[f"Prompt B<br>{labels[i]}<br>{metric}: {mag_b[i]:.2f}" for i in range(len(labels))],
176
+ hoverinfo="text",
177
+ ))
178
+
179
+ # Difference line
180
+ diff = mag_a[:len(mag_b)] - mag_b[:len(mag_a)]
181
+ min_len = min(len(mag_a), len(mag_b))
182
+ fig.add_trace(go.Scatter(
183
+ x=labels[:min_len], y=diff[:min_len],
184
+ name="Δ (A − B)", mode="lines+markers",
185
+ line=dict(color="#e05050", width=2, dash="dot"),
186
+ marker=dict(size=4, color="#e05050"),
187
+ yaxis="y2",
188
+ hovertext=[f"Δ at {labels[i]}: {diff[i]:+.2f}" for i in range(min_len)],
189
+ hoverinfo="text",
190
+ ))
191
+
192
+ fig.update_layout(
193
+ title=dict(text=f"Activation Magnitude Comparison ({metric})", font=dict(color=ACCENT_COLOR, size=14)),
194
+ barmode="group",
195
+ paper_bgcolor=PAPER_COLOR, plot_bgcolor=BG_COLOR,
196
+ height=480, margin=dict(l=60, r=60, t=50, b=60),
197
+ xaxis=dict(tickfont=dict(color=TEXT_COLOR, size=8), gridcolor=GRID_COLOR, tickangle=45),
198
+ yaxis=dict(
199
+ title=dict(text=metric, font=dict(color=TEXT_COLOR, size=11)),
200
+ tickfont=dict(color=TEXT_COLOR, size=9), gridcolor=GRID_COLOR,
201
+ ),
202
+ yaxis2=dict(
203
+ title=dict(text="Δ", font=dict(color="#e05050", size=11)),
204
+ tickfont=dict(color="#e05050", size=9),
205
+ overlaying="y", side="right", zeroline=True,
206
+ zerolinecolor="rgba(224,80,80,0.3)",
207
+ ),
208
+ legend=dict(font=dict(color=TEXT_COLOR, size=10), bgcolor="rgba(26,26,46,0.8)"),
209
+ bargap=0.15,
210
+ )
211
+ return fig
212
+
213
+
214
+ # ---------------------------------------------------------------------------
215
+ # Token-Layer grid comparison
216
+ # ---------------------------------------------------------------------------
217
+ def create_token_layer_comparison(
218
+ result_a: ExtractionResult,
219
+ result_b: ExtractionResult,
220
+ normalize: str = "global",
221
+ ) -> go.Figure:
222
+ """Side-by-side token-layer activation grids."""
223
+ norms_a = np.linalg.norm(result_a.hidden_states, axis=-1)
224
+ norms_b = np.linalg.norm(result_b.hidden_states, axis=-1)
225
+
226
+ # Joint normalization
227
+ if normalize == "global":
228
+ vmin = min(float(norms_a.min()), float(norms_b.min()))
229
+ vmax = max(float(norms_a.max()), float(norms_b.max()))
230
+ if vmax > vmin:
231
+ disp_a = (norms_a - vmin) / (vmax - vmin)
232
+ disp_b = (norms_b - vmin) / (vmax - vmin)
233
+ else:
234
+ disp_a, disp_b = norms_a * 0, norms_b * 0
235
+ else:
236
+ disp_a, disp_b = norms_a.copy(), norms_b.copy()
237
+
238
+ labels_a = [t[:10] for t in result_a.tokens]
239
+ labels_b = [t[:10] for t in result_b.tokens]
240
+ y_labels = ["Emb"] + [f"L{i}" for i in range(result_a.num_layers)]
241
+
242
+ fig = make_subplots(
243
+ rows=1, cols=2,
244
+ subplot_titles=["Prompt A — Token×Layer", "Prompt B — Token×Layer"],
245
+ horizontal_spacing=0.08,
246
+ )
247
+
248
+ def _hover(tokens, norms):
249
+ h = np.empty(norms.shape, dtype=object)
250
+ for i in range(norms.shape[0]):
251
+ lname = "Embedding" if i == 0 else f"Layer {i-1}"
252
+ for j in range(len(tokens)):
253
+ h[i, j] = f"{tokens[j]}<br>{lname}<br>L2: {norms[i, j]:.2f}"
254
+ return h
255
+
256
+ fig.add_trace(
257
+ go.Heatmap(
258
+ z=disp_a, x=labels_a, y=y_labels,
259
+ text=_hover(result_a.tokens, norms_a), hoverinfo="text",
260
+ colorscale=TOKEN_LAYER_COLORSCALE, showscale=False,
261
+ ), row=1, col=1,
262
+ )
263
+ fig.add_trace(
264
+ go.Heatmap(
265
+ z=disp_b, x=labels_b, y=y_labels,
266
+ text=_hover(result_b.tokens, norms_b), hoverinfo="text",
267
+ colorscale=TOKEN_LAYER_COLORSCALE, showscale=True,
268
+ colorbar=dict(
269
+ title=dict(text="Norm", font=dict(color=TEXT_COLOR)),
270
+ tickfont=dict(color=TEXT_COLOR),
271
+ ),
272
+ ), row=1, col=2,
273
+ )
274
+
275
+ fig.update_layout(
276
+ paper_bgcolor=PAPER_COLOR, plot_bgcolor=BG_COLOR,
277
+ height=520, margin=dict(l=60, r=30, t=50, b=30),
278
+ )
279
+ for col in (1, 2):
280
+ fig.update_xaxes(tickfont=dict(color=TEXT_COLOR, size=8), tickangle=45, side="top", row=1, col=col)
281
+ fig.update_yaxes(tickfont=dict(color=TEXT_COLOR, size=7), autorange="reversed", row=1, col=col)
282
+ for ann in fig.layout.annotations:
283
+ ann.font = dict(color=ACCENT_COLOR, size=12)
284
+ return fig
285
+
286
+
287
+ # ---------------------------------------------------------------------------
288
+ # Scatter comparison (overlay both prompts)
289
+ # ---------------------------------------------------------------------------
290
+ def create_scatter_comparison(
291
+ result_a: ExtractionResult,
292
+ result_b: ExtractionResult,
293
+ layer: int = 18,
294
+ method: str = "pca",
295
+ ) -> go.Figure:
296
+ """Overlay scatter plot with both prompts' tokens in same reduced space."""
297
+ from viz_scatter import _run_pca, _run_umap
298
+
299
+ reduce_fn = _run_umap if method == "umap" else _run_pca
300
+ hs_idx = int(np.clip(layer + 1, 0, result_a.hidden_states.shape[0] - 1))
301
+
302
+ data_a = result_a.hidden_states[hs_idx].astype(np.float64)
303
+ data_b = result_b.hidden_states[
304
+ int(np.clip(layer + 1, 0, result_b.hidden_states.shape[0] - 1))
305
+ ].astype(np.float64)
306
+
307
+ # Joint reduction for fair comparison
308
+ stacked = np.vstack([data_a, data_b])
309
+ reduced = reduce_fn(stacked)
310
+ n_a = len(result_a.tokens)
311
+ red_a = reduced[:n_a]
312
+ red_b = reduced[n_a:]
313
+
314
+ fig = go.Figure()
315
+
316
+ # Prompt A tokens
317
+ fig.add_trace(go.Scatter(
318
+ x=red_a[:, 0], y=red_a[:, 1],
319
+ mode="markers+text",
320
+ marker=dict(size=12, color=COLOR_A, opacity=0.9, line=dict(width=1, color="white")),
321
+ text=[t[:8] for t in result_a.tokens],
322
+ textposition="top center",
323
+ textfont=dict(color=COLOR_A, size=9),
324
+ name="Prompt A",
325
+ hovertext=[
326
+ f"A: {result_a.tokens[j]}<br>Pos: {j}<br>x: {red_a[j,0]:.3f}, y: {red_a[j,1]:.3f}"
327
+ for j in range(n_a)
328
+ ],
329
+ hoverinfo="text",
330
+ ))
331
+
332
+ # Prompt B tokens
333
+ fig.add_trace(go.Scatter(
334
+ x=red_b[:, 0], y=red_b[:, 1],
335
+ mode="markers+text",
336
+ marker=dict(size=12, color=COLOR_B, opacity=0.9, symbol="diamond",
337
+ line=dict(width=1, color="white")),
338
+ text=[t[:8] for t in result_b.tokens],
339
+ textposition="bottom center",
340
+ textfont=dict(color=COLOR_B, size=9),
341
+ name="Prompt B",
342
+ hovertext=[
343
+ f"B: {result_b.tokens[j]}<br>Pos: {j}<br>x: {red_b[j,0]:.3f}, y: {red_b[j,1]:.3f}"
344
+ for j in range(len(result_b.tokens))
345
+ ],
346
+ hoverinfo="text",
347
+ ))
348
+
349
+ fig.update_layout(
350
+ title=dict(
351
+ text=f"Token Space Comparison ({method.upper()}) — Layer {layer}",
352
+ font=dict(color=ACCENT_COLOR, size=14),
353
+ ),
354
+ xaxis=dict(
355
+ title=dict(text=f"{method.upper()} 1", font=dict(color=TEXT_COLOR, size=11)),
356
+ tickfont=dict(color=TEXT_COLOR, size=9), gridcolor=GRID_COLOR, zeroline=False,
357
+ ),
358
+ yaxis=dict(
359
+ title=dict(text=f"{method.upper()} 2", font=dict(color=TEXT_COLOR, size=11)),
360
+ tickfont=dict(color=TEXT_COLOR, size=9), gridcolor=GRID_COLOR, zeroline=False,
361
+ ),
362
+ paper_bgcolor=PAPER_COLOR, plot_bgcolor=BG_COLOR,
363
+ height=480, margin=dict(l=60, r=30, t=50, b=50),
364
+ legend=dict(font=dict(color=TEXT_COLOR, size=10), bgcolor="rgba(26,26,46,0.8)"),
365
+ )
366
+ return fig