Spaces:
Runtime error
Runtime error
| from typing import Any, List, Tuple, Optional, Union, Dict | |
| from einops import rearrange | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from diffusers.models import ModelMixin | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.models.attention import FeedForward | |
| from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps | |
| from modules.models.attention import attention, get_cu_seqlens, get_preferred_attention_backend | |
| from .posemb_layers import apply_rotary_emb, get_nd_rotary_pos_embed | |
| from .modulate_layers import load_modulation, modulate, apply_gate | |
| class RMSNorm(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| elementwise_affine=True, | |
| eps: float = 1e-6, | |
| device=None, | |
| dtype=None, | |
| ): | |
| """ | |
| Initialize the RMSNorm normalization layer. | |
| Args: | |
| dim (int): The dimension of the input tensor. | |
| eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. | |
| Attributes: | |
| eps (float): A small value added to the denominator for numerical stability. | |
| weight (nn.Parameter): Learnable scaling parameter. | |
| """ | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super().__init__() | |
| self.eps = eps | |
| if elementwise_affine: | |
| self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) | |
| def _norm(self, x): | |
| """ | |
| Apply the RMSNorm normalization to the input tensor. | |
| Args: | |
| x (torch.Tensor): The input tensor. | |
| Returns: | |
| torch.Tensor: The normalized tensor. | |
| """ | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| """ | |
| Forward pass through the RMSNorm layer. | |
| Args: | |
| x (torch.Tensor): The input tensor. | |
| Returns: | |
| torch.Tensor: The output tensor after applying RMSNorm. | |
| """ | |
| output = self._norm(x.float()).type_as(x) | |
| if hasattr(self, "weight"): | |
| output = output * self.weight | |
| return output | |
| class MMDoubleStreamBlock(nn.Module): | |
| """ | |
| A multimodal dit block with seperate modulation for | |
| text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206 | |
| (Flux.1): https://github.com/black-forest-labs/flux | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| heads_num: int, | |
| mlp_width_ratio: float, | |
| mlp_act_type: str = "gelu_tanh", | |
| dtype: Optional[torch.dtype] = None, | |
| device: Optional[torch.device] = None, | |
| dit_modulation_type: Optional[str] = "wanx", | |
| attn_backend: str = 'auto', | |
| ): | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super().__init__() | |
| self.attn_backend = get_preferred_attention_backend() if attn_backend == 'auto' else attn_backend | |
| self.dit_modulation_type = dit_modulation_type | |
| self.heads_num = heads_num | |
| head_dim = hidden_size // heads_num | |
| mlp_hidden_dim = int(hidden_size * mlp_width_ratio) | |
| self.img_mod = load_modulation( | |
| modulate_type=self.dit_modulation_type, | |
| hidden_size=hidden_size, | |
| factor=6, | |
| **factory_kwargs, | |
| ) | |
| self.img_norm1 = nn.LayerNorm( | |
| hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs | |
| ) | |
| self.img_attn_qkv = nn.Linear( | |
| hidden_size, hidden_size * 3, bias=True, **factory_kwargs | |
| ) | |
| self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, | |
| eps=1e-6, **factory_kwargs) | |
| self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, | |
| eps=1e-6, **factory_kwargs) | |
| self.img_attn_proj = nn.Linear( | |
| hidden_size, hidden_size, bias=True, **factory_kwargs | |
| ) | |
| self.img_norm2 = nn.LayerNorm( | |
| hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs | |
| ) | |
| # There is no dtype fpr FeedForward, because FSDP2 casts the dtype for all parameters. | |
| # You may need to give the dtype when no autocast and fsdp !!! | |
| self.img_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, | |
| activation_fn="gelu-approximate") | |
| self.txt_mod = load_modulation( | |
| modulate_type=self.dit_modulation_type, | |
| hidden_size=hidden_size, | |
| factor=6, | |
| **factory_kwargs, | |
| ) | |
| self.txt_norm1 = nn.LayerNorm( | |
| hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs | |
| ) | |
| self.txt_attn_qkv = nn.Linear( | |
| hidden_size, hidden_size * 3, bias=True, **factory_kwargs | |
| ) | |
| self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, | |
| eps=1e-6, **factory_kwargs) | |
| self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, | |
| eps=1e-6, **factory_kwargs) | |
| self.txt_attn_proj = nn.Linear( | |
| hidden_size, hidden_size, bias=True, **factory_kwargs | |
| ) | |
| self.txt_norm2 = nn.LayerNorm( | |
| hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs | |
| ) | |
| self.txt_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, | |
| activation_fn="gelu-approximate") | |
| def forward( | |
| self, | |
| img: torch.Tensor, | |
| txt: torch.Tensor, | |
| vec: torch.Tensor, | |
| vis_freqs_cis: tuple = None, | |
| txt_freqs_cis: tuple = None, | |
| attn_kwargs: Optional[dict] = {}, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| tt, th, tw = attn_kwargs['thw'] | |
| ( | |
| img_mod1_shift, | |
| img_mod1_scale, | |
| img_mod1_gate, | |
| img_mod2_shift, | |
| img_mod2_scale, | |
| img_mod2_gate, | |
| ) = self.img_mod(vec) | |
| ( | |
| txt_mod1_shift, | |
| txt_mod1_scale, | |
| txt_mod1_gate, | |
| txt_mod2_shift, | |
| txt_mod2_scale, | |
| txt_mod2_gate, | |
| ) = self.txt_mod(vec) | |
| # Prepare image for attention. | |
| img_modulated = self.img_norm1(img) | |
| img_modulated = modulate( | |
| img_modulated, shift=img_mod1_shift, scale=img_mod1_scale | |
| ) | |
| img_qkv = self.img_attn_qkv(img_modulated) | |
| img_q, img_k, img_v = rearrange( | |
| img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num | |
| ) | |
| # Apply QK-Norm if needed | |
| img_q = self.img_attn_q_norm(img_q).to(img_v) | |
| img_k = self.img_attn_k_norm(img_k).to(img_v) | |
| # Apply RoPE if needed. | |
| if vis_freqs_cis is not None: | |
| img_qq, img_kk = apply_rotary_emb( | |
| img_q, img_k, vis_freqs_cis, head_first=False) | |
| assert ( | |
| img_qq.shape == img_q.shape and img_kk.shape == img_k.shape | |
| ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" | |
| img_q, img_k = img_qq, img_kk | |
| # Prepare txt for attention. | |
| txt_modulated = self.txt_norm1(txt) | |
| txt_modulated = modulate( | |
| txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale | |
| ) | |
| txt_qkv = self.txt_attn_qkv(txt_modulated) | |
| txt_q, txt_k, txt_v = rearrange( | |
| txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num | |
| ) | |
| # Apply QK-Norm if needed. | |
| txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) | |
| txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) | |
| if txt_freqs_cis is not None: | |
| raise NotImplementedError("RoPE text is not supported for inference") | |
| txt_qq, txt_kk = apply_rotary_emb( | |
| txt_q, txt_k, txt_freqs_cis, head_first=False) | |
| assert ( | |
| txt_qq.shape == txt_q.shape and txt_kk.shape == txt_k.shape | |
| ), f"txt_kk: {txt_qq.shape}, txt_q: {txt_q.shape}, txt_kk: {txt_kk.shape}, txt_k: {txt_k.shape}" | |
| txt_q, txt_k = txt_qq, txt_kk | |
| # attention computation start | |
| q = torch.cat((img_q, txt_q), dim=1) | |
| k = torch.cat((img_k, txt_k), dim=1) | |
| v = torch.cat((img_v, txt_v), dim=1) | |
| attn = attention( | |
| q, k, v, | |
| backend=self.attn_backend, | |
| attn_kwargs=attn_kwargs, | |
| ) | |
| attn = attn.flatten(2, 3) | |
| # attention computation end | |
| img_attn, txt_attn = attn[:, | |
| : img.shape[1]], attn[:, img.shape[1]:] | |
| # Calculate the img bloks. | |
| img = img + apply_gate(self.img_attn_proj(img_attn), | |
| gate=img_mod1_gate) | |
| img = img + apply_gate( | |
| self.img_mlp( | |
| modulate( | |
| self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale | |
| ) | |
| ), | |
| gate=img_mod2_gate, | |
| ) | |
| # Calculate the txt bloks. | |
| txt = txt + apply_gate(self.txt_attn_proj(txt_attn), | |
| gate=txt_mod1_gate) | |
| txt = txt + apply_gate( | |
| self.txt_mlp( | |
| modulate( | |
| self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale | |
| ) | |
| ), | |
| gate=txt_mod2_gate, | |
| ) | |
| return img, txt | |
| class WanTimeTextImageEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| time_freq_dim: int, | |
| time_proj_dim: int, | |
| text_embed_dim: int, | |
| image_embed_dim: Optional[int] = None, | |
| pos_embed_seq_len: Optional[int] = None, | |
| ): | |
| super().__init__() | |
| self.timesteps_proj = Timesteps( | |
| num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) | |
| self.time_embedder = TimestepEmbedding( | |
| in_channels=time_freq_dim, time_embed_dim=dim) | |
| self.act_fn = nn.SiLU() | |
| self.time_proj = nn.Linear(dim, time_proj_dim) | |
| self.text_embedder = PixArtAlphaTextProjection( | |
| text_embed_dim, dim, act_fn="gelu_tanh") | |
| def forward( | |
| self, | |
| timestep: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| ): | |
| timestep = self.timesteps_proj(timestep) | |
| time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype | |
| if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: | |
| timestep = timestep.to(time_embedder_dtype) | |
| temb = self.time_embedder(timestep).type_as(encoder_hidden_states) | |
| timestep_proj = self.time_proj(self.act_fn(temb)) | |
| encoder_hidden_states = self.text_embedder(encoder_hidden_states) | |
| return temb, timestep_proj, encoder_hidden_states | |
| class Transformer3DModel(ModelMixin, ConfigMixin): | |
| _fsdp_shard_conditions: list = [ | |
| lambda name, module: isinstance(module, (MMDoubleStreamBlock))] | |
| _supports_gradient_checkpointing = True | |
| def __init__( | |
| self, | |
| args: Any, | |
| patch_size: list = [1, 2, 2], | |
| in_channels: int = 4, # Should be VAE.config.latent_channels. | |
| out_channels: int = None, | |
| hidden_size: int = 3072, | |
| heads_num: int = 24, | |
| text_states_dim: int = 4096, | |
| mlp_width_ratio: float = 4.0, | |
| mm_double_blocks_depth: int = 20, | |
| rope_dim_list: List[int] = [16, 56, 56], | |
| rope_type: str = 'rope', | |
| dtype: Optional[torch.dtype] = None, | |
| device: Optional[torch.device] = None, | |
| dit_modulation_type: str = "wanx", | |
| attn_backend: str = 'auto', | |
| theta: int = 256, | |
| ): | |
| self.args = args | |
| self.out_channels = out_channels or in_channels | |
| self.patch_size = patch_size | |
| self.hidden_size = hidden_size | |
| self.heads_num = heads_num | |
| self.rope_dim_list = rope_dim_list | |
| self.dit_modulation_type = dit_modulation_type | |
| self.mm_double_blocks_depth = mm_double_blocks_depth | |
| self.attn_backend = get_preferred_attention_backend() if attn_backend == 'auto' else attn_backend | |
| self.rope_type = rope_type | |
| self.theta = theta | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| if hidden_size % heads_num != 0: | |
| raise ValueError( | |
| f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}" | |
| ) | |
| # image projection | |
| self.img_in = nn.Conv3d( | |
| in_channels, hidden_size, kernel_size=patch_size, stride=patch_size) | |
| # condition embedding | |
| self.condition_embedder = WanTimeTextImageEmbedding( | |
| dim=hidden_size, | |
| time_freq_dim=256, | |
| time_proj_dim=hidden_size * 6, | |
| text_embed_dim=text_states_dim, | |
| ) | |
| # double blocks | |
| self.double_blocks = nn.ModuleList( | |
| [ | |
| MMDoubleStreamBlock( | |
| self.hidden_size, | |
| self.heads_num, | |
| mlp_width_ratio=mlp_width_ratio, | |
| dit_modulation_type=self.dit_modulation_type, | |
| attn_backend=attn_backend, | |
| **factory_kwargs, | |
| ) | |
| for _ in range(mm_double_blocks_depth) | |
| ] | |
| ) | |
| # Output norm & projection | |
| self.norm_out = nn.LayerNorm( | |
| hidden_size, elementwise_affine=False, eps=1e-6 | |
| ) | |
| self.proj_out = nn.Linear( | |
| hidden_size, out_channels * math.prod(patch_size), | |
| **factory_kwargs) | |
| def get_rotary_pos_embed(self, vis_rope_size, txt_rope_size=None): | |
| target_ndim = 3 | |
| ndim = 5 - 2 | |
| if len(vis_rope_size) != target_ndim: | |
| vis_rope_size = [1] * (target_ndim - len(vis_rope_size) | |
| ) + vis_rope_size # time axis | |
| head_dim = self.hidden_size // self.heads_num | |
| rope_dim_list = self.rope_dim_list | |
| if rope_dim_list is None: | |
| rope_dim_list = [head_dim // | |
| target_ndim for _ in range(target_ndim)] | |
| assert ( | |
| sum(rope_dim_list) == head_dim | |
| ), "sum(rope_dim_list) should equal to head_dim of attention layer" | |
| vis_freqs, txt_freqs = get_nd_rotary_pos_embed( | |
| rope_dim_list, | |
| vis_rope_size, | |
| txt_rope_size=txt_rope_size, | |
| theta=self.theta, | |
| use_real=True, | |
| theta_rescale_factor=1, | |
| ) | |
| return vis_freqs, txt_freqs | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| timestep: torch.Tensor, # Should be in range(0, 1000). | |
| encoder_hidden_states: torch.Tensor = None, | |
| encoder_hidden_states_mask: torch.Tensor = None, | |
| return_dict: bool = True, | |
| ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: | |
| # For Multi-item Input: hidden_states: (b, n, c, t, h, w) | |
| # Permute the items into the temporal dimension | |
| is_multi_item = (len(hidden_states.shape) == 6) | |
| num_items = 0 | |
| if is_multi_item: | |
| num_items = hidden_states.shape[1] | |
| if num_items > 1: | |
| assert self.patch_size[0] == 1, "For multi-item input, patch_size[0] must be 1" | |
| # Move the last item to the first position | |
| hidden_states = torch.cat( | |
| [ | |
| hidden_states[:, -1:], | |
| hidden_states[:, :-1] | |
| ], | |
| dim=1 | |
| ) | |
| hidden_states = rearrange( | |
| hidden_states, 'b n c t h w -> b c (n t) h w') | |
| out = {} | |
| batch_size, _, ot, oh, ow = hidden_states.shape | |
| tt, th, tw = ( | |
| ot // self.patch_size[0], | |
| oh // self.patch_size[1], | |
| ow // self.patch_size[2], | |
| ) | |
| # Text Mask | |
| if encoder_hidden_states_mask == None: | |
| encoder_hidden_states_mask = torch.ones( | |
| (encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]), dtype=torch.bool).to(encoder_hidden_states.device) | |
| # Prepare img, txt, vec. | |
| img = self.img_in(hidden_states).flatten(2).transpose(1, 2) | |
| temb, vec, txt = self.condition_embedder( | |
| timestep, encoder_hidden_states) | |
| if vec.shape[-1] > self.hidden_size: | |
| vec = vec.unflatten(1, (6, -1)) | |
| txt_seq_len = txt.shape[1] | |
| img_seq_len = img.shape[1] | |
| # rope | |
| vis_freqs_cis, txt_freqs_cis = self.get_rotary_pos_embed(vis_rope_size=( | |
| tt, th, tw), txt_rope_size=txt_seq_len if self.rope_type == 'mrope' else None) | |
| # Compute attn_kwargs | |
| attn_kwargs = {'thw': [tt, th, tw], 'txt_len': txt_seq_len} | |
| if self.attn_backend == 'flash_attn': | |
| cu_seqlens_q = get_cu_seqlens( | |
| encoder_hidden_states_mask, img_seq_len) | |
| cu_seqlens_kv = cu_seqlens_q | |
| max_seqlen_q = img_seq_len + txt_seq_len | |
| max_seqlen_kv = max_seqlen_q | |
| attn_kwargs.update({ | |
| 'cu_seqlens_q': cu_seqlens_q, | |
| 'cu_seqlens_kv': cu_seqlens_kv, | |
| 'max_seqlen_q': max_seqlen_q, | |
| 'max_seqlen_kv': max_seqlen_kv, | |
| }) | |
| # --------------------- Pass through DiT blocks ------------------------ | |
| for _, block in enumerate(self.double_blocks): | |
| double_block_args = [ | |
| img, | |
| txt, | |
| vec, | |
| vis_freqs_cis, | |
| txt_freqs_cis, | |
| attn_kwargs | |
| ] | |
| img, txt = block(*double_block_args) | |
| img_len = img.shape[1] | |
| x = torch.cat((img, txt), 1) | |
| img = x[:, :img_len, ...] | |
| # ---------------------------- Final layer ------------------------------ | |
| img = self.proj_out(self.norm_out(img)) | |
| img = self.unpatchify(img, tt, th, tw) | |
| # Reshape back to multiple items | |
| if is_multi_item: | |
| img = rearrange( | |
| img, 'b c (n t) h w -> b n c t h w', n=num_items) | |
| if num_items > 1: | |
| # Move the first item back to the last position | |
| img = torch.cat( | |
| [ | |
| img[:, 1:], | |
| img[:, :1] | |
| ], | |
| dim=1 | |
| ) | |
| return (img, txt) | |
| def unpatchify(self, x, t, h, w): | |
| """ | |
| x: (N, T, patch_size**2 * C) | |
| imgs: (N, H, W, C) | |
| """ | |
| c = self.out_channels | |
| pt, ph, pw = self.patch_size | |
| assert t * h * w == x.shape[1] | |
| x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c)) | |
| x = torch.einsum("nthwopqc->nctohpwq", x) | |
| imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) | |
| return imgs | |