| import torch |
| import torch.nn.functional as F |
| from torchvision.transforms import functional as TF |
| from PIL import Image, ImageDraw, ImageFilter, ImageFont |
| import scipy.ndimage |
| import numpy as np |
| from contextlib import nullcontext |
| import os |
| from tqdm import tqdm |
|
|
| from comfy import model_management |
| from comfy.utils import ProgressBar |
| from comfy.utils import common_upscale |
| from nodes import MAX_RESOLUTION |
|
|
| import folder_paths |
|
|
| from ..utility.utility import tensor2pil, pil2tensor |
|
|
| script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| main_device = model_management.get_torch_device() |
| offload_device = model_management.unet_offload_device() |
|
|
| class BatchCLIPSeg: |
|
|
| def __init__(self): |
| pass |
| |
| @classmethod |
| def INPUT_TYPES(s): |
| |
| return {"required": |
| { |
| "images": ("IMAGE",), |
| "text": ("STRING", {"multiline": False}), |
| "threshold": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 10.0, "step": 0.001}), |
| "binary_mask": ("BOOLEAN", {"default": True}), |
| "combine_mask": ("BOOLEAN", {"default": False}), |
| "use_cuda": ("BOOLEAN", {"default": True}), |
| }, |
| "optional": |
| { |
| "blur_sigma": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}), |
| "opt_model": ("CLIPSEGMODEL", ), |
| "prev_mask": ("MASK", {"default": None}), |
| "image_bg_level": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), |
| "invert": ("BOOLEAN", {"default": False}), |
| } |
| } |
|
|
| CATEGORY = "KJNodes/masking" |
| RETURN_TYPES = ("MASK", "IMAGE", ) |
| RETURN_NAMES = ("Mask", "Image", ) |
| FUNCTION = "segment_image" |
| DESCRIPTION = """ |
| Segments an image or batch of images using CLIPSeg. |
| """ |
|
|
| def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda, blur_sigma=0.0, opt_model=None, prev_mask=None, invert= False, image_bg_level=0.5): |
| from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation |
| import torchvision.transforms as transforms |
| offload_device = model_management.unet_offload_device() |
| device = model_management.get_torch_device() |
| if not use_cuda: |
| device = torch.device("cpu") |
| dtype = model_management.unet_dtype() |
|
|
| if opt_model is None: |
| checkpoint_path = os.path.join(folder_paths.models_dir,'clip_seg', 'clipseg-rd64-refined-fp16') |
| if not hasattr(self, "model"): |
| try: |
| if not os.path.exists(checkpoint_path): |
| from huggingface_hub import snapshot_download |
| snapshot_download(repo_id="Kijai/clipseg-rd64-refined-fp16", local_dir=checkpoint_path, local_dir_use_symlinks=False) |
| self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path) |
| except: |
| checkpoint_path = "CIDAS/clipseg-rd64-refined" |
| self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path) |
| processor = CLIPSegProcessor.from_pretrained(checkpoint_path) |
|
|
| else: |
| self.model = opt_model['model'] |
| processor = opt_model['processor'] |
|
|
| self.model.to(dtype).to(device) |
|
|
| B, H, W, C = images.shape |
| images = images.to(device) |
| |
| autocast_condition = (dtype != torch.float32) and not model_management.is_device_mps(device) |
| with torch.autocast(model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext(): |
|
|
| PIL_images = [Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) for image in images ] |
| prompt = [text] * len(images) |
| input_prc = processor(text=prompt, images=PIL_images, return_tensors="pt") |
|
|
| for key in input_prc: |
| input_prc[key] = input_prc[key].to(device) |
| outputs = self.model(**input_prc) |
|
|
| mask_tensor = torch.sigmoid(outputs.logits) |
| mask_tensor = (mask_tensor - mask_tensor.min()) / (mask_tensor.max() - mask_tensor.min()) |
| mask_tensor = torch.where(mask_tensor > (threshold), mask_tensor, torch.tensor(0, dtype=torch.float)) |
| print(mask_tensor.shape) |
| if len(mask_tensor.shape) == 2: |
| mask_tensor = mask_tensor.unsqueeze(0) |
| mask_tensor = F.interpolate(mask_tensor.unsqueeze(1), size=(H, W), mode='nearest') |
| mask_tensor = mask_tensor.squeeze(1) |
|
|
| self.model.to(offload_device) |
| |
| if binary_mask: |
| mask_tensor = (mask_tensor > 0).float() |
| if blur_sigma > 0: |
| kernel_size = int(6 * int(blur_sigma) + 1) |
| blur = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(blur_sigma, blur_sigma)) |
| mask_tensor = blur(mask_tensor) |
|
|
| if combine_mask: |
| mask_tensor = torch.max(mask_tensor, dim=0)[0] |
| mask_tensor = mask_tensor.unsqueeze(0).repeat(len(images),1,1) |
|
|
| del outputs |
| model_management.soft_empty_cache() |
|
|
| if prev_mask is not None: |
| if prev_mask.shape != mask_tensor.shape: |
| prev_mask = F.interpolate(prev_mask.unsqueeze(1), size=(H, W), mode='nearest') |
| mask_tensor = mask_tensor + prev_mask.to(device) |
| torch.clamp(mask_tensor, min=0.0, max=1.0) |
|
|
| if invert: |
| mask_tensor = 1 - mask_tensor |
|
|
| image_tensor = images * mask_tensor.unsqueeze(-1) + (1 - mask_tensor.unsqueeze(-1)) * image_bg_level |
| image_tensor = torch.clamp(image_tensor, min=0.0, max=1.0).cpu().float() |
|
|
| mask_tensor = mask_tensor.cpu().float() |
| |
| return mask_tensor, image_tensor, |
|
|
| class DownloadAndLoadCLIPSeg: |
|
|
| def __init__(self): |
| pass |
| |
| @classmethod |
| def INPUT_TYPES(s): |
| |
| return {"required": |
| { |
| "model": ( |
| [ 'Kijai/clipseg-rd64-refined-fp16', |
| 'CIDAS/clipseg-rd64-refined', |
| ], |
| ), |
| }, |
| } |
|
|
| CATEGORY = "KJNodes/masking" |
| RETURN_TYPES = ("CLIPSEGMODEL",) |
| RETURN_NAMES = ("clipseg_model",) |
| FUNCTION = "segment_image" |
| DESCRIPTION = """ |
| Downloads and loads CLIPSeg model with huggingface_hub, |
| to ComfyUI/models/clip_seg |
| """ |
|
|
| def segment_image(self, model): |
| from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation |
| checkpoint_path = os.path.join(folder_paths.models_dir,'clip_seg', os.path.basename(model)) |
| if not hasattr(self, "model"): |
| if not os.path.exists(checkpoint_path): |
| from huggingface_hub import snapshot_download |
| snapshot_download(repo_id=model, local_dir=checkpoint_path, local_dir_use_symlinks=False) |
| self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path) |
|
|
| processor = CLIPSegProcessor.from_pretrained(checkpoint_path) |
|
|
| clipseg_model = {} |
| clipseg_model['model'] = self.model |
| clipseg_model['processor'] = processor |
|
|
| return clipseg_model, |
|
|
| class CreateTextMask: |
|
|
| RETURN_TYPES = ("IMAGE", "MASK",) |
| FUNCTION = "createtextmask" |
| CATEGORY = "KJNodes/text" |
| DESCRIPTION = """ |
| Creates a text image and mask. |
| Looks for fonts from this folder: |
| ComfyUI/custom_nodes/ComfyUI-KJNodes/fonts |
| |
| If start_rotation and/or end_rotation are different values, |
| creates animation between them. |
| """ |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "invert": ("BOOLEAN", {"default": False}), |
| "frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}), |
| "text_x": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}), |
| "text_y": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}), |
| "font_size": ("INT", {"default": 32,"min": 8, "max": 4096, "step": 1}), |
| "font_color": ("STRING", {"default": "white"}), |
| "text": ("STRING", {"default": "HELLO!", "multiline": True}), |
| "font": (folder_paths.get_filename_list("kjnodes_fonts"), ), |
| "width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), |
| "height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), |
| "start_rotation": ("INT", {"default": 0,"min": 0, "max": 359, "step": 1}), |
| "end_rotation": ("INT", {"default": 0,"min": -359, "max": 359, "step": 1}), |
| }, |
| } |
|
|
| def createtextmask(self, frames, width, height, invert, text_x, text_y, text, font_size, font_color, font, start_rotation, end_rotation): |
| |
| batch_size = frames |
| out = [] |
| masks = [] |
| rotation = start_rotation |
| if start_rotation != end_rotation: |
| rotation_increment = (end_rotation - start_rotation) / (batch_size - 1) |
|
|
| font_path = folder_paths.get_full_path("kjnodes_fonts", font) |
| |
| for i in range(batch_size): |
| image = Image.new("RGB", (width, height), "black") |
| draw = ImageDraw.Draw(image) |
| font = ImageFont.truetype(font_path, font_size) |
|
|
| |
| text_lines = text.split('\n') |
| lines = [] |
| for text_line in text_lines: |
| if text_line.strip() == "": |
| |
| lines.append("") |
| continue |
| words = text_line.split() |
| current_line = [] |
| for word in words: |
| if current_line: |
| test_line = " ".join(current_line + [word]) |
| else: |
| test_line = word |
| try: |
| test_line_width = font.getbbox(test_line)[2] |
| except Exception: |
| test_line_width = font.getsize(test_line)[0] |
| if test_line_width <= width - 2 * text_x: |
| current_line.append(word) |
| else: |
| lines.append(" ".join(current_line)) |
| current_line = [word] |
| if current_line: |
| lines.append(" ".join(current_line)) |
|
|
| |
| y_offset = text_y |
| for line in lines: |
| text_width = font.getlength(line) |
| text_height = font_size |
| text_center_x = text_x + text_width / 2 |
| text_center_y = y_offset + text_height / 2 |
| try: |
| draw.text((text_x, y_offset), line, font=font, fill=font_color, features=['-liga']) |
| except: |
| draw.text((text_x, y_offset), line, font=font, fill=font_color) |
| y_offset += text_height |
|
|
| if start_rotation != end_rotation: |
| image = image.rotate(rotation, center=(text_center_x, text_center_y)) |
| rotation += rotation_increment |
|
|
| image = np.array(image).astype(np.float32) / 255.0 |
| image = torch.from_numpy(image)[None,] |
| mask = image[:, :, :, 0] |
| masks.append(mask) |
| out.append(image) |
|
|
| if invert: |
| return (1.0 - torch.cat(out, dim=0), 1.0 - torch.cat(masks, dim=0),) |
| return (torch.cat(out, dim=0),torch.cat(masks, dim=0),) |
|
|
| class ColorToMask: |
| |
| RETURN_TYPES = ("MASK",) |
| FUNCTION = "clip" |
| CATEGORY = "KJNodes/masking" |
| DESCRIPTION = """ |
| Converts chosen RGB value to a mask. |
| With batch inputs, the **per_batch** |
| controls the number of images processed at once. |
| """ |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "images": ("IMAGE",), |
| "invert": ("BOOLEAN", {"default": False}), |
| "red": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}), |
| "green": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}), |
| "blue": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}), |
| "threshold": ("INT", {"default": 10,"min": 0, "max": 255, "step": 1}), |
| "per_batch": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}), |
| }, |
| } |
|
|
| def clip(self, images, red, green, blue, threshold, invert, per_batch): |
|
|
| color = torch.tensor([red, green, blue], dtype=torch.uint8) |
| black = torch.tensor([0, 0, 0], dtype=torch.uint8) |
| white = torch.tensor([255, 255, 255], dtype=torch.uint8) |
| |
| if invert: |
| black, white = white, black |
|
|
| steps = images.shape[0] |
| pbar = ProgressBar(steps) |
| tensors_out = [] |
| |
| for start_idx in range(0, images.shape[0], per_batch): |
|
|
| |
| color_distances = torch.norm(images[start_idx:start_idx+per_batch] * 255 - color, dim=-1) |
| |
| |
| mask = color_distances <= threshold |
| |
| |
| mask_out = torch.where(mask.unsqueeze(-1), white, black).float() |
| mask_out = mask_out.mean(dim=-1) |
|
|
| tensors_out.append(mask_out.cpu()) |
| batch_count = mask_out.shape[0] |
| pbar.update(batch_count) |
| |
| tensors_out = torch.cat(tensors_out, dim=0) |
| tensors_out = torch.clamp(tensors_out, min=0.0, max=1.0) |
| return tensors_out, |
| |
| class CreateFluidMask: |
| |
| RETURN_TYPES = ("IMAGE", "MASK") |
| FUNCTION = "createfluidmask" |
| CATEGORY = "KJNodes/masking/generate" |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "invert": ("BOOLEAN", {"default": False}), |
| "frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}), |
| "width": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}), |
| "height": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}), |
| "inflow_count": ("INT", {"default": 3,"min": 0, "max": 255, "step": 1}), |
| "inflow_velocity": ("INT", {"default": 1,"min": 0, "max": 255, "step": 1}), |
| "inflow_radius": ("INT", {"default": 8,"min": 0, "max": 255, "step": 1}), |
| "inflow_padding": ("INT", {"default": 50,"min": 0, "max": 255, "step": 1}), |
| "inflow_duration": ("INT", {"default": 60,"min": 0, "max": 255, "step": 1}), |
| }, |
| } |
| |
| def createfluidmask(self, frames, width, height, invert, inflow_count, inflow_velocity, inflow_radius, inflow_padding, inflow_duration): |
| from ..utility.fluid import Fluid |
| try: |
| from scipy.special import erf |
| except: |
| from scipy.spatial import erf |
| out = [] |
| masks = [] |
| RESOLUTION = width, height |
| DURATION = frames |
|
|
| INFLOW_PADDING = inflow_padding |
| INFLOW_DURATION = inflow_duration |
| INFLOW_RADIUS = inflow_radius |
| INFLOW_VELOCITY = inflow_velocity |
| INFLOW_COUNT = inflow_count |
|
|
| print('Generating fluid solver, this may take some time.') |
| fluid = Fluid(RESOLUTION, 'dye') |
|
|
| center = np.floor_divide(RESOLUTION, 2) |
| r = np.min(center) - INFLOW_PADDING |
|
|
| points = np.linspace(-np.pi, np.pi, INFLOW_COUNT, endpoint=False) |
| points = tuple(np.array((np.cos(p), np.sin(p))) for p in points) |
| normals = tuple(-p for p in points) |
| points = tuple(r * p + center for p in points) |
|
|
| inflow_velocity = np.zeros_like(fluid.velocity) |
| inflow_dye = np.zeros(fluid.shape) |
| for p, n in zip(points, normals): |
| mask = np.linalg.norm(fluid.indices - p[:, None, None], axis=0) <= INFLOW_RADIUS |
| inflow_velocity[:, mask] += n[:, None] * INFLOW_VELOCITY |
| inflow_dye[mask] = 1 |
|
|
| |
| for f in range(DURATION): |
| print(f'Computing frame {f + 1} of {DURATION}.') |
| if f <= INFLOW_DURATION: |
| fluid.velocity += inflow_velocity |
| fluid.dye += inflow_dye |
|
|
| curl = fluid.step()[1] |
| |
| |
| curl = (erf(curl * 2) + 1) / 4 |
|
|
| color = np.dstack((curl, np.ones(fluid.shape), fluid.dye)) |
| color = (np.clip(color, 0, 1) * 255).astype('uint8') |
| image = np.array(color).astype(np.float32) / 255.0 |
| image = torch.from_numpy(image)[None,] |
| mask = image[:, :, :, 0] |
| masks.append(mask) |
| out.append(image) |
| |
| if invert: |
| return (1.0 - torch.cat(out, dim=0),1.0 - torch.cat(masks, dim=0),) |
| return (torch.cat(out, dim=0),torch.cat(masks, dim=0),) |
|
|
| class CreateAudioMask: |
| |
| RETURN_TYPES = ("IMAGE",) |
| FUNCTION = "createaudiomask" |
| CATEGORY = "KJNodes/deprecated" |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "invert": ("BOOLEAN", {"default": False}), |
| "frames": ("INT", {"default": 16,"min": 1, "max": 255, "step": 1}), |
| "scale": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 2.0, "step": 0.01}), |
| "audio_path": ("STRING", {"default": "audio.wav"}), |
| "width": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}), |
| "height": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}), |
| }, |
| } |
|
|
| def createaudiomask(self, frames, width, height, invert, audio_path, scale): |
| try: |
| import librosa |
| except ImportError: |
| raise Exception("Can not import librosa. Install it with 'pip install librosa'") |
| batch_size = frames |
| out = [] |
| masks = [] |
| if audio_path == "audio.wav": |
| audio_path = os.path.join(script_directory, audio_path) |
| audio, sr = librosa.load(audio_path) |
| spectrogram = np.abs(librosa.stft(audio)) |
| |
| for i in range(batch_size): |
| image = Image.new("RGB", (width, height), "black") |
| draw = ImageDraw.Draw(image) |
| frame = spectrogram[:, i] |
| circle_radius = int(height * np.mean(frame)) |
| circle_radius *= scale |
| circle_center = (width // 2, height // 2) |
|
|
| draw.ellipse([(circle_center[0] - circle_radius, circle_center[1] - circle_radius), |
| (circle_center[0] + circle_radius, circle_center[1] + circle_radius)], |
| fill='white') |
| |
| image = np.array(image).astype(np.float32) / 255.0 |
| image = torch.from_numpy(image)[None,] |
| mask = image[:, :, :, 0] |
| masks.append(mask) |
| out.append(image) |
|
|
| if invert: |
| return (1.0 - torch.cat(out, dim=0),) |
| return (torch.cat(out, dim=0),torch.cat(masks, dim=0),) |
| |
| class CreateGradientMask: |
| |
| RETURN_TYPES = ("MASK",) |
| FUNCTION = "createmask" |
| CATEGORY = "KJNodes/masking/generate" |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "invert": ("BOOLEAN", {"default": False}), |
| "frames": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}), |
| "width": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}), |
| "height": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}), |
| }, |
| } |
| def createmask(self, frames, width, height, invert): |
| |
| batch_size = frames |
| out = [] |
| |
| image_batch = np.zeros((batch_size, height, width), dtype=np.float32) |
| |
| for i in range(batch_size): |
| gradient = np.linspace(1.0, 0.0, width, dtype=np.float32) |
| time = i / frames |
| offset_gradient = gradient - time |
| image_batch[i] = offset_gradient.reshape(1, -1) |
| output = torch.from_numpy(image_batch) |
| mask = output |
| out.append(mask) |
| if invert: |
| return (1.0 - torch.cat(out, dim=0),) |
| return (torch.cat(out, dim=0),) |
|
|
| class CreateFadeMask: |
| |
| RETURN_TYPES = ("MASK",) |
| FUNCTION = "createfademask" |
| CATEGORY = "KJNodes/deprecated" |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "invert": ("BOOLEAN", {"default": False}), |
| "frames": ("INT", {"default": 2,"min": 2, "max": 10000, "step": 1}), |
| "width": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}), |
| "height": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}), |
| "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out"],), |
| "start_level": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 1.0, "step": 0.01}), |
| "midpoint_level": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 1.0, "step": 0.01}), |
| "end_level": ("FLOAT", {"default": 0.0,"min": 0.0, "max": 1.0, "step": 0.01}), |
| "midpoint_frame": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}), |
| }, |
| } |
| |
| def createfademask(self, frames, width, height, invert, interpolation, start_level, midpoint_level, end_level, midpoint_frame): |
| def ease_in(t): |
| return t * t |
|
|
| def ease_out(t): |
| return 1 - (1 - t) * (1 - t) |
|
|
| def ease_in_out(t): |
| return 3 * t * t - 2 * t * t * t |
|
|
| batch_size = frames |
| out = [] |
| image_batch = np.zeros((batch_size, height, width), dtype=np.float32) |
|
|
| if midpoint_frame == 0: |
| midpoint_frame = batch_size // 2 |
|
|
| for i in range(batch_size): |
| if i <= midpoint_frame: |
| t = i / midpoint_frame |
| if interpolation == "ease_in": |
| t = ease_in(t) |
| elif interpolation == "ease_out": |
| t = ease_out(t) |
| elif interpolation == "ease_in_out": |
| t = ease_in_out(t) |
| color = start_level - t * (start_level - midpoint_level) |
| else: |
| t = (i - midpoint_frame) / (batch_size - midpoint_frame) |
| if interpolation == "ease_in": |
| t = ease_in(t) |
| elif interpolation == "ease_out": |
| t = ease_out(t) |
| elif interpolation == "ease_in_out": |
| t = ease_in_out(t) |
| color = midpoint_level - t * (midpoint_level - end_level) |
|
|
| color = np.clip(color, 0, 255) |
| image = np.full((height, width), color, dtype=np.float32) |
| image_batch[i] = image |
|
|
| output = torch.from_numpy(image_batch) |
| mask = output |
| out.append(mask) |
|
|
| if invert: |
| return (1.0 - torch.cat(out, dim=0),) |
| return (torch.cat(out, dim=0),) |
|
|
| class CreateFadeMaskAdvanced: |
| |
| RETURN_TYPES = ("MASK",) |
| FUNCTION = "createfademask" |
| CATEGORY = "KJNodes/masking/generate" |
| DESCRIPTION = """ |
| Create a batch of masks interpolated between given frames and values. |
| Uses same syntax as Fizz' BatchValueSchedule. |
| First value is the frame index (not that this starts from 0, not 1) |
| and the second value inside the brackets is the float value of the mask in range 0.0 - 1.0 |
| |
| For example the default values: |
| 0:(0.0) |
| 7:(1.0) |
| 15:(0.0) |
| |
| Would create a mask batch fo 16 frames, starting from black, |
| interpolating with the chosen curve to fully white at the 8th frame, |
| and interpolating from that to fully black at the 16th frame. |
| """ |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "points_string": ("STRING", {"default": "0:(0.0),\n7:(1.0),\n15:(0.0)\n", "multiline": True}), |
| "invert": ("BOOLEAN", {"default": False}), |
| "frames": ("INT", {"default": 16,"min": 2, "max": 10000, "step": 1}), |
| "width": ("INT", {"default": 512,"min": 1, "max": 4096, "step": 1}), |
| "height": ("INT", {"default": 512,"min": 1, "max": 4096, "step": 1}), |
| "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out", "none", "default_to_black"],), |
| }, |
| } |
| |
| def createfademask(self, frames, width, height, invert, points_string, interpolation): |
| def ease_in(t): |
| return t * t |
| |
| def ease_out(t): |
| return 1 - (1 - t) * (1 - t) |
|
|
| def ease_in_out(t): |
| return 3 * t * t - 2 * t * t * t |
| |
| |
| points = [] |
| points_string = points_string.rstrip(',\n') |
| for point_str in points_string.split(','): |
| frame_str, color_str = point_str.split(':') |
| frame = int(frame_str.strip()) |
| color = float(color_str.strip()[1:-1]) |
| points.append((frame, color)) |
|
|
| |
| if (interpolation != "default_to_black") and (len(points) == 0 or points[-1][0] != frames - 1): |
| |
| points.append((frames - 1, points[-1][1] if points else 0)) |
|
|
| |
| points.sort(key=lambda x: x[0]) |
|
|
| batch_size = frames |
| out = [] |
| image_batch = np.zeros((batch_size, height, width), dtype=np.float32) |
|
|
| |
| next_point = 1 |
|
|
| for i in range(batch_size): |
| while next_point < len(points) and i > points[next_point][0]: |
| next_point += 1 |
|
|
| |
| prev_point = next_point - 1 |
|
|
| if interpolation == "none": |
| exact_match = False |
| for p in points: |
| if p[0] == i: |
| color = p[1] |
| exact_match = True |
| break |
| if not exact_match: |
| color = points[prev_point][1] |
|
|
| elif interpolation == "default_to_black": |
| exact_match = False |
| for p in points: |
| if p[0] == i: |
| color = p[1] |
| exact_match = True |
| break |
| if not exact_match: |
| color = 0 |
| else: |
| t = (i - points[prev_point][0]) / (points[next_point][0] - points[prev_point][0]) |
| if interpolation == "ease_in": |
| t = ease_in(t) |
| elif interpolation == "ease_out": |
| t = ease_out(t) |
| elif interpolation == "ease_in_out": |
| t = ease_in_out(t) |
| elif interpolation == "linear": |
| pass |
|
|
| color = points[prev_point][1] - t * (points[prev_point][1] - points[next_point][1]) |
| |
| color = np.clip(color, 0, 255) |
| image = np.full((height, width), color, dtype=np.float32) |
| image_batch[i] = image |
|
|
| output = torch.from_numpy(image_batch) |
| mask = output |
| out.append(mask) |
|
|
| if invert: |
| return (1.0 - torch.cat(out, dim=0),) |
| return (torch.cat(out, dim=0),) |
|
|
| class CreateMagicMask: |
| |
| RETURN_TYPES = ("MASK", "MASK",) |
| RETURN_NAMES = ("mask", "mask_inverted",) |
| FUNCTION = "createmagicmask" |
| CATEGORY = "KJNodes/masking/generate" |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "frames": ("INT", {"default": 16,"min": 2, "max": 4096, "step": 1}), |
| "depth": ("INT", {"default": 12,"min": 1, "max": 500, "step": 1}), |
| "distortion": ("FLOAT", {"default": 1.5,"min": 0.0, "max": 100.0, "step": 0.01}), |
| "seed": ("INT", {"default": 123,"min": 0, "max": 99999999, "step": 1}), |
| "transitions": ("INT", {"default": 1,"min": 1, "max": 20, "step": 1}), |
| "frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), |
| "frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), |
| }, |
| } |
|
|
| def createmagicmask(self, frames, transitions, depth, distortion, seed, frame_width, frame_height): |
| from ..utility.magictex import coordinate_grid, random_transform, magic |
| import matplotlib.pyplot as plt |
| rng = np.random.default_rng(seed) |
| out = [] |
| coords = coordinate_grid((frame_width, frame_height)) |
|
|
| |
| frames_per_transition = frames // transitions |
|
|
| |
| base_params = { |
| "coords": random_transform(coords, rng), |
| "depth": depth, |
| "distortion": distortion, |
| } |
| for t in range(transitions): |
| |
| params1 = base_params.copy() |
| params2 = base_params.copy() |
|
|
| params1['coords'] = random_transform(coords, rng) |
| params2['coords'] = random_transform(coords, rng) |
|
|
| for i in range(frames_per_transition): |
| |
| alpha = i / frames_per_transition |
|
|
| |
| params = params1.copy() |
| params['coords'] = (1 - alpha) * params1['coords'] + alpha * params2['coords'] |
|
|
| tex = magic(**params) |
|
|
| dpi = frame_width / 10 |
| fig = plt.figure(figsize=(10, 10), dpi=dpi) |
|
|
| ax = fig.add_subplot(111) |
| plt.subplots_adjust(left=0, right=1, bottom=0, top=1) |
| |
| ax.get_yaxis().set_ticks([]) |
| ax.get_xaxis().set_ticks([]) |
| ax.imshow(tex, aspect='auto') |
| |
| fig.canvas.draw() |
| img = np.array(fig.canvas.renderer._renderer) |
| |
| plt.close(fig) |
| |
| pil_img = Image.fromarray(img).convert("L") |
| mask = torch.tensor(np.array(pil_img)) / 255.0 |
| |
| out.append(mask) |
| |
| return (torch.stack(out, dim=0), 1.0 - torch.stack(out, dim=0),) |
| |
| class CreateShapeMask: |
| |
| RETURN_TYPES = ("MASK", "MASK",) |
| RETURN_NAMES = ("mask", "mask_inverted",) |
| FUNCTION = "createshapemask" |
| CATEGORY = "KJNodes/masking/generate" |
| DESCRIPTION = """ |
| Creates a mask or batch of masks with the specified shape. |
| Locations are center locations. |
| Grow value is the amount to grow the shape on each frame, creating animated masks. |
| """ |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "shape": ( |
| [ 'circle', |
| 'square', |
| 'triangle', |
| ], |
| { |
| "default": 'circle' |
| }), |
| "frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}), |
| "location_x": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}), |
| "location_y": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}), |
| "grow": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), |
| "frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), |
| "frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), |
| "shape_width": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}), |
| "shape_height": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}), |
| }, |
| } |
|
|
| def createshapemask(self, frames, frame_width, frame_height, location_x, location_y, shape_width, shape_height, grow, shape): |
| |
| batch_size = frames |
| out = [] |
| color = "white" |
| for i in range(batch_size): |
| image = Image.new("RGB", (frame_width, frame_height), "black") |
| draw = ImageDraw.Draw(image) |
|
|
| |
| current_width = max(0, shape_width + i*grow) |
| current_height = max(0, shape_height + i*grow) |
|
|
| if shape == 'circle' or shape == 'square': |
| |
| left_up_point = (location_x - current_width // 2, location_y - current_height // 2) |
| right_down_point = (location_x + current_width // 2, location_y + current_height // 2) |
| two_points = [left_up_point, right_down_point] |
|
|
| if shape == 'circle': |
| draw.ellipse(two_points, fill=color) |
| elif shape == 'square': |
| draw.rectangle(two_points, fill=color) |
| |
| elif shape == 'triangle': |
| |
| left_up_point = (location_x - current_width // 2, location_y + current_height // 2) |
| right_down_point = (location_x + current_width // 2, location_y + current_height // 2) |
| top_point = (location_x, location_y - current_height // 2) |
| draw.polygon([top_point, left_up_point, right_down_point], fill=color) |
|
|
| image = pil2tensor(image) |
| mask = image[:, :, :, 0] |
| out.append(mask) |
| outstack = torch.cat(out, dim=0) |
| return (outstack, 1.0 - outstack,) |
| |
| class CreateVoronoiMask: |
| |
| RETURN_TYPES = ("MASK", "MASK",) |
| RETURN_NAMES = ("mask", "mask_inverted",) |
| FUNCTION = "createvoronoi" |
| CATEGORY = "KJNodes/masking/generate" |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "frames": ("INT", {"default": 16,"min": 2, "max": 4096, "step": 1}), |
| "num_points": ("INT", {"default": 15,"min": 1, "max": 4096, "step": 1}), |
| "line_width": ("INT", {"default": 4,"min": 1, "max": 4096, "step": 1}), |
| "speed": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 1.0, "step": 0.01}), |
| "frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), |
| "frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), |
| }, |
| } |
|
|
| def createvoronoi(self, frames, num_points, line_width, speed, frame_width, frame_height): |
| from scipy.spatial import Voronoi |
| from matplotlib import pyplot as plt |
| |
| batch_size = frames |
| out = [] |
| |
| |
| aspect_ratio = frame_width / frame_height |
| |
| |
| start_points = np.random.rand(num_points, 2) |
| start_points[:, 0] *= aspect_ratio |
| |
| end_points = np.random.rand(num_points, 2) |
| end_points[:, 0] *= aspect_ratio |
|
|
| for i in range(batch_size): |
| |
| t = (i * speed) / (batch_size - 1) |
| t = np.clip(t, 0, 1) |
| points = (1 - t) * start_points + t * end_points |
|
|
| |
| points[:, 0] *= aspect_ratio |
|
|
| vor = Voronoi(points) |
|
|
| |
| fig, ax = plt.subplots() |
| plt.subplots_adjust(left=0, right=1, bottom=0, top=1) |
| ax.set_xlim([0, aspect_ratio]); ax.set_ylim([0, 1]) |
| ax.axis('off') |
| ax.margins(0, 0) |
| fig.set_size_inches(aspect_ratio * frame_height/100, frame_height/100) |
| ax.fill_between([0, 1], [0, 1], color='white') |
|
|
| |
| for simplex in vor.ridge_vertices: |
| simplex = np.asarray(simplex) |
| if np.all(simplex >= 0): |
| plt.plot(vor.vertices[simplex, 0], vor.vertices[simplex, 1], 'k-', linewidth=line_width) |
|
|
| fig.canvas.draw() |
| img = np.array(fig.canvas.renderer._renderer) |
|
|
| plt.close(fig) |
|
|
| pil_img = Image.fromarray(img).convert("L") |
| mask = torch.tensor(np.array(pil_img)) / 255.0 |
|
|
| out.append(mask) |
|
|
| return (torch.stack(out, dim=0), 1.0 - torch.stack(out, dim=0),) |
| |
| class GetMaskSizeAndCount: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "mask": ("MASK",), |
| }} |
|
|
| RETURN_TYPES = ("MASK","INT", "INT", "INT",) |
| RETURN_NAMES = ("mask", "width", "height", "count",) |
| FUNCTION = "getsize" |
| CATEGORY = "KJNodes/masking" |
| DESCRIPTION = """ |
| Returns the width, height and batch size of the mask, |
| and passes it through unchanged. |
| |
| """ |
|
|
| def getsize(self, mask): |
| width = mask.shape[2] |
| height = mask.shape[1] |
| count = mask.shape[0] |
| return {"ui": { |
| "text": [f"{count}x{width}x{height}"]}, |
| "result": (mask, width, height, count) |
| } |
|
|
| class GrowMaskWithBlur: |
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "mask": ("MASK",), |
| "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}), |
| "incremental_expandrate": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}), |
| "tapered_corners": ("BOOLEAN", {"default": True}), |
| "flip_input": ("BOOLEAN", {"default": False}), |
| "blur_radius": ("FLOAT", { |
| "default": 0.0, |
| "min": 0.0, |
| "max": 100, |
| "step": 0.1 |
| }), |
| "lerp_alpha": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), |
| "decay_factor": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), |
| }, |
| "optional": { |
| "fill_holes": ("BOOLEAN", {"default": False}), |
| }, |
| } |
|
|
| CATEGORY = "KJNodes/masking" |
| RETURN_TYPES = ("MASK", "MASK",) |
| RETURN_NAMES = ("mask", "mask_inverted",) |
| FUNCTION = "expand_mask" |
| DESCRIPTION = """ |
| # GrowMaskWithBlur |
| - mask: Input mask or mask batch |
| - expand: Expand or contract mask or mask batch by a given amount |
| - incremental_expandrate: increase expand rate by a given amount per frame |
| - tapered_corners: use tapered corners |
| - flip_input: flip input mask |
| - blur_radius: value higher than 0 will blur the mask |
| - lerp_alpha: alpha value for interpolation between frames |
| - decay_factor: decay value for interpolation between frames |
| - fill_holes: fill holes in the mask (slow)""" |
| |
| def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, incremental_expandrate, lerp_alpha, decay_factor, fill_holes=False): |
| import kornia.morphology as morph |
| alpha = lerp_alpha |
| decay = decay_factor |
| if flip_input: |
| mask = 1.0 - mask |
|
|
| growmask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) |
| out = [] |
| previous_output = None |
| current_expand = expand |
| for m in tqdm(growmask, desc="Expanding/Contracting Mask"): |
| output = m.unsqueeze(0).unsqueeze(0).to(main_device) |
| if abs(round(current_expand)) > 0: |
| |
| if tapered_corners: |
| kernel = torch.tensor([[0, 1, 0], |
| [1, 1, 1], |
| [0, 1, 0]], dtype=torch.float32, device=output.device) |
| else: |
| kernel = torch.tensor([[1, 1, 1], |
| [1, 1, 1], |
| [1, 1, 1]], dtype=torch.float32, device=output.device) |
| |
| for _ in range(abs(round(current_expand))): |
| if current_expand < 0: |
| output = morph.erosion(output, kernel) |
| else: |
| output = morph.dilation(output, kernel) |
| |
| output = output.squeeze(0).squeeze(0) |
| |
| if current_expand < 0: |
| current_expand -= abs(incremental_expandrate) |
| else: |
| current_expand += abs(incremental_expandrate) |
| |
| if fill_holes: |
| binary_mask = output > 0 |
| output_np = binary_mask.cpu().numpy() |
| filled = scipy.ndimage.binary_fill_holes(output_np) |
| output = torch.from_numpy(filled.astype(np.float32)).to(output.device) |
| |
| if alpha < 1.0 and previous_output is not None: |
| output = alpha * output + (1 - alpha) * previous_output |
| if decay < 1.0 and previous_output is not None: |
| output += decay * previous_output |
| output = output / output.max() |
| previous_output = output |
| out.append(output.cpu()) |
|
|
| if blur_radius != 0: |
| |
| for idx, tensor in enumerate(out): |
| |
| pil_image = tensor2pil(tensor.cpu().detach())[0] |
| |
| pil_image = pil_image.filter(ImageFilter.GaussianBlur(blur_radius)) |
| |
| out[idx] = pil2tensor(pil_image) |
| blurred = torch.cat(out, dim=0) |
| return (blurred, 1.0 - blurred) |
| else: |
| return (torch.stack(out, dim=0), 1.0 - torch.stack(out, dim=0),) |
| |
| class MaskBatchMulti: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}), |
| "mask_1": ("MASK", ), |
| "mask_2": ("MASK", ), |
| }, |
| } |
|
|
| RETURN_TYPES = ("MASK",) |
| RETURN_NAMES = ("masks",) |
| FUNCTION = "combine" |
| CATEGORY = "KJNodes/masking" |
| DESCRIPTION = """ |
| Creates an image batch from multiple masks. |
| You can set how many inputs the node has, |
| with the **inputcount** and clicking update. |
| """ |
|
|
| def combine(self, inputcount, **kwargs): |
| mask = kwargs["mask_1"] |
| for c in range(1, inputcount): |
| new_mask = kwargs[f"mask_{c + 1}"] |
| if mask.shape[1:] != new_mask.shape[1:]: |
| new_mask = F.interpolate(new_mask.unsqueeze(1), size=(mask.shape[1], mask.shape[2]), mode="bicubic").squeeze(1) |
| mask = torch.cat((mask, new_mask), dim=0) |
| return (mask,) |
|
|
| class OffsetMask: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "mask": ("MASK",), |
| "x": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }), |
| "y": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }), |
| "angle": ("INT", { "default": 0, "min": -360, "max": 360, "step": 1, "display": "number" }), |
| "duplication_factor": ("INT", { "default": 1, "min": 1, "max": 1000, "step": 1, "display": "number" }), |
| "roll": ("BOOLEAN", { "default": False }), |
| "incremental": ("BOOLEAN", { "default": False }), |
| "padding_mode": ( |
| [ |
| 'empty', |
| 'border', |
| 'reflection', |
| |
| ], { |
| "default": 'empty' |
| }), |
| } |
| } |
|
|
| RETURN_TYPES = ("MASK",) |
| RETURN_NAMES = ("mask",) |
| FUNCTION = "offset" |
| CATEGORY = "KJNodes/masking" |
| DESCRIPTION = """ |
| Offsets the mask by the specified amount. |
| - mask: Input mask or mask batch |
| - x: Horizontal offset |
| - y: Vertical offset |
| - angle: Angle in degrees |
| - roll: roll edge wrapping |
| - duplication_factor: Number of times to duplicate the mask to form a batch |
| - border padding_mode: Padding mode for the mask |
| """ |
|
|
| def offset(self, mask, x, y, angle, roll=False, incremental=False, duplication_factor=1, padding_mode="empty"): |
| |
| mask = mask.repeat(duplication_factor, 1, 1).clone() |
|
|
| batch_size, height, width = mask.shape |
|
|
| if angle != 0 and incremental: |
| for i in range(batch_size): |
| rotation_angle = angle * (i+1) |
| mask[i] = TF.rotate(mask[i].unsqueeze(0), rotation_angle).squeeze(0) |
| elif angle > 0: |
| for i in range(batch_size): |
| mask[i] = TF.rotate(mask[i].unsqueeze(0), angle).squeeze(0) |
|
|
| if roll: |
| if incremental: |
| for i in range(batch_size): |
| shift_x = min(x*(i+1), width-1) |
| shift_y = min(y*(i+1), height-1) |
| if shift_x != 0: |
| mask[i] = torch.roll(mask[i], shifts=shift_x, dims=1) |
| if shift_y != 0: |
| mask[i] = torch.roll(mask[i], shifts=shift_y, dims=0) |
| else: |
| shift_x = min(x, width-1) |
| shift_y = min(y, height-1) |
| if shift_x != 0: |
| mask = torch.roll(mask, shifts=shift_x, dims=2) |
| if shift_y != 0: |
| mask = torch.roll(mask, shifts=shift_y, dims=1) |
| else: |
| |
| for i in range(batch_size): |
| if incremental: |
| temp_x = min(x * (i+1), width-1) |
| temp_y = min(y * (i+1), height-1) |
| else: |
| temp_x = min(x, width-1) |
| temp_y = min(y, height-1) |
| if temp_x > 0: |
| if padding_mode == 'empty': |
| mask[i] = torch.cat([torch.zeros((height, temp_x)), mask[i, :, :-temp_x]], dim=1) |
| elif padding_mode in ['replicate', 'reflect']: |
| mask[i] = F.pad(mask[i, :, :-temp_x], (0, temp_x), mode=padding_mode) |
| elif temp_x < 0: |
| if padding_mode == 'empty': |
| mask[i] = torch.cat([mask[i, :, :temp_x], torch.zeros((height, -temp_x))], dim=1) |
| elif padding_mode in ['replicate', 'reflect']: |
| mask[i] = F.pad(mask[i, :, -temp_x:], (temp_x, 0), mode=padding_mode) |
|
|
| if temp_y > 0: |
| if padding_mode == 'empty': |
| mask[i] = torch.cat([torch.zeros((temp_y, width)), mask[i, :-temp_y, :]], dim=0) |
| elif padding_mode in ['replicate', 'reflect']: |
| mask[i] = F.pad(mask[i, :-temp_y, :], (0, temp_y), mode=padding_mode) |
| elif temp_y < 0: |
| if padding_mode == 'empty': |
| mask[i] = torch.cat([mask[i, :temp_y, :], torch.zeros((-temp_y, width))], dim=0) |
| elif padding_mode in ['replicate', 'reflect']: |
| mask[i] = F.pad(mask[i, -temp_y:, :], (temp_y, 0), mode=padding_mode) |
| |
| return mask, |
| |
| class RoundMask: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "mask": ("MASK",), |
| }} |
|
|
| RETURN_TYPES = ("MASK",) |
| FUNCTION = "round" |
| CATEGORY = "KJNodes/masking" |
| DESCRIPTION = """ |
| Rounds the mask or batch of masks to a binary mask. |
| <img src="https://github.com/kijai/ComfyUI-KJNodes/assets/40791699/52c85202-f74e-4b96-9dac-c8bda5ddcc40" width="300" height="250" alt="RoundMask example"> |
| |
| """ |
|
|
| def round(self, mask): |
| mask = mask.round() |
| return (mask,) |
| |
| class ResizeMask: |
| upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "mask": ("MASK",), |
| "width": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1, "display": "number" }), |
| "height": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1, "display": "number" }), |
| "keep_proportions": ("BOOLEAN", { "default": False }), |
| "upscale_method": (s.upscale_methods,), |
| "crop": (["disabled","center"],), |
| } |
| } |
|
|
| RETURN_TYPES = ("MASK", "INT", "INT",) |
| RETURN_NAMES = ("mask", "width", "height",) |
| FUNCTION = "resize" |
| CATEGORY = "KJNodes/masking" |
| DESCRIPTION = """ |
| Resizes the mask or batch of masks to the specified width and height. |
| """ |
|
|
| def resize(self, mask, width, height, keep_proportions, upscale_method,crop): |
| if keep_proportions: |
| _, oh, ow = mask.shape |
| width = ow if width == 0 else width |
| height = oh if height == 0 else height |
| ratio = min(width / ow, height / oh) |
| width = round(ow*ratio) |
| height = round(oh*ratio) |
|
|
| if upscale_method == "lanczos": |
| out_mask = common_upscale(mask.unsqueeze(1).repeat(1, 3, 1, 1), width, height, upscale_method, crop=crop).movedim(1,-1)[:, :, :, 0] |
| else: |
| out_mask = common_upscale(mask.unsqueeze(1), width, height, upscale_method, crop=crop).squeeze(1) |
|
|
| return(out_mask, out_mask.shape[2], out_mask.shape[1],) |
|
|
| class RemapMaskRange: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "mask": ("MASK",), |
| "min": ("FLOAT", {"default": 0.0,"min": -10.0, "max": 1.0, "step": 0.01}), |
| "max": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 10.0, "step": 0.01}), |
| } |
| } |
|
|
| RETURN_TYPES = ("MASK",) |
| RETURN_NAMES = ("mask",) |
| FUNCTION = "remap" |
| CATEGORY = "KJNodes/masking" |
| DESCRIPTION = """ |
| Sets new min and max values for the mask. |
| """ |
|
|
| def remap(self, mask, min, max): |
|
|
| |
| mask_max = torch.max(mask) |
| |
| |
| mask_max = mask_max if mask_max > 0 else 1 |
| |
| |
| |
| scaled_mask = (mask / mask_max) * (max - min) + min |
| |
| |
| scaled_mask = torch.clamp(scaled_mask, min=0.0, max=1.0) |
| |
| return (scaled_mask, ) |
|
|
|
|
| def get_mask_polygon(self, mask_np): |
| import cv2 |
| """Helper function to get polygon points from mask""" |
| |
| contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| |
| if not contours: |
| return None |
| |
| |
| largest_contour = max(contours, key=cv2.contourArea) |
| |
| |
| epsilon = 0.02 * cv2.arcLength(largest_contour, True) |
| polygon = cv2.approxPolyDP(largest_contour, epsilon, True) |
| |
| return polygon.squeeze() |
|
|
| import cv2 |
| class SeparateMasks: |
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "mask": ("MASK", ), |
| "size_threshold_width" : ("INT", {"default": 256, "min": 0.0, "max": 4096, "step": 1}), |
| "size_threshold_height" : ("INT", {"default": 256, "min": 0.0, "max": 4096, "step": 1}), |
| "mode": (["convex_polygons", "area", "box"],), |
| "max_poly_points": ("INT", {"default": 8, "min": 3, "max": 32, "step": 1}), |
|
|
| }, |
| } |
|
|
| RETURN_TYPES = ("MASK",) |
| RETURN_NAMES = ("mask",) |
| FUNCTION = "separate" |
| CATEGORY = "KJNodes/masking" |
| OUTPUT_NODE = True |
| DESCRIPTION = "Separates a mask into multiple masks based on the size of the connected components." |
|
|
| def polygon_to_mask(self, polygon, shape): |
| mask = np.zeros((shape[0], shape[1]), dtype=np.uint8) |
|
|
| if len(polygon.shape) == 2: |
| polygon = polygon.astype(np.int32) |
| cv2.fillPoly(mask, [polygon], 1) |
| return mask |
|
|
| def get_mask_polygon(self, mask_np, max_points): |
| contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| if not contours: |
| return None |
| |
| largest_contour = max(contours, key=cv2.contourArea) |
| hull = cv2.convexHull(largest_contour) |
| |
| |
| perimeter = cv2.arcLength(hull, True) |
| epsilon = perimeter * 0.01 |
| |
| min_eps = perimeter * 0.001 |
| max_eps = perimeter * 0.2 |
| |
| best_approx = None |
| best_diff = float('inf') |
| max_iterations = 20 |
| |
| |
| |
| for i in range(max_iterations): |
| curr_eps = (min_eps + max_eps) / 2 |
| approx = cv2.approxPolyDP(hull, curr_eps, True) |
| points_diff = len(approx) - max_points |
| |
| |
| |
| if abs(points_diff) < best_diff: |
| best_approx = approx |
| best_diff = abs(points_diff) |
| |
| if len(approx) > max_points: |
| min_eps = curr_eps * 1.1 |
| elif len(approx) < max_points: |
| max_eps = curr_eps * 0.9 |
| else: |
| return approx.squeeze() |
| |
| if abs(max_eps - min_eps) < perimeter * 0.0001: |
| break |
| |
| |
| return best_approx.squeeze() if best_approx is not None else hull.squeeze() |
|
|
| def separate(self, mask: torch.Tensor, size_threshold_width: int, size_threshold_height: int, max_poly_points: int, mode: str): |
| from scipy.ndimage import label, center_of_mass |
| import numpy as np |
| |
| B, H, W = mask.shape |
| separated = [] |
|
|
| mask = mask.round() |
| |
| for b in range(B): |
| mask_np = mask[b].cpu().numpy().astype(np.uint8) |
| structure = np.ones((3, 3), dtype=np.int8) |
| labeled, ncomponents = label(mask_np, structure=structure) |
| pbar = ProgressBar(ncomponents) |
| |
| for component in range(1, ncomponents + 1): |
| component_mask_np = (labeled == component).astype(np.uint8) |
| |
| rows = np.any(component_mask_np, axis=1) |
| cols = np.any(component_mask_np, axis=0) |
| y_min, y_max = np.where(rows)[0][[0, -1]] |
| x_min, x_max = np.where(cols)[0][[0, -1]] |
| |
| width = x_max - x_min + 1 |
| height = y_max - y_min + 1 |
| centroid_x = (x_min + x_max) / 2 |
| print(f"Component {component}: width={width}, height={height}, x_pos={centroid_x}") |
| |
| if width >= size_threshold_width and height >= size_threshold_height: |
| if mode == "convex_polygons": |
| polygon = self.get_mask_polygon(component_mask_np, max_poly_points) |
| if polygon is not None: |
| poly_mask = self.polygon_to_mask(polygon, (H, W)) |
| poly_mask = torch.tensor(poly_mask, device=mask.device) |
| separated.append((centroid_x, poly_mask)) |
| elif mode == "box": |
| |
| box_mask = np.zeros((H, W), dtype=np.uint8) |
| box_mask[y_min:y_max+1, x_min:x_max+1] = 1 |
| box_mask = torch.tensor(box_mask, device=mask.device) |
| separated.append((centroid_x, box_mask)) |
| else: |
| area_mask = torch.tensor(component_mask_np, device=mask.device) |
| separated.append((centroid_x, area_mask)) |
| pbar.update(1) |
| |
| if len(separated) > 0: |
| |
| separated.sort(key=lambda x: x[0]) |
| separated = [x[1] for x in separated] |
| out_masks = torch.stack(separated, dim=0) |
| return out_masks, |
| else: |
| return torch.empty((1, 64, 64), device=mask.device), |
|
|
|
|
| class ConsolidateMasksKJ: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "masks": ("MASK",), |
| "width": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}), |
| "height": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}), |
| "padding": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("MASK",) |
| FUNCTION = "consolidate" |
|
|
| CATEGORY = "KJNodes/masking" |
| DESCRIPTION = "Consolidates a batch of separate masks by finding the largest group of masks that fit inside a tile of the given width and height (including the padding), and repeating until no more masks can be combined." |
|
|
| def consolidate(self, masks, width=512, height=512, padding=0): |
| B, H, W = masks.shape |
|
|
| def mask_fits(coords, candidate_coords): |
| x_min, y_min, x_max, y_max = coords |
| cx_min, cy_min, cx_max, cy_max = candidate_coords |
| nx_min, ny_min = min(x_min, cx_min), min(y_min, cy_min) |
| nx_max, ny_max = max(x_max, cx_max), max(y_max, cy_max) |
| if nx_min + width < nx_max + padding or ny_min + height < ny_max + padding: |
| return False, coords |
| return True, (nx_min, ny_min, nx_max, ny_max) |
|
|
| separated = [] |
| final_masks = [] |
| for b in range(B): |
| m = masks[b] |
| rows, cols = m.any(dim=1), m.any(dim=0) |
| y_min, y_max = torch.where(rows)[0][[0, -1]] |
| x_min, x_max = torch.where(cols)[0][[0, -1]] |
| w = x_max - x_min + 1 |
| h = y_max - y_min + 1 |
| separated.append(((x_min.item(), y_min.item(), x_max.item(), y_max.item()), m)) |
|
|
| separated.sort(key=lambda x: x[0]) |
| fits = [] |
| for i, masks in enumerate(separated): |
| coord = masks[0] |
| fits_in_box = [] |
| for j, cand_mask in enumerate(separated): |
| if i == j: |
| continue |
| r, coord = mask_fits(coord, cand_mask[0]) |
| if r: |
| fits_in_box.append(j) |
| fits.append((i, fits_in_box)) |
| fits.sort(key=lambda x: -len(x[1])) |
| seen = [] |
| unique_fits = [] |
| for idx, fs in fits: |
| uniq = [i for i in fs if i not in seen] |
| unique_fits.append((idx, fs, uniq)) |
| seen.extend(uniq) |
| unique_fits.sort(key=lambda x: (-len(x[1]), -len(x[2]))) |
| merged = [] |
| for mask_idx, fitting_masks, _ in unique_fits: |
| if mask_idx in merged: |
| continue |
| fitting_masks = [i for i in fitting_masks if i not in merged] |
| combined_mask = separated[mask_idx][1].clone() |
| for i in fitting_masks: |
| combined_mask += separated[i][1] |
| merged.append(i) |
| merged.append(mask_idx) |
| final_masks.append(combined_mask) |
|
|
| print(f"Consolidated {B} masks into {len(final_masks)}") |
| return (torch.stack(final_masks, dim=0),) |
|
|
|
|
| class DrawMaskOnImage: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "image": ("IMAGE", ), |
| "mask": ("MASK", ), |
| "color": ("STRING", {"default": "0, 0, 0", "tooltip": "Color as RGB/RGBA values in range 0-255 or 0.0-1.0, separated by commas. Ex: 255, 0, 0, 128"}), |
| }, |
| "optional": { |
| "device": (["cpu", "gpu"], {"default": "cpu", "tooltip": "Device to use for processing"}), |
| } |
| } |
|
|
| RETURN_TYPES = ("IMAGE", ) |
| RETURN_NAMES = ("images",) |
| FUNCTION = "apply" |
| CATEGORY = "KJNodes/masking" |
| DESCRIPTION = "Applies the provided masks to the input images with Alpha Blending support." |
|
|
| def apply(self, image, mask, color, device="cpu"): |
| B, H, W, C = image.shape |
| BM, HM, WM = mask.shape |
|
|
| processing_device = main_device if device == "gpu" else torch.device("cpu") |
|
|
| in_masks = mask.clone().to(processing_device) |
| in_images = image.clone().to(processing_device) |
|
|
| |
| if HM != H or WM != W: |
| in_masks = F.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest-exact').squeeze(1) |
| |
| if B > BM: |
| in_masks = in_masks.repeat((B + BM - 1) // BM, 1, 1)[:B] |
| elif BM > B: |
| in_masks = in_masks[:B] |
|
|
| output_images = [] |
|
|
| |
| color = color.strip() |
| color_values = [] |
|
|
| if color.startswith('#'): |
| |
| hex_color = color.lstrip('#') |
| if len(hex_color) == 3: |
| color_values = [int(c*2, 16) / 255.0 for c in hex_color] |
| elif len(hex_color) == 4: |
| color_values = [int(c*2, 16) / 255.0 for c in hex_color] |
| elif len(hex_color) == 6: |
| color_values = [int(hex_color[i:i+2], 16) / 255.0 for i in (0, 2, 4)] |
| elif len(hex_color) == 8: |
| color_values = [int(hex_color[i:i+2], 16) / 255.0 for i in (0, 2, 4, 6)] |
| else: |
| raise ValueError(f"Invalid hex color format: {color}") |
| else: |
| |
| for x in color.split(","): |
| val = float(x.strip()) |
| color_values.append(val / 255.0 if val > 1.0 else val) |
|
|
| rgb = color_values[:3] |
| alpha_val = color_values[3] if len(color_values) == 4 else 1.0 |
|
|
| fill_color = torch.tensor(rgb, dtype=torch.float32, device=processing_device) |
|
|
| for i in tqdm(range(B), desc="DrawMaskOnImage batch"): |
| curr_mask = in_masks[i] |
| img_idx = min(i, B - 1) |
| curr_image = in_images[img_idx] |
|
|
| blend_factor = curr_mask.unsqueeze(-1) * alpha_val |
| img_channels = curr_image.shape[-1] |
|
|
| if img_channels == 4: |
| img_rgb = curr_image[..., :3] |
| img_a = curr_image[..., 3:] |
| out_rgb = img_rgb * (1 - blend_factor) + fill_color * blend_factor |
| out_a = torch.maximum(img_a, blend_factor) |
| masked_image = torch.cat((out_rgb, out_a), dim=-1) |
| else: |
| masked_image = curr_image * (1 - blend_factor) + fill_color * blend_factor |
| output_images.append(masked_image) |
|
|
| if not output_images: |
| return (torch.zeros((0, H, W, C), dtype=image.dtype),) |
|
|
| out_tensor = torch.stack(output_images, dim=0).cpu() |
|
|
| return (out_tensor, ) |
|
|
| class BlockifyMask: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "masks": ("MASK",), |
| "block_size": ("INT", {"default": 32, "min": 8, "max": 512, "step": 1, "tooltip": "Size of blocks in pixels (smaller = smaller blocks)"}), |
| }, |
| "optional": { |
| "device": (["cpu", "gpu"], {"default": "cpu", "tooltip": "Device to use for processing"}), |
| } |
| } |
|
|
| RETURN_TYPES = ("MASK", ) |
| RETURN_NAMES = ("mask",) |
| FUNCTION = "process" |
| CATEGORY = "KJNodes/masking" |
| DESCRIPTION = "Creates a block mask by dividing the bounding box of each mask into blocks of the specified size and filling in blocks that contain any part of the original mask." |
|
|
| def process(self, masks, block_size, device="cpu"): |
| processing_device = main_device if device == "gpu" else torch.device("cpu") |
| |
| masks = masks.to(processing_device) |
| batch_size, height, width = masks.shape |
| |
| result_masks = torch.zeros_like(masks) |
| |
| for i in tqdm(range(batch_size), desc="BlockifyMask batch"): |
| mask = masks[i] |
| |
| |
| mask_bool = mask > 0 |
| if not mask_bool.any(): |
| continue |
| |
| y_indices = torch.nonzero(mask_bool.any(dim=1), as_tuple=True)[0] |
| x_indices = torch.nonzero(mask_bool.any(dim=0), as_tuple=True)[0] |
| |
| if len(y_indices) == 0 or len(x_indices) == 0: |
| continue |
| |
| y_min, y_max = y_indices[0], y_indices[-1] |
| x_min, x_max = x_indices[0], x_indices[-1] |
| |
| bbox_width = x_max - x_min + 1 |
| bbox_height = y_max - y_min + 1 |
| |
| |
| w_divisions = max(1, bbox_width // block_size) |
| h_divisions = max(1, bbox_height // block_size) |
| |
| w_slice = bbox_width // w_divisions |
| h_slice = bbox_height // h_divisions |
| |
| |
| y_coords = torch.arange(y_min, y_max + 1, device=processing_device).view(-1, 1) |
| x_coords = torch.arange(x_min, x_max + 1, device=processing_device).view(1, -1) |
| |
| |
| w_block_indices = (x_coords - x_min) // w_slice |
| h_block_indices = (y_coords - y_min) // h_slice |
| |
| |
| w_block_indices = w_block_indices.clamp(0, w_divisions - 1) |
| h_block_indices = h_block_indices.clamp(0, h_divisions - 1) |
| |
| |
| block_ids = h_block_indices * w_divisions + w_block_indices |
| |
| |
| mask_region = mask[y_min:y_max+1, x_min:x_max+1] |
| |
| |
| max_blocks = h_divisions * w_divisions |
| block_content = torch.zeros(max_blocks, device=processing_device) |
| block_content.scatter_add_(0, block_ids.flatten(), mask_region.flatten()) |
| |
| |
| has_content = block_content > 0 |
| block_mask = has_content[block_ids] |
| |
| |
| result_masks[i, y_min:y_max+1, x_min:x_max+1] = block_mask.float() |
| |
| return (result_masks.clamp(0, 1),) |
|
|