| import torch |
| import torch.nn as nn |
| from .wan_video_dit import WanModel, precompute_freqs_cis, sinusoidal_embedding_1d |
| from einops import rearrange |
| from ..core import gradient_checkpoint_forward |
|
|
| def precompute_freqs_cis_1d(dim: int, end: int = 16384, theta: float = 10000.0): |
| f_freqs_cis = precompute_freqs_cis(dim, end, theta) |
| return f_freqs_cis.chunk(3, dim=-1) |
|
|
| class MovaAudioDit(WanModel): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| head_dim = kwargs.get("dim", 1536) // kwargs.get("num_heads", 12) |
| self.freqs = precompute_freqs_cis_1d(head_dim) |
| self.patch_embedding = nn.Conv1d( |
| kwargs.get("in_dim", 128), kwargs.get("dim", 1536), kernel_size=[1], stride=[1] |
| ) |
|
|
| def precompute_freqs_cis(self, dim: int, end: int = 16384, theta: float = 10000.0): |
| self.f_freqs_cis = precompute_freqs_cis_1d(dim, end, theta) |
|
|
| def forward(self, |
| x: torch.Tensor, |
| timestep: torch.Tensor, |
| context: torch.Tensor, |
| use_gradient_checkpointing: bool = False, |
| use_gradient_checkpointing_offload: bool = False, |
| **kwargs, |
| ): |
| t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) |
| t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) |
| context = self.text_embedding(context) |
| x, (f, ) = self.patchify(x) |
| freqs = torch.cat([ |
| self.freqs[0][:f].view(f, -1).expand(f, -1), |
| self.freqs[1][:f].view(f, -1).expand(f, -1), |
| self.freqs[2][:f].view(f, -1).expand(f, -1), |
| ], dim=-1).reshape(f, 1, -1).to(x.device) |
|
|
| for block in self.blocks: |
| x = gradient_checkpoint_forward( |
| block, |
| use_gradient_checkpointing, |
| use_gradient_checkpointing_offload, |
| x, context, t_mod, freqs, |
| ) |
| x = self.head(x, t) |
| x = self.unpatchify(x, (f, )) |
| return x |
|
|
| def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): |
| return rearrange( |
| x, 'b f (p c) -> b c (f p)', |
| f=grid_size[0], |
| p=self.patch_size[0] |
| ) |
|
|