tensorbend / assets /gpu-ops-flxI8RuZ.js
Ex0bit's picture
Fix vision image freeze: reduce maxPixels, parallelize ViT attention
f265480
(function(){const e=document.createElement("link").relList;if(e&&e.supports&&e.supports("modulepreload"))return;for(const r of document.querySelectorAll('link[rel="modulepreload"]'))t(r);new MutationObserver(r=>{for(const a of r)if(a.type==="childList")for(const i of a.addedNodes)i.tagName==="LINK"&&i.rel==="modulepreload"&&t(i)}).observe(document,{childList:!0,subtree:!0});function u(r){const a={};return r.integrity&&(a.integrity=r.integrity),r.referrerPolicy&&(a.referrerPolicy=r.referrerPolicy),r.crossOrigin==="use-credentials"?a.credentials="include":r.crossOrigin==="anonymous"?a.credentials="omit":a.credentials="same-origin",a}function t(r){if(r.ep)return;r.ep=!0;const a=u(r);fetch(r.href,a)}})();class b{constructor(){this.device=null,this.adapter=null,this.adapterInfo=null,this.pipelineCache=new Map,this.bufferCache=new Map,this.bindGroupCache=new Map}async init(){if(!navigator.gpu)throw new Error("WebGPU not supported");if(this.adapter=await navigator.gpu.requestAdapter({powerPreference:"high-performance"}),!this.adapter)throw new Error("No WebGPU adapter found");this.adapterInfo=await this.adapter.requestAdapterInfo?.()??{};const e={},u={maxBufferSize:4*1024*1024*1024,maxStorageBufferBindingSize:4*1024*1024*1024,maxComputeWorkgroupStorageSize:32768,maxComputeInvocationsPerWorkgroup:256,maxComputeWorkgroupSizeX:256,maxStorageBuffersPerShaderStage:10};for(const[r,a]of Object.entries(u))this.adapter.limits[r]!==void 0&&(e[r]=Math.min(a,this.adapter.limits[r]));const t=[];return this.adapter.features.has("shader-f16")&&t.push("shader-f16"),this.adapter.features.has("subgroups")&&t.push("subgroups"),this.device=await this.adapter.requestDevice({requiredLimits:e,requiredFeatures:t}),this.hasF16=this.device.features.has("shader-f16"),this.hasSubgroups=this.device.features.has("subgroups"),this.device.lost.then(r=>console.error("WebGPU device lost:",r)),this}createBuffer(e,u,t=GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_DST){const r=Math.ceil(u/4)*4,a=this.device.createBuffer({size:r,usage:t,label:e});return this.bufferCache.set(e,a),a}createBufferFromData(e,u,t=GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_DST|GPUBufferUsage.COPY_SRC){const r=this.createBuffer(e,u.byteLength,t);return this.device.queue.writeBuffer(r,0,u),r}createReadbackBuffer(e,u){const t=Math.ceil(u/4)*4;return this.device.createBuffer({size:t,usage:GPUBufferUsage.MAP_READ|GPUBufferUsage.COPY_DST,label:e+"_readback"})}getOrCreatePipeline(e,u,t="main"){if(this.pipelineCache.has(e))return this.pipelineCache.get(e);const r=this.device.createShaderModule({code:u,label:e}),a=this.device.createComputePipeline({layout:"auto",compute:{module:r,entryPoint:t},label:e});return this.pipelineCache.set(e,a),a}initTimestamps(){this.device.features.has("timestamp-query")&&(this._tsQuerySet=this.device.createQuerySet({type:"timestamp",count:2}),this._tsResolveBuf=this.device.createBuffer({size:16,usage:GPUBufferUsage.QUERY_RESOLVE|GPUBufferUsage.COPY_SRC}),this._tsReadBuf=this.device.createBuffer({size:16,usage:GPUBufferUsage.MAP_READ|GPUBufferUsage.COPY_DST}),this._tsEnabled=!0,this._tsResults=[])}beginBatch(){this._encoder=this.device.createCommandEncoder(),this._passCount=0,this.singlePassMode&&(this._singlePass=this._encoder.beginComputePass()),this._tsEnabled&&this._encoder.beginComputePass({timestampWrites:{querySet:this._tsQuerySet,beginningOfPassWriteIndex:0}}).end()}endBatch(){this._singlePass&&(this._singlePass.end(),this._singlePass=null),this._tsEnabled&&this._encoder&&(this._encoder.beginComputePass({timestampWrites:{querySet:this._tsQuerySet,beginningOfPassWriteIndex:1}}).end(),this._encoder.resolveQuerySet(this._tsQuerySet,0,2,this._tsResolveBuf,0),this._encoder.copyBufferToBuffer(this._tsResolveBuf,0,this._tsReadBuf,0,16)),this._encoder&&(this.device.queue.submit([this._encoder.finish()]),this._encoder=null)}async readTimestamp(){if(!this._tsEnabled)return null;await this._tsReadBuf.mapAsync(GPUMapMode.READ);const e=new BigInt64Array(this._tsReadBuf.getMappedRange().slice(0));this._tsReadBuf.unmap();const t=Number(e[1]-e[0])/1e6;return this._tsResults.push(t),t}copyBuffer(e,u,t,r=0,a=0){if(this._singlePass){this._singlePass.end(),this._encoder.copyBufferToBuffer(e,r,u,a,t),this._singlePass=this._encoder.beginComputePass();return}const i=this._encoder||this.device.createCommandEncoder();i.copyBufferToBuffer(e,r,u,a,t),this._encoder||this.device.queue.submit([i.finish()])}startRecording(){this._recording=[]}stopRecording(){const e=this._recording;return this._recording=null,e}replay(e,u){if(u)for(const r of u)this.device.queue.writeBuffer(r.buffer,r.offset,r.data,r.dataOffset,r.size);const t=this._encoder;for(let r=0;r<e.length;r++){const a=e[r];if(a.multi){const i=t.beginComputePass(),o=a.ops;for(let n=0;n<o.length;n++){const s=o[n];i.setPipeline(s.pipeline),i.setBindGroup(0,s.bindGroup),i.dispatchWorkgroups(s.wgX,s.wgY)}i.end()}else{const i=t.beginComputePass();i.setPipeline(a.pipeline),i.setBindGroup(0,a.bindGroup),i.dispatchWorkgroups(a.wgX,a.wgY),i.end()}}}dispatch(e,u,t,r=1,a=1){if(this._recording&&this._recording.push({pipeline:e,bindGroup:u[0],wgX:t,wgY:r}),this._singlePass){const n=this._singlePass;this._passCount!==void 0&&this._passCount++,n.setPipeline(e);for(let s=0;s<u.length;s++)n.setBindGroup(s,u[s]);n.dispatchWorkgroups(t,r,a);return}const i=this._encoder||this.device.createCommandEncoder(),o=i.beginComputePass();this._passCount!==void 0&&this._passCount++,o.setPipeline(e);for(let n=0;n<u.length;n++)o.setBindGroup(n,u[n]);o.dispatchWorkgroups(t,r,a),o.end(),this._encoder||this.device.queue.submit([i.finish()])}dispatchMulti(e){if(this._recording&&this._recording.push({multi:!0,ops:e.map(r=>({pipeline:r.pipeline,bindGroup:r.bindGroups[0],wgX:r.workgroupsX,wgY:r.workgroupsY||1}))}),this._singlePass){this._passCount!==void 0&&this._passCount++;for(const r of e){this._singlePass.setPipeline(r.pipeline);for(let a=0;a<r.bindGroups.length;a++)this._singlePass.setBindGroup(a,r.bindGroups[a]);this._singlePass.dispatchWorkgroups(r.workgroupsX,r.workgroupsY||1,r.workgroupsZ||1)}return}const u=this._encoder||this.device.createCommandEncoder(),t=u.beginComputePass();this._passCount!==void 0&&this._passCount++;for(const r of e){t.setPipeline(r.pipeline);for(let a=0;a<r.bindGroups.length;a++)t.setBindGroup(a,r.bindGroups[a]);t.dispatchWorkgroups(r.workgroupsX,r.workgroupsY||1,r.workgroupsZ||1)}t.end(),this._encoder||this.device.queue.submit([u.finish()])}async readBuffer(e,u){const t=this.createReadbackBuffer("_readback",u),r=this.device.createCommandEncoder();r.copyBufferToBuffer(e,0,t,0,u),this.device.queue.submit([r.finish()]),await t.mapAsync(GPUMapMode.READ);const a=new Float32Array(t.getMappedRange().slice(0));return t.unmap(),t.destroy(),a}createBindGroup(e,u,t){return this.device.createBindGroup({layout:e.getBindGroupLayout(u),entries:t.map((r,a)=>({binding:a,resource:{buffer:r}}))})}createBindGroupWithOffsets(e,u,t){return this.device.createBindGroup({layout:e.getBindGroupLayout(u),entries:t.map((r,a)=>({binding:a,resource:r.buffer?{buffer:r.buffer,offset:r.offset||0,size:r.size}:{buffer:r}}))})}getCachedBindGroup(e,u,t,r){let a=this.bindGroupCache.get(e);return a||(a=this.createBindGroup(u,t,r),this.bindGroupCache.set(e,a)),a}destroy(){for(const e of this.bufferCache.values())e.destroy();this.bufferCache.clear(),this.bindGroupCache.clear(),this.device?.destroy()}}const Tr=Object.freeze(Object.defineProperty({__proto__:null,GPUContext:b},Symbol.toStringTag,{value:"Module"})),v="modulepreload",m=function(d){return"/"+d},f={},Hr=function(e,u,t){let r=Promise.resolve();if(u&&u.length>0){let n=function(s){return Promise.all(s.map(l=>Promise.resolve(l).then(p=>({status:"fulfilled",value:p}),p=>({status:"rejected",reason:p}))))};document.getElementsByTagName("link");const i=document.querySelector("meta[property=csp-nonce]"),o=i?.nonce||i?.getAttribute("nonce");r=n(u.map(s=>{if(s=m(s),s in f)return;f[s]=!0;const l=s.endsWith(".css"),p=l?'[rel="stylesheet"]':"";if(document.querySelector(`link[href="${s}"]${p}`))return;const g=document.createElement("link");if(g.rel=l?"stylesheet":v,l||(g.as="script"),g.crossOrigin="",g.href=s,o&&g.setAttribute("nonce",o),document.head.appendChild(g),l)return new Promise((c,w)=>{g.addEventListener("load",c),g.addEventListener("error",()=>w(new Error(`Unable to preload CSS for ${s}`)))})}))}function a(i){const o=new Event("vite:preloadError",{cancelable:!0});if(o.payload=i,window.dispatchEvent(o),!o.defaultPrevented)throw i}return r.then(i=>{for(const o of i||[])o.status==="rejected"&&a(o.reason);return e().catch(a)})},k=`
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> qweight: array<u32>;
@group(0) @binding(2) var<storage, read> scales: array<u32>;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
fn dq(w: u32, bk: u32) -> f32 {
return dot(vec4<f32>(input[bk], input[bk+1u], input[bk+2u], input[bk+3u]),
vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.))
+ dot(vec4<f32>(input[bk+4u], input[bk+5u], input[bk+6u], input[bk+7u]),
vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let col = gid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
if (col >= N) { return; }
var sum: f32 = 0.0;
for (var g: u32 = 0u; g < num_groups; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off];
let w1 = qweight[off + N];
let w2 = qweight[off + 2u * N];
let w3 = qweight[off + 3u * N];
let bk = pr * 8u;
sum += s * (dq(w0, bk) + dq(w1, bk+8u) +
dq(w2, bk+16u) + dq(w3, bk+24u));
}
}
output[col] = sum;
}`,h=`
enable f16;
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> qweight: array<u32>;
@group(0) @binding(2) var<storage, read> scales: array<u32>;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
fn dq16(w: u32, bk: u32) -> f32 {
let i0 = vec4<f16>(f16(input[bk]), f16(input[bk+1u]), f16(input[bk+2u]), f16(input[bk+3u]));
let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.));
let i1 = vec4<f16>(f16(input[bk+4u]), f16(input[bk+5u]), f16(input[bk+6u]), f16(input[bk+7u]));
let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.));
return f32(dot(i0, w0)) + f32(dot(i1, w1));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let col = gid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
if (col >= N) { return; }
var sum: f32 = 0.0;
for (var g: u32 = 0u; g < num_groups; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off];
let w1 = qweight[off + N];
let w2 = qweight[off + 2u * N];
let w3 = qweight[off + 3u * N];
let bk = pr * 8u;
sum += s * (dq16(w0, bk) + dq16(w1, bk+8u) +
dq16(w2, bk+16u) + dq16(w3, bk+24u));
}
}
output[col] = sum;
}`,x=`
struct Params { K: u32, N: u32, group_size: u32, num_splits: u32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> qweight: array<u32>;
@group(0) @binding(2) var<storage, read> scales: array<u32>;
@group(0) @binding(3) var<storage, read_write> partials: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
fn dq(w: u32, bk: u32) -> f32 {
return dot(vec4<f32>(input[bk], input[bk+1u], input[bk+2u], input[bk+3u]),
vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.))
+ dot(vec4<f32>(input[bk+4u], input[bk+5u], input[bk+6u], input[bk+7u]),
vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u, @builtin(workgroup_id) wgid: vec3u) {
let col = gid.x;
let split_id = wgid.y;
let K = params.K;
let N = params.N;
let gs = params.group_size;
let ns = params.num_splits;
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
if (col >= N) { return; }
let groups_per_split = num_groups / ns;
let g_start = split_id * groups_per_split;
let g_end = select(g_start + groups_per_split, num_groups, split_id >= ns - 1u);
var sum: f32 = 0.0;
for (var g = g_start; g < g_end; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off];
let w1 = qweight[off + N];
let w2 = qweight[off + 2u * N];
let w3 = qweight[off + 3u * N];
let bk = pr * 8u;
sum += s * (dq(w0, bk) + dq(w1, bk+8u) +
dq(w2, bk+16u) + dq(w3, bk+24u));
}
}
partials[split_id * N + col] = sum;
}`,q=`
struct Params { N: u32, num_splits: u32, }
@group(0) @binding(0) var<storage, read> partials: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<uniform> params: Params;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let col = gid.x;
if (col >= params.N) { return; }
var sum: f32 = 0.0;
for (var s = 0u; s < params.num_splits; s++) {
sum += partials[s * params.N + col];
}
output[col] = sum;
}`,y=`
struct Params { K: u32, N: u32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> weight: array<u32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let col = gid.x;
if (col >= params.N) { return; }
var sum: f32 = 0.0;
let K = params.K;
let base = col * K / 2u;
// Process 4 BF16 values (2 u32 reads) per iteration
for (var k: u32 = 0u; k < K; k += 4u) {
let p0 = weight[base + k / 2u];
let p1 = weight[base + k / 2u + 1u];
let w0 = bitcast<f32>((p0 & 0xFFFFu) << 16u);
let w1 = bitcast<f32>(p0 & 0xFFFF0000u);
let w2 = bitcast<f32>((p1 & 0xFFFFu) << 16u);
let w3 = bitcast<f32>(p1 & 0xFFFF0000u);
sum += input[k] * w0 + input[k + 1u] * w1 + input[k + 2u] * w2 + input[k + 3u] * w3;
}
output[col] = sum;
}`,F=`
struct Params { N: u32, eps: f32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> weight: array<u32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
fn unpack_bf16(packed: u32, idx: u32) -> f32 {
let bits = (packed >> (idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
var<workgroup> wg_scratch: array<f32, 256>;
@compute @workgroup_size(256)
fn main(@builtin(local_invocation_id) lid: vec3u) {
let tid = lid.x;
let N = params.N;
// Parallel sum of squares
var ss: f32 = 0.0;
for (var i = tid; i < N; i += 256u) {
ss += input[i] * input[i];
}
wg_scratch[tid] = ss;
workgroupBarrier();
// Tree reduction
for (var s = 128u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps);
// Parallel normalize
for (var i = tid; i < N; i += 256u) {
let w = unpack_bf16(weight[i / 2u], i % 2u);
output[i] = input[i] * rms * (1.0 + w);
}
}`,N=`
struct Params { N: u32, }
@group(0) @binding(0) var<storage, read> gate: array<f32>;
@group(0) @binding(1) var<storage, read> up: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
if (i >= params.N) { return; }
let x = gate[i];
output[i] = (x / (1.0 + exp(-x))) * up[i];
}`,P=`
struct Params { N: u32, }
@group(0) @binding(0) var<storage, read_write> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<uniform> params: Params;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
if (i >= params.N) { return; }
a[i] = a[i] + b[i];
}`,E=`
struct Params { N: u32, num_heads: u32, head_dim: u32, }
@group(0) @binding(0) var<storage, read> src: array<f32>;
@group(0) @binding(1) var<storage, read_write> dst_a: array<f32>;
@group(0) @binding(2) var<storage, read_write> dst_b: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
if (i >= params.N) { return; }
let hd = params.head_dim;
let head = i / hd;
let d = i % hd;
dst_a[i] = src[head * hd * 2u + d];
dst_b[i] = src[head * hd * 2u + hd + d];
}`,S=`
struct Params { N: u32, }
@group(0) @binding(0) var<storage, read> x: array<f32>;
@group(0) @binding(1) var<storage, read> gate: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
if (i >= params.N) { return; }
output[i] = x[i] / (1.0 + exp(-gate[i]));
}`,R=`
struct Params { token_id: u32, dim: u32, }
@group(0) @binding(0) var<storage, read> embeddings: array<u32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<uniform> params: Params;
fn unpack_bf16(packed: u32, idx: u32) -> f32 {
let bits = (packed >> (idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
if (i >= params.dim) { return; }
let flat = params.token_id * params.dim + i;
output[i] = unpack_bf16(embeddings[flat / 2u], flat % 2u);
}`,K=`
struct ArgmaxResult { idx: u32, val: f32, }
struct Params { dim: u32, }
@group(0) @binding(0) var<storage, read> embeddings: array<u32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<storage, read> argmax_result: ArgmaxResult;
@group(0) @binding(3) var<uniform> params: Params;
fn unpack_bf16(packed: u32, idx: u32) -> f32 {
let bits = (packed >> (idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
if (i >= params.dim) { return; }
let flat = argmax_result.idx * params.dim + i;
output[i] = unpack_bf16(embeddings[flat / 2u], flat % 2u);
}`,B=`
struct Params { N: u32, }
struct Result { idx: u32, val: f32, }
@group(0) @binding(0) var<storage, read> logits: array<f32>;
@group(0) @binding(1) var<storage, read_write> result: Result;
@group(0) @binding(2) var<uniform> params: Params;
var<workgroup> s_val: array<f32, 256>;
var<workgroup> s_idx: array<u32, 256>;
@compute @workgroup_size(256)
fn main(@builtin(local_invocation_id) lid: vec3u) {
let tid = lid.x;
let N = params.N;
// Each thread scans its portion
var best_val: f32 = -1e30;
var best_idx: u32 = 0u;
for (var i = tid; i < N; i += 256u) {
if (logits[i] > best_val) {
best_val = logits[i];
best_idx = i;
}
}
s_val[tid] = best_val;
s_idx[tid] = best_idx;
workgroupBarrier();
// Tree reduction
for (var s = 128u; s > 0u; s >>= 1u) {
if (tid < s && s_val[tid + s] > s_val[tid]) {
s_val[tid] = s_val[tid + s];
s_idx[tid] = s_idx[tid + s];
}
workgroupBarrier();
}
if (tid == 0u) {
result.idx = s_idx[0];
result.val = s_val[0];
}
}`,D=`
struct Params { N: u32, }
@group(0) @binding(0) var<storage, read> logits: array<f32>;
@group(0) @binding(1) var<storage, read_write> result: array<u32>;
@group(0) @binding(2) var<uniform> params: Params;
@compute @workgroup_size(256)
fn main(@builtin(local_invocation_id) lid: vec3u) {
let tid = lid.x;
let N = params.N;
// Each thread scans its portion, keeping best 1
var best_val: f32 = -1e30;
var best_idx: u32 = 0u;
for (var i = tid; i < N; i += 256u) {
if (logits[i] > best_val) {
best_val = logits[i];
best_idx = i;
}
}
// Output (idx, val) pair — 256 candidates total
result[tid * 2u] = best_idx;
result[tid * 2u + 1u] = bitcast<u32>(best_val);
}`,A=`
struct ArgmaxResult { idx: u32, val: f32, }
struct Params { recent_count: u32, history_slot: u32, }
@group(0) @binding(0) var<storage, read> argmax_result: ArgmaxResult;
@group(0) @binding(1) var<storage, read_write> recent_tokens: array<u32>;
@group(0) @binding(2) var<storage, read_write> token_history: array<u32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(1)
fn main() {
let tok = argmax_result.idx;
recent_tokens[params.recent_count] = tok;
token_history[params.history_slot] = tok;
}`,z=`
struct Params { position: u32, num_kv_heads: u32, head_dim: u32, }
@group(0) @binding(0) var<storage, read> k_proj: array<f32>;
@group(0) @binding(1) var<storage, read> v_proj: array<f32>;
@group(0) @binding(2) var<storage, read_write> k_cache: array<f32>;
@group(0) @binding(3) var<storage, read_write> v_cache: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
let total = params.num_kv_heads * params.head_dim;
if (i >= total) { return; }
let offset = params.position * total + i;
k_cache[offset] = k_proj[i];
v_cache[offset] = v_proj[i];
}`,T=`
struct Params {
seq_len: u32,
head_dim: u32,
num_kv_heads: u32,
num_q_heads: u32,
heads_per_kv: u32,
num_splits: u32,
_pad0: u32,
_pad1: u32,
}
@group(0) @binding(0) var<storage, read> q_full: array<f32>;
@group(0) @binding(1) var<storage, read> k_cache: array<f32>;
@group(0) @binding(2) var<storage, read> v_cache: array<f32>;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
var<workgroup> s_dot: array<f32, 256>;
@compute @workgroup_size(256)
fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) {
let h = wg.x;
let split = wg.y;
let tid = lid.x;
let hd = params.head_dim;
let kv_head = h / params.heads_per_kv;
let q_offset = h * hd;
let kv_stride = params.num_kv_heads * hd;
let kv_offset = kv_head * hd;
let scale = 1.0 / sqrt(f32(hd));
let ns = params.num_splits;
// Compute this split's range of KV positions
let chunk = (params.seq_len + ns - 1u) / ns;
let pos_start = split * chunk;
let pos_end = min(pos_start + chunk, params.seq_len);
// Load Q (already RoPE'd)
let q_val = q_full[q_offset + tid];
// Online softmax over this split's positions
var running_max: f32 = -1e30;
var running_sum: f32 = 0.0;
var running_out: f32 = 0.0;
for (var pos = pos_start; pos < pos_end; pos++) {
let k_base = pos * kv_stride + kv_offset;
// K is already RoPE'd in cache
let k_val = k_cache[k_base + tid];
s_dot[tid] = q_val * k_val;
workgroupBarrier();
for (var s = 128u; s > 0u; s >>= 1u) {
if (tid < s) { s_dot[tid] += s_dot[tid + s]; }
workgroupBarrier();
}
let score = s_dot[0] * scale;
let v_val = v_cache[pos * kv_stride + kv_offset + tid];
let new_max = max(running_max, score);
let correction = exp(running_max - new_max);
running_sum = running_sum * correction + exp(score - new_max);
running_out = running_out * correction + exp(score - new_max) * v_val;
running_max = new_max;
}
if (ns == 1u) {
// Single split: write final output directly (no reduction needed)
if (running_sum > 0.0) {
output[q_offset + tid] = running_out / running_sum;
} else {
output[q_offset + tid] = 0.0;
}
} else {
// Multi-split: write partial results for reduction kernel
// Layout: [num_heads * num_splits * (head_dim + 2)]
// Per split: [partial_out[head_dim], log_sum_exp, running_max]
let stride = hd + 2u;
let base = (h * ns + split) * stride;
if (running_sum > 0.0) {
output[base + tid] = running_out / running_sum;
} else {
output[base + tid] = 0.0;
}
if (tid == 0u) {
output[base + hd] = log(running_sum) + running_max; // LSE
output[base + hd + 1u] = running_max;
}
}
}`,H=`
struct Params {
head_dim: u32,
num_splits: u32,
num_heads: u32,
_pad: u32,
}
@group(0) @binding(0) var<storage, read> partials: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<uniform> params: Params;
@compute @workgroup_size(256)
fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) {
let h = wg.x;
let tid = lid.x;
let hd = params.head_dim;
let ns = params.num_splits;
let stride = hd + 2u;
// Find global max LSE across splits
var global_max_lse: f32 = -1e30;
for (var s: u32 = 0u; s < ns; s++) {
let base = (h * ns + s) * stride;
let lse = partials[base + hd];
global_max_lse = max(global_max_lse, lse);
}
// Weighted combination using LSE rescaling
var acc: f32 = 0.0;
var weight_sum: f32 = 0.0;
for (var s: u32 = 0u; s < ns; s++) {
let base = (h * ns + s) * stride;
let lse = partials[base + hd];
let w = exp(lse - global_max_lse);
acc += w * partials[base + tid];
if (tid == 0u) {
weight_sum += w;
}
}
// Broadcast weight_sum from thread 0
// (all threads computed same weight_sum since lse is scalar, but only tid==0 accumulated)
// Actually all threads can compute weight_sum independently
var ws: f32 = 0.0;
for (var s: u32 = 0u; s < ns; s++) {
let base = (h * ns + s) * stride;
let lse = partials[base + hd];
ws += exp(lse - global_max_lse);
}
output[h * hd + tid] = acc / ws;
}`,M=`
struct Params {
num_heads: u32,
key_dim: u32,
value_dim: u32,
}
@group(0) @binding(0) var<storage, read> qkv: array<f32>;
@group(0) @binding(1) var<storage, read_write> state: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<storage, read> a_proj: array<f32>;
@group(0) @binding(4) var<storage, read> b_proj: array<f32>;
@group(0) @binding(5) var<storage, read> A_log: array<u32>;
@group(0) @binding(6) var<storage, read> dt_bias: array<u32>;
@group(0) @binding(7) var<uniform> params: Params;
fn sigmoid(x: f32) -> f32 { return 1.0 / (1.0 + exp(-x)); }
fn softplus(x: f32) -> f32 {
if (x > 20.0) { return x; }
return log(1.0 + exp(x));
}
fn unpack_bf16(packed: u32, idx: u32) -> f32 {
let bits = (packed >> (idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
@compute @workgroup_size(256)
fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) {
let h = wg.x;
let tid = lid.x;
let nh = params.num_heads;
let kd = params.key_dim;
let vd = params.value_dim;
if (h >= nh) { return; }
let qh_off = h * kd;
let kh_off = nh * kd + h * kd;
let vh_off = 2u * nh * kd + h * vd;
let state_off = h * kd * vd;
// Compute Q/K L2 norms (all threads compute redundantly — cheap)
var q_ss: f32 = 0.0;
for (var ki: u32 = 0u; ki < kd; ki++) { q_ss += qkv[qh_off + ki] * qkv[qh_off + ki]; }
let q_inv_norm = 1.0 / max(sqrt(q_ss), 1e-6);
var k_ss: f32 = 0.0;
for (var ki: u32 = 0u; ki < kd; ki++) { k_ss += qkv[kh_off + ki] * qkv[kh_off + ki]; }
let k_inv_norm = 1.0 / max(sqrt(k_ss), 1e-6);
// Decay and gate (same for all threads in this head)
let a_log_val = unpack_bf16(A_log[h / 2u], h % 2u);
let dt_bias_val = unpack_bf16(dt_bias[h / 2u], h % 2u);
let decay = exp(-exp(a_log_val) * softplus(a_proj[h] + dt_bias_val));
let beta = sigmoid(b_proj[h]);
// Each thread handles value dimensions in stride
for (var vi = tid; vi < vd; vi += 256u) {
// Step 1: Decay state column
for (var ki: u32 = 0u; ki < kd; ki++) {
let idx = state_off + ki * vd + vi;
state[idx] *= decay;
}
// Step 2: kv_mem = S^T @ k_normalized (for this vi)
var kv_mem: f32 = 0.0;
for (var ki: u32 = 0u; ki < kd; ki++) {
kv_mem += state[state_off + ki * vd + vi] * qkv[kh_off + ki] * k_inv_norm;
}
// Step 3: delta rule update
let v_val = qkv[vh_off + vi];
let delta = (v_val - kv_mem) * beta;
for (var ki: u32 = 0u; ki < kd; ki++) {
let idx = state_off + ki * vd + vi;
state[idx] += qkv[kh_off + ki] * k_inv_norm * delta;
}
// Step 4: Output = S^T @ (q_normalized * scale)
let scale = 1.0 / sqrt(f32(kd));
var o_val: f32 = 0.0;
for (var ki: u32 = 0u; ki < kd; ki++) {
o_val += state[state_off + ki * vd + vi] * qkv[qh_off + ki] * q_inv_norm * scale;
}
output[h * vd + vi] = o_val;
}
}`,G=`
struct Params { num_heads: u32, head_dim: u32, eps: f32, }
@group(0) @binding(0) var<storage, read_write> x: array<f32>;
@group(0) @binding(1) var<storage, read> weight: array<u32>;
@group(0) @binding(2) var<uniform> params: Params;
fn unpack_bf16(packed: u32, idx: u32) -> f32 {
let bits = (packed >> (idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
var<workgroup> wg_scratch: array<f32, 256>;
@compute @workgroup_size(256)
fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) {
let h = wg.x;
let tid = lid.x;
let hd = params.head_dim;
if (h >= params.num_heads) { return; }
let off = h * hd;
// Parallel sum of squares
var ss: f32 = 0.0;
for (var i = tid; i < hd; i += 256u) {
ss += x[off + i] * x[off + i];
}
wg_scratch[tid] = ss;
workgroupBarrier();
for (var s = 128u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
let rms = 1.0 / sqrt(wg_scratch[0] / f32(hd) + params.eps);
// Parallel normalize
for (var i = tid; i < hd; i += 256u) {
let w = unpack_bf16(weight[i / 2u], i % 2u);
x[off + i] = x[off + i] * rms * (1.0 + w);
}
}`,O=`
struct Params { num_heads: u32, head_dim: u32, eps: f32, }
@group(0) @binding(0) var<storage, read_write> x: array<f32>;
@group(0) @binding(1) var<storage, read> weight: array<u32>;
@group(0) @binding(2) var<uniform> params: Params;
fn unpack_bf16(packed: u32, idx: u32) -> f32 {
let bits = (packed >> (idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
var<workgroup> wg_scratch: array<f32, 256>;
@compute @workgroup_size(256)
fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) {
let h = wg.x;
let tid = lid.x;
let hd = params.head_dim;
if (h >= params.num_heads) { return; }
let off = h * hd;
var ss: f32 = 0.0;
for (var i = tid; i < hd; i += 256u) {
ss += x[off + i] * x[off + i];
}
wg_scratch[tid] = ss;
workgroupBarrier();
for (var s = 128u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
let rms = 1.0 / sqrt(wg_scratch[0] / f32(hd) + params.eps);
for (var i = tid; i < hd; i += 256u) {
let w = unpack_bf16(weight[i / 2u], i % 2u);
x[off + i] = x[off + i] * rms * w;
}
}`,U=`
struct Params { channels: u32, }
@group(0) @binding(0) var<storage, read_write> x: array<f32>;
@group(0) @binding(1) var<storage, read_write> hist: array<f32>;
@group(0) @binding(2) var<storage, read> weight: array<u32>;
@group(0) @binding(3) var<uniform> params: Params;
fn unpack_bf16(packed: u32, idx: u32) -> f32 {
let bits = (packed >> (idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let c = gid.x;
if (c >= params.channels) { return; }
let ch = params.channels;
let h0 = hist[c];
let h1 = hist[ch + c];
let h2 = hist[2u * ch + c];
let cur = x[c];
let w_base = c * 4u;
let w0 = unpack_bf16(weight[w_base / 2u], w_base % 2u);
let w1 = unpack_bf16(weight[(w_base + 1u) / 2u], (w_base + 1u) % 2u);
let w2 = unpack_bf16(weight[(w_base + 2u) / 2u], (w_base + 2u) % 2u);
let w3 = unpack_bf16(weight[(w_base + 3u) / 2u], (w_base + 3u) % 2u);
let conv_out = w0 * h0 + w1 * h1 + w2 * h2 + w3 * cur;
x[c] = conv_out / (1.0 + exp(-conv_out));
hist[c] = h1;
hist[ch + c] = h2;
hist[2u * ch + c] = cur;
}`,I=`
struct Params {
num_heads: u32,
key_dim: u32,
value_dim: u32,
total_channels: u32,
eps: f32,
hidden_size: u32,
num_value_heads: u32,
}
@group(0) @binding(0) var<storage, read_write> qkv: array<f32>;
@group(0) @binding(1) var<storage, read_write> hist: array<f32>;
@group(0) @binding(2) var<storage, read> conv_weight: array<u32>;
@group(0) @binding(3) var<storage, read_write> state: array<f32>;
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
@group(0) @binding(5) var<storage, read> hidden_input: array<f32>;
@group(0) @binding(6) var<storage, read> ab_weight: array<u32>;
@group(0) @binding(7) var<storage, read> A_log: array<u32>;
@group(0) @binding(8) var<storage, read> dt_bias: array<u32>;
@group(0) @binding(9) var<storage, read> norm_weight: array<u32>;
@group(0) @binding(10) var<uniform> params: Params;
fn unpack_bf16(packed: u32, idx: u32) -> f32 {
let bits = (packed >> (idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
fn sigmoid_f(x: f32) -> f32 { return 1.0 / (1.0 + exp(-x)); }
fn softplus_f(x: f32) -> f32 {
if (x > 20.0) { return x; }
return log(1.0 + exp(x));
}
var<workgroup> wg_scratch: array<f32, 128>;
@compute @workgroup_size(128)
fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) {
let h = wg.x;
let tid = lid.x;
let nh = params.num_heads;
let kd = params.key_dim;
let vd = params.value_dim;
let ch = params.total_channels;
let H = params.hidden_size;
let nhv = params.num_value_heads;
let vpk = nhv / nh; // value heads per key head (1 for 2B, 2 for 4B)
let evd = vpk * vd; // effective value dim per key head
if (h >= nh) { return; }
// Alpha/beta computed per value head inside the vpk loop (Phase 2).
// When H > 0: inline BF16 dot product (a_proj/b_proj are BF16 in merged ab_weight)
// When H == 0: pre-computed via GPTQ dispatches, alpha/beta read per-value-head
// ── Phase 1: Conv1d for this head's channels ──
// Head h has 3 blocks: Q[h*kd..], K[nh*kd + h*kd..], V[2*nh*kd + h*evd..]
// evd = vpk * vd (effective value dim per key head, handles asymmetric heads)
let cpb = kd + kd + evd; // channels per key head block
for (var i = tid; i < cpb; i += 128u) {
// Map local channel to global channel index
var c: u32;
if (i < kd) {
c = h * kd + i;
} else if (i < 2u * kd) {
c = nh * kd + h * kd + (i - kd);
} else {
c = 2u * nh * kd + h * evd + (i - 2u * kd);
}
// Conv1d kernel=4 with history
let h0 = hist[c];
let h1 = hist[ch + c];
let h2 = hist[2u * ch + c];
let cur = qkv[c];
let wb = c * 4u;
let w0 = unpack_bf16(conv_weight[wb / 2u], wb % 2u);
let w1 = unpack_bf16(conv_weight[(wb + 1u) / 2u], (wb + 1u) % 2u);
let w2 = unpack_bf16(conv_weight[(wb + 2u) / 2u], (wb + 2u) % 2u);
let w3 = unpack_bf16(conv_weight[(wb + 3u) / 2u], (wb + 3u) % 2u);
let conv_out = w0 * h0 + w1 * h1 + w2 * h2 + w3 * cur;
qkv[c] = conv_out / (1.0 + exp(-conv_out)); // SiLU
hist[c] = h1;
hist[ch + c] = h2;
hist[2u * ch + c] = cur;
}
workgroupBarrier();
// ── Phase 2+3: DeltaNet recurrent + RMSNorm, looped over vpk value heads ──
// Q/K are shared across value heads within this key head
let qh_off = h * kd;
let kh_off = nh * kd + h * kd;
// Q/K L2 norms (all threads compute redundantly, shared across value heads)
var q_ss: f32 = 0.0;
for (var ki: u32 = 0u; ki < kd; ki++) { q_ss += qkv[qh_off + ki] * qkv[qh_off + ki]; }
let q_inv_norm = 1.0 / max(sqrt(q_ss), 1e-6);
var k_ss: f32 = 0.0;
for (var ki: u32 = 0u; ki < kd; ki++) { k_ss += qkv[kh_off + ki] * qkv[kh_off + ki]; }
let k_inv_norm = 1.0 / max(sqrt(k_ss), 1e-6);
let qk_scale = 1.0 / sqrt(f32(kd));
// Loop over vpk value heads per key head (1 for 2B, 2 for 4B)
for (var vhi: u32 = 0u; vhi < vpk; vhi++) {
let vh = h * vpk + vhi; // global value head index
// Get per-value-head alpha/beta
var alpha: f32;
var beta_raw: f32;
if (H == 0u) {
// Quantized path: alpha/beta pre-computed per value head
alpha = hidden_input[vh];
beta_raw = bitcast<f32>(ab_weight[vh]);
} else {
// BF16 inline: dot product per value head
let half_H = H / 2u;
let a_base = vh * half_H;
let b_base = (nhv + vh) * half_H;
var a_sum: f32 = 0.0;
var b_sum: f32 = 0.0;
for (var k = tid * 4u; k < H; k += 512u) {
let hv0 = hidden_input[k];
let hv1 = hidden_input[k + 1u];
let hv2 = hidden_input[k + 2u];
let hv3 = hidden_input[k + 3u];
let ap0 = ab_weight[a_base + k / 2u];
let ap1 = ab_weight[a_base + k / 2u + 1u];
a_sum += hv0 * bitcast<f32>((ap0 & 0xFFFFu) << 16u)
+ hv1 * bitcast<f32>(ap0 & 0xFFFF0000u)
+ hv2 * bitcast<f32>((ap1 & 0xFFFFu) << 16u)
+ hv3 * bitcast<f32>(ap1 & 0xFFFF0000u);
let bp0 = ab_weight[b_base + k / 2u];
let bp1 = ab_weight[b_base + k / 2u + 1u];
b_sum += hv0 * bitcast<f32>((bp0 & 0xFFFFu) << 16u)
+ hv1 * bitcast<f32>(bp0 & 0xFFFF0000u)
+ hv2 * bitcast<f32>((bp1 & 0xFFFFu) << 16u)
+ hv3 * bitcast<f32>(bp1 & 0xFFFF0000u);
}
wg_scratch[tid] = a_sum;
workgroupBarrier();
for (var s = 64u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
alpha = wg_scratch[0];
wg_scratch[tid] = b_sum;
workgroupBarrier();
for (var s = 64u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
beta_raw = wg_scratch[0];
}
// Per-value-head decay and gate
let a_log_val = unpack_bf16(A_log[vh / 2u], vh % 2u);
let dt_bias_val = unpack_bf16(dt_bias[vh / 2u], vh % 2u);
let decay = exp(-exp(a_log_val) * softplus_f(alpha + dt_bias_val));
let beta = sigmoid_f(beta_raw);
// Per-value-head V offset, state offset, output offset
let v_off = 2u * nh * kd + vh * vd;
let s_off = vh * kd * vd;
let o_off = vh * vd;
// DeltaNet recurrent update for this value head
for (var vi = tid; vi < vd; vi += 128u) {
for (var ki: u32 = 0u; ki < kd; ki++) {
state[s_off + ki * vd + vi] *= decay;
}
var kv_mem: f32 = 0.0;
for (var ki: u32 = 0u; ki < kd; ki++) {
kv_mem += state[s_off + ki * vd + vi] * qkv[kh_off + ki] * k_inv_norm;
}
let v_val = qkv[v_off + vi];
let delta = (v_val - kv_mem) * beta;
for (var ki: u32 = 0u; ki < kd; ki++) {
state[s_off + ki * vd + vi] += qkv[kh_off + ki] * k_inv_norm * delta;
}
var o_val: f32 = 0.0;
for (var ki: u32 = 0u; ki < kd; ki++) {
o_val += state[s_off + ki * vd + vi] * qkv[qh_off + ki] * q_inv_norm * qk_scale;
}
output[o_off + vi] = o_val;
}
workgroupBarrier();
// ── RMSNorm for this value head (nogated, w formula) ──
// norm_weight is shared across all value heads (just vd weights)
var ss: f32 = 0.0;
for (var i = tid; i < vd; i += 128u) {
ss += output[o_off + i] * output[o_off + i];
}
wg_scratch[tid] = ss;
workgroupBarrier();
for (var s = 64u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
let rms = 1.0 / sqrt(wg_scratch[0] / f32(vd) + params.eps);
for (var i = tid; i < vd; i += 128u) {
let w = unpack_bf16(norm_weight[i / 2u], i % 2u);
output[o_off + i] = output[o_off + i] * rms * w;
}
workgroupBarrier();
}
}`,C=`
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> qweight_gate: array<u32>;
@group(0) @binding(2) var<storage, read> scales_gate: array<u32>;
@group(0) @binding(3) var<storage, read> qweight_up: array<u32>;
@group(0) @binding(4) var<storage, read> scales_up: array<u32>;
@group(0) @binding(5) var<storage, read_write> output: array<f32>;
@group(0) @binding(6) var<uniform> params: Params;
fn dq_in(w: u32, bk: u32) -> f32 {
return dot(vec4<f32>(input[bk], input[bk+1u], input[bk+2u], input[bk+3u]),
vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.))
+ dot(vec4<f32>(input[bk+4u], input[bk+5u], input[bk+6u], input[bk+7u]),
vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let col = gid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
if (col >= N) { return; }
var gate_sum: f32 = 0.0;
var up_sum: f32 = 0.0;
for (var grp: u32 = 0u; grp < num_groups; grp++) {
let sf = grp * N + col;
let sg = unpack2x16float(scales_gate[sf >> 1u])[sf & 1u];
let su = unpack2x16float(scales_up[sf >> 1u])[sf & 1u];
let gstart = grp * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let gw0 = qweight_gate[off]; let uw0 = qweight_up[off];
let gw1 = qweight_gate[off+N]; let uw1 = qweight_up[off+N];
let gw2 = qweight_gate[off+2u*N]; let uw2 = qweight_up[off+2u*N];
let gw3 = qweight_gate[off+3u*N]; let uw3 = qweight_up[off+3u*N];
let bk = pr * 8u;
gate_sum += sg * (dq_in(gw0,bk) + dq_in(gw1,bk+8u) + dq_in(gw2,bk+16u) + dq_in(gw3,bk+24u));
up_sum += su * (dq_in(uw0,bk) + dq_in(uw1,bk+8u) + dq_in(uw2,bk+16u) + dq_in(uw3,bk+24u));
}
}
output[col] = (gate_sum / (1.0 + exp(-gate_sum))) * up_sum;
}`,L=`
enable f16;
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> qweight_gate: array<u32>;
@group(0) @binding(2) var<storage, read> scales_gate: array<u32>;
@group(0) @binding(3) var<storage, read> qweight_up: array<u32>;
@group(0) @binding(4) var<storage, read> scales_up: array<u32>;
@group(0) @binding(5) var<storage, read_write> output: array<f32>;
@group(0) @binding(6) var<uniform> params: Params;
fn dq16(w: u32, bk: u32) -> f32 {
let i0 = vec4<f16>(f16(input[bk]), f16(input[bk+1u]), f16(input[bk+2u]), f16(input[bk+3u]));
let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.));
let i1 = vec4<f16>(f16(input[bk+4u]), f16(input[bk+5u]), f16(input[bk+6u]), f16(input[bk+7u]));
let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.));
return f32(dot(i0, w0)) + f32(dot(i1, w1));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let col = gid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
if (col >= N) { return; }
var gate_sum: f32 = 0.0;
var up_sum: f32 = 0.0;
for (var grp: u32 = 0u; grp < num_groups; grp++) {
let sf = grp * N + col;
let sg = unpack2x16float(scales_gate[sf >> 1u])[sf & 1u];
let su = unpack2x16float(scales_up[sf >> 1u])[sf & 1u];
let gstart = grp * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let gw0 = qweight_gate[off]; let uw0 = qweight_up[off];
let gw1 = qweight_gate[off+N]; let uw1 = qweight_up[off+N];
let gw2 = qweight_gate[off+2u*N]; let uw2 = qweight_up[off+2u*N];
let gw3 = qweight_gate[off+3u*N]; let uw3 = qweight_up[off+3u*N];
let bk = pr * 8u;
gate_sum += sg * (dq16(gw0,bk) + dq16(gw1,bk+8u) + dq16(gw2,bk+16u) + dq16(gw3,bk+24u));
up_sum += su * (dq16(uw0,bk) + dq16(uw1,bk+8u) + dq16(uw2,bk+16u) + dq16(uw3,bk+24u));
}
}
output[col] = (gate_sum / (1.0 + exp(-gate_sum))) * up_sum;
}`,Q=`
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> qweight_gate: array<u32>;
@group(0) @binding(2) var<storage, read> scales_gate: array<u32>;
@group(0) @binding(3) var<storage, read> qweight_up: array<u32>;
@group(0) @binding(4) var<storage, read> scales_up: array<u32>;
@group(0) @binding(5) var<storage, read_write> output: array<f32>;
@group(0) @binding(6) var<uniform> params: Params;
var<workgroup> partial: array<f32, 64>;
fn dq_in(w: u32, bk: u32) -> f32 {
return dot(vec4<f32>(input[bk], input[bk+1u], input[bk+2u], input[bk+3u]),
vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.))
+ dot(vec4<f32>(input[bk+4u], input[bk+5u], input[bk+6u], input[bk+7u]),
vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.));
}
@compute @workgroup_size(32)
fn main(@builtin(local_invocation_id) lid: vec3u,
@builtin(workgroup_id) wid: vec3u) {
let tid = lid.x;
let col = wid.x * 8u + (tid / 4u);
let lane = tid & 3u;
let K = params.K;
let N = params.N;
let gs = params.group_size;
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
let gpl = num_groups / 4u;
let g_start = lane * gpl;
let g_end = g_start + gpl;
var gate_sum: f32 = 0.0;
var up_sum: f32 = 0.0;
if (col < N) {
for (var grp = g_start; grp < g_end; grp++) {
let sf = grp * N + col;
let sg = unpack2x16float(scales_gate[sf >> 1u])[sf & 1u];
let su = unpack2x16float(scales_up[sf >> 1u])[sf & 1u];
let gstart_row = grp * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart_row + r;
let off = pr * N + col;
let gw0 = qweight_gate[off]; let uw0 = qweight_up[off];
let gw1 = qweight_gate[off+N]; let uw1 = qweight_up[off+N];
let gw2 = qweight_gate[off+2u*N]; let uw2 = qweight_up[off+2u*N];
let gw3 = qweight_gate[off+3u*N]; let uw3 = qweight_up[off+3u*N];
let bk = pr * 8u;
gate_sum += sg * (dq_in(gw0,bk) + dq_in(gw1,bk+8u) + dq_in(gw2,bk+16u) + dq_in(gw3,bk+24u));
up_sum += su * (dq_in(uw0,bk) + dq_in(uw1,bk+8u) + dq_in(uw2,bk+16u) + dq_in(uw3,bk+24u));
}
}
}
partial[tid] = gate_sum;
partial[32u + tid] = up_sum;
workgroupBarrier();
if (lane == 0u && col < N) {
let g = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u];
let u = partial[32u+tid] + partial[32u+tid+1u] + partial[32u+tid+2u] + partial[32u+tid+3u];
output[col] = (g / (1.0 + exp(-g))) * u;
}
}`,j=`
enable f16;
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> qweight_gate: array<u32>;
@group(0) @binding(2) var<storage, read> scales_gate: array<u32>;
@group(0) @binding(3) var<storage, read> qweight_up: array<u32>;
@group(0) @binding(4) var<storage, read> scales_up: array<u32>;
@group(0) @binding(5) var<storage, read_write> output: array<f32>;
@group(0) @binding(6) var<uniform> params: Params;
var<workgroup> partial: array<f32, 64>;
fn dq16(w: u32, bk: u32) -> f32 {
let i0 = vec4<f16>(f16(input[bk]), f16(input[bk+1u]), f16(input[bk+2u]), f16(input[bk+3u]));
let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.));
let i1 = vec4<f16>(f16(input[bk+4u]), f16(input[bk+5u]), f16(input[bk+6u]), f16(input[bk+7u]));
let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.));
return f32(dot(i0, w0)) + f32(dot(i1, w1));
}
@compute @workgroup_size(32)
fn main(@builtin(local_invocation_id) lid: vec3u,
@builtin(workgroup_id) wid: vec3u) {
let tid = lid.x;
let col = wid.x * 8u + (tid / 4u);
let lane = tid & 3u;
let K = params.K;
let N = params.N;
let gs = params.group_size;
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
let gpl = num_groups / 4u;
let g_start = lane * gpl;
let g_end = g_start + gpl;
var gate_sum: f32 = 0.0;
var up_sum: f32 = 0.0;
if (col < N) {
for (var grp = g_start; grp < g_end; grp++) {
let sf = grp * N + col;
let sg = unpack2x16float(scales_gate[sf >> 1u])[sf & 1u];
let su = unpack2x16float(scales_up[sf >> 1u])[sf & 1u];
let gstart_row = grp * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart_row + r;
let off = pr * N + col;
let gw0 = qweight_gate[off]; let uw0 = qweight_up[off];
let gw1 = qweight_gate[off+N]; let uw1 = qweight_up[off+N];
let gw2 = qweight_gate[off+2u*N]; let uw2 = qweight_up[off+2u*N];
let gw3 = qweight_gate[off+3u*N]; let uw3 = qweight_up[off+3u*N];
let bk = pr * 8u;
gate_sum += sg * (dq16(gw0,bk) + dq16(gw1,bk+8u) + dq16(gw2,bk+16u) + dq16(gw3,bk+24u));
up_sum += su * (dq16(uw0,bk) + dq16(uw1,bk+8u) + dq16(uw2,bk+16u) + dq16(uw3,bk+24u));
}
}
}
partial[tid] = gate_sum;
partial[32u + tid] = up_sum;
workgroupBarrier();
if (lane == 0u && col < N) {
let g = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u];
let u = partial[32u+tid] + partial[32u+tid+1u] + partial[32u+tid+2u] + partial[32u+tid+3u];
output[col] = (g / (1.0 + exp(-g))) * u;
}
}`,V=`
struct Params { N: u32, eps: f32, }
@group(0) @binding(0) var<storage, read_write> hidden: array<f32>;
@group(0) @binding(1) var<storage, read> addend: array<f32>;
@group(0) @binding(2) var<storage, read> weight: array<u32>;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
fn unpack_bf16(packed: u32, idx: u32) -> f32 {
let bits = (packed >> (idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
var<workgroup> wg_scratch: array<f32, 256>;
@compute @workgroup_size(256)
fn main(@builtin(local_invocation_id) lid: vec3u) {
let tid = lid.x;
let N = params.N;
// Phase 1: Add + compute sum of squares
var ss: f32 = 0.0;
for (var i = tid; i < N; i += 256u) {
let val = hidden[i] + addend[i];
hidden[i] = val;
ss += val * val;
}
wg_scratch[tid] = ss;
workgroupBarrier();
// Tree reduction
for (var s = 128u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps);
// Phase 2: Normalize
for (var i = tid; i < N; i += 256u) {
let w = unpack_bf16(weight[i / 2u], i % 2u);
output[i] = hidden[i] * rms * (1.0 + w);
}
}`,W=`
struct Params { N: u32, eps: f32, }
@group(0) @binding(0) var<storage, read> hidden: array<f32>;
@group(0) @binding(1) var<storage, read> addend: array<f32>;
@group(0) @binding(2) var<storage, read> weight: array<u32>;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
fn unpack_bf16(packed: u32, idx: u32) -> f32 {
let bits = (packed >> (idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
var<workgroup> wg_vals: array<f32, 7680>;
var<workgroup> wg_scratch: array<f32, 256>;
@compute @workgroup_size(256)
fn main(@builtin(local_invocation_id) lid: vec3u) {
let tid = lid.x;
let N = params.N;
var ss: f32 = 0.0;
for (var i = tid; i < N; i += 256u) {
let val = hidden[i] + addend[i];
wg_vals[i] = val;
ss += val * val;
}
wg_scratch[tid] = ss;
workgroupBarrier();
for (var s = 128u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps);
for (var i = tid; i < N; i += 256u) {
let w = unpack_bf16(weight[i / 2u], i % 2u);
output[i] = wg_vals[i] * rms * (1.0 + w);
}
}`,Y=`
struct Params { N: u32, eps: f32, }
@group(0) @binding(0) var<storage, read> input_a: array<f32>;
@group(0) @binding(1) var<storage, read> input_b: array<f32>;
@group(0) @binding(2) var<storage, read> input_c: array<f32>;
@group(0) @binding(3) var<storage, read> weight: array<u32>;
@group(0) @binding(4) var<storage, read_write> hidden_out: array<f32>;
@group(0) @binding(5) var<storage, read_write> normed: array<f32>;
@group(0) @binding(6) var<uniform> params: Params;
fn unpack_bf16(packed: u32, idx: u32) -> f32 {
let bits = (packed >> (idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
var<workgroup> wg_vals: array<f32, 7680>;
var<workgroup> wg_scratch: array<f32, 256>;
@compute @workgroup_size(256)
fn main(@builtin(local_invocation_id) lid: vec3u) {
let tid = lid.x;
let N = params.N;
var ss: f32 = 0.0;
for (var i = tid; i < N; i += 256u) {
let val = input_a[i] + input_b[i] + input_c[i];
wg_vals[i] = val;
hidden_out[i] = val;
ss += val * val;
}
wg_scratch[tid] = ss;
workgroupBarrier();
for (var s = 128u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps);
for (var i = tid; i < N; i += 256u) {
let w = unpack_bf16(weight[i / 2u], i % 2u);
normed[i] = wg_vals[i] * rms * (1.0 + w);
}
}`,X=`
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> x: array<f32>;
@group(0) @binding(1) var<storage, read> gate: array<f32>;
@group(0) @binding(2) var<storage, read> qweight: array<u32>;
@group(0) @binding(3) var<storage, read> scales: array<u32>;
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
@group(0) @binding(5) var<uniform> params: Params;
fn sig(v: f32) -> f32 { return 1.0 / (1.0 + exp(-v)); }
fn dq_sg(w: u32, bk: u32) -> f32 {
let il = vec4<f32>(x[bk]*sig(gate[bk]), x[bk+1u]*sig(gate[bk+1u]), x[bk+2u]*sig(gate[bk+2u]), x[bk+3u]*sig(gate[bk+3u]));
let ih = vec4<f32>(x[bk+4u]*sig(gate[bk+4u]), x[bk+5u]*sig(gate[bk+5u]), x[bk+6u]*sig(gate[bk+6u]), x[bk+7u]*sig(gate[bk+7u]));
return dot(il, vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.))
+ dot(ih, vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let col = gid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
if (col >= N) { return; }
var sum: f32 = 0.0;
for (var g: u32 = 0u; g < num_groups; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N];
let bk = pr * 8u;
sum += s * (dq_sg(w0,bk) + dq_sg(w1,bk+8u) + dq_sg(w2,bk+16u) + dq_sg(w3,bk+24u));
}
}
output[col] = sum;
}`,$=`
enable f16;
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> x: array<f32>;
@group(0) @binding(1) var<storage, read> gate: array<f32>;
@group(0) @binding(2) var<storage, read> qweight: array<u32>;
@group(0) @binding(3) var<storage, read> scales: array<u32>;
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
@group(0) @binding(5) var<uniform> params: Params;
fn sig(v: f32) -> f32 { return 1.0 / (1.0 + exp(-v)); }
fn dq16_sg(w: u32, bk: u32) -> f32 {
let i0 = vec4<f16>(f16(x[bk]*sig(gate[bk])), f16(x[bk+1u]*sig(gate[bk+1u])), f16(x[bk+2u]*sig(gate[bk+2u])), f16(x[bk+3u]*sig(gate[bk+3u])));
let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.));
let i1 = vec4<f16>(f16(x[bk+4u]*sig(gate[bk+4u])), f16(x[bk+5u]*sig(gate[bk+5u])), f16(x[bk+6u]*sig(gate[bk+6u])), f16(x[bk+7u]*sig(gate[bk+7u])));
let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.));
return f32(dot(i0, w0)) + f32(dot(i1, w1));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let col = gid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
if (col >= N) { return; }
var sum: f32 = 0.0;
for (var g: u32 = 0u; g < num_groups; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N];
let bk = pr * 8u;
sum += s * (dq16_sg(w0,bk) + dq16_sg(w1,bk+8u) + dq16_sg(w2,bk+16u) + dq16_sg(w3,bk+24u));
}
}
output[col] = sum;
}`,Z=`
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read> qweight: array<u32>;
@group(0) @binding(3) var<storage, read> scales: array<u32>;
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
@group(0) @binding(5) var<uniform> params: Params;
fn silu(v: f32) -> f32 { return v / (1.0 + exp(-v)); }
fn dq_sl(w: u32, bk: u32) -> f32 {
let il = vec4<f32>(silu(a[bk])*b[bk], silu(a[bk+1u])*b[bk+1u], silu(a[bk+2u])*b[bk+2u], silu(a[bk+3u])*b[bk+3u]);
let ih = vec4<f32>(silu(a[bk+4u])*b[bk+4u], silu(a[bk+5u])*b[bk+5u], silu(a[bk+6u])*b[bk+6u], silu(a[bk+7u])*b[bk+7u]);
return dot(il, vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.))
+ dot(ih, vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let col = gid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
if (col >= N) { return; }
var sum: f32 = 0.0;
for (var g: u32 = 0u; g < num_groups; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N];
let bk = pr * 8u;
sum += s * (dq_sl(w0,bk) + dq_sl(w1,bk+8u) + dq_sl(w2,bk+16u) + dq_sl(w3,bk+24u));
}
}
output[col] = sum;
}`,J=`
enable f16;
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read> qweight: array<u32>;
@group(0) @binding(3) var<storage, read> scales: array<u32>;
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
@group(0) @binding(5) var<uniform> params: Params;
fn silu(v: f32) -> f32 { return v / (1.0 + exp(-v)); }
fn dq16_sl(w: u32, bk: u32) -> f32 {
let i0 = vec4<f16>(f16(silu(a[bk])*b[bk]), f16(silu(a[bk+1u])*b[bk+1u]), f16(silu(a[bk+2u])*b[bk+2u]), f16(silu(a[bk+3u])*b[bk+3u]));
let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.));
let i1 = vec4<f16>(f16(silu(a[bk+4u])*b[bk+4u]), f16(silu(a[bk+5u])*b[bk+5u]), f16(silu(a[bk+6u])*b[bk+6u]), f16(silu(a[bk+7u])*b[bk+7u]));
let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.));
return f32(dot(i0, w0)) + f32(dot(i1, w1));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let col = gid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
if (col >= N) { return; }
var sum: f32 = 0.0;
for (var g: u32 = 0u; g < num_groups; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N];
let bk = pr * 8u;
sum += s * (dq16_sl(w0,bk) + dq16_sl(w1,bk+8u) + dq16_sl(w2,bk+16u) + dq16_sl(w3,bk+24u));
}
}
output[col] = sum;
}`,rr=`
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read> qweight: array<u32>;
@group(0) @binding(3) var<storage, read> scales: array<u32>;
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
@group(0) @binding(5) var<uniform> params: Params;
fn silu(v: f32) -> f32 { return v / (1.0 + exp(-v)); }
fn dq_sl(w: u32, bk: u32) -> f32 {
let il = vec4<f32>(silu(a[bk])*b[bk], silu(a[bk+1u])*b[bk+1u], silu(a[bk+2u])*b[bk+2u], silu(a[bk+3u])*b[bk+3u]);
let ih = vec4<f32>(silu(a[bk+4u])*b[bk+4u], silu(a[bk+5u])*b[bk+5u], silu(a[bk+6u])*b[bk+6u], silu(a[bk+7u])*b[bk+7u]);
return dot(il, vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.))
+ dot(ih, vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.));
}
var<workgroup> partial: array<f32, 32>;
@compute @workgroup_size(32)
fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) {
let tid = lid.x;
let N = params.N;
let gs = params.group_size;
let packed_rows = params.K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
let col = wid.x * 8u + (tid / 4u);
let lane = tid & 3u;
var sum: f32 = 0.0;
if (col < N) {
let gpl = num_groups / 4u;
let g_start = lane * gpl;
let g_end = g_start + gpl;
for (var g = g_start; g < g_end; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N];
let bk = pr * 8u;
sum += s * (dq_sl(w0,bk) + dq_sl(w1,bk+8u) + dq_sl(w2,bk+16u) + dq_sl(w3,bk+24u));
}
}
}
partial[tid] = sum;
workgroupBarrier();
if (lane == 0u && col < N) {
output[col] = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u];
}
}`,ar=`
enable f16;
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read> qweight: array<u32>;
@group(0) @binding(3) var<storage, read> scales: array<u32>;
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
@group(0) @binding(5) var<uniform> params: Params;
fn silu(v: f32) -> f32 { return v / (1.0 + exp(-v)); }
fn dq16_sl(w: u32, bk: u32) -> f32 {
let i0 = vec4<f16>(f16(silu(a[bk])*b[bk]), f16(silu(a[bk+1u])*b[bk+1u]), f16(silu(a[bk+2u])*b[bk+2u]), f16(silu(a[bk+3u])*b[bk+3u]));
let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.));
let i1 = vec4<f16>(f16(silu(a[bk+4u])*b[bk+4u]), f16(silu(a[bk+5u])*b[bk+5u]), f16(silu(a[bk+6u])*b[bk+6u]), f16(silu(a[bk+7u])*b[bk+7u]));
let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.));
return f32(dot(i0, w0)) + f32(dot(i1, w1));
}
var<workgroup> partial: array<f32, 32>;
@compute @workgroup_size(32)
fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) {
let tid = lid.x;
let N = params.N;
let gs = params.group_size;
let packed_rows = params.K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
let col = wid.x * 8u + (tid / 4u);
let lane = tid & 3u;
var sum: f32 = 0.0;
if (col < N) {
let gpl = num_groups / 4u;
let g_start = lane * gpl;
let g_end = g_start + gpl;
for (var g = g_start; g < g_end; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N];
let bk = pr * 8u;
sum += s * (dq16_sl(w0,bk) + dq16_sl(w1,bk+8u) + dq16_sl(w2,bk+16u) + dq16_sl(w3,bk+24u));
}
}
}
partial[tid] = sum;
workgroupBarrier();
if (lane == 0u && col < N) {
output[col] = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u];
}
}`,er=`
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> x: array<f32>;
@group(0) @binding(1) var<storage, read> gate: array<f32>;
@group(0) @binding(2) var<storage, read> qweight: array<u32>;
@group(0) @binding(3) var<storage, read> scales: array<u32>;
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
@group(0) @binding(5) var<uniform> params: Params;
fn sig(v: f32) -> f32 { return 1.0 / (1.0 + exp(-v)); }
fn dq_sg(w: u32, bk: u32) -> f32 {
let il = vec4<f32>(x[bk]*sig(gate[bk]), x[bk+1u]*sig(gate[bk+1u]), x[bk+2u]*sig(gate[bk+2u]), x[bk+3u]*sig(gate[bk+3u]));
let ih = vec4<f32>(x[bk+4u]*sig(gate[bk+4u]), x[bk+5u]*sig(gate[bk+5u]), x[bk+6u]*sig(gate[bk+6u]), x[bk+7u]*sig(gate[bk+7u]));
return dot(il, vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.))
+ dot(ih, vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.));
}
var<workgroup> partial: array<f32, 32>;
@compute @workgroup_size(32)
fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) {
let tid = lid.x;
let N = params.N;
let gs = params.group_size;
let packed_rows = params.K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
let col = wid.x * 8u + (tid / 4u);
let lane = tid & 3u;
var sum: f32 = 0.0;
if (col < N) {
let gpl = num_groups / 4u;
let g_start = lane * gpl;
let g_end = g_start + gpl;
for (var g = g_start; g < g_end; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N];
let bk = pr * 8u;
sum += s * (dq_sg(w0,bk) + dq_sg(w1,bk+8u) + dq_sg(w2,bk+16u) + dq_sg(w3,bk+24u));
}
}
}
partial[tid] = sum;
workgroupBarrier();
if (lane == 0u && col < N) {
output[col] = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u];
}
}`,tr=`
enable f16;
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> x: array<f32>;
@group(0) @binding(1) var<storage, read> gate: array<f32>;
@group(0) @binding(2) var<storage, read> qweight: array<u32>;
@group(0) @binding(3) var<storage, read> scales: array<u32>;
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
@group(0) @binding(5) var<uniform> params: Params;
fn sig(v: f32) -> f32 { return 1.0 / (1.0 + exp(-v)); }
fn dq16_sg(w: u32, bk: u32) -> f32 {
let i0 = vec4<f16>(f16(x[bk]*sig(gate[bk])), f16(x[bk+1u]*sig(gate[bk+1u])), f16(x[bk+2u]*sig(gate[bk+2u])), f16(x[bk+3u]*sig(gate[bk+3u])));
let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.));
let i1 = vec4<f16>(f16(x[bk+4u]*sig(gate[bk+4u])), f16(x[bk+5u]*sig(gate[bk+5u])), f16(x[bk+6u]*sig(gate[bk+6u])), f16(x[bk+7u]*sig(gate[bk+7u])));
let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.));
return f32(dot(i0, w0)) + f32(dot(i1, w1));
}
var<workgroup> partial: array<f32, 32>;
@compute @workgroup_size(32)
fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) {
let tid = lid.x;
let N = params.N;
let gs = params.group_size;
let packed_rows = params.K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
let col = wid.x * 8u + (tid / 4u);
let lane = tid & 3u;
var sum: f32 = 0.0;
if (col < N) {
let gpl = num_groups / 4u;
let g_start = lane * gpl;
let g_end = g_start + gpl;
for (var g = g_start; g < g_end; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N];
let bk = pr * 8u;
sum += s * (dq16_sg(w0,bk) + dq16_sg(w1,bk+8u) + dq16_sg(w2,bk+16u) + dq16_sg(w3,bk+24u));
}
}
}
partial[tid] = sum;
workgroupBarrier();
if (lane == 0u && col < N) {
output[col] = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u];
}
}`,ur=`
struct Params {
K: u32, // hidden_size
N: u32, // intermediate_size
group_size: u32,
eps: f32,
// BF16 norm weights packed as vec4<u32> for correct uniform alignment.
norm_weight: array<vec4<u32>, 640>,
}
@group(0) @binding(0) var<storage, read> hidden_in: array<f32>;
@group(0) @binding(1) var<storage, read> mlp_residual: array<f32>;
@group(0) @binding(2) var<storage, read> attn_residual: array<f32>;
@group(0) @binding(3) var<storage, read> qweight_merged: array<u32>;
@group(0) @binding(4) var<storage, read> scales_merged: array<u32>;
@group(0) @binding(5) var<storage, read_write> hidden_out: array<f32>;
@group(0) @binding(6) var<storage, read_write> output: array<f32>;
@group(0) @binding(7) var<uniform> params: Params;
fn unpack_norm_weight(i: u32) -> f32 {
// Each vec4<u32> holds 8 BF16 values (4 u32 × 2 per u32)
let vec_idx = i / 8u;
let u32_idx = (i / 2u) % 4u;
let bf16_idx = i % 2u;
let packed = params.norm_weight[vec_idx][u32_idx];
let bits = (packed >> (bf16_idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
// On-the-fly normed value: avoids 16KB shared memory (wg_normed) that killed occupancy.
// Inputs are L2-cached (~24KB total), so re-reading is cheap.
var<workgroup> wg_scratch: array<f32, 32>;
fn normed_3way(i: u32, rms: f32) -> f32 {
let val = hidden_in[i] + mlp_residual[i] + attn_residual[i];
let w = unpack_norm_weight(i);
return val * rms * (1.0 + w);
}
fn dq_fly(pw: u32, bk: u32, rms: f32) -> f32 {
return dot(vec4<f32>(normed_3way(bk, rms), normed_3way(bk+1u, rms), normed_3way(bk+2u, rms), normed_3way(bk+3u, rms)),
vec4<f32>(f32(pw&0xFu)-8., f32((pw>>4u)&0xFu)-8., f32((pw>>8u)&0xFu)-8., f32((pw>>12u)&0xFu)-8.))
+ dot(vec4<f32>(normed_3way(bk+4u, rms), normed_3way(bk+5u, rms), normed_3way(bk+6u, rms), normed_3way(bk+7u, rms)),
vec4<f32>(f32((pw>>16u)&0xFu)-8., f32((pw>>20u)&0xFu)-8., f32((pw>>24u)&0xFu)-8., f32((pw>>28u)&0xFu)-8.));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u,
@builtin(workgroup_id) wg: vec3u,
@builtin(local_invocation_id) lid: vec3u) {
let tid = lid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
// ── Phase 1: Three-way add → compute RMS scalar (no shared memory for normed values) ──
var ss: f32 = 0.0;
for (var i = tid; i < K; i += 32u) {
let val = hidden_in[i] + mlp_residual[i] + attn_residual[i];
ss += val * val;
// WG 0: write hidden_out (residual for next layer)
if (wg.x == 0u) { hidden_out[i] = val; }
}
wg_scratch[tid] = ss;
workgroupBarrier();
// Tree reduction for sum of squares
for (var s = 16u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
let rms = 1.0 / sqrt(wg_scratch[0] / f32(K) + params.eps);
// ── Phase 2: Gate + Up GPTQ MatVec + SiLU (on-the-fly normed values from L2 cache) ──
let col = gid.x;
if (col >= N) { return; }
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
let up_qw_base = packed_rows * N;
let up_sc_base = num_groups * N;
var gate_sum: f32 = 0.0;
var up_sum: f32 = 0.0;
for (var grp: u32 = 0u; grp < num_groups; grp++) {
let gsf = grp * N + col;
let usf = up_sc_base + grp * N + col;
let sg = unpack2x16float(scales_merged[gsf >> 1u])[gsf & 1u];
let su = unpack2x16float(scales_merged[usf >> 1u])[usf & 1u];
let gstart = grp * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let g_off = pr * N + col;
let u_off = up_qw_base + pr * N + col;
let gw0 = qweight_merged[g_off]; let uw0 = qweight_merged[u_off];
let gw1 = qweight_merged[g_off+N]; let uw1 = qweight_merged[u_off+N];
let gw2 = qweight_merged[g_off+2u*N]; let uw2 = qweight_merged[u_off+2u*N];
let gw3 = qweight_merged[g_off+3u*N]; let uw3 = qweight_merged[u_off+3u*N];
let bk = pr * 8u;
gate_sum += sg * (dq_fly(gw0,bk,rms) + dq_fly(gw1,bk+8u,rms) + dq_fly(gw2,bk+16u,rms) + dq_fly(gw3,bk+24u,rms));
up_sum += su * (dq_fly(uw0,bk,rms) + dq_fly(uw1,bk+8u,rms) + dq_fly(uw2,bk+16u,rms) + dq_fly(uw3,bk+24u,rms));
}
}
output[col] = (gate_sum / (1.0 + exp(-gate_sum))) * up_sum;
}`,ir=`
struct Params {
K: u32,
N: u32,
group_size: u32,
eps: f32,
norm_weight: array<vec4<u32>, 640>,
}
@group(0) @binding(0) var<storage, read> hidden_in: array<f32>;
@group(0) @binding(1) var<storage, read> residual: array<f32>;
@group(0) @binding(2) var<storage, read> qweight: array<u32>;
@group(0) @binding(3) var<storage, read> scales: array<u32>;
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
@group(0) @binding(5) var<uniform> params: Params;
fn unpack_norm_weight(i: u32) -> f32 {
let vec_idx = i / 8u;
let u32_idx = (i / 2u) % 4u;
let bf16_idx = i % 2u;
let packed = params.norm_weight[vec_idx][u32_idx];
let bits = (packed >> (bf16_idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
// On-the-fly: only 128 bytes shared memory (vs 16KB wg_normed that killed occupancy)
var<workgroup> wg_scratch: array<f32, 32>;
fn normed_add(i: u32, rms: f32) -> f32 {
let val = hidden_in[i] + residual[i];
let w = unpack_norm_weight(i);
return val * rms * (1.0 + w);
}
fn dq_fly(pw: u32, bk: u32, rms: f32) -> f32 {
return dot(vec4<f32>(normed_add(bk, rms), normed_add(bk+1u, rms), normed_add(bk+2u, rms), normed_add(bk+3u, rms)),
vec4<f32>(f32(pw&0xFu)-8., f32((pw>>4u)&0xFu)-8., f32((pw>>8u)&0xFu)-8., f32((pw>>12u)&0xFu)-8.))
+ dot(vec4<f32>(normed_add(bk+4u, rms), normed_add(bk+5u, rms), normed_add(bk+6u, rms), normed_add(bk+7u, rms)),
vec4<f32>(f32((pw>>16u)&0xFu)-8., f32((pw>>20u)&0xFu)-8., f32((pw>>24u)&0xFu)-8., f32((pw>>28u)&0xFu)-8.));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u,
@builtin(local_invocation_id) lid: vec3u) {
let tid = lid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
// ── Phase 1: Compute RMS scalar only (inputs re-read on-the-fly in phase 2) ──
var ss: f32 = 0.0;
for (var i = tid; i < K; i += 32u) {
let val = hidden_in[i] + residual[i];
ss += val * val;
}
wg_scratch[tid] = ss;
workgroupBarrier();
for (var s = 16u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
let rms = 1.0 / sqrt(wg_scratch[0] / f32(K) + params.eps);
// ── Phase 2: GPTQ MatVec with on-the-fly normed values (L2-cached inputs) ──
let col = gid.x;
if (col >= N) { return; }
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
var sum: f32 = 0.0;
for (var grp: u32 = 0u; grp < num_groups; grp++) {
let sf = grp * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = grp * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N];
let bk = pr * 8u;
sum += s * (dq_fly(w0,bk,rms) + dq_fly(w1,bk+8u,rms) + dq_fly(w2,bk+16u,rms) + dq_fly(w3,bk+24u,rms));
}
}
output[col] = sum;
}`,sr=`
struct Params {
K: u32,
N: u32,
group_size: u32,
eps: f32,
norm_weight: array<vec4<u32>, 640>,
}
@group(0) @binding(0) var<storage, read> hidden_in: array<f32>;
@group(0) @binding(1) var<storage, read> qweight: array<u32>;
@group(0) @binding(2) var<storage, read> scales: array<u32>;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
fn unpack_norm_weight(i: u32) -> f32 {
let vec_idx = i / 8u;
let u32_idx = (i / 2u) % 4u;
let bf16_idx = i % 2u;
let packed = params.norm_weight[vec_idx][u32_idx];
let bits = (packed >> (bf16_idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
var<workgroup> wg_scratch: array<f32, 32>;
fn normed_val(i: u32, rms: f32) -> f32 {
let w = unpack_norm_weight(i);
return hidden_in[i] * rms * (1.0 + w);
}
fn dq_fly(pw: u32, bk: u32, rms: f32) -> f32 {
return dot(vec4<f32>(normed_val(bk, rms), normed_val(bk+1u, rms), normed_val(bk+2u, rms), normed_val(bk+3u, rms)),
vec4<f32>(f32(pw&0xFu)-8., f32((pw>>4u)&0xFu)-8., f32((pw>>8u)&0xFu)-8., f32((pw>>12u)&0xFu)-8.))
+ dot(vec4<f32>(normed_val(bk+4u, rms), normed_val(bk+5u, rms), normed_val(bk+6u, rms), normed_val(bk+7u, rms)),
vec4<f32>(f32((pw>>16u)&0xFu)-8., f32((pw>>20u)&0xFu)-8., f32((pw>>24u)&0xFu)-8., f32((pw>>28u)&0xFu)-8.));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u,
@builtin(local_invocation_id) lid: vec3u) {
let tid = lid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
// ── Phase 1: Compute RMS scalar only (inputs re-read on-the-fly in phase 2) ──
var ss: f32 = 0.0;
for (var i = tid; i < K; i += 32u) {
let val = hidden_in[i];
ss += val * val;
}
wg_scratch[tid] = ss;
workgroupBarrier();
for (var s = 16u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
let rms = 1.0 / sqrt(wg_scratch[0] / f32(K) + params.eps);
// ── Phase 2: GPTQ MatVec with on-the-fly normed values ──
let col = gid.x;
if (col >= N) { return; }
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
var sum: f32 = 0.0;
for (var grp: u32 = 0u; grp < num_groups; grp++) {
let sf = grp * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = grp * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N];
let bk = pr * 8u;
sum += s * (dq_fly(w0,bk,rms) + dq_fly(w1,bk+8u,rms) + dq_fly(w2,bk+16u,rms) + dq_fly(w3,bk+24u,rms));
}
}
output[col] = sum;
}`,or=`
struct Params {
K: u32, N: u32, group_size: u32, eps: f32,
norm_weight: array<vec4<u32>, 640>,
}
@group(0) @binding(0) var<storage, read> hidden_in: array<f32>;
@group(0) @binding(1) var<storage, read> residual: array<f32>;
@group(0) @binding(2) var<storage, read> qweight: array<u32>;
@group(0) @binding(3) var<storage, read> scales: array<u32>;
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
@group(0) @binding(5) var<uniform> params: Params;
fn unpack_nw(i: u32) -> f32 {
let vec_idx = i / 8u;
let u32_idx = (i / 2u) % 4u;
let bf16_idx = i % 2u;
let packed = params.norm_weight[vec_idx][u32_idx];
let bits = (packed >> (bf16_idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
var<workgroup> wg_scratch: array<f32, 32>;
var<workgroup> wg_rms: f32;
fn nv(i: u32) -> f32 {
return (hidden_in[i] + residual[i]) * wg_rms * (1.0 + unpack_nw(i));
}
fn dq_nv(w: u32, bk: u32) -> f32 {
return dot(vec4<f32>(nv(bk), nv(bk+1u), nv(bk+2u), nv(bk+3u)),
vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.))
+ dot(vec4<f32>(nv(bk+4u), nv(bk+5u), nv(bk+6u), nv(bk+7u)),
vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u,
@builtin(local_invocation_id) lid: vec3u) {
let tid = lid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
// Phase 1: compute RMS (132B shared memory total)
var ss: f32 = 0.0;
for (var i = tid; i < K; i += 32u) {
let val = hidden_in[i] + residual[i];
ss += val * val;
}
wg_scratch[tid] = ss;
workgroupBarrier();
for (var s = 16u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
if (tid == 0u) { wg_rms = 1.0 / sqrt(wg_scratch[0] / f32(K) + params.eps); }
workgroupBarrier();
// Phase 2: GPTQ matvec with inline norm (re-reads input from L1 cache)
let col = gid.x;
if (col >= N) { return; }
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
var sum: f32 = 0.0;
for (var grp: u32 = 0u; grp < num_groups; grp++) {
let sf = grp * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = grp * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N];
let bk = pr * 8u;
sum += s * (dq_nv(w0,bk) + dq_nv(w1,bk+8u) + dq_nv(w2,bk+16u) + dq_nv(w3,bk+24u));
}
}
output[col] = sum;
}`,nr=`
struct Params {
K: u32, N: u32, group_size: u32, eps: f32,
norm_weight: array<vec4<u32>, 640>,
}
@group(0) @binding(0) var<storage, read> hidden_in: array<f32>;
@group(0) @binding(1) var<storage, read> qweight: array<u32>;
@group(0) @binding(2) var<storage, read> scales: array<u32>;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
fn unpack_nw(i: u32) -> f32 {
let packed = params.norm_weight[i / 8u][(i / 2u) % 4u];
let bits = (packed >> ((i % 2u) * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
var<workgroup> wg_scratch: array<f32, 32>;
var<workgroup> wg_rms: f32;
fn nv(i: u32) -> f32 {
return hidden_in[i] * wg_rms * (1.0 + unpack_nw(i));
}
fn dq_nv(w: u32, bk: u32) -> f32 {
return dot(vec4<f32>(nv(bk), nv(bk+1u), nv(bk+2u), nv(bk+3u)),
vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.))
+ dot(vec4<f32>(nv(bk+4u), nv(bk+5u), nv(bk+6u), nv(bk+7u)),
vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u,
@builtin(local_invocation_id) lid: vec3u) {
let tid = lid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
var ss: f32 = 0.0;
for (var i = tid; i < K; i += 32u) {
ss += hidden_in[i] * hidden_in[i];
}
wg_scratch[tid] = ss;
workgroupBarrier();
for (var s = 16u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
if (tid == 0u) { wg_rms = 1.0 / sqrt(wg_scratch[0] / f32(K) + params.eps); }
workgroupBarrier();
let col = gid.x;
if (col >= N) { return; }
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
var sum: f32 = 0.0;
for (var grp: u32 = 0u; grp < num_groups; grp++) {
let sf = grp * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = grp * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N];
let bk = pr * 8u;
sum += s * (dq_nv(w0,bk) + dq_nv(w1,bk+8u) + dq_nv(w2,bk+16u) + dq_nv(w3,bk+24u));
}
}
output[col] = sum;
}`,gr=`
struct Params { num_tokens: u32, penalty: f32, presence: f32, _pad: u32, }
@group(0) @binding(0) var<storage, read_write> logits: array<f32>;
@group(0) @binding(1) var<storage, read> recent_tokens: array<u32>;
@group(0) @binding(2) var<uniform> params: Params;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
if (i >= params.num_tokens) { return; }
let tok = recent_tokens[i];
var val = logits[tok];
// Repetition penalty (multiplicative)
if (params.penalty > 1.0) {
if (val > 0.0) {
val = val / params.penalty;
} else {
val = val * params.penalty;
}
}
// Presence penalty (additive)
val -= params.presence;
logits[tok] = val;
}`,dr=`
struct Params {
temperature: f32,
top_k: u32,
top_p: f32,
random_seed: u32,
}
struct Result { idx: u32, val: f32, }
@group(0) @binding(0) var<storage, read> candidates: array<u32>; // 256 × (idx u32, val u32)
@group(0) @binding(1) var<storage, read_write> result: Result;
@group(0) @binding(2) var<uniform> params: Params;
var<workgroup> wg_idx: array<u32, 256>;
var<workgroup> wg_val: array<f32, 256>;
var<workgroup> wg_prob: array<f32, 256>;
// PCG random number generator
fn pcg(state: u32) -> u32 {
let s = state * 747796405u + 2891336453u;
let word = ((s >> ((s >> 28u) + 4u)) ^ s) * 277803737u;
return (word >> 22u) ^ word;
}
@compute @workgroup_size(256)
fn main(@builtin(local_invocation_id) lid: vec3u) {
let tid = lid.x;
// Load candidate (idx, logit) pair
wg_idx[tid] = candidates[tid * 2u];
wg_val[tid] = bitcast<f32>(candidates[tid * 2u + 1u]);
workgroupBarrier();
// ── Step 1: Parallel find max logit (for numerical stability) ──
var local_max = wg_val[tid];
workgroupBarrier();
// Tree reduction for max
for (var s = 128u; s > 0u; s >>= 1u) {
if (tid < s) {
local_max = max(local_max, wg_val[tid + s]);
wg_val[tid] = local_max;
}
workgroupBarrier();
}
let max_val = wg_val[0];
// Reload original values (they were clobbered by reduction)
wg_val[tid] = bitcast<f32>(candidates[tid * 2u + 1u]);
workgroupBarrier();
// ── Step 2: Bitonic sort by logit value (descending) ──
// Sort all 256 candidates so top-K and top-P can work on sorted order.
for (var k = 2u; k <= 256u; k <<= 1u) {
for (var j = k >> 1u; j > 0u; j >>= 1u) {
let ixj = tid ^ j;
if (ixj > tid) {
let ascending = (tid & k) != 0u;
let val_tid = wg_val[tid];
let val_ixj = wg_val[ixj];
// We want descending, so swap if tid has smaller value
let should_swap = select((val_tid < val_ixj), (val_tid > val_ixj), ascending);
if (should_swap) {
wg_val[tid] = val_ixj;
wg_val[ixj] = val_tid;
let tmp_idx = wg_idx[tid];
wg_idx[tid] = wg_idx[ixj];
wg_idx[ixj] = tmp_idx;
}
}
workgroupBarrier();
}
}
// ── Step 3: Temperature + softmax (parallel exp, parallel prefix sum) ──
let temp = max(params.temperature, 0.001); // Avoid division by zero
let K = min(params.top_k, 256u);
// Each thread computes exp((val - max) / temp) if within top-K
var prob: f32 = 0.0;
if (tid < K) {
prob = exp((wg_val[tid] - max_val) / temp);
}
wg_prob[tid] = prob;
workgroupBarrier();
// ── Step 4: Parallel prefix sum for cumulative probabilities ──
// Blelloch scan (up-sweep then down-sweep)
// First, compute total sum via tree reduction (we need it for normalization)
var sum_val = prob;
for (var s = 1u; s < 256u; s <<= 1u) {
workgroupBarrier();
if (tid >= s) {
sum_val = wg_prob[tid - s] + wg_prob[tid];
}
workgroupBarrier();
wg_prob[tid] = sum_val;
}
// Now wg_prob[i] = cumulative sum up to and including index i
// wg_prob[255] = total sum
let total_sum = wg_prob[255];
workgroupBarrier();
// ── Step 5: Thread 0 does top-P + sampling ──
// This is sequential but operates on sorted, cumulative data — very fast.
if (tid == 0u) {
let topP = params.top_p;
let threshold = topP * total_sum; // Un-normalized threshold
// Find cutoff: first index where cumulative prob >= threshold
var cutoff = K;
for (var i = 0u; i < K; i++) {
if (wg_prob[i] >= threshold) {
cutoff = i + 1u;
break;
}
}
// Sample: generate random float in [0, cumProb[cutoff-1])
let rand_u32 = pcg(params.random_seed);
let rand_f32 = f32(rand_u32) / 4294967296.0; // [0, 1)
let sample_point = rand_f32 * wg_prob[cutoff - 1u];
// Find sampled index via linear scan of cumulative probs
var selected = 0u;
for (var i = 0u; i < cutoff; i++) {
if (wg_prob[i] >= sample_point) {
selected = i;
break;
}
}
result.idx = wg_idx[selected];
result.val = wg_val[selected];
}
}`;function _(d=320,e=1e7,u=33,t=30,r=128){return`
const ROPE_THETA: f32 = ${e};
const MROPE_S1_LIMIT: u32 = ${u}u;
const MROPE_S2_LIMIT: u32 = ${t}u;
const PARTIAL_DIM: u32 = ${r}u;
struct Params {
num_heads: u32, // Q heads
num_kv_heads: u32, // KV heads
head_dim: u32,
eps: f32,
cache_position: u32, // KV cache write index (sequential, updated per token)
position: u32, // RoPE temporal/text position (updated per token)
position_h: u32, // mRoPE height dimension (= position for text tokens)
position_w: u32, // mRoPE width dimension (= position for text tokens)
// Packed BF16 norm weights: [Q norm weights | K norm weights]
// Q: numHeads*headDim BF16 values, K: numKVHeads*headDim BF16 values
qk_norm_weight: array<vec4<u32>, ${d}>,
}
@group(0) @binding(0) var<storage, read> q_proj_full: array<f32>;
@group(0) @binding(1) var<storage, read_write> k_proj: array<f32>;
@group(0) @binding(2) var<storage, read> v_proj: array<f32>;
@group(0) @binding(3) var<storage, read_write> q_proj: array<f32>;
@group(0) @binding(4) var<storage, read_write> q_gate: array<f32>;
@group(0) @binding(5) var<storage, read_write> k_cache: array<f32>;
@group(0) @binding(6) var<storage, read_write> v_cache: array<f32>;
@group(0) @binding(7) var<uniform> params: Params;
fn unpack_norm_bf16(flat_idx: u32) -> f32 {
let vec_idx = flat_idx / 8u;
let u32_idx = (flat_idx / 2u) % 4u;
let bf16_idx = flat_idx % 2u;
let packed = params.qk_norm_weight[vec_idx][u32_idx];
let bits = (packed >> (bf16_idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
var<workgroup> wg_data: array<f32, 256>;
var<workgroup> wg_scratch: array<f32, 256>;
@compute @workgroup_size(256)
fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) {
let h = wg.x;
let tid = lid.x;
let hd = params.head_dim;
let nh = params.num_heads;
let nkv = params.num_kv_heads;
if (h < nh) {
// ── Q head: split from interleaved + RMSNorm ──
let src_off = h * hd * 2u;
// Phase 1: Read Q values, compute ss, write gate
var ss: f32 = 0.0;
for (var i = tid; i < hd; i += 256u) {
let q_val = q_proj_full[src_off + i];
wg_data[i] = q_val;
ss += q_val * q_val;
q_gate[h * hd + i] = q_proj_full[src_off + hd + i];
}
wg_scratch[tid] = ss;
workgroupBarrier();
for (var s = 128u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
let rms = 1.0 / sqrt(wg_scratch[0] / f32(hd) + params.eps);
// Phase 2: Normalize Q and write to shared mem
for (var i = tid; i < hd; i += 256u) {
let w = unpack_norm_bf16(h * hd + i);
wg_data[i] = wg_data[i] * rms * (1.0 + w);
}
workgroupBarrier();
// Phase 3: Apply mRoPE to Q (rotate_half: dim i pairs with dim i+half_dim)
// mRoPE: interleaved position selection based on freq_idx
let q_half = PARTIAL_DIM / 2u;
for (var i = tid; i < hd; i += 256u) {
var q_out = wg_data[i];
if (i < PARTIAL_DIM) {
let freq_idx = i % q_half;
let freq = 1.0 / pow(ROPE_THETA, f32(2u * freq_idx) / f32(PARTIAL_DIM));
// mRoPE: select position based on interleaved section
var pos = params.position;
if (freq_idx % 3u == 1u && freq_idx < MROPE_S1_LIMIT) { pos = params.position_h; }
else if (freq_idx % 3u == 2u && freq_idx < MROPE_S2_LIMIT) { pos = params.position_w; }
let angle = f32(pos) * freq;
let cos_a = cos(angle);
let sin_a = sin(angle);
if (i < q_half) {
q_out = wg_data[i] * cos_a - wg_data[i + q_half] * sin_a;
} else {
q_out = wg_data[i] * cos_a + wg_data[i - q_half] * sin_a;
}
}
q_proj[h * hd + i] = q_out;
}
} else {
// ── KV head: K norm + KV cache store ──
let kv_h = h - nh;
if (kv_h >= nkv) { return; }
let off = kv_h * hd;
let q_norm_size = nh * hd; // offset to K norm weights in packed array
// Phase 1: Read K, compute ss
var ss: f32 = 0.0;
for (var i = tid; i < hd; i += 256u) {
let k_val = k_proj[off + i];
wg_data[i] = k_val;
ss += k_val * k_val;
}
wg_scratch[tid] = ss;
workgroupBarrier();
for (var s = 128u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
let rms = 1.0 / sqrt(wg_scratch[0] / f32(hd) + params.eps);
// Phase 2: Normalize K, write to shared mem
for (var i = tid; i < hd; i += 256u) {
let w = unpack_norm_bf16(q_norm_size + kv_h * hd + i);
wg_data[i] = wg_data[i] * rms * (1.0 + w);
}
workgroupBarrier();
// Phase 3: Apply mRoPE to K, write to kProj + k_cache, copy V to v_cache
let total = nkv * hd;
let cache_off = params.cache_position * total + off;
let k_half = PARTIAL_DIM / 2u;
for (var i = tid; i < hd; i += 256u) {
var k_out = wg_data[i];
if (i < PARTIAL_DIM) {
let freq_idx = i % k_half;
let freq = 1.0 / pow(ROPE_THETA, f32(2u * freq_idx) / f32(PARTIAL_DIM));
// mRoPE: select position based on interleaved section
var pos = params.position;
if (freq_idx % 3u == 1u && freq_idx < MROPE_S1_LIMIT) { pos = params.position_h; }
else if (freq_idx % 3u == 2u && freq_idx < MROPE_S2_LIMIT) { pos = params.position_w; }
let angle = f32(pos) * freq;
let cos_a = cos(angle);
let sin_a = sin(angle);
if (i < k_half) {
k_out = wg_data[i] * cos_a - wg_data[i + k_half] * sin_a;
} else {
k_out = wg_data[i] * cos_a + wg_data[i - k_half] * sin_a;
}
}
k_proj[off + i] = k_out;
k_cache[cache_off + i] = k_out;
v_cache[cache_off + i] = v_proj[off + i];
}
}
}`}const lr=_(320),pr=`
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> qweight: array<u32>;
@group(0) @binding(2) var<storage, read> scales: array<u32>;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
fn dq(w: u32, bk: u32) -> f32 {
return dot(vec4<f32>(input[bk], input[bk+1u], input[bk+2u], input[bk+3u]),
vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.))
+ dot(vec4<f32>(input[bk+4u], input[bk+5u], input[bk+6u], input[bk+7u]),
vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.));
}
var<workgroup> partial: array<f32, 32>;
@compute @workgroup_size(32)
fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) {
let tid = lid.x;
let N = params.N;
let gs = params.group_size;
let packed_rows = params.K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
let col_in_wg = tid / 4u; // 0-7: which of 8 columns this thread handles
let lane = tid & 3u; // 0-3: which quarter of K this thread computes
let col = wid.x * 8u + col_in_wg;
var sum: f32 = 0.0;
if (col < N) {
let groups_per_lane = num_groups / 4u;
let g_start = lane * groups_per_lane;
let g_end = g_start + groups_per_lane;
for (var g = g_start; g < g_end; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off];
let w1 = qweight[off + N];
let w2 = qweight[off + 2u * N];
let w3 = qweight[off + 3u * N];
let bk = pr * 8u;
sum += s * (dq(w0, bk) + dq(w1, bk+8u) +
dq(w2, bk+16u) + dq(w3, bk+24u));
}
}
}
// Store partial sum and reduce within each 4-thread column group
// With @workgroup_size(32) on Apple GPU, all 32 threads are in 1 SIMD group
// so workgroupBarrier is effectively free.
partial[tid] = sum;
workgroupBarrier();
if (lane == 0u && col < N) {
output[col] = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u];
}
}`,fr=`
enable f16;
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> qweight: array<u32>;
@group(0) @binding(2) var<storage, read> scales: array<u32>;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
fn dq16(w: u32, bk: u32) -> f32 {
let i0 = vec4<f16>(f16(input[bk]), f16(input[bk+1u]), f16(input[bk+2u]), f16(input[bk+3u]));
let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.));
let i1 = vec4<f16>(f16(input[bk+4u]), f16(input[bk+5u]), f16(input[bk+6u]), f16(input[bk+7u]));
let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.));
return f32(dot(i0, w0)) + f32(dot(i1, w1));
}
var<workgroup> partial: array<f32, 32>;
@compute @workgroup_size(32)
fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) {
let tid = lid.x;
let N = params.N;
let gs = params.group_size;
let packed_rows = params.K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
let col_in_wg = tid / 4u;
let lane = tid & 3u;
let col = wid.x * 8u + col_in_wg;
var sum: f32 = 0.0;
if (col < N) {
let groups_per_lane = num_groups / 4u;
let g_start = lane * groups_per_lane;
let g_end = g_start + groups_per_lane;
for (var g = g_start; g < g_end; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off];
let w1 = qweight[off + N];
let w2 = qweight[off + 2u * N];
let w3 = qweight[off + 3u * N];
let bk = pr * 8u;
sum += s * (dq16(w0, bk) + dq16(w1, bk+8u) +
dq16(w2, bk+16u) + dq16(w3, bk+24u));
}
}
}
partial[tid] = sum;
workgroupBarrier();
if (lane == 0u && col < N) {
output[col] = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u];
}
}`,_r=`
enable f16;
enable subgroups;
struct Params { K: u32, N: u32, group_size: u32, num_splits: u32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> qweight: array<u32>;
@group(0) @binding(2) var<storage, read> scales: array<u32>;
@group(0) @binding(3) var<storage, read_write> partials: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
fn dq16(w: u32, bk: u32) -> f32 {
let i0 = vec4<f16>(f16(input[bk]), f16(input[bk+1u]), f16(input[bk+2u]), f16(input[bk+3u]));
let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.));
let i1 = vec4<f16>(f16(input[bk+4u]), f16(input[bk+5u]), f16(input[bk+6u]), f16(input[bk+7u]));
let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.));
return f32(dot(i0, w0)) + f32(dot(i1, w1));
}
@compute @workgroup_size(32)
fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) {
let tid = lid.x;
let N = params.N;
let gs = params.group_size;
let ns = params.num_splits;
let packed_rows = params.K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
let col_in_wg = tid / 4u;
let lane = tid & 3u;
let col = wid.x * 8u + col_in_wg;
let split_id = wid.y;
var sum: f32 = 0.0;
if (col < N) {
// Divide groups across splits, then across 4 lanes within each split
let groups_per_split = num_groups / ns;
let g_base = split_id * groups_per_split;
let split_end = select(g_base + groups_per_split, num_groups, split_id >= ns - 1u);
let split_groups = split_end - g_base;
let groups_per_lane = split_groups / 4u;
let g_start = g_base + lane * groups_per_lane;
let g_end = select(g_start + groups_per_lane, split_end, lane >= 3u);
for (var g = g_start; g < g_end; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off];
let w1 = qweight[off + N];
let w2 = qweight[off + 2u * N];
let w3 = qweight[off + 3u * N];
let bk = pr * 8u;
sum += s * (dq16(w0, bk) + dq16(w1, bk+8u) +
dq16(w2, bk+16u) + dq16(w3, bk+24u));
}
}
}
// Subgroup reduction across 4 lanes (no shared memory, no barrier)
var s = sum;
s += subgroupShuffleXor(s, 1u);
s += subgroupShuffleXor(s, 2u);
if (lane == 0u && col < N) {
partials[split_id * N + col] = s;
}
}`,cr=`
struct Params { H_half: u32, num_tokens: u32, }
@group(0) @binding(0) var<storage, read> src: array<u32>;
@group(0) @binding(1) var<storage, read> indices: array<u32>;
@group(0) @binding(2) var<storage, read_write> dst: array<u32>;
@group(0) @binding(3) var<uniform> p: Params;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let col = gid.x;
let row_out = gid.y;
if (col >= p.H_half || row_out >= p.num_tokens) { return; }
let src_row = indices[row_out];
dst[row_out * p.H_half + col] = src[src_row * p.H_half + col];
}
`,wr=`
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> bf16_weight: array<u32>;
@group(0) @binding(1) var<storage, read_write> qweight: array<u32>;
@group(0) @binding(2) var<storage, read_write> scales_f32: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
var<workgroup> wg_max: array<f32, 32>;
fn unpack_bf16(idx: u32, base: u32) -> f32 {
let packed = bf16_weight[base + idx / 2u];
let bits = (packed >> ((idx % 2u) * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
@compute @workgroup_size(32)
fn main(@builtin(workgroup_id) wg_id: vec3u,
@builtin(local_invocation_id) lid: vec3u) {
let col = wg_id.x + wg_id.y * 65535u; // 2D dispatch for N > 65535
let tid = lid.x;
let K = params.K;
let N = params.N;
if (col >= N) { return; }
let gs = params.group_size;
let num_groups = K / gs;
let base = col * K / 2u;
for (var g = 0u; g < num_groups; g++) {
let gstart = g * gs;
// Find max_abs (parallel reduction across 32 threads)
var local_max: f32 = 0.0;
for (var i = tid; i < gs; i += 32u) {
local_max = max(local_max, abs(unpack_bf16(gstart + i, base)));
}
wg_max[tid] = local_max;
workgroupBarrier();
for (var s = 16u; s > 0u; s >>= 1u) {
if (tid < s) { wg_max[tid] = max(wg_max[tid], wg_max[tid + s]); }
workgroupBarrier();
}
let scale_val = wg_max[0] / 7.0;
let inv_scale = select(1.0 / scale_val, 0.0, scale_val == 0.0);
// Store scale as f32 (packed to F16 in separate pass)
if (tid == 0u) { scales_f32[g * N + col] = scale_val; }
// Quantize and pack 8 INT4 values per u32
let rows_per_group = gs / 8u;
for (var r = tid; r < rows_per_group; r += 32u) {
var packed: u32 = 0u;
let bk = gstart + r * 8u;
for (var b = 0u; b < 8u; b++) {
let val = unpack_bf16(bk + b, base);
var q = i32(round(val * inv_scale)) + 8;
q = clamp(q, 0, 15);
packed |= (u32(q) << (b * 4u));
}
let packed_row = g * rows_per_group + r;
qweight[packed_row * N + col] = packed;
}
}
}`,br=`
struct Params { count: u32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<u32>;
@group(0) @binding(2) var<uniform> params: Params;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let idx = gid.x;
if (idx >= params.count) { return; }
let a = input[idx * 2u];
let b = input[idx * 2u + 1u];
output[idx] = pack2x16float(vec2<f32>(a, b));
}`,vr=`
struct Params { N: u32, eps: f32, }
@group(0) @binding(0) var<storage, read_write> hidden: array<f32>;
@group(0) @binding(1) var<storage, read> addend: array<f32>;
@group(0) @binding(2) var<storage, read> weight: array<u32>;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
fn unpack_bf16(packed: u32, idx: u32) -> f32 {
let bits = (packed >> (idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
var<workgroup> wg_scratch: array<f32, 256>;
@compute @workgroup_size(256)
fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) {
let tid = lid.x;
let N = params.N;
let off = wid.x * N;
var ss: f32 = 0.0;
for (var i = tid; i < N; i += 256u) {
let val = hidden[off + i] + addend[off + i];
hidden[off + i] = val;
ss += val * val;
}
wg_scratch[tid] = ss;
workgroupBarrier();
for (var s = 128u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps);
for (var i = tid; i < N; i += 256u) {
let w = unpack_bf16(weight[i / 2u], i % 2u);
output[off + i] = hidden[off + i] * rms * (1.0 + w);
}
}`,mr=`
struct Params { N: u32, eps: f32, }
@group(0) @binding(0) var<storage, read> hidden: array<f32>;
@group(0) @binding(1) var<storage, read> addend: array<f32>;
@group(0) @binding(2) var<storage, read> weight: array<u32>;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
fn unpack_bf16(packed: u32, idx: u32) -> f32 {
let bits = (packed >> (idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
var<workgroup> wg_vals: array<f32, 7680>;
var<workgroup> wg_scratch: array<f32, 256>;
@compute @workgroup_size(256)
fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) {
let tid = lid.x;
let N = params.N;
let off = wid.x * N;
var ss: f32 = 0.0;
for (var i = tid; i < N; i += 256u) {
let val = hidden[off + i] + addend[off + i];
wg_vals[i] = val;
ss += val * val;
}
wg_scratch[tid] = ss;
workgroupBarrier();
for (var s = 128u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps);
for (var i = tid; i < N; i += 256u) {
let w = unpack_bf16(weight[i / 2u], i % 2u);
output[off + i] = wg_vals[i] * rms * (1.0 + w);
}
}`,kr=`
struct Params { N: u32, eps: f32, }
@group(0) @binding(0) var<storage, read> input_a: array<f32>;
@group(0) @binding(1) var<storage, read> input_b: array<f32>;
@group(0) @binding(2) var<storage, read> input_c: array<f32>;
@group(0) @binding(3) var<storage, read> weight: array<u32>;
@group(0) @binding(4) var<storage, read_write> hidden_out: array<f32>;
@group(0) @binding(5) var<storage, read_write> normed: array<f32>;
@group(0) @binding(6) var<uniform> params: Params;
fn unpack_bf16(packed: u32, idx: u32) -> f32 {
let bits = (packed >> (idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
var<workgroup> wg_vals: array<f32, 7680>;
var<workgroup> wg_scratch: array<f32, 256>;
@compute @workgroup_size(256)
fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) {
let tid = lid.x;
let N = params.N;
let off = wid.x * N;
var ss: f32 = 0.0;
for (var i = tid; i < N; i += 256u) {
let val = input_a[off + i] + input_b[off + i] + input_c[off + i];
wg_vals[i] = val;
hidden_out[off + i] = val;
ss += val * val;
}
wg_scratch[tid] = ss;
workgroupBarrier();
for (var s = 128u; s > 0u; s >>= 1u) {
if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; }
workgroupBarrier();
}
let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps);
for (var i = tid; i < N; i += 256u) {
let w = unpack_bf16(weight[i / 2u], i % 2u);
normed[off + i] = wg_vals[i] * rms * (1.0 + w);
}
}`,hr=`
enable f16;
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> qweight_gate: array<u32>;
@group(0) @binding(2) var<storage, read> scales_gate: array<u32>;
@group(0) @binding(3) var<storage, read> qweight_up: array<u32>;
@group(0) @binding(4) var<storage, read> scales_up: array<u32>;
@group(0) @binding(5) var<storage, read_write> output: array<f32>;
@group(0) @binding(6) var<uniform> params: Params;
fn dq16_at(w: u32, base: u32, bk: u32) -> f32 {
let i0 = vec4<f16>(f16(input[base+bk]), f16(input[base+bk+1u]), f16(input[base+bk+2u]), f16(input[base+bk+3u]));
let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.));
let i1 = vec4<f16>(f16(input[base+bk+4u]), f16(input[base+bk+5u]), f16(input[base+bk+6u]), f16(input[base+bk+7u]));
let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.));
return f32(dot(i0, w0)) + f32(dot(i1, w1));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let col = gid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
if (col >= N) { return; }
var gate0: f32 = 0.0; var up0: f32 = 0.0;
var gate1: f32 = 0.0; var up1: f32 = 0.0;
for (var grp: u32 = 0u; grp < num_groups; grp++) {
let sf = grp * N + col;
let sg = unpack2x16float(scales_gate[sf >> 1u])[sf & 1u];
let su = unpack2x16float(scales_up[sf >> 1u])[sf & 1u];
let gstart = grp * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let gw0 = qweight_gate[off]; let uw0 = qweight_up[off];
let gw1 = qweight_gate[off+N]; let uw1 = qweight_up[off+N];
let gw2 = qweight_gate[off+2u*N]; let uw2 = qweight_up[off+2u*N];
let gw3 = qweight_gate[off+3u*N]; let uw3 = qweight_up[off+3u*N];
let bk = pr * 8u;
// Batch 0: input at offset 0
gate0 += sg * (dq16_at(gw0,0u,bk) + dq16_at(gw1,0u,bk+8u) + dq16_at(gw2,0u,bk+16u) + dq16_at(gw3,0u,bk+24u));
up0 += su * (dq16_at(uw0,0u,bk) + dq16_at(uw1,0u,bk+8u) + dq16_at(uw2,0u,bk+16u) + dq16_at(uw3,0u,bk+24u));
// Batch 1: input at offset K
gate1 += sg * (dq16_at(gw0,K,bk) + dq16_at(gw1,K,bk+8u) + dq16_at(gw2,K,bk+16u) + dq16_at(gw3,K,bk+24u));
up1 += su * (dq16_at(uw0,K,bk) + dq16_at(uw1,K,bk+8u) + dq16_at(uw2,K,bk+16u) + dq16_at(uw3,K,bk+24u));
}
}
output[col] = (gate0 / (1.0 + exp(-gate0))) * up0;
output[N + col] = (gate1 / (1.0 + exp(-gate1))) * up1;
}`,xr=`
enable f16;
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read> qweight: array<u32>;
@group(0) @binding(3) var<storage, read> scales: array<u32>;
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
@group(0) @binding(5) var<uniform> params: Params;
fn silu(v: f32) -> f32 { return v / (1.0 + exp(-v)); }
fn dq16_sl_at(w: u32, base: u32, bk: u32) -> f32 {
let i0 = vec4<f16>(f16(silu(a[base+bk])*b[base+bk]), f16(silu(a[base+bk+1u])*b[base+bk+1u]), f16(silu(a[base+bk+2u])*b[base+bk+2u]), f16(silu(a[base+bk+3u])*b[base+bk+3u]));
let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.));
let i1 = vec4<f16>(f16(silu(a[base+bk+4u])*b[base+bk+4u]), f16(silu(a[base+bk+5u])*b[base+bk+5u]), f16(silu(a[base+bk+6u])*b[base+bk+6u]), f16(silu(a[base+bk+7u])*b[base+bk+7u]));
let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.));
return f32(dot(i0, w0)) + f32(dot(i1, w1));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let col = gid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
if (col >= N) { return; }
var sum0: f32 = 0.0;
var sum1: f32 = 0.0;
for (var g: u32 = 0u; g < num_groups; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N];
let bk = pr * 8u;
sum0 += s * (dq16_sl_at(w0,0u,bk) + dq16_sl_at(w1,0u,bk+8u) + dq16_sl_at(w2,0u,bk+16u) + dq16_sl_at(w3,0u,bk+24u));
sum1 += s * (dq16_sl_at(w0,K,bk) + dq16_sl_at(w1,K,bk+8u) + dq16_sl_at(w2,K,bk+16u) + dq16_sl_at(w3,K,bk+24u));
}
}
output[col] = sum0;
output[N + col] = sum1;
}`,qr=`
enable f16;
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> x: array<f32>;
@group(0) @binding(1) var<storage, read> gate: array<f32>;
@group(0) @binding(2) var<storage, read> qweight: array<u32>;
@group(0) @binding(3) var<storage, read> scales: array<u32>;
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
@group(0) @binding(5) var<uniform> params: Params;
fn sig(v: f32) -> f32 { return 1.0 / (1.0 + exp(-v)); }
fn dq16_sg_at(w: u32, base: u32, bk: u32) -> f32 {
let i0 = vec4<f16>(f16(x[base+bk]*sig(gate[base+bk])), f16(x[base+bk+1u]*sig(gate[base+bk+1u])), f16(x[base+bk+2u]*sig(gate[base+bk+2u])), f16(x[base+bk+3u]*sig(gate[base+bk+3u])));
let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.));
let i1 = vec4<f16>(f16(x[base+bk+4u]*sig(gate[base+bk+4u])), f16(x[base+bk+5u]*sig(gate[base+bk+5u])), f16(x[base+bk+6u]*sig(gate[base+bk+6u])), f16(x[base+bk+7u]*sig(gate[base+bk+7u])));
let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.));
return f32(dot(i0, w0)) + f32(dot(i1, w1));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let col = gid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
if (col >= N) { return; }
var sum0: f32 = 0.0;
var sum1: f32 = 0.0;
for (var g: u32 = 0u; g < num_groups; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N];
let bk = pr * 8u;
sum0 += s * (dq16_sg_at(w0,0u,bk) + dq16_sg_at(w1,0u,bk+8u) + dq16_sg_at(w2,0u,bk+16u) + dq16_sg_at(w3,0u,bk+24u));
sum1 += s * (dq16_sg_at(w0,K,bk) + dq16_sg_at(w1,K,bk+8u) + dq16_sg_at(w2,K,bk+16u) + dq16_sg_at(w3,K,bk+24u));
}
}
output[col] = sum0;
output[N + col] = sum1;
}`,yr=`
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> qweight: array<u32>;
@group(0) @binding(2) var<storage, read> scales: array<u32>;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
fn dq_at(w: u32, base: u32, bk: u32) -> f32 {
return dot(vec4<f32>(input[base+bk], input[base+bk+1u], input[base+bk+2u], input[base+bk+3u]),
vec4<f32>(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.))
+ dot(vec4<f32>(input[base+bk+4u], input[base+bk+5u], input[base+bk+6u], input[base+bk+7u]),
vec4<f32>(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let col = gid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
if (col >= N) { return; }
var sum0: f32 = 0.0;
var sum1: f32 = 0.0;
for (var g: u32 = 0u; g < num_groups; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off];
let w1 = qweight[off + N];
let w2 = qweight[off + 2u * N];
let w3 = qweight[off + 3u * N];
let bk = pr * 8u;
sum0 += s * (dq_at(w0, 0u, bk) + dq_at(w1, 0u, bk+8u) +
dq_at(w2, 0u, bk+16u) + dq_at(w3, 0u, bk+24u));
sum1 += s * (dq_at(w0, K, bk) + dq_at(w1, K, bk+8u) +
dq_at(w2, K, bk+16u) + dq_at(w3, K, bk+24u));
}
}
output[col] = sum0;
output[N + col] = sum1;
}`,Fr=`
enable f16;
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> qweight: array<u32>;
@group(0) @binding(2) var<storage, read> scales: array<u32>;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
fn dq16_at(w: u32, base: u32, bk: u32) -> f32 {
let i0 = vec4<f16>(f16(input[base+bk]), f16(input[base+bk+1u]), f16(input[base+bk+2u]), f16(input[base+bk+3u]));
let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.));
let i1 = vec4<f16>(f16(input[base+bk+4u]), f16(input[base+bk+5u]), f16(input[base+bk+6u]), f16(input[base+bk+7u]));
let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.));
return f32(dot(i0, w0)) + f32(dot(i1, w1));
}
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let col = gid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
if (col >= N) { return; }
var sum0: f32 = 0.0;
var sum1: f32 = 0.0;
for (var g: u32 = 0u; g < num_groups; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off];
let w1 = qweight[off + N];
let w2 = qweight[off + 2u * N];
let w3 = qweight[off + 3u * N];
let bk = pr * 8u;
sum0 += s * (dq16_at(w0, 0u, bk) + dq16_at(w1, 0u, bk+8u) +
dq16_at(w2, 0u, bk+16u) + dq16_at(w3, 0u, bk+24u));
sum1 += s * (dq16_at(w0, K, bk) + dq16_at(w1, K, bk+8u) +
dq16_at(w2, K, bk+16u) + dq16_at(w3, K, bk+24u));
}
}
output[col] = sum0;
output[N + col] = sum1;
}`,Nr=`
enable f16;
struct Params { K: u32, N: u32, group_size: u32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> qweight: array<u32>;
@group(0) @binding(2) var<storage, read> scales: array<u32>;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
fn dq16_at(w: u32, base: u32, bk: u32) -> f32 {
let i0 = vec4<f16>(f16(input[base+bk]), f16(input[base+bk+1u]), f16(input[base+bk+2u]), f16(input[base+bk+3u]));
let w0 = vec4<f16>(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.));
let i1 = vec4<f16>(f16(input[base+bk+4u]), f16(input[base+bk+5u]), f16(input[base+bk+6u]), f16(input[base+bk+7u]));
let w1 = vec4<f16>(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.));
return f32(dot(i0, w0)) + f32(dot(i1, w1));
}
var<workgroup> partial0: array<f32, 32>;
var<workgroup> partial1: array<f32, 32>;
@compute @workgroup_size(32)
fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) {
let tid = lid.x;
let K = params.K;
let N = params.N;
let gs = params.group_size;
let packed_rows = K / 8u;
let rows_per_group = gs / 8u;
let num_groups = packed_rows / rows_per_group;
let col_in_wg = tid / 4u;
let lane = tid & 3u;
let col = wid.x * 8u + col_in_wg;
var sum0: f32 = 0.0;
var sum1: f32 = 0.0;
if (col < N) {
let groups_per_lane = num_groups / 4u;
let g_start = lane * groups_per_lane;
let g_end = g_start + groups_per_lane;
for (var g = g_start; g < g_end; g++) {
let sf = g * N + col;
let s = unpack2x16float(scales[sf >> 1u])[sf & 1u];
let gstart = g * rows_per_group;
for (var r: u32 = 0u; r < rows_per_group; r += 4u) {
let pr = gstart + r;
let off = pr * N + col;
let w0 = qweight[off];
let w1 = qweight[off + N];
let w2 = qweight[off + 2u * N];
let w3 = qweight[off + 3u * N];
let bk = pr * 8u;
sum0 += s * (dq16_at(w0, 0u, bk) + dq16_at(w1, 0u, bk+8u) +
dq16_at(w2, 0u, bk+16u) + dq16_at(w3, 0u, bk+24u));
sum1 += s * (dq16_at(w0, K, bk) + dq16_at(w1, K, bk+8u) +
dq16_at(w2, K, bk+16u) + dq16_at(w3, K, bk+24u));
}
}
}
partial0[tid] = sum0;
partial1[tid] = sum1;
workgroupBarrier();
if (lane == 0u && col < N) {
output[col] = partial0[tid] + partial0[tid+1u] + partial0[tid+2u] + partial0[tid+3u];
output[N + col] = partial1[tid] + partial1[tid+1u] + partial1[tid+2u] + partial1[tid+3u];
}
}`,Pr=`
struct Params { N: u32, eps: f32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> weight: array<u32>; // BF16 packed
@group(0) @binding(2) var<storage, read> bias: array<u32>; // BF16 packed
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
fn unpack_bf16(packed: u32, idx: u32) -> f32 {
let bits = (packed >> (idx * 16u)) & 0xFFFFu;
return bitcast<f32>(bits << 16u);
}
var<workgroup> shared_sum: array<f32, 256>;
var<workgroup> shared_sq: array<f32, 256>;
@compute @workgroup_size(256)
fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) {
let N = params.N;
let token = wid.x;
let base = token * N;
let tid = lid.x;
// Phase 1: compute sum and sum-of-squares
var s: f32 = 0.0;
var sq: f32 = 0.0;
for (var i = tid; i < N; i += 256u) {
let v = input[base + i];
s += v;
sq += v * v;
}
shared_sum[tid] = s;
shared_sq[tid] = sq;
workgroupBarrier();
// Parallel reduction
for (var stride: u32 = 128u; stride > 0u; stride >>= 1u) {
if (tid < stride) {
shared_sum[tid] += shared_sum[tid + stride];
shared_sq[tid] += shared_sq[tid + stride];
}
workgroupBarrier();
}
let mean = shared_sum[0] / f32(N);
let variance = shared_sq[0] / f32(N) - mean * mean;
let inv_std = 1.0 / sqrt(variance + params.eps);
// Phase 2: normalize and apply weight + bias
for (var i = tid; i < N; i += 256u) {
let v = input[base + i];
let normalized = (v - mean) * inv_std;
let w = unpack_bf16(weight[i >> 1u], i);
let b = unpack_bf16(bias[i >> 1u], i);
output[base + i] = normalized * w + b;
}
}`,Er=`
struct Params { K: u32, N: u32, }
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> weight: array<u32>; // BF16 packed [N, K/2]
@group(0) @binding(2) var<storage, read> bias: array<u32>; // BF16 packed [N/2]
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let col = gid.x; // output feature index
let token = gid.y; // batch/token index
let K = params.K;
let N = params.N;
if (col >= N) { return; }
let in_base = token * K;
let w_base = col * K / 2u;
var sum: f32 = 0.0;
for (var k: u32 = 0u; k < K; k += 4u) {
let p0 = weight[w_base + k / 2u];
let p1 = weight[w_base + k / 2u + 1u];
let w0 = bitcast<f32>((p0 & 0xFFFFu) << 16u);
let w1 = bitcast<f32>(p0 & 0xFFFF0000u);
let w2 = bitcast<f32>((p1 & 0xFFFFu) << 16u);
let w3 = bitcast<f32>(p1 & 0xFFFF0000u);
sum += input[in_base + k] * w0 + input[in_base + k + 1u] * w1
+ input[in_base + k + 2u] * w2 + input[in_base + k + 3u] * w3;
}
// Add BF16 bias
let bp = bias[col >> 1u];
let b = bitcast<f32>(select((bp & 0xFFFFu) << 16u, bp & 0xFFFF0000u, (col & 1u) == 1u));
output[token * N + col] = sum + b;
}`,Sr=`
@group(0) @binding(0) var<storage, read_write> data: array<f32>;
@group(0) @binding(1) var<uniform> len: u32;
const SQRT_2_OVER_PI: f32 = 0.7978845608;
const COEFF: f32 = 0.044715;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
if (i >= len) { return; }
let x = data[i];
// For |x| > 10, GELU saturates: positive → x, negative → 0
// Avoids intermediate overflow in x^3 for large values
if (x > 10.0) { return; } // data[i] = x (already in place)
if (x < -10.0) { data[i] = 0.0; return; }
let inner = SQRT_2_OVER_PI * (x + COEFF * x * x * x);
data[i] = 0.5 * x * (1.0 + tanh(inner));
}`,Rr=`
@group(0) @binding(0) var<storage, read_write> data: array<f32>;
@group(0) @binding(1) var<uniform> len: u32;
const SQRT_2_OVER_PI: f32 = 0.7978845608;
const COEFF: f32 = 0.044715;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
if (i >= len) { return; }
let x = data[i];
if (x > 10.0) { return; }
if (x < -10.0) { data[i] = 0.0; return; }
let inner = SQRT_2_OVER_PI * (x + COEFF * x * x * x);
data[i] = 0.5 * x * (1.0 + tanh(inner));
}`,Kr=`
struct Params { seq_len: u32, num_heads: u32, head_dim: u32, }
@group(0) @binding(0) var<storage, read> q_in: array<f32>;
@group(0) @binding(1) var<storage, read> k_in: array<f32>;
@group(0) @binding(2) var<storage, read_write> q_out: array<f32>;
@group(0) @binding(3) var<storage, read_write> k_out: array<f32>;
@group(0) @binding(4) var<storage, read> cos_buf: array<f32>;
@group(0) @binding(5) var<storage, read> sin_buf: array<f32>;
@group(0) @binding(6) var<uniform> params: Params;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let idx = gid.x;
let total = params.seq_len * params.num_heads * params.head_dim;
if (idx >= total) { return; }
let d = idx % params.head_dim;
let head_and_seq = idx / params.head_dim;
let seq = head_and_seq / params.num_heads;
let half_dim = params.head_dim / 2u;
// cos/sin are [seq_len, head_dim] — same for all heads
let cs_idx = seq * params.head_dim + d;
let c = cos_buf[cs_idx];
let s = sin_buf[cs_idx];
// rotate_half: pair d with d+half_dim (or d-half_dim)
let partner_d = select(d + half_dim, d - half_dim, d >= half_dim);
let partner_idx = head_and_seq * params.head_dim + partner_d;
let sign = select(-1.0, 1.0, d >= half_dim);
q_out[idx] = q_in[idx] * c + sign * q_in[partner_idx] * s;
k_out[idx] = k_in[idx] * c + sign * k_in[partner_idx] * s;
}`,Br=`
struct Params { seq_len: u32, num_heads: u32, head_dim: u32, scale: f32, }
@group(0) @binding(0) var<storage, read> q: array<f32>;
@group(0) @binding(1) var<storage, read> k: array<f32>;
@group(0) @binding(2) var<storage, read> v: array<f32>;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> params: Params;
var<workgroup> wg_dot: array<f32, 32>;
@compute @workgroup_size(32)
fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) {
let q_pos = wid.x;
let head = wid.y;
let tid = lid.x;
let S = params.seq_len;
let H = params.num_heads;
let D = params.head_dim;
let scale = params.scale;
let q_base = (q_pos * H + head) * D;
// Each thread handles D elements with stride 32 (D=64→2, D=128→4 per thread)
// Pre-load Q into registers
var qc0: f32 = 0.0; var qc1: f32 = 0.0; var qc2: f32 = 0.0; var qc3: f32 = 0.0;
qc0 = q[q_base + tid];
if (tid + 32u < D) { qc1 = q[q_base + tid + 32u]; }
if (tid + 64u < D) { qc2 = q[q_base + tid + 64u]; }
if (tid + 96u < D) { qc3 = q[q_base + tid + 96u]; }
// Online softmax + V accumulation, parallelized across D dimension
var running_max: f32 = -1e30;
var running_sum: f32 = 0.0;
var a0: f32 = 0.0; var a1: f32 = 0.0; var a2: f32 = 0.0; var a3: f32 = 0.0;
for (var kp: u32 = 0u; kp < S; kp++) {
let k_base = (kp * H + head) * D;
// Partial dot product: each thread multiplies its D elements
var pd: f32 = qc0 * k[k_base + tid];
if (tid + 32u < D) { pd += qc1 * k[k_base + tid + 32u]; }
if (tid + 64u < D) { pd += qc2 * k[k_base + tid + 64u]; }
if (tid + 96u < D) { pd += qc3 * k[k_base + tid + 96u]; }
// Tree reduction: 32 → 1
wg_dot[tid] = pd;
workgroupBarrier();
if (tid < 16u) { wg_dot[tid] += wg_dot[tid + 16u]; }
workgroupBarrier();
if (tid < 8u) { wg_dot[tid] += wg_dot[tid + 8u]; }
workgroupBarrier();
if (tid < 4u) { wg_dot[tid] += wg_dot[tid + 4u]; }
workgroupBarrier();
if (tid < 2u) { wg_dot[tid] += wg_dot[tid + 2u]; }
workgroupBarrier();
if (tid == 0u) { wg_dot[0] += wg_dot[1]; }
workgroupBarrier();
let s = wg_dot[0] * scale;
// Online softmax (all threads have same s after reduction)
let old_max = running_max;
running_max = max(running_max, s);
let corr = exp(old_max - running_max);
running_sum = running_sum * corr;
a0 = a0 * corr; a1 = a1 * corr; a2 = a2 * corr; a3 = a3 * corr;
let w = exp(s - running_max);
running_sum += w;
// Accumulate weighted V for this thread's D elements
let v_base = (kp * H + head) * D;
a0 += w * v[v_base + tid];
if (tid + 32u < D) { a1 += w * v[v_base + tid + 32u]; }
if (tid + 64u < D) { a2 += w * v[v_base + tid + 64u]; }
if (tid + 96u < D) { a3 += w * v[v_base + tid + 96u]; }
}
// Write output
let inv_sum = 1.0 / running_sum;
let out_base = (q_pos * H + head) * D;
output[out_base + tid] = a0 * inv_sum;
if (tid + 32u < D) { output[out_base + tid + 32u] = a1 * inv_sum; }
if (tid + 64u < D) { output[out_base + tid + 64u] = a2 * inv_sum; }
if (tid + 96u < D) { output[out_base + tid + 96u] = a3 * inv_sum; }
}`,Dr=`
@group(0) @binding(0) var<storage, read_write> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<uniform> len: u32;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
if (i >= len) { return; }
a[i] += b[i];
}`,Ar=`
struct Params { H: u32, }
@group(0) @binding(0) var<storage, read> vision: array<f32>;
@group(0) @binding(1) var<storage, read> scatter_indices: array<u32>;
@group(0) @binding(2) var<storage, read_write> embeds: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) {
let vit_idx = wid.x;
let pos = scatter_indices[vit_idx];
let H = params.H;
let tid = lid.x;
for (var i = tid; i < H; i += 256u) {
embeds[pos * H + i] = vision[vit_idx * H + i];
}
}`,zr={gptq_matvec:k,gptq_matvec_f16:h,gptq_matvec_4t:pr,gptq_matvec_4t_f16:fr,gptq_splitk:x,reduce_splitk:q,bf16_matvec:y,rmsnorm:F,silu_mul:N,add:P,embedding:R,embed_from_argmax:K,argmax:B,topk_extract:D,kv_cache_store:z,gqa_attention_head:T,gqa_reduce:H,deltanet_recurrent:M,head_rmsnorm:G,head_rmsnorm_nogated:O,causal_conv1d:U,split:E,sigmoid_mul:S,fused_gate_up_silu:C,fused_gate_up_silu_f16:L,fused_gate_up_silu_4t:Q,fused_gate_up_silu_4t_f16:j,add_rmsnorm:V,add_rmsnorm_ro:W,three_way_add_rmsnorm:Y,norm_gptq_lite:or,norm_gptq_lite_noadd:nr,fused_sigmoid_gptq:X,fused_sigmoid_gptq_f16:$,fused_sigmoid_gptq_4t:er,fused_sigmoid_gptq_4t_f16:tr,fused_silu_gptq:Z,fused_silu_gptq_f16:J,fused_silu_gptq_4t:rr,fused_silu_gptq_4t_f16:ar,fused_addnorm_gate_up_silu:ur,rep_penalty:gr,gpu_sample:dr,append_token:A,fused_norm_gptq:ir,fused_norm_gptq_noadd:sr,fused_conv_deltanet_norm:I,fused_split_qknorm_kvstore:lr,gptq_matvec_4t_f16_sk:_r,gather_rows_bf16:cr,quantize_bf16_to_int4:wr,pack_f32_to_f16_pairs:br,gptq_matmul_b2:yr,gptq_matmul_b2_f16:Fr,gptq_matmul_b2_4t_f16:Nr,add_rmsnorm_b2:vr,add_rmsnorm_ro_b2:mr,three_way_add_rmsnorm_b2:kr,fused_gate_up_silu_b2_f16:hr,fused_silu_gptq_b2_f16:xr,fused_sigmoid_gptq_b2_f16:qr,vit_layernorm:Pr,vit_bf16_matvec_bias:Er,vit_gelu_tanh:Sr,vit_gelu:Rr,vit_rope:Kr,vit_attention:Br,vit_add:Dr,vit_scatter_embed:Ar},Mr=Object.freeze(Object.defineProperty({__proto__:null,SHADERS:zr,SHADER_FUSED_SPLIT_QKNORM_KVSTORE_FN:_},Symbol.toStringTag,{value:"Module"}));export{b as G,zr as S,Hr as _,_ as a,Mr as b,Tr as g};