qmd-web / src /pipeline /models.ts
shreyask's picture
Deploy qmd-web
ac50275 verified
import {
pipeline,
AutoTokenizer,
AutoModelForCausalLM,
type FeatureExtractionPipeline,
type PreTrainedTokenizer,
type PreTrainedModel,
type ProgressInfo,
} from "@huggingface/transformers";
import { MODEL_EMBEDDING, MODEL_RERANKER, MODEL_EXPANSION } from "../constants";
import type { ModelState } from "../types";
type ProgressCallback = (state: ModelState) => void;
// Singleton model instances
let embeddingPipeline: FeatureExtractionPipeline | null = null;
let embeddingLoadPromise: Promise<void> | null = null;
// Reranker uses AutoModel + AutoTokenizer (not a pipeline)
let rerankerModel: PreTrainedModel | null = null;
let rerankerTokenizer: PreTrainedTokenizer | null = null;
let rerankerTokenYes = -1;
let rerankerTokenNo = -1;
let rerankerLoadPromise: Promise<void> | null = null;
// Expansion uses AutoModel + AutoTokenizer (model was exported without KV cache)
let expansionModel: PreTrainedModel | null = null;
let expansionTokenizer: PreTrainedTokenizer | null = null;
let expansionLoadPromise: Promise<void> | null = null;
/** Check whether WebGPU is available in this browser. */
export async function checkWebGPU(): Promise<boolean> {
if (!navigator.gpu) return false;
try {
const adapter = await navigator.gpu.requestAdapter();
return adapter !== null;
} catch {
return false;
}
}
// ---------------------------------------------------------------------------
// Internal: translate Transformers.js ProgressInfo → our ModelState
// ---------------------------------------------------------------------------
function makeProgressHandler(
modelName: string,
onProgress?: ProgressCallback,
): ((info: ProgressInfo) => void) | undefined {
if (!onProgress) return undefined;
// Debounce progress updates to avoid flickering (multiple files fire rapidly)
let lastProgress = -1;
let debounceTimer: ReturnType<typeof setTimeout> | null = null;
return (info: ProgressInfo) => {
switch (info.status) {
case "initiate":
// Only fire once at start, ignore subsequent file initiations
if (lastProgress < 0) {
lastProgress = 0;
onProgress({ name: modelName, status: "downloading", progress: 0 });
}
break;
case "download":
break; // skip, we'll get progress events
case "progress": {
const p = (info as { progress: number }).progress / 100;
// Only update if progress moved by at least 2% or debounce expired
if (p - lastProgress >= 0.02 || p >= 1) {
lastProgress = p;
if (debounceTimer) clearTimeout(debounceTimer);
debounceTimer = setTimeout(() => {
onProgress({ name: modelName, status: "downloading", progress: p });
}, 50);
}
break;
}
case "done":
if (debounceTimer) clearTimeout(debounceTimer);
onProgress({ name: modelName, status: "loading", progress: 1 });
break;
case "ready":
if (debounceTimer) clearTimeout(debounceTimer);
lastProgress = 1;
onProgress({ name: modelName, status: "ready", progress: 1 });
break;
}
};
}
// ---------------------------------------------------------------------------
// Individual model loaders
// ---------------------------------------------------------------------------
export async function loadEmbeddingModel(
onProgress?: ProgressCallback,
): Promise<void> {
if (embeddingPipeline) return;
if (embeddingLoadPromise) return await embeddingLoadPromise;
const name = "embedding";
onProgress?.({ name, status: "pending", progress: 0 });
embeddingLoadPromise = (async () => {
try {
embeddingPipeline = await pipeline("feature-extraction", MODEL_EMBEDDING, {
dtype: "q4",
device: "webgpu",
progress_callback: makeProgressHandler(name, onProgress),
});
onProgress?.({ name, status: "ready", progress: 1 });
} catch (err) {
onProgress?.({
name,
status: "error",
progress: 0,
error: err instanceof Error ? err.message : String(err),
});
throw err;
} finally {
embeddingLoadPromise = null;
}
})();
return await embeddingLoadPromise;
}
export async function loadRerankerModel(
onProgress?: ProgressCallback,
): Promise<void> {
if (rerankerModel) return;
if (rerankerLoadPromise) return await rerankerLoadPromise;
const name = "reranker";
onProgress?.({ name, status: "pending", progress: 0 });
rerankerLoadPromise = (async () => {
try {
const progressHandler = makeProgressHandler(name, onProgress);
// Load tokenizer and model separately (cross-encoder pattern)
rerankerTokenizer = await AutoTokenizer.from_pretrained(MODEL_RERANKER, {
progress_callback: progressHandler,
});
// Pre-compute "yes" and "no" token IDs for scoring
const yesIds = rerankerTokenizer("yes", { add_special_tokens: false }).input_ids.data;
const noIds = rerankerTokenizer("no", { add_special_tokens: false }).input_ids.data;
rerankerTokenYes = Number(yesIds[yesIds.length - 1]);
rerankerTokenNo = Number(noIds[noIds.length - 1]);
rerankerModel = await AutoModelForCausalLM.from_pretrained(MODEL_RERANKER, {
dtype: "q4",
device: "webgpu",
progress_callback: progressHandler,
});
onProgress?.({ name, status: "ready", progress: 1 });
} catch (err) {
onProgress?.({
name,
status: "error",
progress: 0,
error: err instanceof Error ? err.message : String(err),
});
throw err;
} finally {
rerankerLoadPromise = null;
}
})();
return await rerankerLoadPromise;
}
export async function loadExpansionModel(
onProgress?: ProgressCallback,
): Promise<void> {
if (expansionModel) return;
if (expansionLoadPromise) return await expansionLoadPromise;
const name = "expansion";
onProgress?.({ name, status: "pending", progress: 0 });
expansionLoadPromise = (async () => {
try {
const progressHandler = makeProgressHandler(name, onProgress);
expansionTokenizer = await AutoTokenizer.from_pretrained(MODEL_EXPANSION, {
progress_callback: progressHandler,
});
// The HF repo has chat_template.jinja but it's not in tokenizer_config.json,
// so set the Qwen ChatML template manually.
if (!expansionTokenizer.chat_template) {
expansionTokenizer.chat_template =
"{% for message in messages %}" +
"<|im_start|>{{ message.role }}\n{{ message.content }}<|im_end|>\n" +
"{% endfor %}" +
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}";
}
expansionModel = await AutoModelForCausalLM.from_pretrained(MODEL_EXPANSION, {
dtype: "q4",
device: "webgpu",
progress_callback: progressHandler,
});
onProgress?.({ name, status: "ready", progress: 1 });
} catch (err) {
onProgress?.({
name,
status: "error",
progress: 0,
error: err instanceof Error ? err.message : String(err),
});
throw err;
} finally {
expansionLoadPromise = null;
}
})();
return await expansionLoadPromise;
}
// ---------------------------------------------------------------------------
function withTimeout<T>(
promise: Promise<T>,
timeoutMs: number,
message: string,
): Promise<T> {
return Promise.race([
promise,
new Promise<T>((_, reject) =>
setTimeout(() => reject(new Error(message)), timeoutMs),
),
]);
}
// Load all models in parallel
// ---------------------------------------------------------------------------
export async function loadAllModels(
onProgress?: ProgressCallback,
): Promise<void> {
const hasWebGPU = await checkWebGPU();
if (!hasWebGPU) {
const err = "WebGPU is not available in this browser";
for (const name of ["embedding", "reranker", "expansion"]) {
onProgress?.({ name, status: "error", progress: 0, error: err });
}
throw new Error(err);
}
const EXPANSION_TIMEOUT_MS = 120_000; // 2 minutes
const [embeddingResult, rerankerResult] = await Promise.allSettled([
loadEmbeddingModel(onProgress),
loadRerankerModel(onProgress),
withTimeout(
loadExpansionModel(onProgress),
EXPANSION_TIMEOUT_MS,
"Expansion model timed out",
).catch((err) => {
onProgress?.({
name: "expansion",
status: "error",
progress: 0,
error: err instanceof Error ? err.message : "Failed to load expansion model",
});
// The pipeline will fall back to using the original query without expansion.
}),
]);
if (embeddingResult.status === "rejected") {
throw embeddingResult.reason;
}
if (rerankerResult.status === "rejected") {
throw rerankerResult.reason;
}
}
// ---------------------------------------------------------------------------
// Getters
// ---------------------------------------------------------------------------
export function getEmbeddingPipeline(): FeatureExtractionPipeline | null {
return embeddingPipeline;
}
export function getRerankerModel(): PreTrainedModel | null {
return rerankerModel;
}
export function getRerankerTokenizer(): PreTrainedTokenizer | null {
return rerankerTokenizer;
}
export function getRerankerTokenIds(): { yes: number; no: number } {
return { yes: rerankerTokenYes, no: rerankerTokenNo };
}
export function getExpansionModel(): PreTrainedModel | null {
return expansionModel;
}
export function getExpansionTokenizer(): PreTrainedTokenizer | null {
return expansionTokenizer;
}
export function isAllModelsReady(): boolean {
// Embedding + reranker are required; expansion is optional
return embeddingPipeline !== null && rerankerModel !== null;
}
export function isExpansionReady(): boolean {
return expansionModel !== null;
}