Spaces:
Sleeping
Sleeping
| // Background removal using @huggingface/transformers with MODNet. | |
| // Runs entirely in the browser (WebGPU with WASM fallback). | |
| // | |
| // We use Xenova/modnet (portrait matting, ~6.5M params) instead of | |
| // briaai/RMBG-1.4 because RMBG's ONNX has a fixed [1,3,1024,1024] input | |
| // shape. U²-Net activations at 1024x1024 blow the WebContent heap on | |
| // phones without WebGPU, triggering a Jetsam kill mid-inference. MODNet | |
| // accepts dynamic shapes and preprocesses to a 512-short-edge image, | |
| // which is ~4x less activation memory and actually fits on mobile WASM. | |
| let transformers = null; | |
| let loading = false; | |
| let loadPromise = null; | |
| // MODNet preprocesses to 512 on the short edge, so activation memory is | |
| // roughly proportional to the output image we feed it. OUTPUT_DIM caps the | |
| // saved image (and indirectly the model's working resolution). | |
| const OUTPUT_DIM = 768; | |
| const MODEL_ID = 'Xenova/modnet'; | |
| // Workarounds for transformers.js + onnxruntime-web on iOS Safari | |
| // (and memory-constrained mobile browsers in general). | |
| // | |
| // Cribbed from cmorenogit/app-profile@01081eb, which hit the same | |
| // "tab silently dies during model load/inference" problem: | |
| // | |
| // 1. numThreads=1 — multi-threaded WASM on JavaScriptCore (iOS Safari) | |
| // spirals activation memory across worker threads. Issue #1242. | |
| // Slower, but single-threaded stays under the tab budget. | |
| // 2. useBrowserCache=false — the Cache API adds a full-size *copy* of | |
| // the model during loading (fetch buffer -> Cache -> WASM heap), | |
| // tripling peak. Re-downloading each session is cheaper than | |
| // crashing. | |
| // 3. No WebGPU on iOS — ONNX Runtime JSEP bug #26827 breaks it on | |
| // iOS 18.x Safari anyway; we were already on WASM in practice. | |
| // | |
| // These are harmless on Android/desktop (just slightly slower loads), | |
| // so we apply them unconditionally rather than sniffing UA. | |
| export async function getTransformers() { | |
| if (transformers) return transformers; | |
| transformers = await import( | |
| 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3' | |
| ); | |
| const env = transformers.env; | |
| env.allowLocalModels = false; | |
| // (1) no Cache API copy — fetch goes straight to WASM heap | |
| env.useBrowserCache = false; | |
| // (2) single-threaded WASM — kills the multi-copy activation spiral | |
| if (env.backends?.onnx?.wasm) { | |
| env.backends.onnx.wasm.numThreads = 1; | |
| } | |
| return transformers; | |
| } | |
| async function loadModel(onProgress) { | |
| if (loadPromise) return loadPromise; | |
| loading = true; | |
| loadPromise = _doLoad(onProgress); | |
| try { | |
| return await loadPromise; | |
| } finally { | |
| loading = false; | |
| loadPromise = null; | |
| } | |
| } | |
| async function _doLoad(onProgress) { | |
| const { AutoModel, AutoProcessor } = await getTransformers(); | |
| if (onProgress) onProgress({ status: 'loading', message: 'Loading AI model...' }); | |
| const hasWebGPU = typeof navigator !== 'undefined' && !!navigator.gpu; | |
| const device = hasWebGPU ? 'webgpu' : 'wasm'; | |
| const dtypeProgress = (p) => { | |
| if (onProgress && p.progress != null) { | |
| onProgress({ | |
| status: 'downloading', | |
| message: `Downloading model: ${Math.round(p.progress)}%`, | |
| progress: p.progress, | |
| }); | |
| } | |
| }; | |
| // Try smallest dtype first to keep peak memory low on phones. | |
| const dtypePreference = device === 'webgpu' | |
| ? ['fp16', 'q8', 'fp32'] | |
| : ['q8', 'fp16', 'fp32']; | |
| let model = null; | |
| let loadedDtype = null; | |
| let lastErr = null; | |
| for (const dtype of dtypePreference) { | |
| try { | |
| model = await AutoModel.from_pretrained(MODEL_ID, { | |
| device, | |
| dtype, | |
| progress_callback: dtypeProgress, | |
| }); | |
| loadedDtype = dtype; | |
| console.log(`[segmentation] loaded ${MODEL_ID} with dtype=${dtype}, device=${device}`); | |
| break; | |
| } catch (err) { | |
| console.warn(`[segmentation] dtype=${dtype} failed:`, err && err.message); | |
| lastErr = err; | |
| } | |
| } | |
| if (!model) throw lastErr || new Error(`Failed to load ${MODEL_ID}`); | |
| const processor = await AutoProcessor.from_pretrained(MODEL_ID); | |
| return { model, processor, dtype: loadedDtype, device }; | |
| } | |
| /** | |
| * Decode the camera image once, downscaled to OUTPUT_DIM. The model's own | |
| * processor will further resize this to short-edge 512 for inference, so | |
| * the same blob is used both as the model input and as the canvas we apply | |
| * the final mask to. | |
| */ | |
| async function prepareOutputImage(imageBlob) { | |
| const probe = await createImageBitmap(imageBlob); | |
| const fullW = probe.width; | |
| const fullH = probe.height; | |
| probe.close(); | |
| const ratio = Math.min(1, OUTPUT_DIM / Math.max(fullW, fullH)); | |
| const w = Math.round(fullW * ratio); | |
| const h = Math.round(fullH * ratio); | |
| const bitmap = await createImageBitmap(imageBlob, { | |
| resizeWidth: w, resizeHeight: h, resizeQuality: 'medium', | |
| }); | |
| const canvas = new OffscreenCanvas(w, h); | |
| canvas.getContext('2d').drawImage(bitmap, 0, 0); | |
| bitmap.close(); | |
| return canvas.convertToBlob({ type: 'image/png' }); | |
| } | |
| /** | |
| * Remove background from an image blob. | |
| * @param {Blob} imageBlob - JPEG/PNG image | |
| * @param {function} onProgress - progress callback | |
| * @returns {Promise<Blob>} - PNG with transparent background | |
| */ | |
| export async function removeBackground(imageBlob, onProgress) { | |
| const tag = (message) => { | |
| if (onProgress) onProgress({ status: 'processing', message }); | |
| }; | |
| // 1. Downscale the camera photo BEFORE loading the model (lower peak mem) | |
| tag('bg: decoding photo'); | |
| const outputBlob = await prepareOutputImage(imageBlob); | |
| // 2. Load model | |
| tag('bg: loading model'); | |
| const { model, processor, dtype, device } = await loadModel(onProgress); | |
| const { RawImage } = await getTransformers(); | |
| const label = `bg[${dtype}/${device}]`; | |
| // 3. Preprocess via the model's own processor. MODNet takes dynamic | |
| // input sizes, short edge 512, so the working resolution scales | |
| // with the output blob we feed it. | |
| tag(`${label}: preprocessing`); | |
| const url = URL.createObjectURL(outputBlob); | |
| const image = await RawImage.fromURL(url); | |
| URL.revokeObjectURL(url); | |
| const { pixel_values } = await processor(image); | |
| tag(`${label}: tensor ${pixel_values.dims.join('x')}`); | |
| // 4. Forward pass — biggest memory peak. MODNet activations scale with | |
| // the input tensor size; the tag above shows the actual dims. | |
| tag(`${label}: running inference`); | |
| const { output } = await model({ input: pixel_values }); | |
| const maskData = output[0].data; // Float32Array | |
| const maskH = output[0].dims[1] || output.dims?.[2]; | |
| const maskW = output[0].dims[2] || output.dims?.[3]; | |
| // 5. Dispose tensors and model immediately so the mask-application step | |
| // has the heap to itself | |
| tag(`${label}: disposing model`); | |
| if (pixel_values.dispose) pixel_values.dispose(); | |
| if (output[0].dispose) output[0].dispose(); | |
| disposeModel(model); | |
| // 6. Apply mask to the output-sized image | |
| tag(`${label}: applying mask`); | |
| const bitmap = await createImageBitmap(outputBlob); | |
| const canvas = new OffscreenCanvas(bitmap.width, bitmap.height); | |
| const ctx = canvas.getContext('2d'); | |
| ctx.drawImage(bitmap, 0, 0); | |
| bitmap.close(); | |
| const imgData = ctx.getImageData(0, 0, canvas.width, canvas.height); | |
| const scaleX = maskW / canvas.width; | |
| const scaleY = maskH / canvas.height; | |
| const data = imgData.data; | |
| const cw = canvas.width; | |
| const ch = canvas.height; | |
| for (let y = 0; y < ch; y++) { | |
| const my = Math.min(Math.floor(y * scaleY), maskH - 1); | |
| const maskRow = my * maskW; | |
| const rowOffset = y * cw * 4; | |
| for (let x = 0; x < cw; x++) { | |
| const mx = Math.min(Math.floor(x * scaleX), maskW - 1); | |
| const v = maskData[maskRow + mx]; | |
| data[rowOffset + x * 4 + 3] = v <= 0 ? 0 : v >= 1 ? 255 : Math.round(v * 255); | |
| } | |
| } | |
| ctx.putImageData(imgData, 0, 0); | |
| tag(`${label}: encoding result`); | |
| const result = await canvas.convertToBlob({ type: 'image/png' }); | |
| if (onProgress) onProgress({ status: 'done', message: `${label}: done` }); | |
| return result; | |
| } | |
| function disposeModel(model) { | |
| try { | |
| if (model.dispose) model.dispose(); | |
| } catch (_) {} | |
| loadPromise = null; | |
| } | |
| export function isModelLoaded() { | |
| return loading; | |
| } | |
| export function isModelLoading() { | |
| return loading; | |
| } | |