Diffusers
Safetensors
icip_source_2 / midi /models /attention_processor.py
hansQAQ's picture
Upload folder using huggingface_hub
278bf35 verified
from typing import Callable, List, Optional, Tuple, Union, Any
import torch
import torch.nn.functional as F
from diffusers.models.attention_processor import Attention
from diffusers.utils import logging
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph
from einops import rearrange
from torch import nn
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class TripoSGAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the TripoSG model. It applies a s normalization layer and rotary embedding on query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
# instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
if not attn.is_cross_attention:
qkv = torch.cat((query, key, value), dim=-1)
split_size = qkv.shape[-1] // attn.heads // 3
qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
kv = torch.cat((key, value), dim=-1)
split_size = kv.shape[-1] // attn.heads // 2
kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
key, value = torch.split(kv, split_size, dim=-1)
head_dim = key.shape[-1]
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class FusedTripoSGAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedTripoSGAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)
# NOTE that pre-trained split heads first, then split qkv
if encoder_hidden_states is None:
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // attn.heads // 3
qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states
)
query = attn.to_q(hidden_states)
kv = attn.to_kv(encoder_hidden_states)
split_size = kv.shape[-1] // attn.heads // 2
kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
key, value = torch.split(kv, split_size, dim=-1)
head_dim = key.shape[-1]
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class MIAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the MIDI model. It applies a normalization layer and rotary embedding on query and key vector.
"""
def __init__(self, use_mi: bool = True):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.use_mi = use_mi
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
num_instances: Optional[Union[int, torch.IntTensor]] = None,
num_instances_per_batch: Optional[int] = None,
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
# instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
if not attn.is_cross_attention:
qkv = torch.cat((query, key, value), dim=-1)
split_size = qkv.shape[-1] // attn.heads // 3
qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
kv = torch.cat((key, value), dim=-1)
split_size = kv.shape[-1] // attn.heads // 2
kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
key, value = torch.split(kv, split_size, dim=-1)
head_dim = key.shape[-1]
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
if not self.use_mi:
hidden_states = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
)
elif num_instances is not None and num_instances_per_batch is None:
# for inference
key = rearrange(
key, "(b ni) h nt c -> b h (ni nt) c", ni=num_instances
).repeat_interleave(num_instances, dim=0)
value = rearrange(
value, "(b ni) h nt c -> b h (ni nt) c", ni=num_instances
).repeat_interleave(num_instances, dim=0)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = F.scaled_dot_product_attention(
query,
key,
value,
dropout_p=0.0,
is_causal=False,
)
elif num_instances is not None and num_instances_per_batch is not None:
# for training (the same batch size is required)
# process multi-instance attention among the first `num_instances` samples
# and do object self-attention on the left samples in per batch
patch_hidden_states = []
start_idx = 0
while start_idx < batch_size: # for classifier-free guidance
for num in num_instances:
# Multi-object self-attention
query_ = query[start_idx : start_idx + num]
key_ = rearrange(
key[start_idx : start_idx + num],
"(b ni) h nt c -> b h (ni nt) c",
ni=num,
).repeat_interleave(num, dim=0)
value_ = rearrange(
value[start_idx : start_idx + num],
"(b ni) h nt c -> b h (ni nt) c",
ni=num,
).repeat_interleave(num, dim=0)
patch_hidden_states.append(
F.scaled_dot_product_attention(
query_,
key_,
value_,
dropout_p=0.0,
is_causal=False,
)
)
# Single-object self-attention for padding and regularization
query_ = query[
start_idx + num : start_idx + num_instances_per_batch
]
key_ = key[start_idx + num : start_idx + num_instances_per_batch]
value_ = value[
start_idx + num : start_idx + num_instances_per_batch
]
if query_.shape[0] > 0:
patch_hidden_states.append(
F.scaled_dot_product_attention(
query_,
key_,
value_,
dropout_p=0.0,
is_causal=False,
)
)
start_idx += num_instances_per_batch
hidden_states = torch.cat(patch_hidden_states, dim=0)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class SketchFusionAttnProcessor:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the TripoSG model. It applies a s normalization layer and rotary embedding on query and key vector.
"""
# @TODO: Spatial Gating Attention, We enforce the model to focus on sketch edges and lines. Intuitively, the control of focus
# area is not rigid. As we discussed, sketch latent in shallow layers might contain more low-level geometry infos
# while deep layers contains semantic infos. As sketch latent token of each patch not only contains infos merely
# in this patch, the latent token is like 'blurring' during forward computing of ViT, so more infos from other patches
# might be involved in this patch's token. This indicates that suppress attention scores of this 'no-drawing-line'
# patch's token in deep layers naively might be unconvincing. Anyway, we will provide two method discussed above
# and test their effectiveness in practical.
# @TODO: 我们暂时就
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
num_instances: Optional[Union[int, torch.IntTensor]] = None,
gating_map: Optional[torch.Tensor] = None, # Check:
gating_intensity: Optional[torch.Tensor] = None, # Check: sketch.sketch_utils.get_sketch_spatial_gating_map
num_instances_per_batch: Optional[Any] = None,
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb
# print('intensity',gating_intensity)
residual = hidden_states
# print(f"### {gating_map.shape} ### {encoder_hidden_states.shape}")
if attn.spatial_norm is not None:
# print("这一行到底有没有走?")
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
# instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
if not attn.is_cross_attention:
qkv = torch.cat((query, key, value), dim=-1)
split_size = qkv.shape[-1] // attn.heads // 3
qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
kv = torch.cat((key, value), dim=-1)
split_size = kv.shape[-1] // attn.heads // 2
kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
key, value = torch.split(kv, split_size, dim=-1)
head_dim = key.shape[-1]
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# @TODO: There we should use gating_map and gating_intensity to scale KEY value. This is equivalent to scale attention score.
# However, the gradient of intensity is not the same between two method. But I think it's fine :)
if gating_map is not None and gating_intensity is not None:
assert gating_map.shape[0] == key.shape[0]
if gating_map.ndim == 3:
# [B, L, 1] or [B, 1, L]
gating_map = gating_map.squeeze()
assert gating_map.shape[-1] == key.shape[-2], f"Unequal sequence length of gating map and key vectors. Gating map: {gating_map.shape[-1]} while key values: {key.shape[-2]}"
# print(gating_intensity.shape, ' ', gating_map.shape)
# Move gating_map and intensity to the same device as key
gating_map = gating_map.to(key.device)
# gating_intensity might be a float or a tensor
if isinstance(gating_intensity, torch.Tensor):
gating_intensity = gating_intensity.to(key.device)
suppression = 1.0 - (1.0 - gating_intensity) * (1.0 - gating_map)
# Cast suppression back to key dtype to avoid precision upcasting (float16 * float32 -> float32)
suppression = suppression.to(key.dtype)
key = key * suppression.unsqueeze(1).unsqueeze(-1)
key = key.to(query.dtype)
value = value.to(query.dtype)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# print(f"###DiT AFTER: {torch.isnan(hidden_states).sum().item()} | {torch.max(hidden_states).item()} | {torch.min(hidden_states).item()} | Dtype: {query.dtype} ")
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
# print(f"###到底是哪里: {torch.isnan(hidden_states).sum().item()}")
# print(torch.isnan(attn.to_out[0].weight).sum().item(),f" | {torch.max(attn.to_out[0].weight).item()} | {torch.min(attn.to_out[0].weight).item()}")
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# print(f"###DiT Linear: {torch.isnan(hidden_states).sum().item()}")
# dropout
hidden_states = attn.to_out[1](hidden_states)
# print(f"###DiT Dropout: {torch.isnan(hidden_states).sum().item()}")
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class SketchFusionAttnProcessor2:
r"""
Ablation study
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
num_instances: Optional[Union[int, torch.IntTensor]] = None,
gating_map: Optional[torch.Tensor] = None, # Check:
gating_intensity: Optional[torch.Tensor] = None, # Check: sketch.sketch_utils.get_sketch_spatial_gating_map
num_instances_per_batch: Optional[Any] = None,
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb
# print("HALO")
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
# instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
if not attn.is_cross_attention:
qkv = torch.cat((query, key, value), dim=-1)
split_size = qkv.shape[-1] // attn.heads // 3
qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
kv = torch.cat((key, value), dim=-1)
split_size = kv.shape[-1] // attn.heads // 2
kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
key, value = torch.split(kv, split_size, dim=-1)
head_dim = key.shape[-1]
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states