Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
| import importlib.metadata | |
| import math | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.models import ModelMixin | |
| from diffusers.utils import is_torch_version, logging | |
| from einops import rearrange | |
| try: | |
| from flash_attn import flash_attn_func, flash_attn_qkvpacked_func | |
| except ImportError: | |
| flash_attn_func = None | |
| MEMORY_LAYOUT = { | |
| "flash": ( | |
| lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), | |
| lambda x: x, | |
| ), | |
| "torch": ( | |
| lambda x: x.transpose(1, 2), | |
| lambda x: x.transpose(1, 2), | |
| ), | |
| "vanilla": ( | |
| lambda x: x.transpose(1, 2), | |
| lambda x: x.transpose(1, 2), | |
| ), | |
| } | |
| def attention( | |
| q, | |
| k, | |
| v, | |
| mode="flash", | |
| drop_rate=0, | |
| attn_mask=None, | |
| causal=False, | |
| max_seqlen_q=None, | |
| batch_size=1, | |
| ): | |
| """ | |
| Perform QKV self attention. | |
| Args: | |
| q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. | |
| k (torch.Tensor): Key tensor with shape [b, s1, a, d] | |
| v (torch.Tensor): Value tensor with shape [b, s1, a, d] | |
| mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. | |
| drop_rate (float): Dropout rate in attention map. (default: 0) | |
| attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). | |
| (default: None) | |
| causal (bool): Whether to use causal attention. (default: False) | |
| cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, | |
| used to index into q. | |
| cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, | |
| used to index into kv. | |
| max_seqlen_q (int): The maximum sequence length in the batch of q. | |
| max_seqlen_kv (int): The maximum sequence length in the batch of k and v. | |
| Returns: | |
| torch.Tensor: Output tensor after self attention with shape [b, s, ad] | |
| """ | |
| pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] | |
| if mode == "torch": | |
| if attn_mask is not None and attn_mask.dtype != torch.bool: | |
| attn_mask = attn_mask.to(q.dtype) | |
| x = F.scaled_dot_product_attention( | |
| q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) | |
| elif mode == "flash": | |
| x = flash_attn_func( | |
| q, | |
| k, | |
| v, | |
| ) | |
| # x with shape [(bxs), a, d] | |
| x = x.view(batch_size, max_seqlen_q, x.shape[-2], | |
| x.shape[-1]) # reshape x to [b, s, a, d] | |
| elif mode == "vanilla": | |
| scale_factor = 1 / math.sqrt(q.size(-1)) | |
| b, a, s, _ = q.shape | |
| s1 = k.size(2) | |
| attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) | |
| if causal: | |
| # Only applied to self attention | |
| assert ( | |
| attn_mask | |
| is None), "Causal mask and attn_mask cannot be used together" | |
| temp_mask = torch.ones( | |
| b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0) | |
| attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | |
| attn_bias.to(q.dtype) | |
| if attn_mask is not None: | |
| if attn_mask.dtype == torch.bool: | |
| attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) | |
| else: | |
| attn_bias += attn_mask | |
| # TODO: Maybe force q and k to be float32 to avoid numerical overflow | |
| attn = (q @ k.transpose(-2, -1)) * scale_factor | |
| attn += attn_bias | |
| attn = attn.softmax(dim=-1) | |
| attn = torch.dropout(attn, p=drop_rate, train=True) | |
| x = attn @ v | |
| else: | |
| raise NotImplementedError(f"Unsupported attention mode: {mode}") | |
| x = post_attn_layout(x) | |
| b, s, a, d = x.shape | |
| out = x.reshape(b, s, -1) | |
| return out | |
| class CausalConv1d(nn.Module): | |
| def __init__(self, | |
| chan_in, | |
| chan_out, | |
| kernel_size=3, | |
| stride=1, | |
| dilation=1, | |
| pad_mode='replicate', | |
| **kwargs): | |
| super().__init__() | |
| self.pad_mode = pad_mode | |
| padding = (kernel_size - 1, 0) # T | |
| self.time_causal_padding = padding | |
| self.conv = nn.Conv1d( | |
| chan_in, | |
| chan_out, | |
| kernel_size, | |
| stride=stride, | |
| dilation=dilation, | |
| **kwargs) | |
| def forward(self, x): | |
| x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) | |
| return self.conv(x) | |
| class MotionEncoder_tc(nn.Module): | |
| def __init__(self, | |
| in_dim: int, | |
| hidden_dim: int, | |
| num_heads=int, | |
| need_global=True, | |
| dtype=None, | |
| device=None): | |
| factory_kwargs = {"dtype": dtype, "device": device} | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.need_global = need_global | |
| self.conv1_local = CausalConv1d( | |
| in_dim, hidden_dim // 4 * num_heads, 3, stride=1) | |
| if need_global: | |
| self.conv1_global = CausalConv1d( | |
| in_dim, hidden_dim // 4, 3, stride=1) | |
| self.norm1 = nn.LayerNorm( | |
| hidden_dim // 4, | |
| elementwise_affine=False, | |
| eps=1e-6, | |
| **factory_kwargs) | |
| self.act = nn.SiLU() | |
| self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2) | |
| self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2) | |
| if need_global: | |
| self.final_linear = nn.Linear(hidden_dim, hidden_dim, | |
| **factory_kwargs) | |
| self.norm1 = nn.LayerNorm( | |
| hidden_dim // 4, | |
| elementwise_affine=False, | |
| eps=1e-6, | |
| **factory_kwargs) | |
| self.norm2 = nn.LayerNorm( | |
| hidden_dim // 2, | |
| elementwise_affine=False, | |
| eps=1e-6, | |
| **factory_kwargs) | |
| self.norm3 = nn.LayerNorm( | |
| hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs) | |
| self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) | |
| def forward(self, x): | |
| x = rearrange(x, 'b t c -> b c t') | |
| x_ori = x.clone() | |
| b, c, t = x.shape | |
| x = self.conv1_local(x) | |
| x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads) | |
| x = self.norm1(x) | |
| x = self.act(x) | |
| x = rearrange(x, 'b t c -> b c t') | |
| x = self.conv2(x) | |
| x = rearrange(x, 'b c t -> b t c') | |
| x = self.norm2(x) | |
| x = self.act(x) | |
| x = rearrange(x, 'b t c -> b c t') | |
| x = self.conv3(x) | |
| x = rearrange(x, 'b c t -> b t c') | |
| x = self.norm3(x) | |
| x = self.act(x) | |
| x = rearrange(x, '(b n) t c -> b t n c', b=b) | |
| padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) | |
| x = torch.cat([x, padding], dim=-2) | |
| x_local = x.clone() | |
| if not self.need_global: | |
| return x_local | |
| x = self.conv1_global(x_ori) | |
| x = rearrange(x, 'b c t -> b t c') | |
| x = self.norm1(x) | |
| x = self.act(x) | |
| x = rearrange(x, 'b t c -> b c t') | |
| x = self.conv2(x) | |
| x = rearrange(x, 'b c t -> b t c') | |
| x = self.norm2(x) | |
| x = self.act(x) | |
| x = rearrange(x, 'b t c -> b c t') | |
| x = self.conv3(x) | |
| x = rearrange(x, 'b c t -> b t c') | |
| x = self.norm3(x) | |
| x = self.act(x) | |
| x = self.final_linear(x) | |
| x = rearrange(x, '(b n) t c -> b t n c', b=b) | |
| return x, x_local | |