| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from typing import Callable, Optional, Union |
| | import math |
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| | from torch.autograd import Function |
| |
|
| | from diffusers.utils import deprecate, logging |
| | from diffusers.utils.import_utils import is_xformers_available |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | if is_xformers_available(): |
| | import xformers |
| | import xformers.ops |
| | else: |
| | xformers = None |
| |
|
| | class Attention(nn.Module): |
| | r""" |
| | A cross attention layer. |
| | Parameters: |
| | query_dim (`int`): The number of channels in the query. |
| | cross_attention_dim (`int`, *optional*): |
| | The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. |
| | heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. |
| | dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. |
| | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. |
| | bias (`bool`, *optional*, defaults to False): |
| | Set to `True` for the query, key, and value linear layers to contain a bias parameter. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | query_dim: int, |
| | cross_attention_dim: Optional[int] = None, |
| | heads: int = 8, |
| | dim_head: int = 64, |
| | dropout: float = 0.0, |
| | bias=False, |
| | upcast_attention: bool = False, |
| | upcast_softmax: bool = False, |
| | cross_attention_norm: Optional[str] = None, |
| | cross_attention_norm_num_groups: int = 32, |
| | added_kv_proj_dim: Optional[int] = None, |
| | norm_num_groups: Optional[int] = None, |
| | out_bias: bool = True, |
| | scale_qk: bool = True, |
| | only_cross_attention: bool = False, |
| | processor: Optional["AttnProcessor"] = None, |
| | ): |
| | super().__init__() |
| | inner_dim = dim_head * heads |
| | cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim |
| | self.upcast_attention = upcast_attention |
| | self.upcast_softmax = upcast_softmax |
| |
|
| | self.scale = dim_head**-0.5 if scale_qk else 1.0 |
| |
|
| | self.heads = heads |
| | |
| | |
| | |
| | self.sliceable_head_dim = heads |
| |
|
| | self.added_kv_proj_dim = added_kv_proj_dim |
| | self.only_cross_attention = only_cross_attention |
| |
|
| | if self.added_kv_proj_dim is None and self.only_cross_attention: |
| | raise ValueError( |
| | "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." |
| | ) |
| |
|
| | if norm_num_groups is not None: |
| | self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) |
| | else: |
| | self.group_norm = None |
| |
|
| | if cross_attention_norm is None: |
| | self.norm_cross = None |
| | elif cross_attention_norm == "layer_norm": |
| | self.norm_cross = nn.LayerNorm(cross_attention_dim) |
| | elif cross_attention_norm == "group_norm": |
| | if self.added_kv_proj_dim is not None: |
| | |
| | |
| | |
| | |
| | |
| | norm_cross_num_channels = added_kv_proj_dim |
| | else: |
| | norm_cross_num_channels = cross_attention_dim |
| |
|
| | self.norm_cross = nn.GroupNorm( |
| | num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True |
| | ) |
| | else: |
| | raise ValueError( |
| | f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" |
| | ) |
| |
|
| | self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) |
| |
|
| | if not self.only_cross_attention: |
| | |
| | self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) |
| | self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) |
| | else: |
| | self.to_k = None |
| | self.to_v = None |
| |
|
| | if self.added_kv_proj_dim is not None: |
| | self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim) |
| | self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim) |
| |
|
| | self.to_out = nn.ModuleList([]) |
| | self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias)) |
| | self.to_out.append(nn.Dropout(dropout)) |
| |
|
| | |
| | |
| | |
| | |
| | if processor is None: |
| | processor = ( |
| | AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else AttnProcessor() |
| | ) |
| | self.set_processor(processor) |
| |
|
| | def set_use_memory_efficient_attention_xformers( |
| | self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None |
| | ): |
| | is_lora = hasattr(self, "processor") and isinstance( |
| | self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor) |
| | ) |
| |
|
| | if use_memory_efficient_attention_xformers: |
| | if self.added_kv_proj_dim is not None: |
| | |
| | |
| | |
| | raise NotImplementedError( |
| | "Memory efficient attention with `xformers` is currently not supported when" |
| | " `self.added_kv_proj_dim` is defined." |
| | ) |
| | elif not is_xformers_available(): |
| | raise ModuleNotFoundError( |
| | ( |
| | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" |
| | " xformers" |
| | ), |
| | name="xformers", |
| | ) |
| | elif not torch.cuda.is_available(): |
| | raise ValueError( |
| | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" |
| | " only available for GPU " |
| | ) |
| | else: |
| | try: |
| | |
| | _ = xformers.ops.memory_efficient_attention( |
| | torch.randn((1, 2, 40), device="cuda"), |
| | torch.randn((1, 2, 40), device="cuda"), |
| | torch.randn((1, 2, 40), device="cuda"), |
| | ) |
| | except Exception as e: |
| | raise e |
| |
|
| | if is_lora: |
| | processor = LoRAXFormersAttnProcessor( |
| | hidden_size=self.processor.hidden_size, |
| | cross_attention_dim=self.processor.cross_attention_dim, |
| | rank=self.processor.rank, |
| | attention_op=attention_op, |
| | ) |
| | processor.load_state_dict(self.processor.state_dict()) |
| | processor.to(self.processor.to_q_lora.up.weight.device) |
| | else: |
| | processor = XFormersAttnProcessor(attention_op=attention_op) |
| | else: |
| | if is_lora: |
| | processor = LoRAAttnProcessor( |
| | hidden_size=self.processor.hidden_size, |
| | cross_attention_dim=self.processor.cross_attention_dim, |
| | rank=self.processor.rank, |
| | ) |
| | processor.load_state_dict(self.processor.state_dict()) |
| | processor.to(self.processor.to_q_lora.up.weight.device) |
| | else: |
| | processor = AttnProcessor() |
| |
|
| | self.set_processor(processor) |
| |
|
| | def set_attention_slice(self, slice_size): |
| | if slice_size is not None and slice_size > self.sliceable_head_dim: |
| | raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") |
| |
|
| | if slice_size is not None and self.added_kv_proj_dim is not None: |
| | processor = SlicedAttnAddedKVProcessor(slice_size) |
| | elif slice_size is not None: |
| | processor = SlicedAttnProcessor(slice_size) |
| | elif self.added_kv_proj_dim is not None: |
| | processor = AttnAddedKVProcessor() |
| | else: |
| | processor = AttnProcessor() |
| |
|
| | self.set_processor(processor) |
| |
|
| | def set_processor(self, processor: "AttnProcessor"): |
| | |
| | |
| | if ( |
| | hasattr(self, "processor") |
| | and isinstance(self.processor, torch.nn.Module) |
| | and not isinstance(processor, torch.nn.Module) |
| | ): |
| | logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") |
| | self._modules.pop("processor") |
| |
|
| | self.processor = processor |
| |
|
| | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): |
| | |
| | |
| | |
| | return self.processor( |
| | self, |
| | hidden_states, |
| | encoder_hidden_states=encoder_hidden_states, |
| | attention_mask=attention_mask, |
| | **cross_attention_kwargs, |
| | ) |
| |
|
| | def batch_to_head_dim(self, tensor): |
| | head_size = self.heads |
| | batch_size, seq_len, dim = tensor.shape |
| | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) |
| | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) |
| | return tensor |
| |
|
| | def head_to_batch_dim(self, tensor, out_dim=3): |
| | head_size = self.heads |
| | batch_size, seq_len, dim = tensor.shape |
| | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) |
| | tensor = tensor.permute(0, 2, 1, 3) |
| |
|
| | if out_dim == 3: |
| | tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) |
| |
|
| | return tensor |
| |
|
| | def get_attention_scores(self, query, key, attention_mask=None): |
| | dtype = query.dtype |
| | if self.upcast_attention: |
| | query = query.float() |
| | key = key.float() |
| |
|
| | if attention_mask is None: |
| | baddbmm_input = torch.empty( |
| | query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device |
| | ) |
| | beta = 0 |
| | else: |
| | baddbmm_input = attention_mask |
| | beta = 1 |
| |
|
| | attention_scores = torch.baddbmm( |
| | baddbmm_input, |
| | query, |
| | key.transpose(-1, -2), |
| | beta=beta, |
| | alpha=self.scale, |
| | ) |
| |
|
| | if self.upcast_softmax: |
| | attention_scores = attention_scores.float() |
| |
|
| | attention_probs = attention_scores.softmax(dim=-1) |
| | attention_probs = attention_probs.to(dtype) |
| |
|
| | return attention_probs |
| |
|
| | def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3): |
| | if batch_size is None: |
| | deprecate( |
| | "batch_size=None", |
| | "0.0.15", |
| | ( |
| | "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect" |
| | " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to" |
| | " `prepare_attention_mask` when preparing the attention_mask." |
| | ), |
| | ) |
| | batch_size = 1 |
| |
|
| | head_size = self.heads |
| | if attention_mask is None: |
| | return attention_mask |
| |
|
| | if attention_mask.shape[-1] != target_length: |
| | if attention_mask.device.type == "mps": |
| | |
| | |
| | padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) |
| | padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) |
| | attention_mask = torch.cat([attention_mask, padding], dim=2) |
| | else: |
| | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) |
| |
|
| | if out_dim == 3: |
| | if attention_mask.shape[0] < batch_size * head_size: |
| | attention_mask = attention_mask.repeat_interleave(head_size, dim=0) |
| | elif out_dim == 4: |
| | attention_mask = attention_mask.unsqueeze(1) |
| | attention_mask = attention_mask.repeat_interleave(head_size, dim=1) |
| |
|
| | return attention_mask |
| |
|
| | def norm_encoder_hidden_states(self, encoder_hidden_states): |
| | assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" |
| |
|
| | if isinstance(self.norm_cross, nn.LayerNorm): |
| | encoder_hidden_states = self.norm_cross(encoder_hidden_states) |
| | elif isinstance(self.norm_cross, nn.GroupNorm): |
| | |
| | |
| | |
| | |
| | |
| | encoder_hidden_states = encoder_hidden_states.transpose(1, 2) |
| | encoder_hidden_states = self.norm_cross(encoder_hidden_states) |
| | encoder_hidden_states = encoder_hidden_states.transpose(1, 2) |
| | else: |
| | assert False |
| |
|
| | return encoder_hidden_states |
| |
|
| |
|
| | class AttnProcessor: |
| | def __call__( |
| | self, |
| | attn: Attention, |
| | hidden_states, |
| | encoder_hidden_states=None, |
| | attention_mask=None, |
| | ): |
| | batch_size, sequence_length, _ = ( |
| | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
| | ) |
| | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
| | query = attn.to_q(hidden_states) |
| |
|
| | if encoder_hidden_states is None: |
| | encoder_hidden_states = hidden_states |
| | elif attn.norm_cross: |
| | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
| |
|
| | key = attn.to_k(encoder_hidden_states) |
| | value = attn.to_v(encoder_hidden_states) |
| |
|
| | query = attn.head_to_batch_dim(query) |
| | key = attn.head_to_batch_dim(key) |
| | value = attn.head_to_batch_dim(value) |
| |
|
| | attention_probs = attn.get_attention_scores(query, key, attention_mask) |
| | hidden_states = torch.bmm(attention_probs, value) |
| | hidden_states = attn.batch_to_head_dim(hidden_states) |
| |
|
| | |
| | hidden_states = attn.to_out[0](hidden_states) |
| | |
| | hidden_states = attn.to_out[1](hidden_states) |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class HRALinearLayer(nn.Module): |
| | def __init__(self, in_features, out_features, bias=False, r=8, apply_GS=False): |
| | super(HRALinearLayer, self).__init__() |
| |
|
| | self.in_features=in_features |
| | self.out_features=out_features |
| | |
| | self.register_buffer('cross_attention_dim', torch.tensor(in_features)) |
| | self.register_buffer('hidden_size', torch.tensor(out_features)) |
| | |
| | self.r = r |
| | self.apply_GS = apply_GS |
| | |
| | half_u = torch.zeros(in_features, r // 2) |
| | nn.init.kaiming_uniform_(half_u, a=math.sqrt(5)) |
| | self.hra_u = nn.Parameter(torch.repeat_interleave(half_u, 2, dim=1), requires_grad=True) |
| |
|
| | def forward(self, attn, x): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | orig_weight = attn.weight.data |
| | if self.apply_GS: |
| | weight = [(self.hra_u[:, 0] / self.hra_u[:, 0].norm()).view(-1, 1)] |
| | for i in range(1, self.r): |
| | ui = self.hra_u[:, i].view(-1, 1) |
| | for j in range(i): |
| | ui = ui - (weight[j].t() @ ui) * weight[j] |
| | weight.append((ui / ui.norm()).view(-1, 1)) |
| | weight = torch.cat(weight, dim=1) |
| | new_weight = orig_weight @ (torch.eye(self.in_features, device=x.device) - 2 * weight @ weight.t()) |
| | |
| | else: |
| | new_weight = orig_weight |
| | hra_u_norm = self.hra_u / self.hra_u.norm(dim=0) |
| | for i in range(self.r): |
| | ui = hra_u_norm[:, i].view(-1, 1) |
| | new_weight = torch.mm(new_weight, torch.eye(self.in_features, device=x.device) - 2 * ui @ ui.t()) |
| |
|
| | out = nn.functional.linear(input=x, weight=new_weight, bias=attn.bias) |
| | return out |
| |
|
| | class HRAAttnProcessor(nn.Module): |
| | def __init__(self, hidden_size, cross_attention_dim=None, r=8, apply_GS=False): |
| | super().__init__() |
| |
|
| | self.hidden_size = hidden_size |
| | self.cross_attention_dim = cross_attention_dim |
| | self.r = r |
| | |
| | self.to_q_hra = HRALinearLayer(hidden_size, hidden_size, r=r, apply_GS=apply_GS) |
| | self.to_k_hra = HRALinearLayer(cross_attention_dim or hidden_size, hidden_size, r=r, apply_GS=apply_GS) |
| | self.to_v_hra = HRALinearLayer(cross_attention_dim or hidden_size, hidden_size, r=r, apply_GS=apply_GS) |
| | self.to_out_hra = HRALinearLayer(hidden_size, hidden_size, r=r, apply_GS=apply_GS) |
| |
|
| | def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): |
| | batch_size, sequence_length, _ = ( |
| | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
| | ) |
| | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
| |
|
| | |
| | |
| | query = self.to_q_hra(attn.to_q, hidden_states) |
| | query = attn.head_to_batch_dim(query) |
| |
|
| | if encoder_hidden_states is None: |
| | encoder_hidden_states = hidden_states |
| | elif attn.norm_cross: |
| | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
| |
|
| | |
| | key = self.to_k_hra(attn.to_k, encoder_hidden_states) |
| | |
| | value = self.to_v_hra(attn.to_v, encoder_hidden_states) |
| |
|
| | key = attn.head_to_batch_dim(key) |
| | value = attn.head_to_batch_dim(value) |
| |
|
| | attention_probs = attn.get_attention_scores(query, key, attention_mask) |
| | hidden_states = torch.bmm(attention_probs, value) |
| | hidden_states = attn.batch_to_head_dim(hidden_states) |
| |
|
| | |
| | |
| | hidden_states = self.to_out_hra(attn.to_out[0], hidden_states) |
| | |
| | hidden_states = attn.to_out[1](hidden_states) |
| |
|
| | return hidden_states |
| | |
| |
|
| | def project(R, eps): |
| | I = torch.zeros((R.size(0), R.size(0)), dtype=R.dtype, device=R.device) |
| | diff = R - I |
| | norm_diff = torch.norm(diff) |
| | if norm_diff <= eps: |
| | return R |
| | else: |
| | return I + eps * (diff / norm_diff) |
| |
|
| | def project_batch(R, eps=1e-5): |
| | |
| | eps = eps * 1 / torch.sqrt(torch.tensor(R.shape[0])) |
| | I = torch.zeros((R.size(1), R.size(1)), device=R.device, dtype=R.dtype).unsqueeze(0).expand_as(R) |
| | diff = R - I |
| | norm_diff = torch.norm(R - I, dim=(1, 2), keepdim=True) |
| | mask = (norm_diff <= eps).bool() |
| | out = torch.where(mask, R, I + eps * (diff / norm_diff)) |
| | return out |
| |
|
| |
|
| | class OFTLinearLayer(nn.Module): |
| | def __init__(self, in_features, out_features, bias=False, block_share=False, eps=6e-5, r=4, is_coft=False): |
| | super(OFTLinearLayer, self).__init__() |
| |
|
| | |
| | self.r = r |
| | |
| | |
| | self.is_coft = is_coft |
| |
|
| | assert in_features % self.r == 0, "in_features must be divisible by r" |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | self.in_features=in_features |
| | self.out_features=out_features |
| |
|
| | self.register_buffer('cross_attention_dim', torch.tensor(in_features)) |
| | self.register_buffer('hidden_size', torch.tensor(out_features)) |
| | |
| | |
| | |
| |
|
| | |
| | self.fix_filt_shape = [in_features, out_features] |
| |
|
| | self.block_share = block_share |
| | |
| | if self.block_share: |
| | |
| | self.R_shape = [in_features // self.r, in_features // self.r] |
| | self.R = nn.Parameter(torch.zeros(self.R_shape[0], self.R_shape[0]), requires_grad=True) |
| | |
| | self.eps = eps * self.R_shape[0] * self.R_shape[0] |
| | else: |
| | |
| | self.R_shape = [self.r, in_features // self.r, in_features // self.r] |
| | R = torch.zeros(self.R_shape[1], self.R_shape[1]) |
| | R = torch.stack([R] * self.r) |
| | self.R = nn.Parameter(R, requires_grad=True) |
| | self.eps = eps * self.R_shape[1] * self.R_shape[1] |
| | |
| | self.tmp = None |
| |
|
| | def forward(self, attn, x): |
| | orig_dtype = x.dtype |
| | dtype = self.R.dtype |
| |
|
| | if self.block_share: |
| | if self.is_coft: |
| | with torch.no_grad(): |
| | self.R.copy_(project(self.R, eps=self.eps)) |
| | orth_rotate = self.cayley(self.R) |
| | else: |
| | if self.is_coft: |
| | with torch.no_grad(): |
| | self.R.copy_(project_batch(self.R, eps=self.eps)) |
| | |
| | orth_rotate = self.cayley_batch(self.R) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | block_diagonal_matrix = self.block_diagonal(orth_rotate) |
| |
|
| | |
| | fix_filt = attn.weight.data |
| | fix_filt = torch.transpose(fix_filt, 0, 1) |
| | filt = torch.mm(block_diagonal_matrix, fix_filt.to(dtype)) |
| | filt = torch.transpose(filt, 0, 1) |
| | |
| | |
| | bias_term = attn.bias.data if attn.bias is not None else None |
| | if bias_term is not None: |
| | bias_term = bias_term.to(orig_dtype) |
| |
|
| | out = nn.functional.linear(input=x.to(orig_dtype), weight=filt.to(orig_dtype), bias=bias_term) |
| | |
| |
|
| | return out |
| |
|
| | def cayley(self, data): |
| | r, c = list(data.shape) |
| | |
| | skew = 0.5 * (data - data.t()) |
| | I = torch.eye(r, device=data.device) |
| | |
| | Q = torch.mm(I - skew, torch.inverse(I + skew)) |
| |
|
| | return Q |
| | |
| | def cayley_batch(self, data): |
| | b, r, c = data.shape |
| | |
| | skew = 0.5 * (data - data.transpose(1, 2)) |
| | |
| | I = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c) |
| |
|
| | |
| | Q = torch.bmm(I - skew, torch.inverse(I + skew)) |
| |
|
| | return Q |
| |
|
| | def block_diagonal(self, R): |
| | if len(R.shape) == 2: |
| | |
| | blocks = [R] * self.r |
| | else: |
| | |
| | blocks = [R[i, ...] for i in range(self.r)] |
| |
|
| | |
| | A = torch.block_diag(*blocks) |
| |
|
| | return A |
| |
|
| | def is_orthogonal(self, R, eps=1e-5): |
| | with torch.no_grad(): |
| | RtR = torch.matmul(R.t(), R) |
| | diff = torch.abs(RtR - torch.eye(R.shape[1], dtype=R.dtype, device=R.device)) |
| | return torch.all(diff < eps) |
| |
|
| | def is_identity_matrix(self, tensor): |
| | if not torch.is_tensor(tensor): |
| | raise TypeError("Input must be a PyTorch tensor.") |
| | if tensor.ndim != 2 or tensor.shape[0] != tensor.shape[1]: |
| | return False |
| | identity = torch.eye(tensor.shape[0], device=tensor.device) |
| | return torch.all(torch.eq(tensor, identity)) |
| |
|
| |
|
| | class OFTAttnProcessor(nn.Module): |
| | def __init__(self, hidden_size, cross_attention_dim=None, eps=2e-5, r=4, is_coft=False): |
| | super().__init__() |
| |
|
| | self.hidden_size = hidden_size |
| | self.cross_attention_dim = cross_attention_dim |
| | self.r = r |
| | self.is_coft = is_coft |
| | |
| | self.to_q_oft = OFTLinearLayer(hidden_size, hidden_size, eps=eps, r=r, is_coft=is_coft) |
| | self.to_k_oft = OFTLinearLayer(cross_attention_dim or hidden_size, hidden_size, eps=eps, r=r, is_coft=is_coft) |
| | self.to_v_oft = OFTLinearLayer(cross_attention_dim or hidden_size, hidden_size, eps=eps, r=r, is_coft=is_coft) |
| | self.to_out_oft = OFTLinearLayer(hidden_size, hidden_size, eps=eps, r=r, is_coft=is_coft) |
| |
|
| | def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): |
| | batch_size, sequence_length, _ = ( |
| | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
| | ) |
| | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
| |
|
| | |
| | |
| | query = self.to_q_oft(attn.to_q, hidden_states) |
| | query = attn.head_to_batch_dim(query) |
| |
|
| | if encoder_hidden_states is None: |
| | encoder_hidden_states = hidden_states |
| | elif attn.norm_cross: |
| | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
| |
|
| | |
| | key = self.to_k_oft(attn.to_k, encoder_hidden_states) |
| | |
| | value = self.to_v_oft(attn.to_v, encoder_hidden_states) |
| |
|
| | key = attn.head_to_batch_dim(key) |
| | value = attn.head_to_batch_dim(value) |
| |
|
| | attention_probs = attn.get_attention_scores(query, key, attention_mask) |
| | hidden_states = torch.bmm(attention_probs, value) |
| | hidden_states = attn.batch_to_head_dim(hidden_states) |
| |
|
| | |
| | |
| | hidden_states = self.to_out_oft(attn.to_out[0], hidden_states) |
| | |
| | hidden_states = attn.to_out[1](hidden_states) |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class AttnAddedKVProcessor: |
| | def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): |
| | residual = hidden_states |
| | hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) |
| | batch_size, sequence_length, _ = hidden_states.shape |
| |
|
| | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
| |
|
| | if encoder_hidden_states is None: |
| | encoder_hidden_states = hidden_states |
| | elif attn.norm_cross: |
| | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
| |
|
| | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
| |
|
| | query = attn.to_q(hidden_states) |
| | query = attn.head_to_batch_dim(query) |
| |
|
| | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) |
| | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) |
| | encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) |
| | encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) |
| |
|
| | if not attn.only_cross_attention: |
| | key = attn.to_k(hidden_states) |
| | value = attn.to_v(hidden_states) |
| | key = attn.head_to_batch_dim(key) |
| | value = attn.head_to_batch_dim(value) |
| | key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) |
| | value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) |
| | else: |
| | key = encoder_hidden_states_key_proj |
| | value = encoder_hidden_states_value_proj |
| |
|
| | attention_probs = attn.get_attention_scores(query, key, attention_mask) |
| | hidden_states = torch.bmm(attention_probs, value) |
| | hidden_states = attn.batch_to_head_dim(hidden_states) |
| |
|
| | |
| | hidden_states = attn.to_out[0](hidden_states) |
| | |
| | hidden_states = attn.to_out[1](hidden_states) |
| |
|
| | hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) |
| | hidden_states = hidden_states + residual |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class AttnAddedKVProcessor2_0: |
| | def __init__(self): |
| | if not hasattr(F, "scaled_dot_product_attention"): |
| | raise ImportError( |
| | "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." |
| | ) |
| |
|
| | def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): |
| | residual = hidden_states |
| | hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) |
| | batch_size, sequence_length, _ = hidden_states.shape |
| |
|
| | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4) |
| |
|
| | if encoder_hidden_states is None: |
| | encoder_hidden_states = hidden_states |
| | elif attn.norm_cross: |
| | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
| |
|
| | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
| |
|
| | query = attn.to_q(hidden_states) |
| | query = attn.head_to_batch_dim(query, out_dim=4) |
| |
|
| | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) |
| | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) |
| | encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4) |
| | encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) |
| |
|
| | if not attn.only_cross_attention: |
| | key = attn.to_k(hidden_states) |
| | value = attn.to_v(hidden_states) |
| | key = attn.head_to_batch_dim(key, out_dim=4) |
| | value = attn.head_to_batch_dim(value, out_dim=4) |
| | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) |
| | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) |
| | else: |
| | key = encoder_hidden_states_key_proj |
| | value = encoder_hidden_states_value_proj |
| |
|
| | |
| | |
| | hidden_states = F.scaled_dot_product_attention( |
| | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
| | ) |
| | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) |
| |
|
| | |
| | hidden_states = attn.to_out[0](hidden_states) |
| | |
| | hidden_states = attn.to_out[1](hidden_states) |
| |
|
| | hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) |
| | hidden_states = hidden_states + residual |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class XFormersAttnProcessor: |
| | def __init__(self, attention_op: Optional[Callable] = None): |
| | self.attention_op = attention_op |
| |
|
| | def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): |
| | batch_size, sequence_length, _ = ( |
| | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
| | ) |
| |
|
| | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
| |
|
| | query = attn.to_q(hidden_states) |
| |
|
| | if encoder_hidden_states is None: |
| | encoder_hidden_states = hidden_states |
| | elif attn.norm_cross: |
| | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
| |
|
| | key = attn.to_k(encoder_hidden_states) |
| | value = attn.to_v(encoder_hidden_states) |
| |
|
| | query = attn.head_to_batch_dim(query).contiguous() |
| | key = attn.head_to_batch_dim(key).contiguous() |
| | value = attn.head_to_batch_dim(value).contiguous() |
| |
|
| | hidden_states = xformers.ops.memory_efficient_attention( |
| | query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale |
| | ) |
| | hidden_states = hidden_states.to(query.dtype) |
| | hidden_states = attn.batch_to_head_dim(hidden_states) |
| |
|
| | |
| | hidden_states = attn.to_out[0](hidden_states) |
| | |
| | hidden_states = attn.to_out[1](hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class AttnProcessor2_0: |
| | def __init__(self): |
| | if not hasattr(F, "scaled_dot_product_attention"): |
| | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
| |
|
| | def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): |
| | batch_size, sequence_length, _ = ( |
| | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
| | ) |
| | inner_dim = hidden_states.shape[-1] |
| |
|
| | if attention_mask is not None: |
| | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
| | |
| | |
| | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
| |
|
| | query = attn.to_q(hidden_states) |
| |
|
| | if encoder_hidden_states is None: |
| | encoder_hidden_states = hidden_states |
| | elif attn.norm_cross: |
| | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
| |
|
| | key = attn.to_k(encoder_hidden_states) |
| | value = attn.to_v(encoder_hidden_states) |
| |
|
| | head_dim = inner_dim // attn.heads |
| | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| |
|
| | |
| | |
| | hidden_states = F.scaled_dot_product_attention( |
| | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
| | ) |
| |
|
| | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
| | hidden_states = hidden_states.to(query.dtype) |
| |
|
| | |
| | hidden_states = attn.to_out[0](hidden_states) |
| | |
| | hidden_states = attn.to_out[1](hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class SlicedAttnProcessor: |
| | def __init__(self, slice_size): |
| | self.slice_size = slice_size |
| |
|
| | def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): |
| | batch_size, sequence_length, _ = ( |
| | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
| | ) |
| | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
| |
|
| | query = attn.to_q(hidden_states) |
| | dim = query.shape[-1] |
| | query = attn.head_to_batch_dim(query) |
| |
|
| | if encoder_hidden_states is None: |
| | encoder_hidden_states = hidden_states |
| | elif attn.norm_cross: |
| | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
| |
|
| | key = attn.to_k(encoder_hidden_states) |
| | value = attn.to_v(encoder_hidden_states) |
| | key = attn.head_to_batch_dim(key) |
| | value = attn.head_to_batch_dim(value) |
| |
|
| | batch_size_attention, query_tokens, _ = query.shape |
| | hidden_states = torch.zeros( |
| | (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype |
| | ) |
| |
|
| | for i in range(batch_size_attention // self.slice_size): |
| | start_idx = i * self.slice_size |
| | end_idx = (i + 1) * self.slice_size |
| |
|
| | query_slice = query[start_idx:end_idx] |
| | key_slice = key[start_idx:end_idx] |
| | attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None |
| |
|
| | attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) |
| |
|
| | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) |
| |
|
| | hidden_states[start_idx:end_idx] = attn_slice |
| |
|
| | hidden_states = attn.batch_to_head_dim(hidden_states) |
| |
|
| | |
| | hidden_states = attn.to_out[0](hidden_states) |
| | |
| | hidden_states = attn.to_out[1](hidden_states) |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class SlicedAttnAddedKVProcessor: |
| | def __init__(self, slice_size): |
| | self.slice_size = slice_size |
| |
|
| | def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None): |
| | residual = hidden_states |
| | hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) |
| |
|
| | batch_size, sequence_length, _ = hidden_states.shape |
| |
|
| | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
| |
|
| | if encoder_hidden_states is None: |
| | encoder_hidden_states = hidden_states |
| | elif attn.norm_cross: |
| | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
| |
|
| | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
| |
|
| | query = attn.to_q(hidden_states) |
| | dim = query.shape[-1] |
| | query = attn.head_to_batch_dim(query) |
| |
|
| | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) |
| | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) |
| |
|
| | encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) |
| | encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) |
| |
|
| | if not attn.only_cross_attention: |
| | key = attn.to_k(hidden_states) |
| | value = attn.to_v(hidden_states) |
| | key = attn.head_to_batch_dim(key) |
| | value = attn.head_to_batch_dim(value) |
| | key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) |
| | value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) |
| | else: |
| | key = encoder_hidden_states_key_proj |
| | value = encoder_hidden_states_value_proj |
| |
|
| | batch_size_attention, query_tokens, _ = query.shape |
| | hidden_states = torch.zeros( |
| | (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype |
| | ) |
| |
|
| | for i in range(batch_size_attention // self.slice_size): |
| | start_idx = i * self.slice_size |
| | end_idx = (i + 1) * self.slice_size |
| |
|
| | query_slice = query[start_idx:end_idx] |
| | key_slice = key[start_idx:end_idx] |
| | attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None |
| |
|
| | attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) |
| |
|
| | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) |
| |
|
| | hidden_states[start_idx:end_idx] = attn_slice |
| |
|
| | hidden_states = attn.batch_to_head_dim(hidden_states) |
| |
|
| | |
| | hidden_states = attn.to_out[0](hidden_states) |
| | |
| | hidden_states = attn.to_out[1](hidden_states) |
| |
|
| | hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) |
| | hidden_states = hidden_states + residual |
| |
|
| | return hidden_states |
| |
|
| |
|
| | AttentionProcessor = Union[ |
| | AttnProcessor, |
| | AttnProcessor2_0, |
| | XFormersAttnProcessor, |
| | SlicedAttnProcessor, |
| | AttnAddedKVProcessor, |
| | SlicedAttnAddedKVProcessor, |
| | AttnAddedKVProcessor2_0, |
| | OFTAttnProcessor, |
| | HRAAttnProcessor |
| | ] |
| |
|