Spaces:
Running
Running
| import { | |
| AutoModelForImageTextToText, | |
| AutoProcessor, | |
| RawImage, | |
| TextStreamer, | |
| type ProgressInfo, | |
| type Tensor, | |
| } from "@huggingface/transformers"; | |
| import { useCallback, useRef, useState, type PropsWithChildren } from "react"; | |
| import { VLMContext, type LoadState } from "./VLMContext"; | |
| const MODEL_ID = "onnx-community/LFM2-VL-450M-ONNX"; | |
| const MODEL_FILE_COUNT = 3; | |
| const MAX_NEW_TOKENS = 128; | |
| type CaptionRequest = { | |
| frame: ImageData; | |
| onStream?: (text: string) => void; | |
| prompt: string; | |
| }; | |
| type ProcessorType = Awaited<ReturnType<typeof AutoProcessor.from_pretrained>>; | |
| type ModelType = Awaited< | |
| ReturnType<typeof AutoModelForImageTextToText.from_pretrained> | |
| >; | |
| const initialLoadState: LoadState = { | |
| error: null, | |
| message: "Downloading...", | |
| progress: 0, | |
| status: "idle", | |
| }; | |
| function normalizeText(text: string) { | |
| return text.replace(/\s+/g, " ").trim(); | |
| } | |
| function getErrorMessage(error: unknown) { | |
| if (error instanceof Error) { | |
| return error.message; | |
| } | |
| return "The model could not be loaded."; | |
| } | |
| export function VLMProvider({ children }: PropsWithChildren) { | |
| const [loadState, setLoadState] = useState(initialLoadState); | |
| const processorRef = useRef<ProcessorType | null>(null); | |
| const modelRef = useRef<ModelType | null>(null); | |
| const loadPromiseRef = useRef<Promise<void> | null>(null); | |
| const generationInFlightRef = useRef(false); | |
| const setLoadProgress = useCallback((state: Partial<LoadState>) => { | |
| setLoadState((current) => ({ | |
| ...current, | |
| ...state, | |
| })); | |
| }, []); | |
| const loadModel = useCallback(async () => { | |
| if (processorRef.current && modelRef.current) { | |
| setLoadProgress({ | |
| error: null, | |
| message: "Model ready", | |
| progress: 100, | |
| status: "ready", | |
| }); | |
| return; | |
| } | |
| if (loadPromiseRef.current) { | |
| return loadPromiseRef.current; | |
| } | |
| if (!("gpu" in navigator)) { | |
| const message = "WebGPU is not available in this browser."; | |
| setLoadProgress({ | |
| error: message, | |
| message: "WebGPU unavailable", | |
| progress: 0, | |
| status: "error", | |
| }); | |
| throw new Error(message); | |
| } | |
| loadPromiseRef.current = (async () => { | |
| try { | |
| const processor = await AutoProcessor.from_pretrained(MODEL_ID); | |
| processorRef.current = processor; | |
| setLoadProgress({ | |
| message: "Downloading...", | |
| progress: 0, | |
| status: "loading", | |
| }); | |
| const progressMap = new Map<string, number>(); | |
| const progressCallback = (info: ProgressInfo) => { | |
| if ( | |
| info.status !== "progress" || | |
| !info.file.endsWith(".onnx_data") || | |
| info.total === 0 | |
| ) { | |
| return; | |
| } | |
| progressMap.set(info.file, info.loaded / info.total); | |
| const totalProgress = | |
| (Array.from(progressMap.values()).reduce( | |
| (sum, value) => sum + value, | |
| 0, | |
| ) / | |
| MODEL_FILE_COUNT) * | |
| 100; | |
| setLoadProgress({ | |
| message: "Downloading...", | |
| progress: totalProgress, | |
| status: "loading", | |
| }); | |
| }; | |
| modelRef.current = await AutoModelForImageTextToText.from_pretrained( | |
| MODEL_ID, | |
| { | |
| device: "webgpu", | |
| dtype: { | |
| vision_encoder: "fp16", | |
| embed_tokens: "fp16", | |
| decoder_model_merged: "q4f16", | |
| }, | |
| progress_callback: progressCallback, | |
| }, | |
| ); | |
| setLoadProgress({ | |
| error: null, | |
| message: "Model ready", | |
| progress: 100, | |
| status: "ready", | |
| }); | |
| } catch (error) { | |
| const message = getErrorMessage(error); | |
| setLoadProgress({ | |
| error: message, | |
| message: "Unable to load model", | |
| progress: 0, | |
| status: "error", | |
| }); | |
| throw error; | |
| } finally { | |
| loadPromiseRef.current = null; | |
| } | |
| })(); | |
| return loadPromiseRef.current; | |
| }, [setLoadProgress]); | |
| const generateCaption = useCallback( | |
| async ({ frame, onStream, prompt }: CaptionRequest) => { | |
| const processor = processorRef.current; | |
| const model = modelRef.current; | |
| if (!processor || !model || !processor.tokenizer) { | |
| throw new Error("The model is not ready yet."); | |
| } | |
| if (generationInFlightRef.current) { | |
| return ""; | |
| } | |
| generationInFlightRef.current = true; | |
| try { | |
| const messages = [ | |
| { | |
| content: [ | |
| { type: "image" }, | |
| { text: normalizeText(prompt), type: "text" }, | |
| ], | |
| role: "user", | |
| }, | |
| ]; | |
| const chatPrompt = processor.apply_chat_template(messages, { | |
| add_generation_prompt: true, | |
| }); | |
| const rawFrame = new RawImage(frame.data, frame.width, frame.height, 4); | |
| const inputs = await processor(rawFrame, chatPrompt, { | |
| add_special_tokens: false, | |
| }); | |
| let streamedText = ""; | |
| const streamer = new TextStreamer(processor.tokenizer, { | |
| callback_function: (text) => { | |
| streamedText += text; | |
| const normalized = normalizeText(streamedText); | |
| if (normalized.length > 0) { | |
| onStream?.(normalized); | |
| } | |
| }, | |
| skip_prompt: true, | |
| skip_special_tokens: true, | |
| }); | |
| const outputs = (await model.generate({ | |
| ...inputs, | |
| do_sample: false, | |
| max_new_tokens: MAX_NEW_TOKENS, | |
| repetition_penalty: 1.08, | |
| streamer, | |
| })) as Tensor; | |
| const inputLength = inputs.input_ids.dims.at(-1) ?? 0; | |
| const generated = outputs.slice(null, [inputLength, null]); | |
| const [decoded] = processor.batch_decode(generated, { | |
| skip_special_tokens: true, | |
| }); | |
| const finalCaption = normalizeText(decoded ?? streamedText); | |
| if (finalCaption.length > 0) { | |
| onStream?.(finalCaption); | |
| } | |
| return finalCaption; | |
| } finally { | |
| generationInFlightRef.current = false; | |
| } | |
| }, | |
| [], | |
| ); | |
| return ( | |
| <VLMContext.Provider | |
| value={{ | |
| ...loadState, | |
| generateCaption, | |
| loadModel, | |
| }} | |
| > | |
| {children} | |
| </VLMContext.Provider> | |
| ); | |
| } | |