MechLORA / temporal_hint_concat.py
dtarnow's picture
Upload temporal_hint_concat.py
7db13b4 verified
import torch
class TemporalHintFromPair:
"""
Concatenate two RGB images (current & previous) along channel dim to produce a 6-channel IMAGE.
Works with batched tensors. If previous is None, it falls back to current (no-op for first frame).
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"current": ("IMAGE",),
"previous": ("IMAGE",),
},
"optional": {
"clip_to_range": ("BOOLEAN", {"default": True}),
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("temporal_hint",)
FUNCTION = "make_hint"
CATEGORY = "Temporal/Utils"
@staticmethod
def _ensure_batch(x):
if x.dim() == 3:
x = x.unsqueeze(0)
return x
@staticmethod
def _match_batch(a, b):
ba, bb = a.shape[0], b.shape[0]
if ba == bb:
return a, b
if ba == 1:
a = a.repeat(bb, 1, 1, 1)
elif bb == 1:
b = b.repeat(ba, 1, 1, 1)
else:
n = min(ba, bb)
a = a[:n]
b = b[:n]
return a, b
def make_hint(self, current, previous, clip_to_range=True):
current = self._ensure_batch(current)
previous = self._ensure_batch(previous)
if current.shape[-1] != 3 or previous.shape[-1] != 3:
raise ValueError(f"Expected RGB images with 3 channels; got {current.shape} & {previous.shape}")
current, previous = self._match_batch(current, previous)
if current.shape[1:3] != previous.shape[1:3]:
previous = torch.nn.functional.interpolate(
previous.permute(0,3,1,2), size=(current.shape[1], current.shape[2]), mode="nearest"
).permute(0,2,3,1)
if clip_to_range:
current = current.clamp(0.0, 1.0)
previous = previous.clamp(0.0, 1.0)
temporal_hint = torch.cat([current, previous], dim=3)
return (temporal_hint,)
NODE_CLASS_MAPPINGS = {
"TemporalHintFromPair": TemporalHintFromPair,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"TemporalHintFromPair": "Temporal Hint From Pair (6ch)",
}