| |
| |
|
|
| |
| |
|
|
| import { InferenceSession, Tensor } from "onnxruntime-web"; |
| import React, { useContext, useEffect, useState, useRef } from "react"; |
| import axios from "axios"; |
| import "./assets/scss/App.scss"; |
| import { handleImageScale } from "./components/helpers/scaleHelper"; |
| import { modelScaleProps, QueueStatus } from "./components/helpers/Interfaces"; |
| import { onnxMaskToImage, arrayToImageData, imageDataToURL } from "./components/helpers/maskUtils"; |
| import { modelData } from "./components/helpers/onnxModelAPI"; |
| import Stage, { DescriptionState } from "./components/Stage"; |
| import AppContext from "./components/hooks/createContext"; |
| import { imageToSamEmbedding } from "./services/maskApi"; |
| import LoadingOverlay from "./components/LoadingOverlay"; |
| import ErrorModal from './components/ErrorModal'; |
| import QueueStatusIndicator from "./components/QueueStatusIndicator"; |
|
|
| const ort = require("onnxruntime-web"); |
|
|
| |
| const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx"; |
|
|
| const App = () => { |
| const { |
| clicks: [clicks, setClicks], |
| image: [image, setImage], |
| maskImg: [maskImg, setMaskImg], |
| maskImgData: [maskImgData, setMaskImgData], |
| isClicked: [isClicked, setIsClicked] |
| } = useContext(AppContext)!; |
| const [model, setModel] = useState<InferenceSession | null>(null); |
| const [tensor, setTensor] = useState<Tensor | null>(null); |
| const [modelScale, setModelScale] = useState<modelScaleProps | null>(null); |
| const [isLoading, setIsLoading] = useState<boolean>(false); |
| const [error, setError] = useState<string | null>(null); |
| const [descriptionState, setDescriptionState] = useState<DescriptionState>({ |
| state: 'ready', |
| description: '' |
| }); |
| const [queueStatus, setQueueStatus] = useState<QueueStatus>({ inQueue: false }); |
|
|
| |
| useEffect(() => { |
| const initModel = async () => { |
| try { |
| if (MODEL_DIR === undefined) return; |
| const URL: string = MODEL_DIR; |
| const model = await InferenceSession.create(URL); |
| setModel(model); |
| } catch (e) { |
| console.log(e); |
| } |
| }; |
| initModel(); |
| }, []); |
|
|
| const handleImageUpload = async (event: React.ChangeEvent<HTMLInputElement>) => { |
| const file = event.target.files?.[0]; |
| if (!file) return; |
|
|
| try { |
| const url = URL.createObjectURL(file); |
| await loadImage(new URL(url)); |
| } catch (error) { |
| setError('Failed to load image. Please try again with a different image.'); |
| console.error('Error loading image:', error); |
| } |
| }; |
|
|
| const loadImage = async (url: URL) => { |
| try { |
| setIsLoading(true); |
| const img = new Image(); |
| img.src = url.href; |
| img.onload = async () => { |
| const { height, width, samScale } = handleImageScale(img); |
| setModelScale({ |
| height: height, |
| width: width, |
| samScale: samScale, |
| }); |
| img.width = width; |
| img.height = height; |
| setImage(img); |
|
|
| |
| await fetchImageEmbedding(img); |
| setIsLoading(false); |
| }; |
| } catch (error) { |
| console.log(error); |
| setIsLoading(false); |
| } |
| }; |
|
|
| const fetchImageEmbedding = async (img: HTMLImageElement) => { |
| try { |
| |
| const canvas = document.createElement('canvas'); |
| canvas.width = img.width; |
| canvas.height = img.height; |
| const ctx = canvas.getContext('2d'); |
| ctx?.drawImage(img, 0, 0); |
| |
| |
| const base64Image = canvas.toDataURL('image/jpeg').split(',')[1]; |
|
|
| |
| const samEmbedding = await imageToSamEmbedding( |
| base64Image, |
| (status: QueueStatus) => { |
| setQueueStatus(status); |
| } |
| ); |
|
|
| |
| const binaryString = window.atob(samEmbedding); |
| const len = binaryString.length; |
| const bytes = new Uint8Array(len); |
| for (let i = 0; i < len; i++) { |
| bytes[i] = binaryString.charCodeAt(i); |
| } |
|
|
| |
| const embedding = new ort.Tensor( |
| 'float32', |
| new Float32Array(bytes.buffer), |
| [1, 256, 64, 64] |
| ); |
| setTensor(embedding); |
| } catch (error) { |
| setQueueStatus({ inQueue: false }); |
| let errorMessage = 'Failed to process image. Please try again.'; |
| if (axios.isAxiosError(error)) { |
| errorMessage = error.response?.data?.message || errorMessage; |
| } |
| setError(errorMessage); |
| console.error('Error fetching embedding:', error); |
| } |
| }; |
|
|
| useEffect(() => { |
| const handleMaskUpdate = async () => { |
| await runONNX(); |
| }; |
| handleMaskUpdate(); |
| }, [clicks]); |
|
|
| const runONNX = async () => { |
| try { |
| |
| if (descriptionState.state !== 'ready') return; |
| |
| console.log('Running ONNX model with:', { |
| modelLoaded: model !== null, |
| hasClicks: clicks !== null, |
| hasTensor: tensor !== null, |
| hasModelScale: modelScale !== null |
| }); |
|
|
| if ( |
| model === null || |
| clicks === null || |
| tensor === null || |
| modelScale === null |
| ) { |
| console.log('Missing required inputs, returning early'); |
| return; |
| } |
| else { |
| console.log('Preparing model feeds with:', { |
| clicks, |
| tensorShape: tensor.dims, |
| modelScale |
| }); |
|
|
| const feeds = modelData({ |
| clicks, |
| tensor, |
| modelScale, |
| }); |
|
|
| if (feeds === undefined) { |
| console.log('Model feeds undefined, returning early'); |
| return; |
| } |
|
|
| console.log('Running model with feeds:', feeds); |
| const results = await model.run(feeds); |
| console.log('Model run complete, got results:', results); |
|
|
| const output = results[model.outputNames[0]]; |
| console.log('Processing output with dims:', output.dims); |
|
|
| |
| const maskArray = Array.from(output.data as Uint8Array); |
| const maskArea = maskArray.filter(val => val > 0).length; |
| console.log('Mask area (number of non-zero pixels):', maskArea); |
|
|
| |
| if (descriptionState.state !== 'ready') return; |
| |
| if (isClicked && maskImgData != null) return; |
| if (maskArea > 0) { |
| setMaskImg(onnxMaskToImage(output.data, output.dims[2], output.dims[3], false)); |
| setMaskImgData(imageDataToURL(arrayToImageData(output.data, output.dims[2], output.dims[3], true))); |
| } else { |
| console.warn('No mask area detected, clearing mask'); |
| setMaskImg(null); |
| |
| } |
| |
| console.log('Mask processing complete'); |
| } |
| } catch (e) { |
| setError('Failed to process the image. Please try again.'); |
| console.error('Error running ONNX model:', e); |
| } |
| }; |
|
|
| const handleNewRegion = () => { |
| setDescriptionState({ |
| state: 'ready', |
| description: '' |
| } as DescriptionState); |
| setMaskImg(null); |
| |
| setIsClicked(false); |
| }; |
|
|
| const handleCopyDescription = () => { |
| navigator.clipboard.writeText(descriptionState.description); |
| }; |
|
|
| const handleReset = () => { |
| |
| setDescriptionState({ |
| state: 'ready', |
| description: '' |
| } as DescriptionState); |
| setMaskImg(null); |
| |
| setImage(null); |
| setClicks(null); |
| setIsClicked(false); |
| }; |
|
|
| return ( |
| <div className="flex flex-col h-screen"> |
| {isLoading && <LoadingOverlay />} |
| {error && <ErrorModal message={error} onClose={() => setError(null)} />} |
| <QueueStatusIndicator queueStatus={queueStatus} /> |
| <div className="flex-1"> |
| <Stage |
| onImageUpload={handleImageUpload} |
| descriptionState={descriptionState} |
| setDescriptionState={setDescriptionState} |
| queueStatus={queueStatus} |
| setQueueStatus={setQueueStatus} |
| /> |
| </div> |
| <div className="description-container"> |
| <div className={`description-box ${descriptionState.state !== 'described' ? descriptionState.state : ''}`}> |
| {descriptionState.description ? ( |
| descriptionState.description + (descriptionState.state === 'describing' ? '...' : '') |
| ) : descriptionState.state === 'describing' ? ( |
| <em>Describing the region... (this may take a while if compute resources are busy)</em> |
| ) : ( |
| image ? ( |
| <em>Click on the image to describe the region</em> |
| ) : ( |
| <em>Upload an image to describe the region</em> |
| ) |
| )} |
| </div> |
| <div className="description-controls"> |
| <button |
| onClick={handleCopyDescription} |
| disabled={descriptionState.state !== 'described'} |
| > |
| Copy description |
| </button> |
| <button |
| onClick={handleNewRegion} |
| disabled={descriptionState.state !== 'described'} |
| > |
| Describe a new region |
| </button> |
| <button |
| onClick={handleReset} |
| className="reset-button" |
| disabled={descriptionState.state === 'describing' || !image} |
| > |
| Try a new image |
| </button> |
| </div> |
| </div> |
| </div> |
| ); |
| }; |
|
|
| export default App; |
|
|