dikdimon's picture
Upload extensions using SD-Hub extension
3dabe4a verified
'''
This code from the following repository: https://github.com/LeapLabTHU/Agent-Attention
@article{han2023agent,
title={Agent Attention: On the Integration of Softmax and Linear Attention},
author={Han, Dongchen and Ye, Tianzhu and Han, Yizeng and Xia, Zhuofan and Song, Shiji and Huang, Gao},
journal={arXiv preprint arXiv:2312.08874},
year={2023}
}
'''
import torch
import math
from typing import Type, Dict, Any, Tuple, Callable
from . import merge
from .utils import isinstance_str, init_generator
from torch import nn, einsum
from einops import rearrange, repeat
from inspect import isfunction
def compute_merge(x: torch.Tensor, tome_info: Dict[str, Any]) -> Tuple[Callable, ...]:
original_h, original_w = tome_info["size"]
original_tokens = original_h * original_w
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
args = tome_info["args"]
if downsample <= args["max_downsample"]:
w = int(math.ceil(original_w / downsample))
h = int(math.ceil(original_h / downsample))
r = int(x.shape[1] * args["ratio"])
agent_r = int(x.shape[1] * args["agent_ratio"])
# Re-init the generator if it hasn't already been initialized or device has changed.
if args["generator"] is None:
args["generator"] = init_generator(x.device)
elif args["generator"].device != x.device:
args["generator"] = init_generator(x.device, fallback=args["generator"])
# If the batch size is odd, then it's not possible for prompted and unprompted images to be in the same
# batch, which causes artifacts with use_rand, so force it to be off.
use_rand = False if x.shape[0] % 2 == 1 else args["use_rand"]
m, u = merge.bipartite_soft_matching_random2d(x, w, h, args["sx"], args["sy"], r, agent_r,
no_rand=not use_rand, generator=args["generator"])
else:
m, u = (merge.do_nothing_2, merge.do_nothing)
m_a, u_a = (m, u) if args["merge_attn"] else (merge.do_nothing_2, merge.do_nothing)
m_c, u_c = (m, u) if args["merge_crossattn"] else (merge.do_nothing_2, merge.do_nothing)
m_m, u_m = (m, u) if args["merge_mlp"] else (merge.do_nothing_2, merge.do_nothing)
return m_a, m_c, m_m, u_a, u_c, u_m # Okay this is probably not very good
def make_tome_block(block_class: Type[torch.nn.Module], old_forward) -> Type[torch.nn.Module]:
"""
Make a patched class on the fly so we don't have to import any specific modules.
This patch applies AgentSD and ToMe to the forward function of the block.
"""
class ToMeBlock(block_class):
# Save for unpatching later
_parent = block_class
_old_forward = old_forward
def _forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor:
m_a, m_c, m_m, u_a, u_c, u_m = compute_merge(x, self._tome_info)
# This is where the meat of the computation happens
y = self.norm1(x)
feature, agent = m_a(y)
x = u_a(self.attn1(feature, agent=agent, context=context if self.disable_self_attn else None)) + x
y = self.norm2(x)
feature, agent = m_c(y)
x = u_c(self.attn2(feature, agent=agent, context=context)) + x
y = self.norm3(x)
feature, _ = m_m(y)
x = u_m(self.ff(feature)) + x
return x
return ToMeBlock
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def make_agent_attn(block_class: Type[torch.nn.Module], k_scale2, k_shortcut, attn_precision=None) -> Type[torch.nn.Module]:
"""
This patch applies AgentSD to the forward function of the block.
"""
class AgentAttention(block_class):
# Save for unpatching later
_parent = block_class
def set_new_params(self):
self.k_scale2 = k_scale2
self.k_shortcut = k_shortcut
self.attn_precision = attn_precision
def forward(self, x, agent=None, context=None, mask=None, *args, **kwargs):
if agent is not None:
if agent.shape[1] * 2 < x.shape[1]:
k_scale2 = self.k_scale2
k_shortcut = self.k_shortcut
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
agent = self.to_q(agent)
q, k, v, agent = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v, agent))
if exists(mask):
print('Mask not supported yet!')
# force cast to fp32 to avoid overflowing
if self.attn_precision == "fp32":
with torch.autocast(enabled=False, device_type='cuda'):
agent, k = agent.float(), k.float()
sim1 = einsum('b i d, b j d -> b i j', agent, k) * self.scale
del k
else:
sim1 = einsum('b i d, b j d -> b i j', agent, k) * self.scale
# attention, what we cannot get enough of
attn1 = sim1.softmax(dim=-1)
agent_feature = einsum('b i j, b j d -> b i d', attn1, v)
# force cast to fp32 to avoid overflowing
if self.attn_precision == "fp32":
with torch.autocast(enabled=False, device_type='cuda'):
q = q.float()
sim2 = einsum('b i d, b j d -> b i j', q, agent) * self.scale ** k_scale2
del q, agent
else:
sim2 = einsum('b i d, b j d -> b i j', q, agent) * self.scale ** k_scale2
# attention, what we cannot get enough of
attn2 = sim2.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn2, agent_feature)
out = out * 1.0 + v * k_shortcut
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
return AgentAttention
def make_diffusers_tome_block(block_class: Type[torch.nn.Module], old_forward) -> Type[torch.nn.Module]:
"""
Make a patched class for a diffusers model.
This patch applies ToMe to the forward function of the block.
"""
class ToMeBlock(block_class):
# Save for unpatching later
_parent = block_class
_old_forward = old_forward
def forward(
self,
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
timestep=None,
cross_attention_kwargs=None,
class_labels=None,
) -> torch.Tensor:
# (1) ToMe
m_a, m_c, m_m, u_a, u_c, u_m = compute_merge(hidden_states, self._tome_info)
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)
# (2) ToMe m_a
norm_hidden_states = m_a(norm_hidden_states)
# 1. Self-Attention
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
# (3) ToMe u_a
hidden_states = u_a(attn_output) + hidden_states
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
# (4) ToMe m_c
norm_hidden_states = m_c(norm_hidden_states)
# 2. Cross-Attention
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
# (5) ToMe u_c
hidden_states = u_c(attn_output) + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
# (6) ToMe m_m
norm_hidden_states = m_m(norm_hidden_states)
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
# (7) ToMe u_m
hidden_states = u_m(ff_output) + hidden_states
return hidden_states
return ToMeBlock
def hook_tome_model(model: torch.nn.Module):
""" Adds a forward pre hook to get the image size. This hook can be removed with remove_patch. """
def hook(module, args):
module._tome_info["size"] = (args[0].shape[2], args[0].shape[3])
return None
model._tome_info["hooks"].append(model.register_forward_pre_hook(hook))
def apply_patch(
model: torch.nn.Module,
ratio: float = 0.5,
max_downsample: int = 1,
sx: int = 2, sy: int = 2,
agent_ratio: float = 0.8,
k_scale2=0.3,
k_shortcut=0.075,
attn_precision=None,
use_rand: bool = True,
merge_attn: bool = True,
merge_crossattn: bool = False,
merge_mlp: bool = False):
"""
Patches a stable diffusion model with AgentSD.
Apply this to the highest level stable diffusion object (i.e., it should have a .model.diffusion_model).
Important Args:
- model: A top level Stable Diffusion module to patch in place. Should have a ".model.diffusion_model"
- ratio: The ratio of tokens to merge. I.e., 0.4 would reduce the total number of tokens by 40%.
The maximum value for this is 1-(1/(sx*sy)). By default, the max is 0.75 (I recommend <= 0.5 though).
Higher values result in more speed-up, but with more visual quality loss.
- agent_ratio: The ratio of tokens to merge when producing agent tokens.
Args to tinker with if you want:
- max_downsample [1, 2, 4, or 8]: Apply AgentSD to layers with at most this amount of downsampling.
E.g., 1 only applies to layers with no downsampling (4/15) while
8 applies to all layers (15/15). I recommend a value of 1 or 2.
- sx, sy: The stride for computing dst sets (see paper). A higher stride means you can merge more tokens,
but the default of (2, 2) works well in most cases. Doesn't have to divide image size.
- k_scale2: The scale used for the second attention is head_dim ** (-0.5 * k_scale2)
- k_shortcut: The ratio used in O = sigma(QA^T) sigma(AK^T) V + k * V.
- attn_precision: Set attn_precision="fp32" to avoid numerical instabilities on SD v2.1 model.
- use_rand: Whether or not to allow random perturbations when computing dst sets (see paper). Usually
you'd want to leave this on, but if you're having weird artifacts try turning this off.
- merge_attn: Whether or not to merge tokens for attention (recommended).
- merge_crossattn: Whether or not to merge tokens for cross attention (not recommended).
- merge_mlp: Whether or not to merge tokens for the mlp layers (very not recommended).
"""
# Make sure the module is not currently patched
remove_patch(model)
is_diffusers = isinstance_str(model, "DiffusionPipeline") or isinstance_str(model, "ModelMixin")
if not is_diffusers:
if not hasattr(model, "model") or not hasattr(model.model, "diffusion_model"):
# Provided model not supported
raise RuntimeError("Provided model was not a Stable Diffusion / Latent Diffusion model, as expected.")
diffusion_model = model.model.diffusion_model
else:
# Supports "pipe.unet" and "unet"
diffusion_model = model.unet if hasattr(model, "unet") else model
diffusion_model._tome_info = {
"size": None,
"hooks": [],
"args": {
"ratio": ratio,
"max_downsample": max_downsample,
"sx": sx, "sy": sy,
"agent_ratio": agent_ratio,
"use_rand": use_rand,
"generator": None,
"merge_attn": merge_attn,
"merge_crossattn": merge_crossattn,
"merge_mlp": merge_mlp
}
}
hook_tome_model(diffusion_model)
for _, module in diffusion_model.named_modules():
# If for some reason this has a different name, create an issue and I'll fix it
if isinstance_str(module, "BasicTransformerBlock"):
module._old_class_= [module.__class__]
_old_forward = None
make_tome_block_fn = make_diffusers_tome_block if is_diffusers else make_tome_block
module.__class__ = make_tome_block_fn(module.__class__, _old_forward)
module._tome_info = diffusion_model._tome_info
module._old_attn1 = [module.attn1.__class__]
module._old_attn2 = [module.attn2.__class__]
module.attn1.__class__ = make_agent_attn(module.attn1.__class__, k_scale2=k_scale2, k_shortcut=k_shortcut, attn_precision=attn_precision)
module.attn2.__class__ = make_agent_attn(module.attn2.__class__, k_scale2=k_scale2, k_shortcut=k_shortcut, attn_precision=attn_precision)
module.attn1.set_new_params()
module.attn2.set_new_params()
# Something introduced in SD 2.0 (LDM only)
if not hasattr(module, "disable_self_attn") and not is_diffusers:
module.disable_self_attn = False
# Something needed for older versions of diffusers
if not hasattr(module, "use_ada_layer_norm_zero") and is_diffusers:
module.use_ada_layer_norm = False
module.use_ada_layer_norm_zero = False
return model
def remove_patch(model: torch.nn.Module):
""" Removes a patch from a AgentSD Diffusion module if it was already patched. """
# For diffusers
model = model.unet if hasattr(model, "unet") else model
for _, module in model.named_modules():
if hasattr(module, "_tome_info"):
for hook in module._tome_info["hooks"]:
hook.remove()
module._tome_info["hooks"].clear()
if module.__class__.__name__ == "ToMeBlock":
if hasattr(module, "_old__class__"):
module.__class__ = module._old__class__[0]
else:
module.__class__ = module._parent
if hasattr(module, "_old_attn1"):
module.attn1.__class__ = module._old_attn1[0]
if hasattr(module, "_old_attn2"):
module.attn2.__class__ = module._old_attn2[0]
return model