background-removal / js /auto-segment.js
sdragly's picture
Move segmentation to cloud
a125618
// Auto-segmentation using SlimSAM (Segment Anything Model)
// Generates a grid of point prompts, runs mask decoder for each,
// filters and deduplicates to find distinct parts of a drawing.
import { getTransformers } from './segmentation.js';
const SAM_MODEL = 'Xenova/slimsam-77-uniform';
const SAM_DIM = 384; // smaller input → quadratically less peak memory
const GRID_SIZE = 6; // 6x6 = 36 points, filtered to non-transparent
const MIN_IOU_SCORE = 0.65;
const MIN_AREA_FRAC = 0.005; // minimum mask area as fraction of image
const NMS_IOU_THRESHOLD = 0.5;
/**
* Auto-segment an image into distinct parts using SlimSAM.
* @param {Blob} imageBlob - background-removed PNG
* @param {function} onProgress - { message, progress }
* @returns {Promise<SegmentResult[]>}
*/
export async function autoSegment(imageBlob, onProgress) {
const { SamModel, AutoProcessor, RawImage } = await getTransformers();
// Prefer WebGPU (fast), fall back to WASM (slow but universal)
const hasWebGPU = typeof navigator !== 'undefined' && !!navigator.gpu;
const device = hasWebGPU ? 'webgpu' : 'wasm';
if (onProgress) onProgress({ message: `Loading segmentation model (${device})...`, progress: 0 });
const dtypeProgress = (p) => {
if (onProgress && p.progress != null) {
onProgress({
message: `Downloading model: ${Math.round(p.progress)}%`,
progress: p.progress * 0.2, // 0-20% for download
});
}
};
// Try the smallest dtype first; fall back step-by-step if the model
// doesn't ship that variant or the runtime can't load it.
const dtypePreference = device === 'webgpu'
? ['fp16', 'q8', 'fp32']
: ['q8', 'fp16', 'fp32'];
let model = null;
let loadedDtype = null;
let lastErr = null;
for (const dtype of dtypePreference) {
try {
model = await SamModel.from_pretrained(SAM_MODEL, {
device,
dtype,
progress_callback: dtypeProgress,
});
loadedDtype = dtype;
console.log(`[auto-segment] loaded SAM with dtype=${dtype}, device=${device}`);
break;
} catch (err) {
console.warn(`[auto-segment] dtype=${dtype} failed:`, err && err.message);
lastErr = err;
}
}
if (!model) throw lastErr || new Error('Failed to load SAM');
if (onProgress) onProgress({ message: `sam: loaded ${loadedDtype}/${device}`, progress: 22 });
const processor = await AutoProcessor.from_pretrained(SAM_MODEL);
if (onProgress) onProgress({ message: 'Preparing image...', progress: 25 });
// Downscale for inference
const workBlob = await downscale(imageBlob, SAM_DIM);
const workBitmap = await createImageBitmap(workBlob);
const imgW = workBitmap.width;
const imgH = workBitmap.height;
// Build opacity map to skip transparent grid points (bg-removed image)
const opaCanvas = new OffscreenCanvas(imgW, imgH);
const opaCtx = opaCanvas.getContext('2d');
opaCtx.drawImage(workBitmap, 0, 0);
const opaData = opaCtx.getImageData(0, 0, imgW, imgH).data;
workBitmap.close();
// Load as RawImage for the processor
const url = URL.createObjectURL(workBlob);
const rawImage = await RawImage.fromURL(url);
URL.revokeObjectURL(url);
if (onProgress) onProgress({ message: 'Analyzing image...', progress: 30 });
// Generate grid points, filter to non-transparent pixels
const allGridPoints = generateGrid(imgW, imgH, GRID_SIZE);
const gridPoints = allGridPoints.filter(([px, py]) => {
// Check a small area around the point for any opaque pixel
const r = 3;
for (let dy = -r; dy <= r; dy++) {
for (let dx = -r; dx <= r; dx++) {
const sx = Math.max(0, Math.min(imgW - 1, px + dx));
const sy = Math.max(0, Math.min(imgH - 1, py + dy));
if (opaData[(sy * imgW + sx) * 4 + 3] > 128) return true;
}
}
return false;
});
const totalPoints = gridPoints.length;
const allMasks = [];
// Helper: yield to browser so page stays responsive
const yieldToBrowser = () => new Promise(r => setTimeout(r, 0));
// Run mask decoder for each grid point
for (let i = 0; i < totalPoints; i++) {
if (onProgress) {
const pct = 30 + (i / totalPoints) * 50; // 30-80%
onProgress({
message: `Finding parts... ${i + 1}/${totalPoints}`,
progress: pct,
});
}
// Yield every few iterations to keep the page responsive
if (i % 3 === 0) await yieldToBrowser();
const [px, py] = gridPoints[i];
try {
const inputs = await processor(rawImage, {
input_points: [[[px, py]]],
input_labels: [[1]],
});
const outputs = await model(inputs);
const masks = await processor.post_process_masks(
outputs.pred_masks,
inputs.original_sizes,
inputs.reshaped_input_sizes,
);
// Get IoU scores - shape [1, numMasks]
const iouScores = outputs.iou_scores.data;
// Find best mask candidate
let bestIdx = 0;
let bestScore = iouScores[0];
for (let j = 1; j < iouScores.length; j++) {
if (iouScores[j] > bestScore) {
bestScore = iouScores[j];
bestIdx = j;
}
}
if (bestScore >= MIN_IOU_SCORE) {
// Extract the mask data for the best candidate
const maskTensor = masks[0][0]; // [numMasks, H, W]
const maskH = maskTensor.dims[1];
const maskW = maskTensor.dims[2];
const maskData = maskTensor.data;
const maskSize = maskH * maskW;
const offset = bestIdx * maskSize;
// Copy just this mask's data
const singleMask = new Float32Array(maskSize);
for (let k = 0; k < maskSize; k++) {
singleMask[k] = maskData[offset + k];
}
allMasks.push({
mask: singleMask,
maskW,
maskH,
score: bestScore,
});
}
// Dispose tensors
if (outputs.pred_masks.dispose) outputs.pred_masks.dispose();
if (outputs.iou_scores.dispose) outputs.iou_scores.dispose();
} catch (e) {
console.warn(`Grid point ${i} failed:`, e);
}
}
if (onProgress) onProgress({ message: 'Filtering results...', progress: 82 });
// Compute bounding box and area for each mask
const minArea = imgW * imgH * MIN_AREA_FRAC;
let candidates = allMasks.map((m, i) => {
const { bbox, area } = computeMaskStats(m.mask, m.maskW, m.maskH);
return { ...m, bbox, area, id: `seg-${i}` };
});
// Filter by area
candidates = candidates.filter(m => m.area >= minArea);
// NMS
candidates = nonMaxSuppression(candidates, NMS_IOU_THRESHOLD);
if (onProgress) onProgress({ message: 'Extracting parts...', progress: 88 });
// Crop each mask from the original image
const results = [];
for (let i = 0; i < candidates.length; i++) {
const c = candidates[i];
c.id = `seg-${i}`;
try {
c.croppedBlob = await extractMaskRegion(imageBlob, c);
results.push(c);
} catch (e) {
console.warn(`Failed to extract segment ${i}:`, e);
}
}
// Sort by area descending (largest first)
results.sort((a, b) => b.area - a.area);
// Dispose model
try { if (model.dispose) model.dispose(); } catch (_) {}
if (onProgress) onProgress({ message: 'Done!', progress: 100 });
return results;
}
// ---- Grid generation ----
function generateGrid(w, h, n) {
const points = [];
const stepX = w / (n + 1);
const stepY = h / (n + 1);
for (let row = 1; row <= n; row++) {
for (let col = 1; col <= n; col++) {
points.push([Math.round(col * stepX), Math.round(row * stepY)]);
}
}
return points;
}
// ---- Mask stats ----
function computeMaskStats(mask, w, h) {
let minX = w, minY = h, maxX = 0, maxY = 0;
let area = 0;
for (let y = 0; y < h; y++) {
for (let x = 0; x < w; x++) {
if (mask[y * w + x] > 0) {
area++;
if (x < minX) minX = x;
if (x > maxX) maxX = x;
if (y < minY) minY = y;
if (y > maxY) maxY = y;
}
}
}
return {
bbox: {
x: minX / w,
y: minY / h,
w: (maxX - minX + 1) / w,
h: (maxY - minY + 1) / h,
},
area,
};
}
// ---- Non-maximum suppression ----
function maskIoU(a, b) {
// Both masks must have same dimensions
const len = a.mask.length;
let intersection = 0, union = 0;
for (let i = 0; i < len; i++) {
const av = a.mask[i] > 0 ? 1 : 0;
const bv = b.mask[i] > 0 ? 1 : 0;
if (av && bv) intersection++;
if (av || bv) union++;
}
return union === 0 ? 0 : intersection / union;
}
function nonMaxSuppression(masks, threshold) {
// Sort by score descending
const sorted = [...masks].sort((a, b) => b.score - a.score);
const kept = [];
for (const candidate of sorted) {
let dominated = false;
for (const existing of kept) {
if (maskIoU(candidate, existing) > threshold) {
dominated = true;
break;
}
}
if (!dominated) {
kept.push(candidate);
}
}
return kept;
}
// ---- Extract mask region as cropped blob ----
async function extractMaskRegion(imageBlob, segment) {
const bitmap = await createImageBitmap(imageBlob);
const fullW = bitmap.width;
const fullH = bitmap.height;
// Convert normalized bbox to pixel coords with padding
const pad = 4;
const bx = Math.max(0, Math.floor(segment.bbox.x * fullW) - pad);
const by = Math.max(0, Math.floor(segment.bbox.y * fullH) - pad);
const bw = Math.min(fullW - bx, Math.ceil(segment.bbox.w * fullW) + pad * 2);
const bh = Math.min(fullH - by, Math.ceil(segment.bbox.h * fullH) + pad * 2);
const canvas = new OffscreenCanvas(bw, bh);
const ctx = canvas.getContext('2d');
ctx.drawImage(bitmap, bx, by, bw, bh, 0, 0, bw, bh);
bitmap.close();
// Apply mask as alpha
const imgData = ctx.getImageData(0, 0, bw, bh);
const scaleX = segment.maskW / fullW;
const scaleY = segment.maskH / fullH;
for (let y = 0; y < bh; y++) {
for (let x = 0; x < bw; x++) {
const mx = Math.min(Math.floor((bx + x) * scaleX), segment.maskW - 1);
const my = Math.min(Math.floor((by + y) * scaleY), segment.maskH - 1);
const maskVal = segment.mask[my * segment.maskW + mx] > 0 ? 1 : 0;
const idx = (y * bw + x) * 4;
// Multiply existing alpha with mask
imgData.data[idx + 3] = Math.round(imgData.data[idx + 3] * maskVal);
}
}
ctx.putImageData(imgData, 0, 0);
return canvas.convertToBlob({ type: 'image/png' });
}
// ---- Downscale ----
async function downscale(imageBlob, maxDim) {
const probe = await createImageBitmap(imageBlob);
const { width, height } = probe;
if (width <= maxDim && height <= maxDim) {
probe.close();
return imageBlob;
}
const ratio = Math.min(maxDim / width, maxDim / height);
const newW = Math.round(width * ratio);
const newH = Math.round(height * ratio);
const resized = await createImageBitmap(imageBlob, {
resizeWidth: newW, resizeHeight: newH, resizeQuality: 'medium',
});
probe.close();
const canvas = new OffscreenCanvas(newW, newH);
canvas.getContext('2d').drawImage(resized, 0, 0);
resized.close();
return canvas.convertToBlob({ type: 'image/png' });
}
// ---- Mask to polygon (for adjustment) ----
/**
* Convert a binary mask to a simplified polygon (normalized 0-1 coords).
* Uses border tracing and Douglas-Peucker simplification.
*/
export function maskToPolygon(mask, maskW, maskH, maxPoints = 30) {
// Find contour points using simple border following
const contour = traceContour(mask, maskW, maskH);
if (contour.length < 3) return contour;
// Simplify with Douglas-Peucker
const tolerance = Math.max(maskW, maskH) * 0.015;
let simplified = douglasPeucker(contour, tolerance);
// Cap points
while (simplified.length > maxPoints) {
simplified = douglasPeucker(simplified, tolerance * 1.5);
}
// Normalize to 0-1
return simplified.map(p => ({ x: p.x / maskW, y: p.y / maskH }));
}
function traceContour(mask, w, h) {
// Find first border pixel
let startX = -1, startY = -1;
outer: for (let y = 0; y < h; y++) {
for (let x = 0; x < w; x++) {
if (mask[y * w + x] > 0) {
startX = x;
startY = y;
break outer;
}
}
}
if (startX < 0) return [];
const contour = [];
const dirs = [
[1, 0], [1, 1], [0, 1], [-1, 1],
[-1, 0], [-1, -1], [0, -1], [1, -1],
];
let cx = startX, cy = startY;
let dir = 0;
const maxSteps = w * h;
for (let step = 0; step < maxSteps; step++) {
contour.push({ x: cx, y: cy });
// Look for next border pixel
let found = false;
const startDir = (dir + 5) % 8; // turn back to find outline
for (let i = 0; i < 8; i++) {
const d = (startDir + i) % 8;
const nx = cx + dirs[d][0];
const ny = cy + dirs[d][1];
if (nx >= 0 && nx < w && ny >= 0 && ny < h && mask[ny * w + nx] > 0) {
cx = nx;
cy = ny;
dir = d;
found = true;
break;
}
}
if (!found || (cx === startX && cy === startY && step > 2)) break;
}
// Subsample to avoid too many points
if (contour.length > 200) {
const step = Math.ceil(contour.length / 200);
const subsampled = [];
for (let i = 0; i < contour.length; i += step) {
subsampled.push(contour[i]);
}
return subsampled;
}
return contour;
}
function douglasPeucker(points, epsilon) {
if (points.length <= 2) return points;
let maxDist = 0;
let maxIdx = 0;
const first = points[0];
const last = points[points.length - 1];
for (let i = 1; i < points.length - 1; i++) {
const d = pointToLineDist(points[i], first, last);
if (d > maxDist) {
maxDist = d;
maxIdx = i;
}
}
if (maxDist > epsilon) {
const left = douglasPeucker(points.slice(0, maxIdx + 1), epsilon);
const right = douglasPeucker(points.slice(maxIdx), epsilon);
return [...left.slice(0, -1), ...right];
}
return [first, last];
}
function pointToLineDist(p, a, b) {
const dx = b.x - a.x;
const dy = b.y - a.y;
const len2 = dx * dx + dy * dy;
if (len2 === 0) return Math.sqrt((p.x - a.x) ** 2 + (p.y - a.y) ** 2);
const t = Math.max(0, Math.min(1, ((p.x - a.x) * dx + (p.y - a.y) * dy) / len2));
const projX = a.x + t * dx;
const projY = a.y + t * dy;
return Math.sqrt((p.x - projX) ** 2 + (p.y - projY) ** 2);
}