| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| from typing import Any |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from ...configuration_utils import ConfigMixin, register_to_config |
| from ...loaders import FromOriginalModelMixin, PeftAdapterMixin |
| from ...utils import apply_lora_scale, logging |
| from ..attention import AttentionMixin, FeedForward |
| from ..cache_utils import CacheMixin |
| from ..modeling_outputs import Transformer2DModelOutput |
| from ..modeling_utils import ModelMixin |
| from ..normalization import FP32LayerNorm |
| from .transformer_wan import ( |
| WanAttention, |
| WanAttnProcessor, |
| WanRotaryPosEmbed, |
| WanTimeTextImageEmbedding, |
| WanTransformerBlock, |
| ) |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class WanVACETransformerBlock(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| ffn_dim: int, |
| num_heads: int, |
| qk_norm: str = "rms_norm_across_heads", |
| cross_attn_norm: bool = False, |
| eps: float = 1e-6, |
| added_kv_proj_dim: int | None = None, |
| apply_input_projection: bool = False, |
| apply_output_projection: bool = False, |
| ): |
| super().__init__() |
|
|
| |
| self.proj_in = None |
| if apply_input_projection: |
| self.proj_in = nn.Linear(dim, dim) |
|
|
| |
| self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) |
| self.attn1 = WanAttention( |
| dim=dim, |
| heads=num_heads, |
| dim_head=dim // num_heads, |
| eps=eps, |
| processor=WanAttnProcessor(), |
| ) |
|
|
| |
| self.attn2 = WanAttention( |
| dim=dim, |
| heads=num_heads, |
| dim_head=dim // num_heads, |
| eps=eps, |
| added_kv_proj_dim=added_kv_proj_dim, |
| processor=WanAttnProcessor(), |
| is_cross_attention=True, |
| ) |
| self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() |
|
|
| |
| self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") |
| self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) |
|
|
| |
| self.proj_out = None |
| if apply_output_projection: |
| self.proj_out = nn.Linear(dim, dim) |
|
|
| self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| control_hidden_states: torch.Tensor, |
| temb: torch.Tensor, |
| rotary_emb: torch.Tensor, |
| ) -> torch.Tensor: |
| if self.proj_in is not None: |
| control_hidden_states = self.proj_in(control_hidden_states) |
| control_hidden_states = control_hidden_states + hidden_states |
|
|
| shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( |
| self.scale_shift_table.to(temb.device) + temb.float() |
| ).chunk(6, dim=1) |
|
|
| |
| norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as( |
| control_hidden_states |
| ) |
| attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) |
| control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states) |
|
|
| |
| norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states) |
| attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) |
| control_hidden_states = control_hidden_states + attn_output |
|
|
| |
| norm_hidden_states = (self.norm3(control_hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( |
| control_hidden_states |
| ) |
| ff_output = self.ffn(norm_hidden_states) |
| control_hidden_states = (control_hidden_states.float() + ff_output.float() * c_gate_msa).type_as( |
| control_hidden_states |
| ) |
|
|
| conditioning_states = None |
| if self.proj_out is not None: |
| conditioning_states = self.proj_out(control_hidden_states) |
|
|
| return conditioning_states, control_hidden_states |
|
|
|
|
| class WanVACETransformer3DModel( |
| ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin |
| ): |
| r""" |
| A Transformer model for video-like data used in the Wan model. |
| |
| Args: |
| patch_size (`tuple[int]`, defaults to `(1, 2, 2)`): |
| 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). |
| num_attention_heads (`int`, defaults to `40`): |
| Fixed length for text embeddings. |
| attention_head_dim (`int`, defaults to `128`): |
| The number of channels in each head. |
| in_channels (`int`, defaults to `16`): |
| The number of channels in the input. |
| out_channels (`int`, defaults to `16`): |
| The number of channels in the output. |
| text_dim (`int`, defaults to `512`): |
| Input dimension for text embeddings. |
| freq_dim (`int`, defaults to `256`): |
| Dimension for sinusoidal time embeddings. |
| ffn_dim (`int`, defaults to `13824`): |
| Intermediate dimension in feed-forward network. |
| num_layers (`int`, defaults to `40`): |
| The number of layers of transformer blocks to use. |
| window_size (`tuple[int]`, defaults to `(-1, -1)`): |
| Window size for local attention (-1 indicates global attention). |
| cross_attn_norm (`bool`, defaults to `True`): |
| Enable cross-attention normalization. |
| qk_norm (`bool`, defaults to `True`): |
| Enable query/key normalization. |
| eps (`float`, defaults to `1e-6`): |
| Epsilon value for normalization layers. |
| add_img_emb (`bool`, defaults to `False`): |
| Whether to use img_emb. |
| added_kv_proj_dim (`int`, *optional*, defaults to `None`): |
| The number of channels to use for the added key and value projections. If `None`, no projection is used. |
| """ |
|
|
| _supports_gradient_checkpointing = True |
| _skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"] |
| _no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"] |
| _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] |
| _keys_to_ignore_on_load_unexpected = ["norm_added_q"] |
| _repeated_blocks = ["WanTransformerBlock", "WanVACETransformerBlock"] |
|
|
| @register_to_config |
| def __init__( |
| self, |
| patch_size: tuple[int, ...] = (1, 2, 2), |
| num_attention_heads: int = 40, |
| attention_head_dim: int = 128, |
| in_channels: int = 16, |
| out_channels: int = 16, |
| text_dim: int = 4096, |
| freq_dim: int = 256, |
| ffn_dim: int = 13824, |
| num_layers: int = 40, |
| cross_attn_norm: bool = True, |
| qk_norm: str | None = "rms_norm_across_heads", |
| eps: float = 1e-6, |
| image_dim: int | None = None, |
| added_kv_proj_dim: int | None = None, |
| rope_max_seq_len: int = 1024, |
| pos_embed_seq_len: int | None = None, |
| vace_layers: list[int] = [0, 5, 10, 15, 20, 25, 30, 35], |
| vace_in_channels: int = 96, |
| ) -> None: |
| super().__init__() |
|
|
| inner_dim = num_attention_heads * attention_head_dim |
| out_channels = out_channels or in_channels |
|
|
| if max(vace_layers) >= num_layers: |
| raise ValueError(f"VACE layers {vace_layers} exceed the number of transformer layers {num_layers}.") |
| if 0 not in vace_layers: |
| raise ValueError("VACE layers must include layer 0.") |
|
|
| |
| self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) |
| self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) |
| self.vace_patch_embedding = nn.Conv3d(vace_in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) |
|
|
| |
| |
| self.condition_embedder = WanTimeTextImageEmbedding( |
| dim=inner_dim, |
| time_freq_dim=freq_dim, |
| time_proj_dim=inner_dim * 6, |
| text_embed_dim=text_dim, |
| image_embed_dim=image_dim, |
| pos_embed_seq_len=pos_embed_seq_len, |
| ) |
|
|
| |
| self.blocks = nn.ModuleList( |
| [ |
| WanTransformerBlock( |
| inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.vace_blocks = nn.ModuleList( |
| [ |
| WanVACETransformerBlock( |
| inner_dim, |
| ffn_dim, |
| num_attention_heads, |
| qk_norm, |
| cross_attn_norm, |
| eps, |
| added_kv_proj_dim, |
| apply_input_projection=i == 0, |
| apply_output_projection=True, |
| ) |
| for i in range(len(vace_layers)) |
| ] |
| ) |
|
|
| |
| self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) |
| self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) |
| self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) |
|
|
| self.gradient_checkpointing = False |
|
|
| @apply_lora_scale("attention_kwargs") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| timestep: torch.LongTensor, |
| encoder_hidden_states: torch.Tensor, |
| encoder_hidden_states_image: torch.Tensor | None = None, |
| control_hidden_states: torch.Tensor = None, |
| control_hidden_states_scale: torch.Tensor = None, |
| return_dict: bool = True, |
| attention_kwargs: dict[str, Any] | None = None, |
| ) -> torch.Tensor | dict[str, torch.Tensor]: |
| batch_size, num_channels, num_frames, height, width = hidden_states.shape |
| p_t, p_h, p_w = self.config.patch_size |
| post_patch_num_frames = num_frames // p_t |
| post_patch_height = height // p_h |
| post_patch_width = width // p_w |
|
|
| if control_hidden_states_scale is None: |
| control_hidden_states_scale = control_hidden_states.new_ones(len(self.config.vace_layers)) |
| control_hidden_states_scale = torch.unbind(control_hidden_states_scale) |
| if len(control_hidden_states_scale) != len(self.config.vace_layers): |
| raise ValueError( |
| f"Length of `control_hidden_states_scale` {len(control_hidden_states_scale)} should be " |
| f"equal to {len(self.config.vace_layers)}." |
| ) |
|
|
| |
| rotary_emb = self.rope(hidden_states) |
|
|
| |
| hidden_states = self.patch_embedding(hidden_states) |
| hidden_states = hidden_states.flatten(2).transpose(1, 2) |
|
|
| control_hidden_states = self.vace_patch_embedding(control_hidden_states) |
| control_hidden_states = control_hidden_states.flatten(2).transpose(1, 2) |
| control_hidden_states_padding = control_hidden_states.new_zeros( |
| batch_size, hidden_states.size(1) - control_hidden_states.size(1), control_hidden_states.size(2) |
| ) |
| control_hidden_states = torch.cat([control_hidden_states, control_hidden_states_padding], dim=1) |
|
|
| |
| temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( |
| timestep, encoder_hidden_states, encoder_hidden_states_image |
| ) |
| timestep_proj = timestep_proj.unflatten(1, (6, -1)) |
|
|
| |
| if encoder_hidden_states_image is not None: |
| encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) |
|
|
| |
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| |
| control_hidden_states_list = [] |
| for i, block in enumerate(self.vace_blocks): |
| conditioning_states, control_hidden_states = self._gradient_checkpointing_func( |
| block, hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb |
| ) |
| control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i])) |
| control_hidden_states_list = control_hidden_states_list[::-1] |
|
|
| for i, block in enumerate(self.blocks): |
| hidden_states = self._gradient_checkpointing_func( |
| block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb |
| ) |
| if i in self.config.vace_layers: |
| control_hint, scale = control_hidden_states_list.pop() |
| hidden_states = hidden_states + control_hint * scale |
| else: |
| |
| control_hidden_states_list = [] |
| for i, block in enumerate(self.vace_blocks): |
| conditioning_states, control_hidden_states = block( |
| hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb |
| ) |
| control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i])) |
| control_hidden_states_list = control_hidden_states_list[::-1] |
|
|
| for i, block in enumerate(self.blocks): |
| hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) |
| if i in self.config.vace_layers: |
| control_hint, scale = control_hidden_states_list.pop() |
| hidden_states = hidden_states + control_hint * scale |
|
|
| |
| shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) |
|
|
| |
| |
| |
| |
| shift = shift.to(hidden_states.device) |
| scale = scale.to(hidden_states.device) |
|
|
| hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) |
| hidden_states = self.proj_out(hidden_states) |
|
|
| hidden_states = hidden_states.reshape( |
| batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 |
| ) |
| hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) |
| output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) |
|
|
| if not return_dict: |
| return (output,) |
|
|
| return Transformer2DModelOutput(sample=output) |
|
|