Spaces:
Running
Running
| (function(){const e=document.createElement("link").relList;if(e&&e.supports&&e.supports("modulepreload"))return;for(const r of document.querySelectorAll('link[rel="modulepreload"]'))t(r);new MutationObserver(r=>{for(const a of r)if(a.type==="childList")for(const i of a.addedNodes)i.tagName==="LINK"&&i.rel==="modulepreload"&&t(i)}).observe(document,{childList:!0,subtree:!0});function u(r){const a={};return r.integrity&&(a.integrity=r.integrity),r.referrerPolicy&&(a.referrerPolicy=r.referrerPolicy),r.crossOrigin==="use-credentials"?a.credentials="include":r.crossOrigin==="anonymous"?a.credentials="omit":a.credentials="same-origin",a}function t(r){if(r.ep)return;r.ep=!0;const a=u(r);fetch(r.href,a)}})();class b{constructor(){this.device=null,this.adapter=null,this.adapterInfo=null,this.pipelineCache=new Map,this.bufferCache=new Map,this.bindGroupCache=new Map}async init(){if(!navigator.gpu)throw new Error("WebGPU not supported");if(this.adapter=await navigator.gpu.requestAdapter({powerPreference:"high-performance"}),!this.adapter)throw new Error("No WebGPU adapter found");this.adapterInfo=await this.adapter.requestAdapterInfo?.()??{};const e={},u={maxBufferSize:4*1024*1024*1024,maxStorageBufferBindingSize:4*1024*1024*1024,maxComputeWorkgroupStorageSize:32768,maxComputeInvocationsPerWorkgroup:256,maxComputeWorkgroupSizeX:256,maxStorageBuffersPerShaderStage:10};for(const[r,a]of Object.entries(u))this.adapter.limits[r]!==void 0&&(e[r]=Math.min(a,this.adapter.limits[r]));const t=[];return this.adapter.features.has("shader-f16")&&t.push("shader-f16"),this.adapter.features.has("subgroups")&&t.push("subgroups"),this.device=await this.adapter.requestDevice({requiredLimits:e,requiredFeatures:t}),this.hasF16=this.device.features.has("shader-f16"),this.hasSubgroups=this.device.features.has("subgroups"),this.device.lost.then(r=>console.error("WebGPU device lost:",r)),this}createBuffer(e,u,t=GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_DST){const r=Math.ceil(u/4)*4,a=this.device.createBuffer({size:r,usage:t,label:e});return this.bufferCache.set(e,a),a}createBufferFromData(e,u,t=GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_DST|GPUBufferUsage.COPY_SRC){const r=this.createBuffer(e,u.byteLength,t);return this.device.queue.writeBuffer(r,0,u),r}createReadbackBuffer(e,u){const t=Math.ceil(u/4)*4;return this.device.createBuffer({size:t,usage:GPUBufferUsage.MAP_READ|GPUBufferUsage.COPY_DST,label:e+"_readback"})}getOrCreatePipeline(e,u,t="main"){if(this.pipelineCache.has(e))return this.pipelineCache.get(e);const r=this.device.createShaderModule({code:u,label:e}),a=this.device.createComputePipeline({layout:"auto",compute:{module:r,entryPoint:t},label:e});return this.pipelineCache.set(e,a),a}initTimestamps(){this.device.features.has("timestamp-query")&&(this._tsQuerySet=this.device.createQuerySet({type:"timestamp",count:2}),this._tsResolveBuf=this.device.createBuffer({size:16,usage:GPUBufferUsage.QUERY_RESOLVE|GPUBufferUsage.COPY_SRC}),this._tsReadBuf=this.device.createBuffer({size:16,usage:GPUBufferUsage.MAP_READ|GPUBufferUsage.COPY_DST}),this._tsEnabled=!0,this._tsResults=[])}beginBatch(){this._encoder=this.device.createCommandEncoder(),this._passCount=0,this.singlePassMode&&(this._singlePass=this._encoder.beginComputePass()),this._tsEnabled&&this._encoder.beginComputePass({timestampWrites:{querySet:this._tsQuerySet,beginningOfPassWriteIndex:0}}).end()}endBatch(){this._singlePass&&(this._singlePass.end(),this._singlePass=null),this._tsEnabled&&this._encoder&&(this._encoder.beginComputePass({timestampWrites:{querySet:this._tsQuerySet,beginningOfPassWriteIndex:1}}).end(),this._encoder.resolveQuerySet(this._tsQuerySet,0,2,this._tsResolveBuf,0),this._encoder.copyBufferToBuffer(this._tsResolveBuf,0,this._tsReadBuf,0,16)),this._encoder&&(this.device.queue.submit([this._encoder.finish()]),this._encoder=null)}async readTimestamp(){if(!this._tsEnabled)return null;await this._tsReadBuf.mapAsync(GPUMapMode.READ);const e=new BigInt64Array(this._tsReadBuf.getMappedRange().slice(0));this._tsReadBuf.unmap();const t=Number(e[1]-e[0])/1e6;return this._tsResults.push(t),t}copyBuffer(e,u,t,r=0,a=0){if(this._singlePass){this._singlePass.end(),this._encoder.copyBufferToBuffer(e,r,u,a,t),this._singlePass=this._encoder.beginComputePass();return}const i=this._encoder||this.device.createCommandEncoder();i.copyBufferToBuffer(e,r,u,a,t),this._encoder||this.device.queue.submit([i.finish()])}startRecording(){this._recording=[]}stopRecording(){const e=this._recording;return this._recording=null,e}replay(e,u){if(u)for(const r of u)this.device.queue.writeBuffer(r.buffer,r.offset,r.data,r.dataOffset,r.size);const t=this._encoder;for(let r=0;r<e.length;r++){const a=e[r];if(a.multi){const i=t.beginComputePass(),o=a.ops;for(let n=0;n<o.length;n++){const s=o[n];i.setPipeline(s.pipeline),i.setBindGroup(0,s.bindGroup),i.dispatchWorkgroups(s.wgX,s.wgY)}i.end()}else{const i=t.beginComputePass();i.setPipeline(a.pipeline),i.setBindGroup(0,a.bindGroup),i.dispatchWorkgroups(a.wgX,a.wgY),i.end()}}}dispatch(e,u,t,r=1,a=1){if(this._recording&&this._recording.push({pipeline:e,bindGroup:u[0],wgX:t,wgY:r}),this._singlePass){const n=this._singlePass;this._passCount!==void 0&&this._passCount++,n.setPipeline(e);for(let s=0;s<u.length;s++)n.setBindGroup(s,u[s]);n.dispatchWorkgroups(t,r,a);return}const i=this._encoder||this.device.createCommandEncoder(),o=i.beginComputePass();this._passCount!==void 0&&this._passCount++,o.setPipeline(e);for(let n=0;n<u.length;n++)o.setBindGroup(n,u[n]);o.dispatchWorkgroups(t,r,a),o.end(),this._encoder||this.device.queue.submit([i.finish()])}dispatchMulti(e){if(this._recording&&this._recording.push({multi:!0,ops:e.map(r=>({pipeline:r.pipeline,bindGroup:r.bindGroups[0],wgX:r.workgroupsX,wgY:r.workgroupsY||1}))}),this._singlePass){this._passCount!==void 0&&this._passCount++;for(const r of e){this._singlePass.setPipeline(r.pipeline);for(let a=0;a<r.bindGroups.length;a++)this._singlePass.setBindGroup(a,r.bindGroups[a]);this._singlePass.dispatchWorkgroups(r.workgroupsX,r.workgroupsY||1,r.workgroupsZ||1)}return}const u=this._encoder||this.device.createCommandEncoder(),t=u.beginComputePass();this._passCount!==void 0&&this._passCount++;for(const r of e){t.setPipeline(r.pipeline);for(let a=0;a<r.bindGroups.length;a++)t.setBindGroup(a,r.bindGroups[a]);t.dispatchWorkgroups(r.workgroupsX,r.workgroupsY||1,r.workgroupsZ||1)}t.end(),this._encoder||this.device.queue.submit([u.finish()])}async readBuffer(e,u){const t=this.createReadbackBuffer("_readback",u),r=this.device.createCommandEncoder();r.copyBufferToBuffer(e,0,t,0,u),this.device.queue.submit([r.finish()]),await t.mapAsync(GPUMapMode.READ);const a=new Float32Array(t.getMappedRange().slice(0));return t.unmap(),t.destroy(),a}createBindGroup(e,u,t){return this.device.createBindGroup({layout:e.getBindGroupLayout(u),entries:t.map((r,a)=>({binding:a,resource:{buffer:r}}))})}createBindGroupWithOffsets(e,u,t){return this.device.createBindGroup({layout:e.getBindGroupLayout(u),entries:t.map((r,a)=>({binding:a,resource:r.buffer?{buffer:r.buffer,offset:r.offset||0,size:r.size}:{buffer:r}}))})}getCachedBindGroup(e,u,t,r){let a=this.bindGroupCache.get(e);return a||(a=this.createBindGroup(u,t,r),this.bindGroupCache.set(e,a)),a}destroy(){for(const e of this.bufferCache.values())e.destroy();this.bufferCache.clear(),this.bindGroupCache.clear(),this.device?.destroy()}}const Tr=Object.freeze(Object.defineProperty({__proto__:null,GPUContext:b},Symbol.toStringTag,{value:"Module"})),v="modulepreload",m=function(d){return"/"+d},f={},Hr=function(e,u,t){let r=Promise.resolve();if(u&&u.length>0){let n=function(s){return Promise.all(s.map(l=>Promise.resolve(l).then(p=>({status:"fulfilled",value:p}),p=>({status:"rejected",reason:p}))))};document.getElementsByTagName("link");const i=document.querySelector("meta[property=csp-nonce]"),o=i?.nonce||i?.getAttribute("nonce");r=n(u.map(s=>{if(s=m(s),s in f)return;f[s]=!0;const l=s.endsWith(".css"),p=l?'[rel="stylesheet"]':"";if(document.querySelector(`link[href="${s}"]${p}`))return;const g=document.createElement("link");if(g.rel=l?"stylesheet":v,l||(g.as="script"),g.crossOrigin="",g.href=s,o&&g.setAttribute("nonce",o),document.head.appendChild(g),l)return new Promise((c,w)=>{g.addEventListener("load",c),g.addEventListener("error",()=>w(new Error(`Unable to preload CSS for ${s}`)))})}))}function a(i){const o=new Event("vite:preloadError",{cancelable:!0});if(o.payload=i,window.dispatchEvent(o),!o.defaultPrevented)throw i}return r.then(i=>{for(const o of i||[])o.status==="rejected"&&a(o.reason);return e().catch(a)})},k=` | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(2) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(3) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| fn dq(w: u32, bk: u32) -> f32 { | |
| return dot(vec4<f32>(input[bk], input[bk+1u], input[bk+2u], input[bk+3u]), | |
| vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) | |
| + dot(vec4<f32>(input[bk+4u], input[bk+5u], input[bk+6u], input[bk+7u]), | |
| vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let col = gid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| if (col >= N) { return; } | |
| var sum: f32 = 0.0; | |
| for (var g: u32 = 0u; g < num_groups; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; | |
| let w1 = qweight[off + N]; | |
| let w2 = qweight[off + 2u * N]; | |
| let w3 = qweight[off + 3u * N]; | |
| let bk = pr * 8u; | |
| sum += s * (dq(w0, bk) + dq(w1, bk+8u) + | |
| dq(w2, bk+16u) + dq(w3, bk+24u)); | |
| } | |
| } | |
| output[col] = sum; | |
| }`,h=` | |
| enable f16; | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(2) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(3) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| fn dq16(w: u32, bk: u32) -> f32 { | |
| let i0 = vec4<f16>(f16(input[bk]), f16(input[bk+1u]), f16(input[bk+2u]), f16(input[bk+3u])); | |
| let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); | |
| let i1 = vec4<f16>(f16(input[bk+4u]), f16(input[bk+5u]), f16(input[bk+6u]), f16(input[bk+7u])); | |
| let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); | |
| return f32(dot(i0, w0)) + f32(dot(i1, w1)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let col = gid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| if (col >= N) { return; } | |
| var sum: f32 = 0.0; | |
| for (var g: u32 = 0u; g < num_groups; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; | |
| let w1 = qweight[off + N]; | |
| let w2 = qweight[off + 2u * N]; | |
| let w3 = qweight[off + 3u * N]; | |
| let bk = pr * 8u; | |
| sum += s * (dq16(w0, bk) + dq16(w1, bk+8u) + | |
| dq16(w2, bk+16u) + dq16(w3, bk+24u)); | |
| } | |
| } | |
| output[col] = sum; | |
| }`,x=` | |
| struct Params { K: u32, N: u32, group_size: u32, num_splits: u32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(2) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(3) var<storage, read_write> partials: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| fn dq(w: u32, bk: u32) -> f32 { | |
| return dot(vec4<f32>(input[bk], input[bk+1u], input[bk+2u], input[bk+3u]), | |
| vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) | |
| + dot(vec4<f32>(input[bk+4u], input[bk+5u], input[bk+6u], input[bk+7u]), | |
| vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u, @builtin(workgroup_id) wgid: vec3u) { | |
| let col = gid.x; | |
| let split_id = wgid.y; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let ns = params.num_splits; | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| if (col >= N) { return; } | |
| let groups_per_split = num_groups / ns; | |
| let g_start = split_id * groups_per_split; | |
| let g_end = select(g_start + groups_per_split, num_groups, split_id >= ns - 1u); | |
| var sum: f32 = 0.0; | |
| for (var g = g_start; g < g_end; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; | |
| let w1 = qweight[off + N]; | |
| let w2 = qweight[off + 2u * N]; | |
| let w3 = qweight[off + 3u * N]; | |
| let bk = pr * 8u; | |
| sum += s * (dq(w0, bk) + dq(w1, bk+8u) + | |
| dq(w2, bk+16u) + dq(w3, bk+24u)); | |
| } | |
| } | |
| partials[split_id * N + col] = sum; | |
| }`,q=` | |
| struct Params { N: u32, num_splits: u32, } | |
| @group(0) @binding(0) var<storage, read> partials: array<f32>; | |
| @group(0) @binding(1) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(2) var<uniform> params: Params; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let col = gid.x; | |
| if (col >= params.N) { return; } | |
| var sum: f32 = 0.0; | |
| for (var s = 0u; s < params.num_splits; s++) { | |
| sum += partials[s * params.N + col]; | |
| } | |
| output[col] = sum; | |
| }`,y=` | |
| struct Params { K: u32, N: u32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read> weight: array<u32>; | |
| @group(0) @binding(2) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(3) var<uniform> params: Params; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let col = gid.x; | |
| if (col >= params.N) { return; } | |
| var sum: f32 = 0.0; | |
| let K = params.K; | |
| let base = col * K / 2u; | |
| // Process 4 BF16 values (2 u32 reads) per iteration | |
| for (var k: u32 = 0u; k < K; k += 4u) { | |
| let p0 = weight[base + k / 2u]; | |
| let p1 = weight[base + k / 2u + 1u]; | |
| let w0 = bitcast<f32>((p0 & 0xFFFFu) << 16u); | |
| let w1 = bitcast<f32>(p0 & 0xFFFF0000u); | |
| let w2 = bitcast<f32>((p1 & 0xFFFFu) << 16u); | |
| let w3 = bitcast<f32>(p1 & 0xFFFF0000u); | |
| sum += input[k] * w0 + input[k + 1u] * w1 + input[k + 2u] * w2 + input[k + 3u] * w3; | |
| } | |
| output[col] = sum; | |
| }`,F=` | |
| struct Params { N: u32, eps: f32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read> weight: array<u32>; | |
| @group(0) @binding(2) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(3) var<uniform> params: Params; | |
| fn unpack_bf16(packed: u32, idx: u32) -> f32 { | |
| let bits = (packed >> (idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| var<workgroup> wg_scratch: array<f32, 256>; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(local_invocation_id) lid: vec3u) { | |
| let tid = lid.x; | |
| let N = params.N; | |
| // Parallel sum of squares | |
| var ss: f32 = 0.0; | |
| for (var i = tid; i < N; i += 256u) { | |
| ss += input[i] * input[i]; | |
| } | |
| wg_scratch[tid] = ss; | |
| workgroupBarrier(); | |
| // Tree reduction | |
| for (var s = 128u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps); | |
| // Parallel normalize | |
| for (var i = tid; i < N; i += 256u) { | |
| let w = unpack_bf16(weight[i / 2u], i % 2u); | |
| output[i] = input[i] * rms * (1.0 + w); | |
| } | |
| }`,N=` | |
| struct Params { N: u32, } | |
| @group(0) @binding(0) var<storage, read> gate: array<f32>; | |
| @group(0) @binding(1) var<storage, read> up: array<f32>; | |
| @group(0) @binding(2) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(3) var<uniform> params: Params; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let i = gid.x; | |
| if (i >= params.N) { return; } | |
| let x = gate[i]; | |
| output[i] = (x / (1.0 + exp(-x))) * up[i]; | |
| }`,P=` | |
| struct Params { N: u32, } | |
| @group(0) @binding(0) var<storage, read_write> a: array<f32>; | |
| @group(0) @binding(1) var<storage, read> b: array<f32>; | |
| @group(0) @binding(2) var<uniform> params: Params; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let i = gid.x; | |
| if (i >= params.N) { return; } | |
| a[i] = a[i] + b[i]; | |
| }`,E=` | |
| struct Params { N: u32, num_heads: u32, head_dim: u32, } | |
| @group(0) @binding(0) var<storage, read> src: array<f32>; | |
| @group(0) @binding(1) var<storage, read_write> dst_a: array<f32>; | |
| @group(0) @binding(2) var<storage, read_write> dst_b: array<f32>; | |
| @group(0) @binding(3) var<uniform> params: Params; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let i = gid.x; | |
| if (i >= params.N) { return; } | |
| let hd = params.head_dim; | |
| let head = i / hd; | |
| let d = i % hd; | |
| dst_a[i] = src[head * hd * 2u + d]; | |
| dst_b[i] = src[head * hd * 2u + hd + d]; | |
| }`,S=` | |
| struct Params { N: u32, } | |
| @group(0) @binding(0) var<storage, read> x: array<f32>; | |
| @group(0) @binding(1) var<storage, read> gate: array<f32>; | |
| @group(0) @binding(2) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(3) var<uniform> params: Params; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let i = gid.x; | |
| if (i >= params.N) { return; } | |
| output[i] = x[i] / (1.0 + exp(-gate[i])); | |
| }`,R=` | |
| struct Params { token_id: u32, dim: u32, } | |
| @group(0) @binding(0) var<storage, read> embeddings: array<u32>; | |
| @group(0) @binding(1) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(2) var<uniform> params: Params; | |
| fn unpack_bf16(packed: u32, idx: u32) -> f32 { | |
| let bits = (packed >> (idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let i = gid.x; | |
| if (i >= params.dim) { return; } | |
| let flat = params.token_id * params.dim + i; | |
| output[i] = unpack_bf16(embeddings[flat / 2u], flat % 2u); | |
| }`,K=` | |
| struct ArgmaxResult { idx: u32, val: f32, } | |
| struct Params { dim: u32, } | |
| @group(0) @binding(0) var<storage, read> embeddings: array<u32>; | |
| @group(0) @binding(1) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(2) var<storage, read> argmax_result: ArgmaxResult; | |
| @group(0) @binding(3) var<uniform> params: Params; | |
| fn unpack_bf16(packed: u32, idx: u32) -> f32 { | |
| let bits = (packed >> (idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let i = gid.x; | |
| if (i >= params.dim) { return; } | |
| let flat = argmax_result.idx * params.dim + i; | |
| output[i] = unpack_bf16(embeddings[flat / 2u], flat % 2u); | |
| }`,B=` | |
| struct Params { N: u32, } | |
| struct Result { idx: u32, val: f32, } | |
| @group(0) @binding(0) var<storage, read> logits: array<f32>; | |
| @group(0) @binding(1) var<storage, read_write> result: Result; | |
| @group(0) @binding(2) var<uniform> params: Params; | |
| var<workgroup> s_val: array<f32, 256>; | |
| var<workgroup> s_idx: array<u32, 256>; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(local_invocation_id) lid: vec3u) { | |
| let tid = lid.x; | |
| let N = params.N; | |
| // Each thread scans its portion | |
| var best_val: f32 = -1e30; | |
| var best_idx: u32 = 0u; | |
| for (var i = tid; i < N; i += 256u) { | |
| if (logits[i] > best_val) { | |
| best_val = logits[i]; | |
| best_idx = i; | |
| } | |
| } | |
| s_val[tid] = best_val; | |
| s_idx[tid] = best_idx; | |
| workgroupBarrier(); | |
| // Tree reduction | |
| for (var s = 128u; s > 0u; s >>= 1u) { | |
| if (tid < s && s_val[tid + s] > s_val[tid]) { | |
| s_val[tid] = s_val[tid + s]; | |
| s_idx[tid] = s_idx[tid + s]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| if (tid == 0u) { | |
| result.idx = s_idx[0]; | |
| result.val = s_val[0]; | |
| } | |
| }`,D=` | |
| struct Params { N: u32, } | |
| @group(0) @binding(0) var<storage, read> logits: array<f32>; | |
| @group(0) @binding(1) var<storage, read_write> result: array<u32>; | |
| @group(0) @binding(2) var<uniform> params: Params; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(local_invocation_id) lid: vec3u) { | |
| let tid = lid.x; | |
| let N = params.N; | |
| // Each thread scans its portion, keeping best 1 | |
| var best_val: f32 = -1e30; | |
| var best_idx: u32 = 0u; | |
| for (var i = tid; i < N; i += 256u) { | |
| if (logits[i] > best_val) { | |
| best_val = logits[i]; | |
| best_idx = i; | |
| } | |
| } | |
| // Output (idx, val) pair — 256 candidates total | |
| result[tid * 2u] = best_idx; | |
| result[tid * 2u + 1u] = bitcast<u32>(best_val); | |
| }`,A=` | |
| struct ArgmaxResult { idx: u32, val: f32, } | |
| struct Params { recent_count: u32, history_slot: u32, } | |
| @group(0) @binding(0) var<storage, read> argmax_result: ArgmaxResult; | |
| @group(0) @binding(1) var<storage, read_write> recent_tokens: array<u32>; | |
| @group(0) @binding(2) var<storage, read_write> token_history: array<u32>; | |
| @group(0) @binding(3) var<uniform> params: Params; | |
| @compute @workgroup_size(1) | |
| fn main() { | |
| let tok = argmax_result.idx; | |
| recent_tokens[params.recent_count] = tok; | |
| token_history[params.history_slot] = tok; | |
| }`,z=` | |
| struct Params { position: u32, num_kv_heads: u32, head_dim: u32, } | |
| @group(0) @binding(0) var<storage, read> k_proj: array<f32>; | |
| @group(0) @binding(1) var<storage, read> v_proj: array<f32>; | |
| @group(0) @binding(2) var<storage, read_write> k_cache: array<f32>; | |
| @group(0) @binding(3) var<storage, read_write> v_cache: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let i = gid.x; | |
| let total = params.num_kv_heads * params.head_dim; | |
| if (i >= total) { return; } | |
| let offset = params.position * total + i; | |
| k_cache[offset] = k_proj[i]; | |
| v_cache[offset] = v_proj[i]; | |
| }`,T=` | |
| struct Params { | |
| seq_len: u32, | |
| head_dim: u32, | |
| num_kv_heads: u32, | |
| num_q_heads: u32, | |
| heads_per_kv: u32, | |
| num_splits: u32, | |
| _pad0: u32, | |
| _pad1: u32, | |
| } | |
| @group(0) @binding(0) var<storage, read> q_full: array<f32>; | |
| @group(0) @binding(1) var<storage, read> k_cache: array<f32>; | |
| @group(0) @binding(2) var<storage, read> v_cache: array<f32>; | |
| @group(0) @binding(3) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| var<workgroup> s_dot: array<f32, 256>; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) { | |
| let h = wg.x; | |
| let split = wg.y; | |
| let tid = lid.x; | |
| let hd = params.head_dim; | |
| let kv_head = h / params.heads_per_kv; | |
| let q_offset = h * hd; | |
| let kv_stride = params.num_kv_heads * hd; | |
| let kv_offset = kv_head * hd; | |
| let scale = 1.0 / sqrt(f32(hd)); | |
| let ns = params.num_splits; | |
| // Compute this split's range of KV positions | |
| let chunk = (params.seq_len + ns - 1u) / ns; | |
| let pos_start = split * chunk; | |
| let pos_end = min(pos_start + chunk, params.seq_len); | |
| // Load Q (already RoPE'd) | |
| let q_val = q_full[q_offset + tid]; | |
| // Online softmax over this split's positions | |
| var running_max: f32 = -1e30; | |
| var running_sum: f32 = 0.0; | |
| var running_out: f32 = 0.0; | |
| for (var pos = pos_start; pos < pos_end; pos++) { | |
| let k_base = pos * kv_stride + kv_offset; | |
| // K is already RoPE'd in cache | |
| let k_val = k_cache[k_base + tid]; | |
| s_dot[tid] = q_val * k_val; | |
| workgroupBarrier(); | |
| for (var s = 128u; s > 0u; s >>= 1u) { | |
| if (tid < s) { s_dot[tid] += s_dot[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| let score = s_dot[0] * scale; | |
| let v_val = v_cache[pos * kv_stride + kv_offset + tid]; | |
| let new_max = max(running_max, score); | |
| let correction = exp(running_max - new_max); | |
| running_sum = running_sum * correction + exp(score - new_max); | |
| running_out = running_out * correction + exp(score - new_max) * v_val; | |
| running_max = new_max; | |
| } | |
| if (ns == 1u) { | |
| // Single split: write final output directly (no reduction needed) | |
| if (running_sum > 0.0) { | |
| output[q_offset + tid] = running_out / running_sum; | |
| } else { | |
| output[q_offset + tid] = 0.0; | |
| } | |
| } else { | |
| // Multi-split: write partial results for reduction kernel | |
| // Layout: [num_heads * num_splits * (head_dim + 2)] | |
| // Per split: [partial_out[head_dim], log_sum_exp, running_max] | |
| let stride = hd + 2u; | |
| let base = (h * ns + split) * stride; | |
| if (running_sum > 0.0) { | |
| output[base + tid] = running_out / running_sum; | |
| } else { | |
| output[base + tid] = 0.0; | |
| } | |
| if (tid == 0u) { | |
| output[base + hd] = log(running_sum) + running_max; // LSE | |
| output[base + hd + 1u] = running_max; | |
| } | |
| } | |
| }`,H=` | |
| struct Params { | |
| head_dim: u32, | |
| num_splits: u32, | |
| num_heads: u32, | |
| _pad: u32, | |
| } | |
| @group(0) @binding(0) var<storage, read> partials: array<f32>; | |
| @group(0) @binding(1) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(2) var<uniform> params: Params; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) { | |
| let h = wg.x; | |
| let tid = lid.x; | |
| let hd = params.head_dim; | |
| let ns = params.num_splits; | |
| let stride = hd + 2u; | |
| // Find global max LSE across splits | |
| var global_max_lse: f32 = -1e30; | |
| for (var s: u32 = 0u; s < ns; s++) { | |
| let base = (h * ns + s) * stride; | |
| let lse = partials[base + hd]; | |
| global_max_lse = max(global_max_lse, lse); | |
| } | |
| // Weighted combination using LSE rescaling | |
| var acc: f32 = 0.0; | |
| var weight_sum: f32 = 0.0; | |
| for (var s: u32 = 0u; s < ns; s++) { | |
| let base = (h * ns + s) * stride; | |
| let lse = partials[base + hd]; | |
| let w = exp(lse - global_max_lse); | |
| acc += w * partials[base + tid]; | |
| if (tid == 0u) { | |
| weight_sum += w; | |
| } | |
| } | |
| // Broadcast weight_sum from thread 0 | |
| // (all threads computed same weight_sum since lse is scalar, but only tid==0 accumulated) | |
| // Actually all threads can compute weight_sum independently | |
| var ws: f32 = 0.0; | |
| for (var s: u32 = 0u; s < ns; s++) { | |
| let base = (h * ns + s) * stride; | |
| let lse = partials[base + hd]; | |
| ws += exp(lse - global_max_lse); | |
| } | |
| output[h * hd + tid] = acc / ws; | |
| }`,M=` | |
| struct Params { | |
| num_heads: u32, | |
| key_dim: u32, | |
| value_dim: u32, | |
| } | |
| @group(0) @binding(0) var<storage, read> qkv: array<f32>; | |
| @group(0) @binding(1) var<storage, read_write> state: array<f32>; | |
| @group(0) @binding(2) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(3) var<storage, read> a_proj: array<f32>; | |
| @group(0) @binding(4) var<storage, read> b_proj: array<f32>; | |
| @group(0) @binding(5) var<storage, read> A_log: array<u32>; | |
| @group(0) @binding(6) var<storage, read> dt_bias: array<u32>; | |
| @group(0) @binding(7) var<uniform> params: Params; | |
| fn sigmoid(x: f32) -> f32 { return 1.0 / (1.0 + exp(-x)); } | |
| fn softplus(x: f32) -> f32 { | |
| if (x > 20.0) { return x; } | |
| return log(1.0 + exp(x)); | |
| } | |
| fn unpack_bf16(packed: u32, idx: u32) -> f32 { | |
| let bits = (packed >> (idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) { | |
| let h = wg.x; | |
| let tid = lid.x; | |
| let nh = params.num_heads; | |
| let kd = params.key_dim; | |
| let vd = params.value_dim; | |
| if (h >= nh) { return; } | |
| let qh_off = h * kd; | |
| let kh_off = nh * kd + h * kd; | |
| let vh_off = 2u * nh * kd + h * vd; | |
| let state_off = h * kd * vd; | |
| // Compute Q/K L2 norms (all threads compute redundantly — cheap) | |
| var q_ss: f32 = 0.0; | |
| for (var ki: u32 = 0u; ki < kd; ki++) { q_ss += qkv[qh_off + ki] * qkv[qh_off + ki]; } | |
| let q_inv_norm = 1.0 / max(sqrt(q_ss), 1e-6); | |
| var k_ss: f32 = 0.0; | |
| for (var ki: u32 = 0u; ki < kd; ki++) { k_ss += qkv[kh_off + ki] * qkv[kh_off + ki]; } | |
| let k_inv_norm = 1.0 / max(sqrt(k_ss), 1e-6); | |
| // Decay and gate (same for all threads in this head) | |
| let a_log_val = unpack_bf16(A_log[h / 2u], h % 2u); | |
| let dt_bias_val = unpack_bf16(dt_bias[h / 2u], h % 2u); | |
| let decay = exp(-exp(a_log_val) * softplus(a_proj[h] + dt_bias_val)); | |
| let beta = sigmoid(b_proj[h]); | |
| // Each thread handles value dimensions in stride | |
| for (var vi = tid; vi < vd; vi += 256u) { | |
| // Step 1: Decay state column | |
| for (var ki: u32 = 0u; ki < kd; ki++) { | |
| let idx = state_off + ki * vd + vi; | |
| state[idx] *= decay; | |
| } | |
| // Step 2: kv_mem = S^T @ k_normalized (for this vi) | |
| var kv_mem: f32 = 0.0; | |
| for (var ki: u32 = 0u; ki < kd; ki++) { | |
| kv_mem += state[state_off + ki * vd + vi] * qkv[kh_off + ki] * k_inv_norm; | |
| } | |
| // Step 3: delta rule update | |
| let v_val = qkv[vh_off + vi]; | |
| let delta = (v_val - kv_mem) * beta; | |
| for (var ki: u32 = 0u; ki < kd; ki++) { | |
| let idx = state_off + ki * vd + vi; | |
| state[idx] += qkv[kh_off + ki] * k_inv_norm * delta; | |
| } | |
| // Step 4: Output = S^T @ (q_normalized * scale) | |
| let scale = 1.0 / sqrt(f32(kd)); | |
| var o_val: f32 = 0.0; | |
| for (var ki: u32 = 0u; ki < kd; ki++) { | |
| o_val += state[state_off + ki * vd + vi] * qkv[qh_off + ki] * q_inv_norm * scale; | |
| } | |
| output[h * vd + vi] = o_val; | |
| } | |
| }`,G=` | |
| struct Params { num_heads: u32, head_dim: u32, eps: f32, } | |
| @group(0) @binding(0) var<storage, read_write> x: array<f32>; | |
| @group(0) @binding(1) var<storage, read> weight: array<u32>; | |
| @group(0) @binding(2) var<uniform> params: Params; | |
| fn unpack_bf16(packed: u32, idx: u32) -> f32 { | |
| let bits = (packed >> (idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| var<workgroup> wg_scratch: array<f32, 256>; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) { | |
| let h = wg.x; | |
| let tid = lid.x; | |
| let hd = params.head_dim; | |
| if (h >= params.num_heads) { return; } | |
| let off = h * hd; | |
| // Parallel sum of squares | |
| var ss: f32 = 0.0; | |
| for (var i = tid; i < hd; i += 256u) { | |
| ss += x[off + i] * x[off + i]; | |
| } | |
| wg_scratch[tid] = ss; | |
| workgroupBarrier(); | |
| for (var s = 128u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| let rms = 1.0 / sqrt(wg_scratch[0] / f32(hd) + params.eps); | |
| // Parallel normalize | |
| for (var i = tid; i < hd; i += 256u) { | |
| let w = unpack_bf16(weight[i / 2u], i % 2u); | |
| x[off + i] = x[off + i] * rms * (1.0 + w); | |
| } | |
| }`,O=` | |
| struct Params { num_heads: u32, head_dim: u32, eps: f32, } | |
| @group(0) @binding(0) var<storage, read_write> x: array<f32>; | |
| @group(0) @binding(1) var<storage, read> weight: array<u32>; | |
| @group(0) @binding(2) var<uniform> params: Params; | |
| fn unpack_bf16(packed: u32, idx: u32) -> f32 { | |
| let bits = (packed >> (idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| var<workgroup> wg_scratch: array<f32, 256>; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) { | |
| let h = wg.x; | |
| let tid = lid.x; | |
| let hd = params.head_dim; | |
| if (h >= params.num_heads) { return; } | |
| let off = h * hd; | |
| var ss: f32 = 0.0; | |
| for (var i = tid; i < hd; i += 256u) { | |
| ss += x[off + i] * x[off + i]; | |
| } | |
| wg_scratch[tid] = ss; | |
| workgroupBarrier(); | |
| for (var s = 128u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| let rms = 1.0 / sqrt(wg_scratch[0] / f32(hd) + params.eps); | |
| for (var i = tid; i < hd; i += 256u) { | |
| let w = unpack_bf16(weight[i / 2u], i % 2u); | |
| x[off + i] = x[off + i] * rms * w; | |
| } | |
| }`,U=` | |
| struct Params { channels: u32, } | |
| @group(0) @binding(0) var<storage, read_write> x: array<f32>; | |
| @group(0) @binding(1) var<storage, read_write> hist: array<f32>; | |
| @group(0) @binding(2) var<storage, read> weight: array<u32>; | |
| @group(0) @binding(3) var<uniform> params: Params; | |
| fn unpack_bf16(packed: u32, idx: u32) -> f32 { | |
| let bits = (packed >> (idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let c = gid.x; | |
| if (c >= params.channels) { return; } | |
| let ch = params.channels; | |
| let h0 = hist[c]; | |
| let h1 = hist[ch + c]; | |
| let h2 = hist[2u * ch + c]; | |
| let cur = x[c]; | |
| let w_base = c * 4u; | |
| let w0 = unpack_bf16(weight[w_base / 2u], w_base % 2u); | |
| let w1 = unpack_bf16(weight[(w_base + 1u) / 2u], (w_base + 1u) % 2u); | |
| let w2 = unpack_bf16(weight[(w_base + 2u) / 2u], (w_base + 2u) % 2u); | |
| let w3 = unpack_bf16(weight[(w_base + 3u) / 2u], (w_base + 3u) % 2u); | |
| let conv_out = w0 * h0 + w1 * h1 + w2 * h2 + w3 * cur; | |
| x[c] = conv_out / (1.0 + exp(-conv_out)); | |
| hist[c] = h1; | |
| hist[ch + c] = h2; | |
| hist[2u * ch + c] = cur; | |
| }`,I=` | |
| struct Params { | |
| num_heads: u32, | |
| key_dim: u32, | |
| value_dim: u32, | |
| total_channels: u32, | |
| eps: f32, | |
| hidden_size: u32, | |
| num_value_heads: u32, | |
| } | |
| @group(0) @binding(0) var<storage, read_write> qkv: array<f32>; | |
| @group(0) @binding(1) var<storage, read_write> hist: array<f32>; | |
| @group(0) @binding(2) var<storage, read> conv_weight: array<u32>; | |
| @group(0) @binding(3) var<storage, read_write> state: array<f32>; | |
| @group(0) @binding(4) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(5) var<storage, read> hidden_input: array<f32>; | |
| @group(0) @binding(6) var<storage, read> ab_weight: array<u32>; | |
| @group(0) @binding(7) var<storage, read> A_log: array<u32>; | |
| @group(0) @binding(8) var<storage, read> dt_bias: array<u32>; | |
| @group(0) @binding(9) var<storage, read> norm_weight: array<u32>; | |
| @group(0) @binding(10) var<uniform> params: Params; | |
| fn unpack_bf16(packed: u32, idx: u32) -> f32 { | |
| let bits = (packed >> (idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| fn sigmoid_f(x: f32) -> f32 { return 1.0 / (1.0 + exp(-x)); } | |
| fn softplus_f(x: f32) -> f32 { | |
| if (x > 20.0) { return x; } | |
| return log(1.0 + exp(x)); | |
| } | |
| var<workgroup> wg_scratch: array<f32, 128>; | |
| @compute @workgroup_size(128) | |
| fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) { | |
| let h = wg.x; | |
| let tid = lid.x; | |
| let nh = params.num_heads; | |
| let kd = params.key_dim; | |
| let vd = params.value_dim; | |
| let ch = params.total_channels; | |
| let H = params.hidden_size; | |
| let nhv = params.num_value_heads; | |
| let vpk = nhv / nh; // value heads per key head (1 for 2B, 2 for 4B) | |
| let evd = vpk * vd; // effective value dim per key head | |
| if (h >= nh) { return; } | |
| // Alpha/beta computed per value head inside the vpk loop (Phase 2). | |
| // When H > 0: inline BF16 dot product (a_proj/b_proj are BF16 in merged ab_weight) | |
| // When H == 0: pre-computed via GPTQ dispatches, alpha/beta read per-value-head | |
| // ── Phase 1: Conv1d for this head's channels ── | |
| // Head h has 3 blocks: Q[h*kd..], K[nh*kd + h*kd..], V[2*nh*kd + h*evd..] | |
| // evd = vpk * vd (effective value dim per key head, handles asymmetric heads) | |
| let cpb = kd + kd + evd; // channels per key head block | |
| for (var i = tid; i < cpb; i += 128u) { | |
| // Map local channel to global channel index | |
| var c: u32; | |
| if (i < kd) { | |
| c = h * kd + i; | |
| } else if (i < 2u * kd) { | |
| c = nh * kd + h * kd + (i - kd); | |
| } else { | |
| c = 2u * nh * kd + h * evd + (i - 2u * kd); | |
| } | |
| // Conv1d kernel=4 with history | |
| let h0 = hist[c]; | |
| let h1 = hist[ch + c]; | |
| let h2 = hist[2u * ch + c]; | |
| let cur = qkv[c]; | |
| let wb = c * 4u; | |
| let w0 = unpack_bf16(conv_weight[wb / 2u], wb % 2u); | |
| let w1 = unpack_bf16(conv_weight[(wb + 1u) / 2u], (wb + 1u) % 2u); | |
| let w2 = unpack_bf16(conv_weight[(wb + 2u) / 2u], (wb + 2u) % 2u); | |
| let w3 = unpack_bf16(conv_weight[(wb + 3u) / 2u], (wb + 3u) % 2u); | |
| let conv_out = w0 * h0 + w1 * h1 + w2 * h2 + w3 * cur; | |
| qkv[c] = conv_out / (1.0 + exp(-conv_out)); // SiLU | |
| hist[c] = h1; | |
| hist[ch + c] = h2; | |
| hist[2u * ch + c] = cur; | |
| } | |
| workgroupBarrier(); | |
| // ── Phase 2+3: DeltaNet recurrent + RMSNorm, looped over vpk value heads ── | |
| // Q/K are shared across value heads within this key head | |
| let qh_off = h * kd; | |
| let kh_off = nh * kd + h * kd; | |
| // Q/K L2 norms (all threads compute redundantly, shared across value heads) | |
| var q_ss: f32 = 0.0; | |
| for (var ki: u32 = 0u; ki < kd; ki++) { q_ss += qkv[qh_off + ki] * qkv[qh_off + ki]; } | |
| let q_inv_norm = 1.0 / max(sqrt(q_ss), 1e-6); | |
| var k_ss: f32 = 0.0; | |
| for (var ki: u32 = 0u; ki < kd; ki++) { k_ss += qkv[kh_off + ki] * qkv[kh_off + ki]; } | |
| let k_inv_norm = 1.0 / max(sqrt(k_ss), 1e-6); | |
| let qk_scale = 1.0 / sqrt(f32(kd)); | |
| // Loop over vpk value heads per key head (1 for 2B, 2 for 4B) | |
| for (var vhi: u32 = 0u; vhi < vpk; vhi++) { | |
| let vh = h * vpk + vhi; // global value head index | |
| // Get per-value-head alpha/beta | |
| var alpha: f32; | |
| var beta_raw: f32; | |
| if (H == 0u) { | |
| // Quantized path: alpha/beta pre-computed per value head | |
| alpha = hidden_input[vh]; | |
| beta_raw = bitcast<f32>(ab_weight[vh]); | |
| } else { | |
| // BF16 inline: dot product per value head | |
| let half_H = H / 2u; | |
| let a_base = vh * half_H; | |
| let b_base = (nhv + vh) * half_H; | |
| var a_sum: f32 = 0.0; | |
| var b_sum: f32 = 0.0; | |
| for (var k = tid * 4u; k < H; k += 512u) { | |
| let hv0 = hidden_input[k]; | |
| let hv1 = hidden_input[k + 1u]; | |
| let hv2 = hidden_input[k + 2u]; | |
| let hv3 = hidden_input[k + 3u]; | |
| let ap0 = ab_weight[a_base + k / 2u]; | |
| let ap1 = ab_weight[a_base + k / 2u + 1u]; | |
| a_sum += hv0 * bitcast<f32>((ap0 & 0xFFFFu) << 16u) | |
| + hv1 * bitcast<f32>(ap0 & 0xFFFF0000u) | |
| + hv2 * bitcast<f32>((ap1 & 0xFFFFu) << 16u) | |
| + hv3 * bitcast<f32>(ap1 & 0xFFFF0000u); | |
| let bp0 = ab_weight[b_base + k / 2u]; | |
| let bp1 = ab_weight[b_base + k / 2u + 1u]; | |
| b_sum += hv0 * bitcast<f32>((bp0 & 0xFFFFu) << 16u) | |
| + hv1 * bitcast<f32>(bp0 & 0xFFFF0000u) | |
| + hv2 * bitcast<f32>((bp1 & 0xFFFFu) << 16u) | |
| + hv3 * bitcast<f32>(bp1 & 0xFFFF0000u); | |
| } | |
| wg_scratch[tid] = a_sum; | |
| workgroupBarrier(); | |
| for (var s = 64u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| alpha = wg_scratch[0]; | |
| wg_scratch[tid] = b_sum; | |
| workgroupBarrier(); | |
| for (var s = 64u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| beta_raw = wg_scratch[0]; | |
| } | |
| // Per-value-head decay and gate | |
| let a_log_val = unpack_bf16(A_log[vh / 2u], vh % 2u); | |
| let dt_bias_val = unpack_bf16(dt_bias[vh / 2u], vh % 2u); | |
| let decay = exp(-exp(a_log_val) * softplus_f(alpha + dt_bias_val)); | |
| let beta = sigmoid_f(beta_raw); | |
| // Per-value-head V offset, state offset, output offset | |
| let v_off = 2u * nh * kd + vh * vd; | |
| let s_off = vh * kd * vd; | |
| let o_off = vh * vd; | |
| // DeltaNet recurrent update for this value head | |
| for (var vi = tid; vi < vd; vi += 128u) { | |
| for (var ki: u32 = 0u; ki < kd; ki++) { | |
| state[s_off + ki * vd + vi] *= decay; | |
| } | |
| var kv_mem: f32 = 0.0; | |
| for (var ki: u32 = 0u; ki < kd; ki++) { | |
| kv_mem += state[s_off + ki * vd + vi] * qkv[kh_off + ki] * k_inv_norm; | |
| } | |
| let v_val = qkv[v_off + vi]; | |
| let delta = (v_val - kv_mem) * beta; | |
| for (var ki: u32 = 0u; ki < kd; ki++) { | |
| state[s_off + ki * vd + vi] += qkv[kh_off + ki] * k_inv_norm * delta; | |
| } | |
| var o_val: f32 = 0.0; | |
| for (var ki: u32 = 0u; ki < kd; ki++) { | |
| o_val += state[s_off + ki * vd + vi] * qkv[qh_off + ki] * q_inv_norm * qk_scale; | |
| } | |
| output[o_off + vi] = o_val; | |
| } | |
| workgroupBarrier(); | |
| // ── RMSNorm for this value head (nogated, w formula) ── | |
| // norm_weight is shared across all value heads (just vd weights) | |
| var ss: f32 = 0.0; | |
| for (var i = tid; i < vd; i += 128u) { | |
| ss += output[o_off + i] * output[o_off + i]; | |
| } | |
| wg_scratch[tid] = ss; | |
| workgroupBarrier(); | |
| for (var s = 64u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| let rms = 1.0 / sqrt(wg_scratch[0] / f32(vd) + params.eps); | |
| for (var i = tid; i < vd; i += 128u) { | |
| let w = unpack_bf16(norm_weight[i / 2u], i % 2u); | |
| output[o_off + i] = output[o_off + i] * rms * w; | |
| } | |
| workgroupBarrier(); | |
| } | |
| }`,C=` | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read> qweight_gate: array<u32>; | |
| @group(0) @binding(2) var<storage, read> scales_gate: array<u32>; | |
| @group(0) @binding(3) var<storage, read> qweight_up: array<u32>; | |
| @group(0) @binding(4) var<storage, read> scales_up: array<u32>; | |
| @group(0) @binding(5) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(6) var<uniform> params: Params; | |
| fn dq_in(w: u32, bk: u32) -> f32 { | |
| return dot(vec4<f32>(input[bk], input[bk+1u], input[bk+2u], input[bk+3u]), | |
| vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) | |
| + dot(vec4<f32>(input[bk+4u], input[bk+5u], input[bk+6u], input[bk+7u]), | |
| vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let col = gid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| if (col >= N) { return; } | |
| var gate_sum: f32 = 0.0; | |
| var up_sum: f32 = 0.0; | |
| for (var grp: u32 = 0u; grp < num_groups; grp++) { | |
| let sf = grp * N + col; | |
| let sg = unpack2x16float(scales_gate[sf >> 1u])[sf & 1u]; | |
| let su = unpack2x16float(scales_up[sf >> 1u])[sf & 1u]; | |
| let gstart = grp * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let gw0 = qweight_gate[off]; let uw0 = qweight_up[off]; | |
| let gw1 = qweight_gate[off+N]; let uw1 = qweight_up[off+N]; | |
| let gw2 = qweight_gate[off+2u*N]; let uw2 = qweight_up[off+2u*N]; | |
| let gw3 = qweight_gate[off+3u*N]; let uw3 = qweight_up[off+3u*N]; | |
| let bk = pr * 8u; | |
| gate_sum += sg * (dq_in(gw0,bk) + dq_in(gw1,bk+8u) + dq_in(gw2,bk+16u) + dq_in(gw3,bk+24u)); | |
| up_sum += su * (dq_in(uw0,bk) + dq_in(uw1,bk+8u) + dq_in(uw2,bk+16u) + dq_in(uw3,bk+24u)); | |
| } | |
| } | |
| output[col] = (gate_sum / (1.0 + exp(-gate_sum))) * up_sum; | |
| }`,L=` | |
| enable f16; | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read> qweight_gate: array<u32>; | |
| @group(0) @binding(2) var<storage, read> scales_gate: array<u32>; | |
| @group(0) @binding(3) var<storage, read> qweight_up: array<u32>; | |
| @group(0) @binding(4) var<storage, read> scales_up: array<u32>; | |
| @group(0) @binding(5) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(6) var<uniform> params: Params; | |
| fn dq16(w: u32, bk: u32) -> f32 { | |
| let i0 = vec4<f16>(f16(input[bk]), f16(input[bk+1u]), f16(input[bk+2u]), f16(input[bk+3u])); | |
| let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); | |
| let i1 = vec4<f16>(f16(input[bk+4u]), f16(input[bk+5u]), f16(input[bk+6u]), f16(input[bk+7u])); | |
| let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); | |
| return f32(dot(i0, w0)) + f32(dot(i1, w1)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let col = gid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| if (col >= N) { return; } | |
| var gate_sum: f32 = 0.0; | |
| var up_sum: f32 = 0.0; | |
| for (var grp: u32 = 0u; grp < num_groups; grp++) { | |
| let sf = grp * N + col; | |
| let sg = unpack2x16float(scales_gate[sf >> 1u])[sf & 1u]; | |
| let su = unpack2x16float(scales_up[sf >> 1u])[sf & 1u]; | |
| let gstart = grp * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let gw0 = qweight_gate[off]; let uw0 = qweight_up[off]; | |
| let gw1 = qweight_gate[off+N]; let uw1 = qweight_up[off+N]; | |
| let gw2 = qweight_gate[off+2u*N]; let uw2 = qweight_up[off+2u*N]; | |
| let gw3 = qweight_gate[off+3u*N]; let uw3 = qweight_up[off+3u*N]; | |
| let bk = pr * 8u; | |
| gate_sum += sg * (dq16(gw0,bk) + dq16(gw1,bk+8u) + dq16(gw2,bk+16u) + dq16(gw3,bk+24u)); | |
| up_sum += su * (dq16(uw0,bk) + dq16(uw1,bk+8u) + dq16(uw2,bk+16u) + dq16(uw3,bk+24u)); | |
| } | |
| } | |
| output[col] = (gate_sum / (1.0 + exp(-gate_sum))) * up_sum; | |
| }`,Q=` | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read> qweight_gate: array<u32>; | |
| @group(0) @binding(2) var<storage, read> scales_gate: array<u32>; | |
| @group(0) @binding(3) var<storage, read> qweight_up: array<u32>; | |
| @group(0) @binding(4) var<storage, read> scales_up: array<u32>; | |
| @group(0) @binding(5) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(6) var<uniform> params: Params; | |
| var<workgroup> partial: array<f32, 64>; | |
| fn dq_in(w: u32, bk: u32) -> f32 { | |
| return dot(vec4<f32>(input[bk], input[bk+1u], input[bk+2u], input[bk+3u]), | |
| vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) | |
| + dot(vec4<f32>(input[bk+4u], input[bk+5u], input[bk+6u], input[bk+7u]), | |
| vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(local_invocation_id) lid: vec3u, | |
| @builtin(workgroup_id) wid: vec3u) { | |
| let tid = lid.x; | |
| let col = wid.x * 8u + (tid / 4u); | |
| let lane = tid & 3u; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| let gpl = num_groups / 4u; | |
| let g_start = lane * gpl; | |
| let g_end = g_start + gpl; | |
| var gate_sum: f32 = 0.0; | |
| var up_sum: f32 = 0.0; | |
| if (col < N) { | |
| for (var grp = g_start; grp < g_end; grp++) { | |
| let sf = grp * N + col; | |
| let sg = unpack2x16float(scales_gate[sf >> 1u])[sf & 1u]; | |
| let su = unpack2x16float(scales_up[sf >> 1u])[sf & 1u]; | |
| let gstart_row = grp * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart_row + r; | |
| let off = pr * N + col; | |
| let gw0 = qweight_gate[off]; let uw0 = qweight_up[off]; | |
| let gw1 = qweight_gate[off+N]; let uw1 = qweight_up[off+N]; | |
| let gw2 = qweight_gate[off+2u*N]; let uw2 = qweight_up[off+2u*N]; | |
| let gw3 = qweight_gate[off+3u*N]; let uw3 = qweight_up[off+3u*N]; | |
| let bk = pr * 8u; | |
| gate_sum += sg * (dq_in(gw0,bk) + dq_in(gw1,bk+8u) + dq_in(gw2,bk+16u) + dq_in(gw3,bk+24u)); | |
| up_sum += su * (dq_in(uw0,bk) + dq_in(uw1,bk+8u) + dq_in(uw2,bk+16u) + dq_in(uw3,bk+24u)); | |
| } | |
| } | |
| } | |
| partial[tid] = gate_sum; | |
| partial[32u + tid] = up_sum; | |
| workgroupBarrier(); | |
| if (lane == 0u && col < N) { | |
| let g = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u]; | |
| let u = partial[32u+tid] + partial[32u+tid+1u] + partial[32u+tid+2u] + partial[32u+tid+3u]; | |
| output[col] = (g / (1.0 + exp(-g))) * u; | |
| } | |
| }`,j=` | |
| enable f16; | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read> qweight_gate: array<u32>; | |
| @group(0) @binding(2) var<storage, read> scales_gate: array<u32>; | |
| @group(0) @binding(3) var<storage, read> qweight_up: array<u32>; | |
| @group(0) @binding(4) var<storage, read> scales_up: array<u32>; | |
| @group(0) @binding(5) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(6) var<uniform> params: Params; | |
| var<workgroup> partial: array<f32, 64>; | |
| fn dq16(w: u32, bk: u32) -> f32 { | |
| let i0 = vec4<f16>(f16(input[bk]), f16(input[bk+1u]), f16(input[bk+2u]), f16(input[bk+3u])); | |
| let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); | |
| let i1 = vec4<f16>(f16(input[bk+4u]), f16(input[bk+5u]), f16(input[bk+6u]), f16(input[bk+7u])); | |
| let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); | |
| return f32(dot(i0, w0)) + f32(dot(i1, w1)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(local_invocation_id) lid: vec3u, | |
| @builtin(workgroup_id) wid: vec3u) { | |
| let tid = lid.x; | |
| let col = wid.x * 8u + (tid / 4u); | |
| let lane = tid & 3u; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| let gpl = num_groups / 4u; | |
| let g_start = lane * gpl; | |
| let g_end = g_start + gpl; | |
| var gate_sum: f32 = 0.0; | |
| var up_sum: f32 = 0.0; | |
| if (col < N) { | |
| for (var grp = g_start; grp < g_end; grp++) { | |
| let sf = grp * N + col; | |
| let sg = unpack2x16float(scales_gate[sf >> 1u])[sf & 1u]; | |
| let su = unpack2x16float(scales_up[sf >> 1u])[sf & 1u]; | |
| let gstart_row = grp * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart_row + r; | |
| let off = pr * N + col; | |
| let gw0 = qweight_gate[off]; let uw0 = qweight_up[off]; | |
| let gw1 = qweight_gate[off+N]; let uw1 = qweight_up[off+N]; | |
| let gw2 = qweight_gate[off+2u*N]; let uw2 = qweight_up[off+2u*N]; | |
| let gw3 = qweight_gate[off+3u*N]; let uw3 = qweight_up[off+3u*N]; | |
| let bk = pr * 8u; | |
| gate_sum += sg * (dq16(gw0,bk) + dq16(gw1,bk+8u) + dq16(gw2,bk+16u) + dq16(gw3,bk+24u)); | |
| up_sum += su * (dq16(uw0,bk) + dq16(uw1,bk+8u) + dq16(uw2,bk+16u) + dq16(uw3,bk+24u)); | |
| } | |
| } | |
| } | |
| partial[tid] = gate_sum; | |
| partial[32u + tid] = up_sum; | |
| workgroupBarrier(); | |
| if (lane == 0u && col < N) { | |
| let g = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u]; | |
| let u = partial[32u+tid] + partial[32u+tid+1u] + partial[32u+tid+2u] + partial[32u+tid+3u]; | |
| output[col] = (g / (1.0 + exp(-g))) * u; | |
| } | |
| }`,V=` | |
| struct Params { N: u32, eps: f32, } | |
| @group(0) @binding(0) var<storage, read_write> hidden: array<f32>; | |
| @group(0) @binding(1) var<storage, read> addend: array<f32>; | |
| @group(0) @binding(2) var<storage, read> weight: array<u32>; | |
| @group(0) @binding(3) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| fn unpack_bf16(packed: u32, idx: u32) -> f32 { | |
| let bits = (packed >> (idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| var<workgroup> wg_scratch: array<f32, 256>; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(local_invocation_id) lid: vec3u) { | |
| let tid = lid.x; | |
| let N = params.N; | |
| // Phase 1: Add + compute sum of squares | |
| var ss: f32 = 0.0; | |
| for (var i = tid; i < N; i += 256u) { | |
| let val = hidden[i] + addend[i]; | |
| hidden[i] = val; | |
| ss += val * val; | |
| } | |
| wg_scratch[tid] = ss; | |
| workgroupBarrier(); | |
| // Tree reduction | |
| for (var s = 128u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps); | |
| // Phase 2: Normalize | |
| for (var i = tid; i < N; i += 256u) { | |
| let w = unpack_bf16(weight[i / 2u], i % 2u); | |
| output[i] = hidden[i] * rms * (1.0 + w); | |
| } | |
| }`,W=` | |
| struct Params { N: u32, eps: f32, } | |
| @group(0) @binding(0) var<storage, read> hidden: array<f32>; | |
| @group(0) @binding(1) var<storage, read> addend: array<f32>; | |
| @group(0) @binding(2) var<storage, read> weight: array<u32>; | |
| @group(0) @binding(3) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| fn unpack_bf16(packed: u32, idx: u32) -> f32 { | |
| let bits = (packed >> (idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| var<workgroup> wg_vals: array<f32, 7680>; | |
| var<workgroup> wg_scratch: array<f32, 256>; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(local_invocation_id) lid: vec3u) { | |
| let tid = lid.x; | |
| let N = params.N; | |
| var ss: f32 = 0.0; | |
| for (var i = tid; i < N; i += 256u) { | |
| let val = hidden[i] + addend[i]; | |
| wg_vals[i] = val; | |
| ss += val * val; | |
| } | |
| wg_scratch[tid] = ss; | |
| workgroupBarrier(); | |
| for (var s = 128u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps); | |
| for (var i = tid; i < N; i += 256u) { | |
| let w = unpack_bf16(weight[i / 2u], i % 2u); | |
| output[i] = wg_vals[i] * rms * (1.0 + w); | |
| } | |
| }`,Y=` | |
| struct Params { N: u32, eps: f32, } | |
| @group(0) @binding(0) var<storage, read> input_a: array<f32>; | |
| @group(0) @binding(1) var<storage, read> input_b: array<f32>; | |
| @group(0) @binding(2) var<storage, read> input_c: array<f32>; | |
| @group(0) @binding(3) var<storage, read> weight: array<u32>; | |
| @group(0) @binding(4) var<storage, read_write> hidden_out: array<f32>; | |
| @group(0) @binding(5) var<storage, read_write> normed: array<f32>; | |
| @group(0) @binding(6) var<uniform> params: Params; | |
| fn unpack_bf16(packed: u32, idx: u32) -> f32 { | |
| let bits = (packed >> (idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| var<workgroup> wg_vals: array<f32, 7680>; | |
| var<workgroup> wg_scratch: array<f32, 256>; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(local_invocation_id) lid: vec3u) { | |
| let tid = lid.x; | |
| let N = params.N; | |
| var ss: f32 = 0.0; | |
| for (var i = tid; i < N; i += 256u) { | |
| let val = input_a[i] + input_b[i] + input_c[i]; | |
| wg_vals[i] = val; | |
| hidden_out[i] = val; | |
| ss += val * val; | |
| } | |
| wg_scratch[tid] = ss; | |
| workgroupBarrier(); | |
| for (var s = 128u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps); | |
| for (var i = tid; i < N; i += 256u) { | |
| let w = unpack_bf16(weight[i / 2u], i % 2u); | |
| normed[i] = wg_vals[i] * rms * (1.0 + w); | |
| } | |
| }`,X=` | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> x: array<f32>; | |
| @group(0) @binding(1) var<storage, read> gate: array<f32>; | |
| @group(0) @binding(2) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(3) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(4) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| fn sig(v: f32) -> f32 { return 1.0 / (1.0 + exp(-v)); } | |
| fn dq_sg(w: u32, bk: u32) -> f32 { | |
| let il = vec4<f32>(x[bk]*sig(gate[bk]), x[bk+1u]*sig(gate[bk+1u]), x[bk+2u]*sig(gate[bk+2u]), x[bk+3u]*sig(gate[bk+3u])); | |
| let ih = vec4<f32>(x[bk+4u]*sig(gate[bk+4u]), x[bk+5u]*sig(gate[bk+5u]), x[bk+6u]*sig(gate[bk+6u]), x[bk+7u]*sig(gate[bk+7u])); | |
| return dot(il, vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) | |
| + dot(ih, vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let col = gid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| if (col >= N) { return; } | |
| var sum: f32 = 0.0; | |
| for (var g: u32 = 0u; g < num_groups; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; | |
| let bk = pr * 8u; | |
| sum += s * (dq_sg(w0,bk) + dq_sg(w1,bk+8u) + dq_sg(w2,bk+16u) + dq_sg(w3,bk+24u)); | |
| } | |
| } | |
| output[col] = sum; | |
| }`,$=` | |
| enable f16; | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> x: array<f32>; | |
| @group(0) @binding(1) var<storage, read> gate: array<f32>; | |
| @group(0) @binding(2) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(3) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(4) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| fn sig(v: f32) -> f32 { return 1.0 / (1.0 + exp(-v)); } | |
| fn dq16_sg(w: u32, bk: u32) -> f32 { | |
| let i0 = vec4<f16>(f16(x[bk]*sig(gate[bk])), f16(x[bk+1u]*sig(gate[bk+1u])), f16(x[bk+2u]*sig(gate[bk+2u])), f16(x[bk+3u]*sig(gate[bk+3u]))); | |
| let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); | |
| let i1 = vec4<f16>(f16(x[bk+4u]*sig(gate[bk+4u])), f16(x[bk+5u]*sig(gate[bk+5u])), f16(x[bk+6u]*sig(gate[bk+6u])), f16(x[bk+7u]*sig(gate[bk+7u]))); | |
| let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); | |
| return f32(dot(i0, w0)) + f32(dot(i1, w1)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let col = gid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| if (col >= N) { return; } | |
| var sum: f32 = 0.0; | |
| for (var g: u32 = 0u; g < num_groups; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; | |
| let bk = pr * 8u; | |
| sum += s * (dq16_sg(w0,bk) + dq16_sg(w1,bk+8u) + dq16_sg(w2,bk+16u) + dq16_sg(w3,bk+24u)); | |
| } | |
| } | |
| output[col] = sum; | |
| }`,Z=` | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> a: array<f32>; | |
| @group(0) @binding(1) var<storage, read> b: array<f32>; | |
| @group(0) @binding(2) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(3) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(4) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| fn silu(v: f32) -> f32 { return v / (1.0 + exp(-v)); } | |
| fn dq_sl(w: u32, bk: u32) -> f32 { | |
| let il = vec4<f32>(silu(a[bk])*b[bk], silu(a[bk+1u])*b[bk+1u], silu(a[bk+2u])*b[bk+2u], silu(a[bk+3u])*b[bk+3u]); | |
| let ih = vec4<f32>(silu(a[bk+4u])*b[bk+4u], silu(a[bk+5u])*b[bk+5u], silu(a[bk+6u])*b[bk+6u], silu(a[bk+7u])*b[bk+7u]); | |
| return dot(il, vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) | |
| + dot(ih, vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let col = gid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| if (col >= N) { return; } | |
| var sum: f32 = 0.0; | |
| for (var g: u32 = 0u; g < num_groups; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; | |
| let bk = pr * 8u; | |
| sum += s * (dq_sl(w0,bk) + dq_sl(w1,bk+8u) + dq_sl(w2,bk+16u) + dq_sl(w3,bk+24u)); | |
| } | |
| } | |
| output[col] = sum; | |
| }`,J=` | |
| enable f16; | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> a: array<f32>; | |
| @group(0) @binding(1) var<storage, read> b: array<f32>; | |
| @group(0) @binding(2) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(3) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(4) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| fn silu(v: f32) -> f32 { return v / (1.0 + exp(-v)); } | |
| fn dq16_sl(w: u32, bk: u32) -> f32 { | |
| let i0 = vec4<f16>(f16(silu(a[bk])*b[bk]), f16(silu(a[bk+1u])*b[bk+1u]), f16(silu(a[bk+2u])*b[bk+2u]), f16(silu(a[bk+3u])*b[bk+3u])); | |
| let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); | |
| let i1 = vec4<f16>(f16(silu(a[bk+4u])*b[bk+4u]), f16(silu(a[bk+5u])*b[bk+5u]), f16(silu(a[bk+6u])*b[bk+6u]), f16(silu(a[bk+7u])*b[bk+7u])); | |
| let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); | |
| return f32(dot(i0, w0)) + f32(dot(i1, w1)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let col = gid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| if (col >= N) { return; } | |
| var sum: f32 = 0.0; | |
| for (var g: u32 = 0u; g < num_groups; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; | |
| let bk = pr * 8u; | |
| sum += s * (dq16_sl(w0,bk) + dq16_sl(w1,bk+8u) + dq16_sl(w2,bk+16u) + dq16_sl(w3,bk+24u)); | |
| } | |
| } | |
| output[col] = sum; | |
| }`,rr=` | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> a: array<f32>; | |
| @group(0) @binding(1) var<storage, read> b: array<f32>; | |
| @group(0) @binding(2) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(3) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(4) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| fn silu(v: f32) -> f32 { return v / (1.0 + exp(-v)); } | |
| fn dq_sl(w: u32, bk: u32) -> f32 { | |
| let il = vec4<f32>(silu(a[bk])*b[bk], silu(a[bk+1u])*b[bk+1u], silu(a[bk+2u])*b[bk+2u], silu(a[bk+3u])*b[bk+3u]); | |
| let ih = vec4<f32>(silu(a[bk+4u])*b[bk+4u], silu(a[bk+5u])*b[bk+5u], silu(a[bk+6u])*b[bk+6u], silu(a[bk+7u])*b[bk+7u]); | |
| return dot(il, vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) | |
| + dot(ih, vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); | |
| } | |
| var<workgroup> partial: array<f32, 32>; | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { | |
| let tid = lid.x; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = params.K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| let col = wid.x * 8u + (tid / 4u); | |
| let lane = tid & 3u; | |
| var sum: f32 = 0.0; | |
| if (col < N) { | |
| let gpl = num_groups / 4u; | |
| let g_start = lane * gpl; | |
| let g_end = g_start + gpl; | |
| for (var g = g_start; g < g_end; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; | |
| let bk = pr * 8u; | |
| sum += s * (dq_sl(w0,bk) + dq_sl(w1,bk+8u) + dq_sl(w2,bk+16u) + dq_sl(w3,bk+24u)); | |
| } | |
| } | |
| } | |
| partial[tid] = sum; | |
| workgroupBarrier(); | |
| if (lane == 0u && col < N) { | |
| output[col] = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u]; | |
| } | |
| }`,ar=` | |
| enable f16; | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> a: array<f32>; | |
| @group(0) @binding(1) var<storage, read> b: array<f32>; | |
| @group(0) @binding(2) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(3) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(4) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| fn silu(v: f32) -> f32 { return v / (1.0 + exp(-v)); } | |
| fn dq16_sl(w: u32, bk: u32) -> f32 { | |
| let i0 = vec4<f16>(f16(silu(a[bk])*b[bk]), f16(silu(a[bk+1u])*b[bk+1u]), f16(silu(a[bk+2u])*b[bk+2u]), f16(silu(a[bk+3u])*b[bk+3u])); | |
| let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); | |
| let i1 = vec4<f16>(f16(silu(a[bk+4u])*b[bk+4u]), f16(silu(a[bk+5u])*b[bk+5u]), f16(silu(a[bk+6u])*b[bk+6u]), f16(silu(a[bk+7u])*b[bk+7u])); | |
| let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); | |
| return f32(dot(i0, w0)) + f32(dot(i1, w1)); | |
| } | |
| var<workgroup> partial: array<f32, 32>; | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { | |
| let tid = lid.x; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = params.K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| let col = wid.x * 8u + (tid / 4u); | |
| let lane = tid & 3u; | |
| var sum: f32 = 0.0; | |
| if (col < N) { | |
| let gpl = num_groups / 4u; | |
| let g_start = lane * gpl; | |
| let g_end = g_start + gpl; | |
| for (var g = g_start; g < g_end; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; | |
| let bk = pr * 8u; | |
| sum += s * (dq16_sl(w0,bk) + dq16_sl(w1,bk+8u) + dq16_sl(w2,bk+16u) + dq16_sl(w3,bk+24u)); | |
| } | |
| } | |
| } | |
| partial[tid] = sum; | |
| workgroupBarrier(); | |
| if (lane == 0u && col < N) { | |
| output[col] = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u]; | |
| } | |
| }`,er=` | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> x: array<f32>; | |
| @group(0) @binding(1) var<storage, read> gate: array<f32>; | |
| @group(0) @binding(2) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(3) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(4) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| fn sig(v: f32) -> f32 { return 1.0 / (1.0 + exp(-v)); } | |
| fn dq_sg(w: u32, bk: u32) -> f32 { | |
| let il = vec4<f32>(x[bk]*sig(gate[bk]), x[bk+1u]*sig(gate[bk+1u]), x[bk+2u]*sig(gate[bk+2u]), x[bk+3u]*sig(gate[bk+3u])); | |
| let ih = vec4<f32>(x[bk+4u]*sig(gate[bk+4u]), x[bk+5u]*sig(gate[bk+5u]), x[bk+6u]*sig(gate[bk+6u]), x[bk+7u]*sig(gate[bk+7u])); | |
| return dot(il, vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) | |
| + dot(ih, vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); | |
| } | |
| var<workgroup> partial: array<f32, 32>; | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { | |
| let tid = lid.x; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = params.K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| let col = wid.x * 8u + (tid / 4u); | |
| let lane = tid & 3u; | |
| var sum: f32 = 0.0; | |
| if (col < N) { | |
| let gpl = num_groups / 4u; | |
| let g_start = lane * gpl; | |
| let g_end = g_start + gpl; | |
| for (var g = g_start; g < g_end; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; | |
| let bk = pr * 8u; | |
| sum += s * (dq_sg(w0,bk) + dq_sg(w1,bk+8u) + dq_sg(w2,bk+16u) + dq_sg(w3,bk+24u)); | |
| } | |
| } | |
| } | |
| partial[tid] = sum; | |
| workgroupBarrier(); | |
| if (lane == 0u && col < N) { | |
| output[col] = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u]; | |
| } | |
| }`,tr=` | |
| enable f16; | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> x: array<f32>; | |
| @group(0) @binding(1) var<storage, read> gate: array<f32>; | |
| @group(0) @binding(2) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(3) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(4) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| fn sig(v: f32) -> f32 { return 1.0 / (1.0 + exp(-v)); } | |
| fn dq16_sg(w: u32, bk: u32) -> f32 { | |
| let i0 = vec4<f16>(f16(x[bk]*sig(gate[bk])), f16(x[bk+1u]*sig(gate[bk+1u])), f16(x[bk+2u]*sig(gate[bk+2u])), f16(x[bk+3u]*sig(gate[bk+3u]))); | |
| let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); | |
| let i1 = vec4<f16>(f16(x[bk+4u]*sig(gate[bk+4u])), f16(x[bk+5u]*sig(gate[bk+5u])), f16(x[bk+6u]*sig(gate[bk+6u])), f16(x[bk+7u]*sig(gate[bk+7u]))); | |
| let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); | |
| return f32(dot(i0, w0)) + f32(dot(i1, w1)); | |
| } | |
| var<workgroup> partial: array<f32, 32>; | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { | |
| let tid = lid.x; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = params.K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| let col = wid.x * 8u + (tid / 4u); | |
| let lane = tid & 3u; | |
| var sum: f32 = 0.0; | |
| if (col < N) { | |
| let gpl = num_groups / 4u; | |
| let g_start = lane * gpl; | |
| let g_end = g_start + gpl; | |
| for (var g = g_start; g < g_end; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; | |
| let bk = pr * 8u; | |
| sum += s * (dq16_sg(w0,bk) + dq16_sg(w1,bk+8u) + dq16_sg(w2,bk+16u) + dq16_sg(w3,bk+24u)); | |
| } | |
| } | |
| } | |
| partial[tid] = sum; | |
| workgroupBarrier(); | |
| if (lane == 0u && col < N) { | |
| output[col] = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u]; | |
| } | |
| }`,ur=` | |
| struct Params { | |
| K: u32, // hidden_size | |
| N: u32, // intermediate_size | |
| group_size: u32, | |
| eps: f32, | |
| // BF16 norm weights packed as vec4<u32> for correct uniform alignment. | |
| norm_weight: array<vec4<u32>, 640>, | |
| } | |
| @group(0) @binding(0) var<storage, read> hidden_in: array<f32>; | |
| @group(0) @binding(1) var<storage, read> mlp_residual: array<f32>; | |
| @group(0) @binding(2) var<storage, read> attn_residual: array<f32>; | |
| @group(0) @binding(3) var<storage, read> qweight_merged: array<u32>; | |
| @group(0) @binding(4) var<storage, read> scales_merged: array<u32>; | |
| @group(0) @binding(5) var<storage, read_write> hidden_out: array<f32>; | |
| @group(0) @binding(6) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(7) var<uniform> params: Params; | |
| fn unpack_norm_weight(i: u32) -> f32 { | |
| // Each vec4<u32> holds 8 BF16 values (4 u32 × 2 per u32) | |
| let vec_idx = i / 8u; | |
| let u32_idx = (i / 2u) % 4u; | |
| let bf16_idx = i % 2u; | |
| let packed = params.norm_weight[vec_idx][u32_idx]; | |
| let bits = (packed >> (bf16_idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| // On-the-fly normed value: avoids 16KB shared memory (wg_normed) that killed occupancy. | |
| // Inputs are L2-cached (~24KB total), so re-reading is cheap. | |
| var<workgroup> wg_scratch: array<f32, 32>; | |
| fn normed_3way(i: u32, rms: f32) -> f32 { | |
| let val = hidden_in[i] + mlp_residual[i] + attn_residual[i]; | |
| let w = unpack_norm_weight(i); | |
| return val * rms * (1.0 + w); | |
| } | |
| fn dq_fly(pw: u32, bk: u32, rms: f32) -> f32 { | |
| return dot(vec4<f32>(normed_3way(bk, rms), normed_3way(bk+1u, rms), normed_3way(bk+2u, rms), normed_3way(bk+3u, rms)), | |
| vec4<f32>(f32(pw&0xFu)-8., f32((pw>>4u)&0xFu)-8., f32((pw>>8u)&0xFu)-8., f32((pw>>12u)&0xFu)-8.)) | |
| + dot(vec4<f32>(normed_3way(bk+4u, rms), normed_3way(bk+5u, rms), normed_3way(bk+6u, rms), normed_3way(bk+7u, rms)), | |
| vec4<f32>(f32((pw>>16u)&0xFu)-8., f32((pw>>20u)&0xFu)-8., f32((pw>>24u)&0xFu)-8., f32((pw>>28u)&0xFu)-8.)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u, | |
| @builtin(workgroup_id) wg: vec3u, | |
| @builtin(local_invocation_id) lid: vec3u) { | |
| let tid = lid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| // ── Phase 1: Three-way add → compute RMS scalar (no shared memory for normed values) ── | |
| var ss: f32 = 0.0; | |
| for (var i = tid; i < K; i += 32u) { | |
| let val = hidden_in[i] + mlp_residual[i] + attn_residual[i]; | |
| ss += val * val; | |
| // WG 0: write hidden_out (residual for next layer) | |
| if (wg.x == 0u) { hidden_out[i] = val; } | |
| } | |
| wg_scratch[tid] = ss; | |
| workgroupBarrier(); | |
| // Tree reduction for sum of squares | |
| for (var s = 16u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| let rms = 1.0 / sqrt(wg_scratch[0] / f32(K) + params.eps); | |
| // ── Phase 2: Gate + Up GPTQ MatVec + SiLU (on-the-fly normed values from L2 cache) ── | |
| let col = gid.x; | |
| if (col >= N) { return; } | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| let up_qw_base = packed_rows * N; | |
| let up_sc_base = num_groups * N; | |
| var gate_sum: f32 = 0.0; | |
| var up_sum: f32 = 0.0; | |
| for (var grp: u32 = 0u; grp < num_groups; grp++) { | |
| let gsf = grp * N + col; | |
| let usf = up_sc_base + grp * N + col; | |
| let sg = unpack2x16float(scales_merged[gsf >> 1u])[gsf & 1u]; | |
| let su = unpack2x16float(scales_merged[usf >> 1u])[usf & 1u]; | |
| let gstart = grp * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let g_off = pr * N + col; | |
| let u_off = up_qw_base + pr * N + col; | |
| let gw0 = qweight_merged[g_off]; let uw0 = qweight_merged[u_off]; | |
| let gw1 = qweight_merged[g_off+N]; let uw1 = qweight_merged[u_off+N]; | |
| let gw2 = qweight_merged[g_off+2u*N]; let uw2 = qweight_merged[u_off+2u*N]; | |
| let gw3 = qweight_merged[g_off+3u*N]; let uw3 = qweight_merged[u_off+3u*N]; | |
| let bk = pr * 8u; | |
| gate_sum += sg * (dq_fly(gw0,bk,rms) + dq_fly(gw1,bk+8u,rms) + dq_fly(gw2,bk+16u,rms) + dq_fly(gw3,bk+24u,rms)); | |
| up_sum += su * (dq_fly(uw0,bk,rms) + dq_fly(uw1,bk+8u,rms) + dq_fly(uw2,bk+16u,rms) + dq_fly(uw3,bk+24u,rms)); | |
| } | |
| } | |
| output[col] = (gate_sum / (1.0 + exp(-gate_sum))) * up_sum; | |
| }`,ir=` | |
| struct Params { | |
| K: u32, | |
| N: u32, | |
| group_size: u32, | |
| eps: f32, | |
| norm_weight: array<vec4<u32>, 640>, | |
| } | |
| @group(0) @binding(0) var<storage, read> hidden_in: array<f32>; | |
| @group(0) @binding(1) var<storage, read> residual: array<f32>; | |
| @group(0) @binding(2) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(3) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(4) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| fn unpack_norm_weight(i: u32) -> f32 { | |
| let vec_idx = i / 8u; | |
| let u32_idx = (i / 2u) % 4u; | |
| let bf16_idx = i % 2u; | |
| let packed = params.norm_weight[vec_idx][u32_idx]; | |
| let bits = (packed >> (bf16_idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| // On-the-fly: only 128 bytes shared memory (vs 16KB wg_normed that killed occupancy) | |
| var<workgroup> wg_scratch: array<f32, 32>; | |
| fn normed_add(i: u32, rms: f32) -> f32 { | |
| let val = hidden_in[i] + residual[i]; | |
| let w = unpack_norm_weight(i); | |
| return val * rms * (1.0 + w); | |
| } | |
| fn dq_fly(pw: u32, bk: u32, rms: f32) -> f32 { | |
| return dot(vec4<f32>(normed_add(bk, rms), normed_add(bk+1u, rms), normed_add(bk+2u, rms), normed_add(bk+3u, rms)), | |
| vec4<f32>(f32(pw&0xFu)-8., f32((pw>>4u)&0xFu)-8., f32((pw>>8u)&0xFu)-8., f32((pw>>12u)&0xFu)-8.)) | |
| + dot(vec4<f32>(normed_add(bk+4u, rms), normed_add(bk+5u, rms), normed_add(bk+6u, rms), normed_add(bk+7u, rms)), | |
| vec4<f32>(f32((pw>>16u)&0xFu)-8., f32((pw>>20u)&0xFu)-8., f32((pw>>24u)&0xFu)-8., f32((pw>>28u)&0xFu)-8.)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u, | |
| @builtin(local_invocation_id) lid: vec3u) { | |
| let tid = lid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| // ── Phase 1: Compute RMS scalar only (inputs re-read on-the-fly in phase 2) ── | |
| var ss: f32 = 0.0; | |
| for (var i = tid; i < K; i += 32u) { | |
| let val = hidden_in[i] + residual[i]; | |
| ss += val * val; | |
| } | |
| wg_scratch[tid] = ss; | |
| workgroupBarrier(); | |
| for (var s = 16u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| let rms = 1.0 / sqrt(wg_scratch[0] / f32(K) + params.eps); | |
| // ── Phase 2: GPTQ MatVec with on-the-fly normed values (L2-cached inputs) ── | |
| let col = gid.x; | |
| if (col >= N) { return; } | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| var sum: f32 = 0.0; | |
| for (var grp: u32 = 0u; grp < num_groups; grp++) { | |
| let sf = grp * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = grp * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; | |
| let bk = pr * 8u; | |
| sum += s * (dq_fly(w0,bk,rms) + dq_fly(w1,bk+8u,rms) + dq_fly(w2,bk+16u,rms) + dq_fly(w3,bk+24u,rms)); | |
| } | |
| } | |
| output[col] = sum; | |
| }`,sr=` | |
| struct Params { | |
| K: u32, | |
| N: u32, | |
| group_size: u32, | |
| eps: f32, | |
| norm_weight: array<vec4<u32>, 640>, | |
| } | |
| @group(0) @binding(0) var<storage, read> hidden_in: array<f32>; | |
| @group(0) @binding(1) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(2) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(3) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| fn unpack_norm_weight(i: u32) -> f32 { | |
| let vec_idx = i / 8u; | |
| let u32_idx = (i / 2u) % 4u; | |
| let bf16_idx = i % 2u; | |
| let packed = params.norm_weight[vec_idx][u32_idx]; | |
| let bits = (packed >> (bf16_idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| var<workgroup> wg_scratch: array<f32, 32>; | |
| fn normed_val(i: u32, rms: f32) -> f32 { | |
| let w = unpack_norm_weight(i); | |
| return hidden_in[i] * rms * (1.0 + w); | |
| } | |
| fn dq_fly(pw: u32, bk: u32, rms: f32) -> f32 { | |
| return dot(vec4<f32>(normed_val(bk, rms), normed_val(bk+1u, rms), normed_val(bk+2u, rms), normed_val(bk+3u, rms)), | |
| vec4<f32>(f32(pw&0xFu)-8., f32((pw>>4u)&0xFu)-8., f32((pw>>8u)&0xFu)-8., f32((pw>>12u)&0xFu)-8.)) | |
| + dot(vec4<f32>(normed_val(bk+4u, rms), normed_val(bk+5u, rms), normed_val(bk+6u, rms), normed_val(bk+7u, rms)), | |
| vec4<f32>(f32((pw>>16u)&0xFu)-8., f32((pw>>20u)&0xFu)-8., f32((pw>>24u)&0xFu)-8., f32((pw>>28u)&0xFu)-8.)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u, | |
| @builtin(local_invocation_id) lid: vec3u) { | |
| let tid = lid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| // ── Phase 1: Compute RMS scalar only (inputs re-read on-the-fly in phase 2) ── | |
| var ss: f32 = 0.0; | |
| for (var i = tid; i < K; i += 32u) { | |
| let val = hidden_in[i]; | |
| ss += val * val; | |
| } | |
| wg_scratch[tid] = ss; | |
| workgroupBarrier(); | |
| for (var s = 16u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| let rms = 1.0 / sqrt(wg_scratch[0] / f32(K) + params.eps); | |
| // ── Phase 2: GPTQ MatVec with on-the-fly normed values ── | |
| let col = gid.x; | |
| if (col >= N) { return; } | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| var sum: f32 = 0.0; | |
| for (var grp: u32 = 0u; grp < num_groups; grp++) { | |
| let sf = grp * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = grp * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; | |
| let bk = pr * 8u; | |
| sum += s * (dq_fly(w0,bk,rms) + dq_fly(w1,bk+8u,rms) + dq_fly(w2,bk+16u,rms) + dq_fly(w3,bk+24u,rms)); | |
| } | |
| } | |
| output[col] = sum; | |
| }`,or=` | |
| struct Params { | |
| K: u32, N: u32, group_size: u32, eps: f32, | |
| norm_weight: array<vec4<u32>, 640>, | |
| } | |
| @group(0) @binding(0) var<storage, read> hidden_in: array<f32>; | |
| @group(0) @binding(1) var<storage, read> residual: array<f32>; | |
| @group(0) @binding(2) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(3) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(4) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| fn unpack_nw(i: u32) -> f32 { | |
| let vec_idx = i / 8u; | |
| let u32_idx = (i / 2u) % 4u; | |
| let bf16_idx = i % 2u; | |
| let packed = params.norm_weight[vec_idx][u32_idx]; | |
| let bits = (packed >> (bf16_idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| var<workgroup> wg_scratch: array<f32, 32>; | |
| var<workgroup> wg_rms: f32; | |
| fn nv(i: u32) -> f32 { | |
| return (hidden_in[i] + residual[i]) * wg_rms * (1.0 + unpack_nw(i)); | |
| } | |
| fn dq_nv(w: u32, bk: u32) -> f32 { | |
| return dot(vec4<f32>(nv(bk), nv(bk+1u), nv(bk+2u), nv(bk+3u)), | |
| vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) | |
| + dot(vec4<f32>(nv(bk+4u), nv(bk+5u), nv(bk+6u), nv(bk+7u)), | |
| vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u, | |
| @builtin(local_invocation_id) lid: vec3u) { | |
| let tid = lid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| // Phase 1: compute RMS (132B shared memory total) | |
| var ss: f32 = 0.0; | |
| for (var i = tid; i < K; i += 32u) { | |
| let val = hidden_in[i] + residual[i]; | |
| ss += val * val; | |
| } | |
| wg_scratch[tid] = ss; | |
| workgroupBarrier(); | |
| for (var s = 16u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| if (tid == 0u) { wg_rms = 1.0 / sqrt(wg_scratch[0] / f32(K) + params.eps); } | |
| workgroupBarrier(); | |
| // Phase 2: GPTQ matvec with inline norm (re-reads input from L1 cache) | |
| let col = gid.x; | |
| if (col >= N) { return; } | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| var sum: f32 = 0.0; | |
| for (var grp: u32 = 0u; grp < num_groups; grp++) { | |
| let sf = grp * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = grp * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; | |
| let bk = pr * 8u; | |
| sum += s * (dq_nv(w0,bk) + dq_nv(w1,bk+8u) + dq_nv(w2,bk+16u) + dq_nv(w3,bk+24u)); | |
| } | |
| } | |
| output[col] = sum; | |
| }`,nr=` | |
| struct Params { | |
| K: u32, N: u32, group_size: u32, eps: f32, | |
| norm_weight: array<vec4<u32>, 640>, | |
| } | |
| @group(0) @binding(0) var<storage, read> hidden_in: array<f32>; | |
| @group(0) @binding(1) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(2) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(3) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| fn unpack_nw(i: u32) -> f32 { | |
| let packed = params.norm_weight[i / 8u][(i / 2u) % 4u]; | |
| let bits = (packed >> ((i % 2u) * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| var<workgroup> wg_scratch: array<f32, 32>; | |
| var<workgroup> wg_rms: f32; | |
| fn nv(i: u32) -> f32 { | |
| return hidden_in[i] * wg_rms * (1.0 + unpack_nw(i)); | |
| } | |
| fn dq_nv(w: u32, bk: u32) -> f32 { | |
| return dot(vec4<f32>(nv(bk), nv(bk+1u), nv(bk+2u), nv(bk+3u)), | |
| vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) | |
| + dot(vec4<f32>(nv(bk+4u), nv(bk+5u), nv(bk+6u), nv(bk+7u)), | |
| vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u, | |
| @builtin(local_invocation_id) lid: vec3u) { | |
| let tid = lid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| var ss: f32 = 0.0; | |
| for (var i = tid; i < K; i += 32u) { | |
| ss += hidden_in[i] * hidden_in[i]; | |
| } | |
| wg_scratch[tid] = ss; | |
| workgroupBarrier(); | |
| for (var s = 16u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| if (tid == 0u) { wg_rms = 1.0 / sqrt(wg_scratch[0] / f32(K) + params.eps); } | |
| workgroupBarrier(); | |
| let col = gid.x; | |
| if (col >= N) { return; } | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| var sum: f32 = 0.0; | |
| for (var grp: u32 = 0u; grp < num_groups; grp++) { | |
| let sf = grp * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = grp * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; | |
| let bk = pr * 8u; | |
| sum += s * (dq_nv(w0,bk) + dq_nv(w1,bk+8u) + dq_nv(w2,bk+16u) + dq_nv(w3,bk+24u)); | |
| } | |
| } | |
| output[col] = sum; | |
| }`,gr=` | |
| struct Params { num_tokens: u32, penalty: f32, presence: f32, _pad: u32, } | |
| @group(0) @binding(0) var<storage, read_write> logits: array<f32>; | |
| @group(0) @binding(1) var<storage, read> recent_tokens: array<u32>; | |
| @group(0) @binding(2) var<uniform> params: Params; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let i = gid.x; | |
| if (i >= params.num_tokens) { return; } | |
| let tok = recent_tokens[i]; | |
| var val = logits[tok]; | |
| // Repetition penalty (multiplicative) | |
| if (params.penalty > 1.0) { | |
| if (val > 0.0) { | |
| val = val / params.penalty; | |
| } else { | |
| val = val * params.penalty; | |
| } | |
| } | |
| // Presence penalty (additive) | |
| val -= params.presence; | |
| logits[tok] = val; | |
| }`,dr=` | |
| struct Params { | |
| temperature: f32, | |
| top_k: u32, | |
| top_p: f32, | |
| random_seed: u32, | |
| } | |
| struct Result { idx: u32, val: f32, } | |
| @group(0) @binding(0) var<storage, read> candidates: array<u32>; // 256 × (idx u32, val u32) | |
| @group(0) @binding(1) var<storage, read_write> result: Result; | |
| @group(0) @binding(2) var<uniform> params: Params; | |
| var<workgroup> wg_idx: array<u32, 256>; | |
| var<workgroup> wg_val: array<f32, 256>; | |
| var<workgroup> wg_prob: array<f32, 256>; | |
| // PCG random number generator | |
| fn pcg(state: u32) -> u32 { | |
| let s = state * 747796405u + 2891336453u; | |
| let word = ((s >> ((s >> 28u) + 4u)) ^ s) * 277803737u; | |
| return (word >> 22u) ^ word; | |
| } | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(local_invocation_id) lid: vec3u) { | |
| let tid = lid.x; | |
| // Load candidate (idx, logit) pair | |
| wg_idx[tid] = candidates[tid * 2u]; | |
| wg_val[tid] = bitcast<f32>(candidates[tid * 2u + 1u]); | |
| workgroupBarrier(); | |
| // ── Step 1: Parallel find max logit (for numerical stability) ── | |
| var local_max = wg_val[tid]; | |
| workgroupBarrier(); | |
| // Tree reduction for max | |
| for (var s = 128u; s > 0u; s >>= 1u) { | |
| if (tid < s) { | |
| local_max = max(local_max, wg_val[tid + s]); | |
| wg_val[tid] = local_max; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let max_val = wg_val[0]; | |
| // Reload original values (they were clobbered by reduction) | |
| wg_val[tid] = bitcast<f32>(candidates[tid * 2u + 1u]); | |
| workgroupBarrier(); | |
| // ── Step 2: Bitonic sort by logit value (descending) ── | |
| // Sort all 256 candidates so top-K and top-P can work on sorted order. | |
| for (var k = 2u; k <= 256u; k <<= 1u) { | |
| for (var j = k >> 1u; j > 0u; j >>= 1u) { | |
| let ixj = tid ^ j; | |
| if (ixj > tid) { | |
| let ascending = (tid & k) != 0u; | |
| let val_tid = wg_val[tid]; | |
| let val_ixj = wg_val[ixj]; | |
| // We want descending, so swap if tid has smaller value | |
| let should_swap = select((val_tid < val_ixj), (val_tid > val_ixj), ascending); | |
| if (should_swap) { | |
| wg_val[tid] = val_ixj; | |
| wg_val[ixj] = val_tid; | |
| let tmp_idx = wg_idx[tid]; | |
| wg_idx[tid] = wg_idx[ixj]; | |
| wg_idx[ixj] = tmp_idx; | |
| } | |
| } | |
| workgroupBarrier(); | |
| } | |
| } | |
| // ── Step 3: Temperature + softmax (parallel exp, parallel prefix sum) ── | |
| let temp = max(params.temperature, 0.001); // Avoid division by zero | |
| let K = min(params.top_k, 256u); | |
| // Each thread computes exp((val - max) / temp) if within top-K | |
| var prob: f32 = 0.0; | |
| if (tid < K) { | |
| prob = exp((wg_val[tid] - max_val) / temp); | |
| } | |
| wg_prob[tid] = prob; | |
| workgroupBarrier(); | |
| // ── Step 4: Parallel prefix sum for cumulative probabilities ── | |
| // Blelloch scan (up-sweep then down-sweep) | |
| // First, compute total sum via tree reduction (we need it for normalization) | |
| var sum_val = prob; | |
| for (var s = 1u; s < 256u; s <<= 1u) { | |
| workgroupBarrier(); | |
| if (tid >= s) { | |
| sum_val = wg_prob[tid - s] + wg_prob[tid]; | |
| } | |
| workgroupBarrier(); | |
| wg_prob[tid] = sum_val; | |
| } | |
| // Now wg_prob[i] = cumulative sum up to and including index i | |
| // wg_prob[255] = total sum | |
| let total_sum = wg_prob[255]; | |
| workgroupBarrier(); | |
| // ── Step 5: Thread 0 does top-P + sampling ── | |
| // This is sequential but operates on sorted, cumulative data — very fast. | |
| if (tid == 0u) { | |
| let topP = params.top_p; | |
| let threshold = topP * total_sum; // Un-normalized threshold | |
| // Find cutoff: first index where cumulative prob >= threshold | |
| var cutoff = K; | |
| for (var i = 0u; i < K; i++) { | |
| if (wg_prob[i] >= threshold) { | |
| cutoff = i + 1u; | |
| break; | |
| } | |
| } | |
| // Sample: generate random float in [0, cumProb[cutoff-1]) | |
| let rand_u32 = pcg(params.random_seed); | |
| let rand_f32 = f32(rand_u32) / 4294967296.0; // [0, 1) | |
| let sample_point = rand_f32 * wg_prob[cutoff - 1u]; | |
| // Find sampled index via linear scan of cumulative probs | |
| var selected = 0u; | |
| for (var i = 0u; i < cutoff; i++) { | |
| if (wg_prob[i] >= sample_point) { | |
| selected = i; | |
| break; | |
| } | |
| } | |
| result.idx = wg_idx[selected]; | |
| result.val = wg_val[selected]; | |
| } | |
| }`;function _(d=320,e=1e7,u=33,t=30,r=128){return` | |
| const ROPE_THETA: f32 = ${e}; | |
| const MROPE_S1_LIMIT: u32 = ${u}u; | |
| const MROPE_S2_LIMIT: u32 = ${t}u; | |
| const PARTIAL_DIM: u32 = ${r}u; | |
| struct Params { | |
| num_heads: u32, // Q heads | |
| num_kv_heads: u32, // KV heads | |
| head_dim: u32, | |
| eps: f32, | |
| cache_position: u32, // KV cache write index (sequential, updated per token) | |
| position: u32, // RoPE temporal/text position (updated per token) | |
| position_h: u32, // mRoPE height dimension (= position for text tokens) | |
| position_w: u32, // mRoPE width dimension (= position for text tokens) | |
| // Packed BF16 norm weights: [Q norm weights | K norm weights] | |
| // Q: numHeads*headDim BF16 values, K: numKVHeads*headDim BF16 values | |
| qk_norm_weight: array<vec4<u32>, ${d}>, | |
| } | |
| @group(0) @binding(0) var<storage, read> q_proj_full: array<f32>; | |
| @group(0) @binding(1) var<storage, read_write> k_proj: array<f32>; | |
| @group(0) @binding(2) var<storage, read> v_proj: array<f32>; | |
| @group(0) @binding(3) var<storage, read_write> q_proj: array<f32>; | |
| @group(0) @binding(4) var<storage, read_write> q_gate: array<f32>; | |
| @group(0) @binding(5) var<storage, read_write> k_cache: array<f32>; | |
| @group(0) @binding(6) var<storage, read_write> v_cache: array<f32>; | |
| @group(0) @binding(7) var<uniform> params: Params; | |
| fn unpack_norm_bf16(flat_idx: u32) -> f32 { | |
| let vec_idx = flat_idx / 8u; | |
| let u32_idx = (flat_idx / 2u) % 4u; | |
| let bf16_idx = flat_idx % 2u; | |
| let packed = params.qk_norm_weight[vec_idx][u32_idx]; | |
| let bits = (packed >> (bf16_idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| var<workgroup> wg_data: array<f32, 256>; | |
| var<workgroup> wg_scratch: array<f32, 256>; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) { | |
| let h = wg.x; | |
| let tid = lid.x; | |
| let hd = params.head_dim; | |
| let nh = params.num_heads; | |
| let nkv = params.num_kv_heads; | |
| if (h < nh) { | |
| // ── Q head: split from interleaved + RMSNorm ── | |
| let src_off = h * hd * 2u; | |
| // Phase 1: Read Q values, compute ss, write gate | |
| var ss: f32 = 0.0; | |
| for (var i = tid; i < hd; i += 256u) { | |
| let q_val = q_proj_full[src_off + i]; | |
| wg_data[i] = q_val; | |
| ss += q_val * q_val; | |
| q_gate[h * hd + i] = q_proj_full[src_off + hd + i]; | |
| } | |
| wg_scratch[tid] = ss; | |
| workgroupBarrier(); | |
| for (var s = 128u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| let rms = 1.0 / sqrt(wg_scratch[0] / f32(hd) + params.eps); | |
| // Phase 2: Normalize Q and write to shared mem | |
| for (var i = tid; i < hd; i += 256u) { | |
| let w = unpack_norm_bf16(h * hd + i); | |
| wg_data[i] = wg_data[i] * rms * (1.0 + w); | |
| } | |
| workgroupBarrier(); | |
| // Phase 3: Apply mRoPE to Q (rotate_half: dim i pairs with dim i+half_dim) | |
| // mRoPE: interleaved position selection based on freq_idx | |
| let q_half = PARTIAL_DIM / 2u; | |
| for (var i = tid; i < hd; i += 256u) { | |
| var q_out = wg_data[i]; | |
| if (i < PARTIAL_DIM) { | |
| let freq_idx = i % q_half; | |
| let freq = 1.0 / pow(ROPE_THETA, f32(2u * freq_idx) / f32(PARTIAL_DIM)); | |
| // mRoPE: select position based on interleaved section | |
| var pos = params.position; | |
| if (freq_idx % 3u == 1u && freq_idx < MROPE_S1_LIMIT) { pos = params.position_h; } | |
| else if (freq_idx % 3u == 2u && freq_idx < MROPE_S2_LIMIT) { pos = params.position_w; } | |
| let angle = f32(pos) * freq; | |
| let cos_a = cos(angle); | |
| let sin_a = sin(angle); | |
| if (i < q_half) { | |
| q_out = wg_data[i] * cos_a - wg_data[i + q_half] * sin_a; | |
| } else { | |
| q_out = wg_data[i] * cos_a + wg_data[i - q_half] * sin_a; | |
| } | |
| } | |
| q_proj[h * hd + i] = q_out; | |
| } | |
| } else { | |
| // ── KV head: K norm + KV cache store ── | |
| let kv_h = h - nh; | |
| if (kv_h >= nkv) { return; } | |
| let off = kv_h * hd; | |
| let q_norm_size = nh * hd; // offset to K norm weights in packed array | |
| // Phase 1: Read K, compute ss | |
| var ss: f32 = 0.0; | |
| for (var i = tid; i < hd; i += 256u) { | |
| let k_val = k_proj[off + i]; | |
| wg_data[i] = k_val; | |
| ss += k_val * k_val; | |
| } | |
| wg_scratch[tid] = ss; | |
| workgroupBarrier(); | |
| for (var s = 128u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| let rms = 1.0 / sqrt(wg_scratch[0] / f32(hd) + params.eps); | |
| // Phase 2: Normalize K, write to shared mem | |
| for (var i = tid; i < hd; i += 256u) { | |
| let w = unpack_norm_bf16(q_norm_size + kv_h * hd + i); | |
| wg_data[i] = wg_data[i] * rms * (1.0 + w); | |
| } | |
| workgroupBarrier(); | |
| // Phase 3: Apply mRoPE to K, write to kProj + k_cache, copy V to v_cache | |
| let total = nkv * hd; | |
| let cache_off = params.cache_position * total + off; | |
| let k_half = PARTIAL_DIM / 2u; | |
| for (var i = tid; i < hd; i += 256u) { | |
| var k_out = wg_data[i]; | |
| if (i < PARTIAL_DIM) { | |
| let freq_idx = i % k_half; | |
| let freq = 1.0 / pow(ROPE_THETA, f32(2u * freq_idx) / f32(PARTIAL_DIM)); | |
| // mRoPE: select position based on interleaved section | |
| var pos = params.position; | |
| if (freq_idx % 3u == 1u && freq_idx < MROPE_S1_LIMIT) { pos = params.position_h; } | |
| else if (freq_idx % 3u == 2u && freq_idx < MROPE_S2_LIMIT) { pos = params.position_w; } | |
| let angle = f32(pos) * freq; | |
| let cos_a = cos(angle); | |
| let sin_a = sin(angle); | |
| if (i < k_half) { | |
| k_out = wg_data[i] * cos_a - wg_data[i + k_half] * sin_a; | |
| } else { | |
| k_out = wg_data[i] * cos_a + wg_data[i - k_half] * sin_a; | |
| } | |
| } | |
| k_proj[off + i] = k_out; | |
| k_cache[cache_off + i] = k_out; | |
| v_cache[cache_off + i] = v_proj[off + i]; | |
| } | |
| } | |
| }`}const lr=_(320),pr=` | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(2) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(3) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| fn dq(w: u32, bk: u32) -> f32 { | |
| return dot(vec4<f32>(input[bk], input[bk+1u], input[bk+2u], input[bk+3u]), | |
| vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) | |
| + dot(vec4<f32>(input[bk+4u], input[bk+5u], input[bk+6u], input[bk+7u]), | |
| vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); | |
| } | |
| var<workgroup> partial: array<f32, 32>; | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { | |
| let tid = lid.x; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = params.K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| let col_in_wg = tid / 4u; // 0-7: which of 8 columns this thread handles | |
| let lane = tid & 3u; // 0-3: which quarter of K this thread computes | |
| let col = wid.x * 8u + col_in_wg; | |
| var sum: f32 = 0.0; | |
| if (col < N) { | |
| let groups_per_lane = num_groups / 4u; | |
| let g_start = lane * groups_per_lane; | |
| let g_end = g_start + groups_per_lane; | |
| for (var g = g_start; g < g_end; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; | |
| let w1 = qweight[off + N]; | |
| let w2 = qweight[off + 2u * N]; | |
| let w3 = qweight[off + 3u * N]; | |
| let bk = pr * 8u; | |
| sum += s * (dq(w0, bk) + dq(w1, bk+8u) + | |
| dq(w2, bk+16u) + dq(w3, bk+24u)); | |
| } | |
| } | |
| } | |
| // Store partial sum and reduce within each 4-thread column group | |
| // With @workgroup_size(32) on Apple GPU, all 32 threads are in 1 SIMD group | |
| // so workgroupBarrier is effectively free. | |
| partial[tid] = sum; | |
| workgroupBarrier(); | |
| if (lane == 0u && col < N) { | |
| output[col] = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u]; | |
| } | |
| }`,fr=` | |
| enable f16; | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(2) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(3) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| fn dq16(w: u32, bk: u32) -> f32 { | |
| let i0 = vec4<f16>(f16(input[bk]), f16(input[bk+1u]), f16(input[bk+2u]), f16(input[bk+3u])); | |
| let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); | |
| let i1 = vec4<f16>(f16(input[bk+4u]), f16(input[bk+5u]), f16(input[bk+6u]), f16(input[bk+7u])); | |
| let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); | |
| return f32(dot(i0, w0)) + f32(dot(i1, w1)); | |
| } | |
| var<workgroup> partial: array<f32, 32>; | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { | |
| let tid = lid.x; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = params.K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| let col_in_wg = tid / 4u; | |
| let lane = tid & 3u; | |
| let col = wid.x * 8u + col_in_wg; | |
| var sum: f32 = 0.0; | |
| if (col < N) { | |
| let groups_per_lane = num_groups / 4u; | |
| let g_start = lane * groups_per_lane; | |
| let g_end = g_start + groups_per_lane; | |
| for (var g = g_start; g < g_end; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; | |
| let w1 = qweight[off + N]; | |
| let w2 = qweight[off + 2u * N]; | |
| let w3 = qweight[off + 3u * N]; | |
| let bk = pr * 8u; | |
| sum += s * (dq16(w0, bk) + dq16(w1, bk+8u) + | |
| dq16(w2, bk+16u) + dq16(w3, bk+24u)); | |
| } | |
| } | |
| } | |
| partial[tid] = sum; | |
| workgroupBarrier(); | |
| if (lane == 0u && col < N) { | |
| output[col] = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u]; | |
| } | |
| }`,_r=` | |
| enable f16; | |
| enable subgroups; | |
| struct Params { K: u32, N: u32, group_size: u32, num_splits: u32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(2) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(3) var<storage, read_write> partials: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| fn dq16(w: u32, bk: u32) -> f32 { | |
| let i0 = vec4<f16>(f16(input[bk]), f16(input[bk+1u]), f16(input[bk+2u]), f16(input[bk+3u])); | |
| let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); | |
| let i1 = vec4<f16>(f16(input[bk+4u]), f16(input[bk+5u]), f16(input[bk+6u]), f16(input[bk+7u])); | |
| let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); | |
| return f32(dot(i0, w0)) + f32(dot(i1, w1)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { | |
| let tid = lid.x; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let ns = params.num_splits; | |
| let packed_rows = params.K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| let col_in_wg = tid / 4u; | |
| let lane = tid & 3u; | |
| let col = wid.x * 8u + col_in_wg; | |
| let split_id = wid.y; | |
| var sum: f32 = 0.0; | |
| if (col < N) { | |
| // Divide groups across splits, then across 4 lanes within each split | |
| let groups_per_split = num_groups / ns; | |
| let g_base = split_id * groups_per_split; | |
| let split_end = select(g_base + groups_per_split, num_groups, split_id >= ns - 1u); | |
| let split_groups = split_end - g_base; | |
| let groups_per_lane = split_groups / 4u; | |
| let g_start = g_base + lane * groups_per_lane; | |
| let g_end = select(g_start + groups_per_lane, split_end, lane >= 3u); | |
| for (var g = g_start; g < g_end; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; | |
| let w1 = qweight[off + N]; | |
| let w2 = qweight[off + 2u * N]; | |
| let w3 = qweight[off + 3u * N]; | |
| let bk = pr * 8u; | |
| sum += s * (dq16(w0, bk) + dq16(w1, bk+8u) + | |
| dq16(w2, bk+16u) + dq16(w3, bk+24u)); | |
| } | |
| } | |
| } | |
| // Subgroup reduction across 4 lanes (no shared memory, no barrier) | |
| var s = sum; | |
| s += subgroupShuffleXor(s, 1u); | |
| s += subgroupShuffleXor(s, 2u); | |
| if (lane == 0u && col < N) { | |
| partials[split_id * N + col] = s; | |
| } | |
| }`,cr=` | |
| struct Params { H_half: u32, num_tokens: u32, } | |
| @group(0) @binding(0) var<storage, read> src: array<u32>; | |
| @group(0) @binding(1) var<storage, read> indices: array<u32>; | |
| @group(0) @binding(2) var<storage, read_write> dst: array<u32>; | |
| @group(0) @binding(3) var<uniform> p: Params; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let col = gid.x; | |
| let row_out = gid.y; | |
| if (col >= p.H_half || row_out >= p.num_tokens) { return; } | |
| let src_row = indices[row_out]; | |
| dst[row_out * p.H_half + col] = src[src_row * p.H_half + col]; | |
| } | |
| `,wr=` | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> bf16_weight: array<u32>; | |
| @group(0) @binding(1) var<storage, read_write> qweight: array<u32>; | |
| @group(0) @binding(2) var<storage, read_write> scales_f32: array<f32>; | |
| @group(0) @binding(3) var<uniform> params: Params; | |
| var<workgroup> wg_max: array<f32, 32>; | |
| fn unpack_bf16(idx: u32, base: u32) -> f32 { | |
| let packed = bf16_weight[base + idx / 2u]; | |
| let bits = (packed >> ((idx % 2u) * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(workgroup_id) wg_id: vec3u, | |
| @builtin(local_invocation_id) lid: vec3u) { | |
| let col = wg_id.x + wg_id.y * 65535u; // 2D dispatch for N > 65535 | |
| let tid = lid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| if (col >= N) { return; } | |
| let gs = params.group_size; | |
| let num_groups = K / gs; | |
| let base = col * K / 2u; | |
| for (var g = 0u; g < num_groups; g++) { | |
| let gstart = g * gs; | |
| // Find max_abs (parallel reduction across 32 threads) | |
| var local_max: f32 = 0.0; | |
| for (var i = tid; i < gs; i += 32u) { | |
| local_max = max(local_max, abs(unpack_bf16(gstart + i, base))); | |
| } | |
| wg_max[tid] = local_max; | |
| workgroupBarrier(); | |
| for (var s = 16u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_max[tid] = max(wg_max[tid], wg_max[tid + s]); } | |
| workgroupBarrier(); | |
| } | |
| let scale_val = wg_max[0] / 7.0; | |
| let inv_scale = select(1.0 / scale_val, 0.0, scale_val == 0.0); | |
| // Store scale as f32 (packed to F16 in separate pass) | |
| if (tid == 0u) { scales_f32[g * N + col] = scale_val; } | |
| // Quantize and pack 8 INT4 values per u32 | |
| let rows_per_group = gs / 8u; | |
| for (var r = tid; r < rows_per_group; r += 32u) { | |
| var packed: u32 = 0u; | |
| let bk = gstart + r * 8u; | |
| for (var b = 0u; b < 8u; b++) { | |
| let val = unpack_bf16(bk + b, base); | |
| var q = i32(round(val * inv_scale)) + 8; | |
| q = clamp(q, 0, 15); | |
| packed |= (u32(q) << (b * 4u)); | |
| } | |
| let packed_row = g * rows_per_group + r; | |
| qweight[packed_row * N + col] = packed; | |
| } | |
| } | |
| }`,br=` | |
| struct Params { count: u32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read_write> output: array<u32>; | |
| @group(0) @binding(2) var<uniform> params: Params; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let idx = gid.x; | |
| if (idx >= params.count) { return; } | |
| let a = input[idx * 2u]; | |
| let b = input[idx * 2u + 1u]; | |
| output[idx] = pack2x16float(vec2<f32>(a, b)); | |
| }`,vr=` | |
| struct Params { N: u32, eps: f32, } | |
| @group(0) @binding(0) var<storage, read_write> hidden: array<f32>; | |
| @group(0) @binding(1) var<storage, read> addend: array<f32>; | |
| @group(0) @binding(2) var<storage, read> weight: array<u32>; | |
| @group(0) @binding(3) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| fn unpack_bf16(packed: u32, idx: u32) -> f32 { | |
| let bits = (packed >> (idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| var<workgroup> wg_scratch: array<f32, 256>; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { | |
| let tid = lid.x; | |
| let N = params.N; | |
| let off = wid.x * N; | |
| var ss: f32 = 0.0; | |
| for (var i = tid; i < N; i += 256u) { | |
| let val = hidden[off + i] + addend[off + i]; | |
| hidden[off + i] = val; | |
| ss += val * val; | |
| } | |
| wg_scratch[tid] = ss; | |
| workgroupBarrier(); | |
| for (var s = 128u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps); | |
| for (var i = tid; i < N; i += 256u) { | |
| let w = unpack_bf16(weight[i / 2u], i % 2u); | |
| output[off + i] = hidden[off + i] * rms * (1.0 + w); | |
| } | |
| }`,mr=` | |
| struct Params { N: u32, eps: f32, } | |
| @group(0) @binding(0) var<storage, read> hidden: array<f32>; | |
| @group(0) @binding(1) var<storage, read> addend: array<f32>; | |
| @group(0) @binding(2) var<storage, read> weight: array<u32>; | |
| @group(0) @binding(3) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| fn unpack_bf16(packed: u32, idx: u32) -> f32 { | |
| let bits = (packed >> (idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| var<workgroup> wg_vals: array<f32, 7680>; | |
| var<workgroup> wg_scratch: array<f32, 256>; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { | |
| let tid = lid.x; | |
| let N = params.N; | |
| let off = wid.x * N; | |
| var ss: f32 = 0.0; | |
| for (var i = tid; i < N; i += 256u) { | |
| let val = hidden[off + i] + addend[off + i]; | |
| wg_vals[i] = val; | |
| ss += val * val; | |
| } | |
| wg_scratch[tid] = ss; | |
| workgroupBarrier(); | |
| for (var s = 128u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps); | |
| for (var i = tid; i < N; i += 256u) { | |
| let w = unpack_bf16(weight[i / 2u], i % 2u); | |
| output[off + i] = wg_vals[i] * rms * (1.0 + w); | |
| } | |
| }`,kr=` | |
| struct Params { N: u32, eps: f32, } | |
| @group(0) @binding(0) var<storage, read> input_a: array<f32>; | |
| @group(0) @binding(1) var<storage, read> input_b: array<f32>; | |
| @group(0) @binding(2) var<storage, read> input_c: array<f32>; | |
| @group(0) @binding(3) var<storage, read> weight: array<u32>; | |
| @group(0) @binding(4) var<storage, read_write> hidden_out: array<f32>; | |
| @group(0) @binding(5) var<storage, read_write> normed: array<f32>; | |
| @group(0) @binding(6) var<uniform> params: Params; | |
| fn unpack_bf16(packed: u32, idx: u32) -> f32 { | |
| let bits = (packed >> (idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| var<workgroup> wg_vals: array<f32, 7680>; | |
| var<workgroup> wg_scratch: array<f32, 256>; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { | |
| let tid = lid.x; | |
| let N = params.N; | |
| let off = wid.x * N; | |
| var ss: f32 = 0.0; | |
| for (var i = tid; i < N; i += 256u) { | |
| let val = input_a[off + i] + input_b[off + i] + input_c[off + i]; | |
| wg_vals[i] = val; | |
| hidden_out[off + i] = val; | |
| ss += val * val; | |
| } | |
| wg_scratch[tid] = ss; | |
| workgroupBarrier(); | |
| for (var s = 128u; s > 0u; s >>= 1u) { | |
| if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } | |
| workgroupBarrier(); | |
| } | |
| let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps); | |
| for (var i = tid; i < N; i += 256u) { | |
| let w = unpack_bf16(weight[i / 2u], i % 2u); | |
| normed[off + i] = wg_vals[i] * rms * (1.0 + w); | |
| } | |
| }`,hr=` | |
| enable f16; | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read> qweight_gate: array<u32>; | |
| @group(0) @binding(2) var<storage, read> scales_gate: array<u32>; | |
| @group(0) @binding(3) var<storage, read> qweight_up: array<u32>; | |
| @group(0) @binding(4) var<storage, read> scales_up: array<u32>; | |
| @group(0) @binding(5) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(6) var<uniform> params: Params; | |
| fn dq16_at(w: u32, base: u32, bk: u32) -> f32 { | |
| let i0 = vec4<f16>(f16(input[base+bk]), f16(input[base+bk+1u]), f16(input[base+bk+2u]), f16(input[base+bk+3u])); | |
| let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); | |
| let i1 = vec4<f16>(f16(input[base+bk+4u]), f16(input[base+bk+5u]), f16(input[base+bk+6u]), f16(input[base+bk+7u])); | |
| let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); | |
| return f32(dot(i0, w0)) + f32(dot(i1, w1)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let col = gid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| if (col >= N) { return; } | |
| var gate0: f32 = 0.0; var up0: f32 = 0.0; | |
| var gate1: f32 = 0.0; var up1: f32 = 0.0; | |
| for (var grp: u32 = 0u; grp < num_groups; grp++) { | |
| let sf = grp * N + col; | |
| let sg = unpack2x16float(scales_gate[sf >> 1u])[sf & 1u]; | |
| let su = unpack2x16float(scales_up[sf >> 1u])[sf & 1u]; | |
| let gstart = grp * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let gw0 = qweight_gate[off]; let uw0 = qweight_up[off]; | |
| let gw1 = qweight_gate[off+N]; let uw1 = qweight_up[off+N]; | |
| let gw2 = qweight_gate[off+2u*N]; let uw2 = qweight_up[off+2u*N]; | |
| let gw3 = qweight_gate[off+3u*N]; let uw3 = qweight_up[off+3u*N]; | |
| let bk = pr * 8u; | |
| // Batch 0: input at offset 0 | |
| gate0 += sg * (dq16_at(gw0,0u,bk) + dq16_at(gw1,0u,bk+8u) + dq16_at(gw2,0u,bk+16u) + dq16_at(gw3,0u,bk+24u)); | |
| up0 += su * (dq16_at(uw0,0u,bk) + dq16_at(uw1,0u,bk+8u) + dq16_at(uw2,0u,bk+16u) + dq16_at(uw3,0u,bk+24u)); | |
| // Batch 1: input at offset K | |
| gate1 += sg * (dq16_at(gw0,K,bk) + dq16_at(gw1,K,bk+8u) + dq16_at(gw2,K,bk+16u) + dq16_at(gw3,K,bk+24u)); | |
| up1 += su * (dq16_at(uw0,K,bk) + dq16_at(uw1,K,bk+8u) + dq16_at(uw2,K,bk+16u) + dq16_at(uw3,K,bk+24u)); | |
| } | |
| } | |
| output[col] = (gate0 / (1.0 + exp(-gate0))) * up0; | |
| output[N + col] = (gate1 / (1.0 + exp(-gate1))) * up1; | |
| }`,xr=` | |
| enable f16; | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> a: array<f32>; | |
| @group(0) @binding(1) var<storage, read> b: array<f32>; | |
| @group(0) @binding(2) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(3) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(4) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| fn silu(v: f32) -> f32 { return v / (1.0 + exp(-v)); } | |
| fn dq16_sl_at(w: u32, base: u32, bk: u32) -> f32 { | |
| let i0 = vec4<f16>(f16(silu(a[base+bk])*b[base+bk]), f16(silu(a[base+bk+1u])*b[base+bk+1u]), f16(silu(a[base+bk+2u])*b[base+bk+2u]), f16(silu(a[base+bk+3u])*b[base+bk+3u])); | |
| let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); | |
| let i1 = vec4<f16>(f16(silu(a[base+bk+4u])*b[base+bk+4u]), f16(silu(a[base+bk+5u])*b[base+bk+5u]), f16(silu(a[base+bk+6u])*b[base+bk+6u]), f16(silu(a[base+bk+7u])*b[base+bk+7u])); | |
| let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); | |
| return f32(dot(i0, w0)) + f32(dot(i1, w1)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let col = gid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| if (col >= N) { return; } | |
| var sum0: f32 = 0.0; | |
| var sum1: f32 = 0.0; | |
| for (var g: u32 = 0u; g < num_groups; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; | |
| let bk = pr * 8u; | |
| sum0 += s * (dq16_sl_at(w0,0u,bk) + dq16_sl_at(w1,0u,bk+8u) + dq16_sl_at(w2,0u,bk+16u) + dq16_sl_at(w3,0u,bk+24u)); | |
| sum1 += s * (dq16_sl_at(w0,K,bk) + dq16_sl_at(w1,K,bk+8u) + dq16_sl_at(w2,K,bk+16u) + dq16_sl_at(w3,K,bk+24u)); | |
| } | |
| } | |
| output[col] = sum0; | |
| output[N + col] = sum1; | |
| }`,qr=` | |
| enable f16; | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> x: array<f32>; | |
| @group(0) @binding(1) var<storage, read> gate: array<f32>; | |
| @group(0) @binding(2) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(3) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(4) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| fn sig(v: f32) -> f32 { return 1.0 / (1.0 + exp(-v)); } | |
| fn dq16_sg_at(w: u32, base: u32, bk: u32) -> f32 { | |
| let i0 = vec4<f16>(f16(x[base+bk]*sig(gate[base+bk])), f16(x[base+bk+1u]*sig(gate[base+bk+1u])), f16(x[base+bk+2u]*sig(gate[base+bk+2u])), f16(x[base+bk+3u]*sig(gate[base+bk+3u]))); | |
| let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); | |
| let i1 = vec4<f16>(f16(x[base+bk+4u]*sig(gate[base+bk+4u])), f16(x[base+bk+5u]*sig(gate[base+bk+5u])), f16(x[base+bk+6u]*sig(gate[base+bk+6u])), f16(x[base+bk+7u]*sig(gate[base+bk+7u]))); | |
| let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); | |
| return f32(dot(i0, w0)) + f32(dot(i1, w1)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let col = gid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| if (col >= N) { return; } | |
| var sum0: f32 = 0.0; | |
| var sum1: f32 = 0.0; | |
| for (var g: u32 = 0u; g < num_groups; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; | |
| let bk = pr * 8u; | |
| sum0 += s * (dq16_sg_at(w0,0u,bk) + dq16_sg_at(w1,0u,bk+8u) + dq16_sg_at(w2,0u,bk+16u) + dq16_sg_at(w3,0u,bk+24u)); | |
| sum1 += s * (dq16_sg_at(w0,K,bk) + dq16_sg_at(w1,K,bk+8u) + dq16_sg_at(w2,K,bk+16u) + dq16_sg_at(w3,K,bk+24u)); | |
| } | |
| } | |
| output[col] = sum0; | |
| output[N + col] = sum1; | |
| }`,yr=` | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(2) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(3) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| fn dq_at(w: u32, base: u32, bk: u32) -> f32 { | |
| return dot(vec4<f32>(input[base+bk], input[base+bk+1u], input[base+bk+2u], input[base+bk+3u]), | |
| vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) | |
| + dot(vec4<f32>(input[base+bk+4u], input[base+bk+5u], input[base+bk+6u], input[base+bk+7u]), | |
| vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let col = gid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| if (col >= N) { return; } | |
| var sum0: f32 = 0.0; | |
| var sum1: f32 = 0.0; | |
| for (var g: u32 = 0u; g < num_groups; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; | |
| let w1 = qweight[off + N]; | |
| let w2 = qweight[off + 2u * N]; | |
| let w3 = qweight[off + 3u * N]; | |
| let bk = pr * 8u; | |
| sum0 += s * (dq_at(w0, 0u, bk) + dq_at(w1, 0u, bk+8u) + | |
| dq_at(w2, 0u, bk+16u) + dq_at(w3, 0u, bk+24u)); | |
| sum1 += s * (dq_at(w0, K, bk) + dq_at(w1, K, bk+8u) + | |
| dq_at(w2, K, bk+16u) + dq_at(w3, K, bk+24u)); | |
| } | |
| } | |
| output[col] = sum0; | |
| output[N + col] = sum1; | |
| }`,Fr=` | |
| enable f16; | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(2) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(3) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| fn dq16_at(w: u32, base: u32, bk: u32) -> f32 { | |
| let i0 = vec4<f16>(f16(input[base+bk]), f16(input[base+bk+1u]), f16(input[base+bk+2u]), f16(input[base+bk+3u])); | |
| let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); | |
| let i1 = vec4<f16>(f16(input[base+bk+4u]), f16(input[base+bk+5u]), f16(input[base+bk+6u]), f16(input[base+bk+7u])); | |
| let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); | |
| return f32(dot(i0, w0)) + f32(dot(i1, w1)); | |
| } | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let col = gid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| if (col >= N) { return; } | |
| var sum0: f32 = 0.0; | |
| var sum1: f32 = 0.0; | |
| for (var g: u32 = 0u; g < num_groups; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; | |
| let w1 = qweight[off + N]; | |
| let w2 = qweight[off + 2u * N]; | |
| let w3 = qweight[off + 3u * N]; | |
| let bk = pr * 8u; | |
| sum0 += s * (dq16_at(w0, 0u, bk) + dq16_at(w1, 0u, bk+8u) + | |
| dq16_at(w2, 0u, bk+16u) + dq16_at(w3, 0u, bk+24u)); | |
| sum1 += s * (dq16_at(w0, K, bk) + dq16_at(w1, K, bk+8u) + | |
| dq16_at(w2, K, bk+16u) + dq16_at(w3, K, bk+24u)); | |
| } | |
| } | |
| output[col] = sum0; | |
| output[N + col] = sum1; | |
| }`,Nr=` | |
| enable f16; | |
| struct Params { K: u32, N: u32, group_size: u32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read> qweight: array<u32>; | |
| @group(0) @binding(2) var<storage, read> scales: array<u32>; | |
| @group(0) @binding(3) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| fn dq16_at(w: u32, base: u32, bk: u32) -> f32 { | |
| let i0 = vec4<f16>(f16(input[base+bk]), f16(input[base+bk+1u]), f16(input[base+bk+2u]), f16(input[base+bk+3u])); | |
| let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); | |
| let i1 = vec4<f16>(f16(input[base+bk+4u]), f16(input[base+bk+5u]), f16(input[base+bk+6u]), f16(input[base+bk+7u])); | |
| let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); | |
| return f32(dot(i0, w0)) + f32(dot(i1, w1)); | |
| } | |
| var<workgroup> partial0: array<f32, 32>; | |
| var<workgroup> partial1: array<f32, 32>; | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { | |
| let tid = lid.x; | |
| let K = params.K; | |
| let N = params.N; | |
| let gs = params.group_size; | |
| let packed_rows = K / 8u; | |
| let rows_per_group = gs / 8u; | |
| let num_groups = packed_rows / rows_per_group; | |
| let col_in_wg = tid / 4u; | |
| let lane = tid & 3u; | |
| let col = wid.x * 8u + col_in_wg; | |
| var sum0: f32 = 0.0; | |
| var sum1: f32 = 0.0; | |
| if (col < N) { | |
| let groups_per_lane = num_groups / 4u; | |
| let g_start = lane * groups_per_lane; | |
| let g_end = g_start + groups_per_lane; | |
| for (var g = g_start; g < g_end; g++) { | |
| let sf = g * N + col; | |
| let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; | |
| let gstart = g * rows_per_group; | |
| for (var r: u32 = 0u; r < rows_per_group; r += 4u) { | |
| let pr = gstart + r; | |
| let off = pr * N + col; | |
| let w0 = qweight[off]; | |
| let w1 = qweight[off + N]; | |
| let w2 = qweight[off + 2u * N]; | |
| let w3 = qweight[off + 3u * N]; | |
| let bk = pr * 8u; | |
| sum0 += s * (dq16_at(w0, 0u, bk) + dq16_at(w1, 0u, bk+8u) + | |
| dq16_at(w2, 0u, bk+16u) + dq16_at(w3, 0u, bk+24u)); | |
| sum1 += s * (dq16_at(w0, K, bk) + dq16_at(w1, K, bk+8u) + | |
| dq16_at(w2, K, bk+16u) + dq16_at(w3, K, bk+24u)); | |
| } | |
| } | |
| } | |
| partial0[tid] = sum0; | |
| partial1[tid] = sum1; | |
| workgroupBarrier(); | |
| if (lane == 0u && col < N) { | |
| output[col] = partial0[tid] + partial0[tid+1u] + partial0[tid+2u] + partial0[tid+3u]; | |
| output[N + col] = partial1[tid] + partial1[tid+1u] + partial1[tid+2u] + partial1[tid+3u]; | |
| } | |
| }`,Pr=` | |
| struct Params { N: u32, eps: f32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read> weight: array<u32>; // BF16 packed | |
| @group(0) @binding(2) var<storage, read> bias: array<u32>; // BF16 packed | |
| @group(0) @binding(3) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| fn unpack_bf16(packed: u32, idx: u32) -> f32 { | |
| let bits = (packed >> (idx * 16u)) & 0xFFFFu; | |
| return bitcast<f32>(bits << 16u); | |
| } | |
| var<workgroup> shared_sum: array<f32, 256>; | |
| var<workgroup> shared_sq: array<f32, 256>; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { | |
| let N = params.N; | |
| let token = wid.x; | |
| let base = token * N; | |
| let tid = lid.x; | |
| // Phase 1: compute sum and sum-of-squares | |
| var s: f32 = 0.0; | |
| var sq: f32 = 0.0; | |
| for (var i = tid; i < N; i += 256u) { | |
| let v = input[base + i]; | |
| s += v; | |
| sq += v * v; | |
| } | |
| shared_sum[tid] = s; | |
| shared_sq[tid] = sq; | |
| workgroupBarrier(); | |
| // Parallel reduction | |
| for (var stride: u32 = 128u; stride > 0u; stride >>= 1u) { | |
| if (tid < stride) { | |
| shared_sum[tid] += shared_sum[tid + stride]; | |
| shared_sq[tid] += shared_sq[tid + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let mean = shared_sum[0] / f32(N); | |
| let variance = shared_sq[0] / f32(N) - mean * mean; | |
| let inv_std = 1.0 / sqrt(variance + params.eps); | |
| // Phase 2: normalize and apply weight + bias | |
| for (var i = tid; i < N; i += 256u) { | |
| let v = input[base + i]; | |
| let normalized = (v - mean) * inv_std; | |
| let w = unpack_bf16(weight[i >> 1u], i); | |
| let b = unpack_bf16(bias[i >> 1u], i); | |
| output[base + i] = normalized * w + b; | |
| } | |
| }`,Er=` | |
| struct Params { K: u32, N: u32, } | |
| @group(0) @binding(0) var<storage, read> input: array<f32>; | |
| @group(0) @binding(1) var<storage, read> weight: array<u32>; // BF16 packed [N, K/2] | |
| @group(0) @binding(2) var<storage, read> bias: array<u32>; // BF16 packed [N/2] | |
| @group(0) @binding(3) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let col = gid.x; // output feature index | |
| let token = gid.y; // batch/token index | |
| let K = params.K; | |
| let N = params.N; | |
| if (col >= N) { return; } | |
| let in_base = token * K; | |
| let w_base = col * K / 2u; | |
| var sum: f32 = 0.0; | |
| for (var k: u32 = 0u; k < K; k += 4u) { | |
| let p0 = weight[w_base + k / 2u]; | |
| let p1 = weight[w_base + k / 2u + 1u]; | |
| let w0 = bitcast<f32>((p0 & 0xFFFFu) << 16u); | |
| let w1 = bitcast<f32>(p0 & 0xFFFF0000u); | |
| let w2 = bitcast<f32>((p1 & 0xFFFFu) << 16u); | |
| let w3 = bitcast<f32>(p1 & 0xFFFF0000u); | |
| sum += input[in_base + k] * w0 + input[in_base + k + 1u] * w1 | |
| + input[in_base + k + 2u] * w2 + input[in_base + k + 3u] * w3; | |
| } | |
| // Add BF16 bias | |
| let bp = bias[col >> 1u]; | |
| let b = bitcast<f32>(select((bp & 0xFFFFu) << 16u, bp & 0xFFFF0000u, (col & 1u) == 1u)); | |
| output[token * N + col] = sum + b; | |
| }`,Sr=` | |
| @group(0) @binding(0) var<storage, read_write> data: array<f32>; | |
| @group(0) @binding(1) var<uniform> len: u32; | |
| const SQRT_2_OVER_PI: f32 = 0.7978845608; | |
| const COEFF: f32 = 0.044715; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let i = gid.x; | |
| if (i >= len) { return; } | |
| let x = data[i]; | |
| // For |x| > 10, GELU saturates: positive → x, negative → 0 | |
| // Avoids intermediate overflow in x^3 for large values | |
| if (x > 10.0) { return; } // data[i] = x (already in place) | |
| if (x < -10.0) { data[i] = 0.0; return; } | |
| let inner = SQRT_2_OVER_PI * (x + COEFF * x * x * x); | |
| data[i] = 0.5 * x * (1.0 + tanh(inner)); | |
| }`,Rr=` | |
| @group(0) @binding(0) var<storage, read_write> data: array<f32>; | |
| @group(0) @binding(1) var<uniform> len: u32; | |
| const SQRT_2_OVER_PI: f32 = 0.7978845608; | |
| const COEFF: f32 = 0.044715; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let i = gid.x; | |
| if (i >= len) { return; } | |
| let x = data[i]; | |
| if (x > 10.0) { return; } | |
| if (x < -10.0) { data[i] = 0.0; return; } | |
| let inner = SQRT_2_OVER_PI * (x + COEFF * x * x * x); | |
| data[i] = 0.5 * x * (1.0 + tanh(inner)); | |
| }`,Kr=` | |
| struct Params { seq_len: u32, num_heads: u32, head_dim: u32, } | |
| @group(0) @binding(0) var<storage, read> q_in: array<f32>; | |
| @group(0) @binding(1) var<storage, read> k_in: array<f32>; | |
| @group(0) @binding(2) var<storage, read_write> q_out: array<f32>; | |
| @group(0) @binding(3) var<storage, read_write> k_out: array<f32>; | |
| @group(0) @binding(4) var<storage, read> cos_buf: array<f32>; | |
| @group(0) @binding(5) var<storage, read> sin_buf: array<f32>; | |
| @group(0) @binding(6) var<uniform> params: Params; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let idx = gid.x; | |
| let total = params.seq_len * params.num_heads * params.head_dim; | |
| if (idx >= total) { return; } | |
| let d = idx % params.head_dim; | |
| let head_and_seq = idx / params.head_dim; | |
| let seq = head_and_seq / params.num_heads; | |
| let half_dim = params.head_dim / 2u; | |
| // cos/sin are [seq_len, head_dim] — same for all heads | |
| let cs_idx = seq * params.head_dim + d; | |
| let c = cos_buf[cs_idx]; | |
| let s = sin_buf[cs_idx]; | |
| // rotate_half: pair d with d+half_dim (or d-half_dim) | |
| let partner_d = select(d + half_dim, d - half_dim, d >= half_dim); | |
| let partner_idx = head_and_seq * params.head_dim + partner_d; | |
| let sign = select(-1.0, 1.0, d >= half_dim); | |
| q_out[idx] = q_in[idx] * c + sign * q_in[partner_idx] * s; | |
| k_out[idx] = k_in[idx] * c + sign * k_in[partner_idx] * s; | |
| }`,Br=` | |
| struct Params { seq_len: u32, num_heads: u32, head_dim: u32, scale: f32, } | |
| @group(0) @binding(0) var<storage, read> q: array<f32>; | |
| @group(0) @binding(1) var<storage, read> k: array<f32>; | |
| @group(0) @binding(2) var<storage, read> v: array<f32>; | |
| @group(0) @binding(3) var<storage, read_write> output: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| var<workgroup> wg_dot: array<f32, 32>; | |
| @compute @workgroup_size(32) | |
| fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { | |
| let q_pos = wid.x; | |
| let head = wid.y; | |
| let tid = lid.x; | |
| let S = params.seq_len; | |
| let H = params.num_heads; | |
| let D = params.head_dim; | |
| let scale = params.scale; | |
| let q_base = (q_pos * H + head) * D; | |
| // Each thread handles D elements with stride 32 (D=64→2, D=128→4 per thread) | |
| // Pre-load Q into registers | |
| var qc0: f32 = 0.0; var qc1: f32 = 0.0; var qc2: f32 = 0.0; var qc3: f32 = 0.0; | |
| qc0 = q[q_base + tid]; | |
| if (tid + 32u < D) { qc1 = q[q_base + tid + 32u]; } | |
| if (tid + 64u < D) { qc2 = q[q_base + tid + 64u]; } | |
| if (tid + 96u < D) { qc3 = q[q_base + tid + 96u]; } | |
| // Online softmax + V accumulation, parallelized across D dimension | |
| var running_max: f32 = -1e30; | |
| var running_sum: f32 = 0.0; | |
| var a0: f32 = 0.0; var a1: f32 = 0.0; var a2: f32 = 0.0; var a3: f32 = 0.0; | |
| for (var kp: u32 = 0u; kp < S; kp++) { | |
| let k_base = (kp * H + head) * D; | |
| // Partial dot product: each thread multiplies its D elements | |
| var pd: f32 = qc0 * k[k_base + tid]; | |
| if (tid + 32u < D) { pd += qc1 * k[k_base + tid + 32u]; } | |
| if (tid + 64u < D) { pd += qc2 * k[k_base + tid + 64u]; } | |
| if (tid + 96u < D) { pd += qc3 * k[k_base + tid + 96u]; } | |
| // Tree reduction: 32 → 1 | |
| wg_dot[tid] = pd; | |
| workgroupBarrier(); | |
| if (tid < 16u) { wg_dot[tid] += wg_dot[tid + 16u]; } | |
| workgroupBarrier(); | |
| if (tid < 8u) { wg_dot[tid] += wg_dot[tid + 8u]; } | |
| workgroupBarrier(); | |
| if (tid < 4u) { wg_dot[tid] += wg_dot[tid + 4u]; } | |
| workgroupBarrier(); | |
| if (tid < 2u) { wg_dot[tid] += wg_dot[tid + 2u]; } | |
| workgroupBarrier(); | |
| if (tid == 0u) { wg_dot[0] += wg_dot[1]; } | |
| workgroupBarrier(); | |
| let s = wg_dot[0] * scale; | |
| // Online softmax (all threads have same s after reduction) | |
| let old_max = running_max; | |
| running_max = max(running_max, s); | |
| let corr = exp(old_max - running_max); | |
| running_sum = running_sum * corr; | |
| a0 = a0 * corr; a1 = a1 * corr; a2 = a2 * corr; a3 = a3 * corr; | |
| let w = exp(s - running_max); | |
| running_sum += w; | |
| // Accumulate weighted V for this thread's D elements | |
| let v_base = (kp * H + head) * D; | |
| a0 += w * v[v_base + tid]; | |
| if (tid + 32u < D) { a1 += w * v[v_base + tid + 32u]; } | |
| if (tid + 64u < D) { a2 += w * v[v_base + tid + 64u]; } | |
| if (tid + 96u < D) { a3 += w * v[v_base + tid + 96u]; } | |
| } | |
| // Write output | |
| let inv_sum = 1.0 / running_sum; | |
| let out_base = (q_pos * H + head) * D; | |
| output[out_base + tid] = a0 * inv_sum; | |
| if (tid + 32u < D) { output[out_base + tid + 32u] = a1 * inv_sum; } | |
| if (tid + 64u < D) { output[out_base + tid + 64u] = a2 * inv_sum; } | |
| if (tid + 96u < D) { output[out_base + tid + 96u] = a3 * inv_sum; } | |
| }`,Dr=` | |
| @group(0) @binding(0) var<storage, read_write> a: array<f32>; | |
| @group(0) @binding(1) var<storage, read> b: array<f32>; | |
| @group(0) @binding(2) var<uniform> len: u32; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3u) { | |
| let i = gid.x; | |
| if (i >= len) { return; } | |
| a[i] += b[i]; | |
| }`,Ar=` | |
| struct Params { H: u32, } | |
| @group(0) @binding(0) var<storage, read> vision: array<f32>; | |
| @group(0) @binding(1) var<storage, read> scatter_indices: array<u32>; | |
| @group(0) @binding(2) var<storage, read_write> embeds: array<f32>; | |
| @group(0) @binding(3) var<uniform> params: Params; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { | |
| let vit_idx = wid.x; | |
| let pos = scatter_indices[vit_idx]; | |
| let H = params.H; | |
| let tid = lid.x; | |
| for (var i = tid; i < H; i += 256u) { | |
| embeds[pos * H + i] = vision[vit_idx * H + i]; | |
| } | |
| }`,zr={gptq_matvec:k,gptq_matvec_f16:h,gptq_matvec_4t:pr,gptq_matvec_4t_f16:fr,gptq_splitk:x,reduce_splitk:q,bf16_matvec:y,rmsnorm:F,silu_mul:N,add:P,embedding:R,embed_from_argmax:K,argmax:B,topk_extract:D,kv_cache_store:z,gqa_attention_head:T,gqa_reduce:H,deltanet_recurrent:M,head_rmsnorm:G,head_rmsnorm_nogated:O,causal_conv1d:U,split:E,sigmoid_mul:S,fused_gate_up_silu:C,fused_gate_up_silu_f16:L,fused_gate_up_silu_4t:Q,fused_gate_up_silu_4t_f16:j,add_rmsnorm:V,add_rmsnorm_ro:W,three_way_add_rmsnorm:Y,norm_gptq_lite:or,norm_gptq_lite_noadd:nr,fused_sigmoid_gptq:X,fused_sigmoid_gptq_f16:$,fused_sigmoid_gptq_4t:er,fused_sigmoid_gptq_4t_f16:tr,fused_silu_gptq:Z,fused_silu_gptq_f16:J,fused_silu_gptq_4t:rr,fused_silu_gptq_4t_f16:ar,fused_addnorm_gate_up_silu:ur,rep_penalty:gr,gpu_sample:dr,append_token:A,fused_norm_gptq:ir,fused_norm_gptq_noadd:sr,fused_conv_deltanet_norm:I,fused_split_qknorm_kvstore:lr,gptq_matvec_4t_f16_sk:_r,gather_rows_bf16:cr,quantize_bf16_to_int4:wr,pack_f32_to_f16_pairs:br,gptq_matmul_b2:yr,gptq_matmul_b2_f16:Fr,gptq_matmul_b2_4t_f16:Nr,add_rmsnorm_b2:vr,add_rmsnorm_ro_b2:mr,three_way_add_rmsnorm_b2:kr,fused_gate_up_silu_b2_f16:hr,fused_silu_gptq_b2_f16:xr,fused_sigmoid_gptq_b2_f16:qr,vit_layernorm:Pr,vit_bf16_matvec_bias:Er,vit_gelu_tanh:Sr,vit_gelu:Rr,vit_rope:Kr,vit_attention:Br,vit_add:Dr,vit_scatter_embed:Ar},Mr=Object.freeze(Object.defineProperty({__proto__:null,SHADERS:zr,SHADER_FUSED_SPLIT_QKNORM_KVSTORE_FN:_},Symbol.toStringTag,{value:"Module"}));export{b as G,zr as S,Hr as _,_ as a,Mr as b,Tr as g}; | |