q / gemv-bench.html
Humuhumu33's picture
ternary GEMV kernel lab: V0/V1/V2 variants
bbf8a0a verified
Raw
History Blame Contribute Delete
11.7 kB
<!doctype html><html><head><meta charset=utf8><meta name=viewport content="width=device-width,initial-scale=1">
<title>Ternary GEMV kernel lab — your GPU</title>
<style>
:root{--bg:#0a0d13;--panel:#111825;--ink:#e8ebf1;--dim:#8b95a7;--ac:#7c5cff;--ok:#48c26c;--no:#f0616d;--warn:#e0a94a;--line:#1e2836}
*{box-sizing:border-box}body{margin:0;background:var(--bg);color:var(--ink);font:15px/1.6 -apple-system,Segoe UI,Roboto,monospace;padding:24px;max-width:820px;margin:0 auto}
h1{font-size:20px;margin:0 0 4px}.sub{color:var(--dim);font-size:13px;margin:0 0 18px}
.card{font-family:ui-monospace,monospace;padding:14px 16px;border:1px solid var(--line);border-radius:10px;background:var(--panel);margin-bottom:12px;word-break:break-word}
.big{font-size:22px;font-weight:700}.k{color:var(--dim)}
.ok{color:var(--ok)}.no{color:var(--no)}.warn{color:var(--warn)}
table{width:100%;border-collapse:collapse;font-family:ui-monospace,monospace;font-size:13px}
td,th{padding:6px 8px;border-bottom:1px solid var(--line);text-align:left}th{color:var(--dim);font-weight:600}
td.n{text-align:right;font-variant-numeric:tabular-nums}
.verdict{font-size:16px;font-weight:600;margin-top:14px}
</style></head><body>
<h1>Ternary GEMV <span style="color:var(--ac)">kernel lab</span> — your GPU</h1>
<p class="sub">Runs the batch-1 BitNet decode kernel (weights read once/token) in several variants against a ~0.69&nbsp;GB ternary matrix on your real GPU, and reports the decode rate each reaches vs the 152&nbsp;GB/s roofline (220&nbsp;tok/s).</p>
<div id="status" class="card">starting…</div>
<div id="out"></div>
<script type="module">
const $=s=>document.querySelector(s), st=$("#status"), out=$("#out");
const MODEL_GB = 0.69, ROOF_TOK = 220, CUR_TOK = 54;
const say=(t,cls)=>{ st.textContent=t; st.className="card"+(cls?" "+cls:""); };
window.addEventListener("unhandledrejection",e=>say("✗ unhandled: "+(e.reason&&(e.reason.message||e.reason)),"no"));
// ---- kernel body templates (KX=1 decode). Same reduction structure; vary the inner math. ----
// V0: scalar u32 load, per-code f32 convert + (-1) bias (the current engine kernel)
const DOT_V0 = `
fn dot16(word:u32, v:u32) -> f32 {
var s4=vec4<f32>(0.0);
s4=s4+x[v] *(vec4<f32>(f32(word&3u),f32((word>>2u)&3u),f32((word>>4u)&3u),f32((word>>6u)&3u))-vec4<f32>(1.0));
s4=s4+x[v+1u]*(vec4<f32>(f32((word>>8u)&3u),f32((word>>10u)&3u),f32((word>>12u)&3u),f32((word>>14u)&3u))-vec4<f32>(1.0));
s4=s4+x[v+2u]*(vec4<f32>(f32((word>>16u)&3u),f32((word>>18u)&3u),f32((word>>20u)&3u),f32((word>>22u)&3u))-vec4<f32>(1.0));
s4=s4+x[v+3u]*(vec4<f32>(f32((word>>24u)&3u),f32((word>>26u)&3u),f32((word>>28u)&3u),f32((word>>30u)&3u))-vec4<f32>(1.0));
return s4.x+s4.y+s4.z+s4.w;
}`;
const BODY_V0 = `let word=qw[rowW+w]; acc=acc+dot16(word, w<<2u);`;
const LOOP_V0 = { arrQw:"array<u32>", stride:"64u", nwExpr:"nw", asum:false };
// V1: wide vec4<u32> load (64 codes / 16 bytes per load), same math as V0
const BODY_V1 = `let qv=qw[rowW+w]; let v=(w<<2u)<<2u;
acc=acc+dot16(qv.x,v)+dot16(qv.y,v+4u)+dot16(qv.z,v+8u)+dot16(qv.w,v+12u);`;
const LOOP_V1 = { arrQw:"array<vec4<u32>>", stride:"64u", nwExpr:"(nw>>2u)", asum:false };
// V2: wide load + BIAS FOLD — drop the 16 per-code (-1); subtract sum(activations) once per row
const DOT_V2 = `
fn dot16(word:u32, v:u32) -> f32 {
var s4=vec4<f32>(0.0);
s4=s4+x[v] *vec4<f32>(f32(word&3u),f32((word>>2u)&3u),f32((word>>4u)&3u),f32((word>>6u)&3u));
s4=s4+x[v+1u]*vec4<f32>(f32((word>>8u)&3u),f32((word>>10u)&3u),f32((word>>12u)&3u),f32((word>>14u)&3u));
s4=s4+x[v+2u]*vec4<f32>(f32((word>>16u)&3u),f32((word>>18u)&3u),f32((word>>20u)&3u),f32((word>>22u)&3u));
s4=s4+x[v+3u]*vec4<f32>(f32((word>>24u)&3u),f32((word>>26u)&3u),f32((word>>28u)&3u),f32((word>>30u)&3u));
return s4.x+s4.y+s4.z+s4.w;
}`;
const BODY_V2 = `let qv=qw[rowW+w]; let v=(w<<2u)<<2u;
acc=acc+dot16(qv.x,v)+dot16(qv.y,v+4u)+dot16(qv.z,v+8u)+dot16(qv.w,v+12u);`;
const LOOP_V2 = { arrQw:"array<vec4<u32>>", stride:"64u", nwExpr:"(nw>>2u)", asum:true };
function kernelWGSL(dot, body, L){
return `
@group(0) @binding(0) var<storage,read> x: array<vec4<f32>>;
@group(0) @binding(1) var<storage,read> qw: ${L.arrQw};
@group(0) @binding(2) var<storage,read_write> o: array<f32>;
@group(0) @binding(3) var<uniform> P: vec4<u32>; // x=K, y=N, z=0, w=asumBits
var<workgroup> red: array<f32, 256>;
${dot}
@compute @workgroup_size(256)
fn main(@builtin(workgroup_id) wg:vec3<u32>, @builtin(local_invocation_id) lid:vec3<u32>){
let K=P.x; let nw=K>>4u; let rr=lid.x/64u; let t=lid.x%64u;
let n0=(wg.y*65535u+wg.x)*4u+rr; let n=min(n0, P.y-1u);
let rowW=n*${L.nwExpr}; var acc=0.0;
var w=t;
loop{ if(w>=${L.nwExpr}){break;}
${body}
w=w+${L.stride}; }
${L.asum? "acc = acc - bitcast<f32>(P.w);" : ""}
red[lid.x]=acc; workgroupBarrier();
var s=32u; loop{ if(s==0u){break;} if(t<s){ red[rr*64u+t]=red[rr*64u+t]+red[rr*64u+t+s]; } workgroupBarrier(); s=s/2u; }
if(t==0u && n0<P.y){ o[n0]=red[rr*64u]; }
}`;
}
// pseudo-random 2-bit codes fill (non-zero, non-compressible)
const FILL = `
@group(0) @binding(0) var<storage,read_write> d: array<u32>;
@group(0) @binding(1) var<uniform> P: vec4<u32>; // x=#u32, y=stride
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid:vec3<u32>){
let n=P.x; var i=gid.x;
loop{ if(i>=n){break;} d[i]=(i*2654435761u+1u)^((i<<7u)*40503u); i=i+P.y; }
}`;
const VARIANTS=[
{id:"V0", name:"V0 baseline (scalar load, per-code −1)", dot:DOT_V0, body:BODY_V0, L:LOOP_V0},
{id:"V1", name:"V1 wide vec4 load", dot:DOT_V0, body:BODY_V1, L:LOOP_V1},
{id:"V2", name:"V2 wide load + bias-fold", dot:DOT_V2, body:BODY_V2, L:LOOP_V2},
];
(async()=>{
try{
if(!navigator.gpu){ say("✗ No WebGPU. Open in Chrome/Edge.","no"); return; }
say("requesting adapter…");
let ad=await navigator.gpu.requestAdapter({powerPreference:"high-performance"}); if(!ad) ad=await navigator.gpu.requestAdapter();
if(!ad){ say("✗ no GPU adapter.","no"); return; }
const L=ad.limits, info=ad.info||{};
const dev=await ad.requestDevice({requiredLimits:{
maxStorageBufferBindingSize:L.maxStorageBufferBindingSize, maxBufferSize:L.maxBufferSize,
maxComputeWorkgroupsPerDimension:L.maxComputeWorkgroupsPerDimension }});
dev.lost.then(i=>say("✗ device lost: "+(i&&i.message||i.reason||""),"no"));
const K=4096, nw=K/16; // 256 words/row
// weight bytes: as close to 0.69 GB as the device allows
const wantBytes=Math.min(L.maxStorageBufferBindingSize, L.maxBufferSize, Math.round(MODEL_GB*1e9));
const N=Math.floor((wantBytes/4)/nw); // whole rows
const qU32=N*nw, wBytes=qU32*4;
say("allocating "+(wBytes/1048576).toFixed(0)+" MB ternary weights ("+N.toLocaleString()+" rows × "+K+")…");
const qw=dev.createBuffer({size:wBytes, usage:GPUBufferUsage.STORAGE});
const xbuf=dev.createBuffer({size:K*4, usage:GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_DST});
const obuf=dev.createBuffer({size:N*4, usage:GPUBufferUsage.STORAGE});
const P=dev.createBuffer({size:16, usage:GPUBufferUsage.UNIFORM|GPUBufferUsage.COPY_DST});
// activations: random f32, and their sum (for V2 bias fold)
const xa=new Float32Array(K); let asum=0; for(let i=0;i<K;i++){ xa[i]=(Math.sin(i*12.9898)*43758.5453)%1; asum+=xa[i]; }
dev.queue.writeBuffer(xbuf,0,xa);
const asumBits=new Uint32Array(new Float32Array([asum]).buffer)[0];
dev.queue.writeBuffer(P,0,new Uint32Array([K, N, 0, asumBits]));
// fill weights (non-zero) so reads hit real VRAM
say("filling weights…");
{ const fmod=dev.createShaderModule({code:FILL});
const fpipe=dev.createComputePipeline({layout:"auto",compute:{module:fmod,entryPoint:"main"}});
const fp=dev.createBuffer({size:16,usage:GPUBufferUsage.UNIFORM|GPUBufferUsage.COPY_DST});
const wgF=Math.min(L.maxComputeWorkgroupsPerDimension,65535);
dev.queue.writeBuffer(fp,0,new Uint32Array([qU32, wgF*256, 0, 0]));
const fbg=dev.createBindGroup({layout:fpipe.getBindGroupLayout(0),entries:[{binding:0,resource:{buffer:qw}},{binding:1,resource:{buffer:fp}}]});
const e=dev.createCommandEncoder(); const p=e.beginComputePass(); p.setPipeline(fpipe); p.setBindGroup(0,fbg); p.dispatchWorkgroups(wgF); p.end();
dev.queue.submit([e.finish()]); await dev.queue.onSubmittedWorkDone(); }
// dispatch dims: workgroup = 4 rows
const groups=Math.ceil(N/4), wgx=Math.min(groups,65535), wgy=Math.ceil(groups/65535);
const PASSES=8, ITERS=5;
const rows=[];
for(const V of VARIANTS){
say("compiling "+V.id+"…");
dev.pushErrorScope("validation");
const mod=dev.createShaderModule({code:kernelWGSL(V.dot,V.body,V.L)});
const ci=await mod.getCompilationInfo(); const er=ci.messages.filter(m=>m.type==="error");
if(er.length){ await dev.popErrorScope(); rows.push({V, err:er[0].message}); continue; }
const pipe=dev.createComputePipeline({layout:"auto",compute:{module:mod,entryPoint:"main"}});
const bg=dev.createBindGroup({layout:pipe.getBindGroupLayout(0),entries:[
{binding:0,resource:{buffer:xbuf}},{binding:1,resource:{buffer:qw}},{binding:2,resource:{buffer:obuf}},{binding:3,resource:{buffer:P}}]});
const scopeErr=await dev.popErrorScope();
if(scopeErr){ rows.push({V, err:scopeErr.message}); continue; }
async function run(passes){
const e=dev.createCommandEncoder();
for(let k=0;k<passes;k++){ const p=e.beginComputePass(); p.setPipeline(pipe); p.setBindGroup(0,bg); p.dispatchWorkgroups(wgx,wgy); p.end(); }
const t0=performance.now(); dev.queue.submit([e.finish()]); await dev.queue.onSubmittedWorkDone(); return performance.now()-t0;
}
say("timing "+V.id+"…");
await run(2); // warm
let best=1e9; for(let k=0;k<ITERS;k++){ best=Math.min(best, await run(PASSES)); }
const msPer=best/PASSES, gbps=(wBytes/1073741824)/(msPer/1000), tok=gbps/MODEL_GB, pct=100*tok/ROOF_TOK;
rows.push({V, msPer, gbps, tok, pct});
}
say("done · adapter: "+((info.vendor||"?")+" "+(info.architecture||"")+" "+(info.device||"")).trim(),"ok");
const best=rows.filter(r=>r.tok).sort((a,b)=>b.tok-a.tok)[0];
out.innerHTML=`
<div class="card"><span class="k">Test matrix</span><br><span class="big">${N.toLocaleString()} × ${K}</span> &nbsp;·&nbsp; ${(wBytes/1073741824).toFixed(2)} GB ternary weights · roofline ${ROOF_TOK} tok/s</div>
<div class="card"><table>
<tr><th>variant</th><th class="n">ms/token</th><th class="n">GB/s</th><th class="n">tok/s</th><th class="n">% roofline</th></tr>
${rows.map(r=>r.err
? `<tr><td>${r.V.name}</td><td colspan=4 class="no">compile/validate failed: ${r.err.slice(0,60)}</td></tr>`
: `<tr><td>${r.V.name}</td><td class="n">${r.msPer.toFixed(2)}</td><td class="n">${r.gbps.toFixed(0)}</td><td class="n ${r===best?'ok':''}"><b>${r.tok.toFixed(0)}</b></td><td class="n ${r.pct<30?'no':r.pct<70?'warn':'ok'}">${r.pct.toFixed(0)}%</td></tr>`).join("")}
<tr><td class="k">current shipped kernel (measured)</td><td class="n k">18.4</td><td class="n k">38</td><td class="n k">${CUR_TOK}</td><td class="n k">${(100*CUR_TOK/ROOF_TOK).toFixed(0)}%</td></tr>
</table></div>
${best? `<div class="verdict ok">✓ Best: <b>${best.V.name}</b> → ${best.tok.toFixed(0)} tok/s (${best.pct.toFixed(0)}% of roofline, ${(best.tok/CUR_TOK).toFixed(1)}× the shipped kernel). ${best.pct>=70? "Bandwidth-bound — ship it." : "Still ALU-bound; next lever = integer dot / fewer converts."}</div>`
: `<div class="verdict no">All variants failed to run.</div>`}`;
}catch(e){ say("✗ "+(e&&(e.message||e)), "no"); }
})();
</script></body></html>