|
|
from typing import Optional |
|
|
from typing import Union |
|
|
|
|
|
import torch |
|
|
from einops import rearrange |
|
|
from torch import Tensor |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
from einops.layers.torch import Rearrange |
|
|
|
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
|
|
|
from .modules import RMSNorm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_rope_rotations(length: int, |
|
|
dim: int, |
|
|
theta: int, |
|
|
*, |
|
|
freq_scaling: float = 1.0, |
|
|
device: Union[torch.device, str] = 'cpu') -> Tensor: |
|
|
assert dim % 2 == 0 |
|
|
|
|
|
with torch.amp.autocast(device_type='cuda', enabled=False): |
|
|
pos = torch.arange(length, dtype=torch.float32, device=device) |
|
|
freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) |
|
|
freqs *= freq_scaling |
|
|
|
|
|
rot = torch.einsum('..., f -> ... f', pos, freqs) |
|
|
rot = torch.stack([torch.cos(rot), -torch.sin(rot), torch.sin(rot), torch.cos(rot)], dim=-1) |
|
|
rot = rearrange(rot, 'n d (i j) -> 1 n d i j', i=2, j=2) |
|
|
return rot |
|
|
|
|
|
|
|
|
def apply_rope(x: Tensor, rot: Tensor) -> tuple[Tensor, Tensor]: |
|
|
with torch.amp.autocast(device_type='cuda', enabled=False): |
|
|
_x = x.float() |
|
|
_x = _x.view(*_x.shape[:-1], -1, 1, 2) |
|
|
x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1] |
|
|
return x_out.reshape(*x.shape).to(dtype=x.dtype) |
|
|
|
|
|
|
|
|
class TimestepEmbedder(nn.Module): |
|
|
""" |
|
|
Embeds scalar timesteps into vector representations. |
|
|
""" |
|
|
|
|
|
def __init__(self, dim, frequency_embedding_size, max_period): |
|
|
super().__init__() |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(frequency_embedding_size, dim), |
|
|
nn.SiLU(), |
|
|
nn.Linear(dim, dim), |
|
|
) |
|
|
self.dim = dim |
|
|
self.max_period = max_period |
|
|
assert dim % 2 == 0, 'dim must be even.' |
|
|
|
|
|
with torch.autocast('cuda', enabled=False): |
|
|
|
|
|
initial_freqs = 1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) / |
|
|
frequency_embedding_size)) |
|
|
freq_scale = 10000 / max_period |
|
|
freqs_tensor = freq_scale * initial_freqs |
|
|
|
|
|
|
|
|
self.register_buffer('freqs', freqs_tensor, persistent=False) |
|
|
|
|
|
def timestep_embedding(self, t): |
|
|
""" |
|
|
Create sinusoidal timestep embeddings. |
|
|
:param t: a 1-D Tensor of N indices, one per batch element. |
|
|
These may be fractional. |
|
|
:param dim: the dimension of the output. |
|
|
:param max_period: controls the minimum frequency of the embeddings. |
|
|
:return: an (N, D) Tensor of positional embeddings. |
|
|
""" |
|
|
|
|
|
|
|
|
args = t[:, None].float() * self.freqs[None] |
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
|
return embedding |
|
|
|
|
|
def forward(self, t): |
|
|
t_freq = self.timestep_embedding(t).to(t.dtype) |
|
|
t_emb = self.mlp(t_freq) |
|
|
return t_emb |
|
|
|
|
|
class ChannelLastConv1d(nn.Conv1d): |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = x.permute(0, 2, 1) |
|
|
x = super().forward(x) |
|
|
x = x.permute(0, 2, 1) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
hidden_dim: int, |
|
|
multiple_of: int = 256, |
|
|
): |
|
|
""" |
|
|
Initialize the FeedForward module. |
|
|
|
|
|
Args: |
|
|
dim (int): Input dimension. |
|
|
hidden_dim (int): Hidden dimension of the feedforward layer. |
|
|
multiple_of (int): Value to ensure hidden dimension is a multiple of this value. |
|
|
|
|
|
Attributes: |
|
|
w1 (ColumnParallelLinear): Linear transformation for the first layer. |
|
|
w2 (RowParallelLinear): Linear transformation for the second layer. |
|
|
w3 (ColumnParallelLinear): Linear transformation for the third layer. |
|
|
|
|
|
""" |
|
|
super().__init__() |
|
|
hidden_dim = int(2 * hidden_dim / 3) |
|
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
|
|
|
|
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
|
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
|
|
|
|
class ConvMLP(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
hidden_dim: int, |
|
|
multiple_of: int = 256, |
|
|
kernel_size: int = 3, |
|
|
padding: int = 1, |
|
|
): |
|
|
""" |
|
|
Initialize the FeedForward module. |
|
|
|
|
|
Args: |
|
|
dim (int): Input dimension. |
|
|
hidden_dim (int): Hidden dimension of the feedforward layer. |
|
|
multiple_of (int): Value to ensure hidden dimension is a multiple of this value. |
|
|
|
|
|
Attributes: |
|
|
w1 (ColumnParallelLinear): Linear transformation for the first layer. |
|
|
w2 (RowParallelLinear): Linear transformation for the second layer. |
|
|
w3 (ColumnParallelLinear): Linear transformation for the third layer. |
|
|
|
|
|
""" |
|
|
super().__init__() |
|
|
hidden_dim = int(2 * hidden_dim / 3) |
|
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
|
|
|
|
|
self.w1 = ChannelLastConv1d(dim, |
|
|
hidden_dim, |
|
|
bias=False, |
|
|
kernel_size=kernel_size, |
|
|
padding=padding) |
|
|
self.w2 = ChannelLastConv1d(hidden_dim, |
|
|
dim, |
|
|
bias=False, |
|
|
kernel_size=kernel_size, |
|
|
padding=padding) |
|
|
self.w3 = ChannelLastConv1d(dim, |
|
|
hidden_dim, |
|
|
bias=False, |
|
|
kernel_size=kernel_size, |
|
|
padding=padding) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
|
|
|
|
|
|
|
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): |
|
|
return x * (1 + scale) + shift |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): |
|
|
|
|
|
|
|
|
|
|
|
q = q.contiguous() |
|
|
k = k.contiguous() |
|
|
v = v.contiguous() |
|
|
out = F.scaled_dot_product_attention(q, k, v) |
|
|
out = rearrange(out, 'b h n d -> b n (h d)').contiguous() |
|
|
return out |
|
|
|
|
|
|
|
|
class SelfAttention(nn.Module): |
|
|
|
|
|
def __init__(self, dim: int, nheads: int): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.nheads = nheads |
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=True) |
|
|
self.q_norm = RMSNorm(dim // nheads) |
|
|
self.k_norm = RMSNorm(dim // nheads) |
|
|
|
|
|
self.split_into_heads = Rearrange('b n (h d j) -> b h n d j', |
|
|
h=nheads, |
|
|
d=dim // nheads, |
|
|
j=3) |
|
|
|
|
|
def pre_attention( |
|
|
self, x: torch.Tensor, |
|
|
rot: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
|
|
qkv = self.qkv(x) |
|
|
q, k, v = self.split_into_heads(qkv).chunk(3, dim=-1) |
|
|
q = q.squeeze(-1) |
|
|
k = k.squeeze(-1) |
|
|
v = v.squeeze(-1) |
|
|
q = self.q_norm(q) |
|
|
k = self.k_norm(k) |
|
|
|
|
|
if rot is not None: |
|
|
q = apply_rope(q, rot) |
|
|
k = apply_rope(k, rot) |
|
|
|
|
|
return q, k, v |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
q, k, v = self.pre_attention(x) |
|
|
out = attention(q, k, v) |
|
|
return out |
|
|
|
|
|
|
|
|
class MMDitSingleBlock(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
dim: int, |
|
|
nhead: int, |
|
|
mlp_ratio: float = 4.0, |
|
|
pre_only: bool = False, |
|
|
kernel_size: int = 7, |
|
|
padding: int = 3): |
|
|
super().__init__() |
|
|
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False) |
|
|
self.attn = SelfAttention(dim, nhead) |
|
|
|
|
|
self.pre_only = pre_only |
|
|
if pre_only: |
|
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True)) |
|
|
else: |
|
|
if kernel_size == 1: |
|
|
self.linear1 = nn.Linear(dim, dim) |
|
|
else: |
|
|
self.linear1 = ChannelLastConv1d(dim, dim, kernel_size=kernel_size, padding=padding) |
|
|
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False) |
|
|
|
|
|
if kernel_size == 1: |
|
|
self.ffn = MLP(dim, int(dim * mlp_ratio)) |
|
|
else: |
|
|
self.ffn = ConvMLP(dim, |
|
|
int(dim * mlp_ratio), |
|
|
kernel_size=kernel_size, |
|
|
padding=padding) |
|
|
|
|
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True)) |
|
|
|
|
|
def pre_attention(self, x: torch.Tensor, c: torch.Tensor, rot: Optional[torch.Tensor]): |
|
|
|
|
|
|
|
|
modulation = self.adaLN_modulation(c) |
|
|
if self.pre_only: |
|
|
(shift_msa, scale_msa) = modulation.chunk(2, dim=-1) |
|
|
gate_msa = shift_mlp = scale_mlp = gate_mlp = None |
|
|
else: |
|
|
(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, |
|
|
gate_mlp) = modulation.chunk(6, dim=-1) |
|
|
|
|
|
x = modulate(self.norm1(x), shift_msa, scale_msa) |
|
|
q, k, v = self.attn.pre_attention(x, rot) |
|
|
return (q, k, v), (gate_msa, shift_mlp, scale_mlp, gate_mlp) |
|
|
|
|
|
def post_attention(self, x: torch.Tensor, attn_out: torch.Tensor, c: tuple[torch.Tensor]): |
|
|
if self.pre_only: |
|
|
return x |
|
|
|
|
|
(gate_msa, shift_mlp, scale_mlp, gate_mlp) = c |
|
|
x = x + self.linear1(attn_out) * gate_msa |
|
|
r = modulate(self.norm2(x), shift_mlp, scale_mlp) |
|
|
x = x + self.ffn(r) * gate_mlp |
|
|
|
|
|
return x |
|
|
|
|
|
def forward(self, x: torch.Tensor, cond: torch.Tensor, |
|
|
rot: Optional[torch.Tensor]) -> torch.Tensor: |
|
|
|
|
|
|
|
|
x_qkv, x_conditions = self.pre_attention(x, cond, rot) |
|
|
attn_out = attention(*x_qkv) |
|
|
x = self.post_attention(x, attn_out, x_conditions) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class JointBlock_AT(nn.Module): |
|
|
""" |
|
|
Audio + Text only JointBlock(去掉 clip 分支) |
|
|
返回 (latent, text_f) |
|
|
""" |
|
|
def __init__(self, dim: int, nhead: int, mlp_ratio: float = 4.0, pre_only: bool = False): |
|
|
super().__init__() |
|
|
self.pre_only = pre_only |
|
|
self.latent_block = MMDitSingleBlock(dim, |
|
|
nhead, |
|
|
mlp_ratio, |
|
|
pre_only=False, |
|
|
kernel_size=3, |
|
|
padding=1) |
|
|
|
|
|
self.text_block = MMDitSingleBlock(dim, nhead, mlp_ratio, pre_only=pre_only, kernel_size=1) |
|
|
|
|
|
def forward(self, latent: torch.Tensor, text_f: torch.Tensor, |
|
|
global_c: torch.Tensor, extended_c: torch.Tensor, latent_rot: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_qkv, x_mod = self.latent_block.pre_attention(latent, extended_c, latent_rot) |
|
|
|
|
|
|
|
|
t_qkv, t_mod = self.text_block.pre_attention(text_f, global_c, rot=None) |
|
|
|
|
|
latent_len = latent.shape[1] |
|
|
text_len = text_f.shape[1] |
|
|
|
|
|
|
|
|
joint_qkv = [torch.cat([x_qkv[i], t_qkv[i]], dim=2) for i in range(3)] |
|
|
|
|
|
attn_out = attention(*joint_qkv) |
|
|
x_attn_out = attn_out[:, :latent_len] |
|
|
t_attn_out = attn_out[:, latent_len:] |
|
|
|
|
|
latent = self.latent_block.post_attention(latent, x_attn_out, x_mod) |
|
|
if not self.pre_only: |
|
|
text_f = self.text_block.post_attention(text_f, t_attn_out, t_mod) |
|
|
|
|
|
return latent, text_f |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FinalBlock(nn.Module): |
|
|
|
|
|
def __init__(self, dim, out_dim): |
|
|
super().__init__() |
|
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True)) |
|
|
self.norm = nn.LayerNorm(dim, elementwise_affine=False) |
|
|
self.conv = ChannelLastConv1d(dim, out_dim, kernel_size=7, padding=3) |
|
|
|
|
|
def forward(self, latent, c): |
|
|
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) |
|
|
latent = modulate(self.norm(latent), shift, scale) |
|
|
latent = self.conv(latent) |
|
|
return latent |