# ComfyUI-RMBG # This custom node for ComfyUI provides functionality for Object removal using SDMatte model. # # reference from https://github.com/vivoCameraResearch/SDMatte # model: https://huggingface.co/1038lab/SDMatte # # This integration script follows GPL-3.0 License. # When using or modifying this code, please respect both the original model licenses # and this integration's license terms. # # Source: https://github.com/AILab-AI/ComfyUI-RMBG import os import sys import copy from pathlib import Path if os.environ.get('SDMATTE_CPU_ONLY', '').lower() in ('1', 'true', 'yes'): os.environ['CUDA_VISIBLE_DEVICES'] = '' import torch import torch.nn.functional as F import numpy as np from PIL import Image, ImageFilter from torch.hub import download_url_to_file from torchvision import transforms from torchvision.transforms import InterpolationMode import folder_paths try: import comfy.model_management COMFY_AVAILABLE = True except Exception as e: print(f"Warning: ComfyUI model management not available: {e}") COMFY_AVAILABLE = False class MockModelManagement: @staticmethod def get_torch_device(): return torch.device('cpu') class MockComfy: model_management = MockModelManagement() comfy = MockComfy() try: from safetensors.torch import load_file SAFETENSORS_AVAILABLE = True except ImportError: SAFETENSORS_AVAILABLE = False print("Warning: safetensors not available. Will use torch.load for model loading.") try: import diffusers import transformers DIFFUSERS_AVAILABLE = True except ImportError: DIFFUSERS_AVAILABLE = False print("Warning: diffusers/transformers not available. SDMatte functionality will be limited.") current_dir = Path(__file__).resolve().parent repo_root = current_dir.parent sdmatte_path = repo_root / "models" / "SDMatte" sys.path.insert(0, str(sdmatte_path)) SDMATTE_MODELS = { "SDMatte": { "model_url": "https://huggingface.co/1038lab/SDMatte/resolve/main/SDMatte.safetensors", "filename": "SDMatte.safetensors", "repo_id": "1038lab/SDMatte" }, "SDMatte_plus": { "model_url": "https://huggingface.co/1038lab/SDMatte/resolve/main/SDMatte_plus.safetensors", "filename": "SDMatte_plus.safetensors", "repo_id": "1038lab/SDMatte" } } REQUIRED_COMPONENTS = ["scheduler", "text_encoder", "tokenizer", "unet", "vae"] def get_or_download_model_file(filename, url, dirname): local_path = folder_paths.get_full_path(dirname, filename) if local_path: return local_path folder = os.path.join(folder_paths.models_dir, dirname) os.makedirs(folder, exist_ok=True) local_path = os.path.join(folder, filename) if not os.path.exists(local_path): print(f"Downloading {filename} from {url} ...") try: download_url_to_file(url, local_path) except Exception as e: raise RuntimeError(f"Failed to download {filename} from {url}: {e}") return local_path def ensure_model_components(model_name): model_info = SDMATTE_MODELS[model_name] repo_id = model_info["repo_id"] components_dir = os.path.join(folder_paths.models_dir, "RMBG", "SDMatte") missing_components = [] for component in REQUIRED_COMPONENTS: component_path = os.path.join(components_dir, component) if not os.path.exists(component_path) or not os.listdir(component_path): missing_components.append(component) if missing_components: print(f"Downloading missing SDMatte components: {missing_components}") base_url = f"https://huggingface.co/{repo_id}/resolve/main" for component in missing_components: component_dir = os.path.join(components_dir, component) os.makedirs(component_dir, exist_ok=True) if component == "scheduler": files = ["scheduler_config.json"] elif component == "text_encoder": files = ["config.json"] elif component == "tokenizer": files = ["merges.txt", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"] elif component == "unet": files = ["config.json"] elif component == "vae": files = ["config.json"] for file in files: file_url = f"{base_url}/{component}/{file}" file_path = os.path.join(component_dir, file) if not os.path.exists(file_path): try: print(f" Downloading {component}/{file}...") download_url_to_file(file_url, file_path) except Exception as e: print(f" Warning: Failed to download {file}: {e}") return components_dir def process_mask(mask_image: Image.Image, invert_output: bool = False, mask_blur: int = 0, mask_offset: int = 0) -> Image.Image: if invert_output: mask_np = np.array(mask_image) mask_image = Image.fromarray(255 - mask_np) if mask_blur > 0: mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur)) if mask_offset != 0: filter_type = ImageFilter.MaxFilter if mask_offset > 0 else ImageFilter.MinFilter size = abs(mask_offset) * 2 + 1 for _ in range(abs(mask_offset)): mask_image = mask_image.filter(filter_type(size)) return mask_image def apply_background_color(image: Image.Image, mask_image: Image.Image, background: str = "Alpha", background_color: str = "#222222") -> Image.Image: rgba_image = image.copy().convert('RGBA') rgba_image.putalpha(mask_image.convert('L')) if background == "Color": def hex_to_rgba(hex_color): hex_color = hex_color.lstrip('#') r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16) return (r, g, b, 255) rgba = hex_to_rgba(background_color) bg_image = Image.new('RGBA', image.size, rgba) composite_image = Image.alpha_composite(bg_image, rgba_image) return composite_image.convert('RGB') return rgba_image def pil2tensor(image): return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) def tensor2pil(image): return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) SDMatteCore = None def _resize_norm_image_bchw(image_bchw: torch.Tensor, size_hw=(1024, 1024)) -> torch.Tensor: if image_bchw.shape[1] == 4: image_bchw = image_bchw[:, :3, :, :] resize = transforms.Resize(size_hw, interpolation=InterpolationMode.BILINEAR, antialias=True) norm = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) x = resize(image_bchw) x = norm(x) return x def _resize_mask_b1hw(mask_b1hw: torch.Tensor, size_hw=(1024, 1024)) -> torch.Tensor: resize = transforms.Resize(size_hw, interpolation=InterpolationMode.BILINEAR, antialias=True) return resize(mask_b1hw) class AILab_SDMatte: @classmethod def INPUT_TYPES(cls): tooltips = { "model": "SDMatte model variant: Standard or Plus version", "image": "Input image for matting extraction", "mask": "Mask: White=foreground, Black=background. If omitted and image has alpha, alpha will be used.", "process_res": "Processing resolution: higher = better quality but slower", "device": "Auto: smart detection, CPU: force CPU, GPU: force GPU", "transparent_object": "Whether input image contains transparent objects", "mask_refine": "Enable mask refinement using mask constraints", "sensitivity": "Sensitivity for mask constraint (0.1-1.0): higher = more strict", "mask_blur": "Blur mask edges (0 = disabled)", "mask_offset": "Expand/shrink mask (positive = expand)", "invert_output": "Invert the mask output", "background": "Background type for output", "background_color": "Background color (when not Alpha)", } return { "required": { "image": ("IMAGE",), "model": (list(SDMATTE_MODELS.keys()), {"default": "SDMatte", "tooltip": tooltips["model"]}), "device": (["Auto", "CPU", "GPU"], {"default": "Auto", "tooltip": tooltips["device"]}), "process_res": ("INT", {"default": 1024, "min": 256, "max": 2048, "step": 8, "tooltip": tooltips["process_res"]}), }, "optional": { "mask": ("MASK", {"tooltip": tooltips["mask"]}), "transparent_object": ("BOOLEAN", {"default": True, "tooltip": tooltips["transparent_object"]}), "mask_refine": ("BOOLEAN", {"default": True, "tooltip": tooltips["mask_refine"]}), "sensitivity": ("FLOAT", {"default": 0.9, "min": 0.1, "max": 1.0, "step": 0.1, "tooltip": tooltips["sensitivity"]}), "mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1, "tooltip": tooltips["mask_blur"]}), "mask_offset": ("INT", {"default": 0, "min": -64, "max": 64, "step": 1, "tooltip": tooltips["mask_offset"]}), "invert_output": ("BOOLEAN", {"default": False, "tooltip": tooltips["invert_output"]}), "background": (["Alpha", "Color"], {"default": "Alpha", "tooltip": tooltips["background"]}), "background_color": ("COLORCODE", {"default": "#222222", "tooltip": tooltips["background_color"]}), } } RETURN_TYPES = ("IMAGE", "MASK", "IMAGE") RETURN_NAMES = ("IMAGE", "MASK", "MASK_IMAGE") FUNCTION = "matting_inference" CATEGORY = "🧪AILab/🧽RMBG" def __init__(self): self.model_cache = {} def load_sdmatte_model(self, model_name, device="Auto"): cache_key = f"{model_name}_{device}" current_model_keys = [k for k in self.model_cache.keys() if k.startswith(model_name)] if cache_key not in self.model_cache and len(self.model_cache) > 0: for key in list(self.model_cache.keys()): if key not in current_model_keys: del self.model_cache[key] import gc gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() if cache_key not in self.model_cache: if not DIFFUSERS_AVAILABLE: raise ImportError("diffusers and transformers are required for SDMatte functionality") global SDMatteCore if SDMatteCore is None: import sys import os current_dir = os.path.dirname(__file__) if current_dir not in sys.path: sys.path.insert(0, current_dir) from SDMatte.modeling.SDMatte.meta_arch import SDMatte as SDMatteCore model_info = SDMATTE_MODELS[model_name] model_path = get_or_download_model_file( model_info["filename"], model_info["model_url"], "RMBG/SDMatte" ) pretrained_repo = ensure_model_components(model_name) sdmatte_model = SDMatteCore( pretrained_model_name_or_path=pretrained_repo, load_weight=False, use_aux_input=True, aux_input="trimap", use_encoder_hidden_states=True, use_attention_mask=True, add_noise=False, ) self._load_model_weights(sdmatte_model, model_path) device_obj = comfy.model_management.get_torch_device() if device == "CPU": device_obj = torch.device('cpu') elif device == "GPU": if not torch.cuda.is_available(): print("SDMatte: GPU requested but CUDA not available, falling back to CPU") device_obj = torch.device('cpu') else: device_obj = comfy.model_management.get_torch_device() sdmatte_model.eval() sdmatte_model.to(device_obj) if device_obj.type == 'cuda': self._apply_memory_optimizations(sdmatte_model) self.model_cache[cache_key] = sdmatte_model return self.model_cache[cache_key] def _load_model_weights(self, model, model_path): if not SAFETENSORS_AVAILABLE: raise ImportError("safetensors is required for SDMatte functionality") try: state_dict = load_file(model_path) model.load_state_dict(state_dict, strict=False) except Exception as e: if os.path.exists(model_path): print(f"[SDMatte] Model file appears corrupted, deleting: {model_path}") os.remove(model_path) raise RuntimeError(f"Failed to load model weights. File may be corrupted. Please try again to re-download. Error: {e}") def _apply_memory_optimizations(self, model): try: torch.cuda.empty_cache() except Exception: pass try: unet = getattr(model, 'unet', None) if unet is not None and hasattr(unet, 'set_attn_processor'): from diffusers.models.attention_processor import SlicedAttnProcessor unet.set_attn_processor(SlicedAttnProcessor(slice_size=1)) except Exception: pass def matting_inference(self, image, model, process_res, device="Auto", mask=None, transparent_object=True, mask_refine=True, sensitivity=0.8, mask_blur=0, mask_offset=0, invert_output=False, background="Alpha", background_color="#222222"): sdmatte_model = self.load_sdmatte_model(model, device) device_obj = comfy.model_management.get_torch_device() if device == "CPU": device_obj = torch.device('cpu') batch_size = image.shape[0] result_masks = [] result_images = [] result_mask_images = [] for b in range(batch_size): img_pil = tensor2pil(image[b]) B, H, W = 1, img_pil.height, img_pil.width orig_h, orig_w = H, W img_bchw = image[b:b+1].permute(0, 3, 1, 2).contiguous().to(device_obj) img_in = _resize_norm_image_bchw(img_bchw, (int(process_res), int(process_res))) if mask is not None: mask_b1hw = mask[b:b+1].unsqueeze(1).contiguous().to(device_obj) mask_for_refine = mask[b:b+1] else: if image.shape[-1] == 4: alpha = image[b, :, :, 3] mask_b1hw = alpha.unsqueeze(0).unsqueeze(0).contiguous().to(device_obj) mask_for_refine = alpha.unsqueeze(0) else: raise ValueError("Mask required: provide a mask or use an image with alpha.") tri = _resize_mask_b1hw(mask_b1hw, (int(process_res), int(process_res))) * 2 - 1 data = {"image": img_in, "is_trans": torch.tensor([1 if transparent_object else 0], device=device_obj), "caption": [""], "trimap": tri, "trimap_coords": torch.tensor([[0,0,1,1]], dtype=tri.dtype, device=device_obj)} with torch.inference_mode(): if device_obj.type == 'cuda': with torch.autocast(device_type='cuda', dtype=torch.float16): pred_alpha = sdmatte_model(data) else: pred_alpha = sdmatte_model(data) out = transforms.Resize((orig_h, orig_w), interpolation=InterpolationMode.BILINEAR, antialias=True)(pred_alpha) out = out.squeeze(1).clamp(0, 1).detach().cpu() if mask_refine: out = self._refine_mask(out, mask_for_refine, sensitivity) mask_pil = Image.fromarray((out[0].numpy() * 255).astype(np.uint8), mode="L") mask_image = process_mask(mask_pil, invert_output, mask_blur, mask_offset) result_image = apply_background_color(img_pil, mask_image, background, background_color) if background == "Color": result_image = result_image.convert("RGB") else: result_image = result_image.convert("RGBA") mask_tensor = torch.from_numpy(np.array(mask_image).astype(np.float32) / 255.0).unsqueeze(0) mask_image_vis = mask_tensor.reshape((-1, 1, mask_image.height, mask_image.width)).movedim(1, -1).expand(-1, -1, -1, 3) result_masks.append(mask_tensor) result_images.append(pil2tensor(result_image)) result_mask_images.append(mask_image_vis) if device_obj.type == 'cuda': torch.cuda.empty_cache() import gc gc.collect() return (torch.cat(result_images, dim=0), torch.cat(result_masks, dim=0), torch.cat(result_mask_images, dim=0)) def _refine_mask(self, mask, trimap, constraint): trimap_cpu = trimap.cpu() foreground_regions = trimap_cpu > constraint background_regions = trimap_cpu < (1.0 - constraint) unknown_regions = ~(foreground_regions | background_regions) refined_mask = mask.clone() refined_mask[background_regions] = 0.0 refined_mask[foreground_regions] = torch.clamp(refined_mask[foreground_regions] * 1.2, 0, 1) alpha_threshold = 0.3 low_confidence = (refined_mask < alpha_threshold) & unknown_regions refined_mask[low_confidence] = 0.0 return refined_mask NODE_CLASS_MAPPINGS = { "AILab_SDMatte": AILab_SDMatte, } NODE_DISPLAY_NAME_MAPPINGS = { "AILab_SDMatte": "SDMatte Matting (RMBG)", }