// Form submit handler 2.1 document.getElementById('textForm').addEventListener('submit', async (e) => { e.preventDefault(); const inputText = document.getElementById('inputText').value; try { const response = await fetch('/process', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ text: inputText }) }); if (!response.ok) { throw new Error('Network response was not ok'); } const data = await response.json(); // Use data.tokens and data.attention from your POST response // displayOutput(data); displayHoverTokens(data, 0, 0); // Changed call here to pass the entire data object renderModelView(data); createLargeThumbnail(data); } catch (error) { console.error('Error:', error); document.getElementById('output').innerText = 'Error processing text.'; } }); function renderModelView(data) { // Extract tokens and attention from the data object const tokens = data.tokens; const attention = data.attention; const container = document.getElementById("model_view_container"); if (!container) return; container.innerHTML = ""; const gridContainer = document.createElement("div"); gridContainer.style.display = "grid"; gridContainer.style.gridTemplateColumns = "repeat(12, 80px)"; gridContainer.style.gridGap = "10px"; gridContainer.style.padding = "20px"; // Loop over all 12 layers and 12 heads, passing the complete data object for (let layerIdx = 0; layerIdx < 12; layerIdx++) { for (let headIdx = 0; headIdx < 12; headIdx++) { const thumbnail = createAttentionThumbnail(data, layerIdx, headIdx); gridContainer.appendChild(thumbnail); } } container.appendChild(gridContainer); } function createLargeThumbnail(data) { const tokens = data.tokens; const attention = data.attention; // Dimensions const graphWidth = 300; const tokenAreaWidth = 60; const totalWidth = tokenAreaWidth * 2 + graphWidth; const graphHeight = 300; // Remove any existing large thumbnail so only the latest shows d3.select("#largeThumbnail").remove(); // Create main container as a flex row const container = d3.select("#thumbnailContainer") .append("div") .attr("id", "largeThumbnail") .style("width", totalWidth + "px") .style("height", graphHeight + "px") .style("border", "2px solid #ddd") .style("border-radius", "6px") .style("background", "#fff") .style("display", "flex") .style("align-items", "center") .style("justify-content", "space-between"); // Left token container const leftTokens = container.append("div") .attr("class", "tokens-left") .style("width", tokenAreaWidth + "px") .style("height", graphHeight + "px") .style("display", "flex") .style("flex-direction", "column") .style("justify-content", "space-around") .style("align-items", "flex-end"); tokens.forEach(token => { leftTokens.append("div") .style("font-size", "14px") .style("text-align", "right") .text(token); }); // Center SVG for the graph const svg = container.append("svg") .attr("width", graphWidth) .attr("height", graphHeight) .attr("viewBox", `0 0 ${graphWidth} ${graphHeight}`) .attr("preserveAspectRatio", "none"); // Right token container const rightTokens = container.append("div") .attr("class", "tokens-right") .style("width", tokenAreaWidth + "px") .style("height", graphHeight + "px") .style("display", "flex") .style("flex-direction", "column") .style("justify-content", "space-around") .style("align-items", "flex-start"); tokens.forEach(token => { rightTokens.append("div") .style("font-size", "14px") .style("text-align", "left") .text(token); }); // Optional: Add header text inside the SVG svg.append("text") .attr("x", graphWidth / 2) .attr("y", 20) .attr("text-anchor", "middle") .attr("font-size", "14px") .text("Layer 1 - Head 1"); // Compute vertical spacing (each token aligns with a node) const tokenCount = tokens.length; const tokenSpacing = graphHeight / (tokenCount + 1); // Draw attention lines for the first layer, first head (attention[0][0]) attention[0][0].forEach((sourceWeights, sourceIdx) => { const rowMax = Math.max(...sourceWeights) || 1; sourceWeights.forEach((weight, targetIdx) => { if (weight > 0.01 && sourceIdx !== targetIdx) { const normalizedWeight = weight / rowMax; svg.append("line") .attr("x1", 20) // left margin within the SVG .attr("y1", tokenSpacing * (sourceIdx + 1)) .attr("x2", graphWidth - 20) // right margin within the SVG .attr("y2", tokenSpacing * (targetIdx + 1)) .attr("stroke", "#800000") .attr("stroke-width", Math.max(0.5, normalizedWeight * 4)) .attr("opacity", Math.min(0.8, normalizedWeight * 2)) .attr("stroke-linecap", "round"); } }); }); return container.node(); } function createAttentionThumbnail(data, layerIdx, headIdx) { const tokens = data.tokens; const attention = data.attention; const width = 80; const tokenHeight = 15; const padding = 10; const paddingTop = 10; //not sur e const paddingBottom = 0; //not sur e const height = paddingTop + tokens.length * tokenHeight; // Adjust height calculation const maxLineWidth = 4; const maxOpacity = 0.8; const xRight = width - 10; // Create a thumbnail container using D3. const thumbnail = d3.select(document.createElement("div")) .classed("thumbnail", true) // .style("position", "relative") .style("height", height + "px") .style("width", width + "px") .style("border", "1px solid #ddd") .style("border-radius", "4px") .style("padding", `10px 0px 0px 0px`) .style("background", "#fff") .style("cursor", "pointer"); // Indicate clickability // Append an SVG container const svg = thumbnail.append("svg") .attr("width", width) .attr("height", height) .attr("viewBox", `0 0 ${width} ${height}`) // not sure .attr("preserveAspectRatio", "none"); // not sure // Add header text for layer and head svg.append("text") .attr("x", width / 2) .attr("y", paddingTop - 2) .attr("text-anchor", "middle") // not sure .attr("dominant-baseline", "middle") // not sure .attr("font-size", "10") .text(`L${layerIdx + 1} H${headIdx + 1}`); const yStart = paddingTop + 5; // Draw attention lines attention[layerIdx][headIdx].forEach((sourceWeights, sourceIdx) => { const rowMax = Math.max(...sourceWeights) || 1; sourceWeights.forEach((weight, targetIdx) => { if (weight > 0.01 && sourceIdx !== targetIdx) { const normalizedWeight = weight / rowMax; svg.append("line") .attr("x1", padding ) .attr("y1", yStart + sourceIdx * tokenHeight) .attr("x2", xRight ) .attr("y2", yStart + targetIdx * tokenHeight) .attr("stroke", "#800000") // Bordo red .attr("stroke-width", Math.max(0.5, normalizedWeight * maxLineWidth)) .attr("opacity", Math.min(maxOpacity, normalizedWeight * 2)) .attr("stroke-linecap", "round"); } }); }); thumbnail.on("click", function() { console.log(`Clicked: Layer ${layerIdx + 1}, Head ${headIdx + 1}`); d3.selectAll(".thumbnail").style("background", "#fff"); d3.select(this).style("background", "#ddd"); // Update displayed head and layer numbers d3.select("#display_head").text(headIdx + 1); d3.select("#display_layer").text(layerIdx + 1); // Update hover tokens displayHoverTokens(data, layerIdx, headIdx); }); return thumbnail.node(); // Return raw DOM node for appending } // Function to display the tokens and attention values // function displayOutput(data) { // const outputDiv = document.getElementById('output'); // outputDiv.innerHTML = ` //

Tokens

//
${JSON.stringify(data.tokens, null, 2)}
//

Attention

//
${JSON.stringify(data.attention, null, 2)}
// `; // } function renderTokens(tokens, attentionData, layer_idx, head_idx) { const container = document.getElementById('tokenContainer'); container.innerHTML = ""; tokens.forEach((token, index) => { const span = document.createElement('span'); span.textContent = token.replace("Ġ", "") + " "; span.style.fontSize = "32px"; span.addEventListener('mouseenter', () => { highlightAttention(index, attentionData, layer_idx, head_idx); }); span.addEventListener('mouseleave', () => { resetTokenSizes(); }); container.appendChild(span); }); } function displayHoverTokens(data, layer_idx, head_idx) { let tokens, attentionMatrix; if (!data.tokens || !data.attention) { tokens = ['This', 'is', 'a', 'test', '.']; // Create a dummy attention matrix if missing attentionMatrix = Array(12) .fill(null) .map(() => Array(12).fill(null).map(() => Array(tokens.length).fill(0))); } else { tokens = data.tokens; attentionMatrix = data.attention; } renderTokens(tokens, attentionMatrix, layer_idx, head_idx); } function resetTokenSizes() { const container = document.getElementById("tokenContainer"); Array.from(container.children).forEach((span) => { span.style.fontSize = "32px"; span.style.color = "#555"; }); } function highlightAttention(index, attentionData, layer_idx, head_idx) { const container = document.getElementById('tokenContainer'); const row = attentionData[layer_idx][head_idx][index]; if (!row) { console.warn(`No attention data for token index ${index}`); return; } const weights = row; if (!weights.length) { return; } // Find the maximum weight const maxWeight = Math.max(...weights) || 1; const baseFontSize = 32; const maxIncrease = 20; const maxIndex = weights.indexOf(maxWeight); Array.from(container.children).forEach((span, idx) => { const weight = weights[idx]; if (typeof weight === 'number') { const newFontSize = baseFontSize + (weight / maxWeight) * maxIncrease; span.style.fontSize = newFontSize + "px"; if (idx === maxIndex) { span.style.color = "#800000"; // Bordo red } else { span.style.color = "#555"; // Reset color } } else { // For tokens without a corresponding weight, reset styles. span.style.fontSize = baseFontSize + "px"; span.style.color = "#555"; } }); }