// Background removal using @huggingface/transformers with MODNet. // Runs entirely in the browser (WebGPU with WASM fallback). // // We use Xenova/modnet (portrait matting, ~6.5M params) instead of // briaai/RMBG-1.4 because RMBG's ONNX has a fixed [1,3,1024,1024] input // shape. U²-Net activations at 1024x1024 blow the WebContent heap on // phones without WebGPU, triggering a Jetsam kill mid-inference. MODNet // accepts dynamic shapes and preprocesses to a 512-short-edge image, // which is ~4x less activation memory and actually fits on mobile WASM. let transformers = null; let loading = false; let loadPromise = null; // MODNet preprocesses to 512 on the short edge, so activation memory is // roughly proportional to the output image we feed it. OUTPUT_DIM caps the // saved image (and indirectly the model's working resolution). const OUTPUT_DIM = 768; const MODEL_ID = 'Xenova/modnet'; // Workarounds for transformers.js + onnxruntime-web on iOS Safari // (and memory-constrained mobile browsers in general). // // Cribbed from cmorenogit/app-profile@01081eb, which hit the same // "tab silently dies during model load/inference" problem: // // 1. numThreads=1 — multi-threaded WASM on JavaScriptCore (iOS Safari) // spirals activation memory across worker threads. Issue #1242. // Slower, but single-threaded stays under the tab budget. // 2. useBrowserCache=false — the Cache API adds a full-size *copy* of // the model during loading (fetch buffer -> Cache -> WASM heap), // tripling peak. Re-downloading each session is cheaper than // crashing. // 3. No WebGPU on iOS — ONNX Runtime JSEP bug #26827 breaks it on // iOS 18.x Safari anyway; we were already on WASM in practice. // // These are harmless on Android/desktop (just slightly slower loads), // so we apply them unconditionally rather than sniffing UA. export async function getTransformers() { if (transformers) return transformers; transformers = await import( 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3' ); const env = transformers.env; env.allowLocalModels = false; // (1) no Cache API copy — fetch goes straight to WASM heap env.useBrowserCache = false; // (2) single-threaded WASM — kills the multi-copy activation spiral if (env.backends?.onnx?.wasm) { env.backends.onnx.wasm.numThreads = 1; } return transformers; } async function loadModel(onProgress) { if (loadPromise) return loadPromise; loading = true; loadPromise = _doLoad(onProgress); try { return await loadPromise; } finally { loading = false; loadPromise = null; } } async function _doLoad(onProgress) { const { AutoModel, AutoProcessor } = await getTransformers(); if (onProgress) onProgress({ status: 'loading', message: 'Loading AI model...' }); const hasWebGPU = typeof navigator !== 'undefined' && !!navigator.gpu; const device = hasWebGPU ? 'webgpu' : 'wasm'; const dtypeProgress = (p) => { if (onProgress && p.progress != null) { onProgress({ status: 'downloading', message: `Downloading model: ${Math.round(p.progress)}%`, progress: p.progress, }); } }; // Try smallest dtype first to keep peak memory low on phones. const dtypePreference = device === 'webgpu' ? ['fp16', 'q8', 'fp32'] : ['q8', 'fp16', 'fp32']; let model = null; let loadedDtype = null; let lastErr = null; for (const dtype of dtypePreference) { try { model = await AutoModel.from_pretrained(MODEL_ID, { device, dtype, progress_callback: dtypeProgress, }); loadedDtype = dtype; console.log(`[segmentation] loaded ${MODEL_ID} with dtype=${dtype}, device=${device}`); break; } catch (err) { console.warn(`[segmentation] dtype=${dtype} failed:`, err && err.message); lastErr = err; } } if (!model) throw lastErr || new Error(`Failed to load ${MODEL_ID}`); const processor = await AutoProcessor.from_pretrained(MODEL_ID); return { model, processor, dtype: loadedDtype, device }; } /** * Decode the camera image once, downscaled to OUTPUT_DIM. The model's own * processor will further resize this to short-edge 512 for inference, so * the same blob is used both as the model input and as the canvas we apply * the final mask to. */ async function prepareOutputImage(imageBlob) { const probe = await createImageBitmap(imageBlob); const fullW = probe.width; const fullH = probe.height; probe.close(); const ratio = Math.min(1, OUTPUT_DIM / Math.max(fullW, fullH)); const w = Math.round(fullW * ratio); const h = Math.round(fullH * ratio); const bitmap = await createImageBitmap(imageBlob, { resizeWidth: w, resizeHeight: h, resizeQuality: 'medium', }); const canvas = new OffscreenCanvas(w, h); canvas.getContext('2d').drawImage(bitmap, 0, 0); bitmap.close(); return canvas.convertToBlob({ type: 'image/png' }); } /** * Remove background from an image blob. * @param {Blob} imageBlob - JPEG/PNG image * @param {function} onProgress - progress callback * @returns {Promise} - PNG with transparent background */ export async function removeBackground(imageBlob, onProgress) { const tag = (message) => { if (onProgress) onProgress({ status: 'processing', message }); }; // 1. Downscale the camera photo BEFORE loading the model (lower peak mem) tag('bg: decoding photo'); const outputBlob = await prepareOutputImage(imageBlob); // 2. Load model tag('bg: loading model'); const { model, processor, dtype, device } = await loadModel(onProgress); const { RawImage } = await getTransformers(); const label = `bg[${dtype}/${device}]`; // 3. Preprocess via the model's own processor. MODNet takes dynamic // input sizes, short edge 512, so the working resolution scales // with the output blob we feed it. tag(`${label}: preprocessing`); const url = URL.createObjectURL(outputBlob); const image = await RawImage.fromURL(url); URL.revokeObjectURL(url); const { pixel_values } = await processor(image); tag(`${label}: tensor ${pixel_values.dims.join('x')}`); // 4. Forward pass — biggest memory peak. MODNet activations scale with // the input tensor size; the tag above shows the actual dims. tag(`${label}: running inference`); const { output } = await model({ input: pixel_values }); const maskData = output[0].data; // Float32Array const maskH = output[0].dims[1] || output.dims?.[2]; const maskW = output[0].dims[2] || output.dims?.[3]; // 5. Dispose tensors and model immediately so the mask-application step // has the heap to itself tag(`${label}: disposing model`); if (pixel_values.dispose) pixel_values.dispose(); if (output[0].dispose) output[0].dispose(); disposeModel(model); // 6. Apply mask to the output-sized image tag(`${label}: applying mask`); const bitmap = await createImageBitmap(outputBlob); const canvas = new OffscreenCanvas(bitmap.width, bitmap.height); const ctx = canvas.getContext('2d'); ctx.drawImage(bitmap, 0, 0); bitmap.close(); const imgData = ctx.getImageData(0, 0, canvas.width, canvas.height); const scaleX = maskW / canvas.width; const scaleY = maskH / canvas.height; const data = imgData.data; const cw = canvas.width; const ch = canvas.height; for (let y = 0; y < ch; y++) { const my = Math.min(Math.floor(y * scaleY), maskH - 1); const maskRow = my * maskW; const rowOffset = y * cw * 4; for (let x = 0; x < cw; x++) { const mx = Math.min(Math.floor(x * scaleX), maskW - 1); const v = maskData[maskRow + mx]; data[rowOffset + x * 4 + 3] = v <= 0 ? 0 : v >= 1 ? 255 : Math.round(v * 255); } } ctx.putImageData(imgData, 0, 0); tag(`${label}: encoding result`); const result = await canvas.convertToBlob({ type: 'image/png' }); if (onProgress) onProgress({ status: 'done', message: `${label}: done` }); return result; } function disposeModel(model) { try { if (model.dispose) model.dispose(); } catch (_) {} loadPromise = null; } export function isModelLoaded() { return loading; } export function isModelLoading() { return loading; }