File size: 5,503 Bytes
cbb0df3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from __future__ import annotations

import heapq
from typing import Dict, List, Tuple

import torch


class BatchEvenMotionPruner:
    """
    Remove the most redundant interior frame from an IMAGE batch until the
    requested batch size is reached.

    Redundancy score for an interior frame i:
        mean_abs_diff(frame[i], frame[left_neighbor]) +
        mean_abs_diff(frame[i], frame[right_neighbor])

    The frame with the LOWEST score is removed first.
    The first and last frames are never removed.
    """

    CATEGORY = "image/batch"
    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("images",)
    FUNCTION = "prune"

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "images": ("IMAGE", {}),
                "target_count": (
                    "INT",
                    {
                        "default": 16,
                        "min": 1,
                        "max": 4096,
                        "step": 1,
                    },
                ),
            }
        }

    @staticmethod
    def _validate_images(images: torch.Tensor) -> torch.Tensor:
        if not isinstance(images, torch.Tensor):
            raise TypeError("Expected 'images' to be a torch.Tensor.")

        # ComfyUI IMAGE is normally [B, H, W, C]. Accept [H, W, C] defensively.
        if images.ndim == 3:
            images = images.unsqueeze(0)
        elif images.ndim != 4:
            raise ValueError(
                f"Expected IMAGE tensor with shape [B,H,W,C], got shape {tuple(images.shape)}."
            )

        return images

    @staticmethod
    def _pair_key(a: int, b: int) -> Tuple[int, int]:
        return (a, b) if a < b else (b, a)

    def _pair_difference(
        self,
        images: torch.Tensor,
        left_idx: int,
        right_idx: int,
        cache: Dict[Tuple[int, int], float],
    ) -> float:
        key = self._pair_key(left_idx, right_idx)
        cached = cache.get(key)
        if cached is not None:
            return cached

        left = images[left_idx].float()
        right = images[right_idx].float()

        # Mean Absolute Difference over all pixels/channels.
        value = torch.mean(torch.abs(left - right)).item()
        cache[key] = value
        return value

    def _candidate_score(
        self,
        images: torch.Tensor,
        idx: int,
        prev_idx: List[int],
        next_idx: List[int],
        cache: Dict[Tuple[int, int], float],
    ) -> float:
        left = prev_idx[idx]
        right = next_idx[idx]
        if left == -1 or right == -1:
            raise ValueError("Endpoints must not be scored for removal.")

        return (
            self._pair_difference(images, left, idx, cache)
            + self._pair_difference(images, idx, right, cache)
        )

    def prune(self, images: torch.Tensor, target_count: int):
        images = self._validate_images(images)

        batch_size = int(images.shape[0])
        target_count = int(target_count)

        if batch_size <= 1 or target_count >= batch_size:
            return (images,)

        # If first and last are protected, batches with 2+ frames cannot go below 2.
        minimum_reachable = 1 if batch_size <= 1 else 2
        desired_count = max(target_count, minimum_reachable)

        if desired_count >= batch_size:
            return (images,)

        prev_idx = [-1] + [i - 1 for i in range(1, batch_size)]
        next_idx = [i + 1 for i in range(batch_size - 1)] + [-1]
        alive = [True] * batch_size
        candidate_version = [0] * batch_size
        pair_cache: Dict[Tuple[int, int], float] = {}
        heap: List[Tuple[float, int, int]] = []

        def push_candidate(i: int) -> None:
            if i <= 0 or i >= batch_size - 1:
                return
            if not alive[i]:
                return
            if prev_idx[i] == -1 or next_idx[i] == -1:
                return

            candidate_version[i] += 1
            score = self._candidate_score(images, i, prev_idx, next_idx, pair_cache)
            heapq.heappush(heap, (score, i, candidate_version[i]))

        # Seed all removable interior frames.
        for i in range(1, batch_size - 1):
            push_candidate(i)

        remaining = batch_size

        while remaining > desired_count and heap:
            _score, idx, version = heapq.heappop(heap)

            # Ignore stale heap entries.
            if not alive[idx]:
                continue
            if candidate_version[idx] != version:
                continue
            if prev_idx[idx] == -1 or next_idx[idx] == -1:
                continue

            left = prev_idx[idx]
            right = next_idx[idx]

            # Remove idx from the linked list.
            alive[idx] = False
            remaining -= 1

            next_idx[left] = right
            prev_idx[right] = left
            prev_idx[idx] = -1
            next_idx[idx] = -1

            # Only neighbors around the removed frame need updated scores.
            push_candidate(left)
            push_candidate(right)

        keep_indices = [i for i, is_alive in enumerate(alive) if is_alive]
        keep_tensor = torch.tensor(keep_indices, device=images.device, dtype=torch.long)
        output = images.index_select(0, keep_tensor)
        return (output,)


NODE_CLASS_MAPPINGS = {
    "BatchEvenMotionPruner": BatchEvenMotionPruner,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "BatchEvenMotionPruner": "Batch Even Motion Pruner",
}