// Auto-segmentation using SlimSAM (Segment Anything Model) // Generates a grid of point prompts, runs mask decoder for each, // filters and deduplicates to find distinct parts of a drawing. import { getTransformers } from './segmentation.js'; const SAM_MODEL = 'Xenova/slimsam-77-uniform'; const SAM_DIM = 384; // smaller input → quadratically less peak memory const GRID_SIZE = 6; // 6x6 = 36 points, filtered to non-transparent const MIN_IOU_SCORE = 0.65; const MIN_AREA_FRAC = 0.005; // minimum mask area as fraction of image const NMS_IOU_THRESHOLD = 0.5; /** * Auto-segment an image into distinct parts using SlimSAM. * @param {Blob} imageBlob - background-removed PNG * @param {function} onProgress - { message, progress } * @returns {Promise} */ export async function autoSegment(imageBlob, onProgress) { const { SamModel, AutoProcessor, RawImage } = await getTransformers(); // Prefer WebGPU (fast), fall back to WASM (slow but universal) const hasWebGPU = typeof navigator !== 'undefined' && !!navigator.gpu; const device = hasWebGPU ? 'webgpu' : 'wasm'; if (onProgress) onProgress({ message: `Loading segmentation model (${device})...`, progress: 0 }); const dtypeProgress = (p) => { if (onProgress && p.progress != null) { onProgress({ message: `Downloading model: ${Math.round(p.progress)}%`, progress: p.progress * 0.2, // 0-20% for download }); } }; // Try the smallest dtype first; fall back step-by-step if the model // doesn't ship that variant or the runtime can't load it. const dtypePreference = device === 'webgpu' ? ['fp16', 'q8', 'fp32'] : ['q8', 'fp16', 'fp32']; let model = null; let loadedDtype = null; let lastErr = null; for (const dtype of dtypePreference) { try { model = await SamModel.from_pretrained(SAM_MODEL, { device, dtype, progress_callback: dtypeProgress, }); loadedDtype = dtype; console.log(`[auto-segment] loaded SAM with dtype=${dtype}, device=${device}`); break; } catch (err) { console.warn(`[auto-segment] dtype=${dtype} failed:`, err && err.message); lastErr = err; } } if (!model) throw lastErr || new Error('Failed to load SAM'); if (onProgress) onProgress({ message: `sam: loaded ${loadedDtype}/${device}`, progress: 22 }); const processor = await AutoProcessor.from_pretrained(SAM_MODEL); if (onProgress) onProgress({ message: 'Preparing image...', progress: 25 }); // Downscale for inference const workBlob = await downscale(imageBlob, SAM_DIM); const workBitmap = await createImageBitmap(workBlob); const imgW = workBitmap.width; const imgH = workBitmap.height; // Build opacity map to skip transparent grid points (bg-removed image) const opaCanvas = new OffscreenCanvas(imgW, imgH); const opaCtx = opaCanvas.getContext('2d'); opaCtx.drawImage(workBitmap, 0, 0); const opaData = opaCtx.getImageData(0, 0, imgW, imgH).data; workBitmap.close(); // Load as RawImage for the processor const url = URL.createObjectURL(workBlob); const rawImage = await RawImage.fromURL(url); URL.revokeObjectURL(url); if (onProgress) onProgress({ message: 'Analyzing image...', progress: 30 }); // Generate grid points, filter to non-transparent pixels const allGridPoints = generateGrid(imgW, imgH, GRID_SIZE); const gridPoints = allGridPoints.filter(([px, py]) => { // Check a small area around the point for any opaque pixel const r = 3; for (let dy = -r; dy <= r; dy++) { for (let dx = -r; dx <= r; dx++) { const sx = Math.max(0, Math.min(imgW - 1, px + dx)); const sy = Math.max(0, Math.min(imgH - 1, py + dy)); if (opaData[(sy * imgW + sx) * 4 + 3] > 128) return true; } } return false; }); const totalPoints = gridPoints.length; const allMasks = []; // Helper: yield to browser so page stays responsive const yieldToBrowser = () => new Promise(r => setTimeout(r, 0)); // Run mask decoder for each grid point for (let i = 0; i < totalPoints; i++) { if (onProgress) { const pct = 30 + (i / totalPoints) * 50; // 30-80% onProgress({ message: `Finding parts... ${i + 1}/${totalPoints}`, progress: pct, }); } // Yield every few iterations to keep the page responsive if (i % 3 === 0) await yieldToBrowser(); const [px, py] = gridPoints[i]; try { const inputs = await processor(rawImage, { input_points: [[[px, py]]], input_labels: [[1]], }); const outputs = await model(inputs); const masks = await processor.post_process_masks( outputs.pred_masks, inputs.original_sizes, inputs.reshaped_input_sizes, ); // Get IoU scores - shape [1, numMasks] const iouScores = outputs.iou_scores.data; // Find best mask candidate let bestIdx = 0; let bestScore = iouScores[0]; for (let j = 1; j < iouScores.length; j++) { if (iouScores[j] > bestScore) { bestScore = iouScores[j]; bestIdx = j; } } if (bestScore >= MIN_IOU_SCORE) { // Extract the mask data for the best candidate const maskTensor = masks[0][0]; // [numMasks, H, W] const maskH = maskTensor.dims[1]; const maskW = maskTensor.dims[2]; const maskData = maskTensor.data; const maskSize = maskH * maskW; const offset = bestIdx * maskSize; // Copy just this mask's data const singleMask = new Float32Array(maskSize); for (let k = 0; k < maskSize; k++) { singleMask[k] = maskData[offset + k]; } allMasks.push({ mask: singleMask, maskW, maskH, score: bestScore, }); } // Dispose tensors if (outputs.pred_masks.dispose) outputs.pred_masks.dispose(); if (outputs.iou_scores.dispose) outputs.iou_scores.dispose(); } catch (e) { console.warn(`Grid point ${i} failed:`, e); } } if (onProgress) onProgress({ message: 'Filtering results...', progress: 82 }); // Compute bounding box and area for each mask const minArea = imgW * imgH * MIN_AREA_FRAC; let candidates = allMasks.map((m, i) => { const { bbox, area } = computeMaskStats(m.mask, m.maskW, m.maskH); return { ...m, bbox, area, id: `seg-${i}` }; }); // Filter by area candidates = candidates.filter(m => m.area >= minArea); // NMS candidates = nonMaxSuppression(candidates, NMS_IOU_THRESHOLD); if (onProgress) onProgress({ message: 'Extracting parts...', progress: 88 }); // Crop each mask from the original image const results = []; for (let i = 0; i < candidates.length; i++) { const c = candidates[i]; c.id = `seg-${i}`; try { c.croppedBlob = await extractMaskRegion(imageBlob, c); results.push(c); } catch (e) { console.warn(`Failed to extract segment ${i}:`, e); } } // Sort by area descending (largest first) results.sort((a, b) => b.area - a.area); // Dispose model try { if (model.dispose) model.dispose(); } catch (_) {} if (onProgress) onProgress({ message: 'Done!', progress: 100 }); return results; } // ---- Grid generation ---- function generateGrid(w, h, n) { const points = []; const stepX = w / (n + 1); const stepY = h / (n + 1); for (let row = 1; row <= n; row++) { for (let col = 1; col <= n; col++) { points.push([Math.round(col * stepX), Math.round(row * stepY)]); } } return points; } // ---- Mask stats ---- function computeMaskStats(mask, w, h) { let minX = w, minY = h, maxX = 0, maxY = 0; let area = 0; for (let y = 0; y < h; y++) { for (let x = 0; x < w; x++) { if (mask[y * w + x] > 0) { area++; if (x < minX) minX = x; if (x > maxX) maxX = x; if (y < minY) minY = y; if (y > maxY) maxY = y; } } } return { bbox: { x: minX / w, y: minY / h, w: (maxX - minX + 1) / w, h: (maxY - minY + 1) / h, }, area, }; } // ---- Non-maximum suppression ---- function maskIoU(a, b) { // Both masks must have same dimensions const len = a.mask.length; let intersection = 0, union = 0; for (let i = 0; i < len; i++) { const av = a.mask[i] > 0 ? 1 : 0; const bv = b.mask[i] > 0 ? 1 : 0; if (av && bv) intersection++; if (av || bv) union++; } return union === 0 ? 0 : intersection / union; } function nonMaxSuppression(masks, threshold) { // Sort by score descending const sorted = [...masks].sort((a, b) => b.score - a.score); const kept = []; for (const candidate of sorted) { let dominated = false; for (const existing of kept) { if (maskIoU(candidate, existing) > threshold) { dominated = true; break; } } if (!dominated) { kept.push(candidate); } } return kept; } // ---- Extract mask region as cropped blob ---- async function extractMaskRegion(imageBlob, segment) { const bitmap = await createImageBitmap(imageBlob); const fullW = bitmap.width; const fullH = bitmap.height; // Convert normalized bbox to pixel coords with padding const pad = 4; const bx = Math.max(0, Math.floor(segment.bbox.x * fullW) - pad); const by = Math.max(0, Math.floor(segment.bbox.y * fullH) - pad); const bw = Math.min(fullW - bx, Math.ceil(segment.bbox.w * fullW) + pad * 2); const bh = Math.min(fullH - by, Math.ceil(segment.bbox.h * fullH) + pad * 2); const canvas = new OffscreenCanvas(bw, bh); const ctx = canvas.getContext('2d'); ctx.drawImage(bitmap, bx, by, bw, bh, 0, 0, bw, bh); bitmap.close(); // Apply mask as alpha const imgData = ctx.getImageData(0, 0, bw, bh); const scaleX = segment.maskW / fullW; const scaleY = segment.maskH / fullH; for (let y = 0; y < bh; y++) { for (let x = 0; x < bw; x++) { const mx = Math.min(Math.floor((bx + x) * scaleX), segment.maskW - 1); const my = Math.min(Math.floor((by + y) * scaleY), segment.maskH - 1); const maskVal = segment.mask[my * segment.maskW + mx] > 0 ? 1 : 0; const idx = (y * bw + x) * 4; // Multiply existing alpha with mask imgData.data[idx + 3] = Math.round(imgData.data[idx + 3] * maskVal); } } ctx.putImageData(imgData, 0, 0); return canvas.convertToBlob({ type: 'image/png' }); } // ---- Downscale ---- async function downscale(imageBlob, maxDim) { const probe = await createImageBitmap(imageBlob); const { width, height } = probe; if (width <= maxDim && height <= maxDim) { probe.close(); return imageBlob; } const ratio = Math.min(maxDim / width, maxDim / height); const newW = Math.round(width * ratio); const newH = Math.round(height * ratio); const resized = await createImageBitmap(imageBlob, { resizeWidth: newW, resizeHeight: newH, resizeQuality: 'medium', }); probe.close(); const canvas = new OffscreenCanvas(newW, newH); canvas.getContext('2d').drawImage(resized, 0, 0); resized.close(); return canvas.convertToBlob({ type: 'image/png' }); } // ---- Mask to polygon (for adjustment) ---- /** * Convert a binary mask to a simplified polygon (normalized 0-1 coords). * Uses border tracing and Douglas-Peucker simplification. */ export function maskToPolygon(mask, maskW, maskH, maxPoints = 30) { // Find contour points using simple border following const contour = traceContour(mask, maskW, maskH); if (contour.length < 3) return contour; // Simplify with Douglas-Peucker const tolerance = Math.max(maskW, maskH) * 0.015; let simplified = douglasPeucker(contour, tolerance); // Cap points while (simplified.length > maxPoints) { simplified = douglasPeucker(simplified, tolerance * 1.5); } // Normalize to 0-1 return simplified.map(p => ({ x: p.x / maskW, y: p.y / maskH })); } function traceContour(mask, w, h) { // Find first border pixel let startX = -1, startY = -1; outer: for (let y = 0; y < h; y++) { for (let x = 0; x < w; x++) { if (mask[y * w + x] > 0) { startX = x; startY = y; break outer; } } } if (startX < 0) return []; const contour = []; const dirs = [ [1, 0], [1, 1], [0, 1], [-1, 1], [-1, 0], [-1, -1], [0, -1], [1, -1], ]; let cx = startX, cy = startY; let dir = 0; const maxSteps = w * h; for (let step = 0; step < maxSteps; step++) { contour.push({ x: cx, y: cy }); // Look for next border pixel let found = false; const startDir = (dir + 5) % 8; // turn back to find outline for (let i = 0; i < 8; i++) { const d = (startDir + i) % 8; const nx = cx + dirs[d][0]; const ny = cy + dirs[d][1]; if (nx >= 0 && nx < w && ny >= 0 && ny < h && mask[ny * w + nx] > 0) { cx = nx; cy = ny; dir = d; found = true; break; } } if (!found || (cx === startX && cy === startY && step > 2)) break; } // Subsample to avoid too many points if (contour.length > 200) { const step = Math.ceil(contour.length / 200); const subsampled = []; for (let i = 0; i < contour.length; i += step) { subsampled.push(contour[i]); } return subsampled; } return contour; } function douglasPeucker(points, epsilon) { if (points.length <= 2) return points; let maxDist = 0; let maxIdx = 0; const first = points[0]; const last = points[points.length - 1]; for (let i = 1; i < points.length - 1; i++) { const d = pointToLineDist(points[i], first, last); if (d > maxDist) { maxDist = d; maxIdx = i; } } if (maxDist > epsilon) { const left = douglasPeucker(points.slice(0, maxIdx + 1), epsilon); const right = douglasPeucker(points.slice(maxIdx), epsilon); return [...left.slice(0, -1), ...right]; } return [first, last]; } function pointToLineDist(p, a, b) { const dx = b.x - a.x; const dy = b.y - a.y; const len2 = dx * dx + dy * dy; if (len2 === 0) return Math.sqrt((p.x - a.x) ** 2 + (p.y - a.y) ** 2); const t = Math.max(0, Math.min(1, ((p.x - a.x) * dx + (p.y - a.y) * dy) / len2)); const projX = a.x + t * dx; const projY = a.y + t * dy; return Math.sqrt((p.x - projX) ** 2 + (p.y - projY) ** 2); }