|
|
import os |
|
|
import logging |
|
|
import copy |
|
|
import gradio as gr |
|
|
import torch |
|
|
import re |
|
|
from torchvision.transforms import GaussianBlur |
|
|
|
|
|
|
|
|
from einops import rearrange |
|
|
from modules import shared, script_callbacks |
|
|
from modules.images import get_next_sequence_number |
|
|
from modules.processing import StableDiffusionProcessing |
|
|
from scripts.ui_wrapper import UIWrapper, arg |
|
|
from scripts.incant_utils import module_hooks, plot_tools, prompt_utils |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
module_field_map = { |
|
|
'savemaps': True, |
|
|
'savemaps_batch': None, |
|
|
'savemaps_step': None, |
|
|
'savemaps_save_steps': None, |
|
|
} |
|
|
|
|
|
|
|
|
SUBMODULES = ['to_q', 'to_k', 'to_v'] |
|
|
|
|
|
|
|
|
class SaveAttentionMapsScript(UIWrapper): |
|
|
def __init__(self): |
|
|
self.infotext_fields: list = [] |
|
|
self.paste_field_names: list = [] |
|
|
|
|
|
def title(self) -> str: |
|
|
return "Save Attention Maps" |
|
|
|
|
|
def setup_ui(self, is_img2img) -> list: |
|
|
with gr.Accordion('Save Attention Maps', open = False): |
|
|
with gr.Row(): |
|
|
active = gr.Checkbox(label = 'Active', default = False) |
|
|
map_types = gr.CheckboxGroup( |
|
|
label = 'Map Types', |
|
|
choices = ['One-Hot Map', 'Per-Token Maps'], |
|
|
value = ['One-Hot Map'], |
|
|
info = 'Select the type of attention maps to save.', |
|
|
) |
|
|
export_folder = gr.Textbox(visible=False, label = 'Export Folder', value = 'attention_maps', info = 'Folder to save attention maps to as a subdirectory of the outputs.') |
|
|
module_name_filter = gr.Textbox(label = 'Module Names', value = 'input_blocks_5_1_transformer_blocks_0_attn2', info = 'Module name to save attention maps for. If the substring is found in the module name, the attention maps will be saved for that module.') |
|
|
class_name_filter = gr.Textbox(label = 'Class Name Filter', value = 'CrossAttention', info = 'Filters eligible modules by the class name.') |
|
|
save_every_n_step = gr.Slider(label = 'Save Every N Step', value = 0, min = 0, max = 100, step = 1, info = 'Save attention maps every N steps. 0 to save last step.') |
|
|
print_modules = gr.Button(value = 'Print Modules To Console') |
|
|
print_modules.click(self.print_modules, inputs=[module_name_filter, class_name_filter]) |
|
|
|
|
|
self.infotext_fields = [] |
|
|
self.paste_field_names = [] |
|
|
|
|
|
opts = [active, module_name_filter, class_name_filter, save_every_n_step, map_types] |
|
|
for opt in opts: |
|
|
opt.do_not_save_to_config = True |
|
|
return opts |
|
|
|
|
|
def before_process_batch(self, p: StableDiffusionProcessing, active, module_name_filter, class_name_filter, save_every_n_step, map_types, *args, **kwargs): |
|
|
|
|
|
module_list = self.get_modules_by_filter(module_name_filter, class_name_filter) |
|
|
script_callbacks.remove_current_script_callbacks() |
|
|
self.unhook_modules(module_list, copy.deepcopy(module_field_map)) |
|
|
|
|
|
setattr(p, 'savemaps_module_list', module_list) |
|
|
setattr(p, 'savemaps_map_types', map_types) |
|
|
|
|
|
if not active: |
|
|
return |
|
|
|
|
|
token_count, _= prompt_utils.get_token_count(p.prompt, p.steps, True) |
|
|
|
|
|
if token_count <= 0: |
|
|
logger.warning("No tokens found in prompt. Skipping saving attention maps.") |
|
|
return |
|
|
|
|
|
setattr(p, 'savemaps_token_count', token_count) |
|
|
setattr(p, 'savemaps_step', 0) |
|
|
|
|
|
token_indices = [] |
|
|
|
|
|
tokenized_prompts = [] |
|
|
batch_chunks, _ = prompt_utils.tokenize_prompt(p.prompt) |
|
|
for batch in batch_chunks: |
|
|
for sub_batch in batch: |
|
|
tokenized_prompts.append(prompt_utils.decode_tokenized_prompt(sub_batch.tokens)) |
|
|
for tp_prompt in tokenized_prompts: |
|
|
for tp in tp_prompt: |
|
|
token_idx, token_id, word = tp |
|
|
|
|
|
if token_id < 49406: |
|
|
token_indices.append(token_idx) |
|
|
|
|
|
tp[2] = re.escape(word) |
|
|
|
|
|
|
|
|
setattr(p, 'savemaps_tokenized_prompts', tokenized_prompts) |
|
|
setattr(p, 'savemaps_token_indices', token_indices) |
|
|
|
|
|
|
|
|
|
|
|
outpath_samples = p.outpath_samples |
|
|
|
|
|
if not outpath_samples: |
|
|
logger.warning("No output path found. Skipping saving attention maps.") |
|
|
return |
|
|
output_folder_path = os.path.join(outpath_samples, 'attention_maps') |
|
|
if not os.path.exists(output_folder_path): |
|
|
logger.info(f"Creating directory: {output_folder_path}") |
|
|
os.makedirs(output_folder_path) |
|
|
|
|
|
|
|
|
seq_num = get_next_sequence_number(output_folder_path, basename='') |
|
|
setattr(p, 'savemaps_seq_num', seq_num) |
|
|
|
|
|
latent_shape = [p.height // p.rng.shape[1], p.width // p.rng.shape[2]] |
|
|
|
|
|
save_steps = [] |
|
|
min_step = max(save_every_n_step-1, 0) |
|
|
if save_every_n_step > 0: |
|
|
save_steps = list(range(min_step, p.steps, save_every_n_step)) |
|
|
else: |
|
|
save_steps = [p.steps-1] |
|
|
|
|
|
if p.steps-1 not in save_steps: |
|
|
save_steps.append(p.steps-1) |
|
|
setattr(p, 'savemaps_save_steps', save_steps) |
|
|
|
|
|
|
|
|
value_map = copy.deepcopy(module_field_map) |
|
|
value_map['savemaps_save_steps'] = save_steps |
|
|
value_map['savemaps_step'] = 0 |
|
|
|
|
|
self.hook_modules(module_list, value_map, p) |
|
|
self.create_save_hook(module_list) |
|
|
|
|
|
def on_cfg_denoiser(params: script_callbacks.CFGDenoiserParams): |
|
|
""" Sets the step for all modules |
|
|
the webui reports an incorrect step so we just count it ourselves |
|
|
""" |
|
|
for module in module_list: |
|
|
module.savemaps_step = p.savemaps_step |
|
|
|
|
|
p.savemaps_step += 1 |
|
|
|
|
|
script_callbacks.on_cfg_denoiser(on_cfg_denoiser) |
|
|
|
|
|
|
|
|
def process(self, p, *args, **kwargs): |
|
|
pass |
|
|
|
|
|
def before_process(self, p: StableDiffusionProcessing, active, module_name_filter, class_name_filter, save_every_n_step, map_types, *args, **kwargs): |
|
|
module_list = self.get_modules_by_filter(module_name_filter, class_name_filter) |
|
|
self.unhook_modules(module_list, copy.deepcopy(module_field_map)) |
|
|
|
|
|
def process_batch(self, p, *args, **kwargs): |
|
|
pass |
|
|
|
|
|
def postprocess_batch(self, p: StableDiffusionProcessing, active, module_name_filter, class_name_filter, save_every_n_step, map_types, *args, **kwargs): |
|
|
module_list = self.get_modules_by_filter(module_name_filter, class_name_filter) |
|
|
|
|
|
if getattr(p, 'savemaps_token_count', None) is None: |
|
|
self.unhook_modules(module_list, copy.deepcopy(module_field_map)) |
|
|
return |
|
|
|
|
|
base_seq_num = getattr(p, 'savemaps_seq_num', None) |
|
|
map_types = getattr(p, 'savemaps_map_types', []) |
|
|
tokenized_prompts = getattr(p, 'savemaps_tokenized_prompts', None) |
|
|
token_indices = getattr(p, 'savemaps_token_indices', None) |
|
|
save_steps = getattr(p, 'savemaps_save_steps', None) |
|
|
save_image_path = os.path.join(p.outpath_samples, 'attention_maps') |
|
|
|
|
|
plot_is_self = False |
|
|
|
|
|
for module in module_list: |
|
|
network_layer_name = module.network_layer_name |
|
|
|
|
|
if not hasattr(module, 'savemaps_batch') or module.savemaps_batch is None: |
|
|
logger.error(f"No attention maps found for module: {network_layer_name}") |
|
|
continue |
|
|
|
|
|
|
|
|
is_self = getattr(module, 'savemaps_is_self', False) |
|
|
if is_self and not plot_is_self: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
attn_maps = module.savemaps_batch |
|
|
attn_map_num, batch_num, hw, seq_len = attn_maps.shape |
|
|
token_indices = p.savemaps_token_indices |
|
|
save_steps = p.savemaps_save_steps |
|
|
downscale_h = round((hw * (p.height / p.width)) ** 0.5) |
|
|
downscale_w = hw // downscale_h |
|
|
gaussian_blur = GaussianBlur(kernel_size=3, sigma=1) |
|
|
|
|
|
|
|
|
if is_self: |
|
|
attn_maps = attn_maps.view(attn_map_num * batch_num, downscale_h, downscale_w, seq_len) |
|
|
attn_maps = attn_maps.permute(0, 3, 1, 2) |
|
|
attn_maps = gaussian_blur(attn_maps) |
|
|
attn_maps = attn_maps.permute(0, 2, 3, 1) |
|
|
if is_self: |
|
|
attn_maps = attn_maps.view(attn_map_num, 2, batch_num // 2, downscale_h * downscale_w, seq_len).mean(dim=1) |
|
|
attn_maps = attn_maps.unsqueeze(2) |
|
|
else: |
|
|
attn_maps = rearrange(attn_maps, 'n (m b) (h w) t -> n m b t h w', m = 2, h = downscale_h).mean(dim=1) |
|
|
attn_map_num, batch_num, token_dim, h, w = attn_maps.shape |
|
|
|
|
|
output_dict_maps = [] |
|
|
per_token_dict_maps = [] |
|
|
one_hot_dict_maps = [] |
|
|
|
|
|
if 'Per-Token Maps' in map_types: |
|
|
|
|
|
|
|
|
for attn_map_idx in range(attn_maps.shape[0]): |
|
|
for batch_idx in range(batch_num): |
|
|
for token_idx in token_indices: |
|
|
|
|
|
attnmap = attn_maps[attn_map_idx, batch_idx, token_idx] |
|
|
_, token_id, word = tokenized_prompts[batch_idx][token_idx] |
|
|
|
|
|
plot_type = f"({token_idx}, {token_id}, '{word}')" |
|
|
filename_info = f'token{token_idx:04}' |
|
|
plot_color = 'viridis' |
|
|
|
|
|
map_info: dict = self.create_base_dict(plot_type, base_seq_num, network_layer_name, save_steps, attn_map_idx, batch_idx, attnmap, filename_info, plot_color) |
|
|
map_info.update({ |
|
|
'token_idx': token_idx, |
|
|
'token_id': token_id, |
|
|
'token_word': word, |
|
|
}) |
|
|
output_dict_maps.append(map_info) |
|
|
|
|
|
if 'One-Hot Map' in map_types: |
|
|
one_hot_map = attn_maps[:, :, token_indices] |
|
|
one_hot_map = one_hot_map.argmax(dim=2, keepdim=True) |
|
|
one_hot_map = one_hot_map.to(dtype=torch.float16) |
|
|
|
|
|
|
|
|
num_colors = max(len(token_indices), 1) |
|
|
min_val, max_val = one_hot_map.min(), one_hot_map.max() |
|
|
step = 1 / num_colors |
|
|
one_hot_map *= step |
|
|
one_hot_map = one_hot_map.sum(dim=2) |
|
|
|
|
|
|
|
|
for attn_map_idx in range(one_hot_map.shape[0]): |
|
|
for batch_idx in range(batch_num): |
|
|
plot_type = "One Hot" |
|
|
plot_color = 'plasma' |
|
|
attnmap = one_hot_map[attn_map_idx, batch_idx] |
|
|
ohm_info: dict = self.create_base_dict(plot_type, base_seq_num, network_layer_name, save_steps, attn_map_idx, batch_idx, attnmap, 'ohm', plot_color) |
|
|
output_dict_maps.append(ohm_info) |
|
|
|
|
|
|
|
|
for md in output_dict_maps: |
|
|
base_seq_num = md['seq_num'] |
|
|
network_layer_name = md['network_layer_name'] |
|
|
savestep_num = md['savestep_num'] |
|
|
attn_map_idx = md['attn_map_idx'] |
|
|
batch_idx = md['batch_idx'] |
|
|
|
|
|
|
|
|
filename_info = md['filename_info'] |
|
|
if len(filename_info) > 0: |
|
|
filename_info = f'{filename_info}_' |
|
|
|
|
|
out_file_name = f'{base_seq_num:04}-{network_layer_name}_{filename_info}step{savestep_num:04}_attnmap_{attn_map_idx:04}_batch{batch_idx:04}.png' |
|
|
out_save_path = os.path.join(save_image_path, out_file_name) |
|
|
|
|
|
|
|
|
plot_type = md['plot_type'] |
|
|
plot_color = md['plot_color'] |
|
|
plot_title = f"{network_layer_name}\nStep {savestep_num}" |
|
|
if len(plot_type) > 0: |
|
|
plot_title += f", {plot_type}" |
|
|
|
|
|
attn_map = md['attnmap'] |
|
|
plot_tools.plot_attention_map( |
|
|
attention_map = attn_map, |
|
|
title = plot_title, |
|
|
save_path = out_save_path, |
|
|
plot_type = plot_color, |
|
|
) |
|
|
|
|
|
if shared.state.interrupted: |
|
|
self.unhook_modules(module_list, copy.deepcopy(module_field_map)) |
|
|
return |
|
|
self.unhook_modules(module_list, copy.deepcopy(module_field_map)) |
|
|
|
|
|
def create_base_dict(self, plot_type:str, base_seq_num: int, network_layer_name: str, save_steps: list, attn_map_idx: int, batch_idx: int, attnmap: torch.Tensor, filename_info: str, plot_color: str): |
|
|
""" Create a base dictionary for saving attention maps for minimum metadata that the save function expects |
|
|
Arguments: |
|
|
plot_type: str - name of the type of plot, used in the plot title |
|
|
base_seq_num: int - start sequence number for saving, prefixes the filename with "000xx-" where xx is the sequence number |
|
|
module_name: str - the module's network layer name |
|
|
save_steps: list[int] - list of steps to save attention maps for, should be same length as the number of attention maps |
|
|
attn_map_idx: int - index of the attention map |
|
|
batch_idx: int - index of the batch |
|
|
attnmap: torch.Tensor - attention map of shape [C, H, W] |
|
|
filename_info: str- a string that goes in the middle of the filename f"000xx-{filename_info}-000yy.png" |
|
|
plot_color: str - one of the matplotlib color maps (default is 'viridis') |
|
|
""" |
|
|
network_layer_name = network_layer_name.removeprefix('diffusion_model_') |
|
|
network_layer_name = network_layer_name.replace('transformer_blocks_', 'tr_bl_') |
|
|
base_dict = { |
|
|
'plot_type': plot_type, |
|
|
'seq_num': base_seq_num + batch_idx, |
|
|
'step': save_steps[attn_map_idx] + 1, |
|
|
'network_layer_name': network_layer_name, |
|
|
'attn_map_idx': attn_map_idx, |
|
|
'savestep_num': save_steps[attn_map_idx] + 1, |
|
|
'batch_idx': batch_idx, |
|
|
'attnmap': attnmap, |
|
|
'filename_info': filename_info, |
|
|
'plot_color': plot_color, |
|
|
} |
|
|
return base_dict |
|
|
|
|
|
def unhook_callbacks(self) -> None: |
|
|
pass |
|
|
|
|
|
def get_xyz_axis_options(self) -> dict: |
|
|
return {} |
|
|
|
|
|
def get_infotext_fields(self) -> list: |
|
|
return self.infotext_fields |
|
|
|
|
|
def create_save_hook(self, module_list): |
|
|
pass |
|
|
|
|
|
def hook_modules(self, module_list: list, value_map: dict, p: StableDiffusionProcessing): |
|
|
def savemaps_hook(module, input, kwargs, output): |
|
|
""" Hook to save attention maps every N steps, or the last step if N is 0. |
|
|
Saves attention maps to a field named 'savemaps_batch' in the module. |
|
|
with shape (attn_map, batch_num, height * width). |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
if not module.savemaps_step in module.savemaps_save_steps: |
|
|
return |
|
|
reweight_crossattn = True |
|
|
|
|
|
|
|
|
is_self = getattr(module, 'savemaps_is_self', False) |
|
|
to_q_map = getattr(module, 'savemaps_to_q_map', None) |
|
|
to_k_map = to_q_map if module.savemaps_is_self else getattr(module, 'savemaps_to_k_map', None) |
|
|
|
|
|
|
|
|
orig_seq_len = to_k_map.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
token_indices = module.savemaps_token_indices |
|
|
|
|
|
if not is_self and reweight_crossattn: |
|
|
to_k_map = to_k_map[:, token_indices, :] |
|
|
|
|
|
attn_map = get_attention_scores(to_q_map, to_k_map, dtype=to_q_map.dtype) |
|
|
b, hw, seq_len = attn_map.shape |
|
|
|
|
|
if not is_self and reweight_crossattn: |
|
|
|
|
|
|
|
|
left_pad = 1 |
|
|
right_pad = orig_seq_len - seq_len - 1 |
|
|
attn_map = torch.nn.functional.pad(attn_map, (left_pad, right_pad), value=0) |
|
|
|
|
|
|
|
|
attn_map = attn_map.unsqueeze(0) |
|
|
|
|
|
|
|
|
if module.savemaps_batch is None: |
|
|
module.savemaps_batch = attn_map |
|
|
else: |
|
|
module.savemaps_batch = torch.cat([module.savemaps_batch, attn_map], dim=0) |
|
|
|
|
|
def savemaps_to_q_hook(module, input, kwargs, output): |
|
|
setattr(module.savemaps_parent_module[0], 'savemaps_to_q_map', output) |
|
|
|
|
|
def savemaps_to_k_hook(module, input, kwargs, output): |
|
|
if not module.savemaps_parent_module[0].savemaps_is_self: |
|
|
setattr(module.savemaps_parent_module[0],'savemaps_to_k_map', output) |
|
|
|
|
|
def savemaps_to_v_hook(module, input, kwargs, output): |
|
|
setattr(module.savemaps_parent_module[0],'savemaps_to_v_map', output) |
|
|
|
|
|
|
|
|
for module in module_list: |
|
|
|
|
|
for key_name, default_value in value_map.items(): |
|
|
module_hooks.modules_add_field(module, key_name, default_value) |
|
|
|
|
|
module_hooks.module_add_forward_hook(module, savemaps_hook, 'forward', with_kwargs=True) |
|
|
module_hooks.modules_add_field(module, 'savemaps_token_count', p.savemaps_token_count) |
|
|
module_hooks.modules_add_field(module, 'savemaps_token_indices', p.savemaps_token_indices) |
|
|
|
|
|
if module.network_layer_name.endswith('attn1'): |
|
|
module_hooks.modules_add_field(module, 'savemaps_is_self', True) |
|
|
if module.network_layer_name.endswith('attn2'): |
|
|
module_hooks.modules_add_field(module, 'savemaps_is_self', False) |
|
|
|
|
|
for module_name in SUBMODULES: |
|
|
if not hasattr(module, module_name): |
|
|
logger.error(f"Submodule not found: {module_name} in module: {module.network_layer_name}") |
|
|
continue |
|
|
submodule = getattr(module, module_name) |
|
|
hook_fn_name = f'savemaps_{module_name}_hook' |
|
|
hook_fn = locals().get(hook_fn_name, None) |
|
|
if not hook_fn: |
|
|
logger.error(f"Hook function '{hook_fn_name}' not found for submodule: {module_name}") |
|
|
continue |
|
|
|
|
|
module_hooks.modules_add_field(submodule, 'savemaps_parent_module', [module]) |
|
|
module_hooks.module_add_forward_hook(submodule, hook_fn, 'forward', with_kwargs=True) |
|
|
|
|
|
def unhook_modules(self, module_list: list, value_map: dict): |
|
|
for module in module_list: |
|
|
for key_name, _ in value_map.items(): |
|
|
module_hooks.modules_remove_field(module, key_name) |
|
|
module_hooks.modules_remove_field(module, 'savemaps_is_self') |
|
|
module_hooks.modules_remove_field(module, 'savemaps_token_count') |
|
|
module_hooks.modules_remove_field(module, 'savemaps_token_indices') |
|
|
module_hooks.remove_module_forward_hook(module, 'savemaps_hook') |
|
|
for module_name in SUBMODULES: |
|
|
module_hooks.modules_remove_field(module, f'savemaps_{module_name}_map') |
|
|
|
|
|
if hasattr(module, module_name): |
|
|
submodule = getattr(module, module_name) |
|
|
module_hooks.modules_remove_field(submodule, 'savemaps_parent_module') |
|
|
module_hooks.remove_module_forward_hook(submodule, f'savemaps_{module_name}_hook') |
|
|
|
|
|
|
|
|
def print_modules(self, module_name_filter, class_name_filter): |
|
|
logger.info("Module name filter: '%s', Class name filter: '%s'", module_name_filter, class_name_filter) |
|
|
modules = self.get_modules_by_filter(module_name_filter, class_name_filter) |
|
|
module_names = [""] |
|
|
if len(modules) > 0: |
|
|
module_names = "\n".join([f"{m.network_layer_name}: {m.__class__.__name__}" for m in modules]) |
|
|
logger.info("Modules found:\n----------\n%s\n----------\n", module_names) |
|
|
|
|
|
def get_modules_by_filter(self, module_name_filter, class_name_filter): |
|
|
if len(class_name_filter) == 0: |
|
|
class_name_filter = None |
|
|
if len(module_name_filter) == 0: |
|
|
module_name_filter = None |
|
|
found_modules = module_hooks.get_modules(module_name_filter, class_name_filter) |
|
|
if len(found_modules) == 0: |
|
|
logger.warning(f"No modules found with module name filter: {module_name_filter} and class name filter") |
|
|
return found_modules |
|
|
|
|
|
|
|
|
def get_attention_scores(to_q_map, to_k_map, dtype): |
|
|
""" Calculate the attention scores for the given query and key maps |
|
|
Arguments: |
|
|
to_q_map: torch.Tensor - query map |
|
|
to_k_map: torch.Tensor - key map |
|
|
dtype: torch.dtype - data type of the tensor |
|
|
Returns: |
|
|
torch.Tensor - attention scores |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_probs = to_q_map @ to_k_map.transpose(-1, -2) |
|
|
attn_probs = attn_probs.to(dtype=torch.float32) |
|
|
|
|
|
channel_dim = to_q_map.size(1) |
|
|
attn_probs /= (channel_dim ** 0.5) |
|
|
attn_probs -= attn_probs.max() |
|
|
|
|
|
|
|
|
attn_probs = attn_probs.softmax(dim=-1).to(device=shared.device, dtype=to_q_map.dtype) |
|
|
attn_probs = attn_probs.to(dtype=dtype) |
|
|
|
|
|
return attn_probs |