Spaces:
Running
Running
| ; | |
| const dynamicI8BundleCache = new Map(); | |
| const arrayBufferFetchCache = new Map(); | |
| const immutableGpuBufferCache = new WeakMap(); | |
| const scratchGpuBufferCache = new WeakMap(); | |
| const computePipelineCache = new WeakMap(); | |
| const stageObjectCache = new WeakMap(); | |
| const textProjectionGpuCache = new WeakMap(); | |
| let browserCacheDbPromise = null; | |
| let customWebGpuDevicePromise = null; | |
| let customWebGpuDevice = null; | |
| let customWebGpuDeviceDescriptorKey = ""; | |
| function resetCustomLowbitWebGpuDevice() { | |
| const device = customWebGpuDevice; | |
| customWebGpuDevicePromise = null; | |
| customWebGpuDevice = null; | |
| customWebGpuDeviceDescriptorKey = ""; | |
| if (device && typeof device.destroy === "function") { | |
| try { | |
| device.destroy(); | |
| } catch (_) { | |
| // Best-effort cleanup after large benchmark sweeps. | |
| } | |
| } | |
| } | |
| function median(values) { | |
| const sorted = [...values].sort((a, b) => a - b); | |
| const mid = Math.floor(sorted.length / 2); | |
| return sorted.length % 2 ? sorted[mid] : (sorted[mid - 1] + sorted[mid]) / 2; | |
| } | |
| function roundToNearest(value) { | |
| return value < 0 ? Math.ceil(value - 0.5) : Math.floor(value + 0.5); | |
| } | |
| function float32ToFloat16Bits(value) { | |
| const floatView = new Float32Array(1); | |
| const intView = new Uint32Array(floatView.buffer); | |
| floatView[0] = value; | |
| const x = intView[0]; | |
| const sign = (x >>> 16) & 0x8000; | |
| let mantissa = x & 0x7fffff; | |
| let exponent = (x >>> 23) & 0xff; | |
| if (exponent === 0xff) { | |
| if (mantissa !== 0) return sign | 0x7e00; | |
| return sign | 0x7c00; | |
| } | |
| exponent = exponent - 127 + 15; | |
| if (exponent >= 0x1f) return sign | 0x7c00; | |
| if (exponent <= 0) { | |
| if (exponent < -10) return sign; | |
| mantissa = (mantissa | 0x800000) >>> (1 - exponent); | |
| return sign | ((mantissa + 0x1000) >>> 13); | |
| } | |
| return sign | (exponent << 10) | ((mantissa + 0x1000) >>> 13); | |
| } | |
| function float16BitsToFloat32(bits) { | |
| const sign = (bits & 0x8000) ? -1 : 1; | |
| const exponent = (bits >>> 10) & 0x1f; | |
| const mantissa = bits & 0x03ff; | |
| if (exponent === 0) { | |
| return sign * Math.pow(2, -14) * (mantissa / 1024); | |
| } | |
| if (exponent === 31) { | |
| return mantissa ? NaN : sign * Infinity; | |
| } | |
| return sign * Math.pow(2, exponent - 15) * (1 + mantissa / 1024); | |
| } | |
| function deterministicInput(row, kIndex) { | |
| const base = ((row * 37 + kIndex * 13 + 17) % 251) - 125; | |
| const fraction = (((row * 17 + kIndex * 29 + 5) % 17) - 8) / 32; | |
| return Math.fround((base + fraction) / 16); | |
| } | |
| function buildInputF16(n, k) { | |
| const values = new Uint16Array(n * k); | |
| for (let row = 0; row < n; ++row) { | |
| for (let kIndex = 0; kIndex < k; ++kIndex) { | |
| values[row * k + kIndex] = float32ToFloat16Bits(deterministicInput(row, kIndex)); | |
| } | |
| } | |
| return values; | |
| } | |
| function buildInputF32(n, k) { | |
| const values = new Float32Array(n * k); | |
| for (let row = 0; row < n; ++row) { | |
| for (let kIndex = 0; kIndex < k; ++kIndex) { | |
| values[row * k + kIndex] = deterministicInput(row, kIndex); | |
| } | |
| } | |
| return values; | |
| } | |
| function buildFluxRopeFrequencies(axisDim = 32, theta = 2000) { | |
| const count = axisDim / 2; | |
| const values = new Float32Array(count); | |
| for (let i = 0; i < count; ++i) { | |
| values[i] = 1 / Math.pow(theta, (i * 2) / axisDim); | |
| } | |
| return values; | |
| } | |
| function buildFluxRopeSinCos(n, textTokens, imageWidth, axisDim = 32, theta = 2000) { | |
| const freqs = buildFluxRopeFrequencies(axisDim, theta); | |
| const pairs = 64; | |
| const values = new Float32Array(n * pairs * 2); | |
| const safeWidth = Math.max(1, Number(imageWidth || 1)); | |
| for (let row = 0; row < n; ++row) { | |
| for (let pair = 0; pair < pairs; ++pair) { | |
| const axis = Math.floor(pair / 16); | |
| const freqIndex = pair - axis * 16; | |
| let position = 0; | |
| if (row < textTokens) { | |
| position = axis === 3 ? row : 0; | |
| } else { | |
| const imageRow = row - textTokens; | |
| const y = Math.floor(imageRow / safeWidth); | |
| const x = imageRow - y * safeWidth; | |
| if (axis === 1) position = y; | |
| else if (axis === 2) position = x; | |
| } | |
| const angle = position * freqs[freqIndex]; | |
| const base = (row * pairs + pair) * 2; | |
| values[base] = Math.cos(angle); | |
| values[base + 1] = Math.sin(angle); | |
| } | |
| } | |
| return values; | |
| } | |
| function sleepMs(ms) { | |
| return new Promise((resolve) => setTimeout(resolve, ms)); | |
| } | |
| async function fetchArrayBuffer(url) { | |
| const key = new URL(url, window.location.href).toString(); | |
| if (!arrayBufferFetchCache.has(key)) { | |
| arrayBufferFetchCache.set(key, (async () => { | |
| let lastError = null; | |
| for (let attempt = 0; attempt < 4; attempt += 1) { | |
| try { | |
| const response = await fetch(key); | |
| if (!response.ok) { | |
| throw new Error(`fetch failed ${response.status}: ${key}`); | |
| } | |
| return await response.arrayBuffer(); | |
| } catch (err) { | |
| lastError = err; | |
| if (attempt < 3) await sleepMs(100 * (attempt + 1)); | |
| } | |
| } | |
| throw lastError || new Error(`fetch failed: ${key}`); | |
| })()); | |
| } | |
| try { | |
| return await arrayBufferFetchCache.get(key); | |
| } catch (err) { | |
| arrayBufferFetchCache.delete(key); | |
| throw err; | |
| } | |
| } | |
| function openBrowserCacheDb() { | |
| if (typeof indexedDB === "undefined") return Promise.resolve(null); | |
| if (browserCacheDbPromise) return browserCacheDbPromise; | |
| browserCacheDbPromise = new Promise((resolve) => { | |
| const request = indexedDB.open("flux2-browser-cache", 2); | |
| request.onupgradeneeded = () => { | |
| const db = request.result; | |
| const store = db.objectStoreNames.contains("text-contexts") | |
| ? request.transaction.objectStore("text-contexts") | |
| : db.createObjectStore("text-contexts", {keyPath: "key"}); | |
| if (!store.indexNames.contains("savedAt")) { | |
| store.createIndex("savedAt", "savedAt"); | |
| } | |
| }; | |
| request.onsuccess = () => resolve(request.result); | |
| request.onerror = () => { | |
| console.warn("[custom-lowbit] IndexedDB open failed", request.error); | |
| resolve(null); | |
| }; | |
| request.onblocked = () => { | |
| console.warn("[custom-lowbit] IndexedDB open blocked"); | |
| resolve(null); | |
| }; | |
| }); | |
| return browserCacheDbPromise; | |
| } | |
| async function loadPersistentFloat32(key, expectedValues) { | |
| if (!key) return null; | |
| const db = await openBrowserCacheDb(); | |
| if (!db) return null; | |
| return await new Promise((resolve) => { | |
| const tx = db.transaction("text-contexts", "readonly"); | |
| const request = tx.objectStore("text-contexts").get(key); | |
| request.onsuccess = () => { | |
| const entry = request.result; | |
| if (!entry || !entry.data) { | |
| resolve(null); | |
| return; | |
| } | |
| const values = entry.data instanceof Float32Array | |
| ? entry.data | |
| : (entry.data instanceof ArrayBuffer ? new Float32Array(entry.data) : null); | |
| resolve(values && values.length === expectedValues ? values : null); | |
| }; | |
| request.onerror = () => resolve(null); | |
| }); | |
| } | |
| async function loadStaticFloat32(url, expectedValues) { | |
| if (!url) return null; | |
| try { | |
| const buffer = await fetchArrayBuffer(url); | |
| if (buffer.byteLength !== expectedValues * 4) return null; | |
| return new Float32Array(buffer); | |
| } catch (err) { | |
| console.warn("[custom-lowbit] static f32 load failed", url, err); | |
| return null; | |
| } | |
| } | |
| async function savePersistentFloat32(key, values) { | |
| if (!key || !(values instanceof Float32Array)) return false; | |
| const db = await openBrowserCacheDb(); | |
| if (!db) return false; | |
| return await new Promise((resolve) => { | |
| const tx = db.transaction("text-contexts", "readwrite"); | |
| const store = tx.objectStore("text-contexts"); | |
| const buffer = values.buffer.slice(values.byteOffset, values.byteOffset + values.byteLength); | |
| const request = store.put({ | |
| key, | |
| data: buffer, | |
| bytes: values.byteLength, | |
| savedAt: Date.now(), | |
| }); | |
| request.onsuccess = () => resolve(true); | |
| request.onerror = () => resolve(false); | |
| }); | |
| } | |
| function createBuffer(device, data, usage) { | |
| const size = Math.max(4, Math.ceil(data.byteLength / 4) * 4); | |
| const buffer = device.createBuffer({size, usage, mappedAtCreation: true}); | |
| new Uint8Array(buffer.getMappedRange()).set(new Uint8Array(data.buffer, data.byteOffset, data.byteLength)); | |
| buffer.unmap(); | |
| return buffer; | |
| } | |
| function createImmutableBuffer(device, data, usage, cacheKey = null) { | |
| if (!cacheKey) return createBuffer(device, data, usage); | |
| let deviceCache = immutableGpuBufferCache.get(device); | |
| if (!deviceCache) { | |
| deviceCache = new Map(); | |
| immutableGpuBufferCache.set(device, deviceCache); | |
| } | |
| const key = `${usage}:${cacheKey}`; | |
| if (!deviceCache.has(key)) { | |
| deviceCache.set(key, createBuffer(device, data, usage)); | |
| } | |
| return deviceCache.get(key); | |
| } | |
| function createEmptyBuffer(device, size, usage) { | |
| return device.createBuffer({size: Math.max(4, Math.ceil(size / 4) * 4), usage}); | |
| } | |
| function createReusableBuffer(device, cacheKey, size, usage) { | |
| if (!cacheKey) return createEmptyBuffer(device, size, usage); | |
| let deviceCache = scratchGpuBufferCache.get(device); | |
| if (!deviceCache) { | |
| deviceCache = new Map(); | |
| scratchGpuBufferCache.set(device, deviceCache); | |
| } | |
| const alignedSize = Math.max(4, Math.ceil(size / 4) * 4); | |
| const key = `${usage}:${alignedSize}:${cacheKey}`; | |
| if (!deviceCache.has(key)) { | |
| deviceCache.set(key, device.createBuffer({size: alignedSize, usage})); | |
| } | |
| return deviceCache.get(key); | |
| } | |
| async function getCachedComputePipeline(device, cacheKey, code) { | |
| let deviceCache = computePipelineCache.get(device); | |
| if (!deviceCache) { | |
| deviceCache = new Map(); | |
| computePipelineCache.set(device, deviceCache); | |
| } | |
| if (!deviceCache.has(cacheKey)) { | |
| const module = device.createShaderModule({code}); | |
| deviceCache.set(cacheKey, device.createComputePipelineAsync({ | |
| layout: "auto", | |
| compute: {module, entryPoint: "main"}, | |
| })); | |
| } | |
| return await deviceCache.get(cacheKey); | |
| } | |
| async function getCachedStageObject(device, cacheKey, factory) { | |
| if (!cacheKey) return await factory(); | |
| let deviceCache = stageObjectCache.get(device); | |
| if (!deviceCache) { | |
| deviceCache = new Map(); | |
| stageObjectCache.set(device, deviceCache); | |
| } | |
| if (!deviceCache.has(cacheKey)) { | |
| deviceCache.set(cacheKey, (async () => factory())()); | |
| } | |
| try { | |
| return await deviceCache.get(cacheKey); | |
| } catch (err) { | |
| deviceCache.delete(cacheKey); | |
| throw err; | |
| } | |
| } | |
| async function requestCustomWebGpuDevice(requiredFeatures = ["shader-f16"], requiredLimitHints = {}) { | |
| if (!navigator.gpu) { | |
| throw new Error("navigator.gpu is not available"); | |
| } | |
| const adapter = await navigator.gpu.requestAdapter({powerPreference: "high-performance"}); | |
| if (!adapter) { | |
| throw new Error("WebGPU adapter is not available"); | |
| } | |
| for (const feature of requiredFeatures) { | |
| if (feature !== "packed_4x8_integer_dot_product" && !adapter.features.has(feature)) { | |
| throw new Error(`WebGPU adapter does not expose ${feature}`); | |
| } | |
| } | |
| if (requiredFeatures.includes("packed_4x8_integer_dot_product")) { | |
| const wgslLanguageFeatures = navigator.gpu.wgslLanguageFeatures || new Set(); | |
| if (!wgslLanguageFeatures.has("packed_4x8_integer_dot_product")) { | |
| throw new Error("WGSL packed_4x8_integer_dot_product is not available"); | |
| } | |
| } | |
| const requiredDeviceFeatures = requiredFeatures.filter((feature) => feature !== "packed_4x8_integer_dot_product"); | |
| const descriptor = {requiredFeatures: requiredDeviceFeatures}; | |
| const requiredLimits = {}; | |
| if (adapter.limits && adapter.limits.maxComputeWorkgroupStorageSize >= 32768) { | |
| requiredLimits.maxComputeWorkgroupStorageSize = 32768; | |
| } | |
| if ( | |
| requiredLimitHints.maxComputeInvocationsPerWorkgroup && | |
| adapter.limits && | |
| adapter.limits.maxComputeInvocationsPerWorkgroup >= requiredLimitHints.maxComputeInvocationsPerWorkgroup | |
| ) { | |
| requiredLimits.maxComputeInvocationsPerWorkgroup = requiredLimitHints.maxComputeInvocationsPerWorkgroup; | |
| } | |
| if (adapter.limits && adapter.limits.maxStorageBufferBindingSize > 128 * 1024 * 1024) { | |
| requiredLimits.maxStorageBufferBindingSize = Math.min(adapter.limits.maxStorageBufferBindingSize, 1024 * 1024 * 1024); | |
| } | |
| if (adapter.limits && adapter.limits.maxBufferSize > 256 * 1024 * 1024) { | |
| requiredLimits.maxBufferSize = Math.min(adapter.limits.maxBufferSize, 1024 * 1024 * 1024); | |
| } | |
| if (Object.keys(requiredLimits).length) { | |
| descriptor.requiredLimits = requiredLimits; | |
| } | |
| const descriptorKey = JSON.stringify({ | |
| features: [...requiredDeviceFeatures].sort(), | |
| limits: Object.fromEntries(Object.entries(requiredLimits).sort(([a], [b]) => a.localeCompare(b))), | |
| }); | |
| if (customWebGpuDevicePromise && customWebGpuDeviceDescriptorKey !== descriptorKey) { | |
| resetCustomLowbitWebGpuDevice(); | |
| } | |
| if (!customWebGpuDevicePromise) { | |
| customWebGpuDeviceDescriptorKey = descriptorKey; | |
| const promise = adapter.requestDevice(descriptor).then((device) => { | |
| customWebGpuDevice = device; | |
| device.lost.then(() => { | |
| if (customWebGpuDevice === device) { | |
| customWebGpuDevice = null; | |
| customWebGpuDevicePromise = null; | |
| } | |
| }); | |
| return device; | |
| }).catch((err) => { | |
| if (customWebGpuDevicePromise === promise) { | |
| customWebGpuDevicePromise = null; | |
| customWebGpuDeviceDescriptorKey = ""; | |
| } | |
| throw err; | |
| }); | |
| customWebGpuDevicePromise = promise; | |
| } | |
| return {adapter, device: await customWebGpuDevicePromise}; | |
| } | |
| async function readFloat32Buffer(device, source, count, sourceOffsetBytes = 0) { | |
| const size = count * 4; | |
| const readback = device.createBuffer({ | |
| size, | |
| usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ, | |
| }); | |
| const encoder = device.createCommandEncoder(); | |
| encoder.copyBufferToBuffer(source, sourceOffsetBytes, readback, 0, size); | |
| device.queue.submit([encoder.finish()]); | |
| await readback.mapAsync(GPUMapMode.READ); | |
| const values = new Float32Array(readback.getMappedRange()).slice(); | |
| readback.unmap(); | |
| readback.destroy(); | |
| return values; | |
| } | |
| function floatArrayStats(values) { | |
| let finite = 0; | |
| let nonFinite = 0; | |
| let min = Infinity; | |
| let max = -Infinity; | |
| let sum = 0; | |
| let sumSq = 0; | |
| for (const value of values) { | |
| if (!Number.isFinite(value)) { | |
| nonFinite += 1; | |
| continue; | |
| } | |
| finite += 1; | |
| min = Math.min(min, value); | |
| max = Math.max(max, value); | |
| sum += value; | |
| sumSq += value * value; | |
| } | |
| if (!finite) { | |
| return {count: values.length, finite, nonFinite, min: NaN, max: NaN, mean: NaN, std: NaN}; | |
| } | |
| const mean = sum / finite; | |
| const variance = Math.max(0, sumSq / finite - mean * mean); | |
| return {count: values.length, finite, nonFinite, min, max, mean, std: Math.sqrt(variance)}; | |
| } | |
| function makeQ4ZpShader32xWide(tileCols, outputF16 = false, kChunk = 64, unrollK = false) { | |
| if (tileCols % 16 !== 0 || kChunk % 8 !== 0) { | |
| throw new Error(`invalid q4 tiling: tileCols=${tileCols}, kChunk=${kChunk}`); | |
| } | |
| const workgroupStorageBytes = (32 * kChunk + tileCols * kChunk) * 2; | |
| if (workgroupStorageBytes > 32768) { | |
| throw new Error(`q4 tiling exceeds WebGPU workgroup storage: tileCols=${tileCols}, kChunk=${kChunk}, bytes=${workgroupStorageBytes}`); | |
| } | |
| const colsPerThread = tileCols / 16; | |
| const xItems = kChunk * 32; | |
| const wItems = kChunk * tileCols; | |
| const outputType = outputF16 ? "f16" : "f32"; | |
| const storeValue = (value) => outputF16 | |
| ? `f16(clamp(${value}, -65504.0, 65504.0))` | |
| : `f32(${value})`; | |
| const colDecls = []; | |
| const accDecls = []; | |
| const wLoads = []; | |
| const accUpdates = []; | |
| const stores = []; | |
| for (let c = 0; c < colsPerThread; ++c) { | |
| const colOffset = c * 16; | |
| colDecls.push(` let col${c} = wid.x * ${tileCols}u + lid.x + ${colOffset}u;`); | |
| accDecls.push(` var acc0${c} = f16(0.0);`); | |
| accDecls.push(` var acc1${c} = f16(0.0);`); | |
| wLoads.push(` let w${c} = w_tile[tile_k * ${tileCols}u + lid.x + ${colOffset}u];`); | |
| accUpdates.push(` acc0${c} = acc0${c} + x0 * w${c};`); | |
| accUpdates.push(` acc1${c} = acc1${c} + x1 * w${c};`); | |
| stores.push(` if (row0 < params.n && col${c} < params.m) { y[row0 * params.m + col${c}] = ${storeValue(`acc0${c}`)}; }`); | |
| stores.push(` if (row1 < params.n && col${c} < params.m) { y[row1 * params.m + col${c}] = ${storeValue(`acc1${c}`)}; }`); | |
| } | |
| const computeBody = unrollK | |
| ? Array.from({length: kChunk}, (_, kk) => { | |
| const unrolledLoads = []; | |
| const unrolledAccUpdates = []; | |
| for (let c = 0; c < colsPerThread; ++c) { | |
| const colOffset = c * 16; | |
| unrolledLoads.push(` let w${c}_${kk} = w_tile[${kk}u * ${tileCols}u + lid.x + ${colOffset}u];`); | |
| unrolledAccUpdates.push(` acc0${c} = acc0${c} + x0_${kk} * w${c}_${kk};`); | |
| unrolledAccUpdates.push(` acc1${c} = acc1${c} + x1_${kk} * w${c}_${kk};`); | |
| } | |
| return ` | |
| let x0_${kk} = x_tile[${kk}u * 32u + lid.y]; | |
| let x1_${kk} = x_tile[${kk}u * 32u + lid.y + 16u]; | |
| ${unrolledLoads.join("\n")} | |
| ${unrolledAccUpdates.join("\n")}`; | |
| }).join("\n") | |
| : ` for (var tile_k = 0u; tile_k < ${kChunk}u; tile_k = tile_k + 1u) { | |
| let x0 = x_tile[tile_k * 32u + lid.y]; | |
| let x1 = x_tile[tile_k * 32u + lid.y + 16u]; | |
| ${wLoads.join("\n")} | |
| ${accUpdates.join("\n")} | |
| }`; | |
| return ` | |
| enable f16; | |
| struct Params { | |
| n: u32, | |
| k: u32, | |
| m: u32, | |
| group_size: u32, | |
| groups_per_col: u32, | |
| packed_group_words: u32, | |
| row_offset: u32, | |
| col_offset: u32, | |
| weight_m: u32, | |
| output_stride: u32, | |
| _pad0: u32, | |
| _pad1: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> x: array<f16>; | |
| @group(0) @binding(1) var<storage, read> packed_w: array<u32>; | |
| @group(0) @binding(2) var<storage, read> scales: array<f16>; | |
| @group(0) @binding(3) var<storage, read> zero_points: array<u32>; | |
| @group(0) @binding(4) var<storage, read_write> y: array<${outputType}>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| var<workgroup> x_tile: array<f16, ${xItems}>; | |
| var<workgroup> w_tile: array<f16, ${wItems}>; | |
| fn load_q_scaled(col: u32, k_index: u32) -> f16 { | |
| let group = k_index / params.group_size; | |
| let inner = k_index - group * params.group_size; | |
| let word = inner / 8u; | |
| let weight_col = params.col_offset + col; | |
| let word_index = (group * params.packed_group_words + word) * params.weight_m + weight_col; | |
| let shift = (inner & 7u) * 4u; | |
| let nibble = (packed_w[word_index] >> shift) & 15u; | |
| let zp = zero_points[group * params.weight_m + weight_col] & 15u; | |
| let q = f16(f32(i32(nibble) - i32(zp))); | |
| return q * scales[group * params.weight_m + weight_col]; | |
| } | |
| @compute @workgroup_size(16, 16, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let row0 = wid.y * 32u + lid.y; | |
| let row1 = row0 + 16u; | |
| ${colDecls.join("\n")} | |
| let local_linear = lid.y * 16u + lid.x; | |
| ${accDecls.join("\n")} | |
| for (var base_k = 0u; base_k < params.k; base_k = base_k + ${kChunk}u) { | |
| for (var offset = 0u; offset < ${xItems}u; offset = offset + 256u) { | |
| let tile_index = local_linear + offset; | |
| let tile_k = tile_index / 32u; | |
| let tile_lane = tile_index - tile_k * 32u; | |
| let src_k = base_k + tile_k; | |
| let src_row = wid.y * 32u + tile_lane; | |
| if (src_row < params.n && src_k < params.k) { | |
| x_tile[tile_index] = x[(params.row_offset + src_row) * params.k + src_k]; | |
| } else { | |
| x_tile[tile_index] = f16(0.0); | |
| } | |
| } | |
| for (var offset = 0u; offset < ${wItems}u; offset = offset + 256u) { | |
| let tile_index = local_linear + offset; | |
| let tile_k = tile_index / ${tileCols}u; | |
| let tile_lane = tile_index - tile_k * ${tileCols}u; | |
| let src_k = base_k + tile_k; | |
| let src_col = wid.x * ${tileCols}u + tile_lane; | |
| if (src_col < params.m && src_k < params.k) { | |
| w_tile[tile_index] = load_q_scaled(src_col, src_k); | |
| } else { | |
| w_tile[tile_index] = f16(0.0); | |
| } | |
| } | |
| workgroupBarrier(); | |
| if (row0 < params.n || row1 < params.n) { | |
| ${computeBody} | |
| } | |
| workgroupBarrier(); | |
| } | |
| ${stores.join("\n").replaceAll("row0 * params.m + col", "row0 * params.output_stride + params.col_offset + col").replaceAll("row1 * params.m + col", "row1 * params.output_stride + params.col_offset + col")} | |
| } | |
| `; | |
| } | |
| function makeQ4ZpDp4aShader32xWide(tileCols, wChunkCols = 64, outputF16 = true) { | |
| if (tileCols % wChunkCols !== 0 || wChunkCols % 16 !== 0) { | |
| throw new Error(`invalid q4 dp4a tiling: tileCols=${tileCols}, wChunkCols=${wChunkCols}`); | |
| } | |
| const colsPerThread = tileCols / 16; | |
| const colsPerChunk = wChunkCols / 16; | |
| const xItems = 32 * 64; | |
| const wItems = 64 * wChunkCols; | |
| const outputType = outputF16 ? "f16" : "f32"; | |
| const storeValue = (value) => outputF16 | |
| ? `f16(clamp(${value}, -65504.0, 65504.0))` | |
| : `f32(${value})`; | |
| const colDecls = []; | |
| const accDecls = []; | |
| const stores = []; | |
| for (let c = 0; c < colsPerThread; ++c) { | |
| const colOffset = c * 16; | |
| colDecls.push(` let col${c} = wid.x * ${tileCols}u + lid.x + ${colOffset}u;`); | |
| accDecls.push(` var acc0${c} = 0.0;`); | |
| accDecls.push(` var acc1${c} = 0.0;`); | |
| stores.push(` if (row0 < params.n && col${c} < params.m) { y[row0 * params.output_stride + params.col_offset + col${c}] = ${storeValue(`acc0${c}`)}; }`); | |
| stores.push(` if (row1 < params.n && col${c} < params.m) { y[row1 * params.output_stride + params.col_offset + col${c}] = ${storeValue(`acc1${c}`)}; }`); | |
| } | |
| const wChunkBlocks = []; | |
| for (let chunk = 0; chunk < tileCols / wChunkCols; ++chunk) { | |
| const chunkColOffset = chunk * wChunkCols; | |
| const wLoads = []; | |
| const accUpdates = []; | |
| for (let c = 0; c < colsPerChunk; ++c) { | |
| const accIndex = chunk * colsPerChunk + c; | |
| const localColOffset = c * 16; | |
| wLoads.push(` let w${accIndex} = w_tile[tile_kw * ${wChunkCols}u + lid.x + ${localColOffset}u];`); | |
| wLoads.push(` let s${accIndex} = f32(scale_tile[tile_kw * ${wChunkCols}u + lid.x + ${localColOffset}u]);`); | |
| accUpdates.push(` acc0${accIndex} = acc0${accIndex} + f32(dot4I8Packed(x0, w${accIndex})) * x_scale0 * s${accIndex};`); | |
| accUpdates.push(` acc1${accIndex} = acc1${accIndex} + f32(dot4I8Packed(x1, w${accIndex})) * x_scale1 * s${accIndex};`); | |
| } | |
| wChunkBlocks.push(` | |
| for (var offset = 0u; offset < ${wItems}u; offset = offset + 256u) { | |
| let tile_index = local_linear + offset; | |
| let tile_kw = tile_index / ${wChunkCols}u; | |
| let tile_col = tile_index - tile_kw * ${wChunkCols}u; | |
| let src_kw = base_kw + tile_kw; | |
| let src_col = wid.x * ${tileCols}u + ${chunkColOffset}u + tile_col; | |
| if (src_col < params.m && src_kw < params.k_words) { | |
| w_tile[tile_index] = load_q4_i8_packed(src_col, src_kw); | |
| let group = (src_kw * 4u) / params.group_size; | |
| scale_tile[tile_index] = scales[group * params.weight_m + params.col_offset + src_col]; | |
| } else { | |
| w_tile[tile_index] = 0u; | |
| scale_tile[tile_index] = f16(0.0); | |
| } | |
| } | |
| workgroupBarrier(); | |
| if (row0 < params.n || row1 < params.n) { | |
| for (var tile_kw = 0u; tile_kw < 64u; tile_kw = tile_kw + 1u) { | |
| let x0 = x_tile[lid.y * 64u + tile_kw]; | |
| let x1 = x_tile[(lid.y + 16u) * 64u + tile_kw]; | |
| ${wLoads.join("\n")} | |
| let x_group_size = max(1u, params.k / max(1u, params.x_scale_groups_per_row)); | |
| let x_group = ((base_kw + tile_kw) * 4u) / x_group_size; | |
| let x_scale0 = load_x_scale(row0, x_group); | |
| let x_scale1 = load_x_scale(row1, x_group); | |
| ${accUpdates.join("\n")} | |
| } | |
| } | |
| workgroupBarrier();`); | |
| } | |
| return ` | |
| requires packed_4x8_integer_dot_product; | |
| enable f16; | |
| struct Params { | |
| n: u32, | |
| k: u32, | |
| m: u32, | |
| group_size: u32, | |
| groups_per_col: u32, | |
| packed_group_words: u32, | |
| row_offset: u32, | |
| col_offset: u32, | |
| weight_m: u32, | |
| output_stride: u32, | |
| k_words: u32, | |
| x_scale_groups_per_row: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> x_packed: array<u32>; | |
| @group(0) @binding(1) var<storage, read> packed_w: array<u32>; | |
| @group(0) @binding(2) var<storage, read> x_scales: array<f32>; | |
| @group(0) @binding(3) var<storage, read> scales: array<f16>; | |
| @group(0) @binding(4) var<storage, read> zero_points: array<u32>; | |
| @group(0) @binding(5) var<storage, read_write> y: array<${outputType}>; | |
| @group(0) @binding(6) var<uniform> params: Params; | |
| var<workgroup> x_tile: array<u32, ${xItems}>; | |
| var<workgroup> w_tile: array<u32, ${wItems}>; | |
| var<workgroup> scale_tile: array<f16, ${wItems}>; | |
| fn pack_i8_lane(value: i32, lane: u32) -> u32 { | |
| return (u32(value) & 0xffu) << (lane * 8u); | |
| } | |
| fn load_q4_i8_packed(col: u32, kw: u32) -> u32 { | |
| let k_index = kw * 4u; | |
| let group = k_index / params.group_size; | |
| let inner = k_index - group * params.group_size; | |
| let q_word = inner / 8u; | |
| let shift_base = (inner & 7u) * 4u; | |
| let weight_col = params.col_offset + col; | |
| let packed = packed_w[(group * params.packed_group_words + q_word) * params.weight_m + weight_col]; | |
| let zp = i32(zero_points[group * params.weight_m + weight_col] & 15u); | |
| var out = 0u; | |
| for (var lane = 0u; lane < 4u; lane = lane + 1u) { | |
| let nibble = i32((packed >> (shift_base + lane * 4u)) & 15u); | |
| out = out | pack_i8_lane(nibble - zp, lane); | |
| } | |
| return out; | |
| } | |
| fn load_x_scale(row: u32, group: u32) -> f32 { | |
| if (row >= params.n) { | |
| return 0.0; | |
| } | |
| let source_row = params.row_offset + row; | |
| if (params.x_scale_groups_per_row > 1u) { | |
| return x_scales[source_row * params.x_scale_groups_per_row + group]; | |
| } | |
| return x_scales[source_row]; | |
| } | |
| @compute @workgroup_size(16, 16, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let row0 = wid.y * 32u + lid.y; | |
| let row1 = row0 + 16u; | |
| ${colDecls.join("\n")} | |
| let local_linear = lid.y * 16u + lid.x; | |
| ${accDecls.join("\n")} | |
| for (var base_kw = 0u; base_kw < params.k_words; base_kw = base_kw + 64u) { | |
| for (var offset = 0u; offset < ${xItems}u; offset = offset + 256u) { | |
| let tile_index = local_linear + offset; | |
| let tile_row = tile_index / 64u; | |
| let tile_kw = tile_index - tile_row * 64u; | |
| let src_row = wid.y * ${rowBlock}u + tile_row; | |
| let src_kw = base_kw + tile_kw; | |
| if (src_row < params.n && src_kw < params.k_words) { | |
| x_tile[tile_index] = x_packed[(params.row_offset + src_row) * params.k_words + src_kw]; | |
| } else { | |
| x_tile[tile_index] = 0u; | |
| } | |
| } | |
| workgroupBarrier(); | |
| ${wChunkBlocks.join("\n")} | |
| } | |
| ${stores.join("\n")} | |
| } | |
| `; | |
| } | |
| const QUANTIZE_X_F32_SHADER = ` | |
| struct Params { | |
| n: u32, | |
| k: u32, | |
| m: u32, | |
| k_words: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> x_f32: array<f32>; | |
| @group(0) @binding(1) var<storage, read_write> x_packed: array<u32>; | |
| @group(0) @binding(2) var<storage, read_write> x_scales: array<f32>; | |
| @group(0) @binding(3) var<uniform> params: Params; | |
| var<workgroup> row_absmax: array<f32, 256>; | |
| fn pack_i8_lane(value: i32, lane: u32) -> u32 { | |
| return (u32(value) & 0xffu) << (lane * 8u); | |
| } | |
| @compute @workgroup_size(256, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let row = wid.x; | |
| let local_id = lid.x; | |
| var absmax = 0.0; | |
| for (var k_index = local_id; k_index < params.k; k_index = k_index + 256u) { | |
| absmax = max(absmax, abs(x_f32[row * params.k + k_index])); | |
| } | |
| row_absmax[local_id] = absmax; | |
| workgroupBarrier(); | |
| for (var stride = 128u; stride > 0u; stride = stride / 2u) { | |
| if (local_id < stride) { | |
| row_absmax[local_id] = max(row_absmax[local_id], row_absmax[local_id + stride]); | |
| } | |
| workgroupBarrier(); | |
| } | |
| let row_scale = row_absmax[0]; | |
| let quant_scale = select(0.0, 127.0 / row_scale, row_scale > 0.0); | |
| if (local_id == 0u) { | |
| x_scales[row] = select(1.0, row_scale / 127.0, row_scale > 0.0); | |
| } | |
| for (var word = local_id; word < params.k_words; word = word + 256u) { | |
| var packed_word = 0u; | |
| for (var lane = 0u; lane < 4u; lane = lane + 1u) { | |
| let k_index = word * 4u + lane; | |
| var q = 0i; | |
| if (k_index < params.k) { | |
| let scaled = round(clamp(x_f32[row * params.k + k_index] * quant_scale, -127.0, 127.0)); | |
| q = i32(scaled); | |
| } | |
| packed_word = packed_word | pack_i8_lane(q, lane); | |
| } | |
| x_packed[row * params.k_words + word] = packed_word; | |
| } | |
| } | |
| `; | |
| const QUANTIZE_X_F32_OFFSET_SHADER = ` | |
| struct Params { | |
| n: u32, | |
| k: u32, | |
| k_words: u32, | |
| source_row_offset: u32, | |
| source_stride: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> x_f32: array<f32>; | |
| @group(0) @binding(1) var<storage, read_write> x_packed: array<u32>; | |
| @group(0) @binding(2) var<storage, read_write> x_scales: array<f32>; | |
| @group(0) @binding(3) var<uniform> params: Params; | |
| var<workgroup> row_absmax: array<f32, 256>; | |
| fn pack_i8_lane(value: i32, lane: u32) -> u32 { | |
| return (u32(value) & 0xffu) << (lane * 8u); | |
| } | |
| fn source_value(row: u32, col: u32) -> f32 { | |
| return x_f32[(row + params.source_row_offset) * params.source_stride + col]; | |
| } | |
| @compute @workgroup_size(256, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let row = wid.x; | |
| let local_id = lid.x; | |
| if (row >= params.n) { | |
| return; | |
| } | |
| var absmax = 0.0; | |
| for (var k_index = local_id; k_index < params.k; k_index = k_index + 256u) { | |
| absmax = max(absmax, abs(source_value(row, k_index))); | |
| } | |
| row_absmax[local_id] = absmax; | |
| workgroupBarrier(); | |
| for (var stride = 128u; stride > 0u; stride = stride / 2u) { | |
| if (local_id < stride) { | |
| row_absmax[local_id] = max(row_absmax[local_id], row_absmax[local_id + stride]); | |
| } | |
| workgroupBarrier(); | |
| } | |
| let row_scale = row_absmax[0]; | |
| let quant_scale = select(0.0, 127.0 / row_scale, row_scale > 0.0); | |
| if (local_id == 0u) { | |
| x_scales[row] = select(1.0, row_scale / 127.0, row_scale > 0.0); | |
| } | |
| for (var word = local_id; word < params.k_words; word = word + 256u) { | |
| var packed_word = 0u; | |
| for (var lane = 0u; lane < 4u; lane = lane + 1u) { | |
| let k_index = word * 4u + lane; | |
| var q = 0i; | |
| if (k_index < params.k) { | |
| let scaled = round(clamp(source_value(row, k_index) * quant_scale, -127.0, 127.0)); | |
| q = i32(scaled); | |
| } | |
| packed_word = packed_word | pack_i8_lane(q, lane); | |
| } | |
| x_packed[row * params.k_words + word] = packed_word; | |
| } | |
| } | |
| `; | |
| const SINGLE_PRENORM_MOD_QUANT_SHADER = ` | |
| enable f16; | |
| struct Params { | |
| n: u32, | |
| k: u32, | |
| m: u32, | |
| k_words: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> x_f32: array<f32>; | |
| @group(0) @binding(1) var<storage, read> shift: array<f32>; | |
| @group(0) @binding(2) var<storage, read> scale: array<f32>; | |
| @group(0) @binding(3) var<storage, read_write> x_packed: array<u32>; | |
| @group(0) @binding(4) var<storage, read_write> x_scales: array<f32>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| var<workgroup> scratch: array<f32, 256>; | |
| var<workgroup> row_values: array<f32, 3072>; | |
| fn pack_i8_lane(value: i32, lane: u32) -> u32 { | |
| return (u32(value) & 0xffu) << (lane * 8u); | |
| } | |
| @compute @workgroup_size(256, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let row = wid.x; | |
| let local_id = lid.x; | |
| if (row >= params.n || params.k > 3072u) { | |
| return; | |
| } | |
| var local_sum = 0.0; | |
| for (var col = local_id; col < params.k; col = col + 256u) { | |
| local_sum = local_sum + x_f32[row * params.k + col]; | |
| } | |
| scratch[local_id] = local_sum; | |
| workgroupBarrier(); | |
| for (var stride = 128u; stride > 0u; stride = stride / 2u) { | |
| if (local_id < stride) { | |
| scratch[local_id] = scratch[local_id] + scratch[local_id + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let mean = scratch[0] / f32(params.k); | |
| var local_var = 0.0; | |
| for (var col = local_id; col < params.k; col = col + 256u) { | |
| let centered = x_f32[row * params.k + col] - mean; | |
| local_var = local_var + centered * centered; | |
| } | |
| scratch[local_id] = local_var; | |
| workgroupBarrier(); | |
| for (var stride = 128u; stride > 0u; stride = stride / 2u) { | |
| if (local_id < stride) { | |
| scratch[local_id] = scratch[local_id] + scratch[local_id + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let inv_std = inverseSqrt(scratch[0] / f32(params.k) + 0.000001); | |
| var local_absmax = 0.0; | |
| for (var col = local_id; col < params.k; col = col + 256u) { | |
| let normed = f32(f16((x_f32[row * params.k + col] - mean) * inv_std)); | |
| let value = f32(f16(clamp(normed * (1.0 + f32(f16(scale[col]))) + f32(f16(shift[col])), -65504.0, 65504.0))); | |
| row_values[col] = value; | |
| local_absmax = max(local_absmax, abs(value)); | |
| } | |
| scratch[local_id] = local_absmax; | |
| workgroupBarrier(); | |
| for (var stride = 128u; stride > 0u; stride = stride / 2u) { | |
| if (local_id < stride) { | |
| scratch[local_id] = max(scratch[local_id], scratch[local_id + stride]); | |
| } | |
| workgroupBarrier(); | |
| } | |
| let row_scale = scratch[0]; | |
| let quant_scale = select(0.0, 127.0 / row_scale, row_scale > 0.0); | |
| if (local_id == 0u) { | |
| x_scales[row] = select(1.0, row_scale / 127.0, row_scale > 0.0); | |
| } | |
| for (var word = local_id; word < params.k_words; word = word + 256u) { | |
| var packed_word = 0u; | |
| for (var lane = 0u; lane < 4u; lane = lane + 1u) { | |
| let col = word * 4u + lane; | |
| var q = 0i; | |
| if (col < params.k) { | |
| let scaled = round(clamp(row_values[col] * quant_scale, -127.0, 127.0)); | |
| q = i32(scaled); | |
| } | |
| packed_word = packed_word | pack_i8_lane(q, lane); | |
| } | |
| x_packed[row * params.k_words + word] = packed_word; | |
| } | |
| } | |
| `; | |
| const SINGLE_PRENORM_MOD_GROUP_QUANT_SHADER = ` | |
| enable f16; | |
| struct Params { | |
| n: u32, | |
| k: u32, | |
| k_words: u32, | |
| group_size: u32, | |
| groups_per_row: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> x_f32: array<f32>; | |
| @group(0) @binding(1) var<storage, read> shift: array<f32>; | |
| @group(0) @binding(2) var<storage, read> scale: array<f32>; | |
| @group(0) @binding(3) var<storage, read_write> x_packed: array<u32>; | |
| @group(0) @binding(4) var<storage, read_write> x_scales: array<f32>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| var<workgroup> scratch: array<f32, 256>; | |
| var<workgroup> row_values: array<f32, 3072>; | |
| fn pack_i8_lane(value: i32, lane: u32) -> u32 { | |
| return (u32(value) & 0xffu) << (lane * 8u); | |
| } | |
| @compute @workgroup_size(256, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let row = wid.x; | |
| let local_id = lid.x; | |
| if (row >= params.n || params.k > 3072u) { | |
| return; | |
| } | |
| var local_sum = 0.0; | |
| for (var col = local_id; col < params.k; col = col + 256u) { | |
| local_sum = local_sum + x_f32[row * params.k + col]; | |
| } | |
| scratch[local_id] = local_sum; | |
| workgroupBarrier(); | |
| for (var stride = 128u; stride > 0u; stride = stride / 2u) { | |
| if (local_id < stride) { | |
| scratch[local_id] = scratch[local_id] + scratch[local_id + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let mean = scratch[0] / f32(params.k); | |
| var local_var = 0.0; | |
| for (var col = local_id; col < params.k; col = col + 256u) { | |
| let centered = x_f32[row * params.k + col] - mean; | |
| local_var = local_var + centered * centered; | |
| } | |
| scratch[local_id] = local_var; | |
| workgroupBarrier(); | |
| for (var stride = 128u; stride > 0u; stride = stride / 2u) { | |
| if (local_id < stride) { | |
| scratch[local_id] = scratch[local_id] + scratch[local_id + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let inv_std = inverseSqrt(scratch[0] / f32(params.k) + 0.000001); | |
| for (var col = local_id; col < params.k; col = col + 256u) { | |
| let normed = f32(f16((x_f32[row * params.k + col] - mean) * inv_std)); | |
| row_values[col] = f32(f16(clamp(normed * (1.0 + f32(f16(scale[col]))) + f32(f16(shift[col])), -65504.0, 65504.0))); | |
| } | |
| workgroupBarrier(); | |
| if (local_id < params.groups_per_row) { | |
| let group_start = local_id * params.group_size; | |
| var group_absmax = 0.0; | |
| for (var offset = 0u; offset < params.group_size; offset = offset + 1u) { | |
| let col = group_start + offset; | |
| if (col < params.k) { | |
| group_absmax = max(group_absmax, abs(row_values[col])); | |
| } | |
| } | |
| x_scales[row * params.groups_per_row + local_id] = select(1.0, group_absmax / 127.0, group_absmax > 0.0); | |
| } | |
| workgroupBarrier(); | |
| for (var word = local_id; word < params.k_words; word = word + 256u) { | |
| let group = (word * 4u) / params.group_size; | |
| let dequant_scale = x_scales[row * params.groups_per_row + group]; | |
| let quant_scale = select(0.0, 1.0 / dequant_scale, dequant_scale > 0.0); | |
| var packed_word = 0u; | |
| for (var lane = 0u; lane < 4u; lane = lane + 1u) { | |
| let col = word * 4u + lane; | |
| var q = 0i; | |
| if (col < params.k) { | |
| let scaled = round(clamp(row_values[col] * quant_scale, -127.0, 127.0)); | |
| q = i32(scaled); | |
| } | |
| packed_word = packed_word | pack_i8_lane(q, lane); | |
| } | |
| x_packed[row * params.k_words + word] = packed_word; | |
| } | |
| } | |
| `; | |
| const SINGLE_PRENORM_MOD_F16_SHADER = ` | |
| enable f16; | |
| struct Params { | |
| n: u32, | |
| k: u32, | |
| _pad0: u32, | |
| _pad1: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> x_f32: array<f32>; | |
| @group(0) @binding(1) var<storage, read> shift: array<f32>; | |
| @group(0) @binding(2) var<storage, read> scale: array<f32>; | |
| @group(0) @binding(3) var<storage, read_write> y_f16: array<f16>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| var<workgroup> scratch: array<f32, 256>; | |
| @compute @workgroup_size(256, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let row = wid.x; | |
| let local_id = lid.x; | |
| if (row >= params.n) { | |
| return; | |
| } | |
| var local_sum = 0.0; | |
| for (var col = local_id; col < params.k; col = col + 256u) { | |
| local_sum = local_sum + x_f32[row * params.k + col]; | |
| } | |
| scratch[local_id] = local_sum; | |
| workgroupBarrier(); | |
| for (var stride = 128u; stride > 0u; stride = stride / 2u) { | |
| if (local_id < stride) { | |
| scratch[local_id] = scratch[local_id] + scratch[local_id + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let mean = scratch[0] / f32(params.k); | |
| var local_var = 0.0; | |
| for (var col = local_id; col < params.k; col = col + 256u) { | |
| let centered = x_f32[row * params.k + col] - mean; | |
| local_var = local_var + centered * centered; | |
| } | |
| scratch[local_id] = local_var; | |
| workgroupBarrier(); | |
| for (var stride = 128u; stride > 0u; stride = stride / 2u) { | |
| if (local_id < stride) { | |
| scratch[local_id] = scratch[local_id] + scratch[local_id + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let inv_std = inverseSqrt(scratch[0] / f32(params.k) + 0.000001); | |
| for (var col = local_id; col < params.k; col = col + 256u) { | |
| let normed = (x_f32[row * params.k + col] - mean) * inv_std; | |
| let value = normed * (1.0 + scale[col]) + shift[col]; | |
| y_f16[row * params.k + col] = f16(clamp(value, -65504.0, 65504.0)); | |
| } | |
| } | |
| `; | |
| const SINGLE_RESIDUAL_GATE_SHADER = ` | |
| enable f16; | |
| struct Params { | |
| n: u32, | |
| k: u32, | |
| _pad0: u32, | |
| _pad1: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> x_f32: array<f32>; | |
| @group(0) @binding(1) var<storage, read> delta_f32: array<f32>; | |
| @group(0) @binding(2) var<storage, read> gate: array<f32>; | |
| @group(0) @binding(3) var<storage, read_write> y_f32: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| @compute @workgroup_size(256, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let col = wid.x * 256u + lid.x; | |
| let row = wid.y; | |
| if (row >= params.n || col >= params.k) { | |
| return; | |
| } | |
| let index = row * params.k + col; | |
| let gated = f32(f16(clamp(gate[col] * delta_f32[index], -65504.0, 65504.0))); | |
| y_f32[index] = f32(f16(clamp(x_f32[index] + gated, -65504.0, 65504.0))); | |
| } | |
| `; | |
| const F32_TO_F16_SHADER = ` | |
| enable f16; | |
| struct Params { | |
| count: u32, | |
| apply_silu: u32, | |
| _pad0: u32, | |
| _pad1: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> x_f32: array<f32>; | |
| @group(0) @binding(1) var<storage, read_write> y_f16: array<f16>; | |
| @group(0) @binding(2) var<uniform> params: Params; | |
| fn silu(value: f32) -> f32 { | |
| return value / (1.0 + exp(-value)); | |
| } | |
| @compute @workgroup_size(256, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let index = wid.x * 256u + lid.x; | |
| if (index >= params.count) { | |
| return; | |
| } | |
| var value = x_f32[index]; | |
| if (params.apply_silu != 0u) { | |
| value = silu(value); | |
| } | |
| y_f16[index] = f16(clamp(value, -65504.0, 65504.0)); | |
| } | |
| `; | |
| const LATENT_UPDATE_F32_SHADER = ` | |
| enable f16; | |
| @group(0) @binding(0) var<storage, read_write> latent: array<f32>; | |
| @group(0) @binding(1) var<storage, read> pred: array<f32>; | |
| @group(0) @binding(2) var<storage, read> params: array<f32>; | |
| @compute @workgroup_size(256, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let index = wid.x * 256u + lid.x; | |
| let count = u32(params[0]); | |
| if (index >= count) { | |
| return; | |
| } | |
| let pred_value = f32(f16(clamp(pred[index], -65504.0, 65504.0))); | |
| latent[index] = clamp(latent[index] + params[1] * pred_value, -65504.0, 65504.0); | |
| } | |
| `; | |
| const LATENT_UPDATE_AB2_F32_SHADER = ` | |
| enable f16; | |
| @group(0) @binding(0) var<storage, read_write> latent: array<f32>; | |
| @group(0) @binding(1) var<storage, read> pred: array<f32>; | |
| @group(0) @binding(2) var<storage, read> previous_pred: array<f32>; | |
| @group(0) @binding(3) var<storage, read> params: array<f32>; | |
| @compute @workgroup_size(256, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let index = wid.x * 256u + lid.x; | |
| let count = u32(params[0]); | |
| if (index >= count) { | |
| return; | |
| } | |
| let current = f32(f16(clamp(pred[index], -65504.0, 65504.0))); | |
| let previous = f32(f16(clamp(previous_pred[index], -65504.0, 65504.0))); | |
| let extrapolated = f32(f16(clamp(current + params[3] * (current - previous), -65504.0, 65504.0))); | |
| latent[index] = clamp(latent[index] + params[1] * current + params[2] * extrapolated, -65504.0, 65504.0); | |
| } | |
| `; | |
| const SINGLE_TAIL_DELTA_CACHE_SHADER = ` | |
| struct Params { | |
| count: u32, | |
| _pad0: u32, | |
| _pad1: u32, | |
| _pad2: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> base_state: array<f32>; | |
| @group(0) @binding(1) var<storage, read> final_state: array<f32>; | |
| @group(0) @binding(2) var<storage, read_write> delta_state: array<f32>; | |
| @group(0) @binding(3) var<uniform> params: Params; | |
| @compute @workgroup_size(256, 1, 1) | |
| fn main(@builtin(local_invocation_id) lid: vec3<u32>, @builtin(workgroup_id) wid: vec3<u32>) { | |
| let index = wid.x * 256u + lid.x; | |
| if (index >= params.count) { | |
| return; | |
| } | |
| delta_state[index] = final_state[index] - base_state[index]; | |
| } | |
| `; | |
| const SINGLE_TAIL_DELTA_APPLY_SHADER = ` | |
| struct Params { | |
| count: u32, | |
| _pad0: u32, | |
| _pad1: u32, | |
| _pad2: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read_write> state: array<f32>; | |
| @group(0) @binding(1) var<storage, read> delta_state: array<f32>; | |
| @group(0) @binding(2) var<uniform> params: Params; | |
| @compute @workgroup_size(256, 1, 1) | |
| fn main(@builtin(local_invocation_id) lid: vec3<u32>, @builtin(workgroup_id) wid: vec3<u32>) { | |
| let index = wid.x * 256u + lid.x; | |
| if (index >= params.count) { | |
| return; | |
| } | |
| state[index] = state[index] + delta_state[index]; | |
| } | |
| `; | |
| function makeI8ScaledDotShader32xWide(tileCols, wChunkCols = 32, outputF16 = false, rowBlock = 32) { | |
| if (tileCols % wChunkCols !== 0 || wChunkCols % 16 !== 0) { | |
| throw new Error(`invalid i8 dot tiling: tileCols=${tileCols}, wChunkCols=${wChunkCols}`); | |
| } | |
| if (rowBlock !== 32 && rowBlock !== 64) { | |
| throw new Error(`invalid i8 dot row block: ${rowBlock}`); | |
| } | |
| const rowsPerHalf = rowBlock / 2; | |
| const invocationCount = 16 * rowsPerHalf; | |
| const colsPerThread = tileCols / 16; | |
| const colsPerChunk = wChunkCols / 16; | |
| const xItems = rowBlock * 64; | |
| const wItems = wChunkCols * 64; | |
| const colDecls = []; | |
| const accDecls = []; | |
| const stores = []; | |
| const outputType = outputF16 ? "f16" : "f32"; | |
| for (let c = 0; c < colsPerThread; ++c) { | |
| const colOffset = c * 16; | |
| colDecls.push(` let col${c} = wid.x * ${tileCols}u + lid.x + ${colOffset}u;`); | |
| accDecls.push(` var acc0${c} = 0i;`); | |
| accDecls.push(` var acc1${c} = 0i;`); | |
| const value0 = `f32(acc0${c}) * x_scales[params.row_offset + row0] * w_scales[col${c}]`; | |
| const value1 = `f32(acc1${c}) * x_scales[params.row_offset + row1] * w_scales[col${c}]`; | |
| stores.push(` if (row0 < params.n && col${c} < params.m) { y[row0 * params.m + col${c}] = ${outputF16 ? `f16(clamp(${value0}, -65504.0, 65504.0))` : value0}; }`); | |
| stores.push(` if (row1 < params.n && col${c} < params.m) { y[row1 * params.m + col${c}] = ${outputF16 ? `f16(clamp(${value1}, -65504.0, 65504.0))` : value1}; }`); | |
| } | |
| const wChunkBlocks = []; | |
| for (let chunk = 0; chunk < tileCols / wChunkCols; ++chunk) { | |
| const chunkColOffset = chunk * wChunkCols; | |
| const wLoads = []; | |
| const accUpdates = []; | |
| for (let c = 0; c < colsPerChunk; ++c) { | |
| const accIndex = chunk * colsPerChunk + c; | |
| const localColOffset = c * 16; | |
| wLoads.push(` let w${accIndex} = w_tile[tile_kw * ${wChunkCols}u + lid.x + ${localColOffset}u];`); | |
| accUpdates.push(` acc0${accIndex} = acc0${accIndex} + dot4I8Packed(x0, w${accIndex});`); | |
| accUpdates.push(` acc1${accIndex} = acc1${accIndex} + dot4I8Packed(x1, w${accIndex});`); | |
| } | |
| wChunkBlocks.push(` | |
| for (var offset = 0u; offset < ${wItems}u; offset = offset + ${invocationCount}u) { | |
| let tile_index = local_linear + offset; | |
| let tile_kw = tile_index / ${wChunkCols}u; | |
| let tile_col = tile_index - tile_kw * ${wChunkCols}u; | |
| let src_kw = base_kw + tile_kw; | |
| let src_col = wid.x * ${tileCols}u + ${chunkColOffset}u + tile_col; | |
| if (src_col < params.m && src_kw < params.k_words) { | |
| w_tile[tile_index] = w_packed[src_kw * params.m + src_col]; | |
| } else { | |
| w_tile[tile_index] = 0u; | |
| } | |
| } | |
| workgroupBarrier(); | |
| if (row0 < params.n || row1 < params.n) { | |
| for (var tile_kw = 0u; tile_kw < 64u; tile_kw = tile_kw + 1u) { | |
| let x0 = x_tile[lid.y * 64u + tile_kw]; | |
| let x1 = x_tile[(lid.y + ${rowsPerHalf}u) * 64u + tile_kw]; | |
| ${wLoads.join("\n")} | |
| ${accUpdates.join("\n")} | |
| } | |
| } | |
| workgroupBarrier();`); | |
| } | |
| return ` | |
| requires packed_4x8_integer_dot_product; | |
| ${outputF16 ? "enable f16;" : ""} | |
| struct Params { | |
| n: u32, | |
| k: u32, | |
| m: u32, | |
| k_words: u32, | |
| row_offset: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> x_packed: array<u32>; | |
| @group(0) @binding(1) var<storage, read> w_packed: array<u32>; | |
| @group(0) @binding(2) var<storage, read> x_scales: array<f32>; | |
| @group(0) @binding(3) var<storage, read> w_scales: array<f32>; | |
| @group(0) @binding(4) var<storage, read_write> y: array<${outputType}>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| var<workgroup> x_tile: array<u32, ${xItems}>; | |
| var<workgroup> w_tile: array<u32, ${wItems}>; | |
| @compute @workgroup_size(16, ${rowsPerHalf}, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let row0 = wid.y * ${rowBlock}u + lid.y; | |
| let row1 = row0 + ${rowsPerHalf}u; | |
| ${colDecls.join("\n")} | |
| let local_linear = lid.y * 16u + lid.x; | |
| ${accDecls.join("\n")} | |
| for (var base_kw = 0u; base_kw < params.k_words; base_kw = base_kw + 64u) { | |
| for (var offset = 0u; offset < ${xItems}u; offset = offset + ${invocationCount}u) { | |
| let tile_index = local_linear + offset; | |
| let tile_row = tile_index / 64u; | |
| let tile_kw = tile_index - tile_row * 64u; | |
| let src_row = wid.y * ${rowBlock}u + tile_row; | |
| let src_kw = base_kw + tile_kw; | |
| if (src_row < params.n && src_kw < params.k_words) { | |
| x_tile[tile_index] = x_packed[(params.row_offset + src_row) * params.k_words + src_kw]; | |
| } else { | |
| x_tile[tile_index] = 0u; | |
| } | |
| } | |
| workgroupBarrier(); | |
| ${wChunkBlocks.join("\n")} | |
| } | |
| ${stores.join("\n")} | |
| } | |
| `; | |
| } | |
| function makeI8ScaledDotResidualShader32xWide(tileCols, wChunkCols = 32, rowBlock = 32) { | |
| if (tileCols % wChunkCols !== 0 || wChunkCols % 16 !== 0) { | |
| throw new Error(`invalid residual i8 dot tiling: tileCols=${tileCols}, wChunkCols=${wChunkCols}`); | |
| } | |
| if (rowBlock !== 32 && rowBlock !== 64) { | |
| throw new Error(`invalid residual i8 dot row block: ${rowBlock}`); | |
| } | |
| const rowsPerHalf = rowBlock / 2; | |
| const invocationCount = 16 * rowsPerHalf; | |
| const colsPerThread = tileCols / 16; | |
| const colsPerChunk = wChunkCols / 16; | |
| const xItems = rowBlock * 64; | |
| const wItems = wChunkCols * 64; | |
| const colDecls = []; | |
| const accDecls = []; | |
| const stores = []; | |
| for (let c = 0; c < colsPerThread; ++c) { | |
| const colOffset = c * 16; | |
| colDecls.push(` let col${c} = wid.x * ${tileCols}u + lid.x + ${colOffset}u;`); | |
| accDecls.push(` var acc0${c} = 0i;`); | |
| accDecls.push(` var acc1${c} = 0i;`); | |
| stores.push(` if (row0 < params.n && col${c} < params.m) { | |
| let dot = f32(acc0${c}) * x_scales[params.row_offset + row0] * w_scales[col${c}]; | |
| y[row0 * params.m + col${c}] = f32(f16(clamp(residual_x[row0 * params.m + col${c}] + mod_gate[col${c}] * dot, -65504.0, 65504.0))); | |
| }`); | |
| stores.push(` if (row1 < params.n && col${c} < params.m) { | |
| let dot = f32(acc1${c}) * x_scales[params.row_offset + row1] * w_scales[col${c}]; | |
| y[row1 * params.m + col${c}] = f32(f16(clamp(residual_x[row1 * params.m + col${c}] + mod_gate[col${c}] * dot, -65504.0, 65504.0))); | |
| }`); | |
| } | |
| const wChunkBlocks = []; | |
| for (let chunk = 0; chunk < tileCols / wChunkCols; ++chunk) { | |
| const chunkColOffset = chunk * wChunkCols; | |
| const wLoads = []; | |
| const accUpdates = []; | |
| for (let c = 0; c < colsPerChunk; ++c) { | |
| const accIndex = chunk * colsPerChunk + c; | |
| const localColOffset = c * 16; | |
| wLoads.push(` let w${accIndex} = w_tile[tile_kw * ${wChunkCols}u + lid.x + ${localColOffset}u];`); | |
| accUpdates.push(` acc0${accIndex} = acc0${accIndex} + dot4I8Packed(x0, w${accIndex});`); | |
| accUpdates.push(` acc1${accIndex} = acc1${accIndex} + dot4I8Packed(x1, w${accIndex});`); | |
| } | |
| wChunkBlocks.push(` | |
| for (var offset = 0u; offset < ${wItems}u; offset = offset + ${invocationCount}u) { | |
| let tile_index = local_linear + offset; | |
| let tile_kw = tile_index / ${wChunkCols}u; | |
| let tile_col = tile_index - tile_kw * ${wChunkCols}u; | |
| let src_kw = base_kw + tile_kw; | |
| let src_col = wid.x * ${tileCols}u + ${chunkColOffset}u + tile_col; | |
| if (src_col < params.m && src_kw < params.k_words) { | |
| w_tile[tile_index] = w_packed[src_kw * params.m + src_col]; | |
| } else { | |
| w_tile[tile_index] = 0u; | |
| } | |
| } | |
| workgroupBarrier(); | |
| if (row0 < params.n || row1 < params.n) { | |
| for (var tile_kw = 0u; tile_kw < 64u; tile_kw = tile_kw + 1u) { | |
| let x0 = x_tile[lid.y * 64u + tile_kw]; | |
| let x1 = x_tile[(lid.y + ${rowsPerHalf}u) * 64u + tile_kw]; | |
| ${wLoads.join("\n")} | |
| ${accUpdates.join("\n")} | |
| } | |
| } | |
| workgroupBarrier();`); | |
| } | |
| return ` | |
| requires packed_4x8_integer_dot_product; | |
| enable f16; | |
| struct Params { | |
| n: u32, | |
| k: u32, | |
| m: u32, | |
| k_words: u32, | |
| row_offset: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> x_packed: array<u32>; | |
| @group(0) @binding(1) var<storage, read> w_packed: array<u32>; | |
| @group(0) @binding(2) var<storage, read> x_scales: array<f32>; | |
| @group(0) @binding(3) var<storage, read> w_scales: array<f32>; | |
| @group(0) @binding(4) var<storage, read_write> y: array<f32>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| @group(0) @binding(6) var<storage, read> residual_x: array<f32>; | |
| @group(0) @binding(7) var<storage, read> mod_gate: array<f32>; | |
| var<workgroup> x_tile: array<u32, ${xItems}>; | |
| var<workgroup> w_tile: array<u32, ${wItems}>; | |
| @compute @workgroup_size(16, ${rowsPerHalf}, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let row0 = wid.y * ${rowBlock}u + lid.y; | |
| let row1 = row0 + ${rowsPerHalf}u; | |
| ${colDecls.join("\n")} | |
| let local_linear = lid.y * 16u + lid.x; | |
| ${accDecls.join("\n")} | |
| for (var base_kw = 0u; base_kw < params.k_words; base_kw = base_kw + 64u) { | |
| for (var offset = 0u; offset < ${xItems}u; offset = offset + ${invocationCount}u) { | |
| let tile_index = local_linear + offset; | |
| let tile_row = tile_index / 64u; | |
| let tile_kw = tile_index - tile_row * 64u; | |
| let src_row = wid.y * 32u + tile_row; | |
| let src_kw = base_kw + tile_kw; | |
| if (src_row < params.n && src_kw < params.k_words) { | |
| x_tile[tile_index] = x_packed[(params.row_offset + src_row) * params.k_words + src_kw]; | |
| } else { | |
| x_tile[tile_index] = 0u; | |
| } | |
| } | |
| workgroupBarrier(); | |
| ${wChunkBlocks.join("\n")} | |
| } | |
| ${stores.join("\n")} | |
| } | |
| `; | |
| } | |
| function makeI8ScaledDotShader16xWide(tileCols) { | |
| const colsPerThread = tileCols / 16; | |
| const xItems = 16 * 64; | |
| const wItems = tileCols * 64; | |
| const colDecls = []; | |
| const accDecls = []; | |
| const wLoads = []; | |
| const accUpdates = []; | |
| const stores = []; | |
| for (let c = 0; c < colsPerThread; ++c) { | |
| const colOffset = c * 16; | |
| colDecls.push(` let col${c} = wid.x * ${tileCols}u + lid.x + ${colOffset}u;`); | |
| accDecls.push(` var acc${c} = 0i;`); | |
| wLoads.push(` let w${c} = w_tile[tile_kw * ${tileCols}u + lid.x + ${colOffset}u];`); | |
| accUpdates.push(` acc${c} = acc${c} + dot4I8Packed(x, w${c});`); | |
| stores.push(` if (row < params.n && col${c} < params.m) { y[row * params.m + col${c}] = f32(acc${c}) * x_scales[params.row_offset + row] * w_scales[col${c}]; }`); | |
| } | |
| return ` | |
| requires packed_4x8_integer_dot_product; | |
| struct Params { | |
| n: u32, | |
| k: u32, | |
| m: u32, | |
| k_words: u32, | |
| row_offset: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> x_packed: array<u32>; | |
| @group(0) @binding(1) var<storage, read> w_packed: array<u32>; | |
| @group(0) @binding(2) var<storage, read> x_scales: array<f32>; | |
| @group(0) @binding(3) var<storage, read> w_scales: array<f32>; | |
| @group(0) @binding(4) var<storage, read_write> y: array<f32>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| var<workgroup> x_tile: array<u32, ${xItems}>; | |
| var<workgroup> w_tile: array<u32, ${wItems}>; | |
| @compute @workgroup_size(16, 16, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let row = wid.y * 16u + lid.y; | |
| ${colDecls.join("\n")} | |
| let local_linear = lid.y * 16u + lid.x; | |
| ${accDecls.join("\n")} | |
| for (var base_kw = 0u; base_kw < params.k_words; base_kw = base_kw + 64u) { | |
| for (var offset = 0u; offset < ${xItems}u; offset = offset + 256u) { | |
| let tile_index = local_linear + offset; | |
| let tile_row = tile_index / 64u; | |
| let tile_kw = tile_index - tile_row * 64u; | |
| let src_row = wid.y * 16u + tile_row; | |
| let src_kw = base_kw + tile_kw; | |
| if (src_row < params.n && src_kw < params.k_words) { | |
| x_tile[tile_index] = x_packed[(params.row_offset + src_row) * params.k_words + src_kw]; | |
| } else { | |
| x_tile[tile_index] = 0u; | |
| } | |
| } | |
| for (var offset = 0u; offset < ${wItems}u; offset = offset + 256u) { | |
| let tile_index = local_linear + offset; | |
| let tile_kw = tile_index / ${tileCols}u; | |
| let tile_col = tile_index - tile_kw * ${tileCols}u; | |
| let src_kw = base_kw + tile_kw; | |
| let src_col = wid.x * ${tileCols}u + tile_col; | |
| if (src_col < params.m && src_kw < params.k_words) { | |
| w_tile[tile_index] = w_packed[src_kw * params.m + src_col]; | |
| } else { | |
| w_tile[tile_index] = 0u; | |
| } | |
| } | |
| workgroupBarrier(); | |
| if (row < params.n) { | |
| for (var tile_kw = 0u; tile_kw < 64u; tile_kw = tile_kw + 1u) { | |
| let x = x_tile[lid.y * 64u + tile_kw]; | |
| ${wLoads.join("\n")} | |
| ${accUpdates.join("\n")} | |
| } | |
| } | |
| workgroupBarrier(); | |
| } | |
| ${stores.join("\n")} | |
| } | |
| `; | |
| } | |
| function q4CpuOutput(row, col, k, m, groupSize, packedWords, xF16, q4, scales, zeroPoints) { | |
| let acc = 0; | |
| for (let kIndex = 0; kIndex < k; ++kIndex) { | |
| const group = Math.floor(kIndex / groupSize); | |
| const inner = kIndex - group * groupSize; | |
| const word = Math.floor(inner / 8); | |
| const shift = (inner & 7) * 4; | |
| const packed = q4[(group * packedWords + word) * m + col]; | |
| const nibble = (packed >>> shift) & 15; | |
| const zp = zeroPoints[group * m + col] & 15; | |
| const scale = float16BitsToFloat32(scales[group * m + col]); | |
| acc += float16BitsToFloat32(xF16[row * k + kIndex]) * (nibble - zp) * scale; | |
| } | |
| return acc; | |
| } | |
| function signedI8Lane(packed, lane) { | |
| const value = (packed >>> (lane * 8)) & 0xff; | |
| return value >= 128 ? value - 256 : value; | |
| } | |
| function quantizedInputRow(row, k, xF32) { | |
| let absmax = 0; | |
| for (let kIndex = 0; kIndex < k; ++kIndex) { | |
| absmax = Math.max(absmax, Math.abs(xF32[row * k + kIndex])); | |
| } | |
| const scale = absmax > 0 ? absmax / 127 : 1; | |
| const q = new Int8Array(k); | |
| for (let kIndex = 0; kIndex < k; ++kIndex) { | |
| const scaled = absmax > 0 ? xF32[row * k + kIndex] * 127 / absmax : 0; | |
| q[kIndex] = Math.max(-127, Math.min(127, roundToNearest(scaled))); | |
| } | |
| return {q, scale}; | |
| } | |
| function i8CpuOutput(row, col, k, m, kWords, xF32, wPacked, wScales) { | |
| const qx = quantizedInputRow(row, k, xF32); | |
| let acc = 0; | |
| for (let kIndex = 0; kIndex < k; ++kIndex) { | |
| const word = Math.floor(kIndex / 4); | |
| const lane = kIndex & 3; | |
| const w = signedI8Lane(wPacked[word * m + col], lane); | |
| acc += qx.q[kIndex] * w; | |
| } | |
| return acc * qx.scale * wScales[col]; | |
| } | |
| async function loadI8Bundle(baseUrl) { | |
| const manifest = await (await fetch(`${baseUrl}manifest.json`)).json(); | |
| if (!manifest.i8_dp4a) { | |
| throw new Error(`${baseUrl} does not contain an i8_dp4a bundle`); | |
| } | |
| const [weightBuf, scaleBuf] = await Promise.all([ | |
| fetchArrayBuffer(`${baseUrl}${manifest.i8_dp4a.packed_u32}`), | |
| fetchArrayBuffer(`${baseUrl}${manifest.i8_dp4a.scales_f32}`), | |
| ]); | |
| return { | |
| manifest, | |
| weight: new Uint32Array(weightBuf), | |
| scales: new Float32Array(scaleBuf), | |
| }; | |
| } | |
| async function loadSingleBlockBundle(baseUrl) { | |
| const manifest = await (await fetch(`${baseUrl}manifest.json`)).json(); | |
| const [queryBuf, keyBuf] = await Promise.all([ | |
| fetchArrayBuffer(`${baseUrl}${manifest.query_norm_scale_f16}`), | |
| fetchArrayBuffer(`${baseUrl}${manifest.key_norm_scale_f16}`), | |
| ]); | |
| return { | |
| manifest, | |
| queryScale: new Uint16Array(queryBuf), | |
| keyScale: new Uint16Array(keyBuf), | |
| }; | |
| } | |
| const SINGLE_MLP_ACTIVATION_SHADER = ` | |
| struct Params { | |
| n: u32, | |
| linear1_m: u32, | |
| linear2_k: u32, | |
| attn_dim: u32, | |
| qkv_dim: u32, | |
| mlp_half_dim: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> linear1_out: array<f32>; | |
| @group(0) @binding(1) var<storage, read_write> linear2_in: array<f32>; | |
| @group(0) @binding(2) var<uniform> params: Params; | |
| fn sigmoid(value: f32) -> f32 { | |
| return 1.0 / (1.0 + exp(-value)); | |
| } | |
| fn synthetic_attention(row: u32, col: u32) -> f32 { | |
| let code = (row * 37u + col * 13u + 17u) % 251u; | |
| return (f32(code) - 125.0) / 64.0; | |
| } | |
| @compute @workgroup_size(256, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let col = wid.x * 256u + lid.x; | |
| let row = wid.y; | |
| if (row >= params.n || col >= params.linear2_k) { | |
| return; | |
| } | |
| var value = 0.0; | |
| if (col < params.attn_dim) { | |
| value = synthetic_attention(row, col); | |
| } else { | |
| let mlp_col = col - params.attn_dim; | |
| let gate = linear1_out[row * params.linear1_m + params.qkv_dim + mlp_col]; | |
| let up = linear1_out[row * params.linear1_m + params.qkv_dim + params.mlp_half_dim + mlp_col]; | |
| value = gate * sigmoid(gate) * up; | |
| } | |
| linear2_in[row * params.linear2_k + col] = value; | |
| } | |
| `; | |
| const SINGLE_ATTENTION_SHADER = ` | |
| enable f16; | |
| struct Params { | |
| n: u32, | |
| linear1_m: u32, | |
| heads: u32, | |
| head_dim: u32, | |
| qkv_dim: u32, | |
| attention_dim: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> linear1_out: array<f32>; | |
| @group(0) @binding(1) var<storage, read> query_scale: array<f16>; | |
| @group(0) @binding(2) var<storage, read> key_scale: array<f16>; | |
| @group(0) @binding(3) var<storage, read_write> attention_out: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| var<workgroup> q_norm: array<f32, 128>; | |
| var<workgroup> scores: array<f32, 4608>; | |
| var<workgroup> scratch: array<f32, 128>; | |
| fn q_value(row: u32, head: u32, dim: u32) -> f32 { | |
| return linear1_out[row * params.linear1_m + head * params.head_dim + dim]; | |
| } | |
| fn k_value(row: u32, head: u32, dim: u32) -> f32 { | |
| return linear1_out[row * params.linear1_m + params.attention_dim + head * params.head_dim + dim]; | |
| } | |
| fn v_value(row: u32, head: u32, dim: u32) -> f32 { | |
| return linear1_out[row * params.linear1_m + params.attention_dim * 2u + head * params.head_dim + dim]; | |
| } | |
| @compute @workgroup_size(128, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let dim = lid.x; | |
| let row = wid.x; | |
| let head = wid.y; | |
| if (row >= params.n || head >= params.heads || params.n > 4608u) { | |
| return; | |
| } | |
| let q = q_value(row, head, dim); | |
| scratch[dim] = q * q; | |
| workgroupBarrier(); | |
| for (var stride = 64u; stride > 0u; stride = stride / 2u) { | |
| if (dim < stride) { | |
| scratch[dim] = scratch[dim] + scratch[dim + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let q_inv = inverseSqrt(scratch[0] / f32(params.head_dim) + 0.000001); | |
| q_norm[dim] = q * q_inv * f32(query_scale[dim]); | |
| workgroupBarrier(); | |
| for (var key_row = 0u; key_row < params.n; key_row = key_row + 1u) { | |
| let k = k_value(key_row, head, dim); | |
| scratch[dim] = k * k; | |
| workgroupBarrier(); | |
| for (var stride = 64u; stride > 0u; stride = stride / 2u) { | |
| if (dim < stride) { | |
| scratch[dim] = scratch[dim] + scratch[dim + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let k_inv = inverseSqrt(scratch[0] / f32(params.head_dim) + 0.000001); | |
| let k_norm = k * k_inv * f32(key_scale[dim]); | |
| scratch[dim] = q_norm[dim] * k_norm; | |
| workgroupBarrier(); | |
| for (var stride = 64u; stride > 0u; stride = stride / 2u) { | |
| if (dim < stride) { | |
| scratch[dim] = scratch[dim] + scratch[dim + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| if (dim == 0u) { | |
| scores[key_row] = scratch[0] * 0.08838835; | |
| } | |
| workgroupBarrier(); | |
| } | |
| var local_max = -3.402823e38; | |
| for (var key_row = dim; key_row < params.n; key_row = key_row + params.head_dim) { | |
| local_max = max(local_max, scores[key_row]); | |
| } | |
| scratch[dim] = local_max; | |
| workgroupBarrier(); | |
| for (var stride = 64u; stride > 0u; stride = stride / 2u) { | |
| if (dim < stride) { | |
| scratch[dim] = max(scratch[dim], scratch[dim + stride]); | |
| } | |
| workgroupBarrier(); | |
| } | |
| let max_score = scratch[0]; | |
| var local_sum = 0.0; | |
| for (var key_row = dim; key_row < params.n; key_row = key_row + params.head_dim) { | |
| local_sum = local_sum + exp(scores[key_row] - max_score); | |
| } | |
| scratch[dim] = local_sum; | |
| workgroupBarrier(); | |
| for (var stride = 64u; stride > 0u; stride = stride / 2u) { | |
| if (dim < stride) { | |
| scratch[dim] = scratch[dim] + scratch[dim + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let inv_sum = 1.0 / scratch[0]; | |
| var out = 0.0; | |
| for (var key_row = 0u; key_row < params.n; key_row = key_row + 1u) { | |
| let weight = f32(f16(exp(scores[key_row] - max_score) * inv_sum)); | |
| out = out + weight * f32(f16(v_value(key_row, head, dim))); | |
| } | |
| attention_out[row * params.attention_dim + head * params.head_dim + dim] = f32(f16(clamp(out, -65504.0, 65504.0))); | |
| } | |
| `; | |
| const SINGLE_QK_NORM_SHADER = ` | |
| enable f16; | |
| struct Params { | |
| n: u32, | |
| linear1_m: u32, | |
| heads: u32, | |
| head_dim: u32, | |
| qkv_dim: u32, | |
| attention_dim: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> linear1_out: array<f32>; | |
| @group(0) @binding(1) var<storage, read> query_scale: array<f16>; | |
| @group(0) @binding(2) var<storage, read> key_scale: array<f16>; | |
| @group(0) @binding(3) var<storage, read_write> q_norm_out: array<f32>; | |
| @group(0) @binding(4) var<storage, read_write> k_norm_out: array<f32>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| var<workgroup> scratch: array<f32, 128>; | |
| fn q_value(row: u32, head: u32, dim: u32) -> f32 { | |
| return linear1_out[row * params.linear1_m + head * params.head_dim + dim]; | |
| } | |
| fn k_value(row: u32, head: u32, dim: u32) -> f32 { | |
| return linear1_out[row * params.linear1_m + params.attention_dim + head * params.head_dim + dim]; | |
| } | |
| @compute @workgroup_size(128, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let dim = lid.x; | |
| let row = wid.x; | |
| let head = wid.y; | |
| if (row >= params.n || head >= params.heads) { | |
| return; | |
| } | |
| let q = q_value(row, head, dim); | |
| scratch[dim] = q * q; | |
| workgroupBarrier(); | |
| for (var stride = 64u; stride > 0u; stride = stride / 2u) { | |
| if (dim < stride) { | |
| scratch[dim] = scratch[dim] + scratch[dim + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let q_inv = inverseSqrt(scratch[0] / f32(params.head_dim) + 0.000001); | |
| q_norm_out[row * params.attention_dim + head * params.head_dim + dim] = | |
| q * q_inv * f32(query_scale[dim]); | |
| workgroupBarrier(); | |
| let k = k_value(row, head, dim); | |
| scratch[dim] = k * k; | |
| workgroupBarrier(); | |
| for (var stride = 64u; stride > 0u; stride = stride / 2u) { | |
| if (dim < stride) { | |
| scratch[dim] = scratch[dim] + scratch[dim + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let k_inv = inverseSqrt(scratch[0] / f32(params.head_dim) + 0.000001); | |
| k_norm_out[row * params.attention_dim + head * params.head_dim + dim] = | |
| k * k_inv * f32(key_scale[dim]); | |
| } | |
| `; | |
| const SINGLE_QK_NORM_ROPE_SHADER = ` | |
| enable f16; | |
| struct Params { | |
| n: u32, | |
| linear1_m: u32, | |
| heads: u32, | |
| head_dim: u32, | |
| qkv_dim: u32, | |
| attention_dim: u32, | |
| text_tokens: u32, | |
| image_width: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> linear1_out: array<f32>; | |
| @group(0) @binding(1) var<storage, read> query_scale: array<f16>; | |
| @group(0) @binding(2) var<storage, read> key_scale: array<f16>; | |
| @group(0) @binding(3) var<storage, read_write> q_norm_out: array<f32>; | |
| @group(0) @binding(4) var<storage, read_write> k_norm_out: array<f32>; | |
| @group(0) @binding(5) var<uniform> params: Params; | |
| @group(0) @binding(6) var<storage, read> rope_sincos: array<f32>; | |
| var<workgroup> scratch: array<f32, 128>; | |
| var<workgroup> scratch2: array<f32, 128>; | |
| var<workgroup> q_temp: array<f32, 128>; | |
| var<workgroup> k_temp: array<f32, 128>; | |
| fn q_value(row: u32, head: u32, dim: u32) -> f32 { | |
| return linear1_out[row * params.linear1_m + head * params.head_dim + dim]; | |
| } | |
| fn k_value(row: u32, head: u32, dim: u32) -> f32 { | |
| return linear1_out[row * params.linear1_m + params.attention_dim + head * params.head_dim + dim]; | |
| } | |
| fn apply_rope(value: f32, mate: f32, dim: u32, row: u32) -> f32 { | |
| let pair = dim / 2u; | |
| let lane = dim & 1u; | |
| let table_index = (row * 64u + pair) * 2u; | |
| let c = rope_sincos[table_index]; | |
| let s = rope_sincos[table_index + 1u]; | |
| if (lane == 0u) { | |
| return c * value - s * mate; | |
| } | |
| return s * mate + c * value; | |
| } | |
| @compute @workgroup_size(128, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let dim = lid.x; | |
| let row = wid.x; | |
| let head = wid.y; | |
| if (row >= params.n || head >= params.heads) { | |
| return; | |
| } | |
| let q = q_value(row, head, dim); | |
| let k = k_value(row, head, dim); | |
| scratch[dim] = q * q; | |
| scratch2[dim] = k * k; | |
| workgroupBarrier(); | |
| for (var stride = 64u; stride > 0u; stride = stride / 2u) { | |
| if (dim < stride) { | |
| scratch[dim] = scratch[dim] + scratch[dim + stride]; | |
| scratch2[dim] = scratch2[dim] + scratch2[dim + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let q_inv = inverseSqrt(scratch[0] / f32(params.head_dim) + 0.000001); | |
| let k_inv = inverseSqrt(scratch2[0] / f32(params.head_dim) + 0.000001); | |
| q_temp[dim] = f32(f16(f32(f16(q * q_inv)) * f32(query_scale[dim]))); | |
| k_temp[dim] = f32(f16(f32(f16(k * k_inv)) * f32(key_scale[dim]))); | |
| workgroupBarrier(); | |
| q_norm_out[row * params.attention_dim + head * params.head_dim + dim] = | |
| apply_rope(q_temp[dim], q_temp[dim ^ 1u], dim, row); | |
| k_norm_out[row * params.attention_dim + head * params.head_dim + dim] = | |
| apply_rope(k_temp[dim], k_temp[dim ^ 1u], dim, row); | |
| } | |
| `; | |
| const SINGLE_ATTENTION_PRENORM_SHADER = ` | |
| enable f16; | |
| struct Params { | |
| n: u32, | |
| linear1_m: u32, | |
| heads: u32, | |
| head_dim: u32, | |
| qkv_dim: u32, | |
| attention_dim: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> linear1_out: array<f32>; | |
| @group(0) @binding(1) var<storage, read> q_norm: array<f32>; | |
| @group(0) @binding(2) var<storage, read> k_norm: array<f32>; | |
| @group(0) @binding(3) var<storage, read_write> attention_out: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| var<workgroup> scores: array<f32, 4608>; | |
| var<workgroup> scratch: array<f32, 128>; | |
| fn qn(row: u32, head: u32, dim: u32) -> f32 { | |
| return q_norm[row * params.attention_dim + head * params.head_dim + dim]; | |
| } | |
| fn kn(row: u32, head: u32, dim: u32) -> f32 { | |
| return k_norm[row * params.attention_dim + head * params.head_dim + dim]; | |
| } | |
| fn v_value(row: u32, head: u32, dim: u32) -> f32 { | |
| return linear1_out[row * params.linear1_m + params.attention_dim * 2u + head * params.head_dim + dim]; | |
| } | |
| @compute @workgroup_size(128, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let dim = lid.x; | |
| let row = wid.x; | |
| let head = wid.y; | |
| if (row >= params.n || head >= params.heads || params.n > 4608u) { | |
| return; | |
| } | |
| let q = qn(row, head, dim); | |
| for (var key_row = 0u; key_row < params.n; key_row = key_row + 1u) { | |
| scratch[dim] = q * kn(key_row, head, dim); | |
| workgroupBarrier(); | |
| for (var stride = 64u; stride > 0u; stride = stride / 2u) { | |
| if (dim < stride) { | |
| scratch[dim] = scratch[dim] + scratch[dim + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| if (dim == 0u) { | |
| scores[key_row] = scratch[0] * 0.08838835; | |
| } | |
| workgroupBarrier(); | |
| } | |
| var local_max = -3.402823e38; | |
| for (var key_row = dim; key_row < params.n; key_row = key_row + params.head_dim) { | |
| local_max = max(local_max, scores[key_row]); | |
| } | |
| scratch[dim] = local_max; | |
| workgroupBarrier(); | |
| for (var stride = 64u; stride > 0u; stride = stride / 2u) { | |
| if (dim < stride) { | |
| scratch[dim] = max(scratch[dim], scratch[dim + stride]); | |
| } | |
| workgroupBarrier(); | |
| } | |
| let max_score = scratch[0]; | |
| var local_sum = 0.0; | |
| for (var key_row = dim; key_row < params.n; key_row = key_row + params.head_dim) { | |
| local_sum = local_sum + exp(scores[key_row] - max_score); | |
| } | |
| scratch[dim] = local_sum; | |
| workgroupBarrier(); | |
| for (var stride = 64u; stride > 0u; stride = stride / 2u) { | |
| if (dim < stride) { | |
| scratch[dim] = scratch[dim] + scratch[dim + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let inv_sum = 1.0 / scratch[0]; | |
| var out = 0.0; | |
| for (var key_row = 0u; key_row < params.n; key_row = key_row + 1u) { | |
| let weight = f32(f16(exp(scores[key_row] - max_score) * inv_sum)); | |
| out = out + weight * f32(f16(v_value(key_row, head, dim))); | |
| } | |
| attention_out[row * params.attention_dim + head * params.head_dim + dim] = f32(f16(clamp(out, -65504.0, 65504.0))); | |
| } | |
| `; | |
| const SINGLE_ATTENTION_TILED_SHADER = ` | |
| enable f16; | |
| struct Params { | |
| n: u32, | |
| linear1_m: u32, | |
| heads: u32, | |
| head_dim: u32, | |
| qkv_dim: u32, | |
| attention_dim: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> linear1_out: array<f32>; | |
| @group(0) @binding(1) var<storage, read> q_norm: array<f32>; | |
| @group(0) @binding(2) var<storage, read> k_norm: array<f32>; | |
| @group(0) @binding(3) var<storage, read_write> attention_out: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| var<workgroup> k_tile: array<f32, 1024>; | |
| var<workgroup> v_tile: array<f32, 1024>; | |
| var<workgroup> scores: array<f32, 32>; | |
| var<workgroup> reduce: array<f32, 64>; | |
| fn qn(row: u32, head: u32, dim: u32) -> f32 { | |
| return q_norm[row * params.attention_dim + head * params.head_dim + dim]; | |
| } | |
| fn kn(row: u32, head: u32, dim: u32) -> f32 { | |
| return k_norm[row * params.attention_dim + head * params.head_dim + dim]; | |
| } | |
| fn v_value(row: u32, head: u32, dim: u32) -> f32 { | |
| return linear1_out[row * params.linear1_m + params.attention_dim * 2u + head * params.head_dim + dim]; | |
| } | |
| @compute @workgroup_size(16, 4, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let dim_lane = lid.x; | |
| let q_lane = lid.y; | |
| let local_linear = q_lane * 16u + dim_lane; | |
| let row = wid.x * 4u + q_lane; | |
| let head = wid.y; | |
| if (head >= params.heads || params.n > 4608u) { | |
| return; | |
| } | |
| let row_valid = row < params.n; | |
| let dim0 = dim_lane * 8u; | |
| var qv: array<f32, 8>; | |
| var acc: array<f32, 8>; | |
| for (var i = 0u; i < 8u; i = i + 1u) { | |
| let dim = dim0 + i; | |
| if (row_valid) { | |
| qv[i] = qn(row, head, dim); | |
| } else { | |
| qv[i] = 0.0; | |
| } | |
| acc[i] = 0.0; | |
| } | |
| var m_state = -3.402823e38; | |
| var l_state = 0.0; | |
| for (var key_base = 0u; key_base < params.n; key_base = key_base + 8u) { | |
| for (var tile_index = local_linear; tile_index < 1024u; tile_index = tile_index + 64u) { | |
| let key_offset = tile_index / 128u; | |
| let dim = tile_index - key_offset * 128u; | |
| let key_row = key_base + key_offset; | |
| if (key_row < params.n) { | |
| k_tile[tile_index] = kn(key_row, head, dim); | |
| v_tile[tile_index] = v_value(key_row, head, dim); | |
| } else { | |
| k_tile[tile_index] = 0.0; | |
| v_tile[tile_index] = 0.0; | |
| } | |
| } | |
| workgroupBarrier(); | |
| for (var key_offset = 0u; key_offset < 8u; key_offset = key_offset + 1u) { | |
| var partial = 0.0; | |
| for (var i = 0u; i < 8u; i = i + 1u) { | |
| partial = partial + qv[i] * k_tile[key_offset * 128u + dim0 + i]; | |
| } | |
| reduce[q_lane * 16u + dim_lane] = partial; | |
| workgroupBarrier(); | |
| for (var stride = 8u; stride > 0u; stride = stride / 2u) { | |
| if (dim_lane < stride) { | |
| reduce[q_lane * 16u + dim_lane] = | |
| reduce[q_lane * 16u + dim_lane] + reduce[q_lane * 16u + dim_lane + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| if (dim_lane == 0u) { | |
| let key_row = key_base + key_offset; | |
| scores[q_lane * 8u + key_offset] = | |
| select(-3.402823e38, reduce[q_lane * 16u] * 0.08838835, row_valid && key_row < params.n); | |
| } | |
| workgroupBarrier(); | |
| } | |
| var local_max = -3.402823e38; | |
| if (dim_lane < 8u) { | |
| local_max = scores[q_lane * 8u + dim_lane]; | |
| } | |
| reduce[q_lane * 16u + dim_lane] = local_max; | |
| workgroupBarrier(); | |
| for (var stride = 8u; stride > 0u; stride = stride / 2u) { | |
| if (dim_lane < stride) { | |
| reduce[q_lane * 16u + dim_lane] = | |
| max(reduce[q_lane * 16u + dim_lane], reduce[q_lane * 16u + dim_lane + stride]); | |
| } | |
| workgroupBarrier(); | |
| } | |
| let tile_max = reduce[q_lane * 16u]; | |
| let m_new = max(m_state, tile_max); | |
| let alpha = exp(m_state - m_new); | |
| var local_sum = 0.0; | |
| if (dim_lane < 8u) { | |
| local_sum = exp(scores[q_lane * 8u + dim_lane] - m_new); | |
| } | |
| reduce[q_lane * 16u + dim_lane] = local_sum; | |
| workgroupBarrier(); | |
| for (var stride = 8u; stride > 0u; stride = stride / 2u) { | |
| if (dim_lane < stride) { | |
| reduce[q_lane * 16u + dim_lane] = | |
| reduce[q_lane * 16u + dim_lane] + reduce[q_lane * 16u + dim_lane + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let l_new = l_state * alpha + reduce[q_lane * 16u]; | |
| for (var i = 0u; i < 8u; i = i + 1u) { | |
| let dim = dim0 + i; | |
| var weighted_v = 0.0; | |
| for (var key_offset = 0u; key_offset < 8u; key_offset = key_offset + 1u) { | |
| let weight = f32(f16(exp(scores[q_lane * 8u + key_offset] - m_new))); | |
| weighted_v = weighted_v + weight * f32(f16(v_tile[key_offset * 128u + dim])); | |
| } | |
| acc[i] = acc[i] * alpha + weighted_v; | |
| } | |
| m_state = m_new; | |
| l_state = l_new; | |
| workgroupBarrier(); | |
| } | |
| if (row_valid) { | |
| let inv_l = 1.0 / l_state; | |
| for (var i = 0u; i < 8u; i = i + 1u) { | |
| let dim = dim0 + i; | |
| attention_out[row * params.attention_dim + head * params.head_dim + dim] = f32(f16(clamp(acc[i] * inv_l, -65504.0, 65504.0))); | |
| } | |
| } | |
| } | |
| `; | |
| function makeSingleAttentionTiledShader(tileKeys, queryRows = 4) { | |
| const keys = Number(tileKeys || 8); | |
| const rows = Number(queryRows || 4); | |
| if (![8, 16].includes(keys)) { | |
| throw new Error(`unsupported single attention tile key count: ${tileKeys}`); | |
| } | |
| if (![4, 8, 16].includes(rows)) { | |
| throw new Error(`unsupported single attention query row count: ${queryRows}`); | |
| } | |
| if (keys === 8 && rows === 4) return SINGLE_ATTENTION_TILED_SHADER; | |
| return ` | |
| enable f16; | |
| struct Params { | |
| n: u32, | |
| linear1_m: u32, | |
| heads: u32, | |
| head_dim: u32, | |
| qkv_dim: u32, | |
| attention_dim: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> linear1_out: array<f32>; | |
| @group(0) @binding(1) var<storage, read> q_norm: array<f32>; | |
| @group(0) @binding(2) var<storage, read> k_norm: array<f32>; | |
| @group(0) @binding(3) var<storage, read_write> attention_out: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| var<workgroup> k_tile: array<f32, ${keys * 128}>; | |
| var<workgroup> v_tile: array<f32, ${keys * 128}>; | |
| var<workgroup> scores: array<f32, ${rows * keys}>; | |
| var<workgroup> reduce: array<f32, ${rows * 16}>; | |
| fn qn(row: u32, head: u32, dim: u32) -> f32 { | |
| return q_norm[row * params.attention_dim + head * params.head_dim + dim]; | |
| } | |
| fn kn(row: u32, head: u32, dim: u32) -> f32 { | |
| return k_norm[row * params.attention_dim + head * params.head_dim + dim]; | |
| } | |
| fn v_value(row: u32, head: u32, dim: u32) -> f32 { | |
| return linear1_out[row * params.linear1_m + params.attention_dim * 2u + head * params.head_dim + dim]; | |
| } | |
| @compute @workgroup_size(16, ${rows}, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let dim_lane = lid.x; | |
| let q_lane = lid.y; | |
| let local_linear = q_lane * 16u + dim_lane; | |
| let row = wid.x * ${rows}u + q_lane; | |
| let head = wid.y; | |
| if (head >= params.heads || params.n > 4608u) { | |
| return; | |
| } | |
| let row_valid = row < params.n; | |
| let dim0 = dim_lane * 8u; | |
| var qv: array<f32, 8>; | |
| var acc: array<f32, 8>; | |
| for (var i = 0u; i < 8u; i = i + 1u) { | |
| let dim = dim0 + i; | |
| if (row_valid) { | |
| qv[i] = qn(row, head, dim); | |
| } else { | |
| qv[i] = 0.0; | |
| } | |
| acc[i] = 0.0; | |
| } | |
| var m_state = -3.402823e38; | |
| var l_state = 0.0; | |
| for (var key_base = 0u; key_base < params.n; key_base = key_base + ${keys}u) { | |
| for (var tile_index = local_linear; tile_index < ${keys * 128}u; tile_index = tile_index + ${rows * 16}u) { | |
| let key_offset = tile_index / 128u; | |
| let dim = tile_index - key_offset * 128u; | |
| let key_row = key_base + key_offset; | |
| if (key_row < params.n) { | |
| k_tile[tile_index] = kn(key_row, head, dim); | |
| v_tile[tile_index] = v_value(key_row, head, dim); | |
| } else { | |
| k_tile[tile_index] = 0.0; | |
| v_tile[tile_index] = 0.0; | |
| } | |
| } | |
| workgroupBarrier(); | |
| for (var key_offset = 0u; key_offset < ${keys}u; key_offset = key_offset + 1u) { | |
| var partial = 0.0; | |
| for (var i = 0u; i < 8u; i = i + 1u) { | |
| partial = partial + qv[i] * k_tile[key_offset * 128u + dim0 + i]; | |
| } | |
| reduce[q_lane * 16u + dim_lane] = partial; | |
| workgroupBarrier(); | |
| for (var stride = 8u; stride > 0u; stride = stride / 2u) { | |
| if (dim_lane < stride) { | |
| reduce[q_lane * 16u + dim_lane] = | |
| reduce[q_lane * 16u + dim_lane] + reduce[q_lane * 16u + dim_lane + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| if (dim_lane == 0u) { | |
| let key_row = key_base + key_offset; | |
| scores[q_lane * ${keys}u + key_offset] = | |
| select(-3.402823e38, reduce[q_lane * 16u] * 0.08838835, row_valid && key_row < params.n); | |
| } | |
| workgroupBarrier(); | |
| } | |
| var local_max = -3.402823e38; | |
| if (dim_lane < ${keys}u) { | |
| local_max = scores[q_lane * ${keys}u + dim_lane]; | |
| } | |
| reduce[q_lane * 16u + dim_lane] = local_max; | |
| workgroupBarrier(); | |
| for (var stride = 8u; stride > 0u; stride = stride / 2u) { | |
| if (dim_lane < stride) { | |
| reduce[q_lane * 16u + dim_lane] = | |
| max(reduce[q_lane * 16u + dim_lane], reduce[q_lane * 16u + dim_lane + stride]); | |
| } | |
| workgroupBarrier(); | |
| } | |
| let tile_max = reduce[q_lane * 16u]; | |
| let m_new = max(m_state, tile_max); | |
| let alpha = exp(m_state - m_new); | |
| var local_sum = 0.0; | |
| if (dim_lane < ${keys}u) { | |
| local_sum = exp(scores[q_lane * ${keys}u + dim_lane] - m_new); | |
| } | |
| reduce[q_lane * 16u + dim_lane] = local_sum; | |
| workgroupBarrier(); | |
| for (var stride = 8u; stride > 0u; stride = stride / 2u) { | |
| if (dim_lane < stride) { | |
| reduce[q_lane * 16u + dim_lane] = | |
| reduce[q_lane * 16u + dim_lane] + reduce[q_lane * 16u + dim_lane + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let l_new = l_state * alpha + reduce[q_lane * 16u]; | |
| for (var i = 0u; i < 8u; i = i + 1u) { | |
| let dim = dim0 + i; | |
| var weighted_v = 0.0; | |
| for (var key_offset = 0u; key_offset < ${keys}u; key_offset = key_offset + 1u) { | |
| let weight = f32(f16(exp(scores[q_lane * ${keys}u + key_offset] - m_new))); | |
| weighted_v = weighted_v + weight * f32(f16(v_tile[key_offset * 128u + dim])); | |
| } | |
| acc[i] = acc[i] * alpha + weighted_v; | |
| } | |
| m_state = m_new; | |
| l_state = l_new; | |
| workgroupBarrier(); | |
| } | |
| if (row_valid) { | |
| let inv_l = 1.0 / l_state; | |
| for (var i = 0u; i < 8u; i = i + 1u) { | |
| let dim = dim0 + i; | |
| attention_out[row * params.attention_dim + head * params.head_dim + dim] = f32(f16(clamp(acc[i] * inv_l, -65504.0, 65504.0))); | |
| } | |
| } | |
| } | |
| `; | |
| } | |
| function makeSingleAttentionSubgroupShader(tileKeys) { | |
| const keys = Number(tileKeys || 8); | |
| if (![8, 16].includes(keys)) { | |
| throw new Error(`unsupported subgroup single attention tile key count: ${tileKeys}`); | |
| } | |
| return ` | |
| enable f16; | |
| enable subgroups; | |
| struct Params { | |
| n: u32, | |
| linear1_m: u32, | |
| heads: u32, | |
| head_dim: u32, | |
| qkv_dim: u32, | |
| attention_dim: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> linear1_out: array<f32>; | |
| @group(0) @binding(1) var<storage, read> q_norm: array<f32>; | |
| @group(0) @binding(2) var<storage, read> k_norm: array<f32>; | |
| @group(0) @binding(3) var<storage, read_write> attention_out: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| var<workgroup> k_tile: array<f32, ${keys * 128}>; | |
| var<workgroup> v_tile: array<f32, ${keys * 128}>; | |
| var<workgroup> scores: array<f32, ${keys}>; | |
| fn qn(row: u32, head: u32, dim: u32) -> f32 { | |
| return q_norm[row * params.attention_dim + head * params.head_dim + dim]; | |
| } | |
| fn kn(row: u32, head: u32, dim: u32) -> f32 { | |
| return k_norm[row * params.attention_dim + head * params.head_dim + dim]; | |
| } | |
| fn v_value(row: u32, head: u32, dim: u32) -> f32 { | |
| return linear1_out[row * params.linear1_m + params.attention_dim * 2u + head * params.head_dim + dim]; | |
| } | |
| @compute @workgroup_size(32, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let lane = lid.x; | |
| let row = wid.x; | |
| let head = wid.y; | |
| if (head >= params.heads || params.n > 4608u) { | |
| return; | |
| } | |
| let row_valid = row < params.n; | |
| let dim0 = lane * 4u; | |
| var qv: array<f32, 4>; | |
| var acc: array<f32, 4>; | |
| for (var i = 0u; i < 4u; i = i + 1u) { | |
| let dim = dim0 + i; | |
| qv[i] = select(0.0, qn(row, head, dim), row_valid); | |
| acc[i] = 0.0; | |
| } | |
| var m_state = -3.402823e38; | |
| var l_state = 0.0; | |
| for (var key_base = 0u; key_base < params.n; key_base = key_base + ${keys}u) { | |
| for (var tile_index = lane; tile_index < ${keys * 128}u; tile_index = tile_index + 32u) { | |
| let key_offset = tile_index / 128u; | |
| let dim = tile_index - key_offset * 128u; | |
| let key_row = key_base + key_offset; | |
| if (key_row < params.n) { | |
| k_tile[tile_index] = kn(key_row, head, dim); | |
| v_tile[tile_index] = v_value(key_row, head, dim); | |
| } else { | |
| k_tile[tile_index] = 0.0; | |
| v_tile[tile_index] = 0.0; | |
| } | |
| } | |
| workgroupBarrier(); | |
| for (var key_offset = 0u; key_offset < ${keys}u; key_offset = key_offset + 1u) { | |
| var partial = 0.0; | |
| for (var i = 0u; i < 4u; i = i + 1u) { | |
| partial = partial + qv[i] * k_tile[key_offset * 128u + dim0 + i]; | |
| } | |
| let dot = subgroupAdd(partial); | |
| if (lane == 0u) { | |
| let key_row = key_base + key_offset; | |
| scores[key_offset] = select(-3.402823e38, dot * 0.08838835, row_valid && key_row < params.n); | |
| } | |
| } | |
| workgroupBarrier(); | |
| var local_max = -3.402823e38; | |
| if (lane < ${keys}u) { | |
| local_max = scores[lane]; | |
| } | |
| let tile_max = subgroupMax(local_max); | |
| let m_new = max(m_state, tile_max); | |
| let alpha = exp(m_state - m_new); | |
| var local_sum = 0.0; | |
| if (lane < ${keys}u) { | |
| local_sum = exp(scores[lane] - m_new); | |
| } | |
| let l_new = l_state * alpha + subgroupAdd(local_sum); | |
| for (var i = 0u; i < 4u; i = i + 1u) { | |
| let dim = dim0 + i; | |
| var weighted_v = 0.0; | |
| for (var key_offset = 0u; key_offset < ${keys}u; key_offset = key_offset + 1u) { | |
| let weight = f32(f16(exp(scores[key_offset] - m_new))); | |
| weighted_v = weighted_v + weight * f32(f16(v_tile[key_offset * 128u + dim])); | |
| } | |
| acc[i] = acc[i] * alpha + weighted_v; | |
| } | |
| m_state = m_new; | |
| l_state = l_new; | |
| workgroupBarrier(); | |
| } | |
| if (row_valid) { | |
| let inv_l = 1.0 / l_state; | |
| for (var i = 0u; i < 4u; i = i + 1u) { | |
| let dim = dim0 + i; | |
| attention_out[row * params.attention_dim + head * params.head_dim + dim] = f32(f16(clamp(acc[i] * inv_l, -65504.0, 65504.0))); | |
| } | |
| } | |
| } | |
| `; | |
| } | |
| function makeLinear1F16ConsumerShader(shader) { | |
| let code = shader; | |
| if (!/^\s*enable f16;/.test(code)) { | |
| code = `enable f16;\n${code}`; | |
| } | |
| code = code.replace( | |
| "@group(0) @binding(0) var<storage, read> linear1_out: array<f32>;", | |
| "@group(0) @binding(0) var<storage, read> linear1_out: array<f16>;", | |
| ); | |
| code = code.replace(/return linear1_out\[([^\]]+)\];/g, "return f32(linear1_out[$1]);"); | |
| code = code.replace(/let gate = linear1_out\[([^\]]+)\];/g, "let gate = f32(linear1_out[$1]);"); | |
| code = code.replace(/let up = linear1_out\[([^\]]+)\];/g, "let up = f32(linear1_out[$1]);"); | |
| return code; | |
| } | |
| function makeSingleQkNormF16StorageShader(shader) { | |
| let code = shader; | |
| if (!/^\s*enable f16;/.test(code)) { | |
| code = `enable f16;\n${code}`; | |
| } | |
| code = code.replace( | |
| "@group(0) @binding(3) var<storage, read_write> q_norm_out: array<f32>;", | |
| "@group(0) @binding(3) var<storage, read_write> q_norm_out: array<f16>;", | |
| ); | |
| code = code.replace( | |
| "@group(0) @binding(4) var<storage, read_write> k_norm_out: array<f32>;", | |
| "@group(0) @binding(4) var<storage, read_write> k_norm_out: array<f16>;", | |
| ); | |
| code = code.replace( | |
| "q_norm_out[row * params.attention_dim + head * params.head_dim + dim] =\n apply_rope(q_temp[dim], q_temp[dim ^ 1u], dim, row);", | |
| "q_norm_out[row * params.attention_dim + head * params.head_dim + dim] =\n f16(apply_rope(q_temp[dim], q_temp[dim ^ 1u], dim, row));", | |
| ); | |
| code = code.replace( | |
| "k_norm_out[row * params.attention_dim + head * params.head_dim + dim] =\n apply_rope(k_temp[dim], k_temp[dim ^ 1u], dim, row);", | |
| "k_norm_out[row * params.attention_dim + head * params.head_dim + dim] =\n f16(apply_rope(k_temp[dim], k_temp[dim ^ 1u], dim, row));", | |
| ); | |
| return code; | |
| } | |
| function makeAttentionQkF16ConsumerShader(shader) { | |
| let code = shader; | |
| if (!/^\s*enable f16;/.test(code)) { | |
| code = `enable f16;\n${code}`; | |
| } | |
| code = code.replace( | |
| "@group(0) @binding(1) var<storage, read> q_norm: array<f32>;", | |
| "@group(0) @binding(1) var<storage, read> q_norm: array<f16>;", | |
| ); | |
| code = code.replace( | |
| "@group(0) @binding(2) var<storage, read> k_norm: array<f32>;", | |
| "@group(0) @binding(2) var<storage, read> k_norm: array<f16>;", | |
| ); | |
| code = code.replace(/return q_norm\[([^\]]+)\];/g, "return f32(q_norm[$1]);"); | |
| code = code.replace(/return k_norm\[([^\]]+)\];/g, "return f32(k_norm[$1]);"); | |
| return code; | |
| } | |
| const DOUBLE_QKV_NORM_ROPE_SHADER = ` | |
| enable f16; | |
| struct Params { | |
| rows: u32, | |
| qkv_m: u32, | |
| heads: u32, | |
| head_dim: u32, | |
| attention_dim: u32, | |
| output_row_offset: u32, | |
| text_tokens: u32, | |
| image_width: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> qkv: array<f16>; | |
| @group(0) @binding(1) var<storage, read> query_scale: array<f16>; | |
| @group(0) @binding(2) var<storage, read> key_scale: array<f16>; | |
| @group(0) @binding(3) var<storage, read_write> q_norm_out: array<f32>; | |
| @group(0) @binding(4) var<storage, read_write> k_norm_out: array<f32>; | |
| @group(0) @binding(5) var<storage, read_write> v_out: array<f32>; | |
| @group(0) @binding(6) var<uniform> params: Params; | |
| @group(0) @binding(7) var<storage, read> rope_freq: array<f32>; | |
| var<workgroup> scratch: array<f32, 128>; | |
| var<workgroup> q_temp: array<f32, 128>; | |
| var<workgroup> k_temp: array<f32, 128>; | |
| fn q_value(row: u32, head: u32, dim: u32) -> f32 { | |
| return f32(qkv[row * params.qkv_m + head * params.head_dim + dim]); | |
| } | |
| fn k_value(row: u32, head: u32, dim: u32) -> f32 { | |
| return f32(qkv[row * params.qkv_m + params.attention_dim + head * params.head_dim + dim]); | |
| } | |
| fn v_value(row: u32, head: u32, dim: u32) -> f32 { | |
| return f32(qkv[row * params.qkv_m + params.attention_dim * 2u + head * params.head_dim + dim]); | |
| } | |
| fn rope_position(joint_row: u32, axis: u32) -> f32 { | |
| if (joint_row < params.text_tokens) { | |
| if (axis == 3u) { | |
| return f32(joint_row); | |
| } | |
| return 0.0; | |
| } | |
| let image_row = joint_row - params.text_tokens; | |
| let safe_width = max(params.image_width, 1u); | |
| let y = image_row / safe_width; | |
| let x = image_row - y * safe_width; | |
| if (axis == 1u) { | |
| return f32(y); | |
| } | |
| if (axis == 2u) { | |
| return f32(x); | |
| } | |
| return 0.0; | |
| } | |
| fn apply_rope(value: f32, mate: f32, dim: u32, joint_row: u32) -> f32 { | |
| let pair = dim / 2u; | |
| let lane = dim & 1u; | |
| let axis = pair / 16u; | |
| let freq_index = pair - axis * 16u; | |
| let angle = rope_position(joint_row, axis) * rope_freq[freq_index]; | |
| let c = cos(angle); | |
| let s = sin(angle); | |
| if (lane == 0u) { | |
| return c * value - s * mate; | |
| } | |
| return s * mate + c * value; | |
| } | |
| @compute @workgroup_size(128, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let dim = lid.x; | |
| let row = wid.x; | |
| let head = wid.y; | |
| if (row >= params.rows || head >= params.heads) { | |
| return; | |
| } | |
| let joint_row = params.output_row_offset + row; | |
| let q = q_value(row, head, dim); | |
| scratch[dim] = q * q; | |
| workgroupBarrier(); | |
| for (var stride = 64u; stride > 0u; stride = stride / 2u) { | |
| if (dim < stride) { | |
| scratch[dim] = scratch[dim] + scratch[dim + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let q_inv = inverseSqrt(scratch[0] / f32(params.head_dim) + 0.000001); | |
| q_temp[dim] = f32(f16(f32(f16(q * q_inv)) * f32(query_scale[dim]))); | |
| workgroupBarrier(); | |
| q_norm_out[joint_row * params.attention_dim + head * params.head_dim + dim] = | |
| apply_rope(q_temp[dim], q_temp[dim ^ 1u], dim, joint_row); | |
| workgroupBarrier(); | |
| let k = k_value(row, head, dim); | |
| scratch[dim] = k * k; | |
| workgroupBarrier(); | |
| for (var stride = 64u; stride > 0u; stride = stride / 2u) { | |
| if (dim < stride) { | |
| scratch[dim] = scratch[dim] + scratch[dim + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let k_inv = inverseSqrt(scratch[0] / f32(params.head_dim) + 0.000001); | |
| k_temp[dim] = f32(f16(f32(f16(k * k_inv)) * f32(key_scale[dim]))); | |
| workgroupBarrier(); | |
| k_norm_out[joint_row * params.attention_dim + head * params.head_dim + dim] = | |
| apply_rope(k_temp[dim], k_temp[dim ^ 1u], dim, joint_row); | |
| v_out[joint_row * params.attention_dim + head * params.head_dim + dim] = v_value(row, head, dim); | |
| } | |
| `; | |
| function makeJointAttentionTiledShader(tileKeys, queryRows = 8) { | |
| const keys = Number(tileKeys || 8); | |
| const rows = Number(queryRows || 8); | |
| if (![8, 16].includes(keys)) { | |
| throw new Error(`unsupported joint attention tile key count: ${tileKeys}`); | |
| } | |
| if (![4, 8, 16].includes(rows)) { | |
| throw new Error(`unsupported joint attention query row count: ${queryRows}`); | |
| } | |
| return ` | |
| enable f16; | |
| struct Params { | |
| n: u32, | |
| heads: u32, | |
| head_dim: u32, | |
| attention_dim: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> v_in: array<f32>; | |
| @group(0) @binding(1) var<storage, read> q_norm: array<f32>; | |
| @group(0) @binding(2) var<storage, read> k_norm: array<f32>; | |
| @group(0) @binding(3) var<storage, read_write> attention_out: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| var<workgroup> k_tile: array<f32, ${keys * 128}>; | |
| var<workgroup> v_tile: array<f32, ${keys * 128}>; | |
| var<workgroup> scores: array<f32, ${rows * keys}>; | |
| var<workgroup> reduce: array<f32, ${rows * 16}>; | |
| fn qn(row: u32, head: u32, dim: u32) -> f32 { | |
| return q_norm[row * params.attention_dim + head * params.head_dim + dim]; | |
| } | |
| fn kn(row: u32, head: u32, dim: u32) -> f32 { | |
| return k_norm[row * params.attention_dim + head * params.head_dim + dim]; | |
| } | |
| fn v_value(row: u32, head: u32, dim: u32) -> f32 { | |
| return v_in[row * params.attention_dim + head * params.head_dim + dim]; | |
| } | |
| @compute @workgroup_size(16, ${rows}, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let dim_lane = lid.x; | |
| let q_lane = lid.y; | |
| let local_linear = q_lane * 16u + dim_lane; | |
| let row = wid.x * ${rows}u + q_lane; | |
| let head = wid.y; | |
| if (head >= params.heads || params.n > 4608u) { | |
| return; | |
| } | |
| let row_valid = row < params.n; | |
| let dim0 = dim_lane * 8u; | |
| var qv: array<f32, 8>; | |
| var acc: array<f32, 8>; | |
| for (var i = 0u; i < 8u; i = i + 1u) { | |
| let dim = dim0 + i; | |
| qv[i] = select(0.0, qn(row, head, dim), row_valid); | |
| acc[i] = 0.0; | |
| } | |
| var m_state = -3.402823e38; | |
| var l_state = 0.0; | |
| for (var key_base = 0u; key_base < params.n; key_base = key_base + ${keys}u) { | |
| for (var tile_index = local_linear; tile_index < ${keys * 128}u; tile_index = tile_index + ${rows * 16}u) { | |
| let key_offset = tile_index / 128u; | |
| let dim = tile_index - key_offset * 128u; | |
| let key_row = key_base + key_offset; | |
| if (key_row < params.n) { | |
| k_tile[tile_index] = kn(key_row, head, dim); | |
| v_tile[tile_index] = v_value(key_row, head, dim); | |
| } else { | |
| k_tile[tile_index] = 0.0; | |
| v_tile[tile_index] = 0.0; | |
| } | |
| } | |
| workgroupBarrier(); | |
| for (var key_offset = 0u; key_offset < ${keys}u; key_offset = key_offset + 1u) { | |
| var partial = 0.0; | |
| for (var i = 0u; i < 8u; i = i + 1u) { | |
| partial = partial + qv[i] * k_tile[key_offset * 128u + dim0 + i]; | |
| } | |
| reduce[q_lane * 16u + dim_lane] = partial; | |
| workgroupBarrier(); | |
| for (var stride = 8u; stride > 0u; stride = stride / 2u) { | |
| if (dim_lane < stride) { | |
| reduce[q_lane * 16u + dim_lane] = | |
| reduce[q_lane * 16u + dim_lane] + reduce[q_lane * 16u + dim_lane + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| if (dim_lane == 0u) { | |
| let key_row = key_base + key_offset; | |
| scores[q_lane * ${keys}u + key_offset] = | |
| select(-3.402823e38, reduce[q_lane * 16u] * 0.08838835, row_valid && key_row < params.n); | |
| } | |
| workgroupBarrier(); | |
| } | |
| var local_max = -3.402823e38; | |
| if (dim_lane < ${keys}u) { | |
| local_max = scores[q_lane * ${keys}u + dim_lane]; | |
| } | |
| reduce[q_lane * 16u + dim_lane] = local_max; | |
| workgroupBarrier(); | |
| for (var stride = 8u; stride > 0u; stride = stride / 2u) { | |
| if (dim_lane < stride) { | |
| reduce[q_lane * 16u + dim_lane] = | |
| max(reduce[q_lane * 16u + dim_lane], reduce[q_lane * 16u + dim_lane + stride]); | |
| } | |
| workgroupBarrier(); | |
| } | |
| let tile_max = reduce[q_lane * 16u]; | |
| let m_new = max(m_state, tile_max); | |
| let alpha = exp(m_state - m_new); | |
| var local_sum = 0.0; | |
| if (dim_lane < ${keys}u) { | |
| local_sum = exp(scores[q_lane * ${keys}u + dim_lane] - m_new); | |
| } | |
| reduce[q_lane * 16u + dim_lane] = local_sum; | |
| workgroupBarrier(); | |
| for (var stride = 8u; stride > 0u; stride = stride / 2u) { | |
| if (dim_lane < stride) { | |
| reduce[q_lane * 16u + dim_lane] = | |
| reduce[q_lane * 16u + dim_lane] + reduce[q_lane * 16u + dim_lane + stride]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let l_new = l_state * alpha + reduce[q_lane * 16u]; | |
| for (var i = 0u; i < 8u; i = i + 1u) { | |
| let dim = dim0 + i; | |
| var weighted_v = 0.0; | |
| for (var key_offset = 0u; key_offset < ${keys}u; key_offset = key_offset + 1u) { | |
| let weight = f32(f16(exp(scores[q_lane * ${keys}u + key_offset] - m_new))); | |
| weighted_v = weighted_v + weight * f32(f16(v_tile[key_offset * 128u + dim])); | |
| } | |
| acc[i] = acc[i] * alpha + weighted_v; | |
| } | |
| m_state = m_new; | |
| l_state = l_new; | |
| workgroupBarrier(); | |
| } | |
| if (row_valid) { | |
| let inv_l = 1.0 / l_state; | |
| for (var i = 0u; i < 8u; i = i + 1u) { | |
| let dim = dim0 + i; | |
| attention_out[row * params.attention_dim + head * params.head_dim + dim] = f32(f16(clamp(acc[i] * inv_l, -65504.0, 65504.0))); | |
| } | |
| } | |
| } | |
| `; | |
| } | |
| const SINGLE_MLP_ACTIVATION_ATTENTION_SHADER = ` | |
| struct Params { | |
| n: u32, | |
| linear1_m: u32, | |
| linear2_k: u32, | |
| attn_dim: u32, | |
| qkv_dim: u32, | |
| mlp_half_dim: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> linear1_out: array<f32>; | |
| @group(0) @binding(1) var<storage, read> attention_out: array<f32>; | |
| @group(0) @binding(2) var<storage, read_write> linear2_in: array<f32>; | |
| @group(0) @binding(3) var<uniform> params: Params; | |
| fn sigmoid(value: f32) -> f32 { | |
| return 1.0 / (1.0 + exp(-value)); | |
| } | |
| @compute @workgroup_size(256, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let col = wid.x * 256u + lid.x; | |
| let row = wid.y; | |
| if (row >= params.n || col >= params.linear2_k) { | |
| return; | |
| } | |
| var value = 0.0; | |
| if (col < params.attn_dim) { | |
| value = attention_out[row * params.attn_dim + col]; | |
| } else { | |
| let mlp_col = col - params.attn_dim; | |
| let gate = linear1_out[row * params.linear1_m + params.qkv_dim + mlp_col]; | |
| let up = linear1_out[row * params.linear1_m + params.qkv_dim + params.mlp_half_dim + mlp_col]; | |
| value = gate * sigmoid(gate) * up; | |
| } | |
| linear2_in[row * params.linear2_k + col] = value; | |
| } | |
| `; | |
| const SINGLE_MLP_ACTIVATE_QUANT_SHADER = ` | |
| struct Params { | |
| n: u32, | |
| linear1_m: u32, | |
| linear2_k: u32, | |
| attn_dim: u32, | |
| qkv_dim: u32, | |
| mlp_half_dim: u32, | |
| k_words: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> linear1_out: array<f32>; | |
| @group(0) @binding(1) var<storage, read_write> packed_out: array<u32>; | |
| @group(0) @binding(2) var<storage, read_write> scales_out: array<f32>; | |
| @group(0) @binding(3) var<uniform> params: Params; | |
| var<workgroup> row_absmax: array<f32, 256>; | |
| fn sigmoid(value: f32) -> f32 { | |
| return 1.0 / (1.0 + exp(-value)); | |
| } | |
| fn synthetic_attention(row: u32, col: u32) -> f32 { | |
| let code = (row * 37u + col * 13u + 17u) % 251u; | |
| return (f32(code) - 125.0) / 64.0; | |
| } | |
| fn activation_value(row: u32, col: u32) -> f32 { | |
| if (col < params.attn_dim) { | |
| return synthetic_attention(row, col); | |
| } | |
| let mlp_col = col - params.attn_dim; | |
| let gate = linear1_out[row * params.linear1_m + params.qkv_dim + mlp_col]; | |
| let up = linear1_out[row * params.linear1_m + params.qkv_dim + params.mlp_half_dim + mlp_col]; | |
| return gate * sigmoid(gate) * up; | |
| } | |
| fn pack_i8_lane(value: i32, lane: u32) -> u32 { | |
| return (u32(value) & 0xffu) << (lane * 8u); | |
| } | |
| @compute @workgroup_size(256, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let row = wid.x; | |
| let local_id = lid.x; | |
| var absmax = 0.0; | |
| for (var col = local_id; col < params.linear2_k; col = col + 256u) { | |
| absmax = max(absmax, abs(activation_value(row, col))); | |
| } | |
| row_absmax[local_id] = absmax; | |
| workgroupBarrier(); | |
| for (var stride = 128u; stride > 0u; stride = stride / 2u) { | |
| if (local_id < stride) { | |
| row_absmax[local_id] = max(row_absmax[local_id], row_absmax[local_id + stride]); | |
| } | |
| workgroupBarrier(); | |
| } | |
| let row_scale = row_absmax[0]; | |
| let quant_scale = select(0.0, 127.0 / row_scale, row_scale > 0.0); | |
| if (local_id == 0u) { | |
| scales_out[row] = select(1.0, row_scale / 127.0, row_scale > 0.0); | |
| } | |
| for (var word = local_id; word < params.k_words; word = word + 256u) { | |
| var packed_word = 0u; | |
| for (var lane = 0u; lane < 4u; lane = lane + 1u) { | |
| let col = word * 4u + lane; | |
| var q = 0i; | |
| if (col < params.linear2_k) { | |
| let scaled = round(clamp(activation_value(row, col) * quant_scale, -127.0, 127.0)); | |
| q = i32(scaled); | |
| } | |
| packed_word = packed_word | pack_i8_lane(q, lane); | |
| } | |
| packed_out[row * params.k_words + word] = packed_word; | |
| } | |
| } | |
| `; | |
| const SINGLE_MLP_ACTIVATE_ATTENTION_QUANT_SHADER = ` | |
| struct Params { | |
| n: u32, | |
| linear1_m: u32, | |
| linear2_k: u32, | |
| attn_dim: u32, | |
| qkv_dim: u32, | |
| mlp_half_dim: u32, | |
| k_words: u32, | |
| }; | |
| @group(0) @binding(0) var<storage, read> linear1_out: array<f32>; | |
| @group(0) @binding(1) var<storage, read> attention_out: array<f32>; | |
| @group(0) @binding(2) var<storage, read_write> packed_out: array<u32>; | |
| @group(0) @binding(3) var<storage, read_write> scales_out: array<f32>; | |
| @group(0) @binding(4) var<uniform> params: Params; | |
| var<workgroup> row_absmax: array<f32, 256>; | |
| fn sigmoid(value: f32) -> f32 { | |
| return 1.0 / (1.0 + exp(-value)); | |
| } | |
| fn activation_value(row: u32, col: u32) -> f32 { | |
| if (col < params.attn_dim) { | |
| return attention_out[row * params.attn_dim + col]; | |
| } | |
| let mlp_col = col - params.attn_dim; | |
| let gate = linear1_out[row * params.linear1_m + params.qkv_dim + mlp_col]; | |
| let up = linear1_out[row * params.linear1_m + params.qkv_dim + params.mlp_half_dim + mlp_col]; | |
| return gate * sigmoid(gate) * up; | |
| } | |
| fn pack_i8_lane(value: i32, lane: u32) -> u32 { | |
| return (u32(value) & 0xffu) << (lane * 8u); | |
| } | |
| @compute @workgroup_size(256, 1, 1) | |
| fn main( | |
| @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(workgroup_id) wid: vec3<u32>) { | |
| let row = wid.x; | |
| let local_id = lid.x; | |
| var absmax = 0.0; | |
| for (var col = local_id; col < params.linear2_k; col = col + 256u) { | |
| absmax = max(absmax, abs(activation_value(row, col))); | |
| } | |
| row_absmax[local_id] = absmax; | |
| workgroupBarrier(); | |
| for (var stride = 128u; stride > 0u; stride = stride / 2u) { | |
| if (local_id < stride) { | |
| row_absmax[local_id] = max(row_absmax[local_id], row_absmax[local_id + stride]); | |
| } | |
| workgroupBarrier(); | |
| } | |
| let row_scale = row_absmax[0]; | |
| let quant_scale = select(0.0, 127.0 / row_scale, row_scale > 0.0); | |
| if (local_id == 0u) { | |
| scales_out[row] = select(1.0, row_scale / 127.0, row_scale > 0.0); | |
| } | |
| for (var word = local_id; word < params.k_words; word = word + 256u) { | |
| var packed_word = 0u; | |
| for (var lane = 0u; lane < 4u; lane = lane + 1u) { | |
| let col = word * 4u + lane; | |
| var q = 0i; | |
| if (col < params.linear2_k) { | |
| let scaled = round(clamp(activation_value(row, col) * quant_scale, -127.0, 127.0)); | |
| q = i32(scaled); | |
| } | |
| packed_word = packed_word | pack_i8_lane(q, lane); | |
| } | |
| packed_out[row * params.k_words + word] = packed_word; | |
| } | |
| } | |
| `; | |
| function verificationSummary(actual, expectedItems) { | |
| let maxAbs = 0; | |
| let maxRel = 0; | |
| const samples = []; | |
| for (const item of expectedItems) { | |
| const got = actual[item.row * item.m + item.col]; | |
| const abs = Math.abs(got - item.expected); | |
| const rel = abs / Math.max(1e-6, Math.abs(item.expected)); | |
| maxAbs = Math.max(maxAbs, abs); | |
| maxRel = Math.max(maxRel, rel); | |
| if (samples.length < 8) { | |
| samples.push({row: item.row, col: item.col, expected: item.expected, actual: got, abs}); | |
| } | |
| } | |
| return {checked: expectedItems.length, max_abs_error: maxAbs, max_rel_error: maxRel, samples}; | |
| } | |
| async function runQ4Bench(device, manifest, bundleBaseUrl, config, inputF16) { | |
| const k = manifest.shape.K; | |
| const m = manifest.shape.N; | |
| const n = Number(config.n || 256); | |
| const groupSize = manifest.matmul_nbits.block_size; | |
| const groups = manifest.matmul_nbits.groups_per_col; | |
| const packedWords = manifest.matmul_nbits.packed_group_words; | |
| const tileCols = Number(config.q4TileCols || 96); | |
| const maxOutputRows = Math.max( | |
| 1, | |
| Math.floor((device.limits.maxStorageBufferBindingSize || 134217728) / (m * 4)) | |
| ); | |
| const rowBlock = Math.max(1, Math.min(n, Number(config.rowBlockSize || maxOutputRows))); | |
| const [q4Buf, scalesBuf, zpBuf] = await Promise.all([ | |
| fetchArrayBuffer(`${bundleBaseUrl}${manifest.q4.packed_u32}`), | |
| fetchArrayBuffer(`${bundleBaseUrl}${manifest.q4.scales_f16}`), | |
| fetchArrayBuffer(`${bundleBaseUrl}${manifest.q4.zero_points_u32}`), | |
| ]); | |
| const q4 = new Uint32Array(q4Buf); | |
| const scales = new Uint16Array(scalesBuf); | |
| const zeroPoints = new Uint32Array(zpBuf); | |
| const xBuffer = createBuffer(device, inputF16, GPUBufferUsage.STORAGE); | |
| const q4Buffer = createBuffer(device, q4, GPUBufferUsage.STORAGE); | |
| const scaleBuffer = createBuffer(device, scales, GPUBufferUsage.STORAGE); | |
| const zpBuffer = createBuffer(device, zeroPoints, GPUBufferUsage.STORAGE); | |
| const yBuffer = createEmptyBuffer(device, rowBlock * m * 4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| const module = device.createShaderModule({code: makeQ4ZpShader32xWide(tileCols)}); | |
| const pipeline = await device.createComputePipelineAsync({ | |
| layout: "auto", | |
| compute: {module, entryPoint: "main"}, | |
| }); | |
| const chunks = []; | |
| for (let rowOffset = 0; rowOffset < n; rowOffset += rowBlock) { | |
| const chunkRows = Math.min(rowBlock, n - rowOffset); | |
| const params = new Uint32Array([chunkRows, k, m, groupSize, groups, packedWords, rowOffset, 0]); | |
| const paramsBuffer = createBuffer(device, params, GPUBufferUsage.UNIFORM); | |
| const bindGroup = device.createBindGroup({ | |
| layout: pipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: xBuffer}}, | |
| {binding: 1, resource: {buffer: q4Buffer}}, | |
| {binding: 2, resource: {buffer: scaleBuffer}}, | |
| {binding: 3, resource: {buffer: zpBuffer}}, | |
| {binding: 4, resource: {buffer: yBuffer}}, | |
| {binding: 5, resource: {buffer: paramsBuffer}}, | |
| ], | |
| }); | |
| chunks.push({rowOffset, chunkRows, bindGroup}); | |
| } | |
| async function dispatch() { | |
| const encoder = device.createCommandEncoder(); | |
| for (const chunk of chunks) { | |
| const pass = encoder.beginComputePass(); | |
| pass.setPipeline(pipeline); | |
| pass.setBindGroup(0, chunk.bindGroup); | |
| pass.dispatchWorkgroups(Math.ceil(m / tileCols), Math.ceil(chunk.chunkRows / 32)); | |
| pass.end(); | |
| } | |
| device.queue.submit([encoder.finish()]); | |
| await device.queue.onSubmittedWorkDone(); | |
| } | |
| for (let i = 0; i < Number(config.warmupRuns ?? 1); ++i) { | |
| await dispatch(); | |
| } | |
| const times = []; | |
| for (let i = 0; i < Number(config.timedRuns ?? 3); ++i) { | |
| const start = performance.now(); | |
| await dispatch(); | |
| times.push(performance.now() - start); | |
| } | |
| let verification = null; | |
| if (config.verify) { | |
| if (chunks.length > 1) { | |
| throw new Error("verification for row-blocked q4 runs is not implemented"); | |
| } | |
| const output = await readFloat32Buffer(device, yBuffer, rowBlock * m); | |
| const expected = []; | |
| const rows = Math.min(n, Number(config.verifyRows || 2)); | |
| const cols = Math.min(m, Number(config.prefixCount || 8)); | |
| for (let row = 0; row < rows; ++row) { | |
| for (let col = 0; col < cols; ++col) { | |
| expected.push({ | |
| row, | |
| col, | |
| m, | |
| expected: q4CpuOutput(row, col, k, m, groupSize, packedWords, inputF16, q4, scales, zeroPoints), | |
| }); | |
| } | |
| } | |
| verification = verificationSummary(output, expected); | |
| } | |
| const medianMs = median(times); | |
| const macs = n * k * m; | |
| return { | |
| mode: "q4_zp", | |
| shape: {n, k, m}, | |
| tile_cols: tileCols, | |
| row_block: rowBlock, | |
| chunks: chunks.length, | |
| timed_ms: times, | |
| summary: { | |
| median_dispatch_ms: medianMs, | |
| effective_tmacs: macs / (medianMs / 1000) / 1e12, | |
| }, | |
| verification, | |
| }; | |
| } | |
| async function runI8Bench(device, manifest, bundleBaseUrl, config, inputF32) { | |
| if (!manifest.i8_dp4a) { | |
| return {mode: "i8_dp4a", skipped: "bundle has no i8_dp4a section"}; | |
| } | |
| const k = manifest.shape.K; | |
| const m = manifest.shape.N; | |
| const n = Number(config.n || 256); | |
| const kWords = manifest.i8_dp4a.k_words; | |
| const tileCols = Number(config.i8TileCols || 96); | |
| const i8Kernel = config.i8Kernel || "32xwide"; | |
| const i8RowsPerWorkgroup = i8Kernel === "16xwide" ? 16 : 32; | |
| const maxOutputRows = Math.max( | |
| 1, | |
| Math.floor((device.limits.maxStorageBufferBindingSize || 134217728) / (m * 4)) | |
| ); | |
| const rowBlock = Math.max(1, Math.min(n, Number(config.rowBlockSize || maxOutputRows))); | |
| const [wBuf, scaleBuf] = await Promise.all([ | |
| fetchArrayBuffer(`${bundleBaseUrl}${manifest.i8_dp4a.packed_u32}`), | |
| fetchArrayBuffer(`${bundleBaseUrl}${manifest.i8_dp4a.scales_f32}`), | |
| ]); | |
| const wPacked = new Uint32Array(wBuf); | |
| const wScales = new Float32Array(scaleBuf); | |
| const xF32Buffer = createBuffer(device, inputF32, GPUBufferUsage.STORAGE); | |
| const xPackedBuffer = createEmptyBuffer(device, n * kWords * 4, GPUBufferUsage.STORAGE); | |
| const xScalesBuffer = createEmptyBuffer(device, n * 4, GPUBufferUsage.STORAGE); | |
| const wBuffer = createBuffer(device, wPacked, GPUBufferUsage.STORAGE); | |
| const wScaleBuffer = createBuffer(device, wScales, GPUBufferUsage.STORAGE); | |
| const yBuffer = createEmptyBuffer(device, rowBlock * m * 4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| const quantParams = new Uint32Array([n, k, m, kWords]); | |
| const quantParamsBuffer = createBuffer(device, quantParams, GPUBufferUsage.UNIFORM); | |
| const quantModule = device.createShaderModule({code: QUANTIZE_X_F32_SHADER}); | |
| const dotModule = device.createShaderModule({ | |
| code: i8Kernel === "16xwide" | |
| ? makeI8ScaledDotShader16xWide(tileCols) | |
| : makeI8ScaledDotShader32xWide(tileCols), | |
| }); | |
| const quantPipeline = await device.createComputePipelineAsync({ | |
| layout: "auto", | |
| compute: {module: quantModule, entryPoint: "main"}, | |
| }); | |
| const dotPipeline = await device.createComputePipelineAsync({ | |
| layout: "auto", | |
| compute: {module: dotModule, entryPoint: "main"}, | |
| }); | |
| const quantBindGroup = device.createBindGroup({ | |
| layout: quantPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: xF32Buffer}}, | |
| {binding: 1, resource: {buffer: xPackedBuffer}}, | |
| {binding: 2, resource: {buffer: xScalesBuffer}}, | |
| {binding: 3, resource: {buffer: quantParamsBuffer}}, | |
| ], | |
| }); | |
| const chunks = []; | |
| for (let rowOffset = 0; rowOffset < n; rowOffset += rowBlock) { | |
| const chunkRows = Math.min(rowBlock, n - rowOffset); | |
| const params = new Uint32Array([chunkRows, k, m, kWords, rowOffset, 0, 0, 0]); | |
| const paramsBuffer = createBuffer(device, params, GPUBufferUsage.UNIFORM); | |
| const dotBindGroup = device.createBindGroup({ | |
| layout: dotPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: xPackedBuffer}}, | |
| {binding: 1, resource: {buffer: wBuffer}}, | |
| {binding: 2, resource: {buffer: xScalesBuffer}}, | |
| {binding: 3, resource: {buffer: wScaleBuffer}}, | |
| {binding: 4, resource: {buffer: yBuffer}}, | |
| {binding: 5, resource: {buffer: paramsBuffer}}, | |
| ], | |
| }); | |
| chunks.push({rowOffset, chunkRows, dotBindGroup}); | |
| } | |
| async function dispatch() { | |
| const encoder = device.createCommandEncoder(); | |
| let pass = encoder.beginComputePass(); | |
| pass.setPipeline(quantPipeline); | |
| pass.setBindGroup(0, quantBindGroup); | |
| pass.dispatchWorkgroups(n); | |
| pass.end(); | |
| for (const chunk of chunks) { | |
| pass = encoder.beginComputePass(); | |
| pass.setPipeline(dotPipeline); | |
| pass.setBindGroup(0, chunk.dotBindGroup); | |
| pass.dispatchWorkgroups(Math.ceil(m / tileCols), Math.ceil(chunk.chunkRows / i8RowsPerWorkgroup)); | |
| pass.end(); | |
| } | |
| device.queue.submit([encoder.finish()]); | |
| await device.queue.onSubmittedWorkDone(); | |
| } | |
| for (let i = 0; i < Number(config.warmupRuns ?? 1); ++i) { | |
| await dispatch(); | |
| } | |
| const times = []; | |
| for (let i = 0; i < Number(config.timedRuns ?? 3); ++i) { | |
| const start = performance.now(); | |
| await dispatch(); | |
| times.push(performance.now() - start); | |
| } | |
| let verification = null; | |
| if (config.verify) { | |
| if (chunks.length > 1) { | |
| throw new Error("verification for row-blocked i8 runs is not implemented"); | |
| } | |
| const output = await readFloat32Buffer(device, yBuffer, rowBlock * m); | |
| const expected = []; | |
| const rows = Math.min(n, Number(config.verifyRows || 2)); | |
| const cols = Math.min(m, Number(config.prefixCount || 8)); | |
| for (let row = 0; row < rows; ++row) { | |
| for (let col = 0; col < cols; ++col) { | |
| expected.push({ | |
| row, | |
| col, | |
| m, | |
| expected: i8CpuOutput(row, col, k, m, kWords, inputF32, wPacked, wScales), | |
| }); | |
| } | |
| } | |
| verification = verificationSummary(output, expected); | |
| } | |
| const medianMs = median(times); | |
| const macs = n * k * m; | |
| return { | |
| mode: "i8_dp4a", | |
| shape: {n, k, m}, | |
| kernel: i8Kernel, | |
| tile_cols: tileCols, | |
| row_block: rowBlock, | |
| chunks: chunks.length, | |
| timed_ms: times, | |
| summary: { | |
| median_dispatch_ms: medianMs, | |
| effective_tmacs: macs / (medianMs / 1000) / 1e12, | |
| }, | |
| verification, | |
| }; | |
| } | |
| async function runCustomLowbitLinearBench(config = {}) { | |
| if (!navigator.gpu) { | |
| throw new Error("navigator.gpu is not available"); | |
| } | |
| const bundleBaseUrl = config.bundleBaseUrl || "/bundle/"; | |
| const manifest = await (await fetch(`${bundleBaseUrl}manifest.json`)).json(); | |
| const adapter = await navigator.gpu.requestAdapter({powerPreference: "high-performance"}); | |
| if (!adapter) { | |
| throw new Error("WebGPU adapter is not available"); | |
| } | |
| const requiredFeatures = []; | |
| if (!adapter.features.has("shader-f16")) { | |
| throw new Error("WebGPU adapter does not expose shader-f16"); | |
| } | |
| requiredFeatures.push("shader-f16"); | |
| const wantI8 = (config.mode || "both") !== "q4"; | |
| const wgslLanguageFeatures = navigator.gpu.wgslLanguageFeatures || new Set(); | |
| const hasDp4a = wgslLanguageFeatures.has("packed_4x8_integer_dot_product"); | |
| const device = await adapter.requestDevice({requiredFeatures}); | |
| const n = Number(config.n || 256); | |
| const k = manifest.shape.K; | |
| const mode = config.mode || "both"; | |
| const inputF16 = buildInputF16(n, k); | |
| const inputF32 = buildInputF32(n, k); | |
| const results = []; | |
| if (mode === "q4" || mode === "both") { | |
| results.push(await runQ4Bench(device, manifest, bundleBaseUrl, config, inputF16)); | |
| } | |
| if (mode === "i8" || mode === "both") { | |
| if (!hasDp4a) { | |
| results.push({mode: "i8_dp4a", skipped: "adapter lacks packed-4x8-integer-dot-product"}); | |
| } else { | |
| results.push(await runI8Bench(device, manifest, bundleBaseUrl, config, inputF32)); | |
| } | |
| } | |
| return { | |
| verdict: "custom-lowbit-linear-bench-completed", | |
| manifest: { | |
| name: manifest.name, | |
| source_node: manifest.source_node, | |
| shape: manifest.shape, | |
| matmul_nbits: manifest.matmul_nbits, | |
| }, | |
| adapter_features: Array.from(adapter.features).sort(), | |
| wgsl_language_features: Array.from(wgslLanguageFeatures).sort(), | |
| config: { | |
| n, | |
| mode, | |
| warmupRuns: Number(config.warmupRuns ?? 1), | |
| timedRuns: Number(config.timedRuns ?? 3), | |
| verify: Boolean(config.verify), | |
| }, | |
| results, | |
| }; | |
| } | |
| window.runCustomLowbitLinearBench = runCustomLowbitLinearBench; | |
| async function runCustomLowbitLinearManifestBench(config = {}) { | |
| if (!navigator.gpu) { | |
| throw new Error("navigator.gpu is not available"); | |
| } | |
| const manifest = config.manifest; | |
| if (!manifest || !manifest.shape || !manifest.q4 || !manifest.matmul_nbits) { | |
| throw new Error("runCustomLowbitLinearManifestBench requires a manifest entry with shape, q4, and matmul_nbits"); | |
| } | |
| const linearManifest = { | |
| ...manifest, | |
| q4: { | |
| ...manifest.q4, | |
| packed_u32: manifest.q4.packed_u32 || manifest.q4.qweight_u32, | |
| scales_f16: manifest.q4.scales_f16, | |
| zero_points_u32: manifest.q4.zero_points_u32, | |
| packed_u32_count: manifest.q4.packed_u32_count || manifest.q4.qweight_u32_count, | |
| }, | |
| }; | |
| const bundleBaseUrl = config.bundleBaseUrl || "/runtime/custom_lowbit/full_transformer/"; | |
| const adapter = await navigator.gpu.requestAdapter({powerPreference: "high-performance"}); | |
| if (!adapter) { | |
| throw new Error("WebGPU adapter is not available"); | |
| } | |
| if (!adapter.features.has("shader-f16")) { | |
| throw new Error("WebGPU adapter does not expose shader-f16"); | |
| } | |
| const wgslLanguageFeatures = navigator.gpu.wgslLanguageFeatures || new Set(); | |
| const deviceDescriptor = {requiredFeatures: ["shader-f16"]}; | |
| if (adapter.limits && adapter.limits.maxComputeWorkgroupStorageSize >= 32768) { | |
| deviceDescriptor.requiredLimits = {maxComputeWorkgroupStorageSize: 32768}; | |
| } | |
| const device = await adapter.requestDevice(deviceDescriptor); | |
| const n = Number(config.n || 16); | |
| const k = manifest.shape.K; | |
| const mode = config.mode || "q4"; | |
| const inputF16 = buildInputF16(n, k); | |
| const inputF32 = mode === "i8" || mode === "both" ? buildInputF32(n, k) : null; | |
| const results = []; | |
| if (mode === "q4" || mode === "both") { | |
| results.push(await runQ4Bench(device, linearManifest, bundleBaseUrl, config, inputF16)); | |
| } | |
| if (mode === "i8" || mode === "both") { | |
| if (!linearManifest.i8_dp4a) { | |
| results.push({mode: "i8_dp4a", skipped: "manifest entry does not contain i8_dp4a weights"}); | |
| } else if (!wgslLanguageFeatures.has("packed_4x8_integer_dot_product")) { | |
| results.push({mode: "i8_dp4a", skipped: "adapter lacks packed-4x8-integer-dot-product"}); | |
| } else { | |
| results.push(await runI8Bench(device, linearManifest, bundleBaseUrl, config, inputF32)); | |
| } | |
| } | |
| return { | |
| verdict: "custom-lowbit-linear-manifest-bench-completed", | |
| manifest: { | |
| id: manifest.id || "", | |
| name: manifest.name || "", | |
| source_node: manifest.source_node, | |
| shape: manifest.shape, | |
| matmul_nbits: manifest.matmul_nbits, | |
| }, | |
| adapter_features: Array.from(adapter.features).sort(), | |
| wgsl_language_features: Array.from(wgslLanguageFeatures).sort(), | |
| config: { | |
| n, | |
| mode, | |
| warmupRuns: Number(config.warmupRuns ?? 1), | |
| timedRuns: Number(config.timedRuns ?? 3), | |
| verify: Boolean(config.verify), | |
| }, | |
| results, | |
| }; | |
| } | |
| window.runCustomLowbitLinearManifestBench = runCustomLowbitLinearManifestBench; | |
| function normalizeFullLinearEntry(entry) { | |
| if (!entry || !entry.shape || !entry.q4 || !entry.matmul_nbits) { | |
| throw new Error("full transformer linear entry requires shape, q4, and matmul_nbits"); | |
| } | |
| return { | |
| ...entry, | |
| q4: { | |
| ...entry.q4, | |
| packed_u32: entry.q4.packed_u32 || entry.q4.qweight_u32, | |
| scales_f16: entry.q4.scales_f16, | |
| zero_points_u32: entry.q4.zero_points_u32, | |
| packed_u32_count: entry.q4.packed_u32_count || entry.q4.qweight_u32_count, | |
| }, | |
| }; | |
| } | |
| function findFullManifestLinear(fullManifest, id) { | |
| const linears = fullManifest && Array.isArray(fullManifest.linears) ? fullManifest.linears : []; | |
| const entry = linears.find((item) => item.id === id || item.name === id || item.source_node === id); | |
| if (!entry) { | |
| throw new Error(`full transformer manifest does not contain linear: ${id}`); | |
| } | |
| return normalizeFullLinearEntry(entry); | |
| } | |
| function makeTimestepEmbeddingF16(timestep, dim = 256, maxPeriod = 10000, timeFactor = 1000) { | |
| const half = Math.floor(dim / 2); | |
| const values = new Uint16Array(dim); | |
| const scaledTimestep = timestep * timeFactor; | |
| for (let i = 0; i < half; ++i) { | |
| const freq = Math.exp(-Math.log(maxPeriod) * i / half); | |
| const arg = scaledTimestep * freq; | |
| values[i] = float32ToFloat16Bits(Math.cos(arg)); | |
| values[half + i] = float32ToFloat16Bits(Math.sin(arg)); | |
| } | |
| if (dim % 2) values[dim - 1] = 0; | |
| return values; | |
| } | |
| async function createQ4LinearStage(device, entry, bundleBaseUrl, inputBuffer, outputBuffer, rows, tileCols, outputF16 = false, options = {}) { | |
| const manifest = normalizeFullLinearEntry(entry); | |
| const k = manifest.shape.K; | |
| const weightM = manifest.shape.N; | |
| const colOffset = Number(options.colOffset ?? 0); | |
| const m = Number(options.outputCols ?? (weightM - colOffset)); | |
| const outputStride = Number(options.outputStride ?? weightM); | |
| const groupSize = manifest.matmul_nbits.block_size; | |
| const groups = manifest.matmul_nbits.groups_per_col; | |
| const packedWords = manifest.matmul_nbits.packed_group_words; | |
| const kChunk = Number(options.kChunk ?? (tileCols >= 256 ? 32 : 64)); | |
| const unrollK = options.unrollK ?? false; | |
| const stageCacheKey = options.cacheKey | |
| ? `q4-stage:${options.cacheKey}:${manifest.id || manifest.name || manifest.source_node}:${rows}:${k}:${m}:${tileCols}:${kChunk}:${unrollK ? 1 : 0}:${outputF16 ? "f16" : "f32"}:${colOffset}:${weightM}:${outputStride}` | |
| : ""; | |
| return await getCachedStageObject(device, stageCacheKey, async () => { | |
| const q4Url = new URL(manifest.q4.packed_u32, new URL(bundleBaseUrl, window.location.href)).toString(); | |
| const scalesUrl = new URL(manifest.q4.scales_f16, new URL(bundleBaseUrl, window.location.href)).toString(); | |
| const zpUrl = new URL(manifest.q4.zero_points_u32, new URL(bundleBaseUrl, window.location.href)).toString(); | |
| const [q4Buf, scalesBuf, zpBuf] = await Promise.all([ | |
| fetchArrayBuffer(q4Url), | |
| fetchArrayBuffer(scalesUrl), | |
| fetchArrayBuffer(zpUrl), | |
| ]); | |
| const q4Buffer = createImmutableBuffer(device, new Uint32Array(q4Buf), GPUBufferUsage.STORAGE, q4Url); | |
| const scaleBuffer = createImmutableBuffer(device, new Uint16Array(scalesBuf), GPUBufferUsage.STORAGE, scalesUrl); | |
| const zpBuffer = createImmutableBuffer(device, new Uint32Array(zpBuf), GPUBufferUsage.STORAGE, zpUrl); | |
| const paramsBuffer = createBuffer( | |
| device, | |
| new Uint32Array([rows, k, m, groupSize, groups, packedWords, 0, colOffset, weightM, outputStride, 0, 0]), | |
| GPUBufferUsage.UNIFORM, | |
| ); | |
| const pipeline = await getCachedComputePipeline( | |
| device, | |
| `q4-zp-32xwide:${tileCols}:${kChunk}:${unrollK ? 1 : 0}:${outputF16 ? "f16" : "f32"}`, | |
| makeQ4ZpShader32xWide(tileCols, outputF16, kChunk, unrollK), | |
| ); | |
| const bindGroup = device.createBindGroup({ | |
| layout: pipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: inputBuffer}}, | |
| {binding: 1, resource: {buffer: q4Buffer}}, | |
| {binding: 2, resource: {buffer: scaleBuffer}}, | |
| {binding: 3, resource: {buffer: zpBuffer}}, | |
| {binding: 4, resource: {buffer: outputBuffer}}, | |
| {binding: 5, resource: {buffer: paramsBuffer}}, | |
| ], | |
| }); | |
| return { | |
| id: manifest.id || manifest.name || manifest.source_node, | |
| shape: manifest.shape, | |
| pipeline, | |
| bindGroup, | |
| buffers: [q4Buffer, scaleBuffer, zpBuffer, paramsBuffer], | |
| workgroupsX: Math.ceil(m / tileCols), | |
| workgroupsY: Math.ceil(rows / 32), | |
| kChunk, | |
| unrollK, | |
| }; | |
| }); | |
| } | |
| async function createQ4Dp4aLinearStage( | |
| device, | |
| entry, | |
| bundleBaseUrl, | |
| inputPackedBuffer, | |
| inputScalesBuffer, | |
| outputBuffer, | |
| rows, | |
| tileCols, | |
| wChunkCols, | |
| outputF16 = true, | |
| options = {}, | |
| ) { | |
| const manifest = normalizeFullLinearEntry(entry); | |
| const k = manifest.shape.K; | |
| const weightM = manifest.shape.N; | |
| const colOffset = Number(options.colOffset ?? 0); | |
| const m = Number(options.outputCols ?? (weightM - colOffset)); | |
| const outputStride = Number(options.outputStride ?? weightM); | |
| const xScaleGroupsPerRow = Number(options.xScaleGroupsPerRow ?? 1); | |
| const groupSize = manifest.matmul_nbits.block_size; | |
| const groups = manifest.matmul_nbits.groups_per_col; | |
| const packedWords = manifest.matmul_nbits.packed_group_words; | |
| const kWords = k / 4; | |
| if (k % 4 !== 0) { | |
| throw new Error(`q4 dp4a stage requires K divisible by 4; got K=${k} for ${manifest.id || manifest.name || manifest.source_node}`); | |
| } | |
| const stageCacheKey = options.cacheKey | |
| ? `q4-dp4a-stage:${options.cacheKey}:${manifest.id || manifest.name || manifest.source_node}:${rows}:${k}:${m}:${tileCols}:${wChunkCols}:${outputF16 ? "f16" : "f32"}:${colOffset}:${weightM}:${outputStride}:${xScaleGroupsPerRow}` | |
| : ""; | |
| return await getCachedStageObject(device, stageCacheKey, async () => { | |
| const q4Url = new URL(manifest.q4.packed_u32, new URL(bundleBaseUrl, window.location.href)).toString(); | |
| const scalesUrl = new URL(manifest.q4.scales_f16, new URL(bundleBaseUrl, window.location.href)).toString(); | |
| const zpUrl = new URL(manifest.q4.zero_points_u32, new URL(bundleBaseUrl, window.location.href)).toString(); | |
| const [q4Buf, scalesBuf, zpBuf] = await Promise.all([ | |
| fetchArrayBuffer(q4Url), | |
| fetchArrayBuffer(scalesUrl), | |
| fetchArrayBuffer(zpUrl), | |
| ]); | |
| const q4Buffer = createImmutableBuffer(device, new Uint32Array(q4Buf), GPUBufferUsage.STORAGE, q4Url); | |
| const scaleBuffer = createImmutableBuffer(device, new Uint16Array(scalesBuf), GPUBufferUsage.STORAGE, scalesUrl); | |
| const zpBuffer = createImmutableBuffer(device, new Uint32Array(zpBuf), GPUBufferUsage.STORAGE, zpUrl); | |
| const paramsBuffer = createBuffer( | |
| device, | |
| new Uint32Array([rows, k, m, groupSize, groups, packedWords, 0, colOffset, weightM, outputStride, kWords, xScaleGroupsPerRow]), | |
| GPUBufferUsage.UNIFORM, | |
| ); | |
| const pipeline = await getCachedComputePipeline( | |
| device, | |
| `q4-zp-dp4a-32xwide:${tileCols}:${wChunkCols}:${outputF16 ? "f16" : "f32"}`, | |
| makeQ4ZpDp4aShader32xWide(tileCols, wChunkCols, outputF16), | |
| ); | |
| const bindGroup = device.createBindGroup({ | |
| layout: pipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: inputPackedBuffer}}, | |
| {binding: 1, resource: {buffer: q4Buffer}}, | |
| {binding: 2, resource: {buffer: inputScalesBuffer}}, | |
| {binding: 3, resource: {buffer: scaleBuffer}}, | |
| {binding: 4, resource: {buffer: zpBuffer}}, | |
| {binding: 5, resource: {buffer: outputBuffer}}, | |
| {binding: 6, resource: {buffer: paramsBuffer}}, | |
| ], | |
| }); | |
| return { | |
| id: manifest.id || manifest.name || manifest.source_node, | |
| shape: manifest.shape, | |
| pipeline, | |
| bindGroup, | |
| buffers: [q4Buffer, scaleBuffer, zpBuffer, paramsBuffer], | |
| workgroupsX: Math.ceil(m / tileCols), | |
| workgroupsY: Math.ceil(rows / 32), | |
| q4Dp4a: true, | |
| }; | |
| }); | |
| } | |
| async function createF32ToF16Stage(device, inputBuffer, outputBuffer, count, applySilu, cacheKey = "") { | |
| return await getCachedStageObject(device, cacheKey ? `f32-to-f16-stage:${cacheKey}:${count}:${applySilu ? 1 : 0}` : "", async () => { | |
| const pipeline = await getCachedComputePipeline(device, "f32-to-f16", F32_TO_F16_SHADER); | |
| const paramsBuffer = createBuffer( | |
| device, | |
| new Uint32Array([count, applySilu ? 1 : 0, 0, 0]), | |
| GPUBufferUsage.UNIFORM, | |
| ); | |
| const bindGroup = device.createBindGroup({ | |
| layout: pipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: inputBuffer}}, | |
| {binding: 1, resource: {buffer: outputBuffer}}, | |
| {binding: 2, resource: {buffer: paramsBuffer}}, | |
| ], | |
| }); | |
| return {pipeline, bindGroup, buffers: [paramsBuffer], workgroupsX: Math.ceil(count / 256), applySilu}; | |
| }); | |
| } | |
| function modulationOutputId(kind) { | |
| const normalized = String(kind || "single").toLowerCase(); | |
| if (normalized === "single" || normalized === "single_stream") return "single_stream_modulation.lin"; | |
| if (normalized === "double_img" || normalized === "img") return "double_stream_modulation_img.lin"; | |
| if (normalized === "double_txt" || normalized === "txt") return "double_stream_modulation_txt.lin"; | |
| return normalized; | |
| } | |
| function findQkNormEntry(fullManifest, id) { | |
| const entries = fullManifest && fullManifest.constants && Array.isArray(fullManifest.constants.qk_norms) | |
| ? fullManifest.constants.qk_norms | |
| : []; | |
| const entry = entries.find((item) => item.id === id || item.name === id); | |
| if (!entry || !entry.files) { | |
| throw new Error(`full transformer manifest does not contain q/k norm constants: ${id}`); | |
| } | |
| return entry; | |
| } | |
| async function loadQkNormScales(fullManifest, bundleBaseUrl, id) { | |
| const entry = findQkNormEntry(fullManifest, id); | |
| const queryScaleUrl = new URL(entry.files.query_norm_scale_f16, new URL(bundleBaseUrl, window.location.href)).toString(); | |
| const keyScaleUrl = new URL(entry.files.key_norm_scale_f16, new URL(bundleBaseUrl, window.location.href)).toString(); | |
| const [queryBuf, keyBuf] = await Promise.all([ | |
| fetchArrayBuffer(queryScaleUrl), | |
| fetchArrayBuffer(keyScaleUrl), | |
| ]); | |
| return { | |
| entry, | |
| queryScale: new Uint16Array(queryBuf), | |
| keyScale: new Uint16Array(keyBuf), | |
| queryScaleUrl, | |
| keyScaleUrl, | |
| }; | |
| } | |
| async function loadDynamicI8BundleFromQ4(entry, bundleBaseUrl) { | |
| const manifest = normalizeFullLinearEntry(entry); | |
| const cacheKey = `${bundleBaseUrl}|${manifest.id || manifest.name || manifest.source_node}`; | |
| if (dynamicI8BundleCache.has(cacheKey)) { | |
| const cached = await dynamicI8BundleCache.get(cacheKey); | |
| return { | |
| ...cached, | |
| cacheHit: true, | |
| loadMs: 0, | |
| convertMs: 0, | |
| }; | |
| } | |
| const k = manifest.shape.K; | |
| const m = manifest.shape.N; | |
| const groupSize = manifest.matmul_nbits.block_size; | |
| const groups = manifest.matmul_nbits.groups_per_col; | |
| const packedWords = manifest.matmul_nbits.packed_group_words; | |
| const kWords = Math.ceil(k / 4); | |
| const promise = (async () => { | |
| const start = performance.now(); | |
| const [q4Buf, scalesBuf, zpBuf] = await Promise.all([ | |
| fetchArrayBuffer(`${bundleBaseUrl}${manifest.q4.packed_u32}`), | |
| fetchArrayBuffer(`${bundleBaseUrl}${manifest.q4.scales_f16}`), | |
| fetchArrayBuffer(`${bundleBaseUrl}${manifest.q4.zero_points_u32}`), | |
| ]); | |
| const fetchMs = performance.now() - start; | |
| const q4 = new Uint32Array(q4Buf); | |
| const scales = new Uint16Array(scalesBuf); | |
| const zeroPoints = new Uint32Array(zpBuf); | |
| const wPacked = new Uint32Array(kWords * m); | |
| const wScales = new Float32Array(m); | |
| function dequant(col, kIndex) { | |
| const group = Math.floor(kIndex / groupSize); | |
| const inner = kIndex - group * groupSize; | |
| const word = Math.floor(inner / 8); | |
| const shift = (inner & 7) * 4; | |
| const packed = q4[(group * packedWords + word) * m + col]; | |
| const nibble = (packed >>> shift) & 15; | |
| const zp = zeroPoints[group * m + col] & 15; | |
| return (nibble - zp) * float16BitsToFloat32(scales[group * m + col]); | |
| } | |
| const convertStart = performance.now(); | |
| for (let col = 0; col < m; ++col) { | |
| let absmax = 0; | |
| for (let kIndex = 0; kIndex < k; ++kIndex) { | |
| absmax = Math.max(absmax, Math.abs(dequant(col, kIndex))); | |
| } | |
| const scale = absmax > 0 ? absmax / 127 : 1; | |
| wScales[col] = scale; | |
| for (let word = 0; word < kWords; ++word) { | |
| let packed = 0; | |
| for (let lane = 0; lane < 4; ++lane) { | |
| const kIndex = word * 4 + lane; | |
| let q = 0; | |
| if (kIndex < k) { | |
| q = roundToNearest(dequant(col, kIndex) / scale); | |
| q = Math.max(-127, Math.min(127, q)); | |
| } | |
| packed = (packed | ((q & 0xff) << (lane * 8))) >>> 0; | |
| } | |
| wPacked[word * m + col] = packed; | |
| } | |
| } | |
| const convertMs = performance.now() - convertStart; | |
| return { | |
| manifest: { | |
| ...manifest, | |
| i8_dp4a: { | |
| k_words: kWords, | |
| packed_u32_count: wPacked.length, | |
| scale_count: wScales.length, | |
| generated_from_q4: true, | |
| }, | |
| }, | |
| weight: wPacked, | |
| scales: wScales, | |
| weightCacheKey: `${cacheKey}:dynamic-weight`, | |
| scalesCacheKey: `${cacheKey}:dynamic-scales`, | |
| loadMs: fetchMs, | |
| convertMs, | |
| cacheHit: false, | |
| }; | |
| })(); | |
| dynamicI8BundleCache.set(cacheKey, promise); | |
| try { | |
| const result = await promise; | |
| dynamicI8BundleCache.set(cacheKey, result); | |
| return result; | |
| } catch (err) { | |
| dynamicI8BundleCache.delete(cacheKey); | |
| throw err; | |
| } | |
| } | |
| async function loadI8BundleFromFullEntry(entry, bundleBaseUrl) { | |
| const manifest = normalizeFullLinearEntry(entry); | |
| if (!manifest.i8_dp4a) { | |
| return loadDynamicI8BundleFromQ4(manifest, bundleBaseUrl); | |
| } | |
| const cacheKey = `${bundleBaseUrl}|${manifest.id || manifest.name || manifest.source_node}|prepacked-i8`; | |
| if (dynamicI8BundleCache.has(cacheKey)) { | |
| const cached = await dynamicI8BundleCache.get(cacheKey); | |
| return { | |
| ...cached, | |
| cacheHit: true, | |
| loadMs: 0, | |
| convertMs: 0, | |
| }; | |
| } | |
| const promise = (async () => { | |
| const start = performance.now(); | |
| const weightUrl = new URL(manifest.i8_dp4a.packed_u32, new URL(bundleBaseUrl, window.location.href)).toString(); | |
| const scalesUrl = new URL(manifest.i8_dp4a.scales_f32, new URL(bundleBaseUrl, window.location.href)).toString(); | |
| const [wBuf, scaleBuf] = await Promise.all([ | |
| fetchArrayBuffer(weightUrl), | |
| fetchArrayBuffer(scalesUrl), | |
| ]); | |
| return { | |
| manifest: { | |
| id: manifest.id, | |
| name: manifest.name, | |
| source_node: manifest.source_node, | |
| shape: manifest.shape, | |
| i8_dp4a: manifest.i8_dp4a, | |
| }, | |
| weight: new Uint32Array(wBuf), | |
| scales: new Float32Array(scaleBuf), | |
| weightCacheKey: weightUrl, | |
| scalesCacheKey: scalesUrl, | |
| loadMs: performance.now() - start, | |
| convertMs: 0, | |
| cacheHit: false, | |
| prepackedI8: true, | |
| }; | |
| })(); | |
| dynamicI8BundleCache.set(cacheKey, promise); | |
| try { | |
| const result = await promise; | |
| dynamicI8BundleCache.set(cacheKey, result); | |
| return result; | |
| } catch (err) { | |
| dynamicI8BundleCache.delete(cacheKey); | |
| throw err; | |
| } | |
| } | |
| async function loadFullTransformerAssetSet(fullManifest, baseUrl, doubleBlockCount = 5, singleBlockCount = 20) { | |
| const [imgIn, txtIn, finalLinear, doubleBlocks, singleBlocks, singleNorms] = await Promise.all([ | |
| loadI8BundleFromFullEntry(findFullManifestLinear(fullManifest, "img_in"), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(fullManifest, "txt_in"), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(fullManifest, "final_layer.linear"), baseUrl), | |
| Promise.all(Array.from({length: doubleBlockCount}, async (_, blockIndex) => { | |
| const [ | |
| imgQkv, txtQkv, imgProj, txtProj, imgMlp0, imgMlp2, txtMlp0, txtMlp2, | |
| imgNorm, txtNorm, | |
| ] = await Promise.all([ | |
| loadI8BundleFromFullEntry(findFullManifestLinear(fullManifest, `double_blocks.${blockIndex}.img_attn.qkv`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(fullManifest, `double_blocks.${blockIndex}.txt_attn.qkv`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(fullManifest, `double_blocks.${blockIndex}.img_attn.proj`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(fullManifest, `double_blocks.${blockIndex}.txt_attn.proj`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(fullManifest, `double_blocks.${blockIndex}.img_mlp.0`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(fullManifest, `double_blocks.${blockIndex}.img_mlp.2`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(fullManifest, `double_blocks.${blockIndex}.txt_mlp.0`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(fullManifest, `double_blocks.${blockIndex}.txt_mlp.2`), baseUrl), | |
| loadQkNormScales(fullManifest, baseUrl, `double_blocks.${blockIndex}.img_attn.norm`), | |
| loadQkNormScales(fullManifest, baseUrl, `double_blocks.${blockIndex}.txt_attn.norm`), | |
| ]); | |
| return {blockIndex, imgQkv, txtQkv, imgProj, txtProj, imgMlp0, imgMlp2, txtMlp0, txtMlp2, imgNorm, txtNorm}; | |
| })), | |
| Promise.all(Array.from({length: singleBlockCount}, async (_, blockIndex) => ({ | |
| blockIndex, | |
| linear1: await loadI8BundleFromFullEntry(findFullManifestLinear(fullManifest, `single_blocks.${blockIndex}.linear1`), baseUrl), | |
| linear2: await loadI8BundleFromFullEntry(findFullManifestLinear(fullManifest, `single_blocks.${blockIndex}.linear2`), baseUrl), | |
| }))), | |
| Promise.all(Array.from({length: singleBlockCount}, (_, blockIndex) => loadQkNormScales(fullManifest, baseUrl, `single_blocks.${blockIndex}.norm`))), | |
| ]); | |
| return {imgIn, txtIn, finalLinear, doubleBlocks, singleBlocks, singleNorms}; | |
| } | |
| function summarizeFullTransformerAssetSet(assetSet) { | |
| let linears = 0; | |
| let cacheHits = 0; | |
| let bytes = 0; | |
| let qkNorms = 0; | |
| const visit = (value) => { | |
| if (!value || typeof value !== "object") return; | |
| if (Array.isArray(value)) { | |
| for (const item of value) visit(item); | |
| return; | |
| } | |
| if (value.weight && value.scales && value.manifest) { | |
| linears += 1; | |
| if (value.cacheHit) cacheHits += 1; | |
| bytes += (value.weight.byteLength || 0) + (value.scales.byteLength || 0); | |
| return; | |
| } | |
| if (value.queryScale && value.keyScale) { | |
| qkNorms += 1; | |
| bytes += (value.queryScale.byteLength || 0) + (value.keyScale.byteLength || 0); | |
| return; | |
| } | |
| for (const child of Object.values(value)) visit(child); | |
| }; | |
| visit(assetSet); | |
| return { | |
| linears, | |
| cacheHits, | |
| cacheMisses: Math.max(0, linears - cacheHits), | |
| qkNorms, | |
| bytes, | |
| }; | |
| } | |
| async function prepareCustomFluxTransformerAssets(config = {}) { | |
| if (!config.fullManifest) { | |
| throw new Error("custom full transformer asset prepare requires fullManifest"); | |
| } | |
| const baseUrl = config.bundleBaseUrl || "/runtime/custom_lowbit/full_transformer/"; | |
| const doubleBlockCount = Math.max(1, Math.min(5, Math.trunc(Number(config.maxDoubleBlocks || config.doubleBlockCount || 5)))); | |
| const singleBlockCount = Math.max(0, Math.min(20, Math.trunc(Number(config.maxSingleBlocks || config.singleBlockCount || 20)))); | |
| const start = performance.now(); | |
| const assets = await loadFullTransformerAssetSet(config.fullManifest, baseUrl, doubleBlockCount, singleBlockCount); | |
| const loadMs = performance.now() - start; | |
| return { | |
| verdict: "custom-flux-transformer-assets-prepared", | |
| load: {total_ms: loadMs}, | |
| summary: { | |
| ...summarizeFullTransformerAssetSet(assets), | |
| doubleBlockCount, | |
| singleBlockCount, | |
| }, | |
| }; | |
| } | |
| async function runCustomTimestepModulationBench(config = {}) { | |
| if (!navigator.gpu) { | |
| throw new Error("navigator.gpu is not available"); | |
| } | |
| const adapter = await navigator.gpu.requestAdapter({powerPreference: "high-performance"}); | |
| if (!adapter) { | |
| throw new Error("WebGPU adapter is not available"); | |
| } | |
| if (!adapter.features.has("shader-f16")) { | |
| throw new Error("WebGPU adapter does not expose shader-f16"); | |
| } | |
| const deviceDescriptor = {requiredFeatures: ["shader-f16"]}; | |
| if (adapter.limits && adapter.limits.maxComputeWorkgroupStorageSize >= 32768) { | |
| deviceDescriptor.requiredLimits = {maxComputeWorkgroupStorageSize: 32768}; | |
| } | |
| const device = await adapter.requestDevice(deviceDescriptor); | |
| const fullManifest = config.manifest; | |
| const bundleBaseUrl = config.bundleBaseUrl || "/runtime/custom_lowbit/full_transformer/"; | |
| const tileCols = Number(config.q4TileCols || 96); | |
| const timestep = Number(config.timestep ?? 1.0); | |
| const modId = modulationOutputId(config.modulation || config.modulationKind || "single"); | |
| const timeInIn = findFullManifestLinear(fullManifest, "time_in.in_layer"); | |
| const timeInOut = findFullManifestLinear(fullManifest, "time_in.out_layer"); | |
| const modulation = findFullManifestLinear(fullManifest, modId); | |
| if (timeInIn.shape.K !== 256 || timeInIn.shape.N !== 3072 || | |
| timeInOut.shape.K !== 3072 || timeInOut.shape.N !== 3072 || | |
| modulation.shape.K !== 3072) { | |
| throw new Error(`unexpected modulation chain shapes: ${timeInIn.shape.K}x${timeInIn.shape.N}, ${timeInOut.shape.K}x${timeInOut.shape.N}, ${modulation.shape.K}x${modulation.shape.N}`); | |
| } | |
| const timestepBuffer = createBuffer(device, makeTimestepEmbeddingF16(timestep), GPUBufferUsage.STORAGE); | |
| const hiddenF32Buffer = createEmptyBuffer(device, 3072 * 4, GPUBufferUsage.STORAGE); | |
| const hiddenF16Buffer = createEmptyBuffer(device, 3072 * 2, GPUBufferUsage.STORAGE); | |
| const vecF32Buffer = createEmptyBuffer(device, 3072 * 4, GPUBufferUsage.STORAGE); | |
| const vecF16Buffer = createEmptyBuffer(device, 3072 * 2, GPUBufferUsage.STORAGE); | |
| const modF32Buffer = createEmptyBuffer(device, modulation.shape.N * 4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| const stage0 = await createQ4LinearStage(device, timeInIn, bundleBaseUrl, timestepBuffer, hiddenF32Buffer, 1, tileCols); | |
| const siluStage = await createF32ToF16Stage(device, hiddenF32Buffer, hiddenF16Buffer, 3072, true); | |
| const stage1 = await createQ4LinearStage(device, timeInOut, bundleBaseUrl, hiddenF16Buffer, vecF32Buffer, 1, tileCols); | |
| const vecCastStage = await createF32ToF16Stage(device, vecF32Buffer, vecF16Buffer, 3072, true); | |
| const stage2 = await createQ4LinearStage(device, modulation, bundleBaseUrl, vecF16Buffer, modF32Buffer, 1, tileCols); | |
| async function dispatch() { | |
| const encoder = device.createCommandEncoder(); | |
| let pass = encoder.beginComputePass(); | |
| pass.setPipeline(stage0.pipeline); | |
| pass.setBindGroup(0, stage0.bindGroup); | |
| pass.dispatchWorkgroups(stage0.workgroupsX, stage0.workgroupsY); | |
| pass.end(); | |
| pass = encoder.beginComputePass(); | |
| pass.setPipeline(siluStage.pipeline); | |
| pass.setBindGroup(0, siluStage.bindGroup); | |
| pass.dispatchWorkgroups(siluStage.workgroupsX); | |
| pass.end(); | |
| pass = encoder.beginComputePass(); | |
| pass.setPipeline(stage1.pipeline); | |
| pass.setBindGroup(0, stage1.bindGroup); | |
| pass.dispatchWorkgroups(stage1.workgroupsX, stage1.workgroupsY); | |
| pass.end(); | |
| pass = encoder.beginComputePass(); | |
| pass.setPipeline(vecCastStage.pipeline); | |
| pass.setBindGroup(0, vecCastStage.bindGroup); | |
| pass.dispatchWorkgroups(vecCastStage.workgroupsX); | |
| pass.end(); | |
| pass = encoder.beginComputePass(); | |
| pass.setPipeline(stage2.pipeline); | |
| pass.setBindGroup(0, stage2.bindGroup); | |
| pass.dispatchWorkgroups(stage2.workgroupsX, stage2.workgroupsY); | |
| pass.end(); | |
| device.queue.submit([encoder.finish()]); | |
| await device.queue.onSubmittedWorkDone(); | |
| } | |
| for (let i = 0; i < Number(config.warmupRuns ?? 1); ++i) { | |
| await dispatch(); | |
| } | |
| const times = []; | |
| for (let i = 0; i < Number(config.timedRuns ?? 3); ++i) { | |
| const start = performance.now(); | |
| await dispatch(); | |
| times.push(performance.now() - start); | |
| } | |
| let sample = null; | |
| const readbackSample = Number(config.readbackSample ?? 24); | |
| if (readbackSample > 0) { | |
| const values = await readFloat32Buffer(device, modF32Buffer, Math.min(readbackSample, modulation.shape.N)); | |
| let finite = 0; | |
| let maxAbs = 0; | |
| for (const value of values) { | |
| if (Number.isFinite(value)) finite += 1; | |
| maxAbs = Math.max(maxAbs, Math.abs(value)); | |
| } | |
| sample = {count: values.length, finite, max_abs: maxAbs, values: Array.from(values.slice(0, Math.min(12, values.length)))}; | |
| } | |
| const medianMs = median(times); | |
| return { | |
| verdict: "custom-timestep-modulation-bench-completed", | |
| config: { | |
| timestep, | |
| modulation: modId, | |
| warmupRuns: Number(config.warmupRuns ?? 1), | |
| timedRuns: Number(config.timedRuns ?? 3), | |
| q4TileCols: tileCols, | |
| }, | |
| stages: [ | |
| {id: stage0.id, shape: stage0.shape}, | |
| {id: "silu"}, | |
| {id: stage1.id, shape: stage1.shape}, | |
| {id: "cast_vec_f32_to_f16"}, | |
| {id: stage2.id, shape: stage2.shape}, | |
| ], | |
| timed_ms: times, | |
| summary: { | |
| median_dispatch_ms: medianMs, | |
| }, | |
| sample, | |
| }; | |
| } | |
| window.runCustomTimestepModulationBench = runCustomTimestepModulationBench; | |
| async function runCustomSingleBlockQ4Bench(config = {}) { | |
| if (!navigator.gpu) { | |
| throw new Error("navigator.gpu is not available"); | |
| } | |
| const adapter = await navigator.gpu.requestAdapter({powerPreference: "high-performance"}); | |
| if (!adapter) { | |
| throw new Error("WebGPU adapter is not available"); | |
| } | |
| if (!adapter.features.has("shader-f16")) { | |
| throw new Error("WebGPU adapter does not expose shader-f16"); | |
| } | |
| const deviceDescriptor = {requiredFeatures: ["shader-f16"]}; | |
| if (adapter.limits && adapter.limits.maxComputeWorkgroupStorageSize >= 32768) { | |
| deviceDescriptor.requiredLimits = {maxComputeWorkgroupStorageSize: 32768}; | |
| } | |
| const device = await adapter.requestDevice(deviceDescriptor); | |
| const fullManifest = config.manifest; | |
| const bundleBaseUrl = config.bundleBaseUrl || "/runtime/custom_lowbit/full_transformer/"; | |
| const blockIndex = Number(config.blockIndex ?? 0); | |
| const n = Number(config.n ?? 768); | |
| const inputK = 3072; | |
| const linear1M = 27648; | |
| const linear2K = 12288; | |
| const linear2M = 3072; | |
| const tileCols = Number(config.q4TileCols || 96); | |
| const textTokens = Math.max(0, Math.min(n, Number(config.textTokens ?? 512))); | |
| const imageTokens = Math.max(0, n - textTokens); | |
| const imageWidth = Math.max(1, Number(config.imageWidth ?? Math.round(Math.sqrt(Math.max(1, imageTokens))))); | |
| const useRope = config.rope !== false; | |
| const linear1 = findFullManifestLinear(fullManifest, `single_blocks.${blockIndex}.linear1`); | |
| const linear2 = findFullManifestLinear(fullManifest, `single_blocks.${blockIndex}.linear2`); | |
| const modulation = findFullManifestLinear(fullManifest, "single_stream_modulation.lin"); | |
| if (linear1.shape.K !== inputK || linear1.shape.N !== linear1M || linear2.shape.K !== linear2K || linear2.shape.N !== linear2M) { | |
| throw new Error(`unexpected single block ${blockIndex} q4 shapes: linear1 ${linear1.shape.K}x${linear1.shape.N}, linear2 ${linear2.shape.K}x${linear2.shape.N}`); | |
| } | |
| const qk = await loadQkNormScales(fullManifest, bundleBaseUrl, `single_blocks.${blockIndex}.norm`); | |
| const timestepBuffer = createBuffer(device, makeTimestepEmbeddingF16(Number(config.timestep ?? 1.0)), GPUBufferUsage.STORAGE); | |
| const timeHiddenF32Buffer = createEmptyBuffer(device, 3072 * 4, GPUBufferUsage.STORAGE); | |
| const timeHiddenF16Buffer = createEmptyBuffer(device, 3072 * 2, GPUBufferUsage.STORAGE); | |
| const vecF32Buffer = createEmptyBuffer(device, 3072 * 4, GPUBufferUsage.STORAGE); | |
| const vecF16Buffer = createEmptyBuffer(device, 3072 * 2, GPUBufferUsage.STORAGE); | |
| const modF32Buffer = createEmptyBuffer(device, 9216 * 4, GPUBufferUsage.STORAGE); | |
| const xF32Buffer = createBuffer(device, buildInputF32(n, inputK), GPUBufferUsage.STORAGE); | |
| const xNormF16Buffer = createEmptyBuffer(device, n * inputK * 2, GPUBufferUsage.STORAGE); | |
| const linear1OutBuffer = createEmptyBuffer(device, n * linear1M * 4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| const qNormBuffer = createEmptyBuffer(device, n * 3072 * 4, GPUBufferUsage.STORAGE); | |
| const kNormBuffer = createEmptyBuffer(device, n * 3072 * 4, GPUBufferUsage.STORAGE); | |
| const attentionOutBuffer = createEmptyBuffer(device, n * 3072 * 4, GPUBufferUsage.STORAGE); | |
| const linear2InF32Buffer = createEmptyBuffer(device, n * linear2K * 4, GPUBufferUsage.STORAGE); | |
| const linear2InF16Buffer = createEmptyBuffer(device, n * linear2K * 2, GPUBufferUsage.STORAGE); | |
| const linear2OutBuffer = createEmptyBuffer(device, n * linear2M * 4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| const residualOutputBuffer = createEmptyBuffer(device, n * linear2M * 4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| const queryScaleBuffer = createBuffer(device, qk.queryScale, GPUBufferUsage.STORAGE); | |
| const keyScaleBuffer = createBuffer(device, qk.keyScale, GPUBufferUsage.STORAGE); | |
| const ropeFreqBuffer = useRope ? createBuffer(device, buildFluxRopeSinCos(n, textTokens, imageWidth), GPUBufferUsage.STORAGE) : null; | |
| const timeIn0 = findFullManifestLinear(fullManifest, "time_in.in_layer"); | |
| const timeIn1 = findFullManifestLinear(fullManifest, "time_in.out_layer"); | |
| const timeStage0 = await createQ4LinearStage(device, timeIn0, bundleBaseUrl, timestepBuffer, timeHiddenF32Buffer, 1, tileCols); | |
| const timeSiluStage = await createF32ToF16Stage(device, timeHiddenF32Buffer, timeHiddenF16Buffer, 3072, true); | |
| const timeStage1 = await createQ4LinearStage(device, timeIn1, bundleBaseUrl, timeHiddenF16Buffer, vecF32Buffer, 1, tileCols); | |
| const vecCastStage = await createF32ToF16Stage(device, vecF32Buffer, vecF16Buffer, 3072, true); | |
| const modStage = await createQ4LinearStage(device, modulation, bundleBaseUrl, vecF16Buffer, modF32Buffer, 1, tileCols); | |
| const preNormModule = device.createShaderModule({code: SINGLE_PRENORM_MOD_F16_SHADER}); | |
| const qkNormModule = device.createShaderModule({code: useRope ? SINGLE_QK_NORM_ROPE_SHADER : SINGLE_QK_NORM_SHADER}); | |
| const attentionModule = device.createShaderModule({code: SINGLE_ATTENTION_PRENORM_SHADER}); | |
| const activationAttentionModule = device.createShaderModule({code: SINGLE_MLP_ACTIVATION_ATTENTION_SHADER}); | |
| const residualModule = device.createShaderModule({code: SINGLE_RESIDUAL_GATE_SHADER}); | |
| const preNormPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: preNormModule, entryPoint: "main"}}); | |
| const qkNormPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: qkNormModule, entryPoint: "main"}}); | |
| const attentionPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: attentionModule, entryPoint: "main"}}); | |
| const activationPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: activationAttentionModule, entryPoint: "main"}}); | |
| const linear2CastStage = await createF32ToF16Stage(device, linear2InF32Buffer, linear2InF16Buffer, n * linear2K, false); | |
| const linear1Stage = await createQ4LinearStage(device, linear1, bundleBaseUrl, xNormF16Buffer, linear1OutBuffer, n, tileCols); | |
| const linear2Stage = await createQ4LinearStage(device, linear2, bundleBaseUrl, linear2InF16Buffer, linear2OutBuffer, n, tileCols); | |
| const residualPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: residualModule, entryPoint: "main"}}); | |
| const preNormParamsBuffer = createBuffer(device, new Uint32Array([n, inputK, 0, 0]), GPUBufferUsage.UNIFORM); | |
| const attentionParamsBuffer = createBuffer(device, new Uint32Array([n, linear1M, 24, 128, 9216, 3072, textTokens, imageWidth]), GPUBufferUsage.UNIFORM); | |
| const activationParamsBuffer = createBuffer(device, new Uint32Array([n, linear1M, linear2K, 3072, 9216, 9216, 0, 0]), GPUBufferUsage.UNIFORM); | |
| const residualParamsBuffer = createBuffer(device, new Uint32Array([n, linear2M, 0, 0]), GPUBufferUsage.UNIFORM); | |
| const modStride = 3072 * 4; | |
| const preNormBindGroup = device.createBindGroup({ | |
| layout: preNormPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: xF32Buffer}}, | |
| {binding: 1, resource: {buffer: modF32Buffer, offset: 0}}, | |
| {binding: 2, resource: {buffer: modF32Buffer, offset: modStride}}, | |
| {binding: 3, resource: {buffer: xNormF16Buffer}}, | |
| {binding: 4, resource: {buffer: preNormParamsBuffer}}, | |
| ], | |
| }); | |
| const qkNormBindGroup = device.createBindGroup({ | |
| layout: qkNormPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: linear1OutBuffer}}, | |
| {binding: 1, resource: {buffer: queryScaleBuffer}}, | |
| {binding: 2, resource: {buffer: keyScaleBuffer}}, | |
| {binding: 3, resource: {buffer: qNormBuffer}}, | |
| {binding: 4, resource: {buffer: kNormBuffer}}, | |
| {binding: 5, resource: {buffer: attentionParamsBuffer}}, | |
| ...(useRope ? [{binding: 6, resource: {buffer: ropeFreqBuffer}}] : []), | |
| ], | |
| }); | |
| const attentionBindGroup = device.createBindGroup({ | |
| layout: attentionPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: linear1OutBuffer}}, | |
| {binding: 1, resource: {buffer: qNormBuffer}}, | |
| {binding: 2, resource: {buffer: kNormBuffer}}, | |
| {binding: 3, resource: {buffer: attentionOutBuffer}}, | |
| {binding: 4, resource: {buffer: attentionParamsBuffer}}, | |
| ], | |
| }); | |
| const activationBindGroup = device.createBindGroup({ | |
| layout: activationPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: linear1OutBuffer}}, | |
| {binding: 1, resource: {buffer: attentionOutBuffer}}, | |
| {binding: 2, resource: {buffer: linear2InF32Buffer}}, | |
| {binding: 3, resource: {buffer: activationParamsBuffer}}, | |
| ], | |
| }); | |
| const residualBindGroup = device.createBindGroup({ | |
| layout: residualPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: xF32Buffer}}, | |
| {binding: 1, resource: {buffer: linear2OutBuffer}}, | |
| {binding: 2, resource: {buffer: modF32Buffer, offset: modStride * 2}}, | |
| {binding: 3, resource: {buffer: residualOutputBuffer}}, | |
| {binding: 4, resource: {buffer: residualParamsBuffer}}, | |
| ], | |
| }); | |
| async function dispatch() { | |
| const encoder = device.createCommandEncoder(); | |
| let pass = encoder.beginComputePass(); | |
| pass.setPipeline(timeStage0.pipeline); | |
| pass.setBindGroup(0, timeStage0.bindGroup); | |
| pass.dispatchWorkgroups(timeStage0.workgroupsX, timeStage0.workgroupsY); | |
| pass.end(); | |
| pass = encoder.beginComputePass(); | |
| pass.setPipeline(timeSiluStage.pipeline); | |
| pass.setBindGroup(0, timeSiluStage.bindGroup); | |
| pass.dispatchWorkgroups(timeSiluStage.workgroupsX); | |
| pass.end(); | |
| pass = encoder.beginComputePass(); | |
| pass.setPipeline(timeStage1.pipeline); | |
| pass.setBindGroup(0, timeStage1.bindGroup); | |
| pass.dispatchWorkgroups(timeStage1.workgroupsX, timeStage1.workgroupsY); | |
| pass.end(); | |
| pass = encoder.beginComputePass(); | |
| pass.setPipeline(vecCastStage.pipeline); | |
| pass.setBindGroup(0, vecCastStage.bindGroup); | |
| pass.dispatchWorkgroups(vecCastStage.workgroupsX); | |
| pass.end(); | |
| pass = encoder.beginComputePass(); | |
| pass.setPipeline(modStage.pipeline); | |
| pass.setBindGroup(0, modStage.bindGroup); | |
| pass.dispatchWorkgroups(modStage.workgroupsX, modStage.workgroupsY); | |
| pass.end(); | |
| pass = encoder.beginComputePass(); | |
| pass.setPipeline(preNormPipeline); | |
| pass.setBindGroup(0, preNormBindGroup); | |
| pass.dispatchWorkgroups(n); | |
| pass.end(); | |
| pass = encoder.beginComputePass(); | |
| pass.setPipeline(linear1Stage.pipeline); | |
| pass.setBindGroup(0, linear1Stage.bindGroup); | |
| pass.dispatchWorkgroups(linear1Stage.workgroupsX, linear1Stage.workgroupsY); | |
| pass.end(); | |
| pass = encoder.beginComputePass(); | |
| pass.setPipeline(qkNormPipeline); | |
| pass.setBindGroup(0, qkNormBindGroup); | |
| pass.dispatchWorkgroups(n, 24); | |
| pass.end(); | |
| pass = encoder.beginComputePass(); | |
| pass.setPipeline(attentionPipeline); | |
| pass.setBindGroup(0, attentionBindGroup); | |
| pass.dispatchWorkgroups(n, 24); | |
| pass.end(); | |
| pass = encoder.beginComputePass(); | |
| pass.setPipeline(activationPipeline); | |
| pass.setBindGroup(0, activationBindGroup); | |
| pass.dispatchWorkgroups(Math.ceil(linear2K / 256), n); | |
| pass.end(); | |
| pass = encoder.beginComputePass(); | |
| pass.setPipeline(linear2CastStage.pipeline); | |
| pass.setBindGroup(0, linear2CastStage.bindGroup); | |
| pass.dispatchWorkgroups(linear2CastStage.workgroupsX); | |
| pass.end(); | |
| pass = encoder.beginComputePass(); | |
| pass.setPipeline(linear2Stage.pipeline); | |
| pass.setBindGroup(0, linear2Stage.bindGroup); | |
| pass.dispatchWorkgroups(linear2Stage.workgroupsX, linear2Stage.workgroupsY); | |
| pass.end(); | |
| pass = encoder.beginComputePass(); | |
| pass.setPipeline(residualPipeline); | |
| pass.setBindGroup(0, residualBindGroup); | |
| pass.dispatchWorkgroups(Math.ceil(linear2M / 256), n); | |
| pass.end(); | |
| device.queue.submit([encoder.finish()]); | |
| await device.queue.onSubmittedWorkDone(); | |
| } | |
| for (let i = 0; i < Number(config.warmupRuns ?? 1); ++i) { | |
| await dispatch(); | |
| } | |
| const times = []; | |
| for (let i = 0; i < Number(config.timedRuns ?? 3); ++i) { | |
| const start = performance.now(); | |
| await dispatch(); | |
| times.push(performance.now() - start); | |
| } | |
| let sample = null; | |
| if (config.readbackSample) { | |
| const count = Math.min(Number(config.readbackSample), n * linear2M); | |
| const values = await readFloat32Buffer(device, residualOutputBuffer, count); | |
| let finite = 0; | |
| let maxAbs = 0; | |
| for (const value of values) { | |
| if (Number.isFinite(value)) finite += 1; | |
| maxAbs = Math.max(maxAbs, Math.abs(value)); | |
| } | |
| sample = {count, finite, max_abs: maxAbs, values: Array.from(values.slice(0, Math.min(8, values.length)))}; | |
| } | |
| const medianMs = median(times); | |
| const macs = n * inputK * linear1M + n * linear2K * linear2M; | |
| return { | |
| verdict: "custom-single-block-q4-bench-completed", | |
| config: { | |
| blockIndex, | |
| n, | |
| timestep: Number(config.timestep ?? 1.0), | |
| warmupRuns: Number(config.warmupRuns ?? 1), | |
| timedRuns: Number(config.timedRuns ?? 3), | |
| q4TileCols: tileCols, | |
| rope: useRope, | |
| textTokens, | |
| imageWidth, | |
| }, | |
| layers: { | |
| modulation: {source_node: modulation.source_node, shape: modulation.shape}, | |
| qkNorm: {id: qk.entry.id, shape: qk.entry.shape}, | |
| linear1: {source_node: linear1.source_node, shape: linear1.shape}, | |
| linear2: {source_node: linear2.source_node, shape: linear2.shape}, | |
| }, | |
| timed_ms: times, | |
| summary: { | |
| median_dispatch_ms: medianMs, | |
| effective_tmacs: macs / (medianMs / 1000) / 1e12, | |
| }, | |
| sample, | |
| }; | |
| } | |
| window.runCustomSingleBlockQ4Bench = runCustomSingleBlockQ4Bench; | |
| async function runCustomSingleStreamMlpBench(config = {}) { | |
| if (!navigator.gpu) { | |
| throw new Error("navigator.gpu is not available"); | |
| } | |
| const adapter = await navigator.gpu.requestAdapter({powerPreference: "high-performance"}); | |
| if (!adapter) { | |
| throw new Error("WebGPU adapter is not available"); | |
| } | |
| if (!adapter.features.has("shader-f16")) { | |
| throw new Error("WebGPU adapter does not expose shader-f16"); | |
| } | |
| const wgslLanguageFeatures = navigator.gpu.wgslLanguageFeatures || new Set(); | |
| if (!wgslLanguageFeatures.has("packed_4x8_integer_dot_product")) { | |
| throw new Error("WGSL packed_4x8_integer_dot_product is not available"); | |
| } | |
| const deviceDescriptor = {requiredFeatures: ["shader-f16"]}; | |
| if (adapter.limits && adapter.limits.maxComputeWorkgroupStorageSize >= 32768) { | |
| deviceDescriptor.requiredLimits = {maxComputeWorkgroupStorageSize: 32768}; | |
| } | |
| const device = await adapter.requestDevice(deviceDescriptor); | |
| const linear1BaseUrl = config.linear1BaseUrl || "/linear1/"; | |
| const linear2BaseUrl = config.linear2BaseUrl || "/linear2/"; | |
| const blockBaseUrl = config.blockBaseUrl || "/block/"; | |
| const useRealAttention = Boolean(config.realAttention); | |
| const useDynamicI8FromQ4 = Boolean(config.fullManifest && config.dynamicI8FromQ4); | |
| const blockIndex = Number(config.blockIndex ?? 0); | |
| const loadStart = performance.now(); | |
| let linear1; | |
| let linear2; | |
| let block; | |
| if (useDynamicI8FromQ4) { | |
| [linear1, linear2, block] = await Promise.all([ | |
| loadI8BundleFromFullEntry( | |
| findFullManifestLinear(config.fullManifest, `single_blocks.${blockIndex}.linear1`), | |
| config.bundleBaseUrl || "/runtime/custom_lowbit/full_transformer/", | |
| ), | |
| loadI8BundleFromFullEntry( | |
| findFullManifestLinear(config.fullManifest, `single_blocks.${blockIndex}.linear2`), | |
| config.bundleBaseUrl || "/runtime/custom_lowbit/full_transformer/", | |
| ), | |
| useRealAttention | |
| ? loadQkNormScales( | |
| config.fullManifest, | |
| config.bundleBaseUrl || "/runtime/custom_lowbit/full_transformer/", | |
| `single_blocks.${blockIndex}.norm`, | |
| ) | |
| : Promise.resolve(null), | |
| ]); | |
| } else { | |
| [linear1, linear2, block] = await Promise.all([ | |
| loadI8Bundle(linear1BaseUrl), | |
| loadI8Bundle(linear2BaseUrl), | |
| useRealAttention ? loadSingleBlockBundle(blockBaseUrl) : Promise.resolve(null), | |
| ]); | |
| } | |
| const loadMs = performance.now() - loadStart; | |
| const n = Number(config.n ?? 768); | |
| const inputK = linear1.manifest.shape.K; | |
| const linear1M = linear1.manifest.shape.N; | |
| const linear2K = linear2.manifest.shape.K; | |
| const linear2M = linear2.manifest.shape.N; | |
| const linear1KWords = linear1.manifest.i8_dp4a.k_words; | |
| const linear2KWords = linear2.manifest.i8_dp4a.k_words; | |
| const linear1TileCols = Number(config.linear1TileCols ?? 256); | |
| const linear2TileCols = Number(config.linear2TileCols ?? 256); | |
| const linear1WChunkCols = Number(config.linear1WChunkCols ?? 64); | |
| const linear2WChunkCols = Number(config.linear2WChunkCols ?? 64); | |
| const fuseActivationQuant = Boolean(config.fuseActivationQuant); | |
| const linear1OutputF16 = config.linear1OutputF16 !== false; | |
| const useRealModulation = Boolean(config.fullManifest && config.realModulation); | |
| const correctSingleBlock = Boolean(config.correctSingleBlock || useRealModulation); | |
| const fuseResidualDot = correctSingleBlock && config.fuseResidualDot !== false; | |
| const singleComputePass = config.singleComputePass !== false; | |
| const useRope = Boolean(config.rope); | |
| const useTiledAttention = useRealAttention && config.tiledAttention !== false; | |
| const attentionTileKeys = Number(config.attentionTileKeys ?? 8); | |
| const attentionQueryRows = Number(config.attentionQueryRows ?? 16); | |
| const textTokens = Math.max(0, Math.min(n, Number(config.textTokens ?? 512))); | |
| const imageTokens = Math.max(0, n - textTokens); | |
| const imageWidth = Math.max(1, Number(config.imageWidth ?? Math.round(Math.sqrt(Math.max(1, imageTokens))))); | |
| const rowsPerWorkgroup = 32; | |
| if (useRealAttention && n > 4608) { | |
| throw new Error("single-stream real-attention benchmark currently supports n <= 4608"); | |
| } | |
| if (inputK !== 3072 || linear1M !== 27648 || linear2K !== 12288 || linear2M !== 3072) { | |
| throw new Error(`unexpected single-stream MLP shapes: linear1 K=${inputK} M=${linear1M}, linear2 K=${linear2K} M=${linear2M}`); | |
| } | |
| const inputF32 = buildInputF32(n, inputK); | |
| const modShift = new Float32Array(inputK); | |
| const modScale = new Float32Array(inputK); | |
| const modGate = new Float32Array(inputK); | |
| modGate.fill(1); | |
| const xF32Buffer = createBuffer(device, inputF32, GPUBufferUsage.STORAGE); | |
| const modF32Buffer = useRealModulation | |
| ? createEmptyBuffer(device, 9216 * 4, GPUBufferUsage.STORAGE) | |
| : null; | |
| const modShiftBuffer = correctSingleBlock && !useRealModulation ? createBuffer(device, modShift, GPUBufferUsage.STORAGE) : null; | |
| const modScaleBuffer = correctSingleBlock && !useRealModulation ? createBuffer(device, modScale, GPUBufferUsage.STORAGE) : null; | |
| const modGateBuffer = correctSingleBlock && !useRealModulation ? createBuffer(device, modGate, GPUBufferUsage.STORAGE) : null; | |
| const xPackedBuffer = createEmptyBuffer(device, n * linear1KWords * 4, GPUBufferUsage.STORAGE); | |
| const xScalesBuffer = createEmptyBuffer(device, n * 4, GPUBufferUsage.STORAGE); | |
| const linear1WBuffer = createBuffer(device, linear1.weight, GPUBufferUsage.STORAGE); | |
| const linear1ScaleBuffer = createBuffer(device, linear1.scales, GPUBufferUsage.STORAGE); | |
| const linear1OutBuffer = createEmptyBuffer( | |
| device, | |
| n * linear1M * (linear1OutputF16 ? 2 : 4), | |
| GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, | |
| ); | |
| const attentionOutBuffer = useRealAttention | |
| ? createEmptyBuffer(device, n * 3072 * 4, GPUBufferUsage.STORAGE) | |
| : null; | |
| const qNormBuffer = useRealAttention | |
| ? createEmptyBuffer(device, n * 3072 * 4, GPUBufferUsage.STORAGE) | |
| : null; | |
| const kNormBuffer = useRealAttention | |
| ? createEmptyBuffer(device, n * 3072 * 4, GPUBufferUsage.STORAGE) | |
| : null; | |
| const queryScaleBuffer = useRealAttention | |
| ? createBuffer(device, block.queryScale, GPUBufferUsage.STORAGE) | |
| : null; | |
| const keyScaleBuffer = useRealAttention | |
| ? createBuffer(device, block.keyScale, GPUBufferUsage.STORAGE) | |
| : null; | |
| const ropeFreqBuffer = useRealAttention && useRope | |
| ? createBuffer(device, buildFluxRopeSinCos(n, textTokens, imageWidth), GPUBufferUsage.STORAGE) | |
| : null; | |
| const linear2InBuffer = fuseActivationQuant | |
| ? null | |
| : createEmptyBuffer(device, n * linear2K * 4, GPUBufferUsage.STORAGE); | |
| const linear2InPackedBuffer = createEmptyBuffer(device, n * linear2KWords * 4, GPUBufferUsage.STORAGE); | |
| const linear2InScalesBuffer = createEmptyBuffer(device, n * 4, GPUBufferUsage.STORAGE); | |
| const linear2WBuffer = createBuffer(device, linear2.weight, GPUBufferUsage.STORAGE); | |
| const linear2ScaleBuffer = createBuffer(device, linear2.scales, GPUBufferUsage.STORAGE); | |
| const outputBuffer = fuseResidualDot | |
| ? null | |
| : createEmptyBuffer(device, n * linear2M * 4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| const residualOutputBuffer = correctSingleBlock | |
| ? createEmptyBuffer(device, n * linear2M * 4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC) | |
| : null; | |
| const quantModule = device.createShaderModule({code: QUANTIZE_X_F32_SHADER}); | |
| const preNormQuantModule = device.createShaderModule({code: SINGLE_PRENORM_MOD_QUANT_SHADER}); | |
| const dot1Module = device.createShaderModule({ | |
| code: makeI8ScaledDotShader32xWide(linear1TileCols, linear1WChunkCols, linear1OutputF16), | |
| }); | |
| const activationModule = device.createShaderModule({code: SINGLE_MLP_ACTIVATION_SHADER}); | |
| const activationAttentionModule = device.createShaderModule({code: SINGLE_MLP_ACTIVATION_ATTENTION_SHADER}); | |
| const qkNormCode = useRope ? SINGLE_QK_NORM_ROPE_SHADER : SINGLE_QK_NORM_SHADER; | |
| const attentionCode = useTiledAttention | |
| ? makeSingleAttentionTiledShader(attentionTileKeys, attentionQueryRows) | |
| : SINGLE_ATTENTION_PRENORM_SHADER; | |
| const qkNormModule = device.createShaderModule({code: linear1OutputF16 ? makeLinear1F16ConsumerShader(qkNormCode) : qkNormCode}); | |
| const attentionModule = device.createShaderModule({code: linear1OutputF16 ? makeLinear1F16ConsumerShader(attentionCode) : attentionCode}); | |
| const activateQuantModule = device.createShaderModule({code: SINGLE_MLP_ACTIVATE_QUANT_SHADER}); | |
| const activateAttentionQuantModule = device.createShaderModule({ | |
| code: linear1OutputF16 | |
| ? makeLinear1F16ConsumerShader(SINGLE_MLP_ACTIVATE_ATTENTION_QUANT_SHADER) | |
| : SINGLE_MLP_ACTIVATE_ATTENTION_QUANT_SHADER, | |
| }); | |
| const dot2Module = device.createShaderModule({ | |
| code: fuseResidualDot | |
| ? makeI8ScaledDotResidualShader32xWide(linear2TileCols, linear2WChunkCols) | |
| : makeI8ScaledDotShader32xWide(linear2TileCols, linear2WChunkCols), | |
| }); | |
| const residualModule = device.createShaderModule({code: SINGLE_RESIDUAL_GATE_SHADER}); | |
| const quantPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: quantModule, entryPoint: "main"}}); | |
| const preNormQuantPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: preNormQuantModule, entryPoint: "main"}}); | |
| const dot1Pipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: dot1Module, entryPoint: "main"}}); | |
| const activationPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: activationModule, entryPoint: "main"}}); | |
| const activationAttentionPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: activationAttentionModule, entryPoint: "main"}}); | |
| const qkNormPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: qkNormModule, entryPoint: "main"}}); | |
| const attentionPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: attentionModule, entryPoint: "main"}}); | |
| const activateQuantPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: activateQuantModule, entryPoint: "main"}}); | |
| const activateAttentionQuantPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: activateAttentionQuantModule, entryPoint: "main"}}); | |
| const dot2Pipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: dot2Module, entryPoint: "main"}}); | |
| const residualPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: residualModule, entryPoint: "main"}}); | |
| let timeStage0 = null; | |
| let timeSiluStage = null; | |
| let timeStage1 = null; | |
| let vecCastStage = null; | |
| let modStage = null; | |
| if (useRealModulation) { | |
| const baseUrl = config.bundleBaseUrl || "/runtime/custom_lowbit/full_transformer/"; | |
| const timestepBuffer = createBuffer(device, makeTimestepEmbeddingF16(Number(config.timestep ?? 1.0)), GPUBufferUsage.STORAGE); | |
| const timeHiddenF32Buffer = createEmptyBuffer(device, 3072 * 4, GPUBufferUsage.STORAGE); | |
| const timeHiddenF16Buffer = createEmptyBuffer(device, 3072 * 2, GPUBufferUsage.STORAGE); | |
| const vecF32Buffer = createEmptyBuffer(device, 3072 * 4, GPUBufferUsage.STORAGE); | |
| const vecF16Buffer = createEmptyBuffer(device, 3072 * 2, GPUBufferUsage.STORAGE); | |
| timeStage0 = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "time_in.in_layer"), baseUrl, timestepBuffer, timeHiddenF32Buffer, 1, 96); | |
| timeSiluStage = await createF32ToF16Stage(device, timeHiddenF32Buffer, timeHiddenF16Buffer, 3072, true); | |
| timeStage1 = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "time_in.out_layer"), baseUrl, timeHiddenF16Buffer, vecF32Buffer, 1, 96); | |
| vecCastStage = await createF32ToF16Stage(device, vecF32Buffer, vecF16Buffer, 3072, true); | |
| modStage = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "single_stream_modulation.lin"), baseUrl, vecF16Buffer, modF32Buffer, 1, 96); | |
| } | |
| const quantXParamsBuffer = createBuffer(device, new Uint32Array([n, inputK, linear1M, linear1KWords]), GPUBufferUsage.UNIFORM); | |
| const dot1ParamsBuffer = createBuffer(device, new Uint32Array([n, inputK, linear1M, linear1KWords, 0, 0, 0, 0]), GPUBufferUsage.UNIFORM); | |
| const activationParamsBuffer = createBuffer( | |
| device, | |
| new Uint32Array([n, linear1M, linear2K, 3072, 9216, 9216, 0, 0]), | |
| GPUBufferUsage.UNIFORM, | |
| ); | |
| const attentionParamsBuffer = useRealAttention | |
| ? createBuffer(device, new Uint32Array([n, linear1M, 24, 128, 9216, 3072, textTokens, imageWidth]), GPUBufferUsage.UNIFORM) | |
| : null; | |
| const activateQuantParamsBuffer = createBuffer( | |
| device, | |
| new Uint32Array([n, linear1M, linear2K, 3072, 9216, 9216, linear2KWords, 0]), | |
| GPUBufferUsage.UNIFORM, | |
| ); | |
| const quantHiddenParamsBuffer = createBuffer(device, new Uint32Array([n, linear2K, linear2M, linear2KWords]), GPUBufferUsage.UNIFORM); | |
| const dot2ParamsBuffer = createBuffer(device, new Uint32Array([n, linear2K, linear2M, linear2KWords, 0, 0, 0, 0]), GPUBufferUsage.UNIFORM); | |
| const modStride = inputK * 4; | |
| const modShiftResource = useRealModulation ? {buffer: modF32Buffer, offset: 0} : {buffer: modShiftBuffer}; | |
| const modScaleResource = useRealModulation ? {buffer: modF32Buffer, offset: modStride} : {buffer: modScaleBuffer}; | |
| const modGateResource = useRealModulation ? {buffer: modF32Buffer, offset: modStride * 2} : {buffer: modGateBuffer}; | |
| const quantXBindGroup = device.createBindGroup({ | |
| layout: quantPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: xF32Buffer}}, | |
| {binding: 1, resource: {buffer: xPackedBuffer}}, | |
| {binding: 2, resource: {buffer: xScalesBuffer}}, | |
| {binding: 3, resource: {buffer: quantXParamsBuffer}}, | |
| ], | |
| }); | |
| const preNormQuantBindGroup = correctSingleBlock | |
| ? device.createBindGroup({ | |
| layout: preNormQuantPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: xF32Buffer}}, | |
| {binding: 1, resource: modShiftResource}, | |
| {binding: 2, resource: modScaleResource}, | |
| {binding: 3, resource: {buffer: xPackedBuffer}}, | |
| {binding: 4, resource: {buffer: xScalesBuffer}}, | |
| {binding: 5, resource: {buffer: quantXParamsBuffer}}, | |
| ], | |
| }) | |
| : null; | |
| const dot1BindGroup = device.createBindGroup({ | |
| layout: dot1Pipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: xPackedBuffer}}, | |
| {binding: 1, resource: {buffer: linear1WBuffer}}, | |
| {binding: 2, resource: {buffer: xScalesBuffer}}, | |
| {binding: 3, resource: {buffer: linear1ScaleBuffer}}, | |
| {binding: 4, resource: {buffer: linear1OutBuffer}}, | |
| {binding: 5, resource: {buffer: dot1ParamsBuffer}}, | |
| ], | |
| }); | |
| const attentionBindGroup = useRealAttention | |
| ? device.createBindGroup({ | |
| layout: attentionPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: linear1OutBuffer}}, | |
| {binding: 1, resource: {buffer: qNormBuffer}}, | |
| {binding: 2, resource: {buffer: kNormBuffer}}, | |
| {binding: 3, resource: {buffer: attentionOutBuffer}}, | |
| {binding: 4, resource: {buffer: attentionParamsBuffer}}, | |
| ], | |
| }) | |
| : null; | |
| const qkNormBindGroup = useRealAttention | |
| ? device.createBindGroup({ | |
| layout: qkNormPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: linear1OutBuffer}}, | |
| {binding: 1, resource: {buffer: queryScaleBuffer}}, | |
| {binding: 2, resource: {buffer: keyScaleBuffer}}, | |
| {binding: 3, resource: {buffer: qNormBuffer}}, | |
| {binding: 4, resource: {buffer: kNormBuffer}}, | |
| {binding: 5, resource: {buffer: attentionParamsBuffer}}, | |
| ...(useRope ? [{binding: 6, resource: {buffer: ropeFreqBuffer}}] : []), | |
| ], | |
| }) | |
| : null; | |
| const activationBindGroup = fuseActivationQuant | |
| ? null | |
| : device.createBindGroup({ | |
| layout: useRealAttention | |
| ? activationAttentionPipeline.getBindGroupLayout(0) | |
| : activationPipeline.getBindGroupLayout(0), | |
| entries: useRealAttention | |
| ? [ | |
| {binding: 0, resource: {buffer: linear1OutBuffer}}, | |
| {binding: 1, resource: {buffer: attentionOutBuffer}}, | |
| {binding: 2, resource: {buffer: linear2InBuffer}}, | |
| {binding: 3, resource: {buffer: activationParamsBuffer}}, | |
| ] | |
| : [ | |
| {binding: 0, resource: {buffer: linear1OutBuffer}}, | |
| {binding: 1, resource: {buffer: linear2InBuffer}}, | |
| {binding: 2, resource: {buffer: activationParamsBuffer}}, | |
| ], | |
| }); | |
| const activateQuantBindGroup = fuseActivationQuant | |
| ? (useRealAttention ? null : device.createBindGroup({ | |
| layout: activateQuantPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: linear1OutBuffer}}, | |
| {binding: 1, resource: {buffer: linear2InPackedBuffer}}, | |
| {binding: 2, resource: {buffer: linear2InScalesBuffer}}, | |
| {binding: 3, resource: {buffer: activateQuantParamsBuffer}}, | |
| ], | |
| })) | |
| : null; | |
| const activateAttentionQuantBindGroup = fuseActivationQuant && useRealAttention | |
| ? device.createBindGroup({ | |
| layout: activateAttentionQuantPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: linear1OutBuffer}}, | |
| {binding: 1, resource: {buffer: attentionOutBuffer}}, | |
| {binding: 2, resource: {buffer: linear2InPackedBuffer}}, | |
| {binding: 3, resource: {buffer: linear2InScalesBuffer}}, | |
| {binding: 4, resource: {buffer: activateQuantParamsBuffer}}, | |
| ], | |
| }) | |
| : null; | |
| const quantHiddenBindGroup = fuseActivationQuant | |
| ? null | |
| : device.createBindGroup({ | |
| layout: quantPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: linear2InBuffer}}, | |
| {binding: 1, resource: {buffer: linear2InPackedBuffer}}, | |
| {binding: 2, resource: {buffer: linear2InScalesBuffer}}, | |
| {binding: 3, resource: {buffer: quantHiddenParamsBuffer}}, | |
| ], | |
| }); | |
| const dot2BindGroup = device.createBindGroup({ | |
| layout: dot2Pipeline.getBindGroupLayout(0), | |
| entries: fuseResidualDot ? [ | |
| {binding: 0, resource: {buffer: linear2InPackedBuffer}}, | |
| {binding: 1, resource: {buffer: linear2WBuffer}}, | |
| {binding: 2, resource: {buffer: linear2InScalesBuffer}}, | |
| {binding: 3, resource: {buffer: linear2ScaleBuffer}}, | |
| {binding: 4, resource: {buffer: residualOutputBuffer}}, | |
| {binding: 5, resource: {buffer: dot2ParamsBuffer}}, | |
| {binding: 6, resource: {buffer: xF32Buffer}}, | |
| {binding: 7, resource: modGateResource}, | |
| ] : [ | |
| {binding: 0, resource: {buffer: linear2InPackedBuffer}}, | |
| {binding: 1, resource: {buffer: linear2WBuffer}}, | |
| {binding: 2, resource: {buffer: linear2InScalesBuffer}}, | |
| {binding: 3, resource: {buffer: linear2ScaleBuffer}}, | |
| {binding: 4, resource: {buffer: outputBuffer}}, | |
| {binding: 5, resource: {buffer: dot2ParamsBuffer}}, | |
| ], | |
| }); | |
| const residualParamsBuffer = correctSingleBlock && !fuseResidualDot | |
| ? createBuffer(device, new Uint32Array([n, linear2M, 0, 0]), GPUBufferUsage.UNIFORM) | |
| : null; | |
| const residualBindGroup = correctSingleBlock && !fuseResidualDot | |
| ? device.createBindGroup({ | |
| layout: residualPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: xF32Buffer}}, | |
| {binding: 1, resource: {buffer: outputBuffer}}, | |
| {binding: 2, resource: modGateResource}, | |
| {binding: 3, resource: {buffer: residualOutputBuffer}}, | |
| {binding: 4, resource: {buffer: residualParamsBuffer}}, | |
| ], | |
| }) | |
| : null; | |
| async function dispatch() { | |
| const encoder = device.createCommandEncoder(); | |
| let pass = encoder.beginComputePass(); | |
| const finishPass = () => { | |
| if (pass) { | |
| pass.end(); | |
| pass = null; | |
| } | |
| }; | |
| const beginPass = () => { | |
| if (!pass) pass = encoder.beginComputePass(); | |
| return pass; | |
| }; | |
| const runStage = (pipeline, bindGroup, x, y = undefined) => { | |
| const stagePass = beginPass(); | |
| stagePass.setPipeline(pipeline); | |
| stagePass.setBindGroup(0, bindGroup); | |
| if (y === undefined) stagePass.dispatchWorkgroups(x); | |
| else stagePass.dispatchWorkgroups(x, y); | |
| if (!singleComputePass) finishPass(); | |
| }; | |
| if (useRealModulation) { | |
| runStage(timeStage0.pipeline, timeStage0.bindGroup, timeStage0.workgroupsX, timeStage0.workgroupsY); | |
| runStage(timeSiluStage.pipeline, timeSiluStage.bindGroup, timeSiluStage.workgroupsX); | |
| runStage(timeStage1.pipeline, timeStage1.bindGroup, timeStage1.workgroupsX, timeStage1.workgroupsY); | |
| runStage(vecCastStage.pipeline, vecCastStage.bindGroup, vecCastStage.workgroupsX); | |
| runStage(modStage.pipeline, modStage.bindGroup, modStage.workgroupsX, modStage.workgroupsY); | |
| } | |
| runStage(correctSingleBlock ? preNormQuantPipeline : quantPipeline, correctSingleBlock ? preNormQuantBindGroup : quantXBindGroup, n); | |
| runStage(dot1Pipeline, dot1BindGroup, Math.ceil(linear1M / linear1TileCols), Math.ceil(n / rowsPerWorkgroup)); | |
| if (useRealAttention) { | |
| runStage(qkNormPipeline, qkNormBindGroup, n, 24); | |
| runStage(attentionPipeline, attentionBindGroup, useTiledAttention ? Math.ceil(n / attentionQueryRows) : n, 24); | |
| } | |
| if (fuseActivationQuant) { | |
| runStage(useRealAttention ? activateAttentionQuantPipeline : activateQuantPipeline, useRealAttention ? activateAttentionQuantBindGroup : activateQuantBindGroup, n); | |
| } else { | |
| runStage(useRealAttention ? activationAttentionPipeline : activationPipeline, activationBindGroup, Math.ceil(linear2K / 256), n); | |
| runStage(quantPipeline, quantHiddenBindGroup, n); | |
| } | |
| runStage(dot2Pipeline, dot2BindGroup, Math.ceil(linear2M / linear2TileCols), Math.ceil(n / rowsPerWorkgroup)); | |
| if (correctSingleBlock && !fuseResidualDot) { | |
| runStage(residualPipeline, residualBindGroup, Math.ceil(linear2M / 256), n); | |
| } | |
| finishPass(); | |
| device.queue.submit([encoder.finish()]); | |
| await device.queue.onSubmittedWorkDone(); | |
| } | |
| for (let i = 0; i < Number(config.warmupRuns ?? 1); ++i) { | |
| await dispatch(); | |
| } | |
| const times = []; | |
| for (let i = 0; i < Number(config.timedRuns ?? 3); ++i) { | |
| const start = performance.now(); | |
| await dispatch(); | |
| times.push(performance.now() - start); | |
| } | |
| let sample = null; | |
| if (config.readbackSample) { | |
| const count = Math.min(Number(config.readbackSample), n * linear2M); | |
| const values = await readFloat32Buffer(device, correctSingleBlock ? residualOutputBuffer : outputBuffer, count); | |
| let finite = 0; | |
| let maxAbs = 0; | |
| for (const value of values) { | |
| if (Number.isFinite(value)) finite += 1; | |
| maxAbs = Math.max(maxAbs, Math.abs(value)); | |
| } | |
| sample = {count, finite, max_abs: maxAbs, values: Array.from(values.slice(0, Math.min(8, values.length)))}; | |
| } | |
| const medianMs = median(times); | |
| const macs = n * inputK * linear1M + n * linear2K * linear2M; | |
| return { | |
| verdict: "custom-single-stream-mlp-bench-completed", | |
| config: { | |
| n, | |
| warmupRuns: Number(config.warmupRuns ?? 1), | |
| timedRuns: Number(config.timedRuns ?? 3), | |
| linear1TileCols, | |
| linear2TileCols, | |
| linear1WChunkCols, | |
| linear2WChunkCols, | |
| linear1OutputF16, | |
| singleComputePass, | |
| fuseActivationQuant, | |
| fuseResidualDot, | |
| correctSingleBlock, | |
| dynamicI8FromQ4: useDynamicI8FromQ4, | |
| realModulation: useRealModulation, | |
| blockIndex, | |
| realAttention: useRealAttention, | |
| rope: useRope, | |
| tiledAttention: useTiledAttention, | |
| attentionTileKeys, | |
| attentionQueryRows, | |
| textTokens, | |
| imageWidth, | |
| }, | |
| layers: { | |
| linear1: {source_node: linear1.manifest.source_node, shape: linear1.manifest.shape}, | |
| linear2: {source_node: linear2.manifest.source_node, shape: linear2.manifest.shape}, | |
| }, | |
| load: { | |
| total_ms: loadMs, | |
| linear1_convert_ms: linear1.convertMs ?? null, | |
| linear2_convert_ms: linear2.convertMs ?? null, | |
| linear1_fetch_ms: linear1.loadMs ?? null, | |
| linear2_fetch_ms: linear2.loadMs ?? null, | |
| linear1_cache_hit: linear1.cacheHit ?? null, | |
| linear2_cache_hit: linear2.cacheHit ?? null, | |
| linear1_prepacked_i8: linear1.prepackedI8 ?? false, | |
| linear2_prepacked_i8: linear2.prepackedI8 ?? false, | |
| }, | |
| adapter_features: Array.from(adapter.features).sort(), | |
| wgsl_language_features: Array.from(wgslLanguageFeatures).sort(), | |
| timed_ms: times, | |
| summary: { | |
| median_dispatch_ms: medianMs, | |
| effective_tmacs: macs / (medianMs / 1000) / 1e12, | |
| }, | |
| sample, | |
| }; | |
| } | |
| window.runCustomSingleStreamMlpBench = runCustomSingleStreamMlpBench; | |
| async function runCustomSingleStreamBlocksLoopBench(config = {}) { | |
| if (!config.fullManifest) { | |
| throw new Error("single blocks loop benchmark requires fullManifest"); | |
| } | |
| const baseUrl = config.bundleBaseUrl || "/runtime/custom_lowbit/full_transformer/"; | |
| const startBlockIndex = Math.max(0, Number(config.blockIndex ?? 0)); | |
| const blockCount = Math.max(1, Math.min(20 - startBlockIndex, Number(config.blockCount ?? 4))); | |
| const blockIndices = Array.from({length: blockCount}, (_, offset) => startBlockIndex + offset); | |
| const loadStart = performance.now(); | |
| const [linears, norms] = await Promise.all([ | |
| Promise.all(blockIndices.map(async (blockIndex) => ({ | |
| blockIndex, | |
| linear1: await loadI8BundleFromFullEntry(findFullManifestLinear(config.fullManifest, `single_blocks.${blockIndex}.linear1`), baseUrl), | |
| linear2: await loadI8BundleFromFullEntry(findFullManifestLinear(config.fullManifest, `single_blocks.${blockIndex}.linear2`), baseUrl), | |
| }))), | |
| Promise.all(blockIndices.map((blockIndex) => loadQkNormScales(config.fullManifest, baseUrl, `single_blocks.${blockIndex}.norm`))), | |
| ]); | |
| const loadMs = performance.now() - loadStart; | |
| const n = Number(config.n ?? 768); | |
| const inputK = 3072; | |
| const linear1M = 27648; | |
| const linear2K = 12288; | |
| const linear2M = 3072; | |
| const linear1KWords = linears[0].linear1.manifest.i8_dp4a.k_words; | |
| const linear2KWords = linears[0].linear2.manifest.i8_dp4a.k_words; | |
| const linear1TileCols = Number(config.linear1TileCols ?? 256); | |
| const linear2TileCols = Number(config.linear2TileCols ?? 256); | |
| const linear1WChunkCols = Number(config.linear1WChunkCols ?? 64); | |
| const linear2WChunkCols = Number(config.linear2WChunkCols ?? 64); | |
| const linear1OutputF16 = config.linear1OutputF16 !== false; | |
| const attentionTileKeys = Number(config.attentionTileKeys ?? 8); | |
| const attentionQueryRows = Number(config.attentionQueryRows ?? 16); | |
| const textTokens = Math.max(0, Math.min(n, Number(config.textTokens ?? 512))); | |
| const imageTokens = Math.max(0, n - textTokens); | |
| const imageWidth = Math.max(1, Number(config.imageWidth ?? Math.round(Math.sqrt(Math.max(1, imageTokens))))); | |
| const rowsPerWorkgroup = 32; | |
| if (n > 4608) { | |
| throw new Error("single blocks loop benchmark currently supports n <= 4608"); | |
| } | |
| const inputF32 = buildInputF32(n, inputK); | |
| const stateBytes = n * inputK * 4; | |
| const stateA = createBuffer(device, inputF32, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST); | |
| const stateB = createEmptyBuffer(device, stateBytes, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST); | |
| const modF32Buffer = createEmptyBuffer(device, 9216 * 4, GPUBufferUsage.STORAGE); | |
| const xPackedBuffer = createEmptyBuffer(device, n * linear1KWords * 4, GPUBufferUsage.STORAGE); | |
| const xScalesBuffer = createEmptyBuffer(device, n * 4, GPUBufferUsage.STORAGE); | |
| const linear1OutBuffer = createEmptyBuffer( | |
| device, | |
| n * linear1M * (linear1OutputF16 ? 2 : 4), | |
| GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, | |
| ); | |
| const attentionOutBuffer = createEmptyBuffer(device, n * inputK * 4, GPUBufferUsage.STORAGE); | |
| const qNormBuffer = createEmptyBuffer(device, n * inputK * 4, GPUBufferUsage.STORAGE); | |
| const kNormBuffer = createEmptyBuffer(device, n * inputK * 4, GPUBufferUsage.STORAGE); | |
| const ropeFreqBuffer = createBuffer(device, buildFluxRopeSinCos(n, textTokens, imageWidth), GPUBufferUsage.STORAGE); | |
| const linear2InPackedBuffer = createEmptyBuffer(device, n * linear2KWords * 4, GPUBufferUsage.STORAGE); | |
| const linear2InScalesBuffer = createEmptyBuffer(device, n * 4, GPUBufferUsage.STORAGE); | |
| const preNormQuantModule = device.createShaderModule({code: SINGLE_PRENORM_MOD_QUANT_SHADER}); | |
| const dot1Module = device.createShaderModule({ | |
| code: makeI8ScaledDotShader32xWide(linear1TileCols, linear1WChunkCols, linear1OutputF16), | |
| }); | |
| const qkNormCode = SINGLE_QK_NORM_ROPE_SHADER; | |
| const attentionCode = makeSingleAttentionTiledShader(attentionTileKeys, attentionQueryRows); | |
| const qkNormModule = device.createShaderModule({code: linear1OutputF16 ? makeLinear1F16ConsumerShader(qkNormCode) : qkNormCode}); | |
| const attentionModule = device.createShaderModule({code: linear1OutputF16 ? makeLinear1F16ConsumerShader(attentionCode) : attentionCode}); | |
| const activateAttentionQuantModule = device.createShaderModule({ | |
| code: linear1OutputF16 | |
| ? makeLinear1F16ConsumerShader(SINGLE_MLP_ACTIVATE_ATTENTION_QUANT_SHADER) | |
| : SINGLE_MLP_ACTIVATE_ATTENTION_QUANT_SHADER, | |
| }); | |
| const dot2Module = device.createShaderModule({ | |
| code: makeI8ScaledDotResidualShader32xWide(linear2TileCols, linear2WChunkCols), | |
| }); | |
| const preNormQuantPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: preNormQuantModule, entryPoint: "main"}}); | |
| const dot1Pipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: dot1Module, entryPoint: "main"}}); | |
| const qkNormPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: qkNormModule, entryPoint: "main"}}); | |
| const attentionPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: attentionModule, entryPoint: "main"}}); | |
| const activateAttentionQuantPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: activateAttentionQuantModule, entryPoint: "main"}}); | |
| const dot2Pipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: dot2Module, entryPoint: "main"}}); | |
| const timestepBuffer = createBuffer(device, makeTimestepEmbeddingF16(Number(config.timestep ?? 1.0)), GPUBufferUsage.STORAGE); | |
| const timeHiddenF32Buffer = createEmptyBuffer(device, 3072 * 4, GPUBufferUsage.STORAGE); | |
| const timeHiddenF16Buffer = createEmptyBuffer(device, 3072 * 2, GPUBufferUsage.STORAGE); | |
| const vecF32Buffer = createEmptyBuffer(device, 3072 * 4, GPUBufferUsage.STORAGE); | |
| const vecF16Buffer = createEmptyBuffer(device, 3072 * 2, GPUBufferUsage.STORAGE); | |
| const timeStage0 = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "time_in.in_layer"), baseUrl, timestepBuffer, timeHiddenF32Buffer, 1, 96); | |
| const timeSiluStage = await createF32ToF16Stage(device, timeHiddenF32Buffer, timeHiddenF16Buffer, 3072, true); | |
| const timeStage1 = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "time_in.out_layer"), baseUrl, timeHiddenF16Buffer, vecF32Buffer, 1, 96); | |
| const vecCastStage = await createF32ToF16Stage(device, vecF32Buffer, vecF16Buffer, 3072, true); | |
| const modStage = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "single_stream_modulation.lin"), baseUrl, vecF16Buffer, modF32Buffer, 1, 96); | |
| const quantXParamsBuffer = createBuffer(device, new Uint32Array([n, inputK, linear1M, linear1KWords]), GPUBufferUsage.UNIFORM); | |
| const dot1ParamsBuffer = createBuffer(device, new Uint32Array([n, inputK, linear1M, linear1KWords, 0, 0, 0, 0]), GPUBufferUsage.UNIFORM); | |
| const attentionParamsBuffer = createBuffer(device, new Uint32Array([n, linear1M, 24, 128, 9216, inputK, textTokens, imageWidth]), GPUBufferUsage.UNIFORM); | |
| const activateQuantParamsBuffer = createBuffer(device, new Uint32Array([n, linear1M, linear2K, inputK, 9216, 9216, linear2KWords, 0]), GPUBufferUsage.UNIFORM); | |
| const dot2ParamsBuffer = createBuffer(device, new Uint32Array([n, linear2K, linear2M, linear2KWords, 0, 0, 0, 0]), GPUBufferUsage.UNIFORM); | |
| const modStride = inputK * 4; | |
| const modShiftResource = {buffer: modF32Buffer, offset: 0}; | |
| const modScaleResource = {buffer: modF32Buffer, offset: modStride}; | |
| const modGateResource = {buffer: modF32Buffer, offset: modStride * 2}; | |
| const makePreNormBindGroup = (stateBuffer) => device.createBindGroup({ | |
| layout: preNormQuantPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: stateBuffer}}, | |
| {binding: 1, resource: modShiftResource}, | |
| {binding: 2, resource: modScaleResource}, | |
| {binding: 3, resource: {buffer: xPackedBuffer}}, | |
| {binding: 4, resource: {buffer: xScalesBuffer}}, | |
| {binding: 5, resource: {buffer: quantXParamsBuffer}}, | |
| ], | |
| }); | |
| const preNormBindGroups = [makePreNormBindGroup(stateA), makePreNormBindGroup(stateB)]; | |
| const attentionBindGroup = device.createBindGroup({ | |
| layout: attentionPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: linear1OutBuffer}}, | |
| {binding: 1, resource: {buffer: qNormBuffer}}, | |
| {binding: 2, resource: {buffer: kNormBuffer}}, | |
| {binding: 3, resource: {buffer: attentionOutBuffer}}, | |
| {binding: 4, resource: {buffer: attentionParamsBuffer}}, | |
| ], | |
| }); | |
| const activateAttentionQuantBindGroup = device.createBindGroup({ | |
| layout: activateAttentionQuantPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: linear1OutBuffer}}, | |
| {binding: 1, resource: {buffer: attentionOutBuffer}}, | |
| {binding: 2, resource: {buffer: linear2InPackedBuffer}}, | |
| {binding: 3, resource: {buffer: linear2InScalesBuffer}}, | |
| {binding: 4, resource: {buffer: activateQuantParamsBuffer}}, | |
| ], | |
| }); | |
| const blockStages = []; | |
| for (let i = 0; i < blockCount; ++i) { | |
| const {blockIndex, linear1, linear2} = linears[i]; | |
| const norm = norms[i]; | |
| const linear1WBuffer = createBuffer(device, linear1.weight, GPUBufferUsage.STORAGE); | |
| const linear1ScaleBuffer = createBuffer(device, linear1.scales, GPUBufferUsage.STORAGE); | |
| const linear2WBuffer = createBuffer(device, linear2.weight, GPUBufferUsage.STORAGE); | |
| const linear2ScaleBuffer = createBuffer(device, linear2.scales, GPUBufferUsage.STORAGE); | |
| const queryScaleBuffer = createBuffer(device, norm.queryScale, GPUBufferUsage.STORAGE); | |
| const keyScaleBuffer = createBuffer(device, norm.keyScale, GPUBufferUsage.STORAGE); | |
| const dot1BindGroup = device.createBindGroup({ | |
| layout: dot1Pipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: xPackedBuffer}}, | |
| {binding: 1, resource: {buffer: linear1WBuffer}}, | |
| {binding: 2, resource: {buffer: xScalesBuffer}}, | |
| {binding: 3, resource: {buffer: linear1ScaleBuffer}}, | |
| {binding: 4, resource: {buffer: linear1OutBuffer}}, | |
| {binding: 5, resource: {buffer: dot1ParamsBuffer}}, | |
| ], | |
| }); | |
| const qkNormBindGroup = device.createBindGroup({ | |
| layout: qkNormPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: linear1OutBuffer}}, | |
| {binding: 1, resource: {buffer: queryScaleBuffer}}, | |
| {binding: 2, resource: {buffer: keyScaleBuffer}}, | |
| {binding: 3, resource: {buffer: qNormBuffer}}, | |
| {binding: 4, resource: {buffer: kNormBuffer}}, | |
| {binding: 5, resource: {buffer: attentionParamsBuffer}}, | |
| {binding: 6, resource: {buffer: ropeFreqBuffer}}, | |
| ], | |
| }); | |
| const makeDot2BindGroup = (inputBuffer, outputBuffer) => device.createBindGroup({ | |
| layout: dot2Pipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: linear2InPackedBuffer}}, | |
| {binding: 1, resource: {buffer: linear2WBuffer}}, | |
| {binding: 2, resource: {buffer: linear2InScalesBuffer}}, | |
| {binding: 3, resource: {buffer: linear2ScaleBuffer}}, | |
| {binding: 4, resource: {buffer: outputBuffer}}, | |
| {binding: 5, resource: {buffer: dot2ParamsBuffer}}, | |
| {binding: 6, resource: {buffer: inputBuffer}}, | |
| {binding: 7, resource: modGateResource}, | |
| ], | |
| }); | |
| blockStages.push({ | |
| blockIndex, | |
| dot1BindGroup, | |
| qkNormBindGroup, | |
| dot2BindGroups: [makeDot2BindGroup(stateA, stateB), makeDot2BindGroup(stateB, stateA)], | |
| }); | |
| } | |
| const runStage = (pass, pipeline, bindGroup, x, y = undefined) => { | |
| pass.setPipeline(pipeline); | |
| pass.setBindGroup(0, bindGroup); | |
| if (y === undefined) pass.dispatchWorkgroups(x); | |
| else pass.dispatchWorkgroups(x, y); | |
| }; | |
| async function dispatch() { | |
| const encoder = device.createCommandEncoder(); | |
| const pass = encoder.beginComputePass(); | |
| runStage(pass, timeStage0.pipeline, timeStage0.bindGroup, timeStage0.workgroupsX, timeStage0.workgroupsY); | |
| runStage(pass, timeSiluStage.pipeline, timeSiluStage.bindGroup, timeSiluStage.workgroupsX); | |
| runStage(pass, timeStage1.pipeline, timeStage1.bindGroup, timeStage1.workgroupsX, timeStage1.workgroupsY); | |
| runStage(pass, vecCastStage.pipeline, vecCastStage.bindGroup, vecCastStage.workgroupsX); | |
| runStage(pass, modStage.pipeline, modStage.bindGroup, modStage.workgroupsX, modStage.workgroupsY); | |
| for (let i = 0; i < blockStages.length; ++i) { | |
| const stage = blockStages[i]; | |
| const ping = i & 1; | |
| runStage(pass, preNormQuantPipeline, preNormBindGroups[ping], n); | |
| runStage(pass, dot1Pipeline, stage.dot1BindGroup, Math.ceil(linear1M / linear1TileCols), Math.ceil(n / rowsPerWorkgroup)); | |
| runStage(pass, qkNormPipeline, stage.qkNormBindGroup, n, 24); | |
| runStage(pass, attentionPipeline, attentionBindGroup, Math.ceil(n / attentionQueryRows), 24); | |
| runStage(pass, activateAttentionQuantPipeline, activateAttentionQuantBindGroup, n); | |
| runStage(pass, dot2Pipeline, stage.dot2BindGroups[ping], Math.ceil(linear2M / linear2TileCols), Math.ceil(n / rowsPerWorkgroup)); | |
| } | |
| pass.end(); | |
| device.queue.submit([encoder.finish()]); | |
| await device.queue.onSubmittedWorkDone(); | |
| } | |
| for (let i = 0; i < Number(config.warmupRuns ?? 1); ++i) { | |
| await dispatch(); | |
| } | |
| const times = []; | |
| for (let i = 0; i < Number(config.timedRuns ?? 3); ++i) { | |
| const start = performance.now(); | |
| await dispatch(); | |
| times.push(performance.now() - start); | |
| } | |
| let sample = null; | |
| if (config.readbackSample) { | |
| const finalBuffer = blockCount % 2 === 0 ? stateA : stateB; | |
| const count = Math.min(Number(config.readbackSample), n * inputK); | |
| const values = await readFloat32Buffer(device, finalBuffer, count); | |
| let finite = 0; | |
| let maxAbs = 0; | |
| for (const value of values) { | |
| if (Number.isFinite(value)) finite += 1; | |
| maxAbs = Math.max(maxAbs, Math.abs(value)); | |
| } | |
| sample = {count, finite, max_abs: maxAbs, values: Array.from(values.slice(0, Math.min(8, values.length)))}; | |
| } | |
| const medianMs = median(times); | |
| const macsPerBlock = n * inputK * linear1M + n * linear2K * linear2M; | |
| return { | |
| verdict: "custom-single-stream-blocks-loop-completed", | |
| config: { | |
| startBlockIndex, | |
| blockCount, | |
| blockIndices, | |
| n, | |
| warmupRuns: Number(config.warmupRuns ?? 1), | |
| timedRuns: Number(config.timedRuns ?? 3), | |
| linear1TileCols, | |
| linear2TileCols, | |
| linear1WChunkCols, | |
| linear2WChunkCols, | |
| linear1OutputF16, | |
| attentionTileKeys, | |
| attentionQueryRows, | |
| textTokens, | |
| imageWidth, | |
| }, | |
| load: {total_ms: loadMs}, | |
| summary: { | |
| median_dispatch_ms: medianMs, | |
| per_block_median_ms: medianMs / blockCount, | |
| effective_tmacs: (macsPerBlock * blockCount) / (medianMs / 1000) / 1e12, | |
| }, | |
| timed_ms: times, | |
| sample, | |
| }; | |
| } | |
| window.runCustomSingleStreamBlocksLoopBench = runCustomSingleStreamBlocksLoopBench; | |
| async function runCustomDoubleStreamBlockBench(config = {}) { | |
| if (!config.fullManifest) { | |
| throw new Error("double-stream block benchmark requires fullManifest"); | |
| } | |
| const baseUrl = config.bundleBaseUrl || "/runtime/custom_lowbit/full_transformer/"; | |
| const blockIndex = Math.max(0, Math.min(4, Number(config.blockIndex ?? 0))); | |
| const imgRows = Number(config.imageTokens ?? 256); | |
| const txtRows = Number(config.textTokens ?? 512); | |
| const jointRows = imgRows + txtRows; | |
| const hidden = 3072; | |
| const qkvM = 9216; | |
| const mlp1M = 18432; | |
| const mlpK = 9216; | |
| const kWordsHidden = hidden / 4; | |
| const kWordsMlp = mlpK / 4; | |
| const maxRows = Math.max(imgRows, txtRows); | |
| const imageWidth = Math.max(1, Number(config.imageWidth ?? Math.round(Math.sqrt(Math.max(1, imgRows))))); | |
| const linearTileCols = Number(config.linearTileCols ?? 128); | |
| const projTileCols = Number(config.projTileCols ?? 128); | |
| const wChunkCols = Number(config.wChunkCols ?? 64); | |
| const attentionTileKeys = Number(config.attentionTileKeys ?? 8); | |
| const attentionQueryRows = Number(config.attentionQueryRows ?? 16); | |
| const loadStart = performance.now(); | |
| const [ | |
| imgQkv, txtQkv, imgProj, txtProj, imgMlp0, imgMlp2, txtMlp0, txtMlp2, | |
| imgNorm, txtNorm, | |
| ] = await Promise.all([ | |
| loadI8BundleFromFullEntry(findFullManifestLinear(config.fullManifest, `double_blocks.${blockIndex}.img_attn.qkv`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(config.fullManifest, `double_blocks.${blockIndex}.txt_attn.qkv`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(config.fullManifest, `double_blocks.${blockIndex}.img_attn.proj`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(config.fullManifest, `double_blocks.${blockIndex}.txt_attn.proj`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(config.fullManifest, `double_blocks.${blockIndex}.img_mlp.0`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(config.fullManifest, `double_blocks.${blockIndex}.img_mlp.2`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(config.fullManifest, `double_blocks.${blockIndex}.txt_mlp.0`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(config.fullManifest, `double_blocks.${blockIndex}.txt_mlp.2`), baseUrl), | |
| loadQkNormScales(config.fullManifest, baseUrl, `double_blocks.${blockIndex}.img_attn.norm`), | |
| loadQkNormScales(config.fullManifest, baseUrl, `double_blocks.${blockIndex}.txt_attn.norm`), | |
| ]); | |
| const loadMs = performance.now() - loadStart; | |
| const imgInput = buildInputF32(imgRows, hidden); | |
| const txtInput = buildInputF32(txtRows, hidden); | |
| const imgState0 = createBuffer(device, imgInput, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| const txtState0 = createBuffer(device, txtInput, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| const imgState1 = createEmptyBuffer(device, imgRows * hidden * 4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| const txtState1 = createEmptyBuffer(device, txtRows * hidden * 4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| const imgState2 = createEmptyBuffer(device, imgRows * hidden * 4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| const txtState2 = createEmptyBuffer(device, txtRows * hidden * 4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| const modImgBuffer = createEmptyBuffer(device, 18432 * 4, GPUBufferUsage.STORAGE); | |
| const modTxtBuffer = createEmptyBuffer(device, 18432 * 4, GPUBufferUsage.STORAGE); | |
| const packedHiddenBuffer = createEmptyBuffer(device, maxRows * kWordsHidden * 4, GPUBufferUsage.STORAGE); | |
| const scalesBuffer = createEmptyBuffer(device, maxRows * 4, GPUBufferUsage.STORAGE); | |
| const packedMlpBuffer = createEmptyBuffer(device, maxRows * kWordsMlp * 4, GPUBufferUsage.STORAGE); | |
| const mlpScalesBuffer = createEmptyBuffer(device, maxRows * 4, GPUBufferUsage.STORAGE); | |
| const imgQkvBuffer = createEmptyBuffer(device, imgRows * qkvM * 2, GPUBufferUsage.STORAGE); | |
| const txtQkvBuffer = createEmptyBuffer(device, txtRows * qkvM * 2, GPUBufferUsage.STORAGE); | |
| const qNormBuffer = createEmptyBuffer(device, jointRows * hidden * 4, GPUBufferUsage.STORAGE); | |
| const kNormBuffer = createEmptyBuffer(device, jointRows * hidden * 4, GPUBufferUsage.STORAGE); | |
| const vBuffer = createEmptyBuffer(device, jointRows * hidden * 4, GPUBufferUsage.STORAGE); | |
| const attentionOutBuffer = createEmptyBuffer(device, jointRows * hidden * 4, GPUBufferUsage.STORAGE); | |
| const mlp0Buffer = createEmptyBuffer(device, maxRows * mlp1M * 2, GPUBufferUsage.STORAGE); | |
| const ropeFreqBuffer = createBuffer(device, buildFluxRopeFrequencies(), GPUBufferUsage.STORAGE); | |
| const preNormQuantModule = device.createShaderModule({code: SINGLE_PRENORM_MOD_QUANT_SHADER}); | |
| const quantOffsetModule = device.createShaderModule({code: QUANTIZE_X_F32_OFFSET_SHADER}); | |
| const dotHiddenToQkvModule = device.createShaderModule({code: makeI8ScaledDotShader32xWide(linearTileCols, wChunkCols, true)}); | |
| const qkvNormModule = device.createShaderModule({code: DOUBLE_QKV_NORM_ROPE_SHADER}); | |
| const attentionModule = device.createShaderModule({code: makeJointAttentionTiledShader(attentionTileKeys, attentionQueryRows)}); | |
| const dotProjModule = device.createShaderModule({code: makeI8ScaledDotResidualShader32xWide(projTileCols, wChunkCols)}); | |
| const dotMlp0Module = device.createShaderModule({code: makeI8ScaledDotShader32xWide(linearTileCols, wChunkCols, true)}); | |
| const activateMlpModule = device.createShaderModule({code: makeLinear1F16ConsumerShader(SINGLE_MLP_ACTIVATE_QUANT_SHADER)}); | |
| const dotMlp2Module = device.createShaderModule({code: makeI8ScaledDotResidualShader32xWide(projTileCols, wChunkCols)}); | |
| const preNormQuantPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: preNormQuantModule, entryPoint: "main"}}); | |
| const quantOffsetPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: quantOffsetModule, entryPoint: "main"}}); | |
| const dotHiddenToQkvPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: dotHiddenToQkvModule, entryPoint: "main"}}); | |
| const qkvNormPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: qkvNormModule, entryPoint: "main"}}); | |
| const attentionPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: attentionModule, entryPoint: "main"}}); | |
| const dotProjPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: dotProjModule, entryPoint: "main"}}); | |
| const dotMlp0Pipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: dotMlp0Module, entryPoint: "main"}}); | |
| const activateMlpPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: activateMlpModule, entryPoint: "main"}}); | |
| const dotMlp2Pipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: dotMlp2Module, entryPoint: "main"}}); | |
| const timestepBuffer = createBuffer(device, makeTimestepEmbeddingF16(Number(config.timestep ?? 1.0)), GPUBufferUsage.STORAGE); | |
| const timeHiddenF32Buffer = createEmptyBuffer(device, hidden * 4, GPUBufferUsage.STORAGE); | |
| const timeHiddenF16Buffer = createEmptyBuffer(device, hidden * 2, GPUBufferUsage.STORAGE); | |
| const vecF32Buffer = createEmptyBuffer(device, hidden * 4, GPUBufferUsage.STORAGE); | |
| const vecF16Buffer = createEmptyBuffer(device, hidden * 2, GPUBufferUsage.STORAGE); | |
| const timeStage0 = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "time_in.in_layer"), baseUrl, timestepBuffer, timeHiddenF32Buffer, 1, 96); | |
| const timeSiluStage = await createF32ToF16Stage(device, timeHiddenF32Buffer, timeHiddenF16Buffer, hidden, true); | |
| const timeStage1 = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "time_in.out_layer"), baseUrl, timeHiddenF16Buffer, vecF32Buffer, 1, 96); | |
| const vecCastStage = await createF32ToF16Stage(device, vecF32Buffer, vecF16Buffer, hidden, true); | |
| const modImgStage = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "double_stream_modulation_img.lin"), baseUrl, vecF16Buffer, modImgBuffer, 1, 96); | |
| const modTxtStage = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "double_stream_modulation_txt.lin"), baseUrl, vecF16Buffer, modTxtBuffer, 1, 96); | |
| const makeLinearBuffers = (bundle) => ({ | |
| w: createImmutableBuffer(device, bundle.weight, GPUBufferUsage.STORAGE, bundle.weightCacheKey), | |
| s: createImmutableBuffer(device, bundle.scales, GPUBufferUsage.STORAGE, bundle.scalesCacheKey), | |
| }); | |
| const weights = { | |
| imgQkv: makeLinearBuffers(imgQkv), | |
| txtQkv: makeLinearBuffers(txtQkv), | |
| imgProj: makeLinearBuffers(imgProj), | |
| txtProj: makeLinearBuffers(txtProj), | |
| imgMlp0: makeLinearBuffers(imgMlp0), | |
| imgMlp2: makeLinearBuffers(imgMlp2), | |
| txtMlp0: makeLinearBuffers(txtMlp0), | |
| txtMlp2: makeLinearBuffers(txtMlp2), | |
| }; | |
| const imgQueryScaleBuffer = createBuffer(device, imgNorm.queryScale, GPUBufferUsage.STORAGE); | |
| const imgKeyScaleBuffer = createBuffer(device, imgNorm.keyScale, GPUBufferUsage.STORAGE); | |
| const txtQueryScaleBuffer = createBuffer(device, txtNorm.queryScale, GPUBufferUsage.STORAGE); | |
| const txtKeyScaleBuffer = createBuffer(device, txtNorm.keyScale, GPUBufferUsage.STORAGE); | |
| const makeParams = (values) => { | |
| const typed = new Uint32Array(values); | |
| return createImmutableBuffer(device, typed, GPUBufferUsage.UNIFORM, `${scratchKey}|params|${Array.from(typed).join(",")}`); | |
| }; | |
| const params = { | |
| imgPre: makeParams([imgRows, hidden, qkvM, kWordsHidden]), | |
| txtPre: makeParams([txtRows, hidden, qkvM, kWordsHidden]), | |
| imgDotQkv: makeParams([imgRows, hidden, qkvM, kWordsHidden, 0, 0, 0, 0]), | |
| txtDotQkv: makeParams([txtRows, hidden, qkvM, kWordsHidden, 0, 0, 0, 0]), | |
| imgQkvNorm: makeParams([imgRows, qkvM, 24, 128, hidden, txtRows, txtRows, imageWidth]), | |
| txtQkvNorm: makeParams([txtRows, qkvM, 24, 128, hidden, 0, txtRows, imageWidth]), | |
| attention: makeParams([jointRows, 24, 128, hidden]), | |
| imgQuantAttn: makeParams([imgRows, hidden, kWordsHidden, txtRows, hidden]), | |
| txtQuantAttn: makeParams([txtRows, hidden, kWordsHidden, 0, hidden]), | |
| imgProj: makeParams([imgRows, hidden, hidden, kWordsHidden, 0, 0, 0, 0]), | |
| txtProj: makeParams([txtRows, hidden, hidden, kWordsHidden, 0, 0, 0, 0]), | |
| imgMlpPre: makeParams([imgRows, hidden, mlp1M, kWordsHidden]), | |
| txtMlpPre: makeParams([txtRows, hidden, mlp1M, kWordsHidden]), | |
| imgMlp0: makeParams([imgRows, hidden, mlp1M, kWordsHidden, 0, 0, 0, 0]), | |
| txtMlp0: makeParams([txtRows, hidden, mlp1M, kWordsHidden, 0, 0, 0, 0]), | |
| imgAct: makeParams([imgRows, mlp1M, mlpK, 0, 0, mlpK, kWordsMlp, 0]), | |
| txtAct: makeParams([txtRows, mlp1M, mlpK, 0, 0, mlpK, kWordsMlp, 0]), | |
| imgMlp2: makeParams([imgRows, mlpK, hidden, kWordsMlp, 0, 0, 0, 0]), | |
| txtMlp2: makeParams([txtRows, mlpK, hidden, kWordsMlp, 0, 0, 0, 0]), | |
| }; | |
| const hiddenModOffsets = { | |
| mod1Shift: 0, | |
| mod1Scale: hidden * 4, | |
| mod1Gate: hidden * 8, | |
| mod2Shift: hidden * 12, | |
| mod2Scale: hidden * 16, | |
| mod2Gate: hidden * 20, | |
| }; | |
| const res = (buffer, offset = 0) => offset ? {buffer, offset} : {buffer}; | |
| const makePreNormBind = (state, modBuffer, shiftOffset, scaleOffset, paramsBuffer) => device.createBindGroup({ | |
| layout: preNormQuantPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: state}}, | |
| {binding: 1, resource: res(modBuffer, shiftOffset)}, | |
| {binding: 2, resource: res(modBuffer, scaleOffset)}, | |
| {binding: 3, resource: {buffer: packedHiddenBuffer}}, | |
| {binding: 4, resource: {buffer: scalesBuffer}}, | |
| {binding: 5, resource: {buffer: paramsBuffer}}, | |
| ], | |
| }); | |
| const makeDotBind = (pipeline, packed, packedScales, w, wScales, out, paramsBuffer) => device.createBindGroup({ | |
| layout: pipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: packed}}, | |
| {binding: 1, resource: {buffer: w}}, | |
| {binding: 2, resource: {buffer: packedScales}}, | |
| {binding: 3, resource: {buffer: wScales}}, | |
| {binding: 4, resource: {buffer: out}}, | |
| {binding: 5, resource: {buffer: paramsBuffer}}, | |
| ], | |
| }); | |
| const makeResidualDotBind = (pipeline, packed, packedScales, w, wScales, out, paramsBuffer, residual, gate) => device.createBindGroup({ | |
| layout: pipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: packed}}, | |
| {binding: 1, resource: {buffer: w}}, | |
| {binding: 2, resource: {buffer: packedScales}}, | |
| {binding: 3, resource: {buffer: wScales}}, | |
| {binding: 4, resource: {buffer: out}}, | |
| {binding: 5, resource: {buffer: paramsBuffer}}, | |
| {binding: 6, resource: {buffer: residual}}, | |
| {binding: 7, resource: gate}, | |
| ], | |
| }); | |
| const bind = { | |
| imgQkvPre: makePreNormBind(imgState0, modImgBuffer, hiddenModOffsets.mod1Shift, hiddenModOffsets.mod1Scale, params.imgPre), | |
| txtQkvPre: makePreNormBind(txtState0, modTxtBuffer, hiddenModOffsets.mod1Shift, hiddenModOffsets.mod1Scale, params.txtPre), | |
| imgQkvDot: makeDotBind(dotHiddenToQkvPipeline, packedHiddenBuffer, scalesBuffer, weights.imgQkv.w, weights.imgQkv.s, imgQkvBuffer, params.imgDotQkv), | |
| txtQkvDot: makeDotBind(dotHiddenToQkvPipeline, packedHiddenBuffer, scalesBuffer, weights.txtQkv.w, weights.txtQkv.s, txtQkvBuffer, params.txtDotQkv), | |
| imgQkvNorm: device.createBindGroup({ | |
| layout: qkvNormPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: imgQkvBuffer}}, | |
| {binding: 1, resource: {buffer: imgQueryScaleBuffer}}, | |
| {binding: 2, resource: {buffer: imgKeyScaleBuffer}}, | |
| {binding: 3, resource: {buffer: qNormBuffer}}, | |
| {binding: 4, resource: {buffer: kNormBuffer}}, | |
| {binding: 5, resource: {buffer: vBuffer}}, | |
| {binding: 6, resource: {buffer: params.imgQkvNorm}}, | |
| {binding: 7, resource: {buffer: ropeFreqBuffer}}, | |
| ], | |
| }), | |
| txtQkvNorm: device.createBindGroup({ | |
| layout: qkvNormPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: txtQkvBuffer}}, | |
| {binding: 1, resource: {buffer: txtQueryScaleBuffer}}, | |
| {binding: 2, resource: {buffer: txtKeyScaleBuffer}}, | |
| {binding: 3, resource: {buffer: qNormBuffer}}, | |
| {binding: 4, resource: {buffer: kNormBuffer}}, | |
| {binding: 5, resource: {buffer: vBuffer}}, | |
| {binding: 6, resource: {buffer: params.txtQkvNorm}}, | |
| {binding: 7, resource: {buffer: ropeFreqBuffer}}, | |
| ], | |
| }), | |
| attention: device.createBindGroup({ | |
| layout: attentionPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: vBuffer}}, | |
| {binding: 1, resource: {buffer: qNormBuffer}}, | |
| {binding: 2, resource: {buffer: kNormBuffer}}, | |
| {binding: 3, resource: {buffer: attentionOutBuffer}}, | |
| {binding: 4, resource: {buffer: params.attention}}, | |
| ], | |
| }), | |
| }; | |
| bind.imgQuantAttn = device.createBindGroup({ | |
| layout: quantOffsetPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: attentionOutBuffer}}, | |
| {binding: 1, resource: {buffer: packedHiddenBuffer}}, | |
| {binding: 2, resource: {buffer: scalesBuffer}}, | |
| {binding: 3, resource: {buffer: params.imgQuantAttn}}, | |
| ], | |
| }); | |
| bind.txtQuantAttn = device.createBindGroup({ | |
| layout: quantOffsetPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: attentionOutBuffer}}, | |
| {binding: 1, resource: {buffer: packedHiddenBuffer}}, | |
| {binding: 2, resource: {buffer: scalesBuffer}}, | |
| {binding: 3, resource: {buffer: params.txtQuantAttn}}, | |
| ], | |
| }); | |
| bind.imgProj = makeResidualDotBind(dotProjPipeline, packedHiddenBuffer, scalesBuffer, weights.imgProj.w, weights.imgProj.s, imgState1, params.imgProj, imgState0, res(modImgBuffer, hiddenModOffsets.mod1Gate)); | |
| bind.txtProj = makeResidualDotBind(dotProjPipeline, packedHiddenBuffer, scalesBuffer, weights.txtProj.w, weights.txtProj.s, txtState1, params.txtProj, txtState0, res(modTxtBuffer, hiddenModOffsets.mod1Gate)); | |
| bind.imgMlpPre = makePreNormBind(imgState1, modImgBuffer, hiddenModOffsets.mod2Shift, hiddenModOffsets.mod2Scale, params.imgMlpPre); | |
| bind.txtMlpPre = makePreNormBind(txtState1, modTxtBuffer, hiddenModOffsets.mod2Shift, hiddenModOffsets.mod2Scale, params.txtMlpPre); | |
| bind.imgMlp0 = makeDotBind(dotMlp0Pipeline, packedHiddenBuffer, scalesBuffer, weights.imgMlp0.w, weights.imgMlp0.s, mlp0Buffer, params.imgMlp0); | |
| bind.txtMlp0 = makeDotBind(dotMlp0Pipeline, packedHiddenBuffer, scalesBuffer, weights.txtMlp0.w, weights.txtMlp0.s, mlp0Buffer, params.txtMlp0); | |
| bind.imgAct = device.createBindGroup({ | |
| layout: activateMlpPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: mlp0Buffer}}, | |
| {binding: 1, resource: {buffer: packedMlpBuffer}}, | |
| {binding: 2, resource: {buffer: mlpScalesBuffer}}, | |
| {binding: 3, resource: {buffer: params.imgAct}}, | |
| ], | |
| }); | |
| bind.txtAct = device.createBindGroup({ | |
| layout: activateMlpPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: mlp0Buffer}}, | |
| {binding: 1, resource: {buffer: packedMlpBuffer}}, | |
| {binding: 2, resource: {buffer: mlpScalesBuffer}}, | |
| {binding: 3, resource: {buffer: params.txtAct}}, | |
| ], | |
| }); | |
| bind.imgMlp2 = makeResidualDotBind(dotMlp2Pipeline, packedMlpBuffer, mlpScalesBuffer, weights.imgMlp2.w, weights.imgMlp2.s, imgState2, params.imgMlp2, imgState1, res(modImgBuffer, hiddenModOffsets.mod2Gate)); | |
| bind.txtMlp2 = makeResidualDotBind(dotMlp2Pipeline, packedMlpBuffer, mlpScalesBuffer, weights.txtMlp2.w, weights.txtMlp2.s, txtState2, params.txtMlp2, txtState1, res(modTxtBuffer, hiddenModOffsets.mod2Gate)); | |
| const runStage = (pass, pipeline, bindGroup, x, y = undefined) => { | |
| pass.setPipeline(pipeline); | |
| pass.setBindGroup(0, bindGroup); | |
| if (y === undefined) pass.dispatchWorkgroups(x); | |
| else pass.dispatchWorkgroups(x, y); | |
| }; | |
| async function dispatch() { | |
| const encoder = device.createCommandEncoder(); | |
| const pass = encoder.beginComputePass(); | |
| runStage(pass, timeStage0.pipeline, timeStage0.bindGroup, timeStage0.workgroupsX, timeStage0.workgroupsY); | |
| runStage(pass, timeSiluStage.pipeline, timeSiluStage.bindGroup, timeSiluStage.workgroupsX); | |
| runStage(pass, timeStage1.pipeline, timeStage1.bindGroup, timeStage1.workgroupsX, timeStage1.workgroupsY); | |
| runStage(pass, vecCastStage.pipeline, vecCastStage.bindGroup, vecCastStage.workgroupsX); | |
| runStage(pass, modImgStage.pipeline, modImgStage.bindGroup, modImgStage.workgroupsX, modImgStage.workgroupsY); | |
| runStage(pass, modTxtStage.pipeline, modTxtStage.bindGroup, modTxtStage.workgroupsX, modTxtStage.workgroupsY); | |
| runStage(pass, preNormQuantPipeline, bind.imgQkvPre, imgRows); | |
| runStage(pass, dotHiddenToQkvPipeline, bind.imgQkvDot, Math.ceil(qkvM / linearTileCols), Math.ceil(imgRows / 32)); | |
| runStage(pass, preNormQuantPipeline, bind.txtQkvPre, txtRows); | |
| runStage(pass, dotHiddenToQkvPipeline, bind.txtQkvDot, Math.ceil(qkvM / linearTileCols), Math.ceil(txtRows / 32)); | |
| runStage(pass, qkvNormPipeline, bind.txtQkvNorm, txtRows, 24); | |
| runStage(pass, qkvNormPipeline, bind.imgQkvNorm, imgRows, 24); | |
| runStage(pass, attentionPipeline, bind.attention, Math.ceil(jointRows / attentionQueryRows), 24); | |
| runStage(pass, quantOffsetPipeline, bind.imgQuantAttn, imgRows); | |
| runStage(pass, dotProjPipeline, bind.imgProj, Math.ceil(hidden / projTileCols), Math.ceil(imgRows / 32)); | |
| runStage(pass, quantOffsetPipeline, bind.txtQuantAttn, txtRows); | |
| runStage(pass, dotProjPipeline, bind.txtProj, Math.ceil(hidden / projTileCols), Math.ceil(txtRows / 32)); | |
| runStage(pass, preNormQuantPipeline, bind.imgMlpPre, imgRows); | |
| runStage(pass, dotMlp0Pipeline, bind.imgMlp0, Math.ceil(mlp1M / linearTileCols), Math.ceil(imgRows / 32)); | |
| runStage(pass, activateMlpPipeline, bind.imgAct, imgRows); | |
| runStage(pass, dotMlp2Pipeline, bind.imgMlp2, Math.ceil(hidden / projTileCols), Math.ceil(imgRows / 32)); | |
| runStage(pass, preNormQuantPipeline, bind.txtMlpPre, txtRows); | |
| runStage(pass, dotMlp0Pipeline, bind.txtMlp0, Math.ceil(mlp1M / linearTileCols), Math.ceil(txtRows / 32)); | |
| runStage(pass, activateMlpPipeline, bind.txtAct, txtRows); | |
| runStage(pass, dotMlp2Pipeline, bind.txtMlp2, Math.ceil(hidden / projTileCols), Math.ceil(txtRows / 32)); | |
| pass.end(); | |
| device.queue.submit([encoder.finish()]); | |
| await device.queue.onSubmittedWorkDone(); | |
| } | |
| for (let i = 0; i < Number(config.warmupRuns ?? 1); ++i) { | |
| await dispatch(); | |
| } | |
| const times = []; | |
| for (let i = 0; i < Number(config.timedRuns ?? 3); ++i) { | |
| const start = performance.now(); | |
| await dispatch(); | |
| times.push(performance.now() - start); | |
| } | |
| let sample = null; | |
| if (config.readbackSample) { | |
| const count = Math.min(Number(config.readbackSample), imgRows * hidden); | |
| const values = await readFloat32Buffer(device, imgState2, count); | |
| let finite = 0; | |
| let maxAbs = 0; | |
| for (const value of values) { | |
| if (Number.isFinite(value)) finite += 1; | |
| maxAbs = Math.max(maxAbs, Math.abs(value)); | |
| } | |
| sample = {count, finite, max_abs: maxAbs, values: Array.from(values.slice(0, Math.min(8, values.length)))}; | |
| } | |
| const medianMs = median(times); | |
| const macs = | |
| imgRows * hidden * qkvM + txtRows * hidden * qkvM + | |
| imgRows * hidden * hidden + txtRows * hidden * hidden + | |
| imgRows * hidden * mlp1M + txtRows * hidden * mlp1M + | |
| imgRows * mlpK * hidden + txtRows * mlpK * hidden; | |
| return { | |
| verdict: "custom-double-stream-block-completed", | |
| config: { | |
| blockIndex, | |
| imgRows, | |
| txtRows, | |
| jointRows, | |
| imageWidth, | |
| linearTileCols, | |
| projTileCols, | |
| wChunkCols, | |
| attentionTileKeys, | |
| attentionQueryRows, | |
| warmupRuns: Number(config.warmupRuns ?? 1), | |
| timedRuns: Number(config.timedRuns ?? 3), | |
| }, | |
| load: {total_ms: loadMs}, | |
| summary: { | |
| median_dispatch_ms: medianMs, | |
| effective_tmacs: macs / (medianMs / 1000) / 1e12, | |
| }, | |
| timed_ms: times, | |
| sample, | |
| }; | |
| } | |
| window.runCustomDoubleStreamBlockBench = runCustomDoubleStreamBlockBench; | |
| async function runCustomDoubleStreamBlocksLoopBench(config = {}) { | |
| if (!config.fullManifest) { | |
| throw new Error("double-stream blocks loop benchmark requires fullManifest"); | |
| } | |
| const {device} = await requestCustomWebGpuDevice(["shader-f16", "packed_4x8_integer_dot_product"]); | |
| const baseUrl = config.bundleBaseUrl || "/runtime/custom_lowbit/full_transformer/"; | |
| const startBlockIndex = Math.max(0, Math.min(4, Number(config.blockIndex ?? 0))); | |
| const blockCount = Math.max(1, Math.min(5 - startBlockIndex, Number(config.blockCount ?? 5))); | |
| const blockIndices = Array.from({length: blockCount}, (_, offset) => startBlockIndex + offset); | |
| const imgRows = Number(config.imageTokens ?? 256); | |
| const txtRows = Number(config.textTokens ?? 512); | |
| const jointRows = imgRows + txtRows; | |
| const hidden = 3072; | |
| const qkvM = 9216; | |
| const mlp1M = 18432; | |
| const mlpK = 9216; | |
| const kWordsHidden = hidden / 4; | |
| const kWordsMlp = mlpK / 4; | |
| const maxRows = Math.max(imgRows, txtRows); | |
| const imageWidth = Math.max(1, Number(config.imageWidth ?? Math.round(Math.sqrt(Math.max(1, imgRows))))); | |
| const linearTileCols = Number(config.linearTileCols ?? 128); | |
| const projTileCols = Number(config.projTileCols ?? 128); | |
| const wChunkCols = Number(config.wChunkCols ?? 64); | |
| const attentionTileKeys = Number(config.attentionTileKeys ?? 8); | |
| const attentionQueryRows = Number(config.attentionQueryRows ?? 16); | |
| const loadStart = performance.now(); | |
| const blocks = await Promise.all(blockIndices.map(async (blockIndex) => { | |
| const [ | |
| imgQkv, txtQkv, imgProj, txtProj, imgMlp0, imgMlp2, txtMlp0, txtMlp2, | |
| imgNorm, txtNorm, | |
| ] = await Promise.all([ | |
| loadI8BundleFromFullEntry(findFullManifestLinear(config.fullManifest, `double_blocks.${blockIndex}.img_attn.qkv`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(config.fullManifest, `double_blocks.${blockIndex}.txt_attn.qkv`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(config.fullManifest, `double_blocks.${blockIndex}.img_attn.proj`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(config.fullManifest, `double_blocks.${blockIndex}.txt_attn.proj`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(config.fullManifest, `double_blocks.${blockIndex}.img_mlp.0`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(config.fullManifest, `double_blocks.${blockIndex}.img_mlp.2`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(config.fullManifest, `double_blocks.${blockIndex}.txt_mlp.0`), baseUrl), | |
| loadI8BundleFromFullEntry(findFullManifestLinear(config.fullManifest, `double_blocks.${blockIndex}.txt_mlp.2`), baseUrl), | |
| loadQkNormScales(config.fullManifest, baseUrl, `double_blocks.${blockIndex}.img_attn.norm`), | |
| loadQkNormScales(config.fullManifest, baseUrl, `double_blocks.${blockIndex}.txt_attn.norm`), | |
| ]); | |
| return {blockIndex, imgQkv, txtQkv, imgProj, txtProj, imgMlp0, imgMlp2, txtMlp0, txtMlp2, imgNorm, txtNorm}; | |
| })); | |
| const loadMs = performance.now() - loadStart; | |
| const imgInput = buildInputF32(imgRows, hidden); | |
| const txtInput = buildInputF32(txtRows, hidden); | |
| const imgStateA = createBuffer(device, imgInput, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST); | |
| const txtStateA = createBuffer(device, txtInput, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST); | |
| const imgStateB = createEmptyBuffer(device, imgRows * hidden * 4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST); | |
| const txtStateB = createEmptyBuffer(device, txtRows * hidden * 4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST); | |
| const imgMidState = createEmptyBuffer(device, imgRows * hidden * 4, GPUBufferUsage.STORAGE); | |
| const txtMidState = createEmptyBuffer(device, txtRows * hidden * 4, GPUBufferUsage.STORAGE); | |
| const modImgBuffer = createEmptyBuffer(device, 18432 * 4, GPUBufferUsage.STORAGE); | |
| const modTxtBuffer = createEmptyBuffer(device, 18432 * 4, GPUBufferUsage.STORAGE); | |
| const packedHiddenBuffer = createEmptyBuffer(device, maxRows * kWordsHidden * 4, GPUBufferUsage.STORAGE); | |
| const scalesBuffer = createEmptyBuffer(device, maxRows * 4, GPUBufferUsage.STORAGE); | |
| const packedMlpBuffer = createEmptyBuffer(device, maxRows * kWordsMlp * 4, GPUBufferUsage.STORAGE); | |
| const mlpScalesBuffer = createEmptyBuffer(device, maxRows * 4, GPUBufferUsage.STORAGE); | |
| const imgQkvBuffer = createEmptyBuffer(device, imgRows * qkvM * 2, GPUBufferUsage.STORAGE); | |
| const txtQkvBuffer = createEmptyBuffer(device, txtRows * qkvM * 2, GPUBufferUsage.STORAGE); | |
| const qNormBuffer = createEmptyBuffer(device, jointRows * hidden * 4, GPUBufferUsage.STORAGE); | |
| const kNormBuffer = createEmptyBuffer(device, jointRows * hidden * 4, GPUBufferUsage.STORAGE); | |
| const vBuffer = createEmptyBuffer(device, jointRows * hidden * 4, GPUBufferUsage.STORAGE); | |
| const attentionOutBuffer = createEmptyBuffer(device, jointRows * hidden * 4, GPUBufferUsage.STORAGE); | |
| const mlp0Buffer = createEmptyBuffer(device, maxRows * mlp1M * 2, GPUBufferUsage.STORAGE); | |
| const ropeFreqBuffer = createBuffer(device, buildFluxRopeFrequencies(), GPUBufferUsage.STORAGE); | |
| const preNormQuantModule = device.createShaderModule({code: SINGLE_PRENORM_MOD_QUANT_SHADER}); | |
| const quantOffsetModule = device.createShaderModule({code: QUANTIZE_X_F32_OFFSET_SHADER}); | |
| const dotHiddenToQkvModule = device.createShaderModule({code: makeI8ScaledDotShader32xWide(linearTileCols, wChunkCols, true)}); | |
| const qkvNormModule = device.createShaderModule({code: DOUBLE_QKV_NORM_ROPE_SHADER}); | |
| const attentionModule = device.createShaderModule({code: makeJointAttentionTiledShader(attentionTileKeys, attentionQueryRows)}); | |
| const dotProjModule = device.createShaderModule({code: makeI8ScaledDotResidualShader32xWide(projTileCols, wChunkCols)}); | |
| const dotMlp0Module = device.createShaderModule({code: makeI8ScaledDotShader32xWide(linearTileCols, wChunkCols, true)}); | |
| const activateMlpModule = device.createShaderModule({code: makeLinear1F16ConsumerShader(SINGLE_MLP_ACTIVATE_QUANT_SHADER)}); | |
| const dotMlp2Module = device.createShaderModule({code: makeI8ScaledDotResidualShader32xWide(projTileCols, wChunkCols)}); | |
| const preNormQuantPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: preNormQuantModule, entryPoint: "main"}}); | |
| const quantOffsetPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: quantOffsetModule, entryPoint: "main"}}); | |
| const dotHiddenToQkvPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: dotHiddenToQkvModule, entryPoint: "main"}}); | |
| const qkvNormPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: qkvNormModule, entryPoint: "main"}}); | |
| const attentionPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: attentionModule, entryPoint: "main"}}); | |
| const dotProjPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: dotProjModule, entryPoint: "main"}}); | |
| const dotMlp0Pipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: dotMlp0Module, entryPoint: "main"}}); | |
| const activateMlpPipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: activateMlpModule, entryPoint: "main"}}); | |
| const dotMlp2Pipeline = await device.createComputePipelineAsync({layout: "auto", compute: {module: dotMlp2Module, entryPoint: "main"}}); | |
| const timestepBuffer = createBuffer(device, makeTimestepEmbeddingF16(Number(config.timestep ?? 1.0)), GPUBufferUsage.STORAGE); | |
| const timeHiddenF32Buffer = createEmptyBuffer(device, hidden * 4, GPUBufferUsage.STORAGE); | |
| const timeHiddenF16Buffer = createEmptyBuffer(device, hidden * 2, GPUBufferUsage.STORAGE); | |
| const vecF32Buffer = createEmptyBuffer(device, hidden * 4, GPUBufferUsage.STORAGE); | |
| const vecF16Buffer = createEmptyBuffer(device, hidden * 2, GPUBufferUsage.STORAGE); | |
| const timeStage0 = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "time_in.in_layer"), baseUrl, timestepBuffer, timeHiddenF32Buffer, 1, 96); | |
| const timeSiluStage = await createF32ToF16Stage(device, timeHiddenF32Buffer, timeHiddenF16Buffer, hidden, true); | |
| const timeStage1 = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "time_in.out_layer"), baseUrl, timeHiddenF16Buffer, vecF32Buffer, 1, 96); | |
| const vecCastStage = await createF32ToF16Stage(device, vecF32Buffer, vecF16Buffer, hidden, true); | |
| const modImgStage = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "double_stream_modulation_img.lin"), baseUrl, vecF16Buffer, modImgBuffer, 1, 96); | |
| const modTxtStage = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "double_stream_modulation_txt.lin"), baseUrl, vecF16Buffer, modTxtBuffer, 1, 96); | |
| const makeLinearBuffers = (bundle) => ({ | |
| w: createImmutableBuffer(device, bundle.weight, GPUBufferUsage.STORAGE, bundle.weightCacheKey), | |
| s: createImmutableBuffer(device, bundle.scales, GPUBufferUsage.STORAGE, bundle.scalesCacheKey), | |
| }); | |
| const makeParams = (values) => createBuffer(device, new Uint32Array(values), GPUBufferUsage.UNIFORM); | |
| const params = { | |
| imgPre: makeParams([imgRows, hidden, qkvM, kWordsHidden]), | |
| txtPre: makeParams([txtRows, hidden, qkvM, kWordsHidden]), | |
| imgDotQkv: makeParams([imgRows, hidden, qkvM, kWordsHidden, 0, 0, 0, 0]), | |
| txtDotQkv: makeParams([txtRows, hidden, qkvM, kWordsHidden, 0, 0, 0, 0]), | |
| imgQkvNorm: makeParams([imgRows, qkvM, 24, 128, hidden, txtRows, txtRows, imageWidth]), | |
| txtQkvNorm: makeParams([txtRows, qkvM, 24, 128, hidden, 0, txtRows, imageWidth]), | |
| attention: makeParams([jointRows, 24, 128, hidden]), | |
| imgQuantAttn: makeParams([imgRows, hidden, kWordsHidden, txtRows, hidden]), | |
| txtQuantAttn: makeParams([txtRows, hidden, kWordsHidden, 0, hidden]), | |
| imgProj: makeParams([imgRows, hidden, hidden, kWordsHidden, 0, 0, 0, 0]), | |
| txtProj: makeParams([txtRows, hidden, hidden, kWordsHidden, 0, 0, 0, 0]), | |
| imgMlpPre: makeParams([imgRows, hidden, mlp1M, kWordsHidden]), | |
| txtMlpPre: makeParams([txtRows, hidden, mlp1M, kWordsHidden]), | |
| imgMlp0: makeParams([imgRows, hidden, mlp1M, kWordsHidden, 0, 0, 0, 0]), | |
| txtMlp0: makeParams([txtRows, hidden, mlp1M, kWordsHidden, 0, 0, 0, 0]), | |
| imgAct: makeParams([imgRows, mlp1M, mlpK, 0, 0, mlpK, kWordsMlp, 0]), | |
| txtAct: makeParams([txtRows, mlp1M, mlpK, 0, 0, mlpK, kWordsMlp, 0]), | |
| imgMlp2: makeParams([imgRows, mlpK, hidden, kWordsMlp, 0, 0, 0, 0]), | |
| txtMlp2: makeParams([txtRows, mlpK, hidden, kWordsMlp, 0, 0, 0, 0]), | |
| }; | |
| const hiddenModOffsets = { | |
| mod1Shift: 0, | |
| mod1Scale: hidden * 4, | |
| mod1Gate: hidden * 8, | |
| mod2Shift: hidden * 12, | |
| mod2Scale: hidden * 16, | |
| mod2Gate: hidden * 20, | |
| }; | |
| const res = (buffer, offset = 0) => offset ? {buffer, offset} : {buffer}; | |
| const makePreNormBind = (state, modBuffer, shiftOffset, scaleOffset, paramsBuffer) => device.createBindGroup({ | |
| layout: preNormQuantPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: state}}, | |
| {binding: 1, resource: res(modBuffer, shiftOffset)}, | |
| {binding: 2, resource: res(modBuffer, scaleOffset)}, | |
| {binding: 3, resource: {buffer: packedHiddenBuffer}}, | |
| {binding: 4, resource: {buffer: scalesBuffer}}, | |
| {binding: 5, resource: {buffer: paramsBuffer}}, | |
| ], | |
| }); | |
| const makeDotBind = (pipeline, packed, packedScales, w, wScales, out, paramsBuffer) => device.createBindGroup({ | |
| layout: pipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: packed}}, | |
| {binding: 1, resource: {buffer: w}}, | |
| {binding: 2, resource: {buffer: packedScales}}, | |
| {binding: 3, resource: {buffer: wScales}}, | |
| {binding: 4, resource: {buffer: out}}, | |
| {binding: 5, resource: {buffer: paramsBuffer}}, | |
| ], | |
| }); | |
| const makeResidualDotBind = (pipeline, packed, packedScales, w, wScales, out, paramsBuffer, residual, gate) => device.createBindGroup({ | |
| layout: pipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: packed}}, | |
| {binding: 1, resource: {buffer: w}}, | |
| {binding: 2, resource: {buffer: packedScales}}, | |
| {binding: 3, resource: {buffer: wScales}}, | |
| {binding: 4, resource: {buffer: out}}, | |
| {binding: 5, resource: {buffer: paramsBuffer}}, | |
| {binding: 6, resource: {buffer: residual}}, | |
| {binding: 7, resource: gate}, | |
| ], | |
| }); | |
| const commonBind = { | |
| attention: device.createBindGroup({ | |
| layout: attentionPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: vBuffer}}, | |
| {binding: 1, resource: {buffer: qNormBuffer}}, | |
| {binding: 2, resource: {buffer: kNormBuffer}}, | |
| {binding: 3, resource: {buffer: attentionOutBuffer}}, | |
| {binding: 4, resource: {buffer: params.attention}}, | |
| ], | |
| }), | |
| imgQuantAttn: device.createBindGroup({ | |
| layout: quantOffsetPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: attentionOutBuffer}}, | |
| {binding: 1, resource: {buffer: packedHiddenBuffer}}, | |
| {binding: 2, resource: {buffer: scalesBuffer}}, | |
| {binding: 3, resource: {buffer: params.imgQuantAttn}}, | |
| ], | |
| }), | |
| txtQuantAttn: device.createBindGroup({ | |
| layout: quantOffsetPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: attentionOutBuffer}}, | |
| {binding: 1, resource: {buffer: packedHiddenBuffer}}, | |
| {binding: 2, resource: {buffer: scalesBuffer}}, | |
| {binding: 3, resource: {buffer: params.txtQuantAttn}}, | |
| ], | |
| }), | |
| imgAct: device.createBindGroup({ | |
| layout: activateMlpPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: mlp0Buffer}}, | |
| {binding: 1, resource: {buffer: packedMlpBuffer}}, | |
| {binding: 2, resource: {buffer: mlpScalesBuffer}}, | |
| {binding: 3, resource: {buffer: params.imgAct}}, | |
| ], | |
| }), | |
| txtAct: device.createBindGroup({ | |
| layout: activateMlpPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: mlp0Buffer}}, | |
| {binding: 1, resource: {buffer: packedMlpBuffer}}, | |
| {binding: 2, resource: {buffer: mlpScalesBuffer}}, | |
| {binding: 3, resource: {buffer: params.txtAct}}, | |
| ], | |
| }), | |
| }; | |
| const blockStages = blocks.map((block, offset) => { | |
| const ping = offset & 1; | |
| const imgInputState = ping ? imgStateB : imgStateA; | |
| const txtInputState = ping ? txtStateB : txtStateA; | |
| const imgOutputState = ping ? imgStateA : imgStateB; | |
| const txtOutputState = ping ? txtStateA : txtStateB; | |
| const weights = { | |
| imgQkv: makeLinearBuffers(block.imgQkv), | |
| txtQkv: makeLinearBuffers(block.txtQkv), | |
| imgProj: makeLinearBuffers(block.imgProj), | |
| txtProj: makeLinearBuffers(block.txtProj), | |
| imgMlp0: makeLinearBuffers(block.imgMlp0), | |
| imgMlp2: makeLinearBuffers(block.imgMlp2), | |
| txtMlp0: makeLinearBuffers(block.txtMlp0), | |
| txtMlp2: makeLinearBuffers(block.txtMlp2), | |
| }; | |
| const imgQueryScaleBuffer = createImmutableBuffer(device, block.imgNorm.queryScale, GPUBufferUsage.STORAGE, block.imgNorm.queryScaleUrl); | |
| const imgKeyScaleBuffer = createImmutableBuffer(device, block.imgNorm.keyScale, GPUBufferUsage.STORAGE, block.imgNorm.keyScaleUrl); | |
| const txtQueryScaleBuffer = createImmutableBuffer(device, block.txtNorm.queryScale, GPUBufferUsage.STORAGE, block.txtNorm.queryScaleUrl); | |
| const txtKeyScaleBuffer = createImmutableBuffer(device, block.txtNorm.keyScale, GPUBufferUsage.STORAGE, block.txtNorm.keyScaleUrl); | |
| return { | |
| blockIndex: block.blockIndex, | |
| imgQkvPre: makePreNormBind(imgInputState, modImgBuffer, hiddenModOffsets.mod1Shift, hiddenModOffsets.mod1Scale, params.imgPre), | |
| txtQkvPre: makePreNormBind(txtInputState, modTxtBuffer, hiddenModOffsets.mod1Shift, hiddenModOffsets.mod1Scale, params.txtPre), | |
| imgQkvDot: makeDotBind(dotHiddenToQkvPipeline, packedHiddenBuffer, scalesBuffer, weights.imgQkv.w, weights.imgQkv.s, imgQkvBuffer, params.imgDotQkv), | |
| txtQkvDot: makeDotBind(dotHiddenToQkvPipeline, packedHiddenBuffer, scalesBuffer, weights.txtQkv.w, weights.txtQkv.s, txtQkvBuffer, params.txtDotQkv), | |
| imgQkvNorm: device.createBindGroup({ | |
| layout: qkvNormPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: imgQkvBuffer}}, | |
| {binding: 1, resource: {buffer: imgQueryScaleBuffer}}, | |
| {binding: 2, resource: {buffer: imgKeyScaleBuffer}}, | |
| {binding: 3, resource: {buffer: qNormBuffer}}, | |
| {binding: 4, resource: {buffer: kNormBuffer}}, | |
| {binding: 5, resource: {buffer: vBuffer}}, | |
| {binding: 6, resource: {buffer: params.imgQkvNorm}}, | |
| {binding: 7, resource: {buffer: ropeFreqBuffer}}, | |
| ], | |
| }), | |
| txtQkvNorm: device.createBindGroup({ | |
| layout: qkvNormPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: txtQkvBuffer}}, | |
| {binding: 1, resource: {buffer: txtQueryScaleBuffer}}, | |
| {binding: 2, resource: {buffer: txtKeyScaleBuffer}}, | |
| {binding: 3, resource: {buffer: qNormBuffer}}, | |
| {binding: 4, resource: {buffer: kNormBuffer}}, | |
| {binding: 5, resource: {buffer: vBuffer}}, | |
| {binding: 6, resource: {buffer: params.txtQkvNorm}}, | |
| {binding: 7, resource: {buffer: ropeFreqBuffer}}, | |
| ], | |
| }), | |
| imgProj: makeResidualDotBind(dotProjPipeline, packedHiddenBuffer, scalesBuffer, weights.imgProj.w, weights.imgProj.s, imgMidState, params.imgProj, imgInputState, res(modImgBuffer, hiddenModOffsets.mod1Gate)), | |
| txtProj: makeResidualDotBind(dotProjPipeline, packedHiddenBuffer, scalesBuffer, weights.txtProj.w, weights.txtProj.s, txtMidState, params.txtProj, txtInputState, res(modTxtBuffer, hiddenModOffsets.mod1Gate)), | |
| imgMlpPre: makePreNormBind(imgMidState, modImgBuffer, hiddenModOffsets.mod2Shift, hiddenModOffsets.mod2Scale, params.imgMlpPre), | |
| txtMlpPre: makePreNormBind(txtMidState, modTxtBuffer, hiddenModOffsets.mod2Shift, hiddenModOffsets.mod2Scale, params.txtMlpPre), | |
| imgMlp0: makeDotBind(dotMlp0Pipeline, packedHiddenBuffer, scalesBuffer, weights.imgMlp0.w, weights.imgMlp0.s, mlp0Buffer, params.imgMlp0), | |
| txtMlp0: makeDotBind(dotMlp0Pipeline, packedHiddenBuffer, scalesBuffer, weights.txtMlp0.w, weights.txtMlp0.s, mlp0Buffer, params.txtMlp0), | |
| imgMlp2: makeResidualDotBind(dotMlp2Pipeline, packedMlpBuffer, mlpScalesBuffer, weights.imgMlp2.w, weights.imgMlp2.s, imgOutputState, params.imgMlp2, imgMidState, res(modImgBuffer, hiddenModOffsets.mod2Gate)), | |
| txtMlp2: makeResidualDotBind(dotMlp2Pipeline, packedMlpBuffer, mlpScalesBuffer, weights.txtMlp2.w, weights.txtMlp2.s, txtOutputState, params.txtMlp2, txtMidState, res(modTxtBuffer, hiddenModOffsets.mod2Gate)), | |
| }; | |
| }); | |
| const runStage = (pass, pipeline, bindGroup, x, y = undefined) => { | |
| pass.setPipeline(pipeline); | |
| pass.setBindGroup(0, bindGroup); | |
| if (y === undefined) pass.dispatchWorkgroups(x); | |
| else pass.dispatchWorkgroups(x, y); | |
| }; | |
| async function dispatch() { | |
| const encoder = device.createCommandEncoder(); | |
| const pass = encoder.beginComputePass(); | |
| runStage(pass, timeStage0.pipeline, timeStage0.bindGroup, timeStage0.workgroupsX, timeStage0.workgroupsY); | |
| runStage(pass, timeSiluStage.pipeline, timeSiluStage.bindGroup, timeSiluStage.workgroupsX); | |
| runStage(pass, timeStage1.pipeline, timeStage1.bindGroup, timeStage1.workgroupsX, timeStage1.workgroupsY); | |
| runStage(pass, vecCastStage.pipeline, vecCastStage.bindGroup, vecCastStage.workgroupsX); | |
| runStage(pass, modImgStage.pipeline, modImgStage.bindGroup, modImgStage.workgroupsX, modImgStage.workgroupsY); | |
| runStage(pass, modTxtStage.pipeline, modTxtStage.bindGroup, modTxtStage.workgroupsX, modTxtStage.workgroupsY); | |
| for (const stage of blockStages) { | |
| runStage(pass, preNormQuantPipeline, stage.imgQkvPre, imgRows); | |
| runStage(pass, dotHiddenToQkvPipeline, stage.imgQkvDot, Math.ceil(qkvM / linearTileCols), Math.ceil(imgRows / 32)); | |
| runStage(pass, preNormQuantPipeline, stage.txtQkvPre, txtRows); | |
| runStage(pass, dotHiddenToQkvPipeline, stage.txtQkvDot, Math.ceil(qkvM / linearTileCols), Math.ceil(txtRows / 32)); | |
| runStage(pass, qkvNormPipeline, stage.txtQkvNorm, txtRows, 24); | |
| runStage(pass, qkvNormPipeline, stage.imgQkvNorm, imgRows, 24); | |
| runStage(pass, attentionPipeline, commonBind.attention, Math.ceil(jointRows / attentionQueryRows), 24); | |
| runStage(pass, quantOffsetPipeline, commonBind.imgQuantAttn, imgRows); | |
| runStage(pass, dotProjPipeline, stage.imgProj, Math.ceil(hidden / projTileCols), Math.ceil(imgRows / 32)); | |
| runStage(pass, quantOffsetPipeline, commonBind.txtQuantAttn, txtRows); | |
| runStage(pass, dotProjPipeline, stage.txtProj, Math.ceil(hidden / projTileCols), Math.ceil(txtRows / 32)); | |
| runStage(pass, preNormQuantPipeline, stage.imgMlpPre, imgRows); | |
| runStage(pass, dotMlp0Pipeline, stage.imgMlp0, Math.ceil(mlp1M / linearTileCols), Math.ceil(imgRows / 32)); | |
| runStage(pass, activateMlpPipeline, commonBind.imgAct, imgRows); | |
| runStage(pass, dotMlp2Pipeline, stage.imgMlp2, Math.ceil(hidden / projTileCols), Math.ceil(imgRows / 32)); | |
| runStage(pass, preNormQuantPipeline, stage.txtMlpPre, txtRows); | |
| runStage(pass, dotMlp0Pipeline, stage.txtMlp0, Math.ceil(mlp1M / linearTileCols), Math.ceil(txtRows / 32)); | |
| runStage(pass, activateMlpPipeline, commonBind.txtAct, txtRows); | |
| runStage(pass, dotMlp2Pipeline, stage.txtMlp2, Math.ceil(hidden / projTileCols), Math.ceil(txtRows / 32)); | |
| } | |
| pass.end(); | |
| device.queue.submit([encoder.finish()]); | |
| await device.queue.onSubmittedWorkDone(); | |
| } | |
| for (let i = 0; i < Number(config.warmupRuns ?? 1); ++i) { | |
| await dispatch(); | |
| } | |
| const times = []; | |
| for (let i = 0; i < Number(config.timedRuns ?? 3); ++i) { | |
| const start = performance.now(); | |
| await dispatch(); | |
| times.push(performance.now() - start); | |
| } | |
| let sample = null; | |
| if (config.readbackSample) { | |
| const finalImgBuffer = blockCount % 2 === 0 ? imgStateA : imgStateB; | |
| const count = Math.min(Number(config.readbackSample), imgRows * hidden); | |
| const values = await readFloat32Buffer(device, finalImgBuffer, count); | |
| let finite = 0; | |
| let maxAbs = 0; | |
| for (const value of values) { | |
| if (Number.isFinite(value)) finite += 1; | |
| maxAbs = Math.max(maxAbs, Math.abs(value)); | |
| } | |
| sample = {count, finite, max_abs: maxAbs, values: Array.from(values.slice(0, Math.min(8, values.length)))}; | |
| } | |
| const medianMs = median(times); | |
| const macsPerBlock = | |
| imgRows * hidden * qkvM + txtRows * hidden * qkvM + | |
| imgRows * hidden * hidden + txtRows * hidden * hidden + | |
| imgRows * hidden * mlp1M + txtRows * hidden * mlp1M + | |
| imgRows * mlpK * hidden + txtRows * mlpK * hidden; | |
| return { | |
| verdict: "custom-double-stream-blocks-loop-completed", | |
| config: { | |
| startBlockIndex, | |
| blockCount, | |
| blockIndices, | |
| imgRows, | |
| txtRows, | |
| jointRows, | |
| imageWidth, | |
| linearTileCols, | |
| projTileCols, | |
| wChunkCols, | |
| attentionTileKeys, | |
| attentionQueryRows, | |
| warmupRuns: Number(config.warmupRuns ?? 1), | |
| timedRuns: Number(config.timedRuns ?? 3), | |
| }, | |
| load: {total_ms: loadMs}, | |
| summary: { | |
| median_dispatch_ms: medianMs, | |
| per_block_median_ms: medianMs / blockCount, | |
| effective_tmacs: (macsPerBlock * blockCount) / (medianMs / 1000) / 1e12, | |
| }, | |
| timed_ms: times, | |
| sample, | |
| }; | |
| } | |
| async function runCustomFluxTransformerDenoise(config = {}) { | |
| if (!config.fullManifest) { | |
| throw new Error("custom full denoise requires fullManifest"); | |
| } | |
| const baseUrl = config.bundleBaseUrl || "/runtime/custom_lowbit/full_transformer/"; | |
| const imgRows = Number(config.imageTokens ?? 256); | |
| const txtRows = Number(config.textTokens ?? 512); | |
| const jointRows = imgRows + txtRows; | |
| const hidden = 3072; | |
| const contextDim = 7680; | |
| const latentChannels = Number(config.latentChannels ?? 128); | |
| const imageWidth = Math.max(1, Number(config.imageWidth ?? Math.round(Math.sqrt(Math.max(1, imgRows))))); | |
| const qkvM = 9216; | |
| const doubleMlp1M = 18432; | |
| const doubleMlpK = 9216; | |
| const singleLinear1M = 27648; | |
| const singleLinear2K = 12288; | |
| const kWordsLatent = latentChannels / 4; | |
| const kWordsContext = contextDim / 4; | |
| const kWordsHidden = hidden / 4; | |
| const hiddenQ4Groups = hidden / 16; | |
| const kWordsDoubleMlp = doubleMlpK / 4; | |
| const kWordsSingleMlp = singleLinear2K / 4; | |
| const maxRows = Math.max(imgRows, txtRows); | |
| const linearTileCols = Number(config.linearTileCols ?? 256); | |
| const projTileCols = Number(config.projTileCols ?? 128); | |
| const finalTileCols = Number(config.finalTileCols ?? 256); | |
| const wChunkCols = Number(config.wChunkCols ?? 64); | |
| const linearRowBlock = Number(config.linearRowBlock ?? 32); | |
| const singleAttentionKernel = String(config.singleAttentionKernel || config.attentionKernel || "tiled").toLowerCase(); | |
| const useSubgroupSingleAttention = singleAttentionKernel === "subgroup" || singleAttentionKernel === "subgroups"; | |
| const {device} = await requestCustomWebGpuDevice( | |
| ["shader-f16", "packed_4x8_integer_dot_product", ...(useSubgroupSingleAttention ? ["subgroups"] : [])], | |
| linearRowBlock > 32 ? {maxComputeInvocationsPerWorkgroup: 512} : {}, | |
| ); | |
| const attentionTileKeys = Number(config.attentionTileKeys ?? 8); | |
| const attentionQueryRows = Number(config.attentionQueryRows ?? 16); | |
| const singleQkNormStorage = String(config.singleQkNormStorage || "f32").toLowerCase(); | |
| const useSingleQkNormStorageF16 = singleQkNormStorage === "f16" || singleQkNormStorage === "fp16" || singleQkNormStorage === "half"; | |
| const singleLinear2Backend = String(config.singleLinear2Backend || "dp4a").toLowerCase(); | |
| const useSingleLinear2Dp4a = singleLinear2Backend === "dp4a" || singleLinear2Backend === "i8" || singleLinear2Backend === "int8"; | |
| const useSingleLinear2Q4 = !useSingleLinear2Dp4a; | |
| const singleLinear1Output = String(config.singleLinear1Output || "f16").toLowerCase(); | |
| const useSingleLinear1OutputF16 = singleLinear1Output === "f16" || singleLinear1Output === "fp16" || singleLinear1Output === "half"; | |
| const singleLinear1QkvBackend = String(config.singleLinear1QkvBackend || "q4").toLowerCase(); | |
| const useSingleLinear1QkvDp4a = useSingleLinear1OutputF16 && ( | |
| singleLinear1QkvBackend === "dp4a" || singleLinear1QkvBackend === "i8" || singleLinear1QkvBackend === "int8" | |
| ); | |
| const singleLinear1MlpBackend = String(config.singleLinear1MlpBackend || "q4").toLowerCase(); | |
| const useSingleLinear1MlpDp4a = useSingleLinear1OutputF16 && ( | |
| singleLinear1MlpBackend === "dp4a" || singleLinear1MlpBackend === "i8" || singleLinear1MlpBackend === "int8" | |
| ); | |
| const useSingleLinear1Dp4a = useSingleLinear1QkvDp4a || useSingleLinear1MlpDp4a; | |
| const useSingleLinear1Q4 = !(useSingleLinear1QkvDp4a && useSingleLinear1MlpDp4a); | |
| const singleLinear1Q4Kernel = String(config.singleLinear1Q4Kernel || "f16").toLowerCase(); | |
| const useSingleLinear1Q4Dp4aQkvOnly = ( | |
| singleLinear1Q4Kernel === "dp4a-qkv" || | |
| singleLinear1Q4Kernel === "qkv-dp4a" || | |
| singleLinear1Q4Kernel === "q4-dp4a-qkv" | |
| ); | |
| const useSingleLinear1Q4Dp4a = useSingleLinear1OutputF16 && ( | |
| singleLinear1Q4Kernel === "dp4a" || | |
| singleLinear1Q4Kernel === "q4-dp4a" || | |
| singleLinear1Q4Kernel === "int8" || | |
| useSingleLinear1Q4Dp4aQkvOnly | |
| ); | |
| const singleLinear1Q4ActivationScale = String( | |
| config.singleLinear1Q4ActivationScale ?? (useSingleLinear1Q4Dp4a ? "group16" : "row") | |
| ).toLowerCase(); | |
| const useSingleLinear1Q4GroupedActivation = useSingleLinear1Q4Dp4a && ( | |
| singleLinear1Q4ActivationScale === "group" || | |
| singleLinear1Q4ActivationScale === "group16" || | |
| singleLinear1Q4ActivationScale === "per-group" | |
| ); | |
| const needsSingleLinear1F16PreNorm = useSingleLinear1Q4 && (!useSingleLinear1Q4Dp4a || useSingleLinear1Q4Dp4aQkvOnly); | |
| const needsSingleLinear1GroupedPreNorm = useSingleLinear1Q4GroupedActivation; | |
| const approxReusePredEvery = Math.max(0, Math.trunc(Number(config.approxReusePredEvery || 0))); | |
| const approxPredictionMode = String(config.approxPredictionMode || config.approxMode || "raw").toLowerCase(); | |
| const approxAbScale = Number.isFinite(Number(config.approxAbScale)) | |
| ? Number(config.approxAbScale) | |
| : 0.5; | |
| const useApproxAb2 = approxReusePredEvery > 1 && (approxPredictionMode === "ab2" || approxPredictionMode === "adams-bashforth"); | |
| const skipSingleBlocksFromStep = Math.max(0, Math.trunc(Number(config.skipSingleBlocksFromStep || 0))); | |
| const skipSingleBlocksFromBlock = Math.max(0, Math.trunc(Number(config.skipSingleBlocksFromBlock || 0))); | |
| const reuseSingleTailFromStep = Math.max(0, Math.trunc(Number(config.reuseSingleTailFromStep || 0))); | |
| const reuseSingleTailFromBlock = Math.max(0, Math.trunc(Number(config.reuseSingleTailFromBlock || 0))); | |
| const profilePhases = config.profilePhases === true && !config.debugStopAfter; | |
| const profileSingleParts = profilePhases && config.profileSingleParts === true; | |
| const stepSubmitFusion = config.stepSubmitFusion === true && approxReusePredEvery <= 1 && !config.debugStopAfter && !profilePhases; | |
| const timesteps = Array.isArray(config.timesteps) && config.timesteps.length > 1 | |
| ? config.timesteps.map(Number) | |
| : [1.0, 0.75, 0.5, 0.25, 0.0]; | |
| const denoiseStepCount = timesteps.length - 1; | |
| const doubleBlockCount = Math.max(1, Math.min(5, Math.trunc(Number(config.maxDoubleBlocks || config.doubleBlockCount || 5)))); | |
| const singleBlockCount = Math.max(0, Math.min(20, Math.trunc(Number(config.maxSingleBlocks || config.singleBlockCount || 20)))); | |
| if (jointRows > 4608) { | |
| throw new Error(`custom WebGPU transformer currently supports joint tokens <= 4608; got ${jointRows}`); | |
| } | |
| if (latentChannels !== 128) { | |
| throw new Error(`custom WebGPU transformer currently expects 128 latent channels; got ${latentChannels}`); | |
| } | |
| const initialLatent = config.initialLatentF32 instanceof Float32Array | |
| ? config.initialLatentF32 | |
| : new Float32Array(config.initialLatentF32 || []); | |
| if (initialLatent.length !== imgRows * latentChannels) { | |
| throw new Error(`initialLatentF32 length ${initialLatent.length} does not match ${imgRows * latentChannels}`); | |
| } | |
| const contextF16 = config.contextF16 instanceof Uint16Array | |
| ? config.contextF16 | |
| : new Uint16Array(config.contextF16 || []); | |
| if (contextF16.length !== txtRows * contextDim) { | |
| throw new Error(`contextF16 length ${contextF16.length} does not match ${txtRows * contextDim}`); | |
| } | |
| const textProjectionCacheKey = config.textContextCacheKey | |
| ? `${baseUrl}|txt-in|${txtRows}|${contextDim}|${hidden}|${config.textContextCacheKey}` | |
| : ""; | |
| const persistentTextProjectionKey = textProjectionCacheKey && config.persistentTextProjectionCache !== false | |
| ? `txtproj-v1|${textProjectionCacheKey}` | |
| : ""; | |
| let textProjectionDeviceCache = null; | |
| let cachedTxtBaseState = null; | |
| let persistentTxtProjectionHit = false; | |
| let staticTxtProjectionHit = false; | |
| let persistentTxtProjectionMs = 0; | |
| if (textProjectionCacheKey) { | |
| textProjectionDeviceCache = textProjectionGpuCache.get(device); | |
| if (!textProjectionDeviceCache) { | |
| textProjectionDeviceCache = new Map(); | |
| textProjectionGpuCache.set(device, textProjectionDeviceCache); | |
| } | |
| cachedTxtBaseState = textProjectionDeviceCache.get(textProjectionCacheKey) || null; | |
| } | |
| let contextF32 = null; | |
| if (!cachedTxtBaseState) { | |
| contextF32 = new Float32Array(contextF16.length); | |
| for (let i = 0; i < contextF16.length; ++i) { | |
| contextF32[i] = float16BitsToFloat32(contextF16[i]); | |
| } | |
| } | |
| const loadStart = performance.now(); | |
| const { | |
| imgIn, | |
| txtIn, | |
| finalLinear, | |
| doubleBlocks, | |
| singleBlocks, | |
| singleNorms, | |
| } = await loadFullTransformerAssetSet(config.fullManifest, baseUrl, doubleBlockCount, singleBlockCount); | |
| const loadMs = performance.now() - loadStart; | |
| const stageSetupStart = performance.now(); | |
| const latentBytes = imgRows * latentChannels * 4; | |
| const imgBytes = imgRows * hidden * 4; | |
| const txtBytes = txtRows * hidden * 4; | |
| const jointBytes = jointRows * hidden * 4; | |
| const scratchKey = [ | |
| baseUrl, | |
| `img=${imgRows}`, | |
| `txt=${txtRows}`, | |
| `latent=${latentChannels}`, | |
| `l1=${singleLinear1Output}`, | |
| `qkv=${singleLinear1QkvBackend}`, | |
| `mlp=${singleLinear1MlpBackend}`, | |
| `q4=${singleLinear1Q4Kernel}`, | |
| `q4act=${singleLinear1Q4ActivationScale}`, | |
| `l2=${singleLinear2Backend}`, | |
| ].join("|"); | |
| const scratchBuffer = (name, size, usage) => createReusableBuffer(device, `${scratchKey}|${name}`, size, usage); | |
| const uploadScratchBuffer = (name, data, usage) => { | |
| const buffer = scratchBuffer(name, data.byteLength, usage | GPUBufferUsage.COPY_DST); | |
| device.queue.writeBuffer(buffer, 0, data); | |
| return buffer; | |
| }; | |
| const latentF32Buffer = uploadScratchBuffer("latent-f32", initialLatent, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| const contextF32Buffer = contextF32 | |
| ? uploadScratchBuffer("context-f32", contextF32, GPUBufferUsage.STORAGE) | |
| : scratchBuffer("context-empty", 4, GPUBufferUsage.STORAGE); | |
| let txtBaseState = cachedTxtBaseState; | |
| if (!txtBaseState && persistentTextProjectionKey) { | |
| const persistentStart = performance.now(); | |
| const persistedTxtBase = await loadPersistentFloat32(persistentTextProjectionKey, txtRows * hidden); | |
| persistentTxtProjectionMs = performance.now() - persistentStart; | |
| if (persistedTxtBase) { | |
| txtBaseState = uploadScratchBuffer("txt-base-persistent", persistedTxtBase, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| persistentTxtProjectionHit = true; | |
| if (textProjectionDeviceCache && textProjectionCacheKey) { | |
| textProjectionDeviceCache.set(textProjectionCacheKey, txtBaseState); | |
| } | |
| } | |
| } | |
| if (!txtBaseState && config.textProjectionUrl) { | |
| const staticStart = performance.now(); | |
| const staticTxtBase = await loadStaticFloat32(config.textProjectionUrl, txtRows * hidden); | |
| persistentTxtProjectionMs += performance.now() - staticStart; | |
| if (staticTxtBase) { | |
| txtBaseState = uploadScratchBuffer("txt-base-static", staticTxtBase, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| staticTxtProjectionHit = true; | |
| if (textProjectionDeviceCache && textProjectionCacheKey) { | |
| textProjectionDeviceCache.set(textProjectionCacheKey, txtBaseState); | |
| } | |
| if (persistentTextProjectionKey) { | |
| savePersistentFloat32(persistentTextProjectionKey, staticTxtBase).catch((err) => { | |
| console.warn("[custom-lowbit] could not persist static text projection", err); | |
| }); | |
| } | |
| } | |
| } | |
| if (!txtBaseState) { | |
| txtBaseState = createEmptyBuffer(device, txtBytes, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| } | |
| const imgStateA = scratchBuffer("img-state-a", imgBytes, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST); | |
| const txtStateA = scratchBuffer("txt-state-a", txtBytes, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST); | |
| const imgStateB = scratchBuffer("img-state-b", imgBytes, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST); | |
| const txtStateB = scratchBuffer("txt-state-b", txtBytes, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST); | |
| const imgMidState = scratchBuffer("img-mid-state", imgBytes, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| const txtMidState = scratchBuffer("txt-mid-state", txtBytes, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| const jointStateA = scratchBuffer("joint-state-a", jointBytes, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST); | |
| const jointStateB = scratchBuffer("joint-state-b", jointBytes, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST); | |
| const predBuffer = scratchBuffer("pred", latentBytes, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| const previousPredBuffer = useApproxAb2 | |
| ? scratchBuffer("previous-pred", latentBytes, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST) | |
| : scratchBuffer("previous-pred-empty", 4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST); | |
| const stepResourceCount = stepSubmitFusion ? denoiseStepCount : 1; | |
| const updateParamsBuffers = Array.from({length: stepResourceCount}, (_, index) => ( | |
| uploadScratchBuffer( | |
| `update-params-${index}`, | |
| new Float32Array([ | |
| imgRows * latentChannels, | |
| index < denoiseStepCount ? timesteps[index + 1] - timesteps[index] : 0, | |
| 0, | |
| 0, | |
| ]), | |
| GPUBufferUsage.STORAGE, | |
| ) | |
| )); | |
| const inputPackedWords = Math.max(imgRows * kWordsLatent, txtRows * kWordsContext); | |
| const inputPackedBuffer = scratchBuffer("input-packed", inputPackedWords * 4, GPUBufferUsage.STORAGE); | |
| const inputScalesBuffer = scratchBuffer("input-scales", Math.max(imgRows, txtRows) * 4, GPUBufferUsage.STORAGE); | |
| const packedHiddenBuffer = scratchBuffer("packed-hidden", maxRows * kWordsHidden * 4, GPUBufferUsage.STORAGE); | |
| const scalesBuffer = scratchBuffer("hidden-scales", maxRows * 4, GPUBufferUsage.STORAGE); | |
| const groupedHiddenScalesBuffer = scratchBuffer("grouped-hidden-scales", maxRows * hiddenQ4Groups * 4, GPUBufferUsage.STORAGE); | |
| const packedDoubleMlpBuffer = scratchBuffer("packed-double-mlp", maxRows * kWordsDoubleMlp * 4, GPUBufferUsage.STORAGE); | |
| const doubleMlpScalesBuffer = scratchBuffer("double-mlp-scales", maxRows * 4, GPUBufferUsage.STORAGE); | |
| const packedSingleMlpBuffer = scratchBuffer("packed-single-mlp", jointRows * kWordsSingleMlp * 4, GPUBufferUsage.STORAGE); | |
| const singleMlpScalesBuffer = scratchBuffer("single-mlp-scales", jointRows * 4, GPUBufferUsage.STORAGE); | |
| const singlePreF16Buffer = needsSingleLinear1F16PreNorm | |
| ? scratchBuffer("single-pre-f16", jointRows * hidden * 2, GPUBufferUsage.STORAGE) | |
| : scratchBuffer("single-pre-f16-empty", 4, GPUBufferUsage.STORAGE); | |
| const singleLinear2InF32Buffer = useSingleLinear2Q4 | |
| ? scratchBuffer("single-linear2-in-f32", jointRows * singleLinear2K * 4, GPUBufferUsage.STORAGE) | |
| : scratchBuffer("single-linear2-in-f32-empty", 4, GPUBufferUsage.STORAGE); | |
| const singleLinear2InF16Buffer = useSingleLinear2Q4 | |
| ? scratchBuffer("single-linear2-in-f16", jointRows * singleLinear2K * 2, GPUBufferUsage.STORAGE) | |
| : scratchBuffer("single-linear2-in-f16-empty", 4, GPUBufferUsage.STORAGE); | |
| const singleLinear2OutBuffer = scratchBuffer("single-linear2-out", jointRows * hidden * 4, GPUBufferUsage.STORAGE); | |
| const maxStorageBindingBytes = Math.max( | |
| jointRows * singleLinear1M * (useSingleLinear1OutputF16 ? 2 : 4), | |
| jointRows * singleLinear2K * 4, | |
| jointRows * singleLinear2K * 2, | |
| jointRows * hidden * 4, | |
| maxRows * doubleMlp1M * 2, | |
| ); | |
| const storageBindingLimit = Number(device.limits?.maxStorageBufferBindingSize || 128 * 1024 * 1024); | |
| if (maxStorageBindingBytes > storageBindingLimit) { | |
| throw new Error( | |
| `CUSTOM_TRANSFORMER_WEBGPU_LIMIT: required storage buffer binding ${maxStorageBindingBytes} bytes exceeds device limit ${storageBindingLimit} bytes`, | |
| ); | |
| } | |
| const imgQkvBuffer = scratchBuffer("img-qkv", imgRows * qkvM * 2, GPUBufferUsage.STORAGE); | |
| const txtQkvBuffer = scratchBuffer("txt-qkv", txtRows * qkvM * 2, GPUBufferUsage.STORAGE); | |
| const doubleQNormBuffer = scratchBuffer("double-q-norm", jointRows * hidden * 4, GPUBufferUsage.STORAGE); | |
| const doubleKNormBuffer = scratchBuffer("double-k-norm", jointRows * hidden * 4, GPUBufferUsage.STORAGE); | |
| const doubleVBuffer = scratchBuffer("double-v", jointRows * hidden * 4, GPUBufferUsage.STORAGE); | |
| const doubleAttentionOutBuffer = scratchBuffer("double-attention-out", jointRows * hidden * 4, GPUBufferUsage.STORAGE); | |
| const doubleMlp0Buffer = scratchBuffer("double-mlp0", maxRows * doubleMlp1M * 2, GPUBufferUsage.STORAGE); | |
| const doubleRopeFreqBuffer = createImmutableBuffer(device, buildFluxRopeFrequencies(), GPUBufferUsage.STORAGE, `${scratchKey}|double-rope-freq`); | |
| const singleLinear1OutBuffer = scratchBuffer("single-linear1-out", jointRows * singleLinear1M * (useSingleLinear1OutputF16 ? 2 : 4), GPUBufferUsage.STORAGE); | |
| const singleAttentionOutBuffer = scratchBuffer("single-attention-out", jointRows * hidden * 4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC); | |
| const singleTailDeltaEnabled = reuseSingleTailFromStep > 0 && reuseSingleTailFromBlock > 0 && reuseSingleTailFromBlock < singleBlockCount; | |
| const skippedOptionalSetupBytes = | |
| (needsSingleLinear1F16PreNorm ? 0 : jointRows * hidden * 2) + | |
| (useSingleLinear2Q4 ? 0 : jointRows * singleLinear2K * 6); | |
| const skippedOptionalPipelines = [ | |
| !needsSingleLinear1GroupedPreNorm, | |
| !needsSingleLinear1F16PreNorm, | |
| !useSingleLinear2Q4, | |
| !useSingleLinear2Q4, | |
| !useApproxAb2, | |
| !singleTailDeltaEnabled, | |
| !singleTailDeltaEnabled, | |
| ].filter(Boolean).length; | |
| const singleTailBaseCacheBuffer = singleTailDeltaEnabled | |
| ? scratchBuffer("single-tail-base-cache", jointBytes, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC) | |
| : null; | |
| const singleTailDeltaCacheBuffer = singleTailDeltaEnabled | |
| ? scratchBuffer("single-tail-delta-cache", jointBytes, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC) | |
| : null; | |
| const singleQkNormBytes = jointRows * hidden * (useSingleQkNormStorageF16 ? 2 : 4); | |
| const singleQNormBuffer = scratchBuffer("single-q-norm", singleQkNormBytes, GPUBufferUsage.STORAGE); | |
| const singleKNormBuffer = scratchBuffer("single-k-norm", singleQkNormBytes, GPUBufferUsage.STORAGE); | |
| const singleRopeSinCosBuffer = createImmutableBuffer(device, buildFluxRopeSinCos(jointRows, txtRows, imageWidth), GPUBufferUsage.STORAGE, `${scratchKey}|single-rope-sincos|w=${imageWidth}`); | |
| const modImgBuffer = scratchBuffer("mod-img", 18432 * 4, GPUBufferUsage.STORAGE); | |
| const modTxtBuffer = scratchBuffer("mod-txt", 18432 * 4, GPUBufferUsage.STORAGE); | |
| const modSingleBuffer = scratchBuffer("mod-single", 9216 * 4, GPUBufferUsage.STORAGE); | |
| const finalModBuffer = scratchBuffer("final-mod", 6144 * 4, GPUBufferUsage.STORAGE); | |
| const timestepBuffers = Array.from({length: stepResourceCount}, (_, index) => ( | |
| uploadScratchBuffer( | |
| `timestep-${index}`, | |
| makeTimestepEmbeddingF16(timesteps[Math.min(index, Math.max(0, denoiseStepCount - 1))]), | |
| GPUBufferUsage.STORAGE, | |
| ) | |
| )); | |
| const timeHiddenF32Buffer = scratchBuffer("time-hidden-f32", hidden * 4, GPUBufferUsage.STORAGE); | |
| const timeHiddenF16Buffer = scratchBuffer("time-hidden-f16", hidden * 2, GPUBufferUsage.STORAGE); | |
| const vecF32Buffer = scratchBuffer("vec-f32", hidden * 4, GPUBufferUsage.STORAGE); | |
| const vecF16Buffer = scratchBuffer("vec-f16", hidden * 2, GPUBufferUsage.STORAGE); | |
| const dotF32Code = makeI8ScaledDotShader32xWide(linearTileCols, wChunkCols, false, linearRowBlock); | |
| const dotF16Code = makeI8ScaledDotShader32xWide(linearTileCols, wChunkCols, true, linearRowBlock); | |
| const finalDotCode = makeI8ScaledDotShader32xWide(finalTileCols, wChunkCols, false, linearRowBlock); | |
| const dotResidualCode = makeI8ScaledDotResidualShader32xWide(projTileCols, wChunkCols, linearRowBlock); | |
| const jointAttentionCode = makeJointAttentionTiledShader(attentionTileKeys, attentionQueryRows); | |
| let singleQkNormCode = useSingleLinear1OutputF16 ? makeLinear1F16ConsumerShader(SINGLE_QK_NORM_ROPE_SHADER) : SINGLE_QK_NORM_ROPE_SHADER; | |
| if (useSingleQkNormStorageF16) { | |
| singleQkNormCode = makeSingleQkNormF16StorageShader(singleQkNormCode); | |
| } | |
| let singleAttentionCode = useSubgroupSingleAttention | |
| ? makeSingleAttentionSubgroupShader(attentionTileKeys) | |
| : makeSingleAttentionTiledShader(attentionTileKeys, attentionQueryRows); | |
| if (useSingleQkNormStorageF16) { | |
| singleAttentionCode = makeAttentionQkF16ConsumerShader(singleAttentionCode); | |
| } | |
| if (useSingleLinear1OutputF16) { | |
| singleAttentionCode = makeLinear1F16ConsumerShader(singleAttentionCode); | |
| } | |
| const singleActivateCode = useSingleLinear1OutputF16 | |
| ? makeLinear1F16ConsumerShader(SINGLE_MLP_ACTIVATE_ATTENTION_QUANT_SHADER) | |
| : SINGLE_MLP_ACTIVATE_ATTENTION_QUANT_SHADER; | |
| const singleActivateF32Code = useSingleLinear1OutputF16 | |
| ? makeLinear1F16ConsumerShader(SINGLE_MLP_ACTIVATION_ATTENTION_SHADER) | |
| : SINGLE_MLP_ACTIVATION_ATTENTION_SHADER; | |
| const [ | |
| quantF32Pipeline, | |
| preNormQuantPipeline, | |
| preNormGroupQuantPipeline, | |
| preNormF16Pipeline, | |
| quantOffsetPipeline, | |
| dotF32Pipeline, | |
| dotF16Pipeline, | |
| finalDotPipeline, | |
| dotResidualPipeline, | |
| doubleQkvNormPipeline, | |
| jointAttentionPipeline, | |
| doubleActivateMlpPipeline, | |
| singleQkNormPipeline, | |
| singleAttentionPipeline, | |
| singleActivatePipeline, | |
| singleActivateF32Pipeline, | |
| singleResidualPipeline, | |
| latentUpdatePipeline, | |
| latentUpdateAb2Pipeline, | |
| singleTailDeltaCachePipeline, | |
| singleTailDeltaApplyPipeline, | |
| ] = await Promise.all([ | |
| getCachedComputePipeline(device, "quant-f32", QUANTIZE_X_F32_SHADER), | |
| getCachedComputePipeline(device, "single-prenorm-mod-quant", SINGLE_PRENORM_MOD_QUANT_SHADER), | |
| needsSingleLinear1GroupedPreNorm | |
| ? getCachedComputePipeline(device, "single-prenorm-mod-group-quant", SINGLE_PRENORM_MOD_GROUP_QUANT_SHADER) | |
| : Promise.resolve(null), | |
| needsSingleLinear1F16PreNorm | |
| ? getCachedComputePipeline(device, "single-prenorm-mod-f16", SINGLE_PRENORM_MOD_F16_SHADER) | |
| : Promise.resolve(null), | |
| getCachedComputePipeline(device, "quant-f32-offset", QUANTIZE_X_F32_OFFSET_SHADER), | |
| getCachedComputePipeline(device, `i8-dot:${linearTileCols}:${wChunkCols}:rows${linearRowBlock}:f32`, dotF32Code), | |
| getCachedComputePipeline(device, `i8-dot:${linearTileCols}:${wChunkCols}:rows${linearRowBlock}:f16`, dotF16Code), | |
| getCachedComputePipeline(device, `i8-dot:${finalTileCols}:${wChunkCols}:rows${linearRowBlock}:final-f32`, finalDotCode), | |
| getCachedComputePipeline(device, `i8-residual-dot:${projTileCols}:${wChunkCols}:rows${linearRowBlock}`, dotResidualCode), | |
| getCachedComputePipeline(device, "double-qkv-norm-rope", DOUBLE_QKV_NORM_ROPE_SHADER), | |
| getCachedComputePipeline(device, `joint-attention:${attentionTileKeys}:${attentionQueryRows}`, jointAttentionCode), | |
| getCachedComputePipeline(device, "double-activate-mlp-f16", makeLinear1F16ConsumerShader(SINGLE_MLP_ACTIVATE_QUANT_SHADER)), | |
| getCachedComputePipeline(device, `single-qk-norm-rope:${useSingleLinear1OutputF16 ? "f16" : "f32"}:qk-${useSingleQkNormStorageF16 ? "f16" : "f32"}`, singleQkNormCode), | |
| getCachedComputePipeline(device, `single-attention:${useSubgroupSingleAttention ? "subgroup" : "tiled"}:${attentionTileKeys}:${attentionQueryRows}:${useSingleLinear1OutputF16 ? "f16" : "f32"}:qk-${useSingleQkNormStorageF16 ? "f16" : "f32"}`, singleAttentionCode), | |
| getCachedComputePipeline(device, `single-activate-quant:${useSingleLinear1OutputF16 ? "f16" : "f32"}`, singleActivateCode), | |
| useSingleLinear2Q4 | |
| ? getCachedComputePipeline(device, `single-activate-f32:${useSingleLinear1OutputF16 ? "f16" : "f32"}`, singleActivateF32Code) | |
| : Promise.resolve(null), | |
| useSingleLinear2Q4 | |
| ? getCachedComputePipeline(device, "single-residual-gate", SINGLE_RESIDUAL_GATE_SHADER) | |
| : Promise.resolve(null), | |
| getCachedComputePipeline(device, "latent-update-f32", LATENT_UPDATE_F32_SHADER), | |
| useApproxAb2 | |
| ? getCachedComputePipeline(device, "latent-update-ab2-f32", LATENT_UPDATE_AB2_F32_SHADER) | |
| : Promise.resolve(null), | |
| singleTailDeltaEnabled | |
| ? getCachedComputePipeline(device, "single-tail-delta-cache", SINGLE_TAIL_DELTA_CACHE_SHADER) | |
| : Promise.resolve(null), | |
| singleTailDeltaEnabled | |
| ? getCachedComputePipeline(device, "single-tail-delta-apply", SINGLE_TAIL_DELTA_APPLY_SHADER) | |
| : Promise.resolve(null), | |
| ]); | |
| const singleLinear2CastStage = useSingleLinear2Q4 | |
| ? await createF32ToF16Stage(device, singleLinear2InF32Buffer, singleLinear2InF16Buffer, jointRows * singleLinear2K, false) | |
| : null; | |
| const singleLinear1Q4Options = useSingleLinear1QkvDp4a | |
| ? {colOffset: qkvM, outputCols: singleLinear1M - qkvM, outputStride: singleLinear1M} | |
| : useSingleLinear1MlpDp4a | |
| ? {colOffset: 0, outputCols: qkvM, outputStride: singleLinear1M} | |
| : {}; | |
| const singleQ4TileCols = Number(config.singleQ4TileCols ?? 256); | |
| const singleQ4KChunk = Number(config.singleQ4KChunk ?? (singleQ4TileCols >= 256 ? 32 : 64)); | |
| const singleQ4Dp4aTileCols = Number(config.singleQ4Dp4aTileCols ?? 128); | |
| const makeSingleLinear1Q4StagesForBlock = async (blockIndex) => { | |
| const linear = findFullManifestLinear(config.fullManifest, `single_blocks.${blockIndex}.linear1`); | |
| if (useSingleLinear1Q4Dp4aQkvOnly) { | |
| return Promise.all([ | |
| createQ4Dp4aLinearStage( | |
| device, | |
| linear, | |
| baseUrl, | |
| packedHiddenBuffer, | |
| useSingleLinear1Q4GroupedActivation ? groupedHiddenScalesBuffer : scalesBuffer, | |
| singleLinear1OutBuffer, | |
| jointRows, | |
| singleQ4Dp4aTileCols, | |
| wChunkCols, | |
| useSingleLinear1OutputF16, | |
| { | |
| colOffset: 0, | |
| outputCols: qkvM, | |
| outputStride: singleLinear1M, | |
| xScaleGroupsPerRow: useSingleLinear1Q4GroupedActivation ? hiddenQ4Groups : 1, | |
| kChunk: singleQ4KChunk, | |
| cacheKey: `${scratchKey}|single-linear1-q4-dp4a-qkv|${blockIndex}`, | |
| }, | |
| ), | |
| createQ4LinearStage( | |
| device, | |
| linear, | |
| baseUrl, | |
| singlePreF16Buffer, | |
| singleLinear1OutBuffer, | |
| jointRows, | |
| singleQ4TileCols, | |
| useSingleLinear1OutputF16, | |
| { | |
| colOffset: qkvM, | |
| outputCols: singleLinear1M - qkvM, | |
| outputStride: singleLinear1M, | |
| kChunk: singleQ4KChunk, | |
| cacheKey: `${scratchKey}|single-linear1-q4-tail|${blockIndex}`, | |
| }, | |
| ), | |
| ]); | |
| } | |
| const stage = useSingleLinear1Q4Dp4a | |
| ? await createQ4Dp4aLinearStage( | |
| device, | |
| linear, | |
| baseUrl, | |
| packedHiddenBuffer, | |
| useSingleLinear1Q4GroupedActivation ? groupedHiddenScalesBuffer : scalesBuffer, | |
| singleLinear1OutBuffer, | |
| jointRows, | |
| singleQ4Dp4aTileCols, | |
| wChunkCols, | |
| useSingleLinear1OutputF16, | |
| { | |
| ...singleLinear1Q4Options, | |
| xScaleGroupsPerRow: useSingleLinear1Q4GroupedActivation ? hiddenQ4Groups : 1, | |
| kChunk: singleQ4KChunk, | |
| cacheKey: `${scratchKey}|single-linear1-q4|${blockIndex}`, | |
| }, | |
| ) | |
| : await createQ4LinearStage( | |
| device, | |
| linear, | |
| baseUrl, | |
| singlePreF16Buffer, | |
| singleLinear1OutBuffer, | |
| jointRows, | |
| singleQ4TileCols, | |
| useSingleLinear1OutputF16, | |
| { | |
| ...singleLinear1Q4Options, | |
| kChunk: singleQ4KChunk, | |
| cacheKey: `${scratchKey}|single-linear1-q4|${blockIndex}`, | |
| }, | |
| ); | |
| return [stage]; | |
| }; | |
| const singleLinear1Q4Stages = useSingleLinear1Q4 | |
| ? await Promise.all(Array.from({length: singleBlockCount}, (_, blockIndex) => makeSingleLinear1Q4StagesForBlock(blockIndex))) | |
| : []; | |
| const singleLinear2Q4Stages = useSingleLinear2Dp4a ? [] : await Promise.all(Array.from({length: singleBlockCount}, (_, blockIndex) => ( | |
| createQ4LinearStage( | |
| device, | |
| findFullManifestLinear(config.fullManifest, `single_blocks.${blockIndex}.linear2`), | |
| baseUrl, | |
| singleLinear2InF16Buffer, | |
| singleLinear2OutBuffer, | |
| jointRows, | |
| singleQ4TileCols, | |
| false, | |
| {kChunk: singleQ4KChunk, cacheKey: `${scratchKey}|single-linear2-q4|${blockIndex}`}, | |
| ) | |
| ))); | |
| const timeStage0Stages = await Promise.all(timestepBuffers.map((buffer, index) => ( | |
| createQ4LinearStage( | |
| device, | |
| findFullManifestLinear(config.fullManifest, "time_in.in_layer"), | |
| baseUrl, | |
| buffer, | |
| timeHiddenF32Buffer, | |
| 1, | |
| 96, | |
| false, | |
| {cacheKey: `${scratchKey}|time-in-0|step-${index}`}, | |
| ) | |
| ))); | |
| const timeSiluStage = await createF32ToF16Stage(device, timeHiddenF32Buffer, timeHiddenF16Buffer, hidden, true, `${scratchKey}|time-silu`); | |
| const timeStage1 = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "time_in.out_layer"), baseUrl, timeHiddenF16Buffer, vecF32Buffer, 1, 96, false, {cacheKey: `${scratchKey}|time-in-1`}); | |
| const vecCastStage = await createF32ToF16Stage(device, vecF32Buffer, vecF16Buffer, hidden, true, `${scratchKey}|vec-cast`); | |
| const modImgStage = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "double_stream_modulation_img.lin"), baseUrl, vecF16Buffer, modImgBuffer, 1, 96, false, {cacheKey: `${scratchKey}|mod-img`}); | |
| const modTxtStage = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "double_stream_modulation_txt.lin"), baseUrl, vecF16Buffer, modTxtBuffer, 1, 96, false, {cacheKey: `${scratchKey}|mod-txt`}); | |
| const modSingleStage = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "single_stream_modulation.lin"), baseUrl, vecF16Buffer, modSingleBuffer, 1, 96, false, {cacheKey: `${scratchKey}|mod-single`}); | |
| const finalModStage = await createQ4LinearStage(device, findFullManifestLinear(config.fullManifest, "final_layer.adaLN_modulation.1"), baseUrl, vecF16Buffer, finalModBuffer, 1, 96, false, {cacheKey: `${scratchKey}|final-mod`}); | |
| const makeLinearBuffers = (bundle) => ({ | |
| w: createImmutableBuffer(device, bundle.weight, GPUBufferUsage.STORAGE, bundle.weightCacheKey), | |
| s: createImmutableBuffer(device, bundle.scales, GPUBufferUsage.STORAGE, bundle.scalesCacheKey), | |
| }); | |
| const makeParams = (values) => createBuffer(device, new Uint32Array(values), GPUBufferUsage.UNIFORM); | |
| const res = (buffer, offset = 0) => offset ? {buffer, offset} : {buffer}; | |
| const runStage = (pass, pipeline, bindGroup, x, y = undefined) => { | |
| pass.setPipeline(pipeline); | |
| pass.setBindGroup(0, bindGroup); | |
| if (y === undefined) pass.dispatchWorkgroups(x); | |
| else pass.dispatchWorkgroups(x, y); | |
| }; | |
| const dotWorkgroupsY = (rows) => Math.ceil(rows / linearRowBlock); | |
| const imgInWeights = makeLinearBuffers(imgIn); | |
| const txtInWeights = makeLinearBuffers(txtIn); | |
| const finalWeights = makeLinearBuffers(finalLinear); | |
| const params = { | |
| quantLatent: makeParams([imgRows, latentChannels, hidden, kWordsLatent]), | |
| quantContext: makeParams([txtRows, contextDim, hidden, kWordsContext]), | |
| imgIn: makeParams([imgRows, latentChannels, hidden, kWordsLatent, 0, 0, 0, 0]), | |
| txtIn: makeParams([txtRows, contextDim, hidden, kWordsContext, 0, 0, 0, 0]), | |
| imgPre: makeParams([imgRows, hidden, qkvM, kWordsHidden]), | |
| txtPre: makeParams([txtRows, hidden, qkvM, kWordsHidden]), | |
| imgDotQkv: makeParams([imgRows, hidden, qkvM, kWordsHidden, 0, 0, 0, 0]), | |
| txtDotQkv: makeParams([txtRows, hidden, qkvM, kWordsHidden, 0, 0, 0, 0]), | |
| imgQkvNorm: makeParams([imgRows, qkvM, 24, 128, hidden, txtRows, txtRows, imageWidth]), | |
| txtQkvNorm: makeParams([txtRows, qkvM, 24, 128, hidden, 0, txtRows, imageWidth]), | |
| doubleAttention: makeParams([jointRows, 24, 128, hidden]), | |
| imgQuantAttn: makeParams([imgRows, hidden, kWordsHidden, txtRows, hidden]), | |
| txtQuantAttn: makeParams([txtRows, hidden, kWordsHidden, 0, hidden]), | |
| imgProj: makeParams([imgRows, hidden, hidden, kWordsHidden, 0, 0, 0, 0]), | |
| txtProj: makeParams([txtRows, hidden, hidden, kWordsHidden, 0, 0, 0, 0]), | |
| imgMlpPre: makeParams([imgRows, hidden, doubleMlp1M, kWordsHidden]), | |
| txtMlpPre: makeParams([txtRows, hidden, doubleMlp1M, kWordsHidden]), | |
| imgMlp0: makeParams([imgRows, hidden, doubleMlp1M, kWordsHidden, 0, 0, 0, 0]), | |
| txtMlp0: makeParams([txtRows, hidden, doubleMlp1M, kWordsHidden, 0, 0, 0, 0]), | |
| imgAct: makeParams([imgRows, doubleMlp1M, doubleMlpK, 0, 0, doubleMlpK, kWordsDoubleMlp, 0]), | |
| txtAct: makeParams([txtRows, doubleMlp1M, doubleMlpK, 0, 0, doubleMlpK, kWordsDoubleMlp, 0]), | |
| imgMlp2: makeParams([imgRows, doubleMlpK, hidden, kWordsDoubleMlp, 0, 0, 0, 0]), | |
| txtMlp2: makeParams([txtRows, doubleMlpK, hidden, kWordsDoubleMlp, 0, 0, 0, 0]), | |
| singlePre: makeParams([jointRows, hidden, singleLinear1M, kWordsHidden]), | |
| singlePreGroup: makeParams([jointRows, hidden, kWordsHidden, 16, hiddenQ4Groups, 0, 0, 0]), | |
| singleDot1: makeParams([jointRows, hidden, singleLinear1M, kWordsHidden, 0, 0, 0, 0]), | |
| singleAttention: makeParams([jointRows, singleLinear1M, 24, 128, 9216, hidden, txtRows, imageWidth]), | |
| singleAct: makeParams([jointRows, singleLinear1M, singleLinear2K, hidden, 9216, 9216, kWordsSingleMlp, 0]), | |
| singleDot2: makeParams([jointRows, singleLinear2K, hidden, kWordsSingleMlp, 0, 0, 0, 0]), | |
| singleResidual: makeParams([jointRows, hidden, 0, 0]), | |
| singleTailDelta: makeParams([jointRows * hidden, 0, 0, 0]), | |
| finalPre: makeParams([imgRows, hidden, latentChannels, kWordsHidden]), | |
| finalDot: makeParams([imgRows, hidden, latentChannels, kWordsHidden, 0, 0, 0, 0]), | |
| }; | |
| const makeQuantF32Bind = (input, packed, scales, paramsBuffer) => device.createBindGroup({ | |
| layout: quantF32Pipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: input}}, | |
| {binding: 1, resource: {buffer: packed}}, | |
| {binding: 2, resource: {buffer: scales}}, | |
| {binding: 3, resource: {buffer: paramsBuffer}}, | |
| ], | |
| }); | |
| const makeDotBind = (pipeline, packed, packedScales, w, wScales, out, paramsBuffer) => device.createBindGroup({ | |
| layout: pipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: packed}}, | |
| {binding: 1, resource: {buffer: w}}, | |
| {binding: 2, resource: {buffer: packedScales}}, | |
| {binding: 3, resource: {buffer: wScales}}, | |
| {binding: 4, resource: {buffer: out}}, | |
| {binding: 5, resource: {buffer: paramsBuffer}}, | |
| ], | |
| }); | |
| const makePreNormBind = (state, modBuffer, shiftOffset, scaleOffset, paramsBuffer, packed = packedHiddenBuffer, scales = scalesBuffer, stateOffset = 0) => device.createBindGroup({ | |
| layout: preNormQuantPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: res(state, stateOffset)}, | |
| {binding: 1, resource: res(modBuffer, shiftOffset)}, | |
| {binding: 2, resource: res(modBuffer, scaleOffset)}, | |
| {binding: 3, resource: {buffer: packed}}, | |
| {binding: 4, resource: {buffer: scales}}, | |
| {binding: 5, resource: {buffer: paramsBuffer}}, | |
| ], | |
| }); | |
| const makePreNormF16Bind = (state, modBuffer, shiftOffset, scaleOffset, paramsBuffer, output = singlePreF16Buffer) => device.createBindGroup({ | |
| layout: preNormF16Pipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: state}}, | |
| {binding: 1, resource: res(modBuffer, shiftOffset)}, | |
| {binding: 2, resource: res(modBuffer, scaleOffset)}, | |
| {binding: 3, resource: {buffer: output}}, | |
| {binding: 4, resource: {buffer: paramsBuffer}}, | |
| ], | |
| }); | |
| const makePreNormGroupBind = (state, modBuffer, shiftOffset, scaleOffset, paramsBuffer) => device.createBindGroup({ | |
| layout: preNormGroupQuantPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: state}}, | |
| {binding: 1, resource: res(modBuffer, shiftOffset)}, | |
| {binding: 2, resource: res(modBuffer, scaleOffset)}, | |
| {binding: 3, resource: {buffer: packedHiddenBuffer}}, | |
| {binding: 4, resource: {buffer: groupedHiddenScalesBuffer}}, | |
| {binding: 5, resource: {buffer: paramsBuffer}}, | |
| ], | |
| }); | |
| const makeResidualDotBind = ( | |
| packed, | |
| packedScales, | |
| w, | |
| wScales, | |
| out, | |
| paramsBuffer, | |
| residual, | |
| gate, | |
| outOffset = 0, | |
| residualOffset = 0, | |
| ) => device.createBindGroup({ | |
| layout: dotResidualPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: packed}}, | |
| {binding: 1, resource: {buffer: w}}, | |
| {binding: 2, resource: {buffer: packedScales}}, | |
| {binding: 3, resource: {buffer: wScales}}, | |
| {binding: 4, resource: res(out, outOffset)}, | |
| {binding: 5, resource: {buffer: paramsBuffer}}, | |
| {binding: 6, resource: res(residual, residualOffset)}, | |
| {binding: 7, resource: gate}, | |
| ], | |
| }); | |
| const inputProjectionBinds = { | |
| quantLatent: makeQuantF32Bind(latentF32Buffer, inputPackedBuffer, inputScalesBuffer, params.quantLatent), | |
| imgIn: makeDotBind(dotF32Pipeline, inputPackedBuffer, inputScalesBuffer, imgInWeights.w, imgInWeights.s, imgStateA, params.imgIn), | |
| quantContext: makeQuantF32Bind(contextF32Buffer, inputPackedBuffer, inputScalesBuffer, params.quantContext), | |
| txtIn: makeDotBind(dotF32Pipeline, inputPackedBuffer, inputScalesBuffer, txtInWeights.w, txtInWeights.s, txtBaseState, params.txtIn), | |
| }; | |
| const hiddenModOffsets = { | |
| mod1Shift: 0, | |
| mod1Scale: hidden * 4, | |
| mod1Gate: hidden * 8, | |
| mod2Shift: hidden * 12, | |
| mod2Scale: hidden * 16, | |
| mod2Gate: hidden * 20, | |
| }; | |
| const singleMod = { | |
| shift: 0, | |
| scale: hidden * 4, | |
| gate: hidden * 8, | |
| }; | |
| const doubleCommonBind = { | |
| attention: device.createBindGroup({ | |
| layout: jointAttentionPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: doubleVBuffer}}, | |
| {binding: 1, resource: {buffer: doubleQNormBuffer}}, | |
| {binding: 2, resource: {buffer: doubleKNormBuffer}}, | |
| {binding: 3, resource: {buffer: doubleAttentionOutBuffer}}, | |
| {binding: 4, resource: {buffer: params.doubleAttention}}, | |
| ], | |
| }), | |
| imgQuantAttn: device.createBindGroup({ | |
| layout: quantOffsetPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: doubleAttentionOutBuffer}}, | |
| {binding: 1, resource: {buffer: packedHiddenBuffer}}, | |
| {binding: 2, resource: {buffer: scalesBuffer}}, | |
| {binding: 3, resource: {buffer: params.imgQuantAttn}}, | |
| ], | |
| }), | |
| txtQuantAttn: device.createBindGroup({ | |
| layout: quantOffsetPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: doubleAttentionOutBuffer}}, | |
| {binding: 1, resource: {buffer: packedHiddenBuffer}}, | |
| {binding: 2, resource: {buffer: scalesBuffer}}, | |
| {binding: 3, resource: {buffer: params.txtQuantAttn}}, | |
| ], | |
| }), | |
| imgAct: device.createBindGroup({ | |
| layout: doubleActivateMlpPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: doubleMlp0Buffer}}, | |
| {binding: 1, resource: {buffer: packedDoubleMlpBuffer}}, | |
| {binding: 2, resource: {buffer: doubleMlpScalesBuffer}}, | |
| {binding: 3, resource: {buffer: params.imgAct}}, | |
| ], | |
| }), | |
| txtAct: device.createBindGroup({ | |
| layout: doubleActivateMlpPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: doubleMlp0Buffer}}, | |
| {binding: 1, resource: {buffer: packedDoubleMlpBuffer}}, | |
| {binding: 2, resource: {buffer: doubleMlpScalesBuffer}}, | |
| {binding: 3, resource: {buffer: params.txtAct}}, | |
| ], | |
| }), | |
| }; | |
| const doubleStages = doubleBlocks.map((block, offset) => { | |
| const ping = offset & 1; | |
| const isLastDoubleBlock = offset === doubleBlocks.length - 1; | |
| const imgInputState = ping ? imgStateB : imgStateA; | |
| const txtInputState = offset === 0 ? txtBaseState : (ping ? txtStateB : txtStateA); | |
| const imgOutputState = isLastDoubleBlock ? jointStateA : (ping ? imgStateA : imgStateB); | |
| const txtOutputState = isLastDoubleBlock ? jointStateA : (ping ? txtStateA : txtStateB); | |
| const imgOutputOffset = isLastDoubleBlock ? txtBytes : 0; | |
| const txtOutputOffset = 0; | |
| const weights = { | |
| imgQkv: makeLinearBuffers(block.imgQkv), | |
| txtQkv: makeLinearBuffers(block.txtQkv), | |
| imgProj: makeLinearBuffers(block.imgProj), | |
| txtProj: makeLinearBuffers(block.txtProj), | |
| imgMlp0: makeLinearBuffers(block.imgMlp0), | |
| imgMlp2: makeLinearBuffers(block.imgMlp2), | |
| txtMlp0: makeLinearBuffers(block.txtMlp0), | |
| txtMlp2: makeLinearBuffers(block.txtMlp2), | |
| }; | |
| const imgQueryScaleBuffer = createBuffer(device, block.imgNorm.queryScale, GPUBufferUsage.STORAGE); | |
| const imgKeyScaleBuffer = createBuffer(device, block.imgNorm.keyScale, GPUBufferUsage.STORAGE); | |
| const txtQueryScaleBuffer = createBuffer(device, block.txtNorm.queryScale, GPUBufferUsage.STORAGE); | |
| const txtKeyScaleBuffer = createBuffer(device, block.txtNorm.keyScale, GPUBufferUsage.STORAGE); | |
| return { | |
| imgQkvPre: makePreNormBind(imgInputState, modImgBuffer, hiddenModOffsets.mod1Shift, hiddenModOffsets.mod1Scale, params.imgPre), | |
| txtQkvPre: makePreNormBind(txtInputState, modTxtBuffer, hiddenModOffsets.mod1Shift, hiddenModOffsets.mod1Scale, params.txtPre), | |
| imgQkvDot: makeDotBind(dotF16Pipeline, packedHiddenBuffer, scalesBuffer, weights.imgQkv.w, weights.imgQkv.s, imgQkvBuffer, params.imgDotQkv), | |
| txtQkvDot: makeDotBind(dotF16Pipeline, packedHiddenBuffer, scalesBuffer, weights.txtQkv.w, weights.txtQkv.s, txtQkvBuffer, params.txtDotQkv), | |
| imgQkvNorm: device.createBindGroup({ | |
| layout: doubleQkvNormPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: imgQkvBuffer}}, | |
| {binding: 1, resource: {buffer: imgQueryScaleBuffer}}, | |
| {binding: 2, resource: {buffer: imgKeyScaleBuffer}}, | |
| {binding: 3, resource: {buffer: doubleQNormBuffer}}, | |
| {binding: 4, resource: {buffer: doubleKNormBuffer}}, | |
| {binding: 5, resource: {buffer: doubleVBuffer}}, | |
| {binding: 6, resource: {buffer: params.imgQkvNorm}}, | |
| {binding: 7, resource: {buffer: doubleRopeFreqBuffer}}, | |
| ], | |
| }), | |
| txtQkvNorm: device.createBindGroup({ | |
| layout: doubleQkvNormPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: txtQkvBuffer}}, | |
| {binding: 1, resource: {buffer: txtQueryScaleBuffer}}, | |
| {binding: 2, resource: {buffer: txtKeyScaleBuffer}}, | |
| {binding: 3, resource: {buffer: doubleQNormBuffer}}, | |
| {binding: 4, resource: {buffer: doubleKNormBuffer}}, | |
| {binding: 5, resource: {buffer: doubleVBuffer}}, | |
| {binding: 6, resource: {buffer: params.txtQkvNorm}}, | |
| {binding: 7, resource: {buffer: doubleRopeFreqBuffer}}, | |
| ], | |
| }), | |
| imgProj: makeResidualDotBind(packedHiddenBuffer, scalesBuffer, weights.imgProj.w, weights.imgProj.s, imgMidState, params.imgProj, imgInputState, res(modImgBuffer, hiddenModOffsets.mod1Gate)), | |
| txtProj: makeResidualDotBind(packedHiddenBuffer, scalesBuffer, weights.txtProj.w, weights.txtProj.s, txtMidState, params.txtProj, txtInputState, res(modTxtBuffer, hiddenModOffsets.mod1Gate)), | |
| imgMlpPre: makePreNormBind(imgMidState, modImgBuffer, hiddenModOffsets.mod2Shift, hiddenModOffsets.mod2Scale, params.imgMlpPre), | |
| txtMlpPre: makePreNormBind(txtMidState, modTxtBuffer, hiddenModOffsets.mod2Shift, hiddenModOffsets.mod2Scale, params.txtMlpPre), | |
| imgMlp0: makeDotBind(dotF16Pipeline, packedHiddenBuffer, scalesBuffer, weights.imgMlp0.w, weights.imgMlp0.s, doubleMlp0Buffer, params.imgMlp0), | |
| txtMlp0: makeDotBind(dotF16Pipeline, packedHiddenBuffer, scalesBuffer, weights.txtMlp0.w, weights.txtMlp0.s, doubleMlp0Buffer, params.txtMlp0), | |
| imgMlp2: makeResidualDotBind(packedDoubleMlpBuffer, doubleMlpScalesBuffer, weights.imgMlp2.w, weights.imgMlp2.s, imgOutputState, params.imgMlp2, imgMidState, res(modImgBuffer, hiddenModOffsets.mod2Gate), imgOutputOffset), | |
| txtMlp2: makeResidualDotBind(packedDoubleMlpBuffer, doubleMlpScalesBuffer, weights.txtMlp2.w, weights.txtMlp2.s, txtOutputState, params.txtMlp2, txtMidState, res(modTxtBuffer, hiddenModOffsets.mod2Gate), txtOutputOffset), | |
| imgOutputState, | |
| txtOutputState, | |
| imgOutputOffset, | |
| txtOutputOffset, | |
| }; | |
| }); | |
| const singleAttentionBind = device.createBindGroup({ | |
| layout: singleAttentionPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: singleLinear1OutBuffer}}, | |
| {binding: 1, resource: {buffer: singleQNormBuffer}}, | |
| {binding: 2, resource: {buffer: singleKNormBuffer}}, | |
| {binding: 3, resource: {buffer: singleAttentionOutBuffer}}, | |
| {binding: 4, resource: {buffer: params.singleAttention}}, | |
| ], | |
| }); | |
| const singleActivateBind = device.createBindGroup({ | |
| layout: singleActivatePipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: singleLinear1OutBuffer}}, | |
| {binding: 1, resource: {buffer: singleAttentionOutBuffer}}, | |
| {binding: 2, resource: {buffer: packedSingleMlpBuffer}}, | |
| {binding: 3, resource: {buffer: singleMlpScalesBuffer}}, | |
| {binding: 4, resource: {buffer: params.singleAct}}, | |
| ], | |
| }); | |
| const singleActivateF32Bind = useSingleLinear2Q4 ? device.createBindGroup({ | |
| layout: singleActivateF32Pipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: singleLinear1OutBuffer}}, | |
| {binding: 1, resource: {buffer: singleAttentionOutBuffer}}, | |
| {binding: 2, resource: {buffer: singleLinear2InF32Buffer}}, | |
| {binding: 3, resource: {buffer: params.singleAct}}, | |
| ], | |
| }) : null; | |
| const singleStages = singleBlocks.map((block, offset) => { | |
| const ping = offset & 1; | |
| const inputState = ping ? jointStateB : jointStateA; | |
| const outputState = ping ? jointStateA : jointStateB; | |
| const weights = { | |
| linear1: makeLinearBuffers(block.linear1), | |
| linear2: makeLinearBuffers(block.linear2), | |
| }; | |
| const norm = singleNorms[offset]; | |
| const queryScaleBuffer = createImmutableBuffer(device, norm.queryScale, GPUBufferUsage.STORAGE, norm.queryScaleUrl); | |
| const keyScaleBuffer = createImmutableBuffer(device, norm.keyScale, GPUBufferUsage.STORAGE, norm.keyScaleUrl); | |
| return { | |
| preNorm: makePreNormBind(inputState, modSingleBuffer, singleMod.shift, singleMod.scale, params.singlePre), | |
| preNormF16: needsSingleLinear1F16PreNorm | |
| ? makePreNormF16Bind(inputState, modSingleBuffer, singleMod.shift, singleMod.scale, params.singlePre) | |
| : null, | |
| preNormGroup: needsSingleLinear1GroupedPreNorm | |
| ? makePreNormGroupBind(inputState, modSingleBuffer, singleMod.shift, singleMod.scale, params.singlePreGroup) | |
| : null, | |
| dot1Q4Stages: singleLinear1Q4Stages[offset] || [], | |
| dot1: makeDotBind(dotF16Pipeline, packedHiddenBuffer, scalesBuffer, weights.linear1.w, weights.linear1.s, singleLinear1OutBuffer, params.singleDot1), | |
| qkNorm: device.createBindGroup({ | |
| layout: singleQkNormPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: singleLinear1OutBuffer}}, | |
| {binding: 1, resource: {buffer: queryScaleBuffer}}, | |
| {binding: 2, resource: {buffer: keyScaleBuffer}}, | |
| {binding: 3, resource: {buffer: singleQNormBuffer}}, | |
| {binding: 4, resource: {buffer: singleKNormBuffer}}, | |
| {binding: 5, resource: {buffer: params.singleAttention}}, | |
| {binding: 6, resource: {buffer: singleRopeSinCosBuffer}}, | |
| ], | |
| }), | |
| dot2Q4: useSingleLinear2Dp4a ? null : singleLinear2Q4Stages[offset], | |
| residualQ4: useSingleLinear2Q4 ? device.createBindGroup({ | |
| layout: singleResidualPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: inputState}}, | |
| {binding: 1, resource: {buffer: singleLinear2OutBuffer}}, | |
| {binding: 2, resource: res(modSingleBuffer, singleMod.gate)}, | |
| {binding: 3, resource: {buffer: outputState}}, | |
| {binding: 4, resource: {buffer: params.singleResidual}}, | |
| ], | |
| }) : null, | |
| dot2: makeResidualDotBind(packedSingleMlpBuffer, singleMlpScalesBuffer, weights.linear2.w, weights.linear2.s, outputState, params.singleDot2, inputState, res(modSingleBuffer, singleMod.gate)), | |
| }; | |
| }); | |
| const finalInputState = (singleStages.length & 1) ? jointStateB : jointStateA; | |
| const singleAttentionWorkgroupsX = useSubgroupSingleAttention | |
| ? jointRows | |
| : Math.ceil(jointRows / attentionQueryRows); | |
| const singleTailDeltaCacheBind = singleTailDeltaEnabled ? device.createBindGroup({ | |
| layout: singleTailDeltaCachePipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: singleTailBaseCacheBuffer}}, | |
| {binding: 1, resource: {buffer: finalInputState}}, | |
| {binding: 2, resource: {buffer: singleTailDeltaCacheBuffer}}, | |
| {binding: 3, resource: {buffer: params.singleTailDelta}}, | |
| ], | |
| }) : null; | |
| const singleTailDeltaApplyBind = singleTailDeltaEnabled ? device.createBindGroup({ | |
| layout: singleTailDeltaApplyPipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: finalInputState}}, | |
| {binding: 1, resource: {buffer: singleTailDeltaCacheBuffer}}, | |
| {binding: 2, resource: {buffer: params.singleTailDelta}}, | |
| ], | |
| }) : null; | |
| const finalNormBind = makePreNormBind(finalInputState, finalModBuffer, 0, hidden * 4, params.finalPre, packedHiddenBuffer, scalesBuffer, txtBytes); | |
| const finalDotBind = makeDotBind(finalDotPipeline, packedHiddenBuffer, scalesBuffer, finalWeights.w, finalWeights.s, predBuffer, params.finalDot); | |
| const latentUpdateBinds = updateParamsBuffers.map((buffer) => device.createBindGroup({ | |
| layout: latentUpdatePipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: latentF32Buffer}}, | |
| {binding: 1, resource: {buffer: predBuffer}}, | |
| {binding: 2, resource: {buffer}}, | |
| ], | |
| })); | |
| const latentUpdateAb2Binds = useApproxAb2 ? updateParamsBuffers.map((buffer) => device.createBindGroup({ | |
| layout: latentUpdateAb2Pipeline.getBindGroupLayout(0), | |
| entries: [ | |
| {binding: 0, resource: {buffer: latentF32Buffer}}, | |
| {binding: 1, resource: {buffer: predBuffer}}, | |
| {binding: 2, resource: {buffer: previousPredBuffer}}, | |
| {binding: 3, resource: {buffer}}, | |
| ], | |
| })) : []; | |
| const stageSetupMs = performance.now() - stageSetupStart; | |
| if (config.prepareOnly === true) { | |
| return { | |
| verdict: "custom-flux-transformer-stage-prepared", | |
| latentF16: new Uint16Array(0), | |
| config: { | |
| imageTokens: imgRows, | |
| textTokens: txtRows, | |
| jointRows, | |
| latentChannels, | |
| imageWidth, | |
| doubleBlockCount, | |
| singleBlockCount, | |
| steps: timesteps.length - 1, | |
| linearTileCols, | |
| projTileCols, | |
| finalTileCols, | |
| wChunkCols, | |
| linearRowBlock, | |
| singleQ4KChunk, | |
| attentionTileKeys, | |
| attentionQueryRows, | |
| singleQkNormStorage: useSingleQkNormStorageF16 ? "f16" : "f32", | |
| singleLinear1Output: useSingleLinear1OutputF16 ? "f16" : "f32", | |
| singleLinear1Q4Kernel: useSingleLinear1Q4Dp4aQkvOnly ? "dp4a-qkv" : (useSingleLinear1Q4Dp4a ? "dp4a" : "f16"), | |
| singleLinear1Q4ActivationScale: useSingleLinear1Q4GroupedActivation ? "group16" : "row", | |
| singleLinear1QkvBackend: useSingleLinear1QkvDp4a ? "dp4a" : "q4", | |
| singleLinear1MlpBackend: useSingleLinear1MlpDp4a ? "dp4a" : "q4", | |
| singleLinear2Backend: useSingleLinear2Dp4a ? "dp4a" : "q4", | |
| singleAttentionKernel: useSubgroupSingleAttention ? "subgroup" : "tiled", | |
| stepSubmitFusion, | |
| singleComputePass: false, | |
| skippedOptionalSetupBytes, | |
| skippedOptionalPipelines, | |
| approxReusePredEvery, | |
| approxPredictionMode: useApproxAb2 ? "ab2" : "raw", | |
| approxAbScale: useApproxAb2 ? approxAbScale : 0, | |
| skipSingleBlocksFromStep, | |
| skipSingleBlocksFromBlock, | |
| reuseSingleTailFromStep, | |
| reuseSingleTailFromBlock, | |
| }, | |
| load: { | |
| total_ms: loadMs, | |
| stage_setup_ms: stageSetupMs, | |
| text_projection_init_ms: 0, | |
| text_projection_cache_hit: Boolean(cachedTxtBaseState || persistentTxtProjectionHit || staticTxtProjectionHit), | |
| text_projection_cache_source: cachedTxtBaseState ? "gpu" : (persistentTxtProjectionHit ? "indexeddb" : (staticTxtProjectionHit ? "static-file" : "not-initialized")), | |
| text_projection_persistent_ms: persistentTxtProjectionMs, | |
| readback_ms: 0, | |
| }, | |
| summary: { | |
| total_step_ms: 0, | |
| median_step_ms: 0, | |
| finite: 0, | |
| max_abs: 0, | |
| }, | |
| step_ms: [], | |
| sample: [], | |
| }; | |
| } | |
| const initStart = performance.now(); | |
| if (!cachedTxtBaseState && !persistentTxtProjectionHit && !staticTxtProjectionHit) { | |
| const encoder = device.createCommandEncoder(); | |
| const pass = encoder.beginComputePass(); | |
| runStage(pass, quantF32Pipeline, inputProjectionBinds.quantContext, txtRows); | |
| runStage(pass, dotF32Pipeline, inputProjectionBinds.txtIn, Math.ceil(hidden / linearTileCols), dotWorkgroupsY(txtRows)); | |
| pass.end(); | |
| device.queue.submit([encoder.finish()]); | |
| await device.queue.onSubmittedWorkDone(); | |
| if (persistentTextProjectionKey) { | |
| const saveStart = performance.now(); | |
| const txtProjection = await readFloat32Buffer(device, txtBaseState, txtRows * hidden); | |
| const saved = await savePersistentFloat32(persistentTextProjectionKey, txtProjection); | |
| persistentTxtProjectionMs += performance.now() - saveStart; | |
| if (!saved) persistentTxtProjectionMs = -Math.abs(persistentTxtProjectionMs || 1); | |
| } | |
| if (textProjectionDeviceCache && textProjectionCacheKey) { | |
| textProjectionDeviceCache.set(textProjectionCacheKey, txtBaseState); | |
| } | |
| } | |
| const initMs = performance.now() - initStart; | |
| if (config.textProjectionOnly === true) { | |
| return { | |
| verdict: "custom-flux-transformer-text-projection-prepared", | |
| latentF16: new Uint16Array(0), | |
| config: { | |
| imageTokens: imgRows, | |
| textTokens: txtRows, | |
| jointRows, | |
| latentChannels, | |
| imageWidth, | |
| linearTileCols, | |
| }, | |
| load: { | |
| total_ms: loadMs, | |
| stage_setup_ms: stageSetupMs, | |
| text_projection_init_ms: initMs, | |
| text_projection_cache_hit: Boolean(cachedTxtBaseState || persistentTxtProjectionHit || staticTxtProjectionHit), | |
| text_projection_cache_source: cachedTxtBaseState ? "gpu" : (persistentTxtProjectionHit ? "indexeddb" : (staticTxtProjectionHit ? "static-file" : "computed")), | |
| text_projection_persistent_ms: persistentTxtProjectionMs, | |
| readback_ms: 0, | |
| }, | |
| summary: { | |
| total_step_ms: 0, | |
| median_step_ms: 0, | |
| finite: 0, | |
| max_abs: 0, | |
| }, | |
| step_ms: [], | |
| sample: [], | |
| }; | |
| } | |
| if (String(config.debugStopAfter || "").toLowerCase() === "input_projection") { | |
| const encoder = device.createCommandEncoder(); | |
| const pass = encoder.beginComputePass(); | |
| runStage(pass, quantF32Pipeline, inputProjectionBinds.quantLatent, imgRows); | |
| runStage(pass, dotF32Pipeline, inputProjectionBinds.imgIn, Math.ceil(hidden / linearTileCols), dotWorkgroupsY(imgRows)); | |
| pass.end(); | |
| device.queue.submit([encoder.finish()]); | |
| await device.queue.onSubmittedWorkDone(); | |
| const imgCount = Math.min( | |
| imgRows * hidden, | |
| Number(config.debugImgReadbackCount ?? config.debugReadbackCount ?? imgRows * hidden), | |
| ); | |
| const txtCount = Math.min( | |
| txtRows * hidden, | |
| Number(config.debugTxtReadbackCount ?? Math.min(4096, txtRows * hidden)), | |
| ); | |
| const [imgProjection, txtProjection] = await Promise.all([ | |
| readFloat32Buffer(device, imgStateA, imgCount), | |
| readFloat32Buffer(device, txtBaseState, txtCount), | |
| ]); | |
| return { | |
| verdict: "custom-flux-transformer-debug-input-projection", | |
| latentF16: new Uint16Array(0), | |
| config: { | |
| imageTokens: imgRows, | |
| textTokens: txtRows, | |
| jointRows, | |
| latentChannels, | |
| imageWidth, | |
| debugStopAfter: "input_projection", | |
| linearTileCols, | |
| wChunkCols, | |
| }, | |
| load: {total_ms: loadMs, text_projection_init_ms: initMs}, | |
| debug: { | |
| input_projection: { | |
| img: { | |
| shape: [imgRows, hidden], | |
| count: imgProjection.length, | |
| stats: floatArrayStats(imgProjection), | |
| values: Array.from(imgProjection), | |
| }, | |
| txt: { | |
| shape: [txtRows, hidden], | |
| count: txtProjection.length, | |
| stats: floatArrayStats(txtProjection), | |
| values: Array.from(txtProjection), | |
| }, | |
| }, | |
| }, | |
| summary: { | |
| finite: imgProjection.length + txtProjection.length, | |
| max_abs: Math.max( | |
| ...Array.from(imgProjection, (value) => Math.abs(value)), | |
| ...Array.from(txtProjection, (value) => Math.abs(value)), | |
| 0, | |
| ), | |
| }, | |
| }; | |
| } | |
| const stepTimes = []; | |
| let approxReusedSteps = 0; | |
| let approxCollapsedUpdateDispatches = 0; | |
| let skippedSingleBlocks = 0; | |
| let reusedSingleTailBlocks = 0; | |
| const phaseProfile = { | |
| time_input_ms: 0, | |
| double_blocks_ms: 0, | |
| single_blocks_ms: 0, | |
| single_linear1_ms: 0, | |
| single_attention_ms: 0, | |
| single_mlp_ms: 0, | |
| final_update_ms: 0, | |
| }; | |
| const deferStepWait = !profilePhases && config.deferStepWait !== false && !config.debugStopAfter; | |
| const singleComputePass = config.singleComputePass === true && !config.debugStopAfter && !profilePhases; | |
| const allStepsStart = performance.now(); | |
| const fusedEncoder = stepSubmitFusion ? device.createCommandEncoder() : null; | |
| let encodedStepCount = 0; | |
| let previousPredAvailable = false; | |
| for (let step = 0; step < denoiseStepCount; ++step) { | |
| const tCurr = timesteps[step]; | |
| const tPrev = timesteps[step + 1]; | |
| const start = performance.now(); | |
| const reusePreviousPrediction = approxReusePredEvery > 1 && step > 0 && (step % approxReusePredEvery) !== 0; | |
| if (reusePreviousPrediction) { | |
| approxReusedSteps += 1; | |
| stepTimes.push(0); | |
| continue; | |
| } | |
| const currentDt = tPrev - tCurr; | |
| let combinedDt = currentDt; | |
| let combinedReuseSteps = 0; | |
| if (approxReusePredEvery > 1) { | |
| for (let lookahead = step + 1; lookahead < timesteps.length - 1; lookahead += 1) { | |
| const lookaheadReusesPrediction = lookahead > 0 && (lookahead % approxReusePredEvery) !== 0; | |
| if (!lookaheadReusesPrediction) { | |
| break; | |
| } | |
| combinedDt += timesteps[lookahead + 1] - timesteps[lookahead]; | |
| combinedReuseSteps += 1; | |
| } | |
| } | |
| approxCollapsedUpdateDispatches += combinedReuseSteps; | |
| const reuseDt = combinedDt - currentDt; | |
| const useAb2LatentUpdate = useApproxAb2 && previousPredAvailable && combinedReuseSteps > 0 && Math.abs(reuseDt) > 0; | |
| const stepResourceIndex = stepSubmitFusion ? step : 0; | |
| if (!stepSubmitFusion) { | |
| device.queue.writeBuffer(updateParamsBuffers[0], 0, new Float32Array([ | |
| imgRows * latentChannels, | |
| useAb2LatentUpdate ? currentDt : combinedDt, | |
| useAb2LatentUpdate ? reuseDt : 0, | |
| useAb2LatentUpdate ? approxAbScale : 0, | |
| ])); | |
| device.queue.writeBuffer(timestepBuffers[0], 0, makeTimestepEmbeddingF16(tCurr)); | |
| } | |
| let encoder = stepSubmitFusion ? fusedEncoder : device.createCommandEncoder(); | |
| encodedStepCount += 1; | |
| let pass = encoder.beginComputePass(); | |
| let phaseStart = performance.now(); | |
| const submitProfilePhase = async (phaseName) => { | |
| if (!profilePhases) return; | |
| pass.end(); | |
| device.queue.submit([encoder.finish()]); | |
| await device.queue.onSubmittedWorkDone(); | |
| phaseProfile[phaseName] += performance.now() - phaseStart; | |
| encoder = device.createCommandEncoder(); | |
| pass = encoder.beginComputePass(); | |
| phaseStart = performance.now(); | |
| }; | |
| const timeStage0 = timeStage0Stages[stepResourceIndex]; | |
| runStage(pass, timeStage0.pipeline, timeStage0.bindGroup, timeStage0.workgroupsX, timeStage0.workgroupsY); | |
| runStage(pass, timeSiluStage.pipeline, timeSiluStage.bindGroup, timeSiluStage.workgroupsX); | |
| runStage(pass, timeStage1.pipeline, timeStage1.bindGroup, timeStage1.workgroupsX, timeStage1.workgroupsY); | |
| runStage(pass, vecCastStage.pipeline, vecCastStage.bindGroup, vecCastStage.workgroupsX); | |
| runStage(pass, modImgStage.pipeline, modImgStage.bindGroup, modImgStage.workgroupsX, modImgStage.workgroupsY); | |
| runStage(pass, modTxtStage.pipeline, modTxtStage.bindGroup, modTxtStage.workgroupsX, modTxtStage.workgroupsY); | |
| runStage(pass, modSingleStage.pipeline, modSingleStage.bindGroup, modSingleStage.workgroupsX, modSingleStage.workgroupsY); | |
| runStage(pass, finalModStage.pipeline, finalModStage.bindGroup, finalModStage.workgroupsX, finalModStage.workgroupsY); | |
| runStage(pass, quantF32Pipeline, inputProjectionBinds.quantLatent, imgRows); | |
| runStage(pass, dotF32Pipeline, inputProjectionBinds.imgIn, Math.ceil(hidden / linearTileCols), dotWorkgroupsY(imgRows)); | |
| await submitProfilePhase("time_input_ms"); | |
| for (let doubleIndex = 0; doubleIndex < doubleStages.length; ++doubleIndex) { | |
| const stage = doubleStages[doubleIndex]; | |
| runStage(pass, preNormQuantPipeline, stage.imgQkvPre, imgRows); | |
| runStage(pass, dotF16Pipeline, stage.imgQkvDot, Math.ceil(qkvM / linearTileCols), dotWorkgroupsY(imgRows)); | |
| runStage(pass, preNormQuantPipeline, stage.txtQkvPre, txtRows); | |
| runStage(pass, dotF16Pipeline, stage.txtQkvDot, Math.ceil(qkvM / linearTileCols), dotWorkgroupsY(txtRows)); | |
| runStage(pass, doubleQkvNormPipeline, stage.txtQkvNorm, txtRows, 24); | |
| runStage(pass, doubleQkvNormPipeline, stage.imgQkvNorm, imgRows, 24); | |
| runStage(pass, jointAttentionPipeline, doubleCommonBind.attention, Math.ceil(jointRows / attentionQueryRows), 24); | |
| runStage(pass, quantOffsetPipeline, doubleCommonBind.imgQuantAttn, imgRows); | |
| runStage(pass, dotResidualPipeline, stage.imgProj, Math.ceil(hidden / projTileCols), dotWorkgroupsY(imgRows)); | |
| runStage(pass, quantOffsetPipeline, doubleCommonBind.txtQuantAttn, txtRows); | |
| runStage(pass, dotResidualPipeline, stage.txtProj, Math.ceil(hidden / projTileCols), dotWorkgroupsY(txtRows)); | |
| if (String(config.debugStopAfter || "").toLowerCase() === `double_block_${doubleIndex}_attn`) { | |
| pass.end(); | |
| device.queue.submit([encoder.finish()]); | |
| await device.queue.onSubmittedWorkDone(); | |
| const imgCount = Math.min( | |
| imgRows * hidden, | |
| Number(config.debugImgReadbackCount ?? config.debugReadbackCount ?? imgRows * hidden), | |
| ); | |
| const txtCount = Math.min( | |
| txtRows * hidden, | |
| Number(config.debugTxtReadbackCount ?? Math.min(4096, txtRows * hidden)), | |
| ); | |
| const [imgValues, txtValues] = await Promise.all([ | |
| readFloat32Buffer(device, imgMidState, imgCount), | |
| readFloat32Buffer(device, txtMidState, txtCount), | |
| ]); | |
| return { | |
| verdict: "custom-flux-transformer-debug-double-block-attention", | |
| latentF16: new Uint16Array(0), | |
| config: { | |
| imageTokens: imgRows, | |
| textTokens: txtRows, | |
| jointRows, | |
| latentChannels, | |
| imageWidth, | |
| debugStopAfter: `double_block_${doubleIndex}_attn`, | |
| linearTileCols, | |
| projTileCols, | |
| wChunkCols, | |
| attentionTileKeys, | |
| attentionQueryRows, | |
| singleLinear1Output: useSingleLinear1OutputF16 ? "f16" : "f32", | |
| singleLinear1Q4Kernel: useSingleLinear1Q4Dp4aQkvOnly ? "dp4a-qkv" : (useSingleLinear1Q4Dp4a ? "dp4a" : "f16"), | |
| singleLinear1Q4ActivationScale: useSingleLinear1Q4GroupedActivation ? "group16" : "row", | |
| singleLinear1QkvBackend: useSingleLinear1QkvDp4a ? "dp4a" : "q4", | |
| singleLinear1MlpBackend: useSingleLinear1MlpDp4a ? "dp4a" : "q4", | |
| singleLinear2Backend: useSingleLinear2Dp4a ? "dp4a" : "q4", | |
| }, | |
| load: {total_ms: loadMs, text_projection_init_ms: initMs}, | |
| debug: { | |
| [`double_block_${doubleIndex}_attn`]: { | |
| img: { | |
| shape: [imgRows, hidden], | |
| count: imgValues.length, | |
| stats: floatArrayStats(imgValues), | |
| values: Array.from(imgValues), | |
| }, | |
| txt: { | |
| shape: [txtRows, hidden], | |
| count: txtValues.length, | |
| stats: floatArrayStats(txtValues), | |
| values: Array.from(txtValues), | |
| }, | |
| }, | |
| }, | |
| summary: { | |
| finite: imgValues.length + txtValues.length, | |
| max_abs: Math.max( | |
| ...Array.from(imgValues, (value) => Math.abs(value)), | |
| ...Array.from(txtValues, (value) => Math.abs(value)), | |
| 0, | |
| ), | |
| }, | |
| }; | |
| } | |
| runStage(pass, preNormQuantPipeline, stage.imgMlpPre, imgRows); | |
| runStage(pass, dotF16Pipeline, stage.imgMlp0, Math.ceil(doubleMlp1M / linearTileCols), dotWorkgroupsY(imgRows)); | |
| runStage(pass, doubleActivateMlpPipeline, doubleCommonBind.imgAct, imgRows); | |
| runStage(pass, dotResidualPipeline, stage.imgMlp2, Math.ceil(hidden / projTileCols), dotWorkgroupsY(imgRows)); | |
| runStage(pass, preNormQuantPipeline, stage.txtMlpPre, txtRows); | |
| runStage(pass, dotF16Pipeline, stage.txtMlp0, Math.ceil(doubleMlp1M / linearTileCols), dotWorkgroupsY(txtRows)); | |
| runStage(pass, doubleActivateMlpPipeline, doubleCommonBind.txtAct, txtRows); | |
| runStage(pass, dotResidualPipeline, stage.txtMlp2, Math.ceil(hidden / projTileCols), dotWorkgroupsY(txtRows)); | |
| if (String(config.debugStopAfter || "").toLowerCase() === `double_block_${doubleIndex}`) { | |
| pass.end(); | |
| device.queue.submit([encoder.finish()]); | |
| await device.queue.onSubmittedWorkDone(); | |
| const imgCount = Math.min( | |
| imgRows * hidden, | |
| Number(config.debugImgReadbackCount ?? config.debugReadbackCount ?? imgRows * hidden), | |
| ); | |
| const txtCount = Math.min( | |
| txtRows * hidden, | |
| Number(config.debugTxtReadbackCount ?? Math.min(4096, txtRows * hidden)), | |
| ); | |
| const [imgValues, txtValues] = await Promise.all([ | |
| readFloat32Buffer(device, stage.imgOutputState, imgCount, stage.imgOutputOffset), | |
| readFloat32Buffer(device, stage.txtOutputState, txtCount, stage.txtOutputOffset), | |
| ]); | |
| return { | |
| verdict: "custom-flux-transformer-debug-double-block", | |
| latentF16: new Uint16Array(0), | |
| config: { | |
| imageTokens: imgRows, | |
| textTokens: txtRows, | |
| jointRows, | |
| latentChannels, | |
| imageWidth, | |
| debugStopAfter: `double_block_${doubleIndex}`, | |
| linearTileCols, | |
| projTileCols, | |
| wChunkCols, | |
| attentionTileKeys, | |
| attentionQueryRows, | |
| }, | |
| load: {total_ms: loadMs, text_projection_init_ms: initMs}, | |
| debug: { | |
| [`double_block_${doubleIndex}`]: { | |
| img: { | |
| shape: [imgRows, hidden], | |
| count: imgValues.length, | |
| stats: floatArrayStats(imgValues), | |
| values: Array.from(imgValues), | |
| }, | |
| txt: { | |
| shape: [txtRows, hidden], | |
| count: txtValues.length, | |
| stats: floatArrayStats(txtValues), | |
| values: Array.from(txtValues), | |
| }, | |
| }, | |
| }, | |
| summary: { | |
| finite: imgValues.length + txtValues.length, | |
| max_abs: Math.max( | |
| ...Array.from(imgValues, (value) => Math.abs(value)), | |
| ...Array.from(txtValues, (value) => Math.abs(value)), | |
| 0, | |
| ), | |
| }, | |
| }; | |
| } | |
| } | |
| if (profilePhases) { | |
| await submitProfilePhase("double_blocks_ms"); | |
| } else if (!singleComputePass) { | |
| pass.end(); | |
| } | |
| if (String(config.debugStopAfter || "").toLowerCase() === "double_blocks") { | |
| if (singleComputePass) pass.end(); | |
| device.queue.submit([encoder.finish()]); | |
| await device.queue.onSubmittedWorkDone(); | |
| const txtCount = Math.min( | |
| txtRows * hidden, | |
| Number(config.debugTxtReadbackCount ?? Math.min(4096, txtRows * hidden)), | |
| ); | |
| const imgCount = Math.min( | |
| imgRows * hidden, | |
| Number(config.debugImgReadbackCount ?? config.debugReadbackCount ?? imgRows * hidden), | |
| ); | |
| const [txtValues, imgValues] = await Promise.all([ | |
| readFloat32Buffer(device, jointStateA, txtCount), | |
| readFloat32Buffer(device, jointStateA, imgCount, txtBytes), | |
| ]); | |
| return { | |
| verdict: "custom-flux-transformer-debug-double-blocks", | |
| latentF16: new Uint16Array(0), | |
| config: { | |
| imageTokens: imgRows, | |
| textTokens: txtRows, | |
| jointRows, | |
| latentChannels, | |
| imageWidth, | |
| debugStopAfter: "double_blocks", | |
| linearTileCols, | |
| projTileCols, | |
| wChunkCols, | |
| attentionTileKeys, | |
| attentionQueryRows, | |
| }, | |
| load: {total_ms: loadMs, text_projection_init_ms: initMs}, | |
| debug: { | |
| double_blocks: { | |
| txt: { | |
| shape: [txtRows, hidden], | |
| count: txtValues.length, | |
| stats: floatArrayStats(txtValues), | |
| values: Array.from(txtValues), | |
| }, | |
| img: { | |
| shape: [imgRows, hidden], | |
| count: imgValues.length, | |
| stats: floatArrayStats(imgValues), | |
| values: Array.from(imgValues), | |
| }, | |
| }, | |
| }, | |
| summary: { | |
| finite: imgValues.length + txtValues.length, | |
| max_abs: Math.max( | |
| ...Array.from(imgValues, (value) => Math.abs(value)), | |
| ...Array.from(txtValues, (value) => Math.abs(value)), | |
| 0, | |
| ), | |
| }, | |
| }; | |
| } | |
| if (!profilePhases && !singleComputePass) { | |
| pass = encoder.beginComputePass(); | |
| } | |
| const reuseCachedSingleTail = singleTailDeltaEnabled && step >= reuseSingleTailFromStep; | |
| const skipLateSingleBlocks = !reuseCachedSingleTail && skipSingleBlocksFromStep > 0 && | |
| step >= skipSingleBlocksFromStep && | |
| skipSingleBlocksFromBlock < singleStages.length; | |
| const singleStopIndex = reuseCachedSingleTail | |
| ? Math.max(0, Math.min(singleStages.length, reuseSingleTailFromBlock)) | |
| : skipLateSingleBlocks | |
| ? Math.max(0, Math.min(singleStages.length, skipSingleBlocksFromBlock)) | |
| : singleStages.length; | |
| for (let singleIndex = 0; singleIndex < singleStopIndex; ++singleIndex) { | |
| const stage = singleStages[singleIndex]; | |
| const q4Stages = stage.dot1Q4Stages || []; | |
| const hasQ4Dp4aStage = q4Stages.some((item) => item && item.q4Dp4a); | |
| const hasQ4F16Stage = q4Stages.some((item) => item && !item.q4Dp4a); | |
| if (hasQ4F16Stage) { | |
| runStage(pass, preNormF16Pipeline, stage.preNormF16, jointRows); | |
| } | |
| if (hasQ4Dp4aStage && useSingleLinear1Q4GroupedActivation) { | |
| runStage(pass, preNormGroupQuantPipeline, stage.preNormGroup, jointRows); | |
| } else if (useSingleLinear1Dp4a || hasQ4Dp4aStage) { | |
| runStage(pass, preNormQuantPipeline, stage.preNorm, jointRows); | |
| } | |
| if (useSingleLinear1Dp4a) { | |
| runStage(pass, dotF16Pipeline, stage.dot1, Math.ceil(singleLinear1M / linearTileCols), dotWorkgroupsY(jointRows)); | |
| } | |
| for (const q4Stage of q4Stages) { | |
| runStage(pass, q4Stage.pipeline, q4Stage.bindGroup, q4Stage.workgroupsX, q4Stage.workgroupsY); | |
| } | |
| if (profileSingleParts) { | |
| await submitProfilePhase("single_linear1_ms"); | |
| } | |
| runStage(pass, singleQkNormPipeline, stage.qkNorm, jointRows, 24); | |
| runStage(pass, singleAttentionPipeline, singleAttentionBind, singleAttentionWorkgroupsX, 24); | |
| if (profileSingleParts) { | |
| await submitProfilePhase("single_attention_ms"); | |
| } | |
| if (String(config.debugStopAfter || "").toLowerCase() === `single_block_${singleIndex}_attn`) { | |
| pass.end(); | |
| device.queue.submit([encoder.finish()]); | |
| await device.queue.onSubmittedWorkDone(); | |
| const txtCount = Math.min( | |
| txtRows * hidden, | |
| Number(config.debugTxtReadbackCount ?? Math.min(4096, txtRows * hidden)), | |
| ); | |
| const imgCount = Math.min( | |
| imgRows * hidden, | |
| Number(config.debugImgReadbackCount ?? config.debugReadbackCount ?? imgRows * hidden), | |
| ); | |
| const [txtValues, imgValues] = await Promise.all([ | |
| readFloat32Buffer(device, singleAttentionOutBuffer, txtCount), | |
| readFloat32Buffer(device, singleAttentionOutBuffer, imgCount, txtBytes), | |
| ]); | |
| return { | |
| verdict: "custom-flux-transformer-debug-single-block-attention", | |
| latentF16: new Uint16Array(0), | |
| config: { | |
| imageTokens: imgRows, | |
| textTokens: txtRows, | |
| jointRows, | |
| latentChannels, | |
| imageWidth, | |
| debugStopAfter: `single_block_${singleIndex}_attn`, | |
| linearTileCols, | |
| projTileCols, | |
| wChunkCols, | |
| attentionTileKeys, | |
| attentionQueryRows, | |
| }, | |
| load: {total_ms: loadMs, text_projection_init_ms: initMs}, | |
| debug: { | |
| [`single_block_${singleIndex}_attn`]: { | |
| txt: { | |
| shape: [txtRows, hidden], | |
| count: txtValues.length, | |
| stats: floatArrayStats(txtValues), | |
| values: Array.from(txtValues), | |
| }, | |
| img: { | |
| shape: [imgRows, hidden], | |
| count: imgValues.length, | |
| stats: floatArrayStats(imgValues), | |
| values: Array.from(imgValues), | |
| }, | |
| }, | |
| }, | |
| summary: { | |
| finite: imgValues.length + txtValues.length, | |
| max_abs: Math.max( | |
| ...Array.from(imgValues, (value) => Math.abs(value)), | |
| ...Array.from(txtValues, (value) => Math.abs(value)), | |
| 0, | |
| ), | |
| }, | |
| }; | |
| } | |
| if (useSingleLinear2Dp4a) { | |
| runStage(pass, singleActivatePipeline, singleActivateBind, jointRows); | |
| runStage(pass, dotResidualPipeline, stage.dot2, Math.ceil(hidden / projTileCols), dotWorkgroupsY(jointRows)); | |
| } else { | |
| runStage(pass, singleActivateF32Pipeline, singleActivateF32Bind, Math.ceil(singleLinear2K / 256), jointRows); | |
| runStage(pass, singleLinear2CastStage.pipeline, singleLinear2CastStage.bindGroup, singleLinear2CastStage.workgroupsX); | |
| runStage(pass, stage.dot2Q4.pipeline, stage.dot2Q4.bindGroup, stage.dot2Q4.workgroupsX, stage.dot2Q4.workgroupsY); | |
| runStage(pass, singleResidualPipeline, stage.residualQ4, Math.ceil(hidden / 256), jointRows); | |
| } | |
| if (profileSingleParts) { | |
| await submitProfilePhase("single_mlp_ms"); | |
| } | |
| if (String(config.debugStopAfter || "").toLowerCase() === `single_block_${singleIndex}`) { | |
| pass.end(); | |
| device.queue.submit([encoder.finish()]); | |
| await device.queue.onSubmittedWorkDone(); | |
| const out = (singleIndex & 1) ? jointStateA : jointStateB; | |
| const txtCount = Math.min( | |
| txtRows * hidden, | |
| Number(config.debugTxtReadbackCount ?? Math.min(4096, txtRows * hidden)), | |
| ); | |
| const imgCount = Math.min( | |
| imgRows * hidden, | |
| Number(config.debugImgReadbackCount ?? config.debugReadbackCount ?? imgRows * hidden), | |
| ); | |
| const [txtValues, imgValues] = await Promise.all([ | |
| readFloat32Buffer(device, out, txtCount), | |
| readFloat32Buffer(device, out, imgCount, txtBytes), | |
| ]); | |
| return { | |
| verdict: "custom-flux-transformer-debug-single-block", | |
| latentF16: new Uint16Array(0), | |
| config: { | |
| imageTokens: imgRows, | |
| textTokens: txtRows, | |
| jointRows, | |
| latentChannels, | |
| imageWidth, | |
| debugStopAfter: `single_block_${singleIndex}`, | |
| linearTileCols, | |
| projTileCols, | |
| wChunkCols, | |
| attentionTileKeys, | |
| attentionQueryRows, | |
| singleLinear1Output: useSingleLinear1OutputF16 ? "f16" : "f32", | |
| singleLinear1Q4Kernel: useSingleLinear1Q4Dp4aQkvOnly ? "dp4a-qkv" : (useSingleLinear1Q4Dp4a ? "dp4a" : "f16"), | |
| singleLinear1Q4ActivationScale: useSingleLinear1Q4GroupedActivation ? "group16" : "row", | |
| singleLinear1QkvBackend: useSingleLinear1QkvDp4a ? "dp4a" : "q4", | |
| singleLinear1MlpBackend: useSingleLinear1MlpDp4a ? "dp4a" : "q4", | |
| singleLinear2Backend: useSingleLinear2Dp4a ? "dp4a" : "q4", | |
| }, | |
| load: {total_ms: loadMs, text_projection_init_ms: initMs}, | |
| debug: { | |
| [`single_block_${singleIndex}`]: { | |
| txt: { | |
| shape: [txtRows, hidden], | |
| count: txtValues.length, | |
| stats: floatArrayStats(txtValues), | |
| values: Array.from(txtValues), | |
| }, | |
| img: { | |
| shape: [imgRows, hidden], | |
| count: imgValues.length, | |
| stats: floatArrayStats(imgValues), | |
| values: Array.from(imgValues), | |
| }, | |
| }, | |
| }, | |
| summary: { | |
| finite: imgValues.length + txtValues.length, | |
| max_abs: Math.max( | |
| ...Array.from(imgValues, (value) => Math.abs(value)), | |
| ...Array.from(txtValues, (value) => Math.abs(value)), | |
| 0, | |
| ), | |
| }, | |
| }; | |
| } | |
| if ( | |
| singleTailDeltaEnabled && | |
| step < reuseSingleTailFromStep && | |
| singleIndex + 1 === reuseSingleTailFromBlock | |
| ) { | |
| const currentBlockOutputState = (singleIndex & 1) ? jointStateA : jointStateB; | |
| pass.end(); | |
| encoder.copyBufferToBuffer(currentBlockOutputState, 0, singleTailBaseCacheBuffer, 0, jointBytes); | |
| pass = encoder.beginComputePass(); | |
| } | |
| } | |
| if (reuseCachedSingleTail) { | |
| reusedSingleTailBlocks += singleStages.length - singleStopIndex; | |
| const currentSingleState = (singleStopIndex & 1) ? jointStateB : jointStateA; | |
| if (currentSingleState !== finalInputState) { | |
| pass.end(); | |
| encoder.copyBufferToBuffer(currentSingleState, 0, finalInputState, 0, jointBytes); | |
| pass = encoder.beginComputePass(); | |
| } | |
| runStage(pass, singleTailDeltaApplyPipeline, singleTailDeltaApplyBind, Math.ceil((jointRows * hidden) / 256)); | |
| } else if (skipLateSingleBlocks) { | |
| skippedSingleBlocks += singleStages.length - singleStopIndex; | |
| const currentSingleState = (singleStopIndex & 1) ? jointStateB : jointStateA; | |
| if (currentSingleState !== finalInputState) { | |
| pass.end(); | |
| encoder.copyBufferToBuffer(currentSingleState, 0, finalInputState, 0, jointBytes); | |
| pass = encoder.beginComputePass(); | |
| } | |
| } else if (singleTailDeltaEnabled && step < reuseSingleTailFromStep) { | |
| pass.end(); | |
| pass = encoder.beginComputePass(); | |
| runStage(pass, singleTailDeltaCachePipeline, singleTailDeltaCacheBind, Math.ceil((jointRows * hidden) / 256)); | |
| } | |
| if (profilePhases && !profileSingleParts) { | |
| await submitProfilePhase("single_blocks_ms"); | |
| } else if (!singleComputePass) { | |
| pass.end(); | |
| pass = encoder.beginComputePass(); | |
| } | |
| runStage(pass, preNormQuantPipeline, finalNormBind, imgRows); | |
| runStage(pass, finalDotPipeline, finalDotBind, Math.ceil(latentChannels / finalTileCols), dotWorkgroupsY(imgRows)); | |
| runStage( | |
| pass, | |
| useAb2LatentUpdate ? latentUpdateAb2Pipeline : latentUpdatePipeline, | |
| useAb2LatentUpdate ? latentUpdateAb2Binds[stepResourceIndex] : latentUpdateBinds[stepResourceIndex], | |
| Math.ceil((imgRows * latentChannels) / 256), | |
| ); | |
| pass.end(); | |
| if (useApproxAb2) { | |
| encoder.copyBufferToBuffer(predBuffer, 0, previousPredBuffer, 0, latentBytes); | |
| previousPredAvailable = true; | |
| } | |
| if (profilePhases) { | |
| device.queue.submit([encoder.finish()]); | |
| await device.queue.onSubmittedWorkDone(); | |
| phaseProfile.final_update_ms += performance.now() - phaseStart; | |
| stepTimes.push(performance.now() - start); | |
| } else if (stepSubmitFusion) { | |
| stepTimes.push(0); | |
| } else { | |
| device.queue.submit([encoder.finish()]); | |
| if (deferStepWait) { | |
| stepTimes.push(0); | |
| } else { | |
| await device.queue.onSubmittedWorkDone(); | |
| stepTimes.push(performance.now() - start); | |
| } | |
| } | |
| } | |
| if (stepSubmitFusion && encodedStepCount) { | |
| device.queue.submit([fusedEncoder.finish()]); | |
| await device.queue.onSubmittedWorkDone(); | |
| const avgStepMs = (performance.now() - allStepsStart) / encodedStepCount; | |
| for (let i = 0; i < stepTimes.length; ++i) { | |
| if (stepTimes[i] === 0) stepTimes[i] = avgStepMs; | |
| } | |
| } else if (deferStepWait && stepTimes.length) { | |
| await device.queue.onSubmittedWorkDone(); | |
| const avgStepMs = (performance.now() - allStepsStart) / stepTimes.length; | |
| for (let i = 0; i < stepTimes.length; ++i) { | |
| if (stepTimes[i] === 0) stepTimes[i] = avgStepMs; | |
| } | |
| } | |
| if (config.warmDispatchOnly === true) { | |
| return { | |
| verdict: "custom-flux-transformer-dispatch-warmed", | |
| latentF16: new Uint16Array(0), | |
| config: { | |
| imageTokens: imgRows, | |
| textTokens: txtRows, | |
| jointRows, | |
| latentChannels, | |
| imageWidth, | |
| doubleBlockCount, | |
| singleBlockCount, | |
| steps: timesteps.length - 1, | |
| linearTileCols, | |
| projTileCols, | |
| finalTileCols, | |
| wChunkCols, | |
| linearRowBlock, | |
| singleQ4KChunk, | |
| attentionTileKeys, | |
| attentionQueryRows, | |
| singleQkNormStorage: useSingleQkNormStorageF16 ? "f16" : "f32", | |
| singleLinear1Output: useSingleLinear1OutputF16 ? "f16" : "f32", | |
| singleLinear1Q4Kernel: useSingleLinear1Q4Dp4aQkvOnly ? "dp4a-qkv" : (useSingleLinear1Q4Dp4a ? "dp4a" : "f16"), | |
| singleLinear1Q4ActivationScale: useSingleLinear1Q4GroupedActivation ? "group16" : "row", | |
| singleLinear1QkvBackend: useSingleLinear1QkvDp4a ? "dp4a" : "q4", | |
| singleLinear1MlpBackend: useSingleLinear1MlpDp4a ? "dp4a" : "q4", | |
| singleLinear2Backend: useSingleLinear2Dp4a ? "dp4a" : "q4", | |
| singleAttentionKernel: useSubgroupSingleAttention ? "subgroup" : "tiled", | |
| profilePhases, | |
| profileSingleParts, | |
| stepSubmitFusion, | |
| singleComputePass, | |
| skippedOptionalSetupBytes, | |
| skippedOptionalPipelines, | |
| approxReusePredEvery, | |
| approxPredictionMode: useApproxAb2 ? "ab2" : "raw", | |
| approxAbScale: useApproxAb2 ? approxAbScale : 0, | |
| approxReusedSteps, | |
| approxCollapsedUpdateDispatches, | |
| skipSingleBlocksFromStep, | |
| skipSingleBlocksFromBlock, | |
| skippedSingleBlocks, | |
| reuseSingleTailFromStep, | |
| reuseSingleTailFromBlock, | |
| reusedSingleTailBlocks, | |
| }, | |
| load: { | |
| total_ms: loadMs, | |
| stage_setup_ms: stageSetupMs, | |
| text_projection_init_ms: initMs, | |
| text_projection_cache_hit: Boolean(cachedTxtBaseState || persistentTxtProjectionHit || staticTxtProjectionHit), | |
| text_projection_cache_source: cachedTxtBaseState ? "gpu" : (persistentTxtProjectionHit ? "indexeddb" : (staticTxtProjectionHit ? "static-file" : "computed")), | |
| text_projection_persistent_ms: persistentTxtProjectionMs, | |
| readback_ms: 0, | |
| }, | |
| summary: { | |
| total_step_ms: stepTimes.reduce((sum, value) => sum + value, 0), | |
| median_step_ms: median(stepTimes), | |
| phase_profile_ms: profilePhases ? {...phaseProfile} : undefined, | |
| finite: 0, | |
| max_abs: 0, | |
| }, | |
| step_ms: stepTimes, | |
| sample: [], | |
| }; | |
| } | |
| const readStart = performance.now(); | |
| const latentF32 = await readFloat32Buffer(device, latentF32Buffer, imgRows * latentChannels); | |
| const readMs = performance.now() - readStart; | |
| const latentF16 = new Uint16Array(latentF32.length); | |
| let finite = 0; | |
| let maxAbs = 0; | |
| for (let i = 0; i < latentF32.length; ++i) { | |
| const value = latentF32[i]; | |
| if (Number.isFinite(value)) { | |
| finite += 1; | |
| maxAbs = Math.max(maxAbs, Math.abs(value)); | |
| latentF16[i] = float32ToFloat16Bits(Math.max(-65504, Math.min(65504, value))); | |
| } else { | |
| latentF16[i] = 0; | |
| } | |
| } | |
| return { | |
| verdict: "custom-flux-transformer-denoise-completed", | |
| latentF16, | |
| config: { | |
| imageTokens: imgRows, | |
| textTokens: txtRows, | |
| jointRows, | |
| latentChannels, | |
| imageWidth, | |
| doubleBlockCount, | |
| singleBlockCount, | |
| steps: timesteps.length - 1, | |
| linearTileCols, | |
| projTileCols, | |
| finalTileCols, | |
| wChunkCols, | |
| linearRowBlock, | |
| singleQ4KChunk, | |
| attentionTileKeys, | |
| attentionQueryRows, | |
| singleQkNormStorage: useSingleQkNormStorageF16 ? "f16" : "f32", | |
| singleLinear1Output: useSingleLinear1OutputF16 ? "f16" : "f32", | |
| singleLinear1Q4Kernel: useSingleLinear1Q4Dp4aQkvOnly ? "dp4a-qkv" : (useSingleLinear1Q4Dp4a ? "dp4a" : "f16"), | |
| singleLinear1Q4ActivationScale: useSingleLinear1Q4GroupedActivation ? "group16" : "row", | |
| singleLinear1QkvBackend: useSingleLinear1QkvDp4a ? "dp4a" : "q4", | |
| singleLinear1MlpBackend: useSingleLinear1MlpDp4a ? "dp4a" : "q4", | |
| singleLinear2Backend: useSingleLinear2Dp4a ? "dp4a" : "q4", | |
| singleAttentionKernel: useSubgroupSingleAttention ? "subgroup" : "tiled", | |
| profilePhases, | |
| profileSingleParts, | |
| stepSubmitFusion, | |
| singleComputePass, | |
| skippedOptionalSetupBytes, | |
| skippedOptionalPipelines, | |
| approxReusePredEvery, | |
| approxPredictionMode: useApproxAb2 ? "ab2" : "raw", | |
| approxAbScale: useApproxAb2 ? approxAbScale : 0, | |
| approxReusedSteps, | |
| approxCollapsedUpdateDispatches, | |
| skipSingleBlocksFromStep, | |
| skipSingleBlocksFromBlock, | |
| skippedSingleBlocks, | |
| reuseSingleTailFromStep, | |
| reuseSingleTailFromBlock, | |
| reusedSingleTailBlocks, | |
| }, | |
| load: { | |
| total_ms: loadMs, | |
| stage_setup_ms: stageSetupMs, | |
| text_projection_init_ms: initMs, | |
| text_projection_cache_hit: Boolean(cachedTxtBaseState || persistentTxtProjectionHit || staticTxtProjectionHit), | |
| text_projection_cache_source: cachedTxtBaseState ? "gpu" : (persistentTxtProjectionHit ? "indexeddb" : (staticTxtProjectionHit ? "static-file" : "computed")), | |
| text_projection_persistent_ms: persistentTxtProjectionMs, | |
| readback_ms: readMs, | |
| }, | |
| summary: { | |
| total_step_ms: stepTimes.reduce((sum, value) => sum + value, 0), | |
| median_step_ms: median(stepTimes), | |
| phase_profile_ms: profilePhases ? {...phaseProfile} : undefined, | |
| finite, | |
| max_abs: maxAbs, | |
| }, | |
| step_ms: stepTimes, | |
| sample: Array.from(latentF32.slice(0, Math.min(8, latentF32.length))), | |
| }; | |
| } | |
| window.runCustomFluxTransformerDenoise = runCustomFluxTransformerDenoise; | |
| window.prepareCustomFluxTransformerAssets = prepareCustomFluxTransformerAssets; | |
| window.runCustomDoubleStreamBlocksLoopBench = runCustomDoubleStreamBlocksLoopBench; | |
| window.resetCustomLowbitWebGpuDevice = resetCustomLowbitWebGpuDevice; | |