q / rb-bench.html
Humuhumu33's picture
register-blocked GEMV lab: B=1/2/4/8/16 rows per group
3653faa verified
Raw
History Blame Contribute Delete
9.92 kB
<!doctype html><html><head><meta charset=utf8><meta name=viewport content="width=device-width,initial-scale=1">
<title>Register-blocked GEMV — your GPU</title>
<style>
:root{--bg:#0a0d13;--panel:#111825;--ink:#e8ebf1;--dim:#8b95a7;--ac:#7c5cff;--ok:#48c26c;--no:#f0616d;--warn:#e0a94a;--line:#1e2836}
*{box-sizing:border-box}body{margin:0;background:var(--bg);color:var(--ink);font:15px/1.6 -apple-system,Segoe UI,Roboto,monospace;padding:24px;max-width:860px;margin:0 auto}
h1{font-size:20px;margin:0 0 4px}.sub{color:var(--dim);font-size:13px;margin:0 0 18px}
.card{font-family:ui-monospace,monospace;padding:14px 16px;border:1px solid var(--line);border-radius:10px;background:var(--panel);margin-bottom:12px;word-break:break-word}
table{width:100%;border-collapse:collapse;font-family:ui-monospace,monospace;font-size:13px}
td,th{padding:6px 8px;border-bottom:1px solid var(--line);text-align:left}th{color:var(--dim);font-weight:600}
td.n{text-align:right;font-variant-numeric:tabular-nums}
.verdict{font-size:16px;font-weight:600;margin-top:14px}.bar{display:inline-block;height:9px;background:var(--ac);border-radius:2px;vertical-align:middle}
</style></head><body>
<h1>Register-blocked <span style="color:var(--ac)">ternary GEMV</span> — your GPU</h1>
<p class="sub">The kernel is memory-latency-bound (45 of 152 GB/s). Each 64-thread group here computes <b>B output rows</b>, so each thread issues B independent weight loads per step (more loads in flight to hide latency) while sharing the activation reads. B=1 is today's kernel. Higher B should climb toward the 152 GB/s roofline — until register pressure drops occupancy.</p>
<div id="status" class="card">starting…</div>
<div id="out"></div>
<script type="module">
const $=s=>document.querySelector(s), st=$("#status"), out=$("#out");
const MODEL_GB=0.69, ROOF=152, ROOF_TOK=220;
const say=(t,cls)=>{ st.textContent=t; st.className="card"+(cls?" "+cls:""); };
window.addEventListener("unhandledrejection",e=>say("✗ unhandled: "+(e.reason&&(e.reason.message||e.reason)),"no"));
const DOT16=`
fn dot16(word:u32, v:u32) -> f32 {
var s4=vec4<f32>(0.0);
s4=s4+x[v] *(vec4<f32>(f32(word&3u),f32((word>>2u)&3u),f32((word>>4u)&3u),f32((word>>6u)&3u))-vec4<f32>(1.0));
s4=s4+x[v+1u]*(vec4<f32>(f32((word>>8u)&3u),f32((word>>10u)&3u),f32((word>>12u)&3u),f32((word>>14u)&3u))-vec4<f32>(1.0));
s4=s4+x[v+2u]*(vec4<f32>(f32((word>>16u)&3u),f32((word>>18u)&3u),f32((word>>20u)&3u),f32((word>>22u)&3u))-vec4<f32>(1.0));
s4=s4+x[v+3u]*(vec4<f32>(f32((word>>24u)&3u),f32((word>>26u)&3u),f32((word>>28u)&3u),f32((word>>30u)&3u))-vec4<f32>(1.0));
return s4.x+s4.y+s4.z+s4.w;
}`;
// B output rows per 64-thread group, 4 groups per workgroup (256 threads). Each thread does B independent
// weight loads/iteration (memory-level parallelism), reusing the shared activation reads.
const kern=(B)=>`
@group(0) @binding(0) var<storage,read> x: array<vec4<f32>>;
@group(0) @binding(1) var<storage,read> qw: array<u32>;
@group(0) @binding(2) var<storage,read_write> o: array<f32>;
@group(0) @binding(3) var<uniform> P: vec4<u32>; // K, N
var<workgroup> red: array<f32, 256>;
${DOT16}
@compute @workgroup_size(256)
fn main(@builtin(workgroup_id) wg:vec3<u32>, @builtin(local_invocation_id) lid:vec3<u32>){
let K=P.x; let nw=K>>4u; let grp=lid.x/64u; let t=lid.x%64u;
let base=((wg.y*65535u+wg.x)*4u + grp)*${B}u;
var acc: array<f32, ${B}>;
${Array.from({length:B},(_,b)=>`acc[${b}]=0.0;`).join(" ")}
var w=t;
loop{ if(w>=nw){break;} let v=w<<2u;
${Array.from({length:B},(_,b)=>`let word${b}=qw[min(base+${b}u,P.y-1u)*nw + w];`).join("\n ")}
${Array.from({length:B},(_,b)=>`acc[${b}]=acc[${b}]+dot16(word${b}, v);`).join("\n ")}
w=w+64u; }
${Array.from({length:B},(_,b)=>`
red[lid.x]=acc[${b}]; workgroupBarrier();
{ var s=32u; loop{ if(s==0u){break;} if(t<s){ red[grp*64u+t]=red[grp*64u+t]+red[grp*64u+t+s]; } workgroupBarrier(); s=s/2u; } }
if(t==0u){ let n=base+${b}u; if(n<P.y){ o[n]=red[grp*64u]; } }
workgroupBarrier();`).join("")}
}`;
const FILL=`
@group(0) @binding(0) var<storage,read_write> d: array<u32>;
@group(0) @binding(1) var<uniform> P: vec4<u32>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid:vec3<u32>){ let n=P.x; var i=gid.x; loop{ if(i>=n){break;} d[i]=(i*2654435761u+1u)^((i<<7u)*40503u); i=i+P.y; } }`;
const BS=[1,2,4,8,16];
(async()=>{
try{
if(!navigator.gpu){ say("✗ No WebGPU. Open in Chrome/Edge.","no"); return; }
say("requesting adapter…");
let ad=await navigator.gpu.requestAdapter({powerPreference:"high-performance"}); if(!ad) ad=await navigator.gpu.requestAdapter();
if(!ad){ say("✗ no GPU adapter.","no"); return; }
const L=ad.limits, info=ad.info||{};
const dev=await ad.requestDevice({requiredLimits:{maxStorageBufferBindingSize:L.maxStorageBufferBindingSize, maxBufferSize:L.maxBufferSize, maxComputeWorkgroupsPerDimension:L.maxComputeWorkgroupsPerDimension}});
dev.lost.then(i=>say("✗ device lost: "+(i&&i.message||i.reason||""),"no"));
const K=4096, nw=K/16;
const wantBytes=Math.min(L.maxStorageBufferBindingSize, L.maxBufferSize, Math.round(MODEL_GB*1e9));
const N=Math.floor((wantBytes/4)/nw), qU32=N*nw, wBytes=qU32*4;
say("allocating "+(wBytes/1048576).toFixed(0)+" MB…");
const qw=dev.createBuffer({size:wBytes, usage:GPUBufferUsage.STORAGE});
const xbuf=dev.createBuffer({size:K*4, usage:GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_DST});
const obuf=dev.createBuffer({size:N*4, usage:GPUBufferUsage.STORAGE});
const P=dev.createBuffer({size:16, usage:GPUBufferUsage.UNIFORM|GPUBufferUsage.COPY_DST});
const xa=new Float32Array(K); for(let i=0;i<K;i++) xa[i]=(Math.sin(i*12.9898)*43758.5453)%1;
dev.queue.writeBuffer(xbuf,0,xa); dev.queue.writeBuffer(P,0,new Uint32Array([K,N,0,0]));
say("filling weights…");
{ const fmod=dev.createShaderModule({code:FILL}); const fpipe=dev.createComputePipeline({layout:"auto",compute:{module:fmod,entryPoint:"main"}});
const fp=dev.createBuffer({size:16,usage:GPUBufferUsage.UNIFORM|GPUBufferUsage.COPY_DST}); const wgF=Math.min(L.maxComputeWorkgroupsPerDimension,65535);
dev.queue.writeBuffer(fp,0,new Uint32Array([qU32, wgF*256, 0, 0]));
const fbg=dev.createBindGroup({layout:fpipe.getBindGroupLayout(0),entries:[{binding:0,resource:{buffer:qw}},{binding:1,resource:{buffer:fp}}]});
const e=dev.createCommandEncoder(); const p=e.beginComputePass(); p.setPipeline(fpipe); p.setBindGroup(0,fbg); p.dispatchWorkgroups(wgF); p.end();
dev.queue.submit([e.finish()]); await dev.queue.onSubmittedWorkDone(); }
const PASSES=10, ITERS=5, rows=[];
for(const B of BS){
say("compiling B="+B+"…");
dev.pushErrorScope("validation");
const mod=dev.createShaderModule({code:kern(B)});
const ci=await mod.getCompilationInfo(); const er=ci.messages.filter(m=>m.type==="error");
if(er.length){ await dev.popErrorScope(); rows.push({B,err:er[0].message}); continue; }
const pipe=dev.createComputePipeline({layout:"auto",compute:{module:mod,entryPoint:"main"}});
const bg=dev.createBindGroup({layout:pipe.getBindGroupLayout(0),entries:[{binding:0,resource:{buffer:xbuf}},{binding:1,resource:{buffer:qw}},{binding:2,resource:{buffer:obuf}},{binding:3,resource:{buffer:P}}]});
const se=await dev.popErrorScope(); if(se){ rows.push({B,err:se.message}); continue; }
const groups=Math.ceil(N/(4*B)), wgx=Math.min(groups,65535), wgy=Math.ceil(groups/65535);
async function run(passes){ const e=dev.createCommandEncoder(); for(let k=0;k<passes;k++){ const p=e.beginComputePass(); p.setPipeline(pipe); p.setBindGroup(0,bg); p.dispatchWorkgroups(wgx,wgy); p.end(); } const t0=performance.now(); dev.queue.submit([e.finish()]); await dev.queue.onSubmittedWorkDone(); return performance.now()-t0; }
say("timing B="+B+"…");
await run(2); let best=1e9; for(let k=0;k<ITERS;k++){ best=Math.min(best, await run(PASSES)); }
const gbps=(wBytes/1073741824)/((best/PASSES)/1000);
rows.push({B, gbps, tok:gbps/MODEL_GB, pct:100*gbps/ROOF});
}
say("done · adapter: "+((info.vendor||"?")+" "+(info.architecture||"")+" "+(info.device||"")).trim(),"ok");
const ok=rows.filter(r=>r.gbps); const best=ok.slice().sort((a,b)=>b.gbps-a.gbps)[0]; const b1=ok.find(r=>r.B===1);
const maxG=Math.max(...ok.map(r=>r.gbps));
const gain=(best&&b1)? best.gbps/b1.gbps : 0;
out.innerHTML=`
<div class="card"><span style="color:var(--dim)">Test matrix</span> &nbsp; ${N.toLocaleString()} × ${K} · ${(wBytes/1073741824).toFixed(2)} GB · roofline ${ROOF} GB/s (${ROOF_TOK} tok/s)</div>
<div class="card"><table>
<tr><th>rows/group (B)</th><th class="n">GB/s</th><th class="n">tok/s-equiv</th><th class="n">% roofline</th><th></th></tr>
${rows.map(r=>r.err
? `<tr><td>B=${r.B}</td><td colspan=4 class="no">failed: ${r.err.slice(0,50)}</td></tr>`
: `<tr><td>B=${r.B}${r.B===1?" (today)":""}</td><td class="n">${r.gbps.toFixed(0)}</td><td class="n ${r===best?'ok':''}"><b>${r.tok.toFixed(0)}</b></td><td class="n ${r.pct<40?'no':r.pct<70?'warn':'ok'}">${r.pct.toFixed(0)}%</td><td><span class="bar" style="width:${Math.round(150*r.gbps/maxG)}px"></span></td></tr>`).join("")}
</table></div>
${best? `<div class="verdict ${best.pct>=70?'ok':best.pct>=45?'warn':'no'}">Best: <b>B=${best.B}</b> → ${best.gbps.toFixed(0)} GB/s (${best.tok.toFixed(0)} tok/s, ${best.pct.toFixed(0)}% of roofline, <b>${gain.toFixed(2)}×</b> today's kernel). ${best.pct>=70?"Memory-latency hidden — port B="+best.B+" into mmT2/mmT2KK.":best.pct>=45?"Real gain; register pressure likely caps it here — port B="+best.B+".":"Little movement — the pattern needs a deeper restructure (shared-mem staging / wider loads)."}</div>` : `<div class="verdict no">All variants failed.</div>`}`;
}catch(e){ say("✗ "+(e&&(e.message||e)), "no"); }
})();
</script></body></html>