saliacoel commited on
Commit
7c57c5d
·
verified ·
1 Parent(s): d18b98d

Upload salia_square.py

Browse files
Files changed (1) hide show
  1. salia_square.py +454 -0
salia_square.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import threading
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+
11
+ import folder_paths
12
+
13
+
14
+ # -----------------------------------------------------------------------------
15
+ # Global caches (shared across node instances)
16
+ # -----------------------------------------------------------------------------
17
+ _CKPT_CACHE: Dict[str, Dict[str, Any]] = {}
18
+ _CONTROLNET_CACHE: Dict[str, Dict[str, Any]] = {}
19
+ _CACHE_LOCK = threading.RLock()
20
+
21
+ _ALLOWED_UPSCALE_FACTORS = (1, 2, 4, 6, 8, 10, 12, 14, 16)
22
+
23
+
24
+ # -----------------------------------------------------------------------------
25
+ # Lazy imports / caching
26
+ # -----------------------------------------------------------------------------
27
+ def _lazy_import_nodes():
28
+ import nodes # comfy-core nodes module
29
+ return nodes
30
+
31
+
32
+ def _get_mtime(path: Optional[str]) -> Optional[float]:
33
+ if not path:
34
+ return None
35
+ try:
36
+ return float(os.path.getmtime(path))
37
+ except Exception:
38
+ return None
39
+
40
+
41
+ def _load_checkpoint_cached(ckpt_name: str):
42
+ """
43
+ Returns (model, clip, vae) for ckpt_name.
44
+ Cached by (ckpt_name + file mtime).
45
+ """
46
+ nodes = _lazy_import_nodes()
47
+
48
+ ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
49
+ if not ckpt_path:
50
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt_name}")
51
+ mtime = _get_mtime(ckpt_path)
52
+
53
+ with _CACHE_LOCK:
54
+ entry = _CKPT_CACHE.get(ckpt_name)
55
+ if entry and entry.get("mtime") == mtime:
56
+ return entry["model"], entry["clip"], entry["vae"]
57
+
58
+ loader = nodes.CheckpointLoaderSimple()
59
+ model, clip, vae = loader.load_checkpoint(ckpt_name)
60
+
61
+ with _CACHE_LOCK:
62
+ _CKPT_CACHE[ckpt_name] = {"mtime": mtime, "model": model, "clip": clip, "vae": vae}
63
+
64
+ return model, clip, vae
65
+
66
+
67
+ def _load_controlnet_cached(controlnet_name: str):
68
+ """
69
+ Returns controlnet for controlnet_name.
70
+ Cached by (controlnet_name + file mtime).
71
+ """
72
+ nodes = _lazy_import_nodes()
73
+
74
+ cn_path = folder_paths.get_full_path("controlnet", controlnet_name)
75
+ if not cn_path:
76
+ raise FileNotFoundError(f"ControlNet not found: {controlnet_name}")
77
+ mtime = _get_mtime(cn_path)
78
+
79
+ with _CACHE_LOCK:
80
+ entry = _CONTROLNET_CACHE.get(controlnet_name)
81
+ if entry and entry.get("mtime") == mtime:
82
+ return entry["controlnet"]
83
+
84
+ loader = nodes.ControlNetLoader()
85
+ (controlnet,) = loader.load_controlnet(control_net_name=controlnet_name)
86
+
87
+ with _CACHE_LOCK:
88
+ _CONTROLNET_CACHE[controlnet_name] = {"mtime": mtime, "controlnet": controlnet}
89
+
90
+ return controlnet
91
+
92
+
93
+ # -----------------------------------------------------------------------------
94
+ # Comfy tensor helpers (IMAGE/MASK)
95
+ # -----------------------------------------------------------------------------
96
+ def _ensure_batched_image(image: torch.Tensor) -> torch.Tensor:
97
+ # Accept [H,W,C] or [B,H,W,C]
98
+ if image.ndim == 3:
99
+ return image.unsqueeze(0)
100
+ return image
101
+
102
+
103
+ def _ensure_batched_mask(mask: torch.Tensor) -> torch.Tensor:
104
+ # Accept [H,W] or [B,H,W] or [B,H,W,1]
105
+ if mask.ndim == 2:
106
+ return mask.unsqueeze(0)
107
+ if mask.ndim == 4 and mask.shape[-1] == 1:
108
+ return mask[..., 0]
109
+ return mask
110
+
111
+
112
+ def _comfy_image_batch_to_pil_list(img_bhwc: torch.Tensor) -> List[Image.Image]:
113
+ img_bhwc = _ensure_batched_image(img_bhwc).detach().cpu().float().clamp(0.0, 1.0)
114
+ b, h, w, c = img_bhwc.shape
115
+ if c not in (3, 4):
116
+ raise ValueError(f"Expected IMAGE with 3 or 4 channels, got {c} channels.")
117
+
118
+ out: List[Image.Image] = []
119
+ for i in range(b):
120
+ arr = (img_bhwc[i].numpy() * 255.0).round().astype(np.uint8)
121
+ mode = "RGB" if c == 3 else "RGBA"
122
+ out.append(Image.fromarray(arr, mode=mode))
123
+ return out
124
+
125
+
126
+ def _pil_list_to_comfy_image_batch(pils: List[Image.Image], want_channels: int) -> torch.Tensor:
127
+ if want_channels not in (3, 4):
128
+ raise ValueError("want_channels must be 3 or 4")
129
+
130
+ tensors: List[torch.Tensor] = []
131
+ for p in pils:
132
+ p = p.convert("RGB") if want_channels == 3 else p.convert("RGBA")
133
+ arr = np.array(p).astype(np.float32) / 255.0
134
+ tensors.append(torch.from_numpy(arr))
135
+ return torch.stack(tensors, dim=0)
136
+
137
+
138
+ def _resize_comfy_image_lanczos(img_bhwc: torch.Tensor, width: int, height: int) -> torch.Tensor:
139
+ img_bhwc = _ensure_batched_image(img_bhwc)
140
+ if width <= 0 or height <= 0:
141
+ raise ValueError("width/height must be > 0")
142
+
143
+ b, h, w, c = img_bhwc.shape
144
+ if (w == width) and (h == height):
145
+ return img_bhwc
146
+
147
+ pils = _comfy_image_batch_to_pil_list(img_bhwc)
148
+ resized = [p.resize((width, height), resample=Image.LANCZOS) for p in pils]
149
+ return _pil_list_to_comfy_image_batch(resized, want_channels=c)
150
+
151
+
152
+ def _resize_comfy_mask_lanczos(mask_bhw: torch.Tensor, width: int, height: int) -> torch.Tensor:
153
+ mask_bhw = _ensure_batched_mask(mask_bhw).detach().cpu().float().clamp(0.0, 1.0)
154
+ b, h, w = mask_bhw.shape
155
+ if width <= 0 or height <= 0:
156
+ raise ValueError("width/height must be > 0")
157
+ if (w == width) and (h == height):
158
+ return mask_bhw
159
+
160
+ out: List[torch.Tensor] = []
161
+ for i in range(b):
162
+ arr = (mask_bhw[i].numpy() * 255.0).round().astype(np.uint8)
163
+ pil = Image.fromarray(arr, mode="L")
164
+ pil = pil.resize((width, height), resample=Image.LANCZOS)
165
+ arr2 = np.array(pil).astype(np.float32) / 255.0
166
+ out.append(torch.from_numpy(arr2))
167
+ return torch.stack(out, dim=0)
168
+
169
+
170
+ def _repeat_batch_if_needed(t: torch.Tensor, target_b: int) -> torch.Tensor:
171
+ if t.ndim == 4:
172
+ b = t.shape[0]
173
+ if b == target_b:
174
+ return t
175
+ if b == 1:
176
+ return t.repeat(target_b, 1, 1, 1)
177
+ raise ValueError(f"Batch mismatch: tensor batch {b} vs target {target_b}")
178
+ if t.ndim == 3:
179
+ b = t.shape[0]
180
+ if b == target_b:
181
+ return t
182
+ if b == 1:
183
+ return t.repeat(target_b, 1, 1)
184
+ raise ValueError(f"Batch mismatch: tensor batch {b} vs target {target_b}")
185
+ raise ValueError("Unsupported tensor rank for batching")
186
+
187
+
188
+ def _alpha_over_composite_at_xy(base_bhwc: torch.Tensor, overlay_bhwc: torch.Tensor, x: int, y: int) -> torch.Tensor:
189
+ """
190
+ Alpha composite overlay (must be RGBA) over base at (x,y).
191
+ Output channels match base channels (RGB stays RGB; RGBA stays RGBA).
192
+ """
193
+ base_bhwc = _ensure_batched_image(base_bhwc).detach().cpu().float().clamp(0.0, 1.0)
194
+ overlay_bhwc = _ensure_batched_image(overlay_bhwc).detach().cpu().float().clamp(0.0, 1.0)
195
+
196
+ b0, H, W, Cb = base_bhwc.shape
197
+ b1, h, w, Co = overlay_bhwc.shape
198
+
199
+ if Co != 4:
200
+ raise ValueError("overlay must be RGBA (4 channels).")
201
+ if Cb not in (3, 4):
202
+ raise ValueError("base must have 3 or 4 channels.")
203
+ if b1 != b0:
204
+ overlay_bhwc = _repeat_batch_if_needed(overlay_bhwc, b0)
205
+
206
+ if x < 0 or y < 0 or (x + w) > W or (y + h) > H:
207
+ raise ValueError(f"Overlay out of bounds: base {W}x{H}, overlay {w}x{h}, x={x}, y={y}")
208
+
209
+ out = base_bhwc.clone()
210
+
211
+ ov_rgb = overlay_bhwc[..., 0:3]
212
+ ov_a = overlay_bhwc[..., 3:4]
213
+
214
+ region = out[:, y : y + h, x : x + w, :]
215
+ bd_rgb = region[..., 0:3]
216
+
217
+ if Cb == 3:
218
+ out_rgb = ov_rgb * ov_a + bd_rgb * (1.0 - ov_a)
219
+ out[:, y : y + h, x : x + w, 0:3] = out_rgb
220
+ return out.clamp(0.0, 1.0)
221
+
222
+ bd_a = region[..., 3:4]
223
+ out_a = ov_a + bd_a * (1.0 - ov_a)
224
+ out_rgb_premul = ov_rgb * ov_a + bd_rgb * bd_a * (1.0 - ov_a)
225
+ out_rgb = torch.where(out_a > 1e-8, out_rgb_premul / out_a, torch.zeros_like(out_rgb_premul))
226
+
227
+ out[:, y : y + h, x : x + w, 0:3] = out_rgb
228
+ out[:, y : y + h, x : x + w, 3:4] = out_a
229
+ return out.clamp(0.0, 1.0)
230
+
231
+
232
+ def _list_asset_pngs_fallback() -> List[str]:
233
+ """
234
+ Best-effort asset PNG listing:
235
+ 1) Try comfyui-salia_online/utils/io.py:list_pngs()
236
+ 2) Else scan ../assets/images relative to this file
237
+ """
238
+ try:
239
+ from ..utils.io import list_pngs # your plugin helper
240
+ choices = list_pngs()
241
+ if choices:
242
+ return choices
243
+ except Exception:
244
+ pass
245
+
246
+ try:
247
+ plugin_root = Path(__file__).resolve().parent.parent
248
+ images_dir = plugin_root / "assets" / "images"
249
+ if images_dir.exists():
250
+ return sorted([p.name for p in images_dir.glob("*.png")])
251
+ except Exception:
252
+ pass
253
+
254
+ return []
255
+
256
+
257
+ # -----------------------------------------------------------------------------
258
+ # The one-node workflow
259
+ # -----------------------------------------------------------------------------
260
+ class Salia_OneNode_SquareWorkflow:
261
+ """
262
+ One-node replacement for the described workflow.
263
+ """
264
+
265
+ CATEGORY = "image/salia"
266
+ RETURN_TYPES = ("IMAGE",)
267
+ RETURN_NAMES = ("image",)
268
+ FUNCTION = "run"
269
+
270
+ @classmethod
271
+ def INPUT_TYPES(cls):
272
+ # Keep INPUT_TYPES light: no model loads here.
273
+ try:
274
+ import comfy.samplers as samplers
275
+ sampler_names = list(getattr(samplers.KSampler, "SAMPLERS", [])) or ["euler"]
276
+ scheduler_names = list(getattr(samplers.KSampler, "SCHEDULERS", [])) or ["karras"]
277
+ except Exception:
278
+ sampler_names = ["euler"]
279
+ scheduler_names = ["karras"]
280
+
281
+ ckpts = folder_paths.get_filename_list("checkpoints") or ["<no checkpoints found>"]
282
+ cns = folder_paths.get_filename_list("controlnet") or ["<no controlnets found>"]
283
+ assets = _list_asset_pngs_fallback() or ["<no pngs found>"]
284
+
285
+ upscale_choices = [str(v) for v in _ALLOWED_UPSCALE_FACTORS]
286
+
287
+ return {
288
+ "required": {
289
+ "image": ("IMAGE",),
290
+
291
+ "X_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
292
+ "Y_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
293
+ "square_size": ("INT", {"default": 384, "min": 1, "max": 8192, "step": 1}),
294
+
295
+ "positive_prompt": ("STRING", {"multiline": True, "default": ""}),
296
+ "negative_prompt": ("STRING", {"multiline": True, "default": ""}),
297
+
298
+ "upscale_factor": (upscale_choices, {"default": "4"}),
299
+
300
+ "checkpoint_name": (ckpts, {}),
301
+ "controlnet_name": (cns, {}),
302
+ "assets_png": (assets, {}),
303
+
304
+ "controlnet_strength": ("FLOAT", {"default": 0.33, "min": 0.0, "max": 10.0, "step": 0.01}),
305
+ "controlnet_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
306
+ "controlnet_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
307
+
308
+ "steps": ("INT", {"default": 30, "min": 1, "max": 200, "step": 1}),
309
+ "cfg": ("FLOAT", {"default": 2.6, "min": 0.0, "max": 10.0, "step": 0.05}),
310
+ "sampler_name": (sampler_names, {"default": "euler"}),
311
+ "scheduler": (scheduler_names, {"default": "karras"}),
312
+ "denoise": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.01}),
313
+ }
314
+ }
315
+
316
+ def run(
317
+ self,
318
+ image: torch.Tensor,
319
+ X_coord: int,
320
+ Y_coord: int,
321
+ square_size: int,
322
+ positive_prompt: str,
323
+ negative_prompt: str,
324
+ upscale_factor: str,
325
+ checkpoint_name: str,
326
+ controlnet_name: str,
327
+ assets_png: str,
328
+ controlnet_strength: float = 0.33,
329
+ controlnet_start_percent: float = 0.0,
330
+ controlnet_end_percent: float = 1.0,
331
+ steps: int = 30,
332
+ cfg: float = 2.6,
333
+ sampler_name: str = "euler",
334
+ scheduler: str = "karras",
335
+ denoise: float = 0.35,
336
+ ):
337
+ # ---- validate ----
338
+ try:
339
+ uf = int(upscale_factor)
340
+ except Exception:
341
+ raise ValueError(f"Invalid upscale_factor: {upscale_factor}")
342
+
343
+ if uf not in _ALLOWED_UPSCALE_FACTORS:
344
+ raise ValueError(f"upscale_factor must be one of {_ALLOWED_UPSCALE_FACTORS}, got {uf}")
345
+
346
+ if square_size <= 0:
347
+ raise ValueError("square_size must be > 0")
348
+
349
+ # ---- crop ----
350
+ base = _ensure_batched_image(image)
351
+ b, H, W, C = base.shape
352
+ if C not in (3, 4):
353
+ raise ValueError(f"Input image must be RGB or RGBA (3/4 channels), got {C}")
354
+
355
+ x = int(X_coord)
356
+ y = int(Y_coord)
357
+ s = int(square_size)
358
+
359
+ if x < 0 or y < 0 or (x + s) > W or (y + s) > H:
360
+ raise ValueError(f"Crop out of bounds: image {W}x{H}, x={x}, y={y}, square_size={s}")
361
+
362
+ crop = base[:, y : y + s, x : x + s, :]
363
+ crop_rgb = crop[..., 0:3]
364
+
365
+ up_w = int(s * uf)
366
+ up_h = int(s * uf)
367
+
368
+ # ---- upscale crop (Lanczos) for VAEEncode ----
369
+ crop_up = _resize_comfy_image_lanczos(crop_rgb, width=up_w, height=up_h)
370
+
371
+ # ---- depth (Salia_Depth) then upscale depth (Lanczos) ----
372
+ # lazy import (don’t import transformers at module import time)
373
+ try:
374
+ from .salia_depth import Salia_Depth
375
+ except Exception:
376
+ from salia_depth import Salia_Depth
377
+
378
+ depth_node = Salia_Depth()
379
+ (depth_img,) = depth_node.execute(image=crop, resolution=-1) # keep original crop res
380
+ depth_img = _ensure_batched_image(depth_img)[..., 0:3]
381
+ depth_up = _resize_comfy_image_lanczos(depth_img, width=up_w, height=up_h)
382
+
383
+ # ---- load alpha mask from assets ----
384
+ try:
385
+ from .salia_loadimage_assets import LoadImage_SaliaOnline_Assets
386
+ except Exception:
387
+ from salia_loadimage_assets import LoadImage_SaliaOnline_Assets
388
+
389
+ assets_loader = LoadImage_SaliaOnline_Assets()
390
+ _asset_img, asset_mask = assets_loader.run(assets_png)
391
+ asset_mask = _ensure_batched_mask(asset_mask)
392
+ asset_mask = _resize_comfy_mask_lanczos(asset_mask, width=up_w, height=up_h)
393
+ asset_mask = _repeat_batch_if_needed(asset_mask, b)
394
+
395
+ # ---- load checkpoint + controlnet (cached) ----
396
+ model, clip, vae = _load_checkpoint_cached(checkpoint_name)
397
+ controlnet = _load_controlnet_cached(controlnet_name)
398
+
399
+ # ---- comfy core pipeline ----
400
+ nodes = _lazy_import_nodes()
401
+
402
+ (pos_cond,) = nodes.CLIPTextEncode().encode(clip=clip, text=positive_prompt)
403
+ (neg_cond,) = nodes.CLIPTextEncode().encode(clip=clip, text=negative_prompt)
404
+
405
+ pos_cn, neg_cn = nodes.ControlNetApplyAdvanced().apply_controlnet(
406
+ positive=pos_cond,
407
+ negative=neg_cond,
408
+ control_net=controlnet,
409
+ image=depth_up,
410
+ strength=float(controlnet_strength),
411
+ start_percent=float(controlnet_start_percent),
412
+ end_percent=float(controlnet_end_percent),
413
+ vae=vae,
414
+ )
415
+
416
+ (latent,) = nodes.VAEEncode().encode(pixels=crop_up, vae=vae)
417
+
418
+ # No seed input requested: generate a fresh seed per execution
419
+ seed = random.randint(0, 2**63 - 1)
420
+
421
+ (latent_out,) = nodes.KSampler().sample(
422
+ model=model,
423
+ seed=seed,
424
+ steps=int(steps),
425
+ cfg=float(cfg),
426
+ sampler_name=sampler_name,
427
+ scheduler=scheduler,
428
+ positive=pos_cn,
429
+ negative=neg_cn,
430
+ latent_image=latent,
431
+ denoise=float(denoise),
432
+ )
433
+
434
+ (decoded_rgb,) = nodes.VAEDecode().decode(samples=latent_out, vae=vae)
435
+
436
+ # Join alpha -> RGBA
437
+ (decoded_rgba_up,) = nodes.JoinImageWithAlpha().join(image=decoded_rgb, alpha=asset_mask)
438
+
439
+ # Downscale back to original square size (Lanczos)
440
+ decoded_rgba_down = _resize_comfy_image_lanczos(decoded_rgba_up, width=s, height=s)
441
+
442
+ # Composite onto original input at (x,y)
443
+ out = _alpha_over_composite_at_xy(base, decoded_rgba_down, x=x, y=y)
444
+
445
+ return (out,)
446
+
447
+
448
+ NODE_CLASS_MAPPINGS = {
449
+ "Salia_OneNode_SquareWorkflow": Salia_OneNode_SquareWorkflow,
450
+ }
451
+
452
+ NODE_DISPLAY_NAME_MAPPINGS = {
453
+ "Salia_OneNode_SquareWorkflow": "Salia One-Node Square Workflow",
454
+ }