|
|
import logging |
|
|
from dataclasses import dataclass |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
from .mmdit_layers import compute_rope_rotations |
|
|
from .mmdit_layers import TimestepEmbedder |
|
|
from .mmdit_layers import MLP, ChannelLastConv1d, ConvMLP |
|
|
from .mmdit_layers import (FinalBlock, MMDitSingleBlock, JointBlock_AT) |
|
|
|
|
|
log = logging.getLogger() |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PreprocessedConditions: |
|
|
text_f: torch.Tensor |
|
|
text_f_c: torch.Tensor |
|
|
|
|
|
|
|
|
class MMAudio(nn.Module): |
|
|
""" |
|
|
一个修改版的 MMAudio 接口尽量和LayerFusionAudioDiT一致。 |
|
|
""" |
|
|
def __init__(self, |
|
|
*, |
|
|
latent_dim: int, |
|
|
text_dim: int, |
|
|
hidden_dim: int, |
|
|
depth: int, |
|
|
fused_depth: int, |
|
|
num_heads: int, |
|
|
mlp_ratio: float = 4.0, |
|
|
latent_seq_len: int, |
|
|
text_seq_len: int = 640, |
|
|
|
|
|
ta_context_dim: int, |
|
|
ta_context_fusion: str = 'add', |
|
|
ta_context_norm: bool = False, |
|
|
|
|
|
empty_string_feat: Optional[torch.Tensor] = None, |
|
|
v2: bool = False) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.v2 = v2 |
|
|
self.latent_dim = latent_dim |
|
|
self._latent_seq_len = latent_seq_len |
|
|
self._text_seq_len = text_seq_len |
|
|
self.hidden_dim = hidden_dim |
|
|
self.num_heads = num_heads |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.ta_context_fusion = ta_context_fusion |
|
|
self.ta_context_norm_flag = ta_context_norm |
|
|
|
|
|
if self.ta_context_fusion == "add": |
|
|
|
|
|
self.ta_context_projection = nn.Linear(ta_context_dim, hidden_dim, bias=False) |
|
|
self.ta_context_norm = nn.LayerNorm(ta_context_dim) if self.ta_context_norm_flag else nn.Identity() |
|
|
elif self.ta_context_fusion == "concat": |
|
|
|
|
|
|
|
|
|
|
|
self.ta_context_projection = nn.Identity() |
|
|
self.ta_context_norm = nn.Identity() |
|
|
else: |
|
|
raise ValueError(f"Unknown ta_context_fusion type: {ta_context_fusion}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.audio_input_proj = nn.Sequential( |
|
|
ChannelLastConv1d(latent_dim*2, hidden_dim, kernel_size=7, padding=3), |
|
|
nn.SELU(), |
|
|
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3), |
|
|
) |
|
|
self.text_input_proj = nn.Sequential( |
|
|
nn.Linear(text_dim, hidden_dim), |
|
|
MLP(hidden_dim, hidden_dim * 4), |
|
|
) |
|
|
|
|
|
self.text_cond_proj = nn.Linear(hidden_dim, hidden_dim) |
|
|
self.global_cond_mlp = MLP(hidden_dim, hidden_dim * 4) |
|
|
|
|
|
|
|
|
|
|
|
self.t_embed = TimestepEmbedder(hidden_dim, frequency_embedding_size=256, max_period=10000) |
|
|
|
|
|
|
|
|
|
|
|
self.joint_blocks = nn.ModuleList([ |
|
|
JointBlock_AT(hidden_dim, num_heads, mlp_ratio=mlp_ratio, pre_only=(i == depth - fused_depth - 1)) |
|
|
for i in range(depth - fused_depth) |
|
|
]) |
|
|
self.fused_blocks = nn.ModuleList([ |
|
|
MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=3, padding=1) |
|
|
for i in range(fused_depth) |
|
|
]) |
|
|
|
|
|
|
|
|
self.final_layer = FinalBlock(hidden_dim, latent_dim) |
|
|
|
|
|
|
|
|
if empty_string_feat is None: |
|
|
empty_string_feat = torch.zeros((text_seq_len, text_dim)) |
|
|
|
|
|
self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False) |
|
|
|
|
|
self.initialize_weights() |
|
|
self.initialize_rotations() |
|
|
|
|
|
def initialize_rotations(self): |
|
|
base_freq = 1.0 |
|
|
|
|
|
|
|
|
latent_rot = compute_rope_rotations(self._latent_seq_len, |
|
|
self.hidden_dim // self.num_heads, |
|
|
10000, |
|
|
freq_scaling=base_freq, |
|
|
device="cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
self.register_buffer('latent_rot', latent_rot, persistent=False) |
|
|
|
|
|
|
|
|
def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None: |
|
|
self._latent_seq_len = latent_seq_len |
|
|
self._clip_seq_len = clip_seq_len |
|
|
self._sync_seq_len = sync_seq_len |
|
|
self.initialize_rotations() |
|
|
|
|
|
def initialize_weights(self): |
|
|
|
|
|
def _basic_init(module): |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.constant_(module.bias, 0) |
|
|
|
|
|
self.apply(_basic_init) |
|
|
|
|
|
|
|
|
nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02) |
|
|
nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02) |
|
|
|
|
|
|
|
|
for block in self.joint_blocks: |
|
|
nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0) |
|
|
nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0) |
|
|
nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0) |
|
|
nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0) |
|
|
for block in self.fused_blocks: |
|
|
nn.init.constant_(block.adaLN_modulation[-1].weight, 0) |
|
|
nn.init.constant_(block.adaLN_modulation[-1].bias, 0) |
|
|
|
|
|
|
|
|
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) |
|
|
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) |
|
|
nn.init.constant_(self.final_layer.conv.weight, 0) |
|
|
nn.init.constant_(self.final_layer.conv.bias, 0) |
|
|
|
|
|
|
|
|
|
|
|
def preprocess_conditions(self, text_f: torch.Tensor) -> PreprocessedConditions: |
|
|
|
|
|
|
|
|
bs = text_f.shape[0] |
|
|
|
|
|
|
|
|
text_f = self.text_input_proj(text_f) |
|
|
|
|
|
text_f_c = self.text_cond_proj(text_f.mean(dim=1)) |
|
|
return PreprocessedConditions(text_f=text_f, text_f_c=text_f_c) |
|
|
|
|
|
def predict_flow(self, x: torch.Tensor, timesteps: torch.Tensor, |
|
|
conditions: PreprocessedConditions, |
|
|
time_aligned_context: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
核心的预测流程,现在加入了 time_aligned_context。 |
|
|
""" |
|
|
assert x.shape[2] == self._latent_seq_len, f'{x.shape=} {self._latent_seq_len=}' |
|
|
|
|
|
|
|
|
text_f = conditions.text_f |
|
|
text_f_c = conditions.text_f_c |
|
|
|
|
|
timesteps = timesteps.to(x.dtype) |
|
|
|
|
|
global_c = self.global_cond_mlp(text_f_c) |
|
|
|
|
|
|
|
|
global_c = self.t_embed(timesteps).unsqueeze(1) + global_c.unsqueeze(1) |
|
|
extended_c = global_c |
|
|
""" |
|
|
这里决定了x的形状,需要debug |
|
|
""" |
|
|
|
|
|
|
|
|
x = torch.cat([x.transpose(1, 2), time_aligned_context], dim=-1) |
|
|
latent = self.audio_input_proj(x) |
|
|
|
|
|
|
|
|
for block in self.joint_blocks: |
|
|
|
|
|
latent, text_f = block(latent, text_f, global_c, extended_c, |
|
|
self.latent_rot) |
|
|
|
|
|
for block in self.fused_blocks: |
|
|
|
|
|
latent = block(latent, extended_c, self.latent_rot) |
|
|
|
|
|
|
|
|
flow = self.final_layer(latent, global_c) |
|
|
return flow |
|
|
|
|
|
def forward(self, |
|
|
x: torch.Tensor, |
|
|
timesteps: torch.Tensor, |
|
|
context: torch.Tensor, |
|
|
time_aligned_context: torch.Tensor, |
|
|
x_mask=None, |
|
|
context_mask=None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
模型主入口,接口已对齐 LayerFusionAudioDiT。 |
|
|
- x: 噪声 latent, shape (B, N_latent, latent_dim) |
|
|
- timesteps: 时间步, shape (B,) |
|
|
- context: 文本条件, shape (B, N_text, text_dim) |
|
|
- time_aligned_context: 时间对齐的条件, shape (B, N_ta, ta_context_dim) |
|
|
""" |
|
|
|
|
|
if timesteps.dim() == 0: |
|
|
timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long) |
|
|
|
|
|
text_conditions = self.preprocess_conditions(context) |
|
|
|
|
|
|
|
|
flow = self.predict_flow(x, timesteps, text_conditions, time_aligned_context) |
|
|
|
|
|
|
|
|
flow = flow.transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return flow |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
def latent_seq_len(self) -> int: |
|
|
return self._latent_seq_len |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def small_16k(**kwargs) -> MMAudio: |
|
|
num_heads = 16 |
|
|
return MMAudio(latent_dim=128, |
|
|
text_dim=1024, |
|
|
hidden_dim=64 * num_heads, |
|
|
depth=12, |
|
|
fused_depth=8, |
|
|
num_heads=num_heads, |
|
|
latent_seq_len=500, |
|
|
**kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
batch_size = 4 |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
config = { |
|
|
"ta_context_dim": 128, |
|
|
"ta_context_fusion": "concat", |
|
|
"ta_context_norm": False |
|
|
} |
|
|
|
|
|
|
|
|
try: |
|
|
model = small_16k(**config).to(device) |
|
|
model.eval() |
|
|
print("Model instantiated successfully!") |
|
|
except Exception as e: |
|
|
print(f"Error during model instantiation: {e}") |
|
|
exit() |
|
|
|
|
|
|
|
|
num_params = sum(p.numel() for p in model.parameters()) / 1e6 |
|
|
print(f'Number of parameters: {num_params:.2f}M') |
|
|
|
|
|
|
|
|
latent_dim = 128 |
|
|
latent_seq_len = 500 |
|
|
text_dim = 1024 |
|
|
|
|
|
text_seq_len = 640 |
|
|
ta_context_dim = config["ta_context_dim"] |
|
|
|
|
|
dummy_x = torch.randn(batch_size,latent_dim, latent_seq_len, device=device) |
|
|
dummy_timesteps = torch.randint(0, 1000, (batch_size,), device=device) |
|
|
dummy_context = torch.randn(batch_size, text_seq_len, text_dim, device=device) |
|
|
|
|
|
|
|
|
dummy_ta_context = torch.randn(batch_size, latent_seq_len, ta_context_dim, device=device) |
|
|
|
|
|
print("\n--- Input Shapes ---") |
|
|
print(f"x (latent): {dummy_x.shape}") |
|
|
print(f"timesteps: {dummy_timesteps.shape}") |
|
|
print(f"context (text): {dummy_context.shape}") |
|
|
print(f"time_aligned_context: {dummy_ta_context.shape}") |
|
|
print("--------------------\n") |
|
|
|
|
|
|
|
|
try: |
|
|
with torch.no_grad(): |
|
|
output = model( |
|
|
x=dummy_x, |
|
|
timesteps=dummy_timesteps, |
|
|
context=dummy_context, |
|
|
time_aligned_context=dummy_ta_context |
|
|
) |
|
|
print("✅ Forward pass successful!") |
|
|
print(f"Output shape: {output.shape}") |
|
|
|
|
|
|
|
|
expected_shape = (batch_size, latent_seq_len, latent_dim) |
|
|
assert output.shape == expected_shape, \ |
|
|
f"Output shape mismatch! Expected {expected_shape}, but got {output.shape}" |
|
|
print("✅ Output shape is correct!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error during forward pass: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |