File size: 2,958 Bytes
166476b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Save as: ComfyUI/custom_nodes/batch_slice_start_end.py
# Restart ComfyUI after saving.

import torch


class Get_Batch_Range_Start_To_End:
    """

    Inputs:

      - start_id (INT)

      - end_id   (INT)

      - images   (IMAGE batch, typically torch.Tensor [B, H, W, C])



    Outputs:

      - sliced_images (IMAGE batch)

      - status        (STRING): "ok" or an error message

      - count         (INT): number of images in the *input* batch



    Behavior:

      - Returns images from start_id to end_id (inclusive).

      - If invalid / impossible (out of range, start>end, empty batch, etc.),

        returns the original input batch unchanged, plus an error message.

    """

    CATEGORY = "image/batch"
    FUNCTION = "slice_batch"

    RETURN_TYPES = ("IMAGE", "STRING", "INT")
    RETURN_NAMES = ("images", "status", "count")

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "start_id": ("INT", {"default": 0, "min": 0, "max": 1_000_000, "step": 1}),
                "end_id": ("INT", {"default": 0, "min": 0, "max": 1_000_000, "step": 1}),
                "images": ("IMAGE",),
            }
        }

    def slice_batch(self, start_id, end_id, images):
        # Validate tensor
        if not isinstance(images, torch.Tensor):
            # Can't reliably "return original" if it's not a tensor, but try anyway.
            return (images, "error: images is not a torch.Tensor", 0)

        # Normalize to batched shape for safety
        original = images
        if images.dim() == 3:
            images = images.unsqueeze(0)  # [1, H, W, C]
        elif images.dim() != 4:
            # Return original unchanged
            count = int(images.shape[0]) if images.dim() > 0 else 0
            return (original, f"error: expected IMAGE with 3 or 4 dims, got {tuple(images.shape)}", count)

        b = int(images.shape[0])  # input batch count

        if b <= 0:
            return (images, "error: empty batch (B=0)", 0)

        # Validate indices (inclusive slicing)
        if start_id > end_id:
            return (images, f"error: start_id > end_id ({start_id} > {end_id})", b)

        if start_id < 0 or end_id < 0:
            return (images, f"error: negative index not allowed (start_id={start_id}, end_id={end_id})", b)

        if start_id >= b or end_id >= b:
            return (
                images,
                f"error: out of range (start_id={start_id}, end_id={end_id}, batch_size={b})",
                b,
            )

        # Slice inclusive: [start_id, end_id]
        sliced = images[start_id : end_id + 1].clone()
        return (sliced, "ok", b)


NODE_CLASS_MAPPINGS = {
    "Get_Batch_Range_Start_To_End": Get_Batch_Range_Start_To_End,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "Get_Batch_Range_Start_To_End": "Get Batch from Batch (From Start ID to End ID)",
}