| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | from typing import Optional, Union, Tuple
|
| |
|
| | import torch
|
| | import torch.nn.functional as F
|
| | from torch import nn
|
| |
|
| | from diffusers.utils import logging
|
| | from diffusers.models.attention_processor import Attention
|
| |
|
| | logger = logging.get_logger(__name__)
|
| |
|
| |
|
| |
|
| | class CustomLiteLACrossAttnProcessor2_0:
|
| | """
|
| | Attention processor for LINEAR CROSS-ATTENTION.
|
| | This correctly uses the `encoder_hidden_states` for keys and values.
|
| | """
|
| | def __init__(self):
|
| | self.kernel_func = nn.ReLU(inplace=False)
|
| | self.eps = 1e-15
|
| | self.pad_val = 1.0
|
| |
|
| |
|
| | def apply_rotary_emb(self, x, freqs_cis):
|
| | cos, sin = freqs_cis
|
| | cos, sin = cos[None, None].to(x.device), sin[None, None].to(x.device)
|
| | x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
| | x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| | return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| |
|
| | def __call__(
|
| | self,
|
| | attn: Attention,
|
| | hidden_states: torch.FloatTensor,
|
| | encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| | attention_mask: Optional[torch.FloatTensor] = None,
|
| | rotary_freqs_cis: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]] = None,
|
| |
|
| | **kwargs,
|
| | ) -> torch.FloatTensor:
|
| |
|
| | input_ndim = hidden_states.ndim
|
| | if input_ndim == 4:
|
| | batch_size, channel, height, width = hidden_states.shape
|
| | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| |
|
| | batch_size = hidden_states.shape[0]
|
| |
|
| |
|
| |
|
| | query = attn.to_q(hidden_states)
|
| |
|
| |
|
| | if encoder_hidden_states is None:
|
| | encoder_hidden_states = hidden_states
|
| |
|
| | key = attn.to_k(encoder_hidden_states)
|
| | value = attn.to_v(encoder_hidden_states)
|
| |
|
| |
|
| | inner_dim = key.shape[-1]
|
| | head_dim = inner_dim // attn.heads
|
| |
|
| | query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
|
| | key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
|
| | value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
|
| |
|
| |
|
| | query = query.permute(0, 1, 3, 2)
|
| |
|
| |
|
| | if rotary_freqs_cis is not None:
|
| | query = self.apply_rotary_emb(query, rotary_freqs_cis)
|
| |
|
| |
|
| | key_freqs = kwargs.get("rotary_freqs_cis_cross", rotary_freqs_cis)
|
| | key = self.apply_rotary_emb(key, key_freqs)
|
| |
|
| |
|
| | query = query.permute(0, 1, 3, 2)
|
| |
|
| |
|
| | query = self.kernel_func(query)
|
| | key = self.kernel_func(key)
|
| |
|
| | query, key, value = query.float(), key.float(), value.float()
|
| | value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
|
| | vk = torch.matmul(value, key)
|
| | hidden_states = torch.matmul(vk, query)
|
| |
|
| | hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
| | hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
|
| |
|
| | hidden_states = hidden_states.to(query.dtype)
|
| |
|
| |
|
| | hidden_states = attn.to_out[0](hidden_states)
|
| |
|
| | hidden_states = attn.to_out[1](hidden_states)
|
| |
|
| | if input_ndim == 4:
|
| | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| |
|
| | return hidden_states
|
| | class CustomLiteLAProcessor2_0:
|
| | """Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
|
| |
|
| | def __init__(self):
|
| | self.kernel_func = nn.ReLU(inplace=False)
|
| | self.eps = 1e-15
|
| | self.pad_val = 1.0
|
| |
|
| | def apply_rotary_emb(
|
| | self,
|
| | x: torch.Tensor,
|
| | freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| | ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| | """
|
| | Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| | to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| | reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| | tensors contain rotary embeddings and are returned as real tensors.
|
| |
|
| | Args:
|
| | x (`torch.Tensor`):
|
| | Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
| | freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| |
|
| | Returns:
|
| | Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| | """
|
| | cos, sin = freqs_cis
|
| | cos = cos[None, None]
|
| | sin = sin[None, None]
|
| | cos, sin = cos.to(x.device), sin.to(x.device)
|
| |
|
| | x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
| | x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| | out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| |
|
| | return out
|
| |
|
| | def __call__(
|
| | self,
|
| | attn: Attention,
|
| | hidden_states: torch.FloatTensor,
|
| | encoder_hidden_states: torch.FloatTensor = None,
|
| | attention_mask: Optional[torch.FloatTensor] = None,
|
| | encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| | rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
| | rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
| | *args,
|
| | **kwargs,
|
| | ) -> torch.FloatTensor:
|
| | hidden_states_len = hidden_states.shape[1]
|
| |
|
| | input_ndim = hidden_states.ndim
|
| | if input_ndim == 4:
|
| | batch_size, channel, height, width = hidden_states.shape
|
| | hidden_states = hidden_states.view(
|
| | batch_size, channel, height * width
|
| | ).transpose(1, 2)
|
| | if encoder_hidden_states is not None:
|
| | context_input_ndim = encoder_hidden_states.ndim
|
| | if context_input_ndim == 4:
|
| | batch_size, channel, height, width = encoder_hidden_states.shape
|
| | encoder_hidden_states = encoder_hidden_states.view(
|
| | batch_size, channel, height * width
|
| | ).transpose(1, 2)
|
| |
|
| | batch_size = hidden_states.shape[0]
|
| |
|
| |
|
| | dtype = hidden_states.dtype
|
| | query = attn.to_q(hidden_states)
|
| | key = attn.to_k(hidden_states)
|
| | value = attn.to_v(hidden_states)
|
| |
|
| |
|
| | has_encoder_hidden_state_proj = (
|
| | hasattr(attn, "add_q_proj")
|
| | and hasattr(attn, "add_k_proj")
|
| | and hasattr(attn, "add_v_proj")
|
| | )
|
| | if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
|
| | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
| | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
| | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
| |
|
| |
|
| | if not attn.is_cross_attention:
|
| | query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
| | key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
| | value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
| | else:
|
| | query = hidden_states
|
| | key = encoder_hidden_states
|
| | value = encoder_hidden_states
|
| |
|
| | inner_dim = key.shape[-1]
|
| | head_dim = inner_dim // attn.heads
|
| |
|
| | query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
|
| | key = (
|
| | key.transpose(-1, -2)
|
| | .reshape(batch_size, attn.heads, head_dim, -1)
|
| | .transpose(-1, -2)
|
| | )
|
| | value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
|
| |
|
| |
|
| |
|
| | query = query.permute(0, 1, 3, 2)
|
| |
|
| |
|
| | if attn.norm_q is not None:
|
| | query = attn.norm_q(query)
|
| | if attn.norm_k is not None:
|
| | key = attn.norm_k(key)
|
| |
|
| |
|
| | if rotary_freqs_cis is not None:
|
| | query = self.apply_rotary_emb(query, rotary_freqs_cis)
|
| | if not attn.is_cross_attention:
|
| | key = self.apply_rotary_emb(key, rotary_freqs_cis)
|
| | elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
|
| | key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
|
| |
|
| |
|
| | query = query.permute(0, 1, 3, 2)
|
| |
|
| | if attention_mask is not None:
|
| |
|
| | attention_mask = attention_mask[:, None, :, None].to(
|
| | key.dtype
|
| | )
|
| | query = query * attention_mask.permute(
|
| | 0, 1, 3, 2
|
| | )
|
| | if not attn.is_cross_attention:
|
| | key = (
|
| | key * attention_mask
|
| | )
|
| | value = value * attention_mask.permute(
|
| | 0, 1, 3, 2
|
| | )
|
| |
|
| | if (
|
| | attn.is_cross_attention
|
| | and encoder_attention_mask is not None
|
| | and has_encoder_hidden_state_proj
|
| | ):
|
| | encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(
|
| | key.dtype
|
| | )
|
| |
|
| | key = key * encoder_attention_mask
|
| | value = value * encoder_attention_mask.permute(
|
| | 0, 1, 3, 2
|
| | )
|
| |
|
| | query = self.kernel_func(query)
|
| | key = self.kernel_func(key)
|
| |
|
| | query, key, value = query.float(), key.float(), value.float()
|
| |
|
| | value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
|
| |
|
| | vk = torch.matmul(value, key)
|
| |
|
| | hidden_states = torch.matmul(vk, query)
|
| |
|
| | if hidden_states.dtype in [torch.float16, torch.bfloat16]:
|
| | hidden_states = hidden_states.float()
|
| |
|
| | hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
| |
|
| | hidden_states = hidden_states.view(
|
| | batch_size, attn.heads * head_dim, -1
|
| | ).permute(0, 2, 1)
|
| |
|
| | hidden_states = hidden_states.to(dtype)
|
| | if encoder_hidden_states is not None:
|
| | encoder_hidden_states = encoder_hidden_states.to(dtype)
|
| |
|
| |
|
| | if (
|
| | encoder_hidden_states is not None
|
| | and not attn.is_cross_attention
|
| | and has_encoder_hidden_state_proj
|
| | ):
|
| | hidden_states, encoder_hidden_states = (
|
| | hidden_states[:, :hidden_states_len],
|
| | hidden_states[:, hidden_states_len:],
|
| | )
|
| |
|
| |
|
| | hidden_states = attn.to_out[0](hidden_states)
|
| |
|
| | hidden_states = attn.to_out[1](hidden_states)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | if input_ndim == 4:
|
| | hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| | batch_size, channel, height, width
|
| | )
|
| | if encoder_hidden_states is not None and context_input_ndim == 4:
|
| | encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
|
| | batch_size, channel, height, width
|
| | )
|
| |
|
| | if torch.get_autocast_gpu_dtype() == torch.float16:
|
| | hidden_states = hidden_states.clip(-65504, 65504)
|
| | if encoder_hidden_states is not None:
|
| | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
| |
|
| | return hidden_states, encoder_hidden_states
|
| |
|
| |
|
| | class CustomerAttnProcessor2_0:
|
| | r"""
|
| | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
| | """
|
| |
|
| | def __init__(self):
|
| | if not hasattr(F, "scaled_dot_product_attention"):
|
| | raise ImportError(
|
| | "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
| | )
|
| |
|
| | def apply_rotary_emb(
|
| | self,
|
| | x: torch.Tensor,
|
| | freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| | ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| | """
|
| | Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| | to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| | reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| | tensors contain rotary embeddings and are returned as real tensors.
|
| |
|
| | Args:
|
| | x (`torch.Tensor`):
|
| | Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
| | freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| |
|
| | Returns:
|
| | Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| | """
|
| | cos, sin = freqs_cis
|
| | cos = cos[None, None]
|
| | sin = sin[None, None]
|
| | cos, sin = cos.to(x.device), sin.to(x.device)
|
| |
|
| | x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
| | x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| | out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| |
|
| | return out
|
| |
|
| | def __call__(
|
| | self,
|
| | attn: Attention,
|
| | hidden_states: torch.FloatTensor,
|
| | encoder_hidden_states: torch.FloatTensor = None,
|
| | attention_mask: Optional[torch.FloatTensor] = None,
|
| | encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| | rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
| | rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
| | *args,
|
| | **kwargs,
|
| | ) -> torch.Tensor:
|
| |
|
| | residual = hidden_states
|
| | input_ndim = hidden_states.ndim
|
| |
|
| | if input_ndim == 4:
|
| | batch_size, channel, height, width = hidden_states.shape
|
| | hidden_states = hidden_states.view(
|
| | batch_size, channel, height * width
|
| | ).transpose(1, 2)
|
| |
|
| | batch_size, sequence_length, _ = (
|
| | hidden_states.shape
|
| | if encoder_hidden_states is None
|
| | else encoder_hidden_states.shape
|
| | )
|
| |
|
| | has_encoder_hidden_state_proj = (
|
| | hasattr(attn, "add_q_proj")
|
| | and hasattr(attn, "add_k_proj")
|
| | and hasattr(attn, "add_v_proj")
|
| | )
|
| |
|
| | if attn.group_norm is not None:
|
| | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
| | 1, 2
|
| | )
|
| |
|
| | query = attn.to_q(hidden_states)
|
| |
|
| | if encoder_hidden_states is None:
|
| | encoder_hidden_states = hidden_states
|
| | elif attn.norm_cross:
|
| | encoder_hidden_states = attn.norm_encoder_hidden_states(
|
| | encoder_hidden_states
|
| | )
|
| |
|
| | key = attn.to_k(encoder_hidden_states)
|
| | value = attn.to_v(encoder_hidden_states)
|
| |
|
| | inner_dim = key.shape[-1]
|
| | head_dim = inner_dim // attn.heads
|
| |
|
| | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| |
|
| | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| |
|
| | if attn.norm_q is not None:
|
| | query = attn.norm_q(query)
|
| | if attn.norm_k is not None:
|
| | key = attn.norm_k(key)
|
| |
|
| |
|
| | if rotary_freqs_cis is not None:
|
| | query = self.apply_rotary_emb(query, rotary_freqs_cis)
|
| | if not attn.is_cross_attention:
|
| | key = self.apply_rotary_emb(key, rotary_freqs_cis)
|
| | elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
|
| | key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
|
| |
|
| | if (
|
| | attn.is_cross_attention
|
| | and encoder_attention_mask is not None
|
| | and has_encoder_hidden_state_proj
|
| | ):
|
| |
|
| |
|
| |
|
| | combined_mask = (
|
| | attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
|
| | )
|
| | attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
|
| | attention_mask = (
|
| | attention_mask[:, None, :, :]
|
| | .expand(-1, attn.heads, -1, -1)
|
| | .to(query.dtype)
|
| | )
|
| |
|
| | elif not attn.is_cross_attention and attention_mask is not None:
|
| | attention_mask = attn.prepare_attention_mask(
|
| | attention_mask, sequence_length, batch_size
|
| | )
|
| |
|
| |
|
| | attention_mask = attention_mask.view(
|
| | batch_size, attn.heads, -1, attention_mask.shape[-1]
|
| | )
|
| |
|
| |
|
| |
|
| | hidden_states = F.scaled_dot_product_attention(
|
| | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| | )
|
| |
|
| | hidden_states = hidden_states.transpose(1, 2).reshape(
|
| | batch_size, -1, attn.heads * head_dim
|
| | )
|
| | hidden_states = hidden_states.to(query.dtype)
|
| |
|
| |
|
| | hidden_states = attn.to_out[0](hidden_states)
|
| |
|
| | hidden_states = attn.to_out[1](hidden_states)
|
| |
|
| | if input_ndim == 4:
|
| | hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| | batch_size, channel, height, width
|
| | )
|
| |
|
| | if attn.residual_connection:
|
| | hidden_states = hidden_states + residual
|
| |
|
| | hidden_states = hidden_states / attn.rescale_output_factor
|
| |
|
| | return hidden_states |