saliacoel commited on
Commit
fc6c8da
·
verified ·
1 Parent(s): 0911006

Upload salia_detailer_ezpz.py

Browse files
Files changed (1) hide show
  1. salia_detailer_ezpz.py +85 -57
salia_detailer_ezpz.py CHANGED
@@ -22,7 +22,7 @@ except Exception as e:
22
 
23
 
24
  # -------------------------------------------------------------------------------------
25
- # Global caches (lazy-load + don't load duplicates across multiple node instances)
26
  # -------------------------------------------------------------------------------------
27
 
28
  _CKPT_CACHE: Dict[str, Tuple[Any, Any, Any]] = {}
@@ -32,19 +32,19 @@ _CN_LOCK = threading.Lock()
32
 
33
 
34
  # -------------------------------------------------------------------------------------
35
- # Plugin root detection (robust against hyphen/underscore module naming)
36
  # -------------------------------------------------------------------------------------
37
 
38
  def _find_plugin_root() -> Path:
39
  """
40
  Walk upwards from this file until we find an 'assets' folder.
41
- This works regardless of how Comfy names the python module.
42
  """
43
  here = Path(__file__).resolve()
44
- for parent in [here.parent] + list(here.parents)[:10]:
45
  if (parent / "assets").is_dir():
46
  return parent
47
- # fallback: typical layout nodes/<thisfile>.py -> plugin root is parent.parent
48
  return here.parent.parent
49
 
50
 
@@ -135,6 +135,48 @@ def _resize_mask_lanczos(mask: torch.Tensor, w: int, h: int) -> torch.Tensor:
135
  return torch.cat(outs, dim=0)
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  # -------------------------------------------------------------------------------------
139
  # Core lazy loaders (checkpoint + controlnet), cached globally
140
  # -------------------------------------------------------------------------------------
@@ -202,10 +244,13 @@ def _safe_asset_path(asset_rel_path: str) -> Path:
202
 
203
  base = img_dir.resolve()
204
  rel = Path(asset_rel_path)
 
205
  if rel.is_absolute():
206
  raise ValueError("Absolute paths are not allowed for asset_image.")
207
 
208
  full = (base / rel).resolve()
 
 
209
  if base != full and base not in full.parents:
210
  raise ValueError(f"Invalid asset path (path traversal blocked): {asset_rel_path}")
211
 
@@ -222,37 +267,31 @@ def _load_asset_image_and_mask(asset_rel_path: str) -> Tuple[torch.Tensor, torch
222
  Returns (IMAGE, MASK) in ComfyUI formats.
223
 
224
  Mask semantics: match ComfyUI core LoadImage:
225
- - If PNG has alpha: mask = 1 - alpha
226
- - Else: mask = 0
227
  """
228
  p = _safe_asset_path(asset_rel_path)
229
 
230
  im = Image.open(p)
231
  im = ImageOps.exif_transpose(im)
232
 
233
- had_alpha = ("A" in im.getbands())
234
  rgba = im.convert("RGBA")
235
  rgb = rgba.convert("RGB")
236
 
237
  rgb_arr = np.array(rgb).astype(np.float32) / 255.0 # [H,W,3]
238
  img_t = torch.from_numpy(rgb_arr)[None, ...]
239
 
240
- if had_alpha:
241
- alpha = np.array(rgba.getchannel("A")).astype(np.float32) / 255.0
242
- mask = 1.0 - alpha
243
- else:
244
- h, w = rgb.size[1], rgb.size[0]
245
- mask = np.zeros((h, w), dtype=np.float32)
246
 
247
  mask_t = torch.from_numpy(mask)[None, ...]
248
  return img_t, mask_t
249
 
250
 
251
  # -------------------------------------------------------------------------------------
252
- # Salia_Depth (INLINED: exact logic, no imports from other files)
253
  # -------------------------------------------------------------------------------------
254
 
255
- # Local model path: assets/depth
256
  MODEL_DIR = PLUGIN_ROOT / "assets" / "depth"
257
  MODEL_DIR.mkdir(parents=True, exist_ok=True)
258
 
@@ -264,6 +303,9 @@ REQUIRED_FILES = {
264
 
265
  ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti"
266
 
 
 
 
267
 
268
  def _have_required_files() -> bool:
269
  return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys())
@@ -420,26 +462,24 @@ def u8_to_comfy_tensor(img_u8: np.ndarray) -> torch.Tensor:
420
  return t.unsqueeze(0) # [1,H,W,C]
421
 
422
 
423
- _PIPE_CACHE: Dict[Tuple[str, str], Any] = {} # (model_source, device_str) -> pipeline
424
-
425
-
426
  def _try_load_pipeline(model_source: str, device: torch.device):
427
  if pipeline is None:
428
  raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}")
429
 
430
  key = (model_source, str(device))
431
- if key in _PIPE_CACHE:
432
- return _PIPE_CACHE[key]
 
433
 
434
- p = pipeline(task="depth-estimation", model=model_source)
435
- try:
436
- p.model = p.model.to(device)
437
- p.device = device
438
- except Exception:
439
- pass
440
 
441
- _PIPE_CACHE[key] = p
442
- return p
443
 
444
 
445
  def get_depth_pipeline(device: torch.device):
@@ -511,28 +551,24 @@ def resize_to_original(depth_rgb_u8: np.ndarray, w0: int, h0: int) -> np.ndarray
511
 
512
  def _salia_depth_execute(image: torch.Tensor, resolution: int = -1) -> torch.Tensor:
513
  """
514
- Internal callable version of your Salia_Depth node:
515
  input: IMAGE [B,H,W,3 or 4]
516
  output: IMAGE [B,H,W,3]
517
  """
518
- # Get torch device
519
  try:
520
  device = model_management.get_torch_device()
521
  except Exception:
522
  device = torch.device("cpu")
523
 
524
- # Load pipeline
525
- pipe = None
526
  try:
527
- pipe = get_depth_pipeline(device)
528
  except Exception:
529
- pipe = None
530
 
531
- # If everything fails, pass-through
532
- if pipe is None:
533
  return image
534
 
535
- # Batch support
536
  if image.ndim == 3:
537
  image = image.unsqueeze(0)
538
 
@@ -544,22 +580,18 @@ def _salia_depth_execute(image: torch.Tensor, resolution: int = -1) -> torch.Ten
544
 
545
  inp_u8 = comfy_tensor_to_u8(image[i])
546
 
547
- # RGBA rule (pre)
548
  rgb_for_depth, alpha_u8 = composite_rgba_over_white_keep_alpha(inp_u8)
549
  had_rgba = alpha_u8 is not None
550
 
551
- # Depth
552
  depth_rgb = depth_estimate_zoe_style(
553
- pipe=pipe,
554
  input_rgb_u8=rgb_for_depth,
555
  detect_resolution=int(resolution),
556
  upscale_method="INTER_CUBIC",
557
  )
558
 
559
- # Resize back to original size
560
  depth_rgb = resize_to_original(depth_rgb, w0=w0, h0=h0)
561
 
562
- # RGBA rule (post)
563
  if had_rgba:
564
  if alpha_u8.shape[0] != h0 or alpha_u8.shape[1] != w0:
565
  try:
@@ -602,7 +634,6 @@ def _alpha_over_region(base: torch.Tensor, overlay_rgba: torch.Tensor, x: int, y
602
  if x < 0 or y < 0 or x + s > W or y + s > H:
603
  raise ValueError(f"Square paste out of bounds. base={W}x{H}, paste at ({x},{y}) size={s}")
604
 
605
- # Match batch
606
  if b2 != B:
607
  if b2 == 1 and B > 1:
608
  overlay_rgba = overlay_rgba.expand(B, -1, -1, -1)
@@ -618,7 +649,6 @@ def _alpha_over_region(base: torch.Tensor, overlay_rgba: torch.Tensor, x: int, y
618
  comp_rgb = overlay_rgb * overlay_a + base_rgb * (1.0 - overlay_a)
619
  out[:, y:y + s, x:x + s, 0:3] = comp_rgb
620
 
621
- # If base has alpha, composite alpha too
622
  if C == 4:
623
  base_a = out[:, y:y + s, x:x + s, 3:4].clamp(0, 1)
624
  comp_a = overlay_a + base_a * (1.0 - overlay_a)
@@ -703,6 +733,7 @@ class Salia_OneNode_WorkflowSquare:
703
  scheduler: str,
704
  denoise: float,
705
  ):
 
706
  if image.ndim == 3:
707
  image = image.unsqueeze(0)
708
  if image.ndim != 4:
@@ -719,6 +750,7 @@ class Salia_OneNode_WorkflowSquare:
719
  up = int(upscale_factor)
720
  if up not in (1, 2, 4, 6, 8, 10, 12, 14, 16):
721
  raise ValueError("upscale_factor must be one of: 1,2,4,6,8,10,12,14,16")
 
722
  if s <= 0:
723
  raise ValueError("square_size must be > 0")
724
  if x < 0 or y < 0 or x + s > W or y + s > H:
@@ -727,6 +759,7 @@ class Salia_OneNode_WorkflowSquare:
727
  up_w = s * up
728
  up_h = s * up
729
 
 
730
  if (up_w % 8) != 0 or (up_h % 8) != 0:
731
  raise ValueError("square_size * upscale_factor must be divisible by 8 (required by VAE pipeline).")
732
 
@@ -735,18 +768,18 @@ class Salia_OneNode_WorkflowSquare:
735
  if end_p < start_p:
736
  start_p, end_p = end_p, start_p
737
 
738
- # 1) Crop
739
  crop = image[:, y:y + s, x:x + s, :]
740
  crop_rgb = crop[:, :, :, 0:3].contiguous()
741
 
742
- # 2) Depth (inline Salia_Depth) then Lanczos upscale
743
  depth_small = _salia_depth_execute(crop_rgb, resolution=s)
744
  depth_up = _resize_image_lanczos(depth_small, up_w, up_h)
745
 
746
- # 3) Upscale crop for VAE Encode
747
  crop_up = _resize_image_lanczos(crop_rgb, up_w, up_h)
748
 
749
- # 4) Load asset mask (inline) and resize
750
  if asset_image == "<no pngs found>":
751
  raise FileNotFoundError("No PNGs found in assets/images for this plugin.")
752
  _asset_img_unused, asset_mask = _load_asset_image_and_mask(asset_image)
@@ -826,18 +859,13 @@ class Salia_OneNode_WorkflowSquare:
826
  latent_image=latent,
827
  )
828
 
829
- # 10) VAE Decode
830
  vae_dec = nodes.VAEDecode()
831
  vae_dec_fn = getattr(vae_dec, vae_dec.FUNCTION)
832
  (decoded_rgb,) = vae_dec_fn(samples=sampled_latent, vae=vae)
833
 
834
- # 11) JoinImageWithAlpha
835
- join = nodes.JoinImageWithAlpha()
836
- join_fn = getattr(join, join.FUNCTION)
837
- try:
838
- (rgba_up,) = join_fn(image=decoded_rgb, alpha=asset_mask_up)
839
- except TypeError:
840
- (rgba_up,) = join_fn(image=decoded_rgb, mask=asset_mask_up)
841
 
842
  # 12) Downscale RGBA back to crop size
843
  rgba_square = _resize_image_lanczos(rgba_up, s, s)
 
22
 
23
 
24
  # -------------------------------------------------------------------------------------
25
+ # Global caches (checkpoint + controlnet) so using the node multiple times won't reload
26
  # -------------------------------------------------------------------------------------
27
 
28
  _CKPT_CACHE: Dict[str, Tuple[Any, Any, Any]] = {}
 
32
 
33
 
34
  # -------------------------------------------------------------------------------------
35
+ # Plugin root detection (works whether file is in plugin root or nodes/)
36
  # -------------------------------------------------------------------------------------
37
 
38
  def _find_plugin_root() -> Path:
39
  """
40
  Walk upwards from this file until we find an 'assets' folder.
41
+ Robust against hyphen/underscore package naming and different file placement.
42
  """
43
  here = Path(__file__).resolve()
44
+ for parent in [here.parent] + list(here.parents)[:12]:
45
  if (parent / "assets").is_dir():
46
  return parent
47
+ # fallback: typical nodes/<file>.py
48
  return here.parent.parent
49
 
50
 
 
135
  return torch.cat(outs, dim=0)
136
 
137
 
138
+ # -------------------------------------------------------------------------------------
139
+ # ✅ ComfyUI 0.5.1 FIX: Manual JoinImageWithAlpha equivalent
140
+ # -------------------------------------------------------------------------------------
141
+
142
+ def _rgb_to_rgba_with_comfy_mask(rgb: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
143
+ """
144
+ Make RGBA from:
145
+ rgb: IMAGE [B,H,W,3] float [0..1]
146
+ mask: MASK [B,H,W] float [0..1] (Comfy convention: 1=masked/transparent)
147
+ Output:
148
+ rgba: IMAGE [B,H,W,4] where alpha = 1 - mask (1=opaque, 0=transparent)
149
+ """
150
+ if rgb.ndim == 3:
151
+ rgb = rgb.unsqueeze(0)
152
+ if mask.ndim == 2:
153
+ mask = mask.unsqueeze(0)
154
+
155
+ if rgb.ndim != 4 or rgb.shape[-1] != 3:
156
+ raise ValueError(f"rgb must be [B,H,W,3], got {tuple(rgb.shape)}")
157
+ if mask.ndim != 3:
158
+ raise ValueError(f"mask must be [B,H,W], got {tuple(mask.shape)}")
159
+
160
+ # Batch match
161
+ if mask.shape[0] != rgb.shape[0]:
162
+ if mask.shape[0] == 1 and rgb.shape[0] > 1:
163
+ mask = mask.expand(rgb.shape[0], -1, -1)
164
+ else:
165
+ raise ValueError("Batch mismatch between rgb and mask.")
166
+
167
+ # Size match
168
+ if mask.shape[1] != rgb.shape[1] or mask.shape[2] != rgb.shape[2]:
169
+ raise ValueError(
170
+ f"Mask size mismatch. rgb={rgb.shape[2]}x{rgb.shape[1]} mask={mask.shape[2]}x{mask.shape[1]}"
171
+ )
172
+
173
+ mask = mask.to(device=rgb.device, dtype=rgb.dtype).clamp(0, 1)
174
+ alpha = (1.0 - mask).unsqueeze(-1).clamp(0, 1) # [B,H,W,1]
175
+
176
+ rgba = torch.cat([rgb.clamp(0, 1), alpha], dim=-1) # [B,H,W,4]
177
+ return rgba
178
+
179
+
180
  # -------------------------------------------------------------------------------------
181
  # Core lazy loaders (checkpoint + controlnet), cached globally
182
  # -------------------------------------------------------------------------------------
 
244
 
245
  base = img_dir.resolve()
246
  rel = Path(asset_rel_path)
247
+
248
  if rel.is_absolute():
249
  raise ValueError("Absolute paths are not allowed for asset_image.")
250
 
251
  full = (base / rel).resolve()
252
+
253
+ # path traversal protection
254
  if base != full and base not in full.parents:
255
  raise ValueError(f"Invalid asset path (path traversal blocked): {asset_rel_path}")
256
 
 
267
  Returns (IMAGE, MASK) in ComfyUI formats.
268
 
269
  Mask semantics: match ComfyUI core LoadImage:
270
+ - alpha is RGBA alpha channel normalized to [0..1]
271
+ - mask = 1 - alpha
272
  """
273
  p = _safe_asset_path(asset_rel_path)
274
 
275
  im = Image.open(p)
276
  im = ImageOps.exif_transpose(im)
277
 
 
278
  rgba = im.convert("RGBA")
279
  rgb = rgba.convert("RGB")
280
 
281
  rgb_arr = np.array(rgb).astype(np.float32) / 255.0 # [H,W,3]
282
  img_t = torch.from_numpy(rgb_arr)[None, ...]
283
 
284
+ alpha = np.array(rgba.getchannel("A")).astype(np.float32) / 255.0 # [H,W]
285
+ mask = 1.0 - alpha # Comfy MASK convention
 
 
 
 
286
 
287
  mask_t = torch.from_numpy(mask)[None, ...]
288
  return img_t, mask_t
289
 
290
 
291
  # -------------------------------------------------------------------------------------
292
+ # Salia_Depth (INLINED, no imports from other files)
293
  # -------------------------------------------------------------------------------------
294
 
 
295
  MODEL_DIR = PLUGIN_ROOT / "assets" / "depth"
296
  MODEL_DIR.mkdir(parents=True, exist_ok=True)
297
 
 
303
 
304
  ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti"
305
 
306
+ _PIPE_CACHE: Dict[Tuple[str, str], Any] = {} # (model_source, device_str) -> pipeline
307
+ _PIPE_LOCK = threading.Lock()
308
+
309
 
310
  def _have_required_files() -> bool:
311
  return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys())
 
462
  return t.unsqueeze(0) # [1,H,W,C]
463
 
464
 
 
 
 
465
  def _try_load_pipeline(model_source: str, device: torch.device):
466
  if pipeline is None:
467
  raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}")
468
 
469
  key = (model_source, str(device))
470
+ with _PIPE_LOCK:
471
+ if key in _PIPE_CACHE:
472
+ return _PIPE_CACHE[key]
473
 
474
+ p = pipeline(task="depth-estimation", model=model_source)
475
+ try:
476
+ p.model = p.model.to(device)
477
+ p.device = device
478
+ except Exception:
479
+ pass
480
 
481
+ _PIPE_CACHE[key] = p
482
+ return p
483
 
484
 
485
  def get_depth_pipeline(device: torch.device):
 
551
 
552
  def _salia_depth_execute(image: torch.Tensor, resolution: int = -1) -> torch.Tensor:
553
  """
554
+ Internal callable version of Salia_Depth:
555
  input: IMAGE [B,H,W,3 or 4]
556
  output: IMAGE [B,H,W,3]
557
  """
 
558
  try:
559
  device = model_management.get_torch_device()
560
  except Exception:
561
  device = torch.device("cpu")
562
 
563
+ pipe_obj = None
 
564
  try:
565
+ pipe_obj = get_depth_pipeline(device)
566
  except Exception:
567
+ pipe_obj = None
568
 
569
+ if pipe_obj is None:
 
570
  return image
571
 
 
572
  if image.ndim == 3:
573
  image = image.unsqueeze(0)
574
 
 
580
 
581
  inp_u8 = comfy_tensor_to_u8(image[i])
582
 
 
583
  rgb_for_depth, alpha_u8 = composite_rgba_over_white_keep_alpha(inp_u8)
584
  had_rgba = alpha_u8 is not None
585
 
 
586
  depth_rgb = depth_estimate_zoe_style(
587
+ pipe=pipe_obj,
588
  input_rgb_u8=rgb_for_depth,
589
  detect_resolution=int(resolution),
590
  upscale_method="INTER_CUBIC",
591
  )
592
 
 
593
  depth_rgb = resize_to_original(depth_rgb, w0=w0, h0=h0)
594
 
 
595
  if had_rgba:
596
  if alpha_u8.shape[0] != h0 or alpha_u8.shape[1] != w0:
597
  try:
 
634
  if x < 0 or y < 0 or x + s > W or y + s > H:
635
  raise ValueError(f"Square paste out of bounds. base={W}x{H}, paste at ({x},{y}) size={s}")
636
 
 
637
  if b2 != B:
638
  if b2 == 1 and B > 1:
639
  overlay_rgba = overlay_rgba.expand(B, -1, -1, -1)
 
649
  comp_rgb = overlay_rgb * overlay_a + base_rgb * (1.0 - overlay_a)
650
  out[:, y:y + s, x:x + s, 0:3] = comp_rgb
651
 
 
652
  if C == 4:
653
  base_a = out[:, y:y + s, x:x + s, 3:4].clamp(0, 1)
654
  comp_a = overlay_a + base_a * (1.0 - overlay_a)
 
733
  scheduler: str,
734
  denoise: float,
735
  ):
736
+ # Normalize input to [B,H,W,C]
737
  if image.ndim == 3:
738
  image = image.unsqueeze(0)
739
  if image.ndim != 4:
 
750
  up = int(upscale_factor)
751
  if up not in (1, 2, 4, 6, 8, 10, 12, 14, 16):
752
  raise ValueError("upscale_factor must be one of: 1,2,4,6,8,10,12,14,16")
753
+
754
  if s <= 0:
755
  raise ValueError("square_size must be > 0")
756
  if x < 0 or y < 0 or x + s > W or y + s > H:
 
759
  up_w = s * up
760
  up_h = s * up
761
 
762
+ # VAE/UNet path likes multiples of 8
763
  if (up_w % 8) != 0 or (up_h % 8) != 0:
764
  raise ValueError("square_size * upscale_factor must be divisible by 8 (required by VAE pipeline).")
765
 
 
768
  if end_p < start_p:
769
  start_p, end_p = end_p, start_p
770
 
771
+ # 1) Crop square
772
  crop = image[:, y:y + s, x:x + s, :]
773
  crop_rgb = crop[:, :, :, 0:3].contiguous()
774
 
775
+ # 2) Depth (inline Salia_Depth) then upscale with Lanczos
776
  depth_small = _salia_depth_execute(crop_rgb, resolution=s)
777
  depth_up = _resize_image_lanczos(depth_small, up_w, up_h)
778
 
779
+ # 3) Upscale crop for VAE encode
780
  crop_up = _resize_image_lanczos(crop_rgb, up_w, up_h)
781
 
782
+ # 4) Load asset mask and resize
783
  if asset_image == "<no pngs found>":
784
  raise FileNotFoundError("No PNGs found in assets/images for this plugin.")
785
  _asset_img_unused, asset_mask = _load_asset_image_and_mask(asset_image)
 
859
  latent_image=latent,
860
  )
861
 
862
+ # 10) VAE Decode -> RGB
863
  vae_dec = nodes.VAEDecode()
864
  vae_dec_fn = getattr(vae_dec, vae_dec.FUNCTION)
865
  (decoded_rgb,) = vae_dec_fn(samples=sampled_latent, vae=vae)
866
 
867
+ # 11) ✅ Manual "JoinImageWithAlpha"
868
+ rgba_up = _rgb_to_rgba_with_comfy_mask(decoded_rgb, asset_mask_up)
 
 
 
 
 
869
 
870
  # 12) Downscale RGBA back to crop size
871
  rgba_square = _resize_image_lanczos(rgba_up, s, s)