Spaces:
Running
Running
File size: 10,628 Bytes
740fd55 9c67f19 5918cee c5b0138 1f4a0b6 c5b0138 1f4a0b6 c5b0138 5918cee 90fb67f 4589db1 90fb67f d31e7ee 42370b7 362d318 90fb67f c5b0138 1f4a0b6 c5b0138 d06b7a5 d31e7ee 5918cee c8ab699 5918cee d31e7ee 5918cee d31e7ee 5918cee 42370b7 fc50382 2e082b4 fc50382 2e082b4 42370b7 fc50382 42370b7 2e082b4 fc50382 2e082b4 fc50382 2e082b4 42370b7 2e082b4 fc50382 2e082b4 fc50382 2e082b4 fc50382 2e082b4 fc50382 42370b7 fc50382 2e082b4 42370b7 fc50382 2e082b4 fc50382 2e082b4 42370b7 2e082b4 42370b7 fc50382 42370b7 fc50382 42370b7 fc50382 5918cee d31e7ee 00a52d1 42370b7 fc50382 4207df3 42370b7 00a52d1 71654ae 501ea17 fc50382 501ea17 fc50382 00a52d1 501ea17 5bdfcd7 00a52d1 fc50382 1058f09 00a52d1 501ea17 5bdfcd7 fc50382 00a52d1 501ea17 fc50382 00a52d1 e4a48bc 5918cee 5bdfcd7 e4a48bc 501ea17 1058f09 fc50382 1058f09 fc50382 d31e7ee 501ea17 5918cee 00a52d1 501ea17 00a52d1 2e082b4 00a52d1 2e082b4 501ea17 5918cee 00a52d1 5918cee d31e7ee c5b0138 a86908c 1f4a0b6 d06b7a5 ff1e5bf d06b7a5 ff1e5bf d06b7a5 ff1e5bf d06b7a5 4589db1 d06b7a5 ff1e5bf d06b7a5 d38275b d06b7a5 7e18603 d06b7a5 29c541d d06b7a5 29c541d b63462e 279b7e9 d06b7a5 29c541d 4589db1 29c541d 4589db1 29c541d e7425c5 29c541d 4589db1 29c541d 4589db1 e7425c5 4589db1 29c541d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 |
// Form submit handler 2.1
document.getElementById('textForm').addEventListener('submit', async (e) => {
e.preventDefault();
const inputText = document.getElementById('inputText').value;
try {
const response = await fetch('/process', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ text: inputText })
});
if (!response.ok) {
throw new Error('Network response was not ok');
}
const data = await response.json();
// Use data.tokens and data.attention from your POST response
// displayOutput(data);
displayHoverTokens(data, 0, 0);
// Changed call here to pass the entire data object
renderModelView(data);
createLargeThumbnail(data);
} catch (error) {
console.error('Error:', error);
document.getElementById('output').innerText = 'Error processing text.';
}
});
function renderModelView(data) {
// Extract tokens and attention from the data object
const tokens = data.tokens;
const attention = data.attention;
const container = document.getElementById("model_view_container");
if (!container) return;
container.innerHTML = "";
const gridContainer = document.createElement("div");
gridContainer.style.display = "grid";
gridContainer.style.gridTemplateColumns = "repeat(12, 80px)";
gridContainer.style.gridGap = "10px";
gridContainer.style.padding = "20px";
// Loop over all 12 layers and 12 heads, passing the complete data object
for (let layerIdx = 0; layerIdx < 12; layerIdx++) {
for (let headIdx = 0; headIdx < 12; headIdx++) {
const thumbnail = createAttentionThumbnail(data, layerIdx, headIdx);
gridContainer.appendChild(thumbnail);
}
}
container.appendChild(gridContainer);
}
function createLargeThumbnail(data) {
const tokens = data.tokens;
const attention = data.attention;
// Dimensions
const graphWidth = 300;
const tokenAreaWidth = 60;
const totalWidth = tokenAreaWidth * 2 + graphWidth;
const graphHeight = 300;
// Remove any existing large thumbnail so only the latest shows
d3.select("#largeThumbnail").remove();
// Create main container as a flex row
const container = d3.select("#thumbnailContainer")
.append("div")
.attr("id", "largeThumbnail")
.style("width", totalWidth + "px")
.style("height", graphHeight + "px")
.style("border", "2px solid #ddd")
.style("border-radius", "6px")
.style("background", "#fff")
.style("display", "flex")
.style("align-items", "center")
.style("justify-content", "space-between");
// Left token container
const leftTokens = container.append("div")
.attr("class", "tokens-left")
.style("width", tokenAreaWidth + "px")
.style("height", graphHeight + "px")
.style("display", "flex")
.style("flex-direction", "column")
.style("justify-content", "space-around")
.style("align-items", "flex-end");
tokens.forEach(token => {
leftTokens.append("div")
.style("font-size", "14px")
.style("text-align", "right")
.text(token);
});
// Center SVG for the graph
const svg = container.append("svg")
.attr("width", graphWidth)
.attr("height", graphHeight)
.attr("viewBox", `0 0 ${graphWidth} ${graphHeight}`)
.attr("preserveAspectRatio", "none");
// Right token container
const rightTokens = container.append("div")
.attr("class", "tokens-right")
.style("width", tokenAreaWidth + "px")
.style("height", graphHeight + "px")
.style("display", "flex")
.style("flex-direction", "column")
.style("justify-content", "space-around")
.style("align-items", "flex-start");
tokens.forEach(token => {
rightTokens.append("div")
.style("font-size", "14px")
.style("text-align", "left")
.text(token);
});
// Optional: Add header text inside the SVG
svg.append("text")
.attr("x", graphWidth / 2)
.attr("y", 20)
.attr("text-anchor", "middle")
.attr("font-size", "14px")
.text("Layer 1 - Head 1");
// Compute vertical spacing (each token aligns with a node)
const tokenCount = tokens.length;
const tokenSpacing = graphHeight / (tokenCount + 1);
// Draw attention lines for the first layer, first head (attention[0][0])
attention[0][0].forEach((sourceWeights, sourceIdx) => {
const rowMax = Math.max(...sourceWeights) || 1;
sourceWeights.forEach((weight, targetIdx) => {
if (weight > 0.01 && sourceIdx !== targetIdx) {
const normalizedWeight = weight / rowMax;
svg.append("line")
.attr("x1", 20) // left margin within the SVG
.attr("y1", tokenSpacing * (sourceIdx + 1))
.attr("x2", graphWidth - 20) // right margin within the SVG
.attr("y2", tokenSpacing * (targetIdx + 1))
.attr("stroke", "#800000")
.attr("stroke-width", Math.max(0.5, normalizedWeight * 4))
.attr("opacity", Math.min(0.8, normalizedWeight * 2))
.attr("stroke-linecap", "round");
}
});
});
return container.node();
}
function createAttentionThumbnail(data, layerIdx, headIdx) {
const tokens = data.tokens;
const attention = data.attention;
const width = 80;
const tokenHeight = 15;
const padding = 10;
const paddingTop = 10; //not sur e
const paddingBottom = 0; //not sur e
const height = paddingTop + tokens.length * tokenHeight; // Adjust height calculation
const maxLineWidth = 4;
const maxOpacity = 0.8;
const xRight = width - 10;
// Create a thumbnail container using D3.
const thumbnail = d3.select(document.createElement("div"))
.classed("thumbnail", true) //
.style("position", "relative")
.style("height", height + "px")
.style("width", width + "px")
.style("border", "1px solid #ddd")
.style("border-radius", "4px")
.style("padding", `10px 0px 0px 0px`)
.style("background", "#fff")
.style("cursor", "pointer"); // Indicate clickability
// Append an SVG container
const svg = thumbnail.append("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`) // not sure
.attr("preserveAspectRatio", "none"); // not sure
// Add header text for layer and head
svg.append("text")
.attr("x", width / 2)
.attr("y", paddingTop - 2)
.attr("text-anchor", "middle") // not sure
.attr("dominant-baseline", "middle") // not sure
.attr("font-size", "10")
.text(`L${layerIdx + 1} H${headIdx + 1}`);
const yStart = paddingTop + 5;
// Draw attention lines
attention[layerIdx][headIdx].forEach((sourceWeights, sourceIdx) => {
const rowMax = Math.max(...sourceWeights) || 1;
sourceWeights.forEach((weight, targetIdx) => {
if (weight > 0.01 && sourceIdx !== targetIdx) {
const normalizedWeight = weight / rowMax;
svg.append("line")
.attr("x1", padding )
.attr("y1", yStart + sourceIdx * tokenHeight)
.attr("x2", xRight )
.attr("y2", yStart + targetIdx * tokenHeight)
.attr("stroke", "#800000") // Bordo red
.attr("stroke-width", Math.max(0.5, normalizedWeight * maxLineWidth))
.attr("opacity", Math.min(maxOpacity, normalizedWeight * 2))
.attr("stroke-linecap", "round");
}
});
});
thumbnail.on("click", function() {
console.log(`Clicked: Layer ${layerIdx + 1}, Head ${headIdx + 1}`);
d3.selectAll(".thumbnail").style("background", "#fff");
d3.select(this).style("background", "#ddd");
// Update displayed head and layer numbers
d3.select("#display_head").text(headIdx + 1);
d3.select("#display_layer").text(layerIdx + 1);
// Update hover tokens
displayHoverTokens(data, layerIdx, headIdx);
});
return thumbnail.node(); // Return raw DOM node for appending
}
// Function to display the tokens and attention values
// function displayOutput(data) {
// const outputDiv = document.getElementById('output');
// outputDiv.innerHTML = `
// <h2>Tokens</h2>
// <pre>${JSON.stringify(data.tokens, null, 2)}</pre>
// <h2>Attention</h2>
// <pre>${JSON.stringify(data.attention, null, 2)}</pre>
// `;
// }
function renderTokens(tokens, attentionData, layer_idx, head_idx) {
const container = document.getElementById('tokenContainer');
container.innerHTML = "";
tokens.forEach((token, index) => {
const span = document.createElement('span');
span.textContent = token.replace("Ġ", "") + " ";
span.style.fontSize = "32px";
span.addEventListener('mouseenter', () => {
highlightAttention(index, attentionData, layer_idx, head_idx);
});
span.addEventListener('mouseleave', () => {
resetTokenSizes();
});
container.appendChild(span);
});
}
function displayHoverTokens(data, layer_idx, head_idx) {
let tokens, attentionMatrix;
if (!data.tokens || !data.attention) {
tokens = ['This', 'is', 'a', 'test', '.'];
// Create a dummy attention matrix if missing
attentionMatrix = Array(12)
.fill(null)
.map(() => Array(12).fill(null).map(() => Array(tokens.length).fill(0)));
} else {
tokens = data.tokens;
attentionMatrix = data.attention;
}
renderTokens(tokens, attentionMatrix, layer_idx, head_idx);
}
function resetTokenSizes() {
const container = document.getElementById("tokenContainer");
Array.from(container.children).forEach((span) => {
span.style.fontSize = "32px";
span.style.color = "#555";
});
}
function highlightAttention(index, attentionData, layer_idx, head_idx) {
const container = document.getElementById('tokenContainer');
const row = attentionData[layer_idx][head_idx][index];
if (!row) {
console.warn(`No attention data for token index ${index}`);
return;
}
const weights = row;
if (!weights.length) {
return;
}
// Find the maximum weight
const maxWeight = Math.max(...weights) || 1;
const baseFontSize = 32;
const maxIncrease = 20;
const maxIndex = weights.indexOf(maxWeight);
Array.from(container.children).forEach((span, idx) => {
const weight = weights[idx];
if (typeof weight === 'number') {
const newFontSize = baseFontSize + (weight / maxWeight) * maxIncrease;
span.style.fontSize = newFontSize + "px";
if (idx === maxIndex) {
span.style.color = "#800000"; // Bordo red
} else {
span.style.color = "#555"; // Reset color
}
} else {
// For tokens without a corresponding weight, reset styles.
span.style.fontSize = baseFontSize + "px";
span.style.color = "#555";
}
});
}
|