Paar, F. (Ferdinand) commited on
Commit
d31e7ee
·
1 Parent(s): 4e93d08

limit textarea to 200 chars and adjust thumbnail sizing3

Browse files
Files changed (2) hide show
  1. frontend/index.html +1 -70
  2. frontend/script.js +20 -19
frontend/index.html CHANGED
@@ -85,7 +85,7 @@
85
  id="inputText"
86
  rows="2"
87
  cols="50"
88
- maxlength="200"
89
  placeholder="Enter your text here..."
90
  autofocus
91
  ></textarea>
@@ -330,72 +330,3 @@
330
  </script>
331
  </body>
332
  </html>
333
-
334
- <!DOCTYPE html>
335
- <html lang="en">
336
- <head>
337
- <meta charset="UTF-8" />
338
- <meta name="viewport" content="width=device-width, initial-scale=1.0"/>
339
- <title>DeepGaze</title>
340
- <link rel="stylesheet" href="static/styles.css" />
341
- <script src="https://d3js.org/d3.v6.min.js"></script>
342
- <style>
343
- /* Process button styling */
344
- .text-form button {
345
- background-color: #800000; /* Bordo red */
346
- color: #fff;
347
- border: none;
348
- padding: 8px 12px;
349
- font-size: 14px;
350
- border-radius: 4px;
351
- cursor: pointer;
352
- }
353
- /* Info container styling for head and layer display */
354
- .info-container {
355
- display: flex;
356
- align-items: center;
357
- margin: 5px 0;
358
- font-size: 0.9rem;
359
- }
360
- .info-container .label {
361
- margin-right: 5px;
362
- }
363
- .info-container .number-box {
364
- border: 1px solid #800000;
365
- border-radius: 4px;
366
- padding: 2px 6px;
367
- font-weight: bold;
368
- color: #800000;
369
- min-width: 20px;
370
- text-align: center;
371
- }
372
- .credits-container {
373
- background-color: #800000;
374
- padding: 10px 15px;
375
- border-radius: 8px;
376
- text-align: center;
377
- font-family: Arial, sans-serif;
378
- font-size: 14px;
379
- color: #ffffff;
380
- margin: 20px auto;
381
- width: fit-content;
382
- box-shadow: 2px 2px 10px rgba(0, 0, 0, 0.1);
383
- }
384
- .credits-container a {
385
- color: #ffcccc;
386
- text-decoration: none;
387
- }
388
- .credits-container a:hover {
389
- text-decoration: underline;
390
- }
391
- </style>
392
- </head>
393
- <body>
394
- <!-- Your existing content remains unchanged -->
395
-
396
- <div class="credits-container">
397
- Created by Samu and Ferdi - Credits to <a href="https://github.com/jessevig/bertviz" target="_blank">BertViz</a>
398
- </div>
399
- <script src="static/script.js?v=2025001fe143433"></script>
400
- </body>
401
- </html>
 
85
  id="inputText"
86
  rows="2"
87
  cols="50"
88
+ maxlength="50"
89
  placeholder="Enter your text here..."
90
  autofocus
91
  ></textarea>
 
330
  </script>
331
  </body>
332
  </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/script.js CHANGED
@@ -18,14 +18,19 @@ document.getElementById('textForm').addEventListener('submit', async (e) => {
18
  // Use data.tokens and data.attention from your POST response
19
  displayOutput(data);
20
  displayHoverTokens(data, 0, 0);
21
- renderModelView(data.tokens, data.attention);
 
22
  } catch (error) {
23
  console.error('Error:', error);
24
  document.getElementById('output').innerText = 'Error processing text.';
25
  }
26
  });
27
 
28
- function renderModelView(tokens, attention) {
 
 
 
 
29
  const container = document.getElementById("model_view_container");
30
  if (!container) return;
31
  container.innerHTML = "";
@@ -36,10 +41,10 @@ function renderModelView(tokens, attention) {
36
  gridContainer.style.gridGap = "10px";
37
  gridContainer.style.padding = "20px";
38
 
39
- // Loop over all 12 layers and 12 heads
40
  for (let layerIdx = 0; layerIdx < 12; layerIdx++) {
41
  for (let headIdx = 0; headIdx < 12; headIdx++) {
42
- const thumbnail = createAttentionThumbnail(tokens, attention, layerIdx, headIdx);
43
  gridContainer.appendChild(thumbnail);
44
  }
45
  }
@@ -47,7 +52,11 @@ function renderModelView(tokens, attention) {
47
  container.appendChild(gridContainer);
48
  }
49
 
50
- function createAttentionThumbnail(tokens, attention, layerIdx, headIdx) {
 
 
 
 
51
  const width = 80;
52
  const tokenHeight = 15;
53
  const padding = 10;
@@ -82,14 +91,10 @@ function createAttentionThumbnail(tokens, attention, layerIdx, headIdx) {
82
  .attr("font-size", "10")
83
  .text(`L${layerIdx + 1} H${headIdx + 1}`);
84
 
85
- // (Removed the drawing of token text on the left and right to keep it clean.)
86
-
87
  // Draw attention lines with per-row normalization.
88
  attention[layerIdx][headIdx].forEach((sourceWeights, sourceIdx) => {
89
- // Compute the maximum weight in this row (or use 1 if all are zero).
90
  const rowMax = Math.max(...sourceWeights) || 1;
91
  sourceWeights.forEach((weight, targetIdx) => {
92
- // Only draw if the weight exceeds a threshold and is not self-attention.
93
  if (weight > 0.01 && sourceIdx !== targetIdx) {
94
  const normalizedWeight = weight / rowMax;
95
  svg.append("line")
@@ -97,7 +102,7 @@ function createAttentionThumbnail(tokens, attention, layerIdx, headIdx) {
97
  .attr("y1", padding + sourceIdx * tokenHeight - 5)
98
  .attr("x2", xRight)
99
  .attr("y2", padding + targetIdx * tokenHeight - 5)
100
- .attr("stroke", "#800000") // Changed color to bordo red.
101
  .attr("stroke-width", Math.max(0.5, normalizedWeight * maxLineWidth))
102
  .attr("opacity", Math.min(maxOpacity, normalizedWeight * 2))
103
  .attr("stroke-linecap", "round");
@@ -105,19 +110,19 @@ function createAttentionThumbnail(tokens, attention, layerIdx, headIdx) {
105
  });
106
  });
107
 
108
- // Click handler: update UI with head and layer numbers.
109
  thumbnail.on("click", function() {
110
- d3.select("#head").text(`Head: ${headIdx + 1}`);
111
- d3.select("#layer").text(`Layer: ${layerIdx + 1}`);
112
  displayHoverTokens(data, layerIdx, headIdx);
113
  });
114
 
115
  return thumbnail.node();
116
  }
 
117
  // Function to display the tokens and attention values
118
  function displayOutput(data) {
119
  const outputDiv = document.getElementById('output');
120
- // Show raw tokens and attention data for debugging or inspection
121
  outputDiv.innerHTML = `
122
  <h2>Tokens</h2>
123
  <pre>${JSON.stringify(data.tokens, null, 2)}</pre>
@@ -168,16 +173,13 @@ function resetTokenSizes() {
168
 
169
  function highlightAttention(index, attentionData, layer_idx, head_idx) {
170
  const container = document.getElementById('tokenContainer');
171
- // Retrieve the attention weights row for the hovered token
172
  const row = attentionData[layer_idx][head_idx][index];
173
  if (!row) {
174
  console.warn(`No attention data for token index ${index}`);
175
  return;
176
  }
177
 
178
- // Consider only the preceding tokens
179
  const weights = row.slice(0, index);
180
- // Normalize using the maximum weight from the current row (or 1 if zero)
181
  const maxWeight = Math.max(...attentionData[layer_idx][head_idx]) || 1;
182
  const baseFontSize = 32;
183
  const maxIncrease = 20;
@@ -185,11 +187,10 @@ function highlightAttention(index, attentionData, layer_idx, head_idx) {
185
  Array.from(container.children).forEach((span, idx) => {
186
  if (idx < index) {
187
  const weight = weights[idx];
188
- // Calculate the new font size based on the attention weight
189
  const newFontSize = baseFontSize + (weight / maxWeight) * maxIncrease;
190
  span.style.fontSize = newFontSize + "px";
191
  } else {
192
  span.style.fontSize = baseFontSize + "px";
193
  }
194
  });
195
- }
 
18
  // Use data.tokens and data.attention from your POST response
19
  displayOutput(data);
20
  displayHoverTokens(data, 0, 0);
21
+ // Changed call here to pass the entire data object
22
+ renderModelView(data);
23
  } catch (error) {
24
  console.error('Error:', error);
25
  document.getElementById('output').innerText = 'Error processing text.';
26
  }
27
  });
28
 
29
+ function renderModelView(data) {
30
+ // Extract tokens and attention from the data object
31
+ const tokens = data.tokens;
32
+ const attention = data.attention;
33
+
34
  const container = document.getElementById("model_view_container");
35
  if (!container) return;
36
  container.innerHTML = "";
 
41
  gridContainer.style.gridGap = "10px";
42
  gridContainer.style.padding = "20px";
43
 
44
+ // Loop over all 12 layers and 12 heads, passing the complete data object
45
  for (let layerIdx = 0; layerIdx < 12; layerIdx++) {
46
  for (let headIdx = 0; headIdx < 12; headIdx++) {
47
+ const thumbnail = createAttentionThumbnail(data, layerIdx, headIdx);
48
  gridContainer.appendChild(thumbnail);
49
  }
50
  }
 
52
  container.appendChild(gridContainer);
53
  }
54
 
55
+ function createAttentionThumbnail(data, layerIdx, headIdx) {
56
+ // Extract tokens and attention from the data object
57
+ const tokens = data.tokens;
58
+ const attention = data.attention;
59
+
60
  const width = 80;
61
  const tokenHeight = 15;
62
  const padding = 10;
 
91
  .attr("font-size", "10")
92
  .text(`L${layerIdx + 1} H${headIdx + 1}`);
93
 
 
 
94
  // Draw attention lines with per-row normalization.
95
  attention[layerIdx][headIdx].forEach((sourceWeights, sourceIdx) => {
 
96
  const rowMax = Math.max(...sourceWeights) || 1;
97
  sourceWeights.forEach((weight, targetIdx) => {
 
98
  if (weight > 0.01 && sourceIdx !== targetIdx) {
99
  const normalizedWeight = weight / rowMax;
100
  svg.append("line")
 
102
  .attr("y1", padding + sourceIdx * tokenHeight - 5)
103
  .attr("x2", xRight)
104
  .attr("y2", padding + targetIdx * tokenHeight - 5)
105
+ .attr("stroke", "#800000") // Bordo red
106
  .attr("stroke-width", Math.max(0.5, normalizedWeight * maxLineWidth))
107
  .attr("opacity", Math.min(maxOpacity, normalizedWeight * 2))
108
  .attr("stroke-linecap", "round");
 
110
  });
111
  });
112
 
113
+ // Click handler remains unchanged, using the passed-in data object.
114
  thumbnail.on("click", function() {
115
+ d3.select("#display_head .number-box").text(headIdx + 1);
116
+ d3.select("#display_layer .number-box").text(layerIdx + 1);
117
  displayHoverTokens(data, layerIdx, headIdx);
118
  });
119
 
120
  return thumbnail.node();
121
  }
122
+
123
  // Function to display the tokens and attention values
124
  function displayOutput(data) {
125
  const outputDiv = document.getElementById('output');
 
126
  outputDiv.innerHTML = `
127
  <h2>Tokens</h2>
128
  <pre>${JSON.stringify(data.tokens, null, 2)}</pre>
 
173
 
174
  function highlightAttention(index, attentionData, layer_idx, head_idx) {
175
  const container = document.getElementById('tokenContainer');
 
176
  const row = attentionData[layer_idx][head_idx][index];
177
  if (!row) {
178
  console.warn(`No attention data for token index ${index}`);
179
  return;
180
  }
181
 
 
182
  const weights = row.slice(0, index);
 
183
  const maxWeight = Math.max(...attentionData[layer_idx][head_idx]) || 1;
184
  const baseFontSize = 32;
185
  const maxIncrease = 20;
 
187
  Array.from(container.children).forEach((span, idx) => {
188
  if (idx < index) {
189
  const weight = weights[idx];
 
190
  const newFontSize = baseFontSize + (weight / maxWeight) * maxIncrease;
191
  span.style.fontSize = newFontSize + "px";
192
  } else {
193
  span.style.fontSize = baseFontSize + "px";
194
  }
195
  });
196
+ }