MyCustomNodes / Inspyrenet_Rembg2.py
saliacoel's picture
Update Inspyrenet_Rembg2.py
71c89c1 verified
from __future__ import annotations
from PIL import Image
import os
import urllib.request
import gc
import threading
from typing import Dict, Tuple, Optional
import torch
import numpy as np
from transparent_background import Remover
from tqdm import tqdm
# Optional: ComfyUI memory manager (present inside ComfyUI)
try:
import comfy.model_management as comfy_mm
except Exception:
comfy_mm = None
CKPT_PATH = "/root/.transparent-background/ckpt_base.pth"
CKPT_URL = "https://huggingface.co/saliacoel/x/resolve/main/ckpt_base.pth"
def _ensure_ckpt_base():
try:
if os.path.isfile(CKPT_PATH) and os.path.getsize(CKPT_PATH) > 0:
return
except Exception:
pass
os.makedirs(os.path.dirname(CKPT_PATH), exist_ok=True)
tmp_path = CKPT_PATH + ".tmp"
try:
with urllib.request.urlopen(CKPT_URL) as resp:
total = resp.headers.get("Content-Length")
total = int(total) if total is not None else None
with open(tmp_path, "wb") as f:
if total:
with tqdm(
total=total,
unit="B",
unit_scale=True,
desc="Downloading ckpt_base.pth",
) as pbar:
while True:
chunk = resp.read(1024 * 1024)
if not chunk:
break
f.write(chunk)
pbar.update(len(chunk))
else:
while True:
chunk = resp.read(1024 * 1024)
if not chunk:
break
f.write(chunk)
os.replace(tmp_path, CKPT_PATH)
finally:
if os.path.isfile(tmp_path):
try:
os.remove(tmp_path)
except Exception:
pass
# Tensor to PIL
def tensor2pil(image: torch.Tensor) -> Image.Image:
arr = image.detach().cpu().numpy()
if arr.ndim == 4 and arr.shape[0] == 1:
arr = arr[0]
arr = np.clip(255.0 * arr, 0, 255).astype(np.uint8)
return Image.fromarray(arr)
# Convert PIL to Tensor
def pil2tensor(image: Image.Image) -> torch.Tensor:
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
def _rgba_to_rgb_on_white(pil_img: Image.Image) -> Image.Image:
if pil_img.mode == "RGBA":
bg = Image.new("RGBA", pil_img.size, (255, 255, 255, 255))
composited = Image.alpha_composite(bg, pil_img)
return composited.convert("RGB")
if pil_img.mode != "RGB":
return pil_img.convert("RGB")
return pil_img
def _force_rgba_opaque(pil_img: Image.Image) -> Image.Image:
"""
Opaque RGBA fallback (alpha=255), so you never get an "invisible" output.
"""
rgba = pil_img.convert("RGBA")
r, g, b, _a = rgba.split()
a = Image.new("L", rgba.size, 255)
return Image.merge("RGBA", (r, g, b, a))
def _alpha_is_all_zero(pil_img: Image.Image) -> bool:
"""
True if RGBA image alpha channel is entirely 0.
"""
if pil_img.mode != "RGBA":
return False
try:
extrema = pil_img.getextrema() # ((min,max),(min,max),(min,max),(min,max))
return extrema[3][1] == 0
except Exception:
return False
def _is_oom_error(e: BaseException) -> bool:
oom_cuda_cls = getattr(getattr(torch, "cuda", None), "OutOfMemoryError", None)
if oom_cuda_cls is not None and isinstance(e, oom_cuda_cls):
return True
oom_torch_cls = getattr(torch, "OutOfMemoryError", None)
if oom_torch_cls is not None and isinstance(e, oom_torch_cls):
return True
msg = str(e).lower()
if "out of memory" in msg:
return True
if "allocation on device" in msg:
return True
return ("cuda" in msg or "cublas" in msg or "hip" in msg) and ("memory" in msg)
def _cuda_soft_cleanup() -> None:
try:
gc.collect()
except Exception:
pass
if torch.cuda.is_available():
try:
torch.cuda.synchronize()
except Exception:
pass
try:
torch.cuda.empty_cache()
except Exception:
pass
try:
torch.cuda.ipc_collect()
except Exception:
pass
def _comfy_soft_empty_cache() -> None:
if comfy_mm is None:
return
if hasattr(comfy_mm, "soft_empty_cache"):
try:
comfy_mm.soft_empty_cache(force=True)
except TypeError:
try:
comfy_mm.soft_empty_cache()
except Exception:
pass
except Exception:
pass
def _get_comfy_torch_device() -> torch.device:
"""
Always prefer ComfyUI's chosen device.
"""
if comfy_mm is not None and hasattr(comfy_mm, "get_torch_device"):
try:
d = comfy_mm.get_torch_device()
if isinstance(d, torch.device):
return d
return torch.device(str(d))
except Exception:
pass
if torch.cuda.is_available():
return torch.device("cuda:0")
return torch.device("cpu")
def _set_current_cuda_device(dev: torch.device) -> None:
"""
Make sure mem_get_info() measurements are on the same device ComfyUI uses.
"""
if dev.type == "cuda":
try:
if dev.index is not None:
torch.cuda.set_device(dev.index)
except Exception:
pass
def _cuda_free_bytes_on(dev: torch.device) -> Optional[int]:
if dev.type != "cuda" or not torch.cuda.is_available():
return None
try:
_set_current_cuda_device(dev)
free_b, _total_b = torch.cuda.mem_get_info()
return int(free_b)
except Exception:
return None
def _comfy_unload_one_smallest_model() -> bool:
"""
Best-effort "smallest-first" eviction of one ComfyUI-tracked loaded model.
If ComfyUI internals differ, this may do nothing (and we fall back to unload_all_models()).
"""
if comfy_mm is None:
return False
if not hasattr(comfy_mm, "current_loaded_models"):
return False
try:
cur_dev = _get_comfy_torch_device()
except Exception:
cur_dev = None
models = []
try:
for lm in list(comfy_mm.current_loaded_models):
try:
# Prefer same device
lm_dev = getattr(lm, "device", None)
if cur_dev is not None and lm_dev is not None and str(lm_dev) != str(cur_dev):
continue
mem_fn = getattr(lm, "model_loaded_memory", None)
if callable(mem_fn):
mem = int(mem_fn())
else:
mem = int(getattr(lm, "loaded_memory", 0) or 0)
if mem > 0:
models.append((mem, lm))
except Exception:
continue
except Exception:
return False
if not models:
return False
models.sort(key=lambda x: x[0]) # smallest first
_mem, lm = models[0]
try:
unload_fn = getattr(lm, "model_unload", None)
if callable(unload_fn):
try:
unload_fn(unpatch_weights=True)
except TypeError:
unload_fn()
except Exception:
pass
# Cleanup hook if present
try:
cleanup = getattr(comfy_mm, "cleanup_models", None)
if callable(cleanup):
cleanup()
except Exception:
pass
_comfy_soft_empty_cache()
_cuda_soft_cleanup()
return True
def _comfy_unload_all_models() -> None:
if comfy_mm is None:
return
if hasattr(comfy_mm, "unload_all_models"):
try:
comfy_mm.unload_all_models()
except Exception:
pass
_comfy_soft_empty_cache()
_cuda_soft_cleanup()
# -----------------------------------------------------------------------------
# Existing singleton cache for Rembg2/Rembg3 (your original)
# -----------------------------------------------------------------------------
_REMOVER_CACHE: Dict[Tuple[bool], Remover] = {}
_REMOVER_RUN_LOCKS: Dict[Tuple[bool], threading.Lock] = {}
_CACHE_LOCK = threading.Lock()
def _get_remover(jit: bool = False) -> tuple[Remover, threading.Lock]:
key = (jit,)
with _CACHE_LOCK:
inst = _REMOVER_CACHE.get(key)
if inst is None:
_ensure_ckpt_base()
try:
inst = Remover(jit=jit) if jit else Remover()
except BaseException as e:
if _is_oom_error(e):
_cuda_soft_cleanup()
raise
_REMOVER_CACHE[key] = inst
run_lock = _REMOVER_RUN_LOCKS.get(key)
if run_lock is None:
run_lock = threading.Lock()
_REMOVER_RUN_LOCKS[key] = run_lock
return inst, run_lock
# -----------------------------------------------------------------------------
# GLOBAL remover (for Load/Remove/Run Global nodes)
# -----------------------------------------------------------------------------
_GLOBAL_LOCK = threading.Lock()
_GLOBAL_RUN_LOCK = threading.Lock()
_GLOBAL_REMOVER: Optional[Remover] = None
_GLOBAL_ON_DEVICE: str = "cpu"
_GLOBAL_VRAM_DELTA_BYTES: int = 0
def _create_global_remover_cpu() -> Remover:
"""
Create the Remover configured like InspyrenetRembg3 (jit=False),
but *try* to force CPU init to avoid VRAM OOM during creation.
"""
_ensure_ckpt_base()
# Prefer constructing on CPU if supported by this library version.
try:
r = Remover(device="cpu") # type: ignore[arg-type]
try:
r.device = "cpu"
except Exception:
pass
return r
except TypeError:
pass
# Fallback: construct default and immediately offload to CPU
r = Remover()
try:
if hasattr(r, "model"):
r.model = r.model.to("cpu")
r.device = "cpu"
except Exception:
pass
_cuda_soft_cleanup()
return r
def _get_global_remover() -> Remover:
global _GLOBAL_REMOVER, _GLOBAL_ON_DEVICE
with _GLOBAL_LOCK:
if _GLOBAL_REMOVER is None:
_GLOBAL_REMOVER = _create_global_remover_cpu()
_GLOBAL_ON_DEVICE = str(getattr(_GLOBAL_REMOVER, "device", "cpu"))
return _GLOBAL_REMOVER
def _move_global_to_cpu() -> None:
global _GLOBAL_ON_DEVICE
r = _get_global_remover()
try:
if hasattr(r, "model"):
r.model = r.model.to("cpu")
r.device = "cpu"
_GLOBAL_ON_DEVICE = "cpu"
except Exception:
pass
_cuda_soft_cleanup()
def _load_global_to_comfy_cuda_no_crash(max_evictions: int = 32) -> bool:
"""
Load the global remover into VRAM on ComfyUI's chosen CUDA device.
Never crashes on OOM: evicts smallest model first, then unload_all as last resort.
Also records a best-effort VRAM delta.
"""
global _GLOBAL_ON_DEVICE, _GLOBAL_VRAM_DELTA_BYTES
r = _get_global_remover()
dev = _get_comfy_torch_device()
if dev.type != "cuda" or not torch.cuda.is_available():
_move_global_to_cpu()
return False
# Already on CUDA?
cur_dev = str(getattr(r, "device", "") or "")
if cur_dev.startswith("cuda"):
_GLOBAL_ON_DEVICE = cur_dev
return True
_set_current_cuda_device(dev)
free_before = _cuda_free_bytes_on(dev)
for _ in range(max_evictions + 1):
try:
# Move model to the SAME device ComfyUI uses
if hasattr(r, "model"):
r.model = r.model.to(dev)
r.device = str(dev)
_GLOBAL_ON_DEVICE = str(dev)
_comfy_soft_empty_cache()
_cuda_soft_cleanup()
free_after = _cuda_free_bytes_on(dev)
if free_before is not None and free_after is not None:
delta = max(0, int(free_before) - int(free_after))
if delta > 0:
_GLOBAL_VRAM_DELTA_BYTES = delta
return True
except BaseException as e:
if not _is_oom_error(e):
raise
_comfy_soft_empty_cache()
_cuda_soft_cleanup()
# Evict ONE smallest model; if that fails, unload all.
if not _comfy_unload_one_smallest_model():
_comfy_unload_all_models()
# Could not load
_move_global_to_cpu()
return False
def _run_global_rgba_no_crash(pil_rgb: Image.Image, fallback_rgba: Image.Image) -> Image.Image:
"""
Run remover.process() (rgba output), matching InspyrenetRembg3 behavior.
On OOM: evict models and retry, then CPU fallback.
If output alpha is fully transparent, return fallback (prevents "invisible" output).
"""
r = _get_global_remover()
# Try to keep it on CUDA (Comfy device) if possible; do not crash if not.
_load_global_to_comfy_cuda_no_crash()
# Attempt 1: whatever device we're on (likely CUDA)
try:
with _GLOBAL_RUN_LOCK:
with torch.inference_mode():
out = r.process(pil_rgb, type="rgba")
if _alpha_is_all_zero(out):
# Treat as failure -> prevents invisible output
return fallback_rgba
return out
except BaseException as e:
if not _is_oom_error(e):
raise
# OOM path: evict one smallest and retry (still on CUDA if we are)
_comfy_soft_empty_cache()
_cuda_soft_cleanup()
_comfy_unload_one_smallest_model()
try:
with _GLOBAL_RUN_LOCK:
with torch.inference_mode():
out = r.process(pil_rgb, type="rgba")
if _alpha_is_all_zero(out):
return fallback_rgba
return out
except BaseException as e:
if not _is_oom_error(e):
raise
# OOM again: unload all comfy models and retry once
_comfy_unload_all_models()
try:
with _GLOBAL_RUN_LOCK:
with torch.inference_mode():
out = r.process(pil_rgb, type="rgba")
if _alpha_is_all_zero(out):
return fallback_rgba
return out
except BaseException as e:
if not _is_oom_error(e):
raise
# Final: CPU fallback
_move_global_to_cpu()
try:
with _GLOBAL_RUN_LOCK:
with torch.inference_mode():
out = r.process(pil_rgb, type="rgba")
if _alpha_is_all_zero(out):
return fallback_rgba
return out
except BaseException:
# Last resort: passthrough
return fallback_rgba
# -----------------------------------------------------------------------------
# Nodes
# -----------------------------------------------------------------------------
class InspyrenetRembg2:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"torchscript_jit": (["default", "on"],)
},
}
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "remove_background"
CATEGORY = "image"
def remove_background(self, image, torchscript_jit):
jit = (torchscript_jit != "default")
remover, run_lock = _get_remover(jit=jit)
img_list = []
for img in tqdm(image, "Inspyrenet Rembg2"):
pil_in = tensor2pil(img)
try:
with run_lock:
with torch.inference_mode():
mid = remover.process(pil_in, type="rgba")
except BaseException as e:
if _is_oom_error(e):
_cuda_soft_cleanup()
raise RuntimeError("InspyrenetRembg2: CUDA out of memory.") from e
raise
out = pil2tensor(mid)
img_list.append(out)
del pil_in, mid, out
img_stack = torch.cat(img_list, dim=0)
mask = img_stack[:, :, :, 3]
return (img_stack, mask)
class InspyrenetRembg3:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "remove_background"
CATEGORY = "image"
def remove_background(self, image):
remover, run_lock = _get_remover(jit=False)
img_list = []
for img in tqdm(image, "Inspyrenet Rembg3"):
pil_in = tensor2pil(img)
pil_rgb = _rgba_to_rgb_on_white(pil_in)
try:
with run_lock:
with torch.inference_mode():
mid = remover.process(pil_rgb, type="rgba")
except BaseException as e:
if _is_oom_error(e):
_cuda_soft_cleanup()
raise RuntimeError("InspyrenetRembg3: CUDA out of memory.") from e
raise
out = pil2tensor(mid)
img_list.append(out)
del pil_in, pil_rgb, mid, out
img_stack = torch.cat(img_list, dim=0)
return (img_stack,)
# -----------------------------------------------------------------------------
# NEW: Global nodes (simple, no user settings on Load/Run)
# -----------------------------------------------------------------------------
class Load_Inspyrenet_Global:
"""
No inputs. Creates the global remover (once) and moves it to ComfyUI's CUDA device (if possible).
Returns:
- loaded_ok (BOOLEAN)
- vram_delta_bytes (INT) best-effort (weights residency only; not peak inference)
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {"required": {}}
RETURN_TYPES = ("BOOLEAN", "INT")
FUNCTION = "load"
CATEGORY = "image"
def load(self):
_get_global_remover()
ok = _load_global_to_comfy_cuda_no_crash()
return (bool(ok), int(_GLOBAL_VRAM_DELTA_BYTES))
class Remove_Inspyrenet_Global:
"""
Offload global remover to CPU or delete it.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"action": (["offload_to_cpu", "delete_instance"],),
}
}
RETURN_TYPES = ("BOOLEAN",)
FUNCTION = "remove"
CATEGORY = "image"
def remove(self, action):
global _GLOBAL_REMOVER, _GLOBAL_ON_DEVICE, _GLOBAL_VRAM_DELTA_BYTES
if action == "offload_to_cpu":
_move_global_to_cpu()
return (True,)
# delete_instance
with _GLOBAL_LOCK:
try:
if _GLOBAL_REMOVER is not None:
try:
if hasattr(_GLOBAL_REMOVER, "model"):
_GLOBAL_REMOVER.model = _GLOBAL_REMOVER.model.to("cpu")
_GLOBAL_REMOVER.device = "cpu"
except Exception:
pass
_GLOBAL_REMOVER = None
_GLOBAL_ON_DEVICE = "cpu"
_GLOBAL_VRAM_DELTA_BYTES = 0
except Exception:
pass
_cuda_soft_cleanup()
return (True,)
class Run_InspyrenetRembg_Global:
"""
No settings. Same behavior as InspyrenetRembg3, but uses the global remover and won't crash on OOM.
On failure/OOM, returns a visible passthrough (opaque RGBA), NOT an invisible image.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "remove_background"
CATEGORY = "image"
def remove_background(self, image):
_get_global_remover()
img_list = []
for img in tqdm(image, "Run InspyrenetRembg Global"):
pil_in = tensor2pil(img)
# Visible fallback (never invisible)
fallback = _force_rgba_opaque(pil_in)
# Exactly like Rembg3 input path
pil_rgb = _rgba_to_rgb_on_white(pil_in)
out_pil = _run_global_rgba_no_crash(pil_rgb, fallback)
out = pil2tensor(out_pil)
img_list.append(out)
del pil_in, fallback, pil_rgb, out_pil, out
img_stack = torch.cat(img_list, dim=0)
return (img_stack,)
NODE_CLASS_MAPPINGS = {
"InspyrenetRembg2": InspyrenetRembg2,
"InspyrenetRembg3": InspyrenetRembg3,
"Load_Inspyrenet_Global": Load_Inspyrenet_Global,
"Remove_Inspyrenet_Global": Remove_Inspyrenet_Global,
"Run_InspyrenetRembg_Global": Run_InspyrenetRembg_Global,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"InspyrenetRembg2": "Inspyrenet Rembg2",
"InspyrenetRembg3": "Inspyrenet Rembg3",
"Load_Inspyrenet_Global": "Load Inspyrenet Global",
"Remove_Inspyrenet_Global": "Remove Inspyrenet Global",
"Run_InspyrenetRembg_Global": "Run InspyrenetRembg Global",
}