import torch import comfy import math from nodes import MAX_RESOLUTION class SwarmImageScaleForMP: @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",), "width": ("INT", {"default": 0, "min": 0, "max": 8192, "tooltip": "The target width of the image."}), "height": ("INT", {"default": 0, "min": 0, "max": 8192, "tooltip": "The target height of the image."}), "can_shrink": ("BOOLEAN", {"default": True, "tooltip": "If true, the image can be shrunk to fit the target size, otherwise it will only be scaled up or left the same."}), } } CATEGORY = "SwarmUI/images" RETURN_TYPES = ("IMAGE",) FUNCTION = "scale" DESCRIPTION = "Scales an image to a target width and height, while keeping the aspect ratio." def scale(self, image, width, height, can_shrink): mpTarget = width * height oldWidth = image.shape[2] oldHeight = image.shape[1] scale = math.sqrt(mpTarget / (oldWidth * oldHeight)) if not can_shrink and scale < 1: return (image,) newWid = int(round(oldWidth * scale / 64) * 64) newHei = int(round(oldHeight * scale / 64) * 64) samples = image.movedim(-1, 1) s = comfy.utils.common_upscale(samples, newWid, newHei, "bilinear", "disabled") s = s.movedim(1, -1) return (s,) class SwarmImageCrop: @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",), "x": ("INT", {"default": 0, "min": 0, "max": 8192, "step": 8, "tooltip": "The x coordinate in pixels of the top left corner of the crop."}), "y": ("INT", {"default": 0, "min": 0, "max": 8192, "step": 8, "tooltip": "The y coordinate in pixels of the top left corner of the crop."}), "width": ("INT", {"default": 512, "min": 64, "max": 8192, "step": 8, "tooltip": "The width in pixels of the crop."}), "height": ("INT", {"default": 512, "min": 64, "max": 8192, "step": 8, "tooltip": "The height in pixels of the crop."}), } } CATEGORY = "SwarmUI/images" RETURN_TYPES = ("IMAGE",) FUNCTION = "crop" DESCRIPTION = "Crops an image to a specific region." def crop(self, image, x, y, width, height): if width <= 0 or height <= 0: return (image,) to_x = width + x to_y = height + y img = image[:, y:to_y, x:to_x, :] return (img,) class SwarmVideoBoomerang: @classmethod def INPUT_TYPES(s): return { "required": { "images": ("IMAGE",), } } CATEGORY = "SwarmUI/video" RETURN_TYPES = ("IMAGE",) FUNCTION = "boomerang" DESCRIPTION = "Creates a boomerang effect by having the video play in reverse after the end, as a simple trick to make it appear to loop smoothly forever." def boomerang(self, images): # return images followed by reverse images images = torch.cat((images, images.flip(0)), 0) return (images,) class SwarmImageNoise: @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",), "amount": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 10.0, "step": 0.01, "round": False}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}) }, "optional": { "mask": ("MASK",) } } CATEGORY = "SwarmUI/images" RETURN_TYPES = ("IMAGE",) FUNCTION = "add_noise" DESCRIPTION = "Adds random noise to an image." def add_noise(self, image, amount, seed, mask=None): generator = torch.manual_seed(seed) while image.dim() < 4: image = image.unsqueeze(0) noise = torch.randn(image.size(), dtype=image.dtype, layout=image.layout, generator=generator, device="cpu") * amount if mask is not None: while mask.dim() < 4: mask = mask.unsqueeze(0) mask = torch.nn.functional.interpolate(mask.to(image.device), size=(image.shape[1], image.shape[2]), mode="bicubic") if image.shape[3] == 3 and image.shape[1] > 3: # (channels-last) mask = mask.movedim(1, -1) noise = noise * mask img = image + noise.to(image.device) img = torch.clamp(img, 0, 1) return (img,) class SwarmTrimFrames: @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",), "trim_start": ("INT", {"default": 0, "min": 0, "max": 4096}), "trim_end": ("INT", {"default": 0, "min": 0, "max": 4096}) } } CATEGORY = "SwarmUI/images" RETURN_TYPES = ("IMAGE",) FUNCTION = "trim" DESCRIPTION = "Trims frames from the start and end of a video." def trim(self, image, trim_start, trim_end): if image.shape[0] <= 1: return (image,) s_in = image start = max(0, min(s_in.shape[0], trim_start)) end = max(0, min(s_in.shape[0], trim_end)) s = s_in[start:s_in.shape[0] - end].clone() return (s,) class SwarmCountFrames: @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",) } } CATEGORY = "SwarmUI/images" RETURN_TYPES = ("INT",) FUNCTION = "count" DESCRIPTION = "Counts the number of frames in an image." def count(self, image): return (image.shape[0],) class SwarmImageWidth: @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",) } } CATEGORY = "SwarmUI/images" RETURN_TYPES = ("INT",) FUNCTION = "get_width" DESCRIPTION = "Gets the width of an image." def get_width(self, image): return (image.shape[-2],) class SwarmImageHeight: @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",) } } CATEGORY = "SwarmUI/images" RETURN_TYPES = ("INT",) FUNCTION = "get_height" DESCRIPTION = "Gets the height of an image." def get_height(self, image): return (image.shape[-3],) class SwarmImageCompositeMaskedColorCorrecting: @classmethod def INPUT_TYPES(s): return { "required": { "destination": ("IMAGE",), "source": ("IMAGE",), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "mask": ("MASK",), "correction_method": (["None", "Uniform", "Linear", "Linear2"], ) } } CATEGORY = "SwarmUI/images" RETURN_TYPES = ("IMAGE",) FUNCTION = "composite" DESCRIPTION = "Works like ImageCompositeMasked, but does color correction for inpainted images (ie outside-the-mask areas are expected to be identical)" def composite(self, destination, source, x, y, mask, correction_method): destination = destination.clone().movedim(-1, 1) source = source.clone().movedim(-1, 1).to(destination.device) source = comfy.utils.repeat_to_batch_size(source, destination.shape[0]) x = max(-source.shape[3], min(x, destination.shape[3])) y = max(-source.shape[2], min(y, destination.shape[2])) left, top = (x, y) right, bottom = (left + source.shape[3], top + source.shape[2],) mask = mask.to(destination.device, copy=True) mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) mask = mask[:, :, :visible_height, :visible_width] inverse_mask = torch.ones_like(mask) - mask source_section = source[:, :, :visible_height, :visible_width] dest_section = destination[:, :, top:bottom, left:right] # Fall through on "None" if correction_method == "Uniform": source_section = color_correct_uniform(source_section, dest_section, inverse_mask) elif correction_method == "Linear": source_section = color_correct_linear(source_section, dest_section, inverse_mask) elif correction_method == "Linear2": source_section = color_correct_linear2(source_section, dest_section, inverse_mask) source_portion = mask * source_section destination_portion = inverse_mask * dest_section destination[:, :, top:bottom, left:right] = source_portion + destination_portion return (destination.movedim(1, -1),) def color_correct_uniform(source_section: torch.Tensor, dest_section: torch.Tensor, inverse_mask: torch.Tensor) -> torch.Tensor: thresholded = (inverse_mask.clamp(0, 1) - 0.9999).clamp(0, 1) * 10000 thresholded_sum = thresholded.sum() if thresholded_sum > 50: source_hsv = rgb2hsv(source_section) dest_hsv = rgb2hsv(dest_section) source_hsv_masked = source_hsv * thresholded dest_hsv_masked = dest_hsv * thresholded diff = dest_hsv_masked - source_hsv_masked diff = diff.sum(dim=[0, 2, 3]) / thresholded_sum diff[0] = 0.0 diff = diff.unsqueeze(0).unsqueeze(2).unsqueeze(2) source_hsv = source_hsv + diff source_hsv = source_hsv.clamp(0, 1) source_section = hsv2rgb(source_hsv) return source_section def color_correct_linear(source_section: torch.Tensor, dest_section: torch.Tensor, inverse_mask: torch.Tensor) -> torch.Tensor: thresholded = (inverse_mask.clamp(0, 1) - 0.9999).clamp(0, 1) * 10000 thresholded_sum = thresholded.sum() if thresholded_sum > 50: source_hsv = rgb2hsv(source_section) dest_hsv = rgb2hsv(dest_section) source_h = source_hsv[:, 0:1, :, :] source_s = linear_fit(source_hsv[:, 1:2, :, :], dest_hsv[:, 1:2, :, :], thresholded) source_v = linear_fit(source_hsv[:, 2:3, :, :], dest_hsv[:, 2:3, :, :], thresholded) source_hsv = torch.cat([source_h, source_s, source_v], dim=1) source_section = hsv2rgb(source_hsv) return source_section # like color_correct_linear, but fits s*v and v instead of s and v to avoid instability from dark pixels def color_correct_linear2(source_section: torch.Tensor, dest_section: torch.Tensor, inverse_mask: torch.Tensor) -> torch.Tensor: thresholded = (inverse_mask.clamp(0, 1) - 0.9999).clamp(0, 1) * 10000 thresholded_sum = thresholded.sum() if thresholded_sum > 50: source_hsv = rgb2hsv(source_section) dest_hsv = rgb2hsv(dest_section) source_h = source_hsv[:, 0:1, :, :] source_sv_mul = source_hsv[:, 1:2, :, :] * source_hsv[:, 2:3, :, :] dest_sv_mul = dest_hsv[:, 1:2, :, :] * dest_hsv[:, 2:3, :, :] source_sv_mul = linear_fit(source_sv_mul, dest_sv_mul, thresholded) source_v = linear_fit(source_hsv[:, 2:3, :, :], dest_hsv[:, 2:3, :, :], thresholded) source_s = torch.zeros_like(source_sv_mul) source_s[source_v != 0] = source_sv_mul[source_v != 0] / source_v[source_v != 0] source_s = source_s.clamp(0, 1) source_hsv = torch.cat([source_h, source_s, source_v], dim=1) source_section = hsv2rgb(source_hsv) return source_section def linear_fit(source_component: torch.Tensor, dest_component: torch.Tensor, thresholded: torch.Tensor) -> torch.Tensor: thresholded_sum = thresholded.sum() source_masked = source_component * thresholded dest_masked = dest_component * thresholded # Simple linear regression on dest as a function of source source_mean = source_masked.sum(dim=[0, 2, 3]) / thresholded_sum dest_mean = dest_masked.sum(dim=[0, 2, 3]) / thresholded_sum source_mean = source_mean.unsqueeze(0).unsqueeze(2).unsqueeze(2) dest_mean = dest_mean.unsqueeze(0).unsqueeze(2).unsqueeze(2) source_deviation = (source_component - source_mean) * thresholded dest_deviation = (dest_component - dest_mean) * thresholded numerator = torch.sum(source_deviation * dest_deviation, (0, 2, 3)) denominator = torch.sum(source_deviation * source_deviation, (0, 2, 3)) # When all src the same color, we fall back to assuming m = 1 (uniform offset) m = torch.where(denominator != 0, numerator / denominator, torch.tensor(1.0)) m = m.unsqueeze(0).unsqueeze(2).unsqueeze(2) b = dest_mean - source_mean * m source_component = m * source_component + b source_component = source_component.clamp(0, 1) return source_component # from https://github.com/limacv/RGB_HSV_HSL def rgb2hsv(rgb: torch.Tensor) -> torch.Tensor: cmax, cmax_idx = torch.max(rgb, dim=1, keepdim=True) cmin = torch.min(rgb, dim=1, keepdim=True)[0] delta = cmax - cmin hsv_h = torch.empty_like(rgb[:, 0:1, :, :]) cmax_idx[delta == 0] = 3 hsv_h[cmax_idx == 0] = (((rgb[:, 1:2] - rgb[:, 2:3]) / delta) % 6)[cmax_idx == 0] hsv_h[cmax_idx == 1] = (((rgb[:, 2:3] - rgb[:, 0:1]) / delta) + 2)[cmax_idx == 1] hsv_h[cmax_idx == 2] = (((rgb[:, 0:1] - rgb[:, 1:2]) / delta) + 4)[cmax_idx == 2] hsv_h[cmax_idx == 3] = 0.0 hsv_h /= 6.0 hsv_s = torch.where(cmax == 0, torch.tensor(0.0).type_as(rgb), delta / cmax) hsv_v = cmax return torch.cat([hsv_h, hsv_s, hsv_v], dim=1) def hsv2rgb(hsv: torch.Tensor) -> torch.Tensor: hsv_h, hsv_s, hsv_l = hsv[:, 0:1], hsv[:, 1:2], hsv[:, 2:3] _c = hsv_l * hsv_s _x = _c * (- torch.abs(hsv_h * 6.0 % 2.0 - 1) + 1.0) _m = hsv_l - _c _o = torch.zeros_like(_c) idx = (hsv_h * 6.0).type(torch.uint8) idx = (idx % 6).expand(-1, 3, -1, -1) rgb = torch.empty_like(hsv) rgb[idx == 0] = torch.cat([_c, _x, _o], dim=1)[idx == 0] rgb[idx == 1] = torch.cat([_x, _c, _o], dim=1)[idx == 1] rgb[idx == 2] = torch.cat([_o, _c, _x], dim=1)[idx == 2] rgb[idx == 3] = torch.cat([_o, _x, _c], dim=1)[idx == 3] rgb[idx == 4] = torch.cat([_x, _o, _c], dim=1)[idx == 4] rgb[idx == 5] = torch.cat([_c, _o, _x], dim=1)[idx == 5] rgb += _m return rgb NODE_CLASS_MAPPINGS = { "SwarmImageScaleForMP": SwarmImageScaleForMP, "SwarmImageCrop": SwarmImageCrop, "SwarmVideoBoomerang": SwarmVideoBoomerang, "SwarmImageNoise": SwarmImageNoise, "SwarmTrimFrames": SwarmTrimFrames, "SwarmCountFrames": SwarmCountFrames, "SwarmImageWidth": SwarmImageWidth, "SwarmImageHeight": SwarmImageHeight, "SwarmImageCompositeMaskedColorCorrecting": SwarmImageCompositeMaskedColorCorrecting }