import { useRef, useState, useEffect, useCallback, type ReactNode, } from "react"; import { pipeline, env, TextStreamer, InterruptableStoppingCriteria, type TextGenerationPipeline, } from "@huggingface/transformers"; // Enable browser Cache API so ONNX files persist across sessions env.useBrowserCache = true; env.allowLocalModels = false; import { LLMContext, createMessageId, type ChatMessage, type LoadingStatus, type ThinkingMode, } from "./LLMContext"; import { ThinkStreamParser, type ThinkDelta } from "../utils/think-parser"; import { MODEL_CONFIG } from "../model-config"; interface LLMProviderProps { modelId: string; children: ReactNode; onReady?: () => void; } function applyDeltas(msg: ChatMessage, deltas: ThinkDelta[]): ChatMessage { let { content, reasoning = "" } = msg; for (const delta of deltas) { if (delta.type === "reasoning") { reasoning += delta.textDelta; } else { content += delta.textDelta; } } return { ...msg, content, reasoning }; } export function LLMProvider({ modelId, children, onReady }: LLMProviderProps) { const generatorRef = useRef | null>(null); const stoppingCriteria = useRef(new InterruptableStoppingCriteria()); const [status, setStatus] = useState({ state: "idle" }); const [messages, setMessages] = useState([]); const messagesRef = useRef([]); const [isGenerating, setIsGenerating] = useState(false); const isGeneratingRef = useRef(false); const [tps, setTps] = useState(0); const [thinkingMode, setThinkingMode] = useState("enabled"); const thinkingModeRef = useRef("enabled"); const [systemPrompt, setSystemPrompt] = useState(MODEL_CONFIG.defaultSystemPrompt); const systemPromptRef = useRef(MODEL_CONFIG.defaultSystemPrompt); useEffect(() => { messagesRef.current = messages; }, [messages]); useEffect(() => { isGeneratingRef.current = isGenerating; }, [isGenerating]); useEffect(() => { thinkingModeRef.current = thinkingMode; }, [thinkingMode]); useEffect(() => { systemPromptRef.current = systemPrompt; }, [systemPrompt]); const onReadyRef = useRef(onReady); onReadyRef.current = onReady; useEffect(() => { if (status.state === "ready") onReadyRef.current?.(); }, [status.state]); useEffect(() => { if (generatorRef.current) return; generatorRef.current = (async () => { setStatus({ state: "loading", message: "Downloading model…" }); try { const gen = await pipeline("text-generation", modelId, { dtype: "q4", device: "webgpu", progress_callback: (progress: Record) => { const status = progress.status as string; if (status === "progress") { const loaded = Number(progress.loaded ?? 0); const total = Number(progress.total ?? 1); const pct = Math.round((loaded / total) * 100); const file = String(progress.file ?? "").split("/").pop() ?? ""; const loadedMB = (loaded / 1048576).toFixed(0); const totalMB = (total / 1048576).toFixed(0); setStatus({ state: "loading", progress: pct, message: `Downloading ${file}… ${loadedMB}/${totalMB} MB (${pct}%)`, }); } else if (status === "ready") { const file = String(progress.file ?? "").split("/").pop() ?? ""; console.log(`[cache] ${file}: loaded (cached or downloaded)`); } else if (status === "initiate") { const file = String(progress.file ?? "").split("/").pop() ?? ""; setStatus({ state: "loading", message: `Loading ${file}…`, }); } }, }); setStatus({ state: "ready" }); return gen; } catch (err) { const msg = err instanceof Error ? err.message : String(err); setStatus({ state: "error", error: msg }); generatorRef.current = null; throw err; } })(); }, [modelId]); const runGeneration = useCallback(async (chatHistory: ChatMessage[]) => { const generator = await generatorRef.current!; setIsGenerating(true); setTps(0); stoppingCriteria.current.reset(); const parser = new ThinkStreamParser(); let tokenCount = 0; let firstTokenTime = 0; let isFirstChunk = true; const assistantIdx = chatHistory.length; setMessages((prev) => [ ...prev, { id: createMessageId(), role: "assistant", content: "", reasoning: "" }, ]); const streamer = new TextStreamer(generator.tokenizer, { skip_prompt: true, skip_special_tokens: false, callback_function: (output: string) => { console.log("Streamed output:", output); if (!output || output === "<|im_end|>") return; let textToPush = output; if (isFirstChunk && thinkingModeRef.current === "enabled") { textToPush = "" + output; } isFirstChunk = false; const deltas = parser.push(textToPush); if (deltas.length === 0) return; setMessages((prev) => { const updated = [...prev]; updated[assistantIdx] = applyDeltas(updated[assistantIdx], deltas); return updated; }); }, token_callback_function: () => { tokenCount++; if (tokenCount === 1) { firstTokenTime = performance.now(); } else { const elapsed = (performance.now() - firstTokenTime) / 1000; if (elapsed > 0) { setTps(Math.round(((tokenCount - 1) / elapsed) * 10) / 10); } } }, }); try { const currentSystemPrompt = systemPromptRef.current; const messagesForModel = [ ...(currentSystemPrompt ? [{ role: "system" as const, content: currentSystemPrompt }] : []), ...chatHistory.map((message) => ({ role: message.role, content: message.content, })), ]; await generator( messagesForModel, { max_new_tokens: 4096, do_sample: true, streamer, stopping_criteria: stoppingCriteria.current, tokenizer_encode_kwargs: { enable_thinking: thinkingModeRef.current === "enabled", }, }, ); } catch (err) { console.error("Generation error:", err); } const remaining = parser.flush(); if (remaining.length > 0) { setMessages((prev) => { const updated = [...prev]; updated[assistantIdx] = applyDeltas(updated[assistantIdx], remaining); return updated; }); } setMessages((prev) => { const updated = [...prev]; updated[assistantIdx] = { ...updated[assistantIdx], content: parser.content.trim() || prev[assistantIdx].content, reasoning: parser.reasoning.trim() || prev[assistantIdx].reasoning, }; return updated; }); setIsGenerating(false); }, []); const send = useCallback( (text: string) => { if (!generatorRef.current || isGeneratingRef.current) return; const userMsg: ChatMessage = { id: createMessageId(), role: "user", content: text, }; setMessages((prev) => [...prev, userMsg]); runGeneration([...messagesRef.current, userMsg]); }, [runGeneration], ); const stop = useCallback(() => { stoppingCriteria.current.interrupt(); }, []); const clearChat = useCallback(() => { if (isGeneratingRef.current) return; setMessages([]); }, []); const editMessage = useCallback( (index: number, newContent: string) => { if (isGeneratingRef.current) return; const updatedHistory = [ ...messagesRef.current.slice(0, index), { ...messagesRef.current[index], content: newContent }, ]; setMessages(updatedHistory); if (messagesRef.current[index]?.role === "user") { setTimeout(() => runGeneration(updatedHistory), 0); } }, [runGeneration], ); const retryMessage = useCallback( (index: number) => { if (isGeneratingRef.current) return; const history = messagesRef.current.slice(0, index); setMessages(history); setTimeout(() => runGeneration(history), 0); }, [runGeneration], ); return ( {children} ); }