|
|
|
|
|
import importlib.metadata |
|
|
import math |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
from diffusers.models import ModelMixin |
|
|
from diffusers.utils import is_torch_version, logging |
|
|
from einops import rearrange |
|
|
|
|
|
try: |
|
|
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func |
|
|
except ImportError: |
|
|
flash_attn_func = None |
|
|
|
|
|
MEMORY_LAYOUT = { |
|
|
"flash": ( |
|
|
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), |
|
|
lambda x: x, |
|
|
), |
|
|
"torch": ( |
|
|
lambda x: x.transpose(1, 2), |
|
|
lambda x: x.transpose(1, 2), |
|
|
), |
|
|
"vanilla": ( |
|
|
lambda x: x.transpose(1, 2), |
|
|
lambda x: x.transpose(1, 2), |
|
|
), |
|
|
} |
|
|
|
|
|
|
|
|
def attention( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
mode="flash", |
|
|
drop_rate=0, |
|
|
attn_mask=None, |
|
|
causal=False, |
|
|
max_seqlen_q=None, |
|
|
batch_size=1, |
|
|
): |
|
|
""" |
|
|
Perform QKV self attention. |
|
|
|
|
|
Args: |
|
|
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. |
|
|
k (torch.Tensor): Key tensor with shape [b, s1, a, d] |
|
|
v (torch.Tensor): Value tensor with shape [b, s1, a, d] |
|
|
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. |
|
|
drop_rate (float): Dropout rate in attention map. (default: 0) |
|
|
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). |
|
|
(default: None) |
|
|
causal (bool): Whether to use causal attention. (default: False) |
|
|
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, |
|
|
used to index into q. |
|
|
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, |
|
|
used to index into kv. |
|
|
max_seqlen_q (int): The maximum sequence length in the batch of q. |
|
|
max_seqlen_kv (int): The maximum sequence length in the batch of k and v. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Output tensor after self attention with shape [b, s, ad] |
|
|
""" |
|
|
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] |
|
|
|
|
|
if mode == "torch": |
|
|
if attn_mask is not None and attn_mask.dtype != torch.bool: |
|
|
attn_mask = attn_mask.to(q.dtype) |
|
|
x = F.scaled_dot_product_attention( |
|
|
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) |
|
|
elif mode == "flash": |
|
|
x = flash_attn_func( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
) |
|
|
|
|
|
x = x.view(batch_size, max_seqlen_q, x.shape[-2], |
|
|
x.shape[-1]) |
|
|
elif mode == "vanilla": |
|
|
scale_factor = 1 / math.sqrt(q.size(-1)) |
|
|
|
|
|
b, a, s, _ = q.shape |
|
|
s1 = k.size(2) |
|
|
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) |
|
|
if causal: |
|
|
|
|
|
assert ( |
|
|
attn_mask |
|
|
is None), "Causal mask and attn_mask cannot be used together" |
|
|
temp_mask = torch.ones( |
|
|
b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0) |
|
|
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) |
|
|
attn_bias.to(q.dtype) |
|
|
|
|
|
if attn_mask is not None: |
|
|
if attn_mask.dtype == torch.bool: |
|
|
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) |
|
|
else: |
|
|
attn_bias += attn_mask |
|
|
|
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * scale_factor |
|
|
attn += attn_bias |
|
|
attn = attn.softmax(dim=-1) |
|
|
attn = torch.dropout(attn, p=drop_rate, train=True) |
|
|
x = attn @ v |
|
|
else: |
|
|
raise NotImplementedError(f"Unsupported attention mode: {mode}") |
|
|
|
|
|
x = post_attn_layout(x) |
|
|
b, s, a, d = x.shape |
|
|
out = x.reshape(b, s, -1) |
|
|
return out |
|
|
|
|
|
|
|
|
class CausalConv1d(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
chan_in, |
|
|
chan_out, |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
dilation=1, |
|
|
pad_mode='replicate', |
|
|
**kwargs): |
|
|
super().__init__() |
|
|
|
|
|
self.pad_mode = pad_mode |
|
|
padding = (kernel_size - 1, 0) |
|
|
self.time_causal_padding = padding |
|
|
|
|
|
self.conv = nn.Conv1d( |
|
|
chan_in, |
|
|
chan_out, |
|
|
kernel_size, |
|
|
stride=stride, |
|
|
dilation=dilation, |
|
|
**kwargs) |
|
|
|
|
|
def forward(self, x): |
|
|
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) |
|
|
return self.conv(x) |
|
|
|
|
|
|
|
|
class MotionEncoder_tc(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
in_dim: int, |
|
|
hidden_dim: int, |
|
|
num_heads=int, |
|
|
need_global=True, |
|
|
dtype=None, |
|
|
device=None): |
|
|
factory_kwargs = {"dtype": dtype, "device": device} |
|
|
super().__init__() |
|
|
|
|
|
self.num_heads = num_heads |
|
|
self.need_global = need_global |
|
|
self.conv1_local = CausalConv1d( |
|
|
in_dim, hidden_dim // 4 * num_heads, 3, stride=1) |
|
|
if need_global: |
|
|
self.conv1_global = CausalConv1d( |
|
|
in_dim, hidden_dim // 4, 3, stride=1) |
|
|
self.norm1 = nn.LayerNorm( |
|
|
hidden_dim // 4, |
|
|
elementwise_affine=False, |
|
|
eps=1e-6, |
|
|
**factory_kwargs) |
|
|
self.act = nn.SiLU() |
|
|
self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2) |
|
|
self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2) |
|
|
|
|
|
if need_global: |
|
|
self.final_linear = nn.Linear(hidden_dim, hidden_dim, |
|
|
**factory_kwargs) |
|
|
|
|
|
self.norm1 = nn.LayerNorm( |
|
|
hidden_dim // 4, |
|
|
elementwise_affine=False, |
|
|
eps=1e-6, |
|
|
**factory_kwargs) |
|
|
|
|
|
self.norm2 = nn.LayerNorm( |
|
|
hidden_dim // 2, |
|
|
elementwise_affine=False, |
|
|
eps=1e-6, |
|
|
**factory_kwargs) |
|
|
|
|
|
self.norm3 = nn.LayerNorm( |
|
|
hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs) |
|
|
|
|
|
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
x = rearrange(x, 'b t c -> b c t') |
|
|
x_ori = x.clone() |
|
|
b, c, t = x.shape |
|
|
x = self.conv1_local(x) |
|
|
x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads) |
|
|
x = self.norm1(x) |
|
|
x = self.act(x) |
|
|
x = rearrange(x, 'b t c -> b c t') |
|
|
x = self.conv2(x) |
|
|
x = rearrange(x, 'b c t -> b t c') |
|
|
x = self.norm2(x) |
|
|
x = self.act(x) |
|
|
x = rearrange(x, 'b t c -> b c t') |
|
|
x = self.conv3(x) |
|
|
x = rearrange(x, 'b c t -> b t c') |
|
|
x = self.norm3(x) |
|
|
x = self.act(x) |
|
|
x = rearrange(x, '(b n) t c -> b t n c', b=b) |
|
|
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) |
|
|
x = torch.cat([x, padding], dim=-2) |
|
|
x_local = x.clone() |
|
|
|
|
|
if not self.need_global: |
|
|
return x_local |
|
|
|
|
|
x = self.conv1_global(x_ori) |
|
|
x = rearrange(x, 'b c t -> b t c') |
|
|
x = self.norm1(x) |
|
|
x = self.act(x) |
|
|
x = rearrange(x, 'b t c -> b c t') |
|
|
x = self.conv2(x) |
|
|
x = rearrange(x, 'b c t -> b t c') |
|
|
x = self.norm2(x) |
|
|
x = self.act(x) |
|
|
x = rearrange(x, 'b t c -> b c t') |
|
|
x = self.conv3(x) |
|
|
x = rearrange(x, 'b c t -> b t c') |
|
|
x = self.norm3(x) |
|
|
x = self.act(x) |
|
|
x = self.final_linear(x) |
|
|
x = rearrange(x, '(b n) t c -> b t n c', b=b) |
|
|
|
|
|
return x, x_local |
|
|
|