dikdimon's picture
Upload extensions using SD-Hub extension
3dabe4a verified
BACKEND = None
try:
from comfy.model_patcher import ModelPatcher
from comfy.samplers import calc_cond_batch
from comfy.ldm.modules.attention import optimized_attention
from .pag_utils import (
parse_unet_blocks,
perturbed_attention,
rescale_guidance,
seg_attention_wrapper,
snf_guidance,
)
try:
from comfy.model_patcher import set_model_options_patch_replace
except ImportError:
from .pag_utils import set_model_options_patch_replace
BACKEND = "ComfyUI"
except ImportError:
from pag_utils import (
parse_unet_blocks,
set_model_options_patch_replace,
perturbed_attention,
rescale_guidance,
seg_attention_wrapper,
snf_guidance,
)
try:
from ldm_patched.modules.model_patcher import ModelPatcher
from ldm_patched.modules.samplers import calc_cond_uncond_batch
from ldm_patched.ldm.modules.attention import optimized_attention
BACKEND = "reForge"
except ImportError:
from backend.patcher.base import ModelPatcher
from backend.sampling.sampling_function import calc_cond_uncond_batch
from backend.attention import attention_function as optimized_attention
BACKEND = "Forge"
class PerturbedAttention:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
"adaptive_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "round": 0.0001}),
"unet_block": (["input", "middle", "output"], {"default": "middle"}),
"unet_block_id": ("INT", {"default": 0}),
"sigma_start": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
"sigma_end": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
"rescale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"rescale_mode": (["full", "partial", "snf"], {"default": "full"}),
},
"optional": {
"unet_block_list": ("STRING", {"default": ""}),
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "model_patches/unet"
def patch(
self,
model: ModelPatcher,
scale: float = 3.0,
adaptive_scale: float = 0.0,
unet_block: str = "middle",
unet_block_id: int = 0,
sigma_start: float = -1.0,
sigma_end: float = -1.0,
rescale: float = 0.0,
rescale_mode: str = "full",
unet_block_list: str = "",
):
m = model.clone()
sigma_start = float("inf") if sigma_start < 0 else sigma_start
if unet_block_list:
blocks = parse_unet_blocks(model, unet_block_list)
else:
blocks = [(unet_block, unet_block_id, None)]
def post_cfg_function(args):
"""CFG+PAG"""
model = args["model"]
cond_pred = args["cond_denoised"]
uncond_pred = args["uncond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
model_options = args["model_options"].copy()
x = args["input"]
signal_scale = scale
if adaptive_scale > 0:
t = model.model_sampling.timestep(sigma)[0].item()
signal_scale -= scale * (adaptive_scale**4) * (1000 - t)
if signal_scale < 0:
signal_scale = 0
if signal_scale == 0 or not (sigma_end < sigma[0] <= sigma_start):
return cfg_result
# Replace Self-attention with PAG
for block in blocks:
layer, number, index = block
model_options = set_model_options_patch_replace(model_options, perturbed_attention, "attn1", layer, number, index)
if BACKEND == "ComfyUI":
(pag_cond_pred,) = calc_cond_batch(model, [cond], x, sigma, model_options)
if BACKEND in {"Forge", "reForge"}:
(pag_cond_pred, _) = calc_cond_uncond_batch(model, cond, None, x, sigma, model_options)
pag = (cond_pred - pag_cond_pred) * signal_scale
if rescale_mode == "snf":
if uncond_pred.any():
return uncond_pred + snf_guidance(cfg_result - uncond_pred, pag)
return cfg_result + pag
return cfg_result + rescale_guidance(pag, cond_pred, cfg_result, rescale, rescale_mode)
m.set_model_sampler_post_cfg_function(post_cfg_function, rescale_mode == "snf")
return (m,)
class SmoothedEnergyGuidanceAdvanced:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
"blur_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 9999.0, "step": 0.01, "round": 0.001}),
"unet_block": (["input", "middle", "output"], {"default": "middle"}),
"unet_block_id": ("INT", {"default": 0}),
"sigma_start": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
"sigma_end": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
"rescale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"rescale_mode": (["full", "partial", "snf"], {"default": "full"}),
},
"optional": {
"unet_block_list": ("STRING", {"default": ""}),
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "model_patches/unet"
def patch(
self,
model: ModelPatcher,
scale: float = 3.0,
blur_sigma: float = -1.0,
unet_block: str = "middle",
unet_block_id: int = 0,
sigma_start: float = -1.0,
sigma_end: float = -1.0,
rescale: float = 0.0,
rescale_mode: str = "full",
unet_block_list: str = "",
):
m = model.clone()
sigma_start = float("inf") if sigma_start < 0 else sigma_start
if unet_block_list:
blocks = parse_unet_blocks(model, unet_block_list)
else:
blocks = [(unet_block, unet_block_id, None)]
def post_cfg_function(args):
"""CFG+SEG"""
model = args["model"]
cond_pred = args["cond_denoised"]
uncond_pred = args["uncond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
model_options = args["model_options"].copy()
x = args["input"]
signal_scale = scale
if signal_scale == 0 or not (sigma_end < sigma[0] <= sigma_start):
return cfg_result
seg_attention = seg_attention_wrapper(optimized_attention, blur_sigma)
# Replace Self-attention with SEG attention
for block in blocks:
layer, number, index = block
model_options = set_model_options_patch_replace(model_options, seg_attention, "attn1", layer, number, index)
if BACKEND == "ComfyUI":
(seg_cond_pred,) = calc_cond_batch(model, [cond], x, sigma, model_options)
if BACKEND in {"Forge", "reForge"}:
(seg_cond_pred, _) = calc_cond_uncond_batch(model, cond, None, x, sigma, model_options)
seg = (cond_pred - seg_cond_pred) * signal_scale
if rescale_mode == "snf":
if uncond_pred.any():
return uncond_pred + snf_guidance(cfg_result - uncond_pred, seg)
return cfg_result + seg
return cfg_result + rescale_guidance(seg, cond_pred, cfg_result, rescale, rescale_mode)
m.set_model_sampler_post_cfg_function(post_cfg_function, rescale_mode == "snf")
return (m,)