| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import os |
| import sys |
| import copy |
| from pathlib import Path |
|
|
| if os.environ.get('SDMATTE_CPU_ONLY', '').lower() in ('1', 'true', 'yes'): |
| os.environ['CUDA_VISIBLE_DEVICES'] = '' |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from PIL import Image, ImageFilter |
| from torch.hub import download_url_to_file |
| from torchvision import transforms |
| from torchvision.transforms import InterpolationMode |
|
|
| import folder_paths |
|
|
| try: |
| import comfy.model_management |
| COMFY_AVAILABLE = True |
| except Exception as e: |
| print(f"Warning: ComfyUI model management not available: {e}") |
| COMFY_AVAILABLE = False |
| class MockModelManagement: |
| @staticmethod |
| def get_torch_device(): |
| return torch.device('cpu') |
|
|
| class MockComfy: |
| model_management = MockModelManagement() |
|
|
| comfy = MockComfy() |
|
|
| try: |
| from safetensors.torch import load_file |
| SAFETENSORS_AVAILABLE = True |
| except ImportError: |
| SAFETENSORS_AVAILABLE = False |
| print("Warning: safetensors not available. Will use torch.load for model loading.") |
|
|
| try: |
| import diffusers |
| import transformers |
| DIFFUSERS_AVAILABLE = True |
| except ImportError: |
| DIFFUSERS_AVAILABLE = False |
| print("Warning: diffusers/transformers not available. SDMatte functionality will be limited.") |
|
|
| current_dir = Path(__file__).resolve().parent |
| repo_root = current_dir.parent |
| sdmatte_path = repo_root / "models" / "SDMatte" |
| sys.path.insert(0, str(sdmatte_path)) |
|
|
| SDMATTE_MODELS = { |
| "SDMatte": { |
| "model_url": "https://huggingface.co/1038lab/SDMatte/resolve/main/SDMatte.safetensors", |
| "filename": "SDMatte.safetensors", |
| "repo_id": "1038lab/SDMatte" |
| }, |
| "SDMatte_plus": { |
| "model_url": "https://huggingface.co/1038lab/SDMatte/resolve/main/SDMatte_plus.safetensors", |
| "filename": "SDMatte_plus.safetensors", |
| "repo_id": "1038lab/SDMatte" |
| } |
| } |
|
|
| REQUIRED_COMPONENTS = ["scheduler", "text_encoder", "tokenizer", "unet", "vae"] |
|
|
| def get_or_download_model_file(filename, url, dirname): |
| local_path = folder_paths.get_full_path(dirname, filename) |
| if local_path: |
| return local_path |
| folder = os.path.join(folder_paths.models_dir, dirname) |
| os.makedirs(folder, exist_ok=True) |
| local_path = os.path.join(folder, filename) |
| if not os.path.exists(local_path): |
| print(f"Downloading {filename} from {url} ...") |
| try: |
| download_url_to_file(url, local_path) |
| except Exception as e: |
| raise RuntimeError(f"Failed to download {filename} from {url}: {e}") |
| return local_path |
|
|
| def ensure_model_components(model_name): |
| model_info = SDMATTE_MODELS[model_name] |
| repo_id = model_info["repo_id"] |
|
|
| components_dir = os.path.join(folder_paths.models_dir, "RMBG", "SDMatte") |
|
|
| missing_components = [] |
| for component in REQUIRED_COMPONENTS: |
| component_path = os.path.join(components_dir, component) |
| if not os.path.exists(component_path) or not os.listdir(component_path): |
| missing_components.append(component) |
|
|
| if missing_components: |
| print(f"Downloading missing SDMatte components: {missing_components}") |
| base_url = f"https://huggingface.co/{repo_id}/resolve/main" |
|
|
| for component in missing_components: |
| component_dir = os.path.join(components_dir, component) |
| os.makedirs(component_dir, exist_ok=True) |
|
|
| if component == "scheduler": |
| files = ["scheduler_config.json"] |
| elif component == "text_encoder": |
| files = ["config.json"] |
| elif component == "tokenizer": |
| files = ["merges.txt", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"] |
| elif component == "unet": |
| files = ["config.json"] |
| elif component == "vae": |
| files = ["config.json"] |
|
|
| for file in files: |
| file_url = f"{base_url}/{component}/{file}" |
| file_path = os.path.join(component_dir, file) |
| if not os.path.exists(file_path): |
| try: |
| print(f" Downloading {component}/{file}...") |
| download_url_to_file(file_url, file_path) |
| except Exception as e: |
| print(f" Warning: Failed to download {file}: {e}") |
|
|
| return components_dir |
|
|
| def process_mask(mask_image: Image.Image, invert_output: bool = False, |
| mask_blur: int = 0, mask_offset: int = 0) -> Image.Image: |
| if invert_output: |
| mask_np = np.array(mask_image) |
| mask_image = Image.fromarray(255 - mask_np) |
| if mask_blur > 0: |
| mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur)) |
| if mask_offset != 0: |
| filter_type = ImageFilter.MaxFilter if mask_offset > 0 else ImageFilter.MinFilter |
| size = abs(mask_offset) * 2 + 1 |
| for _ in range(abs(mask_offset)): |
| mask_image = mask_image.filter(filter_type(size)) |
| return mask_image |
|
|
| def apply_background_color(image: Image.Image, mask_image: Image.Image, |
| background: str = "Alpha", |
| background_color: str = "#222222") -> Image.Image: |
| rgba_image = image.copy().convert('RGBA') |
| rgba_image.putalpha(mask_image.convert('L')) |
|
|
| if background == "Color": |
| def hex_to_rgba(hex_color): |
| hex_color = hex_color.lstrip('#') |
| r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16) |
| return (r, g, b, 255) |
| rgba = hex_to_rgba(background_color) |
| bg_image = Image.new('RGBA', image.size, rgba) |
| composite_image = Image.alpha_composite(bg_image, rgba_image) |
| return composite_image.convert('RGB') |
| return rgba_image |
|
|
| def pil2tensor(image): |
| return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) |
|
|
| def tensor2pil(image): |
| return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) |
|
|
| SDMatteCore = None |
|
|
| def _resize_norm_image_bchw(image_bchw: torch.Tensor, size_hw=(1024, 1024)) -> torch.Tensor: |
| if image_bchw.shape[1] == 4: |
| image_bchw = image_bchw[:, :3, :, :] |
|
|
| resize = transforms.Resize(size_hw, interpolation=InterpolationMode.BILINEAR, antialias=True) |
| norm = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| x = resize(image_bchw) |
| x = norm(x) |
| return x |
|
|
| def _resize_mask_b1hw(mask_b1hw: torch.Tensor, size_hw=(1024, 1024)) -> torch.Tensor: |
| resize = transforms.Resize(size_hw, interpolation=InterpolationMode.BILINEAR, antialias=True) |
| return resize(mask_b1hw) |
|
|
| class AILab_SDMatte: |
| @classmethod |
| def INPUT_TYPES(cls): |
| tooltips = { |
| "model": "SDMatte model variant: Standard or Plus version", |
| "image": "Input image for matting extraction", |
| "mask": "Mask: White=foreground, Black=background. If omitted and image has alpha, alpha will be used.", |
| "process_res": "Processing resolution: higher = better quality but slower", |
| "device": "Auto: smart detection, CPU: force CPU, GPU: force GPU", |
| "transparent_object": "Whether input image contains transparent objects", |
| "mask_refine": "Enable mask refinement using mask constraints", |
| "sensitivity": "Sensitivity for mask constraint (0.1-1.0): higher = more strict", |
| "mask_blur": "Blur mask edges (0 = disabled)", |
| "mask_offset": "Expand/shrink mask (positive = expand)", |
| "invert_output": "Invert the mask output", |
| "background": "Background type for output", |
| "background_color": "Background color (when not Alpha)", |
| } |
| return { |
| "required": { |
| "image": ("IMAGE",), |
| "model": (list(SDMATTE_MODELS.keys()), {"default": "SDMatte", "tooltip": tooltips["model"]}), |
| "device": (["Auto", "CPU", "GPU"], {"default": "Auto", "tooltip": tooltips["device"]}), |
| "process_res": ("INT", {"default": 1024, "min": 256, "max": 2048, "step": 8, "tooltip": tooltips["process_res"]}), |
| }, |
| "optional": { |
| "mask": ("MASK", {"tooltip": tooltips["mask"]}), |
| "transparent_object": ("BOOLEAN", {"default": True, "tooltip": tooltips["transparent_object"]}), |
| "mask_refine": ("BOOLEAN", {"default": True, "tooltip": tooltips["mask_refine"]}), |
| "sensitivity": ("FLOAT", {"default": 0.9, "min": 0.1, "max": 1.0, "step": 0.1, "tooltip": tooltips["sensitivity"]}), |
| "mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1, "tooltip": tooltips["mask_blur"]}), |
| "mask_offset": ("INT", {"default": 0, "min": -64, "max": 64, "step": 1, "tooltip": tooltips["mask_offset"]}), |
| "invert_output": ("BOOLEAN", {"default": False, "tooltip": tooltips["invert_output"]}), |
| "background": (["Alpha", "Color"], {"default": "Alpha", "tooltip": tooltips["background"]}), |
| "background_color": ("COLORCODE", {"default": "#222222", "tooltip": tooltips["background_color"]}), |
| } |
| } |
|
|
| RETURN_TYPES = ("IMAGE", "MASK", "IMAGE") |
| RETURN_NAMES = ("IMAGE", "MASK", "MASK_IMAGE") |
| FUNCTION = "matting_inference" |
| CATEGORY = "🧪AILab/🧽RMBG" |
|
|
| def __init__(self): |
| self.model_cache = {} |
|
|
| def load_sdmatte_model(self, model_name, device="Auto"): |
| cache_key = f"{model_name}_{device}" |
| |
| current_model_keys = [k for k in self.model_cache.keys() if k.startswith(model_name)] |
| if cache_key not in self.model_cache and len(self.model_cache) > 0: |
| for key in list(self.model_cache.keys()): |
| if key not in current_model_keys: |
| del self.model_cache[key] |
| import gc |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| if cache_key not in self.model_cache: |
| if not DIFFUSERS_AVAILABLE: |
| raise ImportError("diffusers and transformers are required for SDMatte functionality") |
| |
| global SDMatteCore |
| if SDMatteCore is None: |
| import sys |
| import os |
| current_dir = os.path.dirname(__file__) |
| if current_dir not in sys.path: |
| sys.path.insert(0, current_dir) |
| from SDMatte.modeling.SDMatte.meta_arch import SDMatte as SDMatteCore |
| |
| model_info = SDMATTE_MODELS[model_name] |
| |
| model_path = get_or_download_model_file( |
| model_info["filename"], |
| model_info["model_url"], |
| "RMBG/SDMatte" |
| ) |
| |
| pretrained_repo = ensure_model_components(model_name) |
| |
| sdmatte_model = SDMatteCore( |
| pretrained_model_name_or_path=pretrained_repo, |
| load_weight=False, |
| use_aux_input=True, |
| aux_input="trimap", |
| use_encoder_hidden_states=True, |
| use_attention_mask=True, |
| add_noise=False, |
| ) |
| |
| self._load_model_weights(sdmatte_model, model_path) |
| |
| device_obj = comfy.model_management.get_torch_device() |
| if device == "CPU": |
| device_obj = torch.device('cpu') |
| elif device == "GPU": |
| if not torch.cuda.is_available(): |
| print("SDMatte: GPU requested but CUDA not available, falling back to CPU") |
| device_obj = torch.device('cpu') |
| else: |
| device_obj = comfy.model_management.get_torch_device() |
| |
| sdmatte_model.eval() |
| sdmatte_model.to(device_obj) |
| |
| if device_obj.type == 'cuda': |
| self._apply_memory_optimizations(sdmatte_model) |
| |
| self.model_cache[cache_key] = sdmatte_model |
| |
| return self.model_cache[cache_key] |
|
|
| def _load_model_weights(self, model, model_path): |
| if not SAFETENSORS_AVAILABLE: |
| raise ImportError("safetensors is required for SDMatte functionality") |
|
|
| try: |
| state_dict = load_file(model_path) |
| model.load_state_dict(state_dict, strict=False) |
| except Exception as e: |
| if os.path.exists(model_path): |
| print(f"[SDMatte] Model file appears corrupted, deleting: {model_path}") |
| os.remove(model_path) |
| raise RuntimeError(f"Failed to load model weights. File may be corrupted. Please try again to re-download. Error: {e}") |
|
|
| def _apply_memory_optimizations(self, model): |
| try: |
| torch.cuda.empty_cache() |
| except Exception: |
| pass |
|
|
| try: |
| unet = getattr(model, 'unet', None) |
| if unet is not None and hasattr(unet, 'set_attn_processor'): |
| from diffusers.models.attention_processor import SlicedAttnProcessor |
| unet.set_attn_processor(SlicedAttnProcessor(slice_size=1)) |
| except Exception: |
| pass |
|
|
| def matting_inference(self, image, model, process_res, device="Auto", |
| mask=None, transparent_object=True, mask_refine=True, |
| sensitivity=0.8, mask_blur=0, mask_offset=0, |
| invert_output=False, background="Alpha", background_color="#222222"): |
| sdmatte_model = self.load_sdmatte_model(model, device) |
| device_obj = comfy.model_management.get_torch_device() |
| if device == "CPU": |
| device_obj = torch.device('cpu') |
| |
| batch_size = image.shape[0] |
| |
| result_masks = [] |
| result_images = [] |
| result_mask_images = [] |
| |
| for b in range(batch_size): |
| img_pil = tensor2pil(image[b]) |
| B, H, W = 1, img_pil.height, img_pil.width |
| orig_h, orig_w = H, W |
| |
| img_bchw = image[b:b+1].permute(0, 3, 1, 2).contiguous().to(device_obj) |
| img_in = _resize_norm_image_bchw(img_bchw, (int(process_res), int(process_res))) |
| |
| if mask is not None: |
| mask_b1hw = mask[b:b+1].unsqueeze(1).contiguous().to(device_obj) |
| mask_for_refine = mask[b:b+1] |
| else: |
| if image.shape[-1] == 4: |
| alpha = image[b, :, :, 3] |
| mask_b1hw = alpha.unsqueeze(0).unsqueeze(0).contiguous().to(device_obj) |
| mask_for_refine = alpha.unsqueeze(0) |
| else: |
| raise ValueError("Mask required: provide a mask or use an image with alpha.") |
| |
| tri = _resize_mask_b1hw(mask_b1hw, (int(process_res), int(process_res))) * 2 - 1 |
| data = {"image": img_in, |
| "is_trans": torch.tensor([1 if transparent_object else 0], device=device_obj), |
| "caption": [""], |
| "trimap": tri, |
| "trimap_coords": torch.tensor([[0,0,1,1]], dtype=tri.dtype, device=device_obj)} |
| |
| with torch.inference_mode(): |
| if device_obj.type == 'cuda': |
| with torch.autocast(device_type='cuda', dtype=torch.float16): |
| pred_alpha = sdmatte_model(data) |
| else: |
| pred_alpha = sdmatte_model(data) |
| |
| out = transforms.Resize((orig_h, orig_w), interpolation=InterpolationMode.BILINEAR, antialias=True)(pred_alpha) |
| out = out.squeeze(1).clamp(0, 1).detach().cpu() |
| |
| if mask_refine: |
| out = self._refine_mask(out, mask_for_refine, sensitivity) |
| |
| mask_pil = Image.fromarray((out[0].numpy() * 255).astype(np.uint8), mode="L") |
| |
| mask_image = process_mask(mask_pil, invert_output, mask_blur, mask_offset) |
| |
| result_image = apply_background_color(img_pil, mask_image, background, background_color) |
| if background == "Color": |
| result_image = result_image.convert("RGB") |
| else: |
| result_image = result_image.convert("RGBA") |
| |
| mask_tensor = torch.from_numpy(np.array(mask_image).astype(np.float32) / 255.0).unsqueeze(0) |
| mask_image_vis = mask_tensor.reshape((-1, 1, mask_image.height, mask_image.width)).movedim(1, -1).expand(-1, -1, -1, 3) |
| |
| result_masks.append(mask_tensor) |
| result_images.append(pil2tensor(result_image)) |
| result_mask_images.append(mask_image_vis) |
| |
| if device_obj.type == 'cuda': |
| torch.cuda.empty_cache() |
| import gc |
| gc.collect() |
| |
| return (torch.cat(result_images, dim=0), torch.cat(result_masks, dim=0), torch.cat(result_mask_images, dim=0)) |
|
|
| def _refine_mask(self, mask, trimap, constraint): |
| trimap_cpu = trimap.cpu() |
| foreground_regions = trimap_cpu > constraint |
| background_regions = trimap_cpu < (1.0 - constraint) |
| unknown_regions = ~(foreground_regions | background_regions) |
|
|
| refined_mask = mask.clone() |
| refined_mask[background_regions] = 0.0 |
| refined_mask[foreground_regions] = torch.clamp(refined_mask[foreground_regions] * 1.2, 0, 1) |
|
|
| alpha_threshold = 0.3 |
| low_confidence = (refined_mask < alpha_threshold) & unknown_regions |
| refined_mask[low_confidence] = 0.0 |
|
|
| return refined_mask |
|
|
|
|
| NODE_CLASS_MAPPINGS = { |
| "AILab_SDMatte": AILab_SDMatte, |
| } |
| NODE_DISPLAY_NAME_MAPPINGS = { |
| "AILab_SDMatte": "SDMatte Matting (RMBG)", |
| } |
|
|