| import { |
| pipeline, |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| type FeatureExtractionPipeline, |
| type PreTrainedTokenizer, |
| type PreTrainedModel, |
| type ProgressInfo, |
| } from "@huggingface/transformers"; |
| import { MODEL_EMBEDDING, MODEL_RERANKER, MODEL_EXPANSION } from "../constants"; |
| import type { ModelState } from "../types"; |
|
|
| type ProgressCallback = (state: ModelState) => void; |
|
|
| |
| let embeddingPipeline: FeatureExtractionPipeline | null = null; |
| let embeddingLoadPromise: Promise<void> | null = null; |
|
|
| |
| let rerankerModel: PreTrainedModel | null = null; |
| let rerankerTokenizer: PreTrainedTokenizer | null = null; |
| let rerankerTokenYes = -1; |
| let rerankerTokenNo = -1; |
| let rerankerLoadPromise: Promise<void> | null = null; |
|
|
| |
| let expansionModel: PreTrainedModel | null = null; |
| let expansionTokenizer: PreTrainedTokenizer | null = null; |
| let expansionLoadPromise: Promise<void> | null = null; |
|
|
| |
| export async function checkWebGPU(): Promise<boolean> { |
| if (!navigator.gpu) return false; |
| try { |
| const adapter = await navigator.gpu.requestAdapter(); |
| return adapter !== null; |
| } catch { |
| return false; |
| } |
| } |
|
|
| |
| |
| |
|
|
| function makeProgressHandler( |
| modelName: string, |
| onProgress?: ProgressCallback, |
| ): ((info: ProgressInfo) => void) | undefined { |
| if (!onProgress) return undefined; |
|
|
| |
| let lastProgress = -1; |
| let debounceTimer: ReturnType<typeof setTimeout> | null = null; |
|
|
| return (info: ProgressInfo) => { |
| switch (info.status) { |
| case "initiate": |
| |
| if (lastProgress < 0) { |
| lastProgress = 0; |
| onProgress({ name: modelName, status: "downloading", progress: 0 }); |
| } |
| break; |
| case "download": |
| break; |
| case "progress": { |
| const p = (info as { progress: number }).progress / 100; |
| |
| if (p - lastProgress >= 0.02 || p >= 1) { |
| lastProgress = p; |
| if (debounceTimer) clearTimeout(debounceTimer); |
| debounceTimer = setTimeout(() => { |
| onProgress({ name: modelName, status: "downloading", progress: p }); |
| }, 50); |
| } |
| break; |
| } |
| case "done": |
| if (debounceTimer) clearTimeout(debounceTimer); |
| onProgress({ name: modelName, status: "loading", progress: 1 }); |
| break; |
| case "ready": |
| if (debounceTimer) clearTimeout(debounceTimer); |
| lastProgress = 1; |
| onProgress({ name: modelName, status: "ready", progress: 1 }); |
| break; |
| } |
| }; |
| } |
|
|
| |
| |
| |
|
|
| export async function loadEmbeddingModel( |
| onProgress?: ProgressCallback, |
| ): Promise<void> { |
| if (embeddingPipeline) return; |
| if (embeddingLoadPromise) return await embeddingLoadPromise; |
| const name = "embedding"; |
| onProgress?.({ name, status: "pending", progress: 0 }); |
| embeddingLoadPromise = (async () => { |
| try { |
| embeddingPipeline = await pipeline("feature-extraction", MODEL_EMBEDDING, { |
| dtype: "q4", |
| device: "webgpu", |
| progress_callback: makeProgressHandler(name, onProgress), |
| }); |
| onProgress?.({ name, status: "ready", progress: 1 }); |
| } catch (err) { |
| onProgress?.({ |
| name, |
| status: "error", |
| progress: 0, |
| error: err instanceof Error ? err.message : String(err), |
| }); |
| throw err; |
| } finally { |
| embeddingLoadPromise = null; |
| } |
| })(); |
|
|
| return await embeddingLoadPromise; |
| } |
|
|
| export async function loadRerankerModel( |
| onProgress?: ProgressCallback, |
| ): Promise<void> { |
| if (rerankerModel) return; |
| if (rerankerLoadPromise) return await rerankerLoadPromise; |
| const name = "reranker"; |
| onProgress?.({ name, status: "pending", progress: 0 }); |
| rerankerLoadPromise = (async () => { |
| try { |
| const progressHandler = makeProgressHandler(name, onProgress); |
|
|
| |
| rerankerTokenizer = await AutoTokenizer.from_pretrained(MODEL_RERANKER, { |
| progress_callback: progressHandler, |
| }); |
|
|
| |
| const yesIds = rerankerTokenizer("yes", { add_special_tokens: false }).input_ids.data; |
| const noIds = rerankerTokenizer("no", { add_special_tokens: false }).input_ids.data; |
| rerankerTokenYes = Number(yesIds[yesIds.length - 1]); |
| rerankerTokenNo = Number(noIds[noIds.length - 1]); |
|
|
| rerankerModel = await AutoModelForCausalLM.from_pretrained(MODEL_RERANKER, { |
| dtype: "q4", |
| device: "webgpu", |
| progress_callback: progressHandler, |
| }); |
|
|
| onProgress?.({ name, status: "ready", progress: 1 }); |
| } catch (err) { |
| onProgress?.({ |
| name, |
| status: "error", |
| progress: 0, |
| error: err instanceof Error ? err.message : String(err), |
| }); |
| throw err; |
| } finally { |
| rerankerLoadPromise = null; |
| } |
| })(); |
|
|
| return await rerankerLoadPromise; |
| } |
|
|
| export async function loadExpansionModel( |
| onProgress?: ProgressCallback, |
| ): Promise<void> { |
| if (expansionModel) return; |
| if (expansionLoadPromise) return await expansionLoadPromise; |
| const name = "expansion"; |
| onProgress?.({ name, status: "pending", progress: 0 }); |
| expansionLoadPromise = (async () => { |
| try { |
| const progressHandler = makeProgressHandler(name, onProgress); |
|
|
| expansionTokenizer = await AutoTokenizer.from_pretrained(MODEL_EXPANSION, { |
| progress_callback: progressHandler, |
| }); |
|
|
| |
| |
| if (!expansionTokenizer.chat_template) { |
| expansionTokenizer.chat_template = |
| "{% for message in messages %}" + |
| "<|im_start|>{{ message.role }}\n{{ message.content }}<|im_end|>\n" + |
| "{% endfor %}" + |
| "{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"; |
| } |
|
|
| expansionModel = await AutoModelForCausalLM.from_pretrained(MODEL_EXPANSION, { |
| dtype: "q4", |
| device: "webgpu", |
| progress_callback: progressHandler, |
| }); |
|
|
| onProgress?.({ name, status: "ready", progress: 1 }); |
| } catch (err) { |
| onProgress?.({ |
| name, |
| status: "error", |
| progress: 0, |
| error: err instanceof Error ? err.message : String(err), |
| }); |
| throw err; |
| } finally { |
| expansionLoadPromise = null; |
| } |
| })(); |
|
|
| return await expansionLoadPromise; |
| } |
|
|
| |
| function withTimeout<T>( |
| promise: Promise<T>, |
| timeoutMs: number, |
| message: string, |
| ): Promise<T> { |
| return Promise.race([ |
| promise, |
| new Promise<T>((_, reject) => |
| setTimeout(() => reject(new Error(message)), timeoutMs), |
| ), |
| ]); |
| } |
|
|
| |
| |
|
|
| export async function loadAllModels( |
| onProgress?: ProgressCallback, |
| ): Promise<void> { |
| const hasWebGPU = await checkWebGPU(); |
| if (!hasWebGPU) { |
| const err = "WebGPU is not available in this browser"; |
| for (const name of ["embedding", "reranker", "expansion"]) { |
| onProgress?.({ name, status: "error", progress: 0, error: err }); |
| } |
| throw new Error(err); |
| } |
|
|
| const EXPANSION_TIMEOUT_MS = 120_000; |
| const [embeddingResult, rerankerResult] = await Promise.allSettled([ |
| loadEmbeddingModel(onProgress), |
| loadRerankerModel(onProgress), |
| withTimeout( |
| loadExpansionModel(onProgress), |
| EXPANSION_TIMEOUT_MS, |
| "Expansion model timed out", |
| ).catch((err) => { |
| onProgress?.({ |
| name: "expansion", |
| status: "error", |
| progress: 0, |
| error: err instanceof Error ? err.message : "Failed to load expansion model", |
| }); |
| |
| }), |
| ]); |
|
|
| if (embeddingResult.status === "rejected") { |
| throw embeddingResult.reason; |
| } |
|
|
| if (rerankerResult.status === "rejected") { |
| throw rerankerResult.reason; |
| } |
| } |
|
|
| |
| |
| |
|
|
| export function getEmbeddingPipeline(): FeatureExtractionPipeline | null { |
| return embeddingPipeline; |
| } |
|
|
| export function getRerankerModel(): PreTrainedModel | null { |
| return rerankerModel; |
| } |
|
|
| export function getRerankerTokenizer(): PreTrainedTokenizer | null { |
| return rerankerTokenizer; |
| } |
|
|
| export function getRerankerTokenIds(): { yes: number; no: number } { |
| return { yes: rerankerTokenYes, no: rerankerTokenNo }; |
| } |
|
|
| export function getExpansionModel(): PreTrainedModel | null { |
| return expansionModel; |
| } |
|
|
| export function getExpansionTokenizer(): PreTrainedTokenizer | null { |
| return expansionTokenizer; |
| } |
|
|
| export function isAllModelsReady(): boolean { |
| |
| return embeddingPipeline !== null && rerankerModel !== null; |
| } |
|
|
| export function isExpansionReady(): boolean { |
| return expansionModel !== null; |
| } |
|
|