File size: 13,089 Bytes
c14d03d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 |
import logging
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
# 假设这些是你原来的导入
from .mmdit_layers import compute_rope_rotations
from .mmdit_layers import TimestepEmbedder
from .mmdit_layers import MLP, ChannelLastConv1d, ConvMLP
from .mmdit_layers import (FinalBlock, MMDitSingleBlock, JointBlock_AT)
log = logging.getLogger()
@dataclass
class PreprocessedConditions:
text_f: torch.Tensor
text_f_c: torch.Tensor
class MMAudio(nn.Module):
"""
一个修改版的 MMAudio 接口尽量和LayerFusionAudioDiT一致。
"""
def __init__(self,
*,
latent_dim: int,
text_dim: int,
hidden_dim: int,
depth: int,
fused_depth: int,
num_heads: int,
mlp_ratio: float = 4.0,
latent_seq_len: int,
text_seq_len: int = 640,
# --- 新增参数,对齐 LayerFusionAudioDiT ---
ta_context_dim: int,
ta_context_fusion: str = 'add', # 'add' or 'concat'
ta_context_norm: bool = False,
# --- 其他原有参数 ---
empty_string_feat: Optional[torch.Tensor] = None,
v2: bool = False) -> None:
super().__init__()
self.v2 = v2
self.latent_dim = latent_dim
self._latent_seq_len = latent_seq_len
self._text_seq_len = text_seq_len
self.hidden_dim = hidden_dim
self.num_heads = num_heads
# --- 1. time_aligned_context 的投影层 ---
# 我们在这里定义一个投影层,而不是在每个 block 里都定义一个。
# 这样更高效,也符合你代码注释中的想法:“现在是每一层proj,改为不映射”。
# 我们的方案是:只映射一次,然后传递给所有层。
self.ta_context_fusion = ta_context_fusion
self.ta_context_norm_flag = ta_context_norm
if self.ta_context_fusion == "add":
# 如果是相加融合,将 ta_context 投射到和 latent 一样的维度 (hidden_dim)
self.ta_context_projection = nn.Linear(ta_context_dim, hidden_dim, bias=False)
self.ta_context_norm = nn.LayerNorm(ta_context_dim) if self.ta_context_norm_flag else nn.Identity()
elif self.ta_context_fusion == "concat":
# 如果是拼接融合,在 block 内部处理,这里不需要主投影层
# 但你的原始代码在concat后也有一个projection,我们可以在 block 内部实现
# 为了简化,这里先假设主要的融合逻辑在 block 内部
self.ta_context_projection = nn.Identity()
self.ta_context_norm = nn.Identity()
else:
raise ValueError(f"Unknown ta_context_fusion type: {ta_context_fusion}")
# --- 原有的输入投影层 (基本不变) ---
# 现在我的输入要变为editing,需要变为latent*2
self.audio_input_proj = nn.Sequential(
ChannelLastConv1d(latent_dim*2, hidden_dim, kernel_size=7, padding=3),
nn.SELU(),
ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3),
)
self.text_input_proj = nn.Sequential(
nn.Linear(text_dim, hidden_dim),
MLP(hidden_dim, hidden_dim * 4),
)
self.text_cond_proj = nn.Linear(hidden_dim, hidden_dim)
self.global_cond_mlp = MLP(hidden_dim, hidden_dim * 4)
#
self.t_embed = TimestepEmbedder(hidden_dim, frequency_embedding_size=256, max_period=10000)
# --- Transformer Blocks (基本不变) ---
# **重要**: 你需要修改 JointBlock_AT 和 MMDitSingleBlock 的 forward 定义来接收 `time_aligned_context`
self.joint_blocks = nn.ModuleList([
JointBlock_AT(hidden_dim, num_heads, mlp_ratio=mlp_ratio, pre_only=(i == depth - fused_depth - 1))
for i in range(depth - fused_depth)
])
self.fused_blocks = nn.ModuleList([
MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=3, padding=1)
for i in range(fused_depth)
])
# --- 输出层 (不变) ---
self.final_layer = FinalBlock(hidden_dim, latent_dim)
if empty_string_feat is None:
empty_string_feat = torch.zeros((text_seq_len, text_dim))
self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False)
self.initialize_weights()
self.initialize_rotations()
def initialize_rotations(self):
base_freq = 1.0
# 唯一需要用到长度的
latent_rot = compute_rope_rotations(self._latent_seq_len,
self.hidden_dim // self.num_heads,
10000,
freq_scaling=base_freq,
device="cuda" if torch.cuda.is_available() else "cpu")
# add to model buffers
self.register_buffer('latent_rot', latent_rot, persistent=False)
# self.clip_rot = nn.Buffer(clip_rot, persistent=False)
def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None:
self._latent_seq_len = latent_seq_len
self._clip_seq_len = clip_seq_len
self._sync_seq_len = sync_seq_len
self.initialize_rotations()
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:兼容性保护
for block in self.joint_blocks:
nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0)
for block in self.fused_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.conv.weight, 0)
nn.init.constant_(self.final_layer.conv.bias, 0)
def preprocess_conditions(self, text_f: torch.Tensor) -> PreprocessedConditions:
# 预处理文本条件
# assert text_f.shape[1] == self._text_seq_len, f'{text_f.shape=} {self._text_seq_len=}'
bs = text_f.shape[0]
# 这里固定外部的llm_embedding
text_f = self.text_input_proj(text_f)
# 全局的条件
text_f_c = self.text_cond_proj(text_f.mean(dim=1))
return PreprocessedConditions(text_f=text_f, text_f_c=text_f_c)
def predict_flow(self, x: torch.Tensor, timesteps: torch.Tensor,
conditions: PreprocessedConditions,
time_aligned_context: torch.Tensor) -> torch.Tensor:
"""
核心的预测流程,现在加入了 time_aligned_context。
"""
assert x.shape[2] == self._latent_seq_len, f'{x.shape=} {self._latent_seq_len=}'
# 1. 预处理各种输入
text_f = conditions.text_f
text_f_c = conditions.text_f_c
timesteps = timesteps.to(x.dtype) # 保持和输入张量同 dtype
global_c = self.global_cond_mlp(text_f_c) # (B, D)
# 2. 融合 timestep
global_c = self.t_embed(timesteps).unsqueeze(1) + global_c.unsqueeze(1) # (B, 1, D)
extended_c = global_c # 这个将作为 AdaLN 的条件
"""
这里决定了x的形状,需要debug
"""
# 3. **处理 time_aligned_context** 这里第一种方式是直接和latent进行融合,然后投影
# 从128->256
x = torch.cat([x.transpose(1, 2), time_aligned_context], dim=-1)
latent = self.audio_input_proj(x) # (B, N, D)
# 4. 依次通过 Transformer Blocks
for block in self.joint_blocks:
# **你需要修改 JointBlock_AT.forward**
latent, text_f = block(latent, text_f, global_c, extended_c,
self.latent_rot)
for block in self.fused_blocks:
# **你需要修改 MMDitSingleBlock.forward**
latent = block(latent, extended_c, self.latent_rot)
# 5. 通过输出层
flow = self.final_layer(latent, global_c)
return flow
def forward(self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
time_aligned_context: torch.Tensor,
x_mask=None,
context_mask=None,
) -> torch.Tensor:
"""
模型主入口,接口已对齐 LayerFusionAudioDiT。
- x: 噪声 latent, shape (B, N_latent, latent_dim)
- timesteps: 时间步, shape (B,)
- context: 文本条件, shape (B, N_text, text_dim)
- time_aligned_context: 时间对齐的条件, shape (B, N_ta, ta_context_dim)
"""
if timesteps.dim() == 0:
timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
text_conditions = self.preprocess_conditions(context)
# 调用核心预测流
flow = self.predict_flow(x, timesteps, text_conditions, time_aligned_context)
flow = flow.transpose(1, 2)
return flow
@property
def latent_seq_len(self) -> int:
return self._latent_seq_len
# latent(b,500,128)
def small_16k(**kwargs) -> MMAudio:
num_heads = 16
return MMAudio(latent_dim=128,
text_dim=1024,
hidden_dim=64 * num_heads,
depth=12,
fused_depth=8,
num_heads=num_heads,
latent_seq_len=500,
**kwargs)
if __name__ == '__main__':
batch_size = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
config = {
"ta_context_dim": 128,
"ta_context_fusion": "concat",
"ta_context_norm": False
}
try:
model = small_16k(**config).to(device)
model.eval() # 使用评估模式
print("Model instantiated successfully!")
except Exception as e:
print(f"Error during model instantiation: {e}")
exit()
num_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f'Number of parameters: {num_params:.2f}M')
latent_dim = 128
latent_seq_len = 500
text_dim = 1024
#
text_seq_len = 640
ta_context_dim = config["ta_context_dim"]
dummy_x = torch.randn(batch_size,latent_dim, latent_seq_len, device=device)
dummy_timesteps = torch.randint(0, 1000, (batch_size,), device=device)
dummy_context = torch.randn(batch_size, text_seq_len, text_dim, device=device)
# 这里的 time_aligned_context 形状需要和 x 一致,以便在特征维度上拼接
dummy_ta_context = torch.randn(batch_size, latent_seq_len, ta_context_dim, device=device)
print("\n--- Input Shapes ---")
print(f"x (latent): {dummy_x.shape}")
print(f"timesteps: {dummy_timesteps.shape}")
print(f"context (text): {dummy_context.shape}")
print(f"time_aligned_context: {dummy_ta_context.shape}")
print("--------------------\n")
# 4. 执行前向传播
try:
with torch.no_grad(): # 在验证时不需要计算梯度
output = model(
x=dummy_x,
timesteps=dummy_timesteps,
context=dummy_context,
time_aligned_context=dummy_ta_context
)
print("✅ Forward pass successful!")
print(f"Output shape: {output.shape}")
# 5. 验证输出形状
expected_shape = (batch_size, latent_seq_len, latent_dim)
assert output.shape == expected_shape, \
f"Output shape mismatch! Expected {expected_shape}, but got {output.shape}"
print("✅ Output shape is correct!")
except Exception as e:
print(f"❌ Error during forward pass: {e}")
import traceback
traceback.print_exc() |