MMEdit / models /dit /mmdit_back.py
CocoBro's picture
init space
c14d03d
import logging
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
# 假设这些是你原来的导入
from .mmdit_layers import compute_rope_rotations
from .mmdit_layers import TimestepEmbedder
from .mmdit_layers import MLP, ChannelLastConv1d, ConvMLP
from .mmdit_layers import (FinalBlock, MMDitSingleBlock, JointBlock_AT)
log = logging.getLogger()
@dataclass
class PreprocessedConditions:
text_f: torch.Tensor
text_f_c: torch.Tensor
class MMAudio(nn.Module):
"""
一个修改版的 MMAudio 接口尽量和LayerFusionAudioDiT一致。
"""
def __init__(self,
*,
latent_dim: int,
text_dim: int,
hidden_dim: int,
depth: int,
fused_depth: int,
num_heads: int,
mlp_ratio: float = 4.0,
latent_seq_len: int,
text_seq_len: int = 640,
# --- 新增参数,对齐 LayerFusionAudioDiT ---
ta_context_dim: int,
ta_context_fusion: str = 'add', # 'add' or 'concat'
ta_context_norm: bool = False,
# --- 其他原有参数 ---
empty_string_feat: Optional[torch.Tensor] = None,
v2: bool = False) -> None:
super().__init__()
self.v2 = v2
self.latent_dim = latent_dim
self._latent_seq_len = latent_seq_len
self._text_seq_len = text_seq_len
self.hidden_dim = hidden_dim
self.num_heads = num_heads
# --- 1. time_aligned_context 的投影层 ---
# 我们在这里定义一个投影层,而不是在每个 block 里都定义一个。
# 这样更高效,也符合你代码注释中的想法:“现在是每一层proj,改为不映射”。
# 我们的方案是:只映射一次,然后传递给所有层。
self.ta_context_fusion = ta_context_fusion
self.ta_context_norm_flag = ta_context_norm
if self.ta_context_fusion == "add":
# 如果是相加融合,将 ta_context 投射到和 latent 一样的维度 (hidden_dim)
self.ta_context_projection = nn.Linear(ta_context_dim, hidden_dim, bias=False)
self.ta_context_norm = nn.LayerNorm(ta_context_dim) if self.ta_context_norm_flag else nn.Identity()
elif self.ta_context_fusion == "concat":
# 如果是拼接融合,在 block 内部处理,这里不需要主投影层
# 但你的原始代码在concat后也有一个projection,我们可以在 block 内部实现
# 为了简化,这里先假设主要的融合逻辑在 block 内部
self.ta_context_projection = nn.Identity()
self.ta_context_norm = nn.Identity()
else:
raise ValueError(f"Unknown ta_context_fusion type: {ta_context_fusion}")
# --- 原有的输入投影层 (基本不变) ---
# 现在我的输入要变为editing,需要变为latent*2
self.audio_input_proj = nn.Sequential(
ChannelLastConv1d(latent_dim*2, hidden_dim, kernel_size=7, padding=3),
nn.SELU(),
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3),
)
self.text_input_proj = nn.Sequential(
nn.Linear(text_dim, hidden_dim),
MLP(hidden_dim, hidden_dim * 4),
)
self.text_cond_proj = nn.Linear(hidden_dim, hidden_dim)
self.global_cond_mlp = MLP(hidden_dim, hidden_dim * 4)
#
self.t_embed = TimestepEmbedder(hidden_dim, frequency_embedding_size=256, max_period=10000)
# --- Transformer Blocks (基本不变) ---
# **重要**: 你需要修改 JointBlock_AT 和 MMDitSingleBlock 的 forward 定义来接收 `time_aligned_context`
self.joint_blocks = nn.ModuleList([
JointBlock_AT(hidden_dim, num_heads, mlp_ratio=mlp_ratio, pre_only=(i == depth - fused_depth - 1))
for i in range(depth - fused_depth)
])
self.fused_blocks = nn.ModuleList([
MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=3, padding=1)
for i in range(fused_depth)
])
# --- 输出层 (不变) ---
self.final_layer = FinalBlock(hidden_dim, latent_dim)
if empty_string_feat is None:
empty_string_feat = torch.zeros((text_seq_len, text_dim))
self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False)
self.initialize_weights()
self.initialize_rotations()
def initialize_rotations(self):
base_freq = 1.0
# 唯一需要用到长度的
latent_rot = compute_rope_rotations(self._latent_seq_len,
self.hidden_dim // self.num_heads,
10000,
freq_scaling=base_freq,
device="cuda" if torch.cuda.is_available() else "cpu")
# add to model buffers
self.register_buffer('latent_rot', latent_rot, persistent=False)
# self.clip_rot = nn.Buffer(clip_rot, persistent=False)
def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None:
self._latent_seq_len = latent_seq_len
self._clip_seq_len = clip_seq_len
self._sync_seq_len = sync_seq_len
self.initialize_rotations()
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:兼容性保护
for block in self.joint_blocks:
nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0)
for block in self.fused_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.conv.weight, 0)
nn.init.constant_(self.final_layer.conv.bias, 0)
def preprocess_conditions(self, text_f: torch.Tensor) -> PreprocessedConditions:
# 预处理文本条件
# assert text_f.shape[1] == self._text_seq_len, f'{text_f.shape=} {self._text_seq_len=}'
bs = text_f.shape[0]
# 这里固定外部的llm_embedding
text_f = self.text_input_proj(text_f)
# 全局的条件
text_f_c = self.text_cond_proj(text_f.mean(dim=1))
return PreprocessedConditions(text_f=text_f, text_f_c=text_f_c)
def predict_flow(self, x: torch.Tensor, timesteps: torch.Tensor,
conditions: PreprocessedConditions,
time_aligned_context: torch.Tensor) -> torch.Tensor:
"""
核心的预测流程,现在加入了 time_aligned_context。
"""
assert x.shape[2] == self._latent_seq_len, f'{x.shape=} {self._latent_seq_len=}'
# 1. 预处理各种输入
text_f = conditions.text_f
text_f_c = conditions.text_f_c
timesteps = timesteps.to(x.dtype) # 保持和输入张量同 dtype
global_c = self.global_cond_mlp(text_f_c) # (B, D)
# 2. 融合 timestep
global_c = self.t_embed(timesteps).unsqueeze(1) + global_c.unsqueeze(1) # (B, 1, D)
extended_c = global_c # 这个将作为 AdaLN 的条件
"""
这里决定了x的形状,需要debug
"""
# 3. **处理 time_aligned_context** 这里第一种方式是直接和latent进行融合,然后投影
# 从128->256
x = torch.cat([x.transpose(1, 2), time_aligned_context], dim=-1)
latent = self.audio_input_proj(x) # (B, N, D)
# 4. 依次通过 Transformer Blocks
for block in self.joint_blocks:
# **你需要修改 JointBlock_AT.forward**
latent, text_f = block(latent, text_f, global_c, extended_c,
self.latent_rot)
for block in self.fused_blocks:
# **你需要修改 MMDitSingleBlock.forward**
latent = block(latent, extended_c, self.latent_rot)
# 5. 通过输出层
flow = self.final_layer(latent, global_c)
return flow
def forward(self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
time_aligned_context: torch.Tensor,
x_mask=None,
context_mask=None,
) -> torch.Tensor:
"""
模型主入口,接口已对齐 LayerFusionAudioDiT。
- x: 噪声 latent, shape (B, N_latent, latent_dim)
- timesteps: 时间步, shape (B,)
- context: 文本条件, shape (B, N_text, text_dim)
- time_aligned_context: 时间对齐的条件, shape (B, N_ta, ta_context_dim)
"""
if timesteps.dim() == 0:
timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
text_conditions = self.preprocess_conditions(context)
# 调用核心预测流
flow = self.predict_flow(x, timesteps, text_conditions, time_aligned_context)
flow = flow.transpose(1, 2)
return flow
@property
def latent_seq_len(self) -> int:
return self._latent_seq_len
# latent(b,500,128)
def small_16k(**kwargs) -> MMAudio:
num_heads = 16
return MMAudio(latent_dim=128,
text_dim=1024,
hidden_dim=64 * num_heads,
depth=12,
fused_depth=8,
num_heads=num_heads,
latent_seq_len=500,
**kwargs)
if __name__ == '__main__':
batch_size = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
config = {
"ta_context_dim": 128,
"ta_context_fusion": "concat",
"ta_context_norm": False
}
try:
model = small_16k(**config).to(device)
model.eval() # 使用评估模式
print("Model instantiated successfully!")
except Exception as e:
print(f"Error during model instantiation: {e}")
exit()
num_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f'Number of parameters: {num_params:.2f}M')
latent_dim = 128
latent_seq_len = 500
text_dim = 1024
#
text_seq_len = 640
ta_context_dim = config["ta_context_dim"]
dummy_x = torch.randn(batch_size,latent_dim, latent_seq_len, device=device)
dummy_timesteps = torch.randint(0, 1000, (batch_size,), device=device)
dummy_context = torch.randn(batch_size, text_seq_len, text_dim, device=device)
# 这里的 time_aligned_context 形状需要和 x 一致,以便在特征维度上拼接
dummy_ta_context = torch.randn(batch_size, latent_seq_len, ta_context_dim, device=device)
print("\n--- Input Shapes ---")
print(f"x (latent): {dummy_x.shape}")
print(f"timesteps: {dummy_timesteps.shape}")
print(f"context (text): {dummy_context.shape}")
print(f"time_aligned_context: {dummy_ta_context.shape}")
print("--------------------\n")
# 4. 执行前向传播
try:
with torch.no_grad(): # 在验证时不需要计算梯度
output = model(
x=dummy_x,
timesteps=dummy_timesteps,
context=dummy_context,
time_aligned_context=dummy_ta_context
)
print("✅ Forward pass successful!")
print(f"Output shape: {output.shape}")
# 5. 验证输出形状
expected_shape = (batch_size, latent_seq_len, latent_dim)
assert output.shape == expected_shape, \
f"Output shape mismatch! Expected {expected_shape}, but got {output.shape}"
print("✅ Output shape is correct!")
except Exception as e:
print(f"❌ Error during forward pass: {e}")
import traceback
traceback.print_exc()