Spaces:
Running
Running
| <html><head><meta charset=utf8><meta name=viewport content="width=device-width,initial-scale=1"> | |
| <title>Ternary GEMV kernel lab — your GPU</title> | |
| <style> | |
| :root{--bg:#0a0d13;--panel:#111825;--ink:#e8ebf1;--dim:#8b95a7;--ac:#7c5cff;--ok:#48c26c;--no:#f0616d;--warn:#e0a94a;--line:#1e2836} | |
| *{box-sizing:border-box}body{margin:0;background:var(--bg);color:var(--ink);font:15px/1.6 -apple-system,Segoe UI,Roboto,monospace;padding:24px;max-width:820px;margin:0 auto} | |
| h1{font-size:20px;margin:0 0 4px}.sub{color:var(--dim);font-size:13px;margin:0 0 18px} | |
| .card{font-family:ui-monospace,monospace;padding:14px 16px;border:1px solid var(--line);border-radius:10px;background:var(--panel);margin-bottom:12px;word-break:break-word} | |
| .big{font-size:22px;font-weight:700}.k{color:var(--dim)} | |
| .ok{color:var(--ok)}.no{color:var(--no)}.warn{color:var(--warn)} | |
| table{width:100%;border-collapse:collapse;font-family:ui-monospace,monospace;font-size:13px} | |
| td,th{padding:6px 8px;border-bottom:1px solid var(--line);text-align:left}th{color:var(--dim);font-weight:600} | |
| td.n{text-align:right;font-variant-numeric:tabular-nums} | |
| .verdict{font-size:16px;font-weight:600;margin-top:14px} | |
| </style></head><body> | |
| <h1>Ternary GEMV <span style="color:var(--ac)">kernel lab</span> — your GPU</h1> | |
| <p class="sub">Runs the batch-1 BitNet decode kernel (weights read once/token) in several variants against a ~0.69 GB ternary matrix on your real GPU, and reports the decode rate each reaches vs the 152 GB/s roofline (220 tok/s).</p> | |
| <div id="status" class="card">starting…</div> | |
| <div id="out"></div> | |
| <script type="module"> | |
| const $=s=>document.querySelector(s), st=$("#status"), out=$("#out"); | |
| const MODEL_GB = 0.69, ROOF_TOK = 220, CUR_TOK = 54; | |
| const say=(t,cls)=>{ st.textContent=t; st.className="card"+(cls?" "+cls:""); }; | |
| window.addEventListener("unhandledrejection",e=>say("✗ unhandled: "+(e.reason&&(e.reason.message||e.reason)),"no")); | |
| // ---- kernel body templates (KX=1 decode). Same reduction structure; vary the inner math. ---- | |
| // V0: scalar u32 load, per-code f32 convert + (-1) bias (the current engine kernel) | |
| const DOT_V0 = ` | |
| fn dot16(word:u32, v:u32) -> f32 { | |
| var s4=vec4<f32>(0.0); | |
| s4=s4+x[v] *(vec4<f32>(f32(word&3u),f32((word>>2u)&3u),f32((word>>4u)&3u),f32((word>>6u)&3u))-vec4<f32>(1.0)); | |
| s4=s4+x[v+1u]*(vec4<f32>(f32((word>>8u)&3u),f32((word>>10u)&3u),f32((word>>12u)&3u),f32((word>>14u)&3u))-vec4<f32>(1.0)); | |
| s4=s4+x[v+2u]*(vec4<f32>(f32((word>>16u)&3u),f32((word>>18u)&3u),f32((word>>20u)&3u),f32((word>>22u)&3u))-vec4<f32>(1.0)); | |
| s4=s4+x[v+3u]*(vec4<f32>(f32((word>>24u)&3u),f32((word>>26u)&3u),f32((word>>28u)&3u),f32((word>>30u)&3u))-vec4<f32>(1.0)); | |
| return s4.x+s4.y+s4.z+s4.w; | |
| }`; | |
| const BODY_V0 = `let word=qw[rowW+w]; acc=acc+dot16(word, w<<2u);`; | |
| const LOOP_V0 = { arrQw:"array<u32>", stride:"64u", nwExpr:"nw", asum:false }; | |
| // V1: wide vec4<u32> load (64 codes / 16 bytes per load), same math as V0 | |
| const BODY_V1 = `let qv=qw[rowW+w]; let v=(w<<2u)<<2u; | |
| acc=acc+dot16(qv.x,v)+dot16(qv.y,v+4u)+dot16(qv.z,v+8u)+dot16(qv.w,v+12u);`; | |
| const LOOP_V1 = { arrQw:"array<vec4<u32>>", stride:"64u", nwExpr:"(nw>>2u)", asum:false }; | |
| // V2: wide load + BIAS FOLD — drop the 16 per-code (-1); subtract sum(activations) once per row | |
| const DOT_V2 = ` | |
| fn dot16(word:u32, v:u32) -> f32 { | |
| var s4=vec4<f32>(0.0); | |
| s4=s4+x[v] *vec4<f32>(f32(word&3u),f32((word>>2u)&3u),f32((word>>4u)&3u),f32((word>>6u)&3u)); | |
| s4=s4+x[v+1u]*vec4<f32>(f32((word>>8u)&3u),f32((word>>10u)&3u),f32((word>>12u)&3u),f32((word>>14u)&3u)); | |
| s4=s4+x[v+2u]*vec4<f32>(f32((word>>16u)&3u),f32((word>>18u)&3u),f32((word>>20u)&3u),f32((word>>22u)&3u)); | |
| s4=s4+x[v+3u]*vec4<f32>(f32((word>>24u)&3u),f32((word>>26u)&3u),f32((word>>28u)&3u),f32((word>>30u)&3u)); | |
| return s4.x+s4.y+s4.z+s4.w; | |
| }`; | |
| const BODY_V2 = `let qv=qw[rowW+w]; let v=(w<<2u)<<2u; | |
| acc=acc+dot16(qv.x,v)+dot16(qv.y,v+4u)+dot16(qv.z,v+8u)+dot16(qv.w,v+12u);`; | |
| const LOOP_V2 = { arrQw:"array<vec4<u32>>", stride:"64u", nwExpr:"(nw>>2u)", asum:true }; | |
| function kernelWGSL(dot, body, L){ | |
| return ` | |
| @group(0) @binding(0) var<storage,read> x: array<vec4<f32>>; | |
| @group(0) @binding(1) var<storage,read> qw: ${L.arrQw}; | |
| @group(0) @binding(2) var<storage,read_write> o: array<f32>; | |
| @group(0) @binding(3) var<uniform> P: vec4<u32>; // x=K, y=N, z=0, w=asumBits | |
| var<workgroup> red: array<f32, 256>; | |
| ${dot} | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(workgroup_id) wg:vec3<u32>, @builtin(local_invocation_id) lid:vec3<u32>){ | |
| let K=P.x; let nw=K>>4u; let rr=lid.x/64u; let t=lid.x%64u; | |
| let n0=(wg.y*65535u+wg.x)*4u+rr; let n=min(n0, P.y-1u); | |
| let rowW=n*${L.nwExpr}; var acc=0.0; | |
| var w=t; | |
| loop{ if(w>=${L.nwExpr}){break;} | |
| ${body} | |
| w=w+${L.stride}; } | |
| ${L.asum? "acc = acc - bitcast<f32>(P.w);" : ""} | |
| red[lid.x]=acc; workgroupBarrier(); | |
| var s=32u; loop{ if(s==0u){break;} if(t<s){ red[rr*64u+t]=red[rr*64u+t]+red[rr*64u+t+s]; } workgroupBarrier(); s=s/2u; } | |
| if(t==0u && n0<P.y){ o[n0]=red[rr*64u]; } | |
| }`; | |
| } | |
| // pseudo-random 2-bit codes fill (non-zero, non-compressible) | |
| const FILL = ` | |
| @group(0) @binding(0) var<storage,read_write> d: array<u32>; | |
| @group(0) @binding(1) var<uniform> P: vec4<u32>; // x=#u32, y=stride | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid:vec3<u32>){ | |
| let n=P.x; var i=gid.x; | |
| loop{ if(i>=n){break;} d[i]=(i*2654435761u+1u)^((i<<7u)*40503u); i=i+P.y; } | |
| }`; | |
| const VARIANTS=[ | |
| {id:"V0", name:"V0 baseline (scalar load, per-code −1)", dot:DOT_V0, body:BODY_V0, L:LOOP_V0}, | |
| {id:"V1", name:"V1 wide vec4 load", dot:DOT_V0, body:BODY_V1, L:LOOP_V1}, | |
| {id:"V2", name:"V2 wide load + bias-fold", dot:DOT_V2, body:BODY_V2, L:LOOP_V2}, | |
| ]; | |
| (async()=>{ | |
| try{ | |
| if(!navigator.gpu){ say("✗ No WebGPU. Open in Chrome/Edge.","no"); return; } | |
| say("requesting adapter…"); | |
| let ad=await navigator.gpu.requestAdapter({powerPreference:"high-performance"}); if(!ad) ad=await navigator.gpu.requestAdapter(); | |
| if(!ad){ say("✗ no GPU adapter.","no"); return; } | |
| const L=ad.limits, info=ad.info||{}; | |
| const dev=await ad.requestDevice({requiredLimits:{ | |
| maxStorageBufferBindingSize:L.maxStorageBufferBindingSize, maxBufferSize:L.maxBufferSize, | |
| maxComputeWorkgroupsPerDimension:L.maxComputeWorkgroupsPerDimension }}); | |
| dev.lost.then(i=>say("✗ device lost: "+(i&&i.message||i.reason||""),"no")); | |
| const K=4096, nw=K/16; // 256 words/row | |
| // weight bytes: as close to 0.69 GB as the device allows | |
| const wantBytes=Math.min(L.maxStorageBufferBindingSize, L.maxBufferSize, Math.round(MODEL_GB*1e9)); | |
| const N=Math.floor((wantBytes/4)/nw); // whole rows | |
| const qU32=N*nw, wBytes=qU32*4; | |
| say("allocating "+(wBytes/1048576).toFixed(0)+" MB ternary weights ("+N.toLocaleString()+" rows × "+K+")…"); | |
| const qw=dev.createBuffer({size:wBytes, usage:GPUBufferUsage.STORAGE}); | |
| const xbuf=dev.createBuffer({size:K*4, usage:GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_DST}); | |
| const obuf=dev.createBuffer({size:N*4, usage:GPUBufferUsage.STORAGE}); | |
| const P=dev.createBuffer({size:16, usage:GPUBufferUsage.UNIFORM|GPUBufferUsage.COPY_DST}); | |
| // activations: random f32, and their sum (for V2 bias fold) | |
| const xa=new Float32Array(K); let asum=0; for(let i=0;i<K;i++){ xa[i]=(Math.sin(i*12.9898)*43758.5453)%1; asum+=xa[i]; } | |
| dev.queue.writeBuffer(xbuf,0,xa); | |
| const asumBits=new Uint32Array(new Float32Array([asum]).buffer)[0]; | |
| dev.queue.writeBuffer(P,0,new Uint32Array([K, N, 0, asumBits])); | |
| // fill weights (non-zero) so reads hit real VRAM | |
| say("filling weights…"); | |
| { const fmod=dev.createShaderModule({code:FILL}); | |
| const fpipe=dev.createComputePipeline({layout:"auto",compute:{module:fmod,entryPoint:"main"}}); | |
| const fp=dev.createBuffer({size:16,usage:GPUBufferUsage.UNIFORM|GPUBufferUsage.COPY_DST}); | |
| const wgF=Math.min(L.maxComputeWorkgroupsPerDimension,65535); | |
| dev.queue.writeBuffer(fp,0,new Uint32Array([qU32, wgF*256, 0, 0])); | |
| const fbg=dev.createBindGroup({layout:fpipe.getBindGroupLayout(0),entries:[{binding:0,resource:{buffer:qw}},{binding:1,resource:{buffer:fp}}]}); | |
| const e=dev.createCommandEncoder(); const p=e.beginComputePass(); p.setPipeline(fpipe); p.setBindGroup(0,fbg); p.dispatchWorkgroups(wgF); p.end(); | |
| dev.queue.submit([e.finish()]); await dev.queue.onSubmittedWorkDone(); } | |
| // dispatch dims: workgroup = 4 rows | |
| const groups=Math.ceil(N/4), wgx=Math.min(groups,65535), wgy=Math.ceil(groups/65535); | |
| const PASSES=8, ITERS=5; | |
| const rows=[]; | |
| for(const V of VARIANTS){ | |
| say("compiling "+V.id+"…"); | |
| dev.pushErrorScope("validation"); | |
| const mod=dev.createShaderModule({code:kernelWGSL(V.dot,V.body,V.L)}); | |
| const ci=await mod.getCompilationInfo(); const er=ci.messages.filter(m=>m.type==="error"); | |
| if(er.length){ await dev.popErrorScope(); rows.push({V, err:er[0].message}); continue; } | |
| const pipe=dev.createComputePipeline({layout:"auto",compute:{module:mod,entryPoint:"main"}}); | |
| const bg=dev.createBindGroup({layout:pipe.getBindGroupLayout(0),entries:[ | |
| {binding:0,resource:{buffer:xbuf}},{binding:1,resource:{buffer:qw}},{binding:2,resource:{buffer:obuf}},{binding:3,resource:{buffer:P}}]}); | |
| const scopeErr=await dev.popErrorScope(); | |
| if(scopeErr){ rows.push({V, err:scopeErr.message}); continue; } | |
| async function run(passes){ | |
| const e=dev.createCommandEncoder(); | |
| for(let k=0;k<passes;k++){ const p=e.beginComputePass(); p.setPipeline(pipe); p.setBindGroup(0,bg); p.dispatchWorkgroups(wgx,wgy); p.end(); } | |
| const t0=performance.now(); dev.queue.submit([e.finish()]); await dev.queue.onSubmittedWorkDone(); return performance.now()-t0; | |
| } | |
| say("timing "+V.id+"…"); | |
| await run(2); // warm | |
| let best=1e9; for(let k=0;k<ITERS;k++){ best=Math.min(best, await run(PASSES)); } | |
| const msPer=best/PASSES, gbps=(wBytes/1073741824)/(msPer/1000), tok=gbps/MODEL_GB, pct=100*tok/ROOF_TOK; | |
| rows.push({V, msPer, gbps, tok, pct}); | |
| } | |
| say("done · adapter: "+((info.vendor||"?")+" "+(info.architecture||"")+" "+(info.device||"")).trim(),"ok"); | |
| const best=rows.filter(r=>r.tok).sort((a,b)=>b.tok-a.tok)[0]; | |
| out.innerHTML=` | |
| <div class="card"><span class="k">Test matrix</span><br><span class="big">${N.toLocaleString()} × ${K}</span> · ${(wBytes/1073741824).toFixed(2)} GB ternary weights · roofline ${ROOF_TOK} tok/s</div> | |
| <div class="card"><table> | |
| <tr><th>variant</th><th class="n">ms/token</th><th class="n">GB/s</th><th class="n">tok/s</th><th class="n">% roofline</th></tr> | |
| ${rows.map(r=>r.err | |
| ? `<tr><td>${r.V.name}</td><td colspan=4 class="no">compile/validate failed: ${r.err.slice(0,60)}</td></tr>` | |
| : `<tr><td>${r.V.name}</td><td class="n">${r.msPer.toFixed(2)}</td><td class="n">${r.gbps.toFixed(0)}</td><td class="n ${r===best?'ok':''}"><b>${r.tok.toFixed(0)}</b></td><td class="n ${r.pct<30?'no':r.pct<70?'warn':'ok'}">${r.pct.toFixed(0)}%</td></tr>`).join("")} | |
| <tr><td class="k">current shipped kernel (measured)</td><td class="n k">18.4</td><td class="n k">38</td><td class="n k">${CUR_TOK}</td><td class="n k">${(100*CUR_TOK/ROOF_TOK).toFixed(0)}%</td></tr> | |
| </table></div> | |
| ${best? `<div class="verdict ok">✓ Best: <b>${best.V.name}</b> → ${best.tok.toFixed(0)} tok/s (${best.pct.toFixed(0)}% of roofline, ${(best.tok/CUR_TOK).toFixed(1)}× the shipped kernel). ${best.pct>=70? "Bandwidth-bound — ship it." : "Still ALU-bound; next lever = integer dot / fewer converts."}</div>` | |
| : `<div class="verdict no">All variants failed to run.</div>`}`; | |
| }catch(e){ say("✗ "+(e&&(e.message||e)), "no"); } | |
| })(); | |
| </script></body></html> | |