saliacoel commited on
Commit
f559035
·
verified ·
1 Parent(s): d68f6df

Update AILab_SAM3Segment.py

Browse files
Files changed (1) hide show
  1. AILab_SAM3Segment.py +1317 -106
AILab_SAM3Segment.py CHANGED
@@ -1,18 +1,37 @@
 
 
 
 
 
 
 
1
  import os
2
  import sys
 
 
 
 
 
3
  from contextlib import nullcontext
4
  from pathlib import Path
 
5
 
6
  import numpy as np
7
  import torch
8
- from PIL import Image, ImageFilter
 
9
  from torch.hub import download_url_to_file
10
 
11
  import folder_paths
12
  import comfy.model_management
 
13
 
14
  from AILab_ImageMaskTools import pil2tensor, tensor2pil
15
 
 
 
 
 
16
  CURRENT_DIR = os.path.dirname(__file__)
17
  SAM3_LOCAL_DIR = os.path.join(CURRENT_DIR, "sam3")
18
  if SAM3_LOCAL_DIR not in sys.path:
@@ -26,7 +45,7 @@ from sam3.model_builder import build_sam3_image_model # noqa: E402
26
  from sam3.model.sam3_image_processor import Sam3Processor # noqa: E402
27
 
28
  _DEFAULT_PT_ENTRY = {
29
- "model_url": "https://huggingface.co/saliacoel/x/resolve/main/sam3.pt",
30
  "filename": "sam3.pt",
31
  }
32
 
@@ -36,11 +55,9 @@ SAM3_MODELS = {
36
 
37
 
38
  def get_sam3_pt_models():
39
- """Return a dictionary containing the PT model definition."""
40
  entry = SAM3_MODELS.get("sam3")
41
  if entry and entry.get("filename", "").endswith(".pt"):
42
  return {"sam3": entry}
43
- # Fallback: upgrade any legacy entry to PT naming
44
  for key, value in SAM3_MODELS.items():
45
  if value.get("filename", "").endswith(".pt"):
46
  return {"sam3": value}
@@ -193,7 +210,6 @@ class SAM3Segment:
193
  return result_image, mask_tensor, mask_rgb
194
 
195
  def segment(self, image, prompt, sam3_model, device, confidence_threshold=0.5, mask_blur=0, mask_offset=0, invert_output=False, unload_model=False, background="Alpha", background_color="#222222"):
196
-
197
  if image.ndim == 3:
198
  image = image.unsqueeze(0)
199
 
@@ -233,97 +249,1311 @@ class SAM3Segment:
233
 
234
 
235
  # ======================================================================================
236
- # NEW FUSED NODE: Salia_ezpz_gated_Duo2 -> SAM3Segment (hardcoded) -> apply_segment_4
237
  # ======================================================================================
238
 
239
- def _fallback_list_asset_pngs():
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  """
241
- Best-effort dropdown helper for both Salia_ezpz_gated_Duo2 and apply_segment_4.
242
- Tries to find a nearby 'assets/images' directory by walking upwards from this file.
243
- Returns relative posix paths (supports subfolders). If none found, returns placeholder.
244
  """
245
  here = Path(__file__).resolve()
246
- images_dir = None
247
  for parent in [here.parent] + list(here.parents)[:12]:
248
- cand = parent / "assets" / "images"
249
- if cand.is_dir():
250
- images_dir = cand
251
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
- if images_dir is None:
254
- return ["<no pngs found>"]
 
 
 
 
 
 
 
 
255
 
 
 
 
 
 
 
 
 
 
256
  files = []
257
- for p in images_dir.rglob("*.png"):
258
- if p.is_file():
259
- files.append(p.relative_to(images_dir).as_posix())
260
  files.sort()
261
- return files or ["<no pngs found>"]
262
 
263
 
264
- def _safe_get_choices_from_node(node_name: str, input_key: str):
265
- """
266
- Try to mirror the exact dropdown options of another loaded node.
267
- Returns None on failure.
268
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  try:
270
- import nodes # comfy core module where custom nodes are registered
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
- node_cls = nodes.NODE_CLASS_MAPPINGS.get(node_name)
273
- if node_cls is None:
274
- return None
275
 
276
- in_types = node_cls.INPUT_TYPES()
277
- req = in_types.get("required", {})
278
- field = req.get(input_key)
279
 
280
- # field is typically like: (choices, config_dict)
281
- if isinstance(field, tuple) and len(field) > 0:
282
- choices = field[0]
283
- if isinstance(choices, (list, tuple)) and len(choices) > 0:
284
- return list(choices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  except Exception:
286
  return None
287
- return None
288
 
289
 
290
- class SAM3Segment_Salia:
291
- """
292
- Fused node pipeline:
293
-
294
- if trigger_string == "":
295
- return input image unchanged
296
-
297
- else:
298
- 1) Salia_ezpz_gated_Duo2(image)-> (image, image_cropped)
299
- 2) SAM3Segment(image_cropped, prompt=...) -> (seg_image, seg_mask, _)
300
- hardcoded:
301
- sam3_model="sam3"
302
- device="GPU"
303
- confidence_threshold=0.50
304
- mask_blur=0
305
- mask_offset=0
306
- invert_output=False
307
- unload_model=False
308
- background="Alpha"
309
- 3) apply_segment_4(mask=seg_mask, img=seg_image, canvas=input image, x=X_coord, y=Y_coord)
310
-
311
- Output: Final_Image
312
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  CATEGORY = "image/salia"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  RETURN_TYPES = ("IMAGE",)
316
  RETURN_NAMES = ("Final_Image",)
317
  FUNCTION = "run"
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  @classmethod
320
- def INPUT_TYPES(cls):
321
- # Pull dropdown choices from the other nodes (if available), else fallback.
322
- assets_salia = _safe_get_choices_from_node("Salia_ezpz_gated_Duo2", "asset_image") or _fallback_list_asset_pngs()
323
- assets_apply = _safe_get_choices_from_node("apply_segment_4", "image") or _fallback_list_asset_pngs()
324
 
325
- upscale_choices = ["1", "2", "4", "6", "8", "10", "12", "14", "16"]
 
 
 
 
 
 
 
 
 
 
326
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  return {
328
  "required": {
329
  "image": ("IMAGE",),
@@ -332,23 +1562,17 @@ class SAM3Segment_Salia:
332
  "X_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
333
  "Y_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
334
 
335
- # 3 prompts total
336
  "positive_prompt": ("STRING", {"default": "", "multiline": True}),
337
  "negative_prompt": ("STRING", {"default": "", "multiline": True}),
338
  "prompt": ("STRING", {"default": "", "multiline": True, "placeholder": "SAM3 prompt"}),
339
 
340
- # two different asset selections:
341
- # - asset_image => Salia_ezpz_gated_Duo2
342
- # - apply_asset_image => apply_segment_4
343
- "asset_image": (assets_salia, {}),
344
- "apply_asset_image": (assets_apply, {}),
345
 
346
- # Salia_ezpz_gated_Duo2 pass-1 inputs
347
  "square_size_1": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
348
  "upscale_factor_1": (upscale_choices, {"default": "4"}),
349
  "denoise_1": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
350
 
351
- # Salia_ezpz_gated_Duo2 pass-2 inputs
352
  "square_size_2": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
353
  "upscale_factor_2": (upscale_choices, {"default": "4"}),
354
  "denoise_2": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
@@ -356,22 +1580,9 @@ class SAM3Segment_Salia:
356
  }
357
 
358
  def __init__(self):
359
- # Reuse SAM3Segment instance to benefit from its processor_cache.
360
  self._sam3 = SAM3Segment()
361
- self._salia_node = None
362
- self._apply_node = None
363
-
364
- @staticmethod
365
- def _require_node_instance(node_name: str):
366
- import nodes # comfy core module where custom nodes are registered
367
-
368
- node_cls = nodes.NODE_CLASS_MAPPINGS.get(node_name)
369
- if node_cls is None:
370
- raise RuntimeError(
371
- f"Required node '{node_name}' was not found in nodes.NODE_CLASS_MAPPINGS. "
372
- f"Make sure its custom-node file is installed and loaded."
373
- )
374
- return node_cls()
375
 
376
  def run(
377
  self,
@@ -391,19 +1602,12 @@ class SAM3Segment_Salia:
391
  upscale_factor_2="4",
392
  denoise_2=0.35,
393
  ):
394
- # Hard bypass: if trigger_string is exactly empty, skip ALL processing.
395
  if trigger_string == "":
396
  return (image,)
397
 
398
- # Lazily instantiate dependent nodes.
399
- if self._salia_node is None:
400
- self._salia_node = self._require_node_instance("Salia_ezpz_gated_Duo2")
401
- if self._apply_node is None:
402
- self._apply_node = self._require_node_instance("apply_segment_4")
403
-
404
- # 1) Run Salia_ezpz_gated_Duo2 (pre-node)
405
- salia_fn = getattr(self._salia_node, getattr(self._salia_node, "FUNCTION", "run"))
406
- out_image, image_cropped = salia_fn(
407
  image=image,
408
  trigger_string=trigger_string,
409
  X_coord=int(X_coord),
@@ -419,7 +1623,7 @@ class SAM3Segment_Salia:
419
  denoise_2=float(denoise_2),
420
  )
421
 
422
- # 2) Run SAM3Segment (center node) on the CROPPED image, with hardcoded settings.
423
  seg_image, seg_mask, _mask_image = self._sam3.segment(
424
  image=image_cropped,
425
  prompt=str(prompt),
@@ -434,9 +1638,8 @@ class SAM3Segment_Salia:
434
  background_color="#222222",
435
  )
436
 
437
- # 3) Run apply_segment_4 (post-node) on the ORIGINAL canvas image.
438
- apply_fn = getattr(self._apply_node, getattr(self._apply_node, "FUNCTION", "run"))
439
- (final_image,) = apply_fn(
440
  mask=seg_mask,
441
  image=str(apply_asset_image),
442
  img=seg_image,
@@ -448,12 +1651,20 @@ class SAM3Segment_Salia:
448
  return (final_image,)
449
 
450
 
 
 
 
 
451
  NODE_CLASS_MAPPINGS = {
452
  "SAM3Segment": SAM3Segment,
 
 
453
  "SAM3Segment_Salia": SAM3Segment_Salia,
454
  }
455
 
456
  NODE_DISPLAY_NAME_MAPPINGS = {
457
  "SAM3Segment": "SAM3 Segmentation (RMBG)",
458
- "SAM3Segment_Salia": "SAM3Segment_Salia (EZPZ + SAM3 + apply_segment_4)",
459
- }
 
 
 
1
+ # AILab_SAM3Segment.py
2
+ # Integrated standalone nodes:
3
+ # - SAM3Segment
4
+ # - Salia_ezpz_gated_Duo2
5
+ # - apply_segment_4
6
+ # - SAM3Segment_Salia (fused)
7
+
8
  import os
9
  import sys
10
+ import hashlib
11
+ import shutil
12
+ import threading
13
+ import urllib.request
14
+ import heapq
15
  from contextlib import nullcontext
16
  from pathlib import Path
17
+ from typing import Any, Dict, Tuple, Optional, List
18
 
19
  import numpy as np
20
  import torch
21
+ import torch.nn.functional as F
22
+ from PIL import Image, ImageFilter, ImageOps
23
  from torch.hub import download_url_to_file
24
 
25
  import folder_paths
26
  import comfy.model_management
27
+ import comfy.model_management as model_management
28
 
29
  from AILab_ImageMaskTools import pil2tensor, tensor2pil
30
 
31
+ # ======================================================================================
32
+ # SAM3Segment (original, with syntax fix)
33
+ # ======================================================================================
34
+
35
  CURRENT_DIR = os.path.dirname(__file__)
36
  SAM3_LOCAL_DIR = os.path.join(CURRENT_DIR, "sam3")
37
  if SAM3_LOCAL_DIR not in sys.path:
 
45
  from sam3.model.sam3_image_processor import Sam3Processor # noqa: E402
46
 
47
  _DEFAULT_PT_ENTRY = {
48
+ "model_url": "https://huggingface.co/1038lab/sam3/resolve/main/sam3.pt",
49
  "filename": "sam3.pt",
50
  }
51
 
 
55
 
56
 
57
  def get_sam3_pt_models():
 
58
  entry = SAM3_MODELS.get("sam3")
59
  if entry and entry.get("filename", "").endswith(".pt"):
60
  return {"sam3": entry}
 
61
  for key, value in SAM3_MODELS.items():
62
  if value.get("filename", "").endswith(".pt"):
63
  return {"sam3": value}
 
210
  return result_image, mask_tensor, mask_rgb
211
 
212
  def segment(self, image, prompt, sam3_model, device, confidence_threshold=0.5, mask_blur=0, mask_offset=0, invert_output=False, unload_model=False, background="Alpha", background_color="#222222"):
 
213
  if image.ndim == 3:
214
  image = image.unsqueeze(0)
215
 
 
249
 
250
 
251
  # ======================================================================================
252
+ # Salia_ezpz_gated_Duo2 (standalone)
253
  # ======================================================================================
254
 
255
+ # transformers is required for depth-estimation pipeline
256
+ try:
257
+ from transformers import pipeline
258
+ except Exception as e:
259
+ pipeline = None
260
+ _TRANSFORMERS_IMPORT_ERROR = e
261
+
262
+ _CKPT_CACHE: Dict[str, Tuple[Any, Any, Any]] = {}
263
+ _CN_CACHE: Dict[str, Any] = {}
264
+ _CKPT_LOCK = threading.Lock()
265
+ _CN_LOCK = threading.Lock()
266
+
267
+
268
+ def _find_plugin_root() -> Path:
269
  """
270
+ Walk upwards from this file until we find an 'assets' folder.
271
+ If not found, fall back to this file's directory.
 
272
  """
273
  here = Path(__file__).resolve()
 
274
  for parent in [here.parent] + list(here.parents)[:12]:
275
+ if (parent / "assets").is_dir():
276
+ return parent
277
+ return here.parent
278
+
279
+
280
+ PLUGIN_ROOT = _find_plugin_root()
281
+
282
+
283
+ def _pil_lanczos():
284
+ if hasattr(Image, "Resampling"):
285
+ return Image.Resampling.LANCZOS
286
+ return Image.LANCZOS
287
+
288
+
289
+ def _image_tensor_to_pil(img: torch.Tensor) -> Image.Image:
290
+ if img.ndim == 4:
291
+ img = img[0]
292
+ img = img.detach().cpu().float().clamp(0, 1)
293
+ arr = (img.numpy() * 255.0).round().astype(np.uint8)
294
+ if arr.shape[-1] == 4:
295
+ return Image.fromarray(arr, mode="RGBA")
296
+ return Image.fromarray(arr, mode="RGB")
297
+
298
+
299
+ def _pil_to_image_tensor(pil: Image.Image) -> torch.Tensor:
300
+ if pil.mode not in ("RGB", "RGBA"):
301
+ pil = pil.convert("RGBA") if "A" in pil.getbands() else pil.convert("RGB")
302
+ arr = np.array(pil).astype(np.float32) / 255.0
303
+ t = torch.from_numpy(arr)
304
+ return t.unsqueeze(0)
305
+
306
+
307
+ def _mask_tensor_to_pil(mask: torch.Tensor) -> Image.Image:
308
+ if mask.ndim == 3:
309
+ mask = mask[0]
310
+ mask = mask.detach().cpu().float().clamp(0, 1)
311
+ arr = (mask.numpy() * 255.0).round().astype(np.uint8)
312
+ return Image.fromarray(arr, mode="L")
313
+
314
+
315
+ def _pil_to_mask_tensor(pil_l: Image.Image) -> torch.Tensor:
316
+ if pil_l.mode != "L":
317
+ pil_l = pil_l.convert("L")
318
+ arr = np.array(pil_l).astype(np.float32) / 255.0
319
+ t = torch.from_numpy(arr)
320
+ return t.unsqueeze(0)
321
+
322
+
323
+ def _resize_image_lanczos(img: torch.Tensor, w: int, h: int) -> torch.Tensor:
324
+ if img.ndim != 4:
325
+ raise ValueError("Expected IMAGE tensor with shape [B,H,W,C].")
326
+ outs = []
327
+ for i in range(img.shape[0]):
328
+ pil = _image_tensor_to_pil(img[i].unsqueeze(0))
329
+ pil = pil.resize((int(w), int(h)), resample=_pil_lanczos())
330
+ outs.append(_pil_to_image_tensor(pil))
331
+ return torch.cat(outs, dim=0)
332
+
333
+
334
+ def _resize_mask_lanczos(mask: torch.Tensor, w: int, h: int) -> torch.Tensor:
335
+ if mask.ndim != 3:
336
+ raise ValueError("Expected MASK tensor with shape [B,H,W].")
337
+ outs = []
338
+ for i in range(mask.shape[0]):
339
+ pil = _mask_tensor_to_pil(mask[i].unsqueeze(0))
340
+ pil = pil.resize((int(w), int(h)), resample=_pil_lanczos())
341
+ outs.append(_pil_to_mask_tensor(pil))
342
+ return torch.cat(outs, dim=0)
343
+
344
+
345
+ def _rgb_to_rgba_with_comfy_mask(rgb: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
346
+ if rgb.ndim == 3:
347
+ rgb = rgb.unsqueeze(0)
348
+ if mask.ndim == 2:
349
+ mask = mask.unsqueeze(0)
350
+
351
+ if rgb.ndim != 4 or rgb.shape[-1] != 3:
352
+ raise ValueError(f"rgb must be [B,H,W,3], got {tuple(rgb.shape)}")
353
+ if mask.ndim != 3:
354
+ raise ValueError(f"mask must be [B,H,W], got {tuple(mask.shape)}")
355
+
356
+ if mask.shape[0] != rgb.shape[0]:
357
+ if mask.shape[0] == 1 and rgb.shape[0] > 1:
358
+ mask = mask.expand(rgb.shape[0], -1, -1)
359
+ else:
360
+ raise ValueError("Batch mismatch between rgb and mask.")
361
+
362
+ if mask.shape[1] != rgb.shape[1] or mask.shape[2] != rgb.shape[2]:
363
+ raise ValueError(
364
+ f"Mask size mismatch. rgb={rgb.shape[2]}x{rgb.shape[1]} mask={mask.shape[2]}x{mask.shape[1]}"
365
+ )
366
+
367
+ mask = mask.to(device=rgb.device, dtype=rgb.dtype).clamp(0, 1)
368
+ alpha = (1.0 - mask).unsqueeze(-1).clamp(0, 1)
369
+ rgba = torch.cat([rgb.clamp(0, 1), alpha], dim=-1)
370
+ return rgba
371
+
372
+
373
+ def _load_checkpoint_cached(ckpt_name: str):
374
+ with _CKPT_LOCK:
375
+ if ckpt_name in _CKPT_CACHE:
376
+ return _CKPT_CACHE[ckpt_name]
377
+ import nodes
378
+ loader = nodes.CheckpointLoaderSimple()
379
+ fn = getattr(loader, loader.FUNCTION)
380
+ model, clip, vae = fn(ckpt_name=ckpt_name)
381
+ _CKPT_CACHE[ckpt_name] = (model, clip, vae)
382
+ return model, clip, vae
383
+
384
 
385
+ def _load_controlnet_cached(control_net_name: str):
386
+ with _CN_LOCK:
387
+ if control_net_name in _CN_CACHE:
388
+ return _CN_CACHE[control_net_name]
389
+ import nodes
390
+ loader = nodes.ControlNetLoader()
391
+ fn = getattr(loader, loader.FUNCTION)
392
+ (cn,) = fn(control_net_name=control_net_name)
393
+ _CN_CACHE[control_net_name] = cn
394
+ return cn
395
 
396
+
397
+ def _assets_images_dir() -> Path:
398
+ return PLUGIN_ROOT / "assets" / "images"
399
+
400
+
401
+ def _list_asset_pngs() -> list:
402
+ img_dir = _assets_images_dir()
403
+ if not img_dir.is_dir():
404
+ return []
405
  files = []
406
+ for p in img_dir.rglob("*"):
407
+ if p.is_file() and p.suffix.lower() == ".png":
408
+ files.append(p.relative_to(img_dir).as_posix())
409
  files.sort()
410
+ return files
411
 
412
 
413
+ def _safe_asset_path(asset_rel_path: str) -> Path:
414
+ img_dir = _assets_images_dir()
415
+ if not img_dir.is_dir():
416
+ raise FileNotFoundError(f"assets/images folder not found: {img_dir}")
417
+
418
+ base = img_dir.resolve()
419
+ rel = Path(asset_rel_path)
420
+
421
+ if rel.is_absolute():
422
+ raise ValueError("Absolute paths are not allowed for asset_image.")
423
+
424
+ full = (base / rel).resolve()
425
+
426
+ if base != full and base not in full.parents:
427
+ raise ValueError(f"Invalid asset path (path traversal blocked): {asset_rel_path}")
428
+
429
+ if not full.is_file():
430
+ raise FileNotFoundError(f"Asset PNG not found in assets/images: {asset_rel_path}")
431
+ if full.suffix.lower() != ".png":
432
+ raise ValueError(f"Asset is not a PNG: {asset_rel_path}")
433
+
434
+ return full
435
+
436
+
437
+ def _load_asset_image_and_mask(asset_rel_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
438
+ p = _safe_asset_path(asset_rel_path)
439
+
440
+ im = Image.open(p)
441
+ im = ImageOps.exif_transpose(im)
442
+
443
+ rgba = im.convert("RGBA")
444
+ rgb = rgba.convert("RGB")
445
+
446
+ rgb_arr = np.array(rgb).astype(np.float32) / 255.0
447
+ img_t = torch.from_numpy(rgb_arr)[None, ...]
448
+
449
+ alpha = np.array(rgba.getchannel("A")).astype(np.float32) / 255.0
450
+ mask = 1.0 - alpha
451
+
452
+ mask_t = torch.from_numpy(mask)[None, ...]
453
+ return img_t, mask_t
454
+
455
+
456
+ MODEL_DIR = PLUGIN_ROOT / "assets" / "depth"
457
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
458
+
459
+ REQUIRED_FILES = {
460
+ "config.json": "https://huggingface.co/saliacoel/depth/resolve/main/config.json",
461
+ "model.safetensors": "https://huggingface.co/saliacoel/depth/resolve/main/model.safetensors",
462
+ "preprocessor_config.json": "https://huggingface.co/saliacoel/depth/resolve/main/preprocessor_config.json",
463
+ }
464
+
465
+ ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti"
466
+
467
+ _PIPE_CACHE: Dict[Tuple[str, str], Any] = {}
468
+ _PIPE_LOCK = threading.Lock()
469
+
470
+
471
+ def _have_required_files() -> bool:
472
+ return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys())
473
+
474
+
475
+ def _download_url_to_file(url: str, dst: Path, timeout: int = 180) -> None:
476
+ dst.parent.mkdir(parents=True, exist_ok=True)
477
+ tmp = dst.with_suffix(dst.suffix + ".tmp")
478
+
479
+ if tmp.exists():
480
+ try:
481
+ tmp.unlink()
482
+ except Exception:
483
+ pass
484
+
485
+ req = urllib.request.Request(url, headers={"User-Agent": "ComfyUI-SaliaDepth/1.1"})
486
+ with urllib.request.urlopen(req, timeout=timeout) as r, open(tmp, "wb") as f:
487
+ shutil.copyfileobj(r, f)
488
+
489
+ tmp.replace(dst)
490
+
491
+
492
+ def ensure_local_model_files() -> bool:
493
+ if _have_required_files():
494
+ return True
495
  try:
496
+ for fname, url in REQUIRED_FILES.items():
497
+ fpath = MODEL_DIR / fname
498
+ if fpath.exists():
499
+ continue
500
+ _download_url_to_file(url, fpath)
501
+ return _have_required_files()
502
+ except Exception:
503
+ return False
504
+
505
+
506
+ def HWC3(x: np.ndarray) -> np.ndarray:
507
+ assert x.dtype == np.uint8
508
+ if x.ndim == 2:
509
+ x = x[:, :, None]
510
+ assert x.ndim == 3
511
+ H, W, C = x.shape
512
+ assert C == 1 or C == 3 or C == 4
513
+ if C == 3:
514
+ return x
515
+ if C == 1:
516
+ return np.concatenate([x, x, x], axis=2)
517
+ color = x[:, :, 0:3].astype(np.float32)
518
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
519
+ y = color * alpha + 255.0 * (1.0 - alpha)
520
+ y = y.clip(0, 255).astype(np.uint8)
521
+ return y
522
+
523
+
524
+ def pad64(x: int) -> int:
525
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
526
+
527
+
528
+ def safer_memory(x: np.ndarray) -> np.ndarray:
529
+ return np.ascontiguousarray(x.copy()).copy()
530
+
531
+
532
+ def resize_image_with_pad_min_side(
533
+ input_image: np.ndarray,
534
+ resolution: int,
535
+ upscale_method: str = "INTER_CUBIC",
536
+ skip_hwc3: bool = False,
537
+ mode: str = "edge",
538
+ ) -> Tuple[np.ndarray, Any]:
539
+ cv2 = None
540
+ try:
541
+ import cv2 as _cv2
542
+ cv2 = _cv2
543
+ except Exception:
544
+ cv2 = None
545
+
546
+ img = input_image if skip_hwc3 else HWC3(input_image)
547
+
548
+ H_raw, W_raw, _ = img.shape
549
+ if resolution <= 0:
550
+ return img, (lambda x: x)
551
+
552
+ k = float(resolution) / float(min(H_raw, W_raw))
553
+ H_target = int(np.round(float(H_raw) * k))
554
+ W_target = int(np.round(float(W_raw) * k))
555
+
556
+ if cv2 is not None:
557
+ upscale_methods = {
558
+ "INTER_NEAREST": cv2.INTER_NEAREST,
559
+ "INTER_LINEAR": cv2.INTER_LINEAR,
560
+ "INTER_AREA": cv2.INTER_AREA,
561
+ "INTER_CUBIC": cv2.INTER_CUBIC,
562
+ "INTER_LANCZOS4": cv2.INTER_LANCZOS4,
563
+ }
564
+ method = upscale_methods.get(upscale_method, cv2.INTER_CUBIC)
565
+ img = cv2.resize(img, (W_target, H_target), interpolation=method if k > 1 else cv2.INTER_AREA)
566
+ else:
567
+ pil = Image.fromarray(img)
568
+ resample = Image.BICUBIC if k > 1 else Image.LANCZOS
569
+ pil = pil.resize((W_target, H_target), resample=resample)
570
+ img = np.array(pil, dtype=np.uint8)
571
+
572
+ H_pad, W_pad = pad64(H_target), pad64(W_target)
573
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
574
+
575
+ def remove_pad(x: np.ndarray) -> np.ndarray:
576
+ return safer_memory(x[:H_target, :W_target, ...])
577
 
578
+ return safer_memory(img_padded), remove_pad
 
 
579
 
 
 
 
580
 
581
+ def pad_only_to_64(img_u8: np.ndarray, mode: str = "edge") -> Tuple[np.ndarray, Any]:
582
+ img = HWC3(img_u8)
583
+ H_raw, W_raw, _ = img.shape
584
+ H_pad, W_pad = pad64(H_raw), pad64(W_raw)
585
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
586
+
587
+ def remove_pad(x: np.ndarray) -> np.ndarray:
588
+ return safer_memory(x[:H_raw, :W_raw, ...])
589
+
590
+ return safer_memory(img_padded), remove_pad
591
+
592
+
593
+ def composite_rgba_over_white_keep_alpha(inp_u8: np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]:
594
+ if inp_u8.ndim == 3 and inp_u8.shape[2] == 4:
595
+ rgba = inp_u8.astype(np.uint8)
596
+ rgb = rgba[:, :, 0:3].astype(np.float32)
597
+ a = (rgba[:, :, 3:4].astype(np.float32) / 255.0)
598
+ rgb_white = (rgb * a + 255.0 * (1.0 - a)).clip(0, 255).astype(np.uint8)
599
+ alpha_u8 = rgba[:, :, 3].copy()
600
+ return rgb_white, alpha_u8
601
+ return HWC3(inp_u8), None
602
+
603
+
604
+ def apply_alpha_then_black_background(depth_rgb_u8: np.ndarray, alpha_u8: np.ndarray) -> np.ndarray:
605
+ depth_rgb_u8 = HWC3(depth_rgb_u8)
606
+ a = (alpha_u8.astype(np.float32) / 255.0)[:, :, None]
607
+ out = (depth_rgb_u8.astype(np.float32) * a).clip(0, 255).astype(np.uint8)
608
+ return out
609
+
610
+
611
+ def comfy_tensor_to_u8(img: torch.Tensor) -> np.ndarray:
612
+ if img.ndim == 4:
613
+ img = img[0]
614
+ arr = img.detach().cpu().float().clamp(0, 1).numpy()
615
+ u8 = (arr * 255.0).round().astype(np.uint8)
616
+ return u8
617
+
618
+
619
+ def u8_to_comfy_tensor(img_u8: np.ndarray) -> torch.Tensor:
620
+ img_u8 = HWC3(img_u8)
621
+ t = torch.from_numpy(img_u8.astype(np.float32) / 255.0)
622
+ return t.unsqueeze(0)
623
+
624
+
625
+ def _try_load_pipeline(model_source: str, device: torch.device):
626
+ if pipeline is None:
627
+ raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}")
628
+
629
+ key = (model_source, str(device))
630
+ with _PIPE_LOCK:
631
+ if key in _PIPE_CACHE:
632
+ return _PIPE_CACHE[key]
633
+
634
+ p = pipeline(task="depth-estimation", model=model_source)
635
+ try:
636
+ p.model = p.model.to(device)
637
+ p.device = device
638
+ except Exception:
639
+ pass
640
+
641
+ _PIPE_CACHE[key] = p
642
+ return p
643
+
644
+
645
+ def get_depth_pipeline(device: torch.device):
646
+ if ensure_local_model_files():
647
+ try:
648
+ return _try_load_pipeline(str(MODEL_DIR), device)
649
+ except Exception:
650
+ pass
651
+ try:
652
+ return _try_load_pipeline(ZOE_FALLBACK_REPO_ID, device)
653
  except Exception:
654
  return None
 
655
 
656
 
657
+ def depth_estimate_zoe_style(
658
+ pipe,
659
+ input_rgb_u8: np.ndarray,
660
+ detect_resolution: int,
661
+ upscale_method: str = "INTER_CUBIC",
662
+ ) -> np.ndarray:
663
+ if detect_resolution == -1:
664
+ work_img, remove_pad = pad_only_to_64(input_rgb_u8, mode="edge")
665
+ else:
666
+ work_img, remove_pad = resize_image_with_pad_min_side(
667
+ input_rgb_u8,
668
+ int(detect_resolution),
669
+ upscale_method=upscale_method,
670
+ skip_hwc3=False,
671
+ mode="edge",
672
+ )
673
+
674
+ pil_image = Image.fromarray(work_img)
675
+
676
+ with torch.no_grad():
677
+ result = pipe(pil_image)
678
+ depth = result["depth"]
679
+
680
+ if isinstance(depth, Image.Image):
681
+ depth_array = np.array(depth, dtype=np.float32)
682
+ else:
683
+ depth_array = np.array(depth, dtype=np.float32)
684
+
685
+ vmin = float(np.percentile(depth_array, 2))
686
+ vmax = float(np.percentile(depth_array, 85))
687
+
688
+ depth_array = depth_array - vmin
689
+ denom = (vmax - vmin)
690
+ if abs(denom) < 1e-12:
691
+ denom = 1e-6
692
+ depth_array = depth_array / denom
693
+
694
+ depth_array = 1.0 - depth_array
695
+ depth_image = (depth_array * 255.0).clip(0, 255).astype(np.uint8)
696
+
697
+ detected_map = remove_pad(HWC3(depth_image))
698
+ return detected_map
699
+
700
+
701
+ def resize_to_original(depth_rgb_u8: np.ndarray, w0: int, h0: int) -> np.ndarray:
702
+ try:
703
+ import cv2
704
+ out = cv2.resize(depth_rgb_u8, (w0, h0), interpolation=cv2.INTER_LINEAR)
705
+ return out.astype(np.uint8)
706
+ except Exception:
707
+ pil = Image.fromarray(depth_rgb_u8)
708
+ pil = pil.resize((w0, h0), resample=Image.BILINEAR)
709
+ return np.array(pil, dtype=np.uint8)
710
+
711
+
712
+ def _salia_depth_execute(image: torch.Tensor, resolution: int = -1) -> torch.Tensor:
713
+ try:
714
+ device = model_management.get_torch_device()
715
+ except Exception:
716
+ device = torch.device("cpu")
717
+
718
+ pipe_obj = None
719
+ try:
720
+ pipe_obj = get_depth_pipeline(device)
721
+ except Exception:
722
+ pipe_obj = None
723
+
724
+ if pipe_obj is None:
725
+ return image
726
+
727
+ if image.ndim == 3:
728
+ image = image.unsqueeze(0)
729
+
730
+ outs = []
731
+ for i in range(image.shape[0]):
732
+ try:
733
+ h0 = int(image[i].shape[0])
734
+ w0 = int(image[i].shape[1])
735
+
736
+ inp_u8 = comfy_tensor_to_u8(image[i])
737
+
738
+ rgb_for_depth, alpha_u8 = composite_rgba_over_white_keep_alpha(inp_u8)
739
+ had_rgba = alpha_u8 is not None
740
+
741
+ depth_rgb = depth_estimate_zoe_style(
742
+ pipe=pipe_obj,
743
+ input_rgb_u8=rgb_for_depth,
744
+ detect_resolution=int(resolution),
745
+ upscale_method="INTER_CUBIC",
746
+ )
747
+
748
+ depth_rgb = resize_to_original(depth_rgb, w0=w0, h0=h0)
749
+
750
+ if had_rgba:
751
+ if alpha_u8.shape[0] != h0 or alpha_u8.shape[1] != w0:
752
+ try:
753
+ import cv2
754
+ alpha_u8 = cv2.resize(alpha_u8, (w0, h0), interpolation=cv2.INTER_LINEAR).astype(np.uint8)
755
+ except Exception:
756
+ pil_a = Image.fromarray(alpha_u8)
757
+ pil_a = pil_a.resize((w0, h0), resample=Image.BILINEAR)
758
+ alpha_u8 = np.array(pil_a, dtype=np.uint8)
759
+
760
+ depth_rgb = apply_alpha_then_black_background(depth_rgb, alpha_u8)
761
+
762
+ outs.append(u8_to_comfy_tensor(depth_rgb))
763
+ except Exception:
764
+ outs.append(image[i].unsqueeze(0))
765
+
766
+ return torch.cat(outs, dim=0)
767
+
768
+
769
+ def _salia_alpha_over_region(base: torch.Tensor, overlay_rgba: torch.Tensor, x: int, y: int) -> torch.Tensor:
770
+ if base.ndim != 4 or overlay_rgba.ndim != 4:
771
+ raise ValueError("base and overlay must be [B,H,W,C].")
772
+
773
+ B, H, W, C = base.shape
774
+ b2, sH, sW, c2 = overlay_rgba.shape
775
+ if c2 != 4:
776
+ raise ValueError("overlay_rgba must have 4 channels (RGBA).")
777
+ if sH != sW:
778
+ raise ValueError("overlay must be square.")
779
+ s = sH
780
+
781
+ if x < 0 or y < 0 or x + s > W or y + s > H:
782
+ raise ValueError(f"Square paste out of bounds. base={W}x{H}, paste at ({x},{y}) size={s}")
783
 
784
+ if b2 != B:
785
+ if b2 == 1 and B > 1:
786
+ overlay_rgba = overlay_rgba.expand(B, -1, -1, -1)
787
+ else:
788
+ raise ValueError("Batch mismatch between base and overlay.")
789
+
790
+ out = base.clone()
791
+
792
+ overlay_rgb = overlay_rgba[..., 0:3].clamp(0, 1)
793
+ overlay_a = overlay_rgba[..., 3:4].clamp(0, 1)
794
+
795
+ base_rgb = out[:, y:y + s, x:x + s, 0:3]
796
+ comp_rgb = overlay_rgb * overlay_a + base_rgb * (1.0 - overlay_a)
797
+ out[:, y:y + s, x:x + s, 0:3] = comp_rgb
798
+
799
+ if C == 4:
800
+ base_a = out[:, y:y + s, x:x + s, 3:4].clamp(0, 1)
801
+ comp_a = overlay_a + base_a * (1.0 - overlay_a)
802
+ out[:, y:y + s, x:x + s, 3:4] = comp_a
803
+
804
+ return out.clamp(0, 1)
805
+
806
+
807
+ _HARDCODED_CKPT_NAME = "SaliaHighlady_Speedy.safetensors"
808
+ _HARDCODED_CONTROLNET_NAME = "diffusion_pytorch_model_promax.safetensors"
809
+ _HARDCODED_CN_START = 0.00
810
+ _HARDCODED_CN_END = 1.00
811
+
812
+ _PASS1_SAMPLER_NAME = "dpmpp_2m_sde_heun_gpu"
813
+ _PASS1_SCHEDULER = "karras"
814
+ _PASS1_STEPS = 29
815
+ _PASS1_CFG = 2.6
816
+ _PASS1_CONTROLNET_STRENGTH = 0.33
817
+
818
+ _PASS2_SAMPLER_NAME = "res_multistep_ancestral_cfg_pp"
819
+ _PASS2_SCHEDULER = "karras"
820
+ _PASS2_STEPS = 30
821
+ _PASS2_CFG = 1.7
822
+ _PASS2_CONTROLNET_STRENGTH = 0.5
823
+
824
+
825
+ class Salia_ezpz_gated_Duo2:
826
  CATEGORY = "image/salia"
827
+ RETURN_TYPES = ("IMAGE", "IMAGE")
828
+ RETURN_NAMES = ("image", "image_cropped")
829
+ FUNCTION = "run"
830
+
831
+ @classmethod
832
+ def INPUT_TYPES(cls):
833
+ assets = _list_asset_pngs() or ["<no pngs found>"]
834
+ upscale_choices = ["1", "2", "4", "6", "8", "10", "12", "14", "16"]
835
+ return {
836
+ "required": {
837
+ "image": ("IMAGE",),
838
+ "trigger_string": ("STRING", {"default": ""}),
839
+ "X_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
840
+ "Y_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
841
+ "positive_prompt": ("STRING", {"default": "", "multiline": True}),
842
+ "negative_prompt": ("STRING", {"default": "", "multiline": True}),
843
+ "asset_image": (assets, {}),
844
+ "square_size_1": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
845
+ "upscale_factor_1": (upscale_choices, {"default": "4"}),
846
+ "denoise_1": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
847
+ "square_size_2": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
848
+ "upscale_factor_2": (upscale_choices, {"default": "4"}),
849
+ "denoise_2": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
850
+ }
851
+ }
852
+
853
+ def run(
854
+ self,
855
+ image: torch.Tensor,
856
+ trigger_string: str = "",
857
+ X_coord: int = 0,
858
+ Y_coord: int = 0,
859
+ positive_prompt: str = "",
860
+ negative_prompt: str = "",
861
+ asset_image: str = "",
862
+ square_size_1: int = 384,
863
+ upscale_factor_1: str = "4",
864
+ denoise_1: float = 0.35,
865
+ square_size_2: int = 384,
866
+ upscale_factor_2: str = "4",
867
+ denoise_2: float = 0.35,
868
+ ):
869
+ if image.ndim == 3:
870
+ image = image.unsqueeze(0)
871
+ if image.ndim != 4:
872
+ raise ValueError("Input image must be [B,H,W,C].")
873
+
874
+ B, H, W, C = image.shape
875
+ if C not in (3, 4):
876
+ raise ValueError("Input image must have 3 (RGB) or 4 (RGBA) channels.")
877
+
878
+ x = int(X_coord)
879
+ y = int(Y_coord)
880
+ s1 = int(square_size_1)
881
+ s2 = int(square_size_2)
882
+
883
+ def _validate_square_bounds(s: int, label: str):
884
+ if s <= 0:
885
+ raise ValueError(f"{label}: square_size must be > 0")
886
+ if x < 0 or y < 0 or x + s > W or y + s > H:
887
+ raise ValueError(f"{label}: out of bounds. image={W}x{H}, rect at ({x},{y}) size={s}")
888
+
889
+ def _validate_upscale(up: int, s: int, label: str):
890
+ if up not in (1, 2, 4, 6, 8, 10, 12, 14, 16):
891
+ raise ValueError(f"{label}: upscale_factor must be one of 1,2,4,6,8,10,12,14,16")
892
+ if ((s * up) % 8) != 0:
893
+ raise ValueError(f"{label}: square_size * upscale_factor must be divisible by 8 (VAE requirement).")
894
+
895
+ def _crop_square(img: torch.Tensor, s: int) -> torch.Tensor:
896
+ return img[:, y:y + s, x:x + s, :]
897
+
898
+ _validate_square_bounds(s2, "final crop (square_size_2)")
899
+
900
+ if trigger_string == "":
901
+ out2 = image
902
+ cropped = _crop_square(out2, s2)
903
+ return (out2, cropped)
904
+
905
+ _validate_square_bounds(s1, "pass1 (square_size_1)")
906
+ _validate_square_bounds(s2, "pass2 (square_size_2)")
907
+
908
+ up1 = int(upscale_factor_1)
909
+ up2 = int(upscale_factor_2)
910
+ _validate_upscale(up1, s1, "pass1")
911
+ _validate_upscale(up2, s2, "pass2")
912
+
913
+ d1 = float(max(0.0, min(1.0, denoise_1)))
914
+ d2 = float(max(0.0, min(1.0, denoise_2)))
915
+
916
+ if asset_image == "<no pngs found>":
917
+ raise FileNotFoundError("No PNGs found in assets/images for this plugin.")
918
+ _asset_img_unused, asset_mask = _load_asset_image_and_mask(asset_image)
919
+
920
+ if asset_mask.ndim == 2:
921
+ asset_mask = asset_mask.unsqueeze(0)
922
+ if asset_mask.ndim != 3:
923
+ raise ValueError("Asset mask must be [B,H,W].")
924
+
925
+ if asset_mask.shape[0] != B:
926
+ if asset_mask.shape[0] == 1 and B > 1:
927
+ asset_mask = asset_mask.expand(B, -1, -1)
928
+ else:
929
+ raise ValueError("Batch mismatch for asset mask vs input image batch.")
930
+
931
+ import nodes
932
+
933
+ try:
934
+ model, clip, vae = _load_checkpoint_cached(_HARDCODED_CKPT_NAME)
935
+ except Exception as e:
936
+ available = folder_paths.get_filename_list("checkpoints") or []
937
+ raise FileNotFoundError(
938
+ f"Hardcoded ckpt not found: '{_HARDCODED_CKPT_NAME}'. "
939
+ f"Put it in models/checkpoints. Available (first 50): {available[:50]}"
940
+ ) from e
941
+
942
+ try:
943
+ controlnet = _load_controlnet_cached(_HARDCODED_CONTROLNET_NAME)
944
+ except Exception as e:
945
+ available = folder_paths.get_filename_list("controlnet") or []
946
+ raise FileNotFoundError(
947
+ f"Hardcoded controlnet not found: '{_HARDCODED_CONTROLNET_NAME}'. "
948
+ f"Put it in models/controlnet. Available (first 50): {available[:50]}"
949
+ ) from e
950
+
951
+ pos_enc = nodes.CLIPTextEncode()
952
+ neg_enc = nodes.CLIPTextEncode()
953
+ pos_fn = getattr(pos_enc, pos_enc.FUNCTION)
954
+ neg_fn = getattr(neg_enc, neg_enc.FUNCTION)
955
+ (pos_cond,) = pos_fn(text=str(positive_prompt), clip=clip)
956
+ (neg_cond,) = neg_fn(text=str(negative_prompt), clip=clip)
957
+
958
+ cn_apply = nodes.ControlNetApplyAdvanced()
959
+ cn_fn = getattr(cn_apply, cn_apply.FUNCTION)
960
+ vae_enc = nodes.VAEEncode()
961
+ vae_enc_fn = getattr(vae_enc, vae_enc.FUNCTION)
962
+ ksampler = nodes.KSampler()
963
+ k_fn = getattr(ksampler, ksampler.FUNCTION)
964
+ vae_dec = nodes.VAEDecode()
965
+ vae_dec_fn = getattr(vae_dec, vae_dec.FUNCTION)
966
+
967
+ def _run_pass(
968
+ pass_index: int,
969
+ in_image: torch.Tensor,
970
+ s: int,
971
+ up: int,
972
+ denoise_v: float,
973
+ steps_v: int,
974
+ cfg_v: float,
975
+ sampler_v: str,
976
+ scheduler_v: str,
977
+ controlnet_strength_v: float,
978
+ ) -> torch.Tensor:
979
+ up_w = s * up
980
+ up_h = s * up
981
+
982
+ crop = in_image[:, y:y + s, x:x + s, :]
983
+ crop_rgb = crop[:, :, :, 0:3].contiguous()
984
+
985
+ depth_small = _salia_depth_execute(crop_rgb, resolution=s)
986
+ depth_up = _resize_image_lanczos(depth_small, up_w, up_h)
987
+
988
+ crop_up = _resize_image_lanczos(crop_rgb, up_w, up_h)
989
+
990
+ asset_mask_up = _resize_mask_lanczos(asset_mask, up_w, up_h)
991
+
992
+ pos_cn, neg_cn = cn_fn(
993
+ strength=float(controlnet_strength_v),
994
+ start_percent=float(_HARDCODED_CN_START),
995
+ end_percent=float(_HARDCODED_CN_END),
996
+ positive=pos_cond,
997
+ negative=neg_cond,
998
+ control_net=controlnet,
999
+ image=depth_up,
1000
+ vae=vae,
1001
+ )
1002
+
1003
+ (latent,) = vae_enc_fn(pixels=crop_up, vae=vae)
1004
+
1005
+ seed_material = (
1006
+ f"{_HARDCODED_CKPT_NAME}|{_HARDCODED_CONTROLNET_NAME}|{asset_image}|"
1007
+ f"pass={pass_index}|x={x}|y={y}|s={s}|up={up}|"
1008
+ f"steps={steps_v}|cfg={cfg_v}|sampler={sampler_v}|scheduler={scheduler_v}|denoise={denoise_v}|"
1009
+ f"cn_strength={controlnet_strength_v}|"
1010
+ f"{positive_prompt}|{negative_prompt}"
1011
+ ).encode("utf-8", errors="ignore")
1012
+ seed64 = int(hashlib.sha256(seed_material).hexdigest()[:16], 16)
1013
+
1014
+ (sampled_latent,) = k_fn(
1015
+ seed=seed64,
1016
+ steps=int(steps_v),
1017
+ cfg=float(cfg_v),
1018
+ sampler_name=str(sampler_v),
1019
+ scheduler=str(scheduler_v),
1020
+ denoise=float(denoise_v),
1021
+ model=model,
1022
+ positive=pos_cn,
1023
+ negative=neg_cn,
1024
+ latent_image=latent,
1025
+ )
1026
+
1027
+ (decoded_rgb,) = vae_dec_fn(samples=sampled_latent, vae=vae)
1028
+
1029
+ rgba_up = _rgb_to_rgba_with_comfy_mask(decoded_rgb, asset_mask_up)
1030
+ rgba_square = _resize_image_lanczos(rgba_up, s, s)
1031
+ out = _salia_alpha_over_region(in_image, rgba_square, x=x, y=y)
1032
+ return out
1033
+
1034
+ out1 = _run_pass(
1035
+ pass_index=1,
1036
+ in_image=image,
1037
+ s=s1,
1038
+ up=up1,
1039
+ denoise_v=d1,
1040
+ steps_v=_PASS1_STEPS,
1041
+ cfg_v=_PASS1_CFG,
1042
+ sampler_v=_PASS1_SAMPLER_NAME,
1043
+ scheduler_v=_PASS1_SCHEDULER,
1044
+ controlnet_strength_v=_PASS1_CONTROLNET_STRENGTH,
1045
+ )
1046
+
1047
+ out2 = _run_pass(
1048
+ pass_index=2,
1049
+ in_image=out1,
1050
+ s=s2,
1051
+ up=up2,
1052
+ denoise_v=d2,
1053
+ steps_v=_PASS2_STEPS,
1054
+ cfg_v=_PASS2_CFG,
1055
+ sampler_v=_PASS2_SAMPLER_NAME,
1056
+ scheduler_v=_PASS2_SCHEDULER,
1057
+ controlnet_strength_v=_PASS2_CONTROLNET_STRENGTH,
1058
+ )
1059
+
1060
+ cropped = out2[:, y:y + s2, x:x + s2, :]
1061
+ return (out2, cropped)
1062
+
1063
+
1064
+ # ======================================================================================
1065
+ # apply_segment_4 (standalone, embedded) - rename internal alpha paste helper to avoid clash
1066
+ # ======================================================================================
1067
+
1068
+ # Expects: <this_file_dir>/assets/images/*.png
1069
+ _AP4_ASSETS_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "assets", "images")
1070
+
1071
+
1072
+ def ap4_list_pngs() -> List[str]:
1073
+ if not os.path.isdir(_AP4_ASSETS_DIR):
1074
+ return []
1075
+ files: List[str] = []
1076
+ for root, _, fnames in os.walk(_AP4_ASSETS_DIR):
1077
+ for f in fnames:
1078
+ if f.lower().endswith(".png"):
1079
+ full = os.path.join(root, f)
1080
+ if os.path.isfile(full):
1081
+ rel = os.path.relpath(full, _AP4_ASSETS_DIR)
1082
+ files.append(rel.replace("\\", "/"))
1083
+ return sorted(files)
1084
+
1085
+
1086
+ def ap4_safe_path(filename: str) -> str:
1087
+ candidate = os.path.join(_AP4_ASSETS_DIR, filename)
1088
+ real_assets = os.path.realpath(_AP4_ASSETS_DIR)
1089
+ real_candidate = os.path.realpath(candidate)
1090
+ if not real_candidate.startswith(real_assets + os.sep) and real_candidate != real_assets:
1091
+ raise ValueError("Unsafe path (path traversal detected).")
1092
+ return real_candidate
1093
+
1094
+
1095
+ def ap4_file_hash(filename: str) -> str:
1096
+ path = ap4_safe_path(filename)
1097
+ h = hashlib.sha256()
1098
+ with open(path, "rb") as f:
1099
+ for chunk in iter(lambda: f.read(1024 * 1024), b""):
1100
+ h.update(chunk)
1101
+ return h.hexdigest()
1102
+
1103
+
1104
+ def ap4_load_image_from_assets(filename: str) -> Tuple[torch.Tensor, torch.Tensor]:
1105
+ path = ap4_safe_path(filename)
1106
+ i = Image.open(path)
1107
+ i = ImageOps.exif_transpose(i)
1108
+
1109
+ if i.mode == "I":
1110
+ i = i.point(lambda px: px * (1 / 255))
1111
+
1112
+ rgb = i.convert("RGB")
1113
+ rgb_np = np.array(rgb).astype(np.float32) / 255.0
1114
+ image = torch.from_numpy(rgb_np)[None, ...]
1115
+
1116
+ bands = i.getbands()
1117
+ if "A" in bands:
1118
+ a = np.array(i.getchannel("A")).astype(np.float32) / 255.0
1119
+ alpha = torch.from_numpy(a)
1120
+ else:
1121
+ l = np.array(i.convert("L")).astype(np.float32) / 255.0
1122
+ alpha = torch.from_numpy(l)
1123
+
1124
+ mask = 1.0 - alpha
1125
+ mask = mask.clamp(0.0, 1.0).unsqueeze(0)
1126
+ return image, mask
1127
+
1128
+
1129
+ def ap4_as_image(img: torch.Tensor) -> torch.Tensor:
1130
+ if not isinstance(img, torch.Tensor):
1131
+ raise TypeError("IMAGE must be a torch.Tensor")
1132
+ if img.dim() != 4:
1133
+ raise ValueError(f"Expected IMAGE shape [B,H,W,C], got {tuple(img.shape)}")
1134
+ if img.shape[-1] not in (3, 4):
1135
+ raise ValueError(f"Expected IMAGE channels 3 (RGB) or 4 (RGBA), got C={img.shape[-1]}")
1136
+ return img
1137
+
1138
+
1139
+ def ap4_as_mask(mask: torch.Tensor) -> torch.Tensor:
1140
+ if not isinstance(mask, torch.Tensor):
1141
+ raise TypeError("MASK must be a torch.Tensor")
1142
+ if mask.dim() == 2:
1143
+ mask = mask.unsqueeze(0)
1144
+ if mask.dim() != 3:
1145
+ raise ValueError(f"Expected MASK shape [B,H,W] or [H,W], got {tuple(mask.shape)}")
1146
+ return mask
1147
+
1148
+
1149
+ def ap4_ensure_rgba(img: torch.Tensor) -> torch.Tensor:
1150
+ img = ap4_as_image(img)
1151
+ if img.shape[-1] == 4:
1152
+ return img
1153
+ B, H, W, _ = img.shape
1154
+ alpha = torch.ones((B, H, W, 1), device=img.device, dtype=img.dtype)
1155
+ return torch.cat([img, alpha], dim=-1)
1156
+
1157
+
1158
+ def ap4_alpha_over_region(overlay: torch.Tensor, canvas: torch.Tensor, x: int, y: int) -> torch.Tensor:
1159
+ overlay = ap4_as_image(overlay)
1160
+ canvas = ap4_as_image(canvas)
1161
+
1162
+ if overlay.shape[0] != canvas.shape[0]:
1163
+ if overlay.shape[0] == 1 and canvas.shape[0] > 1:
1164
+ overlay = overlay.expand(canvas.shape[0], *overlay.shape[1:])
1165
+ elif canvas.shape[0] == 1 and overlay.shape[0] > 1:
1166
+ canvas = canvas.expand(overlay.shape[0], *canvas.shape[1:])
1167
+ else:
1168
+ raise ValueError(f"Batch mismatch: overlay {overlay.shape[0]} vs canvas {canvas.shape[0]}")
1169
+
1170
+ _, Hc, Wc, Cc = canvas.shape
1171
+ _, Ho, Wo, _ = overlay.shape
1172
+
1173
+ x = int(x)
1174
+ y = int(y)
1175
+
1176
+ out = canvas.clone()
1177
+
1178
+ x0c = max(0, x)
1179
+ y0c = max(0, y)
1180
+ x1c = min(Wc, x + Wo)
1181
+ y1c = min(Hc, y + Ho)
1182
+
1183
+ if x1c <= x0c or y1c <= y0c:
1184
+ return out
1185
+
1186
+ x0o = x0c - x
1187
+ y0o = y0c - y
1188
+ x1o = x0o + (x1c - x0c)
1189
+ y1o = y0o + (y1c - y0c)
1190
+
1191
+ canvas_region = out[:, y0c:y1c, x0c:x1c, :]
1192
+ overlay_region = overlay[:, y0o:y1o, x0o:x1o, :]
1193
+
1194
+ canvas_rgba = ap4_ensure_rgba(canvas_region)
1195
+ overlay_rgba = ap4_ensure_rgba(overlay_region)
1196
+
1197
+ over_rgb = overlay_rgba[..., :3].clamp(0.0, 1.0)
1198
+ over_a = overlay_rgba[..., 3:4].clamp(0.0, 1.0)
1199
+
1200
+ under_rgb = canvas_rgba[..., :3].clamp(0.0, 1.0)
1201
+ under_a = canvas_rgba[..., 3:4].clamp(0.0, 1.0)
1202
+
1203
+ over_pm = over_rgb * over_a
1204
+ under_pm = under_rgb * under_a
1205
+
1206
+ out_a = over_a + under_a * (1.0 - over_a)
1207
+ out_pm = over_pm + under_pm * (1.0 - over_a)
1208
+
1209
+ eps = 1e-6
1210
+ out_rgb = torch.where(out_a > eps, out_pm / (out_a + eps), torch.zeros_like(out_pm))
1211
+ out_rgb = out_rgb.clamp(0.0, 1.0)
1212
+ out_a = out_a.clamp(0.0, 1.0)
1213
+
1214
+ if Cc == 3:
1215
+ out[:, y0c:y1c, x0c:x1c, :] = out_rgb
1216
+ else:
1217
+ out[:, y0c:y1c, x0c:x1c, :] = torch.cat([out_rgb, out_a], dim=-1)
1218
+
1219
+ return out
1220
+
1221
+
1222
+ class AP4_AILab_MaskCombiner_Exact:
1223
+ def combine_masks(self, mask_1, mode="combine", mask_2=None, mask_3=None, mask_4=None):
1224
+ masks = [m for m in [mask_1, mask_2, mask_3, mask_4] if m is not None]
1225
+ if len(masks) <= 1:
1226
+ return (masks[0] if masks else torch.zeros((1, 64, 64), dtype=torch.float32),)
1227
+
1228
+ ref_shape = masks[0].shape
1229
+ masks = [self._resize_if_needed(m, ref_shape) for m in masks]
1230
+
1231
+ if mode == "combine":
1232
+ result = torch.maximum(masks[0], masks[1])
1233
+ for mask in masks[2:]:
1234
+ result = torch.maximum(result, mask)
1235
+ elif mode == "intersection":
1236
+ result = torch.minimum(masks[0], masks[1])
1237
+ else:
1238
+ result = torch.abs(masks[0] - masks[1])
1239
+
1240
+ return (torch.clamp(result, 0, 1),)
1241
+
1242
+ def _resize_if_needed(self, mask, target_shape):
1243
+ if mask.shape == target_shape:
1244
+ return mask
1245
+
1246
+ if len(mask.shape) == 2:
1247
+ mask = mask.unsqueeze(0)
1248
+ elif len(mask.shape) == 4:
1249
+ mask = mask.squeeze(1)
1250
+
1251
+ target_height = target_shape[-2] if len(target_shape) >= 2 else target_shape[0]
1252
+ target_width = target_shape[-1] if len(target_shape) >= 2 else target_shape[1]
1253
+
1254
+ resized_masks = []
1255
+ for i in range(mask.shape[0]):
1256
+ mask_np = mask[i].cpu().numpy()
1257
+ img = Image.fromarray((mask_np * 255).astype(np.uint8))
1258
+ img_resized = img.resize((target_width, target_height), Image.LANCZOS)
1259
+ mask_resized = np.array(img_resized).astype(np.float32) / 255.0
1260
+ resized_masks.append(torch.from_numpy(mask_resized))
1261
+
1262
+ return torch.stack(resized_masks)
1263
+
1264
+
1265
+ def ap4_resize_mask_comfy(alpha_mask: torch.Tensor, image_shape_hwc: Tuple[int, int, int]) -> torch.Tensor:
1266
+ H = int(image_shape_hwc[0])
1267
+ W = int(image_shape_hwc[1])
1268
+ return F.interpolate(
1269
+ alpha_mask.reshape((-1, 1, alpha_mask.shape[-2], alpha_mask.shape[-1])),
1270
+ size=(H, W),
1271
+ mode="bilinear",
1272
+ ).squeeze(1)
1273
+
1274
+
1275
+ def ap4_join_image_with_alpha_comfy(image: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor:
1276
+ image = ap4_as_image(image)
1277
+ alpha = ap4_as_mask(alpha)
1278
+ alpha = alpha.to(device=image.device, dtype=image.dtype)
1279
+
1280
+ batch_size = min(len(image), len(alpha))
1281
+ out_images = []
1282
+
1283
+ alpha_resized = 1.0 - ap4_resize_mask_comfy(alpha, image.shape[1:])
1284
+
1285
+ for i in range(batch_size):
1286
+ out_images.append(torch.cat((image[i][:, :, :3], alpha_resized[i].unsqueeze(2)), dim=2))
1287
+
1288
+ return torch.stack(out_images)
1289
+
1290
+
1291
+ def ap4_try_get_comfy_model_management():
1292
+ try:
1293
+ import comfy.model_management as mm # type: ignore
1294
+ return mm
1295
+ except Exception:
1296
+ return None
1297
+
1298
+
1299
+ def ap4_gaussian_kernel_1d(kernel_size: int, sigma: float, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
1300
+ center = (kernel_size - 1) / 2.0
1301
+ xs = torch.arange(kernel_size, device=device, dtype=dtype) - center
1302
+ kernel = torch.exp(-(xs * xs) / (2.0 * sigma * sigma))
1303
+ kernel = kernel / kernel.sum()
1304
+ return kernel
1305
+
1306
+
1307
+ def ap4_mask_blur(mask: torch.Tensor, amount: int = 8, device: str = "gpu") -> torch.Tensor:
1308
+ mask = ap4_as_mask(mask).clamp(0.0, 1.0)
1309
+
1310
+ if amount == 0:
1311
+ return mask
1312
+
1313
+ k = int(amount)
1314
+ if k % 2 == 0:
1315
+ k += 1
1316
+
1317
+ sigma = 0.3 * (((k - 1) * 0.5) - 1.0) + 0.8
1318
+
1319
+ mm = ap4_try_get_comfy_model_management()
1320
+
1321
+ if device == "gpu":
1322
+ if mm is not None:
1323
+ proc_device = mm.get_torch_device()
1324
+ else:
1325
+ proc_device = torch.device("cuda") if torch.cuda.is_available() else mask.device
1326
+ elif device == "cpu":
1327
+ proc_device = torch.device("cpu")
1328
+ else:
1329
+ proc_device = mask.device
1330
+
1331
+ out_device = mask.device
1332
+ if device in ("gpu", "cpu") and mm is not None:
1333
+ out_device = mm.intermediate_device()
1334
+
1335
+ orig_dtype = mask.dtype
1336
+ x = mask.to(device=proc_device, dtype=torch.float32)
1337
+
1338
+ _, H, W = x.shape
1339
+ pad = k // 2
1340
+
1341
+ pad_mode = "reflect" if (H > pad and W > pad and H > 1 and W > 1) else "replicate"
1342
+
1343
+ x4 = x.unsqueeze(1)
1344
+ x4 = F.pad(x4, (pad, pad, pad, pad), mode=pad_mode)
1345
+
1346
+ kern1d = ap4_gaussian_kernel_1d(k, sigma, device=proc_device, dtype=torch.float32)
1347
+ w_h = kern1d.view(1, 1, 1, k)
1348
+ w_v = kern1d.view(1, 1, k, 1)
1349
+
1350
+ x4 = F.conv2d(x4, w_h)
1351
+ x4 = F.conv2d(x4, w_v)
1352
+
1353
+ out = x4.squeeze(1).clamp(0.0, 1.0)
1354
+ return out.to(device=out_device, dtype=orig_dtype)
1355
+
1356
+
1357
+ def ap4_dilate_mask(mask: torch.Tensor, dilation: int = 3) -> torch.Tensor:
1358
+ mask = ap4_as_mask(mask).clamp(0.0, 1.0)
1359
+ dilation = int(dilation)
1360
+ if dilation == 0:
1361
+ return mask
1362
+
1363
+ k = abs(dilation)
1364
+ x = mask.unsqueeze(1)
1365
+
1366
+ if dilation > 0:
1367
+ y = F.max_pool2d(x, kernel_size=k, stride=1, padding=k // 2)
1368
+ else:
1369
+ y = -F.max_pool2d(-x, kernel_size=k, stride=1, padding=k // 2)
1370
+
1371
+ return y.squeeze(1).clamp(0.0, 1.0)
1372
+
1373
+
1374
+ def ap4_fill_holes_grayscale_numpy_heap(f: np.ndarray, connectivity: int = 8) -> np.ndarray:
1375
+ f = np.clip(f, 0.0, 1.0).astype(np.float32, copy=False)
1376
+ H, W = f.shape
1377
+ if H == 0 or W == 0:
1378
+ return f
1379
+
1380
+ cost = np.full((H, W), np.inf, dtype=np.float32)
1381
+ finalized = np.zeros((H, W), dtype=np.bool_)
1382
+ heap: List[Tuple[float, int, int]] = []
1383
+
1384
+ def push(y: int, x: int):
1385
+ c = float(f[y, x])
1386
+ if c < float(cost[y, x]):
1387
+ cost[y, x] = c
1388
+ heapq.heappush(heap, (c, y, x))
1389
+
1390
+ for x in range(W):
1391
+ push(0, x)
1392
+ if H > 1:
1393
+ push(H - 1, x)
1394
+ for y in range(H):
1395
+ push(y, 0)
1396
+ if W > 1:
1397
+ push(y, W - 1)
1398
+
1399
+ if connectivity == 4:
1400
+ neigh = [(-1, 0), (1, 0), (0, -1), (0, 1)]
1401
+ else:
1402
+ neigh = [(-1, -1), (-1, 0), (-1, 1),
1403
+ (0, -1), (0, 1),
1404
+ (1, -1), (1, 0), (1, 1)]
1405
+
1406
+ eps = 1e-8
1407
+ while heap:
1408
+ c, y, x = heapq.heappop(heap)
1409
+
1410
+ if finalized[y, x]:
1411
+ continue
1412
+ if c > float(cost[y, x]) + eps:
1413
+ continue
1414
+
1415
+ finalized[y, x] = True
1416
+
1417
+ for dy, dx in neigh:
1418
+ ny = y + dy
1419
+ nx = x + dx
1420
+ if ny < 0 or ny >= H or nx < 0 or nx >= W:
1421
+ continue
1422
+ if finalized[ny, nx]:
1423
+ continue
1424
+
1425
+ v = float(f[ny, nx])
1426
+ nc = c if c >= v else v
1427
+ if nc < float(cost[ny, nx]) - eps:
1428
+ cost[ny, nx] = nc
1429
+ heapq.heappush(heap, (nc, ny, nx))
1430
+
1431
+ return cost
1432
+
1433
+
1434
+ def ap4_fill_holes_mask(mask: torch.Tensor) -> torch.Tensor:
1435
+ mask = ap4_as_mask(mask).clamp(0.0, 1.0)
1436
+
1437
+ B, H, W = mask.shape
1438
+ device = mask.device
1439
+ dtype = mask.dtype
1440
+
1441
+ mask_np = np.ascontiguousarray(mask.detach().cpu().numpy().astype(np.float32, copy=False))
1442
+ filled_np = np.empty_like(mask_np)
1443
+
1444
+ try:
1445
+ from skimage.morphology import reconstruction # type: ignore
1446
+ footprint = np.ones((3, 3), dtype=bool)
1447
+
1448
+ for b in range(B):
1449
+ f = mask_np[b]
1450
+ seed = f.copy()
1451
+
1452
+ if H > 2 and W > 2:
1453
+ seed[1:-1, 1:-1] = 1.0
1454
+ else:
1455
+ seed[:, :] = 1.0
1456
+ seed[0, :] = f[0, :]
1457
+ seed[-1, :] = f[-1, :]
1458
+ seed[:, 0] = f[:, 0]
1459
+ seed[:, -1] = f[:, -1]
1460
+
1461
+ filled_np[b] = reconstruction(seed, f, method="erosion", footprint=footprint).astype(np.float32)
1462
+
1463
+ except Exception:
1464
+ for b in range(B):
1465
+ filled_np[b] = ap4_fill_holes_grayscale_numpy_heap(mask_np[b], connectivity=8)
1466
+
1467
+ out = torch.from_numpy(filled_np).to(device=device, dtype=dtype)
1468
+ return out.clamp(0.0, 1.0)
1469
+
1470
+
1471
+ class apply_segment_4:
1472
+ CATEGORY = "image/salia"
1473
+
1474
+ @classmethod
1475
+ def INPUT_TYPES(cls):
1476
+ choices = ap4_list_pngs() or ["<no pngs found>"]
1477
+ return {
1478
+ "required": {
1479
+ "mask": ("MASK",),
1480
+ "image": (choices, {}),
1481
+ "img": ("IMAGE",),
1482
+ "canvas": ("IMAGE",),
1483
+ "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
1484
+ "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
1485
+ }
1486
+ }
1487
+
1488
  RETURN_TYPES = ("IMAGE",)
1489
  RETURN_NAMES = ("Final_Image",)
1490
  FUNCTION = "run"
1491
 
1492
+ def run(self, mask, image, img, canvas, x, y):
1493
+ if image == "<no pngs found>":
1494
+ raise FileNotFoundError("No PNGs found in assets/images next to this node")
1495
+
1496
+ mask_in = ap4_as_mask(mask).clamp(0.0, 1.0)
1497
+
1498
+ blurred = ap4_mask_blur(mask_in, amount=8, device="gpu")
1499
+ dilated = ap4_dilate_mask(blurred, dilation=3)
1500
+ filled = ap4_fill_holes_mask(dilated)
1501
+
1502
+ inversed_mask = 1.0 - filled
1503
+
1504
+ _asset_img, loaded_mask = ap4_load_image_from_assets(image)
1505
+
1506
+ combiner = AP4_AILab_MaskCombiner_Exact()
1507
+
1508
+ inv_cpu = inversed_mask.detach().cpu()
1509
+ loaded_cpu = ap4_as_mask(loaded_mask).detach().cpu()
1510
+
1511
+ (alpha_mask,) = combiner.combine_masks(inv_cpu, mode="combine", mask_2=(1.0 - loaded_cpu))
1512
+ alpha_mask = torch.clamp(alpha_mask, 0.0, 1.0)
1513
+
1514
+ alpha_image = ap4_join_image_with_alpha_comfy(img, alpha_mask)
1515
+
1516
+ canvas = ap4_as_image(canvas)
1517
+ alpha_image = alpha_image.to(device=canvas.device, dtype=canvas.dtype)
1518
+ final = ap4_alpha_over_region(alpha_image, canvas, x, y)
1519
+
1520
+ return (final,)
1521
+
1522
  @classmethod
1523
+ def IS_CHANGED(cls, mask, image, img, canvas, x, y):
1524
+ if image == "<no pngs found>":
1525
+ return image
1526
+ return ap4_file_hash(image)
1527
 
1528
+ @classmethod
1529
+ def VALIDATE_INPUTS(cls, mask, image, img, canvas, x, y):
1530
+ if image == "<no pngs found>":
1531
+ return "No PNGs found in assets/images next to this node"
1532
+ try:
1533
+ path = ap4_safe_path(image)
1534
+ except Exception as e:
1535
+ return str(e)
1536
+ if not os.path.isfile(path):
1537
+ return f"File not found in assets/images: {image}"
1538
+ return True
1539
 
1540
+
1541
+ # ======================================================================================
1542
+ # Fused node: Salia_ezpz_gated_Duo2 -> SAM3Segment (hardcoded) -> apply_segment_4
1543
+ # ======================================================================================
1544
+
1545
+ class SAM3Segment_Salia:
1546
+ CATEGORY = "image/salia"
1547
+ RETURN_TYPES = ("IMAGE",)
1548
+ RETURN_NAMES = ("Final_Image",)
1549
+ FUNCTION = "run"
1550
+
1551
+ @classmethod
1552
+ def INPUT_TYPES(cls):
1553
+ # Use the exact dropdown sources of the embedded nodes
1554
+ salia_assets = _list_asset_pngs() or ["<no pngs found>"]
1555
+ ap4_assets = ap4_list_pngs() or ["<no pngs found>"]
1556
+ upscale_choices = ["1", "2", "4", "6", "8", "10", "12", "14", "16"]
1557
  return {
1558
  "required": {
1559
  "image": ("IMAGE",),
 
1562
  "X_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
1563
  "Y_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
1564
 
 
1565
  "positive_prompt": ("STRING", {"default": "", "multiline": True}),
1566
  "negative_prompt": ("STRING", {"default": "", "multiline": True}),
1567
  "prompt": ("STRING", {"default": "", "multiline": True, "placeholder": "SAM3 prompt"}),
1568
 
1569
+ "asset_image": (salia_assets, {}),
1570
+ "apply_asset_image": (ap4_assets, {}),
 
 
 
1571
 
 
1572
  "square_size_1": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
1573
  "upscale_factor_1": (upscale_choices, {"default": "4"}),
1574
  "denoise_1": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
1575
 
 
1576
  "square_size_2": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
1577
  "upscale_factor_2": (upscale_choices, {"default": "4"}),
1578
  "denoise_2": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
 
1580
  }
1581
 
1582
  def __init__(self):
 
1583
  self._sam3 = SAM3Segment()
1584
+ self._salia = Salia_ezpz_gated_Duo2()
1585
+ self._ap4 = apply_segment_4()
 
 
 
 
 
 
 
 
 
 
 
 
1586
 
1587
  def run(
1588
  self,
 
1602
  upscale_factor_2="4",
1603
  denoise_2=0.35,
1604
  ):
1605
+ # EXACT bypass: if trigger_string is empty, return input image as Final_Image
1606
  if trigger_string == "":
1607
  return (image,)
1608
 
1609
+ # 1) Pre-node: Salia_ezpz_gated_Duo2 -> image_cropped
1610
+ _out_image, image_cropped = self._salia.run(
 
 
 
 
 
 
 
1611
  image=image,
1612
  trigger_string=trigger_string,
1613
  X_coord=int(X_coord),
 
1623
  denoise_2=float(denoise_2),
1624
  )
1625
 
1626
+ # 2) Center: SAM3Segment with hardcoded settings on the CROPPED image
1627
  seg_image, seg_mask, _mask_image = self._sam3.segment(
1628
  image=image_cropped,
1629
  prompt=str(prompt),
 
1638
  background_color="#222222",
1639
  )
1640
 
1641
+ # 3) Post-node: apply_segment_4 onto ORIGINAL input canvas (not Duo2 output)
1642
+ (final_image,) = self._ap4.run(
 
1643
  mask=seg_mask,
1644
  image=str(apply_asset_image),
1645
  img=seg_image,
 
1651
  return (final_image,)
1652
 
1653
 
1654
+ # ======================================================================================
1655
+ # Node mappings (all nodes in this file)
1656
+ # ======================================================================================
1657
+
1658
  NODE_CLASS_MAPPINGS = {
1659
  "SAM3Segment": SAM3Segment,
1660
+ "Salia_ezpz_gated_Duo2": Salia_ezpz_gated_Duo2,
1661
+ "apply_segment_4": apply_segment_4,
1662
  "SAM3Segment_Salia": SAM3Segment_Salia,
1663
  }
1664
 
1665
  NODE_DISPLAY_NAME_MAPPINGS = {
1666
  "SAM3Segment": "SAM3 Segmentation (RMBG)",
1667
+ "Salia_ezpz_gated_Duo2": "Salia_ezpz_gated_Duo2",
1668
+ "apply_segment_4": "apply_segment_4",
1669
+ "SAM3Segment_Salia": "SAM3Segment_Salia (Duo2 → SAM3 → apply_segment_4)",
1670
+ }