# -*- coding: utf-8 -*- # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms import PIL.Image import numpy as np from typing import Optional class AttnProcessor(nn.Module): r""" Default processor for performing attention-related computations. """ def __init__( self, hidden_size=None, cross_attention_dim=None, ): super().__init__() def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class IPAttnProcessor(nn.Module): r""" Attention processor for IP-Adapater. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. scale (`float`, defaults to 1.0): the weight scale of image prompt. num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): The context length of the image features. """ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.scale = scale self.num_tokens = num_tokens self.skip = skip self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: # get encoder_hidden_states, ip_hidden_states end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states, ip_hidden_states = ( encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :], ) if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) if not self.skip: # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) ip_key = attn.head_to_batch_dim(ip_key) ip_value = attn.head_to_batch_dim(ip_value) ip_attention_probs = attn.get_attention_scores(query, ip_key, None) self.attn_map = ip_attention_probs ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class AttnProcessor2_0(nn.Module): def __init__(self, hidden_size: Optional[int] = None, cross_attention_dim: Optional[int] = None): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0 or later.") def forward(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: b, c, h, w = hidden_states.shape hidden_states = hidden_states.view(b, c, h * w).transpose(1, 2) # group norm if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) # q, k, v query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # reshape heads bsz = hidden_states.shape[0] head_dim = key.shape[-1] // attn.heads query = query.view(bsz, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(bsz, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(bsz, -1, attn.heads, head_dim).transpose(1, 2) if attention_mask is not None: pass out = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) # merge heads out = out.transpose(1, 2).reshape(bsz, -1, attn.heads * head_dim).to(query.dtype) # out proj + dropout out = attn.to_out[1](attn.to_out[0](out)) if input_ndim == 4: out = out.transpose(-1, -2).reshape(bsz, c, h, w) if attn.residual_connection: out = out + residual out = out / attn.rescale_output_factor return out def prepare_mask(mask: PIL.Image.Image) -> torch.Tensor: """ mask: PIL.Image | np.ndarray | torch.Tensor 반환: (B,1,H,W) float32 in {0,1} """ if isinstance(mask, torch.Tensor): m = mask.clone() if m.ndim == 2: # (H,W) -> (1,1,H,W) m = m.unsqueeze(0).unsqueeze(0) elif m.ndim == 3: # (1,H,W) or (B,H,W) -> (B,1,H,W) if m.shape[0] == 1: m = m.unsqueeze(0) else: m = m.unsqueeze(1) if m.min() < 0 or m.max() > 1: raise ValueError("Mask tensor must be in [0,1].") m = (m >= 0.5).float() return m if isinstance(mask, (PIL.Image.Image, np.ndarray)): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): arr = np.concatenate([np.array(m.convert("L"))[None, None, ...] for m in mask], axis=0).astype(np.float16) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): arr = np.concatenate([m[None, None, ...] for m in mask], axis=0).astype(np.float16) if arr.max() > 1.0: arr = arr / 255.0 else: raise TypeError("Unsupported mask type.") arr = (arr >= 0.5).astype(np.float16) return torch.from_numpy(arr) class IPAttnProcessor2_0(nn.Module): def __init__(self, hidden_size: int, cross_attention_dim: int, scale: float = 1.0, num_tokens: int = 4, skip: bool = False): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("IPAttnProcessor2_0 requires PyTorch 2.0 or later.") self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.scale = float(scale) self.num_tokens = int(num_tokens) self.skip = bool(skip) proj_in = cross_attention_dim if cross_attention_dim is not None else hidden_size self.to_k_ip = nn.Linear(proj_in, hidden_size, bias=False) self.to_v_ip = nn.Linear(proj_in, hidden_size, bias=False) self.last_scale = None self.last_skip = None self.last_out_l2 = None self.last_layer_name = None self.last_group = None self.last_ip_source = None self.last_ip_mu = None def forward(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: b, c, h, w = hidden_states.shape hidden_states = hidden_states.view(b, c, h * w).transpose(1, 2) else: b = hidden_states.shape[0] if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: base_enc = hidden_states tail_ip_tokens = None else: if encoder_hidden_states.shape[1] >= self.num_tokens and self.num_tokens > 0: end_pos = encoder_hidden_states.shape[1] - self.num_tokens base_enc = encoder_hidden_states[:, :end_pos, :] # 텍스트(+기타)만 tail_ip_tokens = encoder_hidden_states[:, end_pos:, :] # 전역 concat된 이미지 토큰 else: base_enc = encoder_hidden_states tail_ip_tokens = None if attn.norm_cross: base_enc = attn.norm_encoder_hidden_states(base_enc) group = getattr(self, "group", "off") # "content" / "style" / "off" override = getattr(self, "ip_tokens_override", None) override_uncond = getattr(self, "ip_tokens_override_uncond", None) ip_tokens = None ip_source = "none" if group == "content": ip_tokens = tail_ip_tokens ip_source = "tail" if tail_ip_tokens is not None else "none" elif group == "style": if override is not None: N, T, D = override.shape if override_uncond is None: override_uncond = torch.zeros_like(override) if b == N: ip_tokens = override elif b == 2 * N: ip_tokens = torch.cat([override_uncond, override], dim=0) elif b % N == 0: reps = b // N ip_tokens = override.repeat(reps, 1, 1) else: ip_tokens = override.expand(b, -1, -1) ip_source = "override" else: ip_tokens = None ip_source = "none" else: ip_tokens = None ip_source = "none" key = attn.to_k(base_enc) value = attn.to_v(base_enc) head_dim = key.shape[-1] // attn.heads query = query.view(b, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(b, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(b, -1, attn.heads, head_dim).transpose(1, 2) out = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) with torch.no_grad(): self.last_group = group self.last_ip_source = ip_source if ip_tokens is None: self.last_ip_mu = None else: mu = ip_tokens.detach().float().mean(dim=(0, 1)) # [D] self.last_ip_mu = mu.cpu() do_inject = (not self.skip) and (ip_tokens is not None) and (ip_tokens.shape[1] == self.num_tokens) if do_inject: ip_k = self.to_k_ip(ip_tokens).view(b, -1, attn.heads, head_dim).transpose(1, 2) ip_v = self.to_v_ip(ip_tokens).view(b, -1, attn.heads, head_dim).transpose(1, 2) ip_out = F.scaled_dot_product_attention(query, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out + float(self.scale) * ip_out with torch.no_grad(): self.last_ip_out_l2 = ip_out.float().pow(2).sum(dim=tuple(range(1, ip_out.ndim))).sqrt().mean().item() out = out.transpose(1, 2).reshape(b, -1, attn.heads * head_dim).to(query.dtype) out = attn.to_out[1](attn.to_out[0](out)) if input_ndim == 4: out = out.transpose(-1, -2).reshape(b, c, h, w) if attn.residual_connection: out = out + residual out = out / attn.rescale_output_factor with torch.no_grad(): self.last_scale = float(self.scale) self.last_skip = bool(self.skip) if isinstance(out, torch.Tensor): if out.ndim >= 2: self.last_out_l2 = out.float().pow(2).sum(dim=tuple(range(1, out.ndim))).sqrt().mean().item() else: self.last_out_l2 = out.float().pow(2).sum().sqrt().item() else: self.last_out_l2 = None return out ## for controlnet class CNAttnProcessor: r""" Default processor for performing attention-related computations. """ def __init__(self, num_tokens=4): self.num_tokens = num_tokens def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class CNAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__(self, num_tokens=4): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.num_tokens = num_tokens def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states