Spaces:
Running
Running
| <div class="d3-neural" style="width:100%;margin:10px 0;"></div> | |
| <style> | |
| .d3-neural .controls { margin-top: 12px; display: flex; gap: 12px; align-items: center; flex-wrap: wrap; } | |
| .d3-neural .controls label { font-size: 12px; color: var(--muted-color); display: flex; align-items: center; gap: 8px; white-space: nowrap; padding: 6px 10px; } | |
| .d3-neural .controls input[type="range"]{ width: 160px; } | |
| .d3-neural .panel { display:flex; gap:16px; align-items:center; } | |
| .d3-neural .left { flex: 0 0 320px; display:flex; flex-direction:column; gap:8px; } | |
| .d3-neural .right { flex: 1 1 auto; min-width: 0; } | |
| .d3-neural canvas { width: 100%; height: auto; border-radius: 8px; border: 1px solid var(--border-color); background: var(--surface-bg); display:block; } | |
| .d3-neural .preview28 { display:grid; grid-template-columns: repeat(28, 1fr); gap: 1px; width: 100%; } | |
| .d3-neural .preview28 span { display:block; aspect-ratio:1/1; border-radius:2px; } | |
| .d3-neural .legend { font-size: 12px; color: var(--text-color); line-height:1.35; } | |
| .d3-neural .probs { display:flex; gap:6px; align-items:flex-end; height: 64px; } | |
| .d3-neural .probs .bar { width: 10px; border-radius:2px 2px 0 0; background: var(--border-color); transition: height .15s ease, background-color .15s ease; } | |
| .d3-neural .probs .bar.active { background: var(--primary-color); } | |
| .d3-neural .probs .tick { font-size: 10px; color: var(--muted-color); text-align:center; margin-top: 2px; } | |
| .d3-neural .canvas-wrap { position: relative; } | |
| .d3-neural .erase-btn { position: absolute; top: 8px; right: 8px; width: 32px; height: 32px; display:flex; align-items:center; justify-content:center; border: 1px solid var(--border-color); } | |
| .d3-neural .canvas-hint { position: absolute; top: 8px; left: 12px; font-size: 12px; font-weight: 700; color: rgb(156, 156, 156); pointer-events: none; } | |
| </style> | |
| <script> | |
| (() => { | |
| const ensureD3 = (cb) => { | |
| if (window.d3 && typeof window.d3.select === 'function') return cb(); | |
| let s = document.getElementById('d3-cdn-script'); | |
| if (!s) { s = document.createElement('script'); s.id = 'd3-cdn-script'; s.src = 'https://cdn.jsdelivr.net/npm/d3@7/dist/d3.min.js'; document.head.appendChild(s); } | |
| const onReady = () => { if (window.d3 && typeof window.d3.select === 'function') cb(); }; | |
| s.addEventListener('load', onReady, { once: true }); | |
| if (window.d3) onReady(); | |
| }; | |
| const ensureTF = (cb) => { | |
| if (window.tf && typeof window.tf.tensor === 'function') return cb(); | |
| let s = document.getElementById('tfjs-cdn-script'); | |
| if (!s) { s = document.createElement('script'); s.id = 'tfjs-cdn-script'; s.src = 'https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.20.0/dist/tf.min.js'; document.head.appendChild(s); } | |
| const onReady = () => { if (window.tf && typeof window.tf.tensor === 'function') cb(); }; | |
| s.addEventListener('load', onReady, { once: true }); | |
| if (window.tf) onReady(); | |
| }; | |
| const bootstrap = () => { | |
| const mount = document.currentScript ? document.currentScript.previousElementSibling : null; | |
| const container = (mount && mount.querySelector && mount.querySelector('.d3-neural')) || document.querySelector('.d3-neural'); | |
| if (!container) return; | |
| if (container.dataset) { if (container.dataset.mounted === 'true') return; container.dataset.mounted = 'true'; } | |
| // Layout: left (canvas + preview + controls), right (svg network) | |
| const panel = document.createElement('div'); | |
| panel.className = 'panel'; | |
| const left = document.createElement('div'); left.className = 'left'; | |
| const right = document.createElement('div'); right.className = 'right'; | |
| panel.appendChild(left); panel.appendChild(right); | |
| container.appendChild(panel); | |
| // Canvas for drawing | |
| const CANVAS_PX = 224; // canvas pixels (square) | |
| const canvas = document.createElement('canvas'); canvas.width = CANVAS_PX; canvas.height = CANVAS_PX; | |
| const ctx = canvas.getContext('2d'); | |
| // init white bg | |
| ctx.fillStyle = '#ffffff'; ctx.fillRect(0,0,CANVAS_PX,CANVAS_PX); | |
| const canvasWrap = document.createElement('div'); canvasWrap.className = 'canvas-wrap'; | |
| canvasWrap.appendChild(canvas); | |
| // Erase icon button (top-right) | |
| const eraseBtn = document.createElement('button'); eraseBtn.className='erase-btn button--ghost'; eraseBtn.type='button'; eraseBtn.setAttribute('aria-label','Clear'); | |
| eraseBtn.innerHTML = '<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><polyline points="3 6 5 6 21 6"></polyline><path d="M19 6l-1 14a2 2 0 0 1-2 2H8a2 2 0 0 1-2-2L5 6"></path><path d="M10 11v6"></path><path d="M14 11v6"></path><path d="M9 6V4a2 2 0 0 1 2-2h2a2 2 0 0 1 2 2v2"></path></svg>'; | |
| eraseBtn.addEventListener('click', () => clearCanvas()); | |
| canvasWrap.appendChild(eraseBtn); | |
| // Hint (top-left) | |
| const hint = document.createElement('div'); hint.className='canvas-hint'; hint.textContent='Draw a digit here'; | |
| canvasWrap.appendChild(hint); | |
| left.appendChild(canvasWrap); | |
| // (preview grid removed) | |
| // (controls removed; erase button is overlayed on canvas) | |
| // (prediction panel removed; predictions rendered next to output nodes) | |
| // SVG network on right | |
| const svg = d3.select(right).append('svg').attr('width','100%').style('display','block'); | |
| const gRoot = svg.append('g'); | |
| const gInput = gRoot.append('g').attr('class','input'); | |
| const gInputLinks = gRoot.append('g').attr('class','input-links'); | |
| const gLinks = gRoot.append('g').attr('class','links'); | |
| const gNodes = gRoot.append('g').attr('class','nodes'); | |
| const gLabels = gRoot.append('g').attr('class','labels'); | |
| const gOutText = gRoot.append('g').attr('class','out-probs'); | |
| // Network structure (compact: 8 -> 8 -> 10) | |
| const layerSizes = [8, 8, 10]; | |
| const layers = layerSizes.map((n, li)=> Array.from({length:n}, (_, i)=>({ id:`L${li}N${i}`, layer: li, index: i, a:0 }))); | |
| // Links only between hidden->hidden and hidden->output | |
| const links = []; | |
| for (let i=0;i<layerSizes[0];i++){ | |
| for (let j=0;j<layerSizes[1];j++) links.push({ s:{l:0,i}, t:{l:1,j}, w: (Math.sin(i*17+j*31)+1)/2 }); | |
| } | |
| for (let i=0;i<layerSizes[1];i++){ | |
| for (let j=0;j<layerSizes[2];j++) links.push({ s:{l:1,i}, t:{l:2,j}, w: (Math.cos(i*7+j*13)+1)/2 }); | |
| } | |
| // Linear classifier: logits = W * feats + b, feats in [0,1] | |
| // features: [total, cx, cy, lr, tb, htrans, vtrans, loopiness] | |
| const W = [ | |
| // 0 1 2 3 4 5 6 7 | |
| [ 0.3, 0.0, 0.0, 0.0, 0.0, -0.8, -0.6, 1.2], // 0 | |
| [-0.2, 0.9, 0.2, 0.8, 0.1, -0.2, 0.2, -1.1], // 1 | |
| [ 0.1, 0.4, 0.2, 0.5, 0.2, 0.9, 0.1, -0.6], // 2 | |
| [ 0.2, 0.3, 0.2, 0.2, 0.2, 0.9, 0.0, -0.2], // 3 | |
| [ 0.0,-0.3, 0.2,-0.6, 0.4, 0.2, 0.8, -0.6], // 4 | |
| [ 0.1,-0.4, 0.2,-0.5, 0.5, 0.9, 0.1, -0.6], // 5 | |
| [ 0.2,-0.2, 0.6,-0.2, 0.8, -0.3, 0.2, 0.6], // 6 | |
| [ 0.0, 0.6,-0.2, 0.6,-0.8, 0.6, 0.0, -0.8], // 7 | |
| [ 0.4, 0.0, 0.0, 0.1, 0.1, 0.6, 0.6, 1.0], // 8 | |
| [ 0.2, 0.2,-0.6, 0.2,-0.8, 0.2, 0.6, 0.5], // 9 | |
| ]; | |
| const b = [-0.2, -0.1, -0.05, -0.05, -0.05, -0.05, -0.05, -0.1, -0.15, -0.1]; | |
| function computeFeatures(x28){ | |
| // x28: Float32Array length 784, values in [0,1] (1 = black/ink) | |
| let sum=0, cx=0, cy=0; const w=28, h=28; | |
| const rowSum = new Array(h).fill(0); const colSum = new Array(w).fill(0); | |
| let hTransitions=0, vTransitions=0; | |
| for (let y=0;y<h;y++){ | |
| for (let x=0;x<w;x++){ | |
| const v = x28[y*w+x]; sum += v; cx += x*v; cy += y*v; rowSum[y]+=v; colSum[x]+=v; | |
| if (x>0){ const v0=x28[y*w+(x-1)], v1=v; if ((v0>0.25)!==(v1>0.25)) hTransitions+=1; } | |
| if (y>0){ const v0=x28[(y-1)*w+x], v1=v; if ((v0>0.25)!==(v1>0.25)) vTransitions+=1; } | |
| } | |
| } | |
| const total = sum/(w*h); // [0,1] | |
| const cxn = sum>1e-6 ? (cx/sum)/(w-1) : 0.5; // [0,1] | |
| const cyn = sum>1e-6 ? (cy/sum)/(h-1) : 0.5; // [0,1] | |
| let left=0,right=0,top=0,bottom=0; | |
| for (let y=0;y<h;y++){ for (let x=0;x<w;x++){ const v=x28[y*w+x]; if (x<w/2) left+=v; else right+=v; if (y<h/2) top+=v; else bottom+=v; }} | |
| const lr = (right/(right+left+1e-6)); | |
| const tb = (bottom/(bottom+top+1e-6)); | |
| const htn = Math.min(1, hTransitions/(w*h*0.35)); | |
| const vtn = Math.min(1, vTransitions/(w*h*0.35)); | |
| // Loopiness proxy: ink near perimeter low vs center high | |
| let perimeter=0, center=0; const m=5; | |
| for (let y=0;y<h;y++){ | |
| for (let x=0;x<w;x++){ | |
| const v=x28[y*w+x]; | |
| const isBorder = (x<m||x>=w-m||y<m||y>=h-m); | |
| if (isBorder) perimeter+=v; else center+=v; | |
| } | |
| } | |
| const loopiness = Math.min(1, center/(perimeter+center+1e-6)*1.8); | |
| return [total, cxn, cyn, lr, tb, htn, vtn, loopiness]; | |
| } | |
| function softmax(arr){ const m=Math.max(...arr); const ex=arr.map(v=>Math.exp(v-m)); const s=ex.reduce((a,b)=>a+b,0)+1e-12; return ex.map(v=>v/s); } | |
| function l2norm(a){ return Math.hypot(...a) || 0; } | |
| function normalize(a){ const n=l2norm(a); return n>0 ? a.map(v=>v/n) : a.slice(); } | |
| function cosine(a,b){ let s=0; for (let i=0;i<a.length;i++) s+=a[i]*b[i]; const na=l2norm(a), nb=l2norm(b)||1; return na>0 ? s/(na*nb) : 0; } | |
| // MNIST-like normalization: crop to tight bbox, scale into 20x20, center in 28x28 | |
| function normalize28(x28){ | |
| const w=28,h=28, thr=0.2; | |
| let minX=29,minY=29,maxX=-1,maxY=-1, sum=0, cx=0, cy=0; | |
| for (let y=0;y<h;y++){ | |
| for (let x=0;x<w;x++){ | |
| const v = x28[y*w+x]; | |
| if (v>thr){ if (x<minX) minX=x; if (x>maxX) maxX=x; if (y<minY) minY=y; if (y>maxY) maxY=y; } | |
| sum += v; cx += x*v; cy += y*v; | |
| } | |
| } | |
| if (sum < 1e-3 || maxX<0){ return x28; } | |
| const comX = cx/sum, comY = cy/sum; | |
| const bw = Math.max(1, maxX-minX+1), bh = Math.max(1, maxY-minY+1); | |
| const scale = 20/Math.max(bw, bh); | |
| const out = new Float32Array(w*h); | |
| // center of canvas | |
| const cxOut = (w-1)/2, cyOut = (h-1)/2; | |
| for (let y=0;y<h;y++){ | |
| for (let x=0;x<w;x++){ | |
| // map output pixel to source space around COM | |
| const sx = (x - cxOut)/scale + comX; | |
| const sy = (y - cyOut)/scale + comY; | |
| out[y*w+x] = bilinearSample(x28, w, h, sx, sy); | |
| } | |
| } | |
| return out; | |
| } | |
| function bilinearSample(img, w, h, x, y){ | |
| const x0 = Math.floor(x), y0 = Math.floor(y); | |
| const x1 = x0+1, y1 = y0+1; | |
| const tx = x - x0, ty = y - y0; | |
| function at(ix,iy){ if (ix<0||iy<0||ix>=w||iy>=h) return 0; return img[iy*w+ix]; } | |
| const v00 = at(x0,y0), v10 = at(x1,y0), v01 = at(x0,y1), v11 = at(x1,y1); | |
| const a = v00*(1-tx)+v10*tx; const b = v01*(1-tx)+v11*tx; return a*(1-ty)+b*ty; | |
| } | |
| // Simple dilation (max-pooling 3x3) to thicken strokes | |
| function dilate28(x){ | |
| const w=28,h=28; const out=new Float32Array(w*h); | |
| for (let y=0;y<h;y++){ | |
| for (let x0=0;x0<w;x0++){ | |
| let m=0; | |
| for (let dy=-1;dy<=1;dy++){ | |
| for (let dx=-1;dx<=1;dx++){ | |
| const xx=x0+dx, yy=y+dy; if (xx<0||yy<0||xx>=w||yy>=h) continue; | |
| const v = x[yy*w+xx]; if (v>m) m=v; | |
| } | |
| } | |
| out[y*w+x0]=m; | |
| } | |
| } | |
| return out; | |
| } | |
| // Glyph-based 28x28 prototypes for digits 0-9 (normalized) | |
| const protoGlyphs28 = []; | |
| (function buildGlyphProtos(){ | |
| const off = document.createElement('canvas'); off.width = CANVAS_PX; off.height = CANVAS_PX; | |
| const c = off.getContext('2d'); | |
| for (let d=0; d<10; d++){ | |
| c.fillStyle = '#ffffff'; c.fillRect(0,0,off.width,off.height); | |
| c.fillStyle = '#000000'; c.textAlign='center'; c.textBaseline='middle'; | |
| c.font = 'bold 180px system-ui, -apple-system, Segoe UI, Roboto, Arial, sans-serif'; | |
| c.fillText(String(d), off.width/2, off.height*0.56); | |
| const src = c.getImageData(0,0,off.width,off.height).data; const block = off.width/28; | |
| const vec = new Float32Array(28*28); | |
| for (let gy=0; gy<28; gy++){ | |
| for (let gx=0; gx<28; gx++){ | |
| let acc=0, cnt=0; const x0=Math.floor(gx*block), y0=Math.floor(gy*block); | |
| for (let yy=y0; yy<y0+block; yy++){ | |
| for (let xx=x0; xx<x0+block; xx++){ | |
| const idx=(yy*off.width+xx)*4; const r=src[idx], g=src[idx+1], b=src[idx+2]; | |
| const gray=(r+g+b)/3/255; acc += (1-gray); cnt++; | |
| } | |
| } | |
| vec[gy*28+gx] = acc/(cnt||1); | |
| } | |
| } | |
| const normed = normalize28(vec); | |
| const n = l2norm(normed)||1; protoGlyphs28.push(normed.map(v=>v/n)); | |
| } | |
| })(); | |
| function dot(a,b){ let s=0; for (let i=0;i<a.length;i++) s+=a[i]*b[i]; return s; } | |
| // Resize handling and node layout | |
| let width=800, height=360; const margin = { top: 16, right: 24, bottom: 24, left: 24 }; | |
| let inputGrid = { cell: 0, x: 0, y: 0, width: 0, height: 0 }; | |
| function layoutNodes(){ | |
| // Right panel width, and a non-square aspect ratio for clarity | |
| width = Math.max(300, Math.round(right.clientWidth || 800)); | |
| height = Math.max(260, Math.round(width * 0.45)); | |
| svg.attr('width', width).attr('height', height); | |
| const innerW = width - margin.left - margin.right; const innerH = height - margin.top - margin.bottom; | |
| gRoot.attr('transform', `translate(${margin.left},${margin.top})`); | |
| // Input grid layout (28x28) at left — cap width to a fraction of innerW | |
| const maxGridFrac = 0.28; // at most 28% of available width | |
| const cellByHeight = Math.floor(innerH / 28); | |
| const cellByWidth = Math.floor((innerW * maxGridFrac) / 28); | |
| let cell = Math.max(3, Math.min(cellByHeight, cellByWidth)); | |
| let gridH = cell * 28; let gridY = Math.floor((innerH - gridH)/2); | |
| inputGrid = { cell, x: 0, y: gridY, width: cell*28, height: gridH }; | |
| // Ensure there is always space for layers to the right | |
| const minRightPad = 40; // minimal free space at right | |
| let startX = inputGrid.width + 24; | |
| if (startX > innerW - minRightPad) { | |
| // Shrink grid to free horizontal room | |
| cell = Math.max(3, Math.floor((innerW - minRightPad - 24) / 28)); | |
| gridH = cell * 28; gridY = Math.floor((innerH - gridH)/2); | |
| inputGrid = { cell, x: 0, y: gridY, width: cell*28, height: gridH }; | |
| startX = inputGrid.width + 24; | |
| } | |
| const nLayers = layerSizes.length; | |
| // Reserve space at right for output labels/bars so they don't get cut off | |
| const rightLabelPad = 100; // px reserved for digit label + bar | |
| const availableW = Math.max(100, innerW - startX - rightLabelPad); | |
| // Reduce inter-layer spacing slightly; keep a sane min/max | |
| const stepX = nLayers > 1 ? Math.min(200, Math.max(28, availableW / (nLayers - 1))) : 0; | |
| const xs = Array.from({ length: nLayers }, (_, li) => startX + stepX * li); | |
| // Y positions evenly spaced per layer | |
| layers.forEach((nodes, li)=>{ | |
| const n = nodes.length; const spacing = innerH/(n+1); | |
| nodes.forEach((nd, i)=>{ nd.x = xs[li]; nd.y = spacing*(i+1); }); | |
| }); | |
| } | |
| let lastX28 = new Float32Array(28*28); | |
| function renderInputGrid(){ | |
| if (!inputGrid || inputGrid.cell <= 0) return; | |
| const data = Array.from({ length: 28*28 }, (_, i) => ({ i, v: lastX28[i] || 0 })); | |
| const sel = gInput.selectAll('rect.input-px').data(data, d=>d.i); | |
| const gap = Math.max(1, Math.floor(inputGrid.cell * 0.10)); | |
| const inner = Math.max(1, inputGrid.cell - gap); | |
| const offset = Math.floor(gap / 2); | |
| sel.enter().append('rect').attr('class','input-px') | |
| .attr('width', inner).attr('height', inner) | |
| .merge(sel) | |
| .attr('x', d => inputGrid.x + (d.i % 28) * inputGrid.cell + offset) | |
| .attr('y', d => inputGrid.y + Math.floor(d.i / 28) * inputGrid.cell + offset) | |
| .attr('fill', d => { const g = 255 - Math.round(d.v * 255); return `rgb(${g},${g},${g})`; }) | |
| .attr('stroke', 'none'); | |
| sel.exit().remove(); | |
| } | |
| function renderInputLinks(){ | |
| // Draw bundle-like links from input grid right edge to first layer nodes (features) | |
| const firstLayer = layers[0]; | |
| if (!firstLayer || !inputGrid || inputGrid.cell <= 0) { gInputLinks.selectAll('path').remove(); return; } | |
| const innerH = height - margin.top - margin.bottom; | |
| const x0 = inputGrid.x + inputGrid.width; | |
| const paths = firstLayer.map((n, idx) => { | |
| const yTarget = n.y; | |
| // source y roughly aligned to node y, clamped within the grid | |
| const y0 = Math.max(inputGrid.y, Math.min(inputGrid.y + inputGrid.height, yTarget)); | |
| const dx = (n.x - x0) * 0.35; | |
| return { x0, y0, x1: n.x - 12, y1: yTarget, c1x: x0 + dx, c1y: y0, c2x: n.x - dx, c2y: yTarget }; | |
| }); | |
| const sel = gInputLinks.selectAll('path.input-link').data(paths); | |
| sel.enter().append('path').attr('class','input-link') | |
| .attr('fill','none') | |
| .attr('stroke','rgba(0,0,0,0.25)') | |
| .attr('stroke-width', 1) | |
| .merge(sel) | |
| .attr('d', d => `M${d.x0},${d.y0} C${d.c1x},${d.c1y} ${d.c2x},${d.c2y} ${d.x1},${d.y1}`); | |
| sel.exit().remove(); | |
| } | |
| function renderGraph(showEdges){ | |
| layoutNodes(); | |
| renderInputGrid(); | |
| renderInputLinks(); | |
| // Nodes | |
| const allNodes = layers.flat(); | |
| const nodeSel = gNodes.selectAll('circle.node').data(allNodes, d=>d.id); | |
| nodeSel.enter().append('circle').attr('class','node') | |
| .attr('r', 10) | |
| .attr('cx', d=>d.x).attr('cy', d=>d.y) | |
| .attr('fill', d=> d.layer===2 ? 'var(--primary-color)' : 'var(--surface-bg)') | |
| .attr('stroke','var(--border-color)').attr('stroke-width',1) | |
| .merge(nodeSel) | |
| .attr('cx', d=>d.x).attr('cy', d=>d.y) | |
| .attr('opacity', 1); | |
| nodeSel.exit().remove(); | |
| // Labels for first hidden layer only (avoid stacking with output probs) | |
| const labels = []; | |
| layers[0].forEach((n,i)=> labels.push({ x:n.x, y:n.y-16, txt:`f${i+1}` })); | |
| const labSel = gLabels.selectAll('text').data(labels); | |
| labSel.enter().append('text').style('font-size','12px').style('fill','var(--muted-color)') | |
| .attr('x', d=>d.x).attr('y', d=>d.y) | |
| .text(d=>d.txt) | |
| .merge(labSel) | |
| .attr('x', d=>d.x).attr('y', d=>d.y).text(d=>d.txt); | |
| labSel.exit().remove(); | |
| // Links as smooth curves | |
| const pathFor = (d) => { | |
| const x1 = layers[d.s.l][d.s.i].x, y1 = layers[d.s.l][d.s.i].y; | |
| const x2 = layers[d.t.l][d.t.j].x, y2 = layers[d.t.l][d.t.j].y; | |
| const dx = (x2 - x1) * 0.45; | |
| return `M${x1},${y1} C${x1+dx},${y1} ${x2-dx},${y2} ${x2},${y2}`; | |
| }; | |
| const linkSel = gLinks.selectAll('path.link').data(links, d=> `${d.s.l}-${d.s.i}-${d.t.l}-${d.t.j}`); | |
| linkSel.enter().append('path').attr('class','link') | |
| .attr('d', pathFor) | |
| .attr('fill','none') | |
| .attr('stroke','rgba(0,0,0,0.25)') | |
| .attr('stroke-width', d=> 0.5 + d.w*1.2) | |
| .merge(linkSel) | |
| .attr('d', pathFor) | |
| .attr('stroke-width', d=> 0.5 + d.w*1.2); | |
| linkSel.exit().remove(); | |
| // Ensure output labels remain aligned with the last layer on resize | |
| gOutText.selectAll('g.out-label') | |
| .attr('transform', function(d){ | |
| if (!d || typeof d.digit !== 'number') return d3.select(this).attr('transform'); | |
| const n = layers[2][d.digit]; | |
| if (!n) return d3.select(this).attr('transform'); | |
| return `translate(${n.x+18},${n.y})`; | |
| }); | |
| } | |
| function setNodeActivations(h1, h2, out){ | |
| layers[0].forEach((n,i)=> n.a = h1[i] || 0); | |
| layers[1].forEach((n,i)=> n.a = h2[i] || 0); | |
| layers[2].forEach((n,i)=> n.a = out[i] || 0); | |
| // Determine top prediction (for ghosting others) | |
| let argmaxIdx = 0; let bestProb = -1; | |
| if (Array.isArray(out)) { | |
| for (let i=0;i<out.length;i++){ if (out[i] > bestProb){ bestProb = out[i]; argmaxIdx = i; } } | |
| } | |
| // Color/opacity by activation | |
| gNodes.selectAll('circle.node') | |
| .attr('fill', d=> d.layer===2 ? 'var(--primary-color)' : `rgba(0,0,0,${0.06 + 0.44*d.a})`) | |
| .attr('stroke', d=> d.layer===2 ? 'var(--primary-color)' : 'var(--border-color)') | |
| .attr('opacity', d=> 0.25 + 0.75*Math.min(1, d.a)) | |
| .attr('r', d=> 8 + 6*Math.min(1, d.a)); | |
| // Link opacity by activation flow | |
| gLinks.selectAll('path.link') | |
| .attr('stroke', d=>{ | |
| const aS = layers[d.s.l][d.s.i].a || 0; const aT = layers[d.t.l][d.t.j].a || 0; | |
| const alpha = Math.min(1, 0.08 + 0.85 * (aS * aT)); | |
| const isDark = document.documentElement.getAttribute('data-theme') === 'dark'; | |
| const base = isDark ? 255 : 0; | |
| return `rgba(${base},${base},${base},${alpha})`; | |
| }); | |
| // Output labels: bold digit + small horizontal bar for probability | |
| const outs = layers[2].map((n,i)=>({ x:n.x+18, y:n.y, digit: i, prob: (out[i]||0), isTop: i===argmaxIdx })); | |
| const gSel = gOutText.selectAll('g.out-label').data(outs, d=>d.digit); | |
| const gEnter = gSel.enter().append('g').attr('class','out-label'); | |
| gEnter.append('text').attr('class','out-digit') | |
| .style('font-size','12px').style('font-weight','700').style('fill','var(--text-color)'); | |
| gEnter.append('rect').attr('class','out-bar-bg').attr('rx',2).attr('ry',2) | |
| .attr('height', 4).attr('fill', 'var(--border-color)'); | |
| gEnter.append('rect').attr('class','out-bar').attr('rx',2).attr('ry',2) | |
| .attr('height', 4); | |
| const BAR_MAX = 64; | |
| gEnter.merge(gSel) | |
| .attr('transform', d=>`translate(${d.x},${d.y})`) | |
| .each(function(d){ | |
| const sel = d3.select(this); | |
| sel.select('text.out-digit') | |
| .attr('x', 0).attr('y', -2) | |
| .text(String(d.digit)); | |
| sel.select('rect.out-bar-bg') | |
| .attr('x', 0).attr('y', 6) | |
| .attr('width', BAR_MAX); | |
| sel.select('rect.out-bar') | |
| .attr('x', 0).attr('y', 6) | |
| .attr('width', Math.max(1, Math.round(d.prob * BAR_MAX))) | |
| .attr('fill', d.isTop ? 'var(--primary-color)' : 'var(--border-color)'); | |
| // Ghost non-top predictions | |
| sel.style('opacity', d.isTop ? 1 : 0.35); | |
| }); | |
| gSel.exit().remove(); | |
| } | |
| // (no separate updateBars; bars are rendered next to nodes) | |
| function runPipeline(){ | |
| const x28raw = downsample28(); | |
| const x28 = dilate28(normalize28(x28raw)); | |
| // Update input grid data | |
| lastX28 = x28; | |
| renderInputGrid(); | |
| const feats = computeFeatures(x28); // 8D in [0,1] | |
| const inkMass = feats[0]; | |
| // Hidden 1 = raw features | |
| const h1 = feats; | |
| // Hidden 2 = simple non-linear mix for visualization only | |
| const h2 = layers[1].map((_, j)=>{ | |
| let s=0; for (let i=0;i<layers[0].length;i++){ const w = (Math.sin(i*17+j*31)+1)/2 * 0.8 + 0.1; s += w*h1[i]; } | |
| return Math.tanh(s*0.8); | |
| }); | |
| let prob; | |
| if (inkMass < 0.03){ | |
| // Too little ink: return near-uniform distribution | |
| prob = Array.from({length:10}, ()=> 1/10); | |
| } else { | |
| // Prefer TFJS model if available | |
| const tfProbs = predictTfjs(x28); | |
| if (tfProbs && tfProbs.length === 10) { | |
| prob = tfProbs; | |
| } else { | |
| // Fallback: rely mostly on glyph similarity | |
| const x28n = normalize(x28); | |
| const logitsGlyph = protoGlyphs28.map(p => 8.0 * cosine(x28n, p)); | |
| const logitsLinear = W.map((row, k)=> dot(row, h1) + b[k]); | |
| const logits = logitsGlyph.map((v,k)=> v + 0.2*logitsLinear[k]); | |
| prob = softmax(logits); | |
| } | |
| } | |
| setNodeActivations(h1, h2.map(v => (v+1)/2), prob); | |
| } | |
| function downsample28(){ | |
| // From canvas (224x224) to 28x28 by average pooling in 8x8 blocks | |
| const block = CANVAS_PX/28; // 8 | |
| const src = ctx.getImageData(0,0,CANVAS_PX,CANVAS_PX).data; | |
| const out = new Float32Array(28*28); | |
| for (let gy=0; gy<28; gy++){ | |
| for (let gx=0; gx<28; gx++){ | |
| let acc=0; let cnt=0; | |
| const x0 = Math.floor(gx*block), y0 = Math.floor(gy*block); | |
| for (let y=y0; y<y0+block; y++){ | |
| for (let x=x0; x<x0+block; x++){ | |
| const idx = (y*CANVAS_PX + x)*4; // RGBA | |
| const r=src[idx], g=src[idx+1], b=src[idx+2]; | |
| const gray = (r+g+b)/3/255; // 1: white, 0: black | |
| const ink = 1-gray; // 1: ink/black | |
| acc += ink; cnt++; | |
| } | |
| } | |
| out[gy*28+gx] = acc/(cnt||1); | |
| } | |
| } | |
| return out; | |
| } | |
| function clearCanvas(){ ctx.fillStyle = '#ffffff'; ctx.fillRect(0,0,CANVAS_PX,CANVAS_PX); runPipeline(); } | |
| // Drawing interactions | |
| let drawing=false; let last=null; | |
| const getPos = (ev) => { | |
| const rect = canvas.getBoundingClientRect(); | |
| const sx = CANVAS_PX/rect.width; const sy = CANVAS_PX/rect.height; | |
| const x = (('touches' in ev)? ev.touches[0].clientX : ev.clientX) - rect.left; | |
| const y = (('touches' in ev)? ev.touches[0].clientY : ev.clientY) - rect.top; | |
| return { x: x*sx, y: y*sy }; | |
| }; | |
| function drawTo(p){ | |
| const size = 24; | |
| ctx.lineCap='round'; ctx.lineJoin='round'; ctx.strokeStyle='#000000'; ctx.lineWidth=size; | |
| if (!last) last = p; | |
| ctx.beginPath(); ctx.moveTo(last.x, last.y); ctx.lineTo(p.x, p.y); ctx.stroke(); | |
| last = p; runPipeline(); | |
| } | |
| function onDown(ev){ drawing=true; last=null; drawTo(getPos(ev)); ev.preventDefault(); } | |
| function onMove(ev){ if (!drawing) return; drawTo(getPos(ev)); ev.preventDefault(); } | |
| function onUp(){ drawing=false; last=null; } | |
| canvas.addEventListener('mousedown', onDown); canvas.addEventListener('mousemove', onMove); window.addEventListener('mouseup', onUp); | |
| canvas.addEventListener('touchstart', onDown, { passive:false }); canvas.addEventListener('touchmove', onMove, { passive:false }); window.addEventListener('touchend', onUp); | |
| // (erase button handled as overlay) | |
| const rerender = () => { renderGraph(true); }; | |
| if (window.ResizeObserver) { | |
| const ro = new ResizeObserver(()=>rerender()); | |
| ro.observe(right); | |
| ro.observe(canvas); | |
| } else { window.addEventListener('resize', rerender); } | |
| // TFJS model (optional) | |
| let tfModel = null; | |
| const tryLoadModel = async () => { | |
| await new Promise((res)=> ensureTF(res)); | |
| const candidates = [ | |
| // Prefer public path via symlink to assets/data | |
| '/data/mnist-variant-model.json', | |
| // Fallbacks to relative copies under content assets (shards must be colocated) | |
| './assets/data/mnist-variant-model.json', | |
| '../assets/data/mnist-variant-model.json', | |
| '/assets/data/mnist-variant-model.json', | |
| // Fallback to public TFJS MNIST | |
| 'https://storage.googleapis.com/tfjs-models/tfjs/mnist/model.json' | |
| ]; | |
| for (const u of candidates){ | |
| try { tfModel = await tf.loadLayersModel(u); return; } catch(_) { /* try next */ } | |
| } | |
| tfModel = null; | |
| }; | |
| function predictTfjs(x28){ | |
| if (!tfModel || !window.tf) return null; | |
| const run = (arr) => { | |
| const t = tf.tidy(()=> tf.tensor(arr, [28,28,1]).expandDims(0)); | |
| try { const y = tfModel.predict(t); const p = y.softmax(); const out = Array.from(p.dataSync()); tf.dispose([y,p,t]); return out; } catch(e){ tf.dispose(t); return null; } | |
| }; | |
| // Try both orientations and keep the one with higher confidence | |
| const p1 = run(x28); | |
| const inv = x28.map(v=>1-v); | |
| const p2 = run(inv); | |
| let probs = p1 || p2; | |
| if (p1 && p2){ | |
| const m1 = Math.max(...p1), m2 = Math.max(...p2); | |
| probs = m2>m1 ? p2 : p1; | |
| } | |
| if (!probs) return null; | |
| // Normalize output size to 10 classes (pad or slice) | |
| if (probs.length < 10){ probs = probs.concat(Array(10 - probs.length).fill(0)); } | |
| if (probs.length > 10){ probs = probs.slice(0,10); } | |
| return probs; | |
| } | |
| // Initial render | |
| renderGraph(true); | |
| clearCanvas(); | |
| tryLoadModel(); | |
| }; | |
| if (document.readyState === 'loading') { document.addEventListener('DOMContentLoaded', () => ensureD3(bootstrap), { once: true }); } else { ensureD3(bootstrap); } | |
| })(); | |
| </script> | |