saliacoel commited on
Commit
15e94d2
·
verified ·
1 Parent(s): ba66a5e

Upload img_alpha_composite_coords.py

Browse files
Files changed (1) hide show
  1. img_alpha_composite_coords.py +179 -0
img_alpha_composite_coords.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # paste_rgba_at_xy.py
2
+ # ComfyUI custom node: Paste a small RGBA image onto a larger RGBA canvas at (x, y)
3
+ # Supports alpha compositing ("source-over") or hard replace.
4
+ #
5
+ # Usage:
6
+ # - canvas: big RGBA image (IMAGE)
7
+ # - overlay: small RGBA image (IMAGE)
8
+ # - x, y: top-left destination coordinate on the canvas
9
+ # - blend_mode: "alpha_over" (default) or "replace"
10
+ # - If either input is RGB (3ch), it will be upgraded to RGBA with alpha=1.0
11
+ # - Batching:
12
+ # * If one input has batch size 1 and the other >1, the single image is broadcast.
13
+ # * If both have batch >1, their batch sizes must match.
14
+
15
+ import torch
16
+
17
+ class PasteRGBAAtXY:
18
+ """
19
+ Paste a small RGBA image on a larger RGBA canvas at a specified (x, y) coordinate.
20
+
21
+ - Default blend mode is proper alpha compositing (SRC over).
22
+ - "replace" mode copies the overlay's RGBA pixels directly (no blending).
23
+ - Handles out-of-bounds and negative coordinates by clipping.
24
+ - Works with batches. Broadcasts a single overlay across a batch of canvases (and vice versa).
25
+ """
26
+
27
+ CATEGORY = "image/compose"
28
+ RETURN_TYPES = ("IMAGE",)
29
+ FUNCTION = "paste"
30
+
31
+ @classmethod
32
+ def INPUT_TYPES(cls):
33
+ return {
34
+ "required": {
35
+ "canvas": ("IMAGE",),
36
+ "overlay": ("IMAGE",),
37
+ "x": ("INT", {"default": 0, "min": -32768, "max": 32768, "step": 1}),
38
+ "y": ("INT", {"default": 0, "min": -32768, "max": 32768, "step": 1}),
39
+ "blend_mode": (["alpha_over", "replace"], {"default": "alpha_over"}),
40
+ }
41
+ }
42
+
43
+ @staticmethod
44
+ def _ensure_rgba(img: torch.Tensor) -> torch.Tensor:
45
+ """
46
+ Ensure an IMAGE tensor is RGBA. If RGB, append opaque alpha.
47
+ Shape convention in ComfyUI for IMAGE is [B, H, W, C] with float32 in [0, 1].
48
+ """
49
+ if img.ndim != 4:
50
+ raise ValueError(f"Expected image tensor with 4 dims [B,H,W,C], got shape {tuple(img.shape)}")
51
+ c = img.shape[-1]
52
+ if c == 4:
53
+ return img
54
+ if c == 3:
55
+ alpha = torch.ones((*img.shape[:-1], 1), dtype=img.dtype, device=img.device)
56
+ return torch.cat([img, alpha], dim=-1)
57
+ raise ValueError(f"Expected 3 or 4 channels, got {c}")
58
+
59
+ @staticmethod
60
+ def _pair_count(b1: int, b2: int) -> int:
61
+ """Return the output batch size if broadcasting is allowed."""
62
+ if b1 == b2:
63
+ return b1
64
+ if b1 == 1 or b2 == 1:
65
+ return max(b1, b2)
66
+ raise ValueError(f"Incompatible batch sizes: {b1} vs {b2}")
67
+
68
+ @staticmethod
69
+ def _get_batch(img: torch.Tensor, i: int) -> torch.Tensor:
70
+ """Fetch the i-th batch image with broadcasting if needed."""
71
+ b = img.shape[0]
72
+ return img[0] if b == 1 else img[i]
73
+
74
+ @staticmethod
75
+ def _alpha_over(dst_rgba: torch.Tensor, src_rgba: torch.Tensor, dx: int, dy: int) -> None:
76
+ """
77
+ Alpha-composite src onto dst in-place at integer offset (dx, dy).
78
+ Both tensors are HxWx4, float in [0,1]. Clips automatically when out of bounds.
79
+ """
80
+ Hc, Wc, _ = dst_rgba.shape
81
+ Ho, Wo, _ = src_rgba.shape
82
+
83
+ # Compute intersection rectangle on canvas (destination)
84
+ x0 = max(0, dx)
85
+ y0 = max(0, dy)
86
+ x1 = min(Wc, dx + Wo)
87
+ y1 = min(Hc, dy + Ho)
88
+
89
+ if x1 <= x0 or y1 <= y0:
90
+ return # Nothing overlaps
91
+
92
+ # Corresponding source crop
93
+ sx0 = x0 - dx
94
+ sy0 = y0 - dy
95
+ w = x1 - x0
96
+ h = y1 - y0
97
+
98
+ dst_region = dst_rgba[y0:y0+h, x0:x0+w, :]
99
+ src_region = src_rgba[sy0:sy0+h, sx0:sx0+w, :]
100
+
101
+ # Split channels
102
+ cb = dst_region[..., :3]
103
+ ab = dst_region[..., 3:4]
104
+ co = src_region[..., :3]
105
+ ao = src_region[..., 3:4]
106
+
107
+ # Premultiply colors
108
+ cb_p = cb * ab
109
+ co_p = co * ao
110
+
111
+ # Source-over composition
112
+ out_a = ao + ab * (1.0 - ao)
113
+ out_c_p = co_p + cb_p * (1.0 - ao)
114
+
115
+ # Convert back to straight (guard against divide by zero)
116
+ eps = 1e-8
117
+ out_c = torch.where(out_a > eps, out_c_p / out_a.clamp_min(eps), torch.zeros_like(out_c_p))
118
+
119
+ # Write back (clamp just in case)
120
+ dst_region[..., :3] = out_c.clamp(0.0, 1.0)
121
+ dst_region[..., 3:4] = out_a.clamp(0.0, 1.0)
122
+
123
+ dst_rgba[y0:y0+h, x0:x0+w, :] = dst_region
124
+
125
+ @staticmethod
126
+ def _replace(dst_rgba: torch.Tensor, src_rgba: torch.Tensor, dx: int, dy: int) -> None:
127
+ """
128
+ Direct overwrite (copy) of src_rgba into dst_rgba at (dx, dy), clipped to bounds.
129
+ """
130
+ Hc, Wc, _ = dst_rgba.shape
131
+ Ho, Wo, _ = src_rgba.shape
132
+
133
+ x0 = max(0, dx)
134
+ y0 = max(0, dy)
135
+ x1 = min(Wc, dx + Wo)
136
+ y1 = min(Hc, dy + Ho)
137
+ if x1 <= x0 or y1 <= y0:
138
+ return
139
+
140
+ sx0 = x0 - dx
141
+ sy0 = y0 - dy
142
+ w = x1 - x0
143
+ h = y1 - y0
144
+
145
+ dst_rgba[y0:y0+h, x0:x0+w, :] = src_rgba[sy0:sy0+h, sx0:sx0+w, :]
146
+
147
+ def paste(self, canvas, overlay, x, y, blend_mode):
148
+ # Ensure RGBA
149
+ canvas = self._ensure_rgba(canvas)
150
+ overlay = self._ensure_rgba(overlay)
151
+
152
+ b_canvas = canvas.shape[0]
153
+ b_overlay = overlay.shape[0]
154
+ out_b = self._pair_count(b_canvas, b_overlay)
155
+
156
+ out_list = []
157
+ for i in range(out_b):
158
+ base = self._get_batch(canvas, i).clone() # HxWx4
159
+ over = self._get_batch(overlay, i) # HxWx4
160
+
161
+ if blend_mode == "alpha_over":
162
+ self._alpha_over(base, over, int(x), int(y))
163
+ else: # "replace"
164
+ self._replace(base, over, int(x), int(y))
165
+
166
+ out_list.append(base)
167
+
168
+ out = torch.stack(out_list, dim=0)
169
+ return (out,)
170
+
171
+
172
+ # --- ComfyUI registration ---
173
+ NODE_CLASS_MAPPINGS = {
174
+ "PasteRGBAAtXY": PasteRGBAAtXY,
175
+ }
176
+
177
+ NODE_DISPLAY_NAME_MAPPINGS = {
178
+ "PasteRGBAAtXY": "Paste RGBA at (X,Y)",
179
+ }