saliacoel commited on
Commit
d9b69d9
·
verified ·
1 Parent(s): fac130f

Upload Inspyrenet_Rembg2.py

Browse files
Files changed (1) hide show
  1. Inspyrenet_Rembg2.py +178 -41
Inspyrenet_Rembg2.py CHANGED
@@ -1,6 +1,9 @@
1
  from PIL import Image
2
  import os
3
  import urllib.request
 
 
 
4
 
5
  import torch
6
  import numpy as np
@@ -22,7 +25,6 @@ def _ensure_ckpt_base():
22
  if os.path.isfile(CKPT_PATH) and os.path.getsize(CKPT_PATH) > 0:
23
  return
24
  except Exception:
25
- # If getsize fails for any reason, fall through to download attempt.
26
  pass
27
 
28
  os.makedirs(os.path.dirname(CKPT_PATH), exist_ok=True)
@@ -57,7 +59,6 @@ def _ensure_ckpt_base():
57
  os.replace(tmp_path, CKPT_PATH)
58
 
59
  finally:
60
- # Clean up partial download if something went wrong
61
  if os.path.isfile(tmp_path):
62
  try:
63
  os.remove(tmp_path)
@@ -68,7 +69,6 @@ def _ensure_ckpt_base():
68
  # Tensor to PIL
69
  def tensor2pil(image: torch.Tensor) -> Image.Image:
70
  arr = image.detach().cpu().numpy()
71
- # Handle accidental singleton batch dim
72
  if arr.ndim == 4 and arr.shape[0] == 1:
73
  arr = arr[0]
74
  arr = np.clip(255.0 * arr, 0, 255).astype(np.uint8)
@@ -82,11 +82,11 @@ def pil2tensor(image: Image.Image) -> torch.Tensor:
82
 
83
  def _rgba_to_rgb_on_white(pil_img: Image.Image) -> Image.Image:
84
  """
85
- 5) If input is RGBA:
86
- - alpha composite over WHITE background
87
- - convert to RGB (drop alpha)
88
- If input is RGB:
89
- - carry on
90
  """
91
  if pil_img.mode == "RGBA":
92
  bg = Image.new("RGBA", pil_img.size, (255, 255, 255, 255))
@@ -99,9 +99,134 @@ def _rgba_to_rgb_on_white(pil_img: Image.Image) -> Image.Image:
99
  return pil_img
100
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  class InspyrenetRembg2:
103
  """
104
- Original node kept (unchanged behavior/output), except it now ensures ckpt exists.
 
 
105
  """
106
  def __init__(self):
107
  pass
@@ -120,18 +245,29 @@ class InspyrenetRembg2:
120
  CATEGORY = "image"
121
 
122
  def remove_background(self, image, torchscript_jit):
123
- _ensure_ckpt_base()
124
-
125
- if (torchscript_jit == "default"):
126
- remover = Remover()
127
- else:
128
- remover = Remover(jit=True)
129
 
130
  img_list = []
131
- for img in tqdm(image, "Inspyrenet Rembg"):
132
- mid = remover.process(tensor2pil(img), type='rgba')
133
- out = pil2tensor(mid)
134
- img_list.append(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  img_stack = torch.cat(img_list, dim=0)
137
  mask = img_stack[:, :, :, 3]
@@ -140,12 +276,9 @@ class InspyrenetRembg2:
140
 
141
  class InspyrenetRembg3:
142
  """
143
- New node per requested changes:
144
- - ensures ckpt_base.pth exists (downloads if missing)
145
- - torchscript_jit hardcoded to "default" (no input, no JIT)
146
- - NO MASK output (IMAGE only)
147
- - if input is RGBA: composite over white, convert to RGB, then run remover
148
- - output remains RGBA (type='rgba')
149
  """
150
  def __init__(self):
151
  pass
@@ -163,24 +296,28 @@ class InspyrenetRembg3:
163
  CATEGORY = "image"
164
 
165
  def remove_background(self, image):
166
- _ensure_ckpt_base()
167
-
168
- # 3) hardcode torchscript_jit == "default"
169
- remover = Remover()
170
 
171
  img_list = []
172
- for img in tqdm(image, "Inspyrenet Rembg3"):
173
- pil_in = tensor2pil(img)
174
-
175
- # 5) normalize input to RGB for the model:
176
- # - if RGBA -> alpha composite on white -> RGB
177
- # - if RGB -> keep
178
- pil_rgb = _rgba_to_rgb_on_white(pil_in)
179
-
180
- # do functionality as usual, output RGBA
181
- mid = remover.process(pil_rgb, type="rgba")
182
- out = pil2tensor(mid)
183
- img_list.append(out)
 
 
 
 
 
 
 
184
 
185
  img_stack = torch.cat(img_list, dim=0)
186
  return (img_stack,)
 
1
  from PIL import Image
2
  import os
3
  import urllib.request
4
+ import gc
5
+ import threading
6
+ from typing import Dict, Tuple
7
 
8
  import torch
9
  import numpy as np
 
25
  if os.path.isfile(CKPT_PATH) and os.path.getsize(CKPT_PATH) > 0:
26
  return
27
  except Exception:
 
28
  pass
29
 
30
  os.makedirs(os.path.dirname(CKPT_PATH), exist_ok=True)
 
59
  os.replace(tmp_path, CKPT_PATH)
60
 
61
  finally:
 
62
  if os.path.isfile(tmp_path):
63
  try:
64
  os.remove(tmp_path)
 
69
  # Tensor to PIL
70
  def tensor2pil(image: torch.Tensor) -> Image.Image:
71
  arr = image.detach().cpu().numpy()
 
72
  if arr.ndim == 4 and arr.shape[0] == 1:
73
  arr = arr[0]
74
  arr = np.clip(255.0 * arr, 0, 255).astype(np.uint8)
 
82
 
83
  def _rgba_to_rgb_on_white(pil_img: Image.Image) -> Image.Image:
84
  """
85
+ If input is RGBA:
86
+ - alpha composite over WHITE background
87
+ - convert to RGB (drop alpha)
88
+ If input is RGB:
89
+ - carry on
90
  """
91
  if pil_img.mode == "RGBA":
92
  bg = Image.new("RGBA", pil_img.size, (255, 255, 255, 255))
 
99
  return pil_img
100
 
101
 
102
+ # -----------------------------------------------------------------------------
103
+ # Process-wide singleton Remover + OOM guard
104
+ # -----------------------------------------------------------------------------
105
+
106
+ # One cached Remover per (jit_flag,) for the entire Python process.
107
+ _REMOVER_CACHE: Dict[Tuple[bool], Remover] = {}
108
+ # Lock per remover to avoid concurrent .process() calls (prevents VRAM spikes).
109
+ _REMOVER_RUN_LOCKS: Dict[Tuple[bool], threading.Lock] = {}
110
+ # Protects cache/lock creation.
111
+ _CACHE_LOCK = threading.Lock()
112
+
113
+
114
+ def _is_oom_error(e: BaseException) -> bool:
115
+ # torch.cuda.OutOfMemoryError only exists on CUDA builds; guard it.
116
+ oom_cls = getattr(getattr(torch, "cuda", None), "OutOfMemoryError", None)
117
+ if oom_cls is not None and isinstance(e, oom_cls):
118
+ return True
119
+
120
+ msg = str(e).lower()
121
+ # Covers common RuntimeError("CUDA out of memory") patterns too.
122
+ return ("out of memory" in msg) and ("cuda" in msg or "cublas" in msg or "hip" in msg)
123
+
124
+
125
+ def _cuda_soft_cleanup() -> None:
126
+ """
127
+ Best-effort cleanup that should NOT evict "important" VRAM like model weights.
128
+
129
+ What it does:
130
+ - gc.collect(): drop dead Python objects sooner
131
+ - torch.cuda.empty_cache(): releases *unused* cached blocks back to the driver
132
+ - torch.cuda.ipc_collect(): helps in some multi-process cases
133
+
134
+ What it does NOT do:
135
+ - It does not unload models still referenced
136
+ - It does not free tensors that still have live references
137
+ """
138
+ try:
139
+ gc.collect()
140
+ except Exception:
141
+ pass
142
+
143
+ if torch.cuda.is_available():
144
+ try:
145
+ torch.cuda.synchronize()
146
+ except Exception:
147
+ pass
148
+ try:
149
+ torch.cuda.empty_cache()
150
+ except Exception:
151
+ pass
152
+ try:
153
+ torch.cuda.ipc_collect()
154
+ except Exception:
155
+ pass
156
+
157
+
158
+ def _get_remover(jit: bool = False) -> tuple[Remover, threading.Lock]:
159
+ """
160
+ Returns a cached Remover instance + a lock to serialize .process() calls.
161
+
162
+ - Only one Remover is constructed per jit setting for the entire process.
163
+ - If construction OOMs, we soft-clean and re-raise (and do NOT cache a broken instance).
164
+ """
165
+ key = (jit,)
166
+ with _CACHE_LOCK:
167
+ inst = _REMOVER_CACHE.get(key)
168
+ if inst is None:
169
+ _ensure_ckpt_base()
170
+ try:
171
+ inst = Remover(jit=jit) if jit else Remover()
172
+ except BaseException as e:
173
+ if _is_oom_error(e):
174
+ _cuda_soft_cleanup()
175
+ raise
176
+ _REMOVER_CACHE[key] = inst
177
+
178
+ run_lock = _REMOVER_RUN_LOCKS.get(key)
179
+ if run_lock is None:
180
+ run_lock = threading.Lock()
181
+ _REMOVER_RUN_LOCKS[key] = run_lock
182
+
183
+ return inst, run_lock
184
+
185
+
186
+ def _remover_process_safe(
187
+ remover: Remover,
188
+ run_lock: threading.Lock,
189
+ pil_img: Image.Image,
190
+ *,
191
+ out_type: str,
192
+ retries: int = 1,
193
+ ) -> Image.Image:
194
+ """
195
+ Runs remover.process() under:
196
+ - a lock (avoid concurrent VRAM spikes),
197
+ - torch.inference_mode() (less VRAM),
198
+ - OOM catch -> soft cleanup -> retry.
199
+
200
+ If it still OOMs after retries, it raises.
201
+ """
202
+ last_err: BaseException | None = None
203
+
204
+ for attempt in range(retries + 1):
205
+ try:
206
+ with run_lock:
207
+ with torch.inference_mode():
208
+ return remover.process(pil_img, type=out_type)
209
+ except BaseException as e:
210
+ last_err = e
211
+ if _is_oom_error(e):
212
+ _cuda_soft_cleanup()
213
+ if attempt < retries:
214
+ continue
215
+ raise
216
+
217
+ # Shouldn't hit, but keeps type-checkers happy.
218
+ raise last_err if last_err is not None else RuntimeError("Unknown failure in _remover_process_safe()")
219
+
220
+
221
+ # -----------------------------------------------------------------------------
222
+ # Nodes
223
+ # -----------------------------------------------------------------------------
224
+
225
  class InspyrenetRembg2:
226
  """
227
+ Original node behavior/output kept, but:
228
+ - Remover is now a process-wide singleton (per jit flag)
229
+ - OOM is caught -> soft CUDA cleanup -> retry once -> then raise
230
  """
231
  def __init__(self):
232
  pass
 
245
  CATEGORY = "image"
246
 
247
  def remove_background(self, image, torchscript_jit):
248
+ jit = (torchscript_jit != "default")
249
+ remover, run_lock = _get_remover(jit=jit)
 
 
 
 
250
 
251
  img_list = []
252
+ try:
253
+ for img in tqdm(image, "Inspyrenet Rembg"):
254
+ pil_in = tensor2pil(img)
255
+ mid = _remover_process_safe(remover, run_lock, pil_in, out_type="rgba", retries=1)
256
+ out = pil2tensor(mid)
257
+ img_list.append(out)
258
+
259
+ # Help Python drop refs earlier (mostly relevant if exceptions occur).
260
+ del pil_in, mid, out
261
+
262
+ except BaseException as e:
263
+ if _is_oom_error(e):
264
+ # Ensure cache cleanup already happened; do another best-effort pass.
265
+ _cuda_soft_cleanup()
266
+ raise RuntimeError(
267
+ "InspyrenetRembg2: CUDA out of memory during background removal. "
268
+ "Freed PyTorch CUDA cache (unused blocks) and retried once; still failed."
269
+ ) from e
270
+ raise
271
 
272
  img_stack = torch.cat(img_list, dim=0)
273
  mask = img_stack[:, :, :, 3]
 
276
 
277
  class InspyrenetRembg3:
278
  """
279
+ New node per your existing changes, plus:
280
+ - singleton Remover reuse
281
+ - OOM catch -> soft cleanup -> retry once -> then raise
 
 
 
282
  """
283
  def __init__(self):
284
  pass
 
296
  CATEGORY = "image"
297
 
298
  def remove_background(self, image):
299
+ remover, run_lock = _get_remover(jit=False)
 
 
 
300
 
301
  img_list = []
302
+ try:
303
+ for img in tqdm(image, "Inspyrenet Rembg3"):
304
+ pil_in = tensor2pil(img)
305
+ pil_rgb = _rgba_to_rgb_on_white(pil_in)
306
+
307
+ mid = _remover_process_safe(remover, run_lock, pil_rgb, out_type="rgba", retries=1)
308
+ out = pil2tensor(mid)
309
+ img_list.append(out)
310
+
311
+ del pil_in, pil_rgb, mid, out
312
+
313
+ except BaseException as e:
314
+ if _is_oom_error(e):
315
+ _cuda_soft_cleanup()
316
+ raise RuntimeError(
317
+ "InspyrenetRembg3: CUDA out of memory during background removal. "
318
+ "Freed PyTorch CUDA cache (unused blocks) and retried once; still failed."
319
+ ) from e
320
+ raise
321
 
322
  img_stack = torch.cat(img_list, dim=0)
323
  return (img_stack,)