Upload salia_detailer_ezpz.py
Browse files- salia_detailer_ezpz.py +85 -57
salia_detailer_ezpz.py
CHANGED
|
@@ -22,7 +22,7 @@ except Exception as e:
|
|
| 22 |
|
| 23 |
|
| 24 |
# -------------------------------------------------------------------------------------
|
| 25 |
-
# Global caches (
|
| 26 |
# -------------------------------------------------------------------------------------
|
| 27 |
|
| 28 |
_CKPT_CACHE: Dict[str, Tuple[Any, Any, Any]] = {}
|
|
@@ -32,19 +32,19 @@ _CN_LOCK = threading.Lock()
|
|
| 32 |
|
| 33 |
|
| 34 |
# -------------------------------------------------------------------------------------
|
| 35 |
-
# Plugin root detection (
|
| 36 |
# -------------------------------------------------------------------------------------
|
| 37 |
|
| 38 |
def _find_plugin_root() -> Path:
|
| 39 |
"""
|
| 40 |
Walk upwards from this file until we find an 'assets' folder.
|
| 41 |
-
|
| 42 |
"""
|
| 43 |
here = Path(__file__).resolve()
|
| 44 |
-
for parent in [here.parent] + list(here.parents)[:
|
| 45 |
if (parent / "assets").is_dir():
|
| 46 |
return parent
|
| 47 |
-
# fallback: typical
|
| 48 |
return here.parent.parent
|
| 49 |
|
| 50 |
|
|
@@ -135,6 +135,48 @@ def _resize_mask_lanczos(mask: torch.Tensor, w: int, h: int) -> torch.Tensor:
|
|
| 135 |
return torch.cat(outs, dim=0)
|
| 136 |
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
# -------------------------------------------------------------------------------------
|
| 139 |
# Core lazy loaders (checkpoint + controlnet), cached globally
|
| 140 |
# -------------------------------------------------------------------------------------
|
|
@@ -202,10 +244,13 @@ def _safe_asset_path(asset_rel_path: str) -> Path:
|
|
| 202 |
|
| 203 |
base = img_dir.resolve()
|
| 204 |
rel = Path(asset_rel_path)
|
|
|
|
| 205 |
if rel.is_absolute():
|
| 206 |
raise ValueError("Absolute paths are not allowed for asset_image.")
|
| 207 |
|
| 208 |
full = (base / rel).resolve()
|
|
|
|
|
|
|
| 209 |
if base != full and base not in full.parents:
|
| 210 |
raise ValueError(f"Invalid asset path (path traversal blocked): {asset_rel_path}")
|
| 211 |
|
|
@@ -222,37 +267,31 @@ def _load_asset_image_and_mask(asset_rel_path: str) -> Tuple[torch.Tensor, torch
|
|
| 222 |
Returns (IMAGE, MASK) in ComfyUI formats.
|
| 223 |
|
| 224 |
Mask semantics: match ComfyUI core LoadImage:
|
| 225 |
-
-
|
| 226 |
-
-
|
| 227 |
"""
|
| 228 |
p = _safe_asset_path(asset_rel_path)
|
| 229 |
|
| 230 |
im = Image.open(p)
|
| 231 |
im = ImageOps.exif_transpose(im)
|
| 232 |
|
| 233 |
-
had_alpha = ("A" in im.getbands())
|
| 234 |
rgba = im.convert("RGBA")
|
| 235 |
rgb = rgba.convert("RGB")
|
| 236 |
|
| 237 |
rgb_arr = np.array(rgb).astype(np.float32) / 255.0 # [H,W,3]
|
| 238 |
img_t = torch.from_numpy(rgb_arr)[None, ...]
|
| 239 |
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
mask = 1.0 - alpha
|
| 243 |
-
else:
|
| 244 |
-
h, w = rgb.size[1], rgb.size[0]
|
| 245 |
-
mask = np.zeros((h, w), dtype=np.float32)
|
| 246 |
|
| 247 |
mask_t = torch.from_numpy(mask)[None, ...]
|
| 248 |
return img_t, mask_t
|
| 249 |
|
| 250 |
|
| 251 |
# -------------------------------------------------------------------------------------
|
| 252 |
-
# Salia_Depth (INLINED
|
| 253 |
# -------------------------------------------------------------------------------------
|
| 254 |
|
| 255 |
-
# Local model path: assets/depth
|
| 256 |
MODEL_DIR = PLUGIN_ROOT / "assets" / "depth"
|
| 257 |
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
| 258 |
|
|
@@ -264,6 +303,9 @@ REQUIRED_FILES = {
|
|
| 264 |
|
| 265 |
ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti"
|
| 266 |
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
def _have_required_files() -> bool:
|
| 269 |
return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys())
|
|
@@ -420,26 +462,24 @@ def u8_to_comfy_tensor(img_u8: np.ndarray) -> torch.Tensor:
|
|
| 420 |
return t.unsqueeze(0) # [1,H,W,C]
|
| 421 |
|
| 422 |
|
| 423 |
-
_PIPE_CACHE: Dict[Tuple[str, str], Any] = {} # (model_source, device_str) -> pipeline
|
| 424 |
-
|
| 425 |
-
|
| 426 |
def _try_load_pipeline(model_source: str, device: torch.device):
|
| 427 |
if pipeline is None:
|
| 428 |
raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}")
|
| 429 |
|
| 430 |
key = (model_source, str(device))
|
| 431 |
-
|
| 432 |
-
|
|
|
|
| 433 |
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
|
| 441 |
-
|
| 442 |
-
|
| 443 |
|
| 444 |
|
| 445 |
def get_depth_pipeline(device: torch.device):
|
|
@@ -511,28 +551,24 @@ def resize_to_original(depth_rgb_u8: np.ndarray, w0: int, h0: int) -> np.ndarray
|
|
| 511 |
|
| 512 |
def _salia_depth_execute(image: torch.Tensor, resolution: int = -1) -> torch.Tensor:
|
| 513 |
"""
|
| 514 |
-
Internal callable version of
|
| 515 |
input: IMAGE [B,H,W,3 or 4]
|
| 516 |
output: IMAGE [B,H,W,3]
|
| 517 |
"""
|
| 518 |
-
# Get torch device
|
| 519 |
try:
|
| 520 |
device = model_management.get_torch_device()
|
| 521 |
except Exception:
|
| 522 |
device = torch.device("cpu")
|
| 523 |
|
| 524 |
-
|
| 525 |
-
pipe = None
|
| 526 |
try:
|
| 527 |
-
|
| 528 |
except Exception:
|
| 529 |
-
|
| 530 |
|
| 531 |
-
|
| 532 |
-
if pipe is None:
|
| 533 |
return image
|
| 534 |
|
| 535 |
-
# Batch support
|
| 536 |
if image.ndim == 3:
|
| 537 |
image = image.unsqueeze(0)
|
| 538 |
|
|
@@ -544,22 +580,18 @@ def _salia_depth_execute(image: torch.Tensor, resolution: int = -1) -> torch.Ten
|
|
| 544 |
|
| 545 |
inp_u8 = comfy_tensor_to_u8(image[i])
|
| 546 |
|
| 547 |
-
# RGBA rule (pre)
|
| 548 |
rgb_for_depth, alpha_u8 = composite_rgba_over_white_keep_alpha(inp_u8)
|
| 549 |
had_rgba = alpha_u8 is not None
|
| 550 |
|
| 551 |
-
# Depth
|
| 552 |
depth_rgb = depth_estimate_zoe_style(
|
| 553 |
-
pipe=
|
| 554 |
input_rgb_u8=rgb_for_depth,
|
| 555 |
detect_resolution=int(resolution),
|
| 556 |
upscale_method="INTER_CUBIC",
|
| 557 |
)
|
| 558 |
|
| 559 |
-
# Resize back to original size
|
| 560 |
depth_rgb = resize_to_original(depth_rgb, w0=w0, h0=h0)
|
| 561 |
|
| 562 |
-
# RGBA rule (post)
|
| 563 |
if had_rgba:
|
| 564 |
if alpha_u8.shape[0] != h0 or alpha_u8.shape[1] != w0:
|
| 565 |
try:
|
|
@@ -602,7 +634,6 @@ def _alpha_over_region(base: torch.Tensor, overlay_rgba: torch.Tensor, x: int, y
|
|
| 602 |
if x < 0 or y < 0 or x + s > W or y + s > H:
|
| 603 |
raise ValueError(f"Square paste out of bounds. base={W}x{H}, paste at ({x},{y}) size={s}")
|
| 604 |
|
| 605 |
-
# Match batch
|
| 606 |
if b2 != B:
|
| 607 |
if b2 == 1 and B > 1:
|
| 608 |
overlay_rgba = overlay_rgba.expand(B, -1, -1, -1)
|
|
@@ -618,7 +649,6 @@ def _alpha_over_region(base: torch.Tensor, overlay_rgba: torch.Tensor, x: int, y
|
|
| 618 |
comp_rgb = overlay_rgb * overlay_a + base_rgb * (1.0 - overlay_a)
|
| 619 |
out[:, y:y + s, x:x + s, 0:3] = comp_rgb
|
| 620 |
|
| 621 |
-
# If base has alpha, composite alpha too
|
| 622 |
if C == 4:
|
| 623 |
base_a = out[:, y:y + s, x:x + s, 3:4].clamp(0, 1)
|
| 624 |
comp_a = overlay_a + base_a * (1.0 - overlay_a)
|
|
@@ -703,6 +733,7 @@ class Salia_OneNode_WorkflowSquare:
|
|
| 703 |
scheduler: str,
|
| 704 |
denoise: float,
|
| 705 |
):
|
|
|
|
| 706 |
if image.ndim == 3:
|
| 707 |
image = image.unsqueeze(0)
|
| 708 |
if image.ndim != 4:
|
|
@@ -719,6 +750,7 @@ class Salia_OneNode_WorkflowSquare:
|
|
| 719 |
up = int(upscale_factor)
|
| 720 |
if up not in (1, 2, 4, 6, 8, 10, 12, 14, 16):
|
| 721 |
raise ValueError("upscale_factor must be one of: 1,2,4,6,8,10,12,14,16")
|
|
|
|
| 722 |
if s <= 0:
|
| 723 |
raise ValueError("square_size must be > 0")
|
| 724 |
if x < 0 or y < 0 or x + s > W or y + s > H:
|
|
@@ -727,6 +759,7 @@ class Salia_OneNode_WorkflowSquare:
|
|
| 727 |
up_w = s * up
|
| 728 |
up_h = s * up
|
| 729 |
|
|
|
|
| 730 |
if (up_w % 8) != 0 or (up_h % 8) != 0:
|
| 731 |
raise ValueError("square_size * upscale_factor must be divisible by 8 (required by VAE pipeline).")
|
| 732 |
|
|
@@ -735,18 +768,18 @@ class Salia_OneNode_WorkflowSquare:
|
|
| 735 |
if end_p < start_p:
|
| 736 |
start_p, end_p = end_p, start_p
|
| 737 |
|
| 738 |
-
# 1) Crop
|
| 739 |
crop = image[:, y:y + s, x:x + s, :]
|
| 740 |
crop_rgb = crop[:, :, :, 0:3].contiguous()
|
| 741 |
|
| 742 |
-
# 2) Depth (inline Salia_Depth) then Lanczos
|
| 743 |
depth_small = _salia_depth_execute(crop_rgb, resolution=s)
|
| 744 |
depth_up = _resize_image_lanczos(depth_small, up_w, up_h)
|
| 745 |
|
| 746 |
-
# 3) Upscale crop for VAE
|
| 747 |
crop_up = _resize_image_lanczos(crop_rgb, up_w, up_h)
|
| 748 |
|
| 749 |
-
# 4) Load asset mask
|
| 750 |
if asset_image == "<no pngs found>":
|
| 751 |
raise FileNotFoundError("No PNGs found in assets/images for this plugin.")
|
| 752 |
_asset_img_unused, asset_mask = _load_asset_image_and_mask(asset_image)
|
|
@@ -826,18 +859,13 @@ class Salia_OneNode_WorkflowSquare:
|
|
| 826 |
latent_image=latent,
|
| 827 |
)
|
| 828 |
|
| 829 |
-
# 10) VAE Decode
|
| 830 |
vae_dec = nodes.VAEDecode()
|
| 831 |
vae_dec_fn = getattr(vae_dec, vae_dec.FUNCTION)
|
| 832 |
(decoded_rgb,) = vae_dec_fn(samples=sampled_latent, vae=vae)
|
| 833 |
|
| 834 |
-
# 11) JoinImageWithAlpha
|
| 835 |
-
|
| 836 |
-
join_fn = getattr(join, join.FUNCTION)
|
| 837 |
-
try:
|
| 838 |
-
(rgba_up,) = join_fn(image=decoded_rgb, alpha=asset_mask_up)
|
| 839 |
-
except TypeError:
|
| 840 |
-
(rgba_up,) = join_fn(image=decoded_rgb, mask=asset_mask_up)
|
| 841 |
|
| 842 |
# 12) Downscale RGBA back to crop size
|
| 843 |
rgba_square = _resize_image_lanczos(rgba_up, s, s)
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
# -------------------------------------------------------------------------------------
|
| 25 |
+
# Global caches (checkpoint + controlnet) so using the node multiple times won't reload
|
| 26 |
# -------------------------------------------------------------------------------------
|
| 27 |
|
| 28 |
_CKPT_CACHE: Dict[str, Tuple[Any, Any, Any]] = {}
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
# -------------------------------------------------------------------------------------
|
| 35 |
+
# Plugin root detection (works whether file is in plugin root or nodes/)
|
| 36 |
# -------------------------------------------------------------------------------------
|
| 37 |
|
| 38 |
def _find_plugin_root() -> Path:
|
| 39 |
"""
|
| 40 |
Walk upwards from this file until we find an 'assets' folder.
|
| 41 |
+
Robust against hyphen/underscore package naming and different file placement.
|
| 42 |
"""
|
| 43 |
here = Path(__file__).resolve()
|
| 44 |
+
for parent in [here.parent] + list(here.parents)[:12]:
|
| 45 |
if (parent / "assets").is_dir():
|
| 46 |
return parent
|
| 47 |
+
# fallback: typical nodes/<file>.py
|
| 48 |
return here.parent.parent
|
| 49 |
|
| 50 |
|
|
|
|
| 135 |
return torch.cat(outs, dim=0)
|
| 136 |
|
| 137 |
|
| 138 |
+
# -------------------------------------------------------------------------------------
|
| 139 |
+
# ✅ ComfyUI 0.5.1 FIX: Manual JoinImageWithAlpha equivalent
|
| 140 |
+
# -------------------------------------------------------------------------------------
|
| 141 |
+
|
| 142 |
+
def _rgb_to_rgba_with_comfy_mask(rgb: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
| 143 |
+
"""
|
| 144 |
+
Make RGBA from:
|
| 145 |
+
rgb: IMAGE [B,H,W,3] float [0..1]
|
| 146 |
+
mask: MASK [B,H,W] float [0..1] (Comfy convention: 1=masked/transparent)
|
| 147 |
+
Output:
|
| 148 |
+
rgba: IMAGE [B,H,W,4] where alpha = 1 - mask (1=opaque, 0=transparent)
|
| 149 |
+
"""
|
| 150 |
+
if rgb.ndim == 3:
|
| 151 |
+
rgb = rgb.unsqueeze(0)
|
| 152 |
+
if mask.ndim == 2:
|
| 153 |
+
mask = mask.unsqueeze(0)
|
| 154 |
+
|
| 155 |
+
if rgb.ndim != 4 or rgb.shape[-1] != 3:
|
| 156 |
+
raise ValueError(f"rgb must be [B,H,W,3], got {tuple(rgb.shape)}")
|
| 157 |
+
if mask.ndim != 3:
|
| 158 |
+
raise ValueError(f"mask must be [B,H,W], got {tuple(mask.shape)}")
|
| 159 |
+
|
| 160 |
+
# Batch match
|
| 161 |
+
if mask.shape[0] != rgb.shape[0]:
|
| 162 |
+
if mask.shape[0] == 1 and rgb.shape[0] > 1:
|
| 163 |
+
mask = mask.expand(rgb.shape[0], -1, -1)
|
| 164 |
+
else:
|
| 165 |
+
raise ValueError("Batch mismatch between rgb and mask.")
|
| 166 |
+
|
| 167 |
+
# Size match
|
| 168 |
+
if mask.shape[1] != rgb.shape[1] or mask.shape[2] != rgb.shape[2]:
|
| 169 |
+
raise ValueError(
|
| 170 |
+
f"Mask size mismatch. rgb={rgb.shape[2]}x{rgb.shape[1]} mask={mask.shape[2]}x{mask.shape[1]}"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
mask = mask.to(device=rgb.device, dtype=rgb.dtype).clamp(0, 1)
|
| 174 |
+
alpha = (1.0 - mask).unsqueeze(-1).clamp(0, 1) # [B,H,W,1]
|
| 175 |
+
|
| 176 |
+
rgba = torch.cat([rgb.clamp(0, 1), alpha], dim=-1) # [B,H,W,4]
|
| 177 |
+
return rgba
|
| 178 |
+
|
| 179 |
+
|
| 180 |
# -------------------------------------------------------------------------------------
|
| 181 |
# Core lazy loaders (checkpoint + controlnet), cached globally
|
| 182 |
# -------------------------------------------------------------------------------------
|
|
|
|
| 244 |
|
| 245 |
base = img_dir.resolve()
|
| 246 |
rel = Path(asset_rel_path)
|
| 247 |
+
|
| 248 |
if rel.is_absolute():
|
| 249 |
raise ValueError("Absolute paths are not allowed for asset_image.")
|
| 250 |
|
| 251 |
full = (base / rel).resolve()
|
| 252 |
+
|
| 253 |
+
# path traversal protection
|
| 254 |
if base != full and base not in full.parents:
|
| 255 |
raise ValueError(f"Invalid asset path (path traversal blocked): {asset_rel_path}")
|
| 256 |
|
|
|
|
| 267 |
Returns (IMAGE, MASK) in ComfyUI formats.
|
| 268 |
|
| 269 |
Mask semantics: match ComfyUI core LoadImage:
|
| 270 |
+
- alpha is RGBA alpha channel normalized to [0..1]
|
| 271 |
+
- mask = 1 - alpha
|
| 272 |
"""
|
| 273 |
p = _safe_asset_path(asset_rel_path)
|
| 274 |
|
| 275 |
im = Image.open(p)
|
| 276 |
im = ImageOps.exif_transpose(im)
|
| 277 |
|
|
|
|
| 278 |
rgba = im.convert("RGBA")
|
| 279 |
rgb = rgba.convert("RGB")
|
| 280 |
|
| 281 |
rgb_arr = np.array(rgb).astype(np.float32) / 255.0 # [H,W,3]
|
| 282 |
img_t = torch.from_numpy(rgb_arr)[None, ...]
|
| 283 |
|
| 284 |
+
alpha = np.array(rgba.getchannel("A")).astype(np.float32) / 255.0 # [H,W]
|
| 285 |
+
mask = 1.0 - alpha # Comfy MASK convention
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
mask_t = torch.from_numpy(mask)[None, ...]
|
| 288 |
return img_t, mask_t
|
| 289 |
|
| 290 |
|
| 291 |
# -------------------------------------------------------------------------------------
|
| 292 |
+
# Salia_Depth (INLINED, no imports from other files)
|
| 293 |
# -------------------------------------------------------------------------------------
|
| 294 |
|
|
|
|
| 295 |
MODEL_DIR = PLUGIN_ROOT / "assets" / "depth"
|
| 296 |
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
| 297 |
|
|
|
|
| 303 |
|
| 304 |
ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti"
|
| 305 |
|
| 306 |
+
_PIPE_CACHE: Dict[Tuple[str, str], Any] = {} # (model_source, device_str) -> pipeline
|
| 307 |
+
_PIPE_LOCK = threading.Lock()
|
| 308 |
+
|
| 309 |
|
| 310 |
def _have_required_files() -> bool:
|
| 311 |
return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys())
|
|
|
|
| 462 |
return t.unsqueeze(0) # [1,H,W,C]
|
| 463 |
|
| 464 |
|
|
|
|
|
|
|
|
|
|
| 465 |
def _try_load_pipeline(model_source: str, device: torch.device):
|
| 466 |
if pipeline is None:
|
| 467 |
raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}")
|
| 468 |
|
| 469 |
key = (model_source, str(device))
|
| 470 |
+
with _PIPE_LOCK:
|
| 471 |
+
if key in _PIPE_CACHE:
|
| 472 |
+
return _PIPE_CACHE[key]
|
| 473 |
|
| 474 |
+
p = pipeline(task="depth-estimation", model=model_source)
|
| 475 |
+
try:
|
| 476 |
+
p.model = p.model.to(device)
|
| 477 |
+
p.device = device
|
| 478 |
+
except Exception:
|
| 479 |
+
pass
|
| 480 |
|
| 481 |
+
_PIPE_CACHE[key] = p
|
| 482 |
+
return p
|
| 483 |
|
| 484 |
|
| 485 |
def get_depth_pipeline(device: torch.device):
|
|
|
|
| 551 |
|
| 552 |
def _salia_depth_execute(image: torch.Tensor, resolution: int = -1) -> torch.Tensor:
|
| 553 |
"""
|
| 554 |
+
Internal callable version of Salia_Depth:
|
| 555 |
input: IMAGE [B,H,W,3 or 4]
|
| 556 |
output: IMAGE [B,H,W,3]
|
| 557 |
"""
|
|
|
|
| 558 |
try:
|
| 559 |
device = model_management.get_torch_device()
|
| 560 |
except Exception:
|
| 561 |
device = torch.device("cpu")
|
| 562 |
|
| 563 |
+
pipe_obj = None
|
|
|
|
| 564 |
try:
|
| 565 |
+
pipe_obj = get_depth_pipeline(device)
|
| 566 |
except Exception:
|
| 567 |
+
pipe_obj = None
|
| 568 |
|
| 569 |
+
if pipe_obj is None:
|
|
|
|
| 570 |
return image
|
| 571 |
|
|
|
|
| 572 |
if image.ndim == 3:
|
| 573 |
image = image.unsqueeze(0)
|
| 574 |
|
|
|
|
| 580 |
|
| 581 |
inp_u8 = comfy_tensor_to_u8(image[i])
|
| 582 |
|
|
|
|
| 583 |
rgb_for_depth, alpha_u8 = composite_rgba_over_white_keep_alpha(inp_u8)
|
| 584 |
had_rgba = alpha_u8 is not None
|
| 585 |
|
|
|
|
| 586 |
depth_rgb = depth_estimate_zoe_style(
|
| 587 |
+
pipe=pipe_obj,
|
| 588 |
input_rgb_u8=rgb_for_depth,
|
| 589 |
detect_resolution=int(resolution),
|
| 590 |
upscale_method="INTER_CUBIC",
|
| 591 |
)
|
| 592 |
|
|
|
|
| 593 |
depth_rgb = resize_to_original(depth_rgb, w0=w0, h0=h0)
|
| 594 |
|
|
|
|
| 595 |
if had_rgba:
|
| 596 |
if alpha_u8.shape[0] != h0 or alpha_u8.shape[1] != w0:
|
| 597 |
try:
|
|
|
|
| 634 |
if x < 0 or y < 0 or x + s > W or y + s > H:
|
| 635 |
raise ValueError(f"Square paste out of bounds. base={W}x{H}, paste at ({x},{y}) size={s}")
|
| 636 |
|
|
|
|
| 637 |
if b2 != B:
|
| 638 |
if b2 == 1 and B > 1:
|
| 639 |
overlay_rgba = overlay_rgba.expand(B, -1, -1, -1)
|
|
|
|
| 649 |
comp_rgb = overlay_rgb * overlay_a + base_rgb * (1.0 - overlay_a)
|
| 650 |
out[:, y:y + s, x:x + s, 0:3] = comp_rgb
|
| 651 |
|
|
|
|
| 652 |
if C == 4:
|
| 653 |
base_a = out[:, y:y + s, x:x + s, 3:4].clamp(0, 1)
|
| 654 |
comp_a = overlay_a + base_a * (1.0 - overlay_a)
|
|
|
|
| 733 |
scheduler: str,
|
| 734 |
denoise: float,
|
| 735 |
):
|
| 736 |
+
# Normalize input to [B,H,W,C]
|
| 737 |
if image.ndim == 3:
|
| 738 |
image = image.unsqueeze(0)
|
| 739 |
if image.ndim != 4:
|
|
|
|
| 750 |
up = int(upscale_factor)
|
| 751 |
if up not in (1, 2, 4, 6, 8, 10, 12, 14, 16):
|
| 752 |
raise ValueError("upscale_factor must be one of: 1,2,4,6,8,10,12,14,16")
|
| 753 |
+
|
| 754 |
if s <= 0:
|
| 755 |
raise ValueError("square_size must be > 0")
|
| 756 |
if x < 0 or y < 0 or x + s > W or y + s > H:
|
|
|
|
| 759 |
up_w = s * up
|
| 760 |
up_h = s * up
|
| 761 |
|
| 762 |
+
# VAE/UNet path likes multiples of 8
|
| 763 |
if (up_w % 8) != 0 or (up_h % 8) != 0:
|
| 764 |
raise ValueError("square_size * upscale_factor must be divisible by 8 (required by VAE pipeline).")
|
| 765 |
|
|
|
|
| 768 |
if end_p < start_p:
|
| 769 |
start_p, end_p = end_p, start_p
|
| 770 |
|
| 771 |
+
# 1) Crop square
|
| 772 |
crop = image[:, y:y + s, x:x + s, :]
|
| 773 |
crop_rgb = crop[:, :, :, 0:3].contiguous()
|
| 774 |
|
| 775 |
+
# 2) Depth (inline Salia_Depth) then upscale with Lanczos
|
| 776 |
depth_small = _salia_depth_execute(crop_rgb, resolution=s)
|
| 777 |
depth_up = _resize_image_lanczos(depth_small, up_w, up_h)
|
| 778 |
|
| 779 |
+
# 3) Upscale crop for VAE encode
|
| 780 |
crop_up = _resize_image_lanczos(crop_rgb, up_w, up_h)
|
| 781 |
|
| 782 |
+
# 4) Load asset mask and resize
|
| 783 |
if asset_image == "<no pngs found>":
|
| 784 |
raise FileNotFoundError("No PNGs found in assets/images for this plugin.")
|
| 785 |
_asset_img_unused, asset_mask = _load_asset_image_and_mask(asset_image)
|
|
|
|
| 859 |
latent_image=latent,
|
| 860 |
)
|
| 861 |
|
| 862 |
+
# 10) VAE Decode -> RGB
|
| 863 |
vae_dec = nodes.VAEDecode()
|
| 864 |
vae_dec_fn = getattr(vae_dec, vae_dec.FUNCTION)
|
| 865 |
(decoded_rgb,) = vae_dec_fn(samples=sampled_latent, vae=vae)
|
| 866 |
|
| 867 |
+
# 11) ✅ Manual "JoinImageWithAlpha"
|
| 868 |
+
rgba_up = _rgb_to_rgba_with_comfy_mask(decoded_rgb, asset_mask_up)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 869 |
|
| 870 |
# 12) Downscale RGBA back to crop size
|
| 871 |
rgba_square = _resize_image_lanczos(rgba_up, s, s)
|