File size: 2,247 Bytes
bc8c4af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | import torch
import torch.nn as nn
from .wan_video_dit import WanModel, precompute_freqs_cis, sinusoidal_embedding_1d
from einops import rearrange
from ..core import gradient_checkpoint_forward
def precompute_freqs_cis_1d(dim: int, end: int = 16384, theta: float = 10000.0):
f_freqs_cis = precompute_freqs_cis(dim, end, theta)
return f_freqs_cis.chunk(3, dim=-1)
class MovaAudioDit(WanModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
head_dim = kwargs.get("dim", 1536) // kwargs.get("num_heads", 12)
self.freqs = precompute_freqs_cis_1d(head_dim)
self.patch_embedding = nn.Conv1d(
kwargs.get("in_dim", 128), kwargs.get("dim", 1536), kernel_size=[1], stride=[1]
)
def precompute_freqs_cis(self, dim: int, end: int = 16384, theta: float = 10000.0):
self.f_freqs_cis = precompute_freqs_cis_1d(dim, end, theta)
def forward(self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
context = self.text_embedding(context)
x, (f, ) = self.patchify(x)
freqs = torch.cat([
self.freqs[0][:f].view(f, -1).expand(f, -1),
self.freqs[1][:f].view(f, -1).expand(f, -1),
self.freqs[2][:f].view(f, -1).expand(f, -1),
], dim=-1).reshape(f, 1, -1).to(x.device)
for block in self.blocks:
x = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x, context, t_mod, freqs,
)
x = self.head(x, t)
x = self.unpatchify(x, (f, ))
return x
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
return rearrange(
x, 'b f (p c) -> b c (f p)',
f=grid_size[0],
p=self.patch_size[0]
)
|