background-removal / js /segmentation.js
sdragly's picture
Move segmentation to cloud
a125618
// Background removal using @huggingface/transformers with MODNet.
// Runs entirely in the browser (WebGPU with WASM fallback).
//
// We use Xenova/modnet (portrait matting, ~6.5M params) instead of
// briaai/RMBG-1.4 because RMBG's ONNX has a fixed [1,3,1024,1024] input
// shape. U²-Net activations at 1024x1024 blow the WebContent heap on
// phones without WebGPU, triggering a Jetsam kill mid-inference. MODNet
// accepts dynamic shapes and preprocesses to a 512-short-edge image,
// which is ~4x less activation memory and actually fits on mobile WASM.
let transformers = null;
let loading = false;
let loadPromise = null;
// MODNet preprocesses to 512 on the short edge, so activation memory is
// roughly proportional to the output image we feed it. OUTPUT_DIM caps the
// saved image (and indirectly the model's working resolution).
const OUTPUT_DIM = 768;
const MODEL_ID = 'Xenova/modnet';
// Workarounds for transformers.js + onnxruntime-web on iOS Safari
// (and memory-constrained mobile browsers in general).
//
// Cribbed from cmorenogit/app-profile@01081eb, which hit the same
// "tab silently dies during model load/inference" problem:
//
// 1. numThreads=1 — multi-threaded WASM on JavaScriptCore (iOS Safari)
// spirals activation memory across worker threads. Issue #1242.
// Slower, but single-threaded stays under the tab budget.
// 2. useBrowserCache=false — the Cache API adds a full-size *copy* of
// the model during loading (fetch buffer -> Cache -> WASM heap),
// tripling peak. Re-downloading each session is cheaper than
// crashing.
// 3. No WebGPU on iOS — ONNX Runtime JSEP bug #26827 breaks it on
// iOS 18.x Safari anyway; we were already on WASM in practice.
//
// These are harmless on Android/desktop (just slightly slower loads),
// so we apply them unconditionally rather than sniffing UA.
export async function getTransformers() {
if (transformers) return transformers;
transformers = await import(
'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3'
);
const env = transformers.env;
env.allowLocalModels = false;
// (1) no Cache API copy — fetch goes straight to WASM heap
env.useBrowserCache = false;
// (2) single-threaded WASM — kills the multi-copy activation spiral
if (env.backends?.onnx?.wasm) {
env.backends.onnx.wasm.numThreads = 1;
}
return transformers;
}
async function loadModel(onProgress) {
if (loadPromise) return loadPromise;
loading = true;
loadPromise = _doLoad(onProgress);
try {
return await loadPromise;
} finally {
loading = false;
loadPromise = null;
}
}
async function _doLoad(onProgress) {
const { AutoModel, AutoProcessor } = await getTransformers();
if (onProgress) onProgress({ status: 'loading', message: 'Loading AI model...' });
const hasWebGPU = typeof navigator !== 'undefined' && !!navigator.gpu;
const device = hasWebGPU ? 'webgpu' : 'wasm';
const dtypeProgress = (p) => {
if (onProgress && p.progress != null) {
onProgress({
status: 'downloading',
message: `Downloading model: ${Math.round(p.progress)}%`,
progress: p.progress,
});
}
};
// Try smallest dtype first to keep peak memory low on phones.
const dtypePreference = device === 'webgpu'
? ['fp16', 'q8', 'fp32']
: ['q8', 'fp16', 'fp32'];
let model = null;
let loadedDtype = null;
let lastErr = null;
for (const dtype of dtypePreference) {
try {
model = await AutoModel.from_pretrained(MODEL_ID, {
device,
dtype,
progress_callback: dtypeProgress,
});
loadedDtype = dtype;
console.log(`[segmentation] loaded ${MODEL_ID} with dtype=${dtype}, device=${device}`);
break;
} catch (err) {
console.warn(`[segmentation] dtype=${dtype} failed:`, err && err.message);
lastErr = err;
}
}
if (!model) throw lastErr || new Error(`Failed to load ${MODEL_ID}`);
const processor = await AutoProcessor.from_pretrained(MODEL_ID);
return { model, processor, dtype: loadedDtype, device };
}
/**
* Decode the camera image once, downscaled to OUTPUT_DIM. The model's own
* processor will further resize this to short-edge 512 for inference, so
* the same blob is used both as the model input and as the canvas we apply
* the final mask to.
*/
async function prepareOutputImage(imageBlob) {
const probe = await createImageBitmap(imageBlob);
const fullW = probe.width;
const fullH = probe.height;
probe.close();
const ratio = Math.min(1, OUTPUT_DIM / Math.max(fullW, fullH));
const w = Math.round(fullW * ratio);
const h = Math.round(fullH * ratio);
const bitmap = await createImageBitmap(imageBlob, {
resizeWidth: w, resizeHeight: h, resizeQuality: 'medium',
});
const canvas = new OffscreenCanvas(w, h);
canvas.getContext('2d').drawImage(bitmap, 0, 0);
bitmap.close();
return canvas.convertToBlob({ type: 'image/png' });
}
/**
* Remove background from an image blob.
* @param {Blob} imageBlob - JPEG/PNG image
* @param {function} onProgress - progress callback
* @returns {Promise<Blob>} - PNG with transparent background
*/
export async function removeBackground(imageBlob, onProgress) {
const tag = (message) => {
if (onProgress) onProgress({ status: 'processing', message });
};
// 1. Downscale the camera photo BEFORE loading the model (lower peak mem)
tag('bg: decoding photo');
const outputBlob = await prepareOutputImage(imageBlob);
// 2. Load model
tag('bg: loading model');
const { model, processor, dtype, device } = await loadModel(onProgress);
const { RawImage } = await getTransformers();
const label = `bg[${dtype}/${device}]`;
// 3. Preprocess via the model's own processor. MODNet takes dynamic
// input sizes, short edge 512, so the working resolution scales
// with the output blob we feed it.
tag(`${label}: preprocessing`);
const url = URL.createObjectURL(outputBlob);
const image = await RawImage.fromURL(url);
URL.revokeObjectURL(url);
const { pixel_values } = await processor(image);
tag(`${label}: tensor ${pixel_values.dims.join('x')}`);
// 4. Forward pass — biggest memory peak. MODNet activations scale with
// the input tensor size; the tag above shows the actual dims.
tag(`${label}: running inference`);
const { output } = await model({ input: pixel_values });
const maskData = output[0].data; // Float32Array
const maskH = output[0].dims[1] || output.dims?.[2];
const maskW = output[0].dims[2] || output.dims?.[3];
// 5. Dispose tensors and model immediately so the mask-application step
// has the heap to itself
tag(`${label}: disposing model`);
if (pixel_values.dispose) pixel_values.dispose();
if (output[0].dispose) output[0].dispose();
disposeModel(model);
// 6. Apply mask to the output-sized image
tag(`${label}: applying mask`);
const bitmap = await createImageBitmap(outputBlob);
const canvas = new OffscreenCanvas(bitmap.width, bitmap.height);
const ctx = canvas.getContext('2d');
ctx.drawImage(bitmap, 0, 0);
bitmap.close();
const imgData = ctx.getImageData(0, 0, canvas.width, canvas.height);
const scaleX = maskW / canvas.width;
const scaleY = maskH / canvas.height;
const data = imgData.data;
const cw = canvas.width;
const ch = canvas.height;
for (let y = 0; y < ch; y++) {
const my = Math.min(Math.floor(y * scaleY), maskH - 1);
const maskRow = my * maskW;
const rowOffset = y * cw * 4;
for (let x = 0; x < cw; x++) {
const mx = Math.min(Math.floor(x * scaleX), maskW - 1);
const v = maskData[maskRow + mx];
data[rowOffset + x * 4 + 3] = v <= 0 ? 0 : v >= 1 ? 255 : Math.round(v * 255);
}
}
ctx.putImageData(imgData, 0, 0);
tag(`${label}: encoding result`);
const result = await canvas.convertToBlob({ type: 'image/png' });
if (onProgress) onProgress({ status: 'done', message: `${label}: done` });
return result;
}
function disposeModel(model) {
try {
if (model.dispose) model.dispose();
} catch (_) {}
loadPromise = null;
}
export function isModelLoaded() {
return loading;
}
export function isModelLoading() {
return loading;
}