# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved. import math from copy import deepcopy import torch import torch.nn as nn import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.attention import FeedForward from diffusers.models.embeddings import ( PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, ) from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import FP32LayerNorm from einops import rearrange try: from flash_attn_interface import flash_attn_func except: from flash_attn import flash_attn_func __all__ = ['WanTransformer3DModel'] def custom_sdpa(q, k, v): out = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)) return out.transpose(1, 2) class WanTimeTextImageEmbedding(nn.Module): def __init__( self, dim, time_freq_dim, time_proj_dim, text_embed_dim, pos_embed_seq_len, ): super().__init__() self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) self.act_fn = nn.SiLU() self.time_proj = nn.Linear(dim, time_proj_dim) self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") def forward( self, timestep: torch.Tensor, dtype=None, ): B, L = timestep.shape timestep = timestep.reshape(-1) timestep = self.timesteps_proj(timestep) # time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype time_embedder_dtype = self.time_embedder.linear_1.weight.dtype if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: timestep = timestep.to(time_embedder_dtype) temb = self.time_embedder(timestep).to(dtype=dtype) timestep_proj = self.time_proj(self.act_fn(temb)) return temb.reshape(B, L, -1), timestep_proj.reshape(B, L, -1) class WanRotaryPosEmbed(nn.Module): def __init__( self, attention_head_dim, patch_size, max_seq_len, theta=10000.0, ): super().__init__() self.attention_head_dim = attention_head_dim self.patch_size = patch_size self.max_seq_len = max_seq_len self.theta = theta self.f_dim = self.attention_head_dim - 2 * (self.attention_head_dim // 3) self.h_dim = self.attention_head_dim // 3 self.w_dim = self.attention_head_dim // 3 # Precompute and register buffers f_freqs_base, h_freqs_base, w_freqs_base = self._precompute_freqs_base( ) self.register_buffer("f_freqs_base", f_freqs_base, persistent=False) self.register_buffer("h_freqs_base", h_freqs_base, persistent=False) self.register_buffer("w_freqs_base", w_freqs_base, persistent=False) def _precompute_freqs_base(self): # freqs_base = 1.0 / (theta ** (2k / dim)) f_freqs_base = 1.0 / (self.theta**(torch.arange( 0, self.f_dim, 2)[:(self.f_dim // 2)].double() / self.f_dim)) h_freqs_base = 1.0 / (self.theta**(torch.arange( 0, self.h_dim, 2)[:(self.h_dim // 2)].double() / self.h_dim)) w_freqs_base = 1.0 / (self.theta**(torch.arange( 0, self.w_dim, 2)[:(self.w_dim // 2)].double() / self.w_dim)) return f_freqs_base, h_freqs_base, w_freqs_base def forward(self, grid_ids): with torch.no_grad(): f_freqs = grid_ids[:, 0, :].unsqueeze(-1) * self.f_freqs_base h_freqs = grid_ids[:, 1, :].unsqueeze(-1) * self.h_freqs_base w_freqs = grid_ids[:, 2, :].unsqueeze(-1) * self.w_freqs_base freqs = torch.cat([f_freqs, h_freqs, w_freqs], dim=-1).float() freqs_cis = torch.polar(torch.ones_like(freqs), freqs) return freqs_cis class WanAttention(torch.nn.Module): def __init__( self, dim, heads=8, dim_head=64, eps=1e-5, dropout=0.0, cross_attention_dim_head=None, attn_mode='torch', ): super().__init__() if attn_mode == 'torch': self.attn_op = custom_sdpa elif attn_mode == 'flashattn': self.attn_op = flash_attn_func else: raise ValueError( f"Unsupported attention mode: {attn_mode}, only support torch and flashattn" ) self.inner_dim = dim_head * heads self.heads = heads self.cross_attention_dim_head = cross_attention_dim_head self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True) self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) self.to_out = torch.nn.ModuleList([ torch.nn.Linear(self.inner_dim, dim, bias=True), torch.nn.Dropout(dropout), ]) self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) self.attn_caches = {} if cross_attention_dim_head is None else None def clear_pred_cache(self, cache_name): if self.attn_caches is None: return cache = self.attn_caches[cache_name] is_pred = cache['is_pred'] cache['mask'][is_pred] = False def clear_cache(self, cache_name): if self.attn_caches is None: return self.attn_caches[cache_name] = None def init_kv_cache(self, cache_name, total_tolen, num_head, head_dim, device, dtype, batch_size): if self.attn_caches is None: return self.attn_caches[cache_name] = { 'k': torch.empty([batch_size, total_tolen, num_head, head_dim], device=device, dtype=dtype), 'v': torch.empty([batch_size, total_tolen, num_head, head_dim], device=device, dtype=dtype), 'id': torch.full((total_tolen, ), -1, device=device), "mask": torch.zeros((total_tolen, ), dtype=torch.bool, device=device), "is_pred": torch.zeros((total_tolen, ), dtype=torch.bool, device=device), } def allocate_slots(self, cache_name, key_size): cache = self.attn_caches[cache_name] mask = cache["mask"] ids = cache["id"] free = (~mask).nonzero(as_tuple=False).squeeze(-1) if free.numel() < key_size: used = mask.nonzero(as_tuple=False).squeeze(-1) used_ids = ids[used] order = torch.argsort(used_ids) need = key_size - free.numel() to_free = used[order[:need]] mask[to_free] = False ids[to_free] = -1 free = (~mask).nonzero(as_tuple=False).squeeze(-1) assert free.numel() >= key_size return free[:key_size] def _next_cache_id(self, cache_name): ids = self.attn_caches[cache_name]['id'] mask = self.attn_caches[cache_name]['mask'] if mask.any(): return ids[mask].max() + 1 else: return torch.tensor(0, device=ids.device, dtype=ids.dtype) def update_cache(self, cache_name, key, value, is_pred): cache = self.attn_caches[cache_name] key_size = key.shape[1] slots = self.allocate_slots(cache_name, key_size) new_id = self._next_cache_id(cache_name) cache['k'][:, slots] = key cache['v'][:, slots] = value cache['mask'][slots] = True cache['id'][slots] = new_id cache['is_pred'][slots] = is_pred return slots def restore_cache(self, cache_name, slots): self.attn_caches[cache_name]['mask'][slots] = False def forward( self, q, k, v, rotary_emb, update_cache=0, cache_name='pos', ): kv_cache = self.attn_caches[ cache_name] if self.attn_caches is not None else None query, key, value = self.to_q(q), self.to_k(k), self.to_v(v) query = self.norm_q(query) query = query.unflatten(2, (self.heads, -1)) key = self.norm_k(key) key = key.unflatten(2, (self.heads, -1)) value = value.unflatten(2, (self.heads, -1)) if rotary_emb is not None: def apply_rotary_emb(x, freqs): x_out = torch.view_as_complex( x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)) x_out = torch.view_as_real(x_out * freqs).flatten(3) return x_out.to(x.dtype) query = apply_rotary_emb(query, rotary_emb) key = apply_rotary_emb(key, rotary_emb) slots = None if kv_cache is not None and kv_cache['k'] is not None: slots = self.update_cache(cache_name, key, value, is_pred=(update_cache == 1)) key_pool = self.attn_caches[cache_name]['k'] value_pool = self.attn_caches[cache_name]['v'] mask = self.attn_caches[cache_name]['mask'] valid = mask.nonzero(as_tuple=False).squeeze(-1) key = key_pool[:, valid] value = value_pool[:, valid] hidden_states = self.attn_op(query, key, value) if update_cache == 0: if kv_cache is not None and kv_cache['k'] is not None: self.restore_cache(cache_name, slots) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) hidden_states = self.to_out[0](hidden_states) hidden_states = self.to_out[1](hidden_states) return hidden_states class WanTransformerBlock(nn.Module): def __init__( self, dim, ffn_dim, num_heads, cross_attn_norm=False, eps=1e-6, attn_mode: str = "flashattn", ): super().__init__() self.attn_mode = attn_mode # 1. Self-attention self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) self.attn1 = WanAttention( dim=dim, heads=num_heads, dim_head=dim // num_heads, eps=eps, cross_attention_dim_head=None, attn_mode=attn_mode, ) # 2. Cross-attention self.attn2 = WanAttention( dim=dim, heads=num_heads, dim_head=dim // num_heads, eps=eps, cross_attention_dim_head=dim // num_heads, attn_mode=attn_mode, ) self.norm2 = FP32LayerNorm( dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() # 3. Feed-forward self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) self.scale_shift_table = nn.Parameter( torch.randn(1, 6, dim) / dim**0.5) def forward( self, hidden_states, encoder_hidden_states, temb, rotary_emb, update_cache=0, cache_name='pos', ) -> torch.Tensor: temb_scale_shift_table = self.scale_shift_table[None] + temb.float() shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = \ rearrange(temb_scale_shift_table, 'b l n c -> b n l c').chunk(6, dim=1) shift_msa = shift_msa.squeeze(1) scale_msa = scale_msa.squeeze(1) gate_msa = gate_msa.squeeze(1) c_shift_msa = c_shift_msa.squeeze(1) c_scale_msa = c_scale_msa.squeeze(1) c_gate_msa = c_gate_msa.squeeze(1) # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states.float()) * (1. + scale_msa) + shift_msa).type_as(hidden_states) attn_output = self.attn1(norm_hidden_states, norm_hidden_states, norm_hidden_states, rotary_emb, update_cache=update_cache, cache_name=cache_name) hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) # 2. Cross-attention norm_hidden_states = self.norm2( hidden_states.float()).type_as(hidden_states) attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, encoder_hidden_states, None, update_cache=0, cache_name=cache_name) hidden_states = hidden_states + attn_output # 3. Feed-forward norm_hidden_states = (self.norm3(hidden_states.float()) * (1. + c_scale_msa) + c_shift_msa).type_as(hidden_states) ff_output = self.ffn(norm_hidden_states) hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) return hidden_states class WanTransformer3DModel(ModelMixin, ConfigMixin): r""" TODO """ @register_to_config def __init__(self, patch_size=[1, 2, 2], num_attention_heads=24, attention_head_dim=128, in_channels=48, out_channels=48, action_dim=30, text_dim=4096, freq_dim=256, ffn_dim=14336, num_layers=30, cross_attn_norm=True, eps=1e-06, rope_max_seq_len=1024, pos_embed_seq_len=None, attn_mode="torch"): r""" TODO """ super().__init__() self.patch_size = patch_size self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) self.patch_embedding_mlp = nn.Linear( in_channels * patch_size[0] * patch_size[1] * patch_size[2], inner_dim) self.action_embedder = nn.Linear(action_dim, inner_dim) self.condition_embedder = WanTimeTextImageEmbedding( dim=inner_dim, time_freq_dim=freq_dim, time_proj_dim=inner_dim * 6, text_embed_dim=text_dim, pos_embed_seq_len=pos_embed_seq_len, ) self.condition_embedder_action = deepcopy(self.condition_embedder) self.blocks = nn.ModuleList([ WanTransformerBlock(inner_dim, ffn_dim, num_attention_heads, cross_attn_norm, eps, attn_mode=attn_mode) for _ in range(num_layers) ]) self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) self.action_proj_out = nn.Linear(inner_dim, action_dim) self.scale_shift_table = nn.Parameter( torch.randn(1, 2, inner_dim) / inner_dim**0.5) def clear_cache(self, cache_name): for block in self.blocks: block.attn1.clear_cache(cache_name) def clear_pred_cache(self, cache_name): for block in self.blocks: block.attn1.clear_pred_cache(cache_name) def create_empty_cache(self, cache_name, attn_window, latent_token_per_chunk, action_token_per_chunk, device, dtype, batch_size): total_tolen = (attn_window // 2) * latent_token_per_chunk + ( attn_window // 2) * action_token_per_chunk for block in self.blocks: block.attn1.init_kv_cache(cache_name, total_tolen, self.num_attention_heads, self.attention_head_dim, device, dtype, batch_size) def forward( self, input_dict, update_cache=0, cache_name="pos", action_mode=False, ): r""" Forward pass through the diffusion model Args: x (List[Tensor]): List of input video tensors, each with shape [C_in, F, H, W] t (Tensor): Diffusion timesteps tensor of shape [B] context (List[Tensor]): List of text embeddings each with shape [L, C] seq_len (`int`): Maximum sequence length for positional encoding y (List[Tensor], *optional*): Conditional video inputs for image-to-video mode, same shape as x Returns: List[Tensor]: List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] """ if action_mode: # action input emb latent_hidden_states = rearrange(input_dict['noisy_latents'], 'b c f h w -> b (f h w) c') latent_hidden_states = self.action_embedder( latent_hidden_states) # B L1 C else: # latent input emb latent_hidden_states = rearrange( input_dict['noisy_latents'], 'b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)', p1=self.patch_size[0], p2=self.patch_size[1], p3=self.patch_size[2]) latent_hidden_states = self.patch_embedding_mlp( latent_hidden_states) text_hidden_states = self.condition_embedder.text_embedder( input_dict["text_emb"]) # B L2 C latent_grid_id = input_dict['grid_id'] rotary_emb = self.rope(latent_grid_id)[:, :, None] # 1 L 1 C pach_scale_h, pach_scale_w = (1, 1) if action_mode else ( self.patch_size[1], self.patch_size[2]) latent_time_steps = torch.repeat_interleave( input_dict['timesteps'], (input_dict['noisy_latents'].shape[-2] // pach_scale_h) * (input_dict['noisy_latents'].shape[-1] // pach_scale_w), dim=1) # L current_condition_embedder = self.condition_embedder_action if action_mode else self.condition_embedder temb, timestep_proj = current_condition_embedder( latent_time_steps, dtype=latent_hidden_states.dtype) timestep_proj = timestep_proj.unflatten(2, (6, -1)) # B L 6 C for block in self.blocks: latent_hidden_states = block(latent_hidden_states, text_hidden_states, timestep_proj, rotary_emb, update_cache=update_cache, cache_name=cache_name) temb_scale_shift_table = self.scale_shift_table[None] + temb[:, :, None, ...] shift, scale = rearrange(temb_scale_shift_table, 'b l n c -> b n l c').chunk(2, dim=1) shift = shift.to(latent_hidden_states.device).squeeze(1) scale = scale.to(latent_hidden_states.device).squeeze(1) latent_hidden_states = (self.norm_out(latent_hidden_states.float()) * (1. + scale) + shift).type_as(latent_hidden_states) if action_mode: latent_hidden_states = self.action_proj_out(latent_hidden_states) else: latent_hidden_states = self.proj_out(latent_hidden_states) latent_hidden_states = rearrange(latent_hidden_states, 'b l (n c) -> b (l n) c', n=math.prod(self.patch_size)) # return latent_hidden_states if __name__ == '__main__': model = WanTransformer3DModel(patch_size=[1, 2, 2], num_attention_heads=24, attention_head_dim=128, in_channels=48, out_channels=48, action_dim=30, text_dim=4096, freq_dim=256, ffn_dim=14336, num_layers=30, cross_attn_norm=True, eps=1e-6, rope_max_seq_len=1024, pos_embed_seq_len=None, attn_mode="torch") print(model)