from __future__ import annotations from PIL import Image import os import urllib.request import gc import threading from typing import Dict, Tuple, Optional import torch import numpy as np from transparent_background import Remover from tqdm import tqdm # Optional: ComfyUI memory manager (present inside ComfyUI) try: import comfy.model_management as comfy_mm except Exception: comfy_mm = None CKPT_PATH = "/root/.transparent-background/ckpt_base.pth" CKPT_URL = "https://huggingface.co/saliacoel/x/resolve/main/ckpt_base.pth" def _ensure_ckpt_base(): try: if os.path.isfile(CKPT_PATH) and os.path.getsize(CKPT_PATH) > 0: return except Exception: pass os.makedirs(os.path.dirname(CKPT_PATH), exist_ok=True) tmp_path = CKPT_PATH + ".tmp" try: with urllib.request.urlopen(CKPT_URL) as resp: total = resp.headers.get("Content-Length") total = int(total) if total is not None else None with open(tmp_path, "wb") as f: if total: with tqdm( total=total, unit="B", unit_scale=True, desc="Downloading ckpt_base.pth", ) as pbar: while True: chunk = resp.read(1024 * 1024) if not chunk: break f.write(chunk) pbar.update(len(chunk)) else: while True: chunk = resp.read(1024 * 1024) if not chunk: break f.write(chunk) os.replace(tmp_path, CKPT_PATH) finally: if os.path.isfile(tmp_path): try: os.remove(tmp_path) except Exception: pass # Tensor to PIL def tensor2pil(image: torch.Tensor) -> Image.Image: arr = image.detach().cpu().numpy() if arr.ndim == 4 and arr.shape[0] == 1: arr = arr[0] arr = np.clip(255.0 * arr, 0, 255).astype(np.uint8) return Image.fromarray(arr) # Convert PIL to Tensor def pil2tensor(image: Image.Image) -> torch.Tensor: return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) def _rgba_to_rgb_on_white(pil_img: Image.Image) -> Image.Image: if pil_img.mode == "RGBA": bg = Image.new("RGBA", pil_img.size, (255, 255, 255, 255)) composited = Image.alpha_composite(bg, pil_img) return composited.convert("RGB") if pil_img.mode != "RGB": return pil_img.convert("RGB") return pil_img def _force_rgba_opaque(pil_img: Image.Image) -> Image.Image: """ Opaque RGBA fallback (alpha=255), so you never get an "invisible" output. """ rgba = pil_img.convert("RGBA") r, g, b, _a = rgba.split() a = Image.new("L", rgba.size, 255) return Image.merge("RGBA", (r, g, b, a)) def _alpha_is_all_zero(pil_img: Image.Image) -> bool: """ True if RGBA image alpha channel is entirely 0. """ if pil_img.mode != "RGBA": return False try: extrema = pil_img.getextrema() # ((min,max),(min,max),(min,max),(min,max)) return extrema[3][1] == 0 except Exception: return False def _is_oom_error(e: BaseException) -> bool: oom_cuda_cls = getattr(getattr(torch, "cuda", None), "OutOfMemoryError", None) if oom_cuda_cls is not None and isinstance(e, oom_cuda_cls): return True oom_torch_cls = getattr(torch, "OutOfMemoryError", None) if oom_torch_cls is not None and isinstance(e, oom_torch_cls): return True msg = str(e).lower() if "out of memory" in msg: return True if "allocation on device" in msg: return True return ("cuda" in msg or "cublas" in msg or "hip" in msg) and ("memory" in msg) def _cuda_soft_cleanup() -> None: try: gc.collect() except Exception: pass if torch.cuda.is_available(): try: torch.cuda.synchronize() except Exception: pass try: torch.cuda.empty_cache() except Exception: pass try: torch.cuda.ipc_collect() except Exception: pass def _comfy_soft_empty_cache() -> None: if comfy_mm is None: return if hasattr(comfy_mm, "soft_empty_cache"): try: comfy_mm.soft_empty_cache(force=True) except TypeError: try: comfy_mm.soft_empty_cache() except Exception: pass except Exception: pass def _get_comfy_torch_device() -> torch.device: """ Always prefer ComfyUI's chosen device. """ if comfy_mm is not None and hasattr(comfy_mm, "get_torch_device"): try: d = comfy_mm.get_torch_device() if isinstance(d, torch.device): return d return torch.device(str(d)) except Exception: pass if torch.cuda.is_available(): return torch.device("cuda:0") return torch.device("cpu") def _set_current_cuda_device(dev: torch.device) -> None: """ Make sure mem_get_info() measurements are on the same device ComfyUI uses. """ if dev.type == "cuda": try: if dev.index is not None: torch.cuda.set_device(dev.index) except Exception: pass def _cuda_free_bytes_on(dev: torch.device) -> Optional[int]: if dev.type != "cuda" or not torch.cuda.is_available(): return None try: _set_current_cuda_device(dev) free_b, _total_b = torch.cuda.mem_get_info() return int(free_b) except Exception: return None def _comfy_unload_one_smallest_model() -> bool: """ Best-effort "smallest-first" eviction of one ComfyUI-tracked loaded model. If ComfyUI internals differ, this may do nothing (and we fall back to unload_all_models()). """ if comfy_mm is None: return False if not hasattr(comfy_mm, "current_loaded_models"): return False try: cur_dev = _get_comfy_torch_device() except Exception: cur_dev = None models = [] try: for lm in list(comfy_mm.current_loaded_models): try: # Prefer same device lm_dev = getattr(lm, "device", None) if cur_dev is not None and lm_dev is not None and str(lm_dev) != str(cur_dev): continue mem_fn = getattr(lm, "model_loaded_memory", None) if callable(mem_fn): mem = int(mem_fn()) else: mem = int(getattr(lm, "loaded_memory", 0) or 0) if mem > 0: models.append((mem, lm)) except Exception: continue except Exception: return False if not models: return False models.sort(key=lambda x: x[0]) # smallest first _mem, lm = models[0] try: unload_fn = getattr(lm, "model_unload", None) if callable(unload_fn): try: unload_fn(unpatch_weights=True) except TypeError: unload_fn() except Exception: pass # Cleanup hook if present try: cleanup = getattr(comfy_mm, "cleanup_models", None) if callable(cleanup): cleanup() except Exception: pass _comfy_soft_empty_cache() _cuda_soft_cleanup() return True def _comfy_unload_all_models() -> None: if comfy_mm is None: return if hasattr(comfy_mm, "unload_all_models"): try: comfy_mm.unload_all_models() except Exception: pass _comfy_soft_empty_cache() _cuda_soft_cleanup() # ----------------------------------------------------------------------------- # Existing singleton cache for Rembg2/Rembg3 (your original) # ----------------------------------------------------------------------------- _REMOVER_CACHE: Dict[Tuple[bool], Remover] = {} _REMOVER_RUN_LOCKS: Dict[Tuple[bool], threading.Lock] = {} _CACHE_LOCK = threading.Lock() def _get_remover(jit: bool = False) -> tuple[Remover, threading.Lock]: key = (jit,) with _CACHE_LOCK: inst = _REMOVER_CACHE.get(key) if inst is None: _ensure_ckpt_base() try: inst = Remover(jit=jit) if jit else Remover() except BaseException as e: if _is_oom_error(e): _cuda_soft_cleanup() raise _REMOVER_CACHE[key] = inst run_lock = _REMOVER_RUN_LOCKS.get(key) if run_lock is None: run_lock = threading.Lock() _REMOVER_RUN_LOCKS[key] = run_lock return inst, run_lock # ----------------------------------------------------------------------------- # GLOBAL remover (for Load/Remove/Run Global nodes) # ----------------------------------------------------------------------------- _GLOBAL_LOCK = threading.Lock() _GLOBAL_RUN_LOCK = threading.Lock() _GLOBAL_REMOVER: Optional[Remover] = None _GLOBAL_ON_DEVICE: str = "cpu" _GLOBAL_VRAM_DELTA_BYTES: int = 0 def _create_global_remover_cpu() -> Remover: """ Create the Remover configured like InspyrenetRembg3 (jit=False), but *try* to force CPU init to avoid VRAM OOM during creation. """ _ensure_ckpt_base() # Prefer constructing on CPU if supported by this library version. try: r = Remover(device="cpu") # type: ignore[arg-type] try: r.device = "cpu" except Exception: pass return r except TypeError: pass # Fallback: construct default and immediately offload to CPU r = Remover() try: if hasattr(r, "model"): r.model = r.model.to("cpu") r.device = "cpu" except Exception: pass _cuda_soft_cleanup() return r def _get_global_remover() -> Remover: global _GLOBAL_REMOVER, _GLOBAL_ON_DEVICE with _GLOBAL_LOCK: if _GLOBAL_REMOVER is None: _GLOBAL_REMOVER = _create_global_remover_cpu() _GLOBAL_ON_DEVICE = str(getattr(_GLOBAL_REMOVER, "device", "cpu")) return _GLOBAL_REMOVER def _move_global_to_cpu() -> None: global _GLOBAL_ON_DEVICE r = _get_global_remover() try: if hasattr(r, "model"): r.model = r.model.to("cpu") r.device = "cpu" _GLOBAL_ON_DEVICE = "cpu" except Exception: pass _cuda_soft_cleanup() def _load_global_to_comfy_cuda_no_crash(max_evictions: int = 32) -> bool: """ Load the global remover into VRAM on ComfyUI's chosen CUDA device. Never crashes on OOM: evicts smallest model first, then unload_all as last resort. Also records a best-effort VRAM delta. """ global _GLOBAL_ON_DEVICE, _GLOBAL_VRAM_DELTA_BYTES r = _get_global_remover() dev = _get_comfy_torch_device() if dev.type != "cuda" or not torch.cuda.is_available(): _move_global_to_cpu() return False # Already on CUDA? cur_dev = str(getattr(r, "device", "") or "") if cur_dev.startswith("cuda"): _GLOBAL_ON_DEVICE = cur_dev return True _set_current_cuda_device(dev) free_before = _cuda_free_bytes_on(dev) for _ in range(max_evictions + 1): try: # Move model to the SAME device ComfyUI uses if hasattr(r, "model"): r.model = r.model.to(dev) r.device = str(dev) _GLOBAL_ON_DEVICE = str(dev) _comfy_soft_empty_cache() _cuda_soft_cleanup() free_after = _cuda_free_bytes_on(dev) if free_before is not None and free_after is not None: delta = max(0, int(free_before) - int(free_after)) if delta > 0: _GLOBAL_VRAM_DELTA_BYTES = delta return True except BaseException as e: if not _is_oom_error(e): raise _comfy_soft_empty_cache() _cuda_soft_cleanup() # Evict ONE smallest model; if that fails, unload all. if not _comfy_unload_one_smallest_model(): _comfy_unload_all_models() # Could not load _move_global_to_cpu() return False def _run_global_rgba_no_crash(pil_rgb: Image.Image, fallback_rgba: Image.Image) -> Image.Image: """ Run remover.process() (rgba output), matching InspyrenetRembg3 behavior. On OOM: evict models and retry, then CPU fallback. If output alpha is fully transparent, return fallback (prevents "invisible" output). """ r = _get_global_remover() # Try to keep it on CUDA (Comfy device) if possible; do not crash if not. _load_global_to_comfy_cuda_no_crash() # Attempt 1: whatever device we're on (likely CUDA) try: with _GLOBAL_RUN_LOCK: with torch.inference_mode(): out = r.process(pil_rgb, type="rgba") if _alpha_is_all_zero(out): # Treat as failure -> prevents invisible output return fallback_rgba return out except BaseException as e: if not _is_oom_error(e): raise # OOM path: evict one smallest and retry (still on CUDA if we are) _comfy_soft_empty_cache() _cuda_soft_cleanup() _comfy_unload_one_smallest_model() try: with _GLOBAL_RUN_LOCK: with torch.inference_mode(): out = r.process(pil_rgb, type="rgba") if _alpha_is_all_zero(out): return fallback_rgba return out except BaseException as e: if not _is_oom_error(e): raise # OOM again: unload all comfy models and retry once _comfy_unload_all_models() try: with _GLOBAL_RUN_LOCK: with torch.inference_mode(): out = r.process(pil_rgb, type="rgba") if _alpha_is_all_zero(out): return fallback_rgba return out except BaseException as e: if not _is_oom_error(e): raise # Final: CPU fallback _move_global_to_cpu() try: with _GLOBAL_RUN_LOCK: with torch.inference_mode(): out = r.process(pil_rgb, type="rgba") if _alpha_is_all_zero(out): return fallback_rgba return out except BaseException: # Last resort: passthrough return fallback_rgba # ----------------------------------------------------------------------------- # Nodes # ----------------------------------------------------------------------------- class InspyrenetRembg2: def __init__(self): pass @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",), "torchscript_jit": (["default", "on"],) }, } RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "remove_background" CATEGORY = "image" def remove_background(self, image, torchscript_jit): jit = (torchscript_jit != "default") remover, run_lock = _get_remover(jit=jit) img_list = [] for img in tqdm(image, "Inspyrenet Rembg2"): pil_in = tensor2pil(img) try: with run_lock: with torch.inference_mode(): mid = remover.process(pil_in, type="rgba") except BaseException as e: if _is_oom_error(e): _cuda_soft_cleanup() raise RuntimeError("InspyrenetRembg2: CUDA out of memory.") from e raise out = pil2tensor(mid) img_list.append(out) del pil_in, mid, out img_stack = torch.cat(img_list, dim=0) mask = img_stack[:, :, :, 3] return (img_stack, mask) class InspyrenetRembg3: def __init__(self): pass @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",), }, } RETURN_TYPES = ("IMAGE",) FUNCTION = "remove_background" CATEGORY = "image" def remove_background(self, image): remover, run_lock = _get_remover(jit=False) img_list = [] for img in tqdm(image, "Inspyrenet Rembg3"): pil_in = tensor2pil(img) pil_rgb = _rgba_to_rgb_on_white(pil_in) try: with run_lock: with torch.inference_mode(): mid = remover.process(pil_rgb, type="rgba") except BaseException as e: if _is_oom_error(e): _cuda_soft_cleanup() raise RuntimeError("InspyrenetRembg3: CUDA out of memory.") from e raise out = pil2tensor(mid) img_list.append(out) del pil_in, pil_rgb, mid, out img_stack = torch.cat(img_list, dim=0) return (img_stack,) # ----------------------------------------------------------------------------- # NEW: Global nodes (simple, no user settings on Load/Run) # ----------------------------------------------------------------------------- class Load_Inspyrenet_Global: """ No inputs. Creates the global remover (once) and moves it to ComfyUI's CUDA device (if possible). Returns: - loaded_ok (BOOLEAN) - vram_delta_bytes (INT) best-effort (weights residency only; not peak inference) """ def __init__(self): pass @classmethod def INPUT_TYPES(s): return {"required": {}} RETURN_TYPES = ("BOOLEAN", "INT") FUNCTION = "load" CATEGORY = "image" def load(self): _get_global_remover() ok = _load_global_to_comfy_cuda_no_crash() return (bool(ok), int(_GLOBAL_VRAM_DELTA_BYTES)) class Remove_Inspyrenet_Global: """ Offload global remover to CPU or delete it. """ def __init__(self): pass @classmethod def INPUT_TYPES(s): return { "required": { "action": (["offload_to_cpu", "delete_instance"],), } } RETURN_TYPES = ("BOOLEAN",) FUNCTION = "remove" CATEGORY = "image" def remove(self, action): global _GLOBAL_REMOVER, _GLOBAL_ON_DEVICE, _GLOBAL_VRAM_DELTA_BYTES if action == "offload_to_cpu": _move_global_to_cpu() return (True,) # delete_instance with _GLOBAL_LOCK: try: if _GLOBAL_REMOVER is not None: try: if hasattr(_GLOBAL_REMOVER, "model"): _GLOBAL_REMOVER.model = _GLOBAL_REMOVER.model.to("cpu") _GLOBAL_REMOVER.device = "cpu" except Exception: pass _GLOBAL_REMOVER = None _GLOBAL_ON_DEVICE = "cpu" _GLOBAL_VRAM_DELTA_BYTES = 0 except Exception: pass _cuda_soft_cleanup() return (True,) class Run_InspyrenetRembg_Global: """ No settings. Same behavior as InspyrenetRembg3, but uses the global remover and won't crash on OOM. On failure/OOM, returns a visible passthrough (opaque RGBA), NOT an invisible image. """ def __init__(self): pass @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",), } } RETURN_TYPES = ("IMAGE",) FUNCTION = "remove_background" CATEGORY = "image" def remove_background(self, image): _get_global_remover() img_list = [] for img in tqdm(image, "Run InspyrenetRembg Global"): pil_in = tensor2pil(img) # Visible fallback (never invisible) fallback = _force_rgba_opaque(pil_in) # Exactly like Rembg3 input path pil_rgb = _rgba_to_rgb_on_white(pil_in) out_pil = _run_global_rgba_no_crash(pil_rgb, fallback) out = pil2tensor(out_pil) img_list.append(out) del pil_in, fallback, pil_rgb, out_pil, out img_stack = torch.cat(img_list, dim=0) return (img_stack,) NODE_CLASS_MAPPINGS = { "InspyrenetRembg2": InspyrenetRembg2, "InspyrenetRembg3": InspyrenetRembg3, "Load_Inspyrenet_Global": Load_Inspyrenet_Global, "Remove_Inspyrenet_Global": Remove_Inspyrenet_Global, "Run_InspyrenetRembg_Global": Run_InspyrenetRembg_Global, } NODE_DISPLAY_NAME_MAPPINGS = { "InspyrenetRembg2": "Inspyrenet Rembg2", "InspyrenetRembg3": "Inspyrenet Rembg3", "Load_Inspyrenet_Global": "Load Inspyrenet Global", "Remove_Inspyrenet_Global": "Remove Inspyrenet Global", "Run_InspyrenetRembg_Global": "Run InspyrenetRembg Global", }