File size: 19,165 Bytes
3dabe4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 |
import logging
import typing as tg
from collections import OrderedDict
from os import environ
from warnings import warn
import gradio as gr
import torch
import modules.scripts as scripts
from modules import patches, script_callbacks, shared
from modules.processing import StableDiffusionProcessing
from modules.script_callbacks import (
AfterCFGCallbackParams,
CFGDenoisedParams,
CFGDenoiserParams,
)
from modules.sd_samplers_cfg_denoiser import catenate_conds
logger = logging.getLogger(__name__)
logger.setLevel(environ.get("SD_WEBUI_LOG_LEVEL", logging.INFO))
"""
An unofficial implementation of "Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance" for Automatic1111 WebUI.
@misc{ahn2024selfrectifying,
title={Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance},
author={Donghoon Ahn and Hyoungwon Cho and Jaewon Min and Wooseok Jang and Jungwoo Kim and SeonHwa Kim and Hyun Hee Park and Kyong Hwan Jin and Seungryong Kim},
year={2024},
eprint={2403.17377},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
Author: v0xie
GitHub URL: https://github.com/v0xie/sd-webui-incantations
"""
class PAGStateParams:
def __init__(self) -> None:
self.pag_scale: float = -1 # PAG guidance scale
self.guidance_scale: float = -1 # CFG
self.x_in = None
self.text_cond: dict | None = None
self.image_cond: dict | None = None
self.sigma = None
self.text_uncond: dict | None = None
self.make_condition_dict: tg.Callable | None = None # callable lambda
self.crossattn_modules: list = [] # callable lambda
self.to_v_modules: list = []
self.to_out_modules: list = []
self.pag_x_out = None
self.batch_size = -1 # Batch size
self.denoiser = None # CFGDenoiser
self.patched_combine_denoised = None
self.conds_list = None
self.uncond_shape_0 = None
class PAGExtensionScript(scripts.Script):
def __init__(self):
self.cached_c = [None, None]
self.handles = []
# Extension title in menu UI
def title(self) -> str:
return "Perturbed Attention Guidance"
# Decide to show menu in txt2img or img2img
def show(self, is_img2img):
return scripts.AlwaysVisible
# Setup menu ui detail
def ui(self, is_img2img) -> list:
with gr.Accordion("Perturbed Attention Guidance", open=False):
active = gr.Checkbox(
value=False, default=False, label="Active", elem_id="pag_active"
)
with gr.Row():
pag_scale = gr.Slider(
value=3.0,
minimum=0,
maximum=20.0,
step=0.5,
label="PAG Scale",
elem_id="pag_scale",
info="",
)
self.infotext_fields = [ # type: ignore
(active, lambda d: gr.Checkbox.update(value="PAG Active" in d)),
(pag_scale, "PAG Scale"),
]
self.paste_field_names = [ # type: ignore
"pag_active",
"pag_scale",
]
return [active, pag_scale]
def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
self.pag_process_batch(p, *args, **kwargs)
def pag_process_batch(
self, p: StableDiffusionProcessing, active, pag_scale, *args, **kwargs
):
# cleanup previous hooks always
script_callbacks.remove_current_script_callbacks()
self.remove_all_hooks()
active = getattr(p, "pag_active", active)
if active is False:
return
pag_scale = getattr(p, "pag_scale", pag_scale)
p.extra_generation_params.update(
{
"PAG Active": active,
"PAG Scale": pag_scale,
}
)
self.create_hook(p, active, pag_scale)
def create_hook(
self, p: StableDiffusionProcessing, active, pag_scale, *args, **kwargs
):
# Create a list of parameters for each concept
pag_params = PAGStateParams()
pag_params.pag_scale = pag_scale
pag_params.guidance_scale = p.cfg_scale
pag_params.batch_size = p.batch_size
pag_params.denoiser = None
# Get all the qv modules
cross_attn_modules = self.get_cross_attn_modules()
if len(cross_attn_modules) == 0:
logger.error("No cross attention modules found, cannot proceed with PAG")
return
pag_params.crossattn_modules = [
m for m in cross_attn_modules if "CrossAttention" in m.__class__.__name__
]
# Use lambda to call the callback function with the parameters to avoid global variables
cfg_denoise_lambda = lambda callback_params: self.on_cfg_denoiser_callback(
callback_params, pag_params
)
cfg_denoised_lambda = lambda callback_params: self.on_cfg_denoised_callback(
callback_params, pag_params
)
# after_cfg_lambda = lambda x: self.cfg_after_cfg_callback(x, params)
unhook_lambda = lambda _: self.unhook_callbacks(pag_params)
self.ready_hijack_forward(pag_params.crossattn_modules, pag_scale)
logger.debug("Hooked callbacks")
script_callbacks.on_cfg_denoiser(cfg_denoise_lambda)
script_callbacks.on_cfg_denoised(cfg_denoised_lambda)
# script_callbacks.on_cfg_after_cfg(after_cfg_lambda)
script_callbacks.on_script_unloaded(unhook_lambda)
def postprocess_batch(self, p, *args, **kwargs):
self.pag_postprocess_batch(p, *args, **kwargs)
def pag_postprocess_batch(self, p, active, *args, **kwargs):
script_callbacks.remove_current_script_callbacks()
logger.debug("Removed script callbacks")
active = getattr(p, "pag_active", active)
if active is False:
return
def remove_all_hooks(self):
cross_attn_modules = self.get_cross_attn_modules()
for module in cross_attn_modules:
to_v = getattr(module, "to_v", None)
self.remove_field_cross_attn_modules(module, "pag_enable")
self.remove_field_cross_attn_modules(module, "pag_last_to_v")
_remove_all_forward_hooks(module, "pag_pre_hook")
if to_v is not None:
self.remove_field_cross_attn_modules(to_v, "pag_parent_module")
_remove_all_forward_hooks(to_v, "to_v_pre_hook")
def unhook_callbacks(self, pag_params: PAGStateParams):
if pag_params is None:
logger.error("PAG params is None")
return
if pag_params.denoiser is not None:
denoiser = pag_params.denoiser
setattr(denoiser, "combine_denoised_patched", False)
try:
patches.undo(__name__, denoiser, "combine_denoised")
except KeyError:
logger.exception("KeyError unhooking combine_denoised")
pass
except RuntimeError:
logger.exception("RuntimeError unhooking combine_denoised")
pass
pag_params.denoiser = None
def ready_hijack_forward(self, crossattn_modules, pag_scale):
"""Create hooks in the forward pass of the cross attention modules
Copies the output of the to_v module to the parent module
Then applies the PAG perturbation to the output of the cross attention module (multiplication by identity)
"""
# add field for last_to_v
for module in crossattn_modules:
to_v = getattr(module, "to_v", None)
self.add_field_cross_attn_modules(module, "pag_enable", False)
self.add_field_cross_attn_modules(module, "pag_last_to_v", None)
self.add_field_cross_attn_modules(to_v, "pag_parent_module", [module])
# self.add_field_cross_attn_modules(to_out, 'pag_parent_module', [module])
def to_v_pre_hook(module, input, kwargs, output):
"""Copy the output of the to_v module to the parent module"""
parent_module = getattr(module, "pag_parent_module", None)
# copy the output of the to_v module to the parent module
if parent_module is not None:
setattr(parent_module[0], "pag_last_to_v", output.detach().clone())
def pag_pre_hook(module, input, kwargs, output):
if (
hasattr(module, "pag_enable")
and getattr(module, "pag_enable", False) is False
):
return
if not hasattr(module, "pag_last_to_v"):
# oops we forgot to unhook
return
batch_size, seq_len, inner_dim = output.shape
identity = torch.eye(seq_len).expand(batch_size, -1, -1).to(shared.device)
# get the last to_v output and save it
last_to_v = getattr(module, "pag_last_to_v", None)
if last_to_v is not None:
new_output = torch.einsum("bij,bjk->bik", identity, last_to_v)
return new_output
else:
# this is bad
return output
# Create hooks
for module in crossattn_modules:
module.register_forward_hook(pag_pre_hook, with_kwargs=True)
to_v = getattr(module, "to_v", None)
if to_v is not None:
to_v.register_forward_hook(to_v_pre_hook, with_kwargs=True)
def get_middle_block_modules(self):
"""Get all attention modules from the middle block
Refere to page 22 of the PAG paper, Appendix A.2
"""
try:
m = shared.sd_model
nlm = m.network_layer_mapping
middle_block_modules = [
m
for m in nlm.values()
if "middle_block_1_transformer_blocks_0_attn1" in m.network_layer_name
and "CrossAttention" in m.__class__.__name__
]
return middle_block_modules
except AttributeError:
logger.exception(
"AttributeError in get_middle_block_modules", stack_info=True
)
return []
except Exception:
logger.exception("Exception in get_middle_block_modules", stack_info=True)
return []
def get_cross_attn_modules(self):
"""Get all cross attention modules"""
return self.get_middle_block_modules()
def add_field_cross_attn_modules(self, module, field, value):
"""Add a field to a module if it doesn't exist"""
if not hasattr(module, field):
setattr(module, field, value)
def remove_field_cross_attn_modules(self, module, field):
"""Remove a field from a module if it exists"""
if hasattr(module, field):
delattr(module, field)
def on_cfg_denoiser_callback(
self, params: CFGDenoiserParams, pag_params: PAGStateParams
):
# always unhook
self.unhook_callbacks(pag_params)
# patch combine_denoised
if pag_params.denoiser is None:
pag_params.denoiser = params.denoiser
if getattr(params.denoiser, "combine_denoised_patched", False) is False:
try:
setattr(
params.denoiser,
"combine_denoised_original",
getattr(params.denoiser, "combine_denoised"),
)
# create patch that references the original function
pass_conds_func = (
lambda *args, **kwargs: combine_denoised_pass_conds_list(
*args,
**kwargs,
original_func=getattr(
params.denoiser, "combine_denoised_original"
),
pag_params=pag_params
)
)
pag_params.patched_combine_denoised = patches.patch(
__name__, params.denoiser, "combine_denoised", pass_conds_func
)
setattr(params.denoiser, "combine_denoised_patched", True)
setattr(
params.denoiser,
"combine_denoised_original",
patches.original(__name__, params.denoiser, "combine_denoised"),
)
except KeyError:
logger.exception("KeyError patching combine_denoised")
pass
except RuntimeError:
logger.exception("RuntimeError patching combine_denoised")
pass
if isinstance(params.text_cond, dict):
text_cond = params.text_cond["crossattn"] # SD XL
pag_params.text_cond = {}
pag_params.text_uncond = {}
for key, value in params.text_cond.items():
pag_params.text_cond[key] = value.clone().detach()
pag_params.text_uncond[key] = value.clone().detach()
else:
text_cond = params.text_cond # SD 1.5
pag_params.text_cond = text_cond.clone().detach()
pag_params.text_uncond = text_cond.clone().detach()
pag_params.x_in = params.x.clone().detach()
pag_params.sigma = params.sigma.clone().detach()
pag_params.image_cond = params.image_cond.clone().detach()
pag_params.denoiser = params.denoiser
pag_params.make_condition_dict = get_make_condition_dict_fn(params.text_uncond)
def on_cfg_denoised_callback(
self, params: CFGDenoisedParams, pag_params: PAGStateParams
):
"""Callback function for the CFGDenoisedParams
Refer to pg.22 A.2 of the PAG paper for how CFG and PAG combine
"""
# passed from on_cfg_denoiser_callback
x_in = pag_params.x_in
tensor = pag_params.text_cond
uncond = pag_params.text_uncond
image_cond_in = pag_params.image_cond
sigma_in = pag_params.sigma
# concatenate the conditions
# "modules/sd_samplers_cfg_denoiser.py:237"
cond_in = catenate_conds([tensor, uncond])
make_condition_dict = get_make_condition_dict_fn(uncond)
conds = make_condition_dict(cond_in, image_cond_in)
# set pag_enable to True for the hooked cross attention modules
for module in pag_params.crossattn_modules:
setattr(module, "pag_enable", True)
# get the PAG guidance (is there a way to optimize this so we don't have to calculate it twice?)
pag_x_out = params.inner_model(x_in, sigma_in, cond=conds)
# update pag_x_out
pag_params.pag_x_out = pag_x_out
# set pag_enable to False
for module in pag_params.crossattn_modules:
setattr(module, "pag_enable", False)
def cfg_after_cfg_callback(
self, params: AfterCFGCallbackParams, pag_params: PAGStateParams
):
# self.unhook_callbacks(pag_params)
pass
def combine_denoised_pass_conds_list(*args, **kwargs):
"""Hijacked function for combine_denoised in CFGDenoiser"""
original_func = kwargs.get("original_func", None)
new_params = kwargs.get("pag_params", None)
if new_params is None:
logger.error("new_params is None")
return original_func(*args)
def new_combine_denoised(x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0] :]
denoised = torch.clone(denoised_uncond)
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (
weight * cond_scale
)
try:
denoised[i] += (x_out[cond_index] - new_params.pag_x_out[i]) * (
weight * new_params.pag_scale
)
except TypeError:
logger.exception("TypeError in combine_denoised_pass_conds_list")
except IndexError:
logger.exception("IndexError in combine_denoised_pass_conds_list")
# logger.debug(f"added PAG guidance to denoised - pag_scale:{global_scale}")
return denoised
return new_combine_denoised(*args)
# from modules/sd_samplers_cfg_denoiser.py:187-195
def get_make_condition_dict_fn(text_uncond):
if shared.sd_model.model.conditioning_key == "crossattn-adm":
make_condition_dict = lambda c_crossattn, c_adm: {
"c_crossattn": [c_crossattn],
"c_adm": c_adm,
}
else:
if isinstance(text_uncond, dict):
make_condition_dict = lambda c_crossattn, c_concat: {
**c_crossattn,
"c_concat": [c_concat],
}
else:
make_condition_dict = lambda c_crossattn, c_concat: {
"c_crossattn": [c_crossattn],
"c_concat": [c_concat],
}
return make_condition_dict
# thanks torch; removing hooks DOESN'T WORK
# thank you to @ProGamerGov for this https://github.com/pytorch/pytorch/issues/70455
def _remove_all_forward_hooks(
module: torch.nn.Module, hook_fn_name: str | None = None
) -> None:
"""
This function removes all forward hooks in the specified module, without requiring
any hook handles. This lets us clean up & remove any hooks that weren't property
deleted.
Warning: Various PyTorch modules and systems make use of hooks, and thus extreme
caution should be exercised when removing all hooks. Users are recommended to give
their hook function a unique name that can be used to safely identify and remove
the target forward hooks.
Args:
module (nn.Module): The module instance to remove forward hooks from.
hook_fn_name (str, optional): Optionally only remove specific forward hooks
based on their function's __name__ attribute.
Default: None
"""
if hook_fn_name is None:
warn("Removing all active hooks can break some PyTorch modules & systems.")
def _remove_hooks(m: torch.nn.Module, name: str | None = None) -> None:
if hasattr(module, "_forward_hooks"):
if m._forward_hooks != OrderedDict():
if name is not None:
dict_items = list(m._forward_hooks.items())
m._forward_hooks = OrderedDict(
[(i, fn) for i, fn in dict_items if fn.__name__ != name]
)
else:
m._forward_hooks = OrderedDict()
def _remove_child_hooks(
target_module: torch.nn.Module, hook_name: str | None = None
) -> None:
for name, child in target_module._modules.items():
if child is not None:
_remove_hooks(child, hook_name)
_remove_child_hooks(child, hook_name)
# Remove hooks from target submodules
_remove_child_hooks(module, hook_fn_name)
# Remove hooks from the target module
_remove_hooks(module, hook_fn_name)
|