saliacoel commited on
Commit
6f46f5e
·
verified ·
1 Parent(s): bfef095

Upload batch_img_replace.py

Browse files
Files changed (1) hide show
  1. batch_img_replace.py +128 -0
batch_img_replace.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ComfyUI Custom Nodes: Batch Index Tools
2
+ # Put this file at:
3
+ # ComfyUI/custom_nodes/batch_index_tools/__init__.py
4
+ # Restart ComfyUI, then look under category: "Batch/Index"
5
+
6
+ import torch
7
+
8
+
9
+ def _clamp_index(index: int, batch_size: int) -> int:
10
+ """Clamp index into [0, batch_size-1]."""
11
+ if batch_size <= 0:
12
+ raise ValueError("Input batch is empty (batch_size <= 0).")
13
+ if index < 0 or index >= batch_size:
14
+ print(
15
+ f"[BatchIndexTools] index {index} out of range for batch_size {batch_size}; "
16
+ f"clamping to valid range."
17
+ )
18
+ index = max(0, min(index, batch_size - 1))
19
+ return index
20
+
21
+
22
+ class BatchGetImageAtIndex:
23
+ """
24
+ Node 1:
25
+ - Takes an IMAGE batch and an integer index
26
+ - Outputs the image at that index (as a batch of size 1)
27
+ Notes:
28
+ - Index is zero-based (0 is the first image).
29
+ - If index is out of range, it is clamped to the nearest valid index.
30
+ """
31
+
32
+ @classmethod
33
+ def INPUT_TYPES(cls):
34
+ return {
35
+ "required": {
36
+ "images": ("IMAGE",),
37
+ "index": ("INT", {"default": 0, "min": 0, "max": 10**9}),
38
+ }
39
+ }
40
+
41
+ RETURN_TYPES = ("IMAGE",)
42
+ RETURN_NAMES = ("image",)
43
+ FUNCTION = "get"
44
+ CATEGORY = "Batch/Index"
45
+
46
+ def get(self, images, index):
47
+ if not torch.is_tensor(images):
48
+ raise TypeError("Expected 'images' to be a torch Tensor (ComfyUI IMAGE type).")
49
+ if images.ndim != 4:
50
+ raise ValueError(f"Expected 'images' with shape [B,H,W,C], got ndim={images.ndim}.")
51
+
52
+ b = images.shape[0]
53
+ idx = _clamp_index(int(index), b)
54
+
55
+ # Keep batch dimension (return a batch of 1)
56
+ out = images[idx : idx + 1]
57
+ return (out,)
58
+
59
+
60
+ class BatchReplaceImageAtIndex:
61
+ """
62
+ Node 2:
63
+ - Takes an IMAGE batch, an integer index, and a single IMAGE
64
+ - Replaces the batch item at that index with the provided image
65
+ - Outputs the modified batch
66
+ Notes:
67
+ - Index is zero-based (0 is the first image).
68
+ - If index is out of range, it is clamped to the nearest valid index.
69
+ - The replacement image must have the same H/W/C as the batch images.
70
+ - If 'image' is a batch, only the first image is used.
71
+ """
72
+
73
+ @classmethod
74
+ def INPUT_TYPES(cls):
75
+ return {
76
+ "required": {
77
+ "images": ("IMAGE",),
78
+ "index": ("INT", {"default": 0, "min": 0, "max": 10**9}),
79
+ "image": ("IMAGE",),
80
+ }
81
+ }
82
+
83
+ RETURN_TYPES = ("IMAGE",)
84
+ RETURN_NAMES = ("images",)
85
+ FUNCTION = "replace"
86
+ CATEGORY = "Batch/Index"
87
+
88
+ def replace(self, images, index, image):
89
+ if not torch.is_tensor(images) or not torch.is_tensor(image):
90
+ raise TypeError("Expected 'images' and 'image' to be torch Tensors (ComfyUI IMAGE type).")
91
+ if images.ndim != 4:
92
+ raise ValueError(f"Expected 'images' with shape [B,H,W,C], got ndim={images.ndim}.")
93
+ if image.ndim != 4:
94
+ raise ValueError(f"Expected 'image' with shape [B,H,W,C], got ndim={image.ndim}.")
95
+
96
+ b = images.shape[0]
97
+ idx = _clamp_index(int(index), b)
98
+
99
+ # Use first image if a batch is provided
100
+ replacement = image[:1]
101
+
102
+ # Validate spatial/channel match
103
+ if replacement.shape[1:] != images.shape[1:]:
104
+ raise ValueError(
105
+ "Replacement image must match batch image shape [H,W,C]. "
106
+ f"Batch has [H,W,C]={tuple(images.shape[1:])}, "
107
+ f"replacement has [H,W,C]={tuple(replacement.shape[1:])}."
108
+ )
109
+
110
+ # Make output without mutating input
111
+ out = images.clone()
112
+
113
+ # Ensure dtype/device match
114
+ rep0 = replacement[0].to(device=out.device, dtype=out.dtype)
115
+
116
+ out[idx] = rep0
117
+ return (out,)
118
+
119
+
120
+ NODE_CLASS_MAPPINGS = {
121
+ "BatchGetImageAtIndex": BatchGetImageAtIndex,
122
+ "BatchReplaceImageAtIndex": BatchReplaceImageAtIndex,
123
+ }
124
+
125
+ NODE_DISPLAY_NAME_MAPPINGS = {
126
+ "BatchGetImageAtIndex": "Batch: Get Image @ Index",
127
+ "BatchReplaceImageAtIndex": "Batch: Replace Image @ Index",
128
+ }