|
|
|
|
|
|
|
|
async () => { |
|
|
|
|
|
|
|
|
|
|
|
globalThis.testFn = () => { |
|
|
document.getElementById('demo').innerHTML = "Hello?" |
|
|
}; |
|
|
|
|
|
const d37 = await import("https://cdn.jsdelivr.net/npm/d3@7/+esm"); |
|
|
const d3 = await import("https://cdn.jsdelivr.net/npm/d3@5/+esm"); |
|
|
const $ = await import("https://cdn.jsdelivr.net/npm/jquery@3.7.1/dist/jquery.min.js"); |
|
|
globalThis.$ = $; |
|
|
|
|
|
globalThis.d3 = d3; |
|
|
|
|
|
globalThis.d3Fn = () => { |
|
|
d3.select('#viz').append('svg') |
|
|
.append('rect') |
|
|
.attr('width', 50) |
|
|
.attr('height', 50) |
|
|
.attr('fill', 'black') |
|
|
.on('mouseover', function(){d3.select(this).attr('fill', 'red')}) |
|
|
.on('mouseout', function(){d3.select(this).attr('fill', 'black')}); |
|
|
|
|
|
}; |
|
|
|
|
|
globalThis.testFn_out = (val,radio_c) => { |
|
|
|
|
|
console.log(val); |
|
|
|
|
|
return([val,radio_c]); |
|
|
}; |
|
|
|
|
|
|
|
|
globalThis.testFn_out_json = (data) => { |
|
|
console.log(data); |
|
|
var $ = jQuery; |
|
|
|
|
|
data_beam = data[1][0]; |
|
|
data_probs = data[1][1]; |
|
|
data_html_inputs = data[1][2]; |
|
|
data_html_target = data[1][3]; |
|
|
data_embds = data[2]; |
|
|
|
|
|
attViz(data[3]); |
|
|
attViz(data[4]); |
|
|
attViz(data[5]); |
|
|
|
|
|
|
|
|
console.log(data_beam, ) |
|
|
const idMapping = data_beam.reduce((acc, el, i) => { |
|
|
acc[el.id] = i; |
|
|
return acc; |
|
|
}, {}); |
|
|
|
|
|
let root; |
|
|
data_beam.forEach(el => { |
|
|
|
|
|
if (el.parentId === null) { |
|
|
root = el; |
|
|
return; |
|
|
} |
|
|
|
|
|
const parentEl = data_beam[idMapping[el.parentId]]; |
|
|
|
|
|
parentEl.children = [...(parentEl.children || []), el]; |
|
|
}); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
d3.select('#d3_beam_search').html(""); |
|
|
d3.select('#d3_beam_search').append(function(){return Tree(root);}); |
|
|
|
|
|
|
|
|
|
|
|
d3.select('#d3_text_grid').html(""); |
|
|
d3.select('#d3_text_grid').append(function(){return TextGrid(data_probs);}); |
|
|
|
|
|
|
|
|
|
|
|
d3.select('#d3_tok').html(data_html_inputs); |
|
|
d3.select('#d3_tok_target').html(data_html_target); |
|
|
|
|
|
|
|
|
d3.select("#d3_embeds_source").html("here"); |
|
|
|
|
|
console.log(d3.select("#select_type").node().value); |
|
|
d3.select("#select_type").attr("hidden", null); |
|
|
d3.select("#select_type").on("change", change); |
|
|
change(); |
|
|
|
|
|
|
|
|
['input', 'output'].forEach(text_type => { |
|
|
['tokens', 'words'].forEach(text_key => { |
|
|
|
|
|
data_i = data_embds[text_type][text_key]; |
|
|
embeddings_network([], data_i['tnse'], data_i['similar_queries'], type=text_type +"_"+text_key, ) |
|
|
}); |
|
|
}); |
|
|
|
|
|
|
|
|
|
|
|
return(['string', {}]) |
|
|
|
|
|
} |
|
|
|
|
|
function change() { |
|
|
show_type = d3.select("#select_type").node().value; |
|
|
|
|
|
d3.selectAll(".d3_embed").attr("hidden",''); |
|
|
d3.selectAll(".d3_graph").attr("hidden", ''); |
|
|
|
|
|
d3.select("#d3_embeds_input_" + show_type).attr("hidden", null); |
|
|
d3.select("#d3_embeds_output_" + show_type).attr("hidden", null); |
|
|
d3.select("#d3_graph_input_" + show_type).attr("hidden", null); |
|
|
d3.select("#d3_graph_output_" + show_type).attr("hidden", null); |
|
|
} |
|
|
|
|
|
function embeddings_network(tokens_text, dict_projected_embds, similar_vocab_queries, type="source", ){ |
|
|
|
|
|
|
|
|
console.log("Each token is a node; distance if in similar list", type ); |
|
|
console.log(tokens_text, dict_projected_embds, similar_vocab_queries); |
|
|
|
|
|
|
|
|
var nodes_tokens = {} |
|
|
var nodeHash = {}; |
|
|
var nodes = []; |
|
|
var edges = []; |
|
|
var edges_ids = []; |
|
|
|
|
|
|
|
|
console.log('similar_vocab_queries', similar_vocab_queries); |
|
|
prev_node = ''; |
|
|
for ([sent_token, value] of Object.entries(similar_vocab_queries)) { |
|
|
|
|
|
|
|
|
token_text = dict_projected_embds[sent_token][3] |
|
|
if (!nodeHash[sent_token]) { |
|
|
nodeHash[sent_token] = {id: sent_token, label: token_text, type: 'sentence', type_i: 0}; |
|
|
nodes.push(nodeHash[sent_token]); |
|
|
} |
|
|
sim_tokens = value['similar_topk'] |
|
|
dist_tokens = value['distance'] |
|
|
|
|
|
for (let index = 0; index < sim_tokens.length; index++) { |
|
|
const sim = sim_tokens[index]; |
|
|
const dist = dist_tokens[index]; |
|
|
|
|
|
token_text_sim = dict_projected_embds[sim][3] |
|
|
if (!nodeHash[sim]) { |
|
|
nodeHash[sim] = {id: sim, label: token_text_sim, type:'similar', type_i: 1}; |
|
|
nodes.push(nodeHash[sim]); |
|
|
} |
|
|
edges.push({source: nodeHash[sent_token], target: nodeHash[sim], weight: dist}); |
|
|
edges_ids.push({source: sent_token, target: sim, weight: dist}); |
|
|
} |
|
|
|
|
|
if (prev_node != '' ) { |
|
|
edges.push({source: nodeHash[prev_node], target:nodeHash[sent_token], weight: 1}); |
|
|
edges_ids.push({source: prev_node, target: sent_token, weight: 1}); |
|
|
} |
|
|
prev_node = sent_token; |
|
|
|
|
|
} |
|
|
console.log("TYPE", type, edges, nodes, edges_ids, similar_vocab_queries) |
|
|
|
|
|
|
|
|
d3.select('#d3_graph_'+type).html(""); |
|
|
d3.select('#d3_graph_'+type).append(function(){return networkPlot({nodes: nodes, links:edges}, similar_vocab_queries, dict_projected_embds,div_type=type);}); |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
function networkPlot(data, similar_vocab_queries,dict_proj, div_type="source", { |
|
|
width = 400, // outer width, in pixels |
|
|
height , // outer height, in pixels |
|
|
r = 3, // radius of nodes |
|
|
padding = 1, // horizontal padding for first and last column |
|
|
// text = d => d[2], |
|
|
} = {}){ |
|
|
|
|
|
data = data |
|
|
similar_vocab_queries = similar_vocab_queries |
|
|
console.log("data, similar_vocab_queries, div_type"); |
|
|
console.log(data, similar_vocab_queries, div_type); |
|
|
|
|
|
|
|
|
var margin = {top: 10, right: 10, bottom: 30, left: 50 }, |
|
|
width = width |
|
|
height = 400 |
|
|
|
|
|
width_box = width + margin.left + margin.right; |
|
|
height_box = height + margin.top + margin.bottom |
|
|
totalWidth = width*2; |
|
|
|
|
|
|
|
|
var svg = d37.create("svg") |
|
|
.attr("width", width + margin.left + margin.right) |
|
|
.attr("height", height + margin.top + margin.bottom) |
|
|
|
|
|
|
|
|
var link = svg |
|
|
.selectAll("line") |
|
|
.data(data.links) |
|
|
.enter() |
|
|
.append("line") |
|
|
.style("fill", d => d.weight == 1 ? "#dfd5d5" : "#000000") |
|
|
.style("stroke", "#aaa") |
|
|
|
|
|
|
|
|
|
|
|
var text = svg |
|
|
.selectAll("text") |
|
|
.data(data.nodes) |
|
|
.enter() |
|
|
.append("text") |
|
|
.style("text-anchor", "middle") |
|
|
.attr("y", 15) |
|
|
.attr("class", d => 'text_token-'+ dict_proj[d.id][4] + div_type) |
|
|
.attr("div-type", div_type) |
|
|
|
|
|
.text(function (d) {return d.label} ) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var node = svg |
|
|
.selectAll("circle") |
|
|
.data(data.nodes) |
|
|
.enter() |
|
|
.append("circle") |
|
|
.attr("r", 6) |
|
|
|
|
|
.attr("class", d => 'node_token-'+ dict_proj[d.id][4] + div_type) |
|
|
.attr("div-type", div_type) |
|
|
.style("fill", d => d.type_i ? "#e85252" : "#6689c6") |
|
|
.on('mouseover', highlight_mouseover ) |
|
|
|
|
|
.on('mouseout',highlight_mouseout ) |
|
|
.on('click', change_legend ) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var simulation = d37.forceSimulation(data.nodes) |
|
|
.force("link", d37.forceLink() |
|
|
.id(function(d) { return d.id; }) |
|
|
.links(data.links) |
|
|
) |
|
|
.force("charge", d37.forceManyBody(-400)) |
|
|
.force("center", d37.forceCenter(width / 2, height / 2)) |
|
|
|
|
|
.on("end", ticked); |
|
|
|
|
|
|
|
|
function ticked() { |
|
|
link |
|
|
.attr("x1", function(d) { return d.source.x; }) |
|
|
.attr("y1", function(d) { return d.source.y; }) |
|
|
.attr("x2", function(d) { return d.target.x; }) |
|
|
.attr("y2", function(d) { return d.target.y; }); |
|
|
|
|
|
node |
|
|
.attr("cx", function (d) { return d.x+3; }) |
|
|
.attr("cy", function(d) { return d.y-3; }); |
|
|
|
|
|
text |
|
|
.attr("transform", function(d) { return "translate(" + d.x + "," + d.y + ")"; }) |
|
|
} |
|
|
|
|
|
function highlight_mouseover(d,i) { |
|
|
console.log("highlight_mouseover", d,i, d37.select(this).attr("div-type")); |
|
|
if (i.type_i == 0 ){ |
|
|
token_id = i.id |
|
|
similar_ids = similar_vocab_queries[token_id]['similar_topk']; |
|
|
d37.select(this).transition() |
|
|
.duration('50') |
|
|
.style('opacity', '1') |
|
|
.attr("r", 12) |
|
|
type = d37.select(this).attr("div-type") |
|
|
similar_ids.forEach(similar_token => { |
|
|
node_id_name = dict_proj[similar_token][4] |
|
|
d37.selectAll('.node_token-'+ node_id_name + type).attr("r",12 ).style('opacity', '1') |
|
|
|
|
|
}); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
function highlight_mouseout(d,i) { |
|
|
if (i.type_i == 0 ){ |
|
|
token_id = i.id |
|
|
console.log("similar_vocab_queries", similar_vocab_queries, "this type:", d37.select(this).attr("div-type")); |
|
|
similar_ids = similar_vocab_queries[token_id]['similar_topk']; |
|
|
|
|
|
d37.select(this).transition() |
|
|
.duration('50') |
|
|
.style('opacity', '.7') |
|
|
.attr("r", 6) |
|
|
type = d37.select(this).attr("div-type") |
|
|
similar_ids.forEach(similar_token => { |
|
|
node_id_name = dict_proj[similar_token][4] |
|
|
d37.selectAll('.node_token-' + node_id_name + type).attr("r",6 ).style('opacity', '.7') |
|
|
d37.selectAll("circle").raise() |
|
|
}); |
|
|
} |
|
|
} |
|
|
|
|
|
function change_legend(d,i,j) { |
|
|
console.log(d,i,dict_proj); |
|
|
if (i['id'] in dict_proj){ |
|
|
|
|
|
|
|
|
show_similar_tokens(i['id'], '#similar_'+type); |
|
|
|
|
|
console.log(dict_proj[i['id']]); |
|
|
} |
|
|
else{console.log("no sentence")}; |
|
|
} |
|
|
|
|
|
function show_similar_tokens(token, div_name_similar='#similar_input_tokens') { |
|
|
d37.select(div_name_similar).html(""); |
|
|
console.log("token", token); |
|
|
console.log("similar_vocab_queries[token]", similar_vocab_queries[token]); |
|
|
token_data = similar_vocab_queries[token]; |
|
|
console.log(token, token_data); |
|
|
var decForm = d37.format(".3f"); |
|
|
|
|
|
d37.select(div_name_similar) |
|
|
.selectAll().append("p") |
|
|
.data(token_data['similar_topk']) |
|
|
.enter() |
|
|
.append("p").append('text') |
|
|
|
|
|
.attr('class_id', d => d) |
|
|
.style("background", d=> {if (d == token) return "yellow"} ) |
|
|
|
|
|
.text((d,i) => do_text(d,i) ); |
|
|
|
|
|
function do_text(d,i){ |
|
|
console.log("do_text d,i" ); |
|
|
console.log(d,i); |
|
|
console.log("data_dict[d], data_dict"); |
|
|
console.log(dict_proj[d], dict_proj); |
|
|
return dict_proj[d][3] + " " + decForm(token_data['distance'][i]) + " "; |
|
|
} |
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
return svg.node(); |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function Tree(data, { // data is either tabular (array of objects) or hierarchy (nested objects) |
|
|
path, // as an alternative to id and parentId, returns an array identifier, imputing internal nodes |
|
|
id = Array.isArray(data) ? d => d.id : null, // if tabular data, given a d in data, returns a unique identifier (string) |
|
|
parentId = Array.isArray(data) ? d => d.parentId : null, // if tabular data, given a node d, returns its parent’s identifier |
|
|
children, // if hierarchical data, given a d in data, returns its children |
|
|
tree = d3.tree, // layout algorithm (typically d3.tree or d3.cluster) |
|
|
sort, // how to sort nodes prior to layout (e.g., (a, b) => d3.descending(a.height, b.height)) |
|
|
label = d => d.name, // given a node d, returns the display name |
|
|
title = d => d.name, // given a node d, returns its hover text |
|
|
link , // given a node d, its link (if any) |
|
|
linkTarget = "_blank", // the target attribute for links (if any) |
|
|
width = 800, // outer width, in pixels |
|
|
height, // outer height, in pixels |
|
|
r = 3, // radius of nodes |
|
|
padding = 1, // horizontal padding for first and last column |
|
|
fill = "#999", // fill for nodes |
|
|
fillOpacity, // fill opacity for nodes |
|
|
stroke = "#555", // stroke for links |
|
|
strokeWidth = 2, // stroke width for links |
|
|
strokeOpacity = 0.4, // stroke opacity for links |
|
|
strokeLinejoin, // stroke line join for links |
|
|
strokeLinecap, // stroke line cap for links |
|
|
halo = "#fff", // color of label halo |
|
|
haloWidth = 3, // padding around the labels |
|
|
curve = d37.curveBumpX, // curve for the link |
|
|
} = {}) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const root = path != null ? d3.stratify().path(path)(data) |
|
|
: id != null || parentId != null ? d3.stratify().id(id).parentId(parentId)(data) |
|
|
: d3.hierarchy(data, children); |
|
|
|
|
|
|
|
|
if (sort != null) root.sort(sort); |
|
|
|
|
|
|
|
|
const descendants = root.descendants(); |
|
|
const L = label == null ? null : descendants.map(d => label(d.data, d)); |
|
|
|
|
|
|
|
|
const descWidth = 10; |
|
|
|
|
|
const realWidth = descWidth * descendants.length |
|
|
const totalWidth = (realWidth > width) ? realWidth : width; |
|
|
|
|
|
const dx = 25; |
|
|
const dy = totalWidth / (root.height + padding); |
|
|
tree().nodeSize([dx, dy])(root); |
|
|
|
|
|
|
|
|
let x0 = Infinity; |
|
|
let x1 = -x0; |
|
|
root.each(d => { |
|
|
if (d.x > x1) x1 = d.x; |
|
|
if (d.x < x0) x0 = d.x; |
|
|
}); |
|
|
|
|
|
|
|
|
if (height === undefined) height = x1 - x0 + dx * 2; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (typeof curve !== "function") throw new Error(`Unsupported curve`); |
|
|
|
|
|
const parent = d3.create("div"); |
|
|
|
|
|
const body = parent.append("div") |
|
|
.style("overflow-x", "scroll") |
|
|
.style("-webkit-overflow-scrolling", "touch"); |
|
|
|
|
|
const svg = body.append("svg") |
|
|
.attr("viewBox", [-dy * padding / 2, x0 - dx, totalWidth, height]) |
|
|
.attr("width", totalWidth) |
|
|
.attr("height", height) |
|
|
.attr("style", "max-width: 100%; height: auto; height: intrinsic;") |
|
|
.attr("font-family", "sans-serif") |
|
|
.attr("font-size", 12); |
|
|
|
|
|
svg.append("g") |
|
|
.attr("fill", "none") |
|
|
.attr("stroke", stroke) |
|
|
.attr("stroke-opacity", strokeOpacity) |
|
|
.attr("stroke-linecap", strokeLinecap) |
|
|
.attr("stroke-linejoin", strokeLinejoin) |
|
|
.attr("stroke-width", strokeWidth) |
|
|
.selectAll("path") |
|
|
.data(root.links()) |
|
|
.join("path") |
|
|
|
|
|
|
|
|
.attr("d", d37.link(curve) |
|
|
.x(d => d.y) |
|
|
.y(d => d.x)); |
|
|
|
|
|
const node = svg.append("g") |
|
|
.selectAll("a") |
|
|
.data(root.descendants()) |
|
|
.join("a") |
|
|
.attr("xlink:href", link == null ? null : d => link(d.data, d)) |
|
|
.attr("target", link == null ? null : linkTarget) |
|
|
.attr("transform", d => `translate(${d.y},${d.x})`); |
|
|
|
|
|
node.append("circle") |
|
|
.attr("fill", d => d.children ? stroke : fill) |
|
|
.attr("r", r); |
|
|
|
|
|
title = d => (d.name + ( d.prob)); |
|
|
|
|
|
if (title != null) node.append("title") |
|
|
.text(d => title(d.data, d)); |
|
|
|
|
|
if (L) node.append("text") |
|
|
.attr("dy", "0.32em") |
|
|
.attr("x", d => d.children ? -6 : 6) |
|
|
.attr("text-anchor", d => d.children ? "end" : "start") |
|
|
.attr("paint-order", "stroke") |
|
|
.attr("stroke", 'white') |
|
|
.attr("fill", d => d.data.prob == 1 ? ('red') : ('black') ) |
|
|
.attr("stroke-width", haloWidth) |
|
|
.text((d, i) => L[i]); |
|
|
body.node().scrollBy(totalWidth, 0); |
|
|
return svg.node(); |
|
|
} |
|
|
|
|
|
function TextGrid(data, div_name, { |
|
|
width = 640, // outer width, in pixels |
|
|
height , // outer height, in pixels |
|
|
r = 3, // radius of nodes |
|
|
padding = 1, // horizontal padding for first and last column |
|
|
// text = d => d[2], |
|
|
} = {}){ |
|
|
|
|
|
|
|
|
|
|
|
const dx = 10; |
|
|
const dy = 10; |
|
|
|
|
|
const marginTop = 20; |
|
|
const marginRight = 20; |
|
|
const marginBottom = 30; |
|
|
const marginLeft = 30; |
|
|
|
|
|
|
|
|
let x0 = Infinity; |
|
|
let x1 = -x0; |
|
|
topk = 10; |
|
|
word_length = 20; |
|
|
const rectWidth = 60; |
|
|
const rectTotal = 70; |
|
|
|
|
|
wval = 0 |
|
|
|
|
|
const realWidth = rectTotal * data.length |
|
|
const totalWidth = (realWidth > width) ? realWidth : width; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (height === undefined) height = topk * word_length + 10; |
|
|
|
|
|
const parent = d3.create("div"); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const body = parent.append("div") |
|
|
.style("overflow-x", "scroll") |
|
|
.style("-webkit-overflow-scrolling", "touch"); |
|
|
|
|
|
const svg = body.append("svg") |
|
|
.attr("width", totalWidth) |
|
|
.attr("height", height) |
|
|
.style("display", "block") |
|
|
.attr("font-family", "sans-serif") |
|
|
.attr("font-size", 10); |
|
|
|
|
|
|
|
|
data.forEach(words_list => { |
|
|
|
|
|
words = words_list[2]; |
|
|
scores = words_list[1]; |
|
|
words_score = words.map( (x,i) => {return {t: x, p: scores[i]}}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var probs = svg.selectAll("text").enter() |
|
|
.data(words_score).join('g'); |
|
|
|
|
|
|
|
|
|
|
|
probs.append("rect") |
|
|
|
|
|
.attr("x", wval) |
|
|
.attr("y", ((d,i) => ( 10+ i*20))) |
|
|
.attr('width', rectWidth) |
|
|
.attr('height', 15) |
|
|
.attr("color", 'gray') |
|
|
.attr("fill", "gray") |
|
|
|
|
|
.attr("fill-opacity", (d) => (d.p)) |
|
|
.attr("stroke-opacity", 0.8) |
|
|
.append("svg:title") |
|
|
.text(function(d){return d.t+":"+d.p;}); |
|
|
|
|
|
|
|
|
probs.append("text") |
|
|
|
|
|
.text((d,i) => (d.t)) |
|
|
.attr("x", wval) |
|
|
.attr("y", ((d,i) => (20 + i*20))) |
|
|
|
|
|
.attr("font-weight", 700); |
|
|
|
|
|
wval = wval + rectTotal; |
|
|
}); |
|
|
|
|
|
|
|
|
body.node().scrollBy(totalWidth, 0); |
|
|
|
|
|
return parent.node(); |
|
|
} |
|
|
|
|
|
|
|
|
function attViz(PYTHON_PARAMS) { |
|
|
var $ = jQuery; |
|
|
const params = PYTHON_PARAMS; |
|
|
const TEXT_SIZE = 15; |
|
|
const BOXWIDTH = 110; |
|
|
const BOXHEIGHT = 22.5; |
|
|
const MATRIX_WIDTH = 115; |
|
|
const CHECKBOX_SIZE = 20; |
|
|
const TEXT_TOP = 30; |
|
|
|
|
|
console.log("d3 version in ffuntions", d3.version) |
|
|
let headColors; |
|
|
try { |
|
|
headColors = d3.scaleOrdinal(d3.schemeCategory10); |
|
|
} catch (err) { |
|
|
console.log('Older d3 version') |
|
|
headColors = d3.scale.category10(); |
|
|
} |
|
|
let config = {}; |
|
|
|
|
|
initialize(); |
|
|
renderVis(); |
|
|
|
|
|
function initialize() { |
|
|
|
|
|
|
|
|
console.log("init") |
|
|
config.attention = params['attention']; |
|
|
config.filter = params['default_filter']; |
|
|
config.rootDivId = params['root_div_id']; |
|
|
config.nLayers = config.attention[config.filter]['attn'].length; |
|
|
config.nHeads = config.attention[config.filter]['attn'][0].length; |
|
|
config.layers = params['include_layers'] |
|
|
|
|
|
if (params['heads']) { |
|
|
config.headVis = new Array(config.nHeads).fill(false); |
|
|
params['heads'].forEach(x => config.headVis[x] = true); |
|
|
} else { |
|
|
config.headVis = new Array(config.nHeads).fill(true); |
|
|
} |
|
|
config.initialTextLength = config.attention[config.filter].right_text.length; |
|
|
config.layer_seq = (params['layer'] == null ? 0 : config.layers.findIndex(layer => params['layer'] === layer)); |
|
|
config.layer = config.layers[config.layer_seq] |
|
|
|
|
|
|
|
|
$('#' + config.rootDivId+ ' #layer').empty(); |
|
|
|
|
|
let layerEl = $('#' + config.rootDivId+ ' #layer'); |
|
|
console.log(layerEl) |
|
|
for (const layer of config.layers) { |
|
|
layerEl.append($("<option />").val(layer).text(layer)); |
|
|
} |
|
|
layerEl.val(config.layer).change(); |
|
|
layerEl.on('change', function (e) { |
|
|
config.layer = +e.currentTarget.value; |
|
|
config.layer_seq = config.layers.findIndex(layer => config.layer === layer); |
|
|
renderVis(); |
|
|
}); |
|
|
|
|
|
$('#'+config.rootDivId+' #filter').on('change', function (e) { |
|
|
|
|
|
|
|
|
config.filter = e.currentTarget.value; |
|
|
renderVis(); |
|
|
}); |
|
|
} |
|
|
|
|
|
function renderVis() { |
|
|
|
|
|
|
|
|
const attnData = config.attention[config.filter]; |
|
|
const leftText = attnData.left_text; |
|
|
const rightText = attnData.right_text; |
|
|
|
|
|
|
|
|
const layerAttention = attnData.attn[config.layer_seq]; |
|
|
|
|
|
|
|
|
$('#'+config.rootDivId+' #vis').empty(); |
|
|
|
|
|
|
|
|
const height = Math.max(leftText.length, rightText.length) * BOXHEIGHT + TEXT_TOP; |
|
|
const svg = d3.select('#'+ config.rootDivId +' #vis') |
|
|
.append('svg') |
|
|
.attr("width", "100%") |
|
|
.attr("height", height + "px"); |
|
|
|
|
|
|
|
|
renderText(svg, leftText, true, layerAttention, 0); |
|
|
renderText(svg, rightText, false, layerAttention, MATRIX_WIDTH + BOXWIDTH); |
|
|
|
|
|
|
|
|
renderAttention(svg, layerAttention); |
|
|
|
|
|
|
|
|
drawCheckboxes(0, svg, layerAttention); |
|
|
} |
|
|
|
|
|
function renderText(svg, text, isLeft, attention, leftPos) { |
|
|
|
|
|
const textContainer = svg.append("svg:g") |
|
|
.attr("id", isLeft ? "left" : "right"); |
|
|
|
|
|
|
|
|
textContainer.append("g") |
|
|
.classed("attentionBoxes", true) |
|
|
.selectAll("g") |
|
|
.data(attention) |
|
|
.enter() |
|
|
.append("g") |
|
|
.attr("head-index", (d, i) => i) |
|
|
.selectAll("rect") |
|
|
.data(d => isLeft ? d : transpose(d)) |
|
|
.enter() |
|
|
.append("rect") |
|
|
.attr("x", function () { |
|
|
var headIndex = +this.parentNode.getAttribute("head-index"); |
|
|
return leftPos + boxOffsets(headIndex); |
|
|
}) |
|
|
.attr("y", (+1) * BOXHEIGHT) |
|
|
.attr("width", BOXWIDTH / activeHeads()) |
|
|
.attr("height", BOXHEIGHT) |
|
|
.attr("fill", function () { |
|
|
return headColors(+this.parentNode.getAttribute("head-index")) |
|
|
}) |
|
|
.style("opacity", 0.0); |
|
|
|
|
|
const tokenContainer = textContainer.append("g").selectAll("g") |
|
|
.data(text) |
|
|
.enter() |
|
|
.append("g"); |
|
|
|
|
|
|
|
|
tokenContainer.append("rect") |
|
|
.classed("background", true) |
|
|
.style("opacity", 0.0) |
|
|
.attr("fill", "lightgray") |
|
|
.attr("x", leftPos) |
|
|
.attr("y", (d, i) => TEXT_TOP + i * BOXHEIGHT) |
|
|
.attr("width", BOXWIDTH) |
|
|
.attr("height", BOXHEIGHT); |
|
|
|
|
|
|
|
|
const textEl = tokenContainer.append("text") |
|
|
.text(d => d) |
|
|
.attr("font-size", TEXT_SIZE + "px") |
|
|
.style("cursor", "default") |
|
|
.style("-webkit-user-select", "none") |
|
|
.attr("x", leftPos) |
|
|
.attr("y", (d, i) => TEXT_TOP + i * BOXHEIGHT); |
|
|
|
|
|
if (isLeft) { |
|
|
textEl.style("text-anchor", "end") |
|
|
.attr("dx", BOXWIDTH - 0.5 * TEXT_SIZE) |
|
|
.attr("dy", TEXT_SIZE); |
|
|
} else { |
|
|
textEl.style("text-anchor", "start") |
|
|
.attr("dx", +0.5 * TEXT_SIZE) |
|
|
.attr("dy", TEXT_SIZE); |
|
|
} |
|
|
|
|
|
tokenContainer.on("mouseover", function (d, index) { |
|
|
|
|
|
|
|
|
textContainer.selectAll(".background") |
|
|
.style("opacity", (d, i) => i === index ? 1.0 : 0.0) |
|
|
|
|
|
|
|
|
svg.select("#attention") |
|
|
.selectAll("line[visibility='visible']") |
|
|
.attr("visibility", null) |
|
|
|
|
|
|
|
|
svg.select("#attention").attr("visibility", "hidden"); |
|
|
|
|
|
|
|
|
if (isLeft) { |
|
|
svg.select("#attention").selectAll("line[left-token-index='" + index + "']").attr("visibility", "visible"); |
|
|
} else { |
|
|
svg.select("#attention").selectAll("line[right-token-index='" + index + "']").attr("visibility", "visible"); |
|
|
} |
|
|
|
|
|
|
|
|
const id = isLeft ? "right" : "left"; |
|
|
const leftPos = isLeft ? MATRIX_WIDTH + BOXWIDTH : 0; |
|
|
svg.select("#" + id) |
|
|
.selectAll(".attentionBoxes") |
|
|
.selectAll("g") |
|
|
.attr("head-index", (d, i) => i) |
|
|
.selectAll("rect") |
|
|
.attr("x", function () { |
|
|
const headIndex = +this.parentNode.getAttribute("head-index"); |
|
|
return leftPos + boxOffsets(headIndex); |
|
|
}) |
|
|
.attr("y", (d, i) => TEXT_TOP + i * BOXHEIGHT) |
|
|
.attr("width", BOXWIDTH / activeHeads()) |
|
|
.attr("height", BOXHEIGHT) |
|
|
.style("opacity", function (d) { |
|
|
const headIndex = +this.parentNode.getAttribute("head-index"); |
|
|
if (config.headVis[headIndex]) |
|
|
if (d) { |
|
|
return d[index]; |
|
|
} else { |
|
|
return 0.0; |
|
|
} |
|
|
else |
|
|
return 0.0; |
|
|
}); |
|
|
}); |
|
|
|
|
|
textContainer.on("mouseleave", function () { |
|
|
|
|
|
|
|
|
d3.select(this).selectAll(".background") |
|
|
.style("opacity", 0.0); |
|
|
|
|
|
|
|
|
svg.select("#attention") |
|
|
.selectAll("line[visibility='visible']") |
|
|
.attr("visibility", null) ; |
|
|
svg.select("#attention").attr("visibility", "visible"); |
|
|
|
|
|
|
|
|
svg.selectAll(".attentionBoxes") |
|
|
.selectAll("g") |
|
|
.selectAll("rect") |
|
|
.style("opacity", 0.0); |
|
|
}); |
|
|
} |
|
|
|
|
|
function renderAttention(svg, attention) { |
|
|
|
|
|
|
|
|
svg.select("#attention").remove(); |
|
|
|
|
|
|
|
|
svg.append("g") |
|
|
.attr("id", "attention") |
|
|
.selectAll(".headAttention") |
|
|
.data(attention) |
|
|
.enter() |
|
|
.append("g") |
|
|
.classed("headAttention", true) |
|
|
.attr("head-index", (d, i) => i) |
|
|
.selectAll(".tokenAttention") |
|
|
.data(d => d) |
|
|
.enter() |
|
|
.append("g") |
|
|
.classed("tokenAttention", true) |
|
|
.attr("left-token-index", (d, i) => i) |
|
|
.selectAll("line") |
|
|
.data(d => d) |
|
|
.enter() |
|
|
.append("line") |
|
|
.attr("x1", BOXWIDTH) |
|
|
.attr("y1", function () { |
|
|
const leftTokenIndex = +this.parentNode.getAttribute("left-token-index") |
|
|
return TEXT_TOP + leftTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2) |
|
|
}) |
|
|
.attr("x2", BOXWIDTH + MATRIX_WIDTH) |
|
|
.attr("y2", (d, rightTokenIndex) => TEXT_TOP + rightTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2)) |
|
|
.attr("stroke-width", 2) |
|
|
.attr("stroke", function () { |
|
|
const headIndex = +this.parentNode.parentNode.getAttribute("head-index"); |
|
|
return headColors(headIndex) |
|
|
}) |
|
|
.attr("left-token-index", function () { |
|
|
return +this.parentNode.getAttribute("left-token-index") |
|
|
}) |
|
|
.attr("right-token-index", (d, i) => i) |
|
|
; |
|
|
updateAttention(svg) |
|
|
} |
|
|
|
|
|
function updateAttention(svg) { |
|
|
svg.select("#attention") |
|
|
.selectAll("line") |
|
|
.attr("stroke-opacity", function (d) { |
|
|
const headIndex = +this.parentNode.parentNode.getAttribute("head-index"); |
|
|
|
|
|
if (config.headVis[headIndex]) { |
|
|
|
|
|
return d / activeHeads() |
|
|
} else { |
|
|
return 0.0; |
|
|
} |
|
|
}) |
|
|
} |
|
|
|
|
|
function boxOffsets(i) { |
|
|
const numHeadsAbove = config.headVis.reduce( |
|
|
function (acc, val, cur) { |
|
|
return val && cur < i ? acc + 1 : acc; |
|
|
}, 0); |
|
|
return numHeadsAbove * (BOXWIDTH / activeHeads()); |
|
|
} |
|
|
|
|
|
function activeHeads() { |
|
|
return config.headVis.reduce(function (acc, val) { |
|
|
return val ? acc + 1 : acc; |
|
|
}, 0); |
|
|
} |
|
|
|
|
|
function drawCheckboxes(top, svg) { |
|
|
const checkboxContainer = svg.append("g"); |
|
|
const checkbox = checkboxContainer.selectAll("rect") |
|
|
.data(config.headVis) |
|
|
.enter() |
|
|
.append("rect") |
|
|
.attr("fill", (d, i) => headColors(i)) |
|
|
.attr("x", (d, i) => i * CHECKBOX_SIZE) |
|
|
.attr("y", top) |
|
|
.attr("width", CHECKBOX_SIZE) |
|
|
.attr("height", CHECKBOX_SIZE); |
|
|
|
|
|
function updateCheckboxes() { |
|
|
checkboxContainer.selectAll("rect") |
|
|
.data(config.headVis) |
|
|
.attr("fill", (d, i) => d ? headColors(i): lighten(headColors(i))); |
|
|
} |
|
|
|
|
|
updateCheckboxes(); |
|
|
|
|
|
checkbox.on("click", function (d, i) { |
|
|
if (config.headVis[i] && activeHeads() === 1) return; |
|
|
config.headVis[i] = !config.headVis[i]; |
|
|
updateCheckboxes(); |
|
|
updateAttention(svg); |
|
|
}); |
|
|
|
|
|
checkbox.on("dblclick", function (d, i) { |
|
|
|
|
|
if (config.headVis[i] && activeHeads() === 1) { |
|
|
config.headVis = new Array(config.nHeads).fill(true); |
|
|
} else { |
|
|
config.headVis = new Array(config.nHeads).fill(false); |
|
|
config.headVis[i] = true; |
|
|
} |
|
|
updateCheckboxes(); |
|
|
updateAttention(svg); |
|
|
}); |
|
|
} |
|
|
|
|
|
function lighten(color) { |
|
|
const c = d3.hsl(color); |
|
|
const increment = (1 - c.l) * 0.6; |
|
|
c.l += increment; |
|
|
c.s -= increment; |
|
|
return c; |
|
|
} |
|
|
|
|
|
function transpose(mat) { |
|
|
return mat[0].map(function (col, i) { |
|
|
return mat.map(function (row) { |
|
|
return row[i]; |
|
|
}); |
|
|
}); |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|