File size: 5,390 Bytes
db22e73 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | <!doctype html>
<meta charset="utf-8">
<title>Magenta RT sink-dot WebGPU PoC</title>
<style>
body{font:15px system-ui;max-width:760px;margin:40px auto;padding:0 20px}
button{padding:10px 16px;font:inherit} pre{white-space:pre-wrap;background:#f5f5f5;padding:16px}
</style>
<h1>Magenta RT custom WebGPU kernel PoC</h1>
<p>Fuses the 24 attention sink-dot Einsums in one WGSL dispatch.</p>
<button id="run">Run benchmark</button>
<pre id="out">idle</pre>
<script type="module">
import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.26.0/dist/ort.webgpu.min.mjs";
ort.env.wasm.numThreads=1;
ort.env.webgpu.powerPreference="high-performance";
const COUNT=24, HEADS=8, DIM=128, ITEMS=COUNT*HEADS;
const out=document.querySelector("#out");
const log=s=>out.textContent+=(out.textContent?"\n":"")+s;
const align=n=>Math.max(16,Math.ceil(n/16)*16);
const shader=`
struct Data { values: array<f32> }
@group(0) @binding(0) var<storage, read> queries: Data;
@group(0) @binding(1) var<storage, read> sinks: Data;
@group(0) @binding(2) var<storage, read_write> outputs: Data;
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let item=gid.x;
if (item>=${ITEMS}u) { return; }
let base=item*${DIM}u;
var sum=0.0;
for (var d=0u; d<${DIM}u; d=d+1u) {
sum=sum+queries.values[base+d]*sinks.values[base+d];
}
outputs.values[item]=sum;
}`;
async function download(device,buffer,bytes){
const staging=device.createBuffer({size:align(bytes),usage:GPUBufferUsage.COPY_DST|GPUBufferUsage.MAP_READ});
const enc=device.createCommandEncoder();
enc.copyBufferToBuffer(buffer,0,staging,0,bytes);
device.queue.submit([enc.finish()]);
await staging.mapAsync(GPUMapMode.READ);
const data=new Float32Array(staging.getMappedRange().slice(0,bytes));
staging.unmap(); staging.destroy();
return data;
}
document.querySelector("#run").onclick=async()=>{
out.textContent="loading ONNX reference and real sink weights...";
const [model,sinkBytes]=await Promise.all([
fetch("sink_dot_24.onnx").then(r=>r.arrayBuffer()),
fetch("sink_dot_24.onnx.sinks.bin").then(r=>r.arrayBuffer())
]);
const session=await ort.InferenceSession.create(new Uint8Array(model),{
executionProviders:["webgpu"],preferredOutputLocation:"gpu-buffer"
});
const device=ort.env.webgpu.device;
const usage=GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_SRC|GPUBufferUsage.COPY_DST;
const makeBuffer=(bytes,data)=>{
const b=device.createBuffer({size:align(bytes),usage});
if(data)device.queue.writeBuffer(b,0,data);
return b;
};
const queries=new Float32Array(ITEMS*DIM);
let state=123456789;
for(let i=0;i<queries.length;i++){
state^=state<<13;state^=state>>>17;state^=state<<5;
queries[i]=((state>>>0)/4294967295-.5)*.2;
}
const sinks=new Float32Array(sinkBytes);
const qPacked=makeBuffer(queries.byteLength,queries);
const sinkPacked=makeBuffer(sinks.byteLength,sinks);
const customOut=makeBuffer(ITEMS*4);
const module=device.createShaderModule({code:shader});
const pipeline=await device.createComputePipelineAsync({layout:"auto",compute:{module,entryPoint:"main"}});
const bindGroup=device.createBindGroup({layout:pipeline.getBindGroupLayout(0),entries:[
{binding:0,resource:{buffer:qPacked}},
{binding:1,resource:{buffer:sinkPacked}},
{binding:2,resource:{buffer:customOut}}
]});
const runCustom=()=>{
const enc=device.createCommandEncoder();
const pass=enc.beginComputePass();
pass.setPipeline(pipeline);pass.setBindGroup(0,bindGroup);
pass.dispatchWorkgroups(Math.ceil(ITEMS/64));pass.end();
device.queue.submit([enc.finish()]);
};
const feeds={};
const qBuffers=[];
for(let i=0;i<COUNT;i++){
const slice=queries.subarray(i*HEADS*DIM,(i+1)*HEADS*DIM);
const b=makeBuffer(slice.byteLength,slice);qBuffers.push(b);
feeds[`q_${i}`]=ort.Tensor.fromGpuBuffer(b,{dataType:"float32",dims:[1,HEADS,1,DIM]});
}
const runOrt=()=>session.run(feeds);
runCustom();
const customValues=await download(device,customOut,ITEMS*4);
const reference=await runOrt();
const ortValues=new Float32Array(ITEMS);
for(let i=0;i<COUNT;i++){
const values=await reference[`out_${i}`].getData();
ortValues.set(values,i*HEADS);
reference[`out_${i}`].dispose();
}
let maxError=0;
for(let i=0;i<ITEMS;i++)maxError=Math.max(maxError,Math.abs(customValues[i]-ortValues[i]));
log(`correctness max abs error: ${maxError.toExponential(3)}`);
for(let i=0;i<20;i++){
runCustom();
const result=await runOrt();
for(const value of Object.values(result))value.dispose();
}
await device.queue.onSubmittedWorkDone();
const N=500;
let t0=performance.now();
for(let i=0;i<N;i++)runCustom();
await device.queue.onSubmittedWorkDone();
const customMs=(performance.now()-t0)/N;
t0=performance.now();
for(let i=0;i<N;i++){
const result=await runOrt();
for(const value of Object.values(result))value.dispose();
}
await device.queue.onSubmittedWorkDone();
const ortMs=(performance.now()-t0)/N;
log(`custom fused WGSL: ${customMs.toFixed(4)} ms / 24 sink dots`);
log(`ORT 24 Einsums: ${ortMs.toFixed(4)} ms / 24 sink dots`);
log(`speedup: ${(ortMs/customMs).toFixed(2)}x`);
log(`whole-frame ceiling: sink dots are tiny; integration requires a fused custom attention operator.`);
};
</script>
|