saliacoel commited on
Commit
f5d70cd
·
verified ·
1 Parent(s): e409f0f

Upload Custom_Batch_Output.py

Browse files
Files changed (1) hide show
  1. Custom_Batch_Output.py +86 -0
Custom_Batch_Output.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Save as: ComfyUI/custom_nodes/special_batch_split.py
2
+ # Restart ComfyUI after saving.
3
+
4
+ import torch
5
+
6
+
7
+ class SpecialBatchSplit:
8
+ """
9
+ Input:
10
+ - images (IMAGE batch, typically torch.Tensor [B, H, W, C])
11
+
12
+ Outputs:
13
+ - Batch_Up: [ ID 7 ] + [ IDs 9..25 ] + [ IDs 27..31 ] + [ IDs 33..36 ]
14
+ - Rife_x3: [ ID 4 ] + [ ID 37 ] (2-image batch)
15
+
16
+ Indexing is 0-based and ranges are inclusive (e.g., 9..25 includes both 9 and 25).
17
+
18
+ Safety behavior:
19
+ - If the input batch is too small (needs at least indices up to 37 => B >= 38),
20
+ or input is not a proper IMAGE tensor, the node returns the original input batch
21
+ for BOTH outputs.
22
+ """
23
+
24
+ CATEGORY = "image/batch"
25
+ FUNCTION = "make_special_batches"
26
+
27
+ RETURN_TYPES = ("IMAGE", "IMAGE")
28
+ RETURN_NAMES = ("Batch_Up", "Rife_x3")
29
+
30
+ @classmethod
31
+ def INPUT_TYPES(cls):
32
+ return {"required": {"images": ("IMAGE",)}}
33
+
34
+ @staticmethod
35
+ def _normalize_to_batch(images: torch.Tensor) -> torch.Tensor:
36
+ # Accept single image [H,W,C] and convert to [1,H,W,C]
37
+ if images.dim() == 3:
38
+ return images.unsqueeze(0)
39
+ return images
40
+
41
+ def make_special_batches(self, images):
42
+ # Basic validation + safe fallback
43
+ if not isinstance(images, torch.Tensor):
44
+ return (images, images)
45
+
46
+ images = self._normalize_to_batch(images)
47
+
48
+ # Expect [B,H,W,C]
49
+ if images.dim() != 4:
50
+ return (images, images)
51
+
52
+ b = int(images.shape[0])
53
+
54
+ # Need indices up to 37 => batch size at least 38
55
+ if b < 38:
56
+ return (images, images)
57
+
58
+ # Build Batch_Up indices (inclusive ranges)
59
+ batch_up_indices = (
60
+ [7]
61
+ + list(range(9, 26)) # 9..25
62
+ + list(range(27, 32)) # 27..31
63
+ + list(range(33, 37)) # 33..36
64
+ )
65
+
66
+ # Build Rife_x3 indices
67
+ rife_x3_indices = [4, 37]
68
+
69
+ # Gather using index_select (works on GPU/CPU, preserves dtype/device)
70
+ device = images.device
71
+ idx_up = torch.tensor(batch_up_indices, dtype=torch.long, device=device)
72
+ idx_rife = torch.tensor(rife_x3_indices, dtype=torch.long, device=device)
73
+
74
+ batch_up = torch.index_select(images, 0, idx_up).clone()
75
+ rife_x3 = torch.index_select(images, 0, idx_rife).clone()
76
+
77
+ return (batch_up, rife_x3)
78
+
79
+
80
+ NODE_CLASS_MAPPINGS = {
81
+ "Custom_Batch_Output": Custom_Batch_Output,
82
+ }
83
+
84
+ NODE_DISPLAY_NAME_MAPPINGS = {
85
+ "Custom_Batch_Output": "Custom_Batch_Output",
86
+ }