import { AutoModelForImageTextToText, AutoProcessor, RawImage, TextStreamer, type ProgressInfo, type Tensor, } from "@huggingface/transformers"; import { useCallback, useRef, useState, type PropsWithChildren } from "react"; import { VLMContext, type LoadState } from "./VLMContext"; const MODEL_ID = "onnx-community/LFM2-VL-450M-ONNX"; const MODEL_FILE_COUNT = 3; const MAX_NEW_TOKENS = 128; type CaptionRequest = { frame: ImageData; onStream?: (text: string) => void; prompt: string; }; type ProcessorType = Awaited>; type ModelType = Awaited< ReturnType >; const initialLoadState: LoadState = { error: null, message: "Downloading...", progress: 0, status: "idle", }; function normalizeText(text: string) { return text.replace(/\s+/g, " ").trim(); } function getErrorMessage(error: unknown) { if (error instanceof Error) { return error.message; } return "The model could not be loaded."; } export function VLMProvider({ children }: PropsWithChildren) { const [loadState, setLoadState] = useState(initialLoadState); const processorRef = useRef(null); const modelRef = useRef(null); const loadPromiseRef = useRef | null>(null); const generationInFlightRef = useRef(false); const setLoadProgress = useCallback((state: Partial) => { setLoadState((current) => ({ ...current, ...state, })); }, []); const loadModel = useCallback(async () => { if (processorRef.current && modelRef.current) { setLoadProgress({ error: null, message: "Model ready", progress: 100, status: "ready", }); return; } if (loadPromiseRef.current) { return loadPromiseRef.current; } if (!("gpu" in navigator)) { const message = "WebGPU is not available in this browser."; setLoadProgress({ error: message, message: "WebGPU unavailable", progress: 0, status: "error", }); throw new Error(message); } loadPromiseRef.current = (async () => { try { const processor = await AutoProcessor.from_pretrained(MODEL_ID); processorRef.current = processor; setLoadProgress({ message: "Downloading...", progress: 0, status: "loading", }); const progressMap = new Map(); const progressCallback = (info: ProgressInfo) => { if ( info.status !== "progress" || !info.file.endsWith(".onnx_data") || info.total === 0 ) { return; } progressMap.set(info.file, info.loaded / info.total); const totalProgress = (Array.from(progressMap.values()).reduce( (sum, value) => sum + value, 0, ) / MODEL_FILE_COUNT) * 100; setLoadProgress({ message: "Downloading...", progress: totalProgress, status: "loading", }); }; modelRef.current = await AutoModelForImageTextToText.from_pretrained( MODEL_ID, { device: "webgpu", dtype: { vision_encoder: "fp16", embed_tokens: "fp16", decoder_model_merged: "q4f16", }, progress_callback: progressCallback, }, ); setLoadProgress({ error: null, message: "Model ready", progress: 100, status: "ready", }); } catch (error) { const message = getErrorMessage(error); setLoadProgress({ error: message, message: "Unable to load model", progress: 0, status: "error", }); throw error; } finally { loadPromiseRef.current = null; } })(); return loadPromiseRef.current; }, [setLoadProgress]); const generateCaption = useCallback( async ({ frame, onStream, prompt }: CaptionRequest) => { const processor = processorRef.current; const model = modelRef.current; if (!processor || !model || !processor.tokenizer) { throw new Error("The model is not ready yet."); } if (generationInFlightRef.current) { return ""; } generationInFlightRef.current = true; try { const messages = [ { content: [ { type: "image" }, { text: normalizeText(prompt), type: "text" }, ], role: "user", }, ]; const chatPrompt = processor.apply_chat_template(messages, { add_generation_prompt: true, }); const rawFrame = new RawImage(frame.data, frame.width, frame.height, 4); const inputs = await processor(rawFrame, chatPrompt, { add_special_tokens: false, }); let streamedText = ""; const streamer = new TextStreamer(processor.tokenizer, { callback_function: (text) => { streamedText += text; const normalized = normalizeText(streamedText); if (normalized.length > 0) { onStream?.(normalized); } }, skip_prompt: true, skip_special_tokens: true, }); const outputs = (await model.generate({ ...inputs, do_sample: false, max_new_tokens: MAX_NEW_TOKENS, repetition_penalty: 1.08, streamer, })) as Tensor; const inputLength = inputs.input_ids.dims.at(-1) ?? 0; const generated = outputs.slice(null, [inputLength, null]); const [decoded] = processor.batch_decode(generated, { skip_special_tokens: true, }); const finalCaption = normalizeText(decoded ?? streamedText); if (finalCaption.length > 0) { onStream?.(finalCaption); } return finalCaption; } finally { generationInFlightRef.current = false; } }, [], ); return ( {children} ); }