dtarnow commited on
Commit
7db13b4
·
verified ·
1 Parent(s): f0addf6

Upload temporal_hint_concat.py

Browse files
Files changed (1) hide show
  1. temporal_hint_concat.py +76 -0
temporal_hint_concat.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+
4
+ class TemporalHintFromPair:
5
+ """
6
+ Concatenate two RGB images (current & previous) along channel dim to produce a 6-channel IMAGE.
7
+ Works with batched tensors. If previous is None, it falls back to current (no-op for first frame).
8
+ """
9
+
10
+ @classmethod
11
+ def INPUT_TYPES(cls):
12
+ return {
13
+ "required": {
14
+ "current": ("IMAGE",),
15
+ "previous": ("IMAGE",),
16
+ },
17
+ "optional": {
18
+ "clip_to_range": ("BOOLEAN", {"default": True}),
19
+ },
20
+ }
21
+
22
+ RETURN_TYPES = ("IMAGE",)
23
+ RETURN_NAMES = ("temporal_hint",)
24
+ FUNCTION = "make_hint"
25
+ CATEGORY = "Temporal/Utils"
26
+
27
+ @staticmethod
28
+ def _ensure_batch(x):
29
+ if x.dim() == 3:
30
+ x = x.unsqueeze(0)
31
+ return x
32
+
33
+ @staticmethod
34
+ def _match_batch(a, b):
35
+ ba, bb = a.shape[0], b.shape[0]
36
+ if ba == bb:
37
+ return a, b
38
+ if ba == 1:
39
+ a = a.repeat(bb, 1, 1, 1)
40
+ elif bb == 1:
41
+ b = b.repeat(ba, 1, 1, 1)
42
+ else:
43
+ n = min(ba, bb)
44
+ a = a[:n]
45
+ b = b[:n]
46
+ return a, b
47
+
48
+ def make_hint(self, current, previous, clip_to_range=True):
49
+ current = self._ensure_batch(current)
50
+ previous = self._ensure_batch(previous)
51
+
52
+ if current.shape[-1] != 3 or previous.shape[-1] != 3:
53
+ raise ValueError(f"Expected RGB images with 3 channels; got {current.shape} & {previous.shape}")
54
+
55
+ current, previous = self._match_batch(current, previous)
56
+
57
+ if current.shape[1:3] != previous.shape[1:3]:
58
+ previous = torch.nn.functional.interpolate(
59
+ previous.permute(0,3,1,2), size=(current.shape[1], current.shape[2]), mode="nearest"
60
+ ).permute(0,2,3,1)
61
+
62
+ if clip_to_range:
63
+ current = current.clamp(0.0, 1.0)
64
+ previous = previous.clamp(0.0, 1.0)
65
+
66
+ temporal_hint = torch.cat([current, previous], dim=3)
67
+ return (temporal_hint,)
68
+
69
+
70
+ NODE_CLASS_MAPPINGS = {
71
+ "TemporalHintFromPair": TemporalHintFromPair,
72
+ }
73
+
74
+ NODE_DISPLAY_NAME_MAPPINGS = {
75
+ "TemporalHintFromPair": "Temporal Hint From Pair (6ch)",
76
+ }