DeeperGaze / frontend /script.js
Paar, F. (Ferdinand)
thubnail filled and layer head box next to each other
2e082b4
// Form submit handler 2.1
document.getElementById('textForm').addEventListener('submit', async (e) => {
e.preventDefault();
const inputText = document.getElementById('inputText').value;
try {
const response = await fetch('/process', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ text: inputText })
});
if (!response.ok) {
throw new Error('Network response was not ok');
}
const data = await response.json();
// Use data.tokens and data.attention from your POST response
// displayOutput(data);
displayHoverTokens(data, 0, 0);
// Changed call here to pass the entire data object
renderModelView(data);
createLargeThumbnail(data);
} catch (error) {
console.error('Error:', error);
document.getElementById('output').innerText = 'Error processing text.';
}
});
function renderModelView(data) {
// Extract tokens and attention from the data object
const tokens = data.tokens;
const attention = data.attention;
const container = document.getElementById("model_view_container");
if (!container) return;
container.innerHTML = "";
const gridContainer = document.createElement("div");
gridContainer.style.display = "grid";
gridContainer.style.gridTemplateColumns = "repeat(12, 80px)";
gridContainer.style.gridGap = "10px";
gridContainer.style.padding = "20px";
// Loop over all 12 layers and 12 heads, passing the complete data object
for (let layerIdx = 0; layerIdx < 12; layerIdx++) {
for (let headIdx = 0; headIdx < 12; headIdx++) {
const thumbnail = createAttentionThumbnail(data, layerIdx, headIdx);
gridContainer.appendChild(thumbnail);
}
}
container.appendChild(gridContainer);
}
function createLargeThumbnail(data) {
const tokens = data.tokens;
const attention = data.attention;
// Dimensions
const graphWidth = 300;
const tokenAreaWidth = 60;
const totalWidth = tokenAreaWidth * 2 + graphWidth;
const graphHeight = 300;
// Remove any existing large thumbnail so only the latest shows
d3.select("#largeThumbnail").remove();
// Create main container as a flex row
const container = d3.select("#thumbnailContainer")
.append("div")
.attr("id", "largeThumbnail")
.style("width", totalWidth + "px")
.style("height", graphHeight + "px")
.style("border", "2px solid #ddd")
.style("border-radius", "6px")
.style("background", "#fff")
.style("display", "flex")
.style("align-items", "center")
.style("justify-content", "space-between");
// Left token container
const leftTokens = container.append("div")
.attr("class", "tokens-left")
.style("width", tokenAreaWidth + "px")
.style("height", graphHeight + "px")
.style("display", "flex")
.style("flex-direction", "column")
.style("justify-content", "space-around")
.style("align-items", "flex-end");
tokens.forEach(token => {
leftTokens.append("div")
.style("font-size", "14px")
.style("text-align", "right")
.text(token);
});
// Center SVG for the graph
const svg = container.append("svg")
.attr("width", graphWidth)
.attr("height", graphHeight)
.attr("viewBox", `0 0 ${graphWidth} ${graphHeight}`)
.attr("preserveAspectRatio", "none");
// Right token container
const rightTokens = container.append("div")
.attr("class", "tokens-right")
.style("width", tokenAreaWidth + "px")
.style("height", graphHeight + "px")
.style("display", "flex")
.style("flex-direction", "column")
.style("justify-content", "space-around")
.style("align-items", "flex-start");
tokens.forEach(token => {
rightTokens.append("div")
.style("font-size", "14px")
.style("text-align", "left")
.text(token);
});
// Optional: Add header text inside the SVG
svg.append("text")
.attr("x", graphWidth / 2)
.attr("y", 20)
.attr("text-anchor", "middle")
.attr("font-size", "14px")
.text("Layer 1 - Head 1");
// Compute vertical spacing (each token aligns with a node)
const tokenCount = tokens.length;
const tokenSpacing = graphHeight / (tokenCount + 1);
// Draw attention lines for the first layer, first head (attention[0][0])
attention[0][0].forEach((sourceWeights, sourceIdx) => {
const rowMax = Math.max(...sourceWeights) || 1;
sourceWeights.forEach((weight, targetIdx) => {
if (weight > 0.01 && sourceIdx !== targetIdx) {
const normalizedWeight = weight / rowMax;
svg.append("line")
.attr("x1", 20) // left margin within the SVG
.attr("y1", tokenSpacing * (sourceIdx + 1))
.attr("x2", graphWidth - 20) // right margin within the SVG
.attr("y2", tokenSpacing * (targetIdx + 1))
.attr("stroke", "#800000")
.attr("stroke-width", Math.max(0.5, normalizedWeight * 4))
.attr("opacity", Math.min(0.8, normalizedWeight * 2))
.attr("stroke-linecap", "round");
}
});
});
return container.node();
}
function createAttentionThumbnail(data, layerIdx, headIdx) {
const tokens = data.tokens;
const attention = data.attention;
const width = 80;
const tokenHeight = 15;
const padding = 10;
const paddingTop = 10; //not sur e
const paddingBottom = 0; //not sur e
const height = paddingTop + tokens.length * tokenHeight; // Adjust height calculation
const maxLineWidth = 4;
const maxOpacity = 0.8;
const xRight = width - 10;
// Create a thumbnail container using D3.
const thumbnail = d3.select(document.createElement("div"))
.classed("thumbnail", true) //
.style("position", "relative")
.style("height", height + "px")
.style("width", width + "px")
.style("border", "1px solid #ddd")
.style("border-radius", "4px")
.style("padding", `10px 0px 0px 0px`)
.style("background", "#fff")
.style("cursor", "pointer"); // Indicate clickability
// Append an SVG container
const svg = thumbnail.append("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`) // not sure
.attr("preserveAspectRatio", "none"); // not sure
// Add header text for layer and head
svg.append("text")
.attr("x", width / 2)
.attr("y", paddingTop - 2)
.attr("text-anchor", "middle") // not sure
.attr("dominant-baseline", "middle") // not sure
.attr("font-size", "10")
.text(`L${layerIdx + 1} H${headIdx + 1}`);
const yStart = paddingTop + 5;
// Draw attention lines
attention[layerIdx][headIdx].forEach((sourceWeights, sourceIdx) => {
const rowMax = Math.max(...sourceWeights) || 1;
sourceWeights.forEach((weight, targetIdx) => {
if (weight > 0.01 && sourceIdx !== targetIdx) {
const normalizedWeight = weight / rowMax;
svg.append("line")
.attr("x1", padding )
.attr("y1", yStart + sourceIdx * tokenHeight)
.attr("x2", xRight )
.attr("y2", yStart + targetIdx * tokenHeight)
.attr("stroke", "#800000") // Bordo red
.attr("stroke-width", Math.max(0.5, normalizedWeight * maxLineWidth))
.attr("opacity", Math.min(maxOpacity, normalizedWeight * 2))
.attr("stroke-linecap", "round");
}
});
});
thumbnail.on("click", function() {
console.log(`Clicked: Layer ${layerIdx + 1}, Head ${headIdx + 1}`);
d3.selectAll(".thumbnail").style("background", "#fff");
d3.select(this).style("background", "#ddd");
// Update displayed head and layer numbers
d3.select("#display_head").text(headIdx + 1);
d3.select("#display_layer").text(layerIdx + 1);
// Update hover tokens
displayHoverTokens(data, layerIdx, headIdx);
});
return thumbnail.node(); // Return raw DOM node for appending
}
// Function to display the tokens and attention values
// function displayOutput(data) {
// const outputDiv = document.getElementById('output');
// outputDiv.innerHTML = `
// <h2>Tokens</h2>
// <pre>${JSON.stringify(data.tokens, null, 2)}</pre>
// <h2>Attention</h2>
// <pre>${JSON.stringify(data.attention, null, 2)}</pre>
// `;
// }
function renderTokens(tokens, attentionData, layer_idx, head_idx) {
const container = document.getElementById('tokenContainer');
container.innerHTML = "";
tokens.forEach((token, index) => {
const span = document.createElement('span');
span.textContent = token.replace("Ġ", "") + " ";
span.style.fontSize = "32px";
span.addEventListener('mouseenter', () => {
highlightAttention(index, attentionData, layer_idx, head_idx);
});
span.addEventListener('mouseleave', () => {
resetTokenSizes();
});
container.appendChild(span);
});
}
function displayHoverTokens(data, layer_idx, head_idx) {
let tokens, attentionMatrix;
if (!data.tokens || !data.attention) {
tokens = ['This', 'is', 'a', 'test', '.'];
// Create a dummy attention matrix if missing
attentionMatrix = Array(12)
.fill(null)
.map(() => Array(12).fill(null).map(() => Array(tokens.length).fill(0)));
} else {
tokens = data.tokens;
attentionMatrix = data.attention;
}
renderTokens(tokens, attentionMatrix, layer_idx, head_idx);
}
function resetTokenSizes() {
const container = document.getElementById("tokenContainer");
Array.from(container.children).forEach((span) => {
span.style.fontSize = "32px";
span.style.color = "#555";
});
}
function highlightAttention(index, attentionData, layer_idx, head_idx) {
const container = document.getElementById('tokenContainer');
const row = attentionData[layer_idx][head_idx][index];
if (!row) {
console.warn(`No attention data for token index ${index}`);
return;
}
const weights = row;
if (!weights.length) {
return;
}
// Find the maximum weight
const maxWeight = Math.max(...weights) || 1;
const baseFontSize = 32;
const maxIncrease = 20;
const maxIndex = weights.indexOf(maxWeight);
Array.from(container.children).forEach((span, idx) => {
const weight = weights[idx];
if (typeof weight === 'number') {
const newFontSize = baseFontSize + (weight / maxWeight) * maxIncrease;
span.style.fontSize = newFontSize + "px";
if (idx === maxIndex) {
span.style.color = "#800000"; // Bordo red
} else {
span.style.color = "#555"; // Reset color
}
} else {
// For tokens without a corresponding weight, reset styles.
span.style.fontSize = baseFontSize + "px";
span.style.color = "#555";
}
});
}