saliacoel commited on
Commit
a7b7bed
·
verified ·
1 Parent(s): 48e6078

Upload Get_Correct_Batch_Img.py

Browse files
Files changed (1) hide show
  1. Get_Correct_Batch_Img.py +171 -0
Get_Correct_Batch_Img.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class Get_Correct_Batch_Img:
5
+ """
6
+ Given a batch of RGBA images, selects:
7
+ - the sprite with the widest visible span along a given Y row (max_img)
8
+ - the sprite with the thinnest visible span along that same row (min_img)
9
+ - the sprite whose width is closest to the midpoint between min/max widths (avg_img)
10
+
11
+ Visibility is determined from the alpha channel (A > 0).
12
+ Only images within [start_index, end_index] (inclusive) are considered.
13
+ """
14
+
15
+ # Where this node appears in the right-click menu:
16
+ CATEGORY = "image/batch"
17
+
18
+ @classmethod
19
+ def INPUT_TYPES(s):
20
+ return {
21
+ "required": {
22
+ # RGBA image batch: torch.Tensor [B, H, W, 4]
23
+ "images": ("IMAGE",),
24
+
25
+ # Sub-batch start index (inclusive, 0-based)
26
+ "start_index": (
27
+ "INT",
28
+ {
29
+ "default": 0,
30
+ "min": 0,
31
+ "max": 2_147_483_647,
32
+ "step": 1,
33
+ },
34
+ ),
35
+
36
+ # Sub-batch end index (inclusive, 0-based)
37
+ "end_index": (
38
+ "INT",
39
+ {
40
+ "default": 0,
41
+ "min": 0,
42
+ "max": 2_147_483_647,
43
+ "step": 1,
44
+ },
45
+ ),
46
+
47
+ # Y coordinate (row) used for the horizontal scan
48
+ "y_coord": (
49
+ "INT",
50
+ {
51
+ "default": 0,
52
+ "min": 0,
53
+ "max": 2_147_483_647,
54
+ "step": 1,
55
+ },
56
+ ),
57
+ }
58
+ }
59
+
60
+ # Three RGBA images out now
61
+ RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
62
+ RETURN_NAMES = ("max_img", "min_img", "avg_img")
63
+ FUNCTION = "select"
64
+
65
+ def select(self, images, start_index, end_index, y_coord):
66
+ # Basic sanity checks
67
+ if not isinstance(images, torch.Tensor):
68
+ raise TypeError(f"Expected IMAGE tensor, got {type(images)}")
69
+
70
+ if images.ndim != 4:
71
+ raise ValueError(
72
+ f"Expected IMAGE of shape [B,H,W,C], got {tuple(images.shape)}"
73
+ )
74
+
75
+ batch_size, height, width, channels = images.shape
76
+
77
+ if channels != 4:
78
+ raise ValueError(
79
+ f"Expected RGBA image with 4 channels, got {channels}. "
80
+ "Make sure your input batch is RGBA (not RGB)."
81
+ )
82
+
83
+ if batch_size == 0:
84
+ raise ValueError("Empty image batch passed to Get_Correct_Batch_Img.")
85
+
86
+ # Clamp and normalize indices
87
+ start = max(0, min(int(start_index), batch_size - 1))
88
+ end = max(0, min(int(end_index), batch_size - 1))
89
+ if start > end:
90
+ start, end = end, start # swap so start <= end
91
+
92
+ # Clamp Y coordinate into image bounds
93
+ y = max(0, min(int(y_coord), height - 1))
94
+
95
+ # Track widest and thinnest sprite
96
+ max_width = None
97
+ min_width = None
98
+ max_idx = start
99
+ min_idx = start
100
+
101
+ # For AVG: store (index, width_px) for all valid sprites
102
+ widths = []
103
+
104
+ # Small alpha threshold; alpha > 0 is "visible"
105
+ alpha_threshold = 0.0
106
+ any_visible = False
107
+
108
+ # Loop over the requested sub-batch only
109
+ for i in range(start, end + 1):
110
+ # row_alpha shape: [W]
111
+ row_alpha = images[i, y, :, 3]
112
+ visible = row_alpha > alpha_threshold
113
+
114
+ if not torch.any(visible):
115
+ # No visible pixels on this row for this image; skip it
116
+ continue
117
+
118
+ any_visible = True
119
+
120
+ # Indices of visible pixels along X
121
+ visible_indices = torch.nonzero(visible, as_tuple=False).squeeze(1)
122
+ left_x = int(visible_indices[0])
123
+ right_x = int(visible_indices[-1])
124
+ width_px = right_x - left_x + 1 # inclusive distance
125
+
126
+ widths.append((i, width_px))
127
+
128
+ # Update max width (widest sprite)
129
+ if max_width is None or width_px > max_width:
130
+ max_width = width_px
131
+ max_idx = i
132
+
133
+ # Update min width (thinnest sprite)
134
+ if min_width is None or width_px < min_width:
135
+ min_width = width_px
136
+ min_idx = i
137
+
138
+ # If nothing had visible pixels on that Y, just return the first image
139
+ # in the sub-batch as all three outputs (so the node never crashes).
140
+ if not any_visible:
141
+ base_img = images[start].unsqueeze(0)
142
+ return (base_img, base_img, base_img)
143
+
144
+ # Compute midpoint between MIN and MAX widths
145
+ center_width = (min_width + max_width) / 2.0
146
+
147
+ # Find sprite whose width is closest to this center_width
148
+ avg_idx = max_idx # default
149
+ closest_diff = None
150
+ for idx, w in widths:
151
+ diff = abs(w - center_width)
152
+ if closest_diff is None or diff < closest_diff:
153
+ closest_diff = diff
154
+ avg_idx = idx
155
+
156
+ # Extract chosen sprites as batch size 1 (B=1, H, W, C)
157
+ max_img = images[max_idx].unsqueeze(0)
158
+ min_img = images[min_idx].unsqueeze(0)
159
+ avg_img = images[avg_idx].unsqueeze(0)
160
+
161
+ return (max_img, min_img, avg_img)
162
+
163
+
164
+ # Register node with ComfyUI
165
+ NODE_CLASS_MAPPINGS = {
166
+ "Get_Correct_Batch_Img": Get_Correct_Batch_Img,
167
+ }
168
+
169
+ NODE_DISPLAY_NAME_MAPPINGS = {
170
+ "Get_Correct_Batch_Img": "Get_Correct_Batch_Img (Salia)",
171
+ }