Spaces:
Running
Running
| // Form submit handler 2.1 | |
| document.getElementById('textForm').addEventListener('submit', async (e) => { | |
| e.preventDefault(); | |
| const inputText = document.getElementById('inputText').value; | |
| try { | |
| const response = await fetch('/process', { | |
| method: 'POST', | |
| headers: { 'Content-Type': 'application/json' }, | |
| body: JSON.stringify({ text: inputText }) | |
| }); | |
| if (!response.ok) { | |
| throw new Error('Network response was not ok'); | |
| } | |
| const data = await response.json(); | |
| // Use data.tokens and data.attention from your POST response | |
| // displayOutput(data); | |
| displayHoverTokens(data, 0, 0); | |
| // Changed call here to pass the entire data object | |
| renderModelView(data); | |
| createLargeThumbnail(data); | |
| } catch (error) { | |
| console.error('Error:', error); | |
| document.getElementById('output').innerText = 'Error processing text.'; | |
| } | |
| }); | |
| function renderModelView(data) { | |
| // Extract tokens and attention from the data object | |
| const tokens = data.tokens; | |
| const attention = data.attention; | |
| const container = document.getElementById("model_view_container"); | |
| if (!container) return; | |
| container.innerHTML = ""; | |
| const gridContainer = document.createElement("div"); | |
| gridContainer.style.display = "grid"; | |
| gridContainer.style.gridTemplateColumns = "repeat(12, 80px)"; | |
| gridContainer.style.gridGap = "10px"; | |
| gridContainer.style.padding = "20px"; | |
| // Loop over all 12 layers and 12 heads, passing the complete data object | |
| for (let layerIdx = 0; layerIdx < 12; layerIdx++) { | |
| for (let headIdx = 0; headIdx < 12; headIdx++) { | |
| const thumbnail = createAttentionThumbnail(data, layerIdx, headIdx); | |
| gridContainer.appendChild(thumbnail); | |
| } | |
| } | |
| container.appendChild(gridContainer); | |
| } | |
| function createLargeThumbnail(data) { | |
| const tokens = data.tokens; | |
| const attention = data.attention; | |
| // Dimensions | |
| const graphWidth = 300; | |
| const tokenAreaWidth = 60; | |
| const totalWidth = tokenAreaWidth * 2 + graphWidth; | |
| const graphHeight = 300; | |
| // Remove any existing large thumbnail so only the latest shows | |
| d3.select("#largeThumbnail").remove(); | |
| // Create main container as a flex row | |
| const container = d3.select("#thumbnailContainer") | |
| .append("div") | |
| .attr("id", "largeThumbnail") | |
| .style("width", totalWidth + "px") | |
| .style("height", graphHeight + "px") | |
| .style("border", "2px solid #ddd") | |
| .style("border-radius", "6px") | |
| .style("background", "#fff") | |
| .style("display", "flex") | |
| .style("align-items", "center") | |
| .style("justify-content", "space-between"); | |
| // Left token container | |
| const leftTokens = container.append("div") | |
| .attr("class", "tokens-left") | |
| .style("width", tokenAreaWidth + "px") | |
| .style("height", graphHeight + "px") | |
| .style("display", "flex") | |
| .style("flex-direction", "column") | |
| .style("justify-content", "space-around") | |
| .style("align-items", "flex-end"); | |
| tokens.forEach(token => { | |
| leftTokens.append("div") | |
| .style("font-size", "14px") | |
| .style("text-align", "right") | |
| .text(token); | |
| }); | |
| // Center SVG for the graph | |
| const svg = container.append("svg") | |
| .attr("width", graphWidth) | |
| .attr("height", graphHeight) | |
| .attr("viewBox", `0 0 ${graphWidth} ${graphHeight}`) | |
| .attr("preserveAspectRatio", "none"); | |
| // Right token container | |
| const rightTokens = container.append("div") | |
| .attr("class", "tokens-right") | |
| .style("width", tokenAreaWidth + "px") | |
| .style("height", graphHeight + "px") | |
| .style("display", "flex") | |
| .style("flex-direction", "column") | |
| .style("justify-content", "space-around") | |
| .style("align-items", "flex-start"); | |
| tokens.forEach(token => { | |
| rightTokens.append("div") | |
| .style("font-size", "14px") | |
| .style("text-align", "left") | |
| .text(token); | |
| }); | |
| // Optional: Add header text inside the SVG | |
| svg.append("text") | |
| .attr("x", graphWidth / 2) | |
| .attr("y", 20) | |
| .attr("text-anchor", "middle") | |
| .attr("font-size", "14px") | |
| .text("Layer 1 - Head 1"); | |
| // Compute vertical spacing (each token aligns with a node) | |
| const tokenCount = tokens.length; | |
| const tokenSpacing = graphHeight / (tokenCount + 1); | |
| // Draw attention lines for the first layer, first head (attention[0][0]) | |
| attention[0][0].forEach((sourceWeights, sourceIdx) => { | |
| const rowMax = Math.max(...sourceWeights) || 1; | |
| sourceWeights.forEach((weight, targetIdx) => { | |
| if (weight > 0.01 && sourceIdx !== targetIdx) { | |
| const normalizedWeight = weight / rowMax; | |
| svg.append("line") | |
| .attr("x1", 20) // left margin within the SVG | |
| .attr("y1", tokenSpacing * (sourceIdx + 1)) | |
| .attr("x2", graphWidth - 20) // right margin within the SVG | |
| .attr("y2", tokenSpacing * (targetIdx + 1)) | |
| .attr("stroke", "#800000") | |
| .attr("stroke-width", Math.max(0.5, normalizedWeight * 4)) | |
| .attr("opacity", Math.min(0.8, normalizedWeight * 2)) | |
| .attr("stroke-linecap", "round"); | |
| } | |
| }); | |
| }); | |
| return container.node(); | |
| } | |
| function createAttentionThumbnail(data, layerIdx, headIdx) { | |
| const tokens = data.tokens; | |
| const attention = data.attention; | |
| const width = 80; | |
| const tokenHeight = 15; | |
| const padding = 10; | |
| const paddingTop = 10; //not sur e | |
| const paddingBottom = 0; //not sur e | |
| const height = paddingTop + tokens.length * tokenHeight; // Adjust height calculation | |
| const maxLineWidth = 4; | |
| const maxOpacity = 0.8; | |
| const xRight = width - 10; | |
| // Create a thumbnail container using D3. | |
| const thumbnail = d3.select(document.createElement("div")) | |
| .classed("thumbnail", true) // | |
| .style("position", "relative") | |
| .style("height", height + "px") | |
| .style("width", width + "px") | |
| .style("border", "1px solid #ddd") | |
| .style("border-radius", "4px") | |
| .style("padding", `10px 0px 0px 0px`) | |
| .style("background", "#fff") | |
| .style("cursor", "pointer"); // Indicate clickability | |
| // Append an SVG container | |
| const svg = thumbnail.append("svg") | |
| .attr("width", width) | |
| .attr("height", height) | |
| .attr("viewBox", `0 0 ${width} ${height}`) // not sure | |
| .attr("preserveAspectRatio", "none"); // not sure | |
| // Add header text for layer and head | |
| svg.append("text") | |
| .attr("x", width / 2) | |
| .attr("y", paddingTop - 2) | |
| .attr("text-anchor", "middle") // not sure | |
| .attr("dominant-baseline", "middle") // not sure | |
| .attr("font-size", "10") | |
| .text(`L${layerIdx + 1} H${headIdx + 1}`); | |
| const yStart = paddingTop + 5; | |
| // Draw attention lines | |
| attention[layerIdx][headIdx].forEach((sourceWeights, sourceIdx) => { | |
| const rowMax = Math.max(...sourceWeights) || 1; | |
| sourceWeights.forEach((weight, targetIdx) => { | |
| if (weight > 0.01 && sourceIdx !== targetIdx) { | |
| const normalizedWeight = weight / rowMax; | |
| svg.append("line") | |
| .attr("x1", padding ) | |
| .attr("y1", yStart + sourceIdx * tokenHeight) | |
| .attr("x2", xRight ) | |
| .attr("y2", yStart + targetIdx * tokenHeight) | |
| .attr("stroke", "#800000") // Bordo red | |
| .attr("stroke-width", Math.max(0.5, normalizedWeight * maxLineWidth)) | |
| .attr("opacity", Math.min(maxOpacity, normalizedWeight * 2)) | |
| .attr("stroke-linecap", "round"); | |
| } | |
| }); | |
| }); | |
| thumbnail.on("click", function() { | |
| console.log(`Clicked: Layer ${layerIdx + 1}, Head ${headIdx + 1}`); | |
| d3.selectAll(".thumbnail").style("background", "#fff"); | |
| d3.select(this).style("background", "#ddd"); | |
| // Update displayed head and layer numbers | |
| d3.select("#display_head").text(headIdx + 1); | |
| d3.select("#display_layer").text(layerIdx + 1); | |
| // Update hover tokens | |
| displayHoverTokens(data, layerIdx, headIdx); | |
| }); | |
| return thumbnail.node(); // Return raw DOM node for appending | |
| } | |
| // Function to display the tokens and attention values | |
| // function displayOutput(data) { | |
| // const outputDiv = document.getElementById('output'); | |
| // outputDiv.innerHTML = ` | |
| // <h2>Tokens</h2> | |
| // <pre>${JSON.stringify(data.tokens, null, 2)}</pre> | |
| // <h2>Attention</h2> | |
| // <pre>${JSON.stringify(data.attention, null, 2)}</pre> | |
| // `; | |
| // } | |
| function renderTokens(tokens, attentionData, layer_idx, head_idx) { | |
| const container = document.getElementById('tokenContainer'); | |
| container.innerHTML = ""; | |
| tokens.forEach((token, index) => { | |
| const span = document.createElement('span'); | |
| span.textContent = token.replace("Ġ", "") + " "; | |
| span.style.fontSize = "32px"; | |
| span.addEventListener('mouseenter', () => { | |
| highlightAttention(index, attentionData, layer_idx, head_idx); | |
| }); | |
| span.addEventListener('mouseleave', () => { | |
| resetTokenSizes(); | |
| }); | |
| container.appendChild(span); | |
| }); | |
| } | |
| function displayHoverTokens(data, layer_idx, head_idx) { | |
| let tokens, attentionMatrix; | |
| if (!data.tokens || !data.attention) { | |
| tokens = ['This', 'is', 'a', 'test', '.']; | |
| // Create a dummy attention matrix if missing | |
| attentionMatrix = Array(12) | |
| .fill(null) | |
| .map(() => Array(12).fill(null).map(() => Array(tokens.length).fill(0))); | |
| } else { | |
| tokens = data.tokens; | |
| attentionMatrix = data.attention; | |
| } | |
| renderTokens(tokens, attentionMatrix, layer_idx, head_idx); | |
| } | |
| function resetTokenSizes() { | |
| const container = document.getElementById("tokenContainer"); | |
| Array.from(container.children).forEach((span) => { | |
| span.style.fontSize = "32px"; | |
| span.style.color = "#555"; | |
| }); | |
| } | |
| function highlightAttention(index, attentionData, layer_idx, head_idx) { | |
| const container = document.getElementById('tokenContainer'); | |
| const row = attentionData[layer_idx][head_idx][index]; | |
| if (!row) { | |
| console.warn(`No attention data for token index ${index}`); | |
| return; | |
| } | |
| const weights = row; | |
| if (!weights.length) { | |
| return; | |
| } | |
| // Find the maximum weight | |
| const maxWeight = Math.max(...weights) || 1; | |
| const baseFontSize = 32; | |
| const maxIncrease = 20; | |
| const maxIndex = weights.indexOf(maxWeight); | |
| Array.from(container.children).forEach((span, idx) => { | |
| const weight = weights[idx]; | |
| if (typeof weight === 'number') { | |
| const newFontSize = baseFontSize + (weight / maxWeight) * maxIncrease; | |
| span.style.fontSize = newFontSize + "px"; | |
| if (idx === maxIndex) { | |
| span.style.color = "#800000"; // Bordo red | |
| } else { | |
| span.style.color = "#555"; // Reset color | |
| } | |
| } else { | |
| // For tokens without a corresponding weight, reset styles. | |
| span.style.fontSize = baseFontSize + "px"; | |
| span.style.color = "#555"; | |
| } | |
| }); | |
| } | |