| | from __future__ import annotations |
| |
|
| | from PIL import Image |
| | import os |
| | import urllib.request |
| | import gc |
| | import threading |
| | from typing import Dict, Tuple, Optional |
| |
|
| | import torch |
| | import numpy as np |
| | from transparent_background import Remover |
| | from tqdm import tqdm |
| |
|
| |
|
| | |
| | try: |
| | import comfy.model_management as comfy_mm |
| | except Exception: |
| | comfy_mm = None |
| |
|
| |
|
| | CKPT_PATH = "/root/.transparent-background/ckpt_base.pth" |
| | CKPT_URL = "https://huggingface.co/saliacoel/x/resolve/main/ckpt_base.pth" |
| |
|
| |
|
| | def _ensure_ckpt_base(): |
| | try: |
| | if os.path.isfile(CKPT_PATH) and os.path.getsize(CKPT_PATH) > 0: |
| | return |
| | except Exception: |
| | pass |
| |
|
| | os.makedirs(os.path.dirname(CKPT_PATH), exist_ok=True) |
| | tmp_path = CKPT_PATH + ".tmp" |
| |
|
| | try: |
| | with urllib.request.urlopen(CKPT_URL) as resp: |
| | total = resp.headers.get("Content-Length") |
| | total = int(total) if total is not None else None |
| |
|
| | with open(tmp_path, "wb") as f: |
| | if total: |
| | with tqdm( |
| | total=total, |
| | unit="B", |
| | unit_scale=True, |
| | desc="Downloading ckpt_base.pth", |
| | ) as pbar: |
| | while True: |
| | chunk = resp.read(1024 * 1024) |
| | if not chunk: |
| | break |
| | f.write(chunk) |
| | pbar.update(len(chunk)) |
| | else: |
| | while True: |
| | chunk = resp.read(1024 * 1024) |
| | if not chunk: |
| | break |
| | f.write(chunk) |
| |
|
| | os.replace(tmp_path, CKPT_PATH) |
| | finally: |
| | if os.path.isfile(tmp_path): |
| | try: |
| | os.remove(tmp_path) |
| | except Exception: |
| | pass |
| |
|
| |
|
| | |
| | def tensor2pil(image: torch.Tensor) -> Image.Image: |
| | arr = image.detach().cpu().numpy() |
| | if arr.ndim == 4 and arr.shape[0] == 1: |
| | arr = arr[0] |
| | arr = np.clip(255.0 * arr, 0, 255).astype(np.uint8) |
| | return Image.fromarray(arr) |
| |
|
| |
|
| | |
| | def pil2tensor(image: Image.Image) -> torch.Tensor: |
| | return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) |
| |
|
| |
|
| | def _rgba_to_rgb_on_white(pil_img: Image.Image) -> Image.Image: |
| | if pil_img.mode == "RGBA": |
| | bg = Image.new("RGBA", pil_img.size, (255, 255, 255, 255)) |
| | composited = Image.alpha_composite(bg, pil_img) |
| | return composited.convert("RGB") |
| |
|
| | if pil_img.mode != "RGB": |
| | return pil_img.convert("RGB") |
| |
|
| | return pil_img |
| |
|
| |
|
| | def _force_rgba_opaque(pil_img: Image.Image) -> Image.Image: |
| | """ |
| | Opaque RGBA fallback (alpha=255), so you never get an "invisible" output. |
| | """ |
| | rgba = pil_img.convert("RGBA") |
| | r, g, b, _a = rgba.split() |
| | a = Image.new("L", rgba.size, 255) |
| | return Image.merge("RGBA", (r, g, b, a)) |
| |
|
| |
|
| | def _alpha_is_all_zero(pil_img: Image.Image) -> bool: |
| | """ |
| | True if RGBA image alpha channel is entirely 0. |
| | """ |
| | if pil_img.mode != "RGBA": |
| | return False |
| | try: |
| | extrema = pil_img.getextrema() |
| | return extrema[3][1] == 0 |
| | except Exception: |
| | return False |
| |
|
| |
|
| | def _is_oom_error(e: BaseException) -> bool: |
| | oom_cuda_cls = getattr(getattr(torch, "cuda", None), "OutOfMemoryError", None) |
| | if oom_cuda_cls is not None and isinstance(e, oom_cuda_cls): |
| | return True |
| |
|
| | oom_torch_cls = getattr(torch, "OutOfMemoryError", None) |
| | if oom_torch_cls is not None and isinstance(e, oom_torch_cls): |
| | return True |
| |
|
| | msg = str(e).lower() |
| | if "out of memory" in msg: |
| | return True |
| | if "allocation on device" in msg: |
| | return True |
| | return ("cuda" in msg or "cublas" in msg or "hip" in msg) and ("memory" in msg) |
| |
|
| |
|
| | def _cuda_soft_cleanup() -> None: |
| | try: |
| | gc.collect() |
| | except Exception: |
| | pass |
| |
|
| | if torch.cuda.is_available(): |
| | try: |
| | torch.cuda.synchronize() |
| | except Exception: |
| | pass |
| | try: |
| | torch.cuda.empty_cache() |
| | except Exception: |
| | pass |
| | try: |
| | torch.cuda.ipc_collect() |
| | except Exception: |
| | pass |
| |
|
| |
|
| | def _comfy_soft_empty_cache() -> None: |
| | if comfy_mm is None: |
| | return |
| | if hasattr(comfy_mm, "soft_empty_cache"): |
| | try: |
| | comfy_mm.soft_empty_cache(force=True) |
| | except TypeError: |
| | try: |
| | comfy_mm.soft_empty_cache() |
| | except Exception: |
| | pass |
| | except Exception: |
| | pass |
| |
|
| |
|
| | def _get_comfy_torch_device() -> torch.device: |
| | """ |
| | Always prefer ComfyUI's chosen device. |
| | """ |
| | if comfy_mm is not None and hasattr(comfy_mm, "get_torch_device"): |
| | try: |
| | d = comfy_mm.get_torch_device() |
| | if isinstance(d, torch.device): |
| | return d |
| | return torch.device(str(d)) |
| | except Exception: |
| | pass |
| |
|
| | if torch.cuda.is_available(): |
| | return torch.device("cuda:0") |
| | return torch.device("cpu") |
| |
|
| |
|
| | def _set_current_cuda_device(dev: torch.device) -> None: |
| | """ |
| | Make sure mem_get_info() measurements are on the same device ComfyUI uses. |
| | """ |
| | if dev.type == "cuda": |
| | try: |
| | if dev.index is not None: |
| | torch.cuda.set_device(dev.index) |
| | except Exception: |
| | pass |
| |
|
| |
|
| | def _cuda_free_bytes_on(dev: torch.device) -> Optional[int]: |
| | if dev.type != "cuda" or not torch.cuda.is_available(): |
| | return None |
| | try: |
| | _set_current_cuda_device(dev) |
| | free_b, _total_b = torch.cuda.mem_get_info() |
| | return int(free_b) |
| | except Exception: |
| | return None |
| |
|
| |
|
| | def _comfy_unload_one_smallest_model() -> bool: |
| | """ |
| | Best-effort "smallest-first" eviction of one ComfyUI-tracked loaded model. |
| | |
| | If ComfyUI internals differ, this may do nothing (and we fall back to unload_all_models()). |
| | """ |
| | if comfy_mm is None: |
| | return False |
| | if not hasattr(comfy_mm, "current_loaded_models"): |
| | return False |
| |
|
| | try: |
| | cur_dev = _get_comfy_torch_device() |
| | except Exception: |
| | cur_dev = None |
| |
|
| | models = [] |
| | try: |
| | for lm in list(comfy_mm.current_loaded_models): |
| | try: |
| | |
| | lm_dev = getattr(lm, "device", None) |
| | if cur_dev is not None and lm_dev is not None and str(lm_dev) != str(cur_dev): |
| | continue |
| |
|
| | mem_fn = getattr(lm, "model_loaded_memory", None) |
| | if callable(mem_fn): |
| | mem = int(mem_fn()) |
| | else: |
| | mem = int(getattr(lm, "loaded_memory", 0) or 0) |
| |
|
| | if mem > 0: |
| | models.append((mem, lm)) |
| | except Exception: |
| | continue |
| | except Exception: |
| | return False |
| |
|
| | if not models: |
| | return False |
| |
|
| | models.sort(key=lambda x: x[0]) |
| | _mem, lm = models[0] |
| |
|
| | try: |
| | unload_fn = getattr(lm, "model_unload", None) |
| | if callable(unload_fn): |
| | try: |
| | unload_fn(unpatch_weights=True) |
| | except TypeError: |
| | unload_fn() |
| | except Exception: |
| | pass |
| |
|
| | |
| | try: |
| | cleanup = getattr(comfy_mm, "cleanup_models", None) |
| | if callable(cleanup): |
| | cleanup() |
| | except Exception: |
| | pass |
| |
|
| | _comfy_soft_empty_cache() |
| | _cuda_soft_cleanup() |
| | return True |
| |
|
| |
|
| | def _comfy_unload_all_models() -> None: |
| | if comfy_mm is None: |
| | return |
| | if hasattr(comfy_mm, "unload_all_models"): |
| | try: |
| | comfy_mm.unload_all_models() |
| | except Exception: |
| | pass |
| | _comfy_soft_empty_cache() |
| | _cuda_soft_cleanup() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | _REMOVER_CACHE: Dict[Tuple[bool], Remover] = {} |
| | _REMOVER_RUN_LOCKS: Dict[Tuple[bool], threading.Lock] = {} |
| | _CACHE_LOCK = threading.Lock() |
| |
|
| |
|
| | def _get_remover(jit: bool = False) -> tuple[Remover, threading.Lock]: |
| | key = (jit,) |
| | with _CACHE_LOCK: |
| | inst = _REMOVER_CACHE.get(key) |
| | if inst is None: |
| | _ensure_ckpt_base() |
| | try: |
| | inst = Remover(jit=jit) if jit else Remover() |
| | except BaseException as e: |
| | if _is_oom_error(e): |
| | _cuda_soft_cleanup() |
| | raise |
| | _REMOVER_CACHE[key] = inst |
| |
|
| | run_lock = _REMOVER_RUN_LOCKS.get(key) |
| | if run_lock is None: |
| | run_lock = threading.Lock() |
| | _REMOVER_RUN_LOCKS[key] = run_lock |
| |
|
| | return inst, run_lock |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | _GLOBAL_LOCK = threading.Lock() |
| | _GLOBAL_RUN_LOCK = threading.Lock() |
| | _GLOBAL_REMOVER: Optional[Remover] = None |
| | _GLOBAL_ON_DEVICE: str = "cpu" |
| | _GLOBAL_VRAM_DELTA_BYTES: int = 0 |
| |
|
| |
|
| | def _create_global_remover_cpu() -> Remover: |
| | """ |
| | Create the Remover configured like InspyrenetRembg3 (jit=False), |
| | but *try* to force CPU init to avoid VRAM OOM during creation. |
| | """ |
| | _ensure_ckpt_base() |
| |
|
| | |
| | try: |
| | r = Remover(device="cpu") |
| | try: |
| | r.device = "cpu" |
| | except Exception: |
| | pass |
| | return r |
| | except TypeError: |
| | pass |
| |
|
| | |
| | r = Remover() |
| | try: |
| | if hasattr(r, "model"): |
| | r.model = r.model.to("cpu") |
| | r.device = "cpu" |
| | except Exception: |
| | pass |
| | _cuda_soft_cleanup() |
| | return r |
| |
|
| |
|
| | def _get_global_remover() -> Remover: |
| | global _GLOBAL_REMOVER, _GLOBAL_ON_DEVICE |
| | with _GLOBAL_LOCK: |
| | if _GLOBAL_REMOVER is None: |
| | _GLOBAL_REMOVER = _create_global_remover_cpu() |
| | _GLOBAL_ON_DEVICE = str(getattr(_GLOBAL_REMOVER, "device", "cpu")) |
| | return _GLOBAL_REMOVER |
| |
|
| |
|
| | def _move_global_to_cpu() -> None: |
| | global _GLOBAL_ON_DEVICE |
| | r = _get_global_remover() |
| | try: |
| | if hasattr(r, "model"): |
| | r.model = r.model.to("cpu") |
| | r.device = "cpu" |
| | _GLOBAL_ON_DEVICE = "cpu" |
| | except Exception: |
| | pass |
| | _cuda_soft_cleanup() |
| |
|
| |
|
| | def _load_global_to_comfy_cuda_no_crash(max_evictions: int = 32) -> bool: |
| | """ |
| | Load the global remover into VRAM on ComfyUI's chosen CUDA device. |
| | Never crashes on OOM: evicts smallest model first, then unload_all as last resort. |
| | Also records a best-effort VRAM delta. |
| | """ |
| | global _GLOBAL_ON_DEVICE, _GLOBAL_VRAM_DELTA_BYTES |
| |
|
| | r = _get_global_remover() |
| | dev = _get_comfy_torch_device() |
| |
|
| | if dev.type != "cuda" or not torch.cuda.is_available(): |
| | _move_global_to_cpu() |
| | return False |
| |
|
| | |
| | cur_dev = str(getattr(r, "device", "") or "") |
| | if cur_dev.startswith("cuda"): |
| | _GLOBAL_ON_DEVICE = cur_dev |
| | return True |
| |
|
| | _set_current_cuda_device(dev) |
| |
|
| | free_before = _cuda_free_bytes_on(dev) |
| |
|
| | for _ in range(max_evictions + 1): |
| | try: |
| | |
| | if hasattr(r, "model"): |
| | r.model = r.model.to(dev) |
| | r.device = str(dev) |
| | _GLOBAL_ON_DEVICE = str(dev) |
| |
|
| | _comfy_soft_empty_cache() |
| | _cuda_soft_cleanup() |
| |
|
| | free_after = _cuda_free_bytes_on(dev) |
| | if free_before is not None and free_after is not None: |
| | delta = max(0, int(free_before) - int(free_after)) |
| | if delta > 0: |
| | _GLOBAL_VRAM_DELTA_BYTES = delta |
| |
|
| | return True |
| |
|
| | except BaseException as e: |
| | if not _is_oom_error(e): |
| | raise |
| | _comfy_soft_empty_cache() |
| | _cuda_soft_cleanup() |
| |
|
| | |
| | if not _comfy_unload_one_smallest_model(): |
| | _comfy_unload_all_models() |
| |
|
| | |
| | _move_global_to_cpu() |
| | return False |
| |
|
| |
|
| | def _run_global_rgba_no_crash(pil_rgb: Image.Image, fallback_rgba: Image.Image) -> Image.Image: |
| | """ |
| | Run remover.process() (rgba output), matching InspyrenetRembg3 behavior. |
| | On OOM: evict models and retry, then CPU fallback. |
| | If output alpha is fully transparent, return fallback (prevents "invisible" output). |
| | """ |
| | r = _get_global_remover() |
| |
|
| | |
| | _load_global_to_comfy_cuda_no_crash() |
| |
|
| | |
| | try: |
| | with _GLOBAL_RUN_LOCK: |
| | with torch.inference_mode(): |
| | out = r.process(pil_rgb, type="rgba") |
| | if _alpha_is_all_zero(out): |
| | |
| | return fallback_rgba |
| | return out |
| | except BaseException as e: |
| | if not _is_oom_error(e): |
| | raise |
| |
|
| | |
| | _comfy_soft_empty_cache() |
| | _cuda_soft_cleanup() |
| | _comfy_unload_one_smallest_model() |
| |
|
| | try: |
| | with _GLOBAL_RUN_LOCK: |
| | with torch.inference_mode(): |
| | out = r.process(pil_rgb, type="rgba") |
| | if _alpha_is_all_zero(out): |
| | return fallback_rgba |
| | return out |
| | except BaseException as e: |
| | if not _is_oom_error(e): |
| | raise |
| |
|
| | |
| | _comfy_unload_all_models() |
| |
|
| | try: |
| | with _GLOBAL_RUN_LOCK: |
| | with torch.inference_mode(): |
| | out = r.process(pil_rgb, type="rgba") |
| | if _alpha_is_all_zero(out): |
| | return fallback_rgba |
| | return out |
| | except BaseException as e: |
| | if not _is_oom_error(e): |
| | raise |
| |
|
| | |
| | _move_global_to_cpu() |
| | try: |
| | with _GLOBAL_RUN_LOCK: |
| | with torch.inference_mode(): |
| | out = r.process(pil_rgb, type="rgba") |
| | if _alpha_is_all_zero(out): |
| | return fallback_rgba |
| | return out |
| | except BaseException: |
| | |
| | return fallback_rgba |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class InspyrenetRembg2: |
| | def __init__(self): |
| | pass |
| |
|
| | @classmethod |
| | def INPUT_TYPES(s): |
| | return { |
| | "required": { |
| | "image": ("IMAGE",), |
| | "torchscript_jit": (["default", "on"],) |
| | }, |
| | } |
| |
|
| | RETURN_TYPES = ("IMAGE", "MASK") |
| | FUNCTION = "remove_background" |
| | CATEGORY = "image" |
| |
|
| | def remove_background(self, image, torchscript_jit): |
| | jit = (torchscript_jit != "default") |
| | remover, run_lock = _get_remover(jit=jit) |
| |
|
| | img_list = [] |
| | for img in tqdm(image, "Inspyrenet Rembg2"): |
| | pil_in = tensor2pil(img) |
| | try: |
| | with run_lock: |
| | with torch.inference_mode(): |
| | mid = remover.process(pil_in, type="rgba") |
| | except BaseException as e: |
| | if _is_oom_error(e): |
| | _cuda_soft_cleanup() |
| | raise RuntimeError("InspyrenetRembg2: CUDA out of memory.") from e |
| | raise |
| |
|
| | out = pil2tensor(mid) |
| | img_list.append(out) |
| | del pil_in, mid, out |
| |
|
| | img_stack = torch.cat(img_list, dim=0) |
| | mask = img_stack[:, :, :, 3] |
| | return (img_stack, mask) |
| |
|
| |
|
| | class InspyrenetRembg3: |
| | def __init__(self): |
| | pass |
| |
|
| | @classmethod |
| | def INPUT_TYPES(s): |
| | return { |
| | "required": { |
| | "image": ("IMAGE",), |
| | }, |
| | } |
| |
|
| | RETURN_TYPES = ("IMAGE",) |
| | FUNCTION = "remove_background" |
| | CATEGORY = "image" |
| |
|
| | def remove_background(self, image): |
| | remover, run_lock = _get_remover(jit=False) |
| |
|
| | img_list = [] |
| | for img in tqdm(image, "Inspyrenet Rembg3"): |
| | pil_in = tensor2pil(img) |
| | pil_rgb = _rgba_to_rgb_on_white(pil_in) |
| |
|
| | try: |
| | with run_lock: |
| | with torch.inference_mode(): |
| | mid = remover.process(pil_rgb, type="rgba") |
| | except BaseException as e: |
| | if _is_oom_error(e): |
| | _cuda_soft_cleanup() |
| | raise RuntimeError("InspyrenetRembg3: CUDA out of memory.") from e |
| | raise |
| |
|
| | out = pil2tensor(mid) |
| | img_list.append(out) |
| | del pil_in, pil_rgb, mid, out |
| |
|
| | img_stack = torch.cat(img_list, dim=0) |
| | return (img_stack,) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class Load_Inspyrenet_Global: |
| | """ |
| | No inputs. Creates the global remover (once) and moves it to ComfyUI's CUDA device (if possible). |
| | Returns: |
| | - loaded_ok (BOOLEAN) |
| | - vram_delta_bytes (INT) best-effort (weights residency only; not peak inference) |
| | """ |
| | def __init__(self): |
| | pass |
| |
|
| | @classmethod |
| | def INPUT_TYPES(s): |
| | return {"required": {}} |
| |
|
| | RETURN_TYPES = ("BOOLEAN", "INT") |
| | FUNCTION = "load" |
| | CATEGORY = "image" |
| |
|
| | def load(self): |
| | _get_global_remover() |
| | ok = _load_global_to_comfy_cuda_no_crash() |
| | return (bool(ok), int(_GLOBAL_VRAM_DELTA_BYTES)) |
| |
|
| |
|
| | class Remove_Inspyrenet_Global: |
| | """ |
| | Offload global remover to CPU or delete it. |
| | """ |
| | def __init__(self): |
| | pass |
| |
|
| | @classmethod |
| | def INPUT_TYPES(s): |
| | return { |
| | "required": { |
| | "action": (["offload_to_cpu", "delete_instance"],), |
| | } |
| | } |
| |
|
| | RETURN_TYPES = ("BOOLEAN",) |
| | FUNCTION = "remove" |
| | CATEGORY = "image" |
| |
|
| | def remove(self, action): |
| | global _GLOBAL_REMOVER, _GLOBAL_ON_DEVICE, _GLOBAL_VRAM_DELTA_BYTES |
| | if action == "offload_to_cpu": |
| | _move_global_to_cpu() |
| | return (True,) |
| |
|
| | |
| | with _GLOBAL_LOCK: |
| | try: |
| | if _GLOBAL_REMOVER is not None: |
| | try: |
| | if hasattr(_GLOBAL_REMOVER, "model"): |
| | _GLOBAL_REMOVER.model = _GLOBAL_REMOVER.model.to("cpu") |
| | _GLOBAL_REMOVER.device = "cpu" |
| | except Exception: |
| | pass |
| | _GLOBAL_REMOVER = None |
| | _GLOBAL_ON_DEVICE = "cpu" |
| | _GLOBAL_VRAM_DELTA_BYTES = 0 |
| | except Exception: |
| | pass |
| |
|
| | _cuda_soft_cleanup() |
| | return (True,) |
| |
|
| |
|
| | class Run_InspyrenetRembg_Global: |
| | """ |
| | No settings. Same behavior as InspyrenetRembg3, but uses the global remover and won't crash on OOM. |
| | On failure/OOM, returns a visible passthrough (opaque RGBA), NOT an invisible image. |
| | """ |
| | def __init__(self): |
| | pass |
| |
|
| | @classmethod |
| | def INPUT_TYPES(s): |
| | return { |
| | "required": { |
| | "image": ("IMAGE",), |
| | } |
| | } |
| |
|
| | RETURN_TYPES = ("IMAGE",) |
| | FUNCTION = "remove_background" |
| | CATEGORY = "image" |
| |
|
| | def remove_background(self, image): |
| | _get_global_remover() |
| |
|
| | img_list = [] |
| | for img in tqdm(image, "Run InspyrenetRembg Global"): |
| | pil_in = tensor2pil(img) |
| |
|
| | |
| | fallback = _force_rgba_opaque(pil_in) |
| |
|
| | |
| | pil_rgb = _rgba_to_rgb_on_white(pil_in) |
| |
|
| | out_pil = _run_global_rgba_no_crash(pil_rgb, fallback) |
| | out = pil2tensor(out_pil) |
| | img_list.append(out) |
| |
|
| | del pil_in, fallback, pil_rgb, out_pil, out |
| |
|
| | img_stack = torch.cat(img_list, dim=0) |
| | return (img_stack,) |
| |
|
| |
|
| | NODE_CLASS_MAPPINGS = { |
| | "InspyrenetRembg2": InspyrenetRembg2, |
| | "InspyrenetRembg3": InspyrenetRembg3, |
| |
|
| | "Load_Inspyrenet_Global": Load_Inspyrenet_Global, |
| | "Remove_Inspyrenet_Global": Remove_Inspyrenet_Global, |
| | "Run_InspyrenetRembg_Global": Run_InspyrenetRembg_Global, |
| | } |
| |
|
| | NODE_DISPLAY_NAME_MAPPINGS = { |
| | "InspyrenetRembg2": "Inspyrenet Rembg2", |
| | "InspyrenetRembg3": "Inspyrenet Rembg3", |
| |
|
| | "Load_Inspyrenet_Global": "Load Inspyrenet Global", |
| | "Remove_Inspyrenet_Global": "Remove Inspyrenet Global", |
| | "Run_InspyrenetRembg_Global": "Run InspyrenetRembg Global", |
| | } |