Spaces:
Running
Running
| <html><head><meta charset=utf8><meta name=viewport content="width=device-width,initial-scale=1"> | |
| <title>Register-blocked GEMV — your GPU</title> | |
| <style> | |
| :root{--bg:#0a0d13;--panel:#111825;--ink:#e8ebf1;--dim:#8b95a7;--ac:#7c5cff;--ok:#48c26c;--no:#f0616d;--warn:#e0a94a;--line:#1e2836} | |
| *{box-sizing:border-box}body{margin:0;background:var(--bg);color:var(--ink);font:15px/1.6 -apple-system,Segoe UI,Roboto,monospace;padding:24px;max-width:860px;margin:0 auto} | |
| h1{font-size:20px;margin:0 0 4px}.sub{color:var(--dim);font-size:13px;margin:0 0 18px} | |
| .card{font-family:ui-monospace,monospace;padding:14px 16px;border:1px solid var(--line);border-radius:10px;background:var(--panel);margin-bottom:12px;word-break:break-word} | |
| table{width:100%;border-collapse:collapse;font-family:ui-monospace,monospace;font-size:13px} | |
| td,th{padding:6px 8px;border-bottom:1px solid var(--line);text-align:left}th{color:var(--dim);font-weight:600} | |
| td.n{text-align:right;font-variant-numeric:tabular-nums} | |
| .verdict{font-size:16px;font-weight:600;margin-top:14px}.bar{display:inline-block;height:9px;background:var(--ac);border-radius:2px;vertical-align:middle} | |
| </style></head><body> | |
| <h1>Register-blocked <span style="color:var(--ac)">ternary GEMV</span> — your GPU</h1> | |
| <p class="sub">The kernel is memory-latency-bound (45 of 152 GB/s). Each 64-thread group here computes <b>B output rows</b>, so each thread issues B independent weight loads per step (more loads in flight to hide latency) while sharing the activation reads. B=1 is today's kernel. Higher B should climb toward the 152 GB/s roofline — until register pressure drops occupancy.</p> | |
| <div id="status" class="card">starting…</div> | |
| <div id="out"></div> | |
| <script type="module"> | |
| const $=s=>document.querySelector(s), st=$("#status"), out=$("#out"); | |
| const MODEL_GB=0.69, ROOF=152, ROOF_TOK=220; | |
| const say=(t,cls)=>{ st.textContent=t; st.className="card"+(cls?" "+cls:""); }; | |
| window.addEventListener("unhandledrejection",e=>say("✗ unhandled: "+(e.reason&&(e.reason.message||e.reason)),"no")); | |
| const DOT16=` | |
| fn dot16(word:u32, v:u32) -> f32 { | |
| var s4=vec4<f32>(0.0); | |
| s4=s4+x[v] *(vec4<f32>(f32(word&3u),f32((word>>2u)&3u),f32((word>>4u)&3u),f32((word>>6u)&3u))-vec4<f32>(1.0)); | |
| s4=s4+x[v+1u]*(vec4<f32>(f32((word>>8u)&3u),f32((word>>10u)&3u),f32((word>>12u)&3u),f32((word>>14u)&3u))-vec4<f32>(1.0)); | |
| s4=s4+x[v+2u]*(vec4<f32>(f32((word>>16u)&3u),f32((word>>18u)&3u),f32((word>>20u)&3u),f32((word>>22u)&3u))-vec4<f32>(1.0)); | |
| s4=s4+x[v+3u]*(vec4<f32>(f32((word>>24u)&3u),f32((word>>26u)&3u),f32((word>>28u)&3u),f32((word>>30u)&3u))-vec4<f32>(1.0)); | |
| return s4.x+s4.y+s4.z+s4.w; | |
| }`; | |
| // B output rows per 64-thread group, 4 groups per workgroup (256 threads). Each thread does B independent | |
| // weight loads/iteration (memory-level parallelism), reusing the shared activation reads. | |
| const kern=(B)=>` | |
| @group(0) @binding(0) var<storage,read> x: array<vec4<f32>>; | |
| @group(0) @binding(1) var<storage,read> qw: array<u32>; | |
| @group(0) @binding(2) var<storage,read_write> o: array<f32>; | |
| @group(0) @binding(3) var<uniform> P: vec4<u32>; // K, N | |
| var<workgroup> red: array<f32, 256>; | |
| ${DOT16} | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(workgroup_id) wg:vec3<u32>, @builtin(local_invocation_id) lid:vec3<u32>){ | |
| let K=P.x; let nw=K>>4u; let grp=lid.x/64u; let t=lid.x%64u; | |
| let base=((wg.y*65535u+wg.x)*4u + grp)*${B}u; | |
| var acc: array<f32, ${B}>; | |
| ${Array.from({length:B},(_,b)=>`acc[${b}]=0.0;`).join(" ")} | |
| var w=t; | |
| loop{ if(w>=nw){break;} let v=w<<2u; | |
| ${Array.from({length:B},(_,b)=>`let word${b}=qw[min(base+${b}u,P.y-1u)*nw + w];`).join("\n ")} | |
| ${Array.from({length:B},(_,b)=>`acc[${b}]=acc[${b}]+dot16(word${b}, v);`).join("\n ")} | |
| w=w+64u; } | |
| ${Array.from({length:B},(_,b)=>` | |
| red[lid.x]=acc[${b}]; workgroupBarrier(); | |
| { var s=32u; loop{ if(s==0u){break;} if(t<s){ red[grp*64u+t]=red[grp*64u+t]+red[grp*64u+t+s]; } workgroupBarrier(); s=s/2u; } } | |
| if(t==0u){ let n=base+${b}u; if(n<P.y){ o[n]=red[grp*64u]; } } | |
| workgroupBarrier();`).join("")} | |
| }`; | |
| const FILL=` | |
| @group(0) @binding(0) var<storage,read_write> d: array<u32>; | |
| @group(0) @binding(1) var<uniform> P: vec4<u32>; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid:vec3<u32>){ let n=P.x; var i=gid.x; loop{ if(i>=n){break;} d[i]=(i*2654435761u+1u)^((i<<7u)*40503u); i=i+P.y; } }`; | |
| const BS=[1,2,4,8,16]; | |
| (async()=>{ | |
| try{ | |
| if(!navigator.gpu){ say("✗ No WebGPU. Open in Chrome/Edge.","no"); return; } | |
| say("requesting adapter…"); | |
| let ad=await navigator.gpu.requestAdapter({powerPreference:"high-performance"}); if(!ad) ad=await navigator.gpu.requestAdapter(); | |
| if(!ad){ say("✗ no GPU adapter.","no"); return; } | |
| const L=ad.limits, info=ad.info||{}; | |
| const dev=await ad.requestDevice({requiredLimits:{maxStorageBufferBindingSize:L.maxStorageBufferBindingSize, maxBufferSize:L.maxBufferSize, maxComputeWorkgroupsPerDimension:L.maxComputeWorkgroupsPerDimension}}); | |
| dev.lost.then(i=>say("✗ device lost: "+(i&&i.message||i.reason||""),"no")); | |
| const K=4096, nw=K/16; | |
| const wantBytes=Math.min(L.maxStorageBufferBindingSize, L.maxBufferSize, Math.round(MODEL_GB*1e9)); | |
| const N=Math.floor((wantBytes/4)/nw), qU32=N*nw, wBytes=qU32*4; | |
| say("allocating "+(wBytes/1048576).toFixed(0)+" MB…"); | |
| const qw=dev.createBuffer({size:wBytes, usage:GPUBufferUsage.STORAGE}); | |
| const xbuf=dev.createBuffer({size:K*4, usage:GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_DST}); | |
| const obuf=dev.createBuffer({size:N*4, usage:GPUBufferUsage.STORAGE}); | |
| const P=dev.createBuffer({size:16, usage:GPUBufferUsage.UNIFORM|GPUBufferUsage.COPY_DST}); | |
| const xa=new Float32Array(K); for(let i=0;i<K;i++) xa[i]=(Math.sin(i*12.9898)*43758.5453)%1; | |
| dev.queue.writeBuffer(xbuf,0,xa); dev.queue.writeBuffer(P,0,new Uint32Array([K,N,0,0])); | |
| say("filling weights…"); | |
| { const fmod=dev.createShaderModule({code:FILL}); const fpipe=dev.createComputePipeline({layout:"auto",compute:{module:fmod,entryPoint:"main"}}); | |
| const fp=dev.createBuffer({size:16,usage:GPUBufferUsage.UNIFORM|GPUBufferUsage.COPY_DST}); const wgF=Math.min(L.maxComputeWorkgroupsPerDimension,65535); | |
| dev.queue.writeBuffer(fp,0,new Uint32Array([qU32, wgF*256, 0, 0])); | |
| const fbg=dev.createBindGroup({layout:fpipe.getBindGroupLayout(0),entries:[{binding:0,resource:{buffer:qw}},{binding:1,resource:{buffer:fp}}]}); | |
| const e=dev.createCommandEncoder(); const p=e.beginComputePass(); p.setPipeline(fpipe); p.setBindGroup(0,fbg); p.dispatchWorkgroups(wgF); p.end(); | |
| dev.queue.submit([e.finish()]); await dev.queue.onSubmittedWorkDone(); } | |
| const PASSES=10, ITERS=5, rows=[]; | |
| for(const B of BS){ | |
| say("compiling B="+B+"…"); | |
| dev.pushErrorScope("validation"); | |
| const mod=dev.createShaderModule({code:kern(B)}); | |
| const ci=await mod.getCompilationInfo(); const er=ci.messages.filter(m=>m.type==="error"); | |
| if(er.length){ await dev.popErrorScope(); rows.push({B,err:er[0].message}); continue; } | |
| const pipe=dev.createComputePipeline({layout:"auto",compute:{module:mod,entryPoint:"main"}}); | |
| const bg=dev.createBindGroup({layout:pipe.getBindGroupLayout(0),entries:[{binding:0,resource:{buffer:xbuf}},{binding:1,resource:{buffer:qw}},{binding:2,resource:{buffer:obuf}},{binding:3,resource:{buffer:P}}]}); | |
| const se=await dev.popErrorScope(); if(se){ rows.push({B,err:se.message}); continue; } | |
| const groups=Math.ceil(N/(4*B)), wgx=Math.min(groups,65535), wgy=Math.ceil(groups/65535); | |
| async function run(passes){ const e=dev.createCommandEncoder(); for(let k=0;k<passes;k++){ const p=e.beginComputePass(); p.setPipeline(pipe); p.setBindGroup(0,bg); p.dispatchWorkgroups(wgx,wgy); p.end(); } const t0=performance.now(); dev.queue.submit([e.finish()]); await dev.queue.onSubmittedWorkDone(); return performance.now()-t0; } | |
| say("timing B="+B+"…"); | |
| await run(2); let best=1e9; for(let k=0;k<ITERS;k++){ best=Math.min(best, await run(PASSES)); } | |
| const gbps=(wBytes/1073741824)/((best/PASSES)/1000); | |
| rows.push({B, gbps, tok:gbps/MODEL_GB, pct:100*gbps/ROOF}); | |
| } | |
| say("done · adapter: "+((info.vendor||"?")+" "+(info.architecture||"")+" "+(info.device||"")).trim(),"ok"); | |
| const ok=rows.filter(r=>r.gbps); const best=ok.slice().sort((a,b)=>b.gbps-a.gbps)[0]; const b1=ok.find(r=>r.B===1); | |
| const maxG=Math.max(...ok.map(r=>r.gbps)); | |
| const gain=(best&&b1)? best.gbps/b1.gbps : 0; | |
| out.innerHTML=` | |
| <div class="card"><span style="color:var(--dim)">Test matrix</span> ${N.toLocaleString()} × ${K} · ${(wBytes/1073741824).toFixed(2)} GB · roofline ${ROOF} GB/s (${ROOF_TOK} tok/s)</div> | |
| <div class="card"><table> | |
| <tr><th>rows/group (B)</th><th class="n">GB/s</th><th class="n">tok/s-equiv</th><th class="n">% roofline</th><th></th></tr> | |
| ${rows.map(r=>r.err | |
| ? `<tr><td>B=${r.B}</td><td colspan=4 class="no">failed: ${r.err.slice(0,50)}</td></tr>` | |
| : `<tr><td>B=${r.B}${r.B===1?" (today)":""}</td><td class="n">${r.gbps.toFixed(0)}</td><td class="n ${r===best?'ok':''}"><b>${r.tok.toFixed(0)}</b></td><td class="n ${r.pct<40?'no':r.pct<70?'warn':'ok'}">${r.pct.toFixed(0)}%</td><td><span class="bar" style="width:${Math.round(150*r.gbps/maxG)}px"></span></td></tr>`).join("")} | |
| </table></div> | |
| ${best? `<div class="verdict ${best.pct>=70?'ok':best.pct>=45?'warn':'no'}">Best: <b>B=${best.B}</b> → ${best.gbps.toFixed(0)} GB/s (${best.tok.toFixed(0)} tok/s, ${best.pct.toFixed(0)}% of roofline, <b>${gain.toFixed(2)}×</b> today's kernel). ${best.pct>=70?"Memory-latency hidden — port B="+best.B+" into mmT2/mmT2KK.":best.pct>=45?"Real gain; register pressure likely caps it here — port B="+best.B+".":"Little movement — the pattern needs a deeper restructure (shared-mem staging / wider loads)."}</div>` : `<div class="verdict no">All variants failed.</div>`}`; | |
| }catch(e){ say("✗ "+(e&&(e.message||e)), "no"); } | |
| })(); | |
| </script></body></html> | |