supra-50m-instruct / worker.js
av-codes's picture
Static space with in-browser Transformers.js inference
875ab20 verified
import {
env,
AutoTokenizer,
AutoModelForCausalLM,
TextStreamer,
InterruptableStoppingCriteria,
} from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3";
env.allowLocalModels = false;
const MODEL_ID = "av-codes/Supra-50M-Instruct-ONNX";
let tokenizer = null;
let model = null;
let generating = false;
const stopping = new InterruptableStoppingCriteria();
function formatPrompt(instruction) {
return (
"Below is an instruction that describes a task. " +
"Write a response that appropriately completes the request.\n\n" +
"### Instruction:\n" +
instruction +
"\n\n### Response:\n"
);
}
async function load() {
self.postMessage({ type: "status", message: "Loading tokenizer..." });
tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID);
self.postMessage({ type: "status", message: "Loading model (50 MB)..." });
model = await AutoModelForCausalLM.from_pretrained(MODEL_ID, {
dtype: "q8",
progress_callback: (progress) => {
if (progress.status === "progress") {
self.postMessage({
type: "progress",
percent: progress.progress,
file: progress.file,
});
}
},
});
self.postMessage({ type: "ready" });
}
async function generate(instruction, params) {
if (!model || !tokenizer || generating) return;
generating = true;
stopping.reset();
const prompt = formatPrompt(instruction);
const inputs = tokenizer(prompt);
const streamer = new TextStreamer(tokenizer, {
skip_prompt: true,
skip_special_tokens: true,
callback_function: (text) => {
self.postMessage({ type: "token", text });
},
});
try {
await model.generate({
...inputs,
max_new_tokens: params.max_new_tokens || 256,
temperature: params.temperature || 0.7,
top_k: params.top_k || 50,
top_p: params.top_p || 0.9,
repetition_penalty: params.repetition_penalty || 1.15,
do_sample: params.temperature > 0,
streamer,
stopping_criteria: [stopping],
});
} catch (e) {
self.postMessage({ type: "error", message: e.message });
}
generating = false;
self.postMessage({ type: "done" });
}
self.onmessage = (e) => {
const { type, instruction, params } = e.data;
if (type === "load") load();
else if (type === "generate") generate(instruction, params);
else if (type === "stop") stopping.interrupt();
};