import { pipeline, AutoTokenizer, AutoModelForCausalLM, type FeatureExtractionPipeline, type PreTrainedTokenizer, type PreTrainedModel, type ProgressInfo, } from "@huggingface/transformers"; import { MODEL_EMBEDDING, MODEL_RERANKER, MODEL_EXPANSION } from "../constants"; import type { ModelState } from "../types"; type ProgressCallback = (state: ModelState) => void; // Singleton model instances let embeddingPipeline: FeatureExtractionPipeline | null = null; let embeddingLoadPromise: Promise | null = null; // Reranker uses AutoModel + AutoTokenizer (not a pipeline) let rerankerModel: PreTrainedModel | null = null; let rerankerTokenizer: PreTrainedTokenizer | null = null; let rerankerTokenYes = -1; let rerankerTokenNo = -1; let rerankerLoadPromise: Promise | null = null; // Expansion uses AutoModel + AutoTokenizer (model was exported without KV cache) let expansionModel: PreTrainedModel | null = null; let expansionTokenizer: PreTrainedTokenizer | null = null; let expansionLoadPromise: Promise | null = null; /** Check whether WebGPU is available in this browser. */ export async function checkWebGPU(): Promise { if (!navigator.gpu) return false; try { const adapter = await navigator.gpu.requestAdapter(); return adapter !== null; } catch { return false; } } // --------------------------------------------------------------------------- // Internal: translate Transformers.js ProgressInfo → our ModelState // --------------------------------------------------------------------------- function makeProgressHandler( modelName: string, onProgress?: ProgressCallback, ): ((info: ProgressInfo) => void) | undefined { if (!onProgress) return undefined; // Debounce progress updates to avoid flickering (multiple files fire rapidly) let lastProgress = -1; let debounceTimer: ReturnType | null = null; return (info: ProgressInfo) => { switch (info.status) { case "initiate": // Only fire once at start, ignore subsequent file initiations if (lastProgress < 0) { lastProgress = 0; onProgress({ name: modelName, status: "downloading", progress: 0 }); } break; case "download": break; // skip, we'll get progress events case "progress": { const p = (info as { progress: number }).progress / 100; // Only update if progress moved by at least 2% or debounce expired if (p - lastProgress >= 0.02 || p >= 1) { lastProgress = p; if (debounceTimer) clearTimeout(debounceTimer); debounceTimer = setTimeout(() => { onProgress({ name: modelName, status: "downloading", progress: p }); }, 50); } break; } case "done": if (debounceTimer) clearTimeout(debounceTimer); onProgress({ name: modelName, status: "loading", progress: 1 }); break; case "ready": if (debounceTimer) clearTimeout(debounceTimer); lastProgress = 1; onProgress({ name: modelName, status: "ready", progress: 1 }); break; } }; } // --------------------------------------------------------------------------- // Individual model loaders // --------------------------------------------------------------------------- export async function loadEmbeddingModel( onProgress?: ProgressCallback, ): Promise { if (embeddingPipeline) return; if (embeddingLoadPromise) return await embeddingLoadPromise; const name = "embedding"; onProgress?.({ name, status: "pending", progress: 0 }); embeddingLoadPromise = (async () => { try { embeddingPipeline = await pipeline("feature-extraction", MODEL_EMBEDDING, { dtype: "q4", device: "webgpu", progress_callback: makeProgressHandler(name, onProgress), }); onProgress?.({ name, status: "ready", progress: 1 }); } catch (err) { onProgress?.({ name, status: "error", progress: 0, error: err instanceof Error ? err.message : String(err), }); throw err; } finally { embeddingLoadPromise = null; } })(); return await embeddingLoadPromise; } export async function loadRerankerModel( onProgress?: ProgressCallback, ): Promise { if (rerankerModel) return; if (rerankerLoadPromise) return await rerankerLoadPromise; const name = "reranker"; onProgress?.({ name, status: "pending", progress: 0 }); rerankerLoadPromise = (async () => { try { const progressHandler = makeProgressHandler(name, onProgress); // Load tokenizer and model separately (cross-encoder pattern) rerankerTokenizer = await AutoTokenizer.from_pretrained(MODEL_RERANKER, { progress_callback: progressHandler, }); // Pre-compute "yes" and "no" token IDs for scoring const yesIds = rerankerTokenizer("yes", { add_special_tokens: false }).input_ids.data; const noIds = rerankerTokenizer("no", { add_special_tokens: false }).input_ids.data; rerankerTokenYes = Number(yesIds[yesIds.length - 1]); rerankerTokenNo = Number(noIds[noIds.length - 1]); rerankerModel = await AutoModelForCausalLM.from_pretrained(MODEL_RERANKER, { dtype: "q4", device: "webgpu", progress_callback: progressHandler, }); onProgress?.({ name, status: "ready", progress: 1 }); } catch (err) { onProgress?.({ name, status: "error", progress: 0, error: err instanceof Error ? err.message : String(err), }); throw err; } finally { rerankerLoadPromise = null; } })(); return await rerankerLoadPromise; } export async function loadExpansionModel( onProgress?: ProgressCallback, ): Promise { if (expansionModel) return; if (expansionLoadPromise) return await expansionLoadPromise; const name = "expansion"; onProgress?.({ name, status: "pending", progress: 0 }); expansionLoadPromise = (async () => { try { const progressHandler = makeProgressHandler(name, onProgress); expansionTokenizer = await AutoTokenizer.from_pretrained(MODEL_EXPANSION, { progress_callback: progressHandler, }); // The HF repo has chat_template.jinja but it's not in tokenizer_config.json, // so set the Qwen ChatML template manually. if (!expansionTokenizer.chat_template) { expansionTokenizer.chat_template = "{% for message in messages %}" + "<|im_start|>{{ message.role }}\n{{ message.content }}<|im_end|>\n" + "{% endfor %}" + "{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"; } expansionModel = await AutoModelForCausalLM.from_pretrained(MODEL_EXPANSION, { dtype: "q4", device: "webgpu", progress_callback: progressHandler, }); onProgress?.({ name, status: "ready", progress: 1 }); } catch (err) { onProgress?.({ name, status: "error", progress: 0, error: err instanceof Error ? err.message : String(err), }); throw err; } finally { expansionLoadPromise = null; } })(); return await expansionLoadPromise; } // --------------------------------------------------------------------------- function withTimeout( promise: Promise, timeoutMs: number, message: string, ): Promise { return Promise.race([ promise, new Promise((_, reject) => setTimeout(() => reject(new Error(message)), timeoutMs), ), ]); } // Load all models in parallel // --------------------------------------------------------------------------- export async function loadAllModels( onProgress?: ProgressCallback, ): Promise { const hasWebGPU = await checkWebGPU(); if (!hasWebGPU) { const err = "WebGPU is not available in this browser"; for (const name of ["embedding", "reranker", "expansion"]) { onProgress?.({ name, status: "error", progress: 0, error: err }); } throw new Error(err); } const EXPANSION_TIMEOUT_MS = 120_000; // 2 minutes const [embeddingResult, rerankerResult] = await Promise.allSettled([ loadEmbeddingModel(onProgress), loadRerankerModel(onProgress), withTimeout( loadExpansionModel(onProgress), EXPANSION_TIMEOUT_MS, "Expansion model timed out", ).catch((err) => { onProgress?.({ name: "expansion", status: "error", progress: 0, error: err instanceof Error ? err.message : "Failed to load expansion model", }); // The pipeline will fall back to using the original query without expansion. }), ]); if (embeddingResult.status === "rejected") { throw embeddingResult.reason; } if (rerankerResult.status === "rejected") { throw rerankerResult.reason; } } // --------------------------------------------------------------------------- // Getters // --------------------------------------------------------------------------- export function getEmbeddingPipeline(): FeatureExtractionPipeline | null { return embeddingPipeline; } export function getRerankerModel(): PreTrainedModel | null { return rerankerModel; } export function getRerankerTokenizer(): PreTrainedTokenizer | null { return rerankerTokenizer; } export function getRerankerTokenIds(): { yes: number; no: number } { return { yes: rerankerTokenYes, no: rerankerTokenNo }; } export function getExpansionModel(): PreTrainedModel | null { return expansionModel; } export function getExpansionTokenizer(): PreTrainedTokenizer | null { return expansionTokenizer; } export function isAllModelsReady(): boolean { // Embedding + reranker are required; expansion is optional return embeddingPipeline !== null && rerankerModel !== null; } export function isExpansionReady(): boolean { return expansionModel !== null; }