| import { useEffect, useState, useRef } from "react"; |
|
|
| import Chat from "./components/Chat"; |
| import ArrowRightIcon from "./components/icons/ArrowRightIcon"; |
| import StopIcon from "./components/icons/StopIcon"; |
| import Progress from "./components/Progress"; |
| import ImageIcon from "./components/icons/ImageIcon"; |
| import ImagePreview from "./components/ImagePreview"; |
|
|
| const IS_WEBGPU_AVAILABLE = !!navigator.gpu; |
| const STICKY_SCROLL_THRESHOLD = 120; |
| const EXAMPLES = [ |
| { |
| display: "Generate an image of a cute baby fox.", |
| prompt: |
| "/imagine A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting, immortal, fluffy, shiny mane, Petals, fairyism, unreal engine 5 and Octane Render, highly detailed, photorealistic, cinematic, natural colors.", |
| }, |
| { |
| prompt: "Convert the formula into latex code.", |
| image: |
| "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/quadratic_formula.png", |
| }, |
| { |
| prompt: "What is the difference between AI and ML?", |
| }, |
| { |
| prompt: "Write python code to compute the nth fibonacci number.", |
| }, |
| ]; |
|
|
| function App() { |
| |
| const worker = useRef(null); |
|
|
| const textareaRef = useRef(null); |
| const chatContainerRef = useRef(null); |
| const imageUploadRef = useRef(null); |
|
|
| |
| const [status, setStatus] = useState(null); |
| const [error, setError] = useState(null); |
| const [loadingMessage, setLoadingMessage] = useState(""); |
| const [progressItems, setProgressItems] = useState([]); |
| const [isRunning, setIsRunning] = useState(false); |
|
|
| |
| const [input, setInput] = useState(""); |
| const [image, setImage] = useState(null); |
| const [messages, setMessages] = useState([]); |
| const [tps, setTps] = useState(null); |
| const [numTokens, setNumTokens] = useState(null); |
| const [imageProgress, setImageProgress] = useState(null); |
| const [imageGenerationTime, setImageGenerationTime] = useState(null); |
|
|
| function onEnter(message, img) { |
| setMessages((prev) => [ |
| ...prev, |
| { role: "user", content: message, image: img ?? image }, |
| ]); |
| setTps(null); |
| setIsRunning(true); |
| setInput(""); |
| setImage(null); |
| setNumTokens(null); |
| setImageProgress(null); |
| setImageGenerationTime(null); |
| } |
|
|
| function onInterrupt() { |
| |
| |
| worker.current.postMessage({ type: "interrupt" }); |
| } |
|
|
| function resizeInput() { |
| if (!textareaRef.current) return; |
|
|
| const target = textareaRef.current; |
| target.style.height = "auto"; |
| const newHeight = Math.min(Math.max(target.scrollHeight, 24), 200); |
| target.style.height = `${newHeight}px`; |
| } |
|
|
| useEffect(() => { |
| resizeInput(); |
| }, [input]); |
|
|
| |
| useEffect(() => { |
| |
| if (!worker.current) { |
| worker.current = new Worker(new URL("./worker.js", import.meta.url), { |
| type: "module", |
| }); |
| worker.current.postMessage({ type: "check" }); |
| } |
|
|
| |
| const onMessageReceived = (e) => { |
| switch (e.data.status) { |
| |
| case "success": |
| setStatus("idle"); |
| break; |
| case "error": |
| setError(e.data.data); |
| break; |
|
|
| case "loading": |
| |
| setStatus("loading"); |
| setLoadingMessage(e.data.data); |
| break; |
|
|
| case "initiate": |
| setProgressItems((prev) => [...prev, e.data]); |
| break; |
|
|
| case "progress": |
| |
| setProgressItems((prev) => |
| prev.map((item) => { |
| if (item.file === e.data.file) { |
| return { ...item, ...e.data }; |
| } |
| return item; |
| }), |
| ); |
| break; |
|
|
| case "done": |
| |
| setProgressItems((prev) => |
| prev.filter((item) => item.file !== e.data.file), |
| ); |
| break; |
|
|
| case "ready": |
| |
| setStatus("ready"); |
| break; |
|
|
| case "start": |
| { |
| |
| setMessages((prev) => [ |
| ...prev, |
| { role: "assistant", content: "" }, |
| ]); |
| } |
| break; |
|
|
| case "text-update": |
| |
| |
| const { output, tps, numTokens } = e.data; |
| setTps(tps); |
| setNumTokens(numTokens); |
| setMessages((prev) => { |
| const cloned = [...prev]; |
| const last = cloned.at(-1); |
| cloned[cloned.length - 1] = { |
| ...last, |
| content: last.content + output, |
| }; |
| return cloned; |
| }); |
| break; |
|
|
| case "image-update": |
| const { blob, progress, time } = e.data; |
|
|
| if (blob) { |
| |
| const url = URL.createObjectURL(blob); |
| setMessages((prev) => { |
| const cloned = [...prev]; |
| const last = cloned.at(-1); |
| cloned[cloned.length - 1] = { |
| ...last, |
| image: url, |
| }; |
| return cloned; |
| }); |
| } else { |
| setImageProgress(progress); |
| setImageGenerationTime(time); |
| } |
| break; |
|
|
| case "complete": |
| |
| setIsRunning(false); |
| break; |
| } |
| }; |
|
|
| const onErrorReceived = (e) => { |
| console.error("Worker error:", e); |
| }; |
|
|
| |
| worker.current.addEventListener("message", onMessageReceived); |
| worker.current.addEventListener("error", onErrorReceived); |
|
|
| |
| return () => { |
| worker.current.removeEventListener("message", onMessageReceived); |
| worker.current.removeEventListener("error", onErrorReceived); |
| }; |
| }, []); |
|
|
| |
| useEffect(() => { |
| if (messages.filter((x) => x.role === "user").length === 0) { |
| |
| return; |
| } |
| if (messages.at(-1).role === "assistant") { |
| |
| return; |
| } |
| setTps(null); |
| worker.current.postMessage({ type: "generate", data: messages }); |
| }, [messages, isRunning]); |
|
|
| useEffect(() => { |
| if (!chatContainerRef.current || !isRunning) return; |
| const element = chatContainerRef.current; |
| if ( |
| element.scrollHeight - element.scrollTop - element.clientHeight < |
| STICKY_SCROLL_THRESHOLD |
| ) { |
| element.scrollTop = element.scrollHeight; |
| } |
| }, [messages, isRunning]); |
|
|
| return IS_WEBGPU_AVAILABLE ? ( |
| <div className="flex flex-col h-screen mx-auto items justify-end text-gray-800 dark:text-gray-200 bg-white dark:bg-gray-900"> |
| {(status === null || status === "idle") && messages.length === 0 && ( |
| <div className="h-full overflow-auto scrollbar-thin flex justify-center items-center flex-col relative"> |
| <div className="flex flex-col items-center mb-1 max-w-[350px] text-center"> |
| <img |
| src="logo.png" |
| width="80%" |
| height="auto" |
| className="block" |
| ></img> |
| <h1 className="text-5xl font-bold mb-1">Janus WebGPU</h1> |
| <h2 className="font-semibold"> |
| A novel autoregressive framework for unified multimodal |
| understanding and generation. |
| </h2> |
| </div> |
| |
| <div className="flex flex-col items-center px-4"> |
| <p className="max-w-[452px] mb-4"> |
| <br /> |
| You are about to load{" "} |
| <a |
| href="https://huggingface.co/onnx-community/Janus-1.3B-ONNX" |
| target="_blank" |
| rel="noreferrer" |
| className="font-medium underline" |
| > |
| Janus-1.3B |
| </a> |
| , a multimodal vision-language model that is optimized for |
| inference on the web. Everything runs 100% locally in your browser |
| with{" "} |
| <a |
| href="https://huggingface.co/docs/transformers.js" |
| target="_blank" |
| rel="noreferrer" |
| className="underline" |
| > |
| 🤗 Transformers.js |
| </a>{" "} |
| and ONNX Runtime Web, meaning no data is sent to a server. Once |
| the model has loaded, it can even be used offline. The source code |
| for the demo can be found on{" "} |
| <a |
| href="https://github.com/huggingface/transformers.js-examples/tree/main/janus-webgpu" |
| target="_blank" |
| rel="noreferrer" |
| className="font-medium underline" |
| > |
| GitHub |
| </a> |
| . |
| </p> |
| |
| {error && ( |
| <div className="text-red-500 text-center mb-2"> |
| <p className="mb-1"> |
| Unable to load model due to the following error: |
| </p> |
| <p className="text-sm">{error}</p> |
| </div> |
| )} |
| |
| {!error && ( |
| <button |
| className="border px-4 py-2 rounded-lg bg-blue-400 text-white hover:bg-blue-500 disabled:bg-blue-100 disabled:cursor-not-allowed select-none" |
| onClick={() => { |
| worker.current.postMessage({ type: "load" }); |
| setStatus("loading"); |
| }} |
| disabled={status === null || status === "loading"} |
| > |
| {status === null ? "Running feature checks..." : "Load model"} |
| </button> |
| )} |
| </div> |
| </div> |
| )} |
| {status === "loading" && ( |
| <> |
| <div className="w-full max-w-[500px] text-left mx-auto p-4 bottom-0 mt-auto"> |
| <p className="text-center mb-1">{loadingMessage}</p> |
| {progressItems.map(({ file, progress, total }, i) => ( |
| <Progress |
| key={i} |
| text={file} |
| percentage={progress} |
| total={total} |
| /> |
| ))} |
| </div> |
| </> |
| )} |
|
|
| {status === "ready" && ( |
| <div |
| ref={chatContainerRef} |
| className="overflow-y-auto scrollbar-thin w-full flex flex-col items-center h-full" |
| > |
| <Chat messages={messages} /> |
| {messages.length === 0 && !image && ( |
| <div className="flex flex-col center"> |
| {EXAMPLES.map(({ display, prompt, image }, i) => ( |
| <div |
| key={i} |
| className="max-w-[600px] m-1 border dark:border-gray-600 rounded-md p-2 bg-gray-100 dark:bg-gray-700 cursor-pointer" |
| onClick={() => onEnter(prompt, image)} |
| > |
| {display ?? prompt} |
| </div> |
| ))} |
| </div> |
| )} |
| |
| <p className="text-center text-sm min-h-6 text-gray-500 dark:text-gray-300"> |
| {messages.length > 0 && ( |
| <> |
| {tps ? ( |
| <> |
| {!isRunning && ( |
| <span> |
| Generated {numTokens} tokens in{" "} |
| {(numTokens / tps).toFixed(2)} seconds ( |
| </span> |
| )} |
| <span className="font-medium font-mono text-center mr-1 text-black dark:text-white"> |
| {tps.toFixed(2)} |
| </span> |
| <span className="text-gray-500 dark:text-gray-300"> |
| tokens/second |
| </span> |
| {!isRunning && <span className="mr-1">).</span>} |
| </> |
| ) : ( |
| imageProgress && ( |
| <> |
| {isRunning ? ( |
| <> |
| <span>Generating image...</span> ( |
| <span className="font-medium font-mono text-center text-black dark:text-white"> |
| {(imageProgress * 100).toFixed(2)}% |
| </span> |
| <span className="mr-1">)</span> |
| </> |
| ) : ( |
| <span> |
| Generated image in{" "} |
| {(imageGenerationTime / 1000).toFixed(2)}{" "} |
| seconds. |
| </span> |
| )} |
| </> |
| ) |
| )} |
|
|
| {!isRunning && ( |
| <span |
| className="underline cursor-pointer" |
| onClick={() => setMessages([])} |
| > |
| Reset |
| </span> |
| )} |
| </> |
| )} |
| </p> |
| </div> |
| )} |
|
|
| <div className="mt-2 border dark:bg-gray-700 rounded-lg w-[600px] max-w-[80%] max-h-[200px] mx-auto relative mb-3 flex"> |
| <label |
| htmlFor="file-upload" |
| className={ |
| status === "ready" |
| ? "cursor-pointer" |
| : "cursor-not-allowed pointer-events-none" |
| } |
| > |
| <ImageIcon |
| className={`h-8 w-8 p-1 rounded-md ${status === "ready" ? "text-gray-800 dark:text-gray-100" : "text-gray-400 dark:text-gray-500"} absolute bottom-3 left-1.5`} |
| ></ImageIcon> |
| <input |
| ref={imageUploadRef} |
| id="file-upload" |
| type="file" |
| accept="image/*" |
| className="hidden" |
| onInput={(e) => { |
| const file = e.target.files[0]; |
| if (!file) { |
| return; |
| } |
| |
| const reader = new FileReader(); |
| |
| // Set up a callback when the file is loaded |
| reader.onload = (e2) => { |
| setImage(e2.target.result); |
| e.target.value = ""; |
| }; |
| |
| reader.readAsDataURL(file); |
| }} |
| ></input> |
| </label> |
| <div className="w-full flex flex-col"> |
| {image && ( |
| <ImagePreview |
| onRemove={() => { |
| setImage(null); |
| }} |
| src={image} |
| className="w-20 h-20 min-w-20 min-h-20 relative p-2" |
| /> |
| )} |
| |
| <textarea |
| ref={textareaRef} |
| className="scrollbar-thin w-full pl-11 pr-12 dark:bg-gray-700 py-4 rounded-lg bg-transparent border-none outline-none text-gray-800 disabled:text-gray-400 dark:text-gray-100 placeholder-gray-500 disabled:placeholder-gray-200 dark:placeholder-gray-300 dark:disabled:placeholder-gray-500 resize-none disabled:cursor-not-allowed" |
| placeholder="Type message or use '/imagine <prompt>' to generate an image." |
| type="text" |
| rows={1} |
| value={input} |
| disabled={status !== "ready"} |
| title={ |
| status === "ready" ? "Model is ready" : "Model not loaded yet" |
| } |
| onKeyDown={(e) => { |
| if ( |
| input.length > 0 && |
| !isRunning && |
| e.key === "Enter" && |
| !e.shiftKey |
| ) { |
| e.preventDefault(); // Prevent default behavior of Enter key |
| onEnter(input, image); |
| } |
| }} |
| onInput={(e) => setInput(e.target.value)} |
| /> |
| </div> |
| {isRunning ? ( |
| <div className="cursor-pointer" onClick={onInterrupt}> |
| <StopIcon className="h-8 w-8 p-1 rounded-md text-gray-800 dark:text-gray-100 absolute right-3 bottom-3" /> |
| </div> |
| ) : input.length > 0 ? ( |
| <div className="cursor-pointer" onClick={() => onEnter(input)}> |
| <ArrowRightIcon |
| className={`h-8 w-8 p-1 bg-gray-800 dark:bg-gray-100 text-white dark:text-black rounded-md absolute right-3 bottom-3`} |
| /> |
| </div> |
| ) : ( |
| <div> |
| <ArrowRightIcon |
| className={`h-8 w-8 p-1 bg-gray-200 dark:bg-gray-600 text-gray-50 dark:text-gray-800 rounded-md absolute right-3 bottom-3`} |
| /> |
| </div> |
| )} |
| </div> |
|
|
| <p className="text-xs text-gray-400 text-center mb-3"> |
| Disclaimer: Generated content may be inaccurate or false. |
| </p> |
| </div> |
| ) : ( |
| <div className="fixed w-screen h-screen bg-black z-10 bg-opacity-[92%] text-white text-2xl font-semibold flex justify-center items-center text-center"> |
| WebGPU is not supported |
| <br /> |
| by this browser :( |
| </div> |
| ); |
| } |
|
|
| export default App; |
|
|