trt_playground / AILab_RMBG.py
saliacoel's picture
Upload AILab_RMBG.py
d86b613 verified
# ComfyUI-RMBG
# This custom node for ComfyUI provides functionality for background removal using various models,
# including RMBG-2.0, INSPYRENET, BEN, BEN2 and BIREFNET-HR. It leverages deep learning techniques
# to process images and generate masks for background removal.
#
# Models License Notice:
# - RMBG-2.0: Apache-2.0 License (https://huggingface.co/briaai/RMBG-2.0)
# - INSPYRENET: MIT License (https://github.com/plemeri/InSPyReNet)
# - BEN: Apache-2.0 License (https://huggingface.co/PramaLLC/BEN)
# - BEN2: Apache-2.0 License (https://huggingface.co/PramaLLC/BEN2)
#
# This integration script follows GPL-3.0 License.
# When using or modifying this code, please respect both the original model licenses
# and this integration's license terms.
#
# Source: https://github.com/1038lab/ComfyUI-RMBG
import os
import sys
import shutil
import types
import threading
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image, ImageFilter
from torchvision import transforms
from huggingface_hub import hf_hub_download
from transformers import AutoModelForImageSegmentation
import importlib.util
import folder_paths
device = "cuda" if torch.cuda.is_available() else "cpu"
folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG"))
# Model configuration
AVAILABLE_MODELS = {
"RMBG-2.0": {
"type": "rmbg",
"repo_id": "1038lab/RMBG-2.0",
"files": {
"config.json": "config.json",
"model.safetensors": "model.safetensors",
"birefnet.py": "birefnet.py",
"BiRefNet_config.py": "BiRefNet_config.py"
},
"cache_dir": "RMBG-2.0"
},
"INSPYRENET": {
"type": "inspyrenet",
"repo_id": "1038lab/inspyrenet",
"files": {
"inspyrenet.safetensors": "inspyrenet.safetensors"
},
"cache_dir": "INSPYRENET"
},
"BEN": {
"type": "ben",
"repo_id": "1038lab/BEN",
"files": {
"model.py": "model.py",
"BEN_Base.pth": "BEN_Base.pth"
},
"cache_dir": "BEN"
},
"BEN2": {
"type": "ben2",
"repo_id": "1038lab/BEN2",
"files": {
"BEN2_Base.pth": "BEN2_Base.pth",
"BEN2.py": "BEN2.py"
},
"cache_dir": "BEN2"
}
}
# Utility functions
def tensor2pil(image: torch.Tensor) -> Image.Image:
return Image.fromarray(np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
def pil2tensor(image: Image.Image) -> torch.Tensor:
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
def handle_model_error(message: str):
print(f"[RMBG ERROR] {message}")
raise RuntimeError(message)
class BaseModelLoader:
"""
Base loader with:
- disk cache handling
- single-instance model lifetime management
- thread-safety (so multiple RMBG nodes can share one loader safely)
"""
def __init__(self):
self.model = None
self.current_model_version = None
self.base_cache_dir = os.path.join(folder_paths.models_dir, "RMBG")
# Protects load/clear/inference for shared singleton loaders
self._lock = threading.RLock()
# Protects downloads (avoid parallel hf_hub_download collisions)
self._download_lock = threading.Lock()
def get_cache_dir(self, model_name: str) -> str:
cache_path = os.path.join(self.base_cache_dir, AVAILABLE_MODELS[model_name]["cache_dir"])
os.makedirs(cache_path, exist_ok=True)
return cache_path
def check_model_cache(self, model_name: str):
model_info = AVAILABLE_MODELS[model_name]
cache_dir = self.get_cache_dir(model_name)
if not os.path.exists(cache_dir):
return False, "Model directory not found"
missing_files = []
for filename in model_info["files"].keys():
if not os.path.exists(os.path.join(cache_dir, model_info["files"][filename])):
missing_files.append(filename)
if missing_files:
return False, f"Missing model files: {', '.join(missing_files)}"
return True, "Model cache verified"
def download_model(self, model_name: str):
model_info = AVAILABLE_MODELS[model_name]
cache_dir = self.get_cache_dir(model_name)
# Only one download at a time per loader instance
with self._download_lock:
try:
os.makedirs(cache_dir, exist_ok=True)
print(f"Downloading {model_name} model files...")
for filename in model_info["files"].keys():
print(f"Downloading {filename}...")
hf_hub_download(
repo_id=model_info["repo_id"],
filename=filename,
local_dir=cache_dir
)
return True, "Model files downloaded successfully"
except Exception as e:
return False, f"Error downloading model files: {str(e)}"
def clear_model(self):
with self._lock:
if self.model is not None:
try:
self.model.cpu()
except Exception:
pass
try:
del self.model
except Exception:
pass
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.model = None
self.current_model_version = None
class RMBGModel(BaseModelLoader):
def __init__(self):
super().__init__()
def load_model(self, model_name: str):
with self._lock:
if self.current_model_version == model_name and self.model is not None:
return
self.clear_model()
cache_dir = self.get_cache_dir(model_name)
try:
# Primary path: Modern transformers compatibility mode (optimized for newer versions)
try:
from transformers import PreTrainedModel
import json
config_path = os.path.join(cache_dir, "config.json")
with open(config_path, "r") as f:
_config = json.load(f)
birefnet_path = os.path.join(cache_dir, "birefnet.py")
BiRefNetConfig_path = os.path.join(cache_dir, "BiRefNet_config.py")
# Load the BiRefNetConfig
config_spec = importlib.util.spec_from_file_location("BiRefNetConfig", BiRefNetConfig_path)
config_module = importlib.util.module_from_spec(config_spec)
sys.modules["BiRefNetConfig"] = config_module
config_spec.loader.exec_module(config_module)
# Fix and load birefnet module
with open(birefnet_path, "r") as f:
birefnet_content = f.read()
birefnet_content = birefnet_content.replace(
"from .BiRefNet_config import BiRefNetConfig",
"from BiRefNetConfig import BiRefNetConfig"
)
module_name = f"custom_birefnet_model_{hash(birefnet_path)}"
module = types.ModuleType(module_name)
sys.modules[module_name] = module
exec(birefnet_content, module.__dict__)
for attr_name in dir(module):
attr = getattr(module, attr_name)
if isinstance(attr, type) and issubclass(attr, PreTrainedModel) and attr != PreTrainedModel:
BiRefNetConfig = getattr(config_module, "BiRefNetConfig")
model_config = BiRefNetConfig()
self.model = attr(model_config)
weights_path = os.path.join(cache_dir, "model.safetensors")
try:
try:
import safetensors.torch
self.model.load_state_dict(safetensors.torch.load_file(weights_path))
except ImportError:
from transformers.modeling_utils import load_state_dict
state_dict = load_state_dict(weights_path)
self.model.load_state_dict(state_dict)
except Exception as load_error:
pytorch_weights = os.path.join(cache_dir, "pytorch_model.bin")
if os.path.exists(pytorch_weights):
self.model.load_state_dict(torch.load(pytorch_weights, map_location="cpu"))
else:
raise RuntimeError(f"Failed to load weights: {str(load_error)}")
break
if self.model is None:
raise RuntimeError("Could not find suitable model class")
except Exception as modern_e:
print("[RMBG INFO] Using standard transformers loading (fallback mode)...")
try:
self.model = AutoModelForImageSegmentation.from_pretrained(
cache_dir,
trust_remote_code=True,
local_files_only=True
)
except Exception as standard_e:
handle_model_error(
"Failed to load model with both modern and standard methods. "
f"Modern error: {str(modern_e)}. Standard error: {str(standard_e)}"
)
except Exception as e:
handle_model_error(f"Error loading model: {str(e)}")
self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
torch.set_float32_matmul_precision("high")
self.model.to(device)
self.current_model_version = model_name
def process_image(self, images, model_name: str, params: dict):
# Lock the full path so multiple RMBG nodes don't run inference at the same time
# on the shared singleton model (avoids race/memory spikes).
with self._lock:
try:
self.load_model(model_name)
transform_image = transforms.Compose([
transforms.Resize((params["process_res"], params["process_res"])),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
if isinstance(images, torch.Tensor):
if len(images.shape) == 3:
images = [images]
else:
images = [img for img in images]
original_sizes = [tensor2pil(img).size for img in images]
input_tensors = [transform_image(tensor2pil(img)).unsqueeze(0) for img in images]
input_batch = torch.cat(input_tensors, dim=0).to(device)
with torch.no_grad():
outputs = self.model(input_batch)
if isinstance(outputs, list) and len(outputs) > 0:
results = outputs[-1].sigmoid().cpu()
elif isinstance(outputs, dict) and "logits" in outputs:
results = outputs["logits"].sigmoid().cpu()
elif isinstance(outputs, torch.Tensor):
results = outputs.sigmoid().cpu()
else:
try:
if hasattr(outputs, "last_hidden_state"):
results = outputs.last_hidden_state.sigmoid().cpu()
else:
results = None
for _k, v in outputs.items():
if isinstance(v, torch.Tensor):
results = v.sigmoid().cpu()
break
if results is None:
handle_model_error("Unable to recognize model output format")
except Exception:
handle_model_error("Unable to recognize model output format")
masks = []
for result, (orig_w, orig_h) in zip(results, original_sizes):
result = result.squeeze()
result = result * (1 + (1 - params["sensitivity"]))
result = torch.clamp(result, 0, 1)
result = F.interpolate(
result.unsqueeze(0).unsqueeze(0),
size=(orig_h, orig_w),
mode="bilinear"
).squeeze()
masks.append(tensor2pil(result))
# Hint to GC / free references (PyTorch will still keep a caching allocator)
del input_batch
return masks
except Exception as e:
handle_model_error(f"Error in batch processing: {str(e)}")
class InspyrenetModel(BaseModelLoader):
def __init__(self):
super().__init__()
def load_model(self, model_name: str):
with self._lock:
if self.current_model_version == model_name and self.model is not None:
return
self.clear_model()
try:
import transparent_background
self.model = transparent_background.Remover()
self.current_model_version = model_name
except Exception as e:
handle_model_error(f"Failed to initialize transparent_background: {str(e)}")
def process_image(self, image, model_name: str, params: dict):
with self._lock:
try:
self.load_model(model_name)
orig_image = tensor2pil(image)
w, h = orig_image.size
aspect_ratio = h / w
new_w = params["process_res"]
new_h = int(params["process_res"] * aspect_ratio)
resized_image = orig_image.resize((new_w, new_h), Image.LANCZOS)
foreground = self.model.process(resized_image, type="rgba")
foreground = foreground.resize((w, h), Image.LANCZOS)
mask = foreground.split()[-1]
return mask
except Exception as e:
handle_model_error(f"Error in Inspyrenet processing: {str(e)}")
class BENModel(BaseModelLoader):
def __init__(self):
super().__init__()
def load_model(self, model_name: str):
with self._lock:
if self.current_model_version == model_name and self.model is not None:
return
self.clear_model()
cache_dir = self.get_cache_dir(model_name)
model_path = os.path.join(cache_dir, "model.py")
module_name = f"custom_ben_model_{hash(model_path)}"
spec = importlib.util.spec_from_file_location(module_name, model_path)
ben_module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = ben_module
spec.loader.exec_module(ben_module)
model_weights_path = os.path.join(cache_dir, "BEN_Base.pth")
self.model = ben_module.BEN_Base()
self.model.loadcheckpoints(model_weights_path)
self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
torch.set_float32_matmul_precision("high")
self.model.to(device)
self.current_model_version = model_name
def process_image(self, image, model_name: str, params: dict):
with self._lock:
try:
self.load_model(model_name)
orig_image = tensor2pil(image)
w, h = orig_image.size
aspect_ratio = h / w
new_w = params["process_res"]
new_h = int(params["process_res"] * aspect_ratio)
resized_image = orig_image.resize((new_w, new_h), Image.LANCZOS)
processed_input = resized_image.convert("RGBA")
with torch.no_grad():
_, foreground = self.model.inference(processed_input)
foreground = foreground.resize((w, h), Image.LANCZOS)
mask = foreground.split()[-1]
return mask
except Exception as e:
handle_model_error(f"Error in BEN processing: {str(e)}")
class BEN2Model(BaseModelLoader):
def __init__(self):
super().__init__()
def load_model(self, model_name: str):
with self._lock:
if self.current_model_version == model_name and self.model is not None:
return
self.clear_model()
try:
cache_dir = self.get_cache_dir(model_name)
model_path = os.path.join(cache_dir, "BEN2.py")
module_name = f"custom_ben2_model_{hash(model_path)}"
spec = importlib.util.spec_from_file_location(module_name, model_path)
ben2_module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = ben2_module
spec.loader.exec_module(ben2_module)
model_weights_path = os.path.join(cache_dir, "BEN2_Base.pth")
self.model = ben2_module.BEN_Base()
self.model.loadcheckpoints(model_weights_path)
self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
torch.set_float32_matmul_precision("high")
self.model.to(device)
self.current_model_version = model_name
except Exception as e:
handle_model_error(f"Error loading BEN2 model: {str(e)}")
def process_image(self, images, model_name: str, params: dict):
with self._lock:
try:
self.load_model(model_name)
if isinstance(images, torch.Tensor):
if len(images.shape) == 3:
images = [images]
else:
images = [img for img in images]
batch_size = 3
all_masks = []
for i in range(0, len(images), batch_size):
batch_images = images[i:i + batch_size]
batch_pil_images = []
original_sizes = []
for img in batch_images:
orig_image = tensor2pil(img)
w, h = orig_image.size
original_sizes.append((w, h))
aspect_ratio = h / w
new_w = params["process_res"]
new_h = int(params["process_res"] * aspect_ratio)
resized_image = orig_image.resize((new_w, new_h), Image.LANCZOS)
processed_input = resized_image.convert("RGBA")
batch_pil_images.append(processed_input)
with torch.no_grad():
try:
foregrounds = self.model.inference(batch_pil_images)
if not isinstance(foregrounds, list):
foregrounds = [foregrounds]
except Exception as e:
handle_model_error(f"Error in BEN2 inference: {str(e)}")
for foreground, (orig_w, orig_h) in zip(foregrounds, original_sizes):
foreground = foreground.resize((orig_w, orig_h), Image.LANCZOS)
mask = foreground.split()[-1]
all_masks.append(mask)
if len(all_masks) == 1:
return all_masks[0]
return all_masks
except Exception as e:
handle_model_error(f"Error in BEN2 processing: {str(e)}")
def refine_foreground(image_bchw: torch.Tensor, masks_b1hw: torch.Tensor) -> torch.Tensor:
b, c, h, w = image_bchw.shape
if b != masks_b1hw.shape[0]:
raise ValueError("images and masks must have the same batch size")
image_np = image_bchw.cpu().numpy()
mask_np = masks_b1hw.cpu().numpy()
refined_fg = []
for i in range(b):
mask = mask_np[i, 0]
thresh = 0.45
mask_binary = (mask > thresh).astype(np.float32)
edge_blur = cv2.GaussianBlur(mask_binary, (3, 3), 0)
transition_mask = np.logical_and(mask > 0.05, mask < 0.95)
alpha = 0.85
mask_refined = np.where(
transition_mask,
alpha * mask + (1 - alpha) * edge_blur,
mask_binary
)
edge_region = np.logical_and(mask > 0.2, mask < 0.8)
mask_refined = np.where(edge_region, mask_refined * 0.98, mask_refined)
result = []
for cc in range(image_np.shape[1]):
channel = image_np[i, cc]
refined = channel * mask_refined
result.append(refined)
refined_fg.append(np.stack(result))
return torch.from_numpy(np.stack(refined_fg))
# -----------------------------------------------------------------------------
# Process-wide SINGLETON model loaders.
# This is the key change that prevents RMBG-2.0 from being loaded into VRAM
# multiple times when you have multiple RMBG nodes in the same workflow.
# -----------------------------------------------------------------------------
_SHARED_MODEL_LOADERS = {
"RMBG-2.0": RMBGModel(),
"INSPYRENET": InspyrenetModel(),
"BEN": BENModel(),
"BEN2": BEN2Model()
}
class RMBG:
def __init__(self):
# IMPORTANT: shared loaders across all RMBG node instances
self.models = _SHARED_MODEL_LOADERS
@classmethod
def INPUT_TYPES(s):
tooltips = {
"image": "Input image to be processed for background removal.",
"model": "Select the background removal model to use.",
"sensitivity": "Adjust mask detection strength (1.0 = default).",
"process_res": "Processing resolution (higher = more VRAM).",
"mask_blur": "Blur mask edges (0 = none).",
"mask_offset": "Expand/shrink mask boundary (+ expand / - shrink).",
"background": "Alpha = transparent output. Color = composite onto solid background.",
"background_color": "Only used when background='Color'. Ignored when background='Alpha' (transparent).",
"invert_output": "Invert mask output.",
"refine_foreground": "Fast Foreground Colour Estimation for better transparency."
}
return {
"required": {
"image": ("IMAGE", {"tooltip": tooltips["image"]}),
# Explicit default to RMBG-2.0
"model": (list(AVAILABLE_MODELS.keys()), {"default": "RMBG-2.0", "tooltip": tooltips["model"]}),
},
"optional": {
"sensitivity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": tooltips["sensitivity"]}),
"process_res": ("INT", {"default": 1024, "min": 256, "max": 2048, "step": 8, "tooltip": tooltips["process_res"]}),
"mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1, "tooltip": tooltips["mask_blur"]}),
"mask_offset": ("INT", {"default": 0, "min": -64, "max": 64, "step": 1, "tooltip": tooltips["mask_offset"]}),
"invert_output": ("BOOLEAN", {"default": False, "tooltip": tooltips["invert_output"]}),
"refine_foreground": ("BOOLEAN", {"default": False, "tooltip": tooltips["refine_foreground"]}),
"background": (["Alpha", "Color"], {"default": "Alpha", "tooltip": tooltips["background"]}),
# NOTE: ignored if background == "Alpha"
"background_color": ("COLORCODE", {"default": "#222222", "tooltip": tooltips["background_color"]}),
}
}
RETURN_TYPES = ("IMAGE", "MASK", "IMAGE")
RETURN_NAMES = ("IMAGE", "MASK", "MASK_IMAGE")
FUNCTION = "process_image"
CATEGORY = "🧪AILab/🧽RMBG"
def process_image(self, image, model, **params):
try:
processed_images = []
processed_masks = []
model_instance = self.models[model]
cache_status, message = model_instance.check_model_cache(model)
if not cache_status:
print(f"Cache check: {message}")
print("Downloading required model files...")
download_status, download_message = model_instance.download_model(model)
if not download_status:
handle_model_error(download_message)
print("Model files downloaded successfully")
model_type = AVAILABLE_MODELS[model]["type"]
def _process_pair(img, mask):
if isinstance(mask, list):
masks = [m.convert("L") for m in mask if isinstance(m, Image.Image)]
mask_local = masks[0] if masks else None
elif isinstance(mask, Image.Image):
mask_local = mask.convert("L")
else:
mask_local = mask
mask_tensor_local = pil2tensor(mask_local)
mask_tensor_local = mask_tensor_local * (1 + (1 - params["sensitivity"]))
mask_tensor_local = torch.clamp(mask_tensor_local, 0, 1)
mask_img_local = tensor2pil(mask_tensor_local)
if params["mask_blur"] > 0:
mask_img_local = mask_img_local.filter(ImageFilter.GaussianBlur(radius=params["mask_blur"]))
if params["mask_offset"] != 0:
if params["mask_offset"] > 0:
for _ in range(params["mask_offset"]):
mask_img_local = mask_img_local.filter(ImageFilter.MaxFilter(3))
else:
for _ in range(-params["mask_offset"]):
mask_img_local = mask_img_local.filter(ImageFilter.MinFilter(3))
if params["invert_output"]:
mask_img_local = Image.fromarray(255 - np.array(mask_img_local))
img_tensor_local = torch.from_numpy(np.array(tensor2pil(img))).permute(2, 0, 1).unsqueeze(0) / 255.0
mask_tensor_b1hw = torch.from_numpy(np.array(mask_img_local)).unsqueeze(0).unsqueeze(0) / 255.0
orig_image_local = tensor2pil(img)
if params.get("refine_foreground", False):
refined_fg_local = refine_foreground(img_tensor_local, mask_tensor_b1hw)
refined_fg_local = tensor2pil(refined_fg_local[0].permute(1, 2, 0))
r, g, b = refined_fg_local.split()
foreground_local = Image.merge("RGBA", (r, g, b, mask_img_local))
else:
orig_rgba_local = orig_image_local.convert("RGBA")
r, g, b, _a = orig_rgba_local.split()
foreground_local = Image.merge("RGBA", (r, g, b, mask_img_local))
# If background == "Alpha", the output is RGBA with transparent background.
# background_color is ignored in that mode.
if params["background"] == "Color":
def hex_to_rgba(hex_color: str):
hex_color = hex_color.lstrip("#")
if len(hex_color) == 6:
rr = int(hex_color[0:2], 16)
gg = int(hex_color[2:4], 16)
bb = int(hex_color[4:6], 16)
aa = 255
elif len(hex_color) == 8:
rr = int(hex_color[0:2], 16)
gg = int(hex_color[2:4], 16)
bb = int(hex_color[4:6], 16)
aa = int(hex_color[6:8], 16)
else:
raise ValueError("Invalid color format")
return (rr, gg, bb, aa)
background_color = params.get("background_color", "#222222")
rgba = hex_to_rgba(background_color)
bg_image = Image.new("RGBA", orig_image_local.size, rgba)
composite_image = Image.alpha_composite(bg_image, foreground_local)
processed_images.append(pil2tensor(composite_image.convert("RGB")))
else:
processed_images.append(pil2tensor(foreground_local))
processed_masks.append(pil2tensor(mask_img_local))
if model_type in ("rmbg", "ben2"):
images_list = [img for img in image]
chunk_size = 4
for start in range(0, len(images_list), chunk_size):
batch_imgs = images_list[start:start + chunk_size]
masks = model_instance.process_image(batch_imgs, model, params)
if isinstance(masks, Image.Image):
masks = [masks]
for img_item, mask_item in zip(batch_imgs, masks):
_process_pair(img_item, mask_item)
else:
for img in image:
mask = model_instance.process_image(img, model, params)
_process_pair(img, mask)
mask_images = []
for mask_tensor in processed_masks:
mask_image = mask_tensor.reshape((-1, 1, mask_tensor.shape[-2], mask_tensor.shape[-1])) \
.movedim(1, -1) \
.expand(-1, -1, -1, 3)
mask_images.append(mask_image)
mask_image_output = torch.cat(mask_images, dim=0)
return (
torch.cat(processed_images, dim=0),
torch.cat(processed_masks, dim=0),
mask_image_output
)
except Exception as e:
handle_model_error(f"Error in image processing: {str(e)}")
empty_mask = torch.zeros((image.shape[0], image.shape[2], image.shape[3]))
empty_mask_image = empty_mask.reshape((-1, 1, empty_mask.shape[-2], empty_mask.shape[-1])) \
.movedim(1, -1) \
.expand(-1, -1, -1, 3)
return (image, empty_mask, empty_mask_image)
NODE_CLASS_MAPPINGS = {
"RMBG": RMBG
}
NODE_DISPLAY_NAME_MAPPINGS = {
"RMBG": "Remove Background (RMBG)"
}