tfrere's picture
tfrere HF Staff
update
f7b880e
raw
history blame
38.1 kB
<div class="d3-neural"></div>
<style>
.d3-neural { position: relative; width:100%;margin:0;}
.d3-neural .controls { margin-top: 12px; display: flex; gap: 12px; align-items: center; flex-wrap: wrap; }
.d3-neural .controls label { font-size: 12px; color: var(--muted-color); display: flex; align-items: center; gap: 8px; white-space: nowrap; padding: 6px 10px; }
.d3-neural .controls input[type="range"]{ width: 160px; }
.d3-neural .panel { display:flex; gap:8px; align-items:stretch; flex-wrap: nowrap; }
.d3-neural .left { flex: 0 0 33.333%; max-width: 33.333%; min-width: 160px; display:flex; flex-direction:column; gap:8px; }
.d3-neural .right { flex: 1 1 66.666%; max-width: 66.666%; min-width: 280px; display:flex; }
.d3-neural .right > svg { flex: 1 1 auto; height: 100%; }
.d3-neural .arrow-sep { flex: 0 0 18px; max-width: 18px; display:flex; align-items:center; justify-content:center; color: var(--muted-color); }
.d3-neural .arrow-sep svg { display:block; width: 16px; height: 16px; }
@media (max-width: 800px) {
.d3-neural .panel { flex-direction: column; }
.d3-neural .left,
.d3-neural .right { flex: 0 0 100%; max-width: 100%; min-width: 0; }
.d3-neural .arrow-sep { display: none; }
}
.d3-neural canvas { width: 100%; height: auto; border-radius: 8px; border: 1px solid var(--border-color); background: var(--surface-bg); display:block; }
.d3-neural .preview28 { display:grid; grid-template-columns: repeat(28, 1fr); gap: 1px; width: 100%; }
.d3-neural .preview28 span { display:block; aspect-ratio:1/1; border-radius:2px; }
.d3-neural .legend { font-size: 12px; color: var(--text-color); line-height:1.35; }
.d3-neural .probs { display:flex; gap:6px; align-items:flex-end; height: 64px; }
.d3-neural .probs .bar { width: 10px; border-radius:2px 2px 0 0; background: var(--border-color); transition: height .15s ease, background-color .15s ease; }
.d3-neural .probs .bar.active { background: var(--primary-color); }
.d3-neural .probs .tick { font-size: 10px; color: var(--muted-color); text-align:center; margin-top: 2px; }
.d3-neural .canvas-wrap { position: relative; }
.d3-neural .erase-btn { position: absolute; top: 8px; right: 8px; width: 32px; height: 32px; display:flex; align-items:center; justify-content:center; border: 1px solid var(--border-color); }
.d3-neural .canvas-hint { position: absolute; top: 8px; left: 12px; font-size: 12px; font-weight: 700; color: rgba(0,0,0,.9); pointer-events: none; transition: opacity .12s ease; }
</style>
<script>
(() => {
const ensureD3 = (cb) => {
if (window.d3 && typeof window.d3.select === 'function') return cb();
let s = document.getElementById('d3-cdn-script');
if (!s) { s = document.createElement('script'); s.id = 'd3-cdn-script'; s.src = 'https://cdn.jsdelivr.net/npm/d3@7/dist/d3.min.js'; document.head.appendChild(s); }
const onReady = () => { if (window.d3 && typeof window.d3.select === 'function') cb(); };
s.addEventListener('load', onReady, { once: true });
if (window.d3) onReady();
};
const ensureTF = (cb) => {
if (window.tf && typeof window.tf.tensor === 'function') return cb();
let s = document.getElementById('tfjs-cdn-script');
if (!s) { s = document.createElement('script'); s.id = 'tfjs-cdn-script'; s.src = 'https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.20.0/dist/tf.min.js'; document.head.appendChild(s); }
const onReady = () => { if (window.tf && typeof window.tf.tensor === 'function') cb(); };
s.addEventListener('load', onReady, { once: true });
if (window.tf) onReady();
};
const bootstrap = () => {
const mount = document.currentScript ? document.currentScript.previousElementSibling : null;
const container = (mount && mount.querySelector && mount.querySelector('.d3-neural')) || document.querySelector('.d3-neural');
if (!container) return;
if (container.dataset) { if (container.dataset.mounted === 'true') return; container.dataset.mounted = 'true'; }
// (tooltip removed)
// Layout: left (canvas + preview + controls), right (svg network)
const panel = document.createElement('div');
panel.className = 'panel';
const left = document.createElement('div'); left.className = 'left';
const arrowSep = document.createElement('div'); arrowSep.className = 'arrow-sep';
arrowSep.innerHTML = '<svg viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg" aria-hidden="true"><line x1="3" y1="12" x2="19" y2="12" stroke="currentColor" stroke-width="2" stroke-linecap="round"/><polyline points="17,7 22,12 17,17" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/></svg>';
const right = document.createElement('div'); right.className = 'right';
panel.appendChild(left); panel.appendChild(arrowSep); panel.appendChild(right);
container.appendChild(panel);
// Canvas for drawing
const CANVAS_PX = 224; // canvas pixels (square)
const canvas = document.createElement('canvas'); canvas.width = CANVAS_PX; canvas.height = CANVAS_PX;
const ctx = canvas.getContext('2d');
// init white bg
ctx.fillStyle = '#ffffff'; ctx.fillRect(0,0,CANVAS_PX,CANVAS_PX);
const canvasWrap = document.createElement('div'); canvasWrap.className = 'canvas-wrap';
canvasWrap.appendChild(canvas);
// Erase icon button (top-right)
const eraseBtn = document.createElement('button'); eraseBtn.className='erase-btn button--ghost'; eraseBtn.type='button'; eraseBtn.setAttribute('aria-label','Clear');
// Hidden until the user interacts with the canvas
eraseBtn.style.display = 'none';
eraseBtn.innerHTML = '<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><polyline points="3 6 5 6 21 6"></polyline><path d="M19 6l-1 14a2 2 0 0 1-2 2H8a2 2 0 0 1-2-2L5 6"></path><path d="M10 11v6"></path><path d="M14 11v6"></path><path d="M9 6V4a2 2 0 0 1 2-2h2a2 2 0 0 1 2 2v2"></path></svg>';
eraseBtn.addEventListener('click', () => clearCanvas());
canvasWrap.appendChild(eraseBtn);
// Hint (top-left)
const hint = document.createElement('div'); hint.className='canvas-hint'; hint.textContent='Draw a digit here';
canvasWrap.appendChild(hint);
left.appendChild(canvasWrap);
// (preview grid removed)
// (controls removed; erase button is overlayed on canvas)
// (prediction panel removed; predictions rendered next to output nodes)
// SVG network on right
const svg = d3.select(right).append('svg').attr('width','100%').style('display','block');
const defs = svg.append('defs');
const gRoot = svg.append('g');
const gInput = gRoot.append('g').attr('class','input');
const gInputLinks = gRoot.append('g').attr('class','input-links');
const gLinks = gRoot.append('g').attr('class','links');
const gNodes = gRoot.append('g').attr('class','nodes');
const gLabels = gRoot.append('g').attr('class','labels');
const gOutText = gRoot.append('g').attr('class','out-probs');
// Network structure (compact: 8 -> 8 -> 10)
const layerSizes = [8, 8, 10];
const layers = layerSizes.map((n, li)=> Array.from({length:n}, (_, i)=>({ id:`L${li}N${i}`, layer: li, index: i, a:0 })));
// Links only between hidden->hidden and hidden->output
const links = [];
for (let i=0;i<layerSizes[0];i++){
for (let j=0;j<layerSizes[1];j++) links.push({ s:{l:0,i}, t:{l:1,j}, w: (Math.sin(i*17+j*31)+1)/2 });
}
for (let i=0;i<layerSizes[1];i++){
for (let j=0;j<layerSizes[2];j++) links.push({ s:{l:1,i}, t:{l:2,j}, w: (Math.cos(i*7+j*13)+1)/2 });
}
// Linear classifier: logits = W * feats + b, feats in [0,1]
// features: [total, cx, cy, lr, tb, htrans, vtrans, loopiness]
const W = [
// 0 1 2 3 4 5 6 7
[ 0.3, 0.0, 0.0, 0.0, 0.0, -0.8, -0.6, 1.2], // 0
[-0.2, 0.9, 0.2, 0.8, 0.1, -0.2, 0.2, -1.1], // 1
[ 0.1, 0.4, 0.2, 0.5, 0.2, 0.9, 0.1, -0.6], // 2
[ 0.2, 0.3, 0.2, 0.2, 0.2, 0.9, 0.0, -0.2], // 3
[ 0.0,-0.3, 0.2,-0.6, 0.4, 0.2, 0.8, -0.6], // 4
[ 0.1,-0.4, 0.2,-0.5, 0.5, 0.9, 0.1, -0.6], // 5
[ 0.2,-0.2, 0.6,-0.2, 0.8, -0.3, 0.2, 0.6], // 6
[ 0.0, 0.6,-0.2, 0.6,-0.8, 0.6, 0.0, -0.8], // 7
[ 0.4, 0.0, 0.0, 0.1, 0.1, 0.6, 0.6, 1.0], // 8
[ 0.2, 0.2,-0.6, 0.2,-0.8, 0.2, 0.6, 0.5], // 9
];
const b = [-0.2, -0.1, -0.05, -0.05, -0.05, -0.05, -0.05, -0.1, -0.15, -0.1];
function computeFeatures(x28){
// x28: Float32Array length 784, values in [0,1] (1 = black/ink)
let sum=0, cx=0, cy=0; const w=28, h=28;
const rowSum = new Array(h).fill(0); const colSum = new Array(w).fill(0);
let hTransitions=0, vTransitions=0;
for (let y=0;y<h;y++){
for (let x=0;x<w;x++){
const v = x28[y*w+x]; sum += v; cx += x*v; cy += y*v; rowSum[y]+=v; colSum[x]+=v;
if (x>0){ const v0=x28[y*w+(x-1)], v1=v; if ((v0>0.25)!==(v1>0.25)) hTransitions+=1; }
if (y>0){ const v0=x28[(y-1)*w+x], v1=v; if ((v0>0.25)!==(v1>0.25)) vTransitions+=1; }
}
}
const total = sum/(w*h); // [0,1]
const cxn = sum>1e-6 ? (cx/sum)/(w-1) : 0.5; // [0,1]
const cyn = sum>1e-6 ? (cy/sum)/(h-1) : 0.5; // [0,1]
let left=0,right=0,top=0,bottom=0;
for (let y=0;y<h;y++){ for (let x=0;x<w;x++){ const v=x28[y*w+x]; if (x<w/2) left+=v; else right+=v; if (y<h/2) top+=v; else bottom+=v; }}
const lr = (right/(right+left+1e-6));
const tb = (bottom/(bottom+top+1e-6));
const htn = Math.min(1, hTransitions/(w*h*0.35));
const vtn = Math.min(1, vTransitions/(w*h*0.35));
// Loopiness proxy: ink near perimeter low vs center high
let perimeter=0, center=0; const m=5;
for (let y=0;y<h;y++){
for (let x=0;x<w;x++){
const v=x28[y*w+x];
const isBorder = (x<m||x>=w-m||y<m||y>=h-m);
if (isBorder) perimeter+=v; else center+=v;
}
}
const loopiness = Math.min(1, center/(perimeter+center+1e-6)*1.8);
return [total, cxn, cyn, lr, tb, htn, vtn, loopiness];
}
function softmax(arr){ const m=Math.max(...arr); const ex=arr.map(v=>Math.exp(v-m)); const s=ex.reduce((a,b)=>a+b,0)+1e-12; return ex.map(v=>v/s); }
function l2norm(a){ return Math.hypot(...a) || 0; }
function normalize(a){ const n=l2norm(a); return n>0 ? a.map(v=>v/n) : a.slice(); }
function cosine(a,b){ let s=0; for (let i=0;i<a.length;i++) s+=a[i]*b[i]; const na=l2norm(a), nb=l2norm(b)||1; return na>0 ? s/(na*nb) : 0; }
// MNIST-like normalization: crop to tight bbox, scale into 20x20, center in 28x28
function normalize28(x28){
const w=28,h=28, thr=0.2;
let minX=29,minY=29,maxX=-1,maxY=-1, sum=0, cx=0, cy=0;
for (let y=0;y<h;y++){
for (let x=0;x<w;x++){
const v = x28[y*w+x];
if (v>thr){ if (x<minX) minX=x; if (x>maxX) maxX=x; if (y<minY) minY=y; if (y>maxY) maxY=y; }
sum += v; cx += x*v; cy += y*v;
}
}
if (sum < 1e-3 || maxX<0){ return x28; }
const comX = cx/sum, comY = cy/sum;
const bw = Math.max(1, maxX-minX+1), bh = Math.max(1, maxY-minY+1);
const scale = 20/Math.max(bw, bh);
const out = new Float32Array(w*h);
// center of canvas
const cxOut = (w-1)/2, cyOut = (h-1)/2;
for (let y=0;y<h;y++){
for (let x=0;x<w;x++){
// map output pixel to source space around COM
const sx = (x - cxOut)/scale + comX;
const sy = (y - cyOut)/scale + comY;
out[y*w+x] = bilinearSample(x28, w, h, sx, sy);
}
}
return out;
}
function bilinearSample(img, w, h, x, y){
const x0 = Math.floor(x), y0 = Math.floor(y);
const x1 = x0+1, y1 = y0+1;
const tx = x - x0, ty = y - y0;
function at(ix,iy){ if (ix<0||iy<0||ix>=w||iy>=h) return 0; return img[iy*w+ix]; }
const v00 = at(x0,y0), v10 = at(x1,y0), v01 = at(x0,y1), v11 = at(x1,y1);
const a = v00*(1-tx)+v10*tx; const b = v01*(1-tx)+v11*tx; return a*(1-ty)+b*ty;
}
// Simple dilation (max-pooling 3x3) to thicken strokes
function dilate28(x){
const w=28,h=28; const out=new Float32Array(w*h);
for (let y=0;y<h;y++){
for (let x0=0;x0<w;x0++){
let m=0;
for (let dy=-1;dy<=1;dy++){
for (let dx=-1;dx<=1;dx++){
const xx=x0+dx, yy=y+dy; if (xx<0||yy<0||xx>=w||yy>=h) continue;
const v = x[yy*w+xx]; if (v>m) m=v;
}
}
out[y*w+x0]=m;
}
}
return out;
}
// Glyph-based 28x28 prototypes for digits 0-9 (normalized)
const protoGlyphs28 = [];
(function buildGlyphProtos(){
const off = document.createElement('canvas'); off.width = CANVAS_PX; off.height = CANVAS_PX;
const c = off.getContext('2d');
for (let d=0; d<10; d++){
c.fillStyle = '#ffffff'; c.fillRect(0,0,off.width,off.height);
c.fillStyle = '#000000'; c.textAlign='center'; c.textBaseline='middle';
c.font = 'bold 180px system-ui, -apple-system, Segoe UI, Roboto, Arial, sans-serif';
c.fillText(String(d), off.width/2, off.height*0.56);
const src = c.getImageData(0,0,off.width,off.height).data; const block = off.width/28;
const vec = new Float32Array(28*28);
for (let gy=0; gy<28; gy++){
for (let gx=0; gx<28; gx++){
let acc=0, cnt=0; const x0=Math.floor(gx*block), y0=Math.floor(gy*block);
for (let yy=y0; yy<y0+block; yy++){
for (let xx=x0; xx<x0+block; xx++){
const idx=(yy*off.width+xx)*4; const r=src[idx], g=src[idx+1], b=src[idx+2];
const gray=(r+g+b)/3/255; acc += (1-gray); cnt++;
}
}
vec[gy*28+gx] = acc/(cnt||1);
}
}
const normed = normalize28(vec);
const n = l2norm(normed)||1; protoGlyphs28.push(normed.map(v=>v/n));
}
})();
function dot(a,b){ let s=0; for (let i=0;i<a.length;i++) s+=a[i]*b[i]; return s; }
// Resize handling and node layout
let width=640, height=360; const margin = { top: 16, right: 8, bottom: 24, left: 8 };
let inputGrid = { cell: 0, x: 0, y: 0, width: 0, height: 0 };
function layoutNodes(){
// Right panel width, and a non-square aspect ratio for clarity
width = Math.max(280, Math.round(right.clientWidth || 640));
height = Math.max(260, Math.round(width * 0.56));
svg.attr('width', width).attr('height', height);
// Match canvas height to SVG height so both columns align vertically
try { canvas.style.height = '100%'; canvasWrap.style.height = height + 'px'; } catch(_) {}
const innerW = width - margin.left - margin.right; const innerH = height - margin.top - margin.bottom;
gRoot.attr('transform', `translate(${margin.left},${margin.top})`);
// Input grid layout (28x28) at left — cap width to a fraction of innerW
const maxGridFrac = 0.28; // at most 28% of available width
const cellByHeight = Math.floor(innerH / 28);
const cellByWidth = Math.floor((innerW * maxGridFrac) / 28);
let cell = Math.max(3, Math.min(cellByHeight, cellByWidth));
let gridH = cell * 28; let gridY = Math.floor((innerH - gridH)/2);
inputGrid = { cell, x: 0, y: gridY, width: cell*28, height: gridH };
// Equal horizontal gaps: grid -> L0 -> L1 -> L2
const nLayers = layerSizes.length; // 3
const rightLabelPad = 36; // smaller pad; use more width for spreading layers
const minGap = 28; const maxGap = 260;
// Ensure enough free space; shrink grid if needed
const desiredMinFree = rightLabelPad + nLayers * minGap; // 3 equal gaps
if (inputGrid.width + desiredMinFree > innerW) {
cell = Math.max(3, Math.floor((innerW - desiredMinFree) / 28));
gridH = cell * 28; gridY = Math.floor((innerH - gridH)/2);
inputGrid = { cell, x: 0, y: gridY, width: cell*28, height: gridH };
}
const gridRight = inputGrid.x + inputGrid.width;
const freeW = Math.max(nLayers * minGap, innerW - gridRight - rightLabelPad);
const gapX = Math.min(maxGap, Math.max(minGap, Math.floor(freeW / nLayers)));
const xs = Array.from({ length: nLayers }, (_, li) => gridRight + gapX * (li + 1));
// Y positions evenly spaced per layer
layers.forEach((nodes, li)=>{
const n = nodes.length;
if (n <= 1) {
nodes.forEach((nd)=>{ nd.x = xs[li]; nd.y = innerH/2; });
} else {
const occupancy = 0.9; // use 90% of vertical space
const span = innerH * occupancy;
const topPad = (innerH - span) / 2;
const spacing = span / (n - 1);
nodes.forEach((nd, i)=>{ nd.x = xs[li]; nd.y = topPad + i*spacing; });
}
});
}
let lastX28 = new Float32Array(28*28);
function nodeRadiusForNode(n){
const a = Math.max(0, Math.min(1, (n && typeof n.a === 'number') ? n.a : 0));
if (n && n.layer === 2) {
// Output nodes: variable radius based on activation
return 8 + 10 * a; // ~8–18
}
// Hidden/feature nodes: variable radius based on activation
return 5 + 5 * a; // ~5–10
}
function renderInputGrid(){
if (!inputGrid || inputGrid.cell <= 0) return;
const data = Array.from({ length: 28*28 }, (_, i) => ({ i, v: lastX28[i] || 0 }));
const sel = gInput.selectAll('rect.input-px').data(data, d=>d.i);
const gap = Math.max(1, Math.floor(inputGrid.cell * 0.10));
const inner = Math.max(1, inputGrid.cell - gap);
const offset = Math.floor(gap / 2);
sel.enter().append('rect').attr('class','input-px')
.attr('width', inner).attr('height', inner)
.merge(sel)
.attr('x', d => inputGrid.x + (d.i % 28) * inputGrid.cell + offset)
.attr('y', d => inputGrid.y + Math.floor(d.i / 28) * inputGrid.cell + offset)
.attr('fill', d => {
// Increase perceived contrast of the input grid by applying a gamma curve
const k = Math.pow(Math.max(0, Math.min(1, d.v)), 0.6); // gamma < 1 → darker darks
const g = 255 - Math.round(k * 255);
return `rgb(${g},${g},${g})`;
})
.attr('stroke', 'none');
sel.exit().remove();
// Border around the input grid area
const borderSel = gInput.selectAll('rect.input-border').data([0]);
borderSel.enter().append('rect').attr('class','input-border')
.attr('fill','none')
.attr('rx', 0).attr('ry', 0)
.attr('stroke','var(--text-color)')
.attr('stroke-opacity', 0.25)
.attr('stroke-width', 1)
.lower()
.merge(borderSel)
.attr('x', inputGrid.x-1)
.attr('y', inputGrid.y-1)
.attr('width', inputGrid.width+1)
.attr('height', inputGrid.height+1);
// Centered label above the input grid
const labelSel = gInput.selectAll('text.input-label').data([0]);
labelSel.enter().append('text').attr('class','input-label')
.attr('text-anchor','middle')
.style('font-size','12px')
.style('font-weight','700')
.style('fill','var(--muted-color)')
.merge(labelSel)
.attr('x', inputGrid.x + inputGrid.width / 2)
.attr('y', Math.max(12, inputGrid.y - 10))
.text('Input 28×28');
}
// Compute link path between two layered nodes using their current radii
function computeLinkD(d){
const s = layers[d.s.l][d.s.i];
const t = layers[d.t.l][d.t.j];
if (!s || !t) return '';
const rs = nodeRadiusForNode(s);
const rt = nodeRadiusForNode(t);
// Use fixed anchors on circle edges for all inter-layer links (except grid->L0 handled elsewhere)
const x1 = s.x + rs, y1 = s.y; // right edge of source circle
const x2 = t.x - rt, y2 = t.y; // left edge of target circle
const dx = (x2 - x1) * 0.45;
return `M${x1},${y1} C${x1+dx},${y1} ${x2-dx},${y2} ${x2},${y2}`;
}
function renderInputLinks(){
// Draw bundle-like links from input grid right edge to first layer nodes (features)
const firstLayer = layers[0];
if (!firstLayer || !inputGrid || inputGrid.cell <= 0) { gInputLinks.selectAll('path').remove(); return; }
const x0 = inputGrid.x + inputGrid.width;
// Define a centered vertical band (half grid height) and distribute sources evenly
const k = firstLayer.length;
const band = inputGrid.height * 0.5;
const centerY = inputGrid.y + inputGrid.height / 2;
const yStart = centerY - band / 2;
const spacing = k > 1 ? band / (k - 1) : 0;
const paths = firstLayer.map((n, idx) => {
// source y from centered band, equidistant
const y0 = k > 1 ? (yStart + idx * spacing) : centerY;
// Target anchor: center of left edge of the node circle
const r = nodeRadiusForNode(n);
const x1 = n.x - r;
const y1 = n.y;
const dx = (x1 - x0) * 0.35;
return { x0, y0, x1, y1, c1x: x0 + dx, c1y: y0, c2x: x1 - dx, c2y: y1, idx };
});
const sel = gInputLinks.selectAll('path.input-link').data(paths);
sel.enter().append('path').attr('class','input-link')
.attr('fill','none')
.attr('stroke','var(--text-color)')
.attr('stroke-opacity', 0.25)
.attr('stroke-width', 1)
.attr('stroke-linecap','round')
.merge(sel)
.attr('d', d => `M${d.x0},${d.y0} C${d.c1x},${d.c1y} ${d.c2x},${d.c2y} ${d.x1},${d.y1}`)
.attr('stroke','var(--text-color)');
sel.exit().remove();
}
// Recompute input link path on the fly (used when node radii change)
function computeInputLinkD(idx){
const firstLayer = layers[0];
const n = firstLayer[idx]; if (!n) return '';
const x0 = inputGrid.x + inputGrid.width;
const k = firstLayer.length;
const band = inputGrid.height * 0.5;
const centerY = inputGrid.y + inputGrid.height / 2;
const yStart = centerY - band / 2;
const spacing = k > 1 ? band / (k - 1) : 0;
const y0 = k > 1 ? (yStart + idx * spacing) : centerY;
const yTarget = n.y;
const vx = n.x - x0; const vy = yTarget - y0; const L = Math.hypot(vx, vy) || 1;
const r = nodeRadiusForNode(n);
const x1 = n.x - (vx / L) * r;
const y1 = yTarget - (vy / L) * r;
const dx = (x1 - x0) * 0.35;
const c1x = x0 + dx, c1y = y0, c2x = x1 - dx, c2y = y1;
return `M${x0},${y0} C${c1x},${c1y} ${c2x},${c2y} ${x1},${y1}`;
}
function renderGraph(showEdges){
layoutNodes();
renderInputGrid();
renderInputLinks();
// Nodes
const allNodes = layers.flat();
const nodeSel = gNodes.selectAll('circle.node').data(allNodes, d=>d.id);
nodeSel.enter().append('circle').attr('class','node')
.attr('r', 10)
.attr('cx', d=>d.x).attr('cy', d=>d.y)
.attr('fill', d=> d.layer===2 ? 'var(--page-bg)' : 'var(--primary-color)')
.attr('fill-opacity', d=> d.layer===2 ? 1 : 0.12)
.attr('stroke', d=> d.layer===2 ? 'var(--border-color)' : 'var(--border-color)')
.attr('stroke-width',1)
.attr('stroke-linejoin','round')
.merge(nodeSel)
.attr('cx', d=>d.x).attr('cy', d=>d.y)
.attr('opacity', 1);
nodeSel.exit().remove();
// Labels for first hidden layer only (avoid stacking with output probs)
const labels = [];
layers[0].forEach((n,i)=> labels.push({ x:n.x-30, y:n.y+4, txt:`f${i+1}` }));
const labSel = gLabels.selectAll('text').data(labels);
labSel.enter().append('text')
.style('font-size','10px')
.style('fill','var(--muted-color)')
.style('paint-order','stroke')
.style('stroke','var(--page-bg)')
.style('stroke-width','3px')
.attr('x', d=>d.x)
.attr('y', d=>d.y)
.text(d=>d.txt)
.merge(labSel)
.style('paint-order','stroke')
.style('stroke','var(--page-bg)')
.style('stroke-width','5px')
.attr('x', d=>d.x)
.attr('y', d=>d.y)
.text(d=>d.txt);
labSel.exit().remove();
// Links as smooth curves
const linkSel = gLinks.selectAll('path.link').data(links, d=> `${d.s.l}-${d.s.i}-${d.t.l}-${d.t.j}`);
linkSel.enter().append('path').attr('class','link')
.attr('d', computeLinkD)
.attr('fill','none')
.attr('stroke','var(--text-color)')
.attr('stroke-opacity', 0.25)
.attr('stroke-width', d=> 0.5 + d.w*1.2)
.attr('stroke-linecap','round')
.merge(linkSel)
.attr('d', computeLinkD)
.attr('stroke','var(--text-color)')
.attr('stroke-width', d=> 0.5 + d.w*1.2);
linkSel.exit().remove();
// Ensure output labels remain aligned with the last layer on resize
gOutText.selectAll('g.out-label')
.attr('transform', function(d){
if (!d || typeof d.digit !== 'number') return d3.select(this).attr('transform');
const n = layers[2][d.digit];
if (!n) return d3.select(this).attr('transform');
const offset = nodeRadiusForNode(n) + 8;
return `translate(${n.x + offset},${n.y})`;
});
// Ensure clip-path circles are updated on resize
if (defs) {
const clips = defs.selectAll('clipPath.clip-node').data(layers[2], d=>d.id);
const ce = clips.enter().append('clipPath').attr('class','clip-node').attr('clipPathUnits','userSpaceOnUse').attr('id', d=>`clip-${d.id}`);
ce.append('circle');
clips.merge(ce).select('circle').attr('cx', d=>d.x).attr('cy', d=>d.y).attr('r', d=>nodeRadiusForNode(d));
clips.exit().remove();
}
}
function setNodeActivations(h1, h2, out){
layers[0].forEach((n,i)=> n.a = h1[i] || 0);
layers[1].forEach((n,i)=> n.a = h2[i] || 0);
layers[2].forEach((n,i)=> n.a = out[i] || 0);
// Determine top prediction (for ghosting others)
let argmaxIdx = 0; let bestProb = -1;
if (Array.isArray(out)) {
for (let i=0;i<out.length;i++){ if (out[i] > bestProb){ bestProb = out[i]; argmaxIdx = i; } }
}
// Color/size by activation with smooth transitions
gNodes.selectAll('circle.node')
.transition().duration(180).ease(d3.easeCubicOut)
.attr('fill', d=> d.layer===2 ? 'var(--page-bg)' : 'var(--primary-color)')
.attr('fill-opacity', d=> d.layer===2 ? 1 : (0.12 + 0.58*Math.min(1, d.a||0)))
.attr('stroke', 'var(--primary-color)')
.attr('stroke-opacity', d=> (d.layer===2 ? 0.9 : (0.45 + 0.45*Math.min(1, d.a||0))))
.attr('opacity', d=> 0.55 + 0.45*Math.min(1, d.a||0))
.attr('r', d=> nodeRadiusForNode(d));
// Link opacity by activation flow
gLinks.selectAll('path.link')
.transition().duration(180).ease(d3.easeCubicOut)
.attr('d', computeLinkD)
.attr('stroke','var(--text-color)')
.attr('stroke-opacity', d=>{
const aS = layers[d.s.l][d.s.i].a || 0; const aT = layers[d.t.l][d.t.j].a || 0;
return Math.min(1, 0.15 + 0.85 * (aS * aT));
})
.attr('stroke-width', d=>{
const aS = layers[d.s.l][d.s.i].a || 0; const aT = layers[d.t.l][d.t.j].a || 0;
return 0.6 + 2.2*(aS*aT);
});
// Theme-aware and activation-aware input links
gInputLinks.selectAll('path.input-link')
.transition().duration(180).ease(d3.easeCubicOut)
.attr('d', (d)=> computeInputLinkD(d.idx))
.attr('stroke','var(--text-color)')
.attr('stroke-opacity', 0.25)
.attr('stroke-width', d=> 0.6 + 2.0*(layers[0][d.idx] ? (layers[0][d.idx].a||0) : 0));
// Update clip-path circles to match new radii/positions of output nodes
if (defs) {
const clips = defs.selectAll('clipPath.clip-node').data(layers[2], d=>d.id);
clips.select('circle')
.transition().duration(180).ease(d3.easeCubicOut)
.attr('cx', d=>d.x)
.attr('cy', d=>d.y)
.attr('r', d=> nodeRadiusForNode(d));
}
// Theme-aware input links on updates handled above via transition
// Output labels: digit placed to the right of the node
const outs = layers[2].map((n,i)=>({ x:n.x + nodeRadiusForNode(n) + 8, y:n.y, digit: i, prob: (out[i]||0), isTop: i===argmaxIdx }));
const gSel = gOutText.selectAll('g.out-label').data(outs, d=>d.digit);
const gEnter = gSel.enter().append('g').attr('class','out-label');
gEnter.append('text').attr('class','out-digit')
.style('font-size','12px').style('font-weight','800').style('fill','var(--text-color)')
.attr('text-anchor','start').attr('dominant-baseline','middle')
.style('paint-order','stroke').style('stroke','var(--transparent-page-contrast)').style('stroke-width','3px');
const merged = gEnter.merge(gSel)
.attr('transform', d=>`translate(${d.x},${d.y})`)
.each(function(d){
const sel = d3.select(this);
sel.select('text.out-digit')
.attr('x', 0).attr('y', 0)
.text(String(d.digit));
// Ghost non-top predictions
sel.style('opacity', d.isTop ? 1 : 0.35);
});
// Remove any previous decorative rings (no highlight ring desired)
gRoot.selectAll('circle.top-ring').remove();
// (tooltip interactions removed)
gSel.exit().remove();
// Output liquid fill using clipPath + rect from bottom
const rects = gNodes.selectAll('rect.out-liquid').data(layers[2], d=>d.id);
const rectEnter = rects.enter().append('rect').attr('class','out-liquid')
.attr('fill','var(--primary-color)')
.attr('fill-opacity', 0.55)
.attr('clip-path', d => `url(#clip-${d.id})`);
rectEnter.merge(rects)
.transition().duration(180).ease(d3.easeCubicOut)
.attr('x', d=> d.x - nodeRadiusForNode(d))
.attr('width', d=> 2 * nodeRadiusForNode(d))
.attr('y', d=> {
const r = nodeRadiusForNode(d);
const h = 2 * r * Math.max(0, Math.min(1, d.a||0));
return d.y + r - h;
})
.attr('height', d=> 2 * nodeRadiusForNode(d) * Math.max(0, Math.min(1, d.a||0)))
.attr('fill-opacity', 0.55);
rects.exit().remove();
}
// (no separate updateBars; bars are rendered next to nodes)
function runPipeline(){
const x28raw = downsample28();
const x28 = dilate28(normalize28(x28raw));
// Update input grid data
lastX28 = x28;
renderInputGrid();
const feats = computeFeatures(x28); // 8D in [0,1]
const inkMass = feats[0];
// Hide hint when user has drawn something
if (hint) { hint.style.opacity = inkMass < 0.01 ? 1 : 0; }
// Hidden 1 = raw features
const h1 = feats;
// Hidden 2 = simple non-linear mix for visualization only
const h2 = layers[1].map((_, j)=>{
let s=0; for (let i=0;i<layers[0].length;i++){ const w = (Math.sin(i*17+j*31)+1)/2 * 0.8 + 0.1; s += w*h1[i]; }
return Math.tanh(s*0.8);
});
let prob;
if (inkMass < 0.03){
// Too little ink: return near-uniform distribution
prob = Array.from({length:10}, ()=> 1/10);
} else {
// Prefer TFJS model if available
const tfProbs = predictTfjs(x28);
if (tfProbs && tfProbs.length === 10) {
prob = tfProbs;
} else {
// Fallback: rely mostly on glyph similarity
const x28n = normalize(x28);
const logitsGlyph = protoGlyphs28.map(p => 8.0 * cosine(x28n, p));
const logitsLinear = W.map((row, k)=> dot(row, h1) + b[k]);
const logits = logitsGlyph.map((v,k)=> v + 0.2*logitsLinear[k]);
prob = softmax(logits);
}
}
setNodeActivations(h1, h2.map(v => (v+1)/2), prob);
}
function downsample28(){
// From canvas (224x224) to 28x28 by average pooling in 8x8 blocks
const block = CANVAS_PX/28; // 8
const src = ctx.getImageData(0,0,CANVAS_PX,CANVAS_PX).data;
const out = new Float32Array(28*28);
for (let gy=0; gy<28; gy++){
for (let gx=0; gx<28; gx++){
let acc=0; let cnt=0;
const x0 = Math.floor(gx*block), y0 = Math.floor(gy*block);
for (let y=y0; y<y0+block; y++){
for (let x=x0; x<x0+block; x++){
const idx = (y*CANVAS_PX + x)*4; // RGBA
const r=src[idx], g=src[idx+1], b=src[idx+2];
const gray = (r+g+b)/3/255; // 1: white, 0: black
const ink = 1-gray; // 1: ink/black
acc += ink; cnt++;
}
}
out[gy*28+gx] = acc/(cnt||1);
}
}
return out;
}
function clearCanvas(){ ctx.fillStyle = '#ffffff'; ctx.fillRect(0,0,CANVAS_PX,CANVAS_PX); runPipeline(); }
// Drawing interactions
let drawing=false; let last=null;
let hasInteracted=false;
const getPos = (ev) => {
const rect = canvas.getBoundingClientRect();
const sx = CANVAS_PX/rect.width; const sy = CANVAS_PX/rect.height;
const x = (('touches' in ev)? ev.touches[0].clientX : ev.clientX) - rect.left;
const y = (('touches' in ev)? ev.touches[0].clientY : ev.clientY) - rect.top;
return { x: x*sx, y: y*sy };
};
function drawTo(p){
const size = 24;
ctx.lineCap='round'; ctx.lineJoin='round'; ctx.strokeStyle='#000000'; ctx.lineWidth=size;
if (!last) last = p;
ctx.beginPath(); ctx.moveTo(last.x, last.y); ctx.lineTo(p.x, p.y); ctx.stroke();
last = p; runPipeline();
}
function onDown(ev){
drawing=true; last=null;
if (!hasInteracted){ hasInteracted=true; try { eraseBtn.style.display = 'flex'; } catch(_) {} }
drawTo(getPos(ev)); ev.preventDefault();
}
function onMove(ev){ if (!drawing) return; drawTo(getPos(ev)); ev.preventDefault(); }
function onUp(){ drawing=false; last=null; }
canvas.addEventListener('mousedown', onDown); canvas.addEventListener('mousemove', onMove); window.addEventListener('mouseup', onUp);
canvas.addEventListener('touchstart', onDown, { passive:false }); canvas.addEventListener('touchmove', onMove, { passive:false }); window.addEventListener('touchend', onUp);
// (erase button handled as overlay)
const rerender = () => { renderGraph(true); };
if (window.ResizeObserver) {
const ro = new ResizeObserver(()=>rerender());
ro.observe(right);
ro.observe(canvas);
} else { window.addEventListener('resize', rerender); }
// TFJS model (optional)
let tfModel = null;
const tryLoadModel = async () => {
await new Promise((res)=> ensureTF(res));
const candidates = [
// Prefer public path via symlink to assets/data
'/data/mnist-variant-model.json',
// Fallbacks to relative copies under content assets (shards must be colocated)
'./assets/data/mnist-variant-model.json',
'../assets/data/mnist-variant-model.json',
'/assets/data/mnist-variant-model.json',
// Fallback to public TFJS MNIST
'https://storage.googleapis.com/tfjs-models/tfjs/mnist/model.json'
];
for (const u of candidates){
try { tfModel = await tf.loadLayersModel(u); return; } catch(_) { /* try next */ }
}
tfModel = null;
};
function predictTfjs(x28){
if (!tfModel || !window.tf) return null;
const run = (arr) => {
const t = tf.tidy(()=> tf.tensor(arr, [28,28,1]).expandDims(0));
try { const y = tfModel.predict(t); const p = y.softmax(); const out = Array.from(p.dataSync()); tf.dispose([y,p,t]); return out; } catch(e){ tf.dispose(t); return null; }
};
// Try both orientations and keep the one with higher confidence
const p1 = run(x28);
const inv = x28.map(v=>1-v);
const p2 = run(inv);
let probs = p1 || p2;
if (p1 && p2){
const m1 = Math.max(...p1), m2 = Math.max(...p2);
probs = m2>m1 ? p2 : p1;
}
if (!probs) return null;
// Normalize output size to 10 classes (pad or slice)
if (probs.length < 10){ probs = probs.concat(Array(10 - probs.length).fill(0)); }
if (probs.length > 10){ probs = probs.slice(0,10); }
return probs;
}
// Initial render
renderGraph(true);
clearCanvas();
tryLoadModel();
};
if (document.readyState === 'loading') { document.addEventListener('DOMContentLoaded', () => ensureD3(bootstrap), { once: true }); } else { ensureD3(bootstrap); }
})();
</script>