ryandt commited on
Commit
2b523e0
·
1 Parent(s): 57e2155

First push

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. app.py +625 -0
  3. model.py +220 -0
  4. requirements.txt +26 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv
2
+ __pycache__
app.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Logit Lens Explorer - Gradio Application.
3
+
4
+ Interactive text generation tool that surfaces the logit lens for each
5
+ generated token. Users input a prompt, the model generates text, and
6
+ clicking any token reveals what the model was predicting at each
7
+ intermediate layer.
8
+
9
+ Part of E02: Logit Lens Explorer.
10
+ """
11
+
12
+ import html as html_lib
13
+ import json
14
+ from typing import Generator
15
+
16
+ import gradio as gr
17
+
18
+ try:
19
+ import spaces
20
+ SPACES_AVAILABLE = True
21
+ except ImportError:
22
+ SPACES_AVAILABLE = False
23
+
24
+ from model import generate_with_logit_lens, load_model, TokenData
25
+
26
+
27
+ def gpu_decorator(duration: int = 120):
28
+ """Return @spaces.GPU decorator if available, otherwise a no-op."""
29
+ if SPACES_AVAILABLE:
30
+ return spaces.GPU(duration=duration)
31
+ return lambda fn: fn
32
+
33
+
34
+ def build_token_html(tokens: list[TokenData]) -> str:
35
+ """Build HTML output from accumulated tokens as plain clickable spans.
36
+
37
+ Each token span carries all data needed for client-side logit lens
38
+ rendering: the token text, probability, and per-layer predictions
39
+ as JSON data attributes.
40
+
41
+ Args:
42
+ tokens: List of TokenData objects.
43
+
44
+ Returns:
45
+ HTML string with clickable token spans.
46
+ """
47
+ font_family = "'Cascadia Code', 'Fira Code', Consolas, monospace"
48
+ style_tag = "<style>.token-span:hover { text-decoration: underline !important; }</style>"
49
+
50
+ if not tokens:
51
+ return (
52
+ f'{style_tag}<div class="token-container" '
53
+ f'style="font-family: {font_family}; line-height: 1.8; padding: 10px;"></div>'
54
+ )
55
+
56
+ spans = []
57
+ for i, token_data in enumerate(tokens):
58
+ token_text = html_lib.escape(token_data.token)
59
+ if "\n" in token_text:
60
+ token_text = token_text.replace("\n", "<br>")
61
+ spans.append(token_text)
62
+ else:
63
+ # Serialize layer predictions as JSON for client-side rendering
64
+ layers_json = html_lib.escape(json.dumps([
65
+ {
66
+ "layer_index": lp.layer_index,
67
+ "top_tokens": lp.top_tokens,
68
+ }
69
+ for lp in token_data.layer_predictions
70
+ ]))
71
+
72
+ span = (
73
+ f'<span class="token-span" data-token-index="{i}"'
74
+ f' data-token="{html_lib.escape(token_data.token)}"'
75
+ f' data-prob="{token_data.probability}"'
76
+ f' data-layers="{layers_json}"'
77
+ f' style="cursor: pointer;">{token_text}</span>'
78
+ )
79
+ spans.append(span)
80
+
81
+ html_content = "".join(spans)
82
+ return (
83
+ f'{style_tag}<div class="token-container" style="font-family: {font_family};'
84
+ f' line-height: 1.8; padding: 10px; white-space: pre-wrap;">{html_content}</div>'
85
+ )
86
+
87
+
88
+ @gpu_decorator(duration=120)
89
+ def run_inference(prompt: str) -> list[TokenData]:
90
+ """Run full text generation on GPU and return all tokens.
91
+
92
+ On HuggingFace Spaces with ZeroGPU, this function is decorated with
93
+ @spaces.GPU to allocate GPU resources for the duration of inference.
94
+
95
+ Args:
96
+ prompt: User prompt text.
97
+
98
+ Returns:
99
+ List of TokenData with token strings, IDs, probabilities,
100
+ and per-layer logit lens predictions.
101
+ """
102
+ return list(generate_with_logit_lens(prompt))
103
+
104
+
105
+ def generate_streaming(prompt: str) -> Generator[str, None, None]:
106
+ """Stream token generation with progressive HTML output.
107
+
108
+ Runs full inference first (GPU-bound), then streams HTML rendering
109
+ from pre-computed tokens (no GPU needed). This architecture is
110
+ required for HuggingFace ZeroGPU compatibility.
111
+
112
+ Args:
113
+ prompt: User prompt text.
114
+
115
+ Yields:
116
+ HTML string with accumulated tokens.
117
+ """
118
+ if not prompt or not prompt.strip():
119
+ yield '<div style="color: #666; padding: 10px;">Please enter a prompt.</div>'
120
+ return
121
+
122
+ # Show loading indicator during GPU inference
123
+ loading = """<div style="color: #60a5fa; padding: 10px; display: flex; align-items: center; gap: 10px;">
124
+ <div style="width: 20px; height: 20px; border: 2px solid #60a5fa;
125
+ border-top-color: transparent; border-radius: 50%;
126
+ animation: spin 1s linear infinite;"></div>
127
+ <style>@keyframes spin { to { transform: rotate(360deg); } }</style>
128
+ Generating...
129
+ </div>"""
130
+ yield loading
131
+
132
+ # Full inference (GPU allocated here on ZeroGPU)
133
+ tokens = run_inference(prompt)
134
+
135
+ if not tokens:
136
+ yield '<div style="color: #666; padding: 10px;">No tokens generated.</div>'
137
+ return
138
+
139
+ # Stream HTML rendering (no GPU needed)
140
+ accumulated: list[TokenData] = []
141
+ for token_data in tokens:
142
+ accumulated.append(token_data)
143
+ yield build_token_html(accumulated)
144
+
145
+
146
+ # JavaScript for token click handling -- reads layer data from span attributes
147
+ # and renders the logit lens panel entirely client-side (no server round-trip).
148
+ # Matches the pattern from the OCR app's alternatives panel.
149
+ TOKEN_CLICK_JS = """
150
+ (function() {
151
+ console.log('[logit-lens] Click handler installed');
152
+
153
+ var CARD_TOP_K = 5; // Show top 5 in each layer card
154
+ var CHART_TOP_N = 20; // Track top 20 most recurring tokens in chart
155
+
156
+ // 20 distinct colors for chart lines
157
+ var LINE_COLORS = [
158
+ '#60a5fa','#f87171','#34d399','#fbbf24','#a78bfa',
159
+ '#fb923c','#2dd4bf','#f472b6','#818cf8','#4ade80',
160
+ '#e879f9','#38bdf8','#facc15','#fb7185','#a3e635',
161
+ '#c084fc','#22d3ee','#fdba74','#86efac','#fca5a5'
162
+ ];
163
+
164
+ function escapeHtml(text) {
165
+ var div = document.createElement('div');
166
+ div.textContent = text;
167
+ return div.innerHTML;
168
+ }
169
+
170
+ function renderLayerCard(layer, finalToken, nLayers, layerIdx) {
171
+ var lastLayer = nLayers - 1;
172
+ var label;
173
+ if (layerIdx === 0) {
174
+ label = 'Layer 0 (embed)';
175
+ } else if (layerIdx === lastLayer) {
176
+ label = 'Layer ' + layerIdx + ' (final)';
177
+ } else {
178
+ label = 'Layer ' + layerIdx;
179
+ }
180
+
181
+ var tokenCells = '';
182
+ var displayCount = Math.min(layer.top_tokens.length, CARD_TOP_K);
183
+ for (var i = 0; i < displayCount; i++) {
184
+ var entry = layer.top_tokens[i];
185
+ var tok = escapeHtml(entry.token);
186
+ var pct = (entry.probability * 100);
187
+ var barWidth = Math.max(pct, 0.5);
188
+ var isMatch = entry.token === finalToken;
189
+ var tokColor = isMatch ? '#60a5fa' : '#e5e7eb';
190
+ var barColor = isMatch ? '#60a5fa' : '#4b5563';
191
+ var fontWeight = isMatch ? '700' : '400';
192
+
193
+ tokenCells +=
194
+ '<div style="display:flex;align-items:center;gap:6px;margin:2px 0;">' +
195
+ '<span style="width:80px;overflow:hidden;text-overflow:ellipsis;' +
196
+ 'white-space:nowrap;font-family:monospace;font-size:12px;' +
197
+ 'color:' + tokColor + ';font-weight:' + fontWeight + ';">' + tok + '</span>' +
198
+ '<span style="width:44px;text-align:right;color:#9ca3af;' +
199
+ 'font-size:11px;flex-shrink:0;">' + pct.toFixed(1) + '%</span>' +
200
+ '<div style="flex:1;height:8px;background:#1f2937;' +
201
+ 'border-radius:4px;overflow:hidden;min-width:30px;">' +
202
+ '<div style="width:' + barWidth + '%;height:100%;' +
203
+ 'background:' + barColor + ';border-radius:4px;"></div>' +
204
+ '</div></div>';
205
+ }
206
+
207
+ var cardBg = (layerIdx % 2 === 0) ? '#111827' : '#0d1117';
208
+ return '<div style="background:' + cardBg + ';border-radius:6px;padding:8px 10px;">' +
209
+ '<div style="color:#9ca3af;font-size:11px;font-family:monospace;' +
210
+ 'margin-bottom:4px;font-weight:600;">' + label + '</div>' +
211
+ tokenCells +
212
+ '</div>';
213
+ }
214
+
215
+ function renderLineChart(layersData, finalToken, nLayers) {
216
+ // Collect frequency counts: how many layers each token appears in
217
+ var tokenFreq = {}; // token -> count of layers it appears in
218
+ var tokenProbs = {}; // token -> array of {layer, prob}
219
+ for (var li = 0; li < nLayers; li++) {
220
+ var tops = layersData[li].top_tokens;
221
+ for (var ti = 0; ti < tops.length; ti++) {
222
+ var tok = tops[ti].token;
223
+ var prob = tops[ti].probability;
224
+ if (!tokenFreq[tok]) {
225
+ tokenFreq[tok] = 0;
226
+ tokenProbs[tok] = [];
227
+ }
228
+ tokenFreq[tok]++;
229
+ tokenProbs[tok].push({layer: li, prob: prob});
230
+ }
231
+ }
232
+
233
+ // Sort by frequency descending, take top N
234
+ var allTokens = Object.keys(tokenFreq);
235
+ allTokens.sort(function(a, b) { return tokenFreq[b] - tokenFreq[a]; });
236
+ var chartTokens = allTokens.slice(0, CHART_TOP_N);
237
+
238
+ // Ensure the final token is always included
239
+ if (chartTokens.indexOf(finalToken) === -1 && tokenFreq[finalToken]) {
240
+ chartTokens.pop();
241
+ chartTokens.push(finalToken);
242
+ }
243
+
244
+ console.log('[logit-lens] Chart tokens:', chartTokens.length, chartTokens);
245
+
246
+ if (chartTokens.length === 0) return null;
247
+
248
+ // Build lookup: token -> layer -> probability (0 if absent)
249
+ var data = {}; // token -> array of length nLayers
250
+ var maxProb = 0;
251
+ for (var ci = 0; ci < chartTokens.length; ci++) {
252
+ var t = chartTokens[ci];
253
+ data[t] = new Array(nLayers);
254
+ for (var l = 0; l < nLayers; l++) { data[t][l] = 0; }
255
+ var entries = tokenProbs[t];
256
+ for (var ei = 0; ei < entries.length; ei++) {
257
+ var p = entries[ei].prob * 100;
258
+ data[t][entries[ei].layer] = p;
259
+ if (p > maxProb) maxProb = p;
260
+ }
261
+ }
262
+
263
+ // Build color map for each token
264
+ var colorMap = {};
265
+ for (var ci = 0; ci < chartTokens.length; ci++) {
266
+ var tok = chartTokens[ci];
267
+ colorMap[tok] = (tok === finalToken) ? '#60a5fa' : LINE_COLORS[ci % LINE_COLORS.length];
268
+ }
269
+
270
+ // SVG dimensions
271
+ var W = 700, H = 300;
272
+ var padL = 45, padR = 20, padT = 20, padB = 30;
273
+ var plotW = W - padL - padR;
274
+ var plotH = H - padT - padB;
275
+ var yMax = Math.ceil(maxProb / 10) * 10;
276
+ if (yMax < 10) yMax = 10;
277
+
278
+ function xPos(layer) { return padL + (layer / (nLayers - 1)) * plotW; }
279
+ function yPos(pct) { return padT + plotH - (pct / yMax) * plotH; }
280
+
281
+ // Start SVG
282
+ var svg = '<svg class="logit-chart-svg" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 ' + W + ' ' + H +
283
+ '" style="width:100%;max-width:' + W + 'px;height:auto;background:#111827;border-radius:8px;display:block;">';
284
+
285
+ // Y-axis gridlines and labels
286
+ var yTicks = 5;
287
+ for (var yi = 0; yi <= yTicks; yi++) {
288
+ var yVal = (yMax / yTicks) * yi;
289
+ var y = yPos(yVal);
290
+ svg += '<line x1="' + padL + '" y1="' + y + '" x2="' + (W - padR) + '" y2="' + y +
291
+ '" stroke="#374151" stroke-width="1"/>';
292
+ svg += '<text x="' + (padL - 6) + '" y="' + (y + 4) +
293
+ '" text-anchor="end" fill="#9ca3af" font-size="10" font-family="monospace">' +
294
+ yVal.toFixed(0) + '%</text>';
295
+ }
296
+
297
+ // X-axis labels (every 4 layers + first and last)
298
+ for (var xi = 0; xi < nLayers; xi++) {
299
+ if (xi === 0 || xi === nLayers - 1 || xi % 4 === 0) {
300
+ var x = xPos(xi);
301
+ svg += '<text x="' + x + '" y="' + (H - 8) +
302
+ '" text-anchor="middle" fill="#9ca3af" font-size="10" font-family="monospace">' +
303
+ xi + '</text>';
304
+ }
305
+ }
306
+
307
+ // Draw lines for each token
308
+ for (var ci = 0; ci < chartTokens.length; ci++) {
309
+ var tok = chartTokens[ci];
310
+ var color = colorMap[tok];
311
+ var strokeW = (tok === finalToken) ? '2.5' : '1.5';
312
+ var opacity = (tok === finalToken) ? '1' : '0.7';
313
+
314
+ var points = '';
315
+ for (var l = 0; l < nLayers; l++) {
316
+ if (l > 0) points += ' ';
317
+ points += xPos(l).toFixed(1) + ',' + yPos(data[tok][l]).toFixed(1);
318
+ }
319
+ svg += '<polyline points="' + points + '" fill="none" stroke="' + color +
320
+ '" stroke-width="' + strokeW + '" opacity="' + opacity + '"/>';
321
+ }
322
+
323
+ // Invisible overlay rect to capture mouse events across the full plot area
324
+ svg += '<rect class="logit-chart-overlay" x="' + padL + '" y="' + padT +
325
+ '" width="' + plotW + '" height="' + plotH + '" fill="transparent" pointer-events="all" style="cursor:crosshair;"/>';
326
+
327
+ // Vertical crosshair line (hidden initially)
328
+ svg += '<line class="logit-chart-crosshair" x1="0" y1="' + padT + '" x2="0" y2="' + (padT + plotH) +
329
+ '" stroke="#9ca3af" stroke-width="1" stroke-dasharray="4,3" visibility="hidden"/>';
330
+
331
+ svg += '</svg>';
332
+
333
+ // Tooltip div (hidden, positioned absolutely over the chart)
334
+ var tooltip = '<div class="logit-chart-tooltip" style="' +
335
+ 'display:none;position:absolute;pointer-events:none;z-index:10;' +
336
+ 'background:#1e293b;border:1px solid #475569;border-radius:6px;padding:8px 10px;' +
337
+ 'font-family:monospace;font-size:11px;color:#e5e7eb;' +
338
+ 'box-shadow:0 4px 12px rgba(0,0,0,0.4);max-width:220px;' +
339
+ '"></div>';
340
+
341
+ // Legend (horizontal wrapping)
342
+ var legend = '<div style="display:flex;flex-wrap:wrap;gap:8px 14px;margin-top:8px;">';
343
+ for (var ci = 0; ci < chartTokens.length; ci++) {
344
+ var tok = chartTokens[ci];
345
+ var color = colorMap[tok];
346
+ var weight = (tok === finalToken) ? '700' : '400';
347
+ legend += '<div style="display:flex;align-items:center;gap:4px;">' +
348
+ '<div style="width:12px;height:3px;background:' + color + ';border-radius:2px;"></div>' +
349
+ '<span style="font-family:monospace;font-size:11px;color:' + color +
350
+ ';font-weight:' + weight + ';">' + escapeHtml(tok) + '</span>' +
351
+ '</div>';
352
+ }
353
+ legend += '</div>';
354
+
355
+ // Return HTML + metadata object (avoids DOM attribute serialization issues)
356
+ var chartMeta = {
357
+ tokens: chartTokens,
358
+ data: data,
359
+ colors: colorMap,
360
+ nLayers: nLayers,
361
+ padL: padL,
362
+ padR: padR,
363
+ padT: padT,
364
+ plotW: plotW,
365
+ plotH: plotH,
366
+ W: W,
367
+ finalToken: finalToken
368
+ };
369
+
370
+ var html = '<div class="logit-chart-wrapper" style="position:relative;margin-bottom:16px;">' +
371
+ '<div style="color:#9ca3af;font-size:11px;font-family:monospace;margin-bottom:6px;font-weight:600;">' +
372
+ 'Probability by Layer (top ' + chartTokens.length + ' recurring tokens)</div>' +
373
+ svg + tooltip + legend + '</div>';
374
+
375
+ return { html: html, meta: chartMeta };
376
+ }
377
+
378
+ function attachChartHover(meta) {
379
+ var panel = document.getElementById('logit-lens-panel');
380
+ if (!panel) { console.error('[logit-lens] hover: panel not found'); return; }
381
+
382
+ var wrapper = panel.querySelector('.logit-chart-wrapper');
383
+ if (!wrapper) { console.error('[logit-lens] hover: wrapper not found'); return; }
384
+
385
+ var svgEl = wrapper.querySelector('.logit-chart-svg');
386
+ var crosshair = wrapper.querySelector('.logit-chart-crosshair');
387
+ var tooltipEl = wrapper.querySelector('.logit-chart-tooltip');
388
+ if (!svgEl || !crosshair || !tooltipEl) { console.error('[logit-lens] hover: SVG elements not found', !!svgEl, !!crosshair, !!tooltipEl); return; }
389
+
390
+ // Sort tokens by probability descending at each layer for tooltip display
391
+ function getLayerEntries(layerIdx) {
392
+ var entries = [];
393
+ for (var i = 0; i < meta.tokens.length; i++) {
394
+ var tok = meta.tokens[i];
395
+ var pct = meta.data[tok][layerIdx];
396
+ if (pct > 0) {
397
+ entries.push({token: tok, pct: pct, color: meta.colors[tok]});
398
+ }
399
+ }
400
+ entries.sort(function(a, b) { return b.pct - a.pct; });
401
+ return entries;
402
+ }
403
+
404
+ function mouseToLayer(e) {
405
+ var rect = svgEl.getBoundingClientRect();
406
+ // Map pixel position to SVG viewBox coordinates
407
+ var scaleX = meta.W / rect.width;
408
+ var svgX = (e.clientX - rect.left) * scaleX;
409
+ // Convert SVG X to layer index
410
+ var layerFrac = (svgX - meta.padL) / meta.plotW;
411
+ var layer = Math.round(layerFrac * (meta.nLayers - 1));
412
+ return Math.max(0, Math.min(meta.nLayers - 1, layer));
413
+ }
414
+
415
+ function svgXForLayer(layer) {
416
+ return meta.padL + (layer / (meta.nLayers - 1)) * meta.plotW;
417
+ }
418
+
419
+ svgEl.addEventListener('mousemove', function(e) {
420
+ var layer = mouseToLayer(e);
421
+ var x = svgXForLayer(layer);
422
+
423
+ // Update crosshair position
424
+ crosshair.setAttribute('x1', x);
425
+ crosshair.setAttribute('x2', x);
426
+ crosshair.setAttribute('visibility', 'visible');
427
+
428
+ // Build tooltip content
429
+ var entries = getLayerEntries(layer);
430
+ var label = 'Layer ' + layer;
431
+ if (layer === 0) label += ' (embed)';
432
+ else if (layer === meta.nLayers - 1) label += ' (final)';
433
+
434
+ var html = '<div style="font-weight:600;margin-bottom:4px;color:#9ca3af;">' + label + '</div>';
435
+ for (var i = 0; i < entries.length; i++) {
436
+ var entry = entries[i];
437
+ var isFinal = (entry.token === meta.finalToken);
438
+ var w = isFinal ? '700' : '400';
439
+ html += '<div style="display:flex;align-items:center;gap:5px;margin:1px 0;">' +
440
+ '<div style="width:8px;height:8px;border-radius:50%;background:' + entry.color +
441
+ ';flex-shrink:0;"></div>' +
442
+ '<span style="color:' + entry.color + ';font-weight:' + w + ';overflow:hidden;' +
443
+ 'text-overflow:ellipsis;white-space:nowrap;max-width:120px;">' +
444
+ escapeHtml(entry.token) + '</span>' +
445
+ '<span style="color:#9ca3af;margin-left:auto;">' + entry.pct.toFixed(1) + '%</span></div>';
446
+ }
447
+ if (entries.length === 0) {
448
+ html += '<div style="color:#6b7280;font-style:italic;">No tracked tokens at this layer</div>';
449
+ }
450
+
451
+ tooltipEl.innerHTML = html;
452
+ tooltipEl.style.display = 'block';
453
+
454
+ // Position tooltip relative to wrapper
455
+ var wrapperRect = wrapper.getBoundingClientRect();
456
+ var svgRect = svgEl.getBoundingClientRect();
457
+ var pixelX = (x / meta.W) * svgRect.width + svgRect.left - wrapperRect.left;
458
+ var tooltipW = tooltipEl.offsetWidth;
459
+
460
+ // Flip to left side if tooltip would overflow right edge
461
+ if (pixelX + tooltipW + 12 > wrapperRect.width) {
462
+ tooltipEl.style.left = (pixelX - tooltipW - 12) + 'px';
463
+ } else {
464
+ tooltipEl.style.left = (pixelX + 12) + 'px';
465
+ }
466
+ tooltipEl.style.top = (svgRect.top - wrapperRect.top + meta.padT) + 'px';
467
+ });
468
+
469
+ svgEl.addEventListener('mouseleave', function() {
470
+ crosshair.setAttribute('visibility', 'hidden');
471
+ tooltipEl.style.display = 'none';
472
+ });
473
+
474
+ console.log('[logit-lens] Chart hover attached');
475
+ }
476
+
477
+ document.addEventListener('click', function(e) {
478
+ var token = e.target.closest('.token-span[data-token-index]');
479
+ if (!token) return;
480
+
481
+ console.log('[logit-lens] Token clicked:', token.textContent, 'index:', token.dataset.tokenIndex);
482
+
483
+ // Highlight selected token, clear previous
484
+ document.querySelectorAll('.token-span').forEach(function(s) {
485
+ s.style.background = '';
486
+ });
487
+ token.style.background = 'rgba(96, 165, 250, 0.2)';
488
+
489
+ // Read data from span attributes
490
+ var finalToken = token.dataset.token;
491
+ var prob = parseFloat(token.dataset.prob) || 0;
492
+ var idx = parseInt(token.dataset.tokenIndex);
493
+
494
+ var layersData;
495
+ try {
496
+ layersData = JSON.parse(token.dataset.layers);
497
+ } catch (err) {
498
+ console.error('[logit-lens] Failed to parse layers data:', err);
499
+ return;
500
+ }
501
+
502
+ console.log('[logit-lens] Layers:', layersData.length, 'Final token:', JSON.stringify(finalToken));
503
+
504
+ var nLayers = layersData.length;
505
+
506
+ // Find first layer where final token appears in top-k
507
+ var firstAppearance = -1;
508
+ for (var li = 0; li < nLayers; li++) {
509
+ var tops = layersData[li].top_tokens;
510
+ for (var ti = 0; ti < tops.length; ti++) {
511
+ if (tops[ti].token === finalToken) {
512
+ firstAppearance = li;
513
+ break;
514
+ }
515
+ }
516
+ if (firstAppearance >= 0) break;
517
+ }
518
+
519
+ // Build header
520
+ var appearanceNote = '';
521
+ if (firstAppearance >= 0) {
522
+ appearanceNote = ' &middot; first in top-k at layer ' + firstAppearance;
523
+ } else if (nLayers > 0) {
524
+ appearanceNote = ' &middot; <span style="color:#f87171;">never in top-k</span>';
525
+ }
526
+
527
+ var header = '<div style="font-weight:600;margin-bottom:12px;padding-bottom:8px;' +
528
+ 'border-bottom:1px solid #374151;">' +
529
+ 'Selected: "<span style="color:#60a5fa;">' + escapeHtml(finalToken) + '</span>"' +
530
+ ' (token ' + idx + ', ' + (prob * 100).toFixed(2) + '%)' +
531
+ appearanceNote + '</div>';
532
+
533
+ // Build line chart (returns {html, meta})
534
+ var chartResult = renderLineChart(layersData, finalToken, nLayers);
535
+ var chartHtml = chartResult ? chartResult.html : '';
536
+ var chartMeta = chartResult ? chartResult.meta : null;
537
+
538
+ // Build layer cards (reversed: final layer at top, embedding at bottom)
539
+ var cards = '';
540
+ for (var i = nLayers - 1; i >= 0; i--) {
541
+ cards += renderLayerCard(layersData[i], finalToken, nLayers, i);
542
+ }
543
+
544
+ var grid = '<div style="display:grid;grid-template-columns:repeat(auto-fill,minmax(200px,1fr));gap:6px;">' +
545
+ cards + '</div>';
546
+
547
+ // Update panel: header -> chart -> grid
548
+ var panel = document.getElementById('logit-lens-panel');
549
+ if (panel) {
550
+ panel.innerHTML = header + chartHtml + grid;
551
+ if (chartMeta) attachChartHover(chartMeta);
552
+ console.log('[logit-lens] Panel updated with chart +', nLayers, 'layers');
553
+ } else {
554
+ console.error('[logit-lens] Panel element #logit-lens-panel not found');
555
+ }
556
+ });
557
+ })();
558
+ """
559
+
560
+ # Initial HTML for the logit lens panel
561
+ LOGIT_LENS_PANEL_INITIAL = """
562
+ <div id="logit-lens-panel" style="
563
+ padding: 16px;
564
+ background: #1f2937;
565
+ border-radius: 8px;
566
+ color: #e5e7eb;
567
+ font-family: system-ui, -apple-system, sans-serif;
568
+ font-size: 14px;
569
+ min-height: 100px;
570
+ max-height: 600px;
571
+ overflow-y: auto;
572
+ ">
573
+ <div style="color: #9ca3af; font-style: italic;">
574
+ Click on any generated token to see per-layer predictions.
575
+ </div>
576
+ </div>
577
+ """
578
+
579
+
580
+ # Build Gradio interface
581
+ with gr.Blocks(title="Logit Lens Explorer") as demo:
582
+ gr.Markdown("# Logit Lens Explorer")
583
+ gr.Markdown(
584
+ "Enter a prompt to generate text. Click any token to see per-layer predictions."
585
+ )
586
+
587
+ prompt_input = gr.Textbox(
588
+ label="Prompt",
589
+ placeholder="Enter a prompt...",
590
+ lines=2,
591
+ )
592
+ submit_btn = gr.Button("Generate", variant="primary")
593
+
594
+ gr.Markdown("### Generated Tokens")
595
+ gr.Markdown("*Click any token to inspect its per-layer predictions.*")
596
+ token_display = gr.HTML(
597
+ value='<div style="color: #666; padding: 10px;">Enter a prompt and click Generate to start.</div>',
598
+ )
599
+
600
+ gr.Markdown("### Logit Lens Panel")
601
+ logit_lens_panel = gr.HTML(
602
+ value=LOGIT_LENS_PANEL_INITIAL,
603
+ )
604
+
605
+ # Wire up generation: button click and Enter key in textbox
606
+ submit_btn.click(
607
+ fn=generate_streaming,
608
+ inputs=[prompt_input],
609
+ outputs=[token_display],
610
+ )
611
+ prompt_input.submit(
612
+ fn=generate_streaming,
613
+ inputs=[prompt_input],
614
+ outputs=[token_display],
615
+ )
616
+
617
+
618
+ if __name__ == "__main__":
619
+ if not SPACES_AVAILABLE:
620
+ print("Preloading model (local development)...")
621
+ load_model()
622
+ else:
623
+ print("ZeroGPU detected - model will load on first inference request")
624
+ print("Starting Gradio server...")
625
+ demo.launch(server_port=7861, js=TOKEN_CLICK_JS)
model.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model loading and inference for Logit Lens Explorer.
3
+
4
+ Loads Llama-3.2-3B-Instruct and provides inference with hidden state
5
+ capture for logit lens visualization.
6
+
7
+ Part of E02: Logit Lens Explorer.
8
+ """
9
+
10
+ from dataclasses import dataclass
11
+ from typing import Generator
12
+
13
+ import torch
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
15
+
16
+ MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
17
+
18
+ _model = None
19
+ _tokenizer = None
20
+ _device = None
21
+
22
+
23
+ @dataclass
24
+ class LayerPrediction:
25
+ """Top-k token predictions from a single transformer layer."""
26
+
27
+ layer_index: int # 0 = embedding, 1-28 = transformer layers
28
+ top_tokens: list[dict] # [{"token": str, "probability": float}, ...]
29
+
30
+
31
+ @dataclass
32
+ class TokenData:
33
+ """Data for a single generated token with per-layer logit lens predictions."""
34
+
35
+ token: str
36
+ token_id: int
37
+ probability: float
38
+ layer_predictions: list[LayerPrediction] # len = 29 (embedding + 28 layers)
39
+
40
+
41
+ def load_model():
42
+ """Load the Llama model and tokenizer. Uses cached singleton."""
43
+ global _model, _tokenizer, _device
44
+
45
+ if _model is not None:
46
+ return _model, _tokenizer
47
+
48
+ _device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
49
+ print(f"Using device: {_device}")
50
+ print(f"Loading model: {MODEL_ID}...")
51
+
52
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
53
+ _model = AutoModelForCausalLM.from_pretrained(
54
+ MODEL_ID,
55
+ attn_implementation="flash_attention_2",
56
+ torch_dtype=torch.float16,
57
+ ).to(_device).eval()
58
+
59
+ print("Model loaded successfully")
60
+ return _model, _tokenizer
61
+
62
+
63
+ def project_hidden_states(
64
+ hidden_states: torch.Tensor,
65
+ model,
66
+ tokenizer,
67
+ top_k: int = 20,
68
+ ) -> list[LayerPrediction]:
69
+ """Batch-project hidden states through RMSNorm + lm_head.
70
+
71
+ Takes stacked hidden states from all layers and projects them through
72
+ the model's final normalization and unembedding head in a single
73
+ batched operation.
74
+
75
+ Args:
76
+ hidden_states: Stacked hidden states, shape (n_layers, 1, hidden_dim).
77
+ model: The causal LM model with .model.norm and .lm_head.
78
+ tokenizer: Tokenizer for decoding token IDs.
79
+ top_k: Number of top predictions per layer.
80
+
81
+ Returns:
82
+ List of LayerPrediction, one per layer.
83
+ """
84
+ # Reshape to (n_layers, hidden_dim), removing any size-1 middle dims, upcast to float32
85
+ n_layers = hidden_states.shape[0]
86
+ hidden_dim = hidden_states.shape[-1]
87
+ hs = hidden_states.reshape(n_layers, hidden_dim).float()
88
+
89
+ # Apply final RMSNorm (float32 for numerical stability)
90
+ normed = model.model.norm(hs)
91
+ # Cast back to model weight dtype for lm_head linear projection
92
+ logits = model.lm_head(normed.to(model.lm_head.weight.dtype))
93
+
94
+ # Softmax in float32 to avoid overflow
95
+ probs = torch.softmax(logits.float(), dim=-1)
96
+ top_probs, top_indices = torch.topk(probs, k=top_k, dim=-1)
97
+
98
+ # Move to CPU once for all layers
99
+ top_probs_cpu = top_probs.cpu().tolist()
100
+ top_indices_cpu = top_indices.cpu().tolist()
101
+
102
+ predictions = []
103
+ for layer_idx in range(len(top_probs_cpu)):
104
+ top_tokens = [
105
+ {"token": tokenizer.decode([int(idx)]), "probability": prob}
106
+ for prob, idx in zip(top_probs_cpu[layer_idx], top_indices_cpu[layer_idx])
107
+ ]
108
+ predictions.append(LayerPrediction(
109
+ layer_index=layer_idx,
110
+ top_tokens=top_tokens,
111
+ ))
112
+ return predictions
113
+
114
+
115
+ def generate_with_logit_lens(
116
+ prompt: str,
117
+ max_new_tokens: int = 512,
118
+ top_k: int = 20,
119
+ ) -> Generator[TokenData, None, None]:
120
+ """Generate text token-by-token with per-layer logit lens predictions.
121
+
122
+ Uses greedy decoding (argmax) for deterministic text generation, but
123
+ records the natural softmax probabilities (temperature=1) for the logit
124
+ lens visualization so layer predictions reflect the model's true
125
+ confidence distribution.
126
+
127
+ Args:
128
+ prompt: User prompt text.
129
+ max_new_tokens: Maximum tokens to generate.
130
+ top_k: Number of top predictions per layer for logit lens.
131
+
132
+ Yields:
133
+ TokenData with token string, ID, probability, and per-layer predictions.
134
+ """
135
+ model, tokenizer = load_model()
136
+
137
+ messages = [{"role": "user", "content": prompt}]
138
+ prompt_full = tokenizer.apply_chat_template(
139
+ messages, tokenize=False, add_generation_prompt=True
140
+ )
141
+
142
+ inputs = tokenizer(prompt_full, return_tensors="pt").to(_device)
143
+ input_ids = inputs.input_ids
144
+ attention_mask = inputs.attention_mask
145
+
146
+ # EOS token IDs for stopping
147
+ eos_token_id = model.config.eos_token_id
148
+ if isinstance(eos_token_id, int):
149
+ eos_token_id = [eos_token_id]
150
+ elif eos_token_id is None:
151
+ eos_token_id = []
152
+
153
+ generated_ids = input_ids.clone()
154
+ past_key_values = DynamicCache()
155
+ seq_length = input_ids.shape[1]
156
+
157
+ with torch.no_grad():
158
+ for step in range(max_new_tokens):
159
+ if step == 0:
160
+ cache_position = torch.arange(seq_length, device=_device)
161
+ outputs = model(
162
+ input_ids=generated_ids,
163
+ attention_mask=attention_mask,
164
+ cache_position=cache_position,
165
+ past_key_values=past_key_values,
166
+ output_hidden_states=True,
167
+ return_dict=True,
168
+ use_cache=True,
169
+ )
170
+ else:
171
+ cache_position = torch.tensor([seq_length], device=_device)
172
+ outputs = model(
173
+ input_ids=generated_ids[:, -1:],
174
+ attention_mask=attention_mask,
175
+ cache_position=cache_position,
176
+ past_key_values=past_key_values,
177
+ output_hidden_states=True,
178
+ return_dict=True,
179
+ use_cache=True,
180
+ )
181
+
182
+ past_key_values = outputs.past_key_values
183
+
184
+ # Greedy decoding with natural probability recording
185
+ next_token_logits = outputs.logits[:, -1, :].float()
186
+ probs = torch.softmax(next_token_logits, dim=-1)
187
+ next_token_id = torch.argmax(probs, dim=-1).item()
188
+ next_token_prob = probs[0, next_token_id].item()
189
+
190
+ if next_token_id in eos_token_id:
191
+ break
192
+
193
+ # Eager logit lens: stack last-position hidden state from each layer
194
+ # outputs.hidden_states is a tuple of (n_layers+1) tensors,
195
+ # each shape (batch, seq_len, hidden_dim)
196
+ hidden_states = torch.stack([
197
+ hs[:, -1:, :] for hs in outputs.hidden_states
198
+ ]) # (n_layers, 1, hidden_dim)
199
+
200
+ layer_predictions = project_hidden_states(
201
+ hidden_states, model, tokenizer, top_k=top_k
202
+ )
203
+
204
+ token_str = tokenizer.decode([next_token_id])
205
+
206
+ yield TokenData(
207
+ token=token_str,
208
+ token_id=next_token_id,
209
+ probability=next_token_prob,
210
+ layer_predictions=layer_predictions,
211
+ )
212
+
213
+ # Update for next iteration
214
+ next_token_tensor = torch.tensor([[next_token_id]], device=_device)
215
+ generated_ids = torch.cat([generated_ids, next_token_tensor], dim=-1)
216
+ attention_mask = torch.cat(
217
+ [attention_mask, torch.ones((1, 1), device=_device, dtype=attention_mask.dtype)],
218
+ dim=-1,
219
+ )
220
+ seq_length += 1
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Logit Lens - Gradio Dependencies
2
+
3
+ # Gradio UI
4
+ gradio>=6.4.0
5
+
6
+ # HuggingFace Spaces (ZeroGPU support)
7
+ spaces
8
+
9
+ # PyTorch + CUDA
10
+ torch==2.6.0
11
+ torchvision
12
+
13
+ # Transformers + Qwen VL
14
+ transformers==4.57.3
15
+ qwen-vl-utils
16
+ huggingface_hub
17
+
18
+ # Attention + Acceleration
19
+ flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
20
+ git+https://github.com/huggingface/accelerate.git
21
+ git+https://github.com/huggingface/peft.git
22
+ transformers-stream-generator
23
+
24
+ # Image processing
25
+ Pillow
26
+ sentencepiece