Spaces:
Sleeping
Sleeping
| // Auto-segmentation using SlimSAM (Segment Anything Model) | |
| // Generates a grid of point prompts, runs mask decoder for each, | |
| // filters and deduplicates to find distinct parts of a drawing. | |
| import { getTransformers } from './segmentation.js'; | |
| const SAM_MODEL = 'Xenova/slimsam-77-uniform'; | |
| const SAM_DIM = 384; // smaller input → quadratically less peak memory | |
| const GRID_SIZE = 6; // 6x6 = 36 points, filtered to non-transparent | |
| const MIN_IOU_SCORE = 0.65; | |
| const MIN_AREA_FRAC = 0.005; // minimum mask area as fraction of image | |
| const NMS_IOU_THRESHOLD = 0.5; | |
| /** | |
| * Auto-segment an image into distinct parts using SlimSAM. | |
| * @param {Blob} imageBlob - background-removed PNG | |
| * @param {function} onProgress - { message, progress } | |
| * @returns {Promise<SegmentResult[]>} | |
| */ | |
| export async function autoSegment(imageBlob, onProgress) { | |
| const { SamModel, AutoProcessor, RawImage } = await getTransformers(); | |
| // Prefer WebGPU (fast), fall back to WASM (slow but universal) | |
| const hasWebGPU = typeof navigator !== 'undefined' && !!navigator.gpu; | |
| const device = hasWebGPU ? 'webgpu' : 'wasm'; | |
| if (onProgress) onProgress({ message: `Loading segmentation model (${device})...`, progress: 0 }); | |
| const dtypeProgress = (p) => { | |
| if (onProgress && p.progress != null) { | |
| onProgress({ | |
| message: `Downloading model: ${Math.round(p.progress)}%`, | |
| progress: p.progress * 0.2, // 0-20% for download | |
| }); | |
| } | |
| }; | |
| // Try the smallest dtype first; fall back step-by-step if the model | |
| // doesn't ship that variant or the runtime can't load it. | |
| const dtypePreference = device === 'webgpu' | |
| ? ['fp16', 'q8', 'fp32'] | |
| : ['q8', 'fp16', 'fp32']; | |
| let model = null; | |
| let loadedDtype = null; | |
| let lastErr = null; | |
| for (const dtype of dtypePreference) { | |
| try { | |
| model = await SamModel.from_pretrained(SAM_MODEL, { | |
| device, | |
| dtype, | |
| progress_callback: dtypeProgress, | |
| }); | |
| loadedDtype = dtype; | |
| console.log(`[auto-segment] loaded SAM with dtype=${dtype}, device=${device}`); | |
| break; | |
| } catch (err) { | |
| console.warn(`[auto-segment] dtype=${dtype} failed:`, err && err.message); | |
| lastErr = err; | |
| } | |
| } | |
| if (!model) throw lastErr || new Error('Failed to load SAM'); | |
| if (onProgress) onProgress({ message: `sam: loaded ${loadedDtype}/${device}`, progress: 22 }); | |
| const processor = await AutoProcessor.from_pretrained(SAM_MODEL); | |
| if (onProgress) onProgress({ message: 'Preparing image...', progress: 25 }); | |
| // Downscale for inference | |
| const workBlob = await downscale(imageBlob, SAM_DIM); | |
| const workBitmap = await createImageBitmap(workBlob); | |
| const imgW = workBitmap.width; | |
| const imgH = workBitmap.height; | |
| // Build opacity map to skip transparent grid points (bg-removed image) | |
| const opaCanvas = new OffscreenCanvas(imgW, imgH); | |
| const opaCtx = opaCanvas.getContext('2d'); | |
| opaCtx.drawImage(workBitmap, 0, 0); | |
| const opaData = opaCtx.getImageData(0, 0, imgW, imgH).data; | |
| workBitmap.close(); | |
| // Load as RawImage for the processor | |
| const url = URL.createObjectURL(workBlob); | |
| const rawImage = await RawImage.fromURL(url); | |
| URL.revokeObjectURL(url); | |
| if (onProgress) onProgress({ message: 'Analyzing image...', progress: 30 }); | |
| // Generate grid points, filter to non-transparent pixels | |
| const allGridPoints = generateGrid(imgW, imgH, GRID_SIZE); | |
| const gridPoints = allGridPoints.filter(([px, py]) => { | |
| // Check a small area around the point for any opaque pixel | |
| const r = 3; | |
| for (let dy = -r; dy <= r; dy++) { | |
| for (let dx = -r; dx <= r; dx++) { | |
| const sx = Math.max(0, Math.min(imgW - 1, px + dx)); | |
| const sy = Math.max(0, Math.min(imgH - 1, py + dy)); | |
| if (opaData[(sy * imgW + sx) * 4 + 3] > 128) return true; | |
| } | |
| } | |
| return false; | |
| }); | |
| const totalPoints = gridPoints.length; | |
| const allMasks = []; | |
| // Helper: yield to browser so page stays responsive | |
| const yieldToBrowser = () => new Promise(r => setTimeout(r, 0)); | |
| // Run mask decoder for each grid point | |
| for (let i = 0; i < totalPoints; i++) { | |
| if (onProgress) { | |
| const pct = 30 + (i / totalPoints) * 50; // 30-80% | |
| onProgress({ | |
| message: `Finding parts... ${i + 1}/${totalPoints}`, | |
| progress: pct, | |
| }); | |
| } | |
| // Yield every few iterations to keep the page responsive | |
| if (i % 3 === 0) await yieldToBrowser(); | |
| const [px, py] = gridPoints[i]; | |
| try { | |
| const inputs = await processor(rawImage, { | |
| input_points: [[[px, py]]], | |
| input_labels: [[1]], | |
| }); | |
| const outputs = await model(inputs); | |
| const masks = await processor.post_process_masks( | |
| outputs.pred_masks, | |
| inputs.original_sizes, | |
| inputs.reshaped_input_sizes, | |
| ); | |
| // Get IoU scores - shape [1, numMasks] | |
| const iouScores = outputs.iou_scores.data; | |
| // Find best mask candidate | |
| let bestIdx = 0; | |
| let bestScore = iouScores[0]; | |
| for (let j = 1; j < iouScores.length; j++) { | |
| if (iouScores[j] > bestScore) { | |
| bestScore = iouScores[j]; | |
| bestIdx = j; | |
| } | |
| } | |
| if (bestScore >= MIN_IOU_SCORE) { | |
| // Extract the mask data for the best candidate | |
| const maskTensor = masks[0][0]; // [numMasks, H, W] | |
| const maskH = maskTensor.dims[1]; | |
| const maskW = maskTensor.dims[2]; | |
| const maskData = maskTensor.data; | |
| const maskSize = maskH * maskW; | |
| const offset = bestIdx * maskSize; | |
| // Copy just this mask's data | |
| const singleMask = new Float32Array(maskSize); | |
| for (let k = 0; k < maskSize; k++) { | |
| singleMask[k] = maskData[offset + k]; | |
| } | |
| allMasks.push({ | |
| mask: singleMask, | |
| maskW, | |
| maskH, | |
| score: bestScore, | |
| }); | |
| } | |
| // Dispose tensors | |
| if (outputs.pred_masks.dispose) outputs.pred_masks.dispose(); | |
| if (outputs.iou_scores.dispose) outputs.iou_scores.dispose(); | |
| } catch (e) { | |
| console.warn(`Grid point ${i} failed:`, e); | |
| } | |
| } | |
| if (onProgress) onProgress({ message: 'Filtering results...', progress: 82 }); | |
| // Compute bounding box and area for each mask | |
| const minArea = imgW * imgH * MIN_AREA_FRAC; | |
| let candidates = allMasks.map((m, i) => { | |
| const { bbox, area } = computeMaskStats(m.mask, m.maskW, m.maskH); | |
| return { ...m, bbox, area, id: `seg-${i}` }; | |
| }); | |
| // Filter by area | |
| candidates = candidates.filter(m => m.area >= minArea); | |
| // NMS | |
| candidates = nonMaxSuppression(candidates, NMS_IOU_THRESHOLD); | |
| if (onProgress) onProgress({ message: 'Extracting parts...', progress: 88 }); | |
| // Crop each mask from the original image | |
| const results = []; | |
| for (let i = 0; i < candidates.length; i++) { | |
| const c = candidates[i]; | |
| c.id = `seg-${i}`; | |
| try { | |
| c.croppedBlob = await extractMaskRegion(imageBlob, c); | |
| results.push(c); | |
| } catch (e) { | |
| console.warn(`Failed to extract segment ${i}:`, e); | |
| } | |
| } | |
| // Sort by area descending (largest first) | |
| results.sort((a, b) => b.area - a.area); | |
| // Dispose model | |
| try { if (model.dispose) model.dispose(); } catch (_) {} | |
| if (onProgress) onProgress({ message: 'Done!', progress: 100 }); | |
| return results; | |
| } | |
| // ---- Grid generation ---- | |
| function generateGrid(w, h, n) { | |
| const points = []; | |
| const stepX = w / (n + 1); | |
| const stepY = h / (n + 1); | |
| for (let row = 1; row <= n; row++) { | |
| for (let col = 1; col <= n; col++) { | |
| points.push([Math.round(col * stepX), Math.round(row * stepY)]); | |
| } | |
| } | |
| return points; | |
| } | |
| // ---- Mask stats ---- | |
| function computeMaskStats(mask, w, h) { | |
| let minX = w, minY = h, maxX = 0, maxY = 0; | |
| let area = 0; | |
| for (let y = 0; y < h; y++) { | |
| for (let x = 0; x < w; x++) { | |
| if (mask[y * w + x] > 0) { | |
| area++; | |
| if (x < minX) minX = x; | |
| if (x > maxX) maxX = x; | |
| if (y < minY) minY = y; | |
| if (y > maxY) maxY = y; | |
| } | |
| } | |
| } | |
| return { | |
| bbox: { | |
| x: minX / w, | |
| y: minY / h, | |
| w: (maxX - minX + 1) / w, | |
| h: (maxY - minY + 1) / h, | |
| }, | |
| area, | |
| }; | |
| } | |
| // ---- Non-maximum suppression ---- | |
| function maskIoU(a, b) { | |
| // Both masks must have same dimensions | |
| const len = a.mask.length; | |
| let intersection = 0, union = 0; | |
| for (let i = 0; i < len; i++) { | |
| const av = a.mask[i] > 0 ? 1 : 0; | |
| const bv = b.mask[i] > 0 ? 1 : 0; | |
| if (av && bv) intersection++; | |
| if (av || bv) union++; | |
| } | |
| return union === 0 ? 0 : intersection / union; | |
| } | |
| function nonMaxSuppression(masks, threshold) { | |
| // Sort by score descending | |
| const sorted = [...masks].sort((a, b) => b.score - a.score); | |
| const kept = []; | |
| for (const candidate of sorted) { | |
| let dominated = false; | |
| for (const existing of kept) { | |
| if (maskIoU(candidate, existing) > threshold) { | |
| dominated = true; | |
| break; | |
| } | |
| } | |
| if (!dominated) { | |
| kept.push(candidate); | |
| } | |
| } | |
| return kept; | |
| } | |
| // ---- Extract mask region as cropped blob ---- | |
| async function extractMaskRegion(imageBlob, segment) { | |
| const bitmap = await createImageBitmap(imageBlob); | |
| const fullW = bitmap.width; | |
| const fullH = bitmap.height; | |
| // Convert normalized bbox to pixel coords with padding | |
| const pad = 4; | |
| const bx = Math.max(0, Math.floor(segment.bbox.x * fullW) - pad); | |
| const by = Math.max(0, Math.floor(segment.bbox.y * fullH) - pad); | |
| const bw = Math.min(fullW - bx, Math.ceil(segment.bbox.w * fullW) + pad * 2); | |
| const bh = Math.min(fullH - by, Math.ceil(segment.bbox.h * fullH) + pad * 2); | |
| const canvas = new OffscreenCanvas(bw, bh); | |
| const ctx = canvas.getContext('2d'); | |
| ctx.drawImage(bitmap, bx, by, bw, bh, 0, 0, bw, bh); | |
| bitmap.close(); | |
| // Apply mask as alpha | |
| const imgData = ctx.getImageData(0, 0, bw, bh); | |
| const scaleX = segment.maskW / fullW; | |
| const scaleY = segment.maskH / fullH; | |
| for (let y = 0; y < bh; y++) { | |
| for (let x = 0; x < bw; x++) { | |
| const mx = Math.min(Math.floor((bx + x) * scaleX), segment.maskW - 1); | |
| const my = Math.min(Math.floor((by + y) * scaleY), segment.maskH - 1); | |
| const maskVal = segment.mask[my * segment.maskW + mx] > 0 ? 1 : 0; | |
| const idx = (y * bw + x) * 4; | |
| // Multiply existing alpha with mask | |
| imgData.data[idx + 3] = Math.round(imgData.data[idx + 3] * maskVal); | |
| } | |
| } | |
| ctx.putImageData(imgData, 0, 0); | |
| return canvas.convertToBlob({ type: 'image/png' }); | |
| } | |
| // ---- Downscale ---- | |
| async function downscale(imageBlob, maxDim) { | |
| const probe = await createImageBitmap(imageBlob); | |
| const { width, height } = probe; | |
| if (width <= maxDim && height <= maxDim) { | |
| probe.close(); | |
| return imageBlob; | |
| } | |
| const ratio = Math.min(maxDim / width, maxDim / height); | |
| const newW = Math.round(width * ratio); | |
| const newH = Math.round(height * ratio); | |
| const resized = await createImageBitmap(imageBlob, { | |
| resizeWidth: newW, resizeHeight: newH, resizeQuality: 'medium', | |
| }); | |
| probe.close(); | |
| const canvas = new OffscreenCanvas(newW, newH); | |
| canvas.getContext('2d').drawImage(resized, 0, 0); | |
| resized.close(); | |
| return canvas.convertToBlob({ type: 'image/png' }); | |
| } | |
| // ---- Mask to polygon (for adjustment) ---- | |
| /** | |
| * Convert a binary mask to a simplified polygon (normalized 0-1 coords). | |
| * Uses border tracing and Douglas-Peucker simplification. | |
| */ | |
| export function maskToPolygon(mask, maskW, maskH, maxPoints = 30) { | |
| // Find contour points using simple border following | |
| const contour = traceContour(mask, maskW, maskH); | |
| if (contour.length < 3) return contour; | |
| // Simplify with Douglas-Peucker | |
| const tolerance = Math.max(maskW, maskH) * 0.015; | |
| let simplified = douglasPeucker(contour, tolerance); | |
| // Cap points | |
| while (simplified.length > maxPoints) { | |
| simplified = douglasPeucker(simplified, tolerance * 1.5); | |
| } | |
| // Normalize to 0-1 | |
| return simplified.map(p => ({ x: p.x / maskW, y: p.y / maskH })); | |
| } | |
| function traceContour(mask, w, h) { | |
| // Find first border pixel | |
| let startX = -1, startY = -1; | |
| outer: for (let y = 0; y < h; y++) { | |
| for (let x = 0; x < w; x++) { | |
| if (mask[y * w + x] > 0) { | |
| startX = x; | |
| startY = y; | |
| break outer; | |
| } | |
| } | |
| } | |
| if (startX < 0) return []; | |
| const contour = []; | |
| const dirs = [ | |
| [1, 0], [1, 1], [0, 1], [-1, 1], | |
| [-1, 0], [-1, -1], [0, -1], [1, -1], | |
| ]; | |
| let cx = startX, cy = startY; | |
| let dir = 0; | |
| const maxSteps = w * h; | |
| for (let step = 0; step < maxSteps; step++) { | |
| contour.push({ x: cx, y: cy }); | |
| // Look for next border pixel | |
| let found = false; | |
| const startDir = (dir + 5) % 8; // turn back to find outline | |
| for (let i = 0; i < 8; i++) { | |
| const d = (startDir + i) % 8; | |
| const nx = cx + dirs[d][0]; | |
| const ny = cy + dirs[d][1]; | |
| if (nx >= 0 && nx < w && ny >= 0 && ny < h && mask[ny * w + nx] > 0) { | |
| cx = nx; | |
| cy = ny; | |
| dir = d; | |
| found = true; | |
| break; | |
| } | |
| } | |
| if (!found || (cx === startX && cy === startY && step > 2)) break; | |
| } | |
| // Subsample to avoid too many points | |
| if (contour.length > 200) { | |
| const step = Math.ceil(contour.length / 200); | |
| const subsampled = []; | |
| for (let i = 0; i < contour.length; i += step) { | |
| subsampled.push(contour[i]); | |
| } | |
| return subsampled; | |
| } | |
| return contour; | |
| } | |
| function douglasPeucker(points, epsilon) { | |
| if (points.length <= 2) return points; | |
| let maxDist = 0; | |
| let maxIdx = 0; | |
| const first = points[0]; | |
| const last = points[points.length - 1]; | |
| for (let i = 1; i < points.length - 1; i++) { | |
| const d = pointToLineDist(points[i], first, last); | |
| if (d > maxDist) { | |
| maxDist = d; | |
| maxIdx = i; | |
| } | |
| } | |
| if (maxDist > epsilon) { | |
| const left = douglasPeucker(points.slice(0, maxIdx + 1), epsilon); | |
| const right = douglasPeucker(points.slice(maxIdx), epsilon); | |
| return [...left.slice(0, -1), ...right]; | |
| } | |
| return [first, last]; | |
| } | |
| function pointToLineDist(p, a, b) { | |
| const dx = b.x - a.x; | |
| const dy = b.y - a.y; | |
| const len2 = dx * dx + dy * dy; | |
| if (len2 === 0) return Math.sqrt((p.x - a.x) ** 2 + (p.y - a.y) ** 2); | |
| const t = Math.max(0, Math.min(1, ((p.x - a.x) * dx + (p.y - a.y) * dy) / len2)); | |
| const projX = a.x + t * dx; | |
| const projY = a.y + t * dy; | |
| return Math.sqrt((p.x - projX) ** 2 + (p.y - projY) ** 2); | |
| } | |