File size: 13,089 Bytes
c14d03d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import logging
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

# 假设这些是你原来的导入
from .mmdit_layers import compute_rope_rotations
from .mmdit_layers import TimestepEmbedder
from .mmdit_layers import MLP, ChannelLastConv1d, ConvMLP
from .mmdit_layers import (FinalBlock, MMDitSingleBlock, JointBlock_AT)

log = logging.getLogger()


@dataclass
class PreprocessedConditions:
    text_f: torch.Tensor
    text_f_c: torch.Tensor


class MMAudio(nn.Module):
    """
    一个修改版的 MMAudio 接口尽量和LayerFusionAudioDiT一致。
    """
    def __init__(self,
                 *,
                 latent_dim: int,
                 text_dim: int,
                 hidden_dim: int,
                 depth: int,
                 fused_depth: int,
                 num_heads: int,
                 mlp_ratio: float = 4.0,
                 latent_seq_len: int,
                 text_seq_len: int = 640,
                 # --- 新增参数,对齐 LayerFusionAudioDiT ---
                 ta_context_dim: int,
                 ta_context_fusion: str = 'add', # 'add' or 'concat'
                 ta_context_norm: bool = False,
                 # --- 其他原有参数 ---
                 empty_string_feat: Optional[torch.Tensor] = None,
                 v2: bool = False) -> None:
        super().__init__()

        self.v2 = v2
        self.latent_dim = latent_dim
        self._latent_seq_len = latent_seq_len
        self._text_seq_len = text_seq_len
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        
        # --- 1. time_aligned_context 的投影层 ---
        # 我们在这里定义一个投影层,而不是在每个 block 里都定义一个。
        # 这样更高效,也符合你代码注释中的想法:“现在是每一层proj,改为不映射”。
        # 我们的方案是:只映射一次,然后传递给所有层。
        self.ta_context_fusion = ta_context_fusion
        self.ta_context_norm_flag = ta_context_norm
        
        if self.ta_context_fusion == "add":
            # 如果是相加融合,将 ta_context 投射到和 latent 一样的维度 (hidden_dim)
            self.ta_context_projection = nn.Linear(ta_context_dim, hidden_dim, bias=False)
            self.ta_context_norm = nn.LayerNorm(ta_context_dim) if self.ta_context_norm_flag else nn.Identity()
        elif self.ta_context_fusion == "concat":
            # 如果是拼接融合,在 block 内部处理,这里不需要主投影层
            # 但你的原始代码在concat后也有一个projection,我们可以在 block 内部实现
            # 为了简化,这里先假设主要的融合逻辑在 block 内部
            self.ta_context_projection = nn.Identity()
            self.ta_context_norm = nn.Identity()
        else:
            raise ValueError(f"Unknown ta_context_fusion type: {ta_context_fusion}")


        # --- 原有的输入投影层 (基本不变) ---
        # 现在我的输入要变为editing,需要变为latent*2
        self.audio_input_proj = nn.Sequential(
            ChannelLastConv1d(latent_dim*2, hidden_dim, kernel_size=7, padding=3),
            nn.SELU(),
            ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3),
        )
        self.text_input_proj = nn.Sequential(
            nn.Linear(text_dim, hidden_dim),
            MLP(hidden_dim, hidden_dim * 4),
        )
            
        self.text_cond_proj = nn.Linear(hidden_dim, hidden_dim)
        self.global_cond_mlp = MLP(hidden_dim, hidden_dim * 4)
        

        # 
        self.t_embed = TimestepEmbedder(hidden_dim, frequency_embedding_size=256, max_period=10000)
            
        # --- Transformer Blocks (基本不变) ---
        # **重要**: 你需要修改 JointBlock_AT 和 MMDitSingleBlock 的 forward 定义来接收 `time_aligned_context`
        self.joint_blocks = nn.ModuleList([
            JointBlock_AT(hidden_dim, num_heads, mlp_ratio=mlp_ratio, pre_only=(i == depth - fused_depth - 1))
            for i in range(depth - fused_depth)
        ])
        self.fused_blocks = nn.ModuleList([
            MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=3, padding=1)
            for i in range(fused_depth)
        ])
        
        # --- 输出层 (不变) ---
        self.final_layer = FinalBlock(hidden_dim, latent_dim)

        
        if empty_string_feat is None:
            empty_string_feat = torch.zeros((text_seq_len, text_dim))
            
        self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False)
        
        self.initialize_weights()
        self.initialize_rotations()

    def initialize_rotations(self):
        base_freq = 1.0

        # 唯一需要用到长度的
        latent_rot = compute_rope_rotations(self._latent_seq_len,
                                            self.hidden_dim // self.num_heads,
                                            10000,
                                            freq_scaling=base_freq,
                                            device="cuda" if torch.cuda.is_available() else "cpu")

        # add to model buffers
        self.register_buffer('latent_rot', latent_rot, persistent=False)
        # self.clip_rot = nn.Buffer(clip_rot, persistent=False)

    def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None:
        self._latent_seq_len = latent_seq_len
        self._clip_seq_len = clip_seq_len
        self._sync_seq_len = sync_seq_len
        self.initialize_rotations()

    def initialize_weights(self):

        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        self.apply(_basic_init)

        # Initialize timestep embedding MLP:
        nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02)
        nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02)

        # Zero-out adaLN modulation layers in DiT blocks:兼容性保护
        for block in self.joint_blocks:
            nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0)
            nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0)
        for block in self.fused_blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)

        # Zero-out output layers:
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.conv.weight, 0)
        nn.init.constant_(self.final_layer.conv.bias, 0)


    
    def preprocess_conditions(self, text_f: torch.Tensor) -> PreprocessedConditions:
        # 预处理文本条件
        # assert text_f.shape[1] == self._text_seq_len, f'{text_f.shape=} {self._text_seq_len=}'
        bs = text_f.shape[0]

        # 这里固定外部的llm_embedding
        text_f = self.text_input_proj(text_f)
        # 全局的条件
        text_f_c = self.text_cond_proj(text_f.mean(dim=1))
        return PreprocessedConditions(text_f=text_f, text_f_c=text_f_c)

    def predict_flow(self, x: torch.Tensor, timesteps: torch.Tensor,
                     conditions: PreprocessedConditions, 
                     time_aligned_context: torch.Tensor) -> torch.Tensor:
        """
        核心的预测流程,现在加入了 time_aligned_context。
        """
        assert x.shape[2] == self._latent_seq_len, f'{x.shape=} {self._latent_seq_len=}'
        
        # 1. 预处理各种输入
        text_f = conditions.text_f
        text_f_c = conditions.text_f_c
        
        timesteps = timesteps.to(x.dtype)  # 保持和输入张量同 dtype

        global_c = self.global_cond_mlp(text_f_c)  # (B, D)
        
        # 2. 融合 timestep
        global_c = self.t_embed(timesteps).unsqueeze(1) + global_c.unsqueeze(1) # (B, 1, D)
        extended_c = global_c # 这个将作为 AdaLN 的条件
        """
        这里决定了x的形状,需要debug
        """
        # 3. **处理 time_aligned_context** 这里第一种方式是直接和latent进行融合,然后投影 
        # 从128->256 
        x = torch.cat([x.transpose(1, 2), time_aligned_context], dim=-1)
        latent = self.audio_input_proj(x)  # (B, N, D)

        # 4. 依次通过 Transformer Blocks
        for block in self.joint_blocks:
            # **你需要修改 JointBlock_AT.forward**
            latent, text_f = block(latent, text_f, global_c, extended_c,
                                           self.latent_rot) 

        for block in self.fused_blocks:
            # **你需要修改 MMDitSingleBlock.forward**
            latent = block(latent, extended_c, self.latent_rot)

        # 5. 通过输出层
        flow = self.final_layer(latent, global_c)
        return flow

    def forward(self, 
                x: torch.Tensor, 
                timesteps: torch.Tensor,
                context: torch.Tensor,
                time_aligned_context: torch.Tensor,
                x_mask=None,
                context_mask=None,
               ) -> torch.Tensor:
        """
        模型主入口,接口已对齐 LayerFusionAudioDiT。
        - x: 噪声 latent, shape (B, N_latent, latent_dim)
        - timesteps: 时间步, shape (B,)
        - context: 文本条件, shape (B, N_text, text_dim)
        - time_aligned_context: 时间对齐的条件, shape (B, N_ta, ta_context_dim)
        """

        if timesteps.dim() == 0:
            timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
        
        text_conditions = self.preprocess_conditions(context)
        
        # 调用核心预测流
        flow = self.predict_flow(x, timesteps, text_conditions, time_aligned_context)
        

        flow = flow.transpose(1, 2)




        return flow



    @property
    def latent_seq_len(self) -> int:
        return self._latent_seq_len
    

# latent(b,500,128)

def small_16k(**kwargs) -> MMAudio:
    num_heads = 16
    return MMAudio(latent_dim=128,
                   text_dim=1024,
                   hidden_dim=64 * num_heads,
                   depth=12,
                   fused_depth=8,
                   num_heads=num_heads,
                   latent_seq_len=500,
                   **kwargs)




if __name__ == '__main__':

    batch_size = 4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")


    config = {
        "ta_context_dim": 128,
        "ta_context_fusion": "concat", 
        "ta_context_norm": False
    }


    try:
        model = small_16k(**config).to(device)
        model.eval() # 使用评估模式
        print("Model instantiated successfully!")
    except Exception as e:
        print(f"Error during model instantiation: {e}")
        exit()


    num_params = sum(p.numel() for p in model.parameters()) / 1e6
    print(f'Number of parameters: {num_params:.2f}M')


    latent_dim = 128
    latent_seq_len = 500
    text_dim = 1024
    # 
    text_seq_len = 640
    ta_context_dim = config["ta_context_dim"]

    dummy_x = torch.randn(batch_size,latent_dim, latent_seq_len, device=device)
    dummy_timesteps = torch.randint(0, 1000, (batch_size,), device=device)
    dummy_context = torch.randn(batch_size, text_seq_len, text_dim, device=device)
    
    # 这里的 time_aligned_context 形状需要和 x 一致,以便在特征维度上拼接
    dummy_ta_context = torch.randn(batch_size, latent_seq_len, ta_context_dim, device=device)

    print("\n--- Input Shapes ---")
    print(f"x (latent):           {dummy_x.shape}")
    print(f"timesteps:            {dummy_timesteps.shape}")
    print(f"context (text):       {dummy_context.shape}")
    print(f"time_aligned_context: {dummy_ta_context.shape}")
    print("--------------------\n")
    
    # 4. 执行前向传播
    try:
        with torch.no_grad(): # 在验证时不需要计算梯度
            output = model(
                x=dummy_x,
                timesteps=dummy_timesteps,
                context=dummy_context,
                time_aligned_context=dummy_ta_context
            )
        print("✅ Forward pass successful!")
        print(f"Output shape: {output.shape}")

        # 5. 验证输出形状
        expected_shape = (batch_size, latent_seq_len, latent_dim)
        assert output.shape == expected_shape, \
            f"Output shape mismatch! Expected {expected_shape}, but got {output.shape}"
        print("✅ Output shape is correct!")

    except Exception as e:
        print(f"❌ Error during forward pass: {e}")
        import traceback
        traceback.print_exc()