Spaces:
Sleeping
Sleeping
File size: 8,206 Bytes
a125618 a2d572d 76e689f a2d572d a125618 76e689f a125618 76e689f a2d572d 76e689f a2d572d 76e689f a2d572d a125618 a2d572d a125618 a2d572d a125618 76e689f a125618 76e689f a125618 76e689f a125618 76e689f a125618 76e689f a125618 a2d572d a125618 76e689f a125618 76e689f a125618 a2d572d a125618 a2d572d a125618 a2d572d a125618 76e689f a125618 76e689f a2d572d 76e689f a2d572d 76e689f a125618 a2d572d a125618 a2d572d a125618 a2d572d a125618 a2d572d 76e689f a2d572d 76e689f a2d572d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 | // Background removal using @huggingface/transformers with MODNet.
// Runs entirely in the browser (WebGPU with WASM fallback).
//
// We use Xenova/modnet (portrait matting, ~6.5M params) instead of
// briaai/RMBG-1.4 because RMBG's ONNX has a fixed [1,3,1024,1024] input
// shape. U²-Net activations at 1024x1024 blow the WebContent heap on
// phones without WebGPU, triggering a Jetsam kill mid-inference. MODNet
// accepts dynamic shapes and preprocesses to a 512-short-edge image,
// which is ~4x less activation memory and actually fits on mobile WASM.
let transformers = null;
let loading = false;
let loadPromise = null;
// MODNet preprocesses to 512 on the short edge, so activation memory is
// roughly proportional to the output image we feed it. OUTPUT_DIM caps the
// saved image (and indirectly the model's working resolution).
const OUTPUT_DIM = 768;
const MODEL_ID = 'Xenova/modnet';
// Workarounds for transformers.js + onnxruntime-web on iOS Safari
// (and memory-constrained mobile browsers in general).
//
// Cribbed from cmorenogit/app-profile@01081eb, which hit the same
// "tab silently dies during model load/inference" problem:
//
// 1. numThreads=1 — multi-threaded WASM on JavaScriptCore (iOS Safari)
// spirals activation memory across worker threads. Issue #1242.
// Slower, but single-threaded stays under the tab budget.
// 2. useBrowserCache=false — the Cache API adds a full-size *copy* of
// the model during loading (fetch buffer -> Cache -> WASM heap),
// tripling peak. Re-downloading each session is cheaper than
// crashing.
// 3. No WebGPU on iOS — ONNX Runtime JSEP bug #26827 breaks it on
// iOS 18.x Safari anyway; we were already on WASM in practice.
//
// These are harmless on Android/desktop (just slightly slower loads),
// so we apply them unconditionally rather than sniffing UA.
export async function getTransformers() {
if (transformers) return transformers;
transformers = await import(
'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3'
);
const env = transformers.env;
env.allowLocalModels = false;
// (1) no Cache API copy — fetch goes straight to WASM heap
env.useBrowserCache = false;
// (2) single-threaded WASM — kills the multi-copy activation spiral
if (env.backends?.onnx?.wasm) {
env.backends.onnx.wasm.numThreads = 1;
}
return transformers;
}
async function loadModel(onProgress) {
if (loadPromise) return loadPromise;
loading = true;
loadPromise = _doLoad(onProgress);
try {
return await loadPromise;
} finally {
loading = false;
loadPromise = null;
}
}
async function _doLoad(onProgress) {
const { AutoModel, AutoProcessor } = await getTransformers();
if (onProgress) onProgress({ status: 'loading', message: 'Loading AI model...' });
const hasWebGPU = typeof navigator !== 'undefined' && !!navigator.gpu;
const device = hasWebGPU ? 'webgpu' : 'wasm';
const dtypeProgress = (p) => {
if (onProgress && p.progress != null) {
onProgress({
status: 'downloading',
message: `Downloading model: ${Math.round(p.progress)}%`,
progress: p.progress,
});
}
};
// Try smallest dtype first to keep peak memory low on phones.
const dtypePreference = device === 'webgpu'
? ['fp16', 'q8', 'fp32']
: ['q8', 'fp16', 'fp32'];
let model = null;
let loadedDtype = null;
let lastErr = null;
for (const dtype of dtypePreference) {
try {
model = await AutoModel.from_pretrained(MODEL_ID, {
device,
dtype,
progress_callback: dtypeProgress,
});
loadedDtype = dtype;
console.log(`[segmentation] loaded ${MODEL_ID} with dtype=${dtype}, device=${device}`);
break;
} catch (err) {
console.warn(`[segmentation] dtype=${dtype} failed:`, err && err.message);
lastErr = err;
}
}
if (!model) throw lastErr || new Error(`Failed to load ${MODEL_ID}`);
const processor = await AutoProcessor.from_pretrained(MODEL_ID);
return { model, processor, dtype: loadedDtype, device };
}
/**
* Decode the camera image once, downscaled to OUTPUT_DIM. The model's own
* processor will further resize this to short-edge 512 for inference, so
* the same blob is used both as the model input and as the canvas we apply
* the final mask to.
*/
async function prepareOutputImage(imageBlob) {
const probe = await createImageBitmap(imageBlob);
const fullW = probe.width;
const fullH = probe.height;
probe.close();
const ratio = Math.min(1, OUTPUT_DIM / Math.max(fullW, fullH));
const w = Math.round(fullW * ratio);
const h = Math.round(fullH * ratio);
const bitmap = await createImageBitmap(imageBlob, {
resizeWidth: w, resizeHeight: h, resizeQuality: 'medium',
});
const canvas = new OffscreenCanvas(w, h);
canvas.getContext('2d').drawImage(bitmap, 0, 0);
bitmap.close();
return canvas.convertToBlob({ type: 'image/png' });
}
/**
* Remove background from an image blob.
* @param {Blob} imageBlob - JPEG/PNG image
* @param {function} onProgress - progress callback
* @returns {Promise<Blob>} - PNG with transparent background
*/
export async function removeBackground(imageBlob, onProgress) {
const tag = (message) => {
if (onProgress) onProgress({ status: 'processing', message });
};
// 1. Downscale the camera photo BEFORE loading the model (lower peak mem)
tag('bg: decoding photo');
const outputBlob = await prepareOutputImage(imageBlob);
// 2. Load model
tag('bg: loading model');
const { model, processor, dtype, device } = await loadModel(onProgress);
const { RawImage } = await getTransformers();
const label = `bg[${dtype}/${device}]`;
// 3. Preprocess via the model's own processor. MODNet takes dynamic
// input sizes, short edge 512, so the working resolution scales
// with the output blob we feed it.
tag(`${label}: preprocessing`);
const url = URL.createObjectURL(outputBlob);
const image = await RawImage.fromURL(url);
URL.revokeObjectURL(url);
const { pixel_values } = await processor(image);
tag(`${label}: tensor ${pixel_values.dims.join('x')}`);
// 4. Forward pass — biggest memory peak. MODNet activations scale with
// the input tensor size; the tag above shows the actual dims.
tag(`${label}: running inference`);
const { output } = await model({ input: pixel_values });
const maskData = output[0].data; // Float32Array
const maskH = output[0].dims[1] || output.dims?.[2];
const maskW = output[0].dims[2] || output.dims?.[3];
// 5. Dispose tensors and model immediately so the mask-application step
// has the heap to itself
tag(`${label}: disposing model`);
if (pixel_values.dispose) pixel_values.dispose();
if (output[0].dispose) output[0].dispose();
disposeModel(model);
// 6. Apply mask to the output-sized image
tag(`${label}: applying mask`);
const bitmap = await createImageBitmap(outputBlob);
const canvas = new OffscreenCanvas(bitmap.width, bitmap.height);
const ctx = canvas.getContext('2d');
ctx.drawImage(bitmap, 0, 0);
bitmap.close();
const imgData = ctx.getImageData(0, 0, canvas.width, canvas.height);
const scaleX = maskW / canvas.width;
const scaleY = maskH / canvas.height;
const data = imgData.data;
const cw = canvas.width;
const ch = canvas.height;
for (let y = 0; y < ch; y++) {
const my = Math.min(Math.floor(y * scaleY), maskH - 1);
const maskRow = my * maskW;
const rowOffset = y * cw * 4;
for (let x = 0; x < cw; x++) {
const mx = Math.min(Math.floor(x * scaleX), maskW - 1);
const v = maskData[maskRow + mx];
data[rowOffset + x * 4 + 3] = v <= 0 ? 0 : v >= 1 ? 255 : Math.round(v * 255);
}
}
ctx.putImageData(imgData, 0, 0);
tag(`${label}: encoding result`);
const result = await canvas.convertToBlob({ type: 'image/png' });
if (onProgress) onProgress({ status: 'done', message: `${label}: done` });
return result;
}
function disposeModel(model) {
try {
if (model.dispose) model.dispose();
} catch (_) {}
loadPromise = null;
}
export function isModelLoaded() {
return loading;
}
export function isModelLoading() {
return loading;
}
|