|
|
import logging |
|
|
from os import environ |
|
|
import modules.scripts as scripts |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
from collections import OrderedDict |
|
|
from typing import Union |
|
|
import agentsd |
|
|
|
|
|
from modules import script_callbacks, rng, shared |
|
|
from modules.script_callbacks import CFGDenoiserParams |
|
|
|
|
|
import torch |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
logger.setLevel(environ.get("SD_WEBUI_LOG_LEVEL", logging.INFO)) |
|
|
|
|
|
""" |
|
|
An implementation of Agent Attention for stable-diffusion-webui: https://github.com/LeapLabTHU/Agent-Attention |
|
|
|
|
|
@misc{han2023agent, |
|
|
title={Agent Attention: On the Integration of Softmax and Linear Attention}, |
|
|
author={Dongchen Han and Tianzhu Ye and Yizeng Han and Zhuofan Xia and Shiji Song and Gao Huang}, |
|
|
year={2023}, |
|
|
eprint={2312.08874}, |
|
|
archivePrefix={arXiv}, |
|
|
primaryClass={cs.CV} |
|
|
} |
|
|
|
|
|
Author: v0xie |
|
|
GitHub URL: https://github.com/v0xie/sd-webui-agentattention |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AgentAttentionExtensionScript(scripts.Script): |
|
|
|
|
|
def title(self): |
|
|
return "Agent Attention" |
|
|
|
|
|
|
|
|
def show(self, is_img2img): |
|
|
return scripts.AlwaysVisible |
|
|
|
|
|
|
|
|
def ui(self, is_img2img): |
|
|
with gr.Accordion('AgentAttention', open=False): |
|
|
active = gr.Checkbox(value=False, default=False, label="Active", elem_id='aa_active') |
|
|
with gr.Row(): |
|
|
hires_fix_only = gr.Checkbox(value=False, default=False, label="Apply to Hires. Fix Only", elem_id = 'aa_hires_fix_only') |
|
|
use_fp32 = gr.Checkbox(value=False, default=False, label="Use FP32 Precision (for SD2.1)", elem_id = 'aa_use_fp32') |
|
|
use_sp = gr.Checkbox(value=False, default=False, label="Use Second Pass", elem_id = 'aa_use_sp') |
|
|
sp_step = gr.Slider(value = 20, minimum = 0, maximum = 100, step = 1, label="Second Pass Step", elem_id = 'aa_sp_step') |
|
|
max_downsample = gr.Radio(choices=[1,2,4,8], value=1, default=1, label="Max Downsample", elem_id = 'aa_max_downsample', info="For SDXL set to values > 1") |
|
|
with gr.Accordion('First Pass', open=False): |
|
|
sx = gr.Slider(value = 4, minimum = 0, maximum = 10, step = 1, label="sx", elem_id = 'aa_sx') |
|
|
sy = gr.Slider(value = 4, minimum = 0, maximum = 10, step = 1, label="sy", elem_id = 'aa_sy') |
|
|
ratio = gr.Slider(value = 0.4, minimum = 0.0, maximum = 1.0, step = 0.01, label="Ratio", elem_id = 'aa_ratio') |
|
|
agent_ratio = gr.Slider(value = 0.95, minimum = 0.0, maximum = 1.0, step = 0.01, label="Agent Ratio", elem_id = 'aa_agent_ratio') |
|
|
with gr.Accordion('Second Pass', open=False): |
|
|
sp_sx = gr.Slider(value = 2, minimum = 0, maximum = 10, step = 1, label="sx", elem_id = 'aa_sp_sx') |
|
|
sp_sy = gr.Slider(value = 2, minimum = 0, maximum = 10, step = 1, label="sy", elem_id = 'aa_sp_sy') |
|
|
sp_ratio = gr.Slider(value = 0.4, minimum = 0.0, maximum = 1.0, step = 0.01, label="Ratio", elem_id = 'aa_sp_ratio') |
|
|
sp_agent_ratio = gr.Slider(value = 0.5, minimum = 0.0, maximum = 1.0, step = 0.01, label="Agent Ratio", elem_id = 'aa_sp_agent_ratio') |
|
|
with gr.Accordion('Advanced', open=False): |
|
|
btn_remove_patch = gr.Button(value="Remove Patch", elem_id='aa_remove_patch') |
|
|
btn_remove_patch.click(self.remove_patch) |
|
|
|
|
|
active.do_not_save_to_config = True |
|
|
use_sp.do_not_save_to_config = True |
|
|
sp_step.do_not_save_to_config = True |
|
|
sx.do_not_save_to_config = True |
|
|
sy.do_not_save_to_config = True |
|
|
ratio.do_not_save_to_config = True |
|
|
agent_ratio.do_not_save_to_config = True |
|
|
sp_sx.do_not_save_to_config = True |
|
|
sp_sy.do_not_save_to_config = True |
|
|
sp_ratio.do_not_save_to_config = True |
|
|
sp_agent_ratio.do_not_save_to_config = True |
|
|
use_fp32.do_not_save_to_config = True |
|
|
max_downsample.do_not_save_to_config = True |
|
|
hires_fix_only.do_not_save_to_config = True |
|
|
self.infotext_fields = [ |
|
|
(active, lambda d: gr.Checkbox.update(value='AgAt Active' in d)), |
|
|
(use_sp, 'AgAt Use Second Pass'), |
|
|
(sp_step, 'AgAt Second Pass Step'), |
|
|
(sx, 'AgAt First Pass sx'), |
|
|
(sy, 'AgAt First Pass sy'), |
|
|
(ratio, 'AgAt First Pass Ratio'), |
|
|
(agent_ratio, 'AgAt First Pass Agent Ratio'), |
|
|
(sp_sx, 'AgAt Second Pass sx'), |
|
|
(sp_sy, 'AgAt Second Pass sy'), |
|
|
(sp_ratio, 'AgAt Second Pass Ratio'), |
|
|
(sp_agent_ratio, 'AgAt Second Pass Agent Ratio'), |
|
|
(use_fp32, 'AgAt Use FP32 Precision'), |
|
|
(max_downsample, 'AgAt Max Downsample'), |
|
|
(hires_fix_only, 'AgAt Apply to Hires. Fix Only'), |
|
|
] |
|
|
self.paste_field_names = [ |
|
|
'aa_active', |
|
|
'aa_use_sp', |
|
|
'aa_sp_step', |
|
|
'aa_sx', |
|
|
'aa_sy', |
|
|
'aa_ratio', |
|
|
'aa_agent_ratio', |
|
|
'aa_sp_sx', |
|
|
'aa_sp_sy', |
|
|
'aa_sp_ratio', |
|
|
'aa_sp_agent_ratio', |
|
|
'aa_use_fp32', |
|
|
'aa_max_downsample' |
|
|
'aa_hires_fix_only' |
|
|
] |
|
|
|
|
|
return [active, use_sp, sp_step, sx, sy, ratio, agent_ratio, sp_sx, sp_sy, sp_ratio, sp_agent_ratio, use_fp32, max_downsample, hires_fix_only] |
|
|
|
|
|
def before_process_batch(self, p, active, use_sp, sp_step, sx, sy, ratio, agent_ratio, sp_sx, sp_sy, sp_ratio, sp_agent_ratio, use_fp32, max_downsample, hires_fix_only, *args, **kwargs): |
|
|
active = getattr(p, "aa_active", active) |
|
|
if active is False: |
|
|
return |
|
|
|
|
|
hires_fix_only = getattr(p, "aa_hires_fix_only", hires_fix_only) |
|
|
if hires_fix_only is True: |
|
|
|
|
|
p.extra_generation_params = { |
|
|
"AgAt Active": active, |
|
|
"AgAt Apply to Hires. Fix Only": hires_fix_only, |
|
|
} |
|
|
logger.debug('Hires. Fix Only is True, skipping') |
|
|
return |
|
|
|
|
|
return self.setup_hook(p, active, use_sp, sp_step, sx, sy, ratio, agent_ratio, sp_sx, sp_sy, sp_ratio, sp_agent_ratio, use_fp32, max_downsample, hires_fix_only) |
|
|
|
|
|
def setup_hook(self, p, active, use_sp, sp_step, sx, sy, ratio, agent_ratio, sp_sx, sp_sy, sp_ratio, sp_agent_ratio, use_fp32, max_downsample, hires_fix_only): |
|
|
active = getattr(p, "aa_active", active) |
|
|
if active is False: |
|
|
return |
|
|
use_sp = getattr(p, "aa_use_sp", use_sp) |
|
|
sp_step = getattr(p, "aa_sp_step", sp_step) |
|
|
sx = getattr(p, "aa_sx", sx) |
|
|
sy = getattr(p, "aa_sy", sy) |
|
|
ratio = getattr(p, "aa_ratio", ratio) |
|
|
agent_ratio = getattr(p, "aa_agent_ratio", agent_ratio) |
|
|
sp_sx = getattr(p, "aa_sp_sx", sp_sx) |
|
|
sp_sy = getattr(p, "aa_sp_sy", sp_sy) |
|
|
sp_ratio = getattr(p, "aa_sp_ratio", sp_ratio) |
|
|
sp_agent_ratio = getattr(p, "aa_sp_agent_ratio", sp_agent_ratio) |
|
|
use_fp32 = getattr(p, "aa_use_fp32", use_fp32) |
|
|
max_downsample = getattr(p, "aa_max_downsample", max_downsample) |
|
|
hires_fix_only = getattr(p, "aa_hires_fix_only", hires_fix_only) |
|
|
|
|
|
p.extra_generation_params.update({ |
|
|
"AgAt Active": active, |
|
|
"AgAt Use Second Pass": use_sp, |
|
|
"AgAt Second Pass Step": sp_step, |
|
|
"AgAt First Pass sx": sx, |
|
|
"AgAt First Pass sy": sy, |
|
|
"AgAt First Pass Ratio": ratio, |
|
|
"AgAt First Pass Agent Ratio": agent_ratio, |
|
|
"AgAt Second Pass sx": sp_sx, |
|
|
"AgAt Second Pass sy": sp_sy, |
|
|
"AgAt Second Pass Ratio": sp_ratio, |
|
|
"AgAt Second Pass Agent Ratio": sp_agent_ratio, |
|
|
"AgAt Use FP32 Precision": use_fp32, |
|
|
"AgAt Max Downsample": max_downsample, |
|
|
"AgAt Apply to Hires. Fix Only": hires_fix_only, |
|
|
}) |
|
|
self.create_hook(p, active, use_sp, sp_step, sx, sy, ratio, agent_ratio, sp_sx, sp_sy, sp_ratio, sp_agent_ratio, use_fp32, max_downsample, hires_fix_only) |
|
|
|
|
|
def create_hook(self, p, active, use_sp, sp_step, sx, sy, ratio, agent_ratio, sp_sx, sp_sy, sp_ratio, sp_agent_ratio, use_fp32, max_downsample, hires_fix_only): |
|
|
|
|
|
y = lambda params: self.on_cfg_denoiser_callback(params, active=active, use_sp=use_sp, sp_step=sp_step, sx=sx, sy=sy, ratio=ratio, agent_ratio=agent_ratio, sp_sx=sp_sx, sp_sy=sp_sy, sp_ratio=sp_ratio, sp_agent_ratio=sp_agent_ratio, use_fp32=use_fp32, max_downsample=max_downsample, hires_fix_only=hires_fix_only) |
|
|
|
|
|
logger.debug('Hooked callbacks') |
|
|
script_callbacks.on_cfg_denoiser(y) |
|
|
script_callbacks.on_script_unloaded(self.unhook_callbacks) |
|
|
|
|
|
def postprocess_batch(self, p, active, use_sp, sp_step, sx, sy, ratio, agent_ratio, sp_sx, sp_sy, sp_ratio, sp_agent_ratio, use_fp32, max_downsample, hires_fix_only, *args, **kwargs): |
|
|
self.unhook_callbacks() |
|
|
|
|
|
def unhook_callbacks(self): |
|
|
logger.debug('Unhooked callbacks') |
|
|
self.remove_patch() |
|
|
script_callbacks.remove_current_script_callbacks() |
|
|
|
|
|
def apply_patch(self, sx=2, sy=2, ratio=0.4, agent_ratio=0.95, use_fp32=False, max_downsample=1): |
|
|
logger.debug('Applied patch with sx: %d, sy: %d, ratio: %f, agent_ratio: %f, use_fp32: %s, max_downsample: %d', sx, sy, ratio, agent_ratio, use_fp32, max_downsample) |
|
|
agentsd.apply_patch(shared.sd_model, sx=sx, sy=sy, ratio=ratio, agent_ratio=agent_ratio, attn_precision='fp32' if use_fp32 else None, max_downsample=max_downsample) |
|
|
|
|
|
def remove_patch(self): |
|
|
logger.debug('Removed patch') |
|
|
agentsd.remove_patch(shared.sd_model) |
|
|
|
|
|
def on_cfg_denoiser_callback(self, params: CFGDenoiserParams, active, use_sp, sp_step, sx, sy, ratio, agent_ratio, sp_sx, sp_sy, sp_ratio, sp_agent_ratio, use_fp32, max_downsample, hires_fix_only, *args, **kwargs): |
|
|
sampling_step = params.sampling_step |
|
|
|
|
|
if sampling_step == 0: |
|
|
self.remove_patch() |
|
|
self.apply_patch(sx=sx, sy=sy, ratio=ratio, agent_ratio=agent_ratio, use_fp32=use_fp32, max_downsample=max_downsample) |
|
|
|
|
|
if sampling_step == sp_step: |
|
|
self.remove_patch() |
|
|
if use_sp: |
|
|
self.apply_patch(sx=sp_sx, sy=sp_sy, ratio=sp_ratio, agent_ratio=sp_agent_ratio, use_fp32=use_fp32, max_downsample=max_downsample) |
|
|
|
|
|
def before_hr(self, p, *args, **kwargs): |
|
|
self.unhook_callbacks() |
|
|
|
|
|
params = getattr(p, "extra_generation_params", None) |
|
|
if not params: |
|
|
logger.error("Missing attribute extra_generation_params") |
|
|
return |
|
|
|
|
|
active = params.get("AgAt Active", False) |
|
|
if active is False: |
|
|
return |
|
|
|
|
|
apply_to_hr_pass = params.get("AgAt Apply to Hires. Fix Only", False) |
|
|
if apply_to_hr_pass is False: |
|
|
logger.debug("Disabled for hires. fix") |
|
|
return |
|
|
|
|
|
self.setup_hook(p, *args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def aa_apply_override(field, boolean: bool = False): |
|
|
def fun(p, x, xs): |
|
|
if boolean: |
|
|
x = True if x.lower() == "true" else False |
|
|
setattr(p, field, x) |
|
|
return fun |
|
|
|
|
|
def aa_apply_field(field): |
|
|
def fun(p, x, xs): |
|
|
if not hasattr(p, "aa_active"): |
|
|
setattr(p, "aa_active", True) |
|
|
setattr(p, field, x) |
|
|
|
|
|
return fun |
|
|
|
|
|
def make_axis_options(): |
|
|
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ in ('scripts.xyz_grid', 'xyz_grid.py')][0].module |
|
|
extra_axis_options = { |
|
|
xyz_grid.AxisOption("[AgentAttention] Active", str, aa_apply_override('aa_active', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)), |
|
|
xyz_grid.AxisOption("[AgentAttention] Use Second Pass", str, aa_apply_override('aa_use_sp', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)), |
|
|
xyz_grid.AxisOption("[AgentAttention] Second Pass Step", int, aa_apply_field("aa_sp_step")), |
|
|
xyz_grid.AxisOption("[AgentAttention] First Pass sx", int, aa_apply_field("aa_sx")), |
|
|
xyz_grid.AxisOption("[AgentAttention] First Pass sy", int, aa_apply_field("aa_sy")), |
|
|
xyz_grid.AxisOption("[AgentAttention] First Pass Ratio", float, aa_apply_field("aa_ratio")), |
|
|
xyz_grid.AxisOption("[AgentAttention] First Pass Agent Ratio", float, aa_apply_field("aa_agent_ratio")), |
|
|
xyz_grid.AxisOption("[AgentAttention] Second Pass sx", int, aa_apply_field("aa_sp_sx")), |
|
|
xyz_grid.AxisOption("[AgentAttention] Second Pass sy", int, aa_apply_field("aa_sp_sy")), |
|
|
xyz_grid.AxisOption("[AgentAttention] Second Pass Ratio", float, aa_apply_field("aa_sp_ratio")), |
|
|
xyz_grid.AxisOption("[AgentAttention] Second Pass Agent Ratio", float, aa_apply_field("aa_sp_agent_ratio")), |
|
|
xyz_grid.AxisOption("[AgentAttention] Use FP32", str, aa_apply_override('aa_use_fp32', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)), |
|
|
xyz_grid.AxisOption("[AgentAttention] Max Downsample", int, aa_apply_field('aa_max_downsample')), |
|
|
} |
|
|
if not any("[AgentAttention]" in x.label for x in xyz_grid.axis_options): |
|
|
xyz_grid.axis_options.extend(extra_axis_options) |
|
|
|
|
|
def callback_before_ui(): |
|
|
try: |
|
|
make_axis_options() |
|
|
except: |
|
|
logger.exception("AgentAttention: Error while making axis options") |
|
|
|
|
|
script_callbacks.on_before_ui(callback_before_ui) |
|
|
|