| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import sys |
| import shutil |
| import types |
| import threading |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import cv2 |
|
|
| from PIL import Image, ImageFilter |
| from torchvision import transforms |
| from huggingface_hub import hf_hub_download |
| from transformers import AutoModelForImageSegmentation |
|
|
| import importlib.util |
| import folder_paths |
|
|
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG")) |
|
|
| |
| AVAILABLE_MODELS = { |
| "RMBG-2.0": { |
| "type": "rmbg", |
| "repo_id": "1038lab/RMBG-2.0", |
| "files": { |
| "config.json": "config.json", |
| "model.safetensors": "model.safetensors", |
| "birefnet.py": "birefnet.py", |
| "BiRefNet_config.py": "BiRefNet_config.py" |
| }, |
| "cache_dir": "RMBG-2.0" |
| }, |
| "INSPYRENET": { |
| "type": "inspyrenet", |
| "repo_id": "1038lab/inspyrenet", |
| "files": { |
| "inspyrenet.safetensors": "inspyrenet.safetensors" |
| }, |
| "cache_dir": "INSPYRENET" |
| }, |
| "BEN": { |
| "type": "ben", |
| "repo_id": "1038lab/BEN", |
| "files": { |
| "model.py": "model.py", |
| "BEN_Base.pth": "BEN_Base.pth" |
| }, |
| "cache_dir": "BEN" |
| }, |
| "BEN2": { |
| "type": "ben2", |
| "repo_id": "1038lab/BEN2", |
| "files": { |
| "BEN2_Base.pth": "BEN2_Base.pth", |
| "BEN2.py": "BEN2.py" |
| }, |
| "cache_dir": "BEN2" |
| } |
| } |
|
|
| |
| def tensor2pil(image: torch.Tensor) -> Image.Image: |
| return Image.fromarray(np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) |
|
|
| def pil2tensor(image: Image.Image) -> torch.Tensor: |
| return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) |
|
|
| def handle_model_error(message: str): |
| print(f"[RMBG ERROR] {message}") |
| raise RuntimeError(message) |
|
|
|
|
| class BaseModelLoader: |
| """ |
| Base loader with: |
| - disk cache handling |
| - single-instance model lifetime management |
| - thread-safety (so multiple RMBG nodes can share one loader safely) |
| """ |
| def __init__(self): |
| self.model = None |
| self.current_model_version = None |
| self.base_cache_dir = os.path.join(folder_paths.models_dir, "RMBG") |
|
|
| |
| self._lock = threading.RLock() |
|
|
| |
| self._download_lock = threading.Lock() |
|
|
| def get_cache_dir(self, model_name: str) -> str: |
| cache_path = os.path.join(self.base_cache_dir, AVAILABLE_MODELS[model_name]["cache_dir"]) |
| os.makedirs(cache_path, exist_ok=True) |
| return cache_path |
|
|
| def check_model_cache(self, model_name: str): |
| model_info = AVAILABLE_MODELS[model_name] |
| cache_dir = self.get_cache_dir(model_name) |
|
|
| if not os.path.exists(cache_dir): |
| return False, "Model directory not found" |
|
|
| missing_files = [] |
| for filename in model_info["files"].keys(): |
| if not os.path.exists(os.path.join(cache_dir, model_info["files"][filename])): |
| missing_files.append(filename) |
|
|
| if missing_files: |
| return False, f"Missing model files: {', '.join(missing_files)}" |
|
|
| return True, "Model cache verified" |
|
|
| def download_model(self, model_name: str): |
| model_info = AVAILABLE_MODELS[model_name] |
| cache_dir = self.get_cache_dir(model_name) |
|
|
| |
| with self._download_lock: |
| try: |
| os.makedirs(cache_dir, exist_ok=True) |
| print(f"Downloading {model_name} model files...") |
|
|
| for filename in model_info["files"].keys(): |
| print(f"Downloading {filename}...") |
| hf_hub_download( |
| repo_id=model_info["repo_id"], |
| filename=filename, |
| local_dir=cache_dir |
| ) |
|
|
| return True, "Model files downloaded successfully" |
|
|
| except Exception as e: |
| return False, f"Error downloading model files: {str(e)}" |
|
|
| def clear_model(self): |
| with self._lock: |
| if self.model is not None: |
| try: |
| self.model.cpu() |
| except Exception: |
| pass |
|
|
| try: |
| del self.model |
| except Exception: |
| pass |
|
|
| import gc |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| self.model = None |
| self.current_model_version = None |
|
|
|
|
| class RMBGModel(BaseModelLoader): |
| def __init__(self): |
| super().__init__() |
|
|
| def load_model(self, model_name: str): |
| with self._lock: |
| if self.current_model_version == model_name and self.model is not None: |
| return |
|
|
| self.clear_model() |
|
|
| cache_dir = self.get_cache_dir(model_name) |
| try: |
| |
| try: |
| from transformers import PreTrainedModel |
| import json |
|
|
| config_path = os.path.join(cache_dir, "config.json") |
| with open(config_path, "r") as f: |
| _config = json.load(f) |
|
|
| birefnet_path = os.path.join(cache_dir, "birefnet.py") |
| BiRefNetConfig_path = os.path.join(cache_dir, "BiRefNet_config.py") |
|
|
| |
| config_spec = importlib.util.spec_from_file_location("BiRefNetConfig", BiRefNetConfig_path) |
| config_module = importlib.util.module_from_spec(config_spec) |
| sys.modules["BiRefNetConfig"] = config_module |
| config_spec.loader.exec_module(config_module) |
|
|
| |
| with open(birefnet_path, "r") as f: |
| birefnet_content = f.read() |
|
|
| birefnet_content = birefnet_content.replace( |
| "from .BiRefNet_config import BiRefNetConfig", |
| "from BiRefNetConfig import BiRefNetConfig" |
| ) |
|
|
| module_name = f"custom_birefnet_model_{hash(birefnet_path)}" |
| module = types.ModuleType(module_name) |
| sys.modules[module_name] = module |
| exec(birefnet_content, module.__dict__) |
|
|
| for attr_name in dir(module): |
| attr = getattr(module, attr_name) |
| if isinstance(attr, type) and issubclass(attr, PreTrainedModel) and attr != PreTrainedModel: |
| BiRefNetConfig = getattr(config_module, "BiRefNetConfig") |
| model_config = BiRefNetConfig() |
| self.model = attr(model_config) |
|
|
| weights_path = os.path.join(cache_dir, "model.safetensors") |
| try: |
| try: |
| import safetensors.torch |
| self.model.load_state_dict(safetensors.torch.load_file(weights_path)) |
| except ImportError: |
| from transformers.modeling_utils import load_state_dict |
| state_dict = load_state_dict(weights_path) |
| self.model.load_state_dict(state_dict) |
| except Exception as load_error: |
| pytorch_weights = os.path.join(cache_dir, "pytorch_model.bin") |
| if os.path.exists(pytorch_weights): |
| self.model.load_state_dict(torch.load(pytorch_weights, map_location="cpu")) |
| else: |
| raise RuntimeError(f"Failed to load weights: {str(load_error)}") |
| break |
|
|
| if self.model is None: |
| raise RuntimeError("Could not find suitable model class") |
|
|
| except Exception as modern_e: |
| print("[RMBG INFO] Using standard transformers loading (fallback mode)...") |
| try: |
| self.model = AutoModelForImageSegmentation.from_pretrained( |
| cache_dir, |
| trust_remote_code=True, |
| local_files_only=True |
| ) |
| except Exception as standard_e: |
| handle_model_error( |
| "Failed to load model with both modern and standard methods. " |
| f"Modern error: {str(modern_e)}. Standard error: {str(standard_e)}" |
| ) |
|
|
| except Exception as e: |
| handle_model_error(f"Error loading model: {str(e)}") |
|
|
| self.model.eval() |
| for param in self.model.parameters(): |
| param.requires_grad = False |
|
|
| torch.set_float32_matmul_precision("high") |
| self.model.to(device) |
| self.current_model_version = model_name |
|
|
| def process_image(self, images, model_name: str, params: dict): |
| |
| |
| with self._lock: |
| try: |
| self.load_model(model_name) |
|
|
| transform_image = transforms.Compose([ |
| transforms.Resize((params["process_res"], params["process_res"])), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| ]) |
|
|
| if isinstance(images, torch.Tensor): |
| if len(images.shape) == 3: |
| images = [images] |
| else: |
| images = [img for img in images] |
|
|
| original_sizes = [tensor2pil(img).size for img in images] |
|
|
| input_tensors = [transform_image(tensor2pil(img)).unsqueeze(0) for img in images] |
| input_batch = torch.cat(input_tensors, dim=0).to(device) |
|
|
| with torch.no_grad(): |
| outputs = self.model(input_batch) |
|
|
| if isinstance(outputs, list) and len(outputs) > 0: |
| results = outputs[-1].sigmoid().cpu() |
| elif isinstance(outputs, dict) and "logits" in outputs: |
| results = outputs["logits"].sigmoid().cpu() |
| elif isinstance(outputs, torch.Tensor): |
| results = outputs.sigmoid().cpu() |
| else: |
| try: |
| if hasattr(outputs, "last_hidden_state"): |
| results = outputs.last_hidden_state.sigmoid().cpu() |
| else: |
| results = None |
| for _k, v in outputs.items(): |
| if isinstance(v, torch.Tensor): |
| results = v.sigmoid().cpu() |
| break |
| if results is None: |
| handle_model_error("Unable to recognize model output format") |
| except Exception: |
| handle_model_error("Unable to recognize model output format") |
|
|
| masks = [] |
| for result, (orig_w, orig_h) in zip(results, original_sizes): |
| result = result.squeeze() |
| result = result * (1 + (1 - params["sensitivity"])) |
| result = torch.clamp(result, 0, 1) |
|
|
| result = F.interpolate( |
| result.unsqueeze(0).unsqueeze(0), |
| size=(orig_h, orig_w), |
| mode="bilinear" |
| ).squeeze() |
|
|
| masks.append(tensor2pil(result)) |
|
|
| |
| del input_batch |
| return masks |
|
|
| except Exception as e: |
| handle_model_error(f"Error in batch processing: {str(e)}") |
|
|
|
|
| class InspyrenetModel(BaseModelLoader): |
| def __init__(self): |
| super().__init__() |
|
|
| def load_model(self, model_name: str): |
| with self._lock: |
| if self.current_model_version == model_name and self.model is not None: |
| return |
| self.clear_model() |
|
|
| try: |
| import transparent_background |
| self.model = transparent_background.Remover() |
| self.current_model_version = model_name |
| except Exception as e: |
| handle_model_error(f"Failed to initialize transparent_background: {str(e)}") |
|
|
| def process_image(self, image, model_name: str, params: dict): |
| with self._lock: |
| try: |
| self.load_model(model_name) |
|
|
| orig_image = tensor2pil(image) |
| w, h = orig_image.size |
|
|
| aspect_ratio = h / w |
| new_w = params["process_res"] |
| new_h = int(params["process_res"] * aspect_ratio) |
| resized_image = orig_image.resize((new_w, new_h), Image.LANCZOS) |
|
|
| foreground = self.model.process(resized_image, type="rgba") |
| foreground = foreground.resize((w, h), Image.LANCZOS) |
| mask = foreground.split()[-1] |
|
|
| return mask |
|
|
| except Exception as e: |
| handle_model_error(f"Error in Inspyrenet processing: {str(e)}") |
|
|
|
|
| class BENModel(BaseModelLoader): |
| def __init__(self): |
| super().__init__() |
|
|
| def load_model(self, model_name: str): |
| with self._lock: |
| if self.current_model_version == model_name and self.model is not None: |
| return |
|
|
| self.clear_model() |
|
|
| cache_dir = self.get_cache_dir(model_name) |
| model_path = os.path.join(cache_dir, "model.py") |
| module_name = f"custom_ben_model_{hash(model_path)}" |
|
|
| spec = importlib.util.spec_from_file_location(module_name, model_path) |
| ben_module = importlib.util.module_from_spec(spec) |
| sys.modules[module_name] = ben_module |
| spec.loader.exec_module(ben_module) |
|
|
| model_weights_path = os.path.join(cache_dir, "BEN_Base.pth") |
| self.model = ben_module.BEN_Base() |
| self.model.loadcheckpoints(model_weights_path) |
|
|
| self.model.eval() |
| for param in self.model.parameters(): |
| param.requires_grad = False |
|
|
| torch.set_float32_matmul_precision("high") |
| self.model.to(device) |
| self.current_model_version = model_name |
|
|
| def process_image(self, image, model_name: str, params: dict): |
| with self._lock: |
| try: |
| self.load_model(model_name) |
|
|
| orig_image = tensor2pil(image) |
| w, h = orig_image.size |
|
|
| aspect_ratio = h / w |
| new_w = params["process_res"] |
| new_h = int(params["process_res"] * aspect_ratio) |
| resized_image = orig_image.resize((new_w, new_h), Image.LANCZOS) |
|
|
| processed_input = resized_image.convert("RGBA") |
|
|
| with torch.no_grad(): |
| _, foreground = self.model.inference(processed_input) |
|
|
| foreground = foreground.resize((w, h), Image.LANCZOS) |
| mask = foreground.split()[-1] |
|
|
| return mask |
|
|
| except Exception as e: |
| handle_model_error(f"Error in BEN processing: {str(e)}") |
|
|
|
|
| class BEN2Model(BaseModelLoader): |
| def __init__(self): |
| super().__init__() |
|
|
| def load_model(self, model_name: str): |
| with self._lock: |
| if self.current_model_version == model_name and self.model is not None: |
| return |
|
|
| self.clear_model() |
|
|
| try: |
| cache_dir = self.get_cache_dir(model_name) |
| model_path = os.path.join(cache_dir, "BEN2.py") |
| module_name = f"custom_ben2_model_{hash(model_path)}" |
|
|
| spec = importlib.util.spec_from_file_location(module_name, model_path) |
| ben2_module = importlib.util.module_from_spec(spec) |
| sys.modules[module_name] = ben2_module |
| spec.loader.exec_module(ben2_module) |
|
|
| model_weights_path = os.path.join(cache_dir, "BEN2_Base.pth") |
| self.model = ben2_module.BEN_Base() |
| self.model.loadcheckpoints(model_weights_path) |
|
|
| self.model.eval() |
| for param in self.model.parameters(): |
| param.requires_grad = False |
|
|
| torch.set_float32_matmul_precision("high") |
| self.model.to(device) |
| self.current_model_version = model_name |
|
|
| except Exception as e: |
| handle_model_error(f"Error loading BEN2 model: {str(e)}") |
|
|
| def process_image(self, images, model_name: str, params: dict): |
| with self._lock: |
| try: |
| self.load_model(model_name) |
|
|
| if isinstance(images, torch.Tensor): |
| if len(images.shape) == 3: |
| images = [images] |
| else: |
| images = [img for img in images] |
|
|
| batch_size = 3 |
| all_masks = [] |
|
|
| for i in range(0, len(images), batch_size): |
| batch_images = images[i:i + batch_size] |
| batch_pil_images = [] |
| original_sizes = [] |
|
|
| for img in batch_images: |
| orig_image = tensor2pil(img) |
| w, h = orig_image.size |
| original_sizes.append((w, h)) |
|
|
| aspect_ratio = h / w |
| new_w = params["process_res"] |
| new_h = int(params["process_res"] * aspect_ratio) |
| resized_image = orig_image.resize((new_w, new_h), Image.LANCZOS) |
| processed_input = resized_image.convert("RGBA") |
| batch_pil_images.append(processed_input) |
|
|
| with torch.no_grad(): |
| try: |
| foregrounds = self.model.inference(batch_pil_images) |
| if not isinstance(foregrounds, list): |
| foregrounds = [foregrounds] |
| except Exception as e: |
| handle_model_error(f"Error in BEN2 inference: {str(e)}") |
|
|
| for foreground, (orig_w, orig_h) in zip(foregrounds, original_sizes): |
| foreground = foreground.resize((orig_w, orig_h), Image.LANCZOS) |
| mask = foreground.split()[-1] |
| all_masks.append(mask) |
|
|
| if len(all_masks) == 1: |
| return all_masks[0] |
| return all_masks |
|
|
| except Exception as e: |
| handle_model_error(f"Error in BEN2 processing: {str(e)}") |
|
|
|
|
| def refine_foreground(image_bchw: torch.Tensor, masks_b1hw: torch.Tensor) -> torch.Tensor: |
| b, c, h, w = image_bchw.shape |
| if b != masks_b1hw.shape[0]: |
| raise ValueError("images and masks must have the same batch size") |
|
|
| image_np = image_bchw.cpu().numpy() |
| mask_np = masks_b1hw.cpu().numpy() |
|
|
| refined_fg = [] |
| for i in range(b): |
| mask = mask_np[i, 0] |
| thresh = 0.45 |
| mask_binary = (mask > thresh).astype(np.float32) |
|
|
| edge_blur = cv2.GaussianBlur(mask_binary, (3, 3), 0) |
| transition_mask = np.logical_and(mask > 0.05, mask < 0.95) |
|
|
| alpha = 0.85 |
| mask_refined = np.where( |
| transition_mask, |
| alpha * mask + (1 - alpha) * edge_blur, |
| mask_binary |
| ) |
|
|
| edge_region = np.logical_and(mask > 0.2, mask < 0.8) |
| mask_refined = np.where(edge_region, mask_refined * 0.98, mask_refined) |
|
|
| result = [] |
| for cc in range(image_np.shape[1]): |
| channel = image_np[i, cc] |
| refined = channel * mask_refined |
| result.append(refined) |
|
|
| refined_fg.append(np.stack(result)) |
|
|
| return torch.from_numpy(np.stack(refined_fg)) |
|
|
|
|
| |
| |
| |
| |
| |
| _SHARED_MODEL_LOADERS = { |
| "RMBG-2.0": RMBGModel(), |
| "INSPYRENET": InspyrenetModel(), |
| "BEN": BENModel(), |
| "BEN2": BEN2Model() |
| } |
|
|
|
|
| class RMBG: |
| def __init__(self): |
| |
| self.models = _SHARED_MODEL_LOADERS |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| tooltips = { |
| "image": "Input image to be processed for background removal.", |
| "model": "Select the background removal model to use.", |
| "sensitivity": "Adjust mask detection strength (1.0 = default).", |
| "process_res": "Processing resolution (higher = more VRAM).", |
| "mask_blur": "Blur mask edges (0 = none).", |
| "mask_offset": "Expand/shrink mask boundary (+ expand / - shrink).", |
| "background": "Alpha = transparent output. Color = composite onto solid background.", |
| "background_color": "Only used when background='Color'. Ignored when background='Alpha' (transparent).", |
| "invert_output": "Invert mask output.", |
| "refine_foreground": "Fast Foreground Colour Estimation for better transparency." |
| } |
|
|
| return { |
| "required": { |
| "image": ("IMAGE", {"tooltip": tooltips["image"]}), |
| |
| "model": (list(AVAILABLE_MODELS.keys()), {"default": "RMBG-2.0", "tooltip": tooltips["model"]}), |
| }, |
| "optional": { |
| "sensitivity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": tooltips["sensitivity"]}), |
| "process_res": ("INT", {"default": 1024, "min": 256, "max": 2048, "step": 8, "tooltip": tooltips["process_res"]}), |
| "mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1, "tooltip": tooltips["mask_blur"]}), |
| "mask_offset": ("INT", {"default": 0, "min": -64, "max": 64, "step": 1, "tooltip": tooltips["mask_offset"]}), |
| "invert_output": ("BOOLEAN", {"default": False, "tooltip": tooltips["invert_output"]}), |
| "refine_foreground": ("BOOLEAN", {"default": False, "tooltip": tooltips["refine_foreground"]}), |
| "background": (["Alpha", "Color"], {"default": "Alpha", "tooltip": tooltips["background"]}), |
| |
| "background_color": ("COLORCODE", {"default": "#222222", "tooltip": tooltips["background_color"]}), |
| } |
| } |
|
|
| RETURN_TYPES = ("IMAGE", "MASK", "IMAGE") |
| RETURN_NAMES = ("IMAGE", "MASK", "MASK_IMAGE") |
| FUNCTION = "process_image" |
| CATEGORY = "🧪AILab/🧽RMBG" |
|
|
| def process_image(self, image, model, **params): |
| try: |
| processed_images = [] |
| processed_masks = [] |
|
|
| model_instance = self.models[model] |
|
|
| cache_status, message = model_instance.check_model_cache(model) |
| if not cache_status: |
| print(f"Cache check: {message}") |
| print("Downloading required model files...") |
| download_status, download_message = model_instance.download_model(model) |
| if not download_status: |
| handle_model_error(download_message) |
| print("Model files downloaded successfully") |
|
|
| model_type = AVAILABLE_MODELS[model]["type"] |
|
|
| def _process_pair(img, mask): |
| if isinstance(mask, list): |
| masks = [m.convert("L") for m in mask if isinstance(m, Image.Image)] |
| mask_local = masks[0] if masks else None |
| elif isinstance(mask, Image.Image): |
| mask_local = mask.convert("L") |
| else: |
| mask_local = mask |
|
|
| mask_tensor_local = pil2tensor(mask_local) |
| mask_tensor_local = mask_tensor_local * (1 + (1 - params["sensitivity"])) |
| mask_tensor_local = torch.clamp(mask_tensor_local, 0, 1) |
| mask_img_local = tensor2pil(mask_tensor_local) |
|
|
| if params["mask_blur"] > 0: |
| mask_img_local = mask_img_local.filter(ImageFilter.GaussianBlur(radius=params["mask_blur"])) |
|
|
| if params["mask_offset"] != 0: |
| if params["mask_offset"] > 0: |
| for _ in range(params["mask_offset"]): |
| mask_img_local = mask_img_local.filter(ImageFilter.MaxFilter(3)) |
| else: |
| for _ in range(-params["mask_offset"]): |
| mask_img_local = mask_img_local.filter(ImageFilter.MinFilter(3)) |
|
|
| if params["invert_output"]: |
| mask_img_local = Image.fromarray(255 - np.array(mask_img_local)) |
|
|
| img_tensor_local = torch.from_numpy(np.array(tensor2pil(img))).permute(2, 0, 1).unsqueeze(0) / 255.0 |
| mask_tensor_b1hw = torch.from_numpy(np.array(mask_img_local)).unsqueeze(0).unsqueeze(0) / 255.0 |
|
|
| orig_image_local = tensor2pil(img) |
|
|
| if params.get("refine_foreground", False): |
| refined_fg_local = refine_foreground(img_tensor_local, mask_tensor_b1hw) |
| refined_fg_local = tensor2pil(refined_fg_local[0].permute(1, 2, 0)) |
| r, g, b = refined_fg_local.split() |
| foreground_local = Image.merge("RGBA", (r, g, b, mask_img_local)) |
| else: |
| orig_rgba_local = orig_image_local.convert("RGBA") |
| r, g, b, _a = orig_rgba_local.split() |
| foreground_local = Image.merge("RGBA", (r, g, b, mask_img_local)) |
|
|
| |
| |
| if params["background"] == "Color": |
| def hex_to_rgba(hex_color: str): |
| hex_color = hex_color.lstrip("#") |
| if len(hex_color) == 6: |
| rr = int(hex_color[0:2], 16) |
| gg = int(hex_color[2:4], 16) |
| bb = int(hex_color[4:6], 16) |
| aa = 255 |
| elif len(hex_color) == 8: |
| rr = int(hex_color[0:2], 16) |
| gg = int(hex_color[2:4], 16) |
| bb = int(hex_color[4:6], 16) |
| aa = int(hex_color[6:8], 16) |
| else: |
| raise ValueError("Invalid color format") |
| return (rr, gg, bb, aa) |
|
|
| background_color = params.get("background_color", "#222222") |
| rgba = hex_to_rgba(background_color) |
| bg_image = Image.new("RGBA", orig_image_local.size, rgba) |
| composite_image = Image.alpha_composite(bg_image, foreground_local) |
| processed_images.append(pil2tensor(composite_image.convert("RGB"))) |
| else: |
| processed_images.append(pil2tensor(foreground_local)) |
|
|
| processed_masks.append(pil2tensor(mask_img_local)) |
|
|
| if model_type in ("rmbg", "ben2"): |
| images_list = [img for img in image] |
| chunk_size = 4 |
| for start in range(0, len(images_list), chunk_size): |
| batch_imgs = images_list[start:start + chunk_size] |
| masks = model_instance.process_image(batch_imgs, model, params) |
| if isinstance(masks, Image.Image): |
| masks = [masks] |
| for img_item, mask_item in zip(batch_imgs, masks): |
| _process_pair(img_item, mask_item) |
| else: |
| for img in image: |
| mask = model_instance.process_image(img, model, params) |
| _process_pair(img, mask) |
|
|
| mask_images = [] |
| for mask_tensor in processed_masks: |
| mask_image = mask_tensor.reshape((-1, 1, mask_tensor.shape[-2], mask_tensor.shape[-1])) \ |
| .movedim(1, -1) \ |
| .expand(-1, -1, -1, 3) |
| mask_images.append(mask_image) |
|
|
| mask_image_output = torch.cat(mask_images, dim=0) |
|
|
| return ( |
| torch.cat(processed_images, dim=0), |
| torch.cat(processed_masks, dim=0), |
| mask_image_output |
| ) |
|
|
| except Exception as e: |
| handle_model_error(f"Error in image processing: {str(e)}") |
| empty_mask = torch.zeros((image.shape[0], image.shape[2], image.shape[3])) |
| empty_mask_image = empty_mask.reshape((-1, 1, empty_mask.shape[-2], empty_mask.shape[-1])) \ |
| .movedim(1, -1) \ |
| .expand(-1, -1, -1, 3) |
| return (image, empty_mask, empty_mask_image) |
|
|
|
|
| NODE_CLASS_MAPPINGS = { |
| "RMBG": RMBG |
| } |
|
|
| NODE_DISPLAY_NAME_MAPPINGS = { |
| "RMBG": "Remove Background (RMBG)" |
| } |
|
|