MyCustomNodes / BatchEvenMotionPruner.py
saliacoel's picture
Upload BatchEvenMotionPruner.py
cbb0df3 verified
from __future__ import annotations
import heapq
from typing import Dict, List, Tuple
import torch
class BatchEvenMotionPruner:
"""
Remove the most redundant interior frame from an IMAGE batch until the
requested batch size is reached.
Redundancy score for an interior frame i:
mean_abs_diff(frame[i], frame[left_neighbor]) +
mean_abs_diff(frame[i], frame[right_neighbor])
The frame with the LOWEST score is removed first.
The first and last frames are never removed.
"""
CATEGORY = "image/batch"
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "prune"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE", {}),
"target_count": (
"INT",
{
"default": 16,
"min": 1,
"max": 4096,
"step": 1,
},
),
}
}
@staticmethod
def _validate_images(images: torch.Tensor) -> torch.Tensor:
if not isinstance(images, torch.Tensor):
raise TypeError("Expected 'images' to be a torch.Tensor.")
# ComfyUI IMAGE is normally [B, H, W, C]. Accept [H, W, C] defensively.
if images.ndim == 3:
images = images.unsqueeze(0)
elif images.ndim != 4:
raise ValueError(
f"Expected IMAGE tensor with shape [B,H,W,C], got shape {tuple(images.shape)}."
)
return images
@staticmethod
def _pair_key(a: int, b: int) -> Tuple[int, int]:
return (a, b) if a < b else (b, a)
def _pair_difference(
self,
images: torch.Tensor,
left_idx: int,
right_idx: int,
cache: Dict[Tuple[int, int], float],
) -> float:
key = self._pair_key(left_idx, right_idx)
cached = cache.get(key)
if cached is not None:
return cached
left = images[left_idx].float()
right = images[right_idx].float()
# Mean Absolute Difference over all pixels/channels.
value = torch.mean(torch.abs(left - right)).item()
cache[key] = value
return value
def _candidate_score(
self,
images: torch.Tensor,
idx: int,
prev_idx: List[int],
next_idx: List[int],
cache: Dict[Tuple[int, int], float],
) -> float:
left = prev_idx[idx]
right = next_idx[idx]
if left == -1 or right == -1:
raise ValueError("Endpoints must not be scored for removal.")
return (
self._pair_difference(images, left, idx, cache)
+ self._pair_difference(images, idx, right, cache)
)
def prune(self, images: torch.Tensor, target_count: int):
images = self._validate_images(images)
batch_size = int(images.shape[0])
target_count = int(target_count)
if batch_size <= 1 or target_count >= batch_size:
return (images,)
# If first and last are protected, batches with 2+ frames cannot go below 2.
minimum_reachable = 1 if batch_size <= 1 else 2
desired_count = max(target_count, minimum_reachable)
if desired_count >= batch_size:
return (images,)
prev_idx = [-1] + [i - 1 for i in range(1, batch_size)]
next_idx = [i + 1 for i in range(batch_size - 1)] + [-1]
alive = [True] * batch_size
candidate_version = [0] * batch_size
pair_cache: Dict[Tuple[int, int], float] = {}
heap: List[Tuple[float, int, int]] = []
def push_candidate(i: int) -> None:
if i <= 0 or i >= batch_size - 1:
return
if not alive[i]:
return
if prev_idx[i] == -1 or next_idx[i] == -1:
return
candidate_version[i] += 1
score = self._candidate_score(images, i, prev_idx, next_idx, pair_cache)
heapq.heappush(heap, (score, i, candidate_version[i]))
# Seed all removable interior frames.
for i in range(1, batch_size - 1):
push_candidate(i)
remaining = batch_size
while remaining > desired_count and heap:
_score, idx, version = heapq.heappop(heap)
# Ignore stale heap entries.
if not alive[idx]:
continue
if candidate_version[idx] != version:
continue
if prev_idx[idx] == -1 or next_idx[idx] == -1:
continue
left = prev_idx[idx]
right = next_idx[idx]
# Remove idx from the linked list.
alive[idx] = False
remaining -= 1
next_idx[left] = right
prev_idx[right] = left
prev_idx[idx] = -1
next_idx[idx] = -1
# Only neighbors around the removed frame need updated scores.
push_candidate(left)
push_candidate(right)
keep_indices = [i for i, is_alive in enumerate(alive) if is_alive]
keep_tensor = torch.tensor(keep_indices, device=images.device, dtype=torch.long)
output = images.index_select(0, keep_tensor)
return (output,)
NODE_CLASS_MAPPINGS = {
"BatchEvenMotionPruner": BatchEvenMotionPruner,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"BatchEvenMotionPruner": "Batch Even Motion Pruner",
}