LFM2-MoE-WebGPU / src /hooks /LLMProvider.tsx
Xenova's picture
Xenova HF Staff
upload demo files
218232f verified
raw
history blame
6.17 kB
import { useRef, useState, useEffect, useCallback, type ReactNode } from "react";
import {
pipeline,
TextStreamer,
InterruptableStoppingCriteria,
type TextGenerationPipeline,
} from "@huggingface/transformers";
import { LLMContext, createMessageId, type ChatMessage, type LoadingStatus } from "./LLMContext";
interface LLMProviderProps {
modelId: string;
children: ReactNode;
onReady?: () => void;
}
export function LLMProvider({ modelId, children, onReady }: LLMProviderProps) {
const generatorRef = useRef<Promise<TextGenerationPipeline> | null>(null);
const stoppingCriteria = useRef(new InterruptableStoppingCriteria());
const [status, setStatus] = useState<LoadingStatus>({ state: "idle" });
const [messages, setMessages] = useState<ChatMessage[]>([]);
const messagesRef = useRef<ChatMessage[]>([]);
const [isGenerating, setIsGenerating] = useState(false);
const isGeneratingRef = useRef(false);
const [tps, setTps] = useState(0);
useEffect(() => {
messagesRef.current = messages;
}, [messages]);
useEffect(() => {
isGeneratingRef.current = isGenerating;
}, [isGenerating]);
const onReadyRef = useRef(onReady);
onReadyRef.current = onReady;
useEffect(() => {
if (status.state === "ready") onReadyRef.current?.();
}, [status.state]);
useEffect(() => {
if (generatorRef.current) return;
generatorRef.current = (async () => {
setStatus({ state: "loading", message: "Downloading model…" });
try {
const gen = await pipeline("text-generation", modelId, {
dtype: "q4f16",
device: "webgpu",
progress_callback: (info) => {
if (info.status !== "progress_total") return;
const loaded = Number(info.loaded ?? 0);
const total = Number(info.total ?? 0);
const pct = Number(info.progress ?? 0);
const toGB = (b: number) => (b / 1e9).toFixed(2);
setStatus({
state: "loading",
progress: pct,
message:
total > 0 ? `${toGB(loaded)} GB of ${toGB(total)} GB (${Math.round(pct)}%)` : `Downloading model…`,
});
},
});
setStatus({ state: "ready" });
return gen;
} catch (err) {
const msg = err instanceof Error ? err.message : String(err);
setStatus({ state: "error", error: msg });
generatorRef.current = null;
throw err;
}
})();
}, [modelId]);
const runGeneration = useCallback(async (chatHistory: ChatMessage[]) => {
const generator = await generatorRef.current!;
setIsGenerating(true);
setTps(0);
stoppingCriteria.current.reset();
let tokenCount = 0;
let firstTokenTime = 0;
const assistantIdx = chatHistory.length;
setMessages((prev) => [...prev, { id: createMessageId(), role: "assistant", content: "" }]);
const streamer = new TextStreamer(generator.tokenizer, {
skip_prompt: true,
skip_special_tokens: true,
callback_function: (output: string) => {
if (!output) return;
setMessages((prev) => {
const updated = [...prev];
updated[assistantIdx] = {
...updated[assistantIdx],
content: updated[assistantIdx].content + output,
};
return updated;
});
},
token_callback_function: () => {
tokenCount++;
if (tokenCount === 1) {
firstTokenTime = performance.now();
} else {
const elapsed = (performance.now() - firstTokenTime) / 1000;
if (elapsed > 0) {
setTps(Math.round(((tokenCount - 1) / elapsed) * 10) / 10);
}
}
},
});
try {
await generator(
chatHistory.map((message) => ({
role: message.role,
content: message.content,
})),
{
max_new_tokens: 4096,
do_sample: false,
streamer,
stopping_criteria: stoppingCriteria.current,
},
);
} catch (err) {
console.error("Generation error:", err);
}
const finalTps =
tokenCount > 1 ? Math.round(((tokenCount - 1) / ((performance.now() - firstTokenTime) / 1000)) * 10) / 10 : 0;
setMessages((prev) => {
const updated = [...prev];
updated[assistantIdx] = {
...updated[assistantIdx],
content: updated[assistantIdx].content.trim(),
tps: finalTps > 0 ? finalTps : undefined,
};
return updated;
});
setIsGenerating(false);
}, []);
const send = useCallback(
(text: string) => {
if (!generatorRef.current || isGeneratingRef.current) return;
const userMsg: ChatMessage = {
id: createMessageId(),
role: "user",
content: text,
};
setMessages((prev) => [...prev, userMsg]);
runGeneration([...messagesRef.current, userMsg]);
},
[runGeneration],
);
const stop = useCallback(() => {
stoppingCriteria.current.interrupt();
}, []);
const clearChat = useCallback(() => {
if (isGeneratingRef.current) return;
setMessages([]);
}, []);
const editMessage = useCallback(
(index: number, newContent: string) => {
if (isGeneratingRef.current) return;
const updatedHistory = [
...messagesRef.current.slice(0, index),
{ ...messagesRef.current[index], content: newContent },
];
setMessages(updatedHistory);
if (messagesRef.current[index]?.role === "user") {
setTimeout(() => runGeneration(updatedHistory), 0);
}
},
[runGeneration],
);
const retryMessage = useCallback(
(index: number) => {
if (isGeneratingRef.current) return;
const history = messagesRef.current.slice(0, index);
setMessages(history);
setTimeout(() => runGeneration(history), 0);
},
[runGeneration],
);
return (
<LLMContext.Provider
value={{
status,
messages,
isGenerating,
tps,
send,
stop,
clearChat,
editMessage,
retryMessage,
}}
>
{children}
</LLMContext.Provider>
);
}