potsawee commited on
Commit
d7a2a0f
·
verified ·
1 Parent(s): deb8802

Upload modeling_text_sync_mimi.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_text_sync_mimi.py +713 -0
modeling_text_sync_mimi.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch TextSyncMimi model - Text-synchronous neural audio codec based on Mimi."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Optional, Dict, List, Union
6
+
7
+ from configuration_mimi import MimiConfig
8
+ from modeling_mimi_clean import MimiPreTrainedModel, MimiModel
9
+ from modeling_backbone_components import (
10
+ CrossAttentionTransformer,
11
+ CausalAttentionTransformer
12
+ )
13
+
14
+
15
+ class TextSyncMimi(MimiPreTrainedModel):
16
+ """
17
+ TextSyncMimi: Text-Synchronous Neural Audio Codec Model
18
+
19
+ A neural audio codec model that combines text and speech representations for
20
+ high-quality text-to-speech synthesis. Features:
21
+
22
+ - Learnable text embeddings
23
+ - Cross-attention transformer for text-speech alignment
24
+ - Autoregressive transformer for causal speech generation
25
+ - BCE-based end token prediction for dynamic duration control
26
+
27
+ Architecture:
28
+ - Text Embedding Layer: Maps token IDs to 4,096-dim embeddings
29
+ - Mimi Encoder: Pre-trained audio encoder (frozen)
30
+ - Text Projection: Linear projection from 4,096 to 512 dimensions
31
+ - Cross-Attention Transformer: Aligns text with speech features
32
+ - Autoregressive Transformer: Generates speech representations
33
+ - End Token Classifier: Predicts when to stop generating
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ config: Optional[Union[MimiConfig, 'TextSyncMimiConfig']] = None,
39
+ model_id: Optional[str] = None,
40
+ token: Optional[str] = None,
41
+ alpha: Optional[float] = None,
42
+ cross_attention_layers: Optional[int] = None,
43
+ causal_attention_layers: Optional[int] = None,
44
+ bce_threshold: Optional[float] = None,
45
+ vocab_size: Optional[int] = None,
46
+ ):
47
+ """
48
+ Initialize TextSyncMimi model.
49
+
50
+ Args:
51
+ config: Model configuration (TextSyncMimiConfig or MimiConfig)
52
+ model_id: Mimi model ID (e.g., "kyutai/mimi"). If None, uses config.mimi_model_id
53
+ token: Hugging Face authentication token
54
+ alpha: Weight for BCE end token loss. If None, uses config.alpha
55
+ cross_attention_layers: Number of cross-attention layers. If None, uses config
56
+ causal_attention_layers: Number of autoregressive layers. If None, uses config
57
+ bce_threshold: BCE loss threshold. If None, uses config.bce_threshold
58
+ vocab_size: Text vocabulary size. If None, uses config.vocab_size
59
+ """
60
+ # Handle config initialization for both manual instantiation and from_pretrained
61
+ if config is None:
62
+ if model_id is None:
63
+ raise ValueError("Either config or model_id must be provided")
64
+ config = MimiConfig.from_pretrained(model_id, token=token)
65
+
66
+ super().__init__(config)
67
+
68
+ # Extract parameters from config if not explicitly provided
69
+ if hasattr(config, 'mimi_model_id'):
70
+ model_id = model_id or config.mimi_model_id
71
+ if model_id is None:
72
+ raise ValueError("model_id must be provided either as argument or in config.mimi_model_id")
73
+
74
+ alpha = alpha if alpha is not None else getattr(config, 'alpha', 1.0)
75
+ cross_attention_layers = cross_attention_layers if cross_attention_layers is not None else getattr(config, 'cross_attention_layers', 2)
76
+ causal_attention_layers = causal_attention_layers if causal_attention_layers is not None else getattr(config, 'causal_attention_layers', 2)
77
+ bce_threshold = bce_threshold if bce_threshold is not None else getattr(config, 'bce_threshold', 0.1)
78
+ vocab_size = vocab_size if vocab_size is not None else getattr(config, 'vocab_size', 128256)
79
+
80
+ # load the mimi backbone
81
+ self.config = config
82
+ model = MimiModel.from_pretrained(model_id, token=token)
83
+
84
+ # hyperparameters for auxiliary loss
85
+ self.alpha = alpha
86
+ self.bce_threshold = bce_threshold
87
+
88
+ # Learnable text token embedding
89
+ self.text_token_embedding = nn.Embedding(vocab_size, 4096)
90
+
91
+ # Text projection
92
+ self.text_proj = nn.Linear(4096, 512)
93
+
94
+ # Cross-attention transformer
95
+ cross_attention_config = MimiConfig(**self.config.__dict__)
96
+ cross_attention_config.num_hidden_layers = cross_attention_layers
97
+ cross_attention_config.hidden_size = 512
98
+ self.cross_attention_transformer = CrossAttentionTransformer(cross_attention_config)
99
+
100
+ # decoder part (v1)
101
+ # Auto-regressive decoder:
102
+ # <|text_speech_latent|> [t_i] [s_i] <|time_speech_start|> [z_(i,1)] [z_(i,2)] ... [z_(i,K)] <|time_speech_end|>
103
+ # masking (not computing loss for <|text_speech_latent|> [t_i] [s_i] <|time_speech_start|>
104
+ # t_i already mapped from 4096 (e.g., llama embedding) -> 512
105
+ # s_i already 512
106
+ # z is mimi's decoder-input which is also 512
107
+ causal_attention_config = MimiConfig(**self.config.__dict__)
108
+ causal_attention_config.num_hidden_layers = causal_attention_layers
109
+ causal_attention_config.hidden_size = 512
110
+ self.ar_transformer = CausalAttentionTransformer(causal_attention_config)
111
+
112
+ # embedding for special positions in the autoregressive decoder
113
+ self.text_speech_latent_embed = nn.Embedding(1, 512)
114
+ self.time_speech_start_embed = nn.Embedding(1, 512)
115
+ self.time_speech_end_embed = nn.Embedding(1, 512)
116
+
117
+ # Binary classification head for end token prediction
118
+ self.end_token_classifier = nn.Linear(512, 1)
119
+
120
+ self.post_init()
121
+
122
+ # Frozen Mimi components
123
+ self.encoder = model.encoder
124
+ self.encoder_transformer = model.encoder_transformer
125
+ self.quantizer = model.quantizer
126
+ self.downsample = model.downsample
127
+ self.upsample = model.upsample
128
+
129
+ # print the number of parameters for each sub network in Millions
130
+ self._print_subnetwork_parameter_counts()
131
+
132
+ def initialize_text_embeddings_from_weights(self, embedding_weight: torch.Tensor) -> None:
133
+ """
134
+ Initialize text embeddings from a weight matrix.
135
+
136
+ Args:
137
+ embedding_weight: Weight matrix of shape (vocab_size, 4096)
138
+ """
139
+ if embedding_weight.dim() != 2 or embedding_weight.size(1) != 4096:
140
+ raise ValueError("embedding_weight must have shape (vocab_size, 4096)")
141
+ if embedding_weight.size(0) != self.text_token_embedding.num_embeddings:
142
+ raise ValueError("Provided vocab_size does not match model's text_token_embedding")
143
+ with torch.no_grad():
144
+ self.text_token_embedding.weight.copy_(embedding_weight)
145
+ for p in self.text_token_embedding.parameters():
146
+ p.requires_grad = True
147
+
148
+ def initialize_text_embeddings_from_llama(self, llama_embeddings_module: torch.nn.Module) -> None:
149
+ """
150
+ Initialize text embeddings from a LLaMA embedding module.
151
+
152
+ Args:
153
+ llama_embeddings_module: LLaMA embedding module with weight shape (vocab_size, 4096)
154
+ """
155
+ if not hasattr(llama_embeddings_module, 'weight'):
156
+ raise ValueError("llama_embeddings_module must have a 'weight' attribute")
157
+ weight = llama_embeddings_module.weight.data
158
+ self.initialize_text_embeddings_from_weights(weight)
159
+
160
+ def _print_subnetwork_parameter_counts(self) -> None:
161
+ """Print parameter counts for model subnetworks."""
162
+ print("=" * 70)
163
+ print("TextSyncMimi Parameter Counts")
164
+ print("=" * 70)
165
+ print(f"Encoder: {sum(p.numel() for p in self.encoder.parameters()) / 1e6:.2f}M")
166
+ print(f"Encoder Transformer: {sum(p.numel() for p in self.encoder_transformer.parameters()) / 1e6:.2f}M")
167
+ print(f"Cross-Attention Transformer: {sum(p.numel() for p in self.cross_attention_transformer.parameters()) / 1e6:.2f}M")
168
+ print(f"AR Transformer: {sum(p.numel() for p in self.ar_transformer.parameters()) / 1e6:.2f}M")
169
+ print(f"Quantizer: {sum(p.numel() for p in self.quantizer.parameters()) / 1e6:.2f}M")
170
+ print("=" * 70)
171
+
172
+ def encode_audio_to_representation(
173
+ self,
174
+ input_values: torch.Tensor,
175
+ audio_attention_mask: Optional[torch.Tensor] = None,
176
+ ) -> torch.Tensor:
177
+ """
178
+ Encode audio to speech representation.
179
+
180
+ Args:
181
+ input_values: Audio waveform (B, 1, audio_len)
182
+ audio_attention_mask: Attention mask (B, audio_len)
183
+
184
+ Returns:
185
+ Speech embeddings (B, 512, 12.5 * T)
186
+ """
187
+ batch_size = input_values.shape[0]
188
+ device = input_values.device
189
+
190
+ # Encode through Mimi encoder pipeline
191
+ embeddings = self.encoder(input_values)
192
+ encoder_outputs = self.encoder_transformer(embeddings.transpose(1, 2))
193
+ embeddings = encoder_outputs[0].transpose(1, 2)
194
+ embeddings = self.downsample(embeddings)
195
+
196
+ # Apply attention mask if provided
197
+ if audio_attention_mask is not None:
198
+ speech_seq_len = embeddings.shape[-1]
199
+ speech_attention_mask = torch.zeros(batch_size, speech_seq_len, device=device, dtype=torch.bool)
200
+
201
+ for b in range(batch_size):
202
+ actual_audio_len = audio_attention_mask[b].sum().item()
203
+ actual_speech_len = int(actual_audio_len * 12.5 / 24000)
204
+ actual_speech_len = min(actual_speech_len, speech_seq_len)
205
+ if actual_speech_len > 0:
206
+ speech_attention_mask[b, :actual_speech_len] = True
207
+
208
+ speech_mask_expanded = speech_attention_mask.unsqueeze(1)
209
+ embeddings = embeddings * speech_mask_expanded.float()
210
+
211
+ return embeddings
212
+
213
+ def generate_autoregressive(
214
+ self,
215
+ text_token_ids: torch.LongTensor,
216
+ input_values: Optional[torch.Tensor] = None,
217
+ speech_embeddings: Optional[torch.Tensor] = None,
218
+ audio_attention_mask: Optional[torch.Tensor] = None,
219
+ speech_attention_mask: Optional[torch.Tensor] = None,
220
+ text_attention_mask: Optional[torch.Tensor] = None,
221
+ max_z_tokens: int = 50,
222
+ end_token_threshold: float = 0.5,
223
+ device: Optional[torch.device] = None,
224
+ ) -> List[List[torch.Tensor]]:
225
+ """
226
+ Generate audio autoregressively.
227
+
228
+ Args:
229
+ text_token_ids: Text token IDs (B, L)
230
+ input_values: Audio input (B, 1, 24000 * T) - for normal mode
231
+ speech_embeddings: Pre-computed speech embeddings (B, T, 512) - for cached mode
232
+ audio_attention_mask: Audio mask (B, audio_seq_len) - for normal mode
233
+ speech_attention_mask: Speech mask (B, speech_seq_len) - for cached mode
234
+ text_attention_mask: Text mask (B, text_seq_len)
235
+ max_z_tokens: Maximum z tokens per text position
236
+ end_token_threshold: Probability threshold for stopping
237
+ device: Device for computation
238
+
239
+ Returns:
240
+ List of z_tokens lists (one per batch item)
241
+ """
242
+ if device is None:
243
+ device = text_token_ids.device
244
+
245
+ self.eval()
246
+
247
+ with torch.no_grad():
248
+ # Get speech embeddings for cross-attention context
249
+ if speech_embeddings is not None:
250
+ # Use pre-computed speech embeddings (cached mode)
251
+ # speech_embeddings should already be (B, T, 512)
252
+ pass # speech_embeddings is already provided
253
+ else:
254
+ # Compute speech embeddings from input_values (normal mode)
255
+ if input_values is None:
256
+ raise ValueError("Either input_values or speech_embeddings must be provided")
257
+ speech_embeddings = self.encode_audio_to_representation(
258
+ input_values,
259
+ audio_attention_mask=audio_attention_mask
260
+ )
261
+ speech_embeddings = speech_embeddings.transpose(1, 2) # (B, T, 512)
262
+
263
+ # Embed token ids then project to 512
264
+ text_embeddings_4096 = self.text_token_embedding(text_token_ids) # (B, L, 4096)
265
+ text_embeddings_proj = self.text_proj(text_embeddings_4096) # (B, L, 512)
266
+
267
+ # Apply cross attention (same as in forward)
268
+ # Create attention masks
269
+ formatted_text_attention_mask = None
270
+ formatted_speech_attention_mask = None
271
+
272
+ batch_size, text_seq_len = text_embeddings_proj.shape[:2]
273
+
274
+ if text_attention_mask is not None:
275
+ causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_embeddings_proj.dtype))
276
+ causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
277
+ padding_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len)
278
+ combined_mask = causal_mask * padding_mask
279
+ formatted_text_attention_mask = torch.where(combined_mask.bool(), 0.0, float('-inf'))
280
+ else:
281
+ causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_embeddings_proj.dtype))
282
+ causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
283
+ formatted_text_attention_mask = torch.where(causal_mask.bool(), 0.0, float('-inf'))
284
+
285
+ # Handle speech attention mask (use speech_attention_mask if available, otherwise audio_attention_mask)
286
+ if speech_attention_mask is not None:
287
+ # For cached data, speech_attention_mask is already in the right format
288
+ speech_seq_len = speech_embeddings.shape[1]
289
+ speech_mask = speech_attention_mask.bool()
290
+ formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
291
+ formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
292
+ elif audio_attention_mask is not None:
293
+ # For non-cached data, convert audio_attention_mask to speech_attention_mask
294
+ speech_seq_len = speech_embeddings.shape[1]
295
+ speech_mask = torch.zeros(batch_size, speech_seq_len, dtype=torch.bool, device=device)
296
+ for b in range(batch_size):
297
+ audio_len = audio_attention_mask[b].sum().item()
298
+ speech_len = int(audio_len * 12.5 / 24000)
299
+ speech_len = min(speech_len, speech_seq_len)
300
+ speech_mask[b, :speech_len] = True
301
+ formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
302
+ formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
303
+ else:
304
+ formatted_speech_attention_mask = None
305
+
306
+ # Cross attention
307
+ cross_attention_outputs = self.cross_attention_transformer(
308
+ hidden_states=text_embeddings_proj,
309
+ encoder_hidden_states=speech_embeddings,
310
+ attention_mask=formatted_text_attention_mask,
311
+ encoder_attention_mask=formatted_speech_attention_mask,
312
+ alignment_chunk_sizes=None, # V1 learns alignment
313
+ )
314
+ cross_attention_outputs = cross_attention_outputs.last_hidden_state
315
+
316
+ # Get special embeddings
317
+ text_speech_latent_emb = self.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device))
318
+ time_speech_start_emb = self.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device))
319
+ time_speech_end_emb = self.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device))
320
+
321
+ generated_z_tokens = []
322
+
323
+ # Generate for each batch item
324
+ for b in range(batch_size):
325
+ # Get valid text length for this sample
326
+ if text_attention_mask is not None:
327
+ valid_text_len = text_attention_mask[b].sum().item()
328
+ else:
329
+ valid_text_len = text_embeddings_proj.shape[1]
330
+
331
+ # Start sequence with text_speech_latent for context
332
+ sequence = [text_speech_latent_emb] # (1, 512)
333
+ batch_z_tokens = [] # Store z_tokens for this batch item
334
+
335
+ # Generate for each text position
336
+ for i in range(valid_text_len):
337
+ # Add t_i and s_i
338
+ t_i = text_embeddings_proj[b, i:i+1] # (1, 512)
339
+ s_i = cross_attention_outputs[b, i:i+1] # (1, 512)
340
+ sequence.extend([t_i, s_i])
341
+
342
+ # Add time_speech_start
343
+ sequence.append(time_speech_start_emb)
344
+
345
+ # Generate z tokens autoregressively for this text position
346
+ z_count = 0
347
+ while z_count < max_z_tokens:
348
+ # Prepare current sequence for AR transformer
349
+ current_sequence = torch.cat(sequence, dim=0).unsqueeze(0) # (1, seq_len, 512)
350
+
351
+ # Create attention mask for current sequence
352
+ seq_len = current_sequence.shape[1]
353
+ ar_attention_mask = torch.ones(1, seq_len, dtype=torch.bool, device=device)
354
+
355
+ # Get prediction from AR transformer
356
+ ar_outputs = self.ar_transformer(
357
+ hidden_states=current_sequence,
358
+ attention_mask=ar_attention_mask,
359
+ )
360
+
361
+ # Get the last prediction
362
+ last_prediction = ar_outputs.last_hidden_state[0, -1:, :] # (1, 512)
363
+
364
+ # Check stopping condition using BCE classifier (v1.1)
365
+ end_token_logit = self.end_token_classifier(last_prediction).squeeze(-1) # (1,)
366
+ end_token_prob = torch.sigmoid(end_token_logit).item() # Convert to probability
367
+
368
+ # Stop if probability is high enough (>= threshold means stop)
369
+ if end_token_prob >= end_token_threshold:
370
+ # Stop generating z tokens
371
+ break
372
+ else:
373
+ # Add this prediction as next z token to both sequence (for context) and z_tokens (for output)
374
+ sequence.append(last_prediction)
375
+ batch_z_tokens.append(last_prediction.squeeze(0)) # Remove batch dimension for output
376
+ z_count += 1
377
+
378
+ # Add time_speech_end to sequence for context
379
+ sequence.append(time_speech_end_emb)
380
+
381
+ # Store z_tokens for this batch item
382
+ generated_z_tokens.append(batch_z_tokens)
383
+
384
+ return generated_z_tokens
385
+
386
+ def forward(
387
+ self,
388
+ text_token_ids: torch.LongTensor,
389
+ input_values: Optional[torch.Tensor] = None,
390
+ speech_embeddings: Optional[torch.Tensor] = None,
391
+ alignment_chunk_sizes: torch.Tensor = None,
392
+ audio_attention_mask: Optional[torch.Tensor] = None,
393
+ speech_attention_mask: Optional[torch.Tensor] = None,
394
+ text_attention_mask: Optional[torch.Tensor] = None,
395
+ **kwargs,
396
+ ) -> Dict[str, torch.Tensor]:
397
+ """
398
+ Forward pass for training.
399
+
400
+ Args:
401
+ text_token_ids: Text token IDs (B, L)
402
+ input_values: Audio input (B, 1, 24000 * T) - for normal mode
403
+ speech_embeddings: Pre-computed speech embeddings (B, T, 512) - for cached mode
404
+ alignment_chunk_sizes: Alignment chunk sizes (B, L)
405
+ audio_attention_mask: Audio mask (B, audio_seq_len)
406
+ speech_attention_mask: Speech mask (B, speech_seq_len)
407
+ text_attention_mask: Text mask (B, text_seq_len)
408
+
409
+ Returns:
410
+ Dictionary with 'loss', 'reconstruction_loss', and 'bce_end_token_loss'
411
+ """
412
+ # Get speech embeddings
413
+ if speech_embeddings is not None:
414
+ pass
415
+ elif input_values is not None:
416
+ # Normal mode: compute speech embeddings from input_values
417
+
418
+ speech_embeddings_raw = self.encode_audio_to_representation(
419
+ input_values,
420
+ audio_attention_mask
421
+ )
422
+ # speech_embeddings_raw.shape = (B, 512, 12.5*T)
423
+ # Transpose: [B, 512, 12.5*T] -> [B, 12.5*T, 512]
424
+ speech_embeddings = speech_embeddings_raw.transpose(1, 2)
425
+ else:
426
+ raise ValueError("Either input_values or speech_embeddings must be provided")
427
+ # Embed token ids and project to 512-dim
428
+ text_embeddings_4096 = self.text_token_embedding(text_token_ids) # (B, L, 4096)
429
+ text_embeddings = self.text_proj(text_embeddings_4096) # (B, L, 512)
430
+
431
+ # Create proper attention masks for cross-attention
432
+ formatted_text_attention_mask = None
433
+ formatted_speech_attention_mask = None
434
+
435
+ # Handle text attention mask (causal mask for decoder)
436
+ batch_size, text_seq_len = text_embeddings.shape[:2]
437
+
438
+ if text_attention_mask is not None:
439
+ # Create causal mask and apply padding mask
440
+ causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=text_embeddings.device, dtype=text_embeddings.dtype))
441
+ causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
442
+
443
+ # Apply padding mask to causal mask
444
+ padding_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len)
445
+ combined_mask = causal_mask * padding_mask
446
+
447
+ # Convert to attention scores (-inf for masked positions)
448
+ formatted_text_attention_mask = torch.where(combined_mask.bool(), 0.0, float('-inf'))
449
+ else:
450
+ # Create causal mask for all positions (no padding mask)
451
+ causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=text_embeddings.device, dtype=text_embeddings.dtype))
452
+ causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
453
+ formatted_text_attention_mask = torch.where(causal_mask.bool(), 0.0, float('-inf'))
454
+
455
+ # Handle speech attention mask (encoder mask)
456
+ # Use speech_attention_mask if available (cached mode), otherwise audio_attention_mask (normal mode)
457
+ if speech_attention_mask is not None:
458
+ # Cached mode: speech_attention_mask is already in the right format
459
+ speech_seq_len = speech_embeddings.shape[1]
460
+ speech_mask = speech_attention_mask.bool()
461
+
462
+ # Convert to attention format: [batch_size, 1, 1, speech_seq_len]
463
+ formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
464
+ formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
465
+ elif audio_attention_mask is not None:
466
+ # Normal mode: convert audio mask to speech embedding mask
467
+ speech_seq_len = speech_embeddings.shape[1]
468
+
469
+ # Create speech attention mask based on actual lengths
470
+ speech_mask = torch.zeros(batch_size, speech_seq_len, dtype=torch.bool, device=speech_embeddings.device)
471
+
472
+ for b in range(batch_size):
473
+ audio_len = audio_attention_mask[b].sum().item()
474
+ speech_len = int(audio_len * 12.5 / 24000)
475
+ speech_len = min(speech_len, speech_seq_len)
476
+ speech_mask[b, :speech_len] = True
477
+
478
+ # Convert to attention format: [batch_size, 1, 1, speech_seq_len]
479
+ formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
480
+ formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
481
+ else:
482
+ # No masking
483
+ formatted_speech_attention_mask = None
484
+
485
+ # Cross attention: text attends to speech (no alignment constraints in V1)
486
+ # hidden_states (decoder) = text, encoder_hidden_states = speech
487
+ cross_attention_outputs = self.cross_attention_transformer(
488
+ hidden_states=text_embeddings,
489
+ encoder_hidden_states=speech_embeddings,
490
+ attention_mask=formatted_text_attention_mask, # Causal mask for text (decoder)
491
+ encoder_attention_mask=formatted_speech_attention_mask, # Mask for speech (encoder)
492
+ alignment_chunk_sizes=None, # v1 doesn't use alignment_chunk_sizes -- the model should learn the alignment itself
493
+ )
494
+ cross_attention_outputs = cross_attention_outputs.last_hidden_state
495
+
496
+ # Auto-regressive decoder part
497
+ # Following v0.5 where the target is the dequantized Mimi decoder-input
498
+ # Compute target representation = Mimi decoder-input (quantized->dequantized at 12.5*seconds)
499
+ # 12.5*seconds => T
500
+ with torch.no_grad():
501
+ embeddings_bct = speech_embeddings.transpose(1, 2) # (B, 512, T)
502
+ codes_kbt = self.quantizer.encode(embeddings_bct) # [K, B, T]
503
+ codes_bkt = codes_kbt.transpose(0, 1) # [B, K, T]
504
+ decoder_input_emb = self.quantizer.decode(codes_bkt) # (B, 512, T)
505
+ target_representation = decoder_input_emb.transpose(1, 2) # (B, T, 512)
506
+
507
+ # Build the interleaved sequence for the autoregressive decoder
508
+ # as well as the mask for loss computation
509
+ # Get special embeddings (all are single embeddings)
510
+ device = text_embeddings.device
511
+ text_speech_latent_emb = self.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device)) # (1, 512)
512
+ time_speech_start_emb = self.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device)) # (1, 512)
513
+ time_speech_end_emb = self.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device)) # (1, 512)
514
+
515
+ batch_size = text_embeddings.shape[0]
516
+ interleaved_sequences = []
517
+ loss_masks = []
518
+ bce_labels_batch = [] # BCE labels: 0 for z tokens, 1 for time_speech_end_emb
519
+ bce_masks = [] # BCE mask: True for z tokens and time_speech_end_emb
520
+ sequence_lengths = [] # Track actual sequence lengths before padding
521
+ all_z_tokens = [] # Collect all valid z_tokens for separation loss
522
+ max_total_length = 0
523
+
524
+ for b in range(batch_size):
525
+ # Start with text_speech_latent embedding
526
+ sequence_parts = [text_speech_latent_emb] # List to collect sequence parts
527
+ loss_mask_parts = [False] # Don't compute loss on special tokens
528
+ bce_label_parts = [0] # BCE labels (dummy for text_speech_latent_emb)
529
+ bce_mask_parts = [False] # BCE mask (False for text_speech_latent_emb)
530
+
531
+ # Get valid text length for this batch item
532
+ if text_attention_mask is not None:
533
+ valid_text_len = text_attention_mask[b].sum().item()
534
+ else:
535
+ valid_text_len = text_embeddings.shape[1]
536
+
537
+ # Track current position in target_representation
538
+ speech_position = 0
539
+
540
+ # For each text token
541
+ for i in range(valid_text_len):
542
+ # Add t_i (text embedding)
543
+ t_i = text_embeddings[b, i:i+1] # (1, 512)
544
+ sequence_parts.append(t_i)
545
+ loss_mask_parts.append(False)
546
+ bce_label_parts.append(0) # Dummy label for t_i
547
+ bce_mask_parts.append(False) # No BCE loss for t_i
548
+
549
+ # Add s_i (cross attention output)
550
+ s_i = cross_attention_outputs[b, i:i+1] # (1, 512)
551
+ sequence_parts.append(s_i)
552
+ loss_mask_parts.append(False)
553
+ bce_label_parts.append(0) # Dummy label for s_i
554
+ bce_mask_parts.append(False) # No BCE loss for s_i
555
+
556
+ # Add time_speech_start
557
+ sequence_parts.append(time_speech_start_emb)
558
+ loss_mask_parts.append(False)
559
+ bce_label_parts.append(0) # Dummy label for time_speech_start
560
+ bce_mask_parts.append(False) # No BCE loss for time_speech_start
561
+
562
+ # Add z tokens for this chunk
563
+ chunk_size = alignment_chunk_sizes[b, i].item()
564
+ if chunk_size > 0: # Only add if chunk size is positive
565
+ end_position = speech_position + chunk_size
566
+ # Make sure we don't exceed target_representation length
567
+ end_position = min(end_position, target_representation.shape[1])
568
+ actual_chunk_size = end_position - speech_position
569
+
570
+ if actual_chunk_size > 0:
571
+ z_tokens = target_representation[b, speech_position:end_position] # (actual_chunk_size, 512)
572
+ sequence_parts.append(z_tokens)
573
+ loss_mask_parts.extend([True] * actual_chunk_size) # Compute loss on z tokens
574
+ bce_label_parts.extend([0] * actual_chunk_size) # Label 0 for z tokens
575
+ bce_mask_parts.extend([True] * actual_chunk_size) # Compute BCE loss on z tokens
576
+
577
+ # Collect z_tokens for separation loss computation
578
+ all_z_tokens.append(z_tokens)
579
+
580
+ speech_position = end_position
581
+
582
+ # Add time_speech_end
583
+ sequence_parts.append(time_speech_end_emb)
584
+ loss_mask_parts.append(False)
585
+ bce_label_parts.append(1)
586
+ bce_mask_parts.append(True)
587
+
588
+ # Concatenate all parts for this batch item
589
+ full_sequence = torch.cat(sequence_parts, dim=0) # (total_length, 512)
590
+ loss_mask = torch.tensor(loss_mask_parts, dtype=torch.bool, device=device)
591
+ bce_labels = torch.tensor(bce_label_parts, dtype=torch.float, device=device)
592
+ bce_mask = torch.tensor(bce_mask_parts, dtype=torch.bool, device=device)
593
+
594
+ interleaved_sequences.append(full_sequence)
595
+ loss_masks.append(loss_mask)
596
+ bce_labels_batch.append(bce_labels)
597
+ bce_masks.append(bce_mask)
598
+ sequence_lengths.append(full_sequence.shape[0]) # Track actual length before padding
599
+ max_total_length = max(max_total_length, full_sequence.shape[0])
600
+
601
+ # Pad sequences
602
+ padded_sequences = []
603
+ padded_loss_masks = []
604
+ padded_bce_labels = []
605
+ padded_bce_masks = []
606
+
607
+ for sequence, loss_mask, bce_labels, bce_mask in zip(interleaved_sequences, loss_masks, bce_labels_batch, bce_masks):
608
+ current_length = sequence.shape[0]
609
+ if current_length < max_total_length:
610
+ padding = torch.zeros(max_total_length - current_length, 512, device=device, dtype=sequence.dtype)
611
+ padded_sequence = torch.cat([sequence, padding], dim=0)
612
+
613
+ mask_padding = torch.zeros(max_total_length - current_length, dtype=torch.bool, device=device)
614
+ padded_mask = torch.cat([loss_mask, mask_padding], dim=0)
615
+
616
+ bce_label_padding = torch.zeros(max_total_length - current_length, dtype=torch.float, device=device)
617
+ padded_bce_label = torch.cat([bce_labels, bce_label_padding], dim=0)
618
+
619
+ bce_mask_padding = torch.zeros(max_total_length - current_length, dtype=torch.bool, device=device)
620
+ padded_bce_mask = torch.cat([bce_mask, bce_mask_padding], dim=0)
621
+ else:
622
+ padded_sequence = sequence
623
+ padded_mask = loss_mask
624
+ padded_bce_label = bce_labels
625
+ padded_bce_mask = bce_mask
626
+
627
+ padded_sequences.append(padded_sequence)
628
+ padded_loss_masks.append(padded_mask)
629
+ padded_bce_labels.append(padded_bce_label)
630
+ padded_bce_masks.append(padded_bce_mask)
631
+
632
+ # Stack into batch tensors
633
+ interleaved_batch = torch.stack(padded_sequences, dim=0) # (batch_size, max_total_length, 512)
634
+ loss_mask_batch = torch.stack(padded_loss_masks, dim=0) # (batch_size, max_total_length)
635
+ bce_labels_batch_tensor = torch.stack(padded_bce_labels, dim=0) # (batch_size, max_total_length)
636
+ bce_mask_batch = torch.stack(padded_bce_masks, dim=0) # (batch_size, max_total_length)
637
+
638
+ # Autoregressive prediction
639
+ if max_total_length > 1:
640
+ ar_input = interleaved_batch[:, :-1, :] # (batch_size, max_total_length-1, 512)
641
+ ar_targets = interleaved_batch[:, 1:, :] # (batch_size, max_total_length-1, 512)
642
+ ar_loss_mask = loss_mask_batch[:, 1:] # (batch_size, max_total_length-1) - shift mask left
643
+ ar_bce_labels = bce_labels_batch_tensor[:, 1:] # (batch_size, max_total_length-1) - shift labels left
644
+ ar_bce_mask = bce_mask_batch[:, 1:] # (batch_size, max_total_length-1) - shift mask left
645
+
646
+ # Create attention mask for autoregressive transformer
647
+ # We need to mask padded positions while maintaining causal property
648
+ ar_seq_len = ar_input.shape[1]
649
+ ar_attention_mask = torch.zeros(batch_size, ar_seq_len, dtype=torch.bool, device=device)
650
+ for b in range(batch_size):
651
+ valid_len = min(ar_seq_len, sequence_lengths[b] - 1)
652
+ if valid_len > 0:
653
+ ar_attention_mask[b, :valid_len] = True
654
+
655
+ ar_outputs = self.ar_transformer(
656
+ hidden_states=ar_input,
657
+ attention_mask=ar_attention_mask, # This will be combined with causal mask inside transformer
658
+ )
659
+ ar_predictions = ar_outputs.last_hidden_state # (batch_size, max_total_length-1, 512)
660
+
661
+ # Compute BCE predictions for end token classification
662
+ bce_logits = self.end_token_classifier(ar_predictions).squeeze(-1) # (batch_size, max_total_length-1)
663
+
664
+ # Compute L2 loss only where ar_loss_mask is True (z tokens)
665
+ if ar_loss_mask.any():
666
+ # Extract valid positions for loss computation
667
+ valid_predictions = ar_predictions[ar_loss_mask] # (num_valid_positions, 512)
668
+ valid_targets = ar_targets[ar_loss_mask] # (num_valid_positions, 512)
669
+
670
+ # Compute L2 loss (MSE)
671
+ reconstruction_loss = nn.functional.mse_loss(
672
+ valid_predictions,
673
+ valid_targets,
674
+ reduction='mean'
675
+ )
676
+ else:
677
+ # Fallback if no valid positions (shouldn't happen in practice)
678
+ reconstruction_loss = torch.tensor(0.0, device=device, requires_grad=True)
679
+
680
+ # Compute BCE loss for end token classification (v1.1)
681
+ if ar_bce_mask.any():
682
+ # Extract valid positions for BCE loss computation
683
+ valid_bce_logits = bce_logits[ar_bce_mask] # (num_valid_bce_positions,)
684
+ valid_bce_labels = ar_bce_labels[ar_bce_mask] # (num_valid_bce_positions,)
685
+
686
+ # Compute BCE loss
687
+ bce_end_token_loss = nn.functional.binary_cross_entropy_with_logits(
688
+ valid_bce_logits,
689
+ valid_bce_labels,
690
+ reduction='mean'
691
+ )
692
+ else:
693
+ # Fallback if no valid BCE positions
694
+ bce_end_token_loss = torch.tensor(0.0, device=device, requires_grad=True)
695
+
696
+ if self.bce_threshold > 0.0:
697
+ clamped_bce_loss = torch.clamp(bce_end_token_loss - self.bce_threshold, min=0.0)
698
+ total_loss = reconstruction_loss + self.alpha * clamped_bce_loss
699
+ else:
700
+ total_loss = reconstruction_loss + self.alpha * bce_end_token_loss
701
+ else:
702
+ reconstruction_loss = torch.tensor(0.0, device=device, requires_grad=True)
703
+ bce_end_token_loss = torch.tensor(0.0, device=device, requires_grad=True)
704
+ total_loss = reconstruction_loss + torch.tensor(0.0, device=device, requires_grad=True)
705
+
706
+ return {
707
+ 'loss': total_loss,
708
+ 'reconstruction_loss': reconstruction_loss,
709
+ 'bce_end_token_loss': bce_end_token_loss,
710
+ }
711
+
712
+
713
+ __all__ = ["TextSyncMimi"]