| import { |
| useRef, |
| useState, |
| useEffect, |
| useCallback, |
| type ReactNode, |
| } from "react"; |
| import { |
| pipeline, |
| TextStreamer, |
| InterruptableStoppingCriteria, |
| type TextGenerationPipeline, |
| } from "@huggingface/transformers"; |
| import { |
| LLMContext, |
| createMessageId, |
| type ChatMessage, |
| type LoadingStatus, |
| type ThinkingMode, |
| } from "./LLMContext"; |
| import { ThinkStreamParser, type ThinkDelta } from "../utils/think-parser"; |
|
|
| interface LLMProviderProps { |
| modelId: string; |
| children: ReactNode; |
| onReady?: () => void; |
| } |
|
|
| function applyDeltas(msg: ChatMessage, deltas: ThinkDelta[]): ChatMessage { |
| let { content, reasoning = "" } = msg; |
| for (const delta of deltas) { |
| if (delta.type === "reasoning") { |
| reasoning += delta.textDelta; |
| } else { |
| content += delta.textDelta; |
| } |
| } |
| return { ...msg, content, reasoning }; |
| } |
|
|
| export function LLMProvider({ modelId, children, onReady }: LLMProviderProps) { |
| const generatorRef = useRef<Promise<TextGenerationPipeline> | null>(null); |
| const stoppingCriteria = useRef(new InterruptableStoppingCriteria()); |
|
|
| const [status, setStatus] = useState<LoadingStatus>({ state: "idle" }); |
| const [messages, setMessages] = useState<ChatMessage[]>([]); |
| const messagesRef = useRef<ChatMessage[]>([]); |
| const [isGenerating, setIsGenerating] = useState(false); |
| const isGeneratingRef = useRef(false); |
| const [tps, setTps] = useState(0); |
| const [thinkingMode, setThinkingMode] = useState<ThinkingMode>("enabled"); |
| const thinkingModeRef = useRef<ThinkingMode>("enabled"); |
|
|
| useEffect(() => { |
| messagesRef.current = messages; |
| }, [messages]); |
|
|
| useEffect(() => { |
| isGeneratingRef.current = isGenerating; |
| }, [isGenerating]); |
|
|
| useEffect(() => { |
| thinkingModeRef.current = thinkingMode; |
| }, [thinkingMode]); |
|
|
| const onReadyRef = useRef(onReady); |
| onReadyRef.current = onReady; |
|
|
| useEffect(() => { |
| if (status.state === "ready") onReadyRef.current?.(); |
| }, [status.state]); |
|
|
| useEffect(() => { |
| if (generatorRef.current) return; |
|
|
| generatorRef.current = (async () => { |
| setStatus({ state: "loading", message: "Downloading model…" }); |
| try { |
| const gen = await pipeline("text-generation", modelId, { |
| dtype: "q4", |
| device: "webgpu", |
| progress_callback: (progress: Record<string, unknown>) => { |
| if (progress.status !== "progress_total") return; |
| setStatus({ |
| state: "loading", |
| progress: Number(progress.progress ?? 0), |
| message: `Downloading model… ${Math.round( |
| Number(progress.progress ?? 0), |
| )}%`, |
| }); |
| }, |
| }); |
| setStatus({ state: "ready" }); |
| return gen; |
| } catch (err) { |
| const msg = err instanceof Error ? err.message : String(err); |
| setStatus({ state: "error", error: msg }); |
| generatorRef.current = null; |
| throw err; |
| } |
| })(); |
| }, [modelId]); |
|
|
| const runGeneration = useCallback(async (chatHistory: ChatMessage[]) => { |
| const generator = await generatorRef.current!; |
| setIsGenerating(true); |
| setTps(0); |
| stoppingCriteria.current.reset(); |
|
|
| const parser = new ThinkStreamParser(); |
| let tokenCount = 0; |
| let firstTokenTime = 0; |
| let isFirstChunk = true; |
|
|
| const assistantIdx = chatHistory.length; |
| setMessages((prev) => [ |
| ...prev, |
| { id: createMessageId(), role: "assistant", content: "", reasoning: "" }, |
| ]); |
|
|
| const streamer = new TextStreamer(generator.tokenizer, { |
| skip_prompt: true, |
| skip_special_tokens: false, |
| callback_function: (output: string) => { |
| console.log("Streamed output:", output); |
| if (!output || output === "<|im_end|>") return; |
|
|
| let textToPush = output; |
| if (isFirstChunk && thinkingModeRef.current === "enabled") { |
| textToPush = "<think>" + output; |
| } |
| isFirstChunk = false; |
| const deltas = parser.push(textToPush); |
| if (deltas.length === 0) return; |
|
|
| setMessages((prev) => { |
| const updated = [...prev]; |
| updated[assistantIdx] = applyDeltas(updated[assistantIdx], deltas); |
| return updated; |
| }); |
| }, |
| token_callback_function: () => { |
| tokenCount++; |
| if (tokenCount === 1) { |
| firstTokenTime = performance.now(); |
| } else { |
| const elapsed = (performance.now() - firstTokenTime) / 1000; |
| if (elapsed > 0) { |
| setTps(Math.round(((tokenCount - 1) / elapsed) * 10) / 10); |
| } |
| } |
| }, |
| }); |
|
|
| try { |
| await generator( |
| chatHistory.map((message) => ({ |
| role: message.role, |
| content: message.content, |
| })), |
| { |
| max_new_tokens: 4096, |
| do_sample: true, |
| streamer, |
| stopping_criteria: stoppingCriteria.current, |
| tokenizer_encode_kwargs: { |
| enable_thinking: thinkingModeRef.current === "enabled", |
| }, |
| }, |
| ); |
| } catch (err) { |
| console.error("Generation error:", err); |
| } |
|
|
| const remaining = parser.flush(); |
| if (remaining.length > 0) { |
| setMessages((prev) => { |
| const updated = [...prev]; |
| updated[assistantIdx] = applyDeltas(updated[assistantIdx], remaining); |
| return updated; |
| }); |
| } |
|
|
| setMessages((prev) => { |
| const updated = [...prev]; |
| updated[assistantIdx] = { |
| ...updated[assistantIdx], |
| content: parser.content.trim() || prev[assistantIdx].content, |
| reasoning: parser.reasoning.trim() || prev[assistantIdx].reasoning, |
| }; |
| return updated; |
| }); |
|
|
| setIsGenerating(false); |
| }, []); |
|
|
| const send = useCallback( |
| (text: string) => { |
| if (!generatorRef.current || isGeneratingRef.current) return; |
|
|
| const userMsg: ChatMessage = { |
| id: createMessageId(), |
| role: "user", |
| content: text, |
| }; |
|
|
| setMessages((prev) => [...prev, userMsg]); |
| runGeneration([...messagesRef.current, userMsg]); |
| }, |
| [runGeneration], |
| ); |
|
|
| const stop = useCallback(() => { |
| stoppingCriteria.current.interrupt(); |
| }, []); |
|
|
| const clearChat = useCallback(() => { |
| if (isGeneratingRef.current) return; |
| setMessages([]); |
| }, []); |
|
|
| const editMessage = useCallback( |
| (index: number, newContent: string) => { |
| if (isGeneratingRef.current) return; |
|
|
| const updatedHistory = [ |
| ...messagesRef.current.slice(0, index), |
| { ...messagesRef.current[index], content: newContent }, |
| ]; |
|
|
| setMessages(updatedHistory); |
|
|
| if (messagesRef.current[index]?.role === "user") { |
| setTimeout(() => runGeneration(updatedHistory), 0); |
| } |
| }, |
| [runGeneration], |
| ); |
|
|
| const retryMessage = useCallback( |
| (index: number) => { |
| if (isGeneratingRef.current) return; |
|
|
| const history = messagesRef.current.slice(0, index); |
| setMessages(history); |
| setTimeout(() => runGeneration(history), 0); |
| }, |
| [runGeneration], |
| ); |
|
|
| return ( |
| <LLMContext.Provider |
| value={{ |
| status, |
| messages, |
| isGenerating, |
| tps, |
| thinkingMode, |
| setThinkingMode, |
| send, |
| stop, |
| clearChat, |
| editMessage, |
| retryMessage, |
| }} |
| > |
| {children} |
| </LLMContext.Provider> |
| ); |
| } |
|
|