Wan_Backup / custom_nodes /ComfyUI-RMBG /py /AILab_YoloV8.py
Eji-Sensei14's picture
Upload folder using huggingface_hub
c6535db verified
import os
from typing import Dict, List, Optional
import numpy as np
import torch
from PIL import Image, ImageDraw
import folder_paths
from AILab_ImageMaskTools import pil2tensor, tensor2pil
ULTRALYTICS_DIR = os.path.join(folder_paths.models_dir, "ultralytics")
YOLO_LEGACY_DIR = os.path.join(folder_paths.models_dir, "yolo")
os.makedirs(ULTRALYTICS_DIR, exist_ok=True)
os.makedirs(YOLO_LEGACY_DIR, exist_ok=True)
folder_paths.add_model_folder_path("ultralytics", ULTRALYTICS_DIR, is_default=True)
folder_paths.add_model_folder_path("ultralytics", YOLO_LEGACY_DIR)
DEVICE_CHOICES = ("auto", "cuda", "cpu", "mps")
MASK_COUNT_CHOICES = ("all", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10")
MASK_INDEX_CHOICES = ("none", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10")
class AILab_YoloV8Adv:
CATEGORY = "🧪AILab/🧽RMBG"
RETURN_TYPES = ("IMAGE", "MASK", "MASK")
RETURN_NAMES = ("ANNOTATED_IMAGE", "MASK", "MASK_LIST")
FUNCTION = "yolo_detect"
_MODEL_CACHE: Dict[str, "YOLO"] = {}
@classmethod
def _list_models(cls) -> List[str]:
files = folder_paths.get_filename_list("ultralytics")
return sorted(f for f in files if f.lower().endswith(".pt"))
@classmethod
def INPUT_TYPES(cls):
models = cls._list_models()
if not models:
models = [f"Put .pt models into {ULTRALYTICS_DIR}"]
default_model = models[0]
return {
"required": {
"images": ("IMAGE",),
"yolo_model": (tuple(models), {"default": default_model, "tooltip": f"YOLOv8 weights stored under {ULTRALYTICS_DIR} (subfolders allowed)."}),
"mask_count": (MASK_COUNT_CHOICES, {"default": "all", "tooltip": "Merge this many detections. 'all' merges everything (or just the selected index when specified)."}),
},
"optional": {
"select_mask_index": (MASK_INDEX_CHOICES, {"default": "none", "tooltip": "1-based index of the first mask to keep. Use 'none' to start from the first detection."}),
"conf": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Confidence threshold forwarded to Ultralytics."}),
"iou": ("FLOAT", {"default": 0.45, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "IOU used during NMS."}),
"classes": ("STRING", {"default": "", "placeholder": "e.g. 0,2,5-7", "tooltip": "Comma list or ranges of class IDs; empty keeps every class."}),
"device": (DEVICE_CHOICES, {"default": "auto", "tooltip": "Force a device or auto-detect CUDA → MPS → CPU."}),
"max_det": ("INT", {"default": 300, "min": 1, "max": 1000, "step": 1, "tooltip": "Maximum detections per image."}),
"retina_masks": ("BOOLEAN", {"default": True, "tooltip": "Use high-resolution masks (Ultralytics retina_masks flag)."}),
"agnostic_nms": ("BOOLEAN", {"default": False, "tooltip": "Enable class-agnostic NMS."}),
},
}
def _resolve_device(self, requested: str) -> str:
if requested != "auto":
return requested
if torch.cuda.is_available():
return "cuda"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return "cpu"
def _parse_classes(self, value: str) -> Optional[List[int]]:
if not value or not value.strip():
return None
classes: List[int] = []
try:
for chunk in value.split(","):
chunk = chunk.strip()
if not chunk:
continue
if "-" in chunk:
start, end = [int(x) for x in chunk.split("-", 1)]
if start > end:
start, end = end, start
classes.extend(range(start, end + 1))
else:
classes.append(int(chunk))
return sorted(set(classes))
except ValueError:
print(f"[AILab_YoloV8] Invalid classes string: {value}. Ignoring filter.")
return None
def _resolve_model_path(self, name: str) -> str:
return folder_paths.get_full_path_or_raise("ultralytics", name)
def _get_model(self, model_path: str):
model = self._MODEL_CACHE.get(model_path)
if model is None:
from ultralytics import YOLO
model = YOLO(model_path)
self._MODEL_CACHE[model_path] = model
return model
def _result_to_tensor(self, result) -> torch.Tensor:
plotted = result.plot()
rgb = plotted[..., ::-1]
return pil2tensor(Image.fromarray(rgb))
def _mask_from_tensor(self, mask_tensor: torch.Tensor, size: Image.Image.size):
mask_np = mask_tensor.detach().cpu().numpy()
mask_img = Image.fromarray((mask_np * 255).astype(np.uint8))
if mask_img.size != size:
mask_img = mask_img.resize(size, Image.Resampling.NEAREST)
return torch.from_numpy(np.array(mask_img).astype(np.float32) / 255.0)
def _collect_masks(self, result, size) -> List[torch.Tensor]:
width, height = size
masks: List[torch.Tensor] = []
if getattr(result, "masks", None) is not None and result.masks.data is not None:
for mask_tensor in result.masks.data:
masks.append(self._mask_from_tensor(mask_tensor, size))
elif getattr(result, "boxes", None) is not None and len(result.boxes.xyxy) > 0:
for box in result.boxes:
x1, y1, x2, y2 = [int(v) for v in box.xyxy[0].tolist()]
mask_img = Image.new("L", size, 0)
draw = ImageDraw.Draw(mask_img)
draw.rectangle([x1, y1, x2, y2], fill=255)
masks.append(torch.from_numpy(np.array(mask_img).astype(np.float32) / 255.0))
if not masks:
masks.append(torch.zeros((height, width), dtype=torch.float32))
return masks
def _merge_masks(self, masks: List[torch.Tensor]) -> torch.Tensor:
if not masks:
raise ValueError("Cannot merge empty mask list.")
merged = torch.zeros_like(masks[0])
for mask in masks:
merged = torch.maximum(merged, mask)
return merged
def yolo_detect(
self,
images,
yolo_model,
mask_count="all",
conf=0.25,
iou=0.45,
classes="",
device="auto",
max_det=300,
retina_masks=True,
agnostic_nms=False,
select_mask_index: str = "none",
):
model_path = self._resolve_model_path(yolo_model)
model = self._get_model(model_path)
device_target = self._resolve_device(device)
class_filter = self._parse_classes(classes)
merged_masks: List[torch.Tensor] = []
annotated_images: List[torch.Tensor] = []
mask_list: List[torch.Tensor] = []
count_limit = 0 if mask_count == "all" else max(0, int(mask_count))
chosen_index: Optional[int] = None
if select_mask_index != "none":
chosen_index = int(select_mask_index) - 1
for idx in range(images.shape[0]):
image_pil = tensor2pil(images[idx])
results = model(
image_pil,
conf=conf,
iou=iou,
classes=class_filter,
device=device_target,
max_det=max_det,
retina_masks=retina_masks,
agnostic_nms=agnostic_nms,
)
if not results:
continue
result = results[0]
annotated_images.append(self._result_to_tensor(result))
frame_masks = self._collect_masks(result, image_pil.size)
selected_masks: List[torch.Tensor]
if chosen_index is None:
if count_limit <= 0 or count_limit >= len(frame_masks):
selected_masks = frame_masks
else:
selected_masks = frame_masks[:count_limit]
else:
if chosen_index >= len(frame_masks):
selected_masks = []
else:
span = count_limit if count_limit > 0 else 1
selected_masks = frame_masks[chosen_index : chosen_index + span]
if selected_masks:
merged_masks.append(self._merge_masks(selected_masks))
mask_list.extend(selected_masks)
else:
fallback = torch.zeros_like(frame_masks[0])
merged_masks.append(fallback)
mask_list.append(fallback)
if not merged_masks:
width, height = tensor2pil(images[0]).size
merged_masks = [torch.zeros((height, width), dtype=torch.float32)]
if not mask_list:
width, height = merged_masks[0].shape[1], merged_masks[0].shape[0]
mask_list = [torch.zeros((height, width), dtype=torch.float32)]
if not annotated_images:
annotated_images = [images]
merged_tensor = torch.stack(merged_masks, dim=0)
annotated_tensor = torch.cat(annotated_images, dim=0)
mask_tensor = torch.stack(mask_list, dim=0)
return annotated_tensor, merged_tensor, mask_tensor
class AILab_YoloV8(AILab_YoloV8Adv):
FUNCTION = "yolo_detect_simple"
@classmethod
def INPUT_TYPES(cls):
models = cls._list_models()
if not models:
models = [f"Put .pt models into {ULTRALYTICS_DIR}"]
default_model = models[0]
return {
"required": {
"images": ("IMAGE",),
"yolo_model": (tuple(models), {"default": default_model, "tooltip": f"YOLOv8 weights stored under {ULTRALYTICS_DIR}. Advanced controls available on YOLOv8 Adv."}),
"mask_count": (MASK_COUNT_CHOICES, {"default": "all", "tooltip": "Merge this many detections. 'all' merges everything (or just the selected index when specified)."}),
},
"optional": {
"select_mask_index": (MASK_INDEX_CHOICES, {"default": "none", "tooltip": "1-based index of the first mask to keep. Use 'none' to start from the first detection."}),
},
}
def yolo_detect_simple(self, images, yolo_model, mask_count="all", select_mask_index="none"):
return super().yolo_detect(
images=images,
yolo_model=yolo_model,
mask_count=mask_count,
conf=0.25,
iou=0.45,
classes="",
device="auto",
max_det=300,
retina_masks=True,
agnostic_nms=False,
select_mask_index=select_mask_index,
)
NODE_CLASS_MAPPINGS = {
"AILab_YoloV8": AILab_YoloV8,
"AILab_YoloV8Adv": AILab_YoloV8Adv,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"AILab_YoloV8": "YOLOv8 (RMBG)",
"AILab_YoloV8Adv": "YOLOv8 Adv (RMBG)",
}