q / blake3-gpu-parallel.html
Humuhumu33's picture
Upload blake3-gpu-parallel.html with huggingface_hub
a3c9d6f verified
Raw
History Blame Contribute Delete
13.4 kB
<!doctype html><html><head><meta charset=utf8><meta name=viewport content="width=device-width,initial-scale=1">
<title>BLAKE3 — parallel, on your GPU</title>
<style>
:root{--bg:#0a0d13;--panel:#111825;--ink:#e8ebf1;--dim:#8b95a7;--ac:#7c5cff;--ok:#48c26c;--no:#f0616d;--line:#1e2836}
*{box-sizing:border-box}body{margin:0;background:var(--bg);color:var(--ink);font:15px/1.55 -apple-system,Segoe UI,Roboto,monospace;padding:24px;max-width:820px;margin:0 auto}
h1{font-size:20px;margin:0 0 4px}.sub{color:var(--dim);font-size:13px;margin:0 0 18px}
#status,#speed{font-family:ui-monospace,monospace;padding:12px 14px;border:1px solid var(--line);border-radius:10px;background:var(--panel);margin-bottom:12px}
#speed{font-size:16px}.big{font-size:17px;font-weight:600;margin:14px 0 6px}
table{border-collapse:collapse;width:100%;font-family:ui-monospace,monospace;font-size:12.5px}
th,td{text-align:left;padding:6px 10px;border-bottom:1px solid var(--line)}
th{color:var(--dim);text-transform:uppercase;font-size:10.5px;letter-spacing:.08em}
td.n{text-align:right;font-variant-numeric:tabular-nums}.ok{color:var(--ok)}.no{color:var(--no)}.hx{color:var(--dim)}
</style></head><body>
<h1>BLAKE3 — <span style="color:var(--ac)">parallel</span>, on your GPU</h1>
<p class="sub">Phase 1: one GPU thread per 1024-byte chunk (parallel). Phase 2: stack tree-merge → root. Correctness vs the CPU oracle, plus throughput on a big block (the size of the 145 MB embed).</p>
<div id="status">initializing…</div>
<div id="speed"></div>
<div id="result"></div>
<table id="tbl"><thead><tr><th>bytes</th><th>match</th></tr></thead><tbody></tbody></table>
<script type="module">
// oracle (holo-blake3.mjs, inlined)
const IV=[0x6a09e667,0xbb67ae85,0x3c6ef372,0xa54ff53a,0x510e527f,0x9b05688c,0x1f83d9ab,0x5be0cd19];
const MSG=[2,6,3,10,7,0,4,13,1,11,12,5,9,14,15,8];const CS=1,CE=2,PA=4,RT=8,BL=64,CH=1024;
const rr=(x,n)=>((x>>>n)|(x<<(32-n)))>>>0;
function g(v,a,b,c,d,mx,my){v[a]=(v[a]+v[b]+mx)>>>0;v[d]=rr(v[d]^v[a],16);v[c]=(v[c]+v[d])>>>0;v[b]=rr(v[b]^v[c],12);v[a]=(v[a]+v[b]+my)>>>0;v[d]=rr(v[d]^v[a],8);v[c]=(v[c]+v[d])>>>0;v[b]=rr(v[b]^v[c],7);}
function rnd(v,m){g(v,0,4,8,12,m[0],m[1]);g(v,1,5,9,13,m[2],m[3]);g(v,2,6,10,14,m[4],m[5]);g(v,3,7,11,15,m[6],m[7]);g(v,0,5,10,15,m[8],m[9]);g(v,1,6,11,12,m[10],m[11]);g(v,2,7,8,13,m[12],m[13]);g(v,3,4,9,14,m[14],m[15]);}
function cmp(cv,m0,ct,bl,fl){const cl=ct>>>0,ch=Math.floor(ct/4294967296)>>>0;const v=[cv[0],cv[1],cv[2],cv[3],cv[4],cv[5],cv[6],cv[7],IV[0],IV[1],IV[2],IV[3],cl,ch,bl>>>0,fl>>>0];let m=m0.slice();for(let r=0;r<7;r++){rnd(v,m);if(r<6){const p=new Array(16);for(let i=0;i<16;i++)p[i]=m[MSG[i]];m=p;}}const o=new Array(16);for(let i=0;i<8;i++){o[i]=(v[i]^v[i+8])>>>0;o[i+8]=(v[i+8]^cv[i])>>>0;}return o;}
function wd(b,o,l){const m=new Array(16).fill(0);for(let i=0;i<l;i++)m[i>>2]|=b[o+i]<<((i&3)*8);for(let i=0;i<16;i++)m[i]>>>=0;return m;}
function ncv(o){return cmp(o.cv,o.m,o.counter,o.blockLen,o.flags).slice(0,8);}
function nrb(o){const out=cmp(o.cv,o.m,0,o.blockLen,o.flags|RT);const b=new Uint8Array(32);for(let i=0;i<8;i++){const w=out[i];b[i*4]=w&255;b[i*4+1]=(w>>>8)&255;b[i*4+2]=(w>>>16)&255;b[i*4+3]=(w>>>24)&255;}return b;}
function cn(b,s,l,ct){let cv=IV.slice();const nB=Math.max(1,Math.ceil(l/BL));let fl=CS;for(let i=0;i<nB-1;i++){cv=cmp(cv,wd(b,s+i*BL,BL),ct,BL,fl).slice(0,8);fl=0;}const lo=s+(nB-1)*BL,ll=l-(nB-1)*BL;return{cv,m:wd(b,lo,ll),counter:ct,blockLen:ll,flags:fl|CE};}
function pn(l,r){return{cv:IV.slice(),m:l.concat(r),counter:0,blockLen:BL,flags:PA};}
function st(b,s,l,ct){if(l<=CH)return cn(b,s,l,ct);let lf=CH;while(lf*2<l)lf*=2;const lc=ncv(st(b,s,lf,ct)),rc=ncv(st(b,s+lf,l-lf,ct+lf/CH));return pn(lc,rc);}
function b3hex(b){const d=nrb(st(b,0,b.length,0));let s='';for(let i=0;i<32;i++)s+=d[i].toString(16).padStart(2,'0');return s;}
// ── parallel WGSL: reuses the GPU-proven compress/chunkCV/chunkRoot/parentCV; adds two kernels ──
const WGSL=`
@group(0) @binding(0) var<storage, read> inp : array<u32>;
@group(0) @binding(1) var<storage, read_write> cvs : array<u32>; // phase1 out / phase2 in: 8 u32 per chunk
@group(0) @binding(2) var<storage, read_write> outp : array<u32>;
@group(0) @binding(3) var<uniform> P : vec4<u32>; // x=len y=N
const IV = array<u32,8>(0x6a09e667u,0xbb67ae85u,0x3c6ef372u,0xa54ff53au,0x510e527fu,0x9b05688cu,0x1f83d9abu,0x5be0cd19u);
fn rotr(x:u32,n:u32)->u32 { return (x>>n)|(x<<(32u-n)); }
fn compress(cv:array<u32,8>, m0:array<u32,16>, cl:u32, ch:u32, bl:u32, fl:u32)->array<u32,16>{
var v:array<u32,16>;
v[0]=cv[0];v[1]=cv[1];v[2]=cv[2];v[3]=cv[3];v[4]=cv[4];v[5]=cv[5];v[6]=cv[6];v[7]=cv[7];
v[8]=IV[0];v[9]=IV[1];v[10]=IV[2];v[11]=IV[3];v[12]=cl;v[13]=ch;v[14]=bl;v[15]=fl;
var m:array<u32,16>=m0;
for(var r:u32=0u;r<7u;r=r+1u){
v[0]=v[0]+v[4]+m[0];v[12]=rotr(v[12]^v[0],16u);v[8]=v[8]+v[12];v[4]=rotr(v[4]^v[8],12u);v[0]=v[0]+v[4]+m[1];v[12]=rotr(v[12]^v[0],8u);v[8]=v[8]+v[12];v[4]=rotr(v[4]^v[8],7u);
v[1]=v[1]+v[5]+m[2];v[13]=rotr(v[13]^v[1],16u);v[9]=v[9]+v[13];v[5]=rotr(v[5]^v[9],12u);v[1]=v[1]+v[5]+m[3];v[13]=rotr(v[13]^v[1],8u);v[9]=v[9]+v[13];v[5]=rotr(v[5]^v[9],7u);
v[2]=v[2]+v[6]+m[4];v[14]=rotr(v[14]^v[2],16u);v[10]=v[10]+v[14];v[6]=rotr(v[6]^v[10],12u);v[2]=v[2]+v[6]+m[5];v[14]=rotr(v[14]^v[2],8u);v[10]=v[10]+v[14];v[6]=rotr(v[6]^v[10],7u);
v[3]=v[3]+v[7]+m[6];v[15]=rotr(v[15]^v[3],16u);v[11]=v[11]+v[15];v[7]=rotr(v[7]^v[11],12u);v[3]=v[3]+v[7]+m[7];v[15]=rotr(v[15]^v[3],8u);v[11]=v[11]+v[15];v[7]=rotr(v[7]^v[11],7u);
v[0]=v[0]+v[5]+m[8];v[15]=rotr(v[15]^v[0],16u);v[10]=v[10]+v[15];v[5]=rotr(v[5]^v[10],12u);v[0]=v[0]+v[5]+m[9];v[15]=rotr(v[15]^v[0],8u);v[10]=v[10]+v[15];v[5]=rotr(v[5]^v[10],7u);
v[1]=v[1]+v[6]+m[10];v[12]=rotr(v[12]^v[1],16u);v[11]=v[11]+v[12];v[6]=rotr(v[6]^v[11],12u);v[1]=v[1]+v[6]+m[11];v[12]=rotr(v[12]^v[1],8u);v[11]=v[11]+v[12];v[6]=rotr(v[6]^v[11],7u);
v[2]=v[2]+v[7]+m[12];v[13]=rotr(v[13]^v[2],16u);v[8]=v[8]+v[13];v[7]=rotr(v[7]^v[8],12u);v[2]=v[2]+v[7]+m[13];v[13]=rotr(v[13]^v[2],8u);v[8]=v[8]+v[13];v[7]=rotr(v[7]^v[8],7u);
v[3]=v[3]+v[4]+m[14];v[14]=rotr(v[14]^v[3],16u);v[9]=v[9]+v[14];v[4]=rotr(v[4]^v[9],12u);v[3]=v[3]+v[4]+m[15];v[14]=rotr(v[14]^v[3],8u);v[9]=v[9]+v[14];v[4]=rotr(v[4]^v[9],7u);
if(r<6u){var perm=array<u32,16>(2u,6u,3u,10u,7u,0u,4u,13u,1u,11u,12u,5u,9u,14u,15u,8u);var p:array<u32,16>;for(var i:u32=0u;i<16u;i=i+1u){p[i]=m[perm[i]];}m=p;}
}
var o:array<u32,16>;for(var i:u32=0u;i<8u;i=i+1u){o[i]=v[i]^v[i+8u];o[i+8u]=v[i+8u]^cv[i];}return o;
}
fn loadBlock(bo:u32)->array<u32,16>{let base=bo>>2u;var m:array<u32,16>;for(var i:u32=0u;i<16u;i=i+1u){m[i]=inp[base+i];}return m;}
fn chunkCV(s:u32,l:u32,ct:u32)->array<u32,8>{var cv:array<u32,8>=IV;let nB=max(1u,(l+63u)/64u);var fl:u32=1u;for(var i:u32=0u;i<nB-1u;i=i+1u){let o=compress(cv,loadBlock(s+i*64u),ct,0u,64u,fl);for(var j:u32=0u;j<8u;j=j+1u){cv[j]=o[j];}fl=0u;}let lo=s+(nB-1u)*64u;let ll=l-(nB-1u)*64u;let o2=compress(cv,loadBlock(lo),ct,0u,ll,fl|2u);var out:array<u32,8>;for(var j:u32=0u;j<8u;j=j+1u){out[j]=o2[j];}return out;}
fn chunkRoot(s:u32,l:u32)->array<u32,8>{var cv:array<u32,8>=IV;let nB=max(1u,(l+63u)/64u);var fl:u32=1u;for(var i:u32=0u;i<nB-1u;i=i+1u){let o=compress(cv,loadBlock(s+i*64u),0u,0u,64u,fl);for(var j:u32=0u;j<8u;j=j+1u){cv[j]=o[j];}fl=0u;}let lo=s+(nB-1u)*64u;let ll=l-(nB-1u)*64u;let o=compress(cv,loadBlock(lo),0u,0u,ll,fl|2u|8u);var out:array<u32,8>;for(var j:u32=0u;j<8u;j=j+1u){out[j]=o[j];}return out;}
fn parentCV(l:array<u32,8>,r:array<u32,8>,root:bool)->array<u32,8>{var m:array<u32,16>;for(var i:u32=0u;i<8u;i=i+1u){m[i]=l[i];m[i+8u]=r[i];}var fl:u32=4u;if(root){fl=12u;}let o=compress(IV,m,0u,0u,64u,fl);var out:array<u32,8>;for(var i:u32=0u;i<8u;i=i+1u){out[i]=o[i];}return out;}
fn getCV(i:u32)->array<u32,8>{var c:array<u32,8>;for(var j:u32=0u;j<8u;j=j+1u){c[j]=cvs[i*8u+j];}return c;}
// PHASE 1 — one invocation per chunk (parallel)
@compute @workgroup_size(64)
fn phase1(@builtin(global_invocation_id) gid:vec3<u32>){
let i=gid.x; let N=P.y; let len=P.x;
if(i>=N){return;}
let off=i*1024u; let ln=min(1024u,len-off);
var cv:array<u32,8>;
if(N==1u){cv=chunkRoot(off,ln);}else{cv=chunkCV(off,ln,i);}
for(var j:u32=0u;j<8u;j=j+1u){cvs[i*8u+j]=cv[j];}
}
// PHASE 2 — single invocation: stack tree-merge over the CVs → root
@compute @workgroup_size(1)
fn phase2(){
let N=P.y;
var root:array<u32,8>;
if(N==1u){root=getCV(0u);}
else{
var stack:array<array<u32,8>,54>; var depth:u32=0u;
for(var c:u32=0u;c<N-1u;c=c+1u){
var cv=getCV(c); var t=c+1u;
loop{ if((t&1u)!=0u){break;} depth=depth-1u; cv=parentCV(stack[depth],cv,false); t=t>>1u; }
stack[depth]=cv; depth=depth+1u;
}
var node=getCV(N-1u);
for(var i:u32=0u;i<depth;i=i+1u){ let li=depth-1u-i; root=parentCV(stack[li],node,i==depth-1u); node=root; }
}
for(var i:u32=0u;i<8u;i=i+1u){outp[i]=root[i];}
}
`;
const $=s=>document.querySelector(s),st_=$("#status"),sp=$("#speed"),tb=$("#tbl tbody"),res=$("#result");
const hexLE=w=>{let s='';for(let i=0;i<8;i++){const x=w[i]>>>0;s+=(x&255).toString(16).padStart(2,'0')+((x>>>8)&255).toString(16).padStart(2,'0')+((x>>>16)&255).toString(16).padStart(2,'0')+((x>>>24)&255).toString(16).padStart(2,'0');}return s;};
(async()=>{
if(!navigator.gpu){st_.textContent="✗ No WebGPU.";st_.classList.add("no");return;}
const ad=await navigator.gpu.requestAdapter(),dev=await ad.requestDevice();
const mod=dev.createShaderModule({code:WGSL});const ci=await mod.getCompilationInfo();const er=ci.messages.filter(m=>m.type==="error");
if(er.length){st_.textContent="✗ WGSL: "+er[0].message+" (line "+er[0].lineNum+")";st_.classList.add("no");return;}
// explicit shared layout so ONE bind group works for both phases (auto-layout would infer different layouts)
const bgl=dev.createBindGroupLayout({entries:[
{binding:0,visibility:GPUShaderStage.COMPUTE,buffer:{type:"read-only-storage"}},
{binding:1,visibility:GPUShaderStage.COMPUTE,buffer:{type:"storage"}},
{binding:2,visibility:GPUShaderStage.COMPUTE,buffer:{type:"storage"}},
{binding:3,visibility:GPUShaderStage.COMPUTE,buffer:{type:"uniform"}}]});
const pll=dev.createPipelineLayout({bindGroupLayouts:[bgl]});
const p1=dev.createComputePipeline({layout:pll,compute:{module:mod,entryPoint:"phase1"}});
const p2=dev.createComputePipeline({layout:pll,compute:{module:mod,entryPoint:"phase2"}});
async function gpu(bytes){
const len=bytes.length,N=len===0?1:Math.ceil(len/1024);
const padU=Math.max(16,(Math.ceil(len/64)||1)*16);const u=new Uint32Array(padU);new Uint8Array(u.buffer).set(bytes);
const inB=dev.createBuffer({size:u.byteLength,usage:GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_DST});dev.queue.writeBuffer(inB,0,u);
const cvB=dev.createBuffer({size:Math.max(32,N*32),usage:GPUBufferUsage.STORAGE});
const oB=dev.createBuffer({size:32,usage:GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_SRC});
const pB=dev.createBuffer({size:16,usage:GPUBufferUsage.UNIFORM|GPUBufferUsage.COPY_DST});dev.queue.writeBuffer(pB,0,new Uint32Array([len,N,0,0]));
const sg=dev.createBuffer({size:32,usage:GPUBufferUsage.COPY_DST|GPUBufferUsage.MAP_READ});
const bg=dev.createBindGroup({layout:bgl,entries:[{binding:0,resource:{buffer:inB}},{binding:1,resource:{buffer:cvB}},{binding:2,resource:{buffer:oB}},{binding:3,resource:{buffer:pB}}]});
const t0=performance.now();
const e=dev.createCommandEncoder();
let pass=e.beginComputePass();pass.setPipeline(p1);pass.setBindGroup(0,bg);pass.dispatchWorkgroups(Math.ceil(N/64));pass.end();
pass=e.beginComputePass();pass.setPipeline(p2);pass.setBindGroup(0,bg);pass.dispatchWorkgroups(1);pass.end();
e.copyBufferToBuffer(oB,0,sg,0,32);dev.queue.submit([e.finish()]);
await sg.mapAsync(GPUMapMode.READ);const w=new Uint32Array(sg.getMappedRange().slice(0));sg.unmap();
const ms=performance.now()-t0;
inB.destroy();cvB.destroy();oB.destroy();pB.destroy();sg.destroy();
return {hex:hexLE(w),ms};
}
// correctness
st_.textContent="checking correctness…";let pass=0;const sizes=[0,1,64,1023,1024,1025,2048,4096,100000,1048576];
for(const n of sizes){const b=new Uint8Array(n);for(let i=0;i<n;i++)b[i]=i&255;const {hex}=await gpu(b);const ok=hex===b3hex(b);if(ok)pass++;const tr=document.createElement("tr");tr.innerHTML=`<td class="n">${n.toLocaleString()}</td><td class="${ok?'ok':'no'}">${ok?'✓ match':'✗ MISMATCH'}</td>`;tb.appendChild(tr);}
st_.textContent=`correctness: ${pass}/${sizes.length}`;st_.classList.add(pass===sizes.length?"ok":"no");
// throughput — 138 MB (the embed block size)
sp.textContent="timing 138 MB (embed-sized block)…";
const BIG=138*1024*1024;const big=new Uint8Array(BIG);for(let i=0;i<BIG;i++)big[i]=i&255;
let best=1e9;for(let k=0;k<3;k++){const {ms}=await gpu(big);best=Math.min(best,ms);}
const gbps=(BIG/1073741824)/(best/1000);
sp.innerHTML=`<b>138 MB verified on GPU in ${best.toFixed(0)} ms → ${gbps.toFixed(2)} GB/s</b>`;sp.classList.add(gbps>1?"ok":"");
res.innerHTML=(pass===sizes.length)?`<div class="big ok">✓ Parallel BLAKE3 on your GPU — byte-exact, and ${gbps.toFixed(1)} GB/s on an embed-sized block.</div><div class="hx">Adapter: ${ad.info?ad.info.vendor+" "+ad.info.architecture:"?"}</div>`:`<div class="big no">✗ correctness failed — see table</div>`;
})();
</script></body></html>