LFM2-VL-WebGPU / src /context /VLMProvider.tsx
mlabonne's picture
upload demo files (#1)
01488bc
import {
AutoModelForImageTextToText,
AutoProcessor,
RawImage,
TextStreamer,
type ProgressInfo,
type Tensor,
} from "@huggingface/transformers";
import { useCallback, useRef, useState, type PropsWithChildren } from "react";
import { VLMContext, type LoadState } from "./VLMContext";
const MODEL_ID = "onnx-community/LFM2-VL-450M-ONNX";
const MODEL_FILE_COUNT = 3;
const MAX_NEW_TOKENS = 128;
type CaptionRequest = {
frame: ImageData;
onStream?: (text: string) => void;
prompt: string;
};
type ProcessorType = Awaited<ReturnType<typeof AutoProcessor.from_pretrained>>;
type ModelType = Awaited<
ReturnType<typeof AutoModelForImageTextToText.from_pretrained>
>;
const initialLoadState: LoadState = {
error: null,
message: "Downloading...",
progress: 0,
status: "idle",
};
function normalizeText(text: string) {
return text.replace(/\s+/g, " ").trim();
}
function getErrorMessage(error: unknown) {
if (error instanceof Error) {
return error.message;
}
return "The model could not be loaded.";
}
export function VLMProvider({ children }: PropsWithChildren) {
const [loadState, setLoadState] = useState(initialLoadState);
const processorRef = useRef<ProcessorType | null>(null);
const modelRef = useRef<ModelType | null>(null);
const loadPromiseRef = useRef<Promise<void> | null>(null);
const generationInFlightRef = useRef(false);
const setLoadProgress = useCallback((state: Partial<LoadState>) => {
setLoadState((current) => ({
...current,
...state,
}));
}, []);
const loadModel = useCallback(async () => {
if (processorRef.current && modelRef.current) {
setLoadProgress({
error: null,
message: "Model ready",
progress: 100,
status: "ready",
});
return;
}
if (loadPromiseRef.current) {
return loadPromiseRef.current;
}
if (!("gpu" in navigator)) {
const message = "WebGPU is not available in this browser.";
setLoadProgress({
error: message,
message: "WebGPU unavailable",
progress: 0,
status: "error",
});
throw new Error(message);
}
loadPromiseRef.current = (async () => {
try {
const processor = await AutoProcessor.from_pretrained(MODEL_ID);
processorRef.current = processor;
setLoadProgress({
message: "Downloading...",
progress: 0,
status: "loading",
});
const progressMap = new Map<string, number>();
const progressCallback = (info: ProgressInfo) => {
if (
info.status !== "progress" ||
!info.file.endsWith(".onnx_data") ||
info.total === 0
) {
return;
}
progressMap.set(info.file, info.loaded / info.total);
const totalProgress =
(Array.from(progressMap.values()).reduce(
(sum, value) => sum + value,
0,
) /
MODEL_FILE_COUNT) *
100;
setLoadProgress({
message: "Downloading...",
progress: totalProgress,
status: "loading",
});
};
modelRef.current = await AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
{
device: "webgpu",
dtype: {
vision_encoder: "fp16",
embed_tokens: "fp16",
decoder_model_merged: "q4f16",
},
progress_callback: progressCallback,
},
);
setLoadProgress({
error: null,
message: "Model ready",
progress: 100,
status: "ready",
});
} catch (error) {
const message = getErrorMessage(error);
setLoadProgress({
error: message,
message: "Unable to load model",
progress: 0,
status: "error",
});
throw error;
} finally {
loadPromiseRef.current = null;
}
})();
return loadPromiseRef.current;
}, [setLoadProgress]);
const generateCaption = useCallback(
async ({ frame, onStream, prompt }: CaptionRequest) => {
const processor = processorRef.current;
const model = modelRef.current;
if (!processor || !model || !processor.tokenizer) {
throw new Error("The model is not ready yet.");
}
if (generationInFlightRef.current) {
return "";
}
generationInFlightRef.current = true;
try {
const messages = [
{
content: [
{ type: "image" },
{ text: normalizeText(prompt), type: "text" },
],
role: "user",
},
];
const chatPrompt = processor.apply_chat_template(messages, {
add_generation_prompt: true,
});
const rawFrame = new RawImage(frame.data, frame.width, frame.height, 4);
const inputs = await processor(rawFrame, chatPrompt, {
add_special_tokens: false,
});
let streamedText = "";
const streamer = new TextStreamer(processor.tokenizer, {
callback_function: (text) => {
streamedText += text;
const normalized = normalizeText(streamedText);
if (normalized.length > 0) {
onStream?.(normalized);
}
},
skip_prompt: true,
skip_special_tokens: true,
});
const outputs = (await model.generate({
...inputs,
do_sample: false,
max_new_tokens: MAX_NEW_TOKENS,
repetition_penalty: 1.08,
streamer,
})) as Tensor;
const inputLength = inputs.input_ids.dims.at(-1) ?? 0;
const generated = outputs.slice(null, [inputLength, null]);
const [decoded] = processor.batch_decode(generated, {
skip_special_tokens: true,
});
const finalCaption = normalizeText(decoded ?? streamedText);
if (finalCaption.length > 0) {
onStream?.(finalCaption);
}
return finalCaption;
} finally {
generationInFlightRef.current = false;
}
},
[],
);
return (
<VLMContext.Provider
value={{
...loadState,
generateCaption,
loadModel,
}}
>
{children}
</VLMContext.Provider>
);
}