File size: 4,300 Bytes
c7a0808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
from typing import List


class BatchFilterKeepFirstLast:
    """

    Batch filter node (IMAGE -> IMAGE) that always keeps the first and last image.



    Modes (int):

      - 0  : passthrough (no changes)

      - 10 : keep 1st, 3rd, 5th, ... (drop every 2nd), but always keep last

      - <10: keep slightly MORE than mode 10 (adds back frames, evenly distributed)

      - >10: keep slightly FEWER than mode 10 (removes extra frames, evenly distributed)



    Notes:

      - In ComfyUI, IMAGE is a batch (torch.Tensor) of shape [B, H, W, C]. We only filter B. :contentReference[oaicite:2]{index=2}

      - Works with RGBA (C=4) or RGB (C=3) since we do not modify channels.

    """

    CATEGORY = "image/batch"
    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("images",)
    FUNCTION = "filter_batch"

    # How much each +/-1 step away from mode 10 adjusts the batch, as a fraction of B.
    # With your reference batch size B=40:
    #   round(40 * 0.05) = 2 images per step (i.e., mode 9 adds ~2; mode 11 removes ~2).
    ADJUST_PER_STEP_FRACTION = 0.05

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "images": ("IMAGE",),
                "mode": ("INT", {"default": 10, "min": 0, "max": 20, "step": 1}),
            }
        }

    def filter_batch(self, images: torch.Tensor, mode: int):
        if not isinstance(images, torch.Tensor):
            raise TypeError("images must be a torch.Tensor")

        if images.ndim != 4:
            raise ValueError(f"Expected images with shape [B,H,W,C], got {tuple(images.shape)}")

        b = int(images.shape[0])
        if b <= 1 or mode == 0:
            return (images,)

        # Base pattern for mode 10: keep 0,2,4,... plus always keep last.
        keep = list(range(0, b, 2))
        if (b - 1) not in keep:
            keep.append(b - 1)
        keep = sorted(set(keep))

        if mode != 10:
            delta = mode - 10  # <0 => keep more, >0 => keep fewer
            step = max(1, int(round(b * self.ADJUST_PER_STEP_FRACTION)))

            min_keep = 1 if b == 1 else 2  # first+last (or just one if batch=1)

            if delta < 0:
                # Add frames back from those we dropped, evenly spread.
                add_count = min((-delta) * step, b - len(keep))
                if add_count > 0:
                    candidates = [i for i in range(b) if i not in keep and i not in (0, b - 1)]
                    add_idxs = self._evenly_pick(candidates, add_count)
                    keep = sorted(set(keep + add_idxs))

            elif delta > 0:
                # Remove extra frames from the kept set (but never first/last), evenly spread.
                max_removable = max(0, len(keep) - min_keep)
                remove_count = min(delta * step, max_removable)
                if remove_count > 0:
                    removable = [i for i in keep if i not in (0, b - 1)]
                    remove_idxs = set(self._evenly_pick(removable, remove_count))
                    keep = [i for i in keep if i not in remove_idxs]
                    keep = sorted(set(keep))

        # Enforce rule: always keep first and last.
        if 0 not in keep:
            keep.insert(0, 0)
        if (b - 1) not in keep:
            keep.append(b - 1)
        keep = sorted(set(keep))

        out = images[keep, ...]
        return (out,)

    @staticmethod
    def _evenly_pick(items: List[int], k: int) -> List[int]:
        """

        Pick k unique elements from items, evenly distributed across the list.

        Deterministic, preserves ordering of selected indices in 'items'.

        """
        m = len(items)
        if k <= 0 or m == 0:
            return []
        k = min(k, m)

        # Choose k positions in (0..m-1) spread out, avoiding endpoints bias.
        # This yields strictly increasing positions for k<=m.
        positions = [int((i + 1) * (m + 1) / (k + 1)) - 1 for i in range(k)]
        return [items[p] for p in positions]


NODE_CLASS_MAPPINGS = {
    "BatchFilterKeepFirstLast": BatchFilterKeepFirstLast,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "BatchFilterKeepFirstLast": "Batch Filter (Keep First/Last)",
}