Paar, F. (Ferdinand) commited on
Commit
00a52d1
·
1 Parent(s): d38275b

mid is cool

Browse files
Files changed (1) hide show
  1. frontend/script.js +35 -24
frontend/script.js CHANGED
@@ -53,54 +53,55 @@ function renderModelView(data) {
53
  }
54
 
55
  function createAttentionThumbnail(data, layerIdx, headIdx) {
56
- // Extract tokens and attention from the data object
57
  const tokens = data.tokens;
58
  const attention = data.attention;
59
-
60
- const width = 80;
61
- const tokenHeight = 15;
62
- const padding = 10;
63
- // Compute the thumbnail height dynamically based on the number of tokens.
64
  const height = padding * 2 + tokens.length * tokenHeight;
65
  const maxLineWidth = 4;
66
  const maxOpacity = 0.8;
67
-
68
- // Compute the right-side x-coordinate numerically.
69
  const xRight = width - padding;
70
-
71
  // Create a thumbnail container using D3.
72
  const thumbnail = d3.select(document.createElement("div"))
 
73
  .style("position", "relative")
74
  .style("height", height + "px")
75
  .style("width", width + "px")
76
  .style("border", "1px solid #ddd")
77
  .style("border-radius", "4px")
78
  .style("padding", "5px")
79
- .style("background", "#fff");
80
-
81
- // Append an SVG container with fixed dimensions.
 
82
  const svg = thumbnail.append("svg")
83
  .attr("width", width)
84
- .attr("height", height);
85
-
86
- // Add header text (e.g., "L4 H4") to show the layer and head number.
 
 
87
  svg.append("text")
88
  .attr("x", width / 2)
89
  .attr("y", 15)
90
- .attr("text-anchor", "middle")
 
91
  .attr("font-size", "10")
92
  .text(`L${layerIdx + 1} H${headIdx + 1}`);
93
-
94
- // Draw attention lines with per-row normalization.
95
  attention[layerIdx][headIdx].forEach((sourceWeights, sourceIdx) => {
96
  const rowMax = Math.max(...sourceWeights) || 1;
97
  sourceWeights.forEach((weight, targetIdx) => {
98
  if (weight > 0.01 && sourceIdx !== targetIdx) {
99
  const normalizedWeight = weight / rowMax;
100
  svg.append("line")
101
- .attr("x1", padding)
102
  .attr("y1", padding + sourceIdx * tokenHeight - 5)
103
- .attr("x2", xRight)
104
  .attr("y2", padding + targetIdx * tokenHeight - 5)
105
  .attr("stroke", "#800000") // Bordo red
106
  .attr("stroke-width", Math.max(0.5, normalizedWeight * maxLineWidth))
@@ -109,15 +110,25 @@ function createAttentionThumbnail(data, layerIdx, headIdx) {
109
  }
110
  });
111
  });
112
-
113
- // Click handler remains unchanged, using the passed-in data object.
114
  thumbnail.on("click", function() {
 
 
 
 
 
 
 
 
 
115
  d3.select("#display_head .number-box").text(headIdx + 1);
116
  d3.select("#display_layer .number-box").text(layerIdx + 1);
 
 
117
  displayHoverTokens(data, layerIdx, headIdx);
118
  });
119
-
120
- return thumbnail.node();
121
  }
122
 
123
  // Function to display the tokens and attention values
 
53
  }
54
 
55
  function createAttentionThumbnail(data, layerIdx, headIdx) {
 
56
  const tokens = data.tokens;
57
  const attention = data.attention;
58
+
59
+ const width = 80;
60
+ const tokenHeight = 15;
61
+ const padding = 10;
 
62
  const height = padding * 2 + tokens.length * tokenHeight;
63
  const maxLineWidth = 4;
64
  const maxOpacity = 0.8;
 
 
65
  const xRight = width - padding;
66
+
67
  // Create a thumbnail container using D3.
68
  const thumbnail = d3.select(document.createElement("div"))
69
+ .classed("thumbnail", true) // ✅ Ensure correct class
70
  .style("position", "relative")
71
  .style("height", height + "px")
72
  .style("width", width + "px")
73
  .style("border", "1px solid #ddd")
74
  .style("border-radius", "4px")
75
  .style("padding", "5px")
76
+ .style("background", "#fff")
77
+ .style("cursor", "pointer"); // Indicate clickability
78
+
79
+ // Append an SVG container
80
  const svg = thumbnail.append("svg")
81
  .attr("width", width)
82
+ .attr("height", height)
83
+ .attr("viewBox", `0 0 ${width} ${height}`) // not sure
84
+ .attr("preserveAspectRatio", "xMidYMid meet"); // not sure
85
+
86
+ // Add header text for layer and head
87
  svg.append("text")
88
  .attr("x", width / 2)
89
  .attr("y", 15)
90
+ .attr("text-anchor", "middle") // not sure
91
+ .attr("dominant-baseline", "middle") // not sure
92
  .attr("font-size", "10")
93
  .text(`L${layerIdx + 1} H${headIdx + 1}`);
94
+
95
+ // Draw attention lines
96
  attention[layerIdx][headIdx].forEach((sourceWeights, sourceIdx) => {
97
  const rowMax = Math.max(...sourceWeights) || 1;
98
  sourceWeights.forEach((weight, targetIdx) => {
99
  if (weight > 0.01 && sourceIdx !== targetIdx) {
100
  const normalizedWeight = weight / rowMax;
101
  svg.append("line")
102
+ .attr("x1", padding + 10)
103
  .attr("y1", padding + sourceIdx * tokenHeight - 5)
104
+ .attr("x2", xRight - 10)
105
  .attr("y2", padding + targetIdx * tokenHeight - 5)
106
  .attr("stroke", "#800000") // Bordo red
107
  .attr("stroke-width", Math.max(0.5, normalizedWeight * maxLineWidth))
 
110
  }
111
  });
112
  });
113
+
 
114
  thumbnail.on("click", function() {
115
+ console.log(`Clicked: Layer ${layerIdx + 1}, Head ${headIdx + 1}`);
116
+
117
+ // Remove background from all thumbnails
118
+ d3.selectAll(".thumbnail").style("background", "#fff");
119
+
120
+ // Set background for clicked thumbnail
121
+ d3.select(this).style("background", "#ddd");
122
+
123
+ // Update displayed head/layer numbers
124
  d3.select("#display_head .number-box").text(headIdx + 1);
125
  d3.select("#display_layer .number-box").text(layerIdx + 1);
126
+
127
+ // Display updated hover tokens
128
  displayHoverTokens(data, layerIdx, headIdx);
129
  });
130
+
131
+ return thumbnail.node(); // Return raw DOM node for appending
132
  }
133
 
134
  // Function to display the tokens and attention values