Spaces:
Configuration error
Configuration error
| import { | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| InterruptableStoppingCriteria, | |
| TextStreamer, | |
| env, | |
| } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@4.2.0"; | |
| env.allowLocalModels = false; | |
| env.useBrowserCache = true; | |
| installFetchTelemetry("llm"); | |
| let tokenizer = null; | |
| let model = null; | |
| let modelId = ""; | |
| let stoppingCriteria = new InterruptableStoppingCriteria(); | |
| let currentTurnId = 0; | |
| function installFetchTelemetry(scope) { | |
| const originalFetch = globalThis.fetch?.bind(globalThis); | |
| if (!originalFetch || globalThis.__browserSpeakFetchTelemetryInstalled) return; | |
| globalThis.__browserSpeakFetchTelemetryInstalled = true; | |
| globalThis.fetch = async (input, init) => { | |
| const startedAt = performance.now(); | |
| const url = fetchUrl(input); | |
| const method = String(init?.method || input?.method || "GET").toUpperCase(); | |
| try { | |
| const response = await originalFetch(input, init); | |
| self.postMessage({ | |
| type: "network", | |
| scope, | |
| method, | |
| url, | |
| responseUrl: response.url || url, | |
| status: response.status, | |
| ok: response.ok, | |
| durationMs: performance.now() - startedAt, | |
| }); | |
| return response; | |
| } catch (error) { | |
| self.postMessage({ | |
| type: "network", | |
| scope, | |
| method, | |
| url, | |
| status: null, | |
| ok: false, | |
| durationMs: performance.now() - startedAt, | |
| error: error.message ?? String(error), | |
| }); | |
| throw error; | |
| } | |
| }; | |
| } | |
| function fetchUrl(input) { | |
| if (typeof input === "string") return input; | |
| if (input instanceof URL) return input.href; | |
| return input?.url ?? ""; | |
| } | |
| self.onmessage = async (event) => { | |
| const message = event.data; | |
| try { | |
| if (message.type === "load") { | |
| await load(message); | |
| } else if (message.type === "generate") { | |
| currentTurnId = message.turnId; | |
| await generate(message.messages, message.turnId, { | |
| sentenceLimit: message.sentenceLimit ?? (message.stopAfterFirstSentence === false ? 0 : 1), | |
| }); | |
| } else if (message.type === "interrupt") { | |
| stoppingCriteria.interrupt(); | |
| } | |
| } catch (error) { | |
| self.postMessage({ type: "error", message: error.message ?? String(error) }); | |
| } | |
| }; | |
| async function load({ model: requestedModelId, device }) { | |
| modelId = requestedModelId; | |
| self.postMessage({ type: "status", message: "Loading", mode: "warn" }); | |
| tokenizer = await AutoTokenizer.from_pretrained(modelId, { | |
| progress_callback: reportProgress("LLM tokenizer"), | |
| }); | |
| model = await AutoModelForCausalLM.from_pretrained(modelId, { | |
| device, | |
| dtype: device === "webgpu" ? "q4f16" : "q4", | |
| progress_callback: reportProgress("LLM"), | |
| }); | |
| self.postMessage({ type: "status", message: "Warming", mode: "warn" }); | |
| const inputs = tokenizer("hello"); | |
| await model.generate({ ...inputs, max_new_tokens: 1 }); | |
| self.postMessage({ type: "ready" }); | |
| } | |
| function reportProgress(label) { | |
| return (progress) => { | |
| if (progress.status === "progress") { | |
| const pct = Number.isFinite(progress.progress) ? ` ${progress.progress.toFixed(0)}%` : ""; | |
| self.postMessage({ type: "status", message: `${label}${pct}`, mode: "warn" }); | |
| } | |
| }; | |
| } | |
| async function generate(messages, turnId, { sentenceLimit = 1 } = {}) { | |
| stoppingCriteria = new InterruptableStoppingCriteria(); | |
| self.postMessage({ type: "start", turnId }); | |
| const promptBuildStartedAt = performance.now(); | |
| const promptMessages = normalizeMessages(messages); | |
| const inputs = tokenizer.apply_chat_template(promptMessages, { | |
| add_generation_prompt: true, | |
| enable_thinking: false, | |
| return_dict: true, | |
| }); | |
| self.postMessage({ | |
| type: "prompt", | |
| turnId, | |
| inputTokens: inputTokenCount(inputs), | |
| promptBuildMs: performance.now() - promptBuildStartedAt, | |
| }); | |
| let firstTokenAt = 0; | |
| let tokenCount = 0; | |
| let tps = 0; | |
| let decodedText = ""; | |
| let emittedChars = 0; | |
| let sentenceStopped = false; | |
| const startedAt = performance.now(); | |
| const streamer = new TextStreamer(tokenizer, { | |
| skip_prompt: true, | |
| skip_special_tokens: true, | |
| token_callback_function: () => { | |
| tokenCount += 1; | |
| firstTokenAt ||= performance.now(); | |
| const elapsed = performance.now() - startedAt; | |
| if (elapsed > 0) tps = (tokenCount / elapsed) * 1000; | |
| }, | |
| callback_function: (text) => { | |
| if (turnId !== currentTurnId || sentenceStopped) return; | |
| decodedText += text; | |
| const boundary = sentenceLimit > 0 ? sentenceBoundary(decodedText, sentenceLimit) : -1; | |
| const emitUntil = boundary > 0 ? boundary : decodedText.length; | |
| const emitText = decodedText.slice(emittedChars, emitUntil); | |
| emittedChars = emitUntil; | |
| if (emitText) { | |
| self.postMessage({ type: "token", turnId, text: emitText, tps, tokenCount }); | |
| } | |
| if (boundary > 0) { | |
| sentenceStopped = true; | |
| stoppingCriteria.interrupt(); | |
| } | |
| }, | |
| }); | |
| await model.generate({ | |
| ...inputs, | |
| max_new_tokens: maxNewTokens(), | |
| do_sample: false, | |
| repetition_penalty: 1.08, | |
| streamer, | |
| stopping_criteria: stoppingCriteria, | |
| }); | |
| self.postMessage({ type: "complete", turnId, tokenCount, firstTokenMs: firstTokenAt - startedAt }); | |
| } | |
| function maxNewTokens() { | |
| if (modelId.includes("Qwen3")) return 40; | |
| if (modelId.includes("135M")) return 64; | |
| return 48; | |
| } | |
| function sentenceBoundary(text, targetCount = 1) { | |
| const normalized = text.replace(/\s+/g, " ").trim(); | |
| if (normalized.length < 12) return -1; | |
| const regex = /[.!?]["')\]]?(?:\s|$)/g; | |
| let match = null; | |
| let count = 0; | |
| while ((match = regex.exec(text))) { | |
| count += 1; | |
| if (count >= targetCount) return match.index + match[0].trimEnd().length; | |
| } | |
| return -1; | |
| } | |
| function normalizeMessages(messages) { | |
| if (typeof tokenizer.apply_chat_template === "function") return messages; | |
| const content = messages.map((message) => `${message.role}: ${message.content}`).join("\n"); | |
| return [{ role: "user", content }]; | |
| } | |
| function inputTokenCount(inputs) { | |
| const inputIds = inputs?.input_ids; | |
| if (!inputIds) return null; | |
| if (Array.isArray(inputIds)) { | |
| return Array.isArray(inputIds[0]) ? inputIds[0].length : inputIds.length; | |
| } | |
| if (Array.isArray(inputIds.dims) && inputIds.dims.length > 0) { | |
| return inputIds.dims[inputIds.dims.length - 1]; | |
| } | |
| if (Number.isFinite(inputIds.size)) return inputIds.size; | |
| if (Number.isFinite(inputIds.length)) return inputIds.length; | |
| if (Number.isFinite(inputIds.data?.length)) return inputIds.data.length; | |
| return null; | |
| } | |