Paar, F. (Ferdinand) commited on
Commit
42370b7
·
1 Parent(s): 90fb67f

thubnail filled and layer head box next to each other

Browse files
Files changed (1) hide show
  1. frontend/script.js +45 -49
frontend/script.js CHANGED
@@ -20,6 +20,7 @@ document.getElementById('textForm').addEventListener('submit', async (e) => {
20
  displayHoverTokens(data, 0, 0);
21
  // Changed call here to pass the entire data object
22
  renderModelView(data);
 
23
  } catch (error) {
24
  console.error('Error:', error);
25
  document.getElementById('output').innerText = 'Error processing text.';
@@ -51,94 +52,88 @@ function renderModelView(data) {
51
 
52
  container.appendChild(gridContainer);
53
  }
54
- function createFirstThumbnail(data) {
55
  const tokens = data.tokens;
56
  const attention = data.attention;
57
 
58
- const width = 200; // Increased size
59
- const height = 200;
60
- const tokenHeight = 15;
61
- const paddingTop = 20; // Space for text above graph
62
- const paddingBottom = 20; // Space for text below graph
63
- const graphHeight = height - (paddingTop + paddingBottom); // Adjust graph size
64
-
65
- const maxLineWidth = 4;
66
- const maxOpacity = 0.8;
67
- const xRight = width - 10; // Adjust for proper positioning
68
 
69
- // Create the thumbnail container
70
- const thumbnail = d3.select("#thumbnailContainer")
71
  .append("div")
72
- .attr("id", "firstThumbnail")
73
- .style("position", "relative")
74
- .style("height", height + "px")
75
  .style("width", width + "px")
 
76
  .style("border", "2px solid #ddd")
77
  .style("border-radius", "6px")
78
- .style("padding", "5px")
79
  .style("background", "#fff")
80
- .style("cursor", "pointer")
81
  .style("display", "flex")
82
  .style("flex-direction", "column")
83
- .style("align-items", "center")
84
- .style("justify-content", "space-between");
85
 
86
- // Top Token Text (Above Graph)
87
- thumbnail.append("div")
88
- .classed("token-text", true)
89
  .style("width", "100%")
90
  .style("text-align", "center")
91
- .style("font-size", "14px")
92
- .style("padding-bottom", "5px")
93
- .text(tokens.join(" ")); // Display all tokens in a line
94
 
95
- // SVG Graph Container
96
- const svg = thumbnail.append("svg")
97
  .attr("width", width)
98
  .attr("height", graphHeight)
99
  .attr("viewBox", `0 0 ${width} ${graphHeight}`)
100
  .attr("preserveAspectRatio", "none");
101
 
102
- // Header text inside the SVG
103
  svg.append("text")
104
  .attr("x", width / 2)
105
- .attr("y", 15)
106
  .attr("text-anchor", "middle")
107
- .attr("font-size", "12px")
108
- .text(`Layer 1 - Head 1`);
109
 
110
- // Compute dynamic Y positions for graph
111
- const yStart = paddingTop;
 
112
 
113
- // Draw attention lines
 
114
  attention[0][0].forEach((sourceWeights, sourceIdx) => {
115
  const rowMax = Math.max(...sourceWeights) || 1;
116
  sourceWeights.forEach((weight, targetIdx) => {
117
  if (weight > 0.01 && sourceIdx !== targetIdx) {
118
  const normalizedWeight = weight / rowMax;
119
  svg.append("line")
120
- .attr("x1", 10)
121
- .attr("y1", yStart + sourceIdx * tokenHeight)
122
- .attr("x2", xRight)
123
- .attr("y2", yStart + targetIdx * tokenHeight)
124
  .attr("stroke", "#800000")
125
- .attr("stroke-width", Math.max(0.5, normalizedWeight * maxLineWidth))
126
- .attr("opacity", Math.min(maxOpacity, normalizedWeight * 2))
127
  .attr("stroke-linecap", "round");
128
  }
129
  });
130
  });
131
 
132
- // Bottom Token Text (Below Graph)
133
- thumbnail.append("div")
134
- .classed("token-text", true)
135
  .style("width", "100%")
136
  .style("text-align", "center")
137
- .style("font-size", "14px")
138
- .style("padding-top", "5px")
139
- .text(tokens.join(" ")); // Display all tokens in a line
140
 
141
- return thumbnail.node();
142
  }
143
 
144
  function createAttentionThumbnail(data, layerIdx, headIdx) {
@@ -147,13 +142,14 @@ function createAttentionThumbnail(data, layerIdx, headIdx) {
147
 
148
  const width = 80;
149
  const tokenHeight = 15;
 
150
  const paddingTop = 10; //not sur e
151
  const paddingBottom = 0; //not sur e
152
  const height = paddingTop + tokens.length * tokenHeight; // Adjust height calculation
153
 
154
  const maxLineWidth = 4;
155
  const maxOpacity = 0.8;
156
- const xRight = width - padding;
157
 
158
  // Create a thumbnail container using D3.
159
  const thumbnail = d3.select(document.createElement("div"))
 
20
  displayHoverTokens(data, 0, 0);
21
  // Changed call here to pass the entire data object
22
  renderModelView(data);
23
+ createLargeThumbnail(data);
24
  } catch (error) {
25
  console.error('Error:', error);
26
  document.getElementById('output').innerText = 'Error processing text.';
 
52
 
53
  container.appendChild(gridContainer);
54
  }
55
+ function createLargeThumbnail(data) {
56
  const tokens = data.tokens;
57
  const attention = data.attention;
58
 
59
+ // Dimensions for the thumbnail
60
+ const width = 300;
61
+ const height = 300;
62
+ const textHeight = 30; // Space for token text (top and bottom)
63
+ const graphHeight = height - 2 * textHeight; // Remaining space for the graph
 
 
 
 
 
64
 
65
+ // Create the main container and append it to #thumbnailContainer
66
+ const container = d3.select("#thumbnailContainer")
67
  .append("div")
68
+ .attr("id", "largeThumbnail")
 
 
69
  .style("width", width + "px")
70
+ .style("height", height + "px")
71
  .style("border", "2px solid #ddd")
72
  .style("border-radius", "6px")
 
73
  .style("background", "#fff")
74
+ .style("padding", "10px")
75
  .style("display", "flex")
76
  .style("flex-direction", "column")
77
+ .style("align-items", "center");
 
78
 
79
+ // Add token text above the graph
80
+ container.append("div")
81
+ .classed("thumbnail-token-text", true)
82
  .style("width", "100%")
83
  .style("text-align", "center")
84
+ .style("font-size", "16px")
85
+ .style("margin-bottom", "5px")
86
+ .text(tokens.join(" "));
87
 
88
+ // Create the SVG container for the graph
89
+ const svg = container.append("svg")
90
  .attr("width", width)
91
  .attr("height", graphHeight)
92
  .attr("viewBox", `0 0 ${width} ${graphHeight}`)
93
  .attr("preserveAspectRatio", "none");
94
 
95
+ // Add header text inside the SVG
96
  svg.append("text")
97
  .attr("x", width / 2)
98
+ .attr("y", 20) // adjust as needed
99
  .attr("text-anchor", "middle")
100
+ .attr("font-size", "14px")
101
+ .text("Layer 1 - Head 1");
102
 
103
+ // Calculate spacing for tokens inside the SVG
104
+ const tokenCount = tokens.length;
105
+ const tokenSpacing = graphHeight / (tokenCount + 1);
106
 
107
+ // Draw attention lines for the first layer, first head
108
+ // (Assuming attention[0][0] is a matrix: tokens x tokens)
109
  attention[0][0].forEach((sourceWeights, sourceIdx) => {
110
  const rowMax = Math.max(...sourceWeights) || 1;
111
  sourceWeights.forEach((weight, targetIdx) => {
112
  if (weight > 0.01 && sourceIdx !== targetIdx) {
113
  const normalizedWeight = weight / rowMax;
114
  svg.append("line")
115
+ .attr("x1", 20) // left margin inside the SVG
116
+ .attr("y1", tokenSpacing * (sourceIdx + 1))
117
+ .attr("x2", width - 20) // right margin inside the SVG
118
+ .attr("y2", tokenSpacing * (targetIdx + 1))
119
  .attr("stroke", "#800000")
120
+ .attr("stroke-width", Math.max(0.5, normalizedWeight * 4))
121
+ .attr("opacity", Math.min(0.8, normalizedWeight * 2))
122
  .attr("stroke-linecap", "round");
123
  }
124
  });
125
  });
126
 
127
+ // Add token text below the graph
128
+ container.append("div")
129
+ .classed("thumbnail-token-text", true)
130
  .style("width", "100%")
131
  .style("text-align", "center")
132
+ .style("font-size", "16px")
133
+ .style("margin-top", "5px")
134
+ .text(tokens.join(" "));
135
 
136
+ return container.node();
137
  }
138
 
139
  function createAttentionThumbnail(data, layerIdx, headIdx) {
 
142
 
143
  const width = 80;
144
  const tokenHeight = 15;
145
+ const padding = 10;
146
  const paddingTop = 10; //not sur e
147
  const paddingBottom = 0; //not sur e
148
  const height = paddingTop + tokens.length * tokenHeight; // Adjust height calculation
149
 
150
  const maxLineWidth = 4;
151
  const maxOpacity = 0.8;
152
+ const xRight = width - 10;
153
 
154
  // Create a thumbnail container using D3.
155
  const thumbnail = d3.select(document.createElement("div"))