| import os
|
| import sys
|
| import copy
|
| from pathlib import Path
|
| import torch
|
| import numpy as np
|
| from PIL import Image, ImageFilter
|
| from torch.hub import download_url_to_file
|
| from safetensors.torch import load_file
|
|
|
| import folder_paths
|
| import comfy.model_management
|
|
|
| from hydra import initialize_config_dir
|
| from hydra.core.global_hydra import GlobalHydra
|
|
|
| try:
|
| from groundingdino.util.slconfig import SLConfig
|
| from groundingdino.models import build_model
|
| from groundingdino.util.utils import clean_state_dict
|
| from groundingdino.util import box_ops
|
| from groundingdino.datasets.transforms import Compose, RandomResize, ToTensor, Normalize
|
| GROUNDINGDINO_AVAILABLE = True
|
| except ImportError:
|
| GROUNDINGDINO_AVAILABLE = False
|
| print("Warning: GroundingDINO not available. Text prompts will use fallback method.")
|
|
|
| current_dir = Path(__file__).resolve().parent
|
| repo_root = current_dir.parent
|
| models_path = repo_root / "models"
|
| sam2_path = models_path / "sam2"
|
| sys.path.insert(0, str(models_path))
|
|
|
| from contextlib import contextmanager
|
|
|
| @contextmanager
|
| def _sam2_no_jit():
|
| _orig = torch.jit.script
|
| torch.jit.script = lambda x, *a, **k: x
|
| try:
|
| yield
|
| finally:
|
| torch.jit.script = _orig
|
|
|
| from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| from AILab_ImageMaskTools import pil2tensor, tensor2pil
|
|
|
|
|
| SAM2_MODELS = {
|
| "sam2.1_hiera_tiny": {
|
| "fp32": {
|
| "model_url": "https://huggingface.co/1038lab/sam2/resolve/main/sam2.1_hiera_tiny.safetensors",
|
| "filename": "sam2.1_hiera_tiny.safetensors"
|
| },
|
| "fp16": {
|
| "model_url": "https://huggingface.co/1038lab/sam2/resolve/main/sam2.1_hiera_tiny-fp16.safetensors",
|
| "filename": "sam2.1_hiera_tiny-fp16.safetensors"
|
| }
|
| },
|
| "sam2.1_hiera_small": {
|
| "fp32": {
|
| "model_url": "https://huggingface.co/1038lab/sam2/resolve/main/sam2.1_hiera_small.safetensors",
|
| "filename": "sam2.1_hiera_small.safetensors"
|
| },
|
| "fp16": {
|
| "model_url": "https://huggingface.co/1038lab/sam2/resolve/main/sam2.1_hiera_small-fp16.safetensors",
|
| "filename": "sam2.1_hiera_small-fp16.safetensors"
|
| }
|
| },
|
| "sam2.1_hiera_base_plus": {
|
| "fp32": {
|
| "model_url": "https://huggingface.co/1038lab/sam2/resolve/main/sam2.1_hiera_base_plus.safetensors",
|
| "filename": "sam2.1_hiera_base_plus.safetensors"
|
| },
|
| "fp16": {
|
| "model_url": "https://huggingface.co/1038lab/sam2/resolve/main/sam2.1_hiera_base_plus-fp16.safetensors",
|
| "filename": "sam2.1_hiera_base_plus-fp16.safetensors"
|
| }
|
| },
|
| "sam2.1_hiera_large": {
|
| "fp32": {
|
| "model_url": "https://huggingface.co/1038lab/sam2/resolve/main/sam2.1_hiera_large.safetensors",
|
| "filename": "sam2.1_hiera_large.safetensors"
|
| },
|
| "fp16": {
|
| "model_url": "https://huggingface.co/1038lab/sam2/resolve/main/sam2.1_hiera_large-fp16.safetensors",
|
| "filename": "sam2.1_hiera_large-fp16.safetensors"
|
| }
|
| }
|
| }
|
|
|
|
|
| DINO_MODELS = {
|
| "GroundingDINO_SwinT_OGC (694MB)": {
|
| "config_url": "https://huggingface.co/1038lab/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py",
|
| "model_url": "https://huggingface.co/1038lab/GroundingDINO/resolve/main/groundingdino_swint_ogc.safetensors",
|
| "config_filename": "GroundingDINO_SwinT_OGC.cfg.py",
|
| "model_filename": "groundingdino_swint_ogc.safetensors"
|
| },
|
| "GroundingDINO_SwinB (938MB)": {
|
| "config_url": "https://huggingface.co/1038lab/GroundingDINO/resolve/main/GroundingDINO_SwinB.cfg.py",
|
| "model_url": "https://huggingface.co/1038lab/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.safetensors",
|
| "config_filename": "GroundingDINO_SwinB.cfg.py",
|
| "model_filename": "groundingdino_swinb_cogcoor.safetensors"
|
| }
|
| }
|
|
|
| def get_or_download_model_file(filename, url, dirname):
|
| local_path = folder_paths.get_full_path(dirname, filename)
|
| if local_path:
|
| return local_path
|
| folder = os.path.join(folder_paths.models_dir, dirname)
|
| os.makedirs(folder, exist_ok=True)
|
| local_path = os.path.join(folder, filename)
|
| if not os.path.exists(local_path):
|
| print(f"Downloading {filename} from {url} ...")
|
| download_url_to_file(url, local_path)
|
| return local_path
|
|
|
| def process_mask(mask_image: Image.Image, invert_output: bool = False,
|
| mask_blur: int = 0, mask_offset: int = 0) -> Image.Image:
|
| if invert_output:
|
| mask_np = np.array(mask_image)
|
| mask_image = Image.fromarray(255 - mask_np)
|
| if mask_blur > 0:
|
| mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur))
|
| if mask_offset != 0:
|
| filter_type = ImageFilter.MaxFilter if mask_offset > 0 else ImageFilter.MinFilter
|
| size = abs(mask_offset) * 2 + 1
|
| for _ in range(abs(mask_offset)):
|
| mask_image = mask_image.filter(filter_type(size))
|
| return mask_image
|
|
|
| def apply_background_color(image: Image.Image, mask_image: Image.Image,
|
| background: str = "Alpha",
|
| background_color: str = "#222222") -> Image.Image:
|
| rgba_image = image.copy().convert('RGBA')
|
| rgba_image.putalpha(mask_image.convert('L'))
|
|
|
| if background == "Color":
|
| def hex_to_rgba(hex_color):
|
| hex_color = hex_color.lstrip('#')
|
| r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
|
| return (r, g, b, 255)
|
| rgba = hex_to_rgba(background_color)
|
| bg_image = Image.new('RGBA', image.size, rgba)
|
| composite_image = Image.alpha_composite(bg_image, rgba_image)
|
| return composite_image.convert('RGB')
|
| return rgba_image
|
|
|
|
|
| class SAM2Segment:
|
| @classmethod
|
| def INPUT_TYPES(cls):
|
| tooltips = {
|
| "prompt": "Enter text description of object to segment",
|
| "sam2_model": "SAM2 model size: Tiny (fastest) to Large (best quality)",
|
| "device": "Auto: smart detection, CPU: force CPU, GPU: force GPU",
|
| "dino_model": "GroundingDINO model for text-to-box detection",
|
| "threshold": "Detection threshold (higher = more strict)",
|
| "mask_blur": "Blur mask edges (0 = disabled)",
|
| "mask_offset": "Expand/shrink mask (positive = expand)",
|
| "invert_output": "Invert the mask output",
|
| "background": "Background type",
|
| "background_color": "Background color (when not Alpha)",
|
| }
|
| return {
|
| "required": {
|
| "image": ("IMAGE",),
|
| "prompt": ("STRING", {"default": "", "multiline": True, "placeholder": "Object to segment", "tooltip": tooltips["prompt"]}),
|
| "sam2_model": (list(SAM2_MODELS.keys()), {"default": "sam2.1_hiera_tiny", "tooltip": tooltips["sam2_model"]}),
|
| "dino_model": (list(DINO_MODELS.keys()), {"default": "GroundingDINO_SwinT_OGC (694MB)", "tooltip": tooltips["dino_model"]}),
|
| "device": (["Auto", "CPU", "GPU"], {"default": "Auto", "tooltip": tooltips["device"]}),
|
| },
|
| "optional": {
|
| "threshold": ("FLOAT", {"default": 0.35, "min": 0.05, "max": 0.95, "step": 0.01, "tooltip": tooltips["threshold"]}),
|
| "mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1, "tooltip": tooltips["mask_blur"]}),
|
| "mask_offset": ("INT", {"default": 0, "min": -64, "max": 64, "step": 1, "tooltip": tooltips["mask_offset"]}),
|
| "invert_output": ("BOOLEAN", {"default": False, "tooltip": tooltips["invert_output"]}),
|
| "background": (["Alpha", "Color"], {"default": "Alpha", "tooltip": tooltips["background"]}),
|
| "background_color": ("COLORCODE", {"default": "#222222", "tooltip": tooltips["background_color"]}),
|
| }
|
| }
|
|
|
| RETURN_TYPES = ("IMAGE", "MASK", "IMAGE")
|
| RETURN_NAMES = ("IMAGE", "MASK", "MASK_IMAGE")
|
| FUNCTION = "segment_v2"
|
| CATEGORY = "🧪AILab/🧽RMBG"
|
|
|
| def __init__(self):
|
| self.dino_model_cache = {}
|
| self.sam2_model_cache = {}
|
|
|
| def load_sam2(self, model_name, device="Auto"):
|
| cache_key = f"{model_name}_{device}"
|
| if cache_key not in self.sam2_model_cache:
|
| model_info = SAM2_MODELS[model_name]
|
| device_obj = comfy.model_management.get_torch_device()
|
|
|
|
|
| if device == "Auto":
|
| precision = "fp16" if hasattr(device_obj, 'type') and device_obj.type == 'cuda' else "fp32"
|
| elif device == "GPU":
|
| precision = "fp16" if hasattr(device_obj, 'type') and device_obj.type == 'cuda' else "fp32"
|
| else:
|
| precision = "fp32"
|
|
|
| print(f"Loading {model_name} in {precision.upper()} precision")
|
|
|
| model_path = get_or_download_model_file(model_info[precision]["filename"], model_info[precision]["model_url"], "sam2")
|
|
|
|
|
| if GlobalHydra().is_initialized():
|
| GlobalHydra.instance().clear()
|
|
|
| initialize_config_dir(config_dir=os.path.join(sam2_path, "configs"), job_name="sam2")
|
|
|
| config_map = {
|
| "sam2.1_hiera_tiny": "sam2.1/sam2.1_hiera_t.yaml",
|
| "sam2.1_hiera_small": "sam2.1/sam2.1_hiera_s.yaml",
|
| "sam2.1_hiera_base_plus": "sam2.1/sam2.1_hiera_b+.yaml",
|
| "sam2.1_hiera_large": "sam2.1/sam2.1_hiera_l.yaml"
|
| }
|
|
|
| config_file = config_map[model_name]
|
| sam_device = comfy.model_management.get_torch_device()
|
|
|
| from sam2.build_sam import build_sam2
|
| from hydra import compose
|
| from omegaconf import OmegaConf
|
| from hydra.utils import instantiate
|
|
|
| cfg = compose(config_name=config_file)
|
| OmegaConf.resolve(cfg)
|
| sam_model = instantiate(cfg.model, _recursive_=True)
|
|
|
| state_dict = load_file(model_path)
|
|
|
|
|
| dtype = {"fp16": torch.float16, "fp32": torch.float32}[precision]
|
| sam_model.load_state_dict(state_dict, strict=False)
|
| sam_model = sam_model.to(dtype).to(sam_device).eval()
|
|
|
|
|
| with _sam2_no_jit():
|
| predictor = SAM2ImagePredictor(sam_model)
|
|
|
| self.sam2_model_cache[cache_key] = predictor
|
| return self.sam2_model_cache[cache_key]
|
|
|
| def segment_v2(self, image, prompt, sam2_model, dino_model, device, threshold=0.35,
|
| mask_blur=0, mask_offset=0, background="Alpha",
|
| background_color="#222222", invert_output=False):
|
| device_obj = comfy.model_management.get_torch_device()
|
|
|
|
|
| batch_size = image.shape[0] if len(image.shape) == 4 else 1
|
| if len(image.shape) == 3:
|
| image = image.unsqueeze(0)
|
|
|
| result_images = []
|
| result_masks = []
|
| result_mask_images = []
|
|
|
| for b in range(batch_size):
|
| img_pil = tensor2pil(image[b])
|
| img_np = np.array(img_pil.convert("RGB"))
|
|
|
|
|
| dino_info = DINO_MODELS[dino_model]
|
| config_path = get_or_download_model_file(dino_info["config_filename"], dino_info["config_url"], "grounding-dino")
|
| weights_path = get_or_download_model_file(dino_info["model_filename"], dino_info["model_url"], "grounding-dino")
|
|
|
|
|
| dino_key = (config_path, weights_path, device_obj)
|
| if dino_key not in self.dino_model_cache:
|
| args = SLConfig.fromfile(config_path)
|
| model = build_model(args)
|
| checkpoint = load_file(weights_path)
|
| if isinstance(checkpoint, dict) and 'model' in checkpoint:
|
| checkpoint = clean_state_dict(checkpoint['model'])
|
| model.load_state_dict(checkpoint, strict=False)
|
| model.eval()
|
| model.to(device_obj)
|
| self.dino_model_cache[dino_key] = model
|
| dino = self.dino_model_cache[dino_key]
|
|
|
|
|
| predictor = self.load_sam2(sam2_model, device)
|
|
|
|
|
| transform = Compose([
|
| RandomResize([800], max_size=1333),
|
| ToTensor(),
|
| Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| ])
|
| image_tensor, _ = transform(img_pil.convert("RGB"), None)
|
| image_tensor = image_tensor.unsqueeze(0).to(device_obj)
|
|
|
|
|
| text_prompt = prompt if prompt.endswith(".") else prompt + "."
|
|
|
|
|
| with torch.no_grad():
|
| outputs = dino(image_tensor, captions=[text_prompt])
|
| logits = outputs["pred_logits"].sigmoid()[0]
|
| boxes = outputs["pred_boxes"][0]
|
|
|
|
|
| filt_mask = logits.max(dim=1)[0] > threshold
|
| boxes_filt = boxes[filt_mask]
|
|
|
|
|
| if boxes_filt.shape[0] == 0:
|
| width, height = img_pil.size
|
| empty_mask = torch.zeros((1, height, width), dtype=torch.float32, device="cpu")
|
| empty_mask_rgb = empty_mask.reshape((-1, 1, height, width)).movedim(1, -1).expand(-1, -1, -1, 3)
|
| result_image = apply_background_color(img_pil, Image.fromarray((empty_mask[0].numpy() * 255).astype(np.uint8)), background, background_color)
|
| result_images.append(pil2tensor(result_image))
|
| result_masks.append(empty_mask)
|
| result_mask_images.append(empty_mask_rgb)
|
| continue
|
|
|
|
|
| H, W = img_pil.size[1], img_pil.size[0]
|
| boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes_filt)
|
| boxes_xyxy = boxes_xyxy * torch.tensor([W, H, W, H], dtype=torch.float32, device=boxes_xyxy.device)
|
| boxes_xyxy = boxes_xyxy.cpu().numpy()
|
|
|
|
|
| from contextlib import nullcontext
|
|
|
| if device == "Auto":
|
| precision = "fp16" if hasattr(device_obj, 'type') and device_obj.type == 'cuda' else "fp32"
|
| elif device == "GPU":
|
| precision = "fp16" if hasattr(device_obj, 'type') and device_obj.type == 'cuda' else "fp32"
|
| else:
|
| precision = "fp32"
|
|
|
| autocast_condition = not comfy.model_management.is_device_mps(device_obj)
|
| with torch.autocast(comfy.model_management.get_autocast_device(device_obj), dtype=torch.float16 if precision == "fp16" else torch.float32) if autocast_condition else nullcontext():
|
| predictor.set_image(img_pil)
|
|
|
|
|
| all_masks = []
|
| for box in boxes_xyxy:
|
| with torch.no_grad():
|
| masks, iou_predictions, low_res_masks = predictor.predict(
|
| point_coords=None,
|
| point_labels=None,
|
| box=box,
|
| multimask_output=False
|
| )
|
| all_masks.append(masks)
|
|
|
|
|
| if len(all_masks) == 1:
|
| mask = all_masks[0]
|
| else:
|
| combined_mask = np.zeros_like(all_masks[0])
|
| for single_mask in all_masks:
|
| combined_mask = np.maximum(combined_mask, single_mask)
|
| mask = combined_mask
|
|
|
|
|
| if mask.ndim > 2:
|
| mask = mask.squeeze()
|
|
|
| mask = (mask * 255).astype(np.uint8)
|
| mask_pil = Image.fromarray(mask, mode="L")
|
|
|
|
|
| mask_image = process_mask(mask_pil, invert_output, mask_blur, mask_offset)
|
| result_image = apply_background_color(img_pil, mask_image, background, background_color)
|
| if background == "Color":
|
| result_image = result_image.convert("RGB")
|
| else:
|
| result_image = result_image.convert("RGBA")
|
|
|
|
|
| mask_tensor = torch.from_numpy(np.array(mask_image).astype(np.float32) / 255.0).unsqueeze(0)
|
| mask_image_vis = mask_tensor.reshape((-1, 1, mask_image.height, mask_image.width)).movedim(1, -1).expand(-1, -1, -1, 3)
|
|
|
| result_images.append(pil2tensor(result_image))
|
| result_masks.append(mask_tensor)
|
| result_mask_images.append(mask_image_vis)
|
|
|
|
|
| if len(result_images) == 0:
|
| width, height = tensor2pil(image[0]).size
|
| empty_mask = torch.zeros((batch_size, 1, height, width), dtype=torch.float32, device="cpu")
|
| empty_mask_rgb = empty_mask.reshape((-1, 1, height, width)).movedim(1, -1).expand(-1, -1, -1, 3)
|
| return (image, empty_mask, empty_mask_rgb)
|
|
|
|
|
| return (torch.cat(result_images, dim=0),
|
| torch.cat(result_masks, dim=0),
|
| torch.cat(result_mask_images, dim=0))
|
|
|
| NODE_CLASS_MAPPINGS = {
|
| "SAM2Segment": SAM2Segment,
|
| }
|
|
|
| NODE_DISPLAY_NAME_MAPPINGS = {
|
| "SAM2Segment": "SAM2 Segmentation (RMBG)",
|
| }
|
|
|