VISTA / ip_adapter /attention_processor.py
ssoxye's picture
Clean Space repo (code only) + gradio app
689a987
# -*- coding: utf-8 -*-
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import PIL.Image
import numpy as np
from typing import Optional
class AttnProcessor(nn.Module):
r"""
Default processor for performing attention-related computations.
"""
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
):
super().__init__()
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IPAttnProcessor(nn.Module):
r"""
Attention processor for IP-Adapater.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.skip = skip
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
if not self.skip:
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
self.attn_map = ip_attention_probs
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttnProcessor2_0(nn.Module):
def __init__(self, hidden_size: Optional[int] = None, cross_attention_dim: Optional[int] = None):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0 or later.")
def forward(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
b, c, h, w = hidden_states.shape
hidden_states = hidden_states.view(b, c, h * w).transpose(1, 2)
# group norm
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
# q, k, v
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# reshape heads
bsz = hidden_states.shape[0]
head_dim = key.shape[-1] // attn.heads
query = query.view(bsz, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(bsz, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(bsz, -1, attn.heads, head_dim).transpose(1, 2)
if attention_mask is not None:
pass
out = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
# merge heads
out = out.transpose(1, 2).reshape(bsz, -1, attn.heads * head_dim).to(query.dtype)
# out proj + dropout
out = attn.to_out[1](attn.to_out[0](out))
if input_ndim == 4:
out = out.transpose(-1, -2).reshape(bsz, c, h, w)
if attn.residual_connection:
out = out + residual
out = out / attn.rescale_output_factor
return out
def prepare_mask(mask: PIL.Image.Image) -> torch.Tensor:
"""
mask: PIL.Image | np.ndarray | torch.Tensor
반환: (B,1,H,W) float32 in {0,1}
"""
if isinstance(mask, torch.Tensor):
m = mask.clone()
if m.ndim == 2: # (H,W) -> (1,1,H,W)
m = m.unsqueeze(0).unsqueeze(0)
elif m.ndim == 3: # (1,H,W) or (B,H,W) -> (B,1,H,W)
if m.shape[0] == 1:
m = m.unsqueeze(0)
else:
m = m.unsqueeze(1)
if m.min() < 0 or m.max() > 1:
raise ValueError("Mask tensor must be in [0,1].")
m = (m >= 0.5).float()
return m
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
arr = np.concatenate([np.array(m.convert("L"))[None, None, ...] for m in mask], axis=0).astype(np.float16) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
arr = np.concatenate([m[None, None, ...] for m in mask], axis=0).astype(np.float16)
if arr.max() > 1.0:
arr = arr / 255.0
else:
raise TypeError("Unsupported mask type.")
arr = (arr >= 0.5).astype(np.float16)
return torch.from_numpy(arr)
class IPAttnProcessor2_0(nn.Module):
def __init__(self, hidden_size: int, cross_attention_dim: int, scale: float = 1.0, num_tokens: int = 4, skip: bool = False):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("IPAttnProcessor2_0 requires PyTorch 2.0 or later.")
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = float(scale)
self.num_tokens = int(num_tokens)
self.skip = bool(skip)
proj_in = cross_attention_dim if cross_attention_dim is not None else hidden_size
self.to_k_ip = nn.Linear(proj_in, hidden_size, bias=False)
self.to_v_ip = nn.Linear(proj_in, hidden_size, bias=False)
self.last_scale = None
self.last_skip = None
self.last_out_l2 = None
self.last_layer_name = None
self.last_group = None
self.last_ip_source = None
self.last_ip_mu = None
def forward(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
b, c, h, w = hidden_states.shape
hidden_states = hidden_states.view(b, c, h * w).transpose(1, 2)
else:
b = hidden_states.shape[0]
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
base_enc = hidden_states
tail_ip_tokens = None
else:
if encoder_hidden_states.shape[1] >= self.num_tokens and self.num_tokens > 0:
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
base_enc = encoder_hidden_states[:, :end_pos, :] # 텍스트(+기타)만
tail_ip_tokens = encoder_hidden_states[:, end_pos:, :] # 전역 concat된 이미지 토큰
else:
base_enc = encoder_hidden_states
tail_ip_tokens = None
if attn.norm_cross:
base_enc = attn.norm_encoder_hidden_states(base_enc)
group = getattr(self, "group", "off") # "content" / "style" / "off"
override = getattr(self, "ip_tokens_override", None)
override_uncond = getattr(self, "ip_tokens_override_uncond", None)
ip_tokens = None
ip_source = "none"
if group == "content":
ip_tokens = tail_ip_tokens
ip_source = "tail" if tail_ip_tokens is not None else "none"
elif group == "style":
if override is not None:
N, T, D = override.shape
if override_uncond is None:
override_uncond = torch.zeros_like(override)
if b == N:
ip_tokens = override
elif b == 2 * N:
ip_tokens = torch.cat([override_uncond, override], dim=0)
elif b % N == 0:
reps = b // N
ip_tokens = override.repeat(reps, 1, 1)
else:
ip_tokens = override.expand(b, -1, -1)
ip_source = "override"
else:
ip_tokens = None
ip_source = "none"
else:
ip_tokens = None
ip_source = "none"
key = attn.to_k(base_enc)
value = attn.to_v(base_enc)
head_dim = key.shape[-1] // attn.heads
query = query.view(b, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(b, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(b, -1, attn.heads, head_dim).transpose(1, 2)
out = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
with torch.no_grad():
self.last_group = group
self.last_ip_source = ip_source
if ip_tokens is None:
self.last_ip_mu = None
else:
mu = ip_tokens.detach().float().mean(dim=(0, 1)) # [D]
self.last_ip_mu = mu.cpu()
do_inject = (not self.skip) and (ip_tokens is not None) and (ip_tokens.shape[1] == self.num_tokens)
if do_inject:
ip_k = self.to_k_ip(ip_tokens).view(b, -1, attn.heads, head_dim).transpose(1, 2)
ip_v = self.to_v_ip(ip_tokens).view(b, -1, attn.heads, head_dim).transpose(1, 2)
ip_out = F.scaled_dot_product_attention(query, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out + float(self.scale) * ip_out
with torch.no_grad():
self.last_ip_out_l2 = ip_out.float().pow(2).sum(dim=tuple(range(1, ip_out.ndim))).sqrt().mean().item()
out = out.transpose(1, 2).reshape(b, -1, attn.heads * head_dim).to(query.dtype)
out = attn.to_out[1](attn.to_out[0](out))
if input_ndim == 4:
out = out.transpose(-1, -2).reshape(b, c, h, w)
if attn.residual_connection:
out = out + residual
out = out / attn.rescale_output_factor
with torch.no_grad():
self.last_scale = float(self.scale)
self.last_skip = bool(self.skip)
if isinstance(out, torch.Tensor):
if out.ndim >= 2:
self.last_out_l2 = out.float().pow(2).sum(dim=tuple(range(1, out.ndim))).sqrt().mean().item()
else:
self.last_out_l2 = out.float().pow(2).sum().sqrt().item()
else:
self.last_out_l2 = None
return out
## for controlnet
class CNAttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __init__(self, num_tokens=4):
self.num_tokens = num_tokens
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class CNAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self, num_tokens=4):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.num_tokens = num_tokens
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states