| |
| import math |
| from copy import deepcopy |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.models.attention import FeedForward |
| from diffusers.models.embeddings import ( |
| PixArtAlphaTextProjection, |
| TimestepEmbedding, |
| Timesteps, |
| ) |
| from diffusers.models.modeling_utils import ModelMixin |
| from diffusers.models.normalization import FP32LayerNorm |
| from einops import rearrange |
|
|
| try: |
| from flash_attn_interface import flash_attn_func |
| except: |
| from flash_attn import flash_attn_func |
|
|
| __all__ = ['WanTransformer3DModel'] |
|
|
|
|
| def custom_sdpa(q, k, v): |
| out = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), |
| v.transpose(1, 2)) |
| return out.transpose(1, 2) |
|
|
|
|
| class WanTimeTextImageEmbedding(nn.Module): |
|
|
| def __init__( |
| self, |
| dim, |
| time_freq_dim, |
| time_proj_dim, |
| text_embed_dim, |
| pos_embed_seq_len, |
| ): |
| super().__init__() |
|
|
| self.timesteps_proj = Timesteps(num_channels=time_freq_dim, |
| flip_sin_to_cos=True, |
| downscale_freq_shift=0) |
| self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, |
| time_embed_dim=dim) |
| self.act_fn = nn.SiLU() |
| self.time_proj = nn.Linear(dim, time_proj_dim) |
| self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, |
| dim, |
| act_fn="gelu_tanh") |
|
|
| def forward( |
| self, |
| timestep: torch.Tensor, |
| dtype=None, |
| ): |
| B, L = timestep.shape |
| timestep = timestep.reshape(-1) |
| timestep = self.timesteps_proj(timestep) |
| |
| time_embedder_dtype = self.time_embedder.linear_1.weight.dtype |
| if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: |
| timestep = timestep.to(time_embedder_dtype) |
| temb = self.time_embedder(timestep).to(dtype=dtype) |
| timestep_proj = self.time_proj(self.act_fn(temb)) |
| return temb.reshape(B, L, -1), timestep_proj.reshape(B, L, -1) |
|
|
|
|
| class WanRotaryPosEmbed(nn.Module): |
|
|
| def __init__( |
| self, |
| attention_head_dim, |
| patch_size, |
| max_seq_len, |
| theta=10000.0, |
| ): |
| super().__init__() |
|
|
| self.attention_head_dim = attention_head_dim |
| self.patch_size = patch_size |
| self.max_seq_len = max_seq_len |
| self.theta = theta |
|
|
| self.f_dim = self.attention_head_dim - 2 * (self.attention_head_dim // |
| 3) |
| self.h_dim = self.attention_head_dim // 3 |
| self.w_dim = self.attention_head_dim // 3 |
|
|
| |
| f_freqs_base, h_freqs_base, w_freqs_base = self._precompute_freqs_base( |
| ) |
| self.register_buffer("f_freqs_base", f_freqs_base, persistent=False) |
| self.register_buffer("h_freqs_base", h_freqs_base, persistent=False) |
| self.register_buffer("w_freqs_base", w_freqs_base, persistent=False) |
|
|
| def _precompute_freqs_base(self): |
| |
| f_freqs_base = 1.0 / (self.theta**(torch.arange( |
| 0, self.f_dim, 2)[:(self.f_dim // 2)].double() / self.f_dim)) |
| h_freqs_base = 1.0 / (self.theta**(torch.arange( |
| 0, self.h_dim, 2)[:(self.h_dim // 2)].double() / self.h_dim)) |
| w_freqs_base = 1.0 / (self.theta**(torch.arange( |
| 0, self.w_dim, 2)[:(self.w_dim // 2)].double() / self.w_dim)) |
| return f_freqs_base, h_freqs_base, w_freqs_base |
|
|
| def forward(self, grid_ids): |
| with torch.no_grad(): |
| f_freqs = grid_ids[:, 0, :].unsqueeze(-1) * self.f_freqs_base |
| h_freqs = grid_ids[:, 1, :].unsqueeze(-1) * self.h_freqs_base |
| w_freqs = grid_ids[:, 2, :].unsqueeze(-1) * self.w_freqs_base |
| freqs = torch.cat([f_freqs, h_freqs, w_freqs], dim=-1).float() |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
|
|
| return freqs_cis |
|
|
|
|
| class WanAttention(torch.nn.Module): |
|
|
| def __init__( |
| self, |
| dim, |
| heads=8, |
| dim_head=64, |
| eps=1e-5, |
| dropout=0.0, |
| cross_attention_dim_head=None, |
| attn_mode='torch', |
| ): |
| super().__init__() |
| if attn_mode == 'torch': |
| self.attn_op = custom_sdpa |
| elif attn_mode == 'flashattn': |
| self.attn_op = flash_attn_func |
| else: |
| raise ValueError( |
| f"Unsupported attention mode: {attn_mode}, only support torch and flashattn" |
| ) |
|
|
| self.inner_dim = dim_head * heads |
| self.heads = heads |
| self.cross_attention_dim_head = cross_attention_dim_head |
| self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads |
|
|
| self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True) |
| self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) |
| self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) |
| self.to_out = torch.nn.ModuleList([ |
| torch.nn.Linear(self.inner_dim, dim, bias=True), |
| torch.nn.Dropout(dropout), |
| ]) |
| self.norm_q = torch.nn.RMSNorm(dim_head * heads, |
| eps=eps, |
| elementwise_affine=True) |
| self.norm_k = torch.nn.RMSNorm(dim_head * heads, |
| eps=eps, |
| elementwise_affine=True) |
| self.attn_caches = {} if cross_attention_dim_head is None else None |
|
|
| def clear_pred_cache(self, cache_name): |
| if self.attn_caches is None: |
| return |
| cache = self.attn_caches[cache_name] |
| is_pred = cache['is_pred'] |
| cache['mask'][is_pred] = False |
|
|
| def clear_cache(self, cache_name): |
| if self.attn_caches is None: |
| return |
| self.attn_caches[cache_name] = None |
|
|
| def init_kv_cache(self, cache_name, total_tolen, num_head, head_dim, |
| device, dtype, batch_size): |
| if self.attn_caches is None: |
| return |
| self.attn_caches[cache_name] = { |
| 'k': |
| torch.empty([batch_size, total_tolen, num_head, head_dim], |
| device=device, |
| dtype=dtype), |
| 'v': |
| torch.empty([batch_size, total_tolen, num_head, head_dim], |
| device=device, |
| dtype=dtype), |
| 'id': |
| torch.full((total_tolen, ), -1, device=device), |
| "mask": |
| torch.zeros((total_tolen, ), dtype=torch.bool, device=device), |
| "is_pred": |
| torch.zeros((total_tolen, ), dtype=torch.bool, device=device), |
| } |
|
|
| def allocate_slots(self, cache_name, key_size): |
| cache = self.attn_caches[cache_name] |
| mask = cache["mask"] |
| ids = cache["id"] |
| free = (~mask).nonzero(as_tuple=False).squeeze(-1) |
|
|
| if free.numel() < key_size: |
| used = mask.nonzero(as_tuple=False).squeeze(-1) |
|
|
| used_ids = ids[used] |
| order = torch.argsort(used_ids) |
| need = key_size - free.numel() |
| to_free = used[order[:need]] |
|
|
| mask[to_free] = False |
| ids[to_free] = -1 |
| free = (~mask).nonzero(as_tuple=False).squeeze(-1) |
|
|
| assert free.numel() >= key_size |
| return free[:key_size] |
|
|
| def _next_cache_id(self, cache_name): |
| ids = self.attn_caches[cache_name]['id'] |
| mask = self.attn_caches[cache_name]['mask'] |
|
|
| if mask.any(): |
| return ids[mask].max() + 1 |
| else: |
| return torch.tensor(0, device=ids.device, dtype=ids.dtype) |
|
|
| def update_cache(self, cache_name, key, value, is_pred): |
| cache = self.attn_caches[cache_name] |
|
|
| key_size = key.shape[1] |
| slots = self.allocate_slots(cache_name, key_size) |
|
|
| new_id = self._next_cache_id(cache_name) |
|
|
| cache['k'][:, slots] = key |
| cache['v'][:, slots] = value |
| cache['mask'][slots] = True |
| cache['id'][slots] = new_id |
| cache['is_pred'][slots] = is_pred |
| return slots |
|
|
| def restore_cache(self, cache_name, slots): |
| self.attn_caches[cache_name]['mask'][slots] = False |
|
|
| def forward( |
| self, |
| q, |
| k, |
| v, |
| rotary_emb, |
| update_cache=0, |
| cache_name='pos', |
| ): |
| kv_cache = self.attn_caches[ |
| cache_name] if self.attn_caches is not None else None |
|
|
| query, key, value = self.to_q(q), self.to_k(k), self.to_v(v) |
| query = self.norm_q(query) |
| query = query.unflatten(2, (self.heads, -1)) |
| key = self.norm_k(key) |
| key = key.unflatten(2, (self.heads, -1)) |
| value = value.unflatten(2, (self.heads, -1)) |
| if rotary_emb is not None: |
|
|
| def apply_rotary_emb(x, freqs): |
| x_out = torch.view_as_complex( |
| x.to(torch.float64).reshape(x.shape[0], x.shape[1], |
| x.shape[2], -1, 2)) |
| x_out = torch.view_as_real(x_out * freqs).flatten(3) |
| return x_out.to(x.dtype) |
| query = apply_rotary_emb(query, rotary_emb) |
| key = apply_rotary_emb(key, rotary_emb) |
| slots = None |
| if kv_cache is not None and kv_cache['k'] is not None: |
| slots = self.update_cache(cache_name, |
| key, |
| value, |
| is_pred=(update_cache == 1)) |
| key_pool = self.attn_caches[cache_name]['k'] |
| value_pool = self.attn_caches[cache_name]['v'] |
| mask = self.attn_caches[cache_name]['mask'] |
| valid = mask.nonzero(as_tuple=False).squeeze(-1) |
| key = key_pool[:, valid] |
| value = value_pool[:, valid] |
|
|
| hidden_states = self.attn_op(query, key, value) |
|
|
| if update_cache == 0: |
| if kv_cache is not None and kv_cache['k'] is not None: |
| self.restore_cache(cache_name, slots) |
|
|
| hidden_states = hidden_states.flatten(2, 3) |
| hidden_states = hidden_states.type_as(query) |
| hidden_states = self.to_out[0](hidden_states) |
| hidden_states = self.to_out[1](hidden_states) |
| return hidden_states |
|
|
|
|
| class WanTransformerBlock(nn.Module): |
|
|
| def __init__( |
| self, |
| dim, |
| ffn_dim, |
| num_heads, |
| cross_attn_norm=False, |
| eps=1e-6, |
| attn_mode: str = "flashattn", |
| ): |
| super().__init__() |
| self.attn_mode = attn_mode |
|
|
| |
| self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) |
| self.attn1 = WanAttention( |
| dim=dim, |
| heads=num_heads, |
| dim_head=dim // num_heads, |
| eps=eps, |
| cross_attention_dim_head=None, |
| attn_mode=attn_mode, |
| ) |
|
|
| |
| self.attn2 = WanAttention( |
| dim=dim, |
| heads=num_heads, |
| dim_head=dim // num_heads, |
| eps=eps, |
| cross_attention_dim_head=dim // num_heads, |
| attn_mode=attn_mode, |
| ) |
| self.norm2 = FP32LayerNorm( |
| dim, eps, |
| elementwise_affine=True) if cross_attn_norm else nn.Identity() |
|
|
| |
| self.ffn = FeedForward(dim, |
| inner_dim=ffn_dim, |
| activation_fn="gelu-approximate") |
| self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) |
|
|
| self.scale_shift_table = nn.Parameter( |
| torch.randn(1, 6, dim) / dim**0.5) |
|
|
| def forward( |
| self, |
| hidden_states, |
| encoder_hidden_states, |
| temb, |
| rotary_emb, |
| update_cache=0, |
| cache_name='pos', |
| ) -> torch.Tensor: |
| temb_scale_shift_table = self.scale_shift_table[None] + temb.float() |
| shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = \ |
| rearrange(temb_scale_shift_table, 'b l n c -> b n l c').chunk(6, dim=1) |
| shift_msa = shift_msa.squeeze(1) |
| scale_msa = scale_msa.squeeze(1) |
| gate_msa = gate_msa.squeeze(1) |
| c_shift_msa = c_shift_msa.squeeze(1) |
| c_scale_msa = c_scale_msa.squeeze(1) |
| c_gate_msa = c_gate_msa.squeeze(1) |
|
|
| |
| norm_hidden_states = (self.norm1(hidden_states.float()) * |
| (1. + scale_msa) + |
| shift_msa).type_as(hidden_states) |
| attn_output = self.attn1(norm_hidden_states, |
| norm_hidden_states, |
| norm_hidden_states, |
| rotary_emb, |
| update_cache=update_cache, |
| cache_name=cache_name) |
| hidden_states = (hidden_states.float() + |
| attn_output * gate_msa).type_as(hidden_states) |
|
|
| |
| norm_hidden_states = self.norm2( |
| hidden_states.float()).type_as(hidden_states) |
| attn_output = self.attn2(norm_hidden_states, |
| encoder_hidden_states, |
| encoder_hidden_states, |
| None, |
| update_cache=0, |
| cache_name=cache_name) |
| hidden_states = hidden_states + attn_output |
|
|
| |
| norm_hidden_states = (self.norm3(hidden_states.float()) * |
| (1. + c_scale_msa) + |
| c_shift_msa).type_as(hidden_states) |
|
|
| ff_output = self.ffn(norm_hidden_states) |
|
|
| hidden_states = (hidden_states.float() + |
| ff_output.float() * c_gate_msa).type_as(hidden_states) |
| return hidden_states |
|
|
|
|
| class WanTransformer3DModel(ModelMixin, ConfigMixin): |
| r""" |
| TODO |
| """ |
|
|
| @register_to_config |
| def __init__(self, |
| patch_size=[1, 2, 2], |
| num_attention_heads=24, |
| attention_head_dim=128, |
| in_channels=48, |
| out_channels=48, |
| action_dim=30, |
| text_dim=4096, |
| freq_dim=256, |
| ffn_dim=14336, |
| num_layers=30, |
| cross_attn_norm=True, |
| eps=1e-06, |
| rope_max_seq_len=1024, |
| pos_embed_seq_len=None, |
| attn_mode="torch"): |
| r""" |
| TODO |
| """ |
| super().__init__() |
| self.patch_size = patch_size |
| self.num_attention_heads = num_attention_heads |
| self.attention_head_dim = attention_head_dim |
| inner_dim = num_attention_heads * attention_head_dim |
| self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, |
| rope_max_seq_len) |
| self.patch_embedding_mlp = nn.Linear( |
| in_channels * patch_size[0] * patch_size[1] * patch_size[2], |
| inner_dim) |
| self.action_embedder = nn.Linear(action_dim, inner_dim) |
| self.condition_embedder = WanTimeTextImageEmbedding( |
| dim=inner_dim, |
| time_freq_dim=freq_dim, |
| time_proj_dim=inner_dim * 6, |
| text_embed_dim=text_dim, |
| pos_embed_seq_len=pos_embed_seq_len, |
| ) |
| self.condition_embedder_action = deepcopy(self.condition_embedder) |
|
|
| self.blocks = nn.ModuleList([ |
| WanTransformerBlock(inner_dim, |
| ffn_dim, |
| num_attention_heads, |
| cross_attn_norm, |
| eps, |
| attn_mode=attn_mode) for _ in range(num_layers) |
| ]) |
|
|
| self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) |
| self.proj_out = nn.Linear(inner_dim, |
| out_channels * math.prod(patch_size)) |
| self.action_proj_out = nn.Linear(inner_dim, action_dim) |
| self.scale_shift_table = nn.Parameter( |
| torch.randn(1, 2, inner_dim) / inner_dim**0.5) |
|
|
| def clear_cache(self, cache_name): |
| for block in self.blocks: |
| block.attn1.clear_cache(cache_name) |
|
|
| def clear_pred_cache(self, cache_name): |
| for block in self.blocks: |
| block.attn1.clear_pred_cache(cache_name) |
|
|
| def create_empty_cache(self, cache_name, attn_window, |
| latent_token_per_chunk, action_token_per_chunk, |
| device, dtype, batch_size): |
| total_tolen = (attn_window // 2) * latent_token_per_chunk + ( |
| attn_window // 2) * action_token_per_chunk |
| for block in self.blocks: |
| block.attn1.init_kv_cache(cache_name, total_tolen, |
| self.num_attention_heads, |
| self.attention_head_dim, device, dtype, batch_size) |
|
|
| def forward( |
| self, |
| input_dict, |
| update_cache=0, |
| cache_name="pos", |
| action_mode=False, |
| ): |
| r""" |
| Forward pass through the diffusion model |
| |
| Args: |
| x (List[Tensor]): |
| List of input video tensors, each with shape [C_in, F, H, W] |
| t (Tensor): |
| Diffusion timesteps tensor of shape [B] |
| context (List[Tensor]): |
| List of text embeddings each with shape [L, C] |
| seq_len (`int`): |
| Maximum sequence length for positional encoding |
| y (List[Tensor], *optional*): |
| Conditional video inputs for image-to-video mode, same shape as x |
| |
| Returns: |
| List[Tensor]: |
| List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] |
| """ |
| if action_mode: |
| latent_hidden_states = rearrange(input_dict['noisy_latents'], |
| 'b c f h w -> b (f h w) c') |
| latent_hidden_states = self.action_embedder( |
| latent_hidden_states) |
| else: |
| latent_hidden_states = rearrange( |
| input_dict['noisy_latents'], |
| 'b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)', |
| p1=self.patch_size[0], |
| p2=self.patch_size[1], |
| p3=self.patch_size[2]) |
| latent_hidden_states = self.patch_embedding_mlp( |
| latent_hidden_states) |
| text_hidden_states = self.condition_embedder.text_embedder( |
| input_dict["text_emb"]) |
|
|
| latent_grid_id = input_dict['grid_id'] |
| rotary_emb = self.rope(latent_grid_id)[:, :, None] |
| pach_scale_h, pach_scale_w = (1, 1) if action_mode else ( |
| self.patch_size[1], self.patch_size[2]) |
|
|
| latent_time_steps = torch.repeat_interleave( |
| input_dict['timesteps'], |
| (input_dict['noisy_latents'].shape[-2] // pach_scale_h) * |
| (input_dict['noisy_latents'].shape[-1] // pach_scale_w), dim=1) |
| current_condition_embedder = self.condition_embedder_action if action_mode else self.condition_embedder |
| temb, timestep_proj = current_condition_embedder( |
| latent_time_steps, dtype=latent_hidden_states.dtype) |
| timestep_proj = timestep_proj.unflatten(2, (6, -1)) |
|
|
| for block in self.blocks: |
| latent_hidden_states = block(latent_hidden_states, |
| text_hidden_states, |
| timestep_proj, |
| rotary_emb, |
| update_cache=update_cache, |
| cache_name=cache_name) |
| temb_scale_shift_table = self.scale_shift_table[None] + temb[:, :, None, ...] |
| shift, scale = rearrange(temb_scale_shift_table, |
| 'b l n c -> b n l c').chunk(2, dim=1) |
| shift = shift.to(latent_hidden_states.device).squeeze(1) |
| scale = scale.to(latent_hidden_states.device).squeeze(1) |
| latent_hidden_states = (self.norm_out(latent_hidden_states.float()) * |
| (1. + scale) + |
| shift).type_as(latent_hidden_states) |
|
|
| if action_mode: |
| latent_hidden_states = self.action_proj_out(latent_hidden_states) |
| else: |
| latent_hidden_states = self.proj_out(latent_hidden_states) |
| latent_hidden_states = rearrange(latent_hidden_states, |
| 'b l (n c) -> b (l n) c', |
| n=math.prod(self.patch_size)) |
|
|
| return latent_hidden_states |
|
|
|
|
| if __name__ == '__main__': |
| model = WanTransformer3DModel(patch_size=[1, 2, 2], |
| num_attention_heads=24, |
| attention_head_dim=128, |
| in_channels=48, |
| out_channels=48, |
| action_dim=30, |
| text_dim=4096, |
| freq_dim=256, |
| ffn_dim=14336, |
| num_layers=30, |
| cross_attn_norm=True, |
| eps=1e-6, |
| rope_max_seq_len=1024, |
| pos_embed_seq_len=None, |
| attn_mode="torch") |
| print(model) |
|
|