| import os |
| from typing import Dict, List, Optional |
|
|
| import numpy as np |
| import torch |
| from PIL import Image, ImageDraw |
|
|
| import folder_paths |
| from AILab_ImageMaskTools import pil2tensor, tensor2pil |
|
|
| ULTRALYTICS_DIR = os.path.join(folder_paths.models_dir, "ultralytics") |
| YOLO_LEGACY_DIR = os.path.join(folder_paths.models_dir, "yolo") |
| os.makedirs(ULTRALYTICS_DIR, exist_ok=True) |
| os.makedirs(YOLO_LEGACY_DIR, exist_ok=True) |
| folder_paths.add_model_folder_path("ultralytics", ULTRALYTICS_DIR, is_default=True) |
| folder_paths.add_model_folder_path("ultralytics", YOLO_LEGACY_DIR) |
|
|
| DEVICE_CHOICES = ("auto", "cuda", "cpu", "mps") |
| MASK_COUNT_CHOICES = ("all", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10") |
| MASK_INDEX_CHOICES = ("none", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10") |
|
|
|
|
| class AILab_YoloV8Adv: |
| CATEGORY = "🧪AILab/🧽RMBG" |
| RETURN_TYPES = ("IMAGE", "MASK", "MASK") |
| RETURN_NAMES = ("ANNOTATED_IMAGE", "MASK", "MASK_LIST") |
| FUNCTION = "yolo_detect" |
|
|
| _MODEL_CACHE: Dict[str, "YOLO"] = {} |
|
|
| @classmethod |
| def _list_models(cls) -> List[str]: |
| files = folder_paths.get_filename_list("ultralytics") |
| return sorted(f for f in files if f.lower().endswith(".pt")) |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| models = cls._list_models() |
| if not models: |
| models = [f"Put .pt models into {ULTRALYTICS_DIR}"] |
| default_model = models[0] |
|
|
| return { |
| "required": { |
| "images": ("IMAGE",), |
| "yolo_model": (tuple(models), {"default": default_model, "tooltip": f"YOLOv8 weights stored under {ULTRALYTICS_DIR} (subfolders allowed)."}), |
| "mask_count": (MASK_COUNT_CHOICES, {"default": "all", "tooltip": "Merge this many detections. 'all' merges everything (or just the selected index when specified)."}), |
| }, |
| "optional": { |
| "select_mask_index": (MASK_INDEX_CHOICES, {"default": "none", "tooltip": "1-based index of the first mask to keep. Use 'none' to start from the first detection."}), |
| "conf": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Confidence threshold forwarded to Ultralytics."}), |
| "iou": ("FLOAT", {"default": 0.45, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "IOU used during NMS."}), |
| "classes": ("STRING", {"default": "", "placeholder": "e.g. 0,2,5-7", "tooltip": "Comma list or ranges of class IDs; empty keeps every class."}), |
| "device": (DEVICE_CHOICES, {"default": "auto", "tooltip": "Force a device or auto-detect CUDA → MPS → CPU."}), |
| "max_det": ("INT", {"default": 300, "min": 1, "max": 1000, "step": 1, "tooltip": "Maximum detections per image."}), |
| "retina_masks": ("BOOLEAN", {"default": True, "tooltip": "Use high-resolution masks (Ultralytics retina_masks flag)."}), |
| "agnostic_nms": ("BOOLEAN", {"default": False, "tooltip": "Enable class-agnostic NMS."}), |
| }, |
| } |
|
|
| def _resolve_device(self, requested: str) -> str: |
| if requested != "auto": |
| return requested |
| if torch.cuda.is_available(): |
| return "cuda" |
| if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| return "mps" |
| return "cpu" |
|
|
| def _parse_classes(self, value: str) -> Optional[List[int]]: |
| if not value or not value.strip(): |
| return None |
| classes: List[int] = [] |
| try: |
| for chunk in value.split(","): |
| chunk = chunk.strip() |
| if not chunk: |
| continue |
| if "-" in chunk: |
| start, end = [int(x) for x in chunk.split("-", 1)] |
| if start > end: |
| start, end = end, start |
| classes.extend(range(start, end + 1)) |
| else: |
| classes.append(int(chunk)) |
| return sorted(set(classes)) |
| except ValueError: |
| print(f"[AILab_YoloV8] Invalid classes string: {value}. Ignoring filter.") |
| return None |
|
|
| def _resolve_model_path(self, name: str) -> str: |
| return folder_paths.get_full_path_or_raise("ultralytics", name) |
|
|
| def _get_model(self, model_path: str): |
| model = self._MODEL_CACHE.get(model_path) |
| if model is None: |
| from ultralytics import YOLO |
| model = YOLO(model_path) |
| self._MODEL_CACHE[model_path] = model |
| return model |
|
|
| def _result_to_tensor(self, result) -> torch.Tensor: |
| plotted = result.plot() |
| rgb = plotted[..., ::-1] |
| return pil2tensor(Image.fromarray(rgb)) |
|
|
| def _mask_from_tensor(self, mask_tensor: torch.Tensor, size: Image.Image.size): |
| mask_np = mask_tensor.detach().cpu().numpy() |
| mask_img = Image.fromarray((mask_np * 255).astype(np.uint8)) |
| if mask_img.size != size: |
| mask_img = mask_img.resize(size, Image.Resampling.NEAREST) |
| return torch.from_numpy(np.array(mask_img).astype(np.float32) / 255.0) |
|
|
| def _collect_masks(self, result, size) -> List[torch.Tensor]: |
| width, height = size |
| masks: List[torch.Tensor] = [] |
|
|
| if getattr(result, "masks", None) is not None and result.masks.data is not None: |
| for mask_tensor in result.masks.data: |
| masks.append(self._mask_from_tensor(mask_tensor, size)) |
|
|
| elif getattr(result, "boxes", None) is not None and len(result.boxes.xyxy) > 0: |
| for box in result.boxes: |
| x1, y1, x2, y2 = [int(v) for v in box.xyxy[0].tolist()] |
| mask_img = Image.new("L", size, 0) |
| draw = ImageDraw.Draw(mask_img) |
| draw.rectangle([x1, y1, x2, y2], fill=255) |
| masks.append(torch.from_numpy(np.array(mask_img).astype(np.float32) / 255.0)) |
|
|
| if not masks: |
| masks.append(torch.zeros((height, width), dtype=torch.float32)) |
|
|
| return masks |
|
|
| def _merge_masks(self, masks: List[torch.Tensor]) -> torch.Tensor: |
| if not masks: |
| raise ValueError("Cannot merge empty mask list.") |
| merged = torch.zeros_like(masks[0]) |
|
|
| for mask in masks: |
| merged = torch.maximum(merged, mask) |
|
|
| return merged |
|
|
| def yolo_detect( |
| self, |
| images, |
| yolo_model, |
| mask_count="all", |
| conf=0.25, |
| iou=0.45, |
| classes="", |
| device="auto", |
| max_det=300, |
| retina_masks=True, |
| agnostic_nms=False, |
| select_mask_index: str = "none", |
| ): |
| model_path = self._resolve_model_path(yolo_model) |
| model = self._get_model(model_path) |
| device_target = self._resolve_device(device) |
| class_filter = self._parse_classes(classes) |
|
|
| merged_masks: List[torch.Tensor] = [] |
| annotated_images: List[torch.Tensor] = [] |
| mask_list: List[torch.Tensor] = [] |
|
|
| count_limit = 0 if mask_count == "all" else max(0, int(mask_count)) |
| chosen_index: Optional[int] = None |
| if select_mask_index != "none": |
| chosen_index = int(select_mask_index) - 1 |
|
|
| for idx in range(images.shape[0]): |
| image_pil = tensor2pil(images[idx]) |
|
|
| results = model( |
| image_pil, |
| conf=conf, |
| iou=iou, |
| classes=class_filter, |
| device=device_target, |
| max_det=max_det, |
| retina_masks=retina_masks, |
| agnostic_nms=agnostic_nms, |
| ) |
|
|
| if not results: |
| continue |
|
|
| result = results[0] |
| annotated_images.append(self._result_to_tensor(result)) |
|
|
| frame_masks = self._collect_masks(result, image_pil.size) |
|
|
| selected_masks: List[torch.Tensor] |
| if chosen_index is None: |
| if count_limit <= 0 or count_limit >= len(frame_masks): |
| selected_masks = frame_masks |
| else: |
| selected_masks = frame_masks[:count_limit] |
| else: |
| if chosen_index >= len(frame_masks): |
| selected_masks = [] |
| else: |
| span = count_limit if count_limit > 0 else 1 |
| selected_masks = frame_masks[chosen_index : chosen_index + span] |
|
|
| if selected_masks: |
| merged_masks.append(self._merge_masks(selected_masks)) |
| mask_list.extend(selected_masks) |
| else: |
| fallback = torch.zeros_like(frame_masks[0]) |
| merged_masks.append(fallback) |
| mask_list.append(fallback) |
|
|
| if not merged_masks: |
| width, height = tensor2pil(images[0]).size |
| merged_masks = [torch.zeros((height, width), dtype=torch.float32)] |
|
|
| if not mask_list: |
| width, height = merged_masks[0].shape[1], merged_masks[0].shape[0] |
| mask_list = [torch.zeros((height, width), dtype=torch.float32)] |
|
|
| if not annotated_images: |
| annotated_images = [images] |
|
|
| merged_tensor = torch.stack(merged_masks, dim=0) |
| annotated_tensor = torch.cat(annotated_images, dim=0) |
| mask_tensor = torch.stack(mask_list, dim=0) |
|
|
| return annotated_tensor, merged_tensor, mask_tensor |
|
|
|
|
| class AILab_YoloV8(AILab_YoloV8Adv): |
| FUNCTION = "yolo_detect_simple" |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| models = cls._list_models() |
| if not models: |
| models = [f"Put .pt models into {ULTRALYTICS_DIR}"] |
| default_model = models[0] |
|
|
| return { |
| "required": { |
| "images": ("IMAGE",), |
| "yolo_model": (tuple(models), {"default": default_model, "tooltip": f"YOLOv8 weights stored under {ULTRALYTICS_DIR}. Advanced controls available on YOLOv8 Adv."}), |
| "mask_count": (MASK_COUNT_CHOICES, {"default": "all", "tooltip": "Merge this many detections. 'all' merges everything (or just the selected index when specified)."}), |
| }, |
| "optional": { |
| "select_mask_index": (MASK_INDEX_CHOICES, {"default": "none", "tooltip": "1-based index of the first mask to keep. Use 'none' to start from the first detection."}), |
| }, |
| } |
|
|
| def yolo_detect_simple(self, images, yolo_model, mask_count="all", select_mask_index="none"): |
| return super().yolo_detect( |
| images=images, |
| yolo_model=yolo_model, |
| mask_count=mask_count, |
| conf=0.25, |
| iou=0.45, |
| classes="", |
| device="auto", |
| max_det=300, |
| retina_masks=True, |
| agnostic_nms=False, |
| select_mask_index=select_mask_index, |
| ) |
|
|
|
|
| NODE_CLASS_MAPPINGS = { |
| "AILab_YoloV8": AILab_YoloV8, |
| "AILab_YoloV8Adv": AILab_YoloV8Adv, |
| } |
|
|
| NODE_DISPLAY_NAME_MAPPINGS = { |
| "AILab_YoloV8": "YOLOv8 (RMBG)", |
| "AILab_YoloV8Adv": "YOLOv8 Adv (RMBG)", |
| } |
|
|