| import { |
| pipeline, |
| TextStreamer, |
| DynamicCache, |
| InterruptableStoppingCriteria, |
| } from "@huggingface/transformers"; |
|
|
| const MODEL_IDS = { |
| "1.7b": "onnx-community/Bonsai-1.7B-ONNX", |
| }; |
|
|
| async function check() { |
| try { |
| const adapter = await navigator.gpu?.requestAdapter(); |
| if (!adapter) throw new Error("WebGPU is not supported (no adapter found)"); |
| } catch (e) { |
| self.postMessage({ status: "error", data: e.toString() }); |
| } |
| } |
|
|
| class TextGenerationPipeline { |
| static instances = new Map(); |
|
|
| static getInstance(modelKey, progress_callback = null) { |
| const modelId = MODEL_IDS[modelKey]; |
| if (!modelId) throw new Error(`Unknown model: ${modelKey}`); |
| if (!this.instances.has(modelKey)) { |
| this.instances.set( |
| modelKey, |
| pipeline("text-generation", modelId, { |
| device: "webgpu", |
| dtype: "q1", |
| progress_callback, |
| }), |
| ); |
| } |
| return this.instances.get(modelKey); |
| } |
| } |
|
|
| const stopping_criteria = new InterruptableStoppingCriteria(); |
| let past_key_values_cache = null; |
| let current_model_key = null; |
|
|
| function disposePastKeyValues() { |
| past_key_values_cache?.dispose?.(); |
| past_key_values_cache = null; |
| } |
|
|
| async function load(modelKey) { |
| if (current_model_key && current_model_key !== modelKey) { |
| disposePastKeyValues(); |
| } |
| current_model_key = modelKey; |
|
|
| self.postMessage({ status: "loading", data: "Loading model..." }); |
|
|
| const generator = await TextGenerationPipeline.getInstance( |
| modelKey, |
| (info) => { |
| if (info.status === "progress_total") { |
| self.postMessage({ |
| status: "progress_total", |
| progress: Number(info.progress ?? 0), |
| loaded: Number(info.loaded ?? 0), |
| total: Number(info.total ?? 0), |
| }); |
| } |
| }, |
| ); |
|
|
| self.postMessage({ |
| status: "loading", |
| data: "Optimizing model for 1-bit execution", |
| }); |
|
|
| const inputs = generator.tokenizer("a"); |
| await generator.model.generate({ ...inputs, max_new_tokens: 1 }); |
|
|
| self.postMessage({ status: "ready" }); |
| } |
|
|
| async function generate(messages) { |
| const generator = await TextGenerationPipeline.getInstance(current_model_key); |
|
|
| let startTime; |
| let numTokens = 0; |
| let tps; |
|
|
| const streamer = new TextStreamer(generator.tokenizer, { |
| skip_prompt: true, |
| skip_special_tokens: true, |
| callback_function: (output) => { |
| self.postMessage({ status: "update", output, tps, numTokens }); |
| }, |
| token_callback_function: () => { |
| startTime ??= performance.now(); |
| if (numTokens++ > 0) { |
| tps = (numTokens / (performance.now() - startTime)) * 1000; |
| } |
| }, |
| }); |
|
|
| self.postMessage({ status: "start" }); |
|
|
| past_key_values_cache ??= new DynamicCache(); |
|
|
| try { |
| const output = await generator(messages, { |
| max_new_tokens: 1024, |
| do_sample: false, |
| streamer, |
| stopping_criteria, |
| past_key_values: past_key_values_cache, |
| }); |
|
|
| self.postMessage({ |
| status: "complete", |
| output: output[0].generated_text.at(-1).content, |
| }); |
| } catch (e) { |
| self.postMessage({ status: "error", data: e.toString() }); |
| } |
| } |
|
|
| self.addEventListener("message", async (e) => { |
| const { type, data } = e.data; |
| switch (type) { |
| case "check": |
| check(); |
| break; |
| case "load": |
| load(data); |
| break; |
| case "generate": |
| stopping_criteria.reset(); |
| generate(data); |
| break; |
| case "interrupt": |
| stopping_criteria.interrupt(); |
| break; |
| case "reset": |
| disposePastKeyValues(); |
| stopping_criteria.reset(); |
| break; |
| } |
| }); |
|
|