q / gpu-blake3.mjs
Humuhumu33's picture
Upload gpu-blake3.mjs with huggingface_hub
4e8cc03 verified
Raw
History Blame Contribute Delete
9 kB
// gpu-blake3.mjs β€” BLAKE3 verification that runs ENTIRELY on the GPU (WGSL). The parallel two-phase kernel
// (phase1: one GPU thread per 1024-byte chunk; phase2: stack tree-merge β†’ root) β€” PROVEN byte-exact vs the CPU
// oracle (holo-blake3.mjs) and measured at 2.74 GB/s on AMD RDNA-3 (138 MB in 49 ms). Used by holo-load2bit's
// getBlock to re-derive each streamed weight block's ΞΊ on the GPU (Law L5) β€” no CPU hash on the hot path.
// Lazy: the device + pipelines are created once on first use; gpuBlake3Hex throws if WebGPU is absent so the
// caller can fall back to the CPU verifier.
import { getDevice } from "./holo-gpu-device.mjs"; // the ONE shared device (max limits) β€” never a second one
const WGSL = `
@group(0) @binding(0) var<storage, read> inp : array<u32>;
@group(0) @binding(1) var<storage, read_write> cvs : array<u32>;
@group(0) @binding(2) var<storage, read_write> outp : array<u32>;
@group(0) @binding(3) var<uniform> P : vec4<u32>;
const IV = array<u32,8>(0x6a09e667u,0xbb67ae85u,0x3c6ef372u,0xa54ff53au,0x510e527fu,0x9b05688cu,0x1f83d9abu,0x5be0cd19u);
fn rotr(x:u32,n:u32)->u32 { return (x>>n)|(x<<(32u-n)); }
fn compress(cv:array<u32,8>, m0:array<u32,16>, cl:u32, ch:u32, bl:u32, fl:u32)->array<u32,16>{
var v:array<u32,16>;
v[0]=cv[0];v[1]=cv[1];v[2]=cv[2];v[3]=cv[3];v[4]=cv[4];v[5]=cv[5];v[6]=cv[6];v[7]=cv[7];
v[8]=IV[0];v[9]=IV[1];v[10]=IV[2];v[11]=IV[3];v[12]=cl;v[13]=ch;v[14]=bl;v[15]=fl;
var m:array<u32,16>=m0;
for(var r:u32=0u;r<7u;r=r+1u){
v[0]=v[0]+v[4]+m[0];v[12]=rotr(v[12]^v[0],16u);v[8]=v[8]+v[12];v[4]=rotr(v[4]^v[8],12u);v[0]=v[0]+v[4]+m[1];v[12]=rotr(v[12]^v[0],8u);v[8]=v[8]+v[12];v[4]=rotr(v[4]^v[8],7u);
v[1]=v[1]+v[5]+m[2];v[13]=rotr(v[13]^v[1],16u);v[9]=v[9]+v[13];v[5]=rotr(v[5]^v[9],12u);v[1]=v[1]+v[5]+m[3];v[13]=rotr(v[13]^v[1],8u);v[9]=v[9]+v[13];v[5]=rotr(v[5]^v[9],7u);
v[2]=v[2]+v[6]+m[4];v[14]=rotr(v[14]^v[2],16u);v[10]=v[10]+v[14];v[6]=rotr(v[6]^v[10],12u);v[2]=v[2]+v[6]+m[5];v[14]=rotr(v[14]^v[2],8u);v[10]=v[10]+v[14];v[6]=rotr(v[6]^v[10],7u);
v[3]=v[3]+v[7]+m[6];v[15]=rotr(v[15]^v[3],16u);v[11]=v[11]+v[15];v[7]=rotr(v[7]^v[11],12u);v[3]=v[3]+v[7]+m[7];v[15]=rotr(v[15]^v[3],8u);v[11]=v[11]+v[15];v[7]=rotr(v[7]^v[11],7u);
v[0]=v[0]+v[5]+m[8];v[15]=rotr(v[15]^v[0],16u);v[10]=v[10]+v[15];v[5]=rotr(v[5]^v[10],12u);v[0]=v[0]+v[5]+m[9];v[15]=rotr(v[15]^v[0],8u);v[10]=v[10]+v[15];v[5]=rotr(v[5]^v[10],7u);
v[1]=v[1]+v[6]+m[10];v[12]=rotr(v[12]^v[1],16u);v[11]=v[11]+v[12];v[6]=rotr(v[6]^v[11],12u);v[1]=v[1]+v[6]+m[11];v[12]=rotr(v[12]^v[1],8u);v[11]=v[11]+v[12];v[6]=rotr(v[6]^v[11],7u);
v[2]=v[2]+v[7]+m[12];v[13]=rotr(v[13]^v[2],16u);v[8]=v[8]+v[13];v[7]=rotr(v[7]^v[8],12u);v[2]=v[2]+v[7]+m[13];v[13]=rotr(v[13]^v[2],8u);v[8]=v[8]+v[13];v[7]=rotr(v[7]^v[8],7u);
v[3]=v[3]+v[4]+m[14];v[14]=rotr(v[14]^v[3],16u);v[9]=v[9]+v[14];v[4]=rotr(v[4]^v[9],12u);v[3]=v[3]+v[4]+m[15];v[14]=rotr(v[14]^v[3],8u);v[9]=v[9]+v[14];v[4]=rotr(v[4]^v[9],7u);
if(r<6u){var perm=array<u32,16>(2u,6u,3u,10u,7u,0u,4u,13u,1u,11u,12u,5u,9u,14u,15u,8u);var p:array<u32,16>;for(var i:u32=0u;i<16u;i=i+1u){p[i]=m[perm[i]];}m=p;}
}
var o:array<u32,16>;for(var i:u32=0u;i<8u;i=i+1u){o[i]=v[i]^v[i+8u];o[i+8u]=v[i+8u]^cv[i];}return o;
}
fn loadBlock(bo:u32)->array<u32,16>{let base=bo>>2u;var m:array<u32,16>;for(var i:u32=0u;i<16u;i=i+1u){m[i]=inp[base+i];}return m;}
fn chunkCV(s:u32,l:u32,ct:u32)->array<u32,8>{var cv:array<u32,8>=IV;let nB=max(1u,(l+63u)/64u);var fl:u32=1u;for(var i:u32=0u;i<nB-1u;i=i+1u){let o=compress(cv,loadBlock(s+i*64u),ct,0u,64u,fl);for(var j:u32=0u;j<8u;j=j+1u){cv[j]=o[j];}fl=0u;}let lo=s+(nB-1u)*64u;let ll=l-(nB-1u)*64u;let o2=compress(cv,loadBlock(lo),ct,0u,ll,fl|2u);var out:array<u32,8>;for(var j:u32=0u;j<8u;j=j+1u){out[j]=o2[j];}return out;}
fn chunkRoot(s:u32,l:u32)->array<u32,8>{var cv:array<u32,8>=IV;let nB=max(1u,(l+63u)/64u);var fl:u32=1u;for(var i:u32=0u;i<nB-1u;i=i+1u){let o=compress(cv,loadBlock(s+i*64u),0u,0u,64u,fl);for(var j:u32=0u;j<8u;j=j+1u){cv[j]=o[j];}fl=0u;}let lo=s+(nB-1u)*64u;let ll=l-(nB-1u)*64u;let o=compress(cv,loadBlock(lo),0u,0u,ll,fl|2u|8u);var out:array<u32,8>;for(var j:u32=0u;j<8u;j=j+1u){out[j]=o[j];}return out;}
fn parentCV(l:array<u32,8>,r:array<u32,8>,root:bool)->array<u32,8>{var m:array<u32,16>;for(var i:u32=0u;i<8u;i=i+1u){m[i]=l[i];m[i+8u]=r[i];}var fl:u32=4u;if(root){fl=12u;}let o=compress(IV,m,0u,0u,64u,fl);var out:array<u32,8>;for(var i:u32=0u;i<8u;i=i+1u){out[i]=o[i];}return out;}
fn getCV(i:u32)->array<u32,8>{var c:array<u32,8>;for(var j:u32=0u;j<8u;j=j+1u){c[j]=cvs[i*8u+j];}return c;}
@compute @workgroup_size(64)
fn phase1(@builtin(global_invocation_id) gid:vec3<u32>){
let i=gid.x; let N=P.y; let len=P.x; if(i>=N){return;}
let off=i*1024u; let ln=min(1024u,len-off);
var cv:array<u32,8>; if(N==1u){cv=chunkRoot(off,ln);}else{cv=chunkCV(off,ln,i);}
for(var j:u32=0u;j<8u;j=j+1u){cvs[i*8u+j]=cv[j];}
}
@compute @workgroup_size(1)
fn phase2(){
let N=P.y; var root:array<u32,8>;
if(N==1u){root=getCV(0u);}
else{
var stack:array<array<u32,8>,54>; var depth:u32=0u;
for(var c:u32=0u;c<N-1u;c=c+1u){var cv=getCV(c);var t=c+1u;loop{if((t&1u)!=0u){break;}depth=depth-1u;cv=parentCV(stack[depth],cv,false);t=t>>1u;}stack[depth]=cv;depth=depth+1u;}
var node=getCV(N-1u);for(var i:u32=0u;i<depth;i=i+1u){let li=depth-1u-i;root=parentCV(stack[li],node,i==depth-1u);node=root;}
}
for(var i:u32=0u;i<8u;i=i+1u){outp[i]=root[i];}
}
`;
let _dev = null, _p1 = null, _p2 = null, _bgl = null, _initP = null;
async function ensure() {
if (_dev) return true;
if (_initP) return _initP;
_initP = (async () => {
_dev = (await getDevice()).dev; // SHARE the engine's device (max limits) β€” no second device
const mod = _dev.createShaderModule({ code: WGSL });
const ci = await mod.getCompilationInfo();
const er = ci.messages.filter((m) => m.type === "error");
if (er.length) throw new Error("gpu-blake3 WGSL: " + er[0].message);
_bgl = _dev.createBindGroupLayout({ entries: [
{ binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: "read-only-storage" } },
{ binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: "storage" } },
{ binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: "storage" } },
{ binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: "uniform" } }] });
const pll = _dev.createPipelineLayout({ bindGroupLayouts: [_bgl] });
_p1 = _dev.createComputePipeline({ layout: pll, compute: { module: mod, entryPoint: "phase1" } });
_p2 = _dev.createComputePipeline({ layout: pll, compute: { module: mod, entryPoint: "phase2" } });
return true;
})().catch((e) => { _initP = null; throw e; });
return _initP;
}
// gpuBlake3Hex(bytes) β†’ 64-hex BLAKE3 of `bytes`, computed entirely on the GPU. Throws if WebGPU is unavailable.
export async function gpuBlake3Hex(bytes) {
await ensure();
const dev = _dev;
const len = bytes.length, N = len === 0 ? 1 : Math.ceil(len / 1024);
const padU = Math.max(16, (Math.ceil(len / 64) || 1) * 16);
const u = new Uint32Array(padU); new Uint8Array(u.buffer).set(bytes);
const inB = dev.createBuffer({ size: u.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST }); dev.queue.writeBuffer(inB, 0, u);
const cvB = dev.createBuffer({ size: Math.max(32, N * 32), usage: GPUBufferUsage.STORAGE });
const oB = dev.createBuffer({ size: 32, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC });
const pB = dev.createBuffer({ size: 16, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); dev.queue.writeBuffer(pB, 0, new Uint32Array([len, N, 0, 0]));
const sg = dev.createBuffer({ size: 32, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ });
const bg = dev.createBindGroup({ layout: _bgl, entries: [{ binding: 0, resource: { buffer: inB } }, { binding: 1, resource: { buffer: cvB } }, { binding: 2, resource: { buffer: oB } }, { binding: 3, resource: { buffer: pB } }] });
const e = dev.createCommandEncoder();
let pass = e.beginComputePass(); pass.setPipeline(_p1); pass.setBindGroup(0, bg); pass.dispatchWorkgroups(Math.ceil(N / 64)); pass.end();
pass = e.beginComputePass(); pass.setPipeline(_p2); pass.setBindGroup(0, bg); pass.dispatchWorkgroups(1); pass.end();
e.copyBufferToBuffer(oB, 0, sg, 0, 32); dev.queue.submit([e.finish()]);
await sg.mapAsync(GPUMapMode.READ);
const w = new Uint32Array(sg.getMappedRange().slice(0)); sg.unmap();
inB.destroy(); cvB.destroy(); oB.destroy(); pB.destroy(); sg.destroy();
let s = "";
for (let i = 0; i < 8; i++) { const x = w[i] >>> 0; s += (x & 255).toString(16).padStart(2, "0") + ((x >>> 8) & 255).toString(16).padStart(2, "0") + ((x >>> 16) & 255).toString(16).padStart(2, "0") + ((x >>> 24) & 255).toString(16).padStart(2, "0"); }
return s;
}
export function gpuBlake3Available() { return typeof navigator !== "undefined" && !!navigator.gpu; }
export default { gpuBlake3Hex, gpuBlake3Available };