| import { |
| SamModel, |
| AutoProcessor, |
| RawImage, |
| Tensor, |
| } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0-alpha.5"; |
|
|
| |
| const statusLabel = document.getElementById("status"); |
| const fileUpload = document.getElementById("upload"); |
| const imageContainer = document.getElementById("container"); |
| const example = document.getElementById("example"); |
| const uploadButton = document.getElementById("upload-button"); |
| const resetButton = document.getElementById("reset-image"); |
| const clearButton = document.getElementById("clear-points"); |
| const cutButton = document.getElementById("cut-mask"); |
| const starIcon = document.getElementById("star-icon"); |
| const crossIcon = document.getElementById("cross-icon"); |
| const maskCanvas = document.getElementById("mask-output"); |
| const maskContext = maskCanvas.getContext("2d"); |
|
|
| const EXAMPLE_URL = |
| "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/corgi.jpg"; |
|
|
| |
| let isEncoding = false; |
| let isDecoding = false; |
| let decodePending = false; |
| let lastPoints = null; |
| let isMultiMaskMode = false; |
| let imageInput = null; |
| let imageProcessed = null; |
| let imageEmbeddings = null; |
|
|
| async function decode() { |
| |
| if (isDecoding) { |
| decodePending = true; |
| return; |
| } |
| isDecoding = true; |
|
|
| |
| const reshaped = imageProcessed.reshaped_input_sizes[0]; |
| const points = lastPoints |
| .map((x) => [x.position[0] * reshaped[1], x.position[1] * reshaped[0]]) |
| .flat(Infinity); |
| const labels = lastPoints.map((x) => BigInt(x.label)).flat(Infinity); |
|
|
| const num_points = lastPoints.length; |
| const input_points = new Tensor("float32", points, [1, 1, num_points, 2]); |
| const input_labels = new Tensor("int64", labels, [1, 1, num_points]); |
|
|
| |
| const { pred_masks, iou_scores } = await model({ |
| ...imageEmbeddings, |
| input_points, |
| input_labels, |
| }); |
|
|
| |
| const masks = await processor.post_process_masks( |
| pred_masks, |
| imageProcessed.original_sizes, |
| imageProcessed.reshaped_input_sizes, |
| ); |
|
|
| isDecoding = false; |
|
|
| updateMaskOverlay(RawImage.fromTensor(masks[0][0]), iou_scores.data); |
|
|
| |
| if (decodePending) { |
| decodePending = false; |
| decode(); |
| } |
| } |
|
|
| function updateMaskOverlay(mask, scores) { |
| |
| if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) { |
| maskCanvas.width = mask.width; |
| maskCanvas.height = mask.height; |
| } |
|
|
| |
| const imageData = maskContext.createImageData( |
| maskCanvas.width, |
| maskCanvas.height, |
| ); |
|
|
| |
| const numMasks = scores.length; |
| let bestIndex = 0; |
| for (let i = 1; i < numMasks; ++i) { |
| if (scores[i] > scores[bestIndex]) { |
| bestIndex = i; |
| } |
| } |
| statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`; |
|
|
| |
| const pixelData = imageData.data; |
| for (let i = 0; i < pixelData.length; ++i) { |
| if (mask.data[numMasks * i + bestIndex] === 1) { |
| const offset = 4 * i; |
| pixelData[offset] = 0; |
| pixelData[offset + 1] = 114; |
| pixelData[offset + 2] = 189; |
| pixelData[offset + 3] = 255; |
| } |
| } |
|
|
| |
| maskContext.putImageData(imageData, 0, 0); |
| } |
|
|
| function clearPointsAndMask() { |
| |
| isMultiMaskMode = false; |
| lastPoints = null; |
|
|
| |
| document.querySelectorAll(".icon").forEach((e) => e.remove()); |
|
|
| |
| cutButton.disabled = true; |
|
|
| |
| maskContext.clearRect(0, 0, maskCanvas.width, maskCanvas.height); |
| } |
| clearButton.addEventListener("click", clearPointsAndMask); |
|
|
| resetButton.addEventListener("click", () => { |
| |
| imageInput = null; |
| imageProcessed = null; |
| imageEmbeddings = null; |
| isEncoding = false; |
| isDecoding = false; |
|
|
| |
| clearPointsAndMask(); |
|
|
| |
| cutButton.disabled = true; |
| imageContainer.style.backgroundImage = "none"; |
| uploadButton.style.display = "flex"; |
| statusLabel.textContent = "Ready"; |
| }); |
|
|
| async function encode(url) { |
| if (isEncoding) return; |
| isEncoding = true; |
| statusLabel.textContent = "Extracting image embedding..."; |
|
|
| imageInput = await RawImage.fromURL(url); |
|
|
| |
| imageContainer.style.backgroundImage = `url(${url})`; |
| uploadButton.style.display = "none"; |
| cutButton.disabled = true; |
|
|
| |
| imageProcessed = await processor(imageInput); |
| imageEmbeddings = await model.get_image_embeddings(imageProcessed); |
|
|
| statusLabel.textContent = "Embedding extracted!"; |
| isEncoding = false; |
| } |
|
|
| |
| fileUpload.addEventListener("change", function (e) { |
| const file = e.target.files[0]; |
| if (!file) return; |
|
|
| const reader = new FileReader(); |
|
|
| |
| reader.onload = (e2) => encode(e2.target.result); |
|
|
| reader.readAsDataURL(file); |
| }); |
|
|
| example.addEventListener("click", (e) => { |
| e.preventDefault(); |
| encode(EXAMPLE_URL); |
| }); |
|
|
| |
| imageContainer.addEventListener("mousedown", (e) => { |
| if (e.button !== 0 && e.button !== 2) { |
| return; |
| } |
| if (!imageEmbeddings) { |
| return; |
| } |
| if (!isMultiMaskMode) { |
| lastPoints = []; |
| isMultiMaskMode = true; |
| cutButton.disabled = false; |
| } |
|
|
| const point = getPoint(e); |
| lastPoints.push(point); |
|
|
| |
| const icon = (point.label === 1 ? starIcon : crossIcon).cloneNode(); |
| icon.style.left = `${point.position[0] * 100}%`; |
| icon.style.top = `${point.position[1] * 100}%`; |
| imageContainer.appendChild(icon); |
|
|
| |
| decode(); |
| }); |
|
|
| |
| function clamp(x, min = 0, max = 1) { |
| return Math.max(Math.min(x, max), min); |
| } |
|
|
| function getPoint(e) { |
| |
| const bb = imageContainer.getBoundingClientRect(); |
|
|
| |
| const mouseX = clamp((e.clientX - bb.left) / bb.width); |
| const mouseY = clamp((e.clientY - bb.top) / bb.height); |
|
|
| return { |
| position: [mouseX, mouseY], |
| label: |
| e.button === 2 |
| ? 0 |
| : 1, |
| }; |
| } |
|
|
| |
| imageContainer.addEventListener("contextmenu", (e) => e.preventDefault()); |
|
|
| |
| imageContainer.addEventListener("mousemove", (e) => { |
| if (!imageEmbeddings || isMultiMaskMode) { |
| |
| |
| return; |
| } |
| lastPoints = [getPoint(e)]; |
|
|
| decode(); |
| }); |
|
|
| |
| cutButton.addEventListener("click", async () => { |
| const [w, h] = [maskCanvas.width, maskCanvas.height]; |
|
|
| |
| const maskImageData = maskContext.getImageData(0, 0, w, h); |
|
|
| |
| const cutCanvas = new OffscreenCanvas(w, h); |
| const cutContext = cutCanvas.getContext("2d"); |
|
|
| |
| const maskPixelData = maskImageData.data; |
| const imagePixelData = imageInput.data; |
| for (let i = 0; i < w * h; ++i) { |
| const sourceOffset = 3 * i; |
| const targetOffset = 4 * i; |
|
|
| if (maskPixelData[targetOffset + 3] > 0) { |
| |
| for (let j = 0; j < 3; ++j) { |
| maskPixelData[targetOffset + j] = imagePixelData[sourceOffset + j]; |
| } |
| } |
| } |
| cutContext.putImageData(maskImageData, 0, 0); |
|
|
| |
| const link = document.createElement("a"); |
| link.download = "image.png"; |
| link.href = URL.createObjectURL(await cutCanvas.convertToBlob()); |
| link.click(); |
| link.remove(); |
| }); |
|
|
| const model_id = "Xenova/slimsam-77-uniform"; |
| statusLabel.textContent = "Loading model..."; |
| const model = await SamModel.from_pretrained(model_id, { |
| dtype: "fp16", |
| device: "webgpu", |
| }); |
| const processor = await AutoProcessor.from_pretrained(model_id); |
| statusLabel.textContent = "Ready"; |
|
|
| |
| fileUpload.disabled = false; |
| uploadButton.style.opacity = 1; |
| example.style.pointerEvents = "auto"; |
|
|