|
|
import logging |
|
|
from os import environ |
|
|
import modules.scripts as scripts |
|
|
import gradio as gr |
|
|
import scipy.stats as stats |
|
|
|
|
|
from scripts.ui_wrapper import UIWrapper, arg |
|
|
from modules import script_callbacks, patches |
|
|
from modules.hypernetworks import hypernetwork |
|
|
|
|
|
from modules.script_callbacks import CFGDenoiserParams, CFGDenoisedParams, AfterCFGCallbackParams |
|
|
from modules.prompt_parser import reconstruct_multicond_batch |
|
|
from modules.processing import StableDiffusionProcessing |
|
|
|
|
|
from modules.sd_samplers_cfg_denoiser import catenate_conds |
|
|
from modules.sd_samplers_cfg_denoiser import CFGDenoiser |
|
|
from modules import shared |
|
|
|
|
|
import math |
|
|
import torch |
|
|
from torch.nn import functional as F |
|
|
from torchvision.transforms import GaussianBlur |
|
|
|
|
|
from warnings import warn |
|
|
from typing import Callable, Dict, Optional |
|
|
from collections import OrderedDict |
|
|
import torch |
|
|
|
|
|
from scripts.incant_utils import module_hooks |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
logger.setLevel(environ.get("SD_WEBUI_LOG_LEVEL", logging.INFO)) |
|
|
|
|
|
incantations_debug = environ.get("INCANTAIONS_DEBUG", False) |
|
|
|
|
|
|
|
|
""" |
|
|
An unofficial implementation of "Rethinking the Spatial Inconsistency in Classifier-Free Diffusion Guidancee" for Automatic1111 WebUI. |
|
|
|
|
|
This builds upon the code provided in the official S-CFG repository: https://github.com/SmilesDZgk/S-CFG |
|
|
|
|
|
|
|
|
@inproceedings{shen2024rethinking, |
|
|
title={Rethinking the Spatial Inconsistency in Classifier-Free Diffusion Guidancee}, |
|
|
author={Shen, Dazhong and Song, Guanglu and Xue, Zeyue and Wang, Fu-Yun and Liu, Yu}, |
|
|
booktitle={Proceedings of The IEEE/CVF Computer Vision and Pattern Recognition Conference (CVPR)}, |
|
|
year={2024} |
|
|
} |
|
|
|
|
|
Parts of the code are based on Diffusers under the Apache License 2.0: |
|
|
# Copyright 2024 The HuggingFace Team. All rights reserved. |
|
|
# |
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
# you may not use this file except in compliance with the License. |
|
|
# You may obtain a copy of the License at |
|
|
# |
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
# |
|
|
# Unless required by applicable law or agreed to in writing, software |
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
# See the License for the specific language governing permissions and |
|
|
# limitations under the License. |
|
|
|
|
|
Author: v0xie |
|
|
GitHub URL: https://github.com/v0xie/sd-webui-incantations |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
handles = [] |
|
|
global_scale = 1 |
|
|
|
|
|
SCFG_MODULES = ['to_q', 'to_k'] |
|
|
|
|
|
|
|
|
class SCFGStateParams: |
|
|
def __init__(self): |
|
|
self.scfg_scale:float = 0.8 |
|
|
self.rate_min = 0.8 |
|
|
self.rate_max = 3.0 |
|
|
self.rate_clamp = 15.0 |
|
|
self.R = 4 |
|
|
self.start_step = 0 |
|
|
self.end_step = 150 |
|
|
self.gaussian_smoothing = None |
|
|
|
|
|
self.max_sampling_steps = -1 |
|
|
self.current_step = 0 |
|
|
self.height = -1 |
|
|
self.width = -1 |
|
|
|
|
|
self.statistics = { |
|
|
"min_rate": float('inf'), |
|
|
"max_rate": float('-inf'), |
|
|
} |
|
|
|
|
|
self.mask_t = None |
|
|
self.mask_fore = None |
|
|
self.denoiser = None |
|
|
self.all_crossattn_modules = None |
|
|
self.patched_combined_denoised = None |
|
|
|
|
|
|
|
|
class SCFGExtensionScript(UIWrapper): |
|
|
def __init__(self): |
|
|
self.cached_c = [None, None] |
|
|
self.handles = [] |
|
|
|
|
|
|
|
|
def title(self) -> str: |
|
|
return "S-CFG" |
|
|
|
|
|
|
|
|
def show(self, is_img2img): |
|
|
return scripts.AlwaysVisible |
|
|
|
|
|
|
|
|
def setup_ui(self, is_img2img) -> list: |
|
|
with gr.Accordion('S-CFG', open=False): |
|
|
active = gr.Checkbox(value=False, default=False, label="Active", elem_id='scfg_active', info="Computationally expensive. A batch size of 4 for 1024x1024 will max out a 24GB card!") |
|
|
with gr.Row(): |
|
|
scfg_scale = gr.Slider(value = 1.0, minimum = 0, maximum = 10.0, step = 0.1, label="SCFG Scale", elem_id = 'scfg_scale', info="") |
|
|
scfg_r = gr.Slider(value = 4, minimum = 1, maximum = 16, step = 1, label="SCFG R", elem_id = 'scfg_r', info="Scale factor. Greater R uses more memory.") |
|
|
with gr.Row(): |
|
|
scfg_rate_min = gr.Slider(value = 0.8, minimum = 0, maximum = 30.0, step = 0.1, label="Min Rate", elem_id = 'scfg_rate_min', info="") |
|
|
scfg_rate_max = gr.Slider(value = 3.0, minimum = 0, maximum = 30.0, step = 0.1, label="Max Rate", elem_id = 'scfg_rate_max', info="") |
|
|
scfg_rate_clamp = gr.Slider(value = 0.0, minimum = 0, maximum = 30.0, step = 0.1, label="Clamp Rate", elem_id = 'scfg_rate_clamp', info="If > 0, clamp max rate to Clamp Rate / CFG Scale. Overrides max rate.") |
|
|
with gr.Row(): |
|
|
start_step = gr.Slider(value = 0, minimum = 0, maximum = 150, step = 1, label="Start Step", elem_id = 'scfg_start_step', info="") |
|
|
end_step = gr.Slider(value = 150, minimum = 0, maximum = 150, step = 1, label="End Step", elem_id = 'scfg_end_step', info="") |
|
|
|
|
|
active.do_not_save_to_config = True |
|
|
scfg_scale.do_not_save_to_config = True |
|
|
scfg_rate_min.do_not_save_to_config = True |
|
|
scfg_rate_max.do_not_save_to_config = True |
|
|
scfg_rate_clamp.do_not_save_to_config = True |
|
|
scfg_r.do_not_save_to_config = True |
|
|
start_step.do_not_save_to_config = True |
|
|
end_step.do_not_save_to_config = True |
|
|
|
|
|
self.infotext_fields = [ |
|
|
(active, lambda d: gr.Checkbox.update(value='SCFG Active' in d)), |
|
|
(scfg_scale, 'SCFG Scale'), |
|
|
(scfg_rate_min, 'SCFG Rate Min'), |
|
|
(scfg_rate_max, 'SCFG Rate Max'), |
|
|
(scfg_rate_clamp, 'SCFG Rate Clamp'), |
|
|
(start_step, 'SCFG Start Step'), |
|
|
(end_step, 'SCFG End Step'), |
|
|
(scfg_r, 'SCFG R'), |
|
|
] |
|
|
self.paste_field_names = [ |
|
|
'scfg_active', |
|
|
'scfg_scale', |
|
|
'scfg_rate_min', |
|
|
'scfg_rate_max', |
|
|
'scfg_rate_clamp', |
|
|
'scfg_start_step', |
|
|
'scfg_end_step', |
|
|
'scfg_r', |
|
|
] |
|
|
return [active, scfg_scale, scfg_rate_min, scfg_rate_max, scfg_rate_clamp, start_step, end_step, scfg_r] |
|
|
|
|
|
def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs): |
|
|
self.pag_process_batch(p, *args, **kwargs) |
|
|
|
|
|
def pag_process_batch(self, p: StableDiffusionProcessing, active, scfg_scale, scfg_rate_min, scfg_rate_max, scfg_rate_clamp, start_step, end_step, scfg_r, *args, **kwargs): |
|
|
|
|
|
script_callbacks.remove_current_script_callbacks() |
|
|
self.remove_all_hooks() |
|
|
|
|
|
active = getattr(p, "scfg_active", active) |
|
|
if active is False: |
|
|
return |
|
|
scfg_scale = getattr(p, "scfg_scale", scfg_scale) |
|
|
scfg_rate_min = getattr(p, "scfg_rate_min", scfg_rate_min) |
|
|
scfg_rate_max = getattr(p, "scfg_rate_max", scfg_rate_max) |
|
|
scfg_rate_clamp = getattr(p, "scfg_rate_clamp", scfg_rate_clamp) |
|
|
start_step = getattr(p, "scfg_start_step", start_step) |
|
|
end_step = getattr(p, "scfg_end_step", end_step) |
|
|
scfg_r = getattr(p, "scfg_r", scfg_r) |
|
|
|
|
|
p.extra_generation_params.update({ |
|
|
"SCFG Active": active, |
|
|
"SCFG Scale": scfg_scale, |
|
|
"SCFG Rate Min": scfg_rate_min, |
|
|
"SCFG Rate Max": scfg_rate_max, |
|
|
"SCFG Rate Clamp": scfg_rate_clamp, |
|
|
"SCFG Start Step": start_step, |
|
|
"SCFG End Step": end_step, |
|
|
"SCFG R": scfg_r, |
|
|
}) |
|
|
self.create_hook(p, active, scfg_scale, scfg_rate_min, scfg_rate_max, scfg_rate_clamp, start_step, end_step, scfg_r) |
|
|
|
|
|
def create_hook(self, p: StableDiffusionProcessing, active, scfg_scale, scfg_rate_min, scfg_rate_max, scfg_rate_clamp, start_step, end_step, scfg_r): |
|
|
|
|
|
scfg_params = SCFGStateParams() |
|
|
|
|
|
|
|
|
if not hasattr(p, 'incant_cfg_params'): |
|
|
logger.error("No incant_cfg_params found in p") |
|
|
p.incant_cfg_params['scfg_params'] = scfg_params |
|
|
|
|
|
scfg_params.denoiser = None |
|
|
scfg_params.all_crossattn_modules = self.get_all_crossattn_modules() |
|
|
scfg_params.max_sampling_steps = p.steps |
|
|
scfg_params.scfg_scale = scfg_scale |
|
|
scfg_params.rate_min = scfg_rate_min |
|
|
scfg_params.rate_max = scfg_rate_max |
|
|
scfg_params.rate_clamp = scfg_rate_clamp |
|
|
scfg_params.start_step = start_step |
|
|
scfg_params.end_step = end_step |
|
|
scfg_params.R = scfg_r |
|
|
scfg_params.height = p.height |
|
|
scfg_params.width = p.width |
|
|
kernel_size = 3 |
|
|
sigma=0.5 |
|
|
scfg_params.gaussian_smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).to(shared.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg_denoised_lambda = lambda callback_params: self.on_cfg_denoised_callback(callback_params, scfg_params) |
|
|
unhook_lambda = lambda _: self.unhook_callbacks(scfg_params) |
|
|
|
|
|
self.ready_hijack_forward(scfg_params.all_crossattn_modules) |
|
|
|
|
|
logger.debug('Hooked callbacks') |
|
|
|
|
|
script_callbacks.on_cfg_denoised(cfg_denoised_lambda) |
|
|
script_callbacks.on_script_unloaded(unhook_lambda) |
|
|
|
|
|
def postprocess_batch(self, p, *args, **kwargs): |
|
|
self.scfg_postprocess_batch(p, *args, **kwargs) |
|
|
|
|
|
def scfg_postprocess_batch(self, p, active, *args, **kwargs): |
|
|
script_callbacks.remove_current_script_callbacks() |
|
|
|
|
|
logger.debug('Removed script callbacks') |
|
|
active = getattr(p, "scfg_active", active) |
|
|
if active is False: |
|
|
return |
|
|
|
|
|
if hasattr(p, 'incant_cfg_params') and 'scfg_params' in p.incant_cfg_params: |
|
|
stats = p.incant_cfg_params['scfg_params'].statistics |
|
|
logger.debug('SCFG Statistics: %s', stats) |
|
|
|
|
|
|
|
|
self.remove_all_hooks() |
|
|
|
|
|
def remove_all_hooks(self): |
|
|
all_crossattn_modules = self.get_all_crossattn_modules() |
|
|
for module in all_crossattn_modules: |
|
|
self.remove_field_cross_attn_modules(module, 'scfg_last_to_q_map') |
|
|
self.remove_field_cross_attn_modules(module, 'scfg_last_to_k_map') |
|
|
if hasattr(module, 'to_q'): |
|
|
handle_scfg_to_q = _remove_all_forward_hooks(module.to_q, 'scfg_to_q_hook') |
|
|
self.remove_field_cross_attn_modules(module.to_q, 'scfg_parent_module') |
|
|
if hasattr(module, 'to_k'): |
|
|
handle_scfg_to_k = _remove_all_forward_hooks(module.to_k, 'scfg_to_k_hook') |
|
|
self.remove_field_cross_attn_modules(module.to_k, 'scfg_parent_module') |
|
|
|
|
|
def unhook_callbacks(self, scfg_params: SCFGStateParams): |
|
|
pass |
|
|
|
|
|
def ready_hijack_forward(self, all_crossattn_modules): |
|
|
""" Create hooks in the forward pass of the cross attention modules |
|
|
Copies the output of the to_v module to the parent module |
|
|
""" |
|
|
|
|
|
def scfg_self_attn_hook(module, input, kwargs, output): |
|
|
|
|
|
scfg_q_map = prepare_attn_map(output, module.scfg_heads) |
|
|
attn_scores = get_attention_scores(scfg_q_map, scfg_q_map) |
|
|
setattr(module.scfg_parent_module[0], 'scfg_last_qv_map', attn_scores) |
|
|
|
|
|
def scfg_cross_attn_hook(module, input, kwargs, output): |
|
|
scfg_q_map = prepare_attn_map(module.scfg_parent_module[0].scfg_last_to_q_map, module.scfg_heads) |
|
|
scfg_k_map = prepare_attn_map(output, module.scfg_heads) |
|
|
|
|
|
attn_scores = get_attention_scores(scfg_q_map, scfg_k_map) |
|
|
setattr(module.scfg_parent_module[0], 'scfg_last_qv_map', attn_scores) |
|
|
|
|
|
|
|
|
def scfg_to_q_hook(module, input, kwargs, output): |
|
|
setattr(module.scfg_parent_module[0], 'scfg_last_to_q_map', output) |
|
|
|
|
|
def scfg_to_k_hook(module, input, kwargs, output): |
|
|
setattr(module.scfg_parent_module[0], 'scfg_last_to_k_map', output) |
|
|
|
|
|
for module in all_crossattn_modules: |
|
|
if not hasattr(module, 'to_q') or not hasattr(module, 'to_k'): |
|
|
logger.error("CrossAttention module '%s' does not have to_q or to_k", module.network_layer_name) |
|
|
continue |
|
|
|
|
|
|
|
|
self.add_field_cross_attn_modules(module.to_q, 'scfg_parent_module', [module]) |
|
|
self.add_field_cross_attn_modules(module, 'scfg_last_to_q_map', None) |
|
|
handle_scfg_to_q = module_hooks.module_add_forward_hook( |
|
|
module.to_q, |
|
|
scfg_to_q_hook, |
|
|
with_kwargs=True |
|
|
) |
|
|
|
|
|
|
|
|
self.add_field_cross_attn_modules(module.to_k, 'scfg_parent_module', [module]) |
|
|
if module.network_layer_name.endswith('attn2'): |
|
|
self.add_field_cross_attn_modules(module, 'scfg_last_to_k_map', None) |
|
|
handle_scfg_to_k = module_hooks.module_add_forward_hook( |
|
|
module.to_k, |
|
|
scfg_to_k_hook, |
|
|
with_kwargs=True |
|
|
) |
|
|
|
|
|
def get_all_crossattn_modules(self): |
|
|
""" |
|
|
Get ALL attention modules |
|
|
""" |
|
|
modules = module_hooks.get_modules( |
|
|
module_name_filter='CrossAttention' |
|
|
) |
|
|
return modules |
|
|
|
|
|
def add_field_cross_attn_modules(self, module, field, value): |
|
|
""" Add a field to a module if it doesn't exist """ |
|
|
module_hooks.modules_add_field(module, field, value) |
|
|
|
|
|
def remove_field_cross_attn_modules(self, module, field): |
|
|
""" Remove a field from a module if it exists """ |
|
|
module_hooks.modules_remove_field(module, field) |
|
|
|
|
|
def on_cfg_denoiser_callback(self, params: CFGDenoiserParams, scfg_params: SCFGStateParams): |
|
|
|
|
|
self.unhook_callbacks(scfg_params) |
|
|
|
|
|
def on_cfg_denoised_callback(self, params: CFGDenoisedParams, scfg_params: SCFGStateParams): |
|
|
""" Callback function for the CFGDenoisedParams |
|
|
Refer to pg.22 A.2 of the PAG paper for how CFG and PAG combine |
|
|
|
|
|
""" |
|
|
scfg_params.current_step = params.sampling_step |
|
|
|
|
|
|
|
|
if not scfg_params.start_step <= params.sampling_step <= scfg_params.end_step: |
|
|
return |
|
|
|
|
|
if scfg_params.scfg_scale <= 0: |
|
|
return |
|
|
|
|
|
|
|
|
R = scfg_params.R |
|
|
max_latent_size = [params.x.shape[-2] // R, params.x.shape[-1] // R] |
|
|
|
|
|
|
|
|
ca_mask, fore_mask = get_mask(scfg_params.all_crossattn_modules, |
|
|
scfg_params, |
|
|
r = scfg_params.R, |
|
|
latent_size = max_latent_size, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
mask_t = F.interpolate(ca_mask, scale_factor=R, mode='nearest') |
|
|
mask_fore = F.interpolate(fore_mask, scale_factor=R, mode='nearest') |
|
|
scfg_params.mask_t = mask_t |
|
|
scfg_params.mask_fore = mask_fore |
|
|
|
|
|
|
|
|
def get_xyz_axis_options(self) -> dict: |
|
|
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ in ("xyz_grid.py", "scripts.xyz_grid")][0].module |
|
|
extra_axis_options = { |
|
|
xyz_grid.AxisOption("[SCFG] Active", str, scfg_apply_override('scfg_active', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)), |
|
|
xyz_grid.AxisOption("[SCFG] SCFG Scale", float, scfg_apply_field("scfg_scale")), |
|
|
xyz_grid.AxisOption("[SCFG] SCFG Rate Min", float, scfg_apply_field("scfg_rate_min")), |
|
|
xyz_grid.AxisOption("[SCFG] SCFG Rate Max", float, scfg_apply_field("scfg_rate_max")), |
|
|
xyz_grid.AxisOption("[SCFG] SCFG Rate Clamp", float, scfg_apply_field("scfg_rate_clamp")), |
|
|
xyz_grid.AxisOption("[SCFG] SCFG Start Step", int, scfg_apply_field("scfg_start_step")), |
|
|
xyz_grid.AxisOption("[SCFG] SCFG End Step", int, scfg_apply_field("scfg_end_step")), |
|
|
xyz_grid.AxisOption("[SCFG] SCFG R", int, scfg_apply_field("scfg_r")), |
|
|
} |
|
|
return extra_axis_options |
|
|
|
|
|
|
|
|
def scfg_combine_denoised(model_delta, cfg_scale, scfg_params: SCFGStateParams): |
|
|
""" The inner loop of the S-CFG denoiser |
|
|
Arguments: |
|
|
model_delta: torch.Tensor - defined by `x_out[cond_index] - denoised_uncond[i]` |
|
|
cfg_scale: float - guidance scale |
|
|
scfg_params: SCFGStateParams - the state parameters for the S-CFG denoiser |
|
|
|
|
|
Returns: |
|
|
int or torch.Tensor - 1.0 if not within interval or scale is 0, else the rate map tensor |
|
|
""" |
|
|
|
|
|
current_step = scfg_params.current_step |
|
|
start_step = scfg_params.start_step |
|
|
end_step = scfg_params.end_step |
|
|
scfg_scale = scfg_params.scfg_scale |
|
|
|
|
|
if not start_step <= current_step <= end_step: |
|
|
return 1.0 |
|
|
|
|
|
if scfg_scale <= 0: |
|
|
return 1.0 |
|
|
|
|
|
mask_t = scfg_params.mask_t |
|
|
mask_fore = scfg_params.mask_fore |
|
|
min_rate = scfg_params.rate_min |
|
|
max_rate = scfg_params.rate_max |
|
|
rate_clamp = scfg_params.rate_clamp |
|
|
|
|
|
model_delta = model_delta.unsqueeze(0) |
|
|
model_delta_norm = model_delta.norm(dim=1, keepdim=True) |
|
|
|
|
|
eps = lambda dtype: torch.finfo(dtype).eps |
|
|
|
|
|
|
|
|
if mask_t.shape[2:] != model_delta_norm.shape[2:]: |
|
|
logger.debug('Rescaling mask_t from %s to %s', mask_t.shape[2:], model_delta_norm.shape[2:]) |
|
|
mask_t = F.interpolate(mask_t, size=model_delta_norm.shape[2:], mode='bilinear') |
|
|
if mask_fore.shape[-2] != model_delta_norm.shape[-2]: |
|
|
logger.debug('Rescaling mask_fore from %s to %s', mask_fore.shape[2:], model_delta_norm.shape[2:]) |
|
|
mask_fore = F.interpolate(mask_fore, size=model_delta_norm.shape[2:], mode='bilinear') |
|
|
|
|
|
delta_mask_norms = (model_delta_norm * mask_t).sum([2,3])/(mask_t.sum([2,3])+eps(mask_t.dtype)) |
|
|
upnormmax = delta_mask_norms.max(dim=1)[0] |
|
|
upnormmax = upnormmax.unsqueeze(-1) |
|
|
|
|
|
fore_norms = (model_delta_norm * mask_fore).sum([2,3])/(mask_fore.sum([2,3])+eps(mask_fore.dtype)) |
|
|
|
|
|
up = fore_norms |
|
|
down = delta_mask_norms |
|
|
|
|
|
tmp_mask = (mask_t.sum([2,3])>0).float() |
|
|
rate = up*(tmp_mask)/(down+eps(down.dtype)) |
|
|
rate = (rate.unsqueeze(-1).unsqueeze(-1)*mask_t).sum(dim=1, keepdim=True) |
|
|
|
|
|
del model_delta_norm, delta_mask_norms, upnormmax, fore_norms, up, down, tmp_mask |
|
|
|
|
|
|
|
|
if rate.min().item() < scfg_params.statistics["min_rate"]: |
|
|
scfg_params.statistics["min_rate"] = rate.min().item() |
|
|
if rate.max().item() > scfg_params.statistics["max_rate"]: |
|
|
scfg_params.statistics["max_rate"] = rate.max().item() |
|
|
|
|
|
|
|
|
rate = rate * scfg_scale |
|
|
|
|
|
rate = torch.clamp(rate,min=min_rate, max=max_rate) |
|
|
|
|
|
if rate_clamp > 0: |
|
|
rate = torch.clamp_max(rate, rate_clamp/cfg_scale) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
smoothing = scfg_params.gaussian_smoothing |
|
|
rate = F.pad(rate, (1, 1, 1, 1), mode='reflect') |
|
|
rate = smoothing(rate) |
|
|
|
|
|
return rate.squeeze(0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def scfg_apply_override(field, boolean: bool = False): |
|
|
def fun(p, x, xs): |
|
|
if boolean: |
|
|
x = True if x.lower() == "true" else False |
|
|
setattr(p, field, x) |
|
|
if not hasattr(p, "scfg_active"): |
|
|
setattr(p, "scfg_active", True) |
|
|
return fun |
|
|
|
|
|
|
|
|
def scfg_apply_field(field): |
|
|
def fun(p, x, xs): |
|
|
if not hasattr(p, "scfg_active"): |
|
|
setattr(p, "scfg_active", True) |
|
|
setattr(p, field, x) |
|
|
return fun |
|
|
|
|
|
|
|
|
def _remove_all_forward_hooks( |
|
|
module: torch.nn.Module, hook_fn_name: Optional[str] = None |
|
|
) -> None: |
|
|
module_hooks.remove_module_forward_hook(module, hook_fn_name) |
|
|
|
|
|
|
|
|
""" |
|
|
# below code modified from https://github.com/SmilesDZgk/S-CFG |
|
|
@inproceedings{shen2024rethinking, |
|
|
title={Rethinking the Spatial Inconsistency in Classifier-Free Diffusion Guidancee}, |
|
|
author={Shen, Dazhong and Song, Guanglu and Xue, Zeyue and Wang, Fu-Yun and Liu, Yu}, |
|
|
booktitle={Proceedings of The IEEE/CVF Computer Vision and Pattern Recognition Conference (CVPR)}, |
|
|
year={2024} |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
import math |
|
|
import numbers |
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
|
class GaussianSmoothing(nn.Module): |
|
|
""" |
|
|
Apply gaussian smoothing on a |
|
|
1d, 2d or 3d tensor. Filtering is performed seperately for each channel |
|
|
in the input using a depthwise convolution. |
|
|
Arguments: |
|
|
channels (int, sequence): Number of channels of the input tensors. Output will |
|
|
have this number of channels as well. |
|
|
kernel_size (int, sequence): Size of the gaussian kernel. |
|
|
sigma (float, sequence): Standard deviation of the gaussian kernel. |
|
|
dim (int, optional): The number of dimensions of the data. |
|
|
Default value is 2 (spatial). |
|
|
""" |
|
|
def __init__(self, channels, kernel_size, sigma, dim=2): |
|
|
super(GaussianSmoothing, self).__init__() |
|
|
if isinstance(kernel_size, numbers.Number): |
|
|
kernel_size = [kernel_size] * dim |
|
|
if isinstance(sigma, numbers.Number): |
|
|
sigma = [sigma] * dim |
|
|
|
|
|
|
|
|
|
|
|
kernel = 1 |
|
|
meshgrids = torch.meshgrid( |
|
|
[ |
|
|
torch.arange(size, dtype=torch.float32) |
|
|
for size in kernel_size |
|
|
] |
|
|
) |
|
|
for size, std, mgrid in zip(kernel_size, sigma, meshgrids): |
|
|
mean = (size - 1) / 2 |
|
|
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ |
|
|
torch.exp(-((mgrid - mean) / (2 * std)) ** 2) |
|
|
|
|
|
|
|
|
kernel = kernel / torch.sum(kernel) |
|
|
|
|
|
|
|
|
kernel = kernel.view(1, 1, *kernel.size()) |
|
|
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) |
|
|
|
|
|
self.register_buffer('weight', kernel) |
|
|
self.groups = channels |
|
|
|
|
|
if dim == 1: |
|
|
self.conv = F.conv1d |
|
|
elif dim == 2: |
|
|
self.conv = F.conv2d |
|
|
elif dim == 3: |
|
|
self.conv = F.conv3d |
|
|
else: |
|
|
raise RuntimeError( |
|
|
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) |
|
|
) |
|
|
|
|
|
def forward(self, input): |
|
|
""" |
|
|
Apply gaussian filter to input. |
|
|
Arguments: |
|
|
input (torch.Tensor): Input to apply gaussian filter on. |
|
|
Returns: |
|
|
filtered (torch.Tensor): Filtered output. |
|
|
""" |
|
|
return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups) |
|
|
|
|
|
|
|
|
def head_to_batch_dim(x, heads, out_dim=3): |
|
|
head_size = heads |
|
|
if x.ndim == 3: |
|
|
|
|
|
batch_size, seq_len, dim = x.shape |
|
|
extra_dim = 1 |
|
|
else: |
|
|
batch_size, extra_dim, seq_len, dim = x.shape |
|
|
x = x.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) |
|
|
x = x.permute(0, 2, 1, 3) |
|
|
if out_dim == 3: |
|
|
x = x.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
def batch_to_head_dim(x, heads): |
|
|
head_size = heads |
|
|
batch_size, seq_len, dim = x.shape |
|
|
x = x.reshape(batch_size // head_size, head_size, seq_len, dim) |
|
|
x = x.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) |
|
|
return x |
|
|
|
|
|
|
|
|
def average_over_head_dim(x, heads): |
|
|
x = rearrange(x, '(b h) s t -> b h s t', h=heads).mean(1) |
|
|
return x |
|
|
|
|
|
|
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
def get_mask(attn_modules, scfg_params: SCFGStateParams, r, latent_size): |
|
|
""" Aggregates the attention across the different layers and heads at the specified resolution. |
|
|
In the original paper, r is a hyper-parameter set to 4. |
|
|
Arguments: |
|
|
attn_modules: List of attention modules |
|
|
scfg_params: SCFGStateParams |
|
|
r: int - |
|
|
latent_size: tuple |
|
|
|
|
|
""" |
|
|
height = scfg_params.height |
|
|
width = scfg_params.width |
|
|
max_dims = height * width |
|
|
latent_size = latent_size[-2:] |
|
|
module_attn_sizes = set() |
|
|
|
|
|
key_corss = f"r{r}_cross" |
|
|
key_self = f"r{r}_self" |
|
|
|
|
|
|
|
|
|
|
|
max_r = r |
|
|
max_sizes = r |
|
|
|
|
|
|
|
|
attnmap_r = 0 |
|
|
|
|
|
r_r = 1 |
|
|
new_ca = 0 |
|
|
new_fore=0 |
|
|
a_n=0 |
|
|
|
|
|
|
|
|
|
|
|
attention_store_proxy = {"r2_cross": [], "r4_cross": [], "r8_cross": [], "r16_cross": [], |
|
|
"r2_self": [], "r4_self": [], "r8_self": [], "r16_self": []} |
|
|
for module in attn_modules: |
|
|
module_type = 'cross' if 'attn2' in module.network_layer_name else 'self' |
|
|
|
|
|
to_q_map = getattr(module, 'scfg_last_to_q_map', None) |
|
|
to_k_map = getattr(module, 'scfg_last_to_k_map', None) |
|
|
|
|
|
if to_k_map is None: |
|
|
to_k_map = to_q_map |
|
|
|
|
|
to_q_map = prepare_attn_map(to_q_map, module.heads) |
|
|
to_k_map = prepare_attn_map(to_k_map, module.heads) |
|
|
|
|
|
module_attn_size = to_q_map.size(1) |
|
|
module_attn_sizes.add(module_attn_size) |
|
|
downscale_h = int((module_attn_size * (height / width)) ** 0.5) |
|
|
downscale_w = module_attn_size // downscale_h |
|
|
module_key = f"r{module_attn_size}_{module_type}" |
|
|
|
|
|
attn_probs = get_attention_scores(to_q_map, to_k_map, to_q_map.dtype) |
|
|
|
|
|
if module_type == 'self': |
|
|
del module.scfg_last_to_q_map |
|
|
else: |
|
|
del module.scfg_last_to_q_map, module.scfg_last_to_k_map |
|
|
|
|
|
if module_key not in attention_store_proxy: |
|
|
attention_store_proxy[module_key] = [] |
|
|
try: |
|
|
attention_store_proxy[module_key].append(attn_probs) |
|
|
except KeyError: |
|
|
continue |
|
|
|
|
|
module_attn_sizes = sorted(list(module_attn_sizes)) |
|
|
attention_maps = attention_store_proxy |
|
|
|
|
|
curr_r = module_attn_sizes.pop(0) |
|
|
while curr_r != None and attnmap_r < max_sizes: |
|
|
key_corss = f"r{curr_r}_cross" |
|
|
key_self = f"r{curr_r}_self" |
|
|
|
|
|
if key_self not in attention_maps.keys() or key_corss not in attention_maps.keys(): |
|
|
next_r = module_attn_sizes.pop(0) |
|
|
attnmap_r += 1 |
|
|
curr_r = next_r |
|
|
continue |
|
|
if len(attention_maps[key_self]) == 0 or len(attention_maps[key_corss]) == 0: |
|
|
curr_r = module_attn_sizes.pop(0) |
|
|
attnmap_r += 1 |
|
|
curr_r = next_r |
|
|
continue |
|
|
|
|
|
sa = torch.stack(attention_maps[key_self], dim=1) |
|
|
ca = torch.stack(attention_maps[key_corss], dim=1) |
|
|
attn_num = sa.size(1) |
|
|
sa = rearrange(sa, 'b n h w -> (b n) h w') |
|
|
ca = rearrange(ca, 'b n h w -> (b n) h w') |
|
|
|
|
|
curr = 0 |
|
|
curr +=sa |
|
|
|
|
|
|
|
|
ssgc_sa = curr |
|
|
ssgc_n = max_r |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for r_value in range(1, ssgc_n): |
|
|
r_pow = r_value + 1 |
|
|
curr @= sa |
|
|
|
|
|
ssgc_sa += curr |
|
|
|
|
|
ssgc_sa/=ssgc_n |
|
|
sa = ssgc_sa |
|
|
|
|
|
|
|
|
ca = sa@ca |
|
|
|
|
|
hw = ca.size(1) |
|
|
|
|
|
downscale_h = round((hw * (height / width)) ** 0.5) |
|
|
|
|
|
ca = rearrange(ca, 'b (h w) c -> b c h w', h=downscale_h ) |
|
|
|
|
|
|
|
|
max_size = latent_size |
|
|
scale_factor = [ |
|
|
max_size[0] / ca.shape[-2], |
|
|
max_size[1] / ca.shape[-1] |
|
|
] |
|
|
mode = 'bilinear' |
|
|
ca = F.interpolate(ca, scale_factor=scale_factor, mode=mode) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
smoothing = scfg_params.gaussian_smoothing |
|
|
channel = ca.size(1) |
|
|
ca= rearrange(ca, ' b c h w -> (b c) h w' ).unsqueeze(1) |
|
|
ca = F.pad(ca, (1, 1, 1, 1), mode='reflect') |
|
|
ca = smoothing(ca.float()).squeeze(1) |
|
|
ca = rearrange(ca, ' (b c) h w -> b c h w' , c= channel) |
|
|
|
|
|
ca_norm = ca/(ca.mean(dim=[2,3], keepdim=True)+torch.finfo(ca.dtype).eps) |
|
|
|
|
|
new_ca+=rearrange(ca_norm, '(b n) c h w -> b n c h w', n=attn_num).sum(1) |
|
|
|
|
|
fore_ca = torch.stack([ca[:,0],ca[:,1:].sum(dim=1)], dim=1) |
|
|
froe_ca_norm = fore_ca/fore_ca.mean(dim=[2,3], keepdim=True) |
|
|
new_fore += rearrange(froe_ca_norm, '(b n) c h w -> b n c h w', n=attn_num).sum(1) |
|
|
a_n+=attn_num |
|
|
|
|
|
if len(module_attn_sizes) > 0: |
|
|
curr_r = module_attn_sizes.pop(0) |
|
|
else: |
|
|
curr_r = None |
|
|
attnmap_r += 1 |
|
|
|
|
|
|
|
|
|
|
|
del ca_norm, froe_ca_norm, fore_ca |
|
|
|
|
|
|
|
|
del attention_maps |
|
|
del sa, ca, ssgc_sa, ssgc_n, curr |
|
|
|
|
|
|
|
|
|
|
|
new_ca = new_ca/a_n |
|
|
new_fore = new_fore/a_n |
|
|
_,new_ca = new_ca.chunk(2, dim=0) |
|
|
fore_ca, _ = new_fore.chunk(2, dim=0) |
|
|
|
|
|
max_ca, inds = torch.max(new_ca[:,:], dim=1) |
|
|
max_ca = max_ca.unsqueeze(1) |
|
|
ca_mask = (new_ca==max_ca).float() |
|
|
|
|
|
max_fore, inds = torch.max(fore_ca[:,:], dim=1) |
|
|
max_fore = max_fore.unsqueeze(1) |
|
|
fore_mask = (fore_ca==max_fore).float() |
|
|
fore_mask = 1.0-fore_mask[:,:1] |
|
|
|
|
|
|
|
|
del new_ca, new_fore, a_n, max_ca, max_fore, inds |
|
|
|
|
|
return [ ca_mask, fore_mask] |
|
|
|
|
|
|
|
|
def prepare_attn_map(to_k_map, heads): |
|
|
to_k_map = head_to_batch_dim(to_k_map, heads) |
|
|
to_k_map = average_over_head_dim(to_k_map, heads) |
|
|
to_k_map = torch.stack([to_k_map[0], to_k_map[0]], dim=0) |
|
|
return to_k_map |
|
|
|
|
|
|
|
|
def get_attention_scores(to_q_map, to_k_map, dtype): |
|
|
""" Calculate the attention scores for the given query and key maps |
|
|
Arguments: |
|
|
to_q_map: torch.Tensor - query map |
|
|
to_k_map: torch.Tensor - key map |
|
|
dtype: torch.dtype - data type of the tensor |
|
|
Returns: |
|
|
torch.Tensor - attention scores |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_probs = to_q_map @ to_k_map.transpose(-1, -2) |
|
|
|
|
|
|
|
|
attn_probs = attn_probs.to(dtype=torch.float32) |
|
|
attn_probs -= torch.max(attn_probs) |
|
|
|
|
|
torch.exp(attn_probs, out = attn_probs) |
|
|
summed = attn_probs.sum(dim=-1, keepdim=True, dtype=torch.float32) |
|
|
attn_probs /= summed |
|
|
|
|
|
attn_probs = attn_probs.to(dtype=dtype) |
|
|
|
|
|
return attn_probs |