Spaces:
Running
Running
| // gpu-blake3.mjs β BLAKE3 verification that runs ENTIRELY on the GPU (WGSL). The parallel two-phase kernel | |
| // (phase1: one GPU thread per 1024-byte chunk; phase2: stack tree-merge β root) β PROVEN byte-exact vs the CPU | |
| // oracle (holo-blake3.mjs) and measured at 2.74 GB/s on AMD RDNA-3 (138 MB in 49 ms). Used by holo-load2bit's | |
| // getBlock to re-derive each streamed weight block's ΞΊ on the GPU (Law L5) β no CPU hash on the hot path. | |
| // Lazy: the device + pipelines are created once on first use; gpuBlake3Hex throws if WebGPU is absent so the | |
| // caller can fall back to the CPU verifier. | |
| import { getDevice } from "./holo-gpu-device.mjs"; // the ONE shared device (max limits) β never a second one | |
| const WGSL = ` | |
| @group(0) @binding(0) var<storage, read> inp : array<u32>; | |
| @group(0) @binding(1) var<storage, read_write> cvs : array<u32>; | |
| @group(0) @binding(2) var<storage, read_write> outp : array<u32>; | |
| @group(0) @binding(3) var<uniform> P : vec4<u32>; | |
| const IV = array<u32,8>(0x6a09e667u,0xbb67ae85u,0x3c6ef372u,0xa54ff53au,0x510e527fu,0x9b05688cu,0x1f83d9abu,0x5be0cd19u); | |
| fn rotr(x:u32,n:u32)->u32 { return (x>>n)|(x<<(32u-n)); } | |
| fn compress(cv:array<u32,8>, m0:array<u32,16>, cl:u32, ch:u32, bl:u32, fl:u32)->array<u32,16>{ | |
| var v:array<u32,16>; | |
| v[0]=cv[0];v[1]=cv[1];v[2]=cv[2];v[3]=cv[3];v[4]=cv[4];v[5]=cv[5];v[6]=cv[6];v[7]=cv[7]; | |
| v[8]=IV[0];v[9]=IV[1];v[10]=IV[2];v[11]=IV[3];v[12]=cl;v[13]=ch;v[14]=bl;v[15]=fl; | |
| var m:array<u32,16>=m0; | |
| for(var r:u32=0u;r<7u;r=r+1u){ | |
| v[0]=v[0]+v[4]+m[0];v[12]=rotr(v[12]^v[0],16u);v[8]=v[8]+v[12];v[4]=rotr(v[4]^v[8],12u);v[0]=v[0]+v[4]+m[1];v[12]=rotr(v[12]^v[0],8u);v[8]=v[8]+v[12];v[4]=rotr(v[4]^v[8],7u); | |
| v[1]=v[1]+v[5]+m[2];v[13]=rotr(v[13]^v[1],16u);v[9]=v[9]+v[13];v[5]=rotr(v[5]^v[9],12u);v[1]=v[1]+v[5]+m[3];v[13]=rotr(v[13]^v[1],8u);v[9]=v[9]+v[13];v[5]=rotr(v[5]^v[9],7u); | |
| v[2]=v[2]+v[6]+m[4];v[14]=rotr(v[14]^v[2],16u);v[10]=v[10]+v[14];v[6]=rotr(v[6]^v[10],12u);v[2]=v[2]+v[6]+m[5];v[14]=rotr(v[14]^v[2],8u);v[10]=v[10]+v[14];v[6]=rotr(v[6]^v[10],7u); | |
| v[3]=v[3]+v[7]+m[6];v[15]=rotr(v[15]^v[3],16u);v[11]=v[11]+v[15];v[7]=rotr(v[7]^v[11],12u);v[3]=v[3]+v[7]+m[7];v[15]=rotr(v[15]^v[3],8u);v[11]=v[11]+v[15];v[7]=rotr(v[7]^v[11],7u); | |
| v[0]=v[0]+v[5]+m[8];v[15]=rotr(v[15]^v[0],16u);v[10]=v[10]+v[15];v[5]=rotr(v[5]^v[10],12u);v[0]=v[0]+v[5]+m[9];v[15]=rotr(v[15]^v[0],8u);v[10]=v[10]+v[15];v[5]=rotr(v[5]^v[10],7u); | |
| v[1]=v[1]+v[6]+m[10];v[12]=rotr(v[12]^v[1],16u);v[11]=v[11]+v[12];v[6]=rotr(v[6]^v[11],12u);v[1]=v[1]+v[6]+m[11];v[12]=rotr(v[12]^v[1],8u);v[11]=v[11]+v[12];v[6]=rotr(v[6]^v[11],7u); | |
| v[2]=v[2]+v[7]+m[12];v[13]=rotr(v[13]^v[2],16u);v[8]=v[8]+v[13];v[7]=rotr(v[7]^v[8],12u);v[2]=v[2]+v[7]+m[13];v[13]=rotr(v[13]^v[2],8u);v[8]=v[8]+v[13];v[7]=rotr(v[7]^v[8],7u); | |
| v[3]=v[3]+v[4]+m[14];v[14]=rotr(v[14]^v[3],16u);v[9]=v[9]+v[14];v[4]=rotr(v[4]^v[9],12u);v[3]=v[3]+v[4]+m[15];v[14]=rotr(v[14]^v[3],8u);v[9]=v[9]+v[14];v[4]=rotr(v[4]^v[9],7u); | |
| if(r<6u){var perm=array<u32,16>(2u,6u,3u,10u,7u,0u,4u,13u,1u,11u,12u,5u,9u,14u,15u,8u);var p:array<u32,16>;for(var i:u32=0u;i<16u;i=i+1u){p[i]=m[perm[i]];}m=p;} | |
| } | |
| var o:array<u32,16>;for(var i:u32=0u;i<8u;i=i+1u){o[i]=v[i]^v[i+8u];o[i+8u]=v[i+8u]^cv[i];}return o; | |
| } | |
| fn loadBlock(bo:u32)->array<u32,16>{let base=bo>>2u;var m:array<u32,16>;for(var i:u32=0u;i<16u;i=i+1u){m[i]=inp[base+i];}return m;} | |
| fn chunkCV(s:u32,l:u32,ct:u32)->array<u32,8>{var cv:array<u32,8>=IV;let nB=max(1u,(l+63u)/64u);var fl:u32=1u;for(var i:u32=0u;i<nB-1u;i=i+1u){let o=compress(cv,loadBlock(s+i*64u),ct,0u,64u,fl);for(var j:u32=0u;j<8u;j=j+1u){cv[j]=o[j];}fl=0u;}let lo=s+(nB-1u)*64u;let ll=l-(nB-1u)*64u;let o2=compress(cv,loadBlock(lo),ct,0u,ll,fl|2u);var out:array<u32,8>;for(var j:u32=0u;j<8u;j=j+1u){out[j]=o2[j];}return out;} | |
| fn chunkRoot(s:u32,l:u32)->array<u32,8>{var cv:array<u32,8>=IV;let nB=max(1u,(l+63u)/64u);var fl:u32=1u;for(var i:u32=0u;i<nB-1u;i=i+1u){let o=compress(cv,loadBlock(s+i*64u),0u,0u,64u,fl);for(var j:u32=0u;j<8u;j=j+1u){cv[j]=o[j];}fl=0u;}let lo=s+(nB-1u)*64u;let ll=l-(nB-1u)*64u;let o=compress(cv,loadBlock(lo),0u,0u,ll,fl|2u|8u);var out:array<u32,8>;for(var j:u32=0u;j<8u;j=j+1u){out[j]=o[j];}return out;} | |
| fn parentCV(l:array<u32,8>,r:array<u32,8>,root:bool)->array<u32,8>{var m:array<u32,16>;for(var i:u32=0u;i<8u;i=i+1u){m[i]=l[i];m[i+8u]=r[i];}var fl:u32=4u;if(root){fl=12u;}let o=compress(IV,m,0u,0u,64u,fl);var out:array<u32,8>;for(var i:u32=0u;i<8u;i=i+1u){out[i]=o[i];}return out;} | |
| fn getCV(i:u32)->array<u32,8>{var c:array<u32,8>;for(var j:u32=0u;j<8u;j=j+1u){c[j]=cvs[i*8u+j];}return c;} | |
| @compute @workgroup_size(64) | |
| fn phase1(@builtin(global_invocation_id) gid:vec3<u32>){ | |
| let i=gid.x; let N=P.y; let len=P.x; if(i>=N){return;} | |
| let off=i*1024u; let ln=min(1024u,len-off); | |
| var cv:array<u32,8>; if(N==1u){cv=chunkRoot(off,ln);}else{cv=chunkCV(off,ln,i);} | |
| for(var j:u32=0u;j<8u;j=j+1u){cvs[i*8u+j]=cv[j];} | |
| } | |
| @compute @workgroup_size(1) | |
| fn phase2(){ | |
| let N=P.y; var root:array<u32,8>; | |
| if(N==1u){root=getCV(0u);} | |
| else{ | |
| var stack:array<array<u32,8>,54>; var depth:u32=0u; | |
| for(var c:u32=0u;c<N-1u;c=c+1u){var cv=getCV(c);var t=c+1u;loop{if((t&1u)!=0u){break;}depth=depth-1u;cv=parentCV(stack[depth],cv,false);t=t>>1u;}stack[depth]=cv;depth=depth+1u;} | |
| var node=getCV(N-1u);for(var i:u32=0u;i<depth;i=i+1u){let li=depth-1u-i;root=parentCV(stack[li],node,i==depth-1u);node=root;} | |
| } | |
| for(var i:u32=0u;i<8u;i=i+1u){outp[i]=root[i];} | |
| } | |
| `; | |
| let _dev = null, _p1 = null, _p2 = null, _bgl = null, _initP = null; | |
| async function ensure() { | |
| if (_dev) return true; | |
| if (_initP) return _initP; | |
| _initP = (async () => { | |
| _dev = (await getDevice()).dev; // SHARE the engine's device (max limits) β no second device | |
| const mod = _dev.createShaderModule({ code: WGSL }); | |
| const ci = await mod.getCompilationInfo(); | |
| const er = ci.messages.filter((m) => m.type === "error"); | |
| if (er.length) throw new Error("gpu-blake3 WGSL: " + er[0].message); | |
| _bgl = _dev.createBindGroupLayout({ entries: [ | |
| { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: "read-only-storage" } }, | |
| { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: "storage" } }, | |
| { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: "storage" } }, | |
| { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: "uniform" } }] }); | |
| const pll = _dev.createPipelineLayout({ bindGroupLayouts: [_bgl] }); | |
| _p1 = _dev.createComputePipeline({ layout: pll, compute: { module: mod, entryPoint: "phase1" } }); | |
| _p2 = _dev.createComputePipeline({ layout: pll, compute: { module: mod, entryPoint: "phase2" } }); | |
| return true; | |
| })().catch((e) => { _initP = null; throw e; }); | |
| return _initP; | |
| } | |
| // gpuBlake3Hex(bytes) β 64-hex BLAKE3 of `bytes`, computed entirely on the GPU. Throws if WebGPU is unavailable. | |
| export async function gpuBlake3Hex(bytes) { | |
| await ensure(); | |
| const dev = _dev; | |
| const len = bytes.length, N = len === 0 ? 1 : Math.ceil(len / 1024); | |
| const padU = Math.max(16, (Math.ceil(len / 64) || 1) * 16); | |
| const u = new Uint32Array(padU); new Uint8Array(u.buffer).set(bytes); | |
| const inB = dev.createBuffer({ size: u.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST }); dev.queue.writeBuffer(inB, 0, u); | |
| const cvB = dev.createBuffer({ size: Math.max(32, N * 32), usage: GPUBufferUsage.STORAGE }); | |
| const oB = dev.createBuffer({ size: 32, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC }); | |
| const pB = dev.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); dev.queue.writeBuffer(pB, 0, new Uint32Array([len, N, 0, 0])); | |
| const sg = dev.createBuffer({ size: 32, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }); | |
| const bg = dev.createBindGroup({ layout: _bgl, entries: [{ binding: 0, resource: { buffer: inB } }, { binding: 1, resource: { buffer: cvB } }, { binding: 2, resource: { buffer: oB } }, { binding: 3, resource: { buffer: pB } }] }); | |
| const e = dev.createCommandEncoder(); | |
| let pass = e.beginComputePass(); pass.setPipeline(_p1); pass.setBindGroup(0, bg); pass.dispatchWorkgroups(Math.ceil(N / 64)); pass.end(); | |
| pass = e.beginComputePass(); pass.setPipeline(_p2); pass.setBindGroup(0, bg); pass.dispatchWorkgroups(1); pass.end(); | |
| e.copyBufferToBuffer(oB, 0, sg, 0, 32); dev.queue.submit([e.finish()]); | |
| await sg.mapAsync(GPUMapMode.READ); | |
| const w = new Uint32Array(sg.getMappedRange().slice(0)); sg.unmap(); | |
| inB.destroy(); cvB.destroy(); oB.destroy(); pB.destroy(); sg.destroy(); | |
| let s = ""; | |
| for (let i = 0; i < 8; i++) { const x = w[i] >>> 0; s += (x & 255).toString(16).padStart(2, "0") + ((x >>> 8) & 255).toString(16).padStart(2, "0") + ((x >>> 16) & 255).toString(16).padStart(2, "0") + ((x >>> 24) & 255).toString(16).padStart(2, "0"); } | |
| return s; | |
| } | |
| export function gpuBlake3Available() { return typeof navigator !== "undefined" && !!navigator.gpu; } | |
| export default { gpuBlake3Hex, gpuBlake3Available }; | |