# Save as: ComfyUI/custom_nodes/special_batch_split.py # Restart ComfyUI after saving. import torch class Custom_Batch_Output: """ Input: - images (IMAGE batch, typically torch.Tensor [B, H, W, C]) Outputs: - Batch_Up: [ ID 7 ] + [ IDs 9..25 ] + [ IDs 27..31 ] + [ IDs 33..36 ] - Rife_x3: [ ID 4 ] + [ ID 37 ] (2-image batch) Indexing is 0-based and ranges are inclusive (e.g., 9..25 includes both 9 and 25). Safety behavior: - If the input batch is too small (needs at least indices up to 37 => B >= 38), or input is not a proper IMAGE tensor, the node returns the original input batch for BOTH outputs. """ CATEGORY = "image/batch" FUNCTION = "make_special_batches" RETURN_TYPES = ("IMAGE", "IMAGE") RETURN_NAMES = ("Batch_Up", "Rife_x3") @classmethod def INPUT_TYPES(cls): return {"required": {"images": ("IMAGE",)}} @staticmethod def _normalize_to_batch(images: torch.Tensor) -> torch.Tensor: # Accept single image [H,W,C] and convert to [1,H,W,C] if images.dim() == 3: return images.unsqueeze(0) return images def make_special_batches(self, images): # Basic validation + safe fallback if not isinstance(images, torch.Tensor): return (images, images) images = self._normalize_to_batch(images) # Expect [B,H,W,C] if images.dim() != 4: return (images, images) b = int(images.shape[0]) # Need indices up to 37 => batch size at least 38 if b < 38: return (images, images) # Build Batch_Up indices (inclusive ranges) batch_up_indices = ( [7] + list(range(9, 26)) # 9..25 + list(range(27, 32)) # 27..31 + list(range(33, 37)) # 33..36 ) # Build Rife_x3 indices rife_x3_indices = [4, 37] # Gather using index_select (works on GPU/CPU, preserves dtype/device) device = images.device idx_up = torch.tensor(batch_up_indices, dtype=torch.long, device=device) idx_rife = torch.tensor(rife_x3_indices, dtype=torch.long, device=device) batch_up = torch.index_select(images, 0, idx_up).clone() rife_x3 = torch.index_select(images, 0, idx_rife).clone() return (batch_up, rife_x3) NODE_CLASS_MAPPINGS = { "Custom_Batch_Output": Custom_Batch_Output, } NODE_DISPLAY_NAME_MAPPINGS = { "Custom_Batch_Output": "Custom_Batch_Output", }