import logging from os import environ import modules.scripts as scripts import gradio as gr import scipy.stats as stats from scripts.ui_wrapper import UIWrapper, arg from modules import script_callbacks, patches from modules.hypernetworks import hypernetwork #import modules.sd_hijack_optimizations from modules.script_callbacks import CFGDenoiserParams, CFGDenoisedParams, AfterCFGCallbackParams from modules.prompt_parser import reconstruct_multicond_batch from modules.processing import StableDiffusionProcessing #from modules.shared import sd_model, opts from modules.sd_samplers_cfg_denoiser import catenate_conds from modules.sd_samplers_cfg_denoiser import CFGDenoiser from modules import shared import math import torch from torch.nn import functional as F from torchvision.transforms import GaussianBlur from warnings import warn from typing import Callable, Dict, Optional from collections import OrderedDict import torch logger = logging.getLogger(__name__) logger.setLevel(environ.get("SD_WEBUI_LOG_LEVEL", logging.INFO)) incantations_debug = environ.get("INCANTAIONS_DEBUG", False) """ An unofficial implementation of "Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance" for Automatic1111 WebUI. @misc{ahn2024selfrectifying, title={Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance}, author={Donghoon Ahn and Hyoungwon Cho and Jaewon Min and Wooseok Jang and Jungwoo Kim and SeonHwa Kim and Hyun Hee Park and Kyong Hwan Jin and Seungryong Kim}, year={2024}, eprint={2403.17377}, archivePrefix={arXiv}, primaryClass={cs.CV} } Include noise interval for CFG and PAG guidance in the sampling process from "Applying Guidance in a Limited Interval Improves Sample and Distribution Quality in Diffusion Models" @misc{kynkäänniemi2024applying, title={Applying Guidance in a Limited Interval Improves Sample and Distribution Quality in Diffusion Models}, author={Tuomas Kynkäänniemi and Miika Aittala and Tero Karras and Samuli Laine and Timo Aila and Jaakko Lehtinen}, year={2024}, eprint={2404.07724}, archivePrefix={arXiv}, primaryClass={cs.CV} } Include CFG schedulers from "Analysis of Classifier-Free Guidance Weight Schedulers" @misc{wang2024analysis, title={Analysis of Classifier-Free Guidance Weight Schedulers}, author={Xi Wang and Nicolas Dufour and Nefeli Andreou and Marie-Paule Cani and Victoria Fernandez Abrevaya and David Picard and Vicky Kalogeiton}, year={2024}, eprint={2404.13040}, archivePrefix={arXiv}, primaryClass={cs.CV} } Saliency-adaptive noise fusion from arXiv:2311.10329 "High-fidelity Person-centric Subject-to-Image Synthesis" @misc{wang2024highfidelity, title={High-fidelity Person-centric Subject-to-Image Synthesis}, author={Yibin Wang and Weizhong Zhang and Jianwei Zheng and Cheng Jin}, year={2024}, eprint={2311.10329}, archivePrefix={arXiv}, primaryClass={cs.CV} } Author: v0xie GitHub URL: https://github.com/v0xie/sd-webui-incantations """ handles = [] global_scale = 1 SCHEDULES = [ 'Constant', 'Clamp-Linear (c=4.0)', 'Clamp-Linear (c=2.0)', 'Clamp-Linear (c=1.0)', 'Linear', 'Inverse-Linear', 'Cosine', 'Clamp-Cosine (c=4.0)', 'Clamp-Cosine (c=2.0)', 'Clamp-Cosine (c=1.0)', 'Sine', 'Interval', 'PCS (s=0.01)', 'PCS (s=0.1)', 'PCS (s=1.0)', 'PCS (s=2.0)', 'PCS (s=4.0)', ] class PAGStateParams: def __init__(self): self.pag_active: bool = False # PAG guidance scale self.pag_sanf: bool = False # saliency-adaptive noise fusion, handled in cfg_combiner self.pag_scale: int = -1 # PAG guidance scale self.pag_start_step: int = 0 self.pag_end_step: int = 150 self.cfg_interval_enable: bool = False self.cfg_interval_schedule: str = 'Constant' self.cfg_interval_low: float = 0 self.cfg_interval_high: float = 50.0 self.cfg_interval_scheduled_value: float = 7.0 self.step : int = 0 self.max_sampling_step : int = 1 self.guidance_scale: int = -1 # CFG self.current_noise_level: float = 100.0 self.x_in = None self.text_cond = None self.image_cond = None self.sigma = None self.text_uncond = None self.make_condition_dict = None # callable lambda self.crossattn_modules = [] # callable lambda self.to_v_modules = [] self.to_out_modules = [] self.pag_x_out = None self.batch_size = -1 # Batch size self.denoiser = None # CFGDenoiser self.patched_combine_denoised = None self.conds_list = None self.uncond_shape_0 = None class PAGExtensionScript(UIWrapper): def __init__(self): self.cached_c = [None, None] self.handles = [] # Extension title in menu UI def title(self) -> str: return "Perturbed Attention Guidance" # Decide to show menu in txt2img or img2img def show(self, is_img2img): return scripts.AlwaysVisible # Setup menu ui detail def setup_ui(self, is_img2img) -> list: with gr.Accordion('Perturbed Attention Guidance', open=False): active = gr.Checkbox(value=False, default=False, label="Active", elem_id='pag_active') pag_sanf = gr.Checkbox(value=False, default=False, label="Use Saliency-Adaptive Noise Fusion", elem_id='pag_sanf') with gr.Row(): pag_scale = gr.Slider(value = 0, minimum = 0, maximum = 20.0, step = 0.5, label="PAG Scale", elem_id = 'pag_scale', info="") with gr.Row(): start_step = gr.Slider(value = 0, minimum = 0, maximum = 150, step = 1, label="Start Step", elem_id = 'pag_start_step', info="") end_step = gr.Slider(value = 150, minimum = 0, maximum = 150, step = 1, label="End Step", elem_id = 'pag_end_step', info="") with gr.Accordion('CFG Scheduler', open=False): cfg_interval_enable = gr.Checkbox(value=False, default=False, label="Enable CFG Scheduler", elem_id='cfg_interval_enable', info="If enabled, applies CFG only within noise interval with the selected schedule type. PAG must be enabled (scale can be 0). SDXL recommend CFG=15; CFG interval (0.28, 5.42]") with gr.Row(): cfg_schedule = gr.Dropdown( value='Constant', choices= SCHEDULES, label="CFG Schedule Type", elem_id='cfg_interval_schedule', ) cfg_interval_low = gr.Slider(value = 0, minimum = 0, maximum = 100, step = 0.1, label="CFG Noise Interval Low", elem_id = 'cfg_interval_low', info="") cfg_interval_high = gr.Slider(value = 100, minimum = 0, maximum = 100, step = 0.1, label="CFG Noise Interval High", elem_id = 'cfg_interval_high', info="") active.do_not_save_to_config = True pag_sanf.do_not_save_to_config = True pag_scale.do_not_save_to_config = True start_step.do_not_save_to_config = True end_step.do_not_save_to_config = True cfg_interval_enable.do_not_save_to_config = True cfg_schedule.do_not_save_to_config = True cfg_interval_low.do_not_save_to_config = True cfg_interval_high.do_not_save_to_config = True self.infotext_fields = [ (active, lambda d: gr.Checkbox.update(value='PAG Active' in d)), (pag_sanf, lambda d: gr.Checkbox.update(value='PAG SANF' in d)), (pag_scale, 'PAG Scale'), (start_step, 'PAG Start Step'), (end_step, 'PAG End Step'), (cfg_interval_enable, 'CFG Interval Enable'), (cfg_schedule, 'CFG Interval Schedule'), (cfg_interval_low, 'CFG Interval Low'), (cfg_interval_high, 'CFG Interval High') ] self.paste_field_names = [ 'pag_active', 'pag_sanf', 'pag_scale', 'pag_start_step', 'pag_end_step', 'cfg_interval_enable', 'cfg_interval_schedule', 'cfg_interval_low', 'cfg_interval_high', ] return [active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, pag_sanf] def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs): self.pag_process_batch(p, *args, **kwargs) def pag_process_batch(self, p: StableDiffusionProcessing, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, pag_sanf, *args, **kwargs): # cleanup previous hooks always script_callbacks.remove_current_script_callbacks() self.remove_all_hooks() active = getattr(p, "pag_active", active) pag_sanf = getattr(p, "pag_sanf", pag_sanf) cfg_interval_enable = getattr(p, "cfg_interval_enable", cfg_interval_enable) if active is False and cfg_interval_enable is False: return pag_scale = getattr(p, "pag_scale", pag_scale) start_step = getattr(p, "pag_start_step", start_step) end_step = getattr(p, "pag_end_step", end_step) cfg_schedule = getattr(p, "cfg_interval_schedule", cfg_schedule) cfg_interval_low = getattr(p, "cfg_interval_low", cfg_interval_low) cfg_interval_high = getattr(p, "cfg_interval_high", cfg_interval_high) if active: p.extra_generation_params.update({ "PAG Active": active, "PAG SANF": pag_sanf, "PAG Scale": pag_scale, "PAG Start Step": start_step, "PAG End Step": end_step, }) if cfg_interval_enable: p.extra_generation_params.update({ "CFG Interval Enable": cfg_interval_enable, "CFG Interval Schedule": cfg_schedule, "CFG Interval Low": cfg_interval_low, "CFG Interval High": cfg_interval_high }) self.create_hook(p, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, pag_sanf) def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, pag_sanf, *args, **kwargs): # Create a list of parameters for each concept pag_params = PAGStateParams() # Add to p's incant_cfg_params if not hasattr(p, 'incant_cfg_params'): logger.error("No incant_cfg_params found in p") p.incant_cfg_params['pag_params'] = pag_params pag_params.pag_active = active pag_params.pag_sanf = pag_sanf pag_params.pag_scale = pag_scale pag_params.pag_start_step = start_step pag_params.pag_end_step = end_step pag_params.cfg_interval_enable = cfg_interval_enable pag_params.cfg_interval_schedule = cfg_schedule pag_params.max_sampling_step = p.steps pag_params.guidance_scale = p.cfg_scale pag_params.batch_size = p.batch_size pag_params.denoiser = None pag_params.cfg_interval_scheduled_value = p.cfg_scale if pag_params.cfg_interval_enable: # Refer to 3.1 Practice in the paper # We want to round high and low noise levels to the nearest integer index low_index = find_closest_index(cfg_interval_low, pag_params.max_sampling_step) high_index = find_closest_index(cfg_interval_high, pag_params.max_sampling_step) pag_params.cfg_interval_low = calculate_noise_level(low_index, pag_params.max_sampling_step) pag_params.cfg_interval_high = calculate_noise_level(high_index, pag_params.max_sampling_step) logger.debug(f"Step Aligned CFG Interval (low, high): ({low_index}, {high_index}), Step Aligned CFG Interval: ({round(pag_params.cfg_interval_low, 4)}, {round(pag_params.cfg_interval_high, 4)})") # Get all the qv modules cross_attn_modules = self.get_cross_attn_modules() if len(cross_attn_modules) == 0: logger.error("No cross attention modules found, cannot proceed with PAG") return pag_params.crossattn_modules = [m for m in cross_attn_modules if 'CrossAttention' in m.__class__.__name__] # Use lambda to call the callback function with the parameters to avoid global variables cfg_denoise_lambda = lambda callback_params: self.on_cfg_denoiser_callback(callback_params, pag_params) cfg_denoised_lambda = lambda callback_params: self.on_cfg_denoised_callback(callback_params, pag_params) #after_cfg_lambda = lambda x: self.cfg_after_cfg_callback(x, params) unhook_lambda = lambda _: self.unhook_callbacks(pag_params) if pag_params.pag_active: self.ready_hijack_forward(pag_params.crossattn_modules, pag_scale) logger.debug('Hooked callbacks') script_callbacks.on_cfg_denoiser(cfg_denoise_lambda) script_callbacks.on_cfg_denoised(cfg_denoised_lambda) #script_callbacks.on_cfg_after_cfg(after_cfg_lambda) script_callbacks.on_script_unloaded(unhook_lambda) def postprocess_batch(self, p, *args, **kwargs): self.pag_postprocess_batch(p, *args, **kwargs) def pag_postprocess_batch(self, p, active, *args, **kwargs): script_callbacks.remove_current_script_callbacks() logger.debug('Removed script callbacks') active = getattr(p, "pag_active", active) if active is False: return def remove_all_hooks(self): cross_attn_modules = self.get_cross_attn_modules() for module in cross_attn_modules: to_v = getattr(module, 'to_v', None) self.remove_field_cross_attn_modules(module, 'pag_enable') self.remove_field_cross_attn_modules(module, 'pag_last_to_v') self.remove_field_cross_attn_modules(to_v, 'pag_parent_module') _remove_all_forward_hooks(module, 'pag_pre_hook') _remove_all_forward_hooks(to_v, 'to_v_pre_hook') def unhook_callbacks(self, pag_params: PAGStateParams): global handles return if pag_params is None: logger.error("PAG params is None") return if pag_params.denoiser is not None: denoiser = pag_params.denoiser setattr(denoiser, 'combine_denoised_patched', False) try: patches.undo(__name__, denoiser, "combine_denoised") except KeyError: logger.exception("KeyError unhooking combine_denoised") pass except RuntimeError: logger.exception("RuntimeError unhooking combine_denoised") pass pag_params.denoiser = None def ready_hijack_forward(self, crossattn_modules, pag_scale): """ Create hooks in the forward pass of the cross attention modules Copies the output of the to_v module to the parent module Then applies the PAG perturbation to the output of the cross attention module (multiplication by identity) """ # add field for last_to_v for module in crossattn_modules: to_v = getattr(module, 'to_v', None) self.add_field_cross_attn_modules(module, 'pag_enable', False) self.add_field_cross_attn_modules(module, 'pag_last_to_v', None) self.add_field_cross_attn_modules(to_v, 'pag_parent_module', [module]) # self.add_field_cross_attn_modules(to_out, 'pag_parent_module', [module]) def to_v_pre_hook(module, input, kwargs, output): """ Copy the output of the to_v module to the parent module """ parent_module = getattr(module, 'pag_parent_module', None) # copy the output of the to_v module to the parent module setattr(parent_module[0], 'pag_last_to_v', output.detach().clone()) def pag_pre_hook(module, input, kwargs, output): if hasattr(module, 'pag_enable') and getattr(module, 'pag_enable', False) is False: return if not hasattr(module, 'pag_last_to_v'): # oops we forgot to unhook return # get the last to_v output and save it last_to_v = getattr(module, 'pag_last_to_v', None) batch_size, seq_len, inner_dim = output.shape identity = torch.eye(seq_len, dtype=last_to_v.dtype, device=shared.device).expand(batch_size, -1, -1) if last_to_v is not None: new_output = torch.einsum('bij,bjk->bik', identity, last_to_v[:, :seq_len, :]) return new_output else: # this is bad return output # Create hooks for module in crossattn_modules: handle_parent = module.register_forward_hook(pag_pre_hook, with_kwargs=True) to_v = getattr(module, 'to_v', None) handle_to_v = to_v.register_forward_hook(to_v_pre_hook, with_kwargs=True) def get_middle_block_modules(self): """ Get all attention modules from the middle block Refere to page 22 of the PAG paper, Appendix A.2 """ try: m = shared.sd_model nlm = m.network_layer_mapping middle_block_modules = [m for m in nlm.values() if 'middle_block_1_transformer_blocks_0_attn1' in m.network_layer_name and 'CrossAttention' in m.__class__.__name__] return middle_block_modules except AttributeError: logger.exception("AttributeError in get_middle_block_modules", stack_info=True) return [] except Exception: logger.exception("Exception in get_middle_block_modules", stack_info=True) return [] def get_cross_attn_modules(self): """ Get all cross attention modules """ return self.get_middle_block_modules() def add_field_cross_attn_modules(self, module, field, value): """ Add a field to a module if it doesn't exist """ if not hasattr(module, field): setattr(module, field, value) def remove_field_cross_attn_modules(self, module, field): """ Remove a field from a module if it exists """ if hasattr(module, field): delattr(module, field) def on_cfg_denoiser_callback(self, params: CFGDenoiserParams, pag_params: PAGStateParams): # always unhook self.unhook_callbacks(pag_params) pag_params.step = params.sampling_step # CFG Interval # TODO: set rho based on sdxl or sd1.5 pag_params.current_noise_level = calculate_noise_level( i = pag_params.step, N = pag_params.max_sampling_step, ) if pag_params.cfg_interval_enable: if pag_params.cfg_interval_schedule != 'Constant': # Calculate noise interval start = pag_params.cfg_interval_low end = pag_params.cfg_interval_high begin_range = start if start <= end else end end_range = end if start <= end else start # Scheduled CFG Value scheduled_cfg_scale = cfg_scheduler(pag_params.cfg_interval_schedule, pag_params.step, pag_params.max_sampling_step, pag_params.guidance_scale) pag_params.cfg_interval_scheduled_value = scheduled_cfg_scale if begin_range <= pag_params.current_noise_level <= end_range else 1.0 # Run PAG only if active and within interval if not pag_params.pag_active or pag_params.pag_scale <= 0: return if not pag_params.pag_start_step <= params.sampling_step <= pag_params.pag_end_step or pag_params.pag_scale <= 0: return if isinstance(params.text_cond, dict): text_cond = params.text_cond['crossattn'] # SD XL pag_params.text_cond = {} pag_params.text_uncond = {} for key, value in params.text_cond.items(): pag_params.text_cond[key] = value.clone().detach() pag_params.text_uncond[key] = value.clone().detach() else: text_cond = params.text_cond # SD 1.5 pag_params.text_cond = text_cond.clone().detach() pag_params.text_uncond = text_cond.clone().detach() pag_params.x_in = params.x.clone().detach() pag_params.sigma = params.sigma.clone().detach() pag_params.image_cond = params.image_cond.clone().detach() pag_params.denoiser = params.denoiser pag_params.make_condition_dict = get_make_condition_dict_fn(params.text_uncond) def on_cfg_denoised_callback(self, params: CFGDenoisedParams, pag_params: PAGStateParams): """ Callback function for the CFGDenoisedParams Refer to pg.22 A.2 of the PAG paper for how CFG and PAG combine """ # Run only within interval # Run PAG only if active and within interval if not pag_params.pag_active or pag_params.pag_scale <= 0: return if not pag_params.pag_start_step <= params.sampling_step <= pag_params.pag_end_step or pag_params.pag_scale <= 0: return # passed from on_cfg_denoiser_callback x_in = pag_params.x_in tensor = pag_params.text_cond uncond = pag_params.text_uncond image_cond_in = pag_params.image_cond sigma_in = pag_params.sigma # concatenate the conditions # "modules/sd_samplers_cfg_denoiser.py:237" cond_in = catenate_conds([tensor, uncond]) make_condition_dict = get_make_condition_dict_fn(uncond) conds = make_condition_dict(cond_in, image_cond_in) # set pag_enable to True for the hooked cross attention modules for module in pag_params.crossattn_modules: setattr(module, 'pag_enable', True) # get the PAG guidance (is there a way to optimize this so we don't have to calculate it twice?) pag_x_out = params.inner_model(x_in, sigma_in, cond=conds) # update pag_x_out pag_params.pag_x_out = pag_x_out # set pag_enable to False for module in pag_params.crossattn_modules: setattr(module, 'pag_enable', False) def cfg_after_cfg_callback(self, params: AfterCFGCallbackParams, pag_params: PAGStateParams): #self.unhook_callbacks(pag_params) pass def get_xyz_axis_options(self) -> dict: xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ in ("xyz_grid.py", "scripts.xyz_grid")][0].module extra_axis_options = { xyz_grid.AxisOption("[PAG] Active", str, pag_apply_override('pag_active', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)), xyz_grid.AxisOption("[PAG] SANF", str, pag_apply_override('pag_sanf', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)), xyz_grid.AxisOption("[PAG] PAG Scale", float, pag_apply_field("pag_scale")), xyz_grid.AxisOption("[PAG] PAG Start Step", int, pag_apply_field("pag_start_step")), xyz_grid.AxisOption("[PAG] PAG End Step", int, pag_apply_field("pag_end_step")), xyz_grid.AxisOption("[PAG] Enable CFG Scheduler", str, pag_apply_override('cfg_interval_enable', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)), xyz_grid.AxisOption("[PAG] CFG Noise Interval Low", float, pag_apply_field("cfg_interval_low")), xyz_grid.AxisOption("[PAG] CFG Noise Interval High", float, pag_apply_field("cfg_interval_high")), xyz_grid.AxisOption("[PAG] CFG Schedule Type", str, pag_apply_override('cfg_interval_schedule', boolean=False), choices=lambda: SCHEDULES), #xyz_grid.AxisOption("[PAG] ctnms_alpha", float, pag_apply_field("pag_ctnms_alpha")), } return extra_axis_options def combine_denoised_pass_conds_list(*args, **kwargs): """ Hijacked function for combine_denoised in CFGDenoiser """ original_func = kwargs.get('original_func', None) new_params = kwargs.get('pag_params', None) if new_params is None: logger.error("new_params is None") return original_func(*args) def new_combine_denoised(x_out, conds_list, uncond, cond_scale): denoised_uncond = x_out[-uncond.shape[0]:] denoised = torch.clone(denoised_uncond) noise_level = calculate_noise_level(new_params.step, new_params.max_sampling_step) # Calculate CFG Scale cfg_scale = cond_scale new_params.cfg_interval_scheduled_value = cfg_scale if new_params.cfg_interval_enable: if new_params.cfg_interval_schedule != 'Constant': # Calculate noise interval start = new_params.cfg_interval_low end = new_params.cfg_interval_high begin_range = start if start <= end else end end_range = end if start <= end else start # Scheduled CFG Value scheduled_cfg_scale = cfg_scheduler(new_params.cfg_interval_schedule, new_params.step, new_params.max_sampling_step, cond_scale) # Only apply CFG in the interval cfg_scale = scheduled_cfg_scale if begin_range <= noise_level <= end_range else 1.0 new_params.cfg_interval_scheduled_value = scheduled_cfg_scale # This may be temporarily necessary for compatibility with scfg # if not new_params.pag_start_step <= new_params.step <= new_params.pag_end_step: # return original_func(*args) # This may be temporarily necessary for compatibility with scfg # if not new_params.pag_start_step <= new_params.step <= new_params.pag_end_step: # return original_func(*args) if incantations_debug: logger.debug(f"Schedule: {new_params.cfg_interval_schedule}, CFG Scale: {cfg_scale}, Noise_level: {round(noise_level,3)}") for i, conds in enumerate(conds_list): for cond_index, weight in conds: if not new_params.cfg_interval_enable: denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cfg_scale) else: denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cfg_scale) # Apply PAG guidance only within interval if not new_params.pag_start_step <= new_params.step <= new_params.pag_end_step or new_params.pag_scale <= 0: continue else: try: denoised[i] += (x_out[cond_index] - new_params.pag_x_out[i]) * (weight * new_params.pag_scale) except TypeError: logger.exception("TypeError in combine_denoised_pass_conds_list") except IndexError: logger.exception("IndexError in combine_denoised_pass_conds_list") #logger.debug(f"added PAG guidance to denoised - pag_scale:{global_scale}") return denoised return new_combine_denoised(*args) # from modules/sd_samplers_cfg_denoiser.py:187-195 def get_make_condition_dict_fn(text_uncond): if shared.sd_model.model.conditioning_key == "crossattn-adm": make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} else: if isinstance(text_uncond, dict): make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} else: make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]} return make_condition_dict def calculate_noise_level(i, N, sigma_min=0.002, sigma_max=80.0, rho=3): """ Calculate the noise level for a given sampling step index. Parameters: i (int): Index of the current sampling step (0-based index). N (int): Total number of sampling steps. sigma_min (float): Minimum sigma value for min noise level, default 0.002. sigma_max (float): Maximum sigma value for max noise level, default 80.0. rho (int): Discretization parameter, default 3 for SD-XL, 7 for EDM2. Returns: float: Calculated noise level for the given step. """ if i == 0: return sigma_max if i >= N: return 0.0 sigma_max_p = sigma_max ** (1/rho) sigma_min_p = sigma_min ** (1/rho) inner_term = sigma_max_p + (i / (N - 1)) * (sigma_min_p - sigma_max_p) noise_level = inner_term ** rho return noise_level def find_closest_index(noise_level: float, N: int, sigma_min=0.002, sigma_max=80.0, rho=3, tol=1e-6): """ Given a noise level, find the closest integer index in the range [0, N-1] that corresponds to the noise level. Parameters: noise_level (float): Target noise level to find the closest index for. N (int): Total number of sampling steps. sigma_min (float): Minimum sigma value for min noise level, default 0.002. sigma_max (float): Maximum sigma value for max noise level, default 80.0. rho (int): Discretization parameter, default 3 for SD-XL, 7 for EDM2. Returns: int: The closest index to the specified noise level. """ # Min/max noise levels for the given range if noise_level <= sigma_min: return N if noise_level >= sigma_max: return 0 #return N - 1 low, high = 0, N - 1 while low <= high: mid = (low + high) // 2 mid_nl = calculate_noise_level(mid, N) if abs(mid_nl - noise_level) < tol: return mid elif mid_nl < noise_level: high = mid - 1 else: low = mid + 1 # If exact match not found, return the index with noise level closest to the target return low if abs(calculate_noise_level(low, N) - noise_level) < abs(calculate_noise_level(high, N) - noise_level) else high ### CFG Schedulers # TODO: Refactor this into something cleaner def cfg_scheduler(schedule: str, step: int, max_steps: int, w0: float) -> float: """ Constant scheduler for CFG guidance weight. Parameters: step (int): Current sampling step. max_steps (int): Total number of sampling steps. w0 (float): Constant value for the guidance weight. Returns: float: Scheduled guidance weight value. """ match schedule: case 'Constant': return constant_schedule(step, max_steps, w0) case 'Linear': return linear_schedule(step, max_steps, w0) case 'Clamp-Linear (c=4.0)': return clamp_linear_schedule(step, max_steps, w0, 4.0) case 'Clamp-Linear (c=2.0)': return clamp_linear_schedule(step, max_steps, w0, 2.0) case 'Clamp-Linear (c=1.0)': return clamp_linear_schedule(step, max_steps, w0, 1.0) case 'Inverse-Linear': return invlinear_schedule(step, max_steps, w0) case 'PCS (s=0.01)': return powered_cosine_schedule(step, max_steps, w0, 0.01) case 'PCS (s=0.1)': return powered_cosine_schedule(step, max_steps, w0, 0.1) case 'PCS (s=1.0)': return powered_cosine_schedule(step, max_steps, w0, 1.0) case 'PCS (s=2.0)': return powered_cosine_schedule(step, max_steps, w0, 2.0) case 'PCS (s=4.0)': return powered_cosine_schedule(step, max_steps, w0, 4.0) case 'Clamp-Cosine (c=4.0)': return clamp_cosine_schedule(step, max_steps, w0, 4.0) case 'Clamp-Cosine (c=2.0)': return clamp_cosine_schedule(step, max_steps, w0, 2.0) case 'Clamp-Cosine (c=1.0)': return clamp_cosine_schedule(step, max_steps, w0, 1.0) case 'Cosine': return cosine_schedule(step, max_steps, w0) case 'Sine': return sine_schedule(step, max_steps, w0) case 'V-Shape': return v_shape_schedule(step, max_steps, w0) case 'A-Shape': return a_shape_schedule(step, max_steps, w0) case 'Interval': return interval_schedule(step, max_steps, w0, 0.25, 5.42) case _: logger.error(f"Invalid CFG schedule: {schedule}") return constant_schedule(step, max_steps, w0) def constant_schedule(step: int, max_steps: int, w0: float): """ Constant scheduler for CFG guidance weight. """ return w0 def linear_schedule(step: int, max_steps: int, w0: float): """ Normalized linear scheduler for CFG guidance weight. Such that integral 0-> T ~ w(t) dt = w*T """ # return w0 * (1 - step / max_steps) return w0 * 2 * (1 - step / max_steps) def clamp_linear_schedule(step: int, max_steps: int, w0: float, c: float): """ Normalized clamp-linear scheduler for CFG guidance weight. """ return max(c, linear_schedule(step, max_steps, w0)) def clamp_cosine_schedule(step: int, max_steps: int, w0: float, c: float): """ Normalized clamp-cosine scheduler for CFG guidance weight. """ return max(c, cosine_schedule(step, max_steps, w0)) def invlinear_schedule(step: int, max_steps: int, w0: float): """ Normalized inverse linear scheduler for CFG guidance weight. """ # return w0 * (step / max_steps) return w0 * 2 * (step / max_steps) def powered_cosine_schedule(step: int, max_steps: int, w0: float, s: float): """ Normalized cosine scheduler for CFG guidance weight. """ return w0 * ((1 - math.cos(math.pi * ((max_steps - step) / max_steps)**s))/2.0) def cosine_schedule(step: int, max_steps: int, w0: float): """ Normalized cosine scheduler for CFG guidance weight. """ return w0 * (1 + math.cos(math.pi * step / max_steps)) def sine_schedule(step: int, max_steps: int, w0: float): """ Normalized sine scheduler for CFG guidance weight. """ return w0 * (math.sin((math.pi * step / max_steps) - (math.pi / 2)) + 1) def v_shape_schedule(step: int, max_steps: int, w0: float): """ Normalized V-shape scheduler for CFG guidance weight. """ if step < max_steps / 2: return invlinear_schedule(step, max_steps, w0) return linear_schedule(step, max_steps, w0) def a_shape_schedule(step: int, max_steps: int, w0: float): """ Normalized A-shape scheduler for CFG guidance weight. """ if step < max_steps / 2: return linear_schedule(step, max_steps, w0) return invlinear_schedule(step, max_steps, w0) def interval_schedule(step: int, max_steps: int, w0: float, low: float, high: float): """ Normalized interval scheduler for CFG guidance weight. """ if low <= step <= high: return w0 return 1.0 # XYZ Plot # Based on @mcmonkey4eva's XYZ Plot implementation here: https://github.com/mcmonkeyprojects/sd-dynamic-thresholding/blob/master/scripts/dynamic_thresholding.py def pag_apply_override(field, boolean: bool = False): def fun(p, x, xs): if boolean: x = True if x.lower() == "true" else False setattr(p, field, x) if not hasattr(p, "pag_active"): setattr(p, "pag_active", True) if 'cfg_interval_' in field and not hasattr(p, "cfg_interval_enable"): setattr(p, "cfg_interval_enable", True) return fun def pag_apply_field(field): def fun(p, x, xs): if not hasattr(p, "pag_active"): setattr(p, "pag_active", True) setattr(p, field, x) return fun # thanks torch; removing hooks DOESN'T WORK # thank you to @ProGamerGov for this https://github.com/pytorch/pytorch/issues/70455 def _remove_all_forward_hooks( module: torch.nn.Module, hook_fn_name: Optional[str] = None ) -> None: """ This function removes all forward hooks in the specified module, without requiring any hook handles. This lets us clean up & remove any hooks that weren't property deleted. Warning: Various PyTorch modules and systems make use of hooks, and thus extreme caution should be exercised when removing all hooks. Users are recommended to give their hook function a unique name that can be used to safely identify and remove the target forward hooks. Args: module (nn.Module): The module instance to remove forward hooks from. hook_fn_name (str, optional): Optionally only remove specific forward hooks based on their function's __name__ attribute. Default: None """ if hook_fn_name is None: warn("Removing all active hooks can break some PyTorch modules & systems.") def _remove_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None: if hasattr(module, "_forward_hooks"): if m._forward_hooks != OrderedDict(): if name is not None: dict_items = list(m._forward_hooks.items()) m._forward_hooks = OrderedDict( [(i, fn) for i, fn in dict_items if fn.__name__ != name] ) else: m._forward_hooks: Dict[int, Callable] = OrderedDict() def _remove_child_hooks( target_module: torch.nn.Module, hook_name: Optional[str] = None ) -> None: for name, child in target_module._modules.items(): if child is not None: _remove_hooks(child, hook_name) _remove_child_hooks(child, hook_name) # Remove hooks from target submodules _remove_child_hooks(module, hook_fn_name) # Remove hooks from the target module _remove_hooks(module, hook_fn_name)