File size: 4,276 Bytes
6f46f5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# ComfyUI Custom Nodes: Batch Index Tools
# Put this file at:
#   ComfyUI/custom_nodes/batch_index_tools/__init__.py
# Restart ComfyUI, then look under category: "Batch/Index"

import torch


def _clamp_index(index: int, batch_size: int) -> int:
    """Clamp index into [0, batch_size-1]."""
    if batch_size <= 0:
        raise ValueError("Input batch is empty (batch_size <= 0).")
    if index < 0 or index >= batch_size:
        print(
            f"[BatchIndexTools] index {index} out of range for batch_size {batch_size}; "
            f"clamping to valid range."
        )
        index = max(0, min(index, batch_size - 1))
    return index


class BatchGetImageAtIndex:
    """

    Node 1:

      - Takes an IMAGE batch and an integer index

      - Outputs the image at that index (as a batch of size 1)

    Notes:

      - Index is zero-based (0 is the first image).

      - If index is out of range, it is clamped to the nearest valid index.

    """

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "images": ("IMAGE",),
                "index": ("INT", {"default": 0, "min": 0, "max": 10**9}),
            }
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = "get"
    CATEGORY = "Batch/Index"

    def get(self, images, index):
        if not torch.is_tensor(images):
            raise TypeError("Expected 'images' to be a torch Tensor (ComfyUI IMAGE type).")
        if images.ndim != 4:
            raise ValueError(f"Expected 'images' with shape [B,H,W,C], got ndim={images.ndim}.")

        b = images.shape[0]
        idx = _clamp_index(int(index), b)

        # Keep batch dimension (return a batch of 1)
        out = images[idx : idx + 1]
        return (out,)


class BatchReplaceImageAtIndex:
    """

    Node 2:

      - Takes an IMAGE batch, an integer index, and a single IMAGE

      - Replaces the batch item at that index with the provided image

      - Outputs the modified batch

    Notes:

      - Index is zero-based (0 is the first image).

      - If index is out of range, it is clamped to the nearest valid index.

      - The replacement image must have the same H/W/C as the batch images.

      - If 'image' is a batch, only the first image is used.

    """

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "images": ("IMAGE",),
                "index": ("INT", {"default": 0, "min": 0, "max": 10**9}),
                "image": ("IMAGE",),
            }
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("images",)
    FUNCTION = "replace"
    CATEGORY = "Batch/Index"

    def replace(self, images, index, image):
        if not torch.is_tensor(images) or not torch.is_tensor(image):
            raise TypeError("Expected 'images' and 'image' to be torch Tensors (ComfyUI IMAGE type).")
        if images.ndim != 4:
            raise ValueError(f"Expected 'images' with shape [B,H,W,C], got ndim={images.ndim}.")
        if image.ndim != 4:
            raise ValueError(f"Expected 'image' with shape [B,H,W,C], got ndim={image.ndim}.")

        b = images.shape[0]
        idx = _clamp_index(int(index), b)

        # Use first image if a batch is provided
        replacement = image[:1]

        # Validate spatial/channel match
        if replacement.shape[1:] != images.shape[1:]:
            raise ValueError(
                "Replacement image must match batch image shape [H,W,C]. "
                f"Batch has [H,W,C]={tuple(images.shape[1:])}, "
                f"replacement has [H,W,C]={tuple(replacement.shape[1:])}."
            )

        # Make output without mutating input
        out = images.clone()

        # Ensure dtype/device match
        rep0 = replacement[0].to(device=out.device, dtype=out.dtype)

        out[idx] = rep0
        return (out,)


NODE_CLASS_MAPPINGS = {
    "BatchGetImageAtIndex": BatchGetImageAtIndex,
    "BatchReplaceImageAtIndex": BatchReplaceImageAtIndex,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "BatchGetImageAtIndex": "Batch: Get Image @ Index",
    "BatchReplaceImageAtIndex": "Batch: Replace Image @ Index",
}