|
|
|
|
|
import torch |
|
|
|
|
|
class TemporalHintFromPair: |
|
|
""" |
|
|
Concatenate two RGB images (current & previous) along channel dim to produce a 6-channel IMAGE. |
|
|
Works with batched tensors. If previous is None, it falls back to current (no-op for first frame). |
|
|
""" |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(cls): |
|
|
return { |
|
|
"required": { |
|
|
"current": ("IMAGE",), |
|
|
"previous": ("IMAGE",), |
|
|
}, |
|
|
"optional": { |
|
|
"clip_to_range": ("BOOLEAN", {"default": True}), |
|
|
}, |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
|
RETURN_NAMES = ("temporal_hint",) |
|
|
FUNCTION = "make_hint" |
|
|
CATEGORY = "Temporal/Utils" |
|
|
|
|
|
@staticmethod |
|
|
def _ensure_batch(x): |
|
|
if x.dim() == 3: |
|
|
x = x.unsqueeze(0) |
|
|
return x |
|
|
|
|
|
@staticmethod |
|
|
def _match_batch(a, b): |
|
|
ba, bb = a.shape[0], b.shape[0] |
|
|
if ba == bb: |
|
|
return a, b |
|
|
if ba == 1: |
|
|
a = a.repeat(bb, 1, 1, 1) |
|
|
elif bb == 1: |
|
|
b = b.repeat(ba, 1, 1, 1) |
|
|
else: |
|
|
n = min(ba, bb) |
|
|
a = a[:n] |
|
|
b = b[:n] |
|
|
return a, b |
|
|
|
|
|
def make_hint(self, current, previous, clip_to_range=True): |
|
|
current = self._ensure_batch(current) |
|
|
previous = self._ensure_batch(previous) |
|
|
|
|
|
if current.shape[-1] != 3 or previous.shape[-1] != 3: |
|
|
raise ValueError(f"Expected RGB images with 3 channels; got {current.shape} & {previous.shape}") |
|
|
|
|
|
current, previous = self._match_batch(current, previous) |
|
|
|
|
|
if current.shape[1:3] != previous.shape[1:3]: |
|
|
previous = torch.nn.functional.interpolate( |
|
|
previous.permute(0,3,1,2), size=(current.shape[1], current.shape[2]), mode="nearest" |
|
|
).permute(0,2,3,1) |
|
|
|
|
|
if clip_to_range: |
|
|
current = current.clamp(0.0, 1.0) |
|
|
previous = previous.clamp(0.0, 1.0) |
|
|
|
|
|
temporal_hint = torch.cat([current, previous], dim=3) |
|
|
return (temporal_hint,) |
|
|
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
|
"TemporalHintFromPair": TemporalHintFromPair, |
|
|
} |
|
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
|
"TemporalHintFromPair": "Temporal Hint From Pair (6ch)", |
|
|
} |
|
|
|