Paar, F. (Ferdinand) commited on
Commit
501ea17
·
1 Parent(s): 5fd77e7

D3 libary for better vis 2.4

Browse files
frontend/ model_view.js ADDED
File without changes
frontend/index.html CHANGED
@@ -28,7 +28,7 @@
28
  Etiam vel dui consequat ante lobortis facilisis. Curabitur facilisis enim sed massa
29
  consectetur, et mollis magna suscipit. Vestibulum odio ante, laoreet at mauris eu,
30
  feugiat pretium tortor. Ut condimentum laoreet felis quis rutrum. Proin vel elit a
31
- augue ornare venenatis id nec neque.Ferdi 2.3
32
  </p>
33
  </section>
34
 
@@ -116,6 +116,6 @@
116
  </aside>
117
  </div>
118
 
119
- <script src="static/script.js?v=20281533"></script>
120
  </body>
121
  </html>
 
28
  Etiam vel dui consequat ante lobortis facilisis. Curabitur facilisis enim sed massa
29
  consectetur, et mollis magna suscipit. Vestibulum odio ante, laoreet at mauris eu,
30
  feugiat pretium tortor. Ut condimentum laoreet felis quis rutrum. Proin vel elit a
31
+ augue ornare venenatis id nec neque.Ferdi 2.4
32
  </p>
33
  </section>
34
 
 
116
  </aside>
117
  </div>
118
 
119
+ <script src="static/script.js?v=202815533"></script>
120
  </body>
121
  </html>
frontend/script.js CHANGED
@@ -48,93 +48,86 @@ function renderModelView(tokens, attention) {
48
  }
49
 
50
  function createAttentionThumbnail(tokens, attention, layerIdx, headIdx) {
51
- const padding = 20; // Define padding before using it
52
- const thumbnail = document.createElement("div");
53
- thumbnail.style.position = "relative";
54
- // Dynamic height: top and bottom padding plus 20px per token.
55
- thumbnail.style.height = (padding * 2 + tokens.length * 20) + "px";
56
- thumbnail.style.width = "120px";
57
- thumbnail.style.border = "1px solid #ddd";
58
- thumbnail.style.borderRadius = "4px";
59
- thumbnail.style.padding = "5px";
60
- thumbnail.style.background = "#fff";
61
-
62
- // Create SVG container with 100% width/height of the thumbnail
63
- const svg = document.createElementNS("http://www.w3.org/2000/svg", "svg");
64
- svg.setAttribute("width", "100%");
65
- svg.setAttribute("height", "100%");
66
-
67
- // Add header text for layer and head using template literals
68
- const header = document.createElementNS("http://www.w3.org/2000/svg", "text");
69
- header.setAttribute("x", "50%");
70
- header.setAttribute("y", "15");
71
- header.setAttribute("text-anchor", "middle");
72
- header.setAttribute("font-size", "10");
73
- header.textContent = `L${layerIdx + 1} H${headIdx + 1}`;
74
- svg.appendChild(header);
75
-
76
- const attentionWeights = attention[layerIdx][headIdx];
77
  const maxLineWidth = 4;
78
  const maxOpacity = 0.8;
79
 
80
- // Normalize weights per head: Compute maximum weight in this head.
 
81
  const headMax = Math.max(...attentionWeights.flat()) || 1;
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  // Draw tokens on both sides (vertically aligned)
84
  tokens.forEach((token, idx) => {
85
  const cleanToken = token.replace("Ġ", "");
86
 
87
- // Left side: fixed x position to 'padding'
88
- const leftText = document.createElementNS("http://www.w3.org/2000/svg", "text");
89
- leftText.setAttribute("x", padding);
90
- leftText.setAttribute("y", padding + (idx * 20));
91
- leftText.setAttribute("font-size", "8");
92
- leftText.textContent = cleanToken;
93
- svg.appendChild(leftText);
94
 
95
- // Right side: x is set to 90% of the SVG width
96
- const rightText = document.createElementNS("http://www.w3.org/2000/svg", "text");
97
- rightText.setAttribute("x", "90%");
98
- rightText.setAttribute("y", padding + (idx * 20));
99
- rightText.setAttribute("text-anchor", "end");
100
- rightText.setAttribute("font-size", "8");
101
- rightText.textContent = cleanToken;
102
- svg.appendChild(rightText);
103
  });
104
 
105
- // Draw attention lines between tokens with normalized weights
106
  attentionWeights.forEach((sourceWeights, sourceIdx) => {
107
  sourceWeights.forEach((weight, targetIdx) => {
108
  if (weight > 0.05 && sourceIdx !== targetIdx) {
109
- // Normalize the weight relative to the maximum weight in this head
110
  const normalizedWeight = weight / headMax;
111
-
112
- const line = document.createElementNS("http://www.w3.org/2000/svg", "line");
113
- // Align lines with tokens using the same padding and vertical spacing
114
- line.setAttribute("x1", padding);
115
- line.setAttribute("y1", padding + (sourceIdx * 20) - 5);
116
- line.setAttribute("x2", "90%");
117
- line.setAttribute("y2", padding + (targetIdx * 20) - 5);
118
- line.setAttribute("stroke", "#2ecc71");
119
- line.setAttribute("stroke-width", Math.max(0.5, normalizedWeight * maxLineWidth));
120
- line.setAttribute("opacity", Math.min(maxOpacity, normalizedWeight * 2));
121
- line.setAttribute("stroke-linecap", "round");
122
- svg.appendChild(line);
123
  }
124
  });
125
  });
126
 
127
- thumbnail.appendChild(svg);
128
-
129
- // On click, update UI elements (for example, display the selected layer and head)
130
- thumbnail.addEventListener("click", () => {
131
- document.getElementById("head").textContent = `Head: ${headIdx + 1}`;
132
- document.getElementById("layer").textContent = `Layer: ${layerIdx + 1}`;
133
- // Dynamically update displayHoverTokens with the selected layer and head
134
- displayHoverTokens(data, layerIdx, headIdx);
135
  });
136
 
137
- return thumbnail;
138
  }
139
  // Function to display the tokens and attention values
140
  function displayOutput(data) {
 
48
  }
49
 
50
  function createAttentionThumbnail(tokens, attention, layerIdx, headIdx) {
51
+ const padding = 20;
52
+ const tokenHeight = 20;
53
+ const width = 120;
54
+ const height = padding * 2 + tokens.length * tokenHeight;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  const maxLineWidth = 4;
56
  const maxOpacity = 0.8;
57
 
58
+ // Extract attention weights for this head and compute the normalization factor.
59
+ const attentionWeights = attention[layerIdx][headIdx];
60
  const headMax = Math.max(...attentionWeights.flat()) || 1;
61
 
62
+ // Create a thumbnail container using D3
63
+ const thumbnail = d3.select(document.createElement("div"))
64
+ .style("position", "relative")
65
+ .style("height", height + "px")
66
+ .style("width", width + "px")
67
+ .style("border", "1px solid #ddd")
68
+ .style("border-radius", "4px")
69
+ .style("padding", "5px")
70
+ .style("background", "#fff");
71
+
72
+ // Append an SVG container inside the thumbnail
73
+ const svg = thumbnail.append("svg")
74
+ .attr("width", "100%")
75
+ .attr("height", "100%");
76
+
77
+ // Add header text using template literals
78
+ svg.append("text")
79
+ .attr("x", "50%")
80
+ .attr("y", 15)
81
+ .attr("text-anchor", "middle")
82
+ .attr("font-size", "10")
83
+ .text(`L${layerIdx + 1} H${headIdx + 1}`);
84
+
85
  // Draw tokens on both sides (vertically aligned)
86
  tokens.forEach((token, idx) => {
87
  const cleanToken = token.replace("Ġ", "");
88
 
89
+ // Left side token
90
+ svg.append("text")
91
+ .attr("x", padding)
92
+ .attr("y", padding + idx * tokenHeight)
93
+ .attr("font-size", "8")
94
+ .text(cleanToken);
 
95
 
96
+ // Right side token
97
+ svg.append("text")
98
+ .attr("x", "90%")
99
+ .attr("y", padding + idx * tokenHeight)
100
+ .attr("text-anchor", "end")
101
+ .attr("font-size", "8")
102
+ .text(cleanToken);
 
103
  });
104
 
105
+ // Draw attention lines with normalized weights
106
  attentionWeights.forEach((sourceWeights, sourceIdx) => {
107
  sourceWeights.forEach((weight, targetIdx) => {
108
  if (weight > 0.05 && sourceIdx !== targetIdx) {
 
109
  const normalizedWeight = weight / headMax;
110
+ svg.append("line")
111
+ .attr("x1", padding)
112
+ .attr("y1", padding + sourceIdx * tokenHeight - 5)
113
+ .attr("x2", "90%")
114
+ .attr("y2", padding + targetIdx * tokenHeight - 5)
115
+ .attr("stroke", "#2ecc71")
116
+ .attr("stroke-width", Math.max(0.5, normalizedWeight * maxLineWidth))
117
+ .attr("opacity", Math.min(maxOpacity, normalizedWeight * 2))
118
+ .attr("stroke-linecap", "round");
 
 
 
119
  }
120
  });
121
  });
122
 
123
+ // On click, update UI elements (assumes 'data' is globally available)
124
+ thumbnail.on("click", function() {
125
+ d3.select("#head").text(`Head: ${headIdx + 1}`);
126
+ d3.select("#layer").text(`Layer: ${layerIdx + 1}`);
127
+ displayHoverTokens(data, layerIdx, headIdx);
 
 
 
128
  });
129
 
130
+ return thumbnail.node();
131
  }
132
  // Function to display the tokens and attention values
133
  function displayOutput(data) {