File size: 3,935 Bytes
f500247 e9952d1 f500247 e9952d1 f500247 e9952d1 f500247 e9952d1 f500247 e9952d1 f500247 e9952d1 f500247 e9952d1 f500247 e9952d1 f500247 e9952d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import {
AutoTokenizer,
AutoModelForCausalLM,
TextStreamer,
InterruptableStoppingCriteria,
} from "@huggingface/transformers";
/**
* Helper function to perform feature detection for WebGPU
*/
async function check() {
try {
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
throw new Error("WebGPU is not supported (no adapter found)");
}
} catch (e) {
self.postMessage({
status: "error",
data: e.toString(),
});
}
}
/**
* Singleton class for model loading
*/
class TextGenerationPipeline {
static model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct";
static async getInstance(progress_callback = null) {
this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
progress_callback,
});
this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, {
dtype: "q4f16",
device: "webgpu",
progress_callback,
});
return Promise.all([this.tokenizer, this.model]);
}
}
const stopping_criteria = new InterruptableStoppingCriteria();
let past_key_values_cache = null;
async function generate(messages) {
// --- MEDICAL ASSISTANT CONFIGURATION ---
const systemPrompt = {
role: "system",
content: "You are a professional Emergency Room (ER) Medical Assistant. Your role is to assist physicians by providing triage classifications, suggesting immediate clinical actions based on ABCDE/ACLS protocols, and summarizing patient history using medical terminology. Be concise, objective, and prioritize life-threatening conditions. Always maintain a clinical tone."
};
// Inject the system prompt at the beginning of the conversation
const messagesWithSystem = [systemPrompt, ...messages];
// ---------------------------------------
const [tokenizer, model] = await TextGenerationPipeline.getInstance();
const inputs = tokenizer.apply_chat_template(messagesWithSystem, {
add_generation_prompt: true,
return_dict: true,
});
let startTime;
let numTokens = 0;
let tps;
const token_callback_function = () => {
startTime ??= performance.now();
if (numTokens++ > 0) {
tps = (numTokens / (performance.now() - startTime)) * 1000;
}
};
const callback_function = (output) => {
self.postMessage({
status: "update",
output,
tps,
numTokens,
});
};
const streamer = new TextStreamer(tokenizer, {
skip_prompt: true,
skip_special_tokens: true,
callback_function,
token_callback_function,
});
self.postMessage({ status: "start" });
const { past_key_values, sequences } = await model.generate({
...inputs,
past_key_values: past_key_values_cache,
max_new_tokens: 1024,
streamer,
stopping_criteria,
return_dict_in_generate: true,
});
past_key_values_cache = past_key_values;
const decoded = tokenizer.batch_decode(sequences, {
skip_special_tokens: true,
});
self.postMessage({
status: "complete",
output: decoded,
});
}
async function load() {
self.postMessage({
status: "loading",
data: "Loading medical assistant model...",
});
const [tokenizer, model] = await TextGenerationPipeline.getInstance((x) => {
self.postMessage(x);
});
self.postMessage({
status: "loading",
data: "Compiling shaders and warming up...",
});
const inputs = tokenizer("a");
await model.generate({ ...inputs, max_new_tokens: 1 });
self.postMessage({ status: "ready" });
}
self.addEventListener("message", async (e) => {
const { type, data } = e.data;
switch (type) {
case "check":
check();
break;
case "load":
load();
break;
case "generate":
stopping_criteria.reset();
generate(data);
break;
case "interrupt":
stopping_criteria.interrupt();
break;
case "reset":
past_key_values_cache = null;
stopping_criteria.reset();
break;
}
}); |