| import logging |
| import os |
| import json |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from typing_extensions import override |
|
|
| import folder_paths |
| import node_helpers |
| from comfy_api.latest import ComfyExtension, io |
|
|
|
|
| def load_and_process_images(image_files, input_dir): |
| """Utility function to load and process a list of images. |
| |
| Args: |
| image_files: List of image filenames |
| input_dir: Base directory containing the images |
| resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") |
| |
| Returns: |
| torch.Tensor: Batch of processed images |
| """ |
| if not image_files: |
| raise ValueError("No valid images found in input") |
|
|
| output_images = [] |
|
|
| for file in image_files: |
| image_path = os.path.join(input_dir, file) |
| img = node_helpers.pillow(Image.open, image_path) |
|
|
| if img.mode == "I": |
| img = img.point(lambda i: i * (1 / 255)) |
| img = img.convert("RGB") |
| img_array = np.array(img).astype(np.float32) / 255.0 |
| img_tensor = torch.from_numpy(img_array)[None,] |
| output_images.append(img_tensor) |
|
|
| return output_images |
|
|
|
|
| class LoadImageDataSetFromFolderNode(io.ComfyNode): |
| @classmethod |
| def define_schema(cls): |
| return io.Schema( |
| node_id="LoadImageDataSetFromFolder", |
| display_name="Load Image Dataset from Folder", |
| category="dataset", |
| is_experimental=True, |
| inputs=[ |
| io.Combo.Input( |
| "folder", |
| options=folder_paths.get_input_subfolders(), |
| tooltip="The folder to load images from.", |
| ) |
| ], |
| outputs=[ |
| io.Image.Output( |
| display_name="images", |
| is_output_list=True, |
| tooltip="List of loaded images", |
| ) |
| ], |
| ) |
|
|
| @classmethod |
| def execute(cls, folder): |
| sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) |
| valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] |
| image_files = [ |
| f |
| for f in os.listdir(sub_input_dir) |
| if any(f.lower().endswith(ext) for ext in valid_extensions) |
| ] |
| output_tensor = load_and_process_images(image_files, sub_input_dir) |
| return io.NodeOutput(output_tensor) |
|
|
|
|
| class LoadImageTextDataSetFromFolderNode(io.ComfyNode): |
| @classmethod |
| def define_schema(cls): |
| return io.Schema( |
| node_id="LoadImageTextDataSetFromFolder", |
| display_name="Load Image and Text Dataset from Folder", |
| category="dataset", |
| is_experimental=True, |
| inputs=[ |
| io.Combo.Input( |
| "folder", |
| options=folder_paths.get_input_subfolders(), |
| tooltip="The folder to load images from.", |
| ) |
| ], |
| outputs=[ |
| io.Image.Output( |
| display_name="images", |
| is_output_list=True, |
| tooltip="List of loaded images", |
| ), |
| io.String.Output( |
| display_name="texts", |
| is_output_list=True, |
| tooltip="List of text captions", |
| ), |
| ], |
| ) |
|
|
| @classmethod |
| def execute(cls, folder): |
| logging.info(f"Loading images from folder: {folder}") |
|
|
| sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) |
| valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] |
|
|
| image_files = [] |
| for item in os.listdir(sub_input_dir): |
| path = os.path.join(sub_input_dir, item) |
| if any(item.lower().endswith(ext) for ext in valid_extensions): |
| image_files.append(path) |
| elif os.path.isdir(path): |
| |
| repeat = 1 |
| if item.split("_")[0].isdigit(): |
| repeat = int(item.split("_")[0]) |
| image_files.extend( |
| [ |
| os.path.join(path, f) |
| for f in os.listdir(path) |
| if any(f.lower().endswith(ext) for ext in valid_extensions) |
| ] |
| * repeat |
| ) |
|
|
| caption_file_path = [ |
| f.replace(os.path.splitext(f)[1], ".txt") for f in image_files |
| ] |
| captions = [] |
| for caption_file in caption_file_path: |
| caption_path = os.path.join(sub_input_dir, caption_file) |
| if os.path.exists(caption_path): |
| with open(caption_path, "r", encoding="utf-8") as f: |
| caption = f.read().strip() |
| captions.append(caption) |
| else: |
| captions.append("") |
|
|
| output_tensor = load_and_process_images(image_files, sub_input_dir) |
|
|
| logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.") |
| return io.NodeOutput(output_tensor, captions) |
|
|
|
|
| def save_images_to_folder(image_list, output_dir, prefix="image"): |
| """Utility function to save a list of image tensors to disk. |
| |
| Args: |
| image_list: List of image tensors (each [1, H, W, C] or [H, W, C] or [C, H, W]) |
| output_dir: Directory to save images to |
| prefix: Filename prefix |
| |
| Returns: |
| List of saved filenames |
| """ |
| os.makedirs(output_dir, exist_ok=True) |
| saved_files = [] |
|
|
| for idx, img_tensor in enumerate(image_list): |
| |
| if isinstance(img_tensor, torch.Tensor): |
| |
| if img_tensor.dim() == 4 and img_tensor.shape[0] == 1: |
| img_tensor = img_tensor.squeeze(0) |
|
|
| |
| if img_tensor.dim() == 3 and img_tensor.shape[0] in [1, 3, 4]: |
| if ( |
| img_tensor.shape[0] <= 4 |
| and img_tensor.shape[1] > 4 |
| and img_tensor.shape[2] > 4 |
| ): |
| img_tensor = img_tensor.permute(1, 2, 0) |
|
|
| |
| img_array = img_tensor.cpu().numpy() |
| img_array = np.clip(img_array * 255.0, 0, 255).astype(np.uint8) |
|
|
| |
| img = Image.fromarray(img_array) |
| else: |
| raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}") |
|
|
| |
| filename = f"{prefix}_{idx:05d}.png" |
| filepath = os.path.join(output_dir, filename) |
| img.save(filepath) |
| saved_files.append(filename) |
|
|
| return saved_files |
|
|
|
|
| class SaveImageDataSetToFolderNode(io.ComfyNode): |
| @classmethod |
| def define_schema(cls): |
| return io.Schema( |
| node_id="SaveImageDataSetToFolder", |
| display_name="Save Image Dataset to Folder", |
| category="dataset", |
| is_experimental=True, |
| is_output_node=True, |
| is_input_list=True, |
| inputs=[ |
| io.Image.Input("images", tooltip="List of images to save."), |
| io.String.Input( |
| "folder_name", |
| default="dataset", |
| tooltip="Name of the folder to save images to (inside output directory).", |
| ), |
| io.String.Input( |
| "filename_prefix", |
| default="image", |
| tooltip="Prefix for saved image filenames.", |
| ), |
| ], |
| outputs=[], |
| ) |
|
|
| @classmethod |
| def execute(cls, images, folder_name, filename_prefix): |
| |
| folder_name = folder_name[0] |
| filename_prefix = filename_prefix[0] |
|
|
| output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) |
| saved_files = save_images_to_folder(images, output_dir, filename_prefix) |
|
|
| logging.info(f"Saved {len(saved_files)} images to {output_dir}.") |
| return io.NodeOutput() |
|
|
|
|
| class SaveImageTextDataSetToFolderNode(io.ComfyNode): |
| @classmethod |
| def define_schema(cls): |
| return io.Schema( |
| node_id="SaveImageTextDataSetToFolder", |
| display_name="Save Image and Text Dataset to Folder", |
| category="dataset", |
| is_experimental=True, |
| is_output_node=True, |
| is_input_list=True, |
| inputs=[ |
| io.Image.Input("images", tooltip="List of images to save."), |
| io.String.Input("texts", tooltip="List of text captions to save."), |
| io.String.Input( |
| "folder_name", |
| default="dataset", |
| tooltip="Name of the folder to save images to (inside output directory).", |
| ), |
| io.String.Input( |
| "filename_prefix", |
| default="image", |
| tooltip="Prefix for saved image filenames.", |
| ), |
| ], |
| outputs=[], |
| ) |
|
|
| @classmethod |
| def execute(cls, images, texts, folder_name, filename_prefix): |
| |
| folder_name = folder_name[0] |
| filename_prefix = filename_prefix[0] |
|
|
| output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) |
| saved_files = save_images_to_folder(images, output_dir, filename_prefix) |
|
|
| |
| for idx, (filename, caption) in enumerate(zip(saved_files, texts)): |
| caption_filename = filename.replace(".png", ".txt") |
| caption_path = os.path.join(output_dir, caption_filename) |
| with open(caption_path, "w", encoding="utf-8") as f: |
| f.write(caption) |
|
|
| logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.") |
| return io.NodeOutput() |
|
|
|
|
| |
|
|
|
|
| def tensor_to_pil(img_tensor): |
| """Convert tensor to PIL Image.""" |
| if img_tensor.dim() == 4 and img_tensor.shape[0] == 1: |
| img_tensor = img_tensor.squeeze(0) |
| img_array = (img_tensor.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) |
| return Image.fromarray(img_array) |
|
|
|
|
| def pil_to_tensor(img): |
| """Convert PIL Image to tensor.""" |
| img_array = np.array(img).astype(np.float32) / 255.0 |
| return torch.from_numpy(img_array)[None,] |
|
|
|
|
| |
|
|
|
|
| class ImageProcessingNode(io.ComfyNode): |
| """Base class for image processing nodes that operate on images. |
| |
| Child classes should set: |
| node_id: Unique node identifier (required) |
| display_name: Display name (optional, defaults to node_id) |
| description: Node description (optional) |
| extra_inputs: List of additional io.Input objects beyond "images" (optional) |
| is_group_process: None (auto-detect), True (group), or False (individual) (optional) |
| is_output_list: True (list output) or False (single output) (optional, default True) |
| |
| Child classes must implement ONE of: |
| _process(cls, image, **kwargs) -> tensor (for single-item processing) |
| _group_process(cls, images, **kwargs) -> list[tensor] (for group processing) |
| """ |
|
|
| node_id = None |
| display_name = None |
| description = None |
| extra_inputs = [] |
| is_group_process = None |
| is_output_list = None |
|
|
| @classmethod |
| def _detect_processing_mode(cls): |
| """Detect whether this node uses group or individual processing. |
| |
| Returns: |
| bool: True if group processing, False if individual processing |
| """ |
| |
| if cls.is_group_process is not None: |
| return cls.is_group_process |
|
|
| |
| base_class = ImageProcessingNode |
|
|
| |
| process_definer = None |
| for klass in cls.__mro__: |
| if "_process" in klass.__dict__: |
| process_definer = klass |
| break |
|
|
| |
| group_definer = None |
| for klass in cls.__mro__: |
| if "_group_process" in klass.__dict__: |
| group_definer = klass |
| break |
|
|
| |
| has_process = process_definer is not None and process_definer is not base_class |
| has_group = group_definer is not None and group_definer is not base_class |
|
|
| if has_process and has_group: |
| raise ValueError( |
| f"{cls.__name__}: Cannot override both _process and _group_process. " |
| "Override only one, or set is_group_process explicitly." |
| ) |
| if not has_process and not has_group: |
| raise ValueError( |
| f"{cls.__name__}: Must override either _process or _group_process" |
| ) |
|
|
| return has_group |
|
|
| @classmethod |
| def define_schema(cls): |
| if cls.node_id is None: |
| raise NotImplementedError(f"{cls.__name__} must set node_id class variable") |
|
|
| is_group = cls._detect_processing_mode() |
|
|
| |
| |
| |
| output_is_list = ( |
| cls.is_output_list if cls.is_output_list is not None else is_group |
| ) |
|
|
| inputs = [ |
| io.Image.Input( |
| "images", |
| tooltip=( |
| "List of images to process." if is_group else "Image to process." |
| ), |
| ) |
| ] |
| inputs.extend(cls.extra_inputs) |
|
|
| return io.Schema( |
| node_id=cls.node_id, |
| display_name=cls.display_name or cls.node_id, |
| category="dataset/image", |
| is_experimental=True, |
| is_input_list=is_group, |
| inputs=inputs, |
| outputs=[ |
| io.Image.Output( |
| display_name="images", |
| is_output_list=output_is_list, |
| tooltip="Processed images", |
| ) |
| ], |
| ) |
|
|
| @classmethod |
| def execute(cls, images, **kwargs): |
| """Execute the node. Routes to _process or _group_process based on mode.""" |
| is_group = cls._detect_processing_mode() |
|
|
| |
| params = {} |
| for k, v in kwargs.items(): |
| if isinstance(v, list) and len(v) == 1: |
| params[k] = v[0] |
| else: |
| params[k] = v |
|
|
| if is_group: |
| |
| result = cls._group_process(images, **params) |
| else: |
| |
| result = cls._process(images, **params) |
|
|
| return io.NodeOutput(result) |
|
|
| @classmethod |
| def _process(cls, image, **kwargs): |
| """Override this method for single-item processing. |
| |
| Args: |
| image: tensor - Single image tensor |
| **kwargs: Additional parameters (already extracted from lists) |
| |
| Returns: |
| tensor - Processed image |
| """ |
| raise NotImplementedError(f"{cls.__name__} must implement _process method") |
|
|
| @classmethod |
| def _group_process(cls, images, **kwargs): |
| """Override this method for group processing. |
| |
| Args: |
| images: list[tensor] - List of image tensors |
| **kwargs: Additional parameters (already extracted from lists) |
| |
| Returns: |
| list[tensor] - Processed images |
| """ |
| raise NotImplementedError( |
| f"{cls.__name__} must implement _group_process method" |
| ) |
|
|
|
|
| class TextProcessingNode(io.ComfyNode): |
| """Base class for text processing nodes that operate on texts. |
| |
| Child classes should set: |
| node_id: Unique node identifier (required) |
| display_name: Display name (optional, defaults to node_id) |
| description: Node description (optional) |
| extra_inputs: List of additional io.Input objects beyond "texts" (optional) |
| is_group_process: None (auto-detect), True (group), or False (individual) (optional) |
| is_output_list: True (list output) or False (single output) (optional, default True) |
| |
| Child classes must implement ONE of: |
| _process(cls, text, **kwargs) -> str (for single-item processing) |
| _group_process(cls, texts, **kwargs) -> list[str] (for group processing) |
| """ |
|
|
| node_id = None |
| display_name = None |
| description = None |
| extra_inputs = [] |
| is_group_process = None |
| is_output_list = None |
|
|
| @classmethod |
| def _detect_processing_mode(cls): |
| """Detect whether this node uses group or individual processing. |
| |
| Returns: |
| bool: True if group processing, False if individual processing |
| """ |
| |
| if cls.is_group_process is not None: |
| return cls.is_group_process |
|
|
| |
| base_class = TextProcessingNode |
|
|
| |
| process_definer = None |
| for klass in cls.__mro__: |
| if "_process" in klass.__dict__: |
| process_definer = klass |
| break |
|
|
| |
| group_definer = None |
| for klass in cls.__mro__: |
| if "_group_process" in klass.__dict__: |
| group_definer = klass |
| break |
|
|
| |
| has_process = process_definer is not None and process_definer is not base_class |
| has_group = group_definer is not None and group_definer is not base_class |
|
|
| if has_process and has_group: |
| raise ValueError( |
| f"{cls.__name__}: Cannot override both _process and _group_process. " |
| "Override only one, or set is_group_process explicitly." |
| ) |
| if not has_process and not has_group: |
| raise ValueError( |
| f"{cls.__name__}: Must override either _process or _group_process" |
| ) |
|
|
| return has_group |
|
|
| @classmethod |
| def define_schema(cls): |
| if cls.node_id is None: |
| raise NotImplementedError(f"{cls.__name__} must set node_id class variable") |
|
|
| is_group = cls._detect_processing_mode() |
|
|
| inputs = [ |
| io.String.Input( |
| "texts", |
| tooltip="List of texts to process." if is_group else "Text to process.", |
| ) |
| ] |
| inputs.extend(cls.extra_inputs) |
|
|
| return io.Schema( |
| node_id=cls.node_id, |
| display_name=cls.display_name or cls.node_id, |
| category="dataset/text", |
| is_experimental=True, |
| is_input_list=is_group, |
| inputs=inputs, |
| outputs=[ |
| io.String.Output( |
| display_name="texts", |
| is_output_list=cls.is_output_list, |
| tooltip="Processed texts", |
| ) |
| ], |
| ) |
|
|
| @classmethod |
| def execute(cls, texts, **kwargs): |
| """Execute the node. Routes to _process or _group_process based on mode.""" |
| is_group = cls._detect_processing_mode() |
|
|
| |
| params = {} |
| for k, v in kwargs.items(): |
| if isinstance(v, list) and len(v) == 1: |
| params[k] = v[0] |
| else: |
| params[k] = v |
|
|
| if is_group: |
| |
| result = cls._group_process(texts, **params) |
| else: |
| |
| result = cls._process(texts, **params) |
|
|
| |
| if cls.is_output_list: |
| |
| return io.NodeOutput(result if is_group else [result]) |
| else: |
| |
| return io.NodeOutput([result]) |
|
|
| @classmethod |
| def _process(cls, text, **kwargs): |
| """Override this method for single-item processing. |
| |
| Args: |
| text: str - Single text string |
| **kwargs: Additional parameters (already extracted from lists) |
| |
| Returns: |
| str - Processed text |
| """ |
| raise NotImplementedError(f"{cls.__name__} must implement _process method") |
|
|
| @classmethod |
| def _group_process(cls, texts, **kwargs): |
| """Override this method for group processing. |
| |
| Args: |
| texts: list[str] - List of text strings |
| **kwargs: Additional parameters (already extracted from lists) |
| |
| Returns: |
| list[str] - Processed texts |
| """ |
| raise NotImplementedError( |
| f"{cls.__name__} must implement _group_process method" |
| ) |
|
|
|
|
| |
|
|
|
|
| class ResizeImagesByShorterEdgeNode(ImageProcessingNode): |
| node_id = "ResizeImagesByShorterEdge" |
| display_name = "Resize Images by Shorter Edge" |
| description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio." |
| extra_inputs = [ |
| io.Int.Input( |
| "shorter_edge", |
| default=512, |
| min=1, |
| max=8192, |
| tooltip="Target length for the shorter edge.", |
| ), |
| ] |
|
|
| @classmethod |
| def _process(cls, image, shorter_edge): |
| img = tensor_to_pil(image) |
| w, h = img.size |
| if w < h: |
| new_w = shorter_edge |
| new_h = int(h * (shorter_edge / w)) |
| else: |
| new_h = shorter_edge |
| new_w = int(w * (shorter_edge / h)) |
| img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) |
| return pil_to_tensor(img) |
|
|
|
|
| class ResizeImagesByLongerEdgeNode(ImageProcessingNode): |
| node_id = "ResizeImagesByLongerEdge" |
| display_name = "Resize Images by Longer Edge" |
| description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio." |
| extra_inputs = [ |
| io.Int.Input( |
| "longer_edge", |
| default=1024, |
| min=1, |
| max=8192, |
| tooltip="Target length for the longer edge.", |
| ), |
| ] |
|
|
| @classmethod |
| def _process(cls, image, longer_edge): |
| resized_images = [] |
| for image_i in image: |
| img = tensor_to_pil(image_i) |
| w, h = img.size |
| if w > h: |
| new_w = longer_edge |
| new_h = int(h * (longer_edge / w)) |
| else: |
| new_h = longer_edge |
| new_w = int(w * (longer_edge / h)) |
| img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) |
| resized_images.append(pil_to_tensor(img)) |
| return torch.cat(resized_images, dim=0) |
|
|
|
|
| class CenterCropImagesNode(ImageProcessingNode): |
| node_id = "CenterCropImages" |
| display_name = "Center Crop Images" |
| description = "Center crop all images to the specified dimensions." |
| extra_inputs = [ |
| io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), |
| io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), |
| ] |
|
|
| @classmethod |
| def _process(cls, image, width, height): |
| img = tensor_to_pil(image) |
| left = max(0, (img.width - width) // 2) |
| top = max(0, (img.height - height) // 2) |
| right = min(img.width, left + width) |
| bottom = min(img.height, top + height) |
| img = img.crop((left, top, right, bottom)) |
| return pil_to_tensor(img) |
|
|
|
|
| class RandomCropImagesNode(ImageProcessingNode): |
| node_id = "RandomCropImages" |
| display_name = "Random Crop Images" |
| description = ( |
| "Randomly crop all images to the specified dimensions (for data augmentation)." |
| ) |
| extra_inputs = [ |
| io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), |
| io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), |
| io.Int.Input( |
| "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." |
| ), |
| ] |
|
|
| @classmethod |
| def _process(cls, image, width, height, seed): |
| np.random.seed(seed % (2**32 - 1)) |
| img = tensor_to_pil(image) |
| max_left = max(0, img.width - width) |
| max_top = max(0, img.height - height) |
| left = np.random.randint(0, max_left + 1) if max_left > 0 else 0 |
| top = np.random.randint(0, max_top + 1) if max_top > 0 else 0 |
| right = min(img.width, left + width) |
| bottom = min(img.height, top + height) |
| img = img.crop((left, top, right, bottom)) |
| return pil_to_tensor(img) |
|
|
|
|
| class NormalizeImagesNode(ImageProcessingNode): |
| node_id = "NormalizeImages" |
| display_name = "Normalize Images" |
| description = "Normalize images using mean and standard deviation." |
| extra_inputs = [ |
| io.Float.Input( |
| "mean", |
| default=0.5, |
| min=0.0, |
| max=1.0, |
| tooltip="Mean value for normalization.", |
| ), |
| io.Float.Input( |
| "std", |
| default=0.5, |
| min=0.001, |
| max=1.0, |
| tooltip="Standard deviation for normalization.", |
| ), |
| ] |
|
|
| @classmethod |
| def _process(cls, image, mean, std): |
| return (image - mean) / std |
|
|
|
|
| class AdjustBrightnessNode(ImageProcessingNode): |
| node_id = "AdjustBrightness" |
| display_name = "Adjust Brightness" |
| description = "Adjust brightness of all images." |
| extra_inputs = [ |
| io.Float.Input( |
| "factor", |
| default=1.0, |
| min=0.0, |
| max=2.0, |
| tooltip="Brightness factor. 1.0 = no change, <1.0 = darker, >1.0 = brighter.", |
| ), |
| ] |
|
|
| @classmethod |
| def _process(cls, image, factor): |
| return (image * factor).clamp(0.0, 1.0) |
|
|
|
|
| class AdjustContrastNode(ImageProcessingNode): |
| node_id = "AdjustContrast" |
| display_name = "Adjust Contrast" |
| description = "Adjust contrast of all images." |
| extra_inputs = [ |
| io.Float.Input( |
| "factor", |
| default=1.0, |
| min=0.0, |
| max=2.0, |
| tooltip="Contrast factor. 1.0 = no change, <1.0 = less contrast, >1.0 = more contrast.", |
| ), |
| ] |
|
|
| @classmethod |
| def _process(cls, image, factor): |
| return ((image - 0.5) * factor + 0.5).clamp(0.0, 1.0) |
|
|
|
|
| class ShuffleDatasetNode(ImageProcessingNode): |
| node_id = "ShuffleDataset" |
| display_name = "Shuffle Image Dataset" |
| description = "Randomly shuffle the order of images in the dataset." |
| is_group_process = True |
| extra_inputs = [ |
| io.Int.Input( |
| "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." |
| ), |
| ] |
|
|
| @classmethod |
| def _group_process(cls, images, seed): |
| np.random.seed(seed % (2**32 - 1)) |
| indices = np.random.permutation(len(images)) |
| return [images[i] for i in indices] |
|
|
|
|
| class ShuffleImageTextDatasetNode(io.ComfyNode): |
| """Special node that shuffles both images and texts together.""" |
|
|
| @classmethod |
| def define_schema(cls): |
| return io.Schema( |
| node_id="ShuffleImageTextDataset", |
| display_name="Shuffle Image-Text Dataset", |
| category="dataset/image", |
| is_experimental=True, |
| is_input_list=True, |
| inputs=[ |
| io.Image.Input("images", tooltip="List of images to shuffle."), |
| io.String.Input("texts", tooltip="List of texts to shuffle."), |
| io.Int.Input( |
| "seed", |
| default=0, |
| min=0, |
| max=0xFFFFFFFFFFFFFFFF, |
| tooltip="Random seed.", |
| ), |
| ], |
| outputs=[ |
| io.Image.Output( |
| display_name="images", |
| is_output_list=True, |
| tooltip="Shuffled images", |
| ), |
| io.String.Output( |
| display_name="texts", is_output_list=True, tooltip="Shuffled texts" |
| ), |
| ], |
| ) |
|
|
| @classmethod |
| def execute(cls, images, texts, seed): |
| seed = seed[0] |
| np.random.seed(seed % (2**32 - 1)) |
| indices = np.random.permutation(len(images)) |
| shuffled_images = [images[i] for i in indices] |
| shuffled_texts = [texts[i] for i in indices] |
| return io.NodeOutput(shuffled_images, shuffled_texts) |
|
|
|
|
| |
|
|
|
|
| class TextToLowercaseNode(TextProcessingNode): |
| node_id = "TextToLowercase" |
| display_name = "Text to Lowercase" |
| description = "Convert all texts to lowercase." |
|
|
| @classmethod |
| def _process(cls, text): |
| return text.lower() |
|
|
|
|
| class TextToUppercaseNode(TextProcessingNode): |
| node_id = "TextToUppercase" |
| display_name = "Text to Uppercase" |
| description = "Convert all texts to uppercase." |
|
|
| @classmethod |
| def _process(cls, text): |
| return text.upper() |
|
|
|
|
| class TruncateTextNode(TextProcessingNode): |
| node_id = "TruncateText" |
| display_name = "Truncate Text" |
| description = "Truncate all texts to a maximum length." |
| extra_inputs = [ |
| io.Int.Input( |
| "max_length", default=77, min=1, max=10000, tooltip="Maximum text length." |
| ), |
| ] |
|
|
| @classmethod |
| def _process(cls, text, max_length): |
| return text[:max_length] |
|
|
|
|
| class AddTextPrefixNode(TextProcessingNode): |
| node_id = "AddTextPrefix" |
| display_name = "Add Text Prefix" |
| description = "Add a prefix to all texts." |
| extra_inputs = [ |
| io.String.Input("prefix", default="", tooltip="Prefix to add."), |
| ] |
|
|
| @classmethod |
| def _process(cls, text, prefix): |
| return prefix + text |
|
|
|
|
| class AddTextSuffixNode(TextProcessingNode): |
| node_id = "AddTextSuffix" |
| display_name = "Add Text Suffix" |
| description = "Add a suffix to all texts." |
| extra_inputs = [ |
| io.String.Input("suffix", default="", tooltip="Suffix to add."), |
| ] |
|
|
| @classmethod |
| def _process(cls, text, suffix): |
| return text + suffix |
|
|
|
|
| class ReplaceTextNode(TextProcessingNode): |
| node_id = "ReplaceText" |
| display_name = "Replace Text" |
| description = "Replace text in all texts." |
| extra_inputs = [ |
| io.String.Input("find", default="", tooltip="Text to find."), |
| io.String.Input("replace", default="", tooltip="Text to replace with."), |
| ] |
|
|
| @classmethod |
| def _process(cls, text, find, replace): |
| return text.replace(find, replace) |
|
|
|
|
| class StripWhitespaceNode(TextProcessingNode): |
| node_id = "StripWhitespace" |
| display_name = "Strip Whitespace" |
| description = "Strip leading and trailing whitespace from all texts." |
|
|
| @classmethod |
| def _process(cls, text): |
| return text.strip() |
|
|
|
|
| |
|
|
|
|
| class ImageDeduplicationNode(ImageProcessingNode): |
| """Remove duplicate or very similar images from the dataset using perceptual hashing.""" |
|
|
| node_id = "ImageDeduplication" |
| display_name = "Image Deduplication" |
| description = "Remove duplicate or very similar images from the dataset." |
| is_group_process = True |
| extra_inputs = [ |
| io.Float.Input( |
| "similarity_threshold", |
| default=0.95, |
| min=0.0, |
| max=1.0, |
| tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.", |
| ), |
| ] |
|
|
| @classmethod |
| def _group_process(cls, images, similarity_threshold): |
| """Remove duplicate images using perceptual hashing.""" |
| if len(images) == 0: |
| return [] |
|
|
| |
| def compute_hash(img_tensor): |
| """Compute a simple perceptual hash by resizing to 8x8 and comparing to average.""" |
| img = tensor_to_pil(img_tensor) |
| |
| img_small = img.resize((8, 8), Image.Resampling.LANCZOS).convert("L") |
| |
| pixels = list(img_small.getdata()) |
| |
| avg = sum(pixels) / len(pixels) |
| |
| hash_bits = "".join("1" if p > avg else "0" for p in pixels) |
| return hash_bits |
|
|
| def hamming_distance(hash1, hash2): |
| """Compute Hamming distance between two hash strings.""" |
| return sum(c1 != c2 for c1, c2 in zip(hash1, hash2)) |
|
|
| |
| hashes = [compute_hash(img) for img in images] |
|
|
| |
| keep_indices = [] |
| for i in range(len(images)): |
| is_duplicate = False |
| for j in keep_indices: |
| |
| distance = hamming_distance(hashes[i], hashes[j]) |
| similarity = 1.0 - (distance / 64.0) |
| if similarity >= similarity_threshold: |
| is_duplicate = True |
| logging.info( |
| f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping" |
| ) |
| break |
|
|
| if not is_duplicate: |
| keep_indices.append(i) |
|
|
| |
| unique_images = [images[i] for i in keep_indices] |
| logging.info( |
| f"Deduplication: kept {len(unique_images)} out of {len(images)} images" |
| ) |
| return unique_images |
|
|
|
|
| class ImageGridNode(ImageProcessingNode): |
| """Combine multiple images into a single grid/collage.""" |
|
|
| node_id = "ImageGrid" |
| display_name = "Image Grid" |
| description = "Arrange multiple images into a grid layout." |
| is_group_process = True |
| is_output_list = False |
| extra_inputs = [ |
| io.Int.Input( |
| "columns", |
| default=4, |
| min=1, |
| max=20, |
| tooltip="Number of columns in the grid.", |
| ), |
| io.Int.Input( |
| "cell_width", |
| default=256, |
| min=32, |
| max=2048, |
| tooltip="Width of each cell in the grid.", |
| ), |
| io.Int.Input( |
| "cell_height", |
| default=256, |
| min=32, |
| max=2048, |
| tooltip="Height of each cell in the grid.", |
| ), |
| io.Int.Input( |
| "padding", default=4, min=0, max=50, tooltip="Padding between images." |
| ), |
| ] |
|
|
| @classmethod |
| def _group_process(cls, images, columns, cell_width, cell_height, padding): |
| """Arrange images into a grid.""" |
| if len(images) == 0: |
| raise ValueError("Cannot create grid from empty image list") |
|
|
| |
| num_images = len(images) |
| rows = (num_images + columns - 1) // columns |
|
|
| |
| grid_width = columns * cell_width + (columns - 1) * padding |
| grid_height = rows * cell_height + (rows - 1) * padding |
|
|
| |
| grid = Image.new("RGB", (grid_width, grid_height), (0, 0, 0)) |
|
|
| |
| for idx, img_tensor in enumerate(images): |
| row = idx // columns |
| col = idx % columns |
|
|
| |
| img = tensor_to_pil(img_tensor) |
| img = img.resize((cell_width, cell_height), Image.Resampling.LANCZOS) |
|
|
| |
| x = col * (cell_width + padding) |
| y = row * (cell_height + padding) |
|
|
| |
| grid.paste(img, (x, y)) |
|
|
| logging.info( |
| f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})" |
| ) |
| return pil_to_tensor(grid) |
|
|
|
|
| class MergeImageListsNode(ImageProcessingNode): |
| """Merge multiple image lists into a single list.""" |
|
|
| node_id = "MergeImageLists" |
| display_name = "Merge Image Lists" |
| description = "Concatenate multiple image lists into one." |
| is_group_process = True |
|
|
| @classmethod |
| def _group_process(cls, images): |
| """Simply return the images list (already merged by input handling).""" |
| |
| |
| logging.info(f"Merged image list contains {len(images)} images") |
| return images |
|
|
|
|
| class MergeTextListsNode(TextProcessingNode): |
| """Merge multiple text lists into a single list.""" |
|
|
| node_id = "MergeTextLists" |
| display_name = "Merge Text Lists" |
| description = "Concatenate multiple text lists into one." |
| is_group_process = True |
|
|
| @classmethod |
| def _group_process(cls, texts): |
| """Simply return the texts list (already merged by input handling).""" |
| |
| |
| logging.info(f"Merged text list contains {len(texts)} texts") |
| return texts |
|
|
|
|
| |
|
|
|
|
| class ResolutionBucket(io.ComfyNode): |
| """Bucket latents and conditions by resolution for efficient batch training.""" |
|
|
| @classmethod |
| def define_schema(cls): |
| return io.Schema( |
| node_id="ResolutionBucket", |
| display_name="Resolution Bucket", |
| category="dataset", |
| is_experimental=True, |
| is_input_list=True, |
| inputs=[ |
| io.Latent.Input( |
| "latents", |
| tooltip="List of latent dicts to bucket by resolution.", |
| ), |
| io.Conditioning.Input( |
| "conditioning", |
| tooltip="List of conditioning lists (must match latents length).", |
| ), |
| ], |
| outputs=[ |
| io.Latent.Output( |
| display_name="latents", |
| is_output_list=True, |
| tooltip="List of batched latent dicts, one per resolution bucket.", |
| ), |
| io.Conditioning.Output( |
| display_name="conditioning", |
| is_output_list=True, |
| tooltip="List of condition lists, one per resolution bucket.", |
| ), |
| ], |
| ) |
|
|
| @classmethod |
| def execute(cls, latents, conditioning): |
| |
| |
|
|
| |
| if len(latents) != len(conditioning): |
| raise ValueError( |
| f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})." |
| ) |
|
|
| |
| flat_latents = [] |
| flat_conditions = [] |
|
|
| for latent_dict, cond in zip(latents, conditioning): |
| samples = latent_dict["samples"] |
| batch_size = samples.shape[0] |
|
|
| |
| for i in range(batch_size): |
| flat_latents.append(samples[i]) |
| flat_conditions.append(cond[i]) |
|
|
| |
| buckets = {} |
|
|
| for latent, cond in zip(flat_latents, flat_conditions): |
| |
| h, w = latent.shape[-2], latent.shape[-1] |
| key = (h, w) |
|
|
| if key not in buckets: |
| buckets[key] = {"latents": [], "conditions": []} |
|
|
| buckets[key]["latents"].append(latent) |
| buckets[key]["conditions"].append(cond) |
|
|
| |
| output_latents = [] |
| output_conditions = [] |
|
|
| for (h, w), bucket_data in buckets.items(): |
| |
| stacked_latents = torch.stack(bucket_data["latents"], dim=0) |
| output_latents.append({"samples": stacked_latents}) |
|
|
| |
| output_conditions.append(bucket_data["conditions"]) |
|
|
| logging.info( |
| f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples" |
| ) |
|
|
| logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples") |
| return io.NodeOutput(output_latents, output_conditions) |
|
|
|
|
| class MakeTrainingDataset(io.ComfyNode): |
| """Encode images with VAE and texts with CLIP to create a training dataset.""" |
|
|
| @classmethod |
| def define_schema(cls): |
| return io.Schema( |
| node_id="MakeTrainingDataset", |
| display_name="Make Training Dataset", |
| category="dataset", |
| is_experimental=True, |
| is_input_list=True, |
| inputs=[ |
| io.Image.Input("images", tooltip="List of images to encode."), |
| io.Vae.Input( |
| "vae", tooltip="VAE model for encoding images to latents." |
| ), |
| io.Clip.Input( |
| "clip", tooltip="CLIP model for encoding text to conditioning." |
| ), |
| io.String.Input( |
| "texts", |
| optional=True, |
| tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).", |
| ), |
| ], |
| outputs=[ |
| io.Latent.Output( |
| display_name="latents", |
| is_output_list=True, |
| tooltip="List of latent dicts", |
| ), |
| io.Conditioning.Output( |
| display_name="conditioning", |
| is_output_list=True, |
| tooltip="List of conditioning lists", |
| ), |
| ], |
| ) |
|
|
| @classmethod |
| def execute(cls, images, vae, clip, texts=None): |
| |
| vae = vae[0] |
| clip = clip[0] |
|
|
| |
| num_images = len(images) |
|
|
| if texts is None or len(texts) == 0: |
| |
| texts = [""] |
|
|
| if len(texts) == 1 and num_images > 1: |
| |
| texts = texts * num_images |
| elif len(texts) != num_images: |
| raise ValueError( |
| f"Number of texts ({len(texts)}) does not match number of images ({num_images}). " |
| f"Text list should have length {num_images}, 1, or 0." |
| ) |
|
|
| |
| logging.info(f"Encoding {num_images} images with VAE...") |
| latents_list = [] |
| for img_tensor in images: |
| |
| latent_tensor = vae.encode(img_tensor[:, :, :, :3]) |
| latents_list.append({"samples": latent_tensor}) |
|
|
| |
| logging.info(f"Encoding {len(texts)} texts with CLIP...") |
| conditioning_list = [] |
| for text in texts: |
| if text == "": |
| cond = clip.encode_from_tokens_scheduled(clip.tokenize("")) |
| else: |
| tokens = clip.tokenize(text) |
| cond = clip.encode_from_tokens_scheduled(tokens) |
| conditioning_list.append(cond) |
|
|
| logging.info( |
| f"Created dataset with {len(latents_list)} latents and {len(conditioning_list)} conditioning." |
| ) |
| return io.NodeOutput(latents_list, conditioning_list) |
|
|
|
|
| class SaveTrainingDataset(io.ComfyNode): |
| """Save encoded training dataset (latents + conditioning) to disk.""" |
|
|
| @classmethod |
| def define_schema(cls): |
| return io.Schema( |
| node_id="SaveTrainingDataset", |
| display_name="Save Training Dataset", |
| category="dataset", |
| is_experimental=True, |
| is_output_node=True, |
| is_input_list=True, |
| inputs=[ |
| io.Latent.Input( |
| "latents", |
| tooltip="List of latent dicts from MakeTrainingDataset.", |
| ), |
| io.Conditioning.Input( |
| "conditioning", |
| tooltip="List of conditioning lists from MakeTrainingDataset.", |
| ), |
| io.String.Input( |
| "folder_name", |
| default="training_dataset", |
| tooltip="Name of folder to save dataset (inside output directory).", |
| ), |
| io.Int.Input( |
| "shard_size", |
| default=1000, |
| min=1, |
| max=100000, |
| tooltip="Number of samples per shard file.", |
| ), |
| ], |
| outputs=[], |
| ) |
|
|
| @classmethod |
| def execute(cls, latents, conditioning, folder_name, shard_size): |
| |
| folder_name = folder_name[0] |
| shard_size = shard_size[0] |
|
|
| |
| |
|
|
| |
| if len(latents) != len(conditioning): |
| raise ValueError( |
| f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). " |
| f"Something went wrong in dataset preparation." |
| ) |
|
|
| |
| output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| |
| num_samples = len(latents) |
| num_shards = (num_samples + shard_size - 1) // shard_size |
|
|
| logging.info( |
| f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..." |
| ) |
|
|
| |
| for shard_idx in range(num_shards): |
| start_idx = shard_idx * shard_size |
| end_idx = min(start_idx + shard_size, num_samples) |
|
|
| |
| shard_data = { |
| "latents": latents[start_idx:end_idx], |
| "conditioning": conditioning[start_idx:end_idx], |
| } |
|
|
| |
| shard_filename = f"shard_{shard_idx:04d}.pkl" |
| shard_path = os.path.join(output_dir, shard_filename) |
|
|
| with open(shard_path, "wb") as f: |
| torch.save(shard_data, f) |
|
|
| logging.info( |
| f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)" |
| ) |
|
|
| |
| metadata = { |
| "num_samples": num_samples, |
| "num_shards": num_shards, |
| "shard_size": shard_size, |
| } |
| metadata_path = os.path.join(output_dir, "metadata.json") |
| with open(metadata_path, "w") as f: |
| json.dump(metadata, f, indent=2) |
|
|
| logging.info(f"Successfully saved {num_samples} samples to {output_dir}.") |
| return io.NodeOutput() |
|
|
|
|
| class LoadTrainingDataset(io.ComfyNode): |
| """Load encoded training dataset from disk.""" |
|
|
| @classmethod |
| def define_schema(cls): |
| return io.Schema( |
| node_id="LoadTrainingDataset", |
| display_name="Load Training Dataset", |
| category="dataset", |
| is_experimental=True, |
| inputs=[ |
| io.String.Input( |
| "folder_name", |
| default="training_dataset", |
| tooltip="Name of folder containing the saved dataset (inside output directory).", |
| ), |
| ], |
| outputs=[ |
| io.Latent.Output( |
| display_name="latents", |
| is_output_list=True, |
| tooltip="List of latent dicts", |
| ), |
| io.Conditioning.Output( |
| display_name="conditioning", |
| is_output_list=True, |
| tooltip="List of conditioning lists", |
| ), |
| ], |
| ) |
|
|
| @classmethod |
| def execute(cls, folder_name): |
| |
| dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name) |
|
|
| if not os.path.exists(dataset_dir): |
| raise ValueError(f"Dataset directory not found: {dataset_dir}") |
|
|
| |
| shard_files = sorted( |
| [ |
| f |
| for f in os.listdir(dataset_dir) |
| if f.startswith("shard_") and f.endswith(".pkl") |
| ] |
| ) |
|
|
| if not shard_files: |
| raise ValueError(f"No shard files found in {dataset_dir}") |
|
|
| logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...") |
|
|
| |
| all_latents = [] |
| all_conditioning = [] |
|
|
| for shard_file in shard_files: |
| shard_path = os.path.join(dataset_dir, shard_file) |
|
|
| with open(shard_path, "rb") as f: |
| shard_data = torch.load(f) |
|
|
| all_latents.extend(shard_data["latents"]) |
| all_conditioning.extend(shard_data["conditioning"]) |
|
|
| logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples") |
|
|
| logging.info( |
| f"Successfully loaded {len(all_latents)} samples from {dataset_dir}." |
| ) |
| return io.NodeOutput(all_latents, all_conditioning) |
|
|
|
|
| |
|
|
|
|
| class DatasetExtension(ComfyExtension): |
| @override |
| async def get_node_list(self) -> list[type[io.ComfyNode]]: |
| return [ |
| |
| LoadImageDataSetFromFolderNode, |
| LoadImageTextDataSetFromFolderNode, |
| SaveImageDataSetToFolderNode, |
| SaveImageTextDataSetToFolderNode, |
| |
| ResizeImagesByShorterEdgeNode, |
| ResizeImagesByLongerEdgeNode, |
| CenterCropImagesNode, |
| RandomCropImagesNode, |
| NormalizeImagesNode, |
| AdjustBrightnessNode, |
| AdjustContrastNode, |
| ShuffleDatasetNode, |
| ShuffleImageTextDatasetNode, |
| |
| TextToLowercaseNode, |
| TextToUppercaseNode, |
| TruncateTextNode, |
| AddTextPrefixNode, |
| AddTextSuffixNode, |
| ReplaceTextNode, |
| StripWhitespaceNode, |
| |
| ImageDeduplicationNode, |
| ImageGridNode, |
| MergeImageListsNode, |
| MergeTextListsNode, |
| |
| MakeTrainingDataset, |
| SaveTrainingDataset, |
| LoadTrainingDataset, |
| ResolutionBucket, |
| ] |
|
|
|
|
| async def comfy_entrypoint() -> DatasetExtension: |
| return DatasetExtension() |
|
|