File size: 8,206 Bytes
a125618
 
 
 
 
 
 
 
 
a2d572d
76e689f
a2d572d
 
 
a125618
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76e689f
 
 
 
a125618
 
 
 
 
 
 
76e689f
 
 
 
 
a2d572d
 
 
 
 
76e689f
a2d572d
 
 
 
 
 
 
76e689f
a2d572d
 
 
a125618
 
a2d572d
a125618
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2d572d
a125618
 
 
76e689f
 
 
a125618
 
 
 
76e689f
a125618
76e689f
 
 
 
 
a125618
 
 
76e689f
a125618
 
76e689f
a125618
 
 
 
a2d572d
 
 
 
 
 
 
 
 
a125618
 
 
 
 
 
 
76e689f
 
a125618
 
76e689f
a125618
a2d572d
a125618
 
 
 
 
a2d572d
 
 
a125618
 
 
 
 
a2d572d
 
 
 
 
 
a125618
 
 
76e689f
 
 
 
a125618
 
76e689f
a2d572d
 
 
76e689f
a2d572d
76e689f
 
 
a125618
 
 
 
 
 
 
 
 
a2d572d
a125618
 
a2d572d
 
 
 
a125618
 
a2d572d
a125618
 
a2d572d
 
76e689f
 
 
 
 
 
 
a2d572d
76e689f
a2d572d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
// Background removal using @huggingface/transformers with MODNet.
// Runs entirely in the browser (WebGPU with WASM fallback).
//
// We use Xenova/modnet (portrait matting, ~6.5M params) instead of
// briaai/RMBG-1.4 because RMBG's ONNX has a fixed [1,3,1024,1024] input
// shape. U²-Net activations at 1024x1024 blow the WebContent heap on
// phones without WebGPU, triggering a Jetsam kill mid-inference. MODNet
// accepts dynamic shapes and preprocesses to a 512-short-edge image,
// which is ~4x less activation memory and actually fits on mobile WASM.

let transformers = null;
let loading = false;
let loadPromise = null;

// MODNet preprocesses to 512 on the short edge, so activation memory is
// roughly proportional to the output image we feed it. OUTPUT_DIM caps the
// saved image (and indirectly the model's working resolution).
const OUTPUT_DIM = 768;
const MODEL_ID = 'Xenova/modnet';

// Workarounds for transformers.js + onnxruntime-web on iOS Safari
// (and memory-constrained mobile browsers in general).
//
// Cribbed from cmorenogit/app-profile@01081eb, which hit the same
// "tab silently dies during model load/inference" problem:
//
//   1. numThreads=1 — multi-threaded WASM on JavaScriptCore (iOS Safari)
//      spirals activation memory across worker threads. Issue #1242.
//      Slower, but single-threaded stays under the tab budget.
//   2. useBrowserCache=false — the Cache API adds a full-size *copy* of
//      the model during loading (fetch buffer -> Cache -> WASM heap),
//      tripling peak. Re-downloading each session is cheaper than
//      crashing.
//   3. No WebGPU on iOS — ONNX Runtime JSEP bug #26827 breaks it on
//      iOS 18.x Safari anyway; we were already on WASM in practice.
//
// These are harmless on Android/desktop (just slightly slower loads),
// so we apply them unconditionally rather than sniffing UA.
export async function getTransformers() {
  if (transformers) return transformers;
  transformers = await import(
    'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3'
  );
  const env = transformers.env;
  env.allowLocalModels = false;
  // (1) no Cache API copy — fetch goes straight to WASM heap
  env.useBrowserCache = false;
  // (2) single-threaded WASM — kills the multi-copy activation spiral
  if (env.backends?.onnx?.wasm) {
    env.backends.onnx.wasm.numThreads = 1;
  }
  return transformers;
}

async function loadModel(onProgress) {
  if (loadPromise) return loadPromise;

  loading = true;
  loadPromise = _doLoad(onProgress);
  try {
    return await loadPromise;
  } finally {
    loading = false;
    loadPromise = null;
  }
}

async function _doLoad(onProgress) {
  const { AutoModel, AutoProcessor } = await getTransformers();

  if (onProgress) onProgress({ status: 'loading', message: 'Loading AI model...' });

  const hasWebGPU = typeof navigator !== 'undefined' && !!navigator.gpu;
  const device = hasWebGPU ? 'webgpu' : 'wasm';

  const dtypeProgress = (p) => {
    if (onProgress && p.progress != null) {
      onProgress({
        status: 'downloading',
        message: `Downloading model: ${Math.round(p.progress)}%`,
        progress: p.progress,
      });
    }
  };

  // Try smallest dtype first to keep peak memory low on phones.
  const dtypePreference = device === 'webgpu'
    ? ['fp16', 'q8', 'fp32']
    : ['q8', 'fp16', 'fp32'];

  let model = null;
  let loadedDtype = null;
  let lastErr = null;
  for (const dtype of dtypePreference) {
    try {
      model = await AutoModel.from_pretrained(MODEL_ID, {
        device,
        dtype,
        progress_callback: dtypeProgress,
      });
      loadedDtype = dtype;
      console.log(`[segmentation] loaded ${MODEL_ID} with dtype=${dtype}, device=${device}`);
      break;
    } catch (err) {
      console.warn(`[segmentation] dtype=${dtype} failed:`, err && err.message);
      lastErr = err;
    }
  }
  if (!model) throw lastErr || new Error(`Failed to load ${MODEL_ID}`);

  const processor = await AutoProcessor.from_pretrained(MODEL_ID);

  return { model, processor, dtype: loadedDtype, device };
}

/**
 * Decode the camera image once, downscaled to OUTPUT_DIM. The model's own
 * processor will further resize this to short-edge 512 for inference, so
 * the same blob is used both as the model input and as the canvas we apply
 * the final mask to.
 */
async function prepareOutputImage(imageBlob) {
  const probe = await createImageBitmap(imageBlob);
  const fullW = probe.width;
  const fullH = probe.height;
  probe.close();

  const ratio = Math.min(1, OUTPUT_DIM / Math.max(fullW, fullH));
  const w = Math.round(fullW * ratio);
  const h = Math.round(fullH * ratio);

  const bitmap = await createImageBitmap(imageBlob, {
    resizeWidth: w, resizeHeight: h, resizeQuality: 'medium',
  });
  const canvas = new OffscreenCanvas(w, h);
  canvas.getContext('2d').drawImage(bitmap, 0, 0);
  bitmap.close();
  return canvas.convertToBlob({ type: 'image/png' });
}

/**
 * Remove background from an image blob.
 * @param {Blob} imageBlob - JPEG/PNG image
 * @param {function} onProgress - progress callback
 * @returns {Promise<Blob>} - PNG with transparent background
 */
export async function removeBackground(imageBlob, onProgress) {
  const tag = (message) => {
    if (onProgress) onProgress({ status: 'processing', message });
  };

  // 1. Downscale the camera photo BEFORE loading the model (lower peak mem)
  tag('bg: decoding photo');
  const outputBlob = await prepareOutputImage(imageBlob);

  // 2. Load model
  tag('bg: loading model');
  const { model, processor, dtype, device } = await loadModel(onProgress);
  const { RawImage } = await getTransformers();
  const label = `bg[${dtype}/${device}]`;

  // 3. Preprocess via the model's own processor. MODNet takes dynamic
  //    input sizes, short edge 512, so the working resolution scales
  //    with the output blob we feed it.
  tag(`${label}: preprocessing`);
  const url = URL.createObjectURL(outputBlob);
  const image = await RawImage.fromURL(url);
  URL.revokeObjectURL(url);
  const { pixel_values } = await processor(image);
  tag(`${label}: tensor ${pixel_values.dims.join('x')}`);

  // 4. Forward pass — biggest memory peak. MODNet activations scale with
  //    the input tensor size; the tag above shows the actual dims.
  tag(`${label}: running inference`);
  const { output } = await model({ input: pixel_values });

  const maskData = output[0].data; // Float32Array
  const maskH = output[0].dims[1] || output.dims?.[2];
  const maskW = output[0].dims[2] || output.dims?.[3];

  // 5. Dispose tensors and model immediately so the mask-application step
  //    has the heap to itself
  tag(`${label}: disposing model`);
  if (pixel_values.dispose) pixel_values.dispose();
  if (output[0].dispose) output[0].dispose();
  disposeModel(model);

  // 6. Apply mask to the output-sized image
  tag(`${label}: applying mask`);
  const bitmap = await createImageBitmap(outputBlob);
  const canvas = new OffscreenCanvas(bitmap.width, bitmap.height);
  const ctx = canvas.getContext('2d');
  ctx.drawImage(bitmap, 0, 0);
  bitmap.close();

  const imgData = ctx.getImageData(0, 0, canvas.width, canvas.height);
  const scaleX = maskW / canvas.width;
  const scaleY = maskH / canvas.height;
  const data = imgData.data;
  const cw = canvas.width;
  const ch = canvas.height;

  for (let y = 0; y < ch; y++) {
    const my = Math.min(Math.floor(y * scaleY), maskH - 1);
    const maskRow = my * maskW;
    const rowOffset = y * cw * 4;
    for (let x = 0; x < cw; x++) {
      const mx = Math.min(Math.floor(x * scaleX), maskW - 1);
      const v = maskData[maskRow + mx];
      data[rowOffset + x * 4 + 3] = v <= 0 ? 0 : v >= 1 ? 255 : Math.round(v * 255);
    }
  }
  ctx.putImageData(imgData, 0, 0);

  tag(`${label}: encoding result`);
  const result = await canvas.convertToBlob({ type: 'image/png' });

  if (onProgress) onProgress({ status: 'done', message: `${label}: done` });
  return result;
}

function disposeModel(model) {
  try {
    if (model.dispose) model.dispose();
  } catch (_) {}
  loadPromise = null;
}

export function isModelLoaded() {
  return loading;
}

export function isModelLoading() {
  return loading;
}