| import React, { useContext, useEffect, useState } from "react"; |
| import AppContext from "./hooks/createContext"; |
| import { ToolProps, QueueStatus } from "./helpers/Interfaces"; |
| import * as _ from "underscore"; |
| import { describeMask, describeMaskWithoutStreaming } from "../services/maskApi"; |
| import ErrorModal from './ErrorModal'; |
| import { DescriptionState } from "./Stage"; |
|
|
| const prompt = "<image>\nDescribe the masked region in detail."; |
|
|
| const Tool = ({ |
| handleMouseMove, |
| descriptionState, |
| setDescriptionState, |
| queueStatus, |
| setQueueStatus |
| }: ToolProps) => { |
| console.log("Tool handleMouseMove"); |
| const { |
| image: [image], |
| maskImg: [maskImg, setMaskImg], |
| maskImgData: [maskImgData, setMaskImgData], |
| isClicked: [isClicked, setIsClicked] |
| } = useContext(AppContext)!; |
|
|
| const [shouldFitToWidth, setShouldFitToWidth] = useState(true); |
| const bodyEl = document.body; |
| const fitToPage = () => { |
| if (!image) return; |
| const maxWidth = window.innerWidth - 64; |
| const maxHeight = window.innerHeight - 200; |
| const imageAspectRatio = image.width / image.height; |
| const containerAspectRatio = maxWidth / maxHeight; |
| |
| setShouldFitToWidth( |
| imageAspectRatio > containerAspectRatio || |
| image.width > maxWidth |
| ); |
| }; |
| const resizeObserver = new ResizeObserver((entries) => { |
| for (const entry of entries) { |
| if (entry.target === bodyEl) { |
| fitToPage(); |
| } |
| } |
| }); |
| useEffect(() => { |
| fitToPage(); |
| resizeObserver.observe(bodyEl); |
| return () => { |
| resizeObserver.unobserve(bodyEl); |
| }; |
| }, [image]); |
|
|
| const imageClasses = ""; |
| const maskImageClasses = `absolute opacity-40 pointer-events-none`; |
|
|
| const [error, setError] = useState<string | null>(null); |
| const [useStreaming, setUseStreaming] = useState(true); |
|
|
| useEffect(() => { |
| if (!isClicked || !maskImg || !maskImgData || !image || descriptionState.state !== 'ready') { |
| console.log("Not ready to call model, isClicked:", isClicked, "maskImg:", maskImg !== null, "maskImgData:", maskImgData !== null, "image:", image !== null, "descriptionState.state:", descriptionState.state); |
| return; |
| } |
|
|
| try { |
| setDescriptionState({ |
| state: 'describing', |
| description: '' |
| } as DescriptionState); |
|
|
| const canvas = document.createElement('canvas'); |
| canvas.width = image.width; |
| canvas.height = image.height; |
| const ctx = canvas.getContext('2d'); |
| ctx?.drawImage(image, 0, 0); |
| const imageBase64 = canvas.toDataURL('image/jpeg').split(',')[1]; |
| const maskBase64 = maskImgData.split(',')[1]; |
|
|
| const describeMaskWithFallback = async (useStreamingInFunction: boolean) => { |
| try { |
| let result; |
| console.log("useStreaming", useStreaming, "useStreamingInFunction", useStreamingInFunction); |
| if (useStreamingInFunction) { |
| result = await describeMask( |
| maskBase64, |
| imageBase64, |
| prompt, |
| (streamResult: string) => { |
| setDescriptionState({ |
| state: 'describing', |
| description: streamResult |
| } as DescriptionState); |
| }, |
| (status: QueueStatus) => { |
| setQueueStatus(status); |
| } |
| ); |
| } else { |
| result = await describeMaskWithoutStreaming( |
| maskBase64, |
| imageBase64, |
| prompt |
| ); |
| } |
| |
| setDescriptionState({ |
| state: 'described', |
| description: result |
| } as DescriptionState); |
| setQueueStatus({ inQueue: false }); |
| setIsClicked(false); |
| } catch (error) { |
| if (useStreaming) { |
| console.log("Error describing mask, switching to non-streaming", error); |
| setUseStreaming(false); |
| describeMaskWithFallback(false); |
| } else { |
| setError('Failed to generate description. Please try again.'); |
| setDescriptionState({ |
| state: 'ready', |
| description: '' |
| } as DescriptionState); |
| setIsClicked(false); |
| console.error('Failed to describe mask:', error); |
| } |
| } |
| }; |
|
|
| describeMaskWithFallback(useStreaming); |
|
|
| } catch (error) { |
| setIsClicked(false); |
| setError('Failed to generate description. Please try again.'); |
| setDescriptionState({ |
| state: 'ready', |
| description: '' |
| } as DescriptionState); |
| console.error('Failed to describe mask:', error); |
| } |
| }, [maskImgData]); |
|
|
| const handleClick = async (e: React.MouseEvent<HTMLImageElement>) => { |
| if (descriptionState.state !== 'ready') return; |
| |
| setMaskImg(null); |
| setMaskImgData(null); |
| setIsClicked(true); |
| handleMouseMove(e); |
| }; |
|
|
| return ( |
| <> |
| {error && <ErrorModal message={error} onClose={() => setError(null)} />} |
| <div className="relative flex items-center justify-center w-full h-full"> |
| {image && ( |
| <img |
| onMouseMove={handleMouseMove} |
| onMouseLeave={() => _.defer(() => (descriptionState.state === 'ready' && !isClicked) ? setMaskImg(null) : undefined)} |
| onTouchStart={handleMouseMove} |
| onClick={handleClick} |
| src={image.src} |
| className={`${ |
| shouldFitToWidth ? "w-full" : "h-full" |
| } ${imageClasses} object-contain max-h-full max-w-full`} |
| ></img> |
| )} |
| {maskImg && ( |
| <img |
| src={maskImg.src} |
| className={`${ |
| shouldFitToWidth ? "w-full" : "h-full" |
| } ${maskImageClasses} object-contain max-h-full max-w-full`} |
| ></img> |
| )} |
| </div> |
| </> |
| ); |
| }; |
|
|
| export default Tool; |
|
|