Spaces:
Sleeping
Sleeping
| import { | |
| useRef, | |
| useState, | |
| useEffect, | |
| useCallback, | |
| type ReactNode, | |
| } from "react"; | |
| import { | |
| pipeline, | |
| env, | |
| TextStreamer, | |
| InterruptableStoppingCriteria, | |
| type TextGenerationPipeline, | |
| } from "@huggingface/transformers"; | |
| // Enable browser Cache API so ONNX files persist across sessions | |
| env.useBrowserCache = true; | |
| env.allowLocalModels = false; | |
| import { | |
| LLMContext, | |
| createMessageId, | |
| type ChatMessage, | |
| type LoadingStatus, | |
| type ThinkingMode, | |
| } from "./LLMContext"; | |
| import { ThinkStreamParser, type ThinkDelta } from "../utils/think-parser"; | |
| import { MODEL_CONFIG } from "../model-config"; | |
| interface LLMProviderProps { | |
| modelId: string; | |
| children: ReactNode; | |
| onReady?: () => void; | |
| } | |
| function applyDeltas(msg: ChatMessage, deltas: ThinkDelta[]): ChatMessage { | |
| let { content, reasoning = "" } = msg; | |
| for (const delta of deltas) { | |
| if (delta.type === "reasoning") { | |
| reasoning += delta.textDelta; | |
| } else { | |
| content += delta.textDelta; | |
| } | |
| } | |
| return { ...msg, content, reasoning }; | |
| } | |
| export function LLMProvider({ modelId, children, onReady }: LLMProviderProps) { | |
| const generatorRef = useRef<Promise<TextGenerationPipeline> | null>(null); | |
| const stoppingCriteria = useRef(new InterruptableStoppingCriteria()); | |
| const [status, setStatus] = useState<LoadingStatus>({ state: "idle" }); | |
| const [messages, setMessages] = useState<ChatMessage[]>([]); | |
| const messagesRef = useRef<ChatMessage[]>([]); | |
| const [isGenerating, setIsGenerating] = useState(false); | |
| const isGeneratingRef = useRef(false); | |
| const [tps, setTps] = useState(0); | |
| const [thinkingMode, setThinkingMode] = useState<ThinkingMode>("enabled"); | |
| const thinkingModeRef = useRef<ThinkingMode>("enabled"); | |
| const [systemPrompt, setSystemPrompt] = useState(MODEL_CONFIG.defaultSystemPrompt); | |
| const systemPromptRef = useRef(MODEL_CONFIG.defaultSystemPrompt); | |
| useEffect(() => { | |
| messagesRef.current = messages; | |
| }, [messages]); | |
| useEffect(() => { | |
| isGeneratingRef.current = isGenerating; | |
| }, [isGenerating]); | |
| useEffect(() => { | |
| thinkingModeRef.current = thinkingMode; | |
| }, [thinkingMode]); | |
| useEffect(() => { | |
| systemPromptRef.current = systemPrompt; | |
| }, [systemPrompt]); | |
| const onReadyRef = useRef(onReady); | |
| onReadyRef.current = onReady; | |
| useEffect(() => { | |
| if (status.state === "ready") onReadyRef.current?.(); | |
| }, [status.state]); | |
| useEffect(() => { | |
| if (generatorRef.current) return; | |
| generatorRef.current = (async () => { | |
| setStatus({ state: "loading", message: "Downloading model…" }); | |
| try { | |
| const gen = await pipeline("text-generation", modelId, { | |
| dtype: "q4", | |
| device: "webgpu", | |
| progress_callback: (progress: Record<string, unknown>) => { | |
| const status = progress.status as string; | |
| if (status === "progress") { | |
| const loaded = Number(progress.loaded ?? 0); | |
| const total = Number(progress.total ?? 1); | |
| const pct = Math.round((loaded / total) * 100); | |
| const file = String(progress.file ?? "").split("/").pop() ?? ""; | |
| const loadedMB = (loaded / 1048576).toFixed(0); | |
| const totalMB = (total / 1048576).toFixed(0); | |
| setStatus({ | |
| state: "loading", | |
| progress: pct, | |
| message: `Downloading ${file}… ${loadedMB}/${totalMB} MB (${pct}%)`, | |
| }); | |
| } else if (status === "ready") { | |
| const file = String(progress.file ?? "").split("/").pop() ?? ""; | |
| console.log(`[cache] ${file}: loaded (cached or downloaded)`); | |
| } else if (status === "initiate") { | |
| const file = String(progress.file ?? "").split("/").pop() ?? ""; | |
| setStatus({ | |
| state: "loading", | |
| message: `Loading ${file}…`, | |
| }); | |
| } | |
| }, | |
| }); | |
| setStatus({ state: "ready" }); | |
| return gen; | |
| } catch (err) { | |
| const msg = err instanceof Error ? err.message : String(err); | |
| setStatus({ state: "error", error: msg }); | |
| generatorRef.current = null; | |
| throw err; | |
| } | |
| })(); | |
| }, [modelId]); | |
| const runGeneration = useCallback(async (chatHistory: ChatMessage[]) => { | |
| const generator = await generatorRef.current!; | |
| setIsGenerating(true); | |
| setTps(0); | |
| stoppingCriteria.current.reset(); | |
| const parser = new ThinkStreamParser(); | |
| let tokenCount = 0; | |
| let firstTokenTime = 0; | |
| let isFirstChunk = true; | |
| const assistantIdx = chatHistory.length; | |
| setMessages((prev) => [ | |
| ...prev, | |
| { id: createMessageId(), role: "assistant", content: "", reasoning: "" }, | |
| ]); | |
| const streamer = new TextStreamer(generator.tokenizer, { | |
| skip_prompt: true, | |
| skip_special_tokens: false, | |
| callback_function: (output: string) => { | |
| console.log("Streamed output:", output); | |
| if (!output || output === "<|im_end|>") return; | |
| let textToPush = output; | |
| if (isFirstChunk && thinkingModeRef.current === "enabled") { | |
| textToPush = "<think>" + output; | |
| } | |
| isFirstChunk = false; | |
| const deltas = parser.push(textToPush); | |
| if (deltas.length === 0) return; | |
| setMessages((prev) => { | |
| const updated = [...prev]; | |
| updated[assistantIdx] = applyDeltas(updated[assistantIdx], deltas); | |
| return updated; | |
| }); | |
| }, | |
| token_callback_function: () => { | |
| tokenCount++; | |
| if (tokenCount === 1) { | |
| firstTokenTime = performance.now(); | |
| } else { | |
| const elapsed = (performance.now() - firstTokenTime) / 1000; | |
| if (elapsed > 0) { | |
| setTps(Math.round(((tokenCount - 1) / elapsed) * 10) / 10); | |
| } | |
| } | |
| }, | |
| }); | |
| try { | |
| const currentSystemPrompt = systemPromptRef.current; | |
| const messagesForModel = [ | |
| ...(currentSystemPrompt | |
| ? [{ role: "system" as const, content: currentSystemPrompt }] | |
| : []), | |
| ...chatHistory.map((message) => ({ | |
| role: message.role, | |
| content: message.content, | |
| })), | |
| ]; | |
| await generator( | |
| messagesForModel, | |
| { | |
| max_new_tokens: 4096, | |
| do_sample: true, | |
| streamer, | |
| stopping_criteria: stoppingCriteria.current, | |
| tokenizer_encode_kwargs: { | |
| enable_thinking: thinkingModeRef.current === "enabled", | |
| }, | |
| }, | |
| ); | |
| } catch (err) { | |
| console.error("Generation error:", err); | |
| } | |
| const remaining = parser.flush(); | |
| if (remaining.length > 0) { | |
| setMessages((prev) => { | |
| const updated = [...prev]; | |
| updated[assistantIdx] = applyDeltas(updated[assistantIdx], remaining); | |
| return updated; | |
| }); | |
| } | |
| setMessages((prev) => { | |
| const updated = [...prev]; | |
| updated[assistantIdx] = { | |
| ...updated[assistantIdx], | |
| content: parser.content.trim() || prev[assistantIdx].content, | |
| reasoning: parser.reasoning.trim() || prev[assistantIdx].reasoning, | |
| }; | |
| return updated; | |
| }); | |
| setIsGenerating(false); | |
| }, []); | |
| const send = useCallback( | |
| (text: string) => { | |
| if (!generatorRef.current || isGeneratingRef.current) return; | |
| const userMsg: ChatMessage = { | |
| id: createMessageId(), | |
| role: "user", | |
| content: text, | |
| }; | |
| setMessages((prev) => [...prev, userMsg]); | |
| runGeneration([...messagesRef.current, userMsg]); | |
| }, | |
| [runGeneration], | |
| ); | |
| const stop = useCallback(() => { | |
| stoppingCriteria.current.interrupt(); | |
| }, []); | |
| const clearChat = useCallback(() => { | |
| if (isGeneratingRef.current) return; | |
| setMessages([]); | |
| }, []); | |
| const editMessage = useCallback( | |
| (index: number, newContent: string) => { | |
| if (isGeneratingRef.current) return; | |
| const updatedHistory = [ | |
| ...messagesRef.current.slice(0, index), | |
| { ...messagesRef.current[index], content: newContent }, | |
| ]; | |
| setMessages(updatedHistory); | |
| if (messagesRef.current[index]?.role === "user") { | |
| setTimeout(() => runGeneration(updatedHistory), 0); | |
| } | |
| }, | |
| [runGeneration], | |
| ); | |
| const retryMessage = useCallback( | |
| (index: number) => { | |
| if (isGeneratingRef.current) return; | |
| const history = messagesRef.current.slice(0, index); | |
| setMessages(history); | |
| setTimeout(() => runGeneration(history), 0); | |
| }, | |
| [runGeneration], | |
| ); | |
| return ( | |
| <LLMContext.Provider | |
| value={{ | |
| status, | |
| messages, | |
| isGenerating, | |
| tps, | |
| thinkingMode, | |
| setThinkingMode, | |
| systemPrompt, | |
| setSystemPrompt, | |
| send, | |
| stop, | |
| clearChat, | |
| editMessage, | |
| retryMessage, | |
| }} | |
| > | |
| {children} | |
| </LLMContext.Provider> | |
| ); | |
| } | |