|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import os
|
| import torch
|
| import torch.nn as nn
|
| import numpy as np
|
| from typing import Tuple, Union
|
| from PIL import Image, ImageFilter
|
| from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
|
| import folder_paths
|
| from huggingface_hub import hf_hub_download
|
| import shutil
|
| from torchvision import transforms
|
|
|
| def pil2tensor(image: Image.Image) -> torch.Tensor:
|
| return torch.from_numpy(np.array(image).astype(np.float32) / 255.0)[None,]
|
|
|
| def tensor2pil(image: torch.Tensor) -> Image.Image:
|
| return Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8))
|
|
|
| def RGB2RGBA(image: Image.Image, mask: Union[Image.Image, torch.Tensor]) -> Image.Image:
|
| if isinstance(mask, torch.Tensor):
|
| mask = mask2image(mask)
|
| if mask.size != image.size:
|
| mask = mask.resize(image.size, Image.Resampling.LANCZOS)
|
| return Image.merge('RGBA', (*image.convert('RGB').split(), mask.convert('L')))
|
|
|
| def mask2image(mask: torch.Tensor) -> Image.Image:
|
| if len(mask.shape) == 2:
|
| mask = mask.unsqueeze(0)
|
| return tensor2pil(mask)
|
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
| folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG"))
|
|
|
| AVAILABLE_MODELS = {
|
| "segformer_fashion": "1038lab/segformer_fashion"
|
| }
|
|
|
| class FashionSegmentAccessories:
|
| @classmethod
|
| def INPUT_TYPES(cls):
|
| accessories_classes = [
|
|
|
| "hat",
|
| "glasses",
|
| "headband, head covering, hair accessory",
|
|
|
| "scarf",
|
| "tie",
|
|
|
| "glove",
|
| "watch",
|
|
|
| "belt",
|
|
|
| "leg warmer",
|
|
|
| "bag, wallet",
|
| "umbrella"
|
| ]
|
|
|
| details_classes = [
|
|
|
| "collar",
|
| "lapel",
|
| "neckline",
|
| "epaulette",
|
| "pocket",
|
|
|
| "buckle",
|
| "zipper",
|
| "applique",
|
| "bow",
|
| "flower",
|
| "bead",
|
| "fringe",
|
| "ribbon",
|
| "rivet",
|
| "ruffle",
|
| "sequin",
|
| "tassel"
|
| ]
|
|
|
| return {
|
| "required": {},
|
| "optional": {
|
| **{cls_name: ("BOOLEAN", {"default": False})
|
| for cls_name in accessories_classes + details_classes},
|
| },
|
| }
|
|
|
| RETURN_TYPES = ("ACCESSORIES_OPTIONS",)
|
| RETURN_NAMES = ("accessories_options",)
|
| FUNCTION = "get_options"
|
| CATEGORY = "🧪AILab/🧽RMBG"
|
|
|
| def get_options(self, **class_selections):
|
| selected = [name for name, selected in class_selections.items() if selected]
|
| return (selected,)
|
|
|
| class FashionSegmentClothing:
|
| def __init__(self):
|
| self.processor = None
|
| self.model = None
|
| self.cache_dir = os.path.join(folder_paths.models_dir, "RMBG", "segformer_fashion")
|
| self.class_map = {
|
| "Unlabelled": 0, "shirt, blouse": 1, "top, t-shirt, sweatshirt": 2, "sweater": 3,
|
| "cardigan": 4, "jacket": 5, "vest": 6, "pants": 7, "shorts": 8, "skirt": 9, "coat": 10,
|
| "dress": 11, "jumpsuit": 12, "cape": 13,
|
| "glasses": 14, "hat": 15, "headband, head covering, hair accessory": 16, "tie": 17, "glove": 18,
|
| "watch": 19, "belt": 20, "leg warmer": 21, "tights, stockings": 22,
|
| "sock": 23, "shoe": 24, "bag, wallet": 25, "scarf": 26, "umbrella": 27,
|
| "hood": 28, "collar": 29, "lapel": 30, "epaulette": 31, "sleeve": 32,
|
| "pocket": 33, "neckline": 34, "buckle": 35, "zipper": 36, "applique": 37,
|
| "bead": 38, "bow": 39, "flower": 40, "fringe": 41, "ribbon": 42,
|
| "rivet": 43, "ruffle": 44, "sequin": 45, "tassel": 46
|
| }
|
|
|
| @classmethod
|
| def INPUT_TYPES(cls):
|
| clothing_classes = [
|
|
|
| "coat",
|
| "jacket",
|
| "cardigan",
|
| "vest",
|
| "sweater",
|
| "hood",
|
| "shirt, blouse",
|
| "top, t-shirt, sweatshirt",
|
| "sleeve",
|
|
|
| "dress",
|
| "jumpsuit",
|
| "cape",
|
|
|
| "pants",
|
| "shorts",
|
| "skirt",
|
|
|
| "tights, stockings",
|
| "sock",
|
| "shoe"
|
| ]
|
|
|
| tooltips = {
|
| "accessories_options": "Select the accessories to be segmented",
|
| "process_res": "Processing resolution (higher = more VRAM)",
|
| "mask_blur": "Blur amount for mask edges",
|
| "mask_offset": "Expand/Shrink mask boundary",
|
| "invert_output": "Invert both image and mask output",
|
| "background": "Choose background type: Alpha (transparent) or Color (custom background color).",
|
| "background_color": "Choose background color (Alpha = transparent)",
|
| }
|
|
|
| return {
|
| "required": {
|
| "images": ("IMAGE",),
|
| },
|
| "optional": {
|
| "accessories_options": ("ACCESSORIES_OPTIONS",),
|
| **{cls_name: ("BOOLEAN", {"default": False,})
|
| for cls_name in clothing_classes},
|
| "process_res": ("INT", {"default": 512, "min": 128, "max": 2048, "step": 32, "tooltip": tooltips["process_res"]}),
|
| "mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1, "tooltip": tooltips["mask_blur"]}),
|
| "mask_offset": ("INT", {"default": 0, "min": -64, "max": 64, "step": 1, "tooltip": tooltips["mask_offset"]}),
|
| "invert_output": ("BOOLEAN", {"default": False, "tooltip": tooltips["invert_output"]}),
|
| "background": (["Alpha", "Color"], {"default": "Alpha", "tooltip": tooltips["background"]}),
|
| "background_color": ("COLORCODE", {"default": "#222222", "tooltip": tooltips["background_color"]}),
|
| },
|
| }
|
|
|
| RETURN_TYPES = ("IMAGE", "MASK", "IMAGE")
|
| RETURN_NAMES = ("IMAGE", "MASK", "MASK_IMAGE")
|
| FUNCTION = "segment_fashion"
|
| CATEGORY = "🧪AILab/🧽RMBG"
|
|
|
| def check_model_cache(self):
|
| if not os.path.exists(self.cache_dir):
|
| return False, "Model directory not found"
|
|
|
| required_files = [
|
| 'config.json',
|
| 'model.safetensors',
|
| 'preprocessor_config.json'
|
| ]
|
|
|
| missing_files = [f for f in required_files if not os.path.exists(os.path.join(self.cache_dir, f))]
|
| if missing_files:
|
| return False, f"Missing required model files: {', '.join(missing_files)}"
|
| return True, "Model cache verified"
|
|
|
| def clear_model(self):
|
| if self.model is not None:
|
| self.model.cpu()
|
| del self.model
|
| self.model = None
|
| self.processor = None
|
| torch.cuda.empty_cache()
|
|
|
| def download_model_files(self):
|
| model_id = AVAILABLE_MODELS["segformer_fashion"]
|
| model_files = {
|
| 'config.json': 'config.json',
|
| 'model.safetensors': 'model.safetensors',
|
| 'preprocessor_config.json': 'preprocessor_config.json'
|
| }
|
|
|
| os.makedirs(self.cache_dir, exist_ok=True)
|
| print(f"Downloading fashion segmentation model files...")
|
|
|
| try:
|
| for save_name, repo_path in model_files.items():
|
| print(f"Downloading {save_name}...")
|
| downloaded_path = hf_hub_download(
|
| repo_id=model_id,
|
| filename=repo_path,
|
| local_dir=self.cache_dir,
|
| local_dir_use_symlinks=False
|
| )
|
|
|
| if os.path.dirname(downloaded_path) != self.cache_dir:
|
| target_path = os.path.join(self.cache_dir, save_name)
|
| shutil.move(downloaded_path, target_path)
|
| return True, "Model files downloaded successfully"
|
| except Exception as e:
|
| return False, f"Error downloading model files: {str(e)}"
|
|
|
| def segment_fashion(self, images, accessories_options=None, process_res=512, mask_blur=0, mask_offset=0,
|
| background="Alpha", background_color="#222222", invert_output=False, **class_selections):
|
| if accessories_options is None:
|
| accessories_options = []
|
| try:
|
|
|
| cache_status, message = self.check_model_cache()
|
| if not cache_status:
|
| print(f"Cache check: {message}")
|
| download_status, download_message = self.download_model_files()
|
| if not download_status:
|
| raise RuntimeError(download_message)
|
|
|
|
|
| if self.processor is None:
|
| self.processor = SegformerImageProcessor.from_pretrained(self.cache_dir)
|
| self.model = AutoModelForSemanticSegmentation.from_pretrained(self.cache_dir)
|
| self.model.eval()
|
| for param in self.model.parameters():
|
| param.requires_grad = False
|
| self.model.to(device)
|
|
|
|
|
| selected_classes = []
|
|
|
| selected_classes.extend([name for name, selected in class_selections.items() if selected])
|
|
|
| selected_classes.extend(accessories_options)
|
|
|
| if not selected_classes:
|
| selected_classes = ["shirt, blouse"]
|
|
|
| transform_image = transforms.Compose([
|
| transforms.Resize((process_res, process_res)),
|
| transforms.ToTensor(),
|
| ])
|
|
|
| batch_tensor = []
|
| batch_masks = []
|
|
|
| for image in images:
|
| orig_image = tensor2pil(image)
|
| w, h = orig_image.size
|
|
|
| inputs = self.processor(images=orig_image, return_tensors="pt")
|
| inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
| with torch.no_grad():
|
| outputs = self.model(**inputs)
|
| logits = outputs.logits.cpu()
|
|
|
| upsampled_logits = nn.functional.interpolate(
|
| logits,
|
| size=(h, w),
|
| mode="bilinear",
|
| align_corners=False,
|
| )
|
| pred_seg = upsampled_logits.argmax(dim=1)[0]
|
|
|
|
|
| combined_mask = None
|
| for class_name in selected_classes:
|
| mask = (pred_seg == self.class_map[class_name]).float()
|
| if combined_mask is None:
|
| combined_mask = mask
|
| else:
|
| combined_mask = torch.clamp(combined_mask + mask, 0, 1)
|
|
|
|
|
| mask_image = Image.fromarray((combined_mask.numpy() * 255).astype(np.uint8))
|
|
|
| if mask_blur > 0:
|
| mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur))
|
|
|
| if mask_offset != 0:
|
| if mask_offset > 0:
|
| mask_image = mask_image.filter(ImageFilter.MaxFilter(size=mask_offset * 2 + 1))
|
| else:
|
| mask_image = mask_image.filter(ImageFilter.MinFilter(size=-mask_offset * 2 + 1))
|
|
|
| if invert_output:
|
| mask_image = Image.fromarray(255 - np.array(mask_image))
|
|
|
|
|
| if background == "Alpha":
|
| rgba_image = RGB2RGBA(orig_image, mask_image)
|
| result_image = pil2tensor(rgba_image)
|
| else:
|
| def hex_to_rgba(hex_color):
|
| hex_color = hex_color.lstrip('#')
|
| if len(hex_color) == 6:
|
| r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
|
| a = 255
|
| elif len(hex_color) == 8:
|
| r, g, b, a = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16), int(hex_color[6:8], 16)
|
| else:
|
| raise ValueError("Invalid color format")
|
| return (r, g, b, a)
|
| rgba_image = RGB2RGBA(orig_image, mask_image)
|
| rgba = hex_to_rgba(background_color)
|
| bg_image = Image.new('RGBA', orig_image.size, rgba)
|
| composite_image = Image.alpha_composite(bg_image, rgba_image)
|
| result_image = pil2tensor(composite_image.convert('RGB'))
|
|
|
| batch_tensor.append(result_image)
|
| batch_masks.append(pil2tensor(mask_image))
|
|
|
|
|
| mask_images = []
|
| for mask_tensor in batch_masks:
|
|
|
| mask_image = mask_tensor.reshape((-1, 1, mask_tensor.shape[-2], mask_tensor.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
| mask_images.append(mask_image)
|
|
|
| mask_image_output = torch.cat(mask_images, dim=0)
|
|
|
| batch_tensor = torch.cat(batch_tensor, dim=0)
|
| batch_masks = torch.cat(batch_masks, dim=0)
|
|
|
| return (batch_tensor, batch_masks, mask_image_output)
|
|
|
| except Exception as e:
|
| self.clear_model()
|
| raise RuntimeError(f"Error in fashion segmentation: {str(e)}")
|
| finally:
|
| if self.model is not None and not self.model.training:
|
| self.clear_model()
|
|
|
| def __del__(self):
|
| self.clear_model()
|
|
|
| NODE_CLASS_MAPPINGS = {
|
| "FashionSegmentAccessories": FashionSegmentAccessories,
|
| "FashionSegmentClothing": FashionSegmentClothing
|
| }
|
|
|
| NODE_DISPLAY_NAME_MAPPINGS = {
|
| "FashionSegmentAccessories": "Accessories Segment (RMBG)",
|
| "FashionSegmentClothing": "Fashion Segment (RMBG)"
|
| } |