import torch from typing import List class BatchFilterKeepFirstLast: """ Batch filter node (IMAGE -> IMAGE) that always keeps the first and last image. Modes (int): - 0 : passthrough (no changes) - 10 : keep 1st, 3rd, 5th, ... (drop every 2nd), but always keep last - <10: keep slightly MORE than mode 10 (adds back frames, evenly distributed) - >10: keep slightly FEWER than mode 10 (removes extra frames, evenly distributed) Notes: - In ComfyUI, IMAGE is a batch (torch.Tensor) of shape [B, H, W, C]. We only filter B. :contentReference[oaicite:2]{index=2} - Works with RGBA (C=4) or RGB (C=3) since we do not modify channels. """ CATEGORY = "image/batch" RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("images",) FUNCTION = "filter_batch" # How much each +/-1 step away from mode 10 adjusts the batch, as a fraction of B. # With your reference batch size B=40: # round(40 * 0.05) = 2 images per step (i.e., mode 9 adds ~2; mode 11 removes ~2). ADJUST_PER_STEP_FRACTION = 0.05 @classmethod def INPUT_TYPES(cls): return { "required": { "images": ("IMAGE",), "mode": ("INT", {"default": 10, "min": 0, "max": 20, "step": 1}), } } def filter_batch(self, images: torch.Tensor, mode: int): if not isinstance(images, torch.Tensor): raise TypeError("images must be a torch.Tensor") if images.ndim != 4: raise ValueError(f"Expected images with shape [B,H,W,C], got {tuple(images.shape)}") b = int(images.shape[0]) if b <= 1 or mode == 0: return (images,) # Base pattern for mode 10: keep 0,2,4,... plus always keep last. keep = list(range(0, b, 2)) if (b - 1) not in keep: keep.append(b - 1) keep = sorted(set(keep)) if mode != 10: delta = mode - 10 # <0 => keep more, >0 => keep fewer step = max(1, int(round(b * self.ADJUST_PER_STEP_FRACTION))) min_keep = 1 if b == 1 else 2 # first+last (or just one if batch=1) if delta < 0: # Add frames back from those we dropped, evenly spread. add_count = min((-delta) * step, b - len(keep)) if add_count > 0: candidates = [i for i in range(b) if i not in keep and i not in (0, b - 1)] add_idxs = self._evenly_pick(candidates, add_count) keep = sorted(set(keep + add_idxs)) elif delta > 0: # Remove extra frames from the kept set (but never first/last), evenly spread. max_removable = max(0, len(keep) - min_keep) remove_count = min(delta * step, max_removable) if remove_count > 0: removable = [i for i in keep if i not in (0, b - 1)] remove_idxs = set(self._evenly_pick(removable, remove_count)) keep = [i for i in keep if i not in remove_idxs] keep = sorted(set(keep)) # Enforce rule: always keep first and last. if 0 not in keep: keep.insert(0, 0) if (b - 1) not in keep: keep.append(b - 1) keep = sorted(set(keep)) out = images[keep, ...] return (out,) @staticmethod def _evenly_pick(items: List[int], k: int) -> List[int]: """ Pick k unique elements from items, evenly distributed across the list. Deterministic, preserves ordering of selected indices in 'items'. """ m = len(items) if k <= 0 or m == 0: return [] k = min(k, m) # Choose k positions in (0..m-1) spread out, avoiding endpoints bias. # This yields strictly increasing positions for k<=m. positions = [int((i + 1) * (m + 1) / (k + 1)) - 1 for i in range(k)] return [items[p] for p in positions] NODE_CLASS_MAPPINGS = { "BatchFilterKeepFirstLast": BatchFilterKeepFirstLast, } NODE_DISPLAY_NAME_MAPPINGS = { "BatchFilterKeepFirstLast": "Batch Filter (Keep First/Last)", }