File size: 4,300 Bytes
c7a0808 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | import torch
from typing import List
class BatchFilterKeepFirstLast:
"""
Batch filter node (IMAGE -> IMAGE) that always keeps the first and last image.
Modes (int):
- 0 : passthrough (no changes)
- 10 : keep 1st, 3rd, 5th, ... (drop every 2nd), but always keep last
- <10: keep slightly MORE than mode 10 (adds back frames, evenly distributed)
- >10: keep slightly FEWER than mode 10 (removes extra frames, evenly distributed)
Notes:
- In ComfyUI, IMAGE is a batch (torch.Tensor) of shape [B, H, W, C]. We only filter B. :contentReference[oaicite:2]{index=2}
- Works with RGBA (C=4) or RGB (C=3) since we do not modify channels.
"""
CATEGORY = "image/batch"
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "filter_batch"
# How much each +/-1 step away from mode 10 adjusts the batch, as a fraction of B.
# With your reference batch size B=40:
# round(40 * 0.05) = 2 images per step (i.e., mode 9 adds ~2; mode 11 removes ~2).
ADJUST_PER_STEP_FRACTION = 0.05
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"mode": ("INT", {"default": 10, "min": 0, "max": 20, "step": 1}),
}
}
def filter_batch(self, images: torch.Tensor, mode: int):
if not isinstance(images, torch.Tensor):
raise TypeError("images must be a torch.Tensor")
if images.ndim != 4:
raise ValueError(f"Expected images with shape [B,H,W,C], got {tuple(images.shape)}")
b = int(images.shape[0])
if b <= 1 or mode == 0:
return (images,)
# Base pattern for mode 10: keep 0,2,4,... plus always keep last.
keep = list(range(0, b, 2))
if (b - 1) not in keep:
keep.append(b - 1)
keep = sorted(set(keep))
if mode != 10:
delta = mode - 10 # <0 => keep more, >0 => keep fewer
step = max(1, int(round(b * self.ADJUST_PER_STEP_FRACTION)))
min_keep = 1 if b == 1 else 2 # first+last (or just one if batch=1)
if delta < 0:
# Add frames back from those we dropped, evenly spread.
add_count = min((-delta) * step, b - len(keep))
if add_count > 0:
candidates = [i for i in range(b) if i not in keep and i not in (0, b - 1)]
add_idxs = self._evenly_pick(candidates, add_count)
keep = sorted(set(keep + add_idxs))
elif delta > 0:
# Remove extra frames from the kept set (but never first/last), evenly spread.
max_removable = max(0, len(keep) - min_keep)
remove_count = min(delta * step, max_removable)
if remove_count > 0:
removable = [i for i in keep if i not in (0, b - 1)]
remove_idxs = set(self._evenly_pick(removable, remove_count))
keep = [i for i in keep if i not in remove_idxs]
keep = sorted(set(keep))
# Enforce rule: always keep first and last.
if 0 not in keep:
keep.insert(0, 0)
if (b - 1) not in keep:
keep.append(b - 1)
keep = sorted(set(keep))
out = images[keep, ...]
return (out,)
@staticmethod
def _evenly_pick(items: List[int], k: int) -> List[int]:
"""
Pick k unique elements from items, evenly distributed across the list.
Deterministic, preserves ordering of selected indices in 'items'.
"""
m = len(items)
if k <= 0 or m == 0:
return []
k = min(k, m)
# Choose k positions in (0..m-1) spread out, avoiding endpoints bias.
# This yields strictly increasing positions for k<=m.
positions = [int((i + 1) * (m + 1) / (k + 1)) - 1 for i in range(k)]
return [items[p] for p in positions]
NODE_CLASS_MAPPINGS = {
"BatchFilterKeepFirstLast": BatchFilterKeepFirstLast,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"BatchFilterKeepFirstLast": "Batch Filter (Keep First/Last)",
}
|