saliacoel commited on
Commit
f9f33c8
·
verified ·
1 Parent(s): 46a5eb0

Upload 2 files

Browse files
Files changed (2) hide show
  1. eye_enlarger_v1.py +441 -0
  2. eye_enlarger_v2.py +269 -0
eye_enlarger_v1.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ # Optional deps. Node works without them, but:
9
+ # - SciPy or OpenCV is used for robust connected-components on the mask.
10
+ # - OpenCV is required for Lanczos interpolation.
11
+ try:
12
+ import cv2 # type: ignore
13
+ except Exception:
14
+ cv2 = None
15
+
16
+ try:
17
+ from scipy import ndimage # type: ignore
18
+ except Exception:
19
+ ndimage = None
20
+
21
+
22
+ def _ensure_mask_2d_numpy(mask_t: torch.Tensor) -> np.ndarray:
23
+ """
24
+ Convert a ComfyUI MASK/IMAGE-like tensor into a single 2D numpy float32 array in [0,1].
25
+ Accepts shapes:
26
+ - [H, W]
27
+ - [H, W, C]
28
+ - [B, H, W]
29
+ - [B, H, W, C]
30
+ """
31
+ m = mask_t.detach().float().cpu()
32
+
33
+ if m.ndim == 2:
34
+ pass
35
+ elif m.ndim == 3:
36
+ # Either [H,W,C] or [B,H,W]
37
+ if m.shape[-1] in (1, 3, 4):
38
+ m = m[..., 0]
39
+ else:
40
+ m = m[0]
41
+ elif m.ndim == 4:
42
+ # [B,H,W,C]
43
+ if m.shape[-1] in (1, 3, 4):
44
+ m = m[0, ..., 0]
45
+ else:
46
+ m = m[0, 0]
47
+ else:
48
+ raise ValueError(f"Unsupported mask ndim={m.ndim}")
49
+
50
+ m_np = m.numpy().astype(np.float32)
51
+ if m_np.size == 0:
52
+ return m_np
53
+
54
+ # If someone fed an 8-bit mask as float, normalize.
55
+ if m_np.max() > 1.5:
56
+ m_np = m_np / 255.0
57
+ return np.clip(m_np, 0.0, 1.0).astype(np.float32)
58
+
59
+
60
+ def _find_eye_centers_numpy(
61
+ mask_2d: np.ndarray,
62
+ max_centers: int = 2,
63
+ threshold: float = 0.5,
64
+ min_area: int = 16,
65
+ ) -> List[Tuple[float, float]]:
66
+ """
67
+ Detect up to `max_centers` white blobs in a black/white mask and return their centers (cx, cy),
68
+ in image pixel coordinates.
69
+
70
+ - Prefers OpenCV connectedComponentsWithStats when available
71
+ - Falls back to SciPy ndimage.label
72
+ - Final fallback: numpy-only split by largest x-gap (works well for 1-2 separated squares)
73
+ """
74
+ if mask_2d.size == 0:
75
+ return []
76
+
77
+ binary = mask_2d > threshold
78
+ if not binary.any():
79
+ return []
80
+
81
+ # --- OpenCV path (fast, robust) ---
82
+ if cv2 is not None:
83
+ bin_u8 = (binary.astype(np.uint8) * 255)
84
+ num, _labels, stats, _centroids = cv2.connectedComponentsWithStats(bin_u8, connectivity=8)
85
+ # label 0 is background
86
+ if num <= 1:
87
+ return []
88
+ areas = stats[1:, cv2.CC_STAT_AREA].astype(np.int64)
89
+ order = np.argsort(-areas)
90
+
91
+ centers: List[Tuple[float, float]] = []
92
+ for idx in order:
93
+ area = int(areas[idx])
94
+ if area < min_area:
95
+ continue
96
+ lab = idx + 1
97
+ x = int(stats[lab, cv2.CC_STAT_LEFT])
98
+ y = int(stats[lab, cv2.CC_STAT_TOP])
99
+ w = int(stats[lab, cv2.CC_STAT_WIDTH])
100
+ h = int(stats[lab, cv2.CC_STAT_HEIGHT])
101
+ # bounding-box center (matches "white square center" use-case)
102
+ cx = x + (w - 1) / 2.0
103
+ cy = y + (h - 1) / 2.0
104
+ centers.append((cx, cy))
105
+ if len(centers) >= max_centers:
106
+ break
107
+ return centers
108
+
109
+ # --- SciPy path ---
110
+ if ndimage is not None:
111
+ labeled, num = ndimage.label(binary)
112
+ if num <= 0:
113
+ return []
114
+ slices = ndimage.find_objects(labeled)
115
+ areas = ndimage.sum(binary.astype(np.uint8), labeled, index=np.arange(1, num + 1))
116
+ areas = np.asarray(areas, dtype=np.float32)
117
+ order = np.argsort(-areas)
118
+
119
+ centers = []
120
+ for idx in order:
121
+ if areas[idx] < min_area:
122
+ continue
123
+ sl = slices[idx]
124
+ if sl is None:
125
+ continue
126
+ ysl, xsl = sl
127
+ y0, y1 = ysl.start, ysl.stop
128
+ x0, x1 = xsl.start, xsl.stop
129
+ cx = (x0 + x1 - 1) / 2.0
130
+ cy = (y0 + y1 - 1) / 2.0
131
+ centers.append((cx, cy))
132
+ if len(centers) >= max_centers:
133
+ break
134
+ return centers
135
+
136
+ # --- Numpy-only fallback (assumes up to 2 separated blobs) ---
137
+ ys, xs = np.where(binary)
138
+ if xs.size < min_area:
139
+ return []
140
+
141
+ xs_sorted = np.sort(xs)
142
+ if xs_sorted.size < 2:
143
+ return [(float(xs_sorted[0]), float(ys[0]))]
144
+
145
+ diffs = np.diff(xs_sorted)
146
+ gap_idx = int(np.argmax(diffs))
147
+ gap = float(diffs[gap_idx])
148
+
149
+ gap_threshold = 10.0
150
+ centers: List[Tuple[float, float]] = []
151
+
152
+ if gap >= gap_threshold and max_centers >= 2:
153
+ split_x = (xs_sorted[gap_idx] + xs_sorted[gap_idx + 1]) / 2.0
154
+ left = binary.copy()
155
+ right = binary.copy()
156
+ left[:, int(math.ceil(split_x)) :] = False
157
+ right[:, : int(math.floor(split_x)) + 1] = False
158
+
159
+ for b in (left, right):
160
+ y2, x2 = np.where(b)
161
+ if x2.size < min_area:
162
+ continue
163
+ x0, x1 = int(x2.min()), int(x2.max())
164
+ y0, y1 = int(y2.min()), int(y2.max())
165
+ centers.append(((x0 + x1) / 2.0, (y0 + y1) / 2.0))
166
+ if len(centers) >= max_centers:
167
+ break
168
+ else:
169
+ x0, x1 = int(xs.min()), int(xs.max())
170
+ y0, y1 = int(ys.min()), int(ys.max())
171
+ centers.append(((x0 + x1) / 2.0, (y0 + y1) / 2.0))
172
+
173
+ return centers
174
+
175
+
176
+ def _falloff_weight_torch(r: torch.Tensor, R: float, hardness: int) -> torch.Tensor:
177
+ """
178
+ Brush hardness like GIMP:
179
+ hardness=0 -> smooth falloff from center to radius
180
+ hardness=100 -> hard edge (full strength until radius)
181
+ """
182
+ R = float(max(R, 1e-6))
183
+ h = float(np.clip(hardness, 0, 100))
184
+ inner = R * (h / 100.0)
185
+
186
+ if inner >= R - 1e-6:
187
+ return (r <= R).to(r.dtype)
188
+
189
+ t = (r - inner) / (R - inner)
190
+ t = t.clamp(0.0, 1.0)
191
+ smooth = t * t * (3.0 - 2.0 * t)
192
+ return (1.0 - smooth) * (r <= R).to(r.dtype)
193
+
194
+
195
+ def _bulge_warp_patch_torch_inplace(
196
+ img_nchw: torch.Tensor, # [1,C,H,W]
197
+ cx: float,
198
+ cy: float,
199
+ radius: float,
200
+ hardness: int,
201
+ strength: int,
202
+ mode: str, # "nearest" or "bilinear"
203
+ ) -> None:
204
+ _, _, H, W = img_nchw.shape
205
+ R = float(radius)
206
+
207
+ x0 = int(max(0, math.floor(cx - R)))
208
+ x1 = int(min(W, math.ceil(cx + R) + 1))
209
+ y0 = int(max(0, math.floor(cy - R)))
210
+ y1 = int(min(H, math.ceil(cy + R) + 1))
211
+ if (x1 - x0) < 2 or (y1 - y0) < 2:
212
+ return
213
+
214
+ patch = img_nchw[:, :, y0:y1, x0:x1]
215
+ ph = y1 - y0
216
+ pw = x1 - x0
217
+
218
+ cx_l = float(cx - x0)
219
+ cy_l = float(cy - y0)
220
+
221
+ device = patch.device
222
+ dtype = patch.dtype
223
+
224
+ ys = torch.arange(ph, device=device, dtype=torch.float32)
225
+ xs = torch.arange(pw, device=device, dtype=torch.float32)
226
+ y, x = torch.meshgrid(ys, xs, indexing="ij")
227
+
228
+ dx = x - cx_l
229
+ dy = y - cy_l
230
+ r = torch.sqrt(dx * dx + dy * dy + 1e-8)
231
+
232
+ w = _falloff_weight_torch(r, R, hardness)
233
+
234
+ amount = float(np.clip(strength, 0, 100)) / 100.0
235
+ s = 1.0 + amount * w # scale factor (dest samples closer to center -> enlarges feature)
236
+
237
+ src_x = cx_l + dx / s
238
+ src_y = cy_l + dy / s
239
+
240
+ # Normalize to [-1,1] for grid_sample
241
+ x_norm = (src_x / (pw - 1)) * 2.0 - 1.0
242
+ y_norm = (src_y / (ph - 1)) * 2.0 - 1.0
243
+ grid = torch.stack((x_norm, y_norm), dim=-1).unsqueeze(0) # [1,ph,pw,2]
244
+
245
+ warped = F.grid_sample(
246
+ patch,
247
+ grid.to(dtype),
248
+ mode=mode,
249
+ padding_mode="border",
250
+ align_corners=True,
251
+ )
252
+
253
+ img_nchw[:, :, y0:y1, x0:x1] = warped
254
+
255
+
256
+ def _falloff_weight_numpy(r: np.ndarray, R: float, hardness: int) -> np.ndarray:
257
+ R = float(max(R, 1e-6))
258
+ h = float(np.clip(hardness, 0, 100))
259
+ inner = R * (h / 100.0)
260
+
261
+ if inner >= R - 1e-6:
262
+ return (r <= R).astype(np.float32)
263
+
264
+ t = (r - inner) / (R - inner)
265
+ t = np.clip(t, 0.0, 1.0)
266
+ smooth = t * t * (3.0 - 2.0 * t)
267
+ return (1.0 - smooth).astype(np.float32) * (r <= R).astype(np.float32)
268
+
269
+
270
+ def _bulge_warp_patch_cv2_inplace(
271
+ img_hwc: np.ndarray, # float32, [H,W,C], in [0,1]
272
+ cx: float,
273
+ cy: float,
274
+ radius: float,
275
+ hardness: int,
276
+ strength: int,
277
+ interp: str, # "none"|"bilinear"|"lanczos"
278
+ ) -> None:
279
+ if cv2 is None:
280
+ raise RuntimeError("OpenCV (cv2) is required for Lanczos interpolation but is not installed.")
281
+
282
+ H, W, _ = img_hwc.shape
283
+ R = float(radius)
284
+
285
+ x0 = int(max(0, math.floor(cx - R)))
286
+ x1 = int(min(W, math.ceil(cx + R) + 1))
287
+ y0 = int(max(0, math.floor(cy - R)))
288
+ y1 = int(min(H, math.ceil(cy + R) + 1))
289
+ if (x1 - x0) < 2 or (y1 - y0) < 2:
290
+ return
291
+
292
+ patch = img_hwc[y0:y1, x0:x1]
293
+ ph, pw = patch.shape[:2]
294
+ cx_l = float(cx - x0)
295
+ cy_l = float(cy - y0)
296
+
297
+ xs, ys = np.meshgrid(np.arange(pw, dtype=np.float32), np.arange(ph, dtype=np.float32))
298
+ dx = xs - cx_l
299
+ dy = ys - cy_l
300
+ r = np.sqrt(dx * dx + dy * dy + 1e-8).astype(np.float32)
301
+
302
+ w = _falloff_weight_numpy(r, R, hardness)
303
+ amount = float(np.clip(strength, 0, 100)) / 100.0
304
+ s = 1.0 + amount * w
305
+
306
+ map_x = (cx_l + dx / s).astype(np.float32)
307
+ map_y = (cy_l + dy / s).astype(np.float32)
308
+
309
+ map_x = np.clip(map_x, 0.0, pw - 1.0)
310
+ map_y = np.clip(map_y, 0.0, ph - 1.0)
311
+
312
+ if interp == "none":
313
+ cv_interp = cv2.INTER_NEAREST
314
+ elif interp == "bilinear":
315
+ cv_interp = cv2.INTER_LINEAR
316
+ elif interp == "lanczos":
317
+ cv_interp = cv2.INTER_LANCZOS4
318
+ else:
319
+ cv_interp = cv2.INTER_LINEAR
320
+
321
+ warped = cv2.remap(
322
+ patch,
323
+ map_x,
324
+ map_y,
325
+ interpolation=cv_interp,
326
+ borderMode=cv2.BORDER_REFLECT_101,
327
+ )
328
+
329
+ img_hwc[y0:y1, x0:x1] = warped
330
+
331
+
332
+ class EyeWarpEnlargeFromMask_v1:
333
+ """
334
+ ComfyUI node:
335
+ - Input: IMAGE + MASK
336
+ - Finds 1-2 white squares/blobs in mask
337
+ - Uses their centers as "eye positions"
338
+ - Applies a local "Grow/Bulge" warp around each center
339
+ """
340
+
341
+ @classmethod
342
+ def INPUT_TYPES(cls):
343
+ return {
344
+ "required": {
345
+ "image": ("IMAGE",),
346
+ "mask": ("MASK",),
347
+ "size": ("INT", {"default": 60, "min": 1, "max": 2048, "step": 1}),
348
+ "hardness": ("INT", {"default": 50, "min": 0, "max": 100, "step": 1}),
349
+ "strength": ("INT", {"default": 50, "min": 0, "max": 100, "step": 1}),
350
+ "repeat": ("INT", {"default": 1, "min": 1, "max": 100, "step": 1}),
351
+ "interpolation": (["none", "bilinear", "lanczos"], {"default": "bilinear"}),
352
+ }
353
+ }
354
+
355
+ RETURN_TYPES = ("IMAGE",)
356
+ FUNCTION = "apply"
357
+ CATEGORY = "image/warp"
358
+
359
+ def apply(self, image, mask, size, hardness, strength, repeat, interpolation):
360
+ if not isinstance(image, torch.Tensor):
361
+ raise TypeError("image must be a torch.Tensor (ComfyUI IMAGE).")
362
+ if not isinstance(mask, torch.Tensor):
363
+ raise TypeError("mask must be a torch.Tensor (ComfyUI MASK).")
364
+
365
+ img = image
366
+ B, H, W, C = img.shape
367
+
368
+ # Align mask batch + resolution
369
+ m = mask
370
+ if m.ndim == 2:
371
+ m = m.unsqueeze(0)
372
+ if m.shape[0] != B:
373
+ if m.shape[0] == 1:
374
+ m = m.repeat(B, 1, 1)
375
+ else:
376
+ m = m[:B]
377
+
378
+ if m.shape[1] != H or m.shape[2] != W:
379
+ m = F.interpolate(m.unsqueeze(1), size=(H, W), mode="nearest").squeeze(1)
380
+
381
+ # Treat "size" like GIMP brush size (diameter): radius = size/2
382
+ radius = max(1.0, float(size) * 0.5)
383
+
384
+ outs = []
385
+ for i in range(B):
386
+ img_i = img[i : i + 1] # [1,H,W,C]
387
+ mask_i_np = _ensure_mask_2d_numpy(m[i]) # [H,W] np
388
+
389
+ centers = _find_eye_centers_numpy(mask_i_np, max_centers=2, threshold=0.5, min_area=16)
390
+
391
+ if len(centers) == 0 or strength <= 0 or repeat <= 0:
392
+ outs.append(img_i)
393
+ continue
394
+
395
+ # Lanczos via OpenCV (CPU). Otherwise torch path.
396
+ if interpolation == "lanczos" and cv2 is not None:
397
+ img_np = img_i[0].detach().float().cpu().numpy()
398
+ if img_np.max() > 1.5:
399
+ img_np = img_np / 255.0
400
+ img_np = np.clip(img_np, 0.0, 1.0).astype(np.float32)
401
+
402
+ for _ in range(int(repeat)):
403
+ for cx, cy in centers:
404
+ _bulge_warp_patch_cv2_inplace(
405
+ img_np, cx, cy, radius, hardness, strength, interp="lanczos"
406
+ )
407
+
408
+ out_i = torch.from_numpy(img_np).unsqueeze(0)
409
+ if img_i.device.type != "cpu":
410
+ out_i = out_i.to(img_i.device)
411
+ outs.append(out_i)
412
+ else:
413
+ mode = "nearest" if interpolation == "none" else "bilinear"
414
+ img_nchw = img_i.permute(0, 3, 1, 2).contiguous().clone()
415
+
416
+ for _ in range(int(repeat)):
417
+ for cx, cy in centers:
418
+ _bulge_warp_patch_torch_inplace(
419
+ img_nchw,
420
+ cx=cx,
421
+ cy=cy,
422
+ radius=radius,
423
+ hardness=hardness,
424
+ strength=strength,
425
+ mode=mode,
426
+ )
427
+
428
+ out_i = img_nchw.permute(0, 2, 3, 1).contiguous()
429
+ outs.append(out_i)
430
+
431
+ out = torch.cat(outs, dim=0).clamp(0.0, 1.0)
432
+ return (out,)
433
+
434
+
435
+ NODE_CLASS_MAPPINGS = {
436
+ "EyeWarpEnlargeFromMask_v1": EyeWarpEnlargeFromMask_v1,
437
+ }
438
+
439
+ NODE_DISPLAY_NAME_MAPPINGS = {
440
+ "EyeWarpEnlargeFromMask_v1": "Eye Warp Enlarge (Mask Centers)",
441
+ }
eye_enlarger_v2.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ try:
5
+ import cv2
6
+ except Exception:
7
+ cv2 = None
8
+
9
+
10
+ def _require_cv2():
11
+ if cv2 is None:
12
+ raise ImportError(
13
+ "Eye Warp Enlarge node requires OpenCV (cv2). "
14
+ "Install opencv-python in your ComfyUI environment."
15
+ )
16
+
17
+
18
+ def _as_mask_2d(mask_tensor: torch.Tensor) -> np.ndarray:
19
+ """
20
+ ComfyUI MASK is typically (H, W) or (H, W, 1) as float [0..1].
21
+ This returns float32 (H, W).
22
+ """
23
+ m = mask_tensor.detach().float().cpu().numpy()
24
+ if m.ndim == 3:
25
+ # e.g. (H, W, 1) or (H, W, C) -> take first channel
26
+ m = m[:, :, 0]
27
+ if m.ndim != 2:
28
+ raise ValueError(f"Unsupported mask shape: {m.shape}")
29
+ return m.astype(np.float32, copy=False)
30
+
31
+
32
+ def _resize_mask_to(mask_2d: np.ndarray, w: int, h: int) -> np.ndarray:
33
+ _require_cv2()
34
+ if mask_2d.shape == (h, w):
35
+ return mask_2d
36
+ return cv2.resize(mask_2d, (w, h), interpolation=cv2.INTER_NEAREST).astype(np.float32, copy=False)
37
+
38
+
39
+ def _find_eye_centers_from_mask(mask_2d: np.ndarray, max_centers: int = 2):
40
+ """
41
+ Finds up to 2 connected white components in a binary-ish mask.
42
+ Returns centers as [(cx, cy), ...] in pixel coordinates (float).
43
+ """
44
+ _require_cv2()
45
+
46
+ h, w = mask_2d.shape
47
+ # Threshold to binary (mask is black/white but may be float)
48
+ binary = (mask_2d > 0.5).astype(np.uint8)
49
+ if binary.max() == 0:
50
+ return []
51
+
52
+ # Connected components
53
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
54
+
55
+ # Filter components
56
+ img_area = float(h * w)
57
+ min_area = max(16, int(0.00001 * img_area)) # keep small squares, ignore tiny specks
58
+
59
+ comps = []
60
+ for label in range(1, num_labels):
61
+ area = int(stats[label, cv2.CC_STAT_AREA])
62
+ if area < min_area:
63
+ continue
64
+ # Ignore "everything is white" situations
65
+ if area > 0.95 * img_area:
66
+ continue
67
+ cx, cy = centroids[label]
68
+ comps.append((area, float(cx), float(cy)))
69
+
70
+ if not comps:
71
+ # Fallback: centroid of all white pixels
72
+ ys, xs = np.nonzero(binary)
73
+ if len(xs) == 0:
74
+ return []
75
+ return [(float(xs.mean()), float(ys.mean()))]
76
+
77
+ # Take largest components (squares)
78
+ comps.sort(key=lambda t: t[0], reverse=True)
79
+ comps = comps[:max_centers]
80
+ centers = [(cx, cy) for _, cx, cy in comps]
81
+ centers.sort(key=lambda p: p[0]) # left-to-right
82
+ return centers
83
+
84
+
85
+ def _build_magnify_maps(h: int, w: int, cx: float, cy: float, size: int, hardness: int, strength: int):
86
+ """
87
+ Builds (map_x, map_y) for cv2.remap that performs a "magnify/bulge" enlarge centered at (cx, cy).
88
+
89
+ size: brush DIAMETER in pixels (like typical brush size).
90
+ hardness: 0..100, controls how much of the radius is full-strength before falling off.
91
+ strength: 0..100, magnification intensity.
92
+ """
93
+ _require_cv2()
94
+
95
+ diameter = float(max(1, size))
96
+ radius = max(1.0, diameter * 0.5)
97
+
98
+ hardness = float(np.clip(hardness, 0, 100))
99
+ strength = float(np.clip(strength, 0, 100))
100
+ s = strength / 100.0 # 0..1
101
+
102
+ # Brush hardness model:
103
+ # inner = full-strength radius portion, outer ring falls off.
104
+ inner = radius * (hardness / 100.0)
105
+
106
+ # Coordinate grid
107
+ yy, xx = np.mgrid[0:h, 0:w].astype(np.float32)
108
+ dx = xx - np.float32(cx)
109
+ dy = yy - np.float32(cy)
110
+ dist = np.sqrt(dx * dx + dy * dy)
111
+
112
+ falloff = np.zeros((h, w), dtype=np.float32)
113
+
114
+ if inner >= radius - 1e-6:
115
+ # Very hard brush: almost step edge
116
+ falloff[dist <= radius] = 1.0
117
+ else:
118
+ # Full strength inside inner radius
119
+ inside = dist <= inner
120
+ falloff[inside] = 1.0
121
+
122
+ # Smooth falloff in the ring (inner..radius)
123
+ ring = (dist > inner) & (dist < radius)
124
+ if np.any(ring):
125
+ u = (dist[ring] - inner) / max(radius - inner, 1e-6) # 0..1
126
+ # Quadratic falloff (smooth-ish)
127
+ falloff[ring] = (1.0 - u) ** 2
128
+
129
+ active = dist < radius
130
+ if not np.any(active) or s <= 0.0:
131
+ # Identity maps
132
+ return xx, yy
133
+
134
+ # Local scale: 1..(1+s) depending on falloff
135
+ local_scale = 1.0 + (s * falloff)
136
+
137
+ map_x = xx.copy()
138
+ map_y = yy.copy()
139
+
140
+ # Inverse mapping for magnify:
141
+ # dest(p) samples from src closer to center: c + (p-c)/scale
142
+ map_x[active] = np.float32(cx) + dx[active] / local_scale[active]
143
+ map_y[active] = np.float32(cy) + dy[active] / local_scale[active]
144
+
145
+ return map_x.astype(np.float32, copy=False), map_y.astype(np.float32, copy=False)
146
+
147
+
148
+ def _interp_flag(mode: str) -> int:
149
+ _require_cv2()
150
+ mode = (mode or "").lower().strip()
151
+ if mode == "none":
152
+ return cv2.INTER_NEAREST
153
+ if mode == "bilinear":
154
+ return cv2.INTER_LINEAR
155
+ if mode == "lanczos":
156
+ return cv2.INTER_LANCZOS4
157
+ return cv2.INTER_LINEAR
158
+
159
+
160
+ class EyeWarpEnlargeFromMask_v2:
161
+ """
162
+ ComfyUI Node:
163
+ - image: IMAGE (B,H,W,C) float 0..1
164
+ - mask: MASK (B,H,W) float 0..1, containing 1 or 2 white squares
165
+ """
166
+
167
+ @classmethod
168
+ def INPUT_TYPES(cls):
169
+ return {
170
+ "required": {
171
+ "image": ("IMAGE",),
172
+ "mask": ("MASK",),
173
+
174
+ # Brush settings
175
+ "size": ("INT", {"default": 60, "min": 1, "max": 2048, "step": 1}),
176
+ "hardness": ("INT", {"default": 50, "min": 0, "max": 100, "step": 1}),
177
+ "strength": ("INT", {"default": 50, "min": 0, "max": 100, "step": 1}),
178
+
179
+ # No mouse distance => user controls repeats/passes
180
+ "repeat": ("INT", {"default": 1, "min": 1, "max": 100, "step": 1}),
181
+
182
+ # Interpolation
183
+ "interpolation": (["bilinear", "lanczos", "none"],),
184
+ }
185
+ }
186
+
187
+ RETURN_TYPES = ("IMAGE",)
188
+ FUNCTION = "apply"
189
+ CATEGORY = "image/warp"
190
+
191
+ def apply(self, image, mask, size=60, hardness=50, strength=50, repeat=1, interpolation="bilinear"):
192
+ _require_cv2()
193
+
194
+ if not isinstance(image, torch.Tensor):
195
+ raise TypeError("image must be a torch Tensor (ComfyUI IMAGE).")
196
+ if not isinstance(mask, torch.Tensor):
197
+ raise TypeError("mask must be a torch Tensor (ComfyUI MASK).")
198
+
199
+ device = image.device
200
+ dtype = image.dtype
201
+
202
+ # Ensure image is (B,H,W,C)
203
+ if image.ndim != 4:
204
+ raise ValueError(f"Expected image shape (B,H,W,C), got: {tuple(image.shape)}")
205
+ b, h, w, c = image.shape
206
+ if c not in (3, 4):
207
+ # still allow, but warn-ish via behavior
208
+ pass
209
+
210
+ # Ensure mask batch aligns (best effort)
211
+ if mask.ndim == 2:
212
+ mask = mask.unsqueeze(0)
213
+ if mask.ndim == 3 and mask.shape[0] != b:
214
+ # If single mask provided for a batch, broadcast it
215
+ if mask.shape[0] == 1 and b > 1:
216
+ mask = mask.repeat(b, 1, 1)
217
+ if mask.shape[0] != b:
218
+ raise ValueError(f"Mask batch ({mask.shape[0]}) does not match image batch ({b}).")
219
+
220
+ interp = _interp_flag(interpolation)
221
+
222
+ # Process on CPU with OpenCV, then return to original device
223
+ img_np = image.detach().float().cpu().numpy()
224
+ out_np = np.empty_like(img_np, dtype=np.float32)
225
+
226
+ for i in range(b):
227
+ frame = np.ascontiguousarray(img_np[i], dtype=np.float32) # (H,W,C)
228
+
229
+ mask_2d = _as_mask_2d(mask[i])
230
+ mask_2d = _resize_mask_to(mask_2d, w=w, h=h)
231
+
232
+ centers = _find_eye_centers_from_mask(mask_2d, max_centers=2)
233
+
234
+ if not centers or strength <= 0 or size <= 0 or repeat <= 0:
235
+ out_np[i] = frame
236
+ continue
237
+
238
+ # Precompute remap maps per center
239
+ maps = []
240
+ for (cx, cy) in centers:
241
+ mx, my = _build_magnify_maps(h=h, w=w, cx=cx, cy=cy, size=size, hardness=hardness, strength=strength)
242
+ maps.append((mx, my))
243
+
244
+ # Apply multiple passes (repeat)
245
+ for _ in range(int(repeat)):
246
+ for (mx, my) in maps:
247
+ frame = cv2.remap(
248
+ frame,
249
+ mx,
250
+ my,
251
+ interpolation=interp,
252
+ borderMode=cv2.BORDER_REFLECT_101,
253
+ )
254
+
255
+ # Clamp back to [0..1] to stay in ComfyUI's IMAGE range
256
+ frame = np.clip(frame, 0.0, 1.0)
257
+ out_np[i] = frame
258
+
259
+ out = torch.from_numpy(out_np).to(device=device, dtype=dtype)
260
+ return (out,)
261
+
262
+
263
+ NODE_CLASS_MAPPINGS = {
264
+ "EyeWarpEnlargeFromMask_v2": EyeWarpEnlargeFromMask_v2
265
+ }
266
+
267
+ NODE_DISPLAY_NAME_MAPPINGS = {
268
+ "EyeWarpEnlargeFromMask_v2": "Eye Warp Enlarge (Mask Centers) v2"
269
+ }