| BACKEND = None
|
|
|
| try:
|
| from comfy.model_patcher import ModelPatcher
|
| from comfy.samplers import calc_cond_batch
|
| from comfy.ldm.modules.attention import optimized_attention
|
| from .pag_utils import (
|
| parse_unet_blocks,
|
| perturbed_attention,
|
| rescale_guidance,
|
| seg_attention_wrapper,
|
| snf_guidance,
|
| )
|
|
|
| try:
|
| from comfy.model_patcher import set_model_options_patch_replace
|
| except ImportError:
|
| from .pag_utils import set_model_options_patch_replace
|
|
|
| BACKEND = "ComfyUI"
|
| except ImportError:
|
| from pag_utils import (
|
| parse_unet_blocks,
|
| set_model_options_patch_replace,
|
| perturbed_attention,
|
| rescale_guidance,
|
| seg_attention_wrapper,
|
| snf_guidance,
|
| )
|
|
|
| try:
|
| from ldm_patched.modules.model_patcher import ModelPatcher
|
| from ldm_patched.modules.samplers import calc_cond_uncond_batch
|
| from ldm_patched.ldm.modules.attention import optimized_attention
|
|
|
| BACKEND = "reForge"
|
| except ImportError:
|
| from backend.patcher.base import ModelPatcher
|
| from backend.sampling.sampling_function import calc_cond_uncond_batch
|
| from backend.attention import attention_function as optimized_attention
|
|
|
| BACKEND = "Forge"
|
|
|
|
|
| class PerturbedAttention:
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {
|
| "required": {
|
| "model": ("MODEL",),
|
| "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
|
| "adaptive_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "round": 0.0001}),
|
| "unet_block": (["input", "middle", "output"], {"default": "middle"}),
|
| "unet_block_id": ("INT", {"default": 0}),
|
| "sigma_start": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
|
| "sigma_end": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
|
| "rescale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
| "rescale_mode": (["full", "partial", "snf"], {"default": "full"}),
|
| },
|
| "optional": {
|
| "unet_block_list": ("STRING", {"default": ""}),
|
| },
|
| }
|
|
|
| RETURN_TYPES = ("MODEL",)
|
| FUNCTION = "patch"
|
|
|
| CATEGORY = "model_patches/unet"
|
|
|
| def patch(
|
| self,
|
| model: ModelPatcher,
|
| scale: float = 3.0,
|
| adaptive_scale: float = 0.0,
|
| unet_block: str = "middle",
|
| unet_block_id: int = 0,
|
| sigma_start: float = -1.0,
|
| sigma_end: float = -1.0,
|
| rescale: float = 0.0,
|
| rescale_mode: str = "full",
|
| unet_block_list: str = "",
|
| ):
|
| m = model.clone()
|
|
|
| sigma_start = float("inf") if sigma_start < 0 else sigma_start
|
| if unet_block_list:
|
| blocks = parse_unet_blocks(model, unet_block_list)
|
| else:
|
| blocks = [(unet_block, unet_block_id, None)]
|
|
|
| def post_cfg_function(args):
|
| """CFG+PAG"""
|
| model = args["model"]
|
| cond_pred = args["cond_denoised"]
|
| uncond_pred = args["uncond_denoised"]
|
| cond = args["cond"]
|
| cfg_result = args["denoised"]
|
| sigma = args["sigma"]
|
| model_options = args["model_options"].copy()
|
| x = args["input"]
|
|
|
| signal_scale = scale
|
| if adaptive_scale > 0:
|
| t = model.model_sampling.timestep(sigma)[0].item()
|
| signal_scale -= scale * (adaptive_scale**4) * (1000 - t)
|
| if signal_scale < 0:
|
| signal_scale = 0
|
|
|
| if signal_scale == 0 or not (sigma_end < sigma[0] <= sigma_start):
|
| return cfg_result
|
|
|
|
|
| for block in blocks:
|
| layer, number, index = block
|
| model_options = set_model_options_patch_replace(model_options, perturbed_attention, "attn1", layer, number, index)
|
|
|
| if BACKEND == "ComfyUI":
|
| (pag_cond_pred,) = calc_cond_batch(model, [cond], x, sigma, model_options)
|
| if BACKEND in {"Forge", "reForge"}:
|
| (pag_cond_pred, _) = calc_cond_uncond_batch(model, cond, None, x, sigma, model_options)
|
|
|
| pag = (cond_pred - pag_cond_pred) * signal_scale
|
|
|
| if rescale_mode == "snf":
|
| if uncond_pred.any():
|
| return uncond_pred + snf_guidance(cfg_result - uncond_pred, pag)
|
| return cfg_result + pag
|
|
|
| return cfg_result + rescale_guidance(pag, cond_pred, cfg_result, rescale, rescale_mode)
|
|
|
| m.set_model_sampler_post_cfg_function(post_cfg_function, rescale_mode == "snf")
|
|
|
| return (m,)
|
|
|
|
|
| class SmoothedEnergyGuidanceAdvanced:
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {
|
| "required": {
|
| "model": ("MODEL",),
|
| "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
|
| "blur_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 9999.0, "step": 0.01, "round": 0.001}),
|
| "unet_block": (["input", "middle", "output"], {"default": "middle"}),
|
| "unet_block_id": ("INT", {"default": 0}),
|
| "sigma_start": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
|
| "sigma_end": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
|
| "rescale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
| "rescale_mode": (["full", "partial", "snf"], {"default": "full"}),
|
| },
|
| "optional": {
|
| "unet_block_list": ("STRING", {"default": ""}),
|
| },
|
| }
|
|
|
| RETURN_TYPES = ("MODEL",)
|
| FUNCTION = "patch"
|
|
|
| CATEGORY = "model_patches/unet"
|
|
|
| def patch(
|
| self,
|
| model: ModelPatcher,
|
| scale: float = 3.0,
|
| blur_sigma: float = -1.0,
|
| unet_block: str = "middle",
|
| unet_block_id: int = 0,
|
| sigma_start: float = -1.0,
|
| sigma_end: float = -1.0,
|
| rescale: float = 0.0,
|
| rescale_mode: str = "full",
|
| unet_block_list: str = "",
|
| ):
|
| m = model.clone()
|
|
|
| sigma_start = float("inf") if sigma_start < 0 else sigma_start
|
| if unet_block_list:
|
| blocks = parse_unet_blocks(model, unet_block_list)
|
| else:
|
| blocks = [(unet_block, unet_block_id, None)]
|
|
|
| def post_cfg_function(args):
|
| """CFG+SEG"""
|
| model = args["model"]
|
| cond_pred = args["cond_denoised"]
|
| uncond_pred = args["uncond_denoised"]
|
| cond = args["cond"]
|
| cfg_result = args["denoised"]
|
| sigma = args["sigma"]
|
| model_options = args["model_options"].copy()
|
| x = args["input"]
|
|
|
| signal_scale = scale
|
|
|
| if signal_scale == 0 or not (sigma_end < sigma[0] <= sigma_start):
|
| return cfg_result
|
|
|
| seg_attention = seg_attention_wrapper(optimized_attention, blur_sigma)
|
|
|
|
|
| for block in blocks:
|
| layer, number, index = block
|
| model_options = set_model_options_patch_replace(model_options, seg_attention, "attn1", layer, number, index)
|
|
|
| if BACKEND == "ComfyUI":
|
| (seg_cond_pred,) = calc_cond_batch(model, [cond], x, sigma, model_options)
|
| if BACKEND in {"Forge", "reForge"}:
|
| (seg_cond_pred, _) = calc_cond_uncond_batch(model, cond, None, x, sigma, model_options)
|
|
|
| seg = (cond_pred - seg_cond_pred) * signal_scale
|
|
|
| if rescale_mode == "snf":
|
| if uncond_pred.any():
|
| return uncond_pred + snf_guidance(cfg_result - uncond_pred, seg)
|
| return cfg_result + seg
|
|
|
| return cfg_result + rescale_guidance(seg, cond_pred, cfg_result, rescale, rescale_mode)
|
|
|
| m.set_model_sampler_post_cfg_function(post_cfg_function, rescale_mode == "snf")
|
|
|
| return (m,)
|
|
|