File size: 2,629 Bytes
f5d70cd
 
 
 
 
 
a336c71
f5d70cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Save as: ComfyUI/custom_nodes/special_batch_split.py
# Restart ComfyUI after saving.

import torch


class Custom_Batch_Output:
    """

    Input:

      - images (IMAGE batch, typically torch.Tensor [B, H, W, C])



    Outputs:

      - Batch_Up:  [ ID 7 ] + [ IDs 9..25 ] + [ IDs 27..31 ] + [ IDs 33..36 ]

      - Rife_x3:   [ ID 4 ] + [ ID 37 ]  (2-image batch)



    Indexing is 0-based and ranges are inclusive (e.g., 9..25 includes both 9 and 25).



    Safety behavior:

      - If the input batch is too small (needs at least indices up to 37 => B >= 38),

        or input is not a proper IMAGE tensor, the node returns the original input batch

        for BOTH outputs.

    """

    CATEGORY = "image/batch"
    FUNCTION = "make_special_batches"

    RETURN_TYPES = ("IMAGE", "IMAGE")
    RETURN_NAMES = ("Batch_Up", "Rife_x3")

    @classmethod
    def INPUT_TYPES(cls):
        return {"required": {"images": ("IMAGE",)}}

    @staticmethod
    def _normalize_to_batch(images: torch.Tensor) -> torch.Tensor:
        # Accept single image [H,W,C] and convert to [1,H,W,C]
        if images.dim() == 3:
            return images.unsqueeze(0)
        return images

    def make_special_batches(self, images):
        # Basic validation + safe fallback
        if not isinstance(images, torch.Tensor):
            return (images, images)

        images = self._normalize_to_batch(images)

        # Expect [B,H,W,C]
        if images.dim() != 4:
            return (images, images)

        b = int(images.shape[0])

        # Need indices up to 37 => batch size at least 38
        if b < 38:
            return (images, images)

        # Build Batch_Up indices (inclusive ranges)
        batch_up_indices = (
            [7]
            + list(range(9, 26))   # 9..25
            + list(range(27, 32))  # 27..31
            + list(range(33, 37))  # 33..36
        )

        # Build Rife_x3 indices
        rife_x3_indices = [4, 37]

        # Gather using index_select (works on GPU/CPU, preserves dtype/device)
        device = images.device
        idx_up = torch.tensor(batch_up_indices, dtype=torch.long, device=device)
        idx_rife = torch.tensor(rife_x3_indices, dtype=torch.long, device=device)

        batch_up = torch.index_select(images, 0, idx_up).clone()
        rife_x3 = torch.index_select(images, 0, idx_rife).clone()

        return (batch_up, rife_x3)


NODE_CLASS_MAPPINGS = {
    "Custom_Batch_Output": Custom_Batch_Output,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "Custom_Batch_Output": "Custom_Batch_Output",
}