import os import logging import copy import gradio as gr import torch import re from torchvision.transforms import GaussianBlur from einops import rearrange from modules import shared, script_callbacks from modules.images import get_next_sequence_number from modules.processing import StableDiffusionProcessing from scripts.ui_wrapper import UIWrapper, arg from scripts.incant_utils import module_hooks, plot_tools, prompt_utils logger = logging.getLogger(__name__) module_field_map = { 'savemaps': True, 'savemaps_batch': None, 'savemaps_step': None, 'savemaps_save_steps': None, } SUBMODULES = ['to_q', 'to_k', 'to_v'] class SaveAttentionMapsScript(UIWrapper): def __init__(self): self.infotext_fields: list = [] self.paste_field_names: list = [] def title(self) -> str: return "Save Attention Maps" def setup_ui(self, is_img2img) -> list: with gr.Accordion('Save Attention Maps', open = False): with gr.Row(): active = gr.Checkbox(label = 'Active', default = False) map_types = gr.CheckboxGroup( label = 'Map Types', choices = ['One-Hot Map', 'Per-Token Maps'], value = ['One-Hot Map'], info = 'Select the type of attention maps to save.', ) export_folder = gr.Textbox(visible=False, label = 'Export Folder', value = 'attention_maps', info = 'Folder to save attention maps to as a subdirectory of the outputs.') module_name_filter = gr.Textbox(label = 'Module Names', value = 'input_blocks_5_1_transformer_blocks_0_attn2', info = 'Module name to save attention maps for. If the substring is found in the module name, the attention maps will be saved for that module.') class_name_filter = gr.Textbox(label = 'Class Name Filter', value = 'CrossAttention', info = 'Filters eligible modules by the class name.') save_every_n_step = gr.Slider(label = 'Save Every N Step', value = 0, min = 0, max = 100, step = 1, info = 'Save attention maps every N steps. 0 to save last step.') print_modules = gr.Button(value = 'Print Modules To Console') print_modules.click(self.print_modules, inputs=[module_name_filter, class_name_filter]) self.infotext_fields = [] self.paste_field_names = [] opts = [active, module_name_filter, class_name_filter, save_every_n_step, map_types] for opt in opts: opt.do_not_save_to_config = True return opts def before_process_batch(self, p: StableDiffusionProcessing, active, module_name_filter, class_name_filter, save_every_n_step, map_types, *args, **kwargs): # Always unhook the modules first module_list = self.get_modules_by_filter(module_name_filter, class_name_filter) script_callbacks.remove_current_script_callbacks() self.unhook_modules(module_list, copy.deepcopy(module_field_map)) setattr(p, 'savemaps_module_list', module_list) setattr(p, 'savemaps_map_types', map_types) if not active: return token_count, _= prompt_utils.get_token_count(p.prompt, p.steps, True) if token_count <= 0: logger.warning("No tokens found in prompt. Skipping saving attention maps.") return setattr(p, 'savemaps_token_count', token_count) setattr(p, 'savemaps_step', 0) token_indices = [] # Tokenize/decode the prompts tokenized_prompts = [] batch_chunks, _ = prompt_utils.tokenize_prompt(p.prompt) for batch in batch_chunks: for sub_batch in batch: tokenized_prompts.append(prompt_utils.decode_tokenized_prompt(sub_batch.tokens)) for tp_prompt in tokenized_prompts: for tp in tp_prompt: token_idx, token_id, word = tp # jank if token_id < 49406: token_indices.append(token_idx) # sanitize tokenized prompts tp[2] = re.escape(word) setattr(p, 'savemaps_tokenized_prompts', tokenized_prompts) setattr(p, 'savemaps_token_indices', token_indices) # Make sure the output folder exists outpath_samples = p.outpath_samples # Move this to plot tools? if not outpath_samples: logger.warning("No output path found. Skipping saving attention maps.") return output_folder_path = os.path.join(outpath_samples, 'attention_maps') if not os.path.exists(output_folder_path): logger.info(f"Creating directory: {output_folder_path}") os.makedirs(output_folder_path) # sequence number for saving seq_num = get_next_sequence_number(output_folder_path, basename='') setattr(p, 'savemaps_seq_num', seq_num) latent_shape = [p.height // p.rng.shape[1], p.width // p.rng.shape[2]] # (height, width) save_steps = [] min_step = max(save_every_n_step-1, 0) if save_every_n_step > 0: save_steps = list(range(min_step, p.steps, save_every_n_step)) else: save_steps = [p.steps-1] # always save last step if p.steps-1 not in save_steps: save_steps.append(p.steps-1) setattr(p, 'savemaps_save_steps', save_steps) # Create fields in module value_map = copy.deepcopy(module_field_map) value_map['savemaps_save_steps'] = save_steps value_map['savemaps_step'] = 0 #value_map['savemaps_shape'] = torch.tensor(latent_shape).to(device=shared.device, dtype=torch.int32) self.hook_modules(module_list, value_map, p) self.create_save_hook(module_list) def on_cfg_denoiser(params: script_callbacks.CFGDenoiserParams): """ Sets the step for all modules the webui reports an incorrect step so we just count it ourselves """ for module in module_list: module.savemaps_step = p.savemaps_step # logger.debug('Setting step to %d for %d modules', p.savemaps_step, len(module_list)) p.savemaps_step += 1 script_callbacks.on_cfg_denoiser(on_cfg_denoiser) def process(self, p, *args, **kwargs): pass def before_process(self, p: StableDiffusionProcessing, active, module_name_filter, class_name_filter, save_every_n_step, map_types, *args, **kwargs): module_list = self.get_modules_by_filter(module_name_filter, class_name_filter) self.unhook_modules(module_list, copy.deepcopy(module_field_map)) def process_batch(self, p, *args, **kwargs): pass def postprocess_batch(self, p: StableDiffusionProcessing, active, module_name_filter, class_name_filter, save_every_n_step, map_types, *args, **kwargs): module_list = self.get_modules_by_filter(module_name_filter, class_name_filter) if getattr(p, 'savemaps_token_count', None) is None: self.unhook_modules(module_list, copy.deepcopy(module_field_map)) return base_seq_num = getattr(p, 'savemaps_seq_num', None) map_types = getattr(p, 'savemaps_map_types', []) tokenized_prompts = getattr(p, 'savemaps_tokenized_prompts', None) token_indices = getattr(p, 'savemaps_token_indices', None) save_steps = getattr(p, 'savemaps_save_steps', None) save_image_path = os.path.join(p.outpath_samples, 'attention_maps') plot_is_self = False # kind of useless for module in module_list: network_layer_name = module.network_layer_name if not hasattr(module, 'savemaps_batch') or module.savemaps_batch is None: logger.error(f"No attention maps found for module: {network_layer_name}") continue # self attn maps are kind of useless atm is_self = getattr(module, 'savemaps_is_self', False) if is_self and not plot_is_self: continue # selfattn: seq_len = hw # crossattn: seq_len = # of tokens attn_maps = module.savemaps_batch # (attn_map num, 2 * batch_num, height * width, sequence_len) attn_map_num, batch_num, hw, seq_len = attn_maps.shape token_indices = p.savemaps_token_indices save_steps = p.savemaps_save_steps downscale_h = round((hw * (p.height / p.width)) ** 0.5) downscale_w = hw // downscale_h gaussian_blur = GaussianBlur(kernel_size=3, sigma=1) # Blur maps if is_self: attn_maps = attn_maps.view(attn_map_num * batch_num, downscale_h, downscale_w, seq_len) # if self-attn, we need to blur over the sequence length attn_maps = attn_maps.permute(0, 3, 1, 2) # (ab, seq_len, height, width) attn_maps = gaussian_blur(attn_maps) # Applying Gaussian smoothing attn_maps = attn_maps.permute(0, 2, 3, 1) # (ab, height, width, seq_len) if is_self: attn_maps = attn_maps.view(attn_map_num, 2, batch_num // 2, downscale_h * downscale_w, seq_len).mean(dim=1) # (attn_map num, batch_num, hw, hw) attn_maps = attn_maps.unsqueeze(2) # (attn_map num, batch_num, 1, hw, hw) else: attn_maps = rearrange(attn_maps, 'n (m b) (h w) t -> n m b t h w', m = 2, h = downscale_h).mean(dim=1) # (attn_map num, batch_num, token_idx, height, width) attn_map_num, batch_num, token_dim, h, w = attn_maps.shape output_dict_maps = [] per_token_dict_maps = [] one_hot_dict_maps = [] if 'Per-Token Maps' in map_types: # write to dict for attn_map_idx in range(attn_maps.shape[0]): for batch_idx in range(batch_num): for token_idx in token_indices: attnmap = attn_maps[attn_map_idx, batch_idx, token_idx] _, token_id, word = tokenized_prompts[batch_idx][token_idx] plot_type = f"({token_idx}, {token_id}, '{word}')" filename_info = f'token{token_idx:04}' plot_color = 'viridis' map_info: dict = self.create_base_dict(plot_type, base_seq_num, network_layer_name, save_steps, attn_map_idx, batch_idx, attnmap, filename_info, plot_color) map_info.update({ 'token_idx': token_idx, 'token_id': token_id, 'token_word': word, }) output_dict_maps.append(map_info) if 'One-Hot Map' in map_types: one_hot_map = attn_maps[:, :, token_indices] # (attn_map num, batch_num, token_idx, height, width) one_hot_map = one_hot_map.argmax(dim=2, keepdim=True) one_hot_map = one_hot_map.to(dtype=torch.float16) # quantize to stable number of colors s.t. num_colors = max(len(token_indices), 1) min_val, max_val = one_hot_map.min(), one_hot_map.max() step = 1 / num_colors one_hot_map *= step one_hot_map = one_hot_map.sum(dim=2) # (attn_map num, batch_num, height, width) # write to dict for attn_map_idx in range(one_hot_map.shape[0]): for batch_idx in range(batch_num): plot_type = "One Hot" plot_color = 'plasma' attnmap = one_hot_map[attn_map_idx, batch_idx] ohm_info: dict = self.create_base_dict(plot_type, base_seq_num, network_layer_name, save_steps, attn_map_idx, batch_idx, attnmap, 'ohm', plot_color) output_dict_maps.append(ohm_info) # Save maps from map dict for md in output_dict_maps: base_seq_num = md['seq_num'] network_layer_name = md['network_layer_name'] savestep_num = md['savestep_num'] attn_map_idx = md['attn_map_idx'] batch_idx = md['batch_idx'] # output filename and path filename_info = md['filename_info'] if len(filename_info) > 0: filename_info = f'{filename_info}_' out_file_name = f'{base_seq_num:04}-{network_layer_name}_{filename_info}step{savestep_num:04}_attnmap_{attn_map_idx:04}_batch{batch_idx:04}.png' out_save_path = os.path.join(save_image_path, out_file_name) # plot title plot_type = md['plot_type'] plot_color = md['plot_color'] plot_title = f"{network_layer_name}\nStep {savestep_num}" if len(plot_type) > 0: plot_title += f", {plot_type}" attn_map = md['attnmap'] plot_tools.plot_attention_map( attention_map = attn_map, title = plot_title, save_path = out_save_path, plot_type = plot_color, ) if shared.state.interrupted: self.unhook_modules(module_list, copy.deepcopy(module_field_map)) return self.unhook_modules(module_list, copy.deepcopy(module_field_map)) def create_base_dict(self, plot_type:str, base_seq_num: int, network_layer_name: str, save_steps: list, attn_map_idx: int, batch_idx: int, attnmap: torch.Tensor, filename_info: str, plot_color: str): """ Create a base dictionary for saving attention maps for minimum metadata that the save function expects Arguments: plot_type: str - name of the type of plot, used in the plot title base_seq_num: int - start sequence number for saving, prefixes the filename with "000xx-" where xx is the sequence number module_name: str - the module's network layer name save_steps: list[int] - list of steps to save attention maps for, should be same length as the number of attention maps attn_map_idx: int - index of the attention map batch_idx: int - index of the batch attnmap: torch.Tensor - attention map of shape [C, H, W] filename_info: str- a string that goes in the middle of the filename f"000xx-{filename_info}-000yy.png" plot_color: str - one of the matplotlib color maps (default is 'viridis') """ network_layer_name = network_layer_name.removeprefix('diffusion_model_') network_layer_name = network_layer_name.replace('transformer_blocks_', 'tr_bl_') base_dict = { 'plot_type': plot_type, 'seq_num': base_seq_num + batch_idx, 'step': save_steps[attn_map_idx] + 1, 'network_layer_name': network_layer_name, 'attn_map_idx': attn_map_idx, 'savestep_num': save_steps[attn_map_idx] + 1, 'batch_idx': batch_idx, 'attnmap': attnmap, 'filename_info': filename_info, 'plot_color': plot_color, } return base_dict def unhook_callbacks(self) -> None: pass def get_xyz_axis_options(self) -> dict: return {} def get_infotext_fields(self) -> list: return self.infotext_fields def create_save_hook(self, module_list): pass def hook_modules(self, module_list: list, value_map: dict, p: StableDiffusionProcessing): def savemaps_hook(module, input, kwargs, output): """ Hook to save attention maps every N steps, or the last step if N is 0. Saves attention maps to a field named 'savemaps_batch' in the module. with shape (attn_map, batch_num, height * width). """ #module.savemaps_step += 1 if not module.savemaps_step in module.savemaps_save_steps: return reweight_crossattn = True is_self = getattr(module, 'savemaps_is_self', False) to_q_map = getattr(module, 'savemaps_to_q_map', None) to_k_map = to_q_map if module.savemaps_is_self else getattr(module, 'savemaps_to_k_map', None) # we want to reweight the attention scores by removing influence of the first token orig_seq_len = to_k_map.shape[1] # token_count = module.savemaps_token_count # min_token = 0 # max_token = min(token_count+1, orig_seq_len) token_indices = module.savemaps_token_indices if not is_self and reweight_crossattn: to_k_map = to_k_map[:, token_indices, :] attn_map = get_attention_scores(to_q_map, to_k_map, dtype=to_q_map.dtype) b, hw, seq_len = attn_map.shape if not is_self and reweight_crossattn: #to_attn_zeros = torch.zeros([b, hw]).unsqueeze(-1).to(device=shared.device, dtype=attn_map.dtype) # (batch, h*w, 1) #attn_map = torch.cat([to_attn_zeros, attn_map], dim=-1) # re pad to original token dim size left_pad = 1 right_pad = orig_seq_len - seq_len - 1 attn_map = torch.nn.functional.pad(attn_map, (left_pad, right_pad), value=0) # re pad to original token dim size # multiply into text embeddings attn_map = attn_map.unsqueeze(0) #attn_map = attn_map.mean(dim=-1) if module.savemaps_batch is None: module.savemaps_batch = attn_map else: module.savemaps_batch = torch.cat([module.savemaps_batch, attn_map], dim=0) def savemaps_to_q_hook(module, input, kwargs, output): setattr(module.savemaps_parent_module[0], 'savemaps_to_q_map', output) def savemaps_to_k_hook(module, input, kwargs, output): if not module.savemaps_parent_module[0].savemaps_is_self: setattr(module.savemaps_parent_module[0],'savemaps_to_k_map', output) def savemaps_to_v_hook(module, input, kwargs, output): setattr(module.savemaps_parent_module[0],'savemaps_to_v_map', output) #for module, kv in zip(module_list, value_map.items()): for module in module_list: # logger.debug('Adding hook to %s', module.network_layer_name) for key_name, default_value in value_map.items(): module_hooks.modules_add_field(module, key_name, default_value) module_hooks.module_add_forward_hook(module, savemaps_hook, 'forward', with_kwargs=True) module_hooks.modules_add_field(module, 'savemaps_token_count', p.savemaps_token_count) module_hooks.modules_add_field(module, 'savemaps_token_indices', p.savemaps_token_indices) if module.network_layer_name.endswith('attn1'): # self attn module_hooks.modules_add_field(module, 'savemaps_is_self', True) if module.network_layer_name.endswith('attn2'): # self attn module_hooks.modules_add_field(module, 'savemaps_is_self', False) for module_name in SUBMODULES: if not hasattr(module, module_name): logger.error(f"Submodule not found: {module_name} in module: {module.network_layer_name}") continue submodule = getattr(module, module_name) hook_fn_name = f'savemaps_{module_name}_hook' hook_fn = locals().get(hook_fn_name, None) if not hook_fn: logger.error(f"Hook function '{hook_fn_name}' not found for submodule: {module_name}") continue module_hooks.modules_add_field(submodule, 'savemaps_parent_module', [module]) module_hooks.module_add_forward_hook(submodule, hook_fn, 'forward', with_kwargs=True) def unhook_modules(self, module_list: list, value_map: dict): for module in module_list: for key_name, _ in value_map.items(): module_hooks.modules_remove_field(module, key_name) module_hooks.modules_remove_field(module, 'savemaps_is_self') module_hooks.modules_remove_field(module, 'savemaps_token_count') module_hooks.modules_remove_field(module, 'savemaps_token_indices') module_hooks.remove_module_forward_hook(module, 'savemaps_hook') for module_name in SUBMODULES: module_hooks.modules_remove_field(module, f'savemaps_{module_name}_map') if hasattr(module, module_name): submodule = getattr(module, module_name) module_hooks.modules_remove_field(submodule, 'savemaps_parent_module') module_hooks.remove_module_forward_hook(submodule, f'savemaps_{module_name}_hook') def print_modules(self, module_name_filter, class_name_filter): logger.info("Module name filter: '%s', Class name filter: '%s'", module_name_filter, class_name_filter) modules = self.get_modules_by_filter(module_name_filter, class_name_filter) module_names = [""] if len(modules) > 0: module_names = "\n".join([f"{m.network_layer_name}: {m.__class__.__name__}" for m in modules]) logger.info("Modules found:\n----------\n%s\n----------\n", module_names) def get_modules_by_filter(self, module_name_filter, class_name_filter): if len(class_name_filter) == 0: class_name_filter = None if len(module_name_filter) == 0: module_name_filter = None found_modules = module_hooks.get_modules(module_name_filter, class_name_filter) if len(found_modules) == 0: logger.warning(f"No modules found with module name filter: {module_name_filter} and class name filter") return found_modules def get_attention_scores(to_q_map, to_k_map, dtype): """ Calculate the attention scores for the given query and key maps Arguments: to_q_map: torch.Tensor - query map to_k_map: torch.Tensor - key map dtype: torch.dtype - data type of the tensor Returns: torch.Tensor - attention scores """ # based on diffusers models/attention.py "get_attention_scores" # use in place operations vs. softmax to save memory: https://stackoverflow.com/questions/53732209/torch-in-place-operations-to-save-memory-softmax # 512x: 2.65G -> 2.47G attn_probs = to_q_map @ to_k_map.transpose(-1, -2) attn_probs = attn_probs.to(dtype=torch.float32) # channel_dim = to_q_map.size(1) attn_probs /= (channel_dim ** 0.5) attn_probs -= attn_probs.max() # avoid nan by converting to float32 and subtracting max attn_probs = attn_probs.softmax(dim=-1).to(device=shared.device, dtype=to_q_map.dtype) attn_probs = attn_probs.to(dtype=dtype) return attn_probs