dikdimon's picture
Upload extensions using SD-Hub extension
3dabe4a verified
import logging
from os import environ
import modules.scripts as scripts
import gradio as gr
import scipy.stats as stats
from scripts.ui_wrapper import UIWrapper, arg
from modules import script_callbacks, patches
from modules.hypernetworks import hypernetwork
#import modules.sd_hijack_optimizations
from modules.script_callbacks import CFGDenoiserParams, CFGDenoisedParams, AfterCFGCallbackParams
from modules.prompt_parser import reconstruct_multicond_batch
from modules.processing import StableDiffusionProcessing
#from modules.shared import sd_model, opts
from modules.sd_samplers_cfg_denoiser import catenate_conds
from modules.sd_samplers_cfg_denoiser import CFGDenoiser
from modules import shared
import math
import torch
from torch.nn import functional as F
from torchvision.transforms import GaussianBlur
from warnings import warn
from typing import Callable, Dict, Optional
from collections import OrderedDict
import torch
from scripts.incant_utils import module_hooks
# from pytorch_memlab import LineProfiler, MemReporter
# reporter = MemReporter()
logger = logging.getLogger(__name__)
logger.setLevel(environ.get("SD_WEBUI_LOG_LEVEL", logging.INFO))
incantations_debug = environ.get("INCANTAIONS_DEBUG", False)
"""
An unofficial implementation of "Rethinking the Spatial Inconsistency in Classifier-Free Diffusion Guidancee" for Automatic1111 WebUI.
This builds upon the code provided in the official S-CFG repository: https://github.com/SmilesDZgk/S-CFG
@inproceedings{shen2024rethinking,
title={Rethinking the Spatial Inconsistency in Classifier-Free Diffusion Guidancee},
author={Shen, Dazhong and Song, Guanglu and Xue, Zeyue and Wang, Fu-Yun and Liu, Yu},
booktitle={Proceedings of The IEEE/CVF Computer Vision and Pattern Recognition Conference (CVPR)},
year={2024}
}
Parts of the code are based on Diffusers under the Apache License 2.0:
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Author: v0xie
GitHub URL: https://github.com/v0xie/sd-webui-incantations
"""
handles = []
global_scale = 1
SCFG_MODULES = ['to_q', 'to_k']
class SCFGStateParams:
def __init__(self):
self.scfg_scale:float = 0.8
self.rate_min = 0.8
self.rate_max = 3.0
self.rate_clamp = 15.0
self.R = 4
self.start_step = 0
self.end_step = 150
self.gaussian_smoothing = None
self.max_sampling_steps = -1
self.current_step = 0
self.height = -1
self.width = -1
self.statistics = {
"min_rate": float('inf'),
"max_rate": float('-inf'),
}
self.mask_t = None
self.mask_fore = None
self.denoiser = None
self.all_crossattn_modules = None
self.patched_combined_denoised = None
class SCFGExtensionScript(UIWrapper):
def __init__(self):
self.cached_c = [None, None]
self.handles = []
# Extension title in menu UI
def title(self) -> str:
return "S-CFG"
# Decide to show menu in txt2img or img2img
def show(self, is_img2img):
return scripts.AlwaysVisible
# Setup menu ui detail
def setup_ui(self, is_img2img) -> list:
with gr.Accordion('S-CFG', open=False):
active = gr.Checkbox(value=False, default=False, label="Active", elem_id='scfg_active', info="Computationally expensive. A batch size of 4 for 1024x1024 will max out a 24GB card!")
with gr.Row():
scfg_scale = gr.Slider(value = 1.0, minimum = 0, maximum = 10.0, step = 0.1, label="SCFG Scale", elem_id = 'scfg_scale', info="")
scfg_r = gr.Slider(value = 4, minimum = 1, maximum = 16, step = 1, label="SCFG R", elem_id = 'scfg_r', info="Scale factor. Greater R uses more memory.")
with gr.Row():
scfg_rate_min = gr.Slider(value = 0.8, minimum = 0, maximum = 30.0, step = 0.1, label="Min Rate", elem_id = 'scfg_rate_min', info="")
scfg_rate_max = gr.Slider(value = 3.0, minimum = 0, maximum = 30.0, step = 0.1, label="Max Rate", elem_id = 'scfg_rate_max', info="")
scfg_rate_clamp = gr.Slider(value = 0.0, minimum = 0, maximum = 30.0, step = 0.1, label="Clamp Rate", elem_id = 'scfg_rate_clamp', info="If > 0, clamp max rate to Clamp Rate / CFG Scale. Overrides max rate.")
with gr.Row():
start_step = gr.Slider(value = 0, minimum = 0, maximum = 150, step = 1, label="Start Step", elem_id = 'scfg_start_step', info="")
end_step = gr.Slider(value = 150, minimum = 0, maximum = 150, step = 1, label="End Step", elem_id = 'scfg_end_step', info="")
active.do_not_save_to_config = True
scfg_scale.do_not_save_to_config = True
scfg_rate_min.do_not_save_to_config = True
scfg_rate_max.do_not_save_to_config = True
scfg_rate_clamp.do_not_save_to_config = True
scfg_r.do_not_save_to_config = True
start_step.do_not_save_to_config = True
end_step.do_not_save_to_config = True
self.infotext_fields = [
(active, lambda d: gr.Checkbox.update(value='SCFG Active' in d)),
(scfg_scale, 'SCFG Scale'),
(scfg_rate_min, 'SCFG Rate Min'),
(scfg_rate_max, 'SCFG Rate Max'),
(scfg_rate_clamp, 'SCFG Rate Clamp'),
(start_step, 'SCFG Start Step'),
(end_step, 'SCFG End Step'),
(scfg_r, 'SCFG R'),
]
self.paste_field_names = [
'scfg_active',
'scfg_scale',
'scfg_rate_min',
'scfg_rate_max',
'scfg_rate_clamp',
'scfg_start_step',
'scfg_end_step',
'scfg_r',
]
return [active, scfg_scale, scfg_rate_min, scfg_rate_max, scfg_rate_clamp, start_step, end_step, scfg_r]
def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
self.pag_process_batch(p, *args, **kwargs)
def pag_process_batch(self, p: StableDiffusionProcessing, active, scfg_scale, scfg_rate_min, scfg_rate_max, scfg_rate_clamp, start_step, end_step, scfg_r, *args, **kwargs):
# cleanup previous hooks always
script_callbacks.remove_current_script_callbacks()
self.remove_all_hooks()
active = getattr(p, "scfg_active", active)
if active is False:
return
scfg_scale = getattr(p, "scfg_scale", scfg_scale)
scfg_rate_min = getattr(p, "scfg_rate_min", scfg_rate_min)
scfg_rate_max = getattr(p, "scfg_rate_max", scfg_rate_max)
scfg_rate_clamp = getattr(p, "scfg_rate_clamp", scfg_rate_clamp)
start_step = getattr(p, "scfg_start_step", start_step)
end_step = getattr(p, "scfg_end_step", end_step)
scfg_r = getattr(p, "scfg_r", scfg_r)
p.extra_generation_params.update({
"SCFG Active": active,
"SCFG Scale": scfg_scale,
"SCFG Rate Min": scfg_rate_min,
"SCFG Rate Max": scfg_rate_max,
"SCFG Rate Clamp": scfg_rate_clamp,
"SCFG Start Step": start_step,
"SCFG End Step": end_step,
"SCFG R": scfg_r,
})
self.create_hook(p, active, scfg_scale, scfg_rate_min, scfg_rate_max, scfg_rate_clamp, start_step, end_step, scfg_r)
def create_hook(self, p: StableDiffusionProcessing, active, scfg_scale, scfg_rate_min, scfg_rate_max, scfg_rate_clamp, start_step, end_step, scfg_r):
# Create a list of parameters for each concept
scfg_params = SCFGStateParams()
# Add to p
if not hasattr(p, 'incant_cfg_params'):
logger.error("No incant_cfg_params found in p")
p.incant_cfg_params['scfg_params'] = scfg_params
scfg_params.denoiser = None
scfg_params.all_crossattn_modules = self.get_all_crossattn_modules()
scfg_params.max_sampling_steps = p.steps
scfg_params.scfg_scale = scfg_scale
scfg_params.rate_min = scfg_rate_min
scfg_params.rate_max = scfg_rate_max
scfg_params.rate_clamp = scfg_rate_clamp
scfg_params.start_step = start_step
scfg_params.end_step = end_step
scfg_params.R = scfg_r
scfg_params.height = p.height
scfg_params.width = p.width
kernel_size = 3
sigma=0.5
scfg_params.gaussian_smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).to(shared.device)
# Use lambda to call the callback function with the parameters to avoid global variables
#cfg_denoise_lambda = lambda callback_params: self.on_cfg_denoiser_callback(callback_params, scfg_params)
cfg_denoised_lambda = lambda callback_params: self.on_cfg_denoised_callback(callback_params, scfg_params)
unhook_lambda = lambda _: self.unhook_callbacks(scfg_params)
self.ready_hijack_forward(scfg_params.all_crossattn_modules)
logger.debug('Hooked callbacks')
#script_callbacks.on_cfg_denoiser(cfg_denoise_lambda)
script_callbacks.on_cfg_denoised(cfg_denoised_lambda)
script_callbacks.on_script_unloaded(unhook_lambda)
def postprocess_batch(self, p, *args, **kwargs):
self.scfg_postprocess_batch(p, *args, **kwargs)
def scfg_postprocess_batch(self, p, active, *args, **kwargs):
script_callbacks.remove_current_script_callbacks()
logger.debug('Removed script callbacks')
active = getattr(p, "scfg_active", active)
if active is False:
return
if hasattr(p, 'incant_cfg_params') and 'scfg_params' in p.incant_cfg_params:
stats = p.incant_cfg_params['scfg_params'].statistics
logger.debug('SCFG Statistics: %s', stats)
self.remove_all_hooks()
def remove_all_hooks(self):
all_crossattn_modules = self.get_all_crossattn_modules()
for module in all_crossattn_modules:
self.remove_field_cross_attn_modules(module, 'scfg_last_to_q_map')
self.remove_field_cross_attn_modules(module, 'scfg_last_to_k_map')
if hasattr(module, 'to_q'):
handle_scfg_to_q = _remove_all_forward_hooks(module.to_q, 'scfg_to_q_hook')
self.remove_field_cross_attn_modules(module.to_q, 'scfg_parent_module')
if hasattr(module, 'to_k'):
handle_scfg_to_k = _remove_all_forward_hooks(module.to_k, 'scfg_to_k_hook')
self.remove_field_cross_attn_modules(module.to_k, 'scfg_parent_module')
def unhook_callbacks(self, scfg_params: SCFGStateParams):
pass
def ready_hijack_forward(self, all_crossattn_modules):
""" Create hooks in the forward pass of the cross attention modules
Copies the output of the to_v module to the parent module
"""
def scfg_self_attn_hook(module, input, kwargs, output):
# scfg_q_map = output.detach().clone()
scfg_q_map = prepare_attn_map(output, module.scfg_heads)
attn_scores = get_attention_scores(scfg_q_map, scfg_q_map)
setattr(module.scfg_parent_module[0], 'scfg_last_qv_map', attn_scores)
def scfg_cross_attn_hook(module, input, kwargs, output):
scfg_q_map = prepare_attn_map(module.scfg_parent_module[0].scfg_last_to_q_map, module.scfg_heads)
scfg_k_map = prepare_attn_map(output, module.scfg_heads)
#scfg_k_map = output.detach().clone()
attn_scores = get_attention_scores(scfg_q_map, scfg_k_map)
setattr(module.scfg_parent_module[0], 'scfg_last_qv_map', attn_scores)
# del module.parent_module[0].scfg_last_to_q_map
def scfg_to_q_hook(module, input, kwargs, output):
setattr(module.scfg_parent_module[0], 'scfg_last_to_q_map', output)
def scfg_to_k_hook(module, input, kwargs, output):
setattr(module.scfg_parent_module[0], 'scfg_last_to_k_map', output)
for module in all_crossattn_modules:
if not hasattr(module, 'to_q') or not hasattr(module, 'to_k'):
logger.error("CrossAttention module '%s' does not have to_q or to_k", module.network_layer_name)
continue
# to_q
self.add_field_cross_attn_modules(module.to_q, 'scfg_parent_module', [module])
self.add_field_cross_attn_modules(module, 'scfg_last_to_q_map', None)
handle_scfg_to_q = module_hooks.module_add_forward_hook(
module.to_q,
scfg_to_q_hook,
with_kwargs=True
)
# to_k
self.add_field_cross_attn_modules(module.to_k, 'scfg_parent_module', [module])
if module.network_layer_name.endswith('attn2'): # cross attn
self.add_field_cross_attn_modules(module, 'scfg_last_to_k_map', None)
handle_scfg_to_k = module_hooks.module_add_forward_hook(
module.to_k,
scfg_to_k_hook,
with_kwargs=True
)
def get_all_crossattn_modules(self):
"""
Get ALL attention modules
"""
modules = module_hooks.get_modules(
module_name_filter='CrossAttention'
)
return modules
def add_field_cross_attn_modules(self, module, field, value):
""" Add a field to a module if it doesn't exist """
module_hooks.modules_add_field(module, field, value)
def remove_field_cross_attn_modules(self, module, field):
""" Remove a field from a module if it exists """
module_hooks.modules_remove_field(module, field)
def on_cfg_denoiser_callback(self, params: CFGDenoiserParams, scfg_params: SCFGStateParams):
# always unhook
self.unhook_callbacks(scfg_params)
def on_cfg_denoised_callback(self, params: CFGDenoisedParams, scfg_params: SCFGStateParams):
""" Callback function for the CFGDenoisedParams
Refer to pg.22 A.2 of the PAG paper for how CFG and PAG combine
"""
scfg_params.current_step = params.sampling_step
# Run only within interval
if not scfg_params.start_step <= params.sampling_step <= scfg_params.end_step:
return
if scfg_params.scfg_scale <= 0:
return
# S-CFG
R = scfg_params.R
max_latent_size = [params.x.shape[-2] // R, params.x.shape[-1] // R]
#with LineProfiler(get_mask) as lp:
ca_mask, fore_mask = get_mask(scfg_params.all_crossattn_modules,
scfg_params,
r = scfg_params.R,
latent_size = max_latent_size,
)
#lp.print_stats()
# todo parameterize this
mask_t = F.interpolate(ca_mask, scale_factor=R, mode='nearest')
mask_fore = F.interpolate(fore_mask, scale_factor=R, mode='nearest')
scfg_params.mask_t = mask_t
scfg_params.mask_fore = mask_fore
def get_xyz_axis_options(self) -> dict:
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ in ("xyz_grid.py", "scripts.xyz_grid")][0].module
extra_axis_options = {
xyz_grid.AxisOption("[SCFG] Active", str, scfg_apply_override('scfg_active', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[SCFG] SCFG Scale", float, scfg_apply_field("scfg_scale")),
xyz_grid.AxisOption("[SCFG] SCFG Rate Min", float, scfg_apply_field("scfg_rate_min")),
xyz_grid.AxisOption("[SCFG] SCFG Rate Max", float, scfg_apply_field("scfg_rate_max")),
xyz_grid.AxisOption("[SCFG] SCFG Rate Clamp", float, scfg_apply_field("scfg_rate_clamp")),
xyz_grid.AxisOption("[SCFG] SCFG Start Step", int, scfg_apply_field("scfg_start_step")),
xyz_grid.AxisOption("[SCFG] SCFG End Step", int, scfg_apply_field("scfg_end_step")),
xyz_grid.AxisOption("[SCFG] SCFG R", int, scfg_apply_field("scfg_r")),
}
return extra_axis_options
def scfg_combine_denoised(model_delta, cfg_scale, scfg_params: SCFGStateParams):
""" The inner loop of the S-CFG denoiser
Arguments:
model_delta: torch.Tensor - defined by `x_out[cond_index] - denoised_uncond[i]`
cfg_scale: float - guidance scale
scfg_params: SCFGStateParams - the state parameters for the S-CFG denoiser
Returns:
int or torch.Tensor - 1.0 if not within interval or scale is 0, else the rate map tensor
"""
current_step = scfg_params.current_step
start_step = scfg_params.start_step
end_step = scfg_params.end_step
scfg_scale = scfg_params.scfg_scale
if not start_step <= current_step <= end_step:
return 1.0
if scfg_scale <= 0:
return 1.0
mask_t = scfg_params.mask_t
mask_fore = scfg_params.mask_fore
min_rate = scfg_params.rate_min
max_rate = scfg_params.rate_max
rate_clamp = scfg_params.rate_clamp
model_delta = model_delta.unsqueeze(0)
model_delta_norm = model_delta.norm(dim=1, keepdim=True)
eps = lambda dtype: torch.finfo(dtype).eps
# rescale map if necessary
if mask_t.shape[2:] != model_delta_norm.shape[2:]:
logger.debug('Rescaling mask_t from %s to %s', mask_t.shape[2:], model_delta_norm.shape[2:])
mask_t = F.interpolate(mask_t, size=model_delta_norm.shape[2:], mode='bilinear')
if mask_fore.shape[-2] != model_delta_norm.shape[-2]:
logger.debug('Rescaling mask_fore from %s to %s', mask_fore.shape[2:], model_delta_norm.shape[2:])
mask_fore = F.interpolate(mask_fore, size=model_delta_norm.shape[2:], mode='bilinear')
delta_mask_norms = (model_delta_norm * mask_t).sum([2,3])/(mask_t.sum([2,3])+eps(mask_t.dtype))
upnormmax = delta_mask_norms.max(dim=1)[0]
upnormmax = upnormmax.unsqueeze(-1)
fore_norms = (model_delta_norm * mask_fore).sum([2,3])/(mask_fore.sum([2,3])+eps(mask_fore.dtype))
up = fore_norms
down = delta_mask_norms
tmp_mask = (mask_t.sum([2,3])>0).float()
rate = up*(tmp_mask)/(down+eps(down.dtype)) # b 257
rate = (rate.unsqueeze(-1).unsqueeze(-1)*mask_t).sum(dim=1, keepdim=True) # b 1, 64 64
del model_delta_norm, delta_mask_norms, upnormmax, fore_norms, up, down, tmp_mask
# unscaled min/max rate
if rate.min().item() < scfg_params.statistics["min_rate"]:
scfg_params.statistics["min_rate"] = rate.min().item()
if rate.max().item() > scfg_params.statistics["max_rate"]:
scfg_params.statistics["max_rate"] = rate.max().item()
# should this go before or after the gaussian blur, or before/after the rate
rate = rate * scfg_scale
rate = torch.clamp(rate,min=min_rate, max=max_rate)
if rate_clamp > 0:
rate = torch.clamp_max(rate, rate_clamp/cfg_scale)
###Gaussian Smoothing
#kernel_size = 3
#sigma=0.5
#smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).to(rate.device)
smoothing = scfg_params.gaussian_smoothing
rate = F.pad(rate, (1, 1, 1, 1), mode='reflect')
rate = smoothing(rate)
return rate.squeeze(0)
# XYZ Plot
# Based on @mcmonkey4eva's XYZ Plot implementation here: https://github.com/mcmonkeyprojects/sd-dynamic-thresholding/blob/master/scripts/dynamic_thresholding.py
def scfg_apply_override(field, boolean: bool = False):
def fun(p, x, xs):
if boolean:
x = True if x.lower() == "true" else False
setattr(p, field, x)
if not hasattr(p, "scfg_active"):
setattr(p, "scfg_active", True)
return fun
def scfg_apply_field(field):
def fun(p, x, xs):
if not hasattr(p, "scfg_active"):
setattr(p, "scfg_active", True)
setattr(p, field, x)
return fun
def _remove_all_forward_hooks(
module: torch.nn.Module, hook_fn_name: Optional[str] = None
) -> None:
module_hooks.remove_module_forward_hook(module, hook_fn_name)
"""
# below code modified from https://github.com/SmilesDZgk/S-CFG
@inproceedings{shen2024rethinking,
title={Rethinking the Spatial Inconsistency in Classifier-Free Diffusion Guidancee},
author={Shen, Dazhong and Song, Guanglu and Xue, Zeyue and Wang, Fu-Yun and Liu, Yu},
booktitle={Proceedings of The IEEE/CVF Computer Vision and Pattern Recognition Conference (CVPR)},
year={2024}
}
"""
import math
import numbers
import torch
from torch import nn
from torch.nn import functional as F
class GaussianSmoothing(nn.Module):
"""
Apply gaussian smoothing on a
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
in the input using a depthwise convolution.
Arguments:
channels (int, sequence): Number of channels of the input tensors. Output will
have this number of channels as well.
kernel_size (int, sequence): Size of the gaussian kernel.
sigma (float, sequence): Standard deviation of the gaussian kernel.
dim (int, optional): The number of dimensions of the data.
Default value is 2 (spatial).
"""
def __init__(self, channels, kernel_size, sigma, dim=2):
super(GaussianSmoothing, self).__init__()
if isinstance(kernel_size, numbers.Number):
kernel_size = [kernel_size] * dim
if isinstance(sigma, numbers.Number):
sigma = [sigma] * dim
# The gaussian kernel is the product of the
# gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid(
[
torch.arange(size, dtype=torch.float32)
for size in kernel_size
]
)
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
torch.exp(-((mgrid - mean) / (2 * std)) ** 2)
# Make sure sum of values in gaussian kernel equals 1.
kernel = kernel / torch.sum(kernel)
# Reshape to depthwise convolutional weight
kernel = kernel.view(1, 1, *kernel.size())
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
self.register_buffer('weight', kernel)
self.groups = channels
if dim == 1:
self.conv = F.conv1d
elif dim == 2:
self.conv = F.conv2d
elif dim == 3:
self.conv = F.conv3d
else:
raise RuntimeError(
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
)
def forward(self, input):
"""
Apply gaussian filter to input.
Arguments:
input (torch.Tensor): Input to apply gaussian filter on.
Returns:
filtered (torch.Tensor): Filtered output.
"""
return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups)
# based on diffusers/models/attention_processor.py Attention head_to_batch_dim
def head_to_batch_dim(x, heads, out_dim=3):
head_size = heads
if x.ndim == 3:
batch_size, seq_len, dim = x.shape
extra_dim = 1
else:
batch_size, extra_dim, seq_len, dim = x.shape
x = x.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
x = x.permute(0, 2, 1, 3)
if out_dim == 3:
x = x.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
return x
# based on diffusers/models/attention_processor.py Attention batch_to_head_dim
def batch_to_head_dim(x, heads):
head_size = heads
batch_size, seq_len, dim = x.shape
x = x.reshape(batch_size // head_size, head_size, seq_len, dim)
x = x.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return x
def average_over_head_dim(x, heads):
x = rearrange(x, '(b h) s t -> b h s t', h=heads).mean(1)
return x
import torch.nn.functional as F
from einops import rearrange
def get_mask(attn_modules, scfg_params: SCFGStateParams, r, latent_size):
""" Aggregates the attention across the different layers and heads at the specified resolution.
In the original paper, r is a hyper-parameter set to 4.
Arguments:
attn_modules: List of attention modules
scfg_params: SCFGStateParams
r: int -
latent_size: tuple
"""
height = scfg_params.height
width = scfg_params.width
max_dims = height * width
latent_size = latent_size[-2:]
module_attn_sizes = set()
key_corss = f"r{r}_cross"
key_self = f"r{r}_self"
# The maximum value of the sizes of attention map to aggregate
max_r = r
max_sizes = r
# The current number of attention map resolutions aggregated
attnmap_r = 0
r_r = 1
new_ca = 0
new_fore=0
a_n=0
# corresponds to diffusers pipe.unet.config.sample_size
# sample_size = 64
# get a layer wise mapping
attention_store_proxy = {"r2_cross": [], "r4_cross": [], "r8_cross": [], "r16_cross": [],
"r2_self": [], "r4_self": [], "r8_self": [], "r16_self": []}
for module in attn_modules:
module_type = 'cross' if 'attn2' in module.network_layer_name else 'self'
to_q_map = getattr(module, 'scfg_last_to_q_map', None)
to_k_map = getattr(module, 'scfg_last_to_k_map', None)
# self-attn
if to_k_map is None:
to_k_map = to_q_map
to_q_map = prepare_attn_map(to_q_map, module.heads)
to_k_map = prepare_attn_map(to_k_map, module.heads)
module_attn_size = to_q_map.size(1)
module_attn_sizes.add(module_attn_size)
downscale_h = int((module_attn_size * (height / width)) ** 0.5)
downscale_w = module_attn_size // downscale_h
module_key = f"r{module_attn_size}_{module_type}"
attn_probs = get_attention_scores(to_q_map, to_k_map, to_q_map.dtype)
if module_type == 'self':
del module.scfg_last_to_q_map
else:
del module.scfg_last_to_q_map, module.scfg_last_to_k_map
if module_key not in attention_store_proxy:
attention_store_proxy[module_key] = []
try:
attention_store_proxy[module_key].append(attn_probs)
except KeyError:
continue
module_attn_sizes = sorted(list(module_attn_sizes))
attention_maps = attention_store_proxy
curr_r = module_attn_sizes.pop(0)
while curr_r != None and attnmap_r < max_sizes:
key_corss = f"r{curr_r}_cross"
key_self = f"r{curr_r}_self"
if key_self not in attention_maps.keys() or key_corss not in attention_maps.keys():
next_r = module_attn_sizes.pop(0)
attnmap_r += 1
curr_r = next_r
continue
if len(attention_maps[key_self]) == 0 or len(attention_maps[key_corss]) == 0:
curr_r = module_attn_sizes.pop(0)
attnmap_r += 1
curr_r = next_r
continue
sa = torch.stack(attention_maps[key_self], dim=1)
ca = torch.stack(attention_maps[key_corss], dim=1)
attn_num = sa.size(1)
sa = rearrange(sa, 'b n h w -> (b n) h w')
ca = rearrange(ca, 'b n h w -> (b n) h w')
curr = 0 # b hw c=hw
curr +=sa
# 4.1.2 Self-Attentiion
ssgc_sa = curr
ssgc_n = max_r
# summation from r=2 to R, we set ssgc_sa to curr which would be sa^1
# major memory hog
# active_bytes peak from 3.41G to 4.04G
# reserved_bytes peak from 3.70G to 4.64G
# optimization 1: active 4.03G -> 3.72G = 0.31G, reserved 4.64G -> 4.16G = 0.48G
for r_value in range(1, ssgc_n):
r_pow = r_value + 1
curr @= sa # optimization 1
# curr = torch.linalg.matrix_power(sa, r_pow) # sa^r
ssgc_sa += curr
ssgc_sa/=ssgc_n
sa = ssgc_sa
########smoothing ca
ca = sa@ca # b hw c
hw = ca.size(1)
downscale_h = round((hw * (height / width)) ** 0.5)
ca = rearrange(ca, 'b (h w) c -> b c h w', h=downscale_h )
# Scale the attention map to the expected size
max_size = latent_size
scale_factor = [
max_size[0] / ca.shape[-2],
max_size[1] / ca.shape[-1]
]
mode = 'bilinear' #'nearest' #
ca = F.interpolate(ca, scale_factor=scale_factor, mode=mode) # b 77 32 32
#####Gaussian Smoothing
#kernel_size = 3
#sigma = 0.5
#smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).to(ca.device)
smoothing = scfg_params.gaussian_smoothing
channel = ca.size(1)
ca= rearrange(ca, ' b c h w -> (b c) h w' ).unsqueeze(1)
ca = F.pad(ca, (1, 1, 1, 1), mode='reflect')
ca = smoothing(ca.float()).squeeze(1)
ca = rearrange(ca, ' (b c) h w -> b c h w' , c= channel)
ca_norm = ca/(ca.mean(dim=[2,3], keepdim=True)+torch.finfo(ca.dtype).eps) ### spatial normlization
new_ca+=rearrange(ca_norm, '(b n) c h w -> b n c h w', n=attn_num).sum(1)
fore_ca = torch.stack([ca[:,0],ca[:,1:].sum(dim=1)], dim=1)
froe_ca_norm = fore_ca/fore_ca.mean(dim=[2,3], keepdim=True) ### spatial normlization
new_fore += rearrange(froe_ca_norm, '(b n) c h w -> b n c h w', n=attn_num).sum(1)
a_n+=attn_num
if len(module_attn_sizes) > 0:
curr_r = module_attn_sizes.pop(0)
else:
curr_r = None
attnmap_r += 1
# r_r *= 2
# optimization 2: memory savings: 3.09G - 2.47G = 0.62G
del ca_norm, froe_ca_norm, fore_ca
# no memory savings
del attention_maps
del sa, ca, ssgc_sa, ssgc_n, curr
# variables used from above:
# new_ca, new_fore, a_n
new_ca = new_ca/a_n
new_fore = new_fore/a_n
_,new_ca = new_ca.chunk(2, dim=0) #[1]
fore_ca, _ = new_fore.chunk(2, dim=0)
max_ca, inds = torch.max(new_ca[:,:], dim=1)
max_ca = max_ca.unsqueeze(1) #
ca_mask = (new_ca==max_ca).float() # b 77/10 16 16
max_fore, inds = torch.max(fore_ca[:,:], dim=1)
max_fore = max_fore.unsqueeze(1) #
fore_mask = (fore_ca==max_fore).float() # b 77/10 16 16
fore_mask = 1.0-fore_mask[:,:1] # b 1 16 16
# no memory savings
del new_ca, new_fore, a_n, max_ca, max_fore, inds
return [ ca_mask, fore_mask]
def prepare_attn_map(to_k_map, heads):
to_k_map = head_to_batch_dim(to_k_map, heads)
to_k_map = average_over_head_dim(to_k_map, heads)
to_k_map = torch.stack([to_k_map[0], to_k_map[0]], dim=0)
return to_k_map
def get_attention_scores(to_q_map, to_k_map, dtype):
""" Calculate the attention scores for the given query and key maps
Arguments:
to_q_map: torch.Tensor - query map
to_k_map: torch.Tensor - key map
dtype: torch.dtype - data type of the tensor
Returns:
torch.Tensor - attention scores
"""
# based on diffusers models/attention.py "get_attention_scores"
# use in place operations vs. softmax to save memory: https://stackoverflow.com/questions/53732209/torch-in-place-operations-to-save-memory-softmax
# 512x: 2.65G -> 2.47G
# attn_probs = attn_scores.softmax(dim=-1).to(device=shared.device, dtype=to_q_map.dtype)
attn_probs = to_q_map @ to_k_map.transpose(-1, -2)
# avoid nan by converting to float32 and subtracting max
attn_probs = attn_probs.to(dtype=torch.float32) #
attn_probs -= torch.max(attn_probs)
torch.exp(attn_probs, out = attn_probs)
summed = attn_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)
attn_probs /= summed
attn_probs = attn_probs.to(dtype=dtype)
return attn_probs