import logging from os import environ import modules.scripts as scripts import gradio as gr from functools import reduce from scripts.incant_utils import plot_tools from einops import rearrange from scripts.ui_wrapper import UIWrapper from modules import script_callbacks from modules import extra_networks from modules import prompt_parser from modules import sd_hijack from modules.script_callbacks import CFGDenoiserParams from modules.processing import StableDiffusionProcessing from modules import shared import math import torch from torch.nn import functional as F from torchvision.transforms import GaussianBlur from warnings import warn from typing import Callable, Dict, Optional from collections import OrderedDict logger = logging.getLogger(__name__) logger.setLevel(environ.get("SD_WEBUI_LOG_LEVEL", logging.INFO)) """ Unofficial implementation of algorithms in "Multi-Concept T2I-Zero: Tweaking Only The Text Embeddings and Nothing Else" Also implements some "Reduce distortion in generation" algorithms from "Enhancing Semantic Fidelity in Text-to-Image Synthesis: Attention Regulation in Diffusion Models" @misc{tunanyan2023multiconcept, title={Multi-Concept T2I-Zero: Tweaking Only The Text Embeddings and Nothing Else}, author={Hazarapet Tunanyan and Dejia Xu and Shant Navasardyan and Zhangyang Wang and Humphrey Shi}, year={2023}, eprint={2310.07419}, archivePrefix={arXiv}, primaryClass={cs.CV} } @misc{zhang2024enhancing, title={Enhancing Semantic Fidelity in Text-to-Image Synthesis: Attention Regulation in Diffusion Models}, author={Yang Zhang and Teoh Tze Tzun and Lim Wei Hern and Tiviatis Sim and Kenji Kawaguchi}, year={2024}, eprint={2403.06381}, archivePrefix={arXiv}, primaryClass={cs.CV} } Author: v0xie GitHub URL: https://github.com/v0xie/sd-webui-incantations """ handles = [] token_indices = [0] class T2I0StateParams: def __init__(self): self.attnreg: bool = False self.ema_smoothing_factor: float = 2.0 self.step_start : int = 0 self.step_end : int = 25 self.token_count: int = 0 self.tokens: list[int] = [] # [0, 20] self.window_size_period: int = 10 # [0, 20] self.ctnms_alpha: float = 0.05 # [0., 1.] if abs value of difference between uncodition and concept-conditioned is less than this, then zero out the concept-conditioned values less than this self.correction_threshold: float = 0.5 # [0., 1.] self.correction_strength: float = 0.25 # [0., 1.) # larger bm is less volatile changes in momentum self.strength = 1.0 self.width = None self.height = None self.dims = [] self.cbs_similarities: list = None # we can precompute this? class T2I0ExtensionScript(UIWrapper): def __init__(self): self.cached_c = [None, None] self.handles = [] # Extension title in menu UI def title(self) -> str: return "Multi T2I-Zero" # Decide to show menu in txt2img or img2img def show(self, is_img2img): return scripts.AlwaysVisible # Setup menu ui detail def setup_ui(self, is_img2img) -> list: with gr.Accordion('Multi-Concept T2I-Zero', open=False): active = gr.Checkbox(value=False, default=False, label="Active", elem_id='t2i0_active') step_start = gr.Slider(value=1, minimum=0, maximum=150, default=1, step=1, label="Step Start", elem_id='t2i0_step_start', info="Start applying the correction at this step. Set to > 1 if using EMA.") step_end = gr.Slider(value=25, minimum=0, maximum=150, default=1, step=1, label="Step End", elem_id='t2i0_step_end') with gr.Row(): tokens = gr.Textbox(visible=True, value="", label="Tokens", elem_id='t2i0_tokens', info="Comma separated list of indices of tokens to condition on. Leave empty to condition on all tokens. Example: For prompt 'A cat and a dog', 'A': 0, 'cat': 1, 'and': 2, 'a': 3, 'dog': 4") with gr.Row(): window_size = gr.Slider(value = 2, minimum = 0, maximum = 100, step = 1, label="Correction by Similarities Window Size", elem_id = 't2i0_window_size', info="Exclude contribution of tokens with indices += this value from the current token index.") correction_threshold = gr.Slider(value = 0.5, minimum = 0., maximum = 1.0, step = 0.01, label="CbS Score Threshold", elem_id = 't2i0_correction_threshold', info="Filter dimensions with similarity below this threshold") correction_strength = gr.Slider(value = 0.0, minimum = 0.0, maximum = 2.0, step = 0.01, label="CbS Correction Strength", elem_id = 't2i0_correction_strength', info="The strength of the correction") with gr.Row(): attnreg = gr.Checkbox(visible=False, value=False, default=False, label="Use Attention Regulation", elem_id='t2i0_use_attnreg') ctnms_alpha = gr.Slider(value = 0.1, minimum = 0.0, maximum = 1.0, step = 0.01, label="Alpha for Cross-Token Non-Maximum Suppression", elem_id = 't2i0_ctnms_alpha', info="Contribution of the suppressed attention map, default 0.1") ema_factor = gr.Slider(value=0.0, minimum=0.0, maximum=4.0, default=2.0, label="EMA Smoothing Factor", elem_id='t2i0_ema_factor', info="Based on method from [arXiv:2403.06381]") active.do_not_save_to_config = True attnreg.do_not_save_to_config = True step_start.do_not_save_to_config = True step_end.do_not_save_to_config = True window_size.do_not_save_to_config = True correction_threshold.do_not_save_to_config = True correction_strength.do_not_save_to_config = True attnreg.do_not_save_to_config = True ctnms_alpha.do_not_save_to_config = True ema_factor.do_not_save_to_config = True tokens.do_not_save_to_config = True self.infotext_fields = [ (active, lambda d: gr.Checkbox.update(value='T2I-0 Active' in d)), #(attnreg, lambda d: gr.Checkbox.update(value='T2I-0 AttnReg' in d)), (window_size, 'T2I-0 Window Size'), (step_start, 'T2I-0 Step Start'), (step_end, 'T2I-0 Step End'), (correction_threshold, 'T2I-0 CbS Score Threshold'), (correction_strength, 'T2I-0 CbS Correction Strength'), (ctnms_alpha, 'T2I-0 CTNMS Alpha'), (ema_factor, 'T2I-0 CTNMS EMA Smoothing Factor'), (tokens, 'T2I-0 Tokens'), ] self.paste_field_names = [ 't2i0_active', 't2i0_attnreg', 't2i0_window_size', 't2i0_ctnms_alpha', 't2i0_correction_threshold', 't2i0_correction_strength' 't2i0_ema_factor', 't2i0_step_start', 't2i0_step_end', 't2i0_tokens' ] return [active, attnreg, window_size, ctnms_alpha, correction_threshold, correction_strength, tokens, ema_factor, step_end, step_start] def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs): self.t2i0_process_batch(p, *args, **kwargs) def t2i0_process_batch(self, p: StableDiffusionProcessing, active, attnreg, window_size, ctnms_alpha, correction_threshold, correction_strength, tokens, ema_factor, step_end, step_start, *args, **kwargs): active = getattr(p, "t2i0_active", active) # use_attnreg = getattr(p, "t2i0_attnreg", attnreg) ema_factor = getattr(p, "t2i0_ema_factor", ema_factor) step_start = getattr(p, "t2i0_step_start", step_start) step_end = getattr(p, "t2i0_step_end", step_end) if active is False: return window_size = getattr(p, "t2i0_window_size", window_size) ctnms_alpha = getattr(p, "t2i0_ctnms_alpha", ctnms_alpha) correction_threshold = getattr(p, "t2i0_correction_threshold", correction_threshold) correction_strength = getattr(p, "t2i0_correction_strength", correction_strength) tokens = getattr(p, "t2i0_tokens", tokens) p.extra_generation_params.update({ "T2I-0 Active": active, #"T2I-0 AttnReg": attnreg, "T2I-0 Window Size": window_size, "T2I-0 Step Start": step_start, "T2I-0 Step End": step_end, "T2I-0 CbS Score Threshold": correction_threshold, "T2I-0 CbS Correction Strength": correction_strength, "T2I-0 CTNMS Alpha": ctnms_alpha, "T2I-0 CTNMS EMA Smoothing Factor": ema_factor, "T2I-0 Tokens": tokens, }) self.create_hook(p, active, attnreg, window_size, ctnms_alpha, correction_threshold, correction_strength, tokens, ema_factor, step_end, step_start, p.width, p.height) def parse_concept_prompt(self, prompt:str) -> list[str]: """ Separate prompt by comma into a list of concepts TODO: parse prompt into a list of concepts using A1111 functions >>> g = lambda prompt: self.parse_concept_prompt(prompt) >>> g("") [] >>> g("apples") ['apples'] >>> g("apple, banana, carrot") ['apple', 'banana', 'carrot'] """ if len(prompt) == 0: return [] return [x.strip() for x in prompt.split(",")] def create_hook(self, p, active, attnreg, window_size, ctnms_alpha, correction_threshold, correction_strength, tokens, ema_factor, step_end, step_start, width, height, *args, **kwargs): # Sanity check cross_attn_modules = self.get_cross_attn_modules() if len(cross_attn_modules) == 0: logger.error("No cross attention modules found, cannot run T2I-0") return if len(tokens) > 0: try: token_indices = [int(x) for x in tokens.split(",")] except ValueError: logger.error("Invalid token indices, must be comma separated integers") raise else: token_indices = [] # Create a list of parameters for each concept t2i0_params = [] #for _, strength in concept_conds: params = T2I0StateParams() params.attnreg = attnreg params.ema_smoothing_factor = ema_factor params.step_start = step_start params.step_end = step_end params.window_size_period = window_size params.ctnms_alpha = ctnms_alpha params.correction_threshold = correction_threshold params.correction_strength = correction_strength params.strength = 1.0 params.width = width params.height = height params.dims = [width, height] params.token_count, _ = get_token_count(p.prompt, p.steps, True) token_indices = [x+1 for x in token_indices if x >= 0 and x < params.token_count] params.tokens = token_indices t2i0_params.append(params) # Use lambda to call the callback function with the parameters to avoid global variables y = lambda params: self.on_cfg_denoiser_callback(params, t2i0_params) # un = lambda params: self.unhook_callbacks() # Hook callbacks if ctnms_alpha > 0: self.ready_hijack_forward(ctnms_alpha, width, height, ema_factor, step_start, step_end, token_indices, params.token_count) logger.debug('Hooked callbacks') script_callbacks.on_cfg_denoiser(y) script_callbacks.on_script_unloaded(self.unhook_callbacks) def postprocess_batch(self, p, *args, **kwargs): self.t2i0_postprocess_batch(p, *args, **kwargs) def t2i0_postprocess_batch(self, p, active, *args, **kwargs): self.unhook_callbacks() active = getattr(p, "t2i0_active", active) if active is False: return def unhook_callbacks(self): global handles logger.debug('Unhooked callbacks') cross_attn_modules = self.get_cross_attn_modules() for module in cross_attn_modules: self.remove_field_cross_attn_modules(module, 't2i0_last_attn_map') self.remove_field_cross_attn_modules(module, 't2i0_step') self.remove_field_cross_attn_modules(module, 't2i0_step_start') self.remove_field_cross_attn_modules(module, 't2i0_step_end') self.remove_field_cross_attn_modules(module, 't2i0_ema_factor') self.remove_field_cross_attn_modules(module, 't2i0_ema') self.remove_field_cross_attn_modules(module, 'plot_num') self.remove_field_cross_attn_modules(module, 't2i0_tokens') self.remove_field_cross_attn_modules(module, 't2i0_token_count') self.remove_field_cross_attn_modules(module, 't2i0_to_v_map') self.remove_field_cross_attn_modules(module.to_k, 't2i0_parent_module') self.remove_field_cross_attn_modules(module.to_v, 't2i0_parent_module') _remove_all_forward_hooks(module, 'cross_token_non_maximum_suppression') # _remove_all_forward_hooks(module, 'cross_token_non_maximum_suppression_pre') # _remove_all_forward_hooks(module.to_k, 't2i0_to_k_hook') _remove_all_forward_hooks(module.to_v, 't2i0_to_v_hook') script_callbacks.remove_current_script_callbacks() def apply_attnreg(self, f, C, alpha, B, *args, **kwargs): """ Apply attention regulation on an embedding. Args: f (Tensor): The embedding tensor of shape (n, d). C (list): Indices of selected tokens. alpha (float): Attnreg strength. B (float): Lagrange multiplier B > 0 gamma (int): Window size for the windowing function. Returns: Tensor: The corrected embedding tensor. """ n, d = f.shape f_tilde = f.detach().clone() # Copy the embedding tensor # for token_idx, c in enumerate(C): # pass return f_tilde def correction_by_similarities(self, f, C, percentile, gamma, alpha, tokens=None, token_count=77): """ Apply the Correction by Similarities algorithm on embeddings. Args: f (Tensor): The embedding tensor of shape (n, d). C (list): Indices of selected tokens. percentile (float): Percentile to use for score threshold. gamma (int): Window size for the windowing function. alpha (float): Correction strength. tokens (list): List of token indices to condition on (default is all tokens if empty list). Returns: Tensor: The corrected embedding tensor. """ if alpha == 0: return f n, d = f.shape token_indices = tokens min_idx = 1 max_idx = min(token_count+1, n) if token_indices is None: token_indices = list(range(min_idx, max_idx)) if token_indices == []: token_indices = list(range(min_idx, max_idx)) else: token_indices = [x+1 for x in token_indices if x >= 0 and x < n] f_tilde = f.detach().clone() # Copy the embedding tensor # Define a windowing function def psi(c, gamma, n, dtype, device, min_idx, max_idx): window = torch.zeros(n, dtype=dtype, device=device) start = max(min_idx, c - gamma) end = min(max_idx, c + gamma + 1) window[start:end] = 1 return window def threshold_filter(t, tau): """ Threshold filter function Filters product values below a threshold tau and normalizes them to leave only the most similar dimensions. Arguments: t: torch.Tensor - The tensor to threshold tau: float - The threshold value Returns: bool: True if the value is above the threshold, False otherwise """ pass for c in token_indices: if c < 0 or c >= n: continue Sc = f[c] * f # Element-wise multiplication # calculate score threshold to filter out values under score threshold # often there is a huge difference between the max and min values, so we use a log-like function instead k = 10 e= 2.718281 pct_max = 1/(1+1e-10) pct_min = 1e-16 # max of 0.999... to 0.0000...1 pct = min(pct_max, max(pct_min, 1 - e**(-k * percentile))) tau = torch.quantile(Sc, pct) Sc_tilde = Sc * (Sc > tau) # Apply threshold and filter Sc_tilde /= Sc_tilde.max() # Normalize window = psi(c, gamma, n, Sc_tilde.dtype, Sc_tilde.device, min_idx, max_idx).unsqueeze(1) # Apply windowing function Sc_tilde *= window f_c_tilde = torch.sum(Sc_tilde * f, dim=0) # Combine embeddings f_tilde[c] = (1 - alpha) * f[c] + alpha * f_c_tilde # Blend embeddings return f_tilde def ready_hijack_forward(self, alpha, width, height, ema_factor, step_start, step_end, tokens, token_count): """ Create a hook to modify the output of the forward pass of the cross attention module Arguments: alpha: float - The strength of the CTNMS correction, default 0.1 width: int - The width of the final output image map height: int - The height of the final output image map ema_factor: float - EMA smoothing factor, default 2.0 step_start: int - Wait to apply CTNMS until this step step_end: int - The number of steps to apply the CTNMS correction, after which don't tokens: list[int] - List of token indices to condition on token_count: int - The number of tokens in the prompt Only modifies the output of the cross attention modules that get context (i.e. text embedding) """ cross_attn_modules = self.get_cross_attn_modules() if len(cross_attn_modules) == 0: logger.error("No cross attention modules found, cannot run T2I-0") return # add field for last_attn_map plot_num = 0 for module in cross_attn_modules: self.add_field_cross_attn_modules(module, 't2i0_last_attn_map', None) self.add_field_cross_attn_modules(module, 't2i0_step', int(-1)) self.add_field_cross_attn_modules(module, 't2i0_step_start', int(step_start)) self.add_field_cross_attn_modules(module, 't2i0_step_end', int(step_end)) self.add_field_cross_attn_modules(module, 't2i0_ema', None) self.add_field_cross_attn_modules(module, 't2i0_ema_factor', float(ema_factor)) self.add_field_cross_attn_modules(module, 'plot_num', int(plot_num)) self.add_field_cross_attn_modules(module, 't2i0_to_v_map', None) self.add_field_cross_attn_modules(module.to_v, 't2i0_parent_module', [module]) self.add_field_cross_attn_modules(module, 't2i0_token_count', int(token_count)) self.add_field_cross_attn_modules(module, 'gaussian_blur', GaussianBlur(kernel_size=3, sigma=1).to(device=shared.device)) if tokens is not None: self.add_field_cross_attn_modules(module, 't2i0_tokens', torch.tensor(tokens).to(device=shared.device, dtype=torch.int64)) else: self.add_field_cross_attn_modules(module, 't2i0_tokens', None) plot_num += 1 # def cross_token_non_maximum_suppression_pre(module, args, kwargs): # pass # pass def cross_token_non_maximum_suppression(module, input, kwargs, output): module.t2i0_step += 1 context = kwargs.get('context', None) if context is None: return if context.shape[1] % 77 != 0: logger.error("Context shape is not divisible by 77, cannot run T2I-0") return current_step = module.t2i0_step start_step = module.t2i0_step_start end_step = module.t2i0_step_end # Select token indices, default is ALL tokens token_count = module.t2i0_token_count token_indices = module.t2i0_tokens if current_step > end_step and end_step > 0: return if current_step < start_step: return batch_size, sequence_length, inner_dim = output.shape max_dims = width*height factor = math.isqrt(max_dims // sequence_length) # should be a square of 2 downscale_width = width // factor downscale_height = height // factor if downscale_width * downscale_height != sequence_length: print(f"Error: Width: {width}, height: {height}, Downscale width: {downscale_width}, height: {downscale_height}, Factor: {factor}, Max dims: {max_dims}\n") return # h = module.heads # head_dim = inner_dim // h dtype = output.dtype device = output.device # Multiply text embeddings into visual embeddings to_v_map = module.t2i0_to_v_map.detach().clone() to_v_inner_dim = to_v_map.size(-2) to_v_map = (to_v_map @ output.transpose(1, 2)).transpose(1, 2) to_v_attention_map = to_v_map.view(batch_size, downscale_height, downscale_width, to_v_inner_dim) # Original attention map attention_map = output.view(batch_size, downscale_height, downscale_width, inner_dim) if token_indices is None: selected_tokens = torch.arange(1, token_count, device=output.device) elif len(token_indices) == 0: selected_tokens = torch.arange(1, token_count, device=output.device) else: selected_tokens = module.t2i0_tokens if module.t2i0_ema is None: module.t2i0_ema = output.detach().clone() # Extract the attention maps for the selected tokens AC = to_v_attention_map[:, :, :, selected_tokens] # Extracting relevant attention maps # Extract and process the selected attention maps # GaussianBlur expects the input [..., C, H, W] gaussian_blur = module.gaussian_blur AC = AC.permute(0, 3, 1, 2) AC = gaussian_blur(AC) # Applying Gaussian smoothing AC = AC.permute(0, 2, 3, 1) # Find the maximum contributing token for each pixel M = torch.argmax(AC, dim=-1) one_hot_M = F.one_hot(M, num_classes=to_v_attention_map.size(-1)).to(dtype=dtype, device=device) # the attention map is of shape [batch_size, height, width, inner_dim] one_hot_M_z = rearrange(one_hot_M, 'b h w c -> b (h w) c') one_hot_M_z = one_hot_M_z @ module.t2i0_to_v_map one_hot_M_z = rearrange(one_hot_M_z, 'b (h w) c -> b h w c', h=downscale_height, w=downscale_width) suppressed_attention_map = one_hot_M_z * attention_map # Reshape back to original dimensions suppressed_attention_map = suppressed_attention_map.view(batch_size, sequence_length, inner_dim) # Calculate the EMA of the suppressed attention map if module.t2i0_ema_factor > 0: ema = module.t2i0_ema ema_factor = module.t2i0_ema_factor / (1 + current_step) # Add the suppressed attention map to the EMA ema = ema_factor * ema + (1 - ema_factor) * suppressed_attention_map module.t2i0_ema = ema out_tensor = (1 -alpha) * output + (alpha) * ema #out_tensor = (1-alpha) * ema + alpha * suppressed_attention_map else: out_tensor = (1-alpha) * output + alpha * suppressed_attention_map return out_tensor def t2i0_to_k_hook(module, input, kwargs, output): pass pass def t2i0_to_v_hook(module, input, kwargs, output): module.t2i0_parent_module[0].t2i0_to_v_map = output # Hook for module in cross_attn_modules: # handle = module.to_k.register_forward_hook(t2i0_to_k_hook, with_kwargs=True) handle = module.to_v.register_forward_hook(t2i0_to_v_hook, with_kwargs=True) handle = module.register_forward_hook(cross_token_non_maximum_suppression, with_kwargs=True) # handle = module.register_forward_pre_hook(cross_token_non_maximum_suppression_pre, with_kwargs=True) def get_cross_attn_modules(self): """ Get all cross attention modules """ try: m = shared.sd_model nlm = m.network_layer_mapping cross_attn_modules = [m for m in nlm.values() if 'CrossAttention' in m.__class__.__name__ and 'attn2' in m.network_layer_name] return cross_attn_modules except AttributeError: logger.exception("AttributeError while getting cross attention modules") return [] except Exception: logger.exception("Error while getting cross attention modules") return [] def add_field_cross_attn_modules(self, module, field, value): """ Add a field to a module if it doesn't exist """ if not hasattr(module, field): setattr(module, field, value) def remove_field_cross_attn_modules(self, module, field): """ Remove a field from a module if it exists """ if hasattr(module, field): delattr(module, field) def on_cfg_denoiser_callback(self, params: CFGDenoiserParams, t2i0_params: list[T2I0StateParams]): if isinstance(params.text_cond, dict): text_cond = params.text_cond['crossattn'] # SD XL else: text_cond = params.text_cond # SD 1.5 sp = t2i0_params[0] window_size = sp.window_size_period correction_strength = sp.correction_strength score_threshold = sp.correction_threshold step = params.sampling_step step_start = sp.step_start step_end = sp.step_end tokens = sp.tokens if sp.tokens is not None else [] if step_start > step: return if step > step_end: return for batch_idx, batch in enumerate(text_cond): window = list(range(0, len(batch))) f_bar = self.correction_by_similarities(batch, window, score_threshold, window_size, correction_strength, tokens) if isinstance(params.text_cond, dict): params.text_cond['crossattn'][batch_idx] = f_bar else: params.text_cond[batch_idx] = f_bar return def get_xyz_axis_options(self) -> dict: xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ in ("xyz_grid.py", "scripts.xyz_grid")][0].module extra_axis_options = { xyz_grid.AxisOption("[T2I-0] Active", str, t2i0_apply_override('t2i0_active', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)), xyz_grid.AxisOption("[T2I-0] Step Start", int, t2i0_apply_field("t2i0_step_start")), xyz_grid.AxisOption("[T2I-0] Step End", int, t2i0_apply_field("t2i0_step_end")), xyz_grid.AxisOption("[T2I-0] CbS Window Size", int, t2i0_apply_field("t2i0_window_size")), xyz_grid.AxisOption("[T2I-0] CbS Score Threshold", float, t2i0_apply_field("t2i0_correction_threshold")), xyz_grid.AxisOption("[T2I-0] CbS Correction Strength", float, t2i0_apply_field("t2i0_correction_strength")), xyz_grid.AxisOption("[T2I-0] CTNMS Alpha", float, t2i0_apply_field("t2i0_ctnms_alpha")), xyz_grid.AxisOption("[T2I-0] CTNMS EMA Smoothing Factor", float, t2i0_apply_field("t2i0_ema_factor")), } return extra_axis_options def plot_attention_map(attention_map: torch.Tensor, title, x_label="X", y_label="Y", save_path=None, plot_type="default"): """ Plots an attention map using matplotlib.pyplot Arguments: attention_map: Tensor - The attention map to plot title: str - The title of the plot x_label: str (optional) - The x-axis label y_label: str (optional) - The y-axis label save_path: str (optional) - The path to save the plot Returns: PIL.Image: The plot as a PIL image """ if attention_map.dim() == 3: attention_map = attention_map.squeeze(0).mean(2) plot_tools.plot_attention_map(attention_map, title, x_label, y_label, save_path, plot_type) def debug_plot_attention_map(attention_map): """ Plots an attention map using matplotlib.pyplot Arguments: attention_map: Tensor - The attention map to plot title: str - The title of the plot x_label: str (optional) - The x-axis label y_label: str (optional) - The y-axis label save_path: str (optional) - The path to save the plot Returns: PIL.Image: The plot as a PIL image """ plot_attention_map( attention_map, "Debug Output", save_path="F:\\incant\\temp\\AAA_out_temp.png" ) # XYZ Plot # Based on @mcmonkey4eva's XYZ Plot implementation here: https://github.com/mcmonkeyprojects/sd-dynamic-thresholding/blob/master/scripts/dynamic_thresholding.py def t2i0_apply_override(field, boolean: bool = False): def fun(p, x, xs): if boolean: x = True if x.lower() == "true" else False setattr(p, field, x) return fun def t2i0_apply_field(field): def fun(p, x, xs): if not hasattr(p, "t2i0_active"): p.t2i0_active = True setattr(p, field, x) return fun # taken from modules/ui.py def get_token_count(text, steps, is_positive: bool = True): try: text, _ = extra_networks.parse_prompt(text) if is_positive: _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) else: prompt_flat_list = [text] prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) except Exception: # a parsing error can happen here during typing, and we don't want to bother the user with # messages related to it in console prompt_schedules = [[[steps, text]]] flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) prompts = [prompt_text for step, prompt_text in flat_prompts] token_count, max_length = max([sd_hijack.model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) return token_count, max_length # thanks torch; removing hooks DOESN'T WORK # thank you to @ProGamerGov for this https://github.com/pytorch/pytorch/issues/70455 def _remove_all_forward_hooks( module: torch.nn.Module, hook_fn_name: Optional[str] = None ) -> None: """ This function removes all forward hooks in the specified module, without requiring any hook handles. This lets us clean up & remove any hooks that weren't property deleted. Warning: Various PyTorch modules and systems make use of hooks, and thus extreme caution should be exercised when removing all hooks. Users are recommended to give their hook function a unique name that can be used to safely identify and remove the target forward hooks. Args: module (nn.Module): The module instance to remove forward hooks from. hook_fn_name (str, optional): Optionally only remove specific forward hooks based on their function's __name__ attribute. Default: None """ if hook_fn_name is None: warn("Removing all active hooks can break some PyTorch modules & systems.") def _remove_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None: if hasattr(module, "_forward_hooks"): if m._forward_hooks != OrderedDict(): if name is not None: dict_items = list(m._forward_hooks.items()) m._forward_hooks = OrderedDict( [(i, fn) for i, fn in dict_items if fn.__name__ != name] ) else: m._forward_hooks: Dict[int, Callable] = OrderedDict() def _remove_child_hooks( target_module: torch.nn.Module, hook_name: Optional[str] = None ) -> None: for _, child in target_module._modules.items(): if child is not None: _remove_hooks(child, hook_name) _remove_child_hooks(child, hook_name) # Remove hooks from target submodules _remove_child_hooks(module, hook_fn_name) # Remove hooks from the target module _remove_hooks(module, hook_fn_name)