magenta-rt-stack / webgpu-kernel-poc.html
multimodalart's picture
multimodalart HF Staff
Add custom WGSL sink-dot benchmark
db22e73 verified
<!doctype html>
<meta charset="utf-8">
<title>Magenta RT sink-dot WebGPU PoC</title>
<style>
body{font:15px system-ui;max-width:760px;margin:40px auto;padding:0 20px}
button{padding:10px 16px;font:inherit} pre{white-space:pre-wrap;background:#f5f5f5;padding:16px}
</style>
<h1>Magenta RT custom WebGPU kernel PoC</h1>
<p>Fuses the 24 attention sink-dot Einsums in one WGSL dispatch.</p>
<button id="run">Run benchmark</button>
<pre id="out">idle</pre>
<script type="module">
import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.26.0/dist/ort.webgpu.min.mjs";
ort.env.wasm.numThreads=1;
ort.env.webgpu.powerPreference="high-performance";
const COUNT=24, HEADS=8, DIM=128, ITEMS=COUNT*HEADS;
const out=document.querySelector("#out");
const log=s=>out.textContent+=(out.textContent?"\n":"")+s;
const align=n=>Math.max(16,Math.ceil(n/16)*16);
const shader=`
struct Data { values: array<f32> }
@group(0) @binding(0) var<storage, read> queries: Data;
@group(0) @binding(1) var<storage, read> sinks: Data;
@group(0) @binding(2) var<storage, read_write> outputs: Data;
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let item=gid.x;
if (item>=${ITEMS}u) { return; }
let base=item*${DIM}u;
var sum=0.0;
for (var d=0u; d<${DIM}u; d=d+1u) {
sum=sum+queries.values[base+d]*sinks.values[base+d];
}
outputs.values[item]=sum;
}`;
async function download(device,buffer,bytes){
const staging=device.createBuffer({size:align(bytes),usage:GPUBufferUsage.COPY_DST|GPUBufferUsage.MAP_READ});
const enc=device.createCommandEncoder();
enc.copyBufferToBuffer(buffer,0,staging,0,bytes);
device.queue.submit([enc.finish()]);
await staging.mapAsync(GPUMapMode.READ);
const data=new Float32Array(staging.getMappedRange().slice(0,bytes));
staging.unmap(); staging.destroy();
return data;
}
document.querySelector("#run").onclick=async()=>{
out.textContent="loading ONNX reference and real sink weights...";
const [model,sinkBytes]=await Promise.all([
fetch("sink_dot_24.onnx").then(r=>r.arrayBuffer()),
fetch("sink_dot_24.onnx.sinks.bin").then(r=>r.arrayBuffer())
]);
const session=await ort.InferenceSession.create(new Uint8Array(model),{
executionProviders:["webgpu"],preferredOutputLocation:"gpu-buffer"
});
const device=ort.env.webgpu.device;
const usage=GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_SRC|GPUBufferUsage.COPY_DST;
const makeBuffer=(bytes,data)=>{
const b=device.createBuffer({size:align(bytes),usage});
if(data)device.queue.writeBuffer(b,0,data);
return b;
};
const queries=new Float32Array(ITEMS*DIM);
let state=123456789;
for(let i=0;i<queries.length;i++){
state^=state<<13;state^=state>>>17;state^=state<<5;
queries[i]=((state>>>0)/4294967295-.5)*.2;
}
const sinks=new Float32Array(sinkBytes);
const qPacked=makeBuffer(queries.byteLength,queries);
const sinkPacked=makeBuffer(sinks.byteLength,sinks);
const customOut=makeBuffer(ITEMS*4);
const module=device.createShaderModule({code:shader});
const pipeline=await device.createComputePipelineAsync({layout:"auto",compute:{module,entryPoint:"main"}});
const bindGroup=device.createBindGroup({layout:pipeline.getBindGroupLayout(0),entries:[
{binding:0,resource:{buffer:qPacked}},
{binding:1,resource:{buffer:sinkPacked}},
{binding:2,resource:{buffer:customOut}}
]});
const runCustom=()=>{
const enc=device.createCommandEncoder();
const pass=enc.beginComputePass();
pass.setPipeline(pipeline);pass.setBindGroup(0,bindGroup);
pass.dispatchWorkgroups(Math.ceil(ITEMS/64));pass.end();
device.queue.submit([enc.finish()]);
};
const feeds={};
const qBuffers=[];
for(let i=0;i<COUNT;i++){
const slice=queries.subarray(i*HEADS*DIM,(i+1)*HEADS*DIM);
const b=makeBuffer(slice.byteLength,slice);qBuffers.push(b);
feeds[`q_${i}`]=ort.Tensor.fromGpuBuffer(b,{dataType:"float32",dims:[1,HEADS,1,DIM]});
}
const runOrt=()=>session.run(feeds);
runCustom();
const customValues=await download(device,customOut,ITEMS*4);
const reference=await runOrt();
const ortValues=new Float32Array(ITEMS);
for(let i=0;i<COUNT;i++){
const values=await reference[`out_${i}`].getData();
ortValues.set(values,i*HEADS);
reference[`out_${i}`].dispose();
}
let maxError=0;
for(let i=0;i<ITEMS;i++)maxError=Math.max(maxError,Math.abs(customValues[i]-ortValues[i]));
log(`correctness max abs error: ${maxError.toExponential(3)}`);
for(let i=0;i<20;i++){
runCustom();
const result=await runOrt();
for(const value of Object.values(result))value.dispose();
}
await device.queue.onSubmittedWorkDone();
const N=500;
let t0=performance.now();
for(let i=0;i<N;i++)runCustom();
await device.queue.onSubmittedWorkDone();
const customMs=(performance.now()-t0)/N;
t0=performance.now();
for(let i=0;i<N;i++){
const result=await runOrt();
for(const value of Object.values(result))value.dispose();
}
await device.queue.onSubmittedWorkDone();
const ortMs=(performance.now()-t0)/N;
log(`custom fused WGSL: ${customMs.toFixed(4)} ms / 24 sink dots`);
log(`ORT 24 Einsums: ${ortMs.toFixed(4)} ms / 24 sink dots`);
log(`speedup: ${(ortMs/customMs).toFixed(2)}x`);
log(`whole-frame ceiling: sink dots are tiny; integration requires a fused custom attention operator.`);
};
</script>