Spaces:
Configuration error
Configuration error
| /** | |
| * Wrapper to handle PNDM scheduler | |
| */ | |
| class TVMPNDMScheduler { | |
| constructor(schedulerConsts, latentShape, tvm, device, vm) { | |
| this.timestep = []; | |
| this.sampleCoeff = []; | |
| this.alphaDiff = []; | |
| this.modelOutputDenomCoeff = []; | |
| this.ets = []; | |
| this.schedulerFunc = []; | |
| this.currSample = undefined; | |
| this.tvm = tvm; | |
| // prebuild constants | |
| // principle: always detach for class members | |
| // to avoid recycling output scope. | |
| function loadConsts(output, dtype, input) { | |
| for (let t = 0; t < input.length; ++t) { | |
| output.push( | |
| tvm.detachFromCurrentScope( | |
| tvm.empty([], dtype, device).copyFrom([input[t]]) | |
| ) | |
| ); | |
| } | |
| } | |
| loadConsts(this.timestep, "int32", schedulerConsts["timesteps"]); | |
| loadConsts(this.sampleCoeff, "float32", schedulerConsts["sample_coeff"]); | |
| loadConsts(this.alphaDiff, "float32", schedulerConsts["alpha_diff"]); | |
| loadConsts( | |
| this.modelOutputDenomCoeff, "float32", | |
| schedulerConsts["model_output_denom_coeff"]); | |
| for (let i = 0; i < 4; ++i) { | |
| this.ets.push( | |
| this.tvm.detachFromCurrentScope( | |
| this.tvm.empty(latentShape, "float32", device) | |
| ) | |
| ); | |
| } | |
| for (let i = 0; i < 5; ++i) { | |
| this.schedulerFunc.push( | |
| tvm.detachFromCurrentScope( | |
| vm.getFunction("pndm_scheduler_step_" + i.toString()) | |
| ) | |
| ); | |
| } | |
| } | |
| dispose() { | |
| for (let t = 0; t < this.timestep.length; ++t) { | |
| this.timestep[t].dispose(); | |
| this.sampleCoeff[t].dispose(); | |
| this.alphaDiff[t].dispose(); | |
| this.modelOutputDenomCoeff[t].dispose(); | |
| } | |
| for (let i = 0; i < this.schedulerFunc.length; ++i) { | |
| this.schedulerFunc[i].dispose(); | |
| } | |
| if (this.currSample) { | |
| this.currSample.dispose(); | |
| } | |
| for (let i = 0; i < this.ets.length; ++i) { | |
| this.ets[i].dispose(); | |
| } | |
| } | |
| step(modelOutput, sample, counter) { | |
| // keep running history of last four inputs | |
| if (counter != 1) { | |
| this.ets.shift(); | |
| this.ets.push(this.tvm.detachFromCurrentScope( | |
| modelOutput | |
| )); | |
| } | |
| if (counter == 0) { | |
| this.currSample = this.tvm.detachFromCurrentScope( | |
| sample | |
| ); | |
| } else if (counter == 1) { | |
| sample = this.tvm.attachToCurrentScope(this.currSample); | |
| this.currSample = undefined; | |
| } | |
| const findex = counter < 4 ? counter : 4; | |
| const prevLatents = this.schedulerFunc[findex]( | |
| sample, | |
| modelOutput, | |
| this.sampleCoeff[counter], | |
| this.alphaDiff[counter], | |
| this.modelOutputDenomCoeff[counter], | |
| this.ets[0], | |
| this.ets[1], | |
| this.ets[2], | |
| this.ets[3] | |
| ); | |
| return prevLatents; | |
| } | |
| } | |
| /** | |
| * Wrapper to handle multistep DPM-solver scheduler | |
| */ | |
| class TVMDPMSolverMultistepScheduler { | |
| constructor(schedulerConsts, latentShape, tvm, device, vm) { | |
| this.timestep = []; | |
| this.alpha = []; | |
| this.sigma = []; | |
| this.c0 = []; | |
| this.c1 = []; | |
| this.c2 = []; | |
| this.lastModelOutput = undefined; | |
| this.convertModelOutputFunc = undefined; | |
| this.stepFunc = undefined; | |
| this.tvm = tvm; | |
| // prebuild constants | |
| // principle: always detach for class members | |
| // to avoid recycling output scope. | |
| function loadConsts(output, dtype, input) { | |
| for (let t = 0; t < input.length; ++t) { | |
| output.push( | |
| tvm.detachFromCurrentScope( | |
| tvm.empty([], dtype, device).copyFrom([input[t]]) | |
| ) | |
| ); | |
| } | |
| } | |
| loadConsts(this.timestep, "int32", schedulerConsts["timesteps"]); | |
| loadConsts(this.alpha, "float32", schedulerConsts["alpha"]); | |
| loadConsts(this.sigma, "float32", schedulerConsts["sigma"]); | |
| loadConsts(this.c0, "float32", schedulerConsts["c0"]); | |
| loadConsts(this.c1, "float32", schedulerConsts["c1"]); | |
| loadConsts(this.c2, "float32", schedulerConsts["c2"]); | |
| this.lastModelOutput = this.tvm.detachFromCurrentScope( | |
| this.tvm.empty(latentShape, "float32", device) | |
| ) | |
| this.convertModelOutputFunc = tvm.detachFromCurrentScope( | |
| vm.getFunction("dpm_solver_multistep_scheduler_convert_model_output") | |
| ) | |
| this.stepFunc = tvm.detachFromCurrentScope( | |
| vm.getFunction("dpm_solver_multistep_scheduler_step") | |
| ) | |
| } | |
| dispose() { | |
| for (let t = 0; t < this.timestep.length; ++t) { | |
| this.timestep[t].dispose(); | |
| this.alpha[t].dispose(); | |
| this.sigma[t].dispose(); | |
| this.c0[t].dispose(); | |
| this.c1[t].dispose(); | |
| this.c2[t].dispose(); | |
| } | |
| this.lastModelOutput.dispose(); | |
| this.convertModelOutputFunc.dispose(); | |
| this.stepFunc.dispose(); | |
| } | |
| step(modelOutput, sample, counter) { | |
| modelOutput = this.convertModelOutputFunc(sample, modelOutput, this.alpha[counter], this.sigma[counter]) | |
| const prevLatents = this.stepFunc( | |
| sample, | |
| modelOutput, | |
| this.lastModelOutput, | |
| this.c0[counter], | |
| this.c1[counter], | |
| this.c2[counter], | |
| ); | |
| this.lastModelOutput = this.tvm.detachFromCurrentScope( | |
| modelOutput | |
| ); | |
| return prevLatents; | |
| } | |
| } | |
| class StableDiffusionPipeline { | |
| constructor(tvm, tokenizer, schedulerConsts, cacheMetadata) { | |
| if (cacheMetadata == undefined) { | |
| throw Error("Expect cacheMetadata"); | |
| } | |
| this.tvm = tvm; | |
| this.tokenizer = tokenizer; | |
| this.maxTokenLength = 77; | |
| this.device = this.tvm.webgpu(); | |
| this.tvm.bindCanvas(document.getElementById("canvas")); | |
| // VM functions | |
| this.vm = this.tvm.detachFromCurrentScope( | |
| this.tvm.createVirtualMachine(this.device) | |
| ); | |
| this.schedulerConsts = schedulerConsts; | |
| this.clipToTextEmbeddings = this.tvm.detachFromCurrentScope( | |
| this.vm.getFunction("clip") | |
| ); | |
| this.clipParams = this.tvm.detachFromCurrentScope( | |
| this.tvm.getParamsFromCache("clip", cacheMetadata.clipParamSize) | |
| ); | |
| this.unetLatentsToNoisePred = this.tvm.detachFromCurrentScope( | |
| this.vm.getFunction("unet") | |
| ); | |
| this.unetParams = this.tvm.detachFromCurrentScope( | |
| this.tvm.getParamsFromCache("unet", cacheMetadata.unetParamSize) | |
| ); | |
| this.vaeToImage = this.tvm.detachFromCurrentScope( | |
| this.vm.getFunction("vae") | |
| ); | |
| this.vaeParams = this.tvm.detachFromCurrentScope( | |
| this.tvm.getParamsFromCache("vae", cacheMetadata.vaeParamSize) | |
| ); | |
| this.imageToRGBA = this.tvm.detachFromCurrentScope( | |
| this.vm.getFunction("image_to_rgba") | |
| ); | |
| this.concatEmbeddings = this.tvm.detachFromCurrentScope( | |
| this.vm.getFunction("concat_embeddings") | |
| ); | |
| } | |
| dispose() { | |
| // note: tvm instance is not owned by this class | |
| this.concatEmbeddings.dispose(); | |
| this.imageToRGBA.dispose() | |
| this.vaeParams.dispose(); | |
| this.vaeToImage.dispose(); | |
| this.unetParams.dispose(); | |
| this.unetLatentsToNoisePred.dispose(); | |
| this.clipParams.dispose(); | |
| this.clipToTextEmbeddings.dispose(); | |
| this.vm.dispose(); | |
| } | |
| /** | |
| * Tokenize the prompt to TVMNDArray. | |
| * @param prompt Input prompt | |
| * @returns The text id NDArray. | |
| */ | |
| tokenize(prompt) { | |
| const encoded = this.tokenizer.encode(prompt, true).input_ids; | |
| const inputIDs = new Int32Array(this.maxTokenLength); | |
| if (encoded.length < this.maxTokenLength) { | |
| inputIDs.set(encoded); | |
| const lastTok = encoded[encoded.length - 1]; | |
| inputIDs.fill(lastTok, encoded.length, inputIDs.length); | |
| } else { | |
| inputIDs.set(encoded.slice(0, this.maxTokenLength)); | |
| } | |
| return this.tvm.empty([1, this.maxTokenLength], "int32", this.device).copyFrom(inputIDs); | |
| } | |
| /** | |
| * async preload webgpu pipelines when possible. | |
| */ | |
| async asyncLoadWebGPUPiplines() { | |
| await this.tvm.asyncLoadWebGPUPiplines(this.vm.getInternalModule()); | |
| } | |
| /** | |
| * Run generation pipeline. | |
| * | |
| * @param prompt Input prompt. | |
| * @param negPrompt Input negative prompt. | |
| * @param progressCallback Callback to check progress. | |
| * @param schedulerId The integer ID of the scheduler to use. | |
| * - 0 for multi-step DPM solver, | |
| * - 1 for PNDM solver. | |
| * @param vaeCycle optionally draw VAE result every cycle iterations. | |
| * @param beginRenderVae Begin rendering VAE after skipping these warmup runs. | |
| */ | |
| async generate( | |
| prompt, | |
| negPrompt = "", | |
| progressCallback = undefined, | |
| schedulerId = 0, | |
| vaeCycle = -1, | |
| beginRenderVae = 10 | |
| ) { | |
| // Principle: beginScope/endScope in synchronized blocks, | |
| // this helps to recycle intermediate memories | |
| // detach states that needs to go across async boundaries. | |
| //-------------------------- | |
| // Stage 0: CLIP | |
| //-------------------------- | |
| this.tvm.beginScope(); | |
| // get latents | |
| const latentShape = [1, 4, 64, 64]; | |
| var unetNumSteps; | |
| if (schedulerId == 0) { | |
| scheduler = new TVMDPMSolverMultistepScheduler( | |
| this.schedulerConsts[0], latentShape, this.tvm, this.device, this.vm); | |
| unetNumSteps = this.schedulerConsts[0]["num_steps"]; | |
| } else { | |
| scheduler = new TVMPNDMScheduler( | |
| this.schedulerConsts[1], latentShape, this.tvm, this.device, this.vm); | |
| unetNumSteps = this.schedulerConsts[1]["num_steps"]; | |
| } | |
| const totalNumSteps = unetNumSteps + 2; | |
| if (progressCallback !== undefined) { | |
| progressCallback("clip", 0, 1, totalNumSteps); | |
| } | |
| const embeddings = this.tvm.withNewScope(() => { | |
| let posInputIDs = this.tokenize(prompt); | |
| let negInputIDs = this.tokenize(negPrompt); | |
| const posEmbeddings = this.clipToTextEmbeddings( | |
| posInputIDs, this.clipParams); | |
| const negEmbeddings = this.clipToTextEmbeddings( | |
| negInputIDs, this.clipParams); | |
| // maintain new latents | |
| return this.tvm.detachFromCurrentScope( | |
| this.concatEmbeddings(negEmbeddings, posEmbeddings) | |
| ); | |
| }); | |
| // use uniform distribution with same variance as normal(0, 1) | |
| const scale = Math.sqrt(12) / 2; | |
| let latents = this.tvm.detachFromCurrentScope( | |
| this.tvm.uniform(latentShape, -scale, scale, this.tvm.webgpu()) | |
| ); | |
| this.tvm.endScope(); | |
| //--------------------------- | |
| // Stage 1: UNet + Scheduler | |
| //--------------------------- | |
| if (vaeCycle != -1) { | |
| // show first frame | |
| this.tvm.withNewScope(() => { | |
| const image = this.vaeToImage(latents, this.vaeParams); | |
| this.tvm.showImage(this.imageToRGBA(image)); | |
| }); | |
| await this.device.sync(); | |
| } | |
| vaeCycle = vaeCycle == -1 ? unetNumSteps : vaeCycle; | |
| let lastSync = undefined; | |
| for (let counter = 0; counter < unetNumSteps; ++counter) { | |
| if (progressCallback !== undefined) { | |
| progressCallback("unet", counter, unetNumSteps, totalNumSteps); | |
| } | |
| const timestep = scheduler.timestep[counter]; | |
| // recycle noisePred, track latents manually | |
| const newLatents = this.tvm.withNewScope(() => { | |
| this.tvm.attachToCurrentScope(latents); | |
| const noisePred = this.unetLatentsToNoisePred( | |
| latents, timestep, embeddings, this.unetParams); | |
| // maintain new latents | |
| return this.tvm.detachFromCurrentScope( | |
| scheduler.step(noisePred, latents, counter) | |
| ); | |
| }); | |
| latents = newLatents; | |
| // use skip one sync, although likely not as useful. | |
| if (lastSync !== undefined) { | |
| await lastSync; | |
| } | |
| // async event checker | |
| lastSync = this.device.sync(); | |
| // Optionally, we can draw intermediate result of VAE. | |
| if ((counter + 1) % vaeCycle == 0 && | |
| (counter + 1) != unetNumSteps && | |
| counter >= beginRenderVae) { | |
| this.tvm.withNewScope(() => { | |
| const image = this.vaeToImage(latents, this.vaeParams); | |
| this.tvm.showImage(this.imageToRGBA(image)); | |
| }); | |
| await this.device.sync(); | |
| } | |
| } | |
| scheduler.dispose(); | |
| embeddings.dispose(); | |
| //----------------------------- | |
| // Stage 2: VAE and draw image | |
| //----------------------------- | |
| if (progressCallback !== undefined) { | |
| progressCallback("vae", 0, 1, totalNumSteps); | |
| } | |
| this.tvm.withNewScope(() => { | |
| const image = this.vaeToImage(latents, this.vaeParams); | |
| this.tvm.showImage(this.imageToRGBA(image)); | |
| }); | |
| latents.dispose(); | |
| await this.device.sync(); | |
| if (progressCallback !== undefined) { | |
| progressCallback("vae", 1, 1, totalNumSteps); | |
| } | |
| } | |
| clearCanvas() { | |
| this.tvm.clearCanvas(); | |
| } | |
| }; | |
| /** | |
| * A instance that can be used to facilitate deployment. | |
| */ | |
| class StableDiffusionInstance { | |
| constructor() { | |
| this.tvm = undefined; | |
| this.pipeline = undefined; | |
| this.config = undefined; | |
| this.generateInProgress = false; | |
| this.logger = console.log; | |
| } | |
| /** | |
| * Initialize TVM | |
| * @param wasmUrl URL to wasm source. | |
| * @param cacheUrl URL to NDArray cache. | |
| * @param logger Custom logger. | |
| */ | |
| async #asyncInitTVM(wasmUrl, cacheUrl) { | |
| if (this.tvm !== undefined) { | |
| return; | |
| } | |
| if (document.getElementById("log") !== undefined) { | |
| this.logger = function (message) { | |
| console.log(message); | |
| const d = document.createElement("div"); | |
| d.innerHTML = message; | |
| document.getElementById("log").appendChild(d); | |
| }; | |
| } | |
| const wasmSource = await ( | |
| await fetch(wasmUrl) | |
| ).arrayBuffer(); | |
| const tvm = await tvmjs.instantiate( | |
| new Uint8Array(wasmSource), | |
| new EmccWASI(), | |
| this.logger | |
| ); | |
| // initialize WebGPU | |
| try { | |
| const output = await tvmjs.detectGPUDevice(); | |
| if (output !== undefined) { | |
| var label = "WebGPU"; | |
| if (output.adapterInfo.description.length != 0) { | |
| label += " - " + output.adapterInfo.description; | |
| } else { | |
| label += " - " + output.adapterInfo.vendor; | |
| } | |
| document.getElementById( | |
| "gpu-tracker-label").innerHTML = ("Initialize GPU device: " + label); | |
| tvm.initWebGPU(output.device); | |
| } else { | |
| document.getElementById( | |
| "gpu-tracker-label").innerHTML = "This browser env do not support WebGPU"; | |
| this.reset(); | |
| throw Error("This browser env do not support WebGPU"); | |
| } | |
| } catch (err) { | |
| document.getElementById("gpu-tracker-label").innerHTML = ( | |
| "Find an error initializing the WebGPU device " + err.toString() | |
| ); | |
| console.log(err.stack); | |
| this.reset(); | |
| throw Error("Find an error initializing WebGPU: " + err.toString()); | |
| } | |
| this.tvm = tvm; | |
| function initProgressCallback(report) { | |
| document.getElementById("progress-tracker-label").innerHTML = report.text; | |
| document.getElementById("progress-tracker-progress").value = report.progress * 100; | |
| } | |
| tvm.registerInitProgressCallback(initProgressCallback); | |
| if (!cacheUrl.startsWith("http")) { | |
| cacheUrl = new URL(cacheUrl, document.URL).href; | |
| } | |
| await tvm.fetchNDArrayCache(cacheUrl, tvm.webgpu()); | |
| } | |
| /** | |
| * Initialize the pipeline | |
| * | |
| * @param schedulerConstUrl The scheduler constant. | |
| * @param tokenizerName The name of the tokenizer. | |
| */ | |
| async #asyncInitPipeline(schedulerConstUrl, tokenizerName) { | |
| if (this.tvm == undefined) { | |
| throw Error("asyncInitTVM is not called"); | |
| } | |
| if (this.pipeline !== undefined) return; | |
| var schedulerConst = [] | |
| for (let i = 0; i < schedulerConstUrl.length; ++i) { | |
| schedulerConst.push(await (await fetch(schedulerConstUrl[i])).json()) | |
| } | |
| const tokenizer = await tvmjsGlobalEnv.getTokenizer(tokenizerName); | |
| this.pipeline = this.tvm.withNewScope(() => { | |
| return new StableDiffusionPipeline(this.tvm, tokenizer, schedulerConst, this.tvm.cacheMetadata); | |
| }); | |
| await this.pipeline.asyncLoadWebGPUPiplines(); | |
| } | |
| /** | |
| * Async initialize config | |
| */ | |
| async #asyncInitConfig() { | |
| if (this.config !== undefined) return; | |
| this.config = await (await fetch("stable-diffusion-config.json")).json(); | |
| } | |
| /** | |
| * Function to create progress callback tracker. | |
| * @returns A progress callback tracker. | |
| */ | |
| #getProgressCallback() { | |
| const tstart = performance.now(); | |
| function progressCallback(stage, counter, numSteps, totalNumSteps) { | |
| const timeElapsed = (performance.now() - tstart) / 1000; | |
| let text = "Generating ... at stage " + stage; | |
| if (stage == "unet") { | |
| counter += 1; | |
| text += " step [" + counter + "/" + numSteps + "]" | |
| } | |
| if (stage == "vae") { | |
| counter = totalNumSteps; | |
| } | |
| text += ", " + Math.ceil(timeElapsed) + " secs elapsed."; | |
| document.getElementById("progress-tracker-label").innerHTML = text; | |
| document.getElementById("progress-tracker-progress").value = (counter / totalNumSteps) * 100; | |
| } | |
| return progressCallback; | |
| } | |
| /** | |
| * Async initialize instance. | |
| */ | |
| async asyncInit() { | |
| if (this.pipeline !== undefined) return; | |
| await this.#asyncInitConfig(); | |
| await this.#asyncInitTVM(this.config.wasmUrl, this.config.cacheUrl); | |
| await this.#asyncInitPipeline(this.config.schedulerConstUrl, this.config.tokenizer); | |
| } | |
| /** | |
| * Async initialize | |
| * | |
| * @param tvm The tvm instance. | |
| */ | |
| async asyncInitOnRPCServerLoad(tvmInstance) { | |
| if (this.tvm !== undefined) { | |
| throw Error("Cannot reuse a loaded instance for rpc"); | |
| } | |
| this.tvm = tvmInstance; | |
| this.tvm.beginScope(); | |
| this.tvm.registerAsyncServerFunc("generate", async (prompt, schedulerId, vaeCycle) => { | |
| document.getElementById("inputPrompt").value = prompt; | |
| const negPrompt = ""; | |
| document.getElementById("negativePrompt").value = ""; | |
| await this.pipeline.generate(prompt, negPrompt, this.#getProgressCallback(), schedulerId, vaeCycle); | |
| }); | |
| this.tvm.registerAsyncServerFunc("clearCanvas", async () => { | |
| this.tvm.clearCanvas(); | |
| }); | |
| this.tvm.registerAsyncServerFunc("showImage", async (data) => { | |
| this.tvm.showImage(data); | |
| }); | |
| this.tvm.endScope(); | |
| } | |
| /** | |
| * Run generate | |
| */ | |
| async generate() { | |
| if (this.requestInProgress) { | |
| this.logger("Request in progress, generate request ignored"); | |
| return; | |
| } | |
| this.requestInProgress = true; | |
| try { | |
| await this.asyncInit(); | |
| const prompt = document.getElementById("inputPrompt").value; | |
| const negPrompt = document.getElementById("negativePrompt").value; | |
| const schedulerId = document.getElementById("schedulerId").value; | |
| const vaeCycle = document.getElementById("vaeCycle").value; | |
| await this.pipeline.generate(prompt, negPrompt, this.#getProgressCallback(), schedulerId, vaeCycle); | |
| } catch (err) { | |
| this.logger("Generate error, " + err.toString()); | |
| console.log(err.stack); | |
| this.reset(); | |
| } | |
| this.requestInProgress = false; | |
| } | |
| /** | |
| * Reset the instance; | |
| */ | |
| reset() { | |
| this.tvm = undefined; | |
| if (this.pipeline !== undefined) { | |
| this.pipeline.dispose(); | |
| } | |
| this.pipeline = undefined; | |
| } | |
| } | |
| localStableDiffusionInst = new StableDiffusionInstance(); | |
| tvmjsGlobalEnv.asyncOnGenerate = async function () { | |
| await localStableDiffusionInst.generate(); | |
| }; | |
| tvmjsGlobalEnv.asyncOnRPCServerLoad = async function (tvm) { | |
| const inst = new StableDiffusionInstance(); | |
| await inst.asyncInitOnRPCServerLoad(tvm); | |
| }; | |