browser-speak / workers /llm-worker.js
Mike0021's picture
Add worker network telemetry to browser evidence
d2ae80e verified
Raw
History Blame Contribute Delete
6.54 kB
import {
AutoModelForCausalLM,
AutoTokenizer,
InterruptableStoppingCriteria,
TextStreamer,
env,
} from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@4.2.0";
env.allowLocalModels = false;
env.useBrowserCache = true;
installFetchTelemetry("llm");
let tokenizer = null;
let model = null;
let modelId = "";
let stoppingCriteria = new InterruptableStoppingCriteria();
let currentTurnId = 0;
function installFetchTelemetry(scope) {
const originalFetch = globalThis.fetch?.bind(globalThis);
if (!originalFetch || globalThis.__browserSpeakFetchTelemetryInstalled) return;
globalThis.__browserSpeakFetchTelemetryInstalled = true;
globalThis.fetch = async (input, init) => {
const startedAt = performance.now();
const url = fetchUrl(input);
const method = String(init?.method || input?.method || "GET").toUpperCase();
try {
const response = await originalFetch(input, init);
self.postMessage({
type: "network",
scope,
method,
url,
responseUrl: response.url || url,
status: response.status,
ok: response.ok,
durationMs: performance.now() - startedAt,
});
return response;
} catch (error) {
self.postMessage({
type: "network",
scope,
method,
url,
status: null,
ok: false,
durationMs: performance.now() - startedAt,
error: error.message ?? String(error),
});
throw error;
}
};
}
function fetchUrl(input) {
if (typeof input === "string") return input;
if (input instanceof URL) return input.href;
return input?.url ?? "";
}
self.onmessage = async (event) => {
const message = event.data;
try {
if (message.type === "load") {
await load(message);
} else if (message.type === "generate") {
currentTurnId = message.turnId;
await generate(message.messages, message.turnId, {
sentenceLimit: message.sentenceLimit ?? (message.stopAfterFirstSentence === false ? 0 : 1),
});
} else if (message.type === "interrupt") {
stoppingCriteria.interrupt();
}
} catch (error) {
self.postMessage({ type: "error", message: error.message ?? String(error) });
}
};
async function load({ model: requestedModelId, device }) {
modelId = requestedModelId;
self.postMessage({ type: "status", message: "Loading", mode: "warn" });
tokenizer = await AutoTokenizer.from_pretrained(modelId, {
progress_callback: reportProgress("LLM tokenizer"),
});
model = await AutoModelForCausalLM.from_pretrained(modelId, {
device,
dtype: device === "webgpu" ? "q4f16" : "q4",
progress_callback: reportProgress("LLM"),
});
self.postMessage({ type: "status", message: "Warming", mode: "warn" });
const inputs = tokenizer("hello");
await model.generate({ ...inputs, max_new_tokens: 1 });
self.postMessage({ type: "ready" });
}
function reportProgress(label) {
return (progress) => {
if (progress.status === "progress") {
const pct = Number.isFinite(progress.progress) ? ` ${progress.progress.toFixed(0)}%` : "";
self.postMessage({ type: "status", message: `${label}${pct}`, mode: "warn" });
}
};
}
async function generate(messages, turnId, { sentenceLimit = 1 } = {}) {
stoppingCriteria = new InterruptableStoppingCriteria();
self.postMessage({ type: "start", turnId });
const promptBuildStartedAt = performance.now();
const promptMessages = normalizeMessages(messages);
const inputs = tokenizer.apply_chat_template(promptMessages, {
add_generation_prompt: true,
enable_thinking: false,
return_dict: true,
});
self.postMessage({
type: "prompt",
turnId,
inputTokens: inputTokenCount(inputs),
promptBuildMs: performance.now() - promptBuildStartedAt,
});
let firstTokenAt = 0;
let tokenCount = 0;
let tps = 0;
let decodedText = "";
let emittedChars = 0;
let sentenceStopped = false;
const startedAt = performance.now();
const streamer = new TextStreamer(tokenizer, {
skip_prompt: true,
skip_special_tokens: true,
token_callback_function: () => {
tokenCount += 1;
firstTokenAt ||= performance.now();
const elapsed = performance.now() - startedAt;
if (elapsed > 0) tps = (tokenCount / elapsed) * 1000;
},
callback_function: (text) => {
if (turnId !== currentTurnId || sentenceStopped) return;
decodedText += text;
const boundary = sentenceLimit > 0 ? sentenceBoundary(decodedText, sentenceLimit) : -1;
const emitUntil = boundary > 0 ? boundary : decodedText.length;
const emitText = decodedText.slice(emittedChars, emitUntil);
emittedChars = emitUntil;
if (emitText) {
self.postMessage({ type: "token", turnId, text: emitText, tps, tokenCount });
}
if (boundary > 0) {
sentenceStopped = true;
stoppingCriteria.interrupt();
}
},
});
await model.generate({
...inputs,
max_new_tokens: maxNewTokens(),
do_sample: false,
repetition_penalty: 1.08,
streamer,
stopping_criteria: stoppingCriteria,
});
self.postMessage({ type: "complete", turnId, tokenCount, firstTokenMs: firstTokenAt - startedAt });
}
function maxNewTokens() {
if (modelId.includes("Qwen3")) return 40;
if (modelId.includes("135M")) return 64;
return 48;
}
function sentenceBoundary(text, targetCount = 1) {
const normalized = text.replace(/\s+/g, " ").trim();
if (normalized.length < 12) return -1;
const regex = /[.!?]["')\]]?(?:\s|$)/g;
let match = null;
let count = 0;
while ((match = regex.exec(text))) {
count += 1;
if (count >= targetCount) return match.index + match[0].trimEnd().length;
}
return -1;
}
function normalizeMessages(messages) {
if (typeof tokenizer.apply_chat_template === "function") return messages;
const content = messages.map((message) => `${message.role}: ${message.content}`).join("\n");
return [{ role: "user", content }];
}
function inputTokenCount(inputs) {
const inputIds = inputs?.input_ids;
if (!inputIds) return null;
if (Array.isArray(inputIds)) {
return Array.isArray(inputIds[0]) ? inputIds[0].length : inputIds.length;
}
if (Array.isArray(inputIds.dims) && inputIds.dims.length > 0) {
return inputIds.dims[inputIds.dims.length - 1];
}
if (Number.isFinite(inputIds.size)) return inputIds.size;
if (Number.isFinite(inputIds.length)) return inputIds.length;
if (Number.isFinite(inputIds.data?.length)) return inputIds.data.length;
return null;
}