(function(){const e=document.createElement("link").relList;if(e&&e.supports&&e.supports("modulepreload"))return;for(const r of document.querySelectorAll('link[rel="modulepreload"]'))t(r);new MutationObserver(r=>{for(const a of r)if(a.type==="childList")for(const i of a.addedNodes)i.tagName==="LINK"&&i.rel==="modulepreload"&&t(i)}).observe(document,{childList:!0,subtree:!0});function u(r){const a={};return r.integrity&&(a.integrity=r.integrity),r.referrerPolicy&&(a.referrerPolicy=r.referrerPolicy),r.crossOrigin==="use-credentials"?a.credentials="include":r.crossOrigin==="anonymous"?a.credentials="omit":a.credentials="same-origin",a}function t(r){if(r.ep)return;r.ep=!0;const a=u(r);fetch(r.href,a)}})();class b{constructor(){this.device=null,this.adapter=null,this.adapterInfo=null,this.pipelineCache=new Map,this.bufferCache=new Map,this.bindGroupCache=new Map}async init(){if(!navigator.gpu)throw new Error("WebGPU not supported");if(this.adapter=await navigator.gpu.requestAdapter({powerPreference:"high-performance"}),!this.adapter)throw new Error("No WebGPU adapter found");this.adapterInfo=await this.adapter.requestAdapterInfo?.()??{};const e={},u={maxBufferSize:4*1024*1024*1024,maxStorageBufferBindingSize:4*1024*1024*1024,maxComputeWorkgroupStorageSize:32768,maxComputeInvocationsPerWorkgroup:256,maxComputeWorkgroupSizeX:256,maxStorageBuffersPerShaderStage:10};for(const[r,a]of Object.entries(u))this.adapter.limits[r]!==void 0&&(e[r]=Math.min(a,this.adapter.limits[r]));const t=[];return this.adapter.features.has("shader-f16")&&t.push("shader-f16"),this.adapter.features.has("subgroups")&&t.push("subgroups"),this.device=await this.adapter.requestDevice({requiredLimits:e,requiredFeatures:t}),this.hasF16=this.device.features.has("shader-f16"),this.hasSubgroups=this.device.features.has("subgroups"),this.device.lost.then(r=>console.error("WebGPU device lost:",r)),this}createBuffer(e,u,t=GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_DST){const r=Math.ceil(u/4)*4,a=this.device.createBuffer({size:r,usage:t,label:e});return this.bufferCache.set(e,a),a}createBufferFromData(e,u,t=GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_DST|GPUBufferUsage.COPY_SRC){const r=this.createBuffer(e,u.byteLength,t);return this.device.queue.writeBuffer(r,0,u),r}createReadbackBuffer(e,u){const t=Math.ceil(u/4)*4;return this.device.createBuffer({size:t,usage:GPUBufferUsage.MAP_READ|GPUBufferUsage.COPY_DST,label:e+"_readback"})}getOrCreatePipeline(e,u,t="main"){if(this.pipelineCache.has(e))return this.pipelineCache.get(e);const r=this.device.createShaderModule({code:u,label:e}),a=this.device.createComputePipeline({layout:"auto",compute:{module:r,entryPoint:t},label:e});return this.pipelineCache.set(e,a),a}initTimestamps(){this.device.features.has("timestamp-query")&&(this._tsQuerySet=this.device.createQuerySet({type:"timestamp",count:2}),this._tsResolveBuf=this.device.createBuffer({size:16,usage:GPUBufferUsage.QUERY_RESOLVE|GPUBufferUsage.COPY_SRC}),this._tsReadBuf=this.device.createBuffer({size:16,usage:GPUBufferUsage.MAP_READ|GPUBufferUsage.COPY_DST}),this._tsEnabled=!0,this._tsResults=[])}beginBatch(){this._encoder=this.device.createCommandEncoder(),this._passCount=0,this.singlePassMode&&(this._singlePass=this._encoder.beginComputePass()),this._tsEnabled&&this._encoder.beginComputePass({timestampWrites:{querySet:this._tsQuerySet,beginningOfPassWriteIndex:0}}).end()}endBatch(){this._singlePass&&(this._singlePass.end(),this._singlePass=null),this._tsEnabled&&this._encoder&&(this._encoder.beginComputePass({timestampWrites:{querySet:this._tsQuerySet,beginningOfPassWriteIndex:1}}).end(),this._encoder.resolveQuerySet(this._tsQuerySet,0,2,this._tsResolveBuf,0),this._encoder.copyBufferToBuffer(this._tsResolveBuf,0,this._tsReadBuf,0,16)),this._encoder&&(this.device.queue.submit([this._encoder.finish()]),this._encoder=null)}async readTimestamp(){if(!this._tsEnabled)return null;await this._tsReadBuf.mapAsync(GPUMapMode.READ);const e=new BigInt64Array(this._tsReadBuf.getMappedRange().slice(0));this._tsReadBuf.unmap();const t=Number(e[1]-e[0])/1e6;return this._tsResults.push(t),t}copyBuffer(e,u,t,r=0,a=0){if(this._singlePass){this._singlePass.end(),this._encoder.copyBufferToBuffer(e,r,u,a,t),this._singlePass=this._encoder.beginComputePass();return}const i=this._encoder||this.device.createCommandEncoder();i.copyBufferToBuffer(e,r,u,a,t),this._encoder||this.device.queue.submit([i.finish()])}startRecording(){this._recording=[]}stopRecording(){const e=this._recording;return this._recording=null,e}replay(e,u){if(u)for(const r of u)this.device.queue.writeBuffer(r.buffer,r.offset,r.data,r.dataOffset,r.size);const t=this._encoder;for(let r=0;r({pipeline:r.pipeline,bindGroup:r.bindGroups[0],wgX:r.workgroupsX,wgY:r.workgroupsY||1}))}),this._singlePass){this._passCount!==void 0&&this._passCount++;for(const r of e){this._singlePass.setPipeline(r.pipeline);for(let a=0;a({binding:a,resource:{buffer:r}}))})}createBindGroupWithOffsets(e,u,t){return this.device.createBindGroup({layout:e.getBindGroupLayout(u),entries:t.map((r,a)=>({binding:a,resource:r.buffer?{buffer:r.buffer,offset:r.offset||0,size:r.size}:{buffer:r}}))})}getCachedBindGroup(e,u,t,r){let a=this.bindGroupCache.get(e);return a||(a=this.createBindGroup(u,t,r),this.bindGroupCache.set(e,a)),a}destroy(){for(const e of this.bufferCache.values())e.destroy();this.bufferCache.clear(),this.bindGroupCache.clear(),this.device?.destroy()}}const Tr=Object.freeze(Object.defineProperty({__proto__:null,GPUContext:b},Symbol.toStringTag,{value:"Module"})),v="modulepreload",m=function(d){return"/"+d},f={},Hr=function(e,u,t){let r=Promise.resolve();if(u&&u.length>0){let n=function(s){return Promise.all(s.map(l=>Promise.resolve(l).then(p=>({status:"fulfilled",value:p}),p=>({status:"rejected",reason:p}))))};document.getElementsByTagName("link");const i=document.querySelector("meta[property=csp-nonce]"),o=i?.nonce||i?.getAttribute("nonce");r=n(u.map(s=>{if(s=m(s),s in f)return;f[s]=!0;const l=s.endsWith(".css"),p=l?'[rel="stylesheet"]':"";if(document.querySelector(`link[href="${s}"]${p}`))return;const g=document.createElement("link");if(g.rel=l?"stylesheet":v,l||(g.as="script"),g.crossOrigin="",g.href=s,o&&g.setAttribute("nonce",o),document.head.appendChild(g),l)return new Promise((c,w)=>{g.addEventListener("load",c),g.addEventListener("error",()=>w(new Error(`Unable to preload CSS for ${s}`)))})}))}function a(i){const o=new Event("vite:preloadError",{cancelable:!0});if(o.payload=i,window.dispatchEvent(o),!o.defaultPrevented)throw i}return r.then(i=>{for(const o of i||[])o.status==="rejected"&&a(o.reason);return e().catch(a)})},k=` struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var qweight: array; @group(0) @binding(2) var scales: array; @group(0) @binding(3) var output: array; @group(0) @binding(4) var params: Params; fn dq(w: u32, bk: u32) -> f32 { return dot(vec4(input[bk], input[bk+1u], input[bk+2u], input[bk+3u]), vec4(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) + dot(vec4(input[bk+4u], input[bk+5u], input[bk+6u], input[bk+7u]), vec4(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u) { let col = gid.x; let K = params.K; let N = params.N; let gs = params.group_size; let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; if (col >= N) { return; } var sum: f32 = 0.0; for (var g: u32 = 0u; g < num_groups; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off + N]; let w2 = qweight[off + 2u * N]; let w3 = qweight[off + 3u * N]; let bk = pr * 8u; sum += s * (dq(w0, bk) + dq(w1, bk+8u) + dq(w2, bk+16u) + dq(w3, bk+24u)); } } output[col] = sum; }`,h=` enable f16; struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var qweight: array; @group(0) @binding(2) var scales: array; @group(0) @binding(3) var output: array; @group(0) @binding(4) var params: Params; fn dq16(w: u32, bk: u32) -> f32 { let i0 = vec4(f16(input[bk]), f16(input[bk+1u]), f16(input[bk+2u]), f16(input[bk+3u])); let w0 = vec4(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); let i1 = vec4(f16(input[bk+4u]), f16(input[bk+5u]), f16(input[bk+6u]), f16(input[bk+7u])); let w1 = vec4(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); return f32(dot(i0, w0)) + f32(dot(i1, w1)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u) { let col = gid.x; let K = params.K; let N = params.N; let gs = params.group_size; let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; if (col >= N) { return; } var sum: f32 = 0.0; for (var g: u32 = 0u; g < num_groups; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off + N]; let w2 = qweight[off + 2u * N]; let w3 = qweight[off + 3u * N]; let bk = pr * 8u; sum += s * (dq16(w0, bk) + dq16(w1, bk+8u) + dq16(w2, bk+16u) + dq16(w3, bk+24u)); } } output[col] = sum; }`,x=` struct Params { K: u32, N: u32, group_size: u32, num_splits: u32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var qweight: array; @group(0) @binding(2) var scales: array; @group(0) @binding(3) var partials: array; @group(0) @binding(4) var params: Params; fn dq(w: u32, bk: u32) -> f32 { return dot(vec4(input[bk], input[bk+1u], input[bk+2u], input[bk+3u]), vec4(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) + dot(vec4(input[bk+4u], input[bk+5u], input[bk+6u], input[bk+7u]), vec4(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u, @builtin(workgroup_id) wgid: vec3u) { let col = gid.x; let split_id = wgid.y; let K = params.K; let N = params.N; let gs = params.group_size; let ns = params.num_splits; let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; if (col >= N) { return; } let groups_per_split = num_groups / ns; let g_start = split_id * groups_per_split; let g_end = select(g_start + groups_per_split, num_groups, split_id >= ns - 1u); var sum: f32 = 0.0; for (var g = g_start; g < g_end; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off + N]; let w2 = qweight[off + 2u * N]; let w3 = qweight[off + 3u * N]; let bk = pr * 8u; sum += s * (dq(w0, bk) + dq(w1, bk+8u) + dq(w2, bk+16u) + dq(w3, bk+24u)); } } partials[split_id * N + col] = sum; }`,q=` struct Params { N: u32, num_splits: u32, } @group(0) @binding(0) var partials: array; @group(0) @binding(1) var output: array; @group(0) @binding(2) var params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3u) { let col = gid.x; if (col >= params.N) { return; } var sum: f32 = 0.0; for (var s = 0u; s < params.num_splits; s++) { sum += partials[s * params.N + col]; } output[col] = sum; }`,y=` struct Params { K: u32, N: u32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var weight: array; @group(0) @binding(2) var output: array; @group(0) @binding(3) var params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3u) { let col = gid.x; if (col >= params.N) { return; } var sum: f32 = 0.0; let K = params.K; let base = col * K / 2u; // Process 4 BF16 values (2 u32 reads) per iteration for (var k: u32 = 0u; k < K; k += 4u) { let p0 = weight[base + k / 2u]; let p1 = weight[base + k / 2u + 1u]; let w0 = bitcast((p0 & 0xFFFFu) << 16u); let w1 = bitcast(p0 & 0xFFFF0000u); let w2 = bitcast((p1 & 0xFFFFu) << 16u); let w3 = bitcast(p1 & 0xFFFF0000u); sum += input[k] * w0 + input[k + 1u] * w1 + input[k + 2u] * w2 + input[k + 3u] * w3; } output[col] = sum; }`,F=` struct Params { N: u32, eps: f32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var weight: array; @group(0) @binding(2) var output: array; @group(0) @binding(3) var params: Params; fn unpack_bf16(packed: u32, idx: u32) -> f32 { let bits = (packed >> (idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } var wg_scratch: array; @compute @workgroup_size(256) fn main(@builtin(local_invocation_id) lid: vec3u) { let tid = lid.x; let N = params.N; // Parallel sum of squares var ss: f32 = 0.0; for (var i = tid; i < N; i += 256u) { ss += input[i] * input[i]; } wg_scratch[tid] = ss; workgroupBarrier(); // Tree reduction for (var s = 128u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps); // Parallel normalize for (var i = tid; i < N; i += 256u) { let w = unpack_bf16(weight[i / 2u], i % 2u); output[i] = input[i] * rms * (1.0 + w); } }`,N=` struct Params { N: u32, } @group(0) @binding(0) var gate: array; @group(0) @binding(1) var up: array; @group(0) @binding(2) var output: array; @group(0) @binding(3) var params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3u) { let i = gid.x; if (i >= params.N) { return; } let x = gate[i]; output[i] = (x / (1.0 + exp(-x))) * up[i]; }`,P=` struct Params { N: u32, } @group(0) @binding(0) var a: array; @group(0) @binding(1) var b: array; @group(0) @binding(2) var params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3u) { let i = gid.x; if (i >= params.N) { return; } a[i] = a[i] + b[i]; }`,E=` struct Params { N: u32, num_heads: u32, head_dim: u32, } @group(0) @binding(0) var src: array; @group(0) @binding(1) var dst_a: array; @group(0) @binding(2) var dst_b: array; @group(0) @binding(3) var params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3u) { let i = gid.x; if (i >= params.N) { return; } let hd = params.head_dim; let head = i / hd; let d = i % hd; dst_a[i] = src[head * hd * 2u + d]; dst_b[i] = src[head * hd * 2u + hd + d]; }`,S=` struct Params { N: u32, } @group(0) @binding(0) var x: array; @group(0) @binding(1) var gate: array; @group(0) @binding(2) var output: array; @group(0) @binding(3) var params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3u) { let i = gid.x; if (i >= params.N) { return; } output[i] = x[i] / (1.0 + exp(-gate[i])); }`,R=` struct Params { token_id: u32, dim: u32, } @group(0) @binding(0) var embeddings: array; @group(0) @binding(1) var output: array; @group(0) @binding(2) var params: Params; fn unpack_bf16(packed: u32, idx: u32) -> f32 { let bits = (packed >> (idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3u) { let i = gid.x; if (i >= params.dim) { return; } let flat = params.token_id * params.dim + i; output[i] = unpack_bf16(embeddings[flat / 2u], flat % 2u); }`,K=` struct ArgmaxResult { idx: u32, val: f32, } struct Params { dim: u32, } @group(0) @binding(0) var embeddings: array; @group(0) @binding(1) var output: array; @group(0) @binding(2) var argmax_result: ArgmaxResult; @group(0) @binding(3) var params: Params; fn unpack_bf16(packed: u32, idx: u32) -> f32 { let bits = (packed >> (idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3u) { let i = gid.x; if (i >= params.dim) { return; } let flat = argmax_result.idx * params.dim + i; output[i] = unpack_bf16(embeddings[flat / 2u], flat % 2u); }`,B=` struct Params { N: u32, } struct Result { idx: u32, val: f32, } @group(0) @binding(0) var logits: array; @group(0) @binding(1) var result: Result; @group(0) @binding(2) var params: Params; var s_val: array; var s_idx: array; @compute @workgroup_size(256) fn main(@builtin(local_invocation_id) lid: vec3u) { let tid = lid.x; let N = params.N; // Each thread scans its portion var best_val: f32 = -1e30; var best_idx: u32 = 0u; for (var i = tid; i < N; i += 256u) { if (logits[i] > best_val) { best_val = logits[i]; best_idx = i; } } s_val[tid] = best_val; s_idx[tid] = best_idx; workgroupBarrier(); // Tree reduction for (var s = 128u; s > 0u; s >>= 1u) { if (tid < s && s_val[tid + s] > s_val[tid]) { s_val[tid] = s_val[tid + s]; s_idx[tid] = s_idx[tid + s]; } workgroupBarrier(); } if (tid == 0u) { result.idx = s_idx[0]; result.val = s_val[0]; } }`,D=` struct Params { N: u32, } @group(0) @binding(0) var logits: array; @group(0) @binding(1) var result: array; @group(0) @binding(2) var params: Params; @compute @workgroup_size(256) fn main(@builtin(local_invocation_id) lid: vec3u) { let tid = lid.x; let N = params.N; // Each thread scans its portion, keeping best 1 var best_val: f32 = -1e30; var best_idx: u32 = 0u; for (var i = tid; i < N; i += 256u) { if (logits[i] > best_val) { best_val = logits[i]; best_idx = i; } } // Output (idx, val) pair — 256 candidates total result[tid * 2u] = best_idx; result[tid * 2u + 1u] = bitcast(best_val); }`,A=` struct ArgmaxResult { idx: u32, val: f32, } struct Params { recent_count: u32, history_slot: u32, } @group(0) @binding(0) var argmax_result: ArgmaxResult; @group(0) @binding(1) var recent_tokens: array; @group(0) @binding(2) var token_history: array; @group(0) @binding(3) var params: Params; @compute @workgroup_size(1) fn main() { let tok = argmax_result.idx; recent_tokens[params.recent_count] = tok; token_history[params.history_slot] = tok; }`,z=` struct Params { position: u32, num_kv_heads: u32, head_dim: u32, } @group(0) @binding(0) var k_proj: array; @group(0) @binding(1) var v_proj: array; @group(0) @binding(2) var k_cache: array; @group(0) @binding(3) var v_cache: array; @group(0) @binding(4) var params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3u) { let i = gid.x; let total = params.num_kv_heads * params.head_dim; if (i >= total) { return; } let offset = params.position * total + i; k_cache[offset] = k_proj[i]; v_cache[offset] = v_proj[i]; }`,T=` struct Params { seq_len: u32, head_dim: u32, num_kv_heads: u32, num_q_heads: u32, heads_per_kv: u32, num_splits: u32, _pad0: u32, _pad1: u32, } @group(0) @binding(0) var q_full: array; @group(0) @binding(1) var k_cache: array; @group(0) @binding(2) var v_cache: array; @group(0) @binding(3) var output: array; @group(0) @binding(4) var params: Params; var s_dot: array; @compute @workgroup_size(256) fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) { let h = wg.x; let split = wg.y; let tid = lid.x; let hd = params.head_dim; let kv_head = h / params.heads_per_kv; let q_offset = h * hd; let kv_stride = params.num_kv_heads * hd; let kv_offset = kv_head * hd; let scale = 1.0 / sqrt(f32(hd)); let ns = params.num_splits; // Compute this split's range of KV positions let chunk = (params.seq_len + ns - 1u) / ns; let pos_start = split * chunk; let pos_end = min(pos_start + chunk, params.seq_len); // Load Q (already RoPE'd) let q_val = q_full[q_offset + tid]; // Online softmax over this split's positions var running_max: f32 = -1e30; var running_sum: f32 = 0.0; var running_out: f32 = 0.0; for (var pos = pos_start; pos < pos_end; pos++) { let k_base = pos * kv_stride + kv_offset; // K is already RoPE'd in cache let k_val = k_cache[k_base + tid]; s_dot[tid] = q_val * k_val; workgroupBarrier(); for (var s = 128u; s > 0u; s >>= 1u) { if (tid < s) { s_dot[tid] += s_dot[tid + s]; } workgroupBarrier(); } let score = s_dot[0] * scale; let v_val = v_cache[pos * kv_stride + kv_offset + tid]; let new_max = max(running_max, score); let correction = exp(running_max - new_max); running_sum = running_sum * correction + exp(score - new_max); running_out = running_out * correction + exp(score - new_max) * v_val; running_max = new_max; } if (ns == 1u) { // Single split: write final output directly (no reduction needed) if (running_sum > 0.0) { output[q_offset + tid] = running_out / running_sum; } else { output[q_offset + tid] = 0.0; } } else { // Multi-split: write partial results for reduction kernel // Layout: [num_heads * num_splits * (head_dim + 2)] // Per split: [partial_out[head_dim], log_sum_exp, running_max] let stride = hd + 2u; let base = (h * ns + split) * stride; if (running_sum > 0.0) { output[base + tid] = running_out / running_sum; } else { output[base + tid] = 0.0; } if (tid == 0u) { output[base + hd] = log(running_sum) + running_max; // LSE output[base + hd + 1u] = running_max; } } }`,H=` struct Params { head_dim: u32, num_splits: u32, num_heads: u32, _pad: u32, } @group(0) @binding(0) var partials: array; @group(0) @binding(1) var output: array; @group(0) @binding(2) var params: Params; @compute @workgroup_size(256) fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) { let h = wg.x; let tid = lid.x; let hd = params.head_dim; let ns = params.num_splits; let stride = hd + 2u; // Find global max LSE across splits var global_max_lse: f32 = -1e30; for (var s: u32 = 0u; s < ns; s++) { let base = (h * ns + s) * stride; let lse = partials[base + hd]; global_max_lse = max(global_max_lse, lse); } // Weighted combination using LSE rescaling var acc: f32 = 0.0; var weight_sum: f32 = 0.0; for (var s: u32 = 0u; s < ns; s++) { let base = (h * ns + s) * stride; let lse = partials[base + hd]; let w = exp(lse - global_max_lse); acc += w * partials[base + tid]; if (tid == 0u) { weight_sum += w; } } // Broadcast weight_sum from thread 0 // (all threads computed same weight_sum since lse is scalar, but only tid==0 accumulated) // Actually all threads can compute weight_sum independently var ws: f32 = 0.0; for (var s: u32 = 0u; s < ns; s++) { let base = (h * ns + s) * stride; let lse = partials[base + hd]; ws += exp(lse - global_max_lse); } output[h * hd + tid] = acc / ws; }`,M=` struct Params { num_heads: u32, key_dim: u32, value_dim: u32, } @group(0) @binding(0) var qkv: array; @group(0) @binding(1) var state: array; @group(0) @binding(2) var output: array; @group(0) @binding(3) var a_proj: array; @group(0) @binding(4) var b_proj: array; @group(0) @binding(5) var A_log: array; @group(0) @binding(6) var dt_bias: array; @group(0) @binding(7) var params: Params; fn sigmoid(x: f32) -> f32 { return 1.0 / (1.0 + exp(-x)); } fn softplus(x: f32) -> f32 { if (x > 20.0) { return x; } return log(1.0 + exp(x)); } fn unpack_bf16(packed: u32, idx: u32) -> f32 { let bits = (packed >> (idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } @compute @workgroup_size(256) fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) { let h = wg.x; let tid = lid.x; let nh = params.num_heads; let kd = params.key_dim; let vd = params.value_dim; if (h >= nh) { return; } let qh_off = h * kd; let kh_off = nh * kd + h * kd; let vh_off = 2u * nh * kd + h * vd; let state_off = h * kd * vd; // Compute Q/K L2 norms (all threads compute redundantly — cheap) var q_ss: f32 = 0.0; for (var ki: u32 = 0u; ki < kd; ki++) { q_ss += qkv[qh_off + ki] * qkv[qh_off + ki]; } let q_inv_norm = 1.0 / max(sqrt(q_ss), 1e-6); var k_ss: f32 = 0.0; for (var ki: u32 = 0u; ki < kd; ki++) { k_ss += qkv[kh_off + ki] * qkv[kh_off + ki]; } let k_inv_norm = 1.0 / max(sqrt(k_ss), 1e-6); // Decay and gate (same for all threads in this head) let a_log_val = unpack_bf16(A_log[h / 2u], h % 2u); let dt_bias_val = unpack_bf16(dt_bias[h / 2u], h % 2u); let decay = exp(-exp(a_log_val) * softplus(a_proj[h] + dt_bias_val)); let beta = sigmoid(b_proj[h]); // Each thread handles value dimensions in stride for (var vi = tid; vi < vd; vi += 256u) { // Step 1: Decay state column for (var ki: u32 = 0u; ki < kd; ki++) { let idx = state_off + ki * vd + vi; state[idx] *= decay; } // Step 2: kv_mem = S^T @ k_normalized (for this vi) var kv_mem: f32 = 0.0; for (var ki: u32 = 0u; ki < kd; ki++) { kv_mem += state[state_off + ki * vd + vi] * qkv[kh_off + ki] * k_inv_norm; } // Step 3: delta rule update let v_val = qkv[vh_off + vi]; let delta = (v_val - kv_mem) * beta; for (var ki: u32 = 0u; ki < kd; ki++) { let idx = state_off + ki * vd + vi; state[idx] += qkv[kh_off + ki] * k_inv_norm * delta; } // Step 4: Output = S^T @ (q_normalized * scale) let scale = 1.0 / sqrt(f32(kd)); var o_val: f32 = 0.0; for (var ki: u32 = 0u; ki < kd; ki++) { o_val += state[state_off + ki * vd + vi] * qkv[qh_off + ki] * q_inv_norm * scale; } output[h * vd + vi] = o_val; } }`,G=` struct Params { num_heads: u32, head_dim: u32, eps: f32, } @group(0) @binding(0) var x: array; @group(0) @binding(1) var weight: array; @group(0) @binding(2) var params: Params; fn unpack_bf16(packed: u32, idx: u32) -> f32 { let bits = (packed >> (idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } var wg_scratch: array; @compute @workgroup_size(256) fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) { let h = wg.x; let tid = lid.x; let hd = params.head_dim; if (h >= params.num_heads) { return; } let off = h * hd; // Parallel sum of squares var ss: f32 = 0.0; for (var i = tid; i < hd; i += 256u) { ss += x[off + i] * x[off + i]; } wg_scratch[tid] = ss; workgroupBarrier(); for (var s = 128u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } let rms = 1.0 / sqrt(wg_scratch[0] / f32(hd) + params.eps); // Parallel normalize for (var i = tid; i < hd; i += 256u) { let w = unpack_bf16(weight[i / 2u], i % 2u); x[off + i] = x[off + i] * rms * (1.0 + w); } }`,O=` struct Params { num_heads: u32, head_dim: u32, eps: f32, } @group(0) @binding(0) var x: array; @group(0) @binding(1) var weight: array; @group(0) @binding(2) var params: Params; fn unpack_bf16(packed: u32, idx: u32) -> f32 { let bits = (packed >> (idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } var wg_scratch: array; @compute @workgroup_size(256) fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) { let h = wg.x; let tid = lid.x; let hd = params.head_dim; if (h >= params.num_heads) { return; } let off = h * hd; var ss: f32 = 0.0; for (var i = tid; i < hd; i += 256u) { ss += x[off + i] * x[off + i]; } wg_scratch[tid] = ss; workgroupBarrier(); for (var s = 128u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } let rms = 1.0 / sqrt(wg_scratch[0] / f32(hd) + params.eps); for (var i = tid; i < hd; i += 256u) { let w = unpack_bf16(weight[i / 2u], i % 2u); x[off + i] = x[off + i] * rms * w; } }`,U=` struct Params { channels: u32, } @group(0) @binding(0) var x: array; @group(0) @binding(1) var hist: array; @group(0) @binding(2) var weight: array; @group(0) @binding(3) var params: Params; fn unpack_bf16(packed: u32, idx: u32) -> f32 { let bits = (packed >> (idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3u) { let c = gid.x; if (c >= params.channels) { return; } let ch = params.channels; let h0 = hist[c]; let h1 = hist[ch + c]; let h2 = hist[2u * ch + c]; let cur = x[c]; let w_base = c * 4u; let w0 = unpack_bf16(weight[w_base / 2u], w_base % 2u); let w1 = unpack_bf16(weight[(w_base + 1u) / 2u], (w_base + 1u) % 2u); let w2 = unpack_bf16(weight[(w_base + 2u) / 2u], (w_base + 2u) % 2u); let w3 = unpack_bf16(weight[(w_base + 3u) / 2u], (w_base + 3u) % 2u); let conv_out = w0 * h0 + w1 * h1 + w2 * h2 + w3 * cur; x[c] = conv_out / (1.0 + exp(-conv_out)); hist[c] = h1; hist[ch + c] = h2; hist[2u * ch + c] = cur; }`,I=` struct Params { num_heads: u32, key_dim: u32, value_dim: u32, total_channels: u32, eps: f32, hidden_size: u32, num_value_heads: u32, } @group(0) @binding(0) var qkv: array; @group(0) @binding(1) var hist: array; @group(0) @binding(2) var conv_weight: array; @group(0) @binding(3) var state: array; @group(0) @binding(4) var output: array; @group(0) @binding(5) var hidden_input: array; @group(0) @binding(6) var ab_weight: array; @group(0) @binding(7) var A_log: array; @group(0) @binding(8) var dt_bias: array; @group(0) @binding(9) var norm_weight: array; @group(0) @binding(10) var params: Params; fn unpack_bf16(packed: u32, idx: u32) -> f32 { let bits = (packed >> (idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } fn sigmoid_f(x: f32) -> f32 { return 1.0 / (1.0 + exp(-x)); } fn softplus_f(x: f32) -> f32 { if (x > 20.0) { return x; } return log(1.0 + exp(x)); } var wg_scratch: array; @compute @workgroup_size(128) fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) { let h = wg.x; let tid = lid.x; let nh = params.num_heads; let kd = params.key_dim; let vd = params.value_dim; let ch = params.total_channels; let H = params.hidden_size; let nhv = params.num_value_heads; let vpk = nhv / nh; // value heads per key head (1 for 2B, 2 for 4B) let evd = vpk * vd; // effective value dim per key head if (h >= nh) { return; } // Alpha/beta computed per value head inside the vpk loop (Phase 2). // When H > 0: inline BF16 dot product (a_proj/b_proj are BF16 in merged ab_weight) // When H == 0: pre-computed via GPTQ dispatches, alpha/beta read per-value-head // ── Phase 1: Conv1d for this head's channels ── // Head h has 3 blocks: Q[h*kd..], K[nh*kd + h*kd..], V[2*nh*kd + h*evd..] // evd = vpk * vd (effective value dim per key head, handles asymmetric heads) let cpb = kd + kd + evd; // channels per key head block for (var i = tid; i < cpb; i += 128u) { // Map local channel to global channel index var c: u32; if (i < kd) { c = h * kd + i; } else if (i < 2u * kd) { c = nh * kd + h * kd + (i - kd); } else { c = 2u * nh * kd + h * evd + (i - 2u * kd); } // Conv1d kernel=4 with history let h0 = hist[c]; let h1 = hist[ch + c]; let h2 = hist[2u * ch + c]; let cur = qkv[c]; let wb = c * 4u; let w0 = unpack_bf16(conv_weight[wb / 2u], wb % 2u); let w1 = unpack_bf16(conv_weight[(wb + 1u) / 2u], (wb + 1u) % 2u); let w2 = unpack_bf16(conv_weight[(wb + 2u) / 2u], (wb + 2u) % 2u); let w3 = unpack_bf16(conv_weight[(wb + 3u) / 2u], (wb + 3u) % 2u); let conv_out = w0 * h0 + w1 * h1 + w2 * h2 + w3 * cur; qkv[c] = conv_out / (1.0 + exp(-conv_out)); // SiLU hist[c] = h1; hist[ch + c] = h2; hist[2u * ch + c] = cur; } workgroupBarrier(); // ── Phase 2+3: DeltaNet recurrent + RMSNorm, looped over vpk value heads ── // Q/K are shared across value heads within this key head let qh_off = h * kd; let kh_off = nh * kd + h * kd; // Q/K L2 norms (all threads compute redundantly, shared across value heads) var q_ss: f32 = 0.0; for (var ki: u32 = 0u; ki < kd; ki++) { q_ss += qkv[qh_off + ki] * qkv[qh_off + ki]; } let q_inv_norm = 1.0 / max(sqrt(q_ss), 1e-6); var k_ss: f32 = 0.0; for (var ki: u32 = 0u; ki < kd; ki++) { k_ss += qkv[kh_off + ki] * qkv[kh_off + ki]; } let k_inv_norm = 1.0 / max(sqrt(k_ss), 1e-6); let qk_scale = 1.0 / sqrt(f32(kd)); // Loop over vpk value heads per key head (1 for 2B, 2 for 4B) for (var vhi: u32 = 0u; vhi < vpk; vhi++) { let vh = h * vpk + vhi; // global value head index // Get per-value-head alpha/beta var alpha: f32; var beta_raw: f32; if (H == 0u) { // Quantized path: alpha/beta pre-computed per value head alpha = hidden_input[vh]; beta_raw = bitcast(ab_weight[vh]); } else { // BF16 inline: dot product per value head let half_H = H / 2u; let a_base = vh * half_H; let b_base = (nhv + vh) * half_H; var a_sum: f32 = 0.0; var b_sum: f32 = 0.0; for (var k = tid * 4u; k < H; k += 512u) { let hv0 = hidden_input[k]; let hv1 = hidden_input[k + 1u]; let hv2 = hidden_input[k + 2u]; let hv3 = hidden_input[k + 3u]; let ap0 = ab_weight[a_base + k / 2u]; let ap1 = ab_weight[a_base + k / 2u + 1u]; a_sum += hv0 * bitcast((ap0 & 0xFFFFu) << 16u) + hv1 * bitcast(ap0 & 0xFFFF0000u) + hv2 * bitcast((ap1 & 0xFFFFu) << 16u) + hv3 * bitcast(ap1 & 0xFFFF0000u); let bp0 = ab_weight[b_base + k / 2u]; let bp1 = ab_weight[b_base + k / 2u + 1u]; b_sum += hv0 * bitcast((bp0 & 0xFFFFu) << 16u) + hv1 * bitcast(bp0 & 0xFFFF0000u) + hv2 * bitcast((bp1 & 0xFFFFu) << 16u) + hv3 * bitcast(bp1 & 0xFFFF0000u); } wg_scratch[tid] = a_sum; workgroupBarrier(); for (var s = 64u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } alpha = wg_scratch[0]; wg_scratch[tid] = b_sum; workgroupBarrier(); for (var s = 64u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } beta_raw = wg_scratch[0]; } // Per-value-head decay and gate let a_log_val = unpack_bf16(A_log[vh / 2u], vh % 2u); let dt_bias_val = unpack_bf16(dt_bias[vh / 2u], vh % 2u); let decay = exp(-exp(a_log_val) * softplus_f(alpha + dt_bias_val)); let beta = sigmoid_f(beta_raw); // Per-value-head V offset, state offset, output offset let v_off = 2u * nh * kd + vh * vd; let s_off = vh * kd * vd; let o_off = vh * vd; // DeltaNet recurrent update for this value head for (var vi = tid; vi < vd; vi += 128u) { for (var ki: u32 = 0u; ki < kd; ki++) { state[s_off + ki * vd + vi] *= decay; } var kv_mem: f32 = 0.0; for (var ki: u32 = 0u; ki < kd; ki++) { kv_mem += state[s_off + ki * vd + vi] * qkv[kh_off + ki] * k_inv_norm; } let v_val = qkv[v_off + vi]; let delta = (v_val - kv_mem) * beta; for (var ki: u32 = 0u; ki < kd; ki++) { state[s_off + ki * vd + vi] += qkv[kh_off + ki] * k_inv_norm * delta; } var o_val: f32 = 0.0; for (var ki: u32 = 0u; ki < kd; ki++) { o_val += state[s_off + ki * vd + vi] * qkv[qh_off + ki] * q_inv_norm * qk_scale; } output[o_off + vi] = o_val; } workgroupBarrier(); // ── RMSNorm for this value head (nogated, w formula) ── // norm_weight is shared across all value heads (just vd weights) var ss: f32 = 0.0; for (var i = tid; i < vd; i += 128u) { ss += output[o_off + i] * output[o_off + i]; } wg_scratch[tid] = ss; workgroupBarrier(); for (var s = 64u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } let rms = 1.0 / sqrt(wg_scratch[0] / f32(vd) + params.eps); for (var i = tid; i < vd; i += 128u) { let w = unpack_bf16(norm_weight[i / 2u], i % 2u); output[o_off + i] = output[o_off + i] * rms * w; } workgroupBarrier(); } }`,C=` struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var qweight_gate: array; @group(0) @binding(2) var scales_gate: array; @group(0) @binding(3) var qweight_up: array; @group(0) @binding(4) var scales_up: array; @group(0) @binding(5) var output: array; @group(0) @binding(6) var params: Params; fn dq_in(w: u32, bk: u32) -> f32 { return dot(vec4(input[bk], input[bk+1u], input[bk+2u], input[bk+3u]), vec4(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) + dot(vec4(input[bk+4u], input[bk+5u], input[bk+6u], input[bk+7u]), vec4(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u) { let col = gid.x; let K = params.K; let N = params.N; let gs = params.group_size; let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; if (col >= N) { return; } var gate_sum: f32 = 0.0; var up_sum: f32 = 0.0; for (var grp: u32 = 0u; grp < num_groups; grp++) { let sf = grp * N + col; let sg = unpack2x16float(scales_gate[sf >> 1u])[sf & 1u]; let su = unpack2x16float(scales_up[sf >> 1u])[sf & 1u]; let gstart = grp * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let gw0 = qweight_gate[off]; let uw0 = qweight_up[off]; let gw1 = qweight_gate[off+N]; let uw1 = qweight_up[off+N]; let gw2 = qweight_gate[off+2u*N]; let uw2 = qweight_up[off+2u*N]; let gw3 = qweight_gate[off+3u*N]; let uw3 = qweight_up[off+3u*N]; let bk = pr * 8u; gate_sum += sg * (dq_in(gw0,bk) + dq_in(gw1,bk+8u) + dq_in(gw2,bk+16u) + dq_in(gw3,bk+24u)); up_sum += su * (dq_in(uw0,bk) + dq_in(uw1,bk+8u) + dq_in(uw2,bk+16u) + dq_in(uw3,bk+24u)); } } output[col] = (gate_sum / (1.0 + exp(-gate_sum))) * up_sum; }`,L=` enable f16; struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var qweight_gate: array; @group(0) @binding(2) var scales_gate: array; @group(0) @binding(3) var qweight_up: array; @group(0) @binding(4) var scales_up: array; @group(0) @binding(5) var output: array; @group(0) @binding(6) var params: Params; fn dq16(w: u32, bk: u32) -> f32 { let i0 = vec4(f16(input[bk]), f16(input[bk+1u]), f16(input[bk+2u]), f16(input[bk+3u])); let w0 = vec4(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); let i1 = vec4(f16(input[bk+4u]), f16(input[bk+5u]), f16(input[bk+6u]), f16(input[bk+7u])); let w1 = vec4(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); return f32(dot(i0, w0)) + f32(dot(i1, w1)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u) { let col = gid.x; let K = params.K; let N = params.N; let gs = params.group_size; let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; if (col >= N) { return; } var gate_sum: f32 = 0.0; var up_sum: f32 = 0.0; for (var grp: u32 = 0u; grp < num_groups; grp++) { let sf = grp * N + col; let sg = unpack2x16float(scales_gate[sf >> 1u])[sf & 1u]; let su = unpack2x16float(scales_up[sf >> 1u])[sf & 1u]; let gstart = grp * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let gw0 = qweight_gate[off]; let uw0 = qweight_up[off]; let gw1 = qweight_gate[off+N]; let uw1 = qweight_up[off+N]; let gw2 = qweight_gate[off+2u*N]; let uw2 = qweight_up[off+2u*N]; let gw3 = qweight_gate[off+3u*N]; let uw3 = qweight_up[off+3u*N]; let bk = pr * 8u; gate_sum += sg * (dq16(gw0,bk) + dq16(gw1,bk+8u) + dq16(gw2,bk+16u) + dq16(gw3,bk+24u)); up_sum += su * (dq16(uw0,bk) + dq16(uw1,bk+8u) + dq16(uw2,bk+16u) + dq16(uw3,bk+24u)); } } output[col] = (gate_sum / (1.0 + exp(-gate_sum))) * up_sum; }`,Q=` struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var qweight_gate: array; @group(0) @binding(2) var scales_gate: array; @group(0) @binding(3) var qweight_up: array; @group(0) @binding(4) var scales_up: array; @group(0) @binding(5) var output: array; @group(0) @binding(6) var params: Params; var partial: array; fn dq_in(w: u32, bk: u32) -> f32 { return dot(vec4(input[bk], input[bk+1u], input[bk+2u], input[bk+3u]), vec4(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) + dot(vec4(input[bk+4u], input[bk+5u], input[bk+6u], input[bk+7u]), vec4(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); } @compute @workgroup_size(32) fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { let tid = lid.x; let col = wid.x * 8u + (tid / 4u); let lane = tid & 3u; let K = params.K; let N = params.N; let gs = params.group_size; let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; let gpl = num_groups / 4u; let g_start = lane * gpl; let g_end = g_start + gpl; var gate_sum: f32 = 0.0; var up_sum: f32 = 0.0; if (col < N) { for (var grp = g_start; grp < g_end; grp++) { let sf = grp * N + col; let sg = unpack2x16float(scales_gate[sf >> 1u])[sf & 1u]; let su = unpack2x16float(scales_up[sf >> 1u])[sf & 1u]; let gstart_row = grp * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart_row + r; let off = pr * N + col; let gw0 = qweight_gate[off]; let uw0 = qweight_up[off]; let gw1 = qweight_gate[off+N]; let uw1 = qweight_up[off+N]; let gw2 = qweight_gate[off+2u*N]; let uw2 = qweight_up[off+2u*N]; let gw3 = qweight_gate[off+3u*N]; let uw3 = qweight_up[off+3u*N]; let bk = pr * 8u; gate_sum += sg * (dq_in(gw0,bk) + dq_in(gw1,bk+8u) + dq_in(gw2,bk+16u) + dq_in(gw3,bk+24u)); up_sum += su * (dq_in(uw0,bk) + dq_in(uw1,bk+8u) + dq_in(uw2,bk+16u) + dq_in(uw3,bk+24u)); } } } partial[tid] = gate_sum; partial[32u + tid] = up_sum; workgroupBarrier(); if (lane == 0u && col < N) { let g = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u]; let u = partial[32u+tid] + partial[32u+tid+1u] + partial[32u+tid+2u] + partial[32u+tid+3u]; output[col] = (g / (1.0 + exp(-g))) * u; } }`,j=` enable f16; struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var qweight_gate: array; @group(0) @binding(2) var scales_gate: array; @group(0) @binding(3) var qweight_up: array; @group(0) @binding(4) var scales_up: array; @group(0) @binding(5) var output: array; @group(0) @binding(6) var params: Params; var partial: array; fn dq16(w: u32, bk: u32) -> f32 { let i0 = vec4(f16(input[bk]), f16(input[bk+1u]), f16(input[bk+2u]), f16(input[bk+3u])); let w0 = vec4(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); let i1 = vec4(f16(input[bk+4u]), f16(input[bk+5u]), f16(input[bk+6u]), f16(input[bk+7u])); let w1 = vec4(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); return f32(dot(i0, w0)) + f32(dot(i1, w1)); } @compute @workgroup_size(32) fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { let tid = lid.x; let col = wid.x * 8u + (tid / 4u); let lane = tid & 3u; let K = params.K; let N = params.N; let gs = params.group_size; let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; let gpl = num_groups / 4u; let g_start = lane * gpl; let g_end = g_start + gpl; var gate_sum: f32 = 0.0; var up_sum: f32 = 0.0; if (col < N) { for (var grp = g_start; grp < g_end; grp++) { let sf = grp * N + col; let sg = unpack2x16float(scales_gate[sf >> 1u])[sf & 1u]; let su = unpack2x16float(scales_up[sf >> 1u])[sf & 1u]; let gstart_row = grp * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart_row + r; let off = pr * N + col; let gw0 = qweight_gate[off]; let uw0 = qweight_up[off]; let gw1 = qweight_gate[off+N]; let uw1 = qweight_up[off+N]; let gw2 = qweight_gate[off+2u*N]; let uw2 = qweight_up[off+2u*N]; let gw3 = qweight_gate[off+3u*N]; let uw3 = qweight_up[off+3u*N]; let bk = pr * 8u; gate_sum += sg * (dq16(gw0,bk) + dq16(gw1,bk+8u) + dq16(gw2,bk+16u) + dq16(gw3,bk+24u)); up_sum += su * (dq16(uw0,bk) + dq16(uw1,bk+8u) + dq16(uw2,bk+16u) + dq16(uw3,bk+24u)); } } } partial[tid] = gate_sum; partial[32u + tid] = up_sum; workgroupBarrier(); if (lane == 0u && col < N) { let g = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u]; let u = partial[32u+tid] + partial[32u+tid+1u] + partial[32u+tid+2u] + partial[32u+tid+3u]; output[col] = (g / (1.0 + exp(-g))) * u; } }`,V=` struct Params { N: u32, eps: f32, } @group(0) @binding(0) var hidden: array; @group(0) @binding(1) var addend: array; @group(0) @binding(2) var weight: array; @group(0) @binding(3) var output: array; @group(0) @binding(4) var params: Params; fn unpack_bf16(packed: u32, idx: u32) -> f32 { let bits = (packed >> (idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } var wg_scratch: array; @compute @workgroup_size(256) fn main(@builtin(local_invocation_id) lid: vec3u) { let tid = lid.x; let N = params.N; // Phase 1: Add + compute sum of squares var ss: f32 = 0.0; for (var i = tid; i < N; i += 256u) { let val = hidden[i] + addend[i]; hidden[i] = val; ss += val * val; } wg_scratch[tid] = ss; workgroupBarrier(); // Tree reduction for (var s = 128u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps); // Phase 2: Normalize for (var i = tid; i < N; i += 256u) { let w = unpack_bf16(weight[i / 2u], i % 2u); output[i] = hidden[i] * rms * (1.0 + w); } }`,W=` struct Params { N: u32, eps: f32, } @group(0) @binding(0) var hidden: array; @group(0) @binding(1) var addend: array; @group(0) @binding(2) var weight: array; @group(0) @binding(3) var output: array; @group(0) @binding(4) var params: Params; fn unpack_bf16(packed: u32, idx: u32) -> f32 { let bits = (packed >> (idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } var wg_vals: array; var wg_scratch: array; @compute @workgroup_size(256) fn main(@builtin(local_invocation_id) lid: vec3u) { let tid = lid.x; let N = params.N; var ss: f32 = 0.0; for (var i = tid; i < N; i += 256u) { let val = hidden[i] + addend[i]; wg_vals[i] = val; ss += val * val; } wg_scratch[tid] = ss; workgroupBarrier(); for (var s = 128u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps); for (var i = tid; i < N; i += 256u) { let w = unpack_bf16(weight[i / 2u], i % 2u); output[i] = wg_vals[i] * rms * (1.0 + w); } }`,Y=` struct Params { N: u32, eps: f32, } @group(0) @binding(0) var input_a: array; @group(0) @binding(1) var input_b: array; @group(0) @binding(2) var input_c: array; @group(0) @binding(3) var weight: array; @group(0) @binding(4) var hidden_out: array; @group(0) @binding(5) var normed: array; @group(0) @binding(6) var params: Params; fn unpack_bf16(packed: u32, idx: u32) -> f32 { let bits = (packed >> (idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } var wg_vals: array; var wg_scratch: array; @compute @workgroup_size(256) fn main(@builtin(local_invocation_id) lid: vec3u) { let tid = lid.x; let N = params.N; var ss: f32 = 0.0; for (var i = tid; i < N; i += 256u) { let val = input_a[i] + input_b[i] + input_c[i]; wg_vals[i] = val; hidden_out[i] = val; ss += val * val; } wg_scratch[tid] = ss; workgroupBarrier(); for (var s = 128u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps); for (var i = tid; i < N; i += 256u) { let w = unpack_bf16(weight[i / 2u], i % 2u); normed[i] = wg_vals[i] * rms * (1.0 + w); } }`,X=` struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var x: array; @group(0) @binding(1) var gate: array; @group(0) @binding(2) var qweight: array; @group(0) @binding(3) var scales: array; @group(0) @binding(4) var output: array; @group(0) @binding(5) var params: Params; fn sig(v: f32) -> f32 { return 1.0 / (1.0 + exp(-v)); } fn dq_sg(w: u32, bk: u32) -> f32 { let il = vec4(x[bk]*sig(gate[bk]), x[bk+1u]*sig(gate[bk+1u]), x[bk+2u]*sig(gate[bk+2u]), x[bk+3u]*sig(gate[bk+3u])); let ih = vec4(x[bk+4u]*sig(gate[bk+4u]), x[bk+5u]*sig(gate[bk+5u]), x[bk+6u]*sig(gate[bk+6u]), x[bk+7u]*sig(gate[bk+7u])); return dot(il, vec4(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) + dot(ih, vec4(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u) { let col = gid.x; let K = params.K; let N = params.N; let gs = params.group_size; let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; if (col >= N) { return; } var sum: f32 = 0.0; for (var g: u32 = 0u; g < num_groups; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; let bk = pr * 8u; sum += s * (dq_sg(w0,bk) + dq_sg(w1,bk+8u) + dq_sg(w2,bk+16u) + dq_sg(w3,bk+24u)); } } output[col] = sum; }`,$=` enable f16; struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var x: array; @group(0) @binding(1) var gate: array; @group(0) @binding(2) var qweight: array; @group(0) @binding(3) var scales: array; @group(0) @binding(4) var output: array; @group(0) @binding(5) var params: Params; fn sig(v: f32) -> f32 { return 1.0 / (1.0 + exp(-v)); } fn dq16_sg(w: u32, bk: u32) -> f32 { let i0 = vec4(f16(x[bk]*sig(gate[bk])), f16(x[bk+1u]*sig(gate[bk+1u])), f16(x[bk+2u]*sig(gate[bk+2u])), f16(x[bk+3u]*sig(gate[bk+3u]))); let w0 = vec4(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); let i1 = vec4(f16(x[bk+4u]*sig(gate[bk+4u])), f16(x[bk+5u]*sig(gate[bk+5u])), f16(x[bk+6u]*sig(gate[bk+6u])), f16(x[bk+7u]*sig(gate[bk+7u]))); let w1 = vec4(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); return f32(dot(i0, w0)) + f32(dot(i1, w1)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u) { let col = gid.x; let K = params.K; let N = params.N; let gs = params.group_size; let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; if (col >= N) { return; } var sum: f32 = 0.0; for (var g: u32 = 0u; g < num_groups; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; let bk = pr * 8u; sum += s * (dq16_sg(w0,bk) + dq16_sg(w1,bk+8u) + dq16_sg(w2,bk+16u) + dq16_sg(w3,bk+24u)); } } output[col] = sum; }`,Z=` struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var a: array; @group(0) @binding(1) var b: array; @group(0) @binding(2) var qweight: array; @group(0) @binding(3) var scales: array; @group(0) @binding(4) var output: array; @group(0) @binding(5) var params: Params; fn silu(v: f32) -> f32 { return v / (1.0 + exp(-v)); } fn dq_sl(w: u32, bk: u32) -> f32 { let il = vec4(silu(a[bk])*b[bk], silu(a[bk+1u])*b[bk+1u], silu(a[bk+2u])*b[bk+2u], silu(a[bk+3u])*b[bk+3u]); let ih = vec4(silu(a[bk+4u])*b[bk+4u], silu(a[bk+5u])*b[bk+5u], silu(a[bk+6u])*b[bk+6u], silu(a[bk+7u])*b[bk+7u]); return dot(il, vec4(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) + dot(ih, vec4(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u) { let col = gid.x; let K = params.K; let N = params.N; let gs = params.group_size; let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; if (col >= N) { return; } var sum: f32 = 0.0; for (var g: u32 = 0u; g < num_groups; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; let bk = pr * 8u; sum += s * (dq_sl(w0,bk) + dq_sl(w1,bk+8u) + dq_sl(w2,bk+16u) + dq_sl(w3,bk+24u)); } } output[col] = sum; }`,J=` enable f16; struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var a: array; @group(0) @binding(1) var b: array; @group(0) @binding(2) var qweight: array; @group(0) @binding(3) var scales: array; @group(0) @binding(4) var output: array; @group(0) @binding(5) var params: Params; fn silu(v: f32) -> f32 { return v / (1.0 + exp(-v)); } fn dq16_sl(w: u32, bk: u32) -> f32 { let i0 = vec4(f16(silu(a[bk])*b[bk]), f16(silu(a[bk+1u])*b[bk+1u]), f16(silu(a[bk+2u])*b[bk+2u]), f16(silu(a[bk+3u])*b[bk+3u])); let w0 = vec4(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); let i1 = vec4(f16(silu(a[bk+4u])*b[bk+4u]), f16(silu(a[bk+5u])*b[bk+5u]), f16(silu(a[bk+6u])*b[bk+6u]), f16(silu(a[bk+7u])*b[bk+7u])); let w1 = vec4(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); return f32(dot(i0, w0)) + f32(dot(i1, w1)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u) { let col = gid.x; let K = params.K; let N = params.N; let gs = params.group_size; let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; if (col >= N) { return; } var sum: f32 = 0.0; for (var g: u32 = 0u; g < num_groups; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; let bk = pr * 8u; sum += s * (dq16_sl(w0,bk) + dq16_sl(w1,bk+8u) + dq16_sl(w2,bk+16u) + dq16_sl(w3,bk+24u)); } } output[col] = sum; }`,rr=` struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var a: array; @group(0) @binding(1) var b: array; @group(0) @binding(2) var qweight: array; @group(0) @binding(3) var scales: array; @group(0) @binding(4) var output: array; @group(0) @binding(5) var params: Params; fn silu(v: f32) -> f32 { return v / (1.0 + exp(-v)); } fn dq_sl(w: u32, bk: u32) -> f32 { let il = vec4(silu(a[bk])*b[bk], silu(a[bk+1u])*b[bk+1u], silu(a[bk+2u])*b[bk+2u], silu(a[bk+3u])*b[bk+3u]); let ih = vec4(silu(a[bk+4u])*b[bk+4u], silu(a[bk+5u])*b[bk+5u], silu(a[bk+6u])*b[bk+6u], silu(a[bk+7u])*b[bk+7u]); return dot(il, vec4(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) + dot(ih, vec4(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); } var partial: array; @compute @workgroup_size(32) fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { let tid = lid.x; let N = params.N; let gs = params.group_size; let packed_rows = params.K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; let col = wid.x * 8u + (tid / 4u); let lane = tid & 3u; var sum: f32 = 0.0; if (col < N) { let gpl = num_groups / 4u; let g_start = lane * gpl; let g_end = g_start + gpl; for (var g = g_start; g < g_end; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; let bk = pr * 8u; sum += s * (dq_sl(w0,bk) + dq_sl(w1,bk+8u) + dq_sl(w2,bk+16u) + dq_sl(w3,bk+24u)); } } } partial[tid] = sum; workgroupBarrier(); if (lane == 0u && col < N) { output[col] = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u]; } }`,ar=` enable f16; struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var a: array; @group(0) @binding(1) var b: array; @group(0) @binding(2) var qweight: array; @group(0) @binding(3) var scales: array; @group(0) @binding(4) var output: array; @group(0) @binding(5) var params: Params; fn silu(v: f32) -> f32 { return v / (1.0 + exp(-v)); } fn dq16_sl(w: u32, bk: u32) -> f32 { let i0 = vec4(f16(silu(a[bk])*b[bk]), f16(silu(a[bk+1u])*b[bk+1u]), f16(silu(a[bk+2u])*b[bk+2u]), f16(silu(a[bk+3u])*b[bk+3u])); let w0 = vec4(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); let i1 = vec4(f16(silu(a[bk+4u])*b[bk+4u]), f16(silu(a[bk+5u])*b[bk+5u]), f16(silu(a[bk+6u])*b[bk+6u]), f16(silu(a[bk+7u])*b[bk+7u])); let w1 = vec4(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); return f32(dot(i0, w0)) + f32(dot(i1, w1)); } var partial: array; @compute @workgroup_size(32) fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { let tid = lid.x; let N = params.N; let gs = params.group_size; let packed_rows = params.K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; let col = wid.x * 8u + (tid / 4u); let lane = tid & 3u; var sum: f32 = 0.0; if (col < N) { let gpl = num_groups / 4u; let g_start = lane * gpl; let g_end = g_start + gpl; for (var g = g_start; g < g_end; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; let bk = pr * 8u; sum += s * (dq16_sl(w0,bk) + dq16_sl(w1,bk+8u) + dq16_sl(w2,bk+16u) + dq16_sl(w3,bk+24u)); } } } partial[tid] = sum; workgroupBarrier(); if (lane == 0u && col < N) { output[col] = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u]; } }`,er=` struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var x: array; @group(0) @binding(1) var gate: array; @group(0) @binding(2) var qweight: array; @group(0) @binding(3) var scales: array; @group(0) @binding(4) var output: array; @group(0) @binding(5) var params: Params; fn sig(v: f32) -> f32 { return 1.0 / (1.0 + exp(-v)); } fn dq_sg(w: u32, bk: u32) -> f32 { let il = vec4(x[bk]*sig(gate[bk]), x[bk+1u]*sig(gate[bk+1u]), x[bk+2u]*sig(gate[bk+2u]), x[bk+3u]*sig(gate[bk+3u])); let ih = vec4(x[bk+4u]*sig(gate[bk+4u]), x[bk+5u]*sig(gate[bk+5u]), x[bk+6u]*sig(gate[bk+6u]), x[bk+7u]*sig(gate[bk+7u])); return dot(il, vec4(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) + dot(ih, vec4(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); } var partial: array; @compute @workgroup_size(32) fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { let tid = lid.x; let N = params.N; let gs = params.group_size; let packed_rows = params.K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; let col = wid.x * 8u + (tid / 4u); let lane = tid & 3u; var sum: f32 = 0.0; if (col < N) { let gpl = num_groups / 4u; let g_start = lane * gpl; let g_end = g_start + gpl; for (var g = g_start; g < g_end; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; let bk = pr * 8u; sum += s * (dq_sg(w0,bk) + dq_sg(w1,bk+8u) + dq_sg(w2,bk+16u) + dq_sg(w3,bk+24u)); } } } partial[tid] = sum; workgroupBarrier(); if (lane == 0u && col < N) { output[col] = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u]; } }`,tr=` enable f16; struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var x: array; @group(0) @binding(1) var gate: array; @group(0) @binding(2) var qweight: array; @group(0) @binding(3) var scales: array; @group(0) @binding(4) var output: array; @group(0) @binding(5) var params: Params; fn sig(v: f32) -> f32 { return 1.0 / (1.0 + exp(-v)); } fn dq16_sg(w: u32, bk: u32) -> f32 { let i0 = vec4(f16(x[bk]*sig(gate[bk])), f16(x[bk+1u]*sig(gate[bk+1u])), f16(x[bk+2u]*sig(gate[bk+2u])), f16(x[bk+3u]*sig(gate[bk+3u]))); let w0 = vec4(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); let i1 = vec4(f16(x[bk+4u]*sig(gate[bk+4u])), f16(x[bk+5u]*sig(gate[bk+5u])), f16(x[bk+6u]*sig(gate[bk+6u])), f16(x[bk+7u]*sig(gate[bk+7u]))); let w1 = vec4(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); return f32(dot(i0, w0)) + f32(dot(i1, w1)); } var partial: array; @compute @workgroup_size(32) fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { let tid = lid.x; let N = params.N; let gs = params.group_size; let packed_rows = params.K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; let col = wid.x * 8u + (tid / 4u); let lane = tid & 3u; var sum: f32 = 0.0; if (col < N) { let gpl = num_groups / 4u; let g_start = lane * gpl; let g_end = g_start + gpl; for (var g = g_start; g < g_end; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; let bk = pr * 8u; sum += s * (dq16_sg(w0,bk) + dq16_sg(w1,bk+8u) + dq16_sg(w2,bk+16u) + dq16_sg(w3,bk+24u)); } } } partial[tid] = sum; workgroupBarrier(); if (lane == 0u && col < N) { output[col] = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u]; } }`,ur=` struct Params { K: u32, // hidden_size N: u32, // intermediate_size group_size: u32, eps: f32, // BF16 norm weights packed as vec4 for correct uniform alignment. norm_weight: array, 640>, } @group(0) @binding(0) var hidden_in: array; @group(0) @binding(1) var mlp_residual: array; @group(0) @binding(2) var attn_residual: array; @group(0) @binding(3) var qweight_merged: array; @group(0) @binding(4) var scales_merged: array; @group(0) @binding(5) var hidden_out: array; @group(0) @binding(6) var output: array; @group(0) @binding(7) var params: Params; fn unpack_norm_weight(i: u32) -> f32 { // Each vec4 holds 8 BF16 values (4 u32 × 2 per u32) let vec_idx = i / 8u; let u32_idx = (i / 2u) % 4u; let bf16_idx = i % 2u; let packed = params.norm_weight[vec_idx][u32_idx]; let bits = (packed >> (bf16_idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } // On-the-fly normed value: avoids 16KB shared memory (wg_normed) that killed occupancy. // Inputs are L2-cached (~24KB total), so re-reading is cheap. var wg_scratch: array; fn normed_3way(i: u32, rms: f32) -> f32 { let val = hidden_in[i] + mlp_residual[i] + attn_residual[i]; let w = unpack_norm_weight(i); return val * rms * (1.0 + w); } fn dq_fly(pw: u32, bk: u32, rms: f32) -> f32 { return dot(vec4(normed_3way(bk, rms), normed_3way(bk+1u, rms), normed_3way(bk+2u, rms), normed_3way(bk+3u, rms)), vec4(f32(pw&0xFu)-8., f32((pw>>4u)&0xFu)-8., f32((pw>>8u)&0xFu)-8., f32((pw>>12u)&0xFu)-8.)) + dot(vec4(normed_3way(bk+4u, rms), normed_3way(bk+5u, rms), normed_3way(bk+6u, rms), normed_3way(bk+7u, rms)), vec4(f32((pw>>16u)&0xFu)-8., f32((pw>>20u)&0xFu)-8., f32((pw>>24u)&0xFu)-8., f32((pw>>28u)&0xFu)-8.)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u, @builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) { let tid = lid.x; let K = params.K; let N = params.N; let gs = params.group_size; // ── Phase 1: Three-way add → compute RMS scalar (no shared memory for normed values) ── var ss: f32 = 0.0; for (var i = tid; i < K; i += 32u) { let val = hidden_in[i] + mlp_residual[i] + attn_residual[i]; ss += val * val; // WG 0: write hidden_out (residual for next layer) if (wg.x == 0u) { hidden_out[i] = val; } } wg_scratch[tid] = ss; workgroupBarrier(); // Tree reduction for sum of squares for (var s = 16u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } let rms = 1.0 / sqrt(wg_scratch[0] / f32(K) + params.eps); // ── Phase 2: Gate + Up GPTQ MatVec + SiLU (on-the-fly normed values from L2 cache) ── let col = gid.x; if (col >= N) { return; } let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; let up_qw_base = packed_rows * N; let up_sc_base = num_groups * N; var gate_sum: f32 = 0.0; var up_sum: f32 = 0.0; for (var grp: u32 = 0u; grp < num_groups; grp++) { let gsf = grp * N + col; let usf = up_sc_base + grp * N + col; let sg = unpack2x16float(scales_merged[gsf >> 1u])[gsf & 1u]; let su = unpack2x16float(scales_merged[usf >> 1u])[usf & 1u]; let gstart = grp * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let g_off = pr * N + col; let u_off = up_qw_base + pr * N + col; let gw0 = qweight_merged[g_off]; let uw0 = qweight_merged[u_off]; let gw1 = qweight_merged[g_off+N]; let uw1 = qweight_merged[u_off+N]; let gw2 = qweight_merged[g_off+2u*N]; let uw2 = qweight_merged[u_off+2u*N]; let gw3 = qweight_merged[g_off+3u*N]; let uw3 = qweight_merged[u_off+3u*N]; let bk = pr * 8u; gate_sum += sg * (dq_fly(gw0,bk,rms) + dq_fly(gw1,bk+8u,rms) + dq_fly(gw2,bk+16u,rms) + dq_fly(gw3,bk+24u,rms)); up_sum += su * (dq_fly(uw0,bk,rms) + dq_fly(uw1,bk+8u,rms) + dq_fly(uw2,bk+16u,rms) + dq_fly(uw3,bk+24u,rms)); } } output[col] = (gate_sum / (1.0 + exp(-gate_sum))) * up_sum; }`,ir=` struct Params { K: u32, N: u32, group_size: u32, eps: f32, norm_weight: array, 640>, } @group(0) @binding(0) var hidden_in: array; @group(0) @binding(1) var residual: array; @group(0) @binding(2) var qweight: array; @group(0) @binding(3) var scales: array; @group(0) @binding(4) var output: array; @group(0) @binding(5) var params: Params; fn unpack_norm_weight(i: u32) -> f32 { let vec_idx = i / 8u; let u32_idx = (i / 2u) % 4u; let bf16_idx = i % 2u; let packed = params.norm_weight[vec_idx][u32_idx]; let bits = (packed >> (bf16_idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } // On-the-fly: only 128 bytes shared memory (vs 16KB wg_normed that killed occupancy) var wg_scratch: array; fn normed_add(i: u32, rms: f32) -> f32 { let val = hidden_in[i] + residual[i]; let w = unpack_norm_weight(i); return val * rms * (1.0 + w); } fn dq_fly(pw: u32, bk: u32, rms: f32) -> f32 { return dot(vec4(normed_add(bk, rms), normed_add(bk+1u, rms), normed_add(bk+2u, rms), normed_add(bk+3u, rms)), vec4(f32(pw&0xFu)-8., f32((pw>>4u)&0xFu)-8., f32((pw>>8u)&0xFu)-8., f32((pw>>12u)&0xFu)-8.)) + dot(vec4(normed_add(bk+4u, rms), normed_add(bk+5u, rms), normed_add(bk+6u, rms), normed_add(bk+7u, rms)), vec4(f32((pw>>16u)&0xFu)-8., f32((pw>>20u)&0xFu)-8., f32((pw>>24u)&0xFu)-8., f32((pw>>28u)&0xFu)-8.)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u, @builtin(local_invocation_id) lid: vec3u) { let tid = lid.x; let K = params.K; let N = params.N; let gs = params.group_size; // ── Phase 1: Compute RMS scalar only (inputs re-read on-the-fly in phase 2) ── var ss: f32 = 0.0; for (var i = tid; i < K; i += 32u) { let val = hidden_in[i] + residual[i]; ss += val * val; } wg_scratch[tid] = ss; workgroupBarrier(); for (var s = 16u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } let rms = 1.0 / sqrt(wg_scratch[0] / f32(K) + params.eps); // ── Phase 2: GPTQ MatVec with on-the-fly normed values (L2-cached inputs) ── let col = gid.x; if (col >= N) { return; } let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; var sum: f32 = 0.0; for (var grp: u32 = 0u; grp < num_groups; grp++) { let sf = grp * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = grp * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; let bk = pr * 8u; sum += s * (dq_fly(w0,bk,rms) + dq_fly(w1,bk+8u,rms) + dq_fly(w2,bk+16u,rms) + dq_fly(w3,bk+24u,rms)); } } output[col] = sum; }`,sr=` struct Params { K: u32, N: u32, group_size: u32, eps: f32, norm_weight: array, 640>, } @group(0) @binding(0) var hidden_in: array; @group(0) @binding(1) var qweight: array; @group(0) @binding(2) var scales: array; @group(0) @binding(3) var output: array; @group(0) @binding(4) var params: Params; fn unpack_norm_weight(i: u32) -> f32 { let vec_idx = i / 8u; let u32_idx = (i / 2u) % 4u; let bf16_idx = i % 2u; let packed = params.norm_weight[vec_idx][u32_idx]; let bits = (packed >> (bf16_idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } var wg_scratch: array; fn normed_val(i: u32, rms: f32) -> f32 { let w = unpack_norm_weight(i); return hidden_in[i] * rms * (1.0 + w); } fn dq_fly(pw: u32, bk: u32, rms: f32) -> f32 { return dot(vec4(normed_val(bk, rms), normed_val(bk+1u, rms), normed_val(bk+2u, rms), normed_val(bk+3u, rms)), vec4(f32(pw&0xFu)-8., f32((pw>>4u)&0xFu)-8., f32((pw>>8u)&0xFu)-8., f32((pw>>12u)&0xFu)-8.)) + dot(vec4(normed_val(bk+4u, rms), normed_val(bk+5u, rms), normed_val(bk+6u, rms), normed_val(bk+7u, rms)), vec4(f32((pw>>16u)&0xFu)-8., f32((pw>>20u)&0xFu)-8., f32((pw>>24u)&0xFu)-8., f32((pw>>28u)&0xFu)-8.)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u, @builtin(local_invocation_id) lid: vec3u) { let tid = lid.x; let K = params.K; let N = params.N; let gs = params.group_size; // ── Phase 1: Compute RMS scalar only (inputs re-read on-the-fly in phase 2) ── var ss: f32 = 0.0; for (var i = tid; i < K; i += 32u) { let val = hidden_in[i]; ss += val * val; } wg_scratch[tid] = ss; workgroupBarrier(); for (var s = 16u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } let rms = 1.0 / sqrt(wg_scratch[0] / f32(K) + params.eps); // ── Phase 2: GPTQ MatVec with on-the-fly normed values ── let col = gid.x; if (col >= N) { return; } let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; var sum: f32 = 0.0; for (var grp: u32 = 0u; grp < num_groups; grp++) { let sf = grp * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = grp * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; let bk = pr * 8u; sum += s * (dq_fly(w0,bk,rms) + dq_fly(w1,bk+8u,rms) + dq_fly(w2,bk+16u,rms) + dq_fly(w3,bk+24u,rms)); } } output[col] = sum; }`,or=` struct Params { K: u32, N: u32, group_size: u32, eps: f32, norm_weight: array, 640>, } @group(0) @binding(0) var hidden_in: array; @group(0) @binding(1) var residual: array; @group(0) @binding(2) var qweight: array; @group(0) @binding(3) var scales: array; @group(0) @binding(4) var output: array; @group(0) @binding(5) var params: Params; fn unpack_nw(i: u32) -> f32 { let vec_idx = i / 8u; let u32_idx = (i / 2u) % 4u; let bf16_idx = i % 2u; let packed = params.norm_weight[vec_idx][u32_idx]; let bits = (packed >> (bf16_idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } var wg_scratch: array; var wg_rms: f32; fn nv(i: u32) -> f32 { return (hidden_in[i] + residual[i]) * wg_rms * (1.0 + unpack_nw(i)); } fn dq_nv(w: u32, bk: u32) -> f32 { return dot(vec4(nv(bk), nv(bk+1u), nv(bk+2u), nv(bk+3u)), vec4(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) + dot(vec4(nv(bk+4u), nv(bk+5u), nv(bk+6u), nv(bk+7u)), vec4(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u, @builtin(local_invocation_id) lid: vec3u) { let tid = lid.x; let K = params.K; let N = params.N; let gs = params.group_size; // Phase 1: compute RMS (132B shared memory total) var ss: f32 = 0.0; for (var i = tid; i < K; i += 32u) { let val = hidden_in[i] + residual[i]; ss += val * val; } wg_scratch[tid] = ss; workgroupBarrier(); for (var s = 16u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } if (tid == 0u) { wg_rms = 1.0 / sqrt(wg_scratch[0] / f32(K) + params.eps); } workgroupBarrier(); // Phase 2: GPTQ matvec with inline norm (re-reads input from L1 cache) let col = gid.x; if (col >= N) { return; } let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; var sum: f32 = 0.0; for (var grp: u32 = 0u; grp < num_groups; grp++) { let sf = grp * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = grp * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; let bk = pr * 8u; sum += s * (dq_nv(w0,bk) + dq_nv(w1,bk+8u) + dq_nv(w2,bk+16u) + dq_nv(w3,bk+24u)); } } output[col] = sum; }`,nr=` struct Params { K: u32, N: u32, group_size: u32, eps: f32, norm_weight: array, 640>, } @group(0) @binding(0) var hidden_in: array; @group(0) @binding(1) var qweight: array; @group(0) @binding(2) var scales: array; @group(0) @binding(3) var output: array; @group(0) @binding(4) var params: Params; fn unpack_nw(i: u32) -> f32 { let packed = params.norm_weight[i / 8u][(i / 2u) % 4u]; let bits = (packed >> ((i % 2u) * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } var wg_scratch: array; var wg_rms: f32; fn nv(i: u32) -> f32 { return hidden_in[i] * wg_rms * (1.0 + unpack_nw(i)); } fn dq_nv(w: u32, bk: u32) -> f32 { return dot(vec4(nv(bk), nv(bk+1u), nv(bk+2u), nv(bk+3u)), vec4(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) + dot(vec4(nv(bk+4u), nv(bk+5u), nv(bk+6u), nv(bk+7u)), vec4(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u, @builtin(local_invocation_id) lid: vec3u) { let tid = lid.x; let K = params.K; let N = params.N; let gs = params.group_size; var ss: f32 = 0.0; for (var i = tid; i < K; i += 32u) { ss += hidden_in[i] * hidden_in[i]; } wg_scratch[tid] = ss; workgroupBarrier(); for (var s = 16u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } if (tid == 0u) { wg_rms = 1.0 / sqrt(wg_scratch[0] / f32(K) + params.eps); } workgroupBarrier(); let col = gid.x; if (col >= N) { return; } let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; var sum: f32 = 0.0; for (var grp: u32 = 0u; grp < num_groups; grp++) { let sf = grp * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = grp * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; let bk = pr * 8u; sum += s * (dq_nv(w0,bk) + dq_nv(w1,bk+8u) + dq_nv(w2,bk+16u) + dq_nv(w3,bk+24u)); } } output[col] = sum; }`,gr=` struct Params { num_tokens: u32, penalty: f32, presence: f32, _pad: u32, } @group(0) @binding(0) var logits: array; @group(0) @binding(1) var recent_tokens: array; @group(0) @binding(2) var params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3u) { let i = gid.x; if (i >= params.num_tokens) { return; } let tok = recent_tokens[i]; var val = logits[tok]; // Repetition penalty (multiplicative) if (params.penalty > 1.0) { if (val > 0.0) { val = val / params.penalty; } else { val = val * params.penalty; } } // Presence penalty (additive) val -= params.presence; logits[tok] = val; }`,dr=` struct Params { temperature: f32, top_k: u32, top_p: f32, random_seed: u32, } struct Result { idx: u32, val: f32, } @group(0) @binding(0) var candidates: array; // 256 × (idx u32, val u32) @group(0) @binding(1) var result: Result; @group(0) @binding(2) var params: Params; var wg_idx: array; var wg_val: array; var wg_prob: array; // PCG random number generator fn pcg(state: u32) -> u32 { let s = state * 747796405u + 2891336453u; let word = ((s >> ((s >> 28u) + 4u)) ^ s) * 277803737u; return (word >> 22u) ^ word; } @compute @workgroup_size(256) fn main(@builtin(local_invocation_id) lid: vec3u) { let tid = lid.x; // Load candidate (idx, logit) pair wg_idx[tid] = candidates[tid * 2u]; wg_val[tid] = bitcast(candidates[tid * 2u + 1u]); workgroupBarrier(); // ── Step 1: Parallel find max logit (for numerical stability) ── var local_max = wg_val[tid]; workgroupBarrier(); // Tree reduction for max for (var s = 128u; s > 0u; s >>= 1u) { if (tid < s) { local_max = max(local_max, wg_val[tid + s]); wg_val[tid] = local_max; } workgroupBarrier(); } let max_val = wg_val[0]; // Reload original values (they were clobbered by reduction) wg_val[tid] = bitcast(candidates[tid * 2u + 1u]); workgroupBarrier(); // ── Step 2: Bitonic sort by logit value (descending) ── // Sort all 256 candidates so top-K and top-P can work on sorted order. for (var k = 2u; k <= 256u; k <<= 1u) { for (var j = k >> 1u; j > 0u; j >>= 1u) { let ixj = tid ^ j; if (ixj > tid) { let ascending = (tid & k) != 0u; let val_tid = wg_val[tid]; let val_ixj = wg_val[ixj]; // We want descending, so swap if tid has smaller value let should_swap = select((val_tid < val_ixj), (val_tid > val_ixj), ascending); if (should_swap) { wg_val[tid] = val_ixj; wg_val[ixj] = val_tid; let tmp_idx = wg_idx[tid]; wg_idx[tid] = wg_idx[ixj]; wg_idx[ixj] = tmp_idx; } } workgroupBarrier(); } } // ── Step 3: Temperature + softmax (parallel exp, parallel prefix sum) ── let temp = max(params.temperature, 0.001); // Avoid division by zero let K = min(params.top_k, 256u); // Each thread computes exp((val - max) / temp) if within top-K var prob: f32 = 0.0; if (tid < K) { prob = exp((wg_val[tid] - max_val) / temp); } wg_prob[tid] = prob; workgroupBarrier(); // ── Step 4: Parallel prefix sum for cumulative probabilities ── // Blelloch scan (up-sweep then down-sweep) // First, compute total sum via tree reduction (we need it for normalization) var sum_val = prob; for (var s = 1u; s < 256u; s <<= 1u) { workgroupBarrier(); if (tid >= s) { sum_val = wg_prob[tid - s] + wg_prob[tid]; } workgroupBarrier(); wg_prob[tid] = sum_val; } // Now wg_prob[i] = cumulative sum up to and including index i // wg_prob[255] = total sum let total_sum = wg_prob[255]; workgroupBarrier(); // ── Step 5: Thread 0 does top-P + sampling ── // This is sequential but operates on sorted, cumulative data — very fast. if (tid == 0u) { let topP = params.top_p; let threshold = topP * total_sum; // Un-normalized threshold // Find cutoff: first index where cumulative prob >= threshold var cutoff = K; for (var i = 0u; i < K; i++) { if (wg_prob[i] >= threshold) { cutoff = i + 1u; break; } } // Sample: generate random float in [0, cumProb[cutoff-1]) let rand_u32 = pcg(params.random_seed); let rand_f32 = f32(rand_u32) / 4294967296.0; // [0, 1) let sample_point = rand_f32 * wg_prob[cutoff - 1u]; // Find sampled index via linear scan of cumulative probs var selected = 0u; for (var i = 0u; i < cutoff; i++) { if (wg_prob[i] >= sample_point) { selected = i; break; } } result.idx = wg_idx[selected]; result.val = wg_val[selected]; } }`;function _(d=320,e=1e7,u=33,t=30,r=128){return` const ROPE_THETA: f32 = ${e}; const MROPE_S1_LIMIT: u32 = ${u}u; const MROPE_S2_LIMIT: u32 = ${t}u; const PARTIAL_DIM: u32 = ${r}u; struct Params { num_heads: u32, // Q heads num_kv_heads: u32, // KV heads head_dim: u32, eps: f32, cache_position: u32, // KV cache write index (sequential, updated per token) position: u32, // RoPE temporal/text position (updated per token) position_h: u32, // mRoPE height dimension (= position for text tokens) position_w: u32, // mRoPE width dimension (= position for text tokens) // Packed BF16 norm weights: [Q norm weights | K norm weights] // Q: numHeads*headDim BF16 values, K: numKVHeads*headDim BF16 values qk_norm_weight: array, ${d}>, } @group(0) @binding(0) var q_proj_full: array; @group(0) @binding(1) var k_proj: array; @group(0) @binding(2) var v_proj: array; @group(0) @binding(3) var q_proj: array; @group(0) @binding(4) var q_gate: array; @group(0) @binding(5) var k_cache: array; @group(0) @binding(6) var v_cache: array; @group(0) @binding(7) var params: Params; fn unpack_norm_bf16(flat_idx: u32) -> f32 { let vec_idx = flat_idx / 8u; let u32_idx = (flat_idx / 2u) % 4u; let bf16_idx = flat_idx % 2u; let packed = params.qk_norm_weight[vec_idx][u32_idx]; let bits = (packed >> (bf16_idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } var wg_data: array; var wg_scratch: array; @compute @workgroup_size(256) fn main(@builtin(workgroup_id) wg: vec3u, @builtin(local_invocation_id) lid: vec3u) { let h = wg.x; let tid = lid.x; let hd = params.head_dim; let nh = params.num_heads; let nkv = params.num_kv_heads; if (h < nh) { // ── Q head: split from interleaved + RMSNorm ── let src_off = h * hd * 2u; // Phase 1: Read Q values, compute ss, write gate var ss: f32 = 0.0; for (var i = tid; i < hd; i += 256u) { let q_val = q_proj_full[src_off + i]; wg_data[i] = q_val; ss += q_val * q_val; q_gate[h * hd + i] = q_proj_full[src_off + hd + i]; } wg_scratch[tid] = ss; workgroupBarrier(); for (var s = 128u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } let rms = 1.0 / sqrt(wg_scratch[0] / f32(hd) + params.eps); // Phase 2: Normalize Q and write to shared mem for (var i = tid; i < hd; i += 256u) { let w = unpack_norm_bf16(h * hd + i); wg_data[i] = wg_data[i] * rms * (1.0 + w); } workgroupBarrier(); // Phase 3: Apply mRoPE to Q (rotate_half: dim i pairs with dim i+half_dim) // mRoPE: interleaved position selection based on freq_idx let q_half = PARTIAL_DIM / 2u; for (var i = tid; i < hd; i += 256u) { var q_out = wg_data[i]; if (i < PARTIAL_DIM) { let freq_idx = i % q_half; let freq = 1.0 / pow(ROPE_THETA, f32(2u * freq_idx) / f32(PARTIAL_DIM)); // mRoPE: select position based on interleaved section var pos = params.position; if (freq_idx % 3u == 1u && freq_idx < MROPE_S1_LIMIT) { pos = params.position_h; } else if (freq_idx % 3u == 2u && freq_idx < MROPE_S2_LIMIT) { pos = params.position_w; } let angle = f32(pos) * freq; let cos_a = cos(angle); let sin_a = sin(angle); if (i < q_half) { q_out = wg_data[i] * cos_a - wg_data[i + q_half] * sin_a; } else { q_out = wg_data[i] * cos_a + wg_data[i - q_half] * sin_a; } } q_proj[h * hd + i] = q_out; } } else { // ── KV head: K norm + KV cache store ── let kv_h = h - nh; if (kv_h >= nkv) { return; } let off = kv_h * hd; let q_norm_size = nh * hd; // offset to K norm weights in packed array // Phase 1: Read K, compute ss var ss: f32 = 0.0; for (var i = tid; i < hd; i += 256u) { let k_val = k_proj[off + i]; wg_data[i] = k_val; ss += k_val * k_val; } wg_scratch[tid] = ss; workgroupBarrier(); for (var s = 128u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } let rms = 1.0 / sqrt(wg_scratch[0] / f32(hd) + params.eps); // Phase 2: Normalize K, write to shared mem for (var i = tid; i < hd; i += 256u) { let w = unpack_norm_bf16(q_norm_size + kv_h * hd + i); wg_data[i] = wg_data[i] * rms * (1.0 + w); } workgroupBarrier(); // Phase 3: Apply mRoPE to K, write to kProj + k_cache, copy V to v_cache let total = nkv * hd; let cache_off = params.cache_position * total + off; let k_half = PARTIAL_DIM / 2u; for (var i = tid; i < hd; i += 256u) { var k_out = wg_data[i]; if (i < PARTIAL_DIM) { let freq_idx = i % k_half; let freq = 1.0 / pow(ROPE_THETA, f32(2u * freq_idx) / f32(PARTIAL_DIM)); // mRoPE: select position based on interleaved section var pos = params.position; if (freq_idx % 3u == 1u && freq_idx < MROPE_S1_LIMIT) { pos = params.position_h; } else if (freq_idx % 3u == 2u && freq_idx < MROPE_S2_LIMIT) { pos = params.position_w; } let angle = f32(pos) * freq; let cos_a = cos(angle); let sin_a = sin(angle); if (i < k_half) { k_out = wg_data[i] * cos_a - wg_data[i + k_half] * sin_a; } else { k_out = wg_data[i] * cos_a + wg_data[i - k_half] * sin_a; } } k_proj[off + i] = k_out; k_cache[cache_off + i] = k_out; v_cache[cache_off + i] = v_proj[off + i]; } } }`}const lr=_(320),pr=` struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var qweight: array; @group(0) @binding(2) var scales: array; @group(0) @binding(3) var output: array; @group(0) @binding(4) var params: Params; fn dq(w: u32, bk: u32) -> f32 { return dot(vec4(input[bk], input[bk+1u], input[bk+2u], input[bk+3u]), vec4(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) + dot(vec4(input[bk+4u], input[bk+5u], input[bk+6u], input[bk+7u]), vec4(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); } var partial: array; @compute @workgroup_size(32) fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { let tid = lid.x; let N = params.N; let gs = params.group_size; let packed_rows = params.K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; let col_in_wg = tid / 4u; // 0-7: which of 8 columns this thread handles let lane = tid & 3u; // 0-3: which quarter of K this thread computes let col = wid.x * 8u + col_in_wg; var sum: f32 = 0.0; if (col < N) { let groups_per_lane = num_groups / 4u; let g_start = lane * groups_per_lane; let g_end = g_start + groups_per_lane; for (var g = g_start; g < g_end; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off + N]; let w2 = qweight[off + 2u * N]; let w3 = qweight[off + 3u * N]; let bk = pr * 8u; sum += s * (dq(w0, bk) + dq(w1, bk+8u) + dq(w2, bk+16u) + dq(w3, bk+24u)); } } } // Store partial sum and reduce within each 4-thread column group // With @workgroup_size(32) on Apple GPU, all 32 threads are in 1 SIMD group // so workgroupBarrier is effectively free. partial[tid] = sum; workgroupBarrier(); if (lane == 0u && col < N) { output[col] = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u]; } }`,fr=` enable f16; struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var qweight: array; @group(0) @binding(2) var scales: array; @group(0) @binding(3) var output: array; @group(0) @binding(4) var params: Params; fn dq16(w: u32, bk: u32) -> f32 { let i0 = vec4(f16(input[bk]), f16(input[bk+1u]), f16(input[bk+2u]), f16(input[bk+3u])); let w0 = vec4(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); let i1 = vec4(f16(input[bk+4u]), f16(input[bk+5u]), f16(input[bk+6u]), f16(input[bk+7u])); let w1 = vec4(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); return f32(dot(i0, w0)) + f32(dot(i1, w1)); } var partial: array; @compute @workgroup_size(32) fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { let tid = lid.x; let N = params.N; let gs = params.group_size; let packed_rows = params.K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; let col_in_wg = tid / 4u; let lane = tid & 3u; let col = wid.x * 8u + col_in_wg; var sum: f32 = 0.0; if (col < N) { let groups_per_lane = num_groups / 4u; let g_start = lane * groups_per_lane; let g_end = g_start + groups_per_lane; for (var g = g_start; g < g_end; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off + N]; let w2 = qweight[off + 2u * N]; let w3 = qweight[off + 3u * N]; let bk = pr * 8u; sum += s * (dq16(w0, bk) + dq16(w1, bk+8u) + dq16(w2, bk+16u) + dq16(w3, bk+24u)); } } } partial[tid] = sum; workgroupBarrier(); if (lane == 0u && col < N) { output[col] = partial[tid] + partial[tid+1u] + partial[tid+2u] + partial[tid+3u]; } }`,_r=` enable f16; enable subgroups; struct Params { K: u32, N: u32, group_size: u32, num_splits: u32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var qweight: array; @group(0) @binding(2) var scales: array; @group(0) @binding(3) var partials: array; @group(0) @binding(4) var params: Params; fn dq16(w: u32, bk: u32) -> f32 { let i0 = vec4(f16(input[bk]), f16(input[bk+1u]), f16(input[bk+2u]), f16(input[bk+3u])); let w0 = vec4(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); let i1 = vec4(f16(input[bk+4u]), f16(input[bk+5u]), f16(input[bk+6u]), f16(input[bk+7u])); let w1 = vec4(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); return f32(dot(i0, w0)) + f32(dot(i1, w1)); } @compute @workgroup_size(32) fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { let tid = lid.x; let N = params.N; let gs = params.group_size; let ns = params.num_splits; let packed_rows = params.K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; let col_in_wg = tid / 4u; let lane = tid & 3u; let col = wid.x * 8u + col_in_wg; let split_id = wid.y; var sum: f32 = 0.0; if (col < N) { // Divide groups across splits, then across 4 lanes within each split let groups_per_split = num_groups / ns; let g_base = split_id * groups_per_split; let split_end = select(g_base + groups_per_split, num_groups, split_id >= ns - 1u); let split_groups = split_end - g_base; let groups_per_lane = split_groups / 4u; let g_start = g_base + lane * groups_per_lane; let g_end = select(g_start + groups_per_lane, split_end, lane >= 3u); for (var g = g_start; g < g_end; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off + N]; let w2 = qweight[off + 2u * N]; let w3 = qweight[off + 3u * N]; let bk = pr * 8u; sum += s * (dq16(w0, bk) + dq16(w1, bk+8u) + dq16(w2, bk+16u) + dq16(w3, bk+24u)); } } } // Subgroup reduction across 4 lanes (no shared memory, no barrier) var s = sum; s += subgroupShuffleXor(s, 1u); s += subgroupShuffleXor(s, 2u); if (lane == 0u && col < N) { partials[split_id * N + col] = s; } }`,cr=` struct Params { H_half: u32, num_tokens: u32, } @group(0) @binding(0) var src: array; @group(0) @binding(1) var indices: array; @group(0) @binding(2) var dst: array; @group(0) @binding(3) var p: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3) { let col = gid.x; let row_out = gid.y; if (col >= p.H_half || row_out >= p.num_tokens) { return; } let src_row = indices[row_out]; dst[row_out * p.H_half + col] = src[src_row * p.H_half + col]; } `,wr=` struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var bf16_weight: array; @group(0) @binding(1) var qweight: array; @group(0) @binding(2) var scales_f32: array; @group(0) @binding(3) var params: Params; var wg_max: array; fn unpack_bf16(idx: u32, base: u32) -> f32 { let packed = bf16_weight[base + idx / 2u]; let bits = (packed >> ((idx % 2u) * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } @compute @workgroup_size(32) fn main(@builtin(workgroup_id) wg_id: vec3u, @builtin(local_invocation_id) lid: vec3u) { let col = wg_id.x + wg_id.y * 65535u; // 2D dispatch for N > 65535 let tid = lid.x; let K = params.K; let N = params.N; if (col >= N) { return; } let gs = params.group_size; let num_groups = K / gs; let base = col * K / 2u; for (var g = 0u; g < num_groups; g++) { let gstart = g * gs; // Find max_abs (parallel reduction across 32 threads) var local_max: f32 = 0.0; for (var i = tid; i < gs; i += 32u) { local_max = max(local_max, abs(unpack_bf16(gstart + i, base))); } wg_max[tid] = local_max; workgroupBarrier(); for (var s = 16u; s > 0u; s >>= 1u) { if (tid < s) { wg_max[tid] = max(wg_max[tid], wg_max[tid + s]); } workgroupBarrier(); } let scale_val = wg_max[0] / 7.0; let inv_scale = select(1.0 / scale_val, 0.0, scale_val == 0.0); // Store scale as f32 (packed to F16 in separate pass) if (tid == 0u) { scales_f32[g * N + col] = scale_val; } // Quantize and pack 8 INT4 values per u32 let rows_per_group = gs / 8u; for (var r = tid; r < rows_per_group; r += 32u) { var packed: u32 = 0u; let bk = gstart + r * 8u; for (var b = 0u; b < 8u; b++) { let val = unpack_bf16(bk + b, base); var q = i32(round(val * inv_scale)) + 8; q = clamp(q, 0, 15); packed |= (u32(q) << (b * 4u)); } let packed_row = g * rows_per_group + r; qweight[packed_row * N + col] = packed; } } }`,br=` struct Params { count: u32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var output: array; @group(0) @binding(2) var params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3u) { let idx = gid.x; if (idx >= params.count) { return; } let a = input[idx * 2u]; let b = input[idx * 2u + 1u]; output[idx] = pack2x16float(vec2(a, b)); }`,vr=` struct Params { N: u32, eps: f32, } @group(0) @binding(0) var hidden: array; @group(0) @binding(1) var addend: array; @group(0) @binding(2) var weight: array; @group(0) @binding(3) var output: array; @group(0) @binding(4) var params: Params; fn unpack_bf16(packed: u32, idx: u32) -> f32 { let bits = (packed >> (idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } var wg_scratch: array; @compute @workgroup_size(256) fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { let tid = lid.x; let N = params.N; let off = wid.x * N; var ss: f32 = 0.0; for (var i = tid; i < N; i += 256u) { let val = hidden[off + i] + addend[off + i]; hidden[off + i] = val; ss += val * val; } wg_scratch[tid] = ss; workgroupBarrier(); for (var s = 128u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps); for (var i = tid; i < N; i += 256u) { let w = unpack_bf16(weight[i / 2u], i % 2u); output[off + i] = hidden[off + i] * rms * (1.0 + w); } }`,mr=` struct Params { N: u32, eps: f32, } @group(0) @binding(0) var hidden: array; @group(0) @binding(1) var addend: array; @group(0) @binding(2) var weight: array; @group(0) @binding(3) var output: array; @group(0) @binding(4) var params: Params; fn unpack_bf16(packed: u32, idx: u32) -> f32 { let bits = (packed >> (idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } var wg_vals: array; var wg_scratch: array; @compute @workgroup_size(256) fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { let tid = lid.x; let N = params.N; let off = wid.x * N; var ss: f32 = 0.0; for (var i = tid; i < N; i += 256u) { let val = hidden[off + i] + addend[off + i]; wg_vals[i] = val; ss += val * val; } wg_scratch[tid] = ss; workgroupBarrier(); for (var s = 128u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps); for (var i = tid; i < N; i += 256u) { let w = unpack_bf16(weight[i / 2u], i % 2u); output[off + i] = wg_vals[i] * rms * (1.0 + w); } }`,kr=` struct Params { N: u32, eps: f32, } @group(0) @binding(0) var input_a: array; @group(0) @binding(1) var input_b: array; @group(0) @binding(2) var input_c: array; @group(0) @binding(3) var weight: array; @group(0) @binding(4) var hidden_out: array; @group(0) @binding(5) var normed: array; @group(0) @binding(6) var params: Params; fn unpack_bf16(packed: u32, idx: u32) -> f32 { let bits = (packed >> (idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } var wg_vals: array; var wg_scratch: array; @compute @workgroup_size(256) fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { let tid = lid.x; let N = params.N; let off = wid.x * N; var ss: f32 = 0.0; for (var i = tid; i < N; i += 256u) { let val = input_a[off + i] + input_b[off + i] + input_c[off + i]; wg_vals[i] = val; hidden_out[off + i] = val; ss += val * val; } wg_scratch[tid] = ss; workgroupBarrier(); for (var s = 128u; s > 0u; s >>= 1u) { if (tid < s) { wg_scratch[tid] += wg_scratch[tid + s]; } workgroupBarrier(); } let rms = 1.0 / sqrt(wg_scratch[0] / f32(N) + params.eps); for (var i = tid; i < N; i += 256u) { let w = unpack_bf16(weight[i / 2u], i % 2u); normed[off + i] = wg_vals[i] * rms * (1.0 + w); } }`,hr=` enable f16; struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var qweight_gate: array; @group(0) @binding(2) var scales_gate: array; @group(0) @binding(3) var qweight_up: array; @group(0) @binding(4) var scales_up: array; @group(0) @binding(5) var output: array; @group(0) @binding(6) var params: Params; fn dq16_at(w: u32, base: u32, bk: u32) -> f32 { let i0 = vec4(f16(input[base+bk]), f16(input[base+bk+1u]), f16(input[base+bk+2u]), f16(input[base+bk+3u])); let w0 = vec4(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); let i1 = vec4(f16(input[base+bk+4u]), f16(input[base+bk+5u]), f16(input[base+bk+6u]), f16(input[base+bk+7u])); let w1 = vec4(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); return f32(dot(i0, w0)) + f32(dot(i1, w1)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u) { let col = gid.x; let K = params.K; let N = params.N; let gs = params.group_size; let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; if (col >= N) { return; } var gate0: f32 = 0.0; var up0: f32 = 0.0; var gate1: f32 = 0.0; var up1: f32 = 0.0; for (var grp: u32 = 0u; grp < num_groups; grp++) { let sf = grp * N + col; let sg = unpack2x16float(scales_gate[sf >> 1u])[sf & 1u]; let su = unpack2x16float(scales_up[sf >> 1u])[sf & 1u]; let gstart = grp * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let gw0 = qweight_gate[off]; let uw0 = qweight_up[off]; let gw1 = qweight_gate[off+N]; let uw1 = qweight_up[off+N]; let gw2 = qweight_gate[off+2u*N]; let uw2 = qweight_up[off+2u*N]; let gw3 = qweight_gate[off+3u*N]; let uw3 = qweight_up[off+3u*N]; let bk = pr * 8u; // Batch 0: input at offset 0 gate0 += sg * (dq16_at(gw0,0u,bk) + dq16_at(gw1,0u,bk+8u) + dq16_at(gw2,0u,bk+16u) + dq16_at(gw3,0u,bk+24u)); up0 += su * (dq16_at(uw0,0u,bk) + dq16_at(uw1,0u,bk+8u) + dq16_at(uw2,0u,bk+16u) + dq16_at(uw3,0u,bk+24u)); // Batch 1: input at offset K gate1 += sg * (dq16_at(gw0,K,bk) + dq16_at(gw1,K,bk+8u) + dq16_at(gw2,K,bk+16u) + dq16_at(gw3,K,bk+24u)); up1 += su * (dq16_at(uw0,K,bk) + dq16_at(uw1,K,bk+8u) + dq16_at(uw2,K,bk+16u) + dq16_at(uw3,K,bk+24u)); } } output[col] = (gate0 / (1.0 + exp(-gate0))) * up0; output[N + col] = (gate1 / (1.0 + exp(-gate1))) * up1; }`,xr=` enable f16; struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var a: array; @group(0) @binding(1) var b: array; @group(0) @binding(2) var qweight: array; @group(0) @binding(3) var scales: array; @group(0) @binding(4) var output: array; @group(0) @binding(5) var params: Params; fn silu(v: f32) -> f32 { return v / (1.0 + exp(-v)); } fn dq16_sl_at(w: u32, base: u32, bk: u32) -> f32 { let i0 = vec4(f16(silu(a[base+bk])*b[base+bk]), f16(silu(a[base+bk+1u])*b[base+bk+1u]), f16(silu(a[base+bk+2u])*b[base+bk+2u]), f16(silu(a[base+bk+3u])*b[base+bk+3u])); let w0 = vec4(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); let i1 = vec4(f16(silu(a[base+bk+4u])*b[base+bk+4u]), f16(silu(a[base+bk+5u])*b[base+bk+5u]), f16(silu(a[base+bk+6u])*b[base+bk+6u]), f16(silu(a[base+bk+7u])*b[base+bk+7u])); let w1 = vec4(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); return f32(dot(i0, w0)) + f32(dot(i1, w1)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u) { let col = gid.x; let K = params.K; let N = params.N; let gs = params.group_size; let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; if (col >= N) { return; } var sum0: f32 = 0.0; var sum1: f32 = 0.0; for (var g: u32 = 0u; g < num_groups; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; let bk = pr * 8u; sum0 += s * (dq16_sl_at(w0,0u,bk) + dq16_sl_at(w1,0u,bk+8u) + dq16_sl_at(w2,0u,bk+16u) + dq16_sl_at(w3,0u,bk+24u)); sum1 += s * (dq16_sl_at(w0,K,bk) + dq16_sl_at(w1,K,bk+8u) + dq16_sl_at(w2,K,bk+16u) + dq16_sl_at(w3,K,bk+24u)); } } output[col] = sum0; output[N + col] = sum1; }`,qr=` enable f16; struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var x: array; @group(0) @binding(1) var gate: array; @group(0) @binding(2) var qweight: array; @group(0) @binding(3) var scales: array; @group(0) @binding(4) var output: array; @group(0) @binding(5) var params: Params; fn sig(v: f32) -> f32 { return 1.0 / (1.0 + exp(-v)); } fn dq16_sg_at(w: u32, base: u32, bk: u32) -> f32 { let i0 = vec4(f16(x[base+bk]*sig(gate[base+bk])), f16(x[base+bk+1u]*sig(gate[base+bk+1u])), f16(x[base+bk+2u]*sig(gate[base+bk+2u])), f16(x[base+bk+3u]*sig(gate[base+bk+3u]))); let w0 = vec4(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); let i1 = vec4(f16(x[base+bk+4u]*sig(gate[base+bk+4u])), f16(x[base+bk+5u]*sig(gate[base+bk+5u])), f16(x[base+bk+6u]*sig(gate[base+bk+6u])), f16(x[base+bk+7u]*sig(gate[base+bk+7u]))); let w1 = vec4(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); return f32(dot(i0, w0)) + f32(dot(i1, w1)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u) { let col = gid.x; let K = params.K; let N = params.N; let gs = params.group_size; let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; if (col >= N) { return; } var sum0: f32 = 0.0; var sum1: f32 = 0.0; for (var g: u32 = 0u; g < num_groups; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off+N]; let w2 = qweight[off+2u*N]; let w3 = qweight[off+3u*N]; let bk = pr * 8u; sum0 += s * (dq16_sg_at(w0,0u,bk) + dq16_sg_at(w1,0u,bk+8u) + dq16_sg_at(w2,0u,bk+16u) + dq16_sg_at(w3,0u,bk+24u)); sum1 += s * (dq16_sg_at(w0,K,bk) + dq16_sg_at(w1,K,bk+8u) + dq16_sg_at(w2,K,bk+16u) + dq16_sg_at(w3,K,bk+24u)); } } output[col] = sum0; output[N + col] = sum1; }`,yr=` struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var qweight: array; @group(0) @binding(2) var scales: array; @group(0) @binding(3) var output: array; @group(0) @binding(4) var params: Params; fn dq_at(w: u32, base: u32, bk: u32) -> f32 { return dot(vec4(input[base+bk], input[base+bk+1u], input[base+bk+2u], input[base+bk+3u]), vec4(f32(w&0xFu)-8., f32((w>>4u)&0xFu)-8., f32((w>>8u)&0xFu)-8., f32((w>>12u)&0xFu)-8.)) + dot(vec4(input[base+bk+4u], input[base+bk+5u], input[base+bk+6u], input[base+bk+7u]), vec4(f32((w>>16u)&0xFu)-8., f32((w>>20u)&0xFu)-8., f32((w>>24u)&0xFu)-8., f32((w>>28u)&0xFu)-8.)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u) { let col = gid.x; let K = params.K; let N = params.N; let gs = params.group_size; let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; if (col >= N) { return; } var sum0: f32 = 0.0; var sum1: f32 = 0.0; for (var g: u32 = 0u; g < num_groups; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off + N]; let w2 = qweight[off + 2u * N]; let w3 = qweight[off + 3u * N]; let bk = pr * 8u; sum0 += s * (dq_at(w0, 0u, bk) + dq_at(w1, 0u, bk+8u) + dq_at(w2, 0u, bk+16u) + dq_at(w3, 0u, bk+24u)); sum1 += s * (dq_at(w0, K, bk) + dq_at(w1, K, bk+8u) + dq_at(w2, K, bk+16u) + dq_at(w3, K, bk+24u)); } } output[col] = sum0; output[N + col] = sum1; }`,Fr=` enable f16; struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var qweight: array; @group(0) @binding(2) var scales: array; @group(0) @binding(3) var output: array; @group(0) @binding(4) var params: Params; fn dq16_at(w: u32, base: u32, bk: u32) -> f32 { let i0 = vec4(f16(input[base+bk]), f16(input[base+bk+1u]), f16(input[base+bk+2u]), f16(input[base+bk+3u])); let w0 = vec4(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); let i1 = vec4(f16(input[base+bk+4u]), f16(input[base+bk+5u]), f16(input[base+bk+6u]), f16(input[base+bk+7u])); let w1 = vec4(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); return f32(dot(i0, w0)) + f32(dot(i1, w1)); } @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u) { let col = gid.x; let K = params.K; let N = params.N; let gs = params.group_size; let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; if (col >= N) { return; } var sum0: f32 = 0.0; var sum1: f32 = 0.0; for (var g: u32 = 0u; g < num_groups; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off + N]; let w2 = qweight[off + 2u * N]; let w3 = qweight[off + 3u * N]; let bk = pr * 8u; sum0 += s * (dq16_at(w0, 0u, bk) + dq16_at(w1, 0u, bk+8u) + dq16_at(w2, 0u, bk+16u) + dq16_at(w3, 0u, bk+24u)); sum1 += s * (dq16_at(w0, K, bk) + dq16_at(w1, K, bk+8u) + dq16_at(w2, K, bk+16u) + dq16_at(w3, K, bk+24u)); } } output[col] = sum0; output[N + col] = sum1; }`,Nr=` enable f16; struct Params { K: u32, N: u32, group_size: u32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var qweight: array; @group(0) @binding(2) var scales: array; @group(0) @binding(3) var output: array; @group(0) @binding(4) var params: Params; fn dq16_at(w: u32, base: u32, bk: u32) -> f32 { let i0 = vec4(f16(input[base+bk]), f16(input[base+bk+1u]), f16(input[base+bk+2u]), f16(input[base+bk+3u])); let w0 = vec4(f16(f32(w&0xFu)-8.), f16(f32((w>>4u)&0xFu)-8.), f16(f32((w>>8u)&0xFu)-8.), f16(f32((w>>12u)&0xFu)-8.)); let i1 = vec4(f16(input[base+bk+4u]), f16(input[base+bk+5u]), f16(input[base+bk+6u]), f16(input[base+bk+7u])); let w1 = vec4(f16(f32((w>>16u)&0xFu)-8.), f16(f32((w>>20u)&0xFu)-8.), f16(f32((w>>24u)&0xFu)-8.), f16(f32((w>>28u)&0xFu)-8.)); return f32(dot(i0, w0)) + f32(dot(i1, w1)); } var partial0: array; var partial1: array; @compute @workgroup_size(32) fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { let tid = lid.x; let K = params.K; let N = params.N; let gs = params.group_size; let packed_rows = K / 8u; let rows_per_group = gs / 8u; let num_groups = packed_rows / rows_per_group; let col_in_wg = tid / 4u; let lane = tid & 3u; let col = wid.x * 8u + col_in_wg; var sum0: f32 = 0.0; var sum1: f32 = 0.0; if (col < N) { let groups_per_lane = num_groups / 4u; let g_start = lane * groups_per_lane; let g_end = g_start + groups_per_lane; for (var g = g_start; g < g_end; g++) { let sf = g * N + col; let s = unpack2x16float(scales[sf >> 1u])[sf & 1u]; let gstart = g * rows_per_group; for (var r: u32 = 0u; r < rows_per_group; r += 4u) { let pr = gstart + r; let off = pr * N + col; let w0 = qweight[off]; let w1 = qweight[off + N]; let w2 = qweight[off + 2u * N]; let w3 = qweight[off + 3u * N]; let bk = pr * 8u; sum0 += s * (dq16_at(w0, 0u, bk) + dq16_at(w1, 0u, bk+8u) + dq16_at(w2, 0u, bk+16u) + dq16_at(w3, 0u, bk+24u)); sum1 += s * (dq16_at(w0, K, bk) + dq16_at(w1, K, bk+8u) + dq16_at(w2, K, bk+16u) + dq16_at(w3, K, bk+24u)); } } } partial0[tid] = sum0; partial1[tid] = sum1; workgroupBarrier(); if (lane == 0u && col < N) { output[col] = partial0[tid] + partial0[tid+1u] + partial0[tid+2u] + partial0[tid+3u]; output[N + col] = partial1[tid] + partial1[tid+1u] + partial1[tid+2u] + partial1[tid+3u]; } }`,Pr=` struct Params { N: u32, eps: f32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var weight: array; // BF16 packed @group(0) @binding(2) var bias: array; // BF16 packed @group(0) @binding(3) var output: array; @group(0) @binding(4) var params: Params; fn unpack_bf16(packed: u32, idx: u32) -> f32 { let bits = (packed >> (idx * 16u)) & 0xFFFFu; return bitcast(bits << 16u); } var shared_sum: array; var shared_sq: array; @compute @workgroup_size(256) fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { let N = params.N; let token = wid.x; let base = token * N; let tid = lid.x; // Phase 1: compute sum and sum-of-squares var s: f32 = 0.0; var sq: f32 = 0.0; for (var i = tid; i < N; i += 256u) { let v = input[base + i]; s += v; sq += v * v; } shared_sum[tid] = s; shared_sq[tid] = sq; workgroupBarrier(); // Parallel reduction for (var stride: u32 = 128u; stride > 0u; stride >>= 1u) { if (tid < stride) { shared_sum[tid] += shared_sum[tid + stride]; shared_sq[tid] += shared_sq[tid + stride]; } workgroupBarrier(); } let mean = shared_sum[0] / f32(N); let variance = shared_sq[0] / f32(N) - mean * mean; let inv_std = 1.0 / sqrt(variance + params.eps); // Phase 2: normalize and apply weight + bias for (var i = tid; i < N; i += 256u) { let v = input[base + i]; let normalized = (v - mean) * inv_std; let w = unpack_bf16(weight[i >> 1u], i); let b = unpack_bf16(bias[i >> 1u], i); output[base + i] = normalized * w + b; } }`,Er=` struct Params { K: u32, N: u32, } @group(0) @binding(0) var input: array; @group(0) @binding(1) var weight: array; // BF16 packed [N, K/2] @group(0) @binding(2) var bias: array; // BF16 packed [N/2] @group(0) @binding(3) var output: array; @group(0) @binding(4) var params: Params; @compute @workgroup_size(32) fn main(@builtin(global_invocation_id) gid: vec3u) { let col = gid.x; // output feature index let token = gid.y; // batch/token index let K = params.K; let N = params.N; if (col >= N) { return; } let in_base = token * K; let w_base = col * K / 2u; var sum: f32 = 0.0; for (var k: u32 = 0u; k < K; k += 4u) { let p0 = weight[w_base + k / 2u]; let p1 = weight[w_base + k / 2u + 1u]; let w0 = bitcast((p0 & 0xFFFFu) << 16u); let w1 = bitcast(p0 & 0xFFFF0000u); let w2 = bitcast((p1 & 0xFFFFu) << 16u); let w3 = bitcast(p1 & 0xFFFF0000u); sum += input[in_base + k] * w0 + input[in_base + k + 1u] * w1 + input[in_base + k + 2u] * w2 + input[in_base + k + 3u] * w3; } // Add BF16 bias let bp = bias[col >> 1u]; let b = bitcast(select((bp & 0xFFFFu) << 16u, bp & 0xFFFF0000u, (col & 1u) == 1u)); output[token * N + col] = sum + b; }`,Sr=` @group(0) @binding(0) var data: array; @group(0) @binding(1) var len: u32; const SQRT_2_OVER_PI: f32 = 0.7978845608; const COEFF: f32 = 0.044715; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3u) { let i = gid.x; if (i >= len) { return; } let x = data[i]; // For |x| > 10, GELU saturates: positive → x, negative → 0 // Avoids intermediate overflow in x^3 for large values if (x > 10.0) { return; } // data[i] = x (already in place) if (x < -10.0) { data[i] = 0.0; return; } let inner = SQRT_2_OVER_PI * (x + COEFF * x * x * x); data[i] = 0.5 * x * (1.0 + tanh(inner)); }`,Rr=` @group(0) @binding(0) var data: array; @group(0) @binding(1) var len: u32; const SQRT_2_OVER_PI: f32 = 0.7978845608; const COEFF: f32 = 0.044715; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3u) { let i = gid.x; if (i >= len) { return; } let x = data[i]; if (x > 10.0) { return; } if (x < -10.0) { data[i] = 0.0; return; } let inner = SQRT_2_OVER_PI * (x + COEFF * x * x * x); data[i] = 0.5 * x * (1.0 + tanh(inner)); }`,Kr=` struct Params { seq_len: u32, num_heads: u32, head_dim: u32, } @group(0) @binding(0) var q_in: array; @group(0) @binding(1) var k_in: array; @group(0) @binding(2) var q_out: array; @group(0) @binding(3) var k_out: array; @group(0) @binding(4) var cos_buf: array; @group(0) @binding(5) var sin_buf: array; @group(0) @binding(6) var params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3u) { let idx = gid.x; let total = params.seq_len * params.num_heads * params.head_dim; if (idx >= total) { return; } let d = idx % params.head_dim; let head_and_seq = idx / params.head_dim; let seq = head_and_seq / params.num_heads; let half_dim = params.head_dim / 2u; // cos/sin are [seq_len, head_dim] — same for all heads let cs_idx = seq * params.head_dim + d; let c = cos_buf[cs_idx]; let s = sin_buf[cs_idx]; // rotate_half: pair d with d+half_dim (or d-half_dim) let partner_d = select(d + half_dim, d - half_dim, d >= half_dim); let partner_idx = head_and_seq * params.head_dim + partner_d; let sign = select(-1.0, 1.0, d >= half_dim); q_out[idx] = q_in[idx] * c + sign * q_in[partner_idx] * s; k_out[idx] = k_in[idx] * c + sign * k_in[partner_idx] * s; }`,Br=` struct Params { seq_len: u32, num_heads: u32, head_dim: u32, scale: f32, } @group(0) @binding(0) var q: array; @group(0) @binding(1) var k: array; @group(0) @binding(2) var v: array; @group(0) @binding(3) var output: array; @group(0) @binding(4) var params: Params; var wg_dot: array; @compute @workgroup_size(32) fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { let q_pos = wid.x; let head = wid.y; let tid = lid.x; let S = params.seq_len; let H = params.num_heads; let D = params.head_dim; let scale = params.scale; let q_base = (q_pos * H + head) * D; // Each thread handles D elements with stride 32 (D=64→2, D=128→4 per thread) // Pre-load Q into registers var qc0: f32 = 0.0; var qc1: f32 = 0.0; var qc2: f32 = 0.0; var qc3: f32 = 0.0; qc0 = q[q_base + tid]; if (tid + 32u < D) { qc1 = q[q_base + tid + 32u]; } if (tid + 64u < D) { qc2 = q[q_base + tid + 64u]; } if (tid + 96u < D) { qc3 = q[q_base + tid + 96u]; } // Online softmax + V accumulation, parallelized across D dimension var running_max: f32 = -1e30; var running_sum: f32 = 0.0; var a0: f32 = 0.0; var a1: f32 = 0.0; var a2: f32 = 0.0; var a3: f32 = 0.0; for (var kp: u32 = 0u; kp < S; kp++) { let k_base = (kp * H + head) * D; // Partial dot product: each thread multiplies its D elements var pd: f32 = qc0 * k[k_base + tid]; if (tid + 32u < D) { pd += qc1 * k[k_base + tid + 32u]; } if (tid + 64u < D) { pd += qc2 * k[k_base + tid + 64u]; } if (tid + 96u < D) { pd += qc3 * k[k_base + tid + 96u]; } // Tree reduction: 32 → 1 wg_dot[tid] = pd; workgroupBarrier(); if (tid < 16u) { wg_dot[tid] += wg_dot[tid + 16u]; } workgroupBarrier(); if (tid < 8u) { wg_dot[tid] += wg_dot[tid + 8u]; } workgroupBarrier(); if (tid < 4u) { wg_dot[tid] += wg_dot[tid + 4u]; } workgroupBarrier(); if (tid < 2u) { wg_dot[tid] += wg_dot[tid + 2u]; } workgroupBarrier(); if (tid == 0u) { wg_dot[0] += wg_dot[1]; } workgroupBarrier(); let s = wg_dot[0] * scale; // Online softmax (all threads have same s after reduction) let old_max = running_max; running_max = max(running_max, s); let corr = exp(old_max - running_max); running_sum = running_sum * corr; a0 = a0 * corr; a1 = a1 * corr; a2 = a2 * corr; a3 = a3 * corr; let w = exp(s - running_max); running_sum += w; // Accumulate weighted V for this thread's D elements let v_base = (kp * H + head) * D; a0 += w * v[v_base + tid]; if (tid + 32u < D) { a1 += w * v[v_base + tid + 32u]; } if (tid + 64u < D) { a2 += w * v[v_base + tid + 64u]; } if (tid + 96u < D) { a3 += w * v[v_base + tid + 96u]; } } // Write output let inv_sum = 1.0 / running_sum; let out_base = (q_pos * H + head) * D; output[out_base + tid] = a0 * inv_sum; if (tid + 32u < D) { output[out_base + tid + 32u] = a1 * inv_sum; } if (tid + 64u < D) { output[out_base + tid + 64u] = a2 * inv_sum; } if (tid + 96u < D) { output[out_base + tid + 96u] = a3 * inv_sum; } }`,Dr=` @group(0) @binding(0) var a: array; @group(0) @binding(1) var b: array; @group(0) @binding(2) var len: u32; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3u) { let i = gid.x; if (i >= len) { return; } a[i] += b[i]; }`,Ar=` struct Params { H: u32, } @group(0) @binding(0) var vision: array; @group(0) @binding(1) var scatter_indices: array; @group(0) @binding(2) var embeds: array; @group(0) @binding(3) var params: Params; @compute @workgroup_size(256) fn main(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u) { let vit_idx = wid.x; let pos = scatter_indices[vit_idx]; let H = params.H; let tid = lid.x; for (var i = tid; i < H; i += 256u) { embeds[pos * H + i] = vision[vit_idx * H + i]; } }`,zr={gptq_matvec:k,gptq_matvec_f16:h,gptq_matvec_4t:pr,gptq_matvec_4t_f16:fr,gptq_splitk:x,reduce_splitk:q,bf16_matvec:y,rmsnorm:F,silu_mul:N,add:P,embedding:R,embed_from_argmax:K,argmax:B,topk_extract:D,kv_cache_store:z,gqa_attention_head:T,gqa_reduce:H,deltanet_recurrent:M,head_rmsnorm:G,head_rmsnorm_nogated:O,causal_conv1d:U,split:E,sigmoid_mul:S,fused_gate_up_silu:C,fused_gate_up_silu_f16:L,fused_gate_up_silu_4t:Q,fused_gate_up_silu_4t_f16:j,add_rmsnorm:V,add_rmsnorm_ro:W,three_way_add_rmsnorm:Y,norm_gptq_lite:or,norm_gptq_lite_noadd:nr,fused_sigmoid_gptq:X,fused_sigmoid_gptq_f16:$,fused_sigmoid_gptq_4t:er,fused_sigmoid_gptq_4t_f16:tr,fused_silu_gptq:Z,fused_silu_gptq_f16:J,fused_silu_gptq_4t:rr,fused_silu_gptq_4t_f16:ar,fused_addnorm_gate_up_silu:ur,rep_penalty:gr,gpu_sample:dr,append_token:A,fused_norm_gptq:ir,fused_norm_gptq_noadd:sr,fused_conv_deltanet_norm:I,fused_split_qknorm_kvstore:lr,gptq_matvec_4t_f16_sk:_r,gather_rows_bf16:cr,quantize_bf16_to_int4:wr,pack_f32_to_f16_pairs:br,gptq_matmul_b2:yr,gptq_matmul_b2_f16:Fr,gptq_matmul_b2_4t_f16:Nr,add_rmsnorm_b2:vr,add_rmsnorm_ro_b2:mr,three_way_add_rmsnorm_b2:kr,fused_gate_up_silu_b2_f16:hr,fused_silu_gptq_b2_f16:xr,fused_sigmoid_gptq_b2_f16:qr,vit_layernorm:Pr,vit_bf16_matvec_bias:Er,vit_gelu_tanh:Sr,vit_gelu:Rr,vit_rope:Kr,vit_attention:Br,vit_add:Dr,vit_scatter_embed:Ar},Mr=Object.freeze(Object.defineProperty({__proto__:null,SHADERS:zr,SHADER_FUSED_SPLIT_QKNORM_KVSTORE_FN:_},Symbol.toStringTag,{value:"Module"}));export{b as G,zr as S,Hr as _,_ as a,Mr as b,Tr as g};