import comfy import torch from dataclasses import dataclass import torch.nn as nn from comfy.model_patcher import ModelPatcher import comfy.ops from typing import Union import comfy.sample import latent_preview import comfy.utils T = torch.Tensor from .VisualStylePrompting.attention_functions import VisualStyleProcessor class ApplyVisualStylePrompting: @classmethod def INPUT_TYPES(s): return { "required": { "reference_image": ("IMAGE",), "reference_image_text": ("STRING", {"multiline": True}), "model": ("MODEL",), "clip": ("CLIP", ), "vae": ("VAE", ), "positive": ("CONDITIONING",), "negative": ("CONDITIONING", ), "enabled": ("BOOLEAN", {"default": True}), "denoise": ("FLOAT", {"default": 1., "min": 0., "max": 1., "step": 1e-2}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096,"step":2}) } } RETURN_TYPES = ("MODEL", "CONDITIONING","CONDITIONING", "LATENT") RETURN_NAMES = ("model", "positive", "negative", "latents") CATEGORY = "♾️Mixlab/Style" FUNCTION = "run" def run( self, reference_image, reference_image_text, model: comfy.model_patcher.ModelPatcher, clip, vae, positive, negative, enabled, denoise, batch_size=1 ): tokens = clip.tokenize(reference_image_text) cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) reference_image_prompt=[[cond, {"pooled_output": pooled}]] reference_image = reference_image.repeat(((batch_size+1)//2, 1,1,1)) self.model = model reference_latent = vae.encode(reference_image[:,:,:,:3]) for n, m in model.model.diffusion_model.named_modules(): if m.__class__.__name__ == "CrossAttention": processor = VisualStyleProcessor(m, enabled=enabled) setattr(m, 'forward', processor.visual_style_forward) conditioning_prompt = reference_image_prompt + positive negative_prompt = negative * 2 latents = torch.zeros_like(reference_latent) latents = torch.cat([latents] * 2) if denoise < 1.0: latents[::1] = reference_latent[:1] else: latents[::2] = reference_latent denoise_mask = torch.ones_like(latents)[:, :1, ...] * denoise denoise_mask[0] = 0. return (model, conditioning_prompt, negative_prompt, {"samples": latents, "noise_mask": denoise_mask}) def exists(val): return val is not None def default(val, d): if exists(val): return val return d class StyleAlignedArgs: def __init__(self, share_attn: str) -> None: self.adain_keys = "k" in share_attn self.adain_values = "v" in share_attn self.adain_queries = "q" in share_attn share_attention: bool = True adain_queries: bool = True adain_keys: bool = True adain_values: bool = True def expand_first( feat: T, scale=1.0, ) -> T: """ Expand the first element so it has the same shape as the rest of the batch. """ b = feat.shape[0] feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1) if scale == 1: feat_style = feat_style.expand(2, b // 2, *feat.shape[1:]) else: feat_style = feat_style.repeat(1, b // 2, 1, 1, 1) feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1) return feat_style.reshape(*feat.shape) def concat_first(feat: T, dim=2, scale=1.0) -> T: """ concat the the feature and the style feature expanded above """ feat_style = expand_first(feat, scale=scale) return torch.cat((feat, feat_style), dim=dim) def calc_mean_std(feat, eps: float = 1e-5) -> "tuple[T, T]": feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt() feat_mean = feat.mean(dim=-2, keepdims=True) return feat_mean, feat_std def adain(feat: T) -> T: feat_mean, feat_std = calc_mean_std(feat) feat_style_mean = expand_first(feat_mean) feat_style_std = expand_first(feat_std) feat = (feat - feat_mean) / feat_std feat = feat * feat_style_std + feat_style_mean return feat class SharedAttentionProcessor: def __init__(self, args: StyleAlignedArgs, scale: float): self.args = args self.scale = scale def __call__(self, q, k, v, extra_options): if self.args.adain_queries: q = adain(q) if self.args.adain_keys: k = adain(k) if self.args.adain_values: v = adain(v) if self.args.share_attention: k = concat_first(k, -2, scale=self.scale) v = concat_first(v, -2) return q, k, v def get_norm_layers( layer: nn.Module, norm_layers_: "dict[str, list[Union[nn.GroupNorm, nn.LayerNorm]]]", share_layer_norm: bool, share_group_norm: bool, ): if isinstance(layer, nn.LayerNorm) and share_layer_norm: norm_layers_["layer"].append(layer) if isinstance(layer, nn.GroupNorm) and share_group_norm: norm_layers_["group"].append(layer) else: for child_layer in layer.children(): get_norm_layers( child_layer, norm_layers_, share_layer_norm, share_group_norm ) def register_norm_forward( norm_layer: Union[nn.GroupNorm, nn.LayerNorm], ) -> Union[nn.GroupNorm, nn.LayerNorm]: if not hasattr(norm_layer, "orig_forward"): setattr(norm_layer, "orig_forward", norm_layer.forward) orig_forward = norm_layer.orig_forward def forward_(hidden_states: T) -> T: n = hidden_states.shape[-2] hidden_states = concat_first(hidden_states, dim=-2) hidden_states = orig_forward(hidden_states) # type: ignore return hidden_states[..., :n, :] norm_layer.forward = forward_ # type: ignore return norm_layer def register_shared_norm( model: ModelPatcher, share_group_norm: bool = True, share_layer_norm: bool = True, ): norm_layers = {"group": [], "layer": []} get_norm_layers(model.model, norm_layers, share_layer_norm, share_group_norm) print( f"Patching {len(norm_layers['group'])} group norms, {len(norm_layers['layer'])} layer norms." ) return [register_norm_forward(layer) for layer in norm_layers["group"]] + [ register_norm_forward(layer) for layer in norm_layers["layer"] ] SHARE_NORM_OPTIONS = ["both", "group", "layer", "disabled"] SHARE_ATTN_OPTIONS = ["q+k", "q+k+v", "disabled"] class StyleAlignedSampleReferenceLatents: @classmethod def INPUT_TYPES(s): return {"required": { "reference_image": ("IMAGE",), "positive": ("CONDITIONING",), "negative": ("CONDITIONING", ), "model": ("MODEL",), "vae": ("VAE", ), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), "scheduler": (comfy.samplers.KSampler.SCHEDULERS.reverse(), ), "denoise": ("FLOAT", {"default": 1, "min": 0.0, "max": 1.0, "step": 0.01}), } } RETURN_TYPES = ("STEP_LATENTS","LATENT") RETURN_NAMES = ("ref_latents", "noised_output") FUNCTION = "run" # CATEGORY = "style_aligned" CATEGORY = "♾️Mixlab/Style" def run(self, reference_image, positive, negative, model, vae, seed, steps, cfg,scheduler,denoise): # TODO noise_mask? def vae_encode_crop_pixels(pixels): x = (pixels.shape[1] // 8) * 8 y = (pixels.shape[2] // 8) * 8 if pixels.shape[1] != x or pixels.shape[2] != y: x_offset = (pixels.shape[1] % 8) // 2 y_offset = (pixels.shape[2] % 8) // 2 pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] return pixels pixels=vae_encode_crop_pixels(reference_image) t = vae.encode(pixels[:,:,:,:3]) latent_image = {"samples":t} noise_seed=seed sampler_name="ddim" sampler = comfy.samplers.sampler_object(sampler_name) total_steps = steps if denoise < 1.0: total_steps = int(steps/denoise) comfy.model_management.load_models_gpu([model]) sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu() sigmas = sigmas[-(steps + 1):] sigmas = sigmas.flip(0) if sigmas[0] == 0: sigmas[0] = 0.0001 latent = latent_image latent_image = latent["samples"] noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") noise_mask = None if "noise_mask" in latent: noise_mask = latent["noise_mask"] ref_latents = [] def callback(step: int, x0: T, x: T, steps: int): ref_latents.insert(0, x[0]) disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) out = latent.copy() out["samples"] = samples out_noised = out ref_latents = torch.stack(ref_latents) return (ref_latents, out_noised) class StyleAlignedReferenceSampler: @classmethod def INPUT_TYPES(cls): return { "required": { "ref_latents": ("STEP_LATENTS",), "reference_image_text": ("STRING", {"multiline": True}), "model": ("MODEL",), "clip": ("CLIP", ), "positive": ("CONDITIONING",), "negative": ("CONDITIONING",), "share_norm": (SHARE_NORM_OPTIONS,), "share_attn": (SHARE_ATTN_OPTIONS,), "scale": ("FLOAT", {"default": 1, "min": 0, "max": 2.0, "step": 0.01}), "batch_size": ("INT", {"default": 2, "min": 1, "max": 8, "step": 1}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }, } RETURN_TYPES = ("LATENT", "LATENT") RETURN_NAMES = ("output", "denoised_output") FUNCTION = "patch" # CATEGORY = "style_aligned" CATEGORY = "♾️Mixlab/Style" def patch( self, ref_latents, reference_image_text, model, clip, positive, negative, share_norm, share_attn, scale, batch_size, seed,steps,cfg,scheduler,denoise ) -> "tuple[dict, dict]": m = model.clone() # ref_latents = vae.encode(reference_image[:,:,:,:3]) tokens = clip.tokenize(reference_image_text) cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) ref_positive=[[cond, {"pooled_output": pooled}]] noise_seed=seed total_steps = steps if denoise < 1.0: total_steps = int(steps/denoise) # comfy.model_management.load_models_gpu([model]) sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu() sigmas = sigmas[-(steps + 1):] sampler_name="ddim" sampler = comfy.samplers.sampler_object(sampler_name) args = StyleAlignedArgs(share_attn) # Concat batch with style latent style_latent_tensor = ref_latents[0].unsqueeze(0) height, width = style_latent_tensor.shape[-2:] latent_t = torch.zeros( [batch_size, 4, height, width], device=ref_latents.device ) latent = {"samples": latent_t} noise = comfy.sample.prepare_noise(latent_t, noise_seed) latent_t = torch.cat((style_latent_tensor, latent_t), dim=0) ref_noise = torch.zeros_like(noise[0]).unsqueeze(0) noise = torch.cat((ref_noise, noise), dim=0) x0_output = {} preview_callback = latent_preview.prepare_callback(m, sigmas.shape[-1] - 1, x0_output) # Replace first latent with the corresponding reference latent after each step def callback(step: int, x0: T, x: T, steps: int): preview_callback(step, x0, x, steps) if (step + 1 < steps): # 当ref_latents的step不够时 if step+1>len(ref_latents)-1: step=len(ref_latents)-2 x[0] = ref_latents[step+1] x0[0] = ref_latents[step+1] # Register shared norms share_group_norm = share_norm in ["group", "both"] share_layer_norm = share_norm in ["layer", "both"] register_shared_norm(m, share_group_norm, share_layer_norm) # Patch cross attn m.set_model_attn1_patch(SharedAttentionProcessor(args, scale)) # Add reference conditioning to batch batched_condition = [] for i,condition in enumerate(positive): additional = condition[1].copy() batch_with_reference = torch.cat([ref_positive[i][0], condition[0].repeat([batch_size] + [1] * len(condition[0].shape[1:]))], dim=0) if 'pooled_output' in additional and 'pooled_output' in ref_positive[i][1]: # combine pooled output pooled_output = torch.cat([ref_positive[i][1]['pooled_output'], additional['pooled_output'].repeat([batch_size] + [1] * len(additional['pooled_output'].shape[1:]))], dim=0) additional['pooled_output'] = pooled_output if 'control' in additional: if 'control' in ref_positive[i][1]: # combine control conditioning control_hint = torch.cat([ref_positive[i][1]['control'].cond_hint_original, additional['control'].cond_hint_original.repeat([batch_size] + [1] * len(additional['control'].cond_hint_original.shape[1:]))], dim=0) cloned_controlnet = additional['control'].copy() cloned_controlnet.set_cond_hint(control_hint, strength=additional['control'].strength, timestep_percent_range=additional['control'].timestep_percent_range) additional['control'] = cloned_controlnet else: # add zeros for first in batch control_hint = torch.cat([torch.zeros_like(additional['control'].cond_hint_original), additional['control'].cond_hint_original.repeat([batch_size] + [1] * len(additional['control'].cond_hint_original.shape[1:]))], dim=0) cloned_controlnet = additional['control'].copy() cloned_controlnet.set_cond_hint(control_hint, strength=additional['control'].strength, timestep_percent_range=additional['control'].timestep_percent_range) additional['control'] = cloned_controlnet batched_condition.append([batch_with_reference, additional]) disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED samples = comfy.sample.sample_custom( m, noise, cfg, sampler, sigmas, batched_condition, negative, latent_t, callback=callback, disable_pbar=disable_pbar, seed=noise_seed, ) # remove reference image samples = samples[1:] out = latent.copy() out["samples"] = samples if "x0" in x0_output: out_denoised = latent.copy() x0 = x0_output["x0"][1:] out_denoised["samples"] = m.model.process_latent_out(x0.cpu()) else: out_denoised = out return (out, out_denoised) class StyleAlignedBatchAlign: @classmethod def INPUT_TYPES(cls): return { "required": { "model": ("MODEL",), "share_norm": (SHARE_NORM_OPTIONS,), "share_attn": (SHARE_ATTN_OPTIONS,), "scale": ("FLOAT", {"default": 1, "min": 0, "max": 1.0, "step": 0.1}), } } RETURN_TYPES = ("MODEL",) FUNCTION = "patch" # CATEGORY = "style_aligned" CATEGORY = "♾️Mixlab/Style" def patch( self, model: ModelPatcher, share_norm: str, share_attn: str, scale: float, ): m = model.clone() share_group_norm = share_norm in ["group", "both"] share_layer_norm = share_norm in ["layer", "both"] register_shared_norm(model, share_group_norm, share_layer_norm) args = StyleAlignedArgs(share_attn) m.set_model_attn1_patch(SharedAttentionProcessor(args, scale)) return (m,)