| |
| |
| import torch |
| import comfy.model_management |
| import comfy.lora |
| import copy |
| from typing import Optional |
| from enum import Enum |
| from comfy.utils import load_torch_file |
| from comfy.conds import CONDRegular |
| from comfy_extras.nodes_compositing import JoinImageWithAlpha |
| try: |
| from .model import ModelPatcher, TransparentVAEDecoder, calculate_weight_adjust_channel |
| except: |
| ModelPatcher, TransparentVAEDecoder, calculate_weight_adjust_channel = None, None, None |
| from .attension_sharing import AttentionSharingPatcher |
| from ...config import LAYER_DIFFUSION, LAYER_DIFFUSION_DIR, LAYER_DIFFUSION_VAE |
| from ...libs.utils import to_lora_patch_dict, get_local_filepath, get_sd_version |
|
|
| load_layer_model_state_dict = load_torch_file |
| class LayerMethod(Enum): |
| FG_ONLY_ATTN = "Attention Injection" |
| FG_ONLY_CONV = "Conv Injection" |
| FG_TO_BLEND = "Foreground" |
| FG_BLEND_TO_BG = "Foreground to Background" |
| BG_TO_BLEND = "Background" |
| BG_BLEND_TO_FG = "Background to Foreground" |
| EVERYTHING = "Everything" |
|
|
| class LayerDiffuse: |
|
|
| def __init__(self) -> None: |
| self.vae_transparent_decoder = None |
| self.frames = 1 |
|
|
| def get_layer_diffusion_method(self, method, has_blend_latent): |
| method = LayerMethod(method) |
| if method == LayerMethod.BG_TO_BLEND and has_blend_latent: |
| method = LayerMethod.BG_BLEND_TO_FG |
| elif method == LayerMethod.FG_TO_BLEND and has_blend_latent: |
| method = LayerMethod.FG_BLEND_TO_BG |
| return method |
|
|
| def apply_layer_c_concat(self, cond, uncond, c_concat): |
| def write_c_concat(cond): |
| new_cond = [] |
| for t in cond: |
| n = [t[0], t[1].copy()] |
| if "model_conds" not in n[1]: |
| n[1]["model_conds"] = {} |
| n[1]["model_conds"]["c_concat"] = CONDRegular(c_concat) |
| new_cond.append(n) |
| return new_cond |
|
|
| return (write_c_concat(cond), write_c_concat(uncond)) |
|
|
| def apply_layer_diffusion(self, model, method, weight, samples, blend_samples, positive, negative, image=None, additional_cond=(None, None, None)): |
| control_img: Optional[torch.TensorType] = None |
| sd_version = get_sd_version(model) |
| model_url = LAYER_DIFFUSION[method.value][sd_version]["model_url"] |
|
|
| if image is not None: |
| image = image.movedim(-1, 1) |
|
|
| try: |
| if hasattr(comfy.lora, "calculate_weight"): |
| comfy.lora.calculate_weight = calculate_weight_adjust_channel(comfy.lora.calculate_weight) |
| else: |
| ModelPatcher.calculate_weight = calculate_weight_adjust_channel(ModelPatcher.calculate_weight) |
| except: |
| pass |
|
|
| if method in [LayerMethod.FG_ONLY_CONV, LayerMethod.FG_ONLY_ATTN] and sd_version == 'sd1': |
| self.frames = 1 |
| elif method in [LayerMethod.BG_TO_BLEND, LayerMethod.FG_TO_BLEND, LayerMethod.BG_BLEND_TO_FG, LayerMethod.FG_BLEND_TO_BG] and sd_version == 'sd1': |
| self.frames = 2 |
| batch_size, _, height, width = samples['samples'].shape |
| if batch_size % 2 != 0: |
| raise Exception(f"The batch size should be a multiple of 2. 批次大小需为2的倍数") |
| control_img = image |
| elif method == LayerMethod.EVERYTHING and sd_version == 'sd1': |
| batch_size, _, height, width = samples['samples'].shape |
| self.frames = 3 |
| if batch_size % 3 != 0: |
| raise Exception(f"The batch size should be a multiple of 3. 批次大小需为3的倍数") |
| if model_url is None: |
| raise Exception(f"{method.value} is not supported for {sd_version} model") |
|
|
| model_path = get_local_filepath(model_url, LAYER_DIFFUSION_DIR) |
| layer_lora_state_dict = load_layer_model_state_dict(model_path) |
| work_model = model.clone() |
| if sd_version == 'sd1': |
| patcher = AttentionSharingPatcher( |
| work_model, self.frames, use_control=control_img is not None |
| ) |
| patcher.load_state_dict(layer_lora_state_dict, strict=True) |
| if control_img is not None: |
| patcher.set_control(control_img) |
| else: |
| layer_lora_patch_dict = to_lora_patch_dict(layer_lora_state_dict) |
| work_model.add_patches(layer_lora_patch_dict, weight) |
|
|
| |
| if method in [LayerMethod.FG_ONLY_ATTN, LayerMethod.FG_ONLY_CONV]: |
| samp_model = work_model |
| elif sd_version == 'sdxl': |
| if method in [LayerMethod.BG_TO_BLEND, LayerMethod.FG_TO_BLEND]: |
| c_concat = model.model.latent_format.process_in(samples["samples"]) |
| else: |
| c_concat = model.model.latent_format.process_in(torch.cat([samples["samples"], blend_samples["samples"]], dim=1)) |
| samp_model, positive, negative = (work_model,) + self.apply_layer_c_concat(positive, negative, c_concat) |
| elif sd_version == 'sd1': |
| if method in [LayerMethod.BG_TO_BLEND, LayerMethod.BG_BLEND_TO_FG]: |
| additional_cond = (additional_cond[0], None) |
| elif method in [LayerMethod.FG_TO_BLEND, LayerMethod.FG_BLEND_TO_BG]: |
| additional_cond = (additional_cond[1], None) |
|
|
| work_model.model_options.setdefault("transformer_options", {}) |
| work_model.model_options["transformer_options"]["cond_overwrite"] = [ |
| cond[0][0] if cond is not None else None |
| for cond in additional_cond |
| ] |
| samp_model = work_model |
|
|
| return samp_model, positive, negative |
|
|
| def join_image_with_alpha(self, image, alpha): |
| out = image.movedim(-1, 1) |
| if out.shape[1] == 3: |
| out = torch.cat([out, torch.ones_like(out[:, :1, :, :])], dim=1) |
| for i in range(out.shape[0]): |
| out[i, 3, :, :] = alpha |
| return out.movedim(1, -1) |
|
|
| def image_to_alpha(self, image, latent): |
| pixel = image.movedim(-1, 1) |
| decoded = [] |
| sub_batch_size = 16 |
| for start_idx in range(0, latent.shape[0], sub_batch_size): |
| decoded.append( |
| self.vae_transparent_decoder.decode_pixel( |
| pixel[start_idx: start_idx + sub_batch_size], |
| latent[start_idx: start_idx + sub_batch_size], |
| ) |
| ) |
| pixel_with_alpha = torch.cat(decoded, dim=0) |
| |
| pixel_with_alpha = pixel_with_alpha.movedim(1, -1) |
| image = pixel_with_alpha[..., 1:] |
| alpha = pixel_with_alpha[..., 0] |
|
|
| alpha = 1.0 - alpha |
| try: |
| new_images, = JoinImageWithAlpha().execute(image, alpha) |
| except: |
| new_images, = JoinImageWithAlpha().join_image_with_alpha(image, alpha) |
| return new_images, alpha |
|
|
| def make_3d_mask(self, mask): |
| if len(mask.shape) == 4: |
| return mask.squeeze(0) |
|
|
| elif len(mask.shape) == 2: |
| return mask.unsqueeze(0) |
|
|
| return mask |
|
|
| def masks_to_list(self, masks): |
| if masks is None: |
| empty_mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") |
| return ([empty_mask],) |
|
|
| res = [] |
|
|
| for mask in masks: |
| res.append(mask) |
|
|
| return [self.make_3d_mask(x) for x in res] |
|
|
| def layer_diffusion_decode(self, layer_diffusion_method, latent, blend_samples, samp_images, model): |
| alpha = [] |
| if layer_diffusion_method is not None: |
| sd_version = get_sd_version(model) |
| if sd_version not in ['sdxl', 'sd1']: |
| raise Exception(f"Only SDXL and SD1.5 model supported for Layer Diffusion") |
| method = self.get_layer_diffusion_method(layer_diffusion_method, blend_samples is not None) |
| sd15_allow = True if sd_version == 'sd1' and method in [LayerMethod.FG_ONLY_ATTN, LayerMethod.EVERYTHING, LayerMethod.BG_TO_BLEND, LayerMethod.BG_BLEND_TO_FG] else False |
| sdxl_allow = True if sd_version == 'sdxl' and method in [LayerMethod.FG_ONLY_CONV, LayerMethod.FG_ONLY_ATTN, LayerMethod.BG_BLEND_TO_FG] else False |
| if sdxl_allow or sd15_allow: |
| if self.vae_transparent_decoder is None: |
| model_url = LAYER_DIFFUSION_VAE['decode'][sd_version]["model_url"] |
| if model_url is None: |
| raise Exception(f"{method.value} is not supported for {sd_version} model") |
| decoder_file = get_local_filepath(model_url, LAYER_DIFFUSION_DIR) |
| self.vae_transparent_decoder = TransparentVAEDecoder( |
| load_torch_file(decoder_file), |
| device=comfy.model_management.get_torch_device(), |
| dtype=(torch.float16 if comfy.model_management.should_use_fp16() else torch.float32), |
| ) |
| if method in [LayerMethod.EVERYTHING, LayerMethod.BG_BLEND_TO_FG, LayerMethod.BG_TO_BLEND]: |
| new_images = [] |
| sliced_samples = copy.copy({"samples": latent}) |
| for index in range(len(samp_images)): |
| if index % self.frames == 0: |
| img = samp_images[index::self.frames] |
| alpha_images, _alpha = self.image_to_alpha(img, sliced_samples["samples"][index::self.frames]) |
| alpha.append(self.make_3d_mask(_alpha[0])) |
| new_images.append(alpha_images[0]) |
| else: |
| new_images.append(samp_images[index]) |
| else: |
| new_images, alpha = self.image_to_alpha(samp_images, latent) |
| else: |
| new_images = samp_images |
| else: |
| new_images = samp_images |
|
|
|
|
| return (new_images, samp_images, alpha) |