| | |
| | import init, { Model } from "./build/m.js"; |
| |
|
| | async function fetchArrayBuffer(url, cacheModel = true) { |
| | if (!cacheModel) |
| | return new Uint8Array(await (await fetch(url)).arrayBuffer()); |
| | const cacheName = "sam-candle-cache"; |
| | const cache = await caches.open(cacheName); |
| | const cachedResponse = await cache.match(url); |
| | if (cachedResponse) { |
| | const data = await cachedResponse.arrayBuffer(); |
| | return new Uint8Array(data); |
| | } |
| | const res = await fetch(url, { cache: "force-cache" }); |
| | cache.put(url, res.clone()); |
| | return new Uint8Array(await res.arrayBuffer()); |
| | } |
| | class SAMModel { |
| | static instance = {}; |
| | |
| | static imageArrayHash = {}; |
| | |
| | static currentModelID = null; |
| |
|
| | static async getInstance(modelURL, modelID) { |
| | if (!this.instance[modelID]) { |
| | await init(); |
| |
|
| | self.postMessage({ |
| | status: "loading", |
| | message: `Loading Model ${modelID}`, |
| | }); |
| | const weightsArrayU8 = await fetchArrayBuffer(modelURL); |
| | this.instance[modelID] = new Model( |
| | weightsArrayU8, |
| | /tiny|mobile/.test(modelID) |
| | ); |
| | } else { |
| | self.postMessage({ status: "loading", message: "Model Already Loaded" }); |
| | } |
| | |
| | this.currentModelID = modelID; |
| | return this.instance[modelID]; |
| | } |
| |
|
| | |
| | static setImageEmbeddings(imageArrayU8) { |
| | |
| | const imageArrayHash = this.getSimpleHash(imageArrayU8); |
| | if ( |
| | this.imageArrayHash[this.currentModelID] === imageArrayHash && |
| | this.instance[this.currentModelID] |
| | ) { |
| | self.postMessage({ |
| | status: "embedding", |
| | message: "Embeddings Already Set", |
| | }); |
| | return; |
| | } |
| | this.imageArrayHash[this.currentModelID] = imageArrayHash; |
| | this.instance[this.currentModelID].set_image_embeddings(imageArrayU8); |
| | self.postMessage({ status: "embedding", message: "Embeddings Set" }); |
| | } |
| |
|
| | static getSimpleHash(imageArrayU8) { |
| | |
| | let imageArrayHash = 0; |
| | for (let i = 0; i < imageArrayU8.length; i += 100) { |
| | imageArrayHash ^= imageArrayU8[i]; |
| | } |
| | return imageArrayHash.toString(16); |
| | } |
| | } |
| |
|
| | async function createImageCanvas( |
| | { mask_shape, mask_data }, |
| | { original_width, original_height, width, height } |
| | ) { |
| | const [_, __, shape_width, shape_height] = mask_shape; |
| | const maskCanvas = new OffscreenCanvas(shape_width, shape_height); |
| | const maskCtx = maskCanvas.getContext("2d"); |
| | const canvas = new OffscreenCanvas(original_width, original_height); |
| | const ctx = canvas.getContext("2d"); |
| |
|
| | const imageData = maskCtx.createImageData( |
| | maskCanvas.width, |
| | maskCanvas.height |
| | ); |
| | const data = imageData.data; |
| |
|
| | for (let p = 0; p < data.length; p += 4) { |
| | data[p] = 0; |
| | data[p + 1] = 0; |
| | data[p + 2] = 0; |
| | data[p + 3] = mask_data[p / 4] * 255; |
| | } |
| | maskCtx.putImageData(imageData, 0, 0); |
| |
|
| | let sx, sy; |
| | if (original_height < original_width) { |
| | sy = original_height / original_width; |
| | sx = 1; |
| | } else { |
| | sy = 1; |
| | sx = original_width / original_height; |
| | } |
| | ctx.drawImage( |
| | maskCanvas, |
| | 0, |
| | 0, |
| | maskCanvas.width * sx, |
| | maskCanvas.height * sy, |
| | 0, |
| | 0, |
| | original_width, |
| | original_height |
| | ); |
| |
|
| | const blob = await canvas.convertToBlob(); |
| | return URL.createObjectURL(blob); |
| | } |
| |
|
| | self.addEventListener("message", async (event) => { |
| | const { modelURL, modelID, imageURL, points } = event.data; |
| | try { |
| | self.postMessage({ status: "loading", message: "Starting SAM" }); |
| | const sam = await SAMModel.getInstance(modelURL, modelID); |
| |
|
| | self.postMessage({ status: "loading", message: "Loading Image" }); |
| | const imageArrayU8 = await fetchArrayBuffer(imageURL, false); |
| |
|
| | self.postMessage({ status: "embedding", message: "Creating Embeddings" }); |
| | SAMModel.setImageEmbeddings(imageArrayU8); |
| | if (!points) { |
| | |
| | self.postMessage({ |
| | status: "complete-embedding", |
| | message: "Embeddings Complete", |
| | }); |
| | return; |
| | } |
| |
|
| | self.postMessage({ status: "segmenting", message: "Segmenting" }); |
| | const { mask, image } = sam.mask_for_point({ points }); |
| | const maskDataURL = await createImageCanvas(mask, image); |
| | |
| | self.postMessage({ |
| | status: "complete", |
| | message: "Segmentation Complete", |
| | output: { maskURL: maskDataURL }, |
| | }); |
| | } catch (e) { |
| | self.postMessage({ error: e }); |
| | } |
| | }); |
| |
|