saliacoel commited on
Commit
94419b7
·
verified ·
1 Parent(s): e01e207

Upload Inspyrenet_RembgX2.py

Browse files
Files changed (1) hide show
  1. Inspyrenet_RembgX2.py +734 -2
Inspyrenet_RembgX2.py CHANGED
@@ -1,4 +1,736 @@
 
1
 
 
 
 
 
 
 
2
 
3
- (Summarize the EMOTIONAL/RELATIONSHIP of Kara and Nick, what does she think about Nick right now?
4
- Write as Kara POV, in her internal monologue, thinking to herself/speaking to herself in her own mind, about Nick. It will be accessed later / permanently as a memory. This is the first ever memory, so you must mention and describe everything from scratch. About the first time they met, about the name of the boy, what he looked like etc. Assume the reader has no idea who and what you are talking about and has no context. Kara's Internal Memory Log – First Entry: Nick. Do not write like a machine/protocoll, write more natural.)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
 
3
+ from PIL import Image
4
+ import os
5
+ import urllib.request
6
+ import gc
7
+ import threading
8
+ from typing import Dict, Tuple, Optional
9
 
10
+ import torch
11
+ import numpy as np
12
+ from transparent_background import Remover
13
+ from tqdm import tqdm
14
+
15
+
16
+ # Optional: ComfyUI memory manager (present inside ComfyUI)
17
+ try:
18
+ import comfy.model_management as comfy_mm
19
+ except Exception:
20
+ comfy_mm = None
21
+
22
+
23
+ CKPT_PATH = "/root/.transparent-background/ckpt_base.pth"
24
+ CKPT_URL = "https://huggingface.co/saliacoel/x/resolve/main/ckpt_base.pth"
25
+
26
+
27
+ def _ensure_ckpt_base():
28
+ try:
29
+ if os.path.isfile(CKPT_PATH) and os.path.getsize(CKPT_PATH) > 0:
30
+ return
31
+ except Exception:
32
+ pass
33
+
34
+ os.makedirs(os.path.dirname(CKPT_PATH), exist_ok=True)
35
+ tmp_path = CKPT_PATH + ".tmp"
36
+
37
+ try:
38
+ with urllib.request.urlopen(CKPT_URL) as resp:
39
+ total = resp.headers.get("Content-Length")
40
+ total = int(total) if total is not None else None
41
+
42
+ with open(tmp_path, "wb") as f:
43
+ if total:
44
+ with tqdm(
45
+ total=total,
46
+ unit="B",
47
+ unit_scale=True,
48
+ desc="Downloading ckpt_base.pth",
49
+ ) as pbar:
50
+ while True:
51
+ chunk = resp.read(1024 * 1024)
52
+ if not chunk:
53
+ break
54
+ f.write(chunk)
55
+ pbar.update(len(chunk))
56
+ else:
57
+ while True:
58
+ chunk = resp.read(1024 * 1024)
59
+ if not chunk:
60
+ break
61
+ f.write(chunk)
62
+
63
+ os.replace(tmp_path, CKPT_PATH)
64
+ finally:
65
+ if os.path.isfile(tmp_path):
66
+ try:
67
+ os.remove(tmp_path)
68
+ except Exception:
69
+ pass
70
+
71
+
72
+ # Tensor to PIL
73
+ def tensor2pil(image: torch.Tensor) -> Image.Image:
74
+ arr = image.detach().cpu().numpy()
75
+ if arr.ndim == 4 and arr.shape[0] == 1:
76
+ arr = arr[0]
77
+ arr = np.clip(255.0 * arr, 0, 255).astype(np.uint8)
78
+ return Image.fromarray(arr)
79
+
80
+
81
+ # Convert PIL to Tensor
82
+ def pil2tensor(image: Image.Image) -> torch.Tensor:
83
+ return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
84
+
85
+
86
+ def _rgba_to_rgb_on_white(pil_img: Image.Image) -> Image.Image:
87
+ if pil_img.mode == "RGBA":
88
+ bg = Image.new("RGBA", pil_img.size, (255, 255, 255, 255))
89
+ composited = Image.alpha_composite(bg, pil_img)
90
+ return composited.convert("RGB")
91
+
92
+ if pil_img.mode != "RGB":
93
+ return pil_img.convert("RGB")
94
+
95
+ return pil_img
96
+
97
+
98
+ def _force_rgba_opaque(pil_img: Image.Image) -> Image.Image:
99
+ """
100
+ Opaque RGBA fallback (alpha=255), so you never get an "invisible" output.
101
+ """
102
+ rgba = pil_img.convert("RGBA")
103
+ r, g, b, _a = rgba.split()
104
+ a = Image.new("L", rgba.size, 255)
105
+ return Image.merge("RGBA", (r, g, b, a))
106
+
107
+
108
+ def _alpha_is_all_zero(pil_img: Image.Image) -> bool:
109
+ """
110
+ True if RGBA image alpha channel is entirely 0.
111
+ """
112
+ if pil_img.mode != "RGBA":
113
+ return False
114
+ try:
115
+ extrema = pil_img.getextrema() # ((min,max),(min,max),(min,max),(min,max))
116
+ return extrema[3][1] == 0
117
+ except Exception:
118
+ return False
119
+
120
+
121
+ def _is_oom_error(e: BaseException) -> bool:
122
+ oom_cuda_cls = getattr(getattr(torch, "cuda", None), "OutOfMemoryError", None)
123
+ if oom_cuda_cls is not None and isinstance(e, oom_cuda_cls):
124
+ return True
125
+
126
+ oom_torch_cls = getattr(torch, "OutOfMemoryError", None)
127
+ if oom_torch_cls is not None and isinstance(e, oom_torch_cls):
128
+ return True
129
+
130
+ msg = str(e).lower()
131
+ if "out of memory" in msg:
132
+ return True
133
+ if "allocation on device" in msg:
134
+ return True
135
+ return ("cuda" in msg or "cublas" in msg or "hip" in msg) and ("memory" in msg)
136
+
137
+
138
+ def _cuda_soft_cleanup() -> None:
139
+ try:
140
+ gc.collect()
141
+ except Exception:
142
+ pass
143
+
144
+ if torch.cuda.is_available():
145
+ try:
146
+ torch.cuda.synchronize()
147
+ except Exception:
148
+ pass
149
+ try:
150
+ torch.cuda.empty_cache()
151
+ except Exception:
152
+ pass
153
+ try:
154
+ torch.cuda.ipc_collect()
155
+ except Exception:
156
+ pass
157
+
158
+
159
+ def _comfy_soft_empty_cache() -> None:
160
+ if comfy_mm is None:
161
+ return
162
+ if hasattr(comfy_mm, "soft_empty_cache"):
163
+ try:
164
+ comfy_mm.soft_empty_cache(force=True)
165
+ except TypeError:
166
+ try:
167
+ comfy_mm.soft_empty_cache()
168
+ except Exception:
169
+ pass
170
+ except Exception:
171
+ pass
172
+
173
+
174
+ def _get_comfy_torch_device() -> torch.device:
175
+ """
176
+ Always prefer ComfyUI's chosen device.
177
+ """
178
+ if comfy_mm is not None and hasattr(comfy_mm, "get_torch_device"):
179
+ try:
180
+ d = comfy_mm.get_torch_device()
181
+ if isinstance(d, torch.device):
182
+ return d
183
+ return torch.device(str(d))
184
+ except Exception:
185
+ pass
186
+
187
+ if torch.cuda.is_available():
188
+ return torch.device("cuda:0")
189
+ return torch.device("cpu")
190
+
191
+
192
+ def _set_current_cuda_device(dev: torch.device) -> None:
193
+ """
194
+ Make sure mem_get_info() measurements are on the same device ComfyUI uses.
195
+ """
196
+ if dev.type == "cuda":
197
+ try:
198
+ if dev.index is not None:
199
+ torch.cuda.set_device(dev.index)
200
+ except Exception:
201
+ pass
202
+
203
+
204
+ def _cuda_free_bytes_on(dev: torch.device) -> Optional[int]:
205
+ if dev.type != "cuda" or not torch.cuda.is_available():
206
+ return None
207
+ try:
208
+ _set_current_cuda_device(dev)
209
+ free_b, _total_b = torch.cuda.mem_get_info()
210
+ return int(free_b)
211
+ except Exception:
212
+ return None
213
+
214
+
215
+ def _comfy_unload_one_smallest_model() -> bool:
216
+ """
217
+ Best-effort "smallest-first" eviction of one ComfyUI-tracked loaded model.
218
+
219
+ If ComfyUI internals differ, this may do nothing (and we fall back to unload_all_models()).
220
+ """
221
+ if comfy_mm is None:
222
+ return False
223
+ if not hasattr(comfy_mm, "current_loaded_models"):
224
+ return False
225
+
226
+ try:
227
+ cur_dev = _get_comfy_torch_device()
228
+ except Exception:
229
+ cur_dev = None
230
+
231
+ models = []
232
+ try:
233
+ for lm in list(comfy_mm.current_loaded_models):
234
+ try:
235
+ # Prefer same device
236
+ lm_dev = getattr(lm, "device", None)
237
+ if cur_dev is not None and lm_dev is not None and str(lm_dev) != str(cur_dev):
238
+ continue
239
+
240
+ mem_fn = getattr(lm, "model_loaded_memory", None)
241
+ if callable(mem_fn):
242
+ mem = int(mem_fn())
243
+ else:
244
+ mem = int(getattr(lm, "loaded_memory", 0) or 0)
245
+
246
+ if mem > 0:
247
+ models.append((mem, lm))
248
+ except Exception:
249
+ continue
250
+ except Exception:
251
+ return False
252
+
253
+ if not models:
254
+ return False
255
+
256
+ models.sort(key=lambda x: x[0]) # smallest first
257
+ _mem, lm = models[0]
258
+
259
+ try:
260
+ unload_fn = getattr(lm, "model_unload", None)
261
+ if callable(unload_fn):
262
+ try:
263
+ unload_fn(unpatch_weights=True)
264
+ except TypeError:
265
+ unload_fn()
266
+ except Exception:
267
+ pass
268
+
269
+ # Cleanup hook if present
270
+ try:
271
+ cleanup = getattr(comfy_mm, "cleanup_models", None)
272
+ if callable(cleanup):
273
+ cleanup()
274
+ except Exception:
275
+ pass
276
+
277
+ _comfy_soft_empty_cache()
278
+ _cuda_soft_cleanup()
279
+ return True
280
+
281
+
282
+ def _comfy_unload_all_models() -> None:
283
+ if comfy_mm is None:
284
+ return
285
+ if hasattr(comfy_mm, "unload_all_models"):
286
+ try:
287
+ comfy_mm.unload_all_models()
288
+ except Exception:
289
+ pass
290
+ _comfy_soft_empty_cache()
291
+ _cuda_soft_cleanup()
292
+
293
+
294
+ # -----------------------------------------------------------------------------
295
+ # Existing singleton cache for Rembg2/Rembg3 (your original)
296
+ # -----------------------------------------------------------------------------
297
+
298
+ _REMOVER_CACHE: Dict[Tuple[bool], Remover] = {}
299
+ _REMOVER_RUN_LOCKS: Dict[Tuple[bool], threading.Lock] = {}
300
+ _CACHE_LOCK = threading.Lock()
301
+
302
+
303
+ def _get_remover(jit: bool = False) -> tuple[Remover, threading.Lock]:
304
+ key = (jit,)
305
+ with _CACHE_LOCK:
306
+ inst = _REMOVER_CACHE.get(key)
307
+ if inst is None:
308
+ _ensure_ckpt_base()
309
+ try:
310
+ inst = Remover(jit=jit) if jit else Remover()
311
+ except BaseException as e:
312
+ if _is_oom_error(e):
313
+ _cuda_soft_cleanup()
314
+ raise
315
+ _REMOVER_CACHE[key] = inst
316
+
317
+ run_lock = _REMOVER_RUN_LOCKS.get(key)
318
+ if run_lock is None:
319
+ run_lock = threading.Lock()
320
+ _REMOVER_RUN_LOCKS[key] = run_lock
321
+
322
+ return inst, run_lock
323
+
324
+
325
+ # -----------------------------------------------------------------------------
326
+ # GLOBAL remover (for Load/Remove/Run Global nodes)
327
+ # -----------------------------------------------------------------------------
328
+
329
+ _GLOBAL_LOCK = threading.Lock()
330
+ _GLOBAL_RUN_LOCK = threading.Lock()
331
+ _GLOBAL_REMOVER: Optional[Remover] = None
332
+ _GLOBAL_ON_DEVICE: str = "cpu"
333
+ _GLOBAL_VRAM_DELTA_BYTES: int = 0
334
+
335
+
336
+ def _create_global_remover_cpu() -> Remover:
337
+ """
338
+ Create the Remover configured like InspyrenetRembg3 (jit=False),
339
+ but *try* to force CPU init to avoid VRAM OOM during creation.
340
+ """
341
+ _ensure_ckpt_base()
342
+
343
+ # Prefer constructing on CPU if supported by this library version.
344
+ try:
345
+ r = Remover(device="cpu") # type: ignore[arg-type]
346
+ try:
347
+ r.device = "cpu"
348
+ except Exception:
349
+ pass
350
+ return r
351
+ except TypeError:
352
+ pass
353
+
354
+ # Fallback: construct default and immediately offload to CPU
355
+ r = Remover()
356
+ try:
357
+ if hasattr(r, "model"):
358
+ r.model = r.model.to("cpu")
359
+ r.device = "cpu"
360
+ except Exception:
361
+ pass
362
+ _cuda_soft_cleanup()
363
+ return r
364
+
365
+
366
+ def _get_global_remover() -> Remover:
367
+ global _GLOBAL_REMOVER, _GLOBAL_ON_DEVICE
368
+ with _GLOBAL_LOCK:
369
+ if _GLOBAL_REMOVER is None:
370
+ _GLOBAL_REMOVER = _create_global_remover_cpu()
371
+ _GLOBAL_ON_DEVICE = str(getattr(_GLOBAL_REMOVER, "device", "cpu"))
372
+ return _GLOBAL_REMOVER
373
+
374
+
375
+ def _move_global_to_cpu() -> None:
376
+ global _GLOBAL_ON_DEVICE
377
+ r = _get_global_remover()
378
+ try:
379
+ if hasattr(r, "model"):
380
+ r.model = r.model.to("cpu")
381
+ r.device = "cpu"
382
+ _GLOBAL_ON_DEVICE = "cpu"
383
+ except Exception:
384
+ pass
385
+ _cuda_soft_cleanup()
386
+
387
+
388
+ def _load_global_to_comfy_cuda_no_crash(max_evictions: int = 32) -> bool:
389
+ """
390
+ Load the global remover into VRAM on ComfyUI's chosen CUDA device.
391
+ Never crashes on OOM: evicts smallest model first, then unload_all as last resort.
392
+ Also records a best-effort VRAM delta.
393
+ """
394
+ global _GLOBAL_ON_DEVICE, _GLOBAL_VRAM_DELTA_BYTES
395
+
396
+ r = _get_global_remover()
397
+ dev = _get_comfy_torch_device()
398
+
399
+ if dev.type != "cuda" or not torch.cuda.is_available():
400
+ _move_global_to_cpu()
401
+ return False
402
+
403
+ # Already on CUDA?
404
+ cur_dev = str(getattr(r, "device", "") or "")
405
+ if cur_dev.startswith("cuda"):
406
+ _GLOBAL_ON_DEVICE = cur_dev
407
+ return True
408
+
409
+ _set_current_cuda_device(dev)
410
+
411
+ free_before = _cuda_free_bytes_on(dev)
412
+
413
+ for _ in range(max_evictions + 1):
414
+ try:
415
+ # Move model to the SAME device ComfyUI uses
416
+ if hasattr(r, "model"):
417
+ r.model = r.model.to(dev)
418
+ r.device = str(dev)
419
+ _GLOBAL_ON_DEVICE = str(dev)
420
+
421
+ _comfy_soft_empty_cache()
422
+ _cuda_soft_cleanup()
423
+
424
+ free_after = _cuda_free_bytes_on(dev)
425
+ if free_before is not None and free_after is not None:
426
+ delta = max(0, int(free_before) - int(free_after))
427
+ if delta > 0:
428
+ _GLOBAL_VRAM_DELTA_BYTES = delta
429
+
430
+ return True
431
+
432
+ except BaseException as e:
433
+ if not _is_oom_error(e):
434
+ raise
435
+ _comfy_soft_empty_cache()
436
+ _cuda_soft_cleanup()
437
+
438
+ # Evict ONE smallest model; if that fails, unload all.
439
+ if not _comfy_unload_one_smallest_model():
440
+ _comfy_unload_all_models()
441
+
442
+ # Could not load
443
+ _move_global_to_cpu()
444
+ return False
445
+
446
+
447
+ def _run_global_rgba_no_crash(pil_rgb: Image.Image, fallback_rgba: Image.Image) -> Image.Image:
448
+ """
449
+ Run remover.process() (rgba output), matching InspyrenetRembg3 behavior.
450
+ On OOM: evict models and retry, then CPU fallback.
451
+ If output alpha is fully transparent, return fallback (prevents "invisible" output).
452
+ """
453
+ r = _get_global_remover()
454
+
455
+ # Try to keep it on CUDA (Comfy device) if possible; do not crash if not.
456
+ _load_global_to_comfy_cuda_no_crash()
457
+
458
+ # Attempt 1: whatever device we're on (likely CUDA)
459
+ try:
460
+ with _GLOBAL_RUN_LOCK:
461
+ with torch.inference_mode():
462
+ out = r.process(pil_rgb, type="rgba")
463
+ if _alpha_is_all_zero(out):
464
+ # Treat as failure -> prevents invisible output
465
+ return fallback_rgba
466
+ return out
467
+ except BaseException as e:
468
+ if not _is_oom_error(e):
469
+ raise
470
+
471
+ # OOM path: evict one smallest and retry (still on CUDA if we are)
472
+ _comfy_soft_empty_cache()
473
+ _cuda_soft_cleanup()
474
+ _comfy_unload_one_smallest_model()
475
+
476
+ try:
477
+ with _GLOBAL_RUN_LOCK:
478
+ with torch.inference_mode():
479
+ out = r.process(pil_rgb, type="rgba")
480
+ if _alpha_is_all_zero(out):
481
+ return fallback_rgba
482
+ return out
483
+ except BaseException as e:
484
+ if not _is_oom_error(e):
485
+ raise
486
+
487
+ # OOM again: unload all comfy models and retry once
488
+ _comfy_unload_all_models()
489
+
490
+ try:
491
+ with _GLOBAL_RUN_LOCK:
492
+ with torch.inference_mode():
493
+ out = r.process(pil_rgb, type="rgba")
494
+ if _alpha_is_all_zero(out):
495
+ return fallback_rgba
496
+ return out
497
+ except BaseException as e:
498
+ if not _is_oom_error(e):
499
+ raise
500
+
501
+ # Final: CPU fallback
502
+ _move_global_to_cpu()
503
+ try:
504
+ with _GLOBAL_RUN_LOCK:
505
+ with torch.inference_mode():
506
+ out = r.process(pil_rgb, type="rgba")
507
+ if _alpha_is_all_zero(out):
508
+ return fallback_rgba
509
+ return out
510
+ except BaseException:
511
+ # Last resort: passthrough
512
+ return fallback_rgba
513
+
514
+
515
+ # -----------------------------------------------------------------------------
516
+ # Nodes
517
+ # -----------------------------------------------------------------------------
518
+
519
+ class InspyrenetRembg2:
520
+ def __init__(self):
521
+ pass
522
+
523
+ @classmethod
524
+ def INPUT_TYPES(s):
525
+ return {
526
+ "required": {
527
+ "image": ("IMAGE",),
528
+ "torchscript_jit": (["default", "on"],)
529
+ },
530
+ }
531
+
532
+ RETURN_TYPES = ("IMAGE", "MASK")
533
+ FUNCTION = "remove_background"
534
+ CATEGORY = "image"
535
+
536
+ def remove_background(self, image, torchscript_jit):
537
+ jit = (torchscript_jit != "default")
538
+ remover, run_lock = _get_remover(jit=jit)
539
+
540
+ img_list = []
541
+ for img in tqdm(image, "Inspyrenet Rembg2"):
542
+ pil_in = tensor2pil(img)
543
+ try:
544
+ with run_lock:
545
+ with torch.inference_mode():
546
+ mid = remover.process(pil_in, type="rgba")
547
+ except BaseException as e:
548
+ if _is_oom_error(e):
549
+ _cuda_soft_cleanup()
550
+ raise RuntimeError("InspyrenetRembg2: CUDA out of memory.") from e
551
+ raise
552
+
553
+ out = pil2tensor(mid)
554
+ img_list.append(out)
555
+ del pil_in, mid, out
556
+
557
+ img_stack = torch.cat(img_list, dim=0)
558
+ mask = img_stack[:, :, :, 3]
559
+ return (img_stack, mask)
560
+
561
+
562
+ class InspyrenetRembg3:
563
+ def __init__(self):
564
+ pass
565
+
566
+ @classmethod
567
+ def INPUT_TYPES(s):
568
+ return {
569
+ "required": {
570
+ "image": ("IMAGE",),
571
+ },
572
+ }
573
+
574
+ RETURN_TYPES = ("IMAGE",)
575
+ FUNCTION = "remove_background"
576
+ CATEGORY = "image"
577
+
578
+ def remove_background(self, image):
579
+ remover, run_lock = _get_remover(jit=False)
580
+
581
+ img_list = []
582
+ for img in tqdm(image, "Inspyrenet Rembg3"):
583
+ pil_in = tensor2pil(img)
584
+ pil_rgb = _rgba_to_rgb_on_white(pil_in)
585
+
586
+ try:
587
+ with run_lock:
588
+ with torch.inference_mode():
589
+ mid = remover.process(pil_rgb, type="rgba")
590
+ except BaseException as e:
591
+ if _is_oom_error(e):
592
+ _cuda_soft_cleanup()
593
+ raise RuntimeError("InspyrenetRembg3: CUDA out of memory.") from e
594
+ raise
595
+
596
+ out = pil2tensor(mid)
597
+ img_list.append(out)
598
+ del pil_in, pil_rgb, mid, out
599
+
600
+ img_stack = torch.cat(img_list, dim=0)
601
+ return (img_stack,)
602
+
603
+
604
+ # -----------------------------------------------------------------------------
605
+ # NEW: Global nodes (simple, no user settings on Load/Run)
606
+ # -----------------------------------------------------------------------------
607
+
608
+ class Load_Inspyrenet_Global:
609
+ """
610
+ No inputs. Creates the global remover (once) and moves it to ComfyUI's CUDA device (if possible).
611
+ Returns:
612
+ - loaded_ok (BOOLEAN)
613
+ - vram_delta_bytes (INT) best-effort (weights residency only; not peak inference)
614
+ """
615
+ def __init__(self):
616
+ pass
617
+
618
+ @classmethod
619
+ def INPUT_TYPES(s):
620
+ return {"required": {}}
621
+
622
+ RETURN_TYPES = ("BOOLEAN", "INT")
623
+ FUNCTION = "load"
624
+ CATEGORY = "image"
625
+
626
+ def load(self):
627
+ _get_global_remover()
628
+ ok = _load_global_to_comfy_cuda_no_crash()
629
+ return (bool(ok), int(_GLOBAL_VRAM_DELTA_BYTES))
630
+
631
+
632
+ class Remove_Inspyrenet_Global:
633
+ """
634
+ Offload global remover to CPU or delete it.
635
+ """
636
+ def __init__(self):
637
+ pass
638
+
639
+ @classmethod
640
+ def INPUT_TYPES(s):
641
+ return {
642
+ "required": {
643
+ "action": (["offload_to_cpu", "delete_instance"],),
644
+ }
645
+ }
646
+
647
+ RETURN_TYPES = ("BOOLEAN",)
648
+ FUNCTION = "remove"
649
+ CATEGORY = "image"
650
+
651
+ def remove(self, action):
652
+ global _GLOBAL_REMOVER, _GLOBAL_ON_DEVICE, _GLOBAL_VRAM_DELTA_BYTES
653
+ if action == "offload_to_cpu":
654
+ _move_global_to_cpu()
655
+ return (True,)
656
+
657
+ # delete_instance
658
+ with _GLOBAL_LOCK:
659
+ try:
660
+ if _GLOBAL_REMOVER is not None:
661
+ try:
662
+ if hasattr(_GLOBAL_REMOVER, "model"):
663
+ _GLOBAL_REMOVER.model = _GLOBAL_REMOVER.model.to("cpu")
664
+ _GLOBAL_REMOVER.device = "cpu"
665
+ except Exception:
666
+ pass
667
+ _GLOBAL_REMOVER = None
668
+ _GLOBAL_ON_DEVICE = "cpu"
669
+ _GLOBAL_VRAM_DELTA_BYTES = 0
670
+ except Exception:
671
+ pass
672
+
673
+ _cuda_soft_cleanup()
674
+ return (True,)
675
+
676
+
677
+ class Run_InspyrenetRembg_Global:
678
+ """
679
+ No settings. Same behavior as InspyrenetRembg3, but uses the global remover and won't crash on OOM.
680
+ On failure/OOM, returns a visible passthrough (opaque RGBA), NOT an invisible image.
681
+ """
682
+ def __init__(self):
683
+ pass
684
+
685
+ @classmethod
686
+ def INPUT_TYPES(s):
687
+ return {
688
+ "required": {
689
+ "image": ("IMAGE",),
690
+ }
691
+ }
692
+
693
+ RETURN_TYPES = ("IMAGE",)
694
+ FUNCTION = "remove_background"
695
+ CATEGORY = "image"
696
+
697
+ def remove_background(self, image):
698
+ _get_global_remover()
699
+
700
+ img_list = []
701
+ for img in tqdm(image, "Run InspyrenetRembg Global"):
702
+ pil_in = tensor2pil(img)
703
+
704
+ # Visible fallback (never invisible)
705
+ fallback = _force_rgba_opaque(pil_in)
706
+
707
+ # Exactly like Rembg3 input path
708
+ pil_rgb = _rgba_to_rgb_on_white(pil_in)
709
+
710
+ out_pil = _run_global_rgba_no_crash(pil_rgb, fallback)
711
+ out = pil2tensor(out_pil)
712
+ img_list.append(out)
713
+
714
+ del pil_in, fallback, pil_rgb, out_pil, out
715
+
716
+ img_stack = torch.cat(img_list, dim=0)
717
+ return (img_stack,)
718
+
719
+
720
+ NODE_CLASS_MAPPINGS = {
721
+ "InspyrenetRembg2": InspyrenetRembg2,
722
+ "InspyrenetRembg3": InspyrenetRembg3,
723
+
724
+ "Load_Inspyrenet_Global": Load_Inspyrenet_Global,
725
+ "Remove_Inspyrenet_Global": Remove_Inspyrenet_Global,
726
+ "Run_InspyrenetRembg_Global": Run_InspyrenetRembg_Global,
727
+ }
728
+
729
+ NODE_DISPLAY_NAME_MAPPINGS = {
730
+ "InspyrenetRembg2": "Inspyrenet Rembg2",
731
+ "InspyrenetRembg3": "Inspyrenet Rembg3",
732
+
733
+ "Load_Inspyrenet_Global": "Load Inspyrenet Global",
734
+ "Remove_Inspyrenet_Global": "Remove Inspyrenet Global",
735
+ "Run_InspyrenetRembg_Global": "Run InspyrenetRembg Global",
736
+ }