Spaces:
Running
Running
| import { useState, useRef, useCallback, useLayoutEffect } from "react"; | |
| import { Send, Paperclip, Brain, ChevronDown, X, Plus } from "lucide-react"; | |
| import { | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| InterruptableStoppingCriteria, | |
| TextStreamer, | |
| } from "@huggingface/transformers"; | |
| import { Streamdown } from "streamdown"; | |
| import type { | |
| PreTrainedTokenizer, | |
| LlamaForCausalLM, | |
| } from "@huggingface/transformers"; | |
| import type React from "react"; | |
| const MODEL_ID = "onnx-community/Baguettotron-ONNX"; | |
| const DTYPES = { | |
| fp32: "FP32 (~1.28 GB)", | |
| fp16: "FP16 (~642 MB)", | |
| q4: "Q4 (~329 MB)", | |
| q4f16: "Q4F16 (~235 MB)", | |
| } as const; | |
| type Dtype = keyof typeof DTYPES; | |
| const SOURCE_SEPARATOR_REGEX = /\n{2,}/g; | |
| type Role = "user" | "assistant"; | |
| /** | |
| * Format the sources into tagged segments for the model input. | |
| */ | |
| const buildSourcesPayload = (rawContext: string) => { | |
| const trimmed = rawContext.trim(); | |
| if (!trimmed) { | |
| return { payload: "", count: 0, segments: [] }; | |
| } | |
| const segments = trimmed | |
| .split(SOURCE_SEPARATOR_REGEX) | |
| .map((segment) => segment.trim()) | |
| .filter(Boolean); | |
| if (segments.length === 0) { | |
| return { payload: "", count: 0, segments: [] }; | |
| } | |
| const payload = | |
| "\n\n" + | |
| segments | |
| .map( | |
| (segment, index) => | |
| `<source_${index + 1}>${segment}</source_${index + 1}>`, | |
| ) | |
| .join("\n"); | |
| return { payload, count: segments.length, segments }; | |
| }; | |
| /** | |
| * Converts <ref>...</ref> tags in the content to superscript references. | |
| */ | |
| const convertRefsToSuperscript = (content: string) => { | |
| const refRegex = /<ref name="([^"]+)">([\s\S]*?)<\/ref>/g; | |
| const refLabelMap = new Map<string, number>(); | |
| let refCounter = 1; | |
| // First, process all complete <ref>...</ref> tags | |
| let result = content.replace(refRegex, (_, sourceName = "", refBody) => { | |
| const label = | |
| refLabelMap.get(sourceName) ?? | |
| (() => { | |
| const assigned = refCounter++; | |
| refLabelMap.set(sourceName, assigned); | |
| return assigned; | |
| })(); | |
| const escapedRefBody = refBody.replace(/"/g, """); | |
| return `<sup className="cursor-pointer" title="${escapedRefBody}">[${label}]</sup>`; | |
| }); | |
| // Remove any trailing incomplete <ref> tag | |
| const incompleteRefRegex = /<ref[^>]*>[\s\S]*$/; | |
| result = result.replace(incompleteRefRegex, ""); | |
| return result; | |
| }; | |
| /** | |
| * Sanitizes user input by replacing angle brackets. | |
| */ | |
| const sanitizeInput = (text: string) => { | |
| return text.replace(/</g, "<").replace(/>/g, ">"); | |
| }; | |
| /** | |
| * Represents a single chat message in the history. | |
| */ | |
| interface Message { | |
| id: number; | |
| role: Role; | |
| content: string; | |
| thinkTrace?: string; | |
| rawStream?: string; | |
| isLoading?: boolean; | |
| timestamp?: number; | |
| thinkEndTime?: number; | |
| } | |
| /** | |
| * A simple, self-contained collapsible component. | |
| */ | |
| const Collapsible: React.FC<{ | |
| title: React.ReactNode; | |
| children: React.ReactNode; | |
| }> = ({ title, children }) => { | |
| const [isOpen, setIsOpen] = useState(false); | |
| const contentRef = useRef<HTMLDivElement>(null); | |
| return ( | |
| <div className="collapsible mt-2"> | |
| <button | |
| onClick={() => setIsOpen(!isOpen)} | |
| className="flex items-center space-x-1 text-xs font-medium text-amber-700 hover:text-amber-900 transition-colors" | |
| > | |
| {title} | |
| <ChevronDown | |
| size={14} | |
| className={`transform transition-transform ${isOpen ? "rotate-180" : "rotate-0"}`} | |
| /> | |
| </button> | |
| <div | |
| ref={contentRef} | |
| style={{ | |
| maxHeight: isOpen ? `${contentRef.current?.scrollHeight}px` : "0px", | |
| }} | |
| className="overflow-hidden transition-all duration-300 ease-in-out" | |
| > | |
| <div className="mt-2 p-2 bg-amber-50 border border-dashed border-amber-200 rounded-md text-xs text-stone-600 prose-sm"> | |
| {children} | |
| </div> | |
| </div> | |
| </div> | |
| ); | |
| }; | |
| /** | |
| * A single chat message bubble. | |
| */ | |
| const MessageBubble: React.FC<{ message: Message; minHeight?: number }> = ({ | |
| message, | |
| minHeight, | |
| }) => { | |
| const { role, content, thinkTrace, isLoading, timestamp, thinkEndTime } = | |
| message; | |
| const isUser = role === "user"; | |
| let thinkingText = ""; | |
| let opacityClass = ""; | |
| const hasDuration = | |
| typeof thinkEndTime === "number" && typeof timestamp === "number"; | |
| const durationSeconds = hasDuration | |
| ? Math.max(Math.round((thinkEndTime - timestamp) / 1000), 0) | |
| : null; | |
| if (isLoading && !thinkEndTime) { | |
| thinkingText = "Thinking..."; | |
| opacityClass = "opacity-70 hover:opacity-100"; | |
| } else if (thinkTrace) { | |
| thinkingText = | |
| durationSeconds !== null | |
| ? `Thought for ${durationSeconds} seconds` | |
| : "Thought interrupted"; | |
| } else { | |
| thinkingText = "Show Thoughts"; | |
| } | |
| const markdownContent = convertRefsToSuperscript(content); | |
| return ( | |
| <div | |
| data-message-id={message.id} | |
| data-role={role} | |
| className={`message flex items-start animate-in fade-in slide-in-from-bottom-2 duration-300 py-2 ${isUser ? "justify-end" : "justify-start"}`} | |
| style={{ | |
| minHeight, | |
| }} | |
| > | |
| <div | |
| className={`max-w-xl lg:max-w-2xl px-4 py-3 rounded-2xl ${ | |
| isUser | |
| ? "bg-amber-500 text-white rounded-br-none" | |
| : "bg-white text-stone-800 rounded-bl-none shadow-sm border border-stone-200" | |
| }`} | |
| > | |
| {(thinkTrace || isLoading) && ( | |
| <Collapsible | |
| title={ | |
| <div className="flex items-center space-x-1.5 text-sm"> | |
| <Brain size={16} /> | |
| <span | |
| className={`${isLoading ? "animate-glisten" : ""} ${opacityClass}`} | |
| > | |
| {thinkingText} | |
| </span> | |
| </div> | |
| } | |
| > | |
| <Streamdown | |
| parseIncompleteMarkdown={false} | |
| className="text-xs text-stone-500" | |
| isAnimating={Boolean(isLoading && thinkEndTime)} | |
| > | |
| {thinkTrace || (isLoading ? "..." : "")} | |
| </Streamdown> | |
| </Collapsible> | |
| )} | |
| <div className={`${thinkTrace || isLoading ? "mt-2" : ""}`}> | |
| <Streamdown | |
| parseIncompleteMarkdown={false} | |
| className="text-sm leading-relaxed text-stone-800" | |
| isAnimating={Boolean(isLoading && !thinkEndTime)} | |
| > | |
| {markdownContent || (isLoading ? "" : "")} | |
| </Streamdown> | |
| </div> | |
| </div> | |
| </div> | |
| ); | |
| }; | |
| /** | |
| * Manages the model and tokenizer loading state and refs. | |
| */ | |
| const useLLM = () => { | |
| const [modelStatus, setModelStatus] = useState< | |
| "idle" | "loading" | "ready" | "error" | |
| >("idle"); | |
| const [loadProgress, setLoadProgress] = useState(0); | |
| const modelRef = useRef<LlamaForCausalLM | null>(null); | |
| const tokenizerRef = useRef<PreTrainedTokenizer | null>(null); | |
| const loadModel = useCallback( | |
| async (dtype: Dtype) => { | |
| if (modelRef.current && tokenizerRef.current) { | |
| setModelStatus("ready"); | |
| setLoadProgress(100); | |
| return; | |
| } | |
| if (modelStatus === "loading") return; | |
| setModelStatus("loading"); | |
| setLoadProgress(0); | |
| const progress_callback = (progress: any) => { | |
| if ( | |
| progress.status === "progress" && | |
| typeof progress.total === "number" && | |
| typeof progress.loaded === "number" && | |
| typeof progress.file === "string" && | |
| progress.file.endsWith(".onnx_data") | |
| ) { | |
| const percentage = Math.round( | |
| (progress.loaded / progress.total) * 100, | |
| ); | |
| setLoadProgress(percentage); | |
| } | |
| }; | |
| try { | |
| const tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID, { | |
| progress_callback, | |
| }); | |
| const model = await AutoModelForCausalLM.from_pretrained(MODEL_ID, { | |
| dtype, | |
| device: "webgpu", | |
| progress_callback, | |
| }); | |
| tokenizerRef.current = tokenizer; | |
| modelRef.current = model; | |
| setLoadProgress(100); | |
| setModelStatus("ready"); | |
| } catch (error) { | |
| console.error("Failed to load model", error); | |
| setModelStatus("error"); | |
| } | |
| }, | |
| [modelStatus], | |
| ); | |
| return { | |
| modelStatus, | |
| loadProgress, | |
| modelRef, | |
| tokenizerRef, | |
| loadModel, | |
| }; | |
| }; | |
| const App: React.FC = () => { | |
| const [messages, setMessages] = useState<Message[]>([]); | |
| const [currentInput, setCurrentInput] = useState(""); | |
| const [context, setContext] = useState(""); | |
| const [showContext, setShowContext] = useState(false); | |
| const [isLoading, setIsLoading] = useState(false); | |
| const [lastMessageMinHeight, setLastMessageMinHeight] = useState< | |
| number | undefined | |
| >(undefined); | |
| const [selectedDtype, setSelectedDtype] = useState<Dtype>("fp16"); | |
| const [dtypeMenuOpen, setDtypeMenuOpen] = useState(false); | |
| const stoppingCriteriaRef = useRef<InterruptableStoppingCriteria | null>( | |
| null, | |
| ); | |
| const mainRef = useRef<HTMLDivElement>(null); | |
| const { modelStatus, loadProgress, modelRef, tokenizerRef, loadModel } = | |
| useLLM(); | |
| useLayoutEffect(() => { | |
| if (!mainRef.current) return; | |
| const el = mainRef.current; | |
| // If the last message is from the assistant, calculate a min-height to prevent layout shifts. | |
| if (messages.at(-1)?.role === "assistant") { | |
| const userMessageElement = el.querySelector<HTMLDivElement>( | |
| `[data-message-id="${messages.at(-2)?.id}"]`, | |
| ); | |
| if (userMessageElement) { | |
| const userMessageHeight = | |
| userMessageElement.getBoundingClientRect().height; | |
| const screenHeight = window.innerHeight; | |
| const newMinHeight = Math.max( | |
| screenHeight - userMessageHeight - 270, | |
| 0, | |
| ); | |
| setLastMessageMinHeight(newMinHeight); | |
| } | |
| } else { | |
| setLastMessageMinHeight(undefined); | |
| } | |
| }, [messages.length]); | |
| useLayoutEffect(() => { | |
| if (mainRef.current) { | |
| const el = mainRef.current; | |
| setTimeout(() => { | |
| el.scrollTo({ | |
| top: el.scrollHeight, | |
| behavior: "smooth", | |
| }); | |
| }, 0); | |
| } | |
| }, [messages.length, lastMessageMinHeight]); | |
| const handleStreamUpdate = useCallback((newToken: string) => { | |
| setMessages((prev) => { | |
| if (prev.length === 0 || prev.at(-1)!.role === "user") { | |
| return prev; | |
| } | |
| const lastMessage = { ...prev.at(-1)! }; | |
| lastMessage.rawStream = (lastMessage.rawStream || "") + newToken; | |
| const raw = lastMessage.rawStream; | |
| const thinkEndTag = "</think>"; | |
| const thinkEndIndex = raw.indexOf(thinkEndTag); | |
| let content; | |
| let thinkTrace = ""; | |
| if (thinkEndIndex !== -1) { | |
| // Think block is complete. | |
| thinkTrace = raw.substring(0, thinkEndIndex); | |
| const contentAfter = raw.substring(thinkEndIndex + thinkEndTag.length); | |
| content = contentAfter.replace("<|im_end|><|end_of_text|>", ""); | |
| if (!lastMessage.thinkEndTime) { | |
| lastMessage.thinkEndTime = Date.now(); | |
| } | |
| } else { | |
| // Think block has started but not finished. | |
| thinkTrace = raw; | |
| content = ""; | |
| } | |
| lastMessage.content = content.trim(); | |
| lastMessage.thinkTrace = thinkTrace.trim(); | |
| return [...prev.slice(0, -1), lastMessage]; | |
| }); | |
| }, []); | |
| const handleStopGeneration = useCallback(() => { | |
| stoppingCriteriaRef.current?.interrupt(); | |
| }, []); | |
| const streamAssistantResponse = useCallback( | |
| async ( | |
| historyForModel: { role: Role; content: string }[], | |
| assistantMessageId: number, | |
| ) => { | |
| const tokenizer = tokenizerRef.current; | |
| const model = modelRef.current; | |
| if (!tokenizer || !model) return; | |
| const inputs = tokenizer.apply_chat_template(historyForModel, { | |
| add_generation_prompt: true, | |
| return_dict: true, | |
| }) as any; | |
| const streamer = new TextStreamer(tokenizer, { | |
| skip_prompt: true, | |
| skip_special_tokens: false, | |
| callback_function: (token: string) => handleStreamUpdate(token), | |
| }); | |
| const stoppingCriteria = new InterruptableStoppingCriteria(); | |
| stoppingCriteriaRef.current = stoppingCriteria; | |
| try { | |
| await model.generate({ | |
| ...inputs, | |
| max_new_tokens: 2048, | |
| streamer, | |
| stopping_criteria: stoppingCriteria, | |
| repetition_penalty: 1.2, | |
| }); | |
| } catch (error) { | |
| console.error(error); | |
| } finally { | |
| stoppingCriteriaRef.current = null; | |
| setIsLoading(false); | |
| setMessages((prev) => | |
| prev.map((msg) => { | |
| if (msg.id === assistantMessageId) { | |
| const { rawStream, isLoading: _, ...rest } = msg; | |
| return rest; | |
| } | |
| return msg; | |
| }), | |
| ); | |
| } | |
| }, | |
| [handleStreamUpdate, modelRef, tokenizerRef], | |
| ); | |
| const handleSubmit = async ( | |
| e?: React.FormEvent, | |
| prompt?: string, | |
| sources?: string, | |
| ) => { | |
| if (e) e.preventDefault(); | |
| if (isLoading || modelStatus !== "ready") return; | |
| const input = prompt || currentInput; | |
| if (!input.trim()) return; | |
| const trimmedContext = (sources || context).trim(); | |
| const { | |
| payload: sourcesPayload, | |
| count: sourceCount, | |
| segments: sourceSegments, | |
| } = buildSourcesPayload(trimmedContext); | |
| const fullPrompt = `${input}${sourcesPayload}`; | |
| const sanitizedInput = sanitizeInput(input); | |
| let userMessageContent = sanitizedInput; | |
| if (sourceCount > 0) { | |
| const sourcesList = sourceSegments | |
| .map( | |
| (seg, i) => | |
| `${i + 1}. ${seg.substring(0, 75)}${seg.length > 75 ? "..." : ""}`, | |
| ) | |
| .join("\n"); | |
| userMessageContent += `\n\n[Source${sourceCount > 1 ? "s" : ""}]:\n${sourcesList}`; | |
| } | |
| const userMessage: Message = { | |
| id: messages.length, | |
| role: "user", | |
| content: userMessageContent, | |
| }; | |
| const assistantPlaceholder: Message = { | |
| id: messages.length + 1, | |
| role: "assistant", | |
| content: "", | |
| thinkTrace: "", | |
| rawStream: "", | |
| isLoading: true, | |
| timestamp: Date.now(), | |
| }; | |
| setMessages((prev) => [...prev, userMessage, assistantPlaceholder]); | |
| setCurrentInput(""); | |
| setContext(""); | |
| setShowContext(false); | |
| setIsLoading(true); | |
| setLastMessageMinHeight(undefined); | |
| const historyForModel = [ | |
| ...messages.map(({ role, content }) => ({ role, content })), | |
| { role: "user" as Role, content: fullPrompt }, | |
| ]; | |
| await streamAssistantResponse(historyForModel, assistantPlaceholder.id); | |
| }; | |
| const handleNewChat = () => { | |
| handleStopGeneration(); | |
| setMessages([]); | |
| setCurrentInput(""); | |
| setContext(""); | |
| setShowContext(false); | |
| setIsLoading(false); | |
| setLastMessageMinHeight(undefined); | |
| }; | |
| return ( | |
| <div className="flex flex-col h-screen bg-amber-50 font-sans text-stone-800"> | |
| {modelStatus === "ready" && ( | |
| <header className="flex-shrink-0 sticky top-0 z-10 flex items-center justify-between p-4 bg-white/90 backdrop-blur-md shadow-sm border-b border-amber-200 h-[100px]"> | |
| <button | |
| onClick={handleNewChat} | |
| className="p-2 rounded-full text-stone-500 hover:text-amber-600 hover:bg-amber-50 transition-colors" | |
| title="New Chat" | |
| > | |
| <Plus size={20} /> | |
| </button> | |
| <div className="flex-1 text-center"> | |
| <h1 className="text-2xl md:text-3xl font-serif font-bold text-amber-800"> | |
| 🥖 Baguettotron WebGPU | |
| </h1> | |
| <p className="text-sm text-stone-600"> | |
| A small but powerful reasoning model | |
| </p> | |
| </div> | |
| </header> | |
| )} | |
| <main ref={mainRef} className="flex-grow overflow-y-auto"> | |
| <div className="mx-auto w-full max-w-6xl p-4 md:p-6 space-y-2 h-full"> | |
| {modelStatus !== "ready" ? ( | |
| <div className="flex h-full flex-col items-center justify-center gap-6 text-center text-stone-600"> | |
| <span className="text-8xl animate-wobble">🥖</span> | |
| <div> | |
| <h1 className="text-5xl font-bold text-amber-800"> | |
| Baguettotron WebGPU | |
| </h1> | |
| <p className="mt-4 max-w-xl text-md"> | |
| You are about to load Baguettotron, a 300M parameter reasoning | |
| model optimized for in-browser inference. Everything runs | |
| entirely in your browser with 🤗 Transformers.js and ONNX | |
| Runtime Web, meaning no data is sent to a server. Once loaded, | |
| it can even be used offline. | |
| </p> | |
| </div> | |
| <div className="relative inline-flex rounded-full shadow-sm"> | |
| <button | |
| onClick={() => loadModel(selectedDtype)} | |
| disabled={modelStatus === "loading"} | |
| className="rounded-l-full bg-amber-600 pl-6 pr-5 py-3 text-white font-medium transition hover:bg-amber-700 disabled:opacity-50 disabled:cursor-not-allowed" | |
| > | |
| {modelStatus === "loading" | |
| ? `Loading ${loadProgress}%` | |
| : `Load model (${selectedDtype})`} | |
| </button> | |
| <button | |
| onClick={() => setDtypeMenuOpen(!dtypeMenuOpen)} | |
| disabled={modelStatus === "loading"} | |
| className="rounded-r-full bg-amber-600 px-3 py-3 text-white transition hover:bg-amber-700 disabled:opacity-50 border-l border-amber-500" | |
| > | |
| <ChevronDown | |
| size={20} | |
| className={`transform transition-transform ${dtypeMenuOpen ? "rotate-180" : ""}`} | |
| /> | |
| </button> | |
| {dtypeMenuOpen && ( | |
| <div className="absolute top-full mt-2 w-full bg-white rounded-md shadow-lg z-10 border border-stone-200"> | |
| {Object.entries(DTYPES).map(([dtype, label]) => ( | |
| <button | |
| key={dtype} | |
| onClick={() => { | |
| setSelectedDtype(dtype as Dtype); | |
| setDtypeMenuOpen(false); | |
| }} | |
| className="w-full text-left px-4 py-2 text-sm text-stone-700 hover:bg-amber-50" | |
| > | |
| {label} | |
| </button> | |
| ))} | |
| </div> | |
| )} | |
| </div> | |
| {modelStatus === "error" && ( | |
| <p className="text-sm text-red-600"> | |
| Model load failed. Check console for details and retry. | |
| </p> | |
| )} | |
| </div> | |
| ) : ( | |
| <> | |
| {messages.length === 0 && ( | |
| <div className="flex flex-col items-center justify-center h-full text-center text-stone-500"> | |
| <div className="p-8 rounded-2xl flex flex-col items-center"> | |
| <h2 className="text-3xl font-semibold mt-4 text-stone-700"> | |
| Welcome to Baguettotron | |
| </h2> | |
| <h3 className="max-w-xs mt-1 text-lg"> | |
| Ask me a question, or try one of the examples below! | |
| </h3> | |
| </div> | |
| <div className="mt-2 flex flex-wrap justify-center gap-4"> | |
| <button | |
| onClick={() => | |
| handleSubmit( | |
| undefined, | |
| "What is the capital of France? Just provide the answer.", | |
| ) | |
| } | |
| className="bg-amber-100 hover:bg-amber-200 text-amber-800 px-4 py-2 rounded-lg shadow-sm border border-amber-200 transition-colors" | |
| > | |
| Encyclopedic knowledge | |
| </button> | |
| {["fp32", "fp16"].includes(selectedDtype) && ( | |
| <button | |
| onClick={() => | |
| handleSubmit( | |
| undefined, | |
| "Write me a short poem about machine learning.", | |
| ) | |
| } | |
| className="bg-amber-100 hover:bg-amber-200 text-amber-800 px-4 py-2 rounded-lg shadow-sm border border-amber-200 transition-colors" | |
| > | |
| Creative writing | |
| </button> | |
| )} | |
| <button | |
| onClick={() => | |
| handleSubmit( | |
| undefined, | |
| "Which is wider: Australia or the Moon?", | |
| "Australia is approximately 4,000 km in width from east to west, according to Geoscience Australia.\n\nThe diameter of the Moon is about 3,476 km, according to Britannica.", | |
| ) | |
| } | |
| className="bg-amber-100 hover:bg-amber-200 text-amber-800 px-4 py-2 rounded-lg shadow-sm border border-amber-200 transition-colors" | |
| > | |
| RAG with grounding | |
| </button> | |
| </div> | |
| </div> | |
| )} | |
| {messages.map((msg, index) => { | |
| const isLastAssistantMessage = | |
| index === messages.length - 1 && msg.role === "assistant"; | |
| const minHeight = isLastAssistantMessage | |
| ? lastMessageMinHeight | |
| : undefined; | |
| return ( | |
| <MessageBubble | |
| key={msg.id} | |
| message={msg} | |
| minHeight={minHeight} | |
| /> | |
| ); | |
| })} | |
| </> | |
| )} | |
| </div> | |
| </main> | |
| {modelStatus === "ready" && ( | |
| <footer className="flex-shrink-0 sticky bottom-0 z-10 p-4 bg-white/90 backdrop-blur-md border-t border-amber-100"> | |
| <form onSubmit={handleSubmit} className="max-w-3xl mx-auto"> | |
| <div | |
| style={{ | |
| maxHeight: showContext ? "120px" : "0px", | |
| transition: "max-height 0.3s ease-in-out", | |
| opacity: showContext ? 1 : 0, | |
| }} | |
| className="overflow-hidden relative" | |
| > | |
| <textarea | |
| value={context} | |
| onChange={(e) => setContext(e.target.value)} | |
| disabled={isLoading} | |
| placeholder="Add RAG context here. Separate multiple sources with two new lines." | |
| className="w-full h-28 p-2 mb-2 rounded-lg border border-stone-300 focus:ring-amber-500 focus:border-amber-500 text-sm resize-none" | |
| /> | |
| <button | |
| type="button" | |
| onClick={() => setShowContext(false)} | |
| className="absolute top-2 right-2 p-1 text-stone-400 hover:text-stone-600 bg-white/50 rounded-full" | |
| > | |
| <X size={16} /> | |
| </button> | |
| </div> | |
| <div className="flex items-center space-x-2"> | |
| <button | |
| type="button" | |
| onClick={() => setShowContext(!showContext)} | |
| title="Add Context for RAG" | |
| className={`flex-shrink-0 p-2 rounded-full transition-colors ${ | |
| showContext | |
| ? "bg-amber-100 text-amber-700" | |
| : "text-stone-500 hover:text-amber-600 hover:bg-amber-50" | |
| }`} | |
| > | |
| <Paperclip size={20} /> | |
| </button> | |
| <input | |
| type="text" | |
| value={currentInput} | |
| onChange={(e) => setCurrentInput(e.target.value)} | |
| placeholder="Send a message..." | |
| className="flex-grow px-4 py-2 rounded-full border border-stone-300 focus:ring-2 focus:ring-amber-500 focus:border-transparent outline-none transition-shadow" | |
| disabled={isLoading || modelStatus !== "ready"} | |
| /> | |
| {isLoading ? ( | |
| <button | |
| type="button" | |
| onClick={handleStopGeneration} | |
| className="group flex h-10 w-10 flex-shrink-0 items-center justify-center rounded-full border border-stone-300 bg-white text-stone-600 hover:border-red-500" | |
| > | |
| <span className="h-3.5 w-3.5 rounded-sm bg-stone-600 transition-colors group-hover:bg-red-500" /> | |
| </button> | |
| ) : ( | |
| <button | |
| type="submit" | |
| disabled={ | |
| isLoading || !currentInput.trim() || modelStatus !== "ready" | |
| } | |
| className="flex h-10 w-10 flex-shrink-0 items-center justify-center rounded-full bg-amber-600 text-white transition-all transform | |
| hover:bg-amber-700 hover:scale-105 active:scale-95 | |
| disabled:bg-stone-300 disabled:scale-100 disabled:cursor-not-allowed" | |
| > | |
| <Send size={20} /> | |
| </button> | |
| )} | |
| </div> | |
| <p className="text-center text-xs text-stone-400 mt-2"> | |
| ⚡ Powered by{" "} | |
| <a | |
| href="https://github.com/huggingface/transformers.js" | |
| target="_blank" | |
| rel="noopener noreferrer" | |
| > | |
| Transformers.js | |
| </a>{" "} | |
| — Runs locally in your browser on WebGPU. | |
| </p> | |
| </form> | |
| </footer> | |
| )} | |
| </div> | |
| ); | |
| }; | |
| export default App; | |