File size: 2,190 Bytes
7db13b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

import torch

class TemporalHintFromPair:
    """
    Concatenate two RGB images (current & previous) along channel dim to produce a 6-channel IMAGE.
    Works with batched tensors. If previous is None, it falls back to current (no-op for first frame).
    """

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "current": ("IMAGE",),
                "previous": ("IMAGE",),
            },
            "optional": {
                "clip_to_range": ("BOOLEAN", {"default": True}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("temporal_hint",)
    FUNCTION = "make_hint"
    CATEGORY = "Temporal/Utils"

    @staticmethod
    def _ensure_batch(x):
        if x.dim() == 3:
            x = x.unsqueeze(0)
        return x

    @staticmethod
    def _match_batch(a, b):
        ba, bb = a.shape[0], b.shape[0]
        if ba == bb:
            return a, b
        if ba == 1:
            a = a.repeat(bb, 1, 1, 1)
        elif bb == 1:
            b = b.repeat(ba, 1, 1, 1)
        else:
            n = min(ba, bb)
            a = a[:n]
            b = b[:n]
        return a, b

    def make_hint(self, current, previous, clip_to_range=True):
        current = self._ensure_batch(current)
        previous = self._ensure_batch(previous)

        if current.shape[-1] != 3 or previous.shape[-1] != 3:
            raise ValueError(f"Expected RGB images with 3 channels; got {current.shape} & {previous.shape}")

        current, previous = self._match_batch(current, previous)

        if current.shape[1:3] != previous.shape[1:3]:
            previous = torch.nn.functional.interpolate(
                previous.permute(0,3,1,2), size=(current.shape[1], current.shape[2]), mode="nearest"
            ).permute(0,2,3,1)

        if clip_to_range:
            current = current.clamp(0.0, 1.0)
            previous = previous.clamp(0.0, 1.0)

        temporal_hint = torch.cat([current, previous], dim=3)
        return (temporal_hint,)


NODE_CLASS_MAPPINGS = {
    "TemporalHintFromPair": TemporalHintFromPair,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "TemporalHintFromPair": "Temporal Hint From Pair (6ch)",
}