| <!doctype html> |
| <meta charset="utf-8"> |
| <title>Magenta RT sink-dot WebGPU PoC</title> |
| <style> |
| body{font:15px system-ui;max-width:760px;margin:40px auto;padding:0 20px} |
| button{padding:10px 16px;font:inherit} pre{white-space:pre-wrap;background:#f5f5f5;padding:16px} |
| </style> |
| <h1>Magenta RT custom WebGPU kernel PoC</h1> |
| <p>Fuses the 24 attention sink-dot Einsums in one WGSL dispatch.</p> |
| <button id="run">Run benchmark</button> |
| <pre id="out">idle</pre> |
| <script type="module"> |
| import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.26.0/dist/ort.webgpu.min.mjs"; |
| ort.env.wasm.numThreads=1; |
| ort.env.webgpu.powerPreference="high-performance"; |
| |
| const COUNT=24, HEADS=8, DIM=128, ITEMS=COUNT*HEADS; |
| const out=document.querySelector("#out"); |
| const log=s=>out.textContent+=(out.textContent?"\n":"")+s; |
| const align=n=>Math.max(16,Math.ceil(n/16)*16); |
| |
| const shader=` |
| struct Data { values: array<f32> } |
| @group(0) @binding(0) var<storage, read> queries: Data; |
| @group(0) @binding(1) var<storage, read> sinks: Data; |
| @group(0) @binding(2) var<storage, read_write> outputs: Data; |
| |
| @compute @workgroup_size(64) |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>) { |
| let item=gid.x; |
| if (item>=${ITEMS}u) { return; } |
| let base=item*${DIM}u; |
| var sum=0.0; |
| for (var d=0u; d<${DIM}u; d=d+1u) { |
| sum=sum+queries.values[base+d]*sinks.values[base+d]; |
| } |
| outputs.values[item]=sum; |
| }`; |
| |
| async function download(device,buffer,bytes){ |
| const staging=device.createBuffer({size:align(bytes),usage:GPUBufferUsage.COPY_DST|GPUBufferUsage.MAP_READ}); |
| const enc=device.createCommandEncoder(); |
| enc.copyBufferToBuffer(buffer,0,staging,0,bytes); |
| device.queue.submit([enc.finish()]); |
| await staging.mapAsync(GPUMapMode.READ); |
| const data=new Float32Array(staging.getMappedRange().slice(0,bytes)); |
| staging.unmap(); staging.destroy(); |
| return data; |
| } |
| |
| document.querySelector("#run").onclick=async()=>{ |
| out.textContent="loading ONNX reference and real sink weights..."; |
| const [model,sinkBytes]=await Promise.all([ |
| fetch("sink_dot_24.onnx").then(r=>r.arrayBuffer()), |
| fetch("sink_dot_24.onnx.sinks.bin").then(r=>r.arrayBuffer()) |
| ]); |
| const session=await ort.InferenceSession.create(new Uint8Array(model),{ |
| executionProviders:["webgpu"],preferredOutputLocation:"gpu-buffer" |
| }); |
| const device=ort.env.webgpu.device; |
| const usage=GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_SRC|GPUBufferUsage.COPY_DST; |
| const makeBuffer=(bytes,data)=>{ |
| const b=device.createBuffer({size:align(bytes),usage}); |
| if(data)device.queue.writeBuffer(b,0,data); |
| return b; |
| }; |
| const queries=new Float32Array(ITEMS*DIM); |
| let state=123456789; |
| for(let i=0;i<queries.length;i++){ |
| state^=state<<13;state^=state>>>17;state^=state<<5; |
| queries[i]=((state>>>0)/4294967295-.5)*.2; |
| } |
| const sinks=new Float32Array(sinkBytes); |
| const qPacked=makeBuffer(queries.byteLength,queries); |
| const sinkPacked=makeBuffer(sinks.byteLength,sinks); |
| const customOut=makeBuffer(ITEMS*4); |
| |
| const module=device.createShaderModule({code:shader}); |
| const pipeline=await device.createComputePipelineAsync({layout:"auto",compute:{module,entryPoint:"main"}}); |
| const bindGroup=device.createBindGroup({layout:pipeline.getBindGroupLayout(0),entries:[ |
| {binding:0,resource:{buffer:qPacked}}, |
| {binding:1,resource:{buffer:sinkPacked}}, |
| {binding:2,resource:{buffer:customOut}} |
| ]}); |
| const runCustom=()=>{ |
| const enc=device.createCommandEncoder(); |
| const pass=enc.beginComputePass(); |
| pass.setPipeline(pipeline);pass.setBindGroup(0,bindGroup); |
| pass.dispatchWorkgroups(Math.ceil(ITEMS/64));pass.end(); |
| device.queue.submit([enc.finish()]); |
| }; |
| |
| const feeds={}; |
| const qBuffers=[]; |
| for(let i=0;i<COUNT;i++){ |
| const slice=queries.subarray(i*HEADS*DIM,(i+1)*HEADS*DIM); |
| const b=makeBuffer(slice.byteLength,slice);qBuffers.push(b); |
| feeds[`q_${i}`]=ort.Tensor.fromGpuBuffer(b,{dataType:"float32",dims:[1,HEADS,1,DIM]}); |
| } |
| const runOrt=()=>session.run(feeds); |
| |
| runCustom(); |
| const customValues=await download(device,customOut,ITEMS*4); |
| const reference=await runOrt(); |
| const ortValues=new Float32Array(ITEMS); |
| for(let i=0;i<COUNT;i++){ |
| const values=await reference[`out_${i}`].getData(); |
| ortValues.set(values,i*HEADS); |
| reference[`out_${i}`].dispose(); |
| } |
| let maxError=0; |
| for(let i=0;i<ITEMS;i++)maxError=Math.max(maxError,Math.abs(customValues[i]-ortValues[i])); |
| log(`correctness max abs error: ${maxError.toExponential(3)}`); |
| |
| for(let i=0;i<20;i++){ |
| runCustom(); |
| const result=await runOrt(); |
| for(const value of Object.values(result))value.dispose(); |
| } |
| await device.queue.onSubmittedWorkDone(); |
| const N=500; |
| let t0=performance.now(); |
| for(let i=0;i<N;i++)runCustom(); |
| await device.queue.onSubmittedWorkDone(); |
| const customMs=(performance.now()-t0)/N; |
| |
| t0=performance.now(); |
| for(let i=0;i<N;i++){ |
| const result=await runOrt(); |
| for(const value of Object.values(result))value.dispose(); |
| } |
| await device.queue.onSubmittedWorkDone(); |
| const ortMs=(performance.now()-t0)/N; |
| |
| log(`custom fused WGSL: ${customMs.toFixed(4)} ms / 24 sink dots`); |
| log(`ORT 24 Einsums: ${ortMs.toFixed(4)} ms / 24 sink dots`); |
| log(`speedup: ${(ortMs/customMs).toFixed(2)}x`); |
| log(`whole-frame ceiling: sink dots are tiny; integration requires a fused custom attention operator.`); |
| }; |
| </script> |
|
|