import torch class TemporalHintFromPair: """ Concatenate two RGB images (current & previous) along channel dim to produce a 6-channel IMAGE. Works with batched tensors. If previous is None, it falls back to current (no-op for first frame). """ @classmethod def INPUT_TYPES(cls): return { "required": { "current": ("IMAGE",), "previous": ("IMAGE",), }, "optional": { "clip_to_range": ("BOOLEAN", {"default": True}), }, } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("temporal_hint",) FUNCTION = "make_hint" CATEGORY = "Temporal/Utils" @staticmethod def _ensure_batch(x): if x.dim() == 3: x = x.unsqueeze(0) return x @staticmethod def _match_batch(a, b): ba, bb = a.shape[0], b.shape[0] if ba == bb: return a, b if ba == 1: a = a.repeat(bb, 1, 1, 1) elif bb == 1: b = b.repeat(ba, 1, 1, 1) else: n = min(ba, bb) a = a[:n] b = b[:n] return a, b def make_hint(self, current, previous, clip_to_range=True): current = self._ensure_batch(current) previous = self._ensure_batch(previous) if current.shape[-1] != 3 or previous.shape[-1] != 3: raise ValueError(f"Expected RGB images with 3 channels; got {current.shape} & {previous.shape}") current, previous = self._match_batch(current, previous) if current.shape[1:3] != previous.shape[1:3]: previous = torch.nn.functional.interpolate( previous.permute(0,3,1,2), size=(current.shape[1], current.shape[2]), mode="nearest" ).permute(0,2,3,1) if clip_to_range: current = current.clamp(0.0, 1.0) previous = previous.clamp(0.0, 1.0) temporal_hint = torch.cat([current, previous], dim=3) return (temporal_hint,) NODE_CLASS_MAPPINGS = { "TemporalHintFromPair": TemporalHintFromPair, } NODE_DISPLAY_NAME_MAPPINGS = { "TemporalHintFromPair": "Temporal Hint From Pair (6ch)", }