File size: 5,390 Bytes
db22e73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
<!doctype html>
<meta charset="utf-8">
<title>Magenta RT sink-dot WebGPU PoC</title>
<style>
body{font:15px system-ui;max-width:760px;margin:40px auto;padding:0 20px}
button{padding:10px 16px;font:inherit} pre{white-space:pre-wrap;background:#f5f5f5;padding:16px}
</style>
<h1>Magenta RT custom WebGPU kernel PoC</h1>
<p>Fuses the 24 attention sink-dot Einsums in one WGSL dispatch.</p>
<button id="run">Run benchmark</button>
<pre id="out">idle</pre>
<script type="module">
import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.26.0/dist/ort.webgpu.min.mjs";
ort.env.wasm.numThreads=1;
ort.env.webgpu.powerPreference="high-performance";

const COUNT=24, HEADS=8, DIM=128, ITEMS=COUNT*HEADS;
const out=document.querySelector("#out");
const log=s=>out.textContent+=(out.textContent?"\n":"")+s;
const align=n=>Math.max(16,Math.ceil(n/16)*16);

const shader=`
struct Data { values: array<f32> }
@group(0) @binding(0) var<storage, read> queries: Data;
@group(0) @binding(1) var<storage, read> sinks: Data;
@group(0) @binding(2) var<storage, read_write> outputs: Data;

@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
  let item=gid.x;
  if (item>=${ITEMS}u) { return; }
  let base=item*${DIM}u;
  var sum=0.0;
  for (var d=0u; d<${DIM}u; d=d+1u) {
    sum=sum+queries.values[base+d]*sinks.values[base+d];
  }
  outputs.values[item]=sum;
}`;

async function download(device,buffer,bytes){
  const staging=device.createBuffer({size:align(bytes),usage:GPUBufferUsage.COPY_DST|GPUBufferUsage.MAP_READ});
  const enc=device.createCommandEncoder();
  enc.copyBufferToBuffer(buffer,0,staging,0,bytes);
  device.queue.submit([enc.finish()]);
  await staging.mapAsync(GPUMapMode.READ);
  const data=new Float32Array(staging.getMappedRange().slice(0,bytes));
  staging.unmap(); staging.destroy();
  return data;
}

document.querySelector("#run").onclick=async()=>{
  out.textContent="loading ONNX reference and real sink weights...";
  const [model,sinkBytes]=await Promise.all([
    fetch("sink_dot_24.onnx").then(r=>r.arrayBuffer()),
    fetch("sink_dot_24.onnx.sinks.bin").then(r=>r.arrayBuffer())
  ]);
  const session=await ort.InferenceSession.create(new Uint8Array(model),{
    executionProviders:["webgpu"],preferredOutputLocation:"gpu-buffer"
  });
  const device=ort.env.webgpu.device;
  const usage=GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_SRC|GPUBufferUsage.COPY_DST;
  const makeBuffer=(bytes,data)=>{
    const b=device.createBuffer({size:align(bytes),usage});
    if(data)device.queue.writeBuffer(b,0,data);
    return b;
  };
  const queries=new Float32Array(ITEMS*DIM);
  let state=123456789;
  for(let i=0;i<queries.length;i++){
    state^=state<<13;state^=state>>>17;state^=state<<5;
    queries[i]=((state>>>0)/4294967295-.5)*.2;
  }
  const sinks=new Float32Array(sinkBytes);
  const qPacked=makeBuffer(queries.byteLength,queries);
  const sinkPacked=makeBuffer(sinks.byteLength,sinks);
  const customOut=makeBuffer(ITEMS*4);

  const module=device.createShaderModule({code:shader});
  const pipeline=await device.createComputePipelineAsync({layout:"auto",compute:{module,entryPoint:"main"}});
  const bindGroup=device.createBindGroup({layout:pipeline.getBindGroupLayout(0),entries:[
    {binding:0,resource:{buffer:qPacked}},
    {binding:1,resource:{buffer:sinkPacked}},
    {binding:2,resource:{buffer:customOut}}
  ]});
  const runCustom=()=>{
    const enc=device.createCommandEncoder();
    const pass=enc.beginComputePass();
    pass.setPipeline(pipeline);pass.setBindGroup(0,bindGroup);
    pass.dispatchWorkgroups(Math.ceil(ITEMS/64));pass.end();
    device.queue.submit([enc.finish()]);
  };

  const feeds={};
  const qBuffers=[];
  for(let i=0;i<COUNT;i++){
    const slice=queries.subarray(i*HEADS*DIM,(i+1)*HEADS*DIM);
    const b=makeBuffer(slice.byteLength,slice);qBuffers.push(b);
    feeds[`q_${i}`]=ort.Tensor.fromGpuBuffer(b,{dataType:"float32",dims:[1,HEADS,1,DIM]});
  }
  const runOrt=()=>session.run(feeds);

  runCustom();
  const customValues=await download(device,customOut,ITEMS*4);
  const reference=await runOrt();
  const ortValues=new Float32Array(ITEMS);
  for(let i=0;i<COUNT;i++){
    const values=await reference[`out_${i}`].getData();
    ortValues.set(values,i*HEADS);
    reference[`out_${i}`].dispose();
  }
  let maxError=0;
  for(let i=0;i<ITEMS;i++)maxError=Math.max(maxError,Math.abs(customValues[i]-ortValues[i]));
  log(`correctness max abs error: ${maxError.toExponential(3)}`);

  for(let i=0;i<20;i++){
    runCustom();
    const result=await runOrt();
    for(const value of Object.values(result))value.dispose();
  }
  await device.queue.onSubmittedWorkDone();
  const N=500;
  let t0=performance.now();
  for(let i=0;i<N;i++)runCustom();
  await device.queue.onSubmittedWorkDone();
  const customMs=(performance.now()-t0)/N;

  t0=performance.now();
  for(let i=0;i<N;i++){
    const result=await runOrt();
    for(const value of Object.values(result))value.dispose();
  }
  await device.queue.onSubmittedWorkDone();
  const ortMs=(performance.now()-t0)/N;

  log(`custom fused WGSL: ${customMs.toFixed(4)} ms / 24 sink dots`);
  log(`ORT 24 Einsums:    ${ortMs.toFixed(4)} ms / 24 sink dots`);
  log(`speedup:           ${(ortMs/customMs).toFixed(2)}x`);
  log(`whole-frame ceiling: sink dots are tiny; integration requires a fused custom attention operator.`);
};
</script>