File size: 37,729 Bytes
d7a2a0f
 
 
 
 
 
e72ddf8
 
94c52d0
e72ddf8
 
 
 
 
 
 
94c52d0
e72ddf8
 
 
 
 
d7a2a0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94c52d0
 
d7a2a0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
"""PyTorch TextSyncMimi model - Text-synchronous neural audio codec based on Mimi."""

import torch
import torch.nn as nn
from typing import Optional, Dict, List, Union

try:
    from .configuration_mimi import MimiConfig
    from .configuration_text_sync_mimi import TextSyncMimiConfig
    from .modeling_mimi_clean import MimiPreTrainedModel, MimiModel
    from .modeling_backbone_components import (
        CrossAttentionTransformer,
        CausalAttentionTransformer
    )
except ImportError:
    from configuration_mimi import MimiConfig
    from configuration_text_sync_mimi import TextSyncMimiConfig
    from modeling_mimi_clean import MimiPreTrainedModel, MimiModel
    from modeling_backbone_components import (
        CrossAttentionTransformer,
        CausalAttentionTransformer
    )


class TextSyncMimi(MimiPreTrainedModel):
    """
    TextSyncMimi: Text-Synchronous Neural Audio Codec Model
    
    A neural audio codec model that combines text and speech representations for
    high-quality text-to-speech synthesis. Features:
    
    - Learnable text embeddings
    - Cross-attention transformer for text-speech alignment
    - Autoregressive transformer for causal speech generation
    - BCE-based end token prediction for dynamic duration control
    
    Architecture:
    - Text Embedding Layer: Maps token IDs to 4,096-dim embeddings
    - Mimi Encoder: Pre-trained audio encoder (frozen)
    - Text Projection: Linear projection from 4,096 to 512 dimensions
    - Cross-Attention Transformer: Aligns text with speech features
    - Autoregressive Transformer: Generates speech representations
    - End Token Classifier: Predicts when to stop generating
    """
    
    config_class = TextSyncMimiConfig
    
    def __init__(
        self, 
        config: Optional[Union[MimiConfig, 'TextSyncMimiConfig']] = None,
        model_id: Optional[str] = None,
        token: Optional[str] = None,
        alpha: Optional[float] = None,
        cross_attention_layers: Optional[int] = None,
        causal_attention_layers: Optional[int] = None,
        bce_threshold: Optional[float] = None,
        vocab_size: Optional[int] = None,
    ):
        """
        Initialize TextSyncMimi model.
        
        Args:
            config: Model configuration (TextSyncMimiConfig or MimiConfig)
            model_id: Mimi model ID (e.g., "kyutai/mimi"). If None, uses config.mimi_model_id
            token: Hugging Face authentication token
            alpha: Weight for BCE end token loss. If None, uses config.alpha
            cross_attention_layers: Number of cross-attention layers. If None, uses config
            causal_attention_layers: Number of autoregressive layers. If None, uses config
            bce_threshold: BCE loss threshold. If None, uses config.bce_threshold
            vocab_size: Text vocabulary size. If None, uses config.vocab_size
        """
        # Handle config initialization for both manual instantiation and from_pretrained
        if config is None:
            if model_id is None:
                raise ValueError("Either config or model_id must be provided")
            config = MimiConfig.from_pretrained(model_id, token=token)
        
        super().__init__(config)
        
        # Extract parameters from config if not explicitly provided
        if hasattr(config, 'mimi_model_id'):
            model_id = model_id or config.mimi_model_id
        if model_id is None:
            raise ValueError("model_id must be provided either as argument or in config.mimi_model_id")
        
        alpha = alpha if alpha is not None else getattr(config, 'alpha', 1.0)
        cross_attention_layers = cross_attention_layers if cross_attention_layers is not None else getattr(config, 'cross_attention_layers', 2)
        causal_attention_layers = causal_attention_layers if causal_attention_layers is not None else getattr(config, 'causal_attention_layers', 2)
        bce_threshold = bce_threshold if bce_threshold is not None else getattr(config, 'bce_threshold', 0.1)
        vocab_size = vocab_size if vocab_size is not None else getattr(config, 'vocab_size', 128256)

        # load the mimi backbone
        self.config = config
        model = MimiModel.from_pretrained(model_id, token=token)

        # hyperparameters for auxiliary loss
        self.alpha = alpha
        self.bce_threshold = bce_threshold

        # Learnable text token embedding
        self.text_token_embedding = nn.Embedding(vocab_size, 4096)

        # Text projection
        self.text_proj = nn.Linear(4096, 512)
        
        # Cross-attention transformer
        cross_attention_config = MimiConfig(**self.config.__dict__)
        cross_attention_config.num_hidden_layers = cross_attention_layers
        cross_attention_config.hidden_size = 512
        self.cross_attention_transformer = CrossAttentionTransformer(cross_attention_config)

        # decoder part (v1)
        # Auto-regressive decoder:
        # <|text_speech_latent|> [t_i] [s_i] <|time_speech_start|> [z_(i,1)] [z_(i,2)] ... [z_(i,K)] <|time_speech_end|>
        # masking (not computing loss for <|text_speech_latent|> [t_i] [s_i] <|time_speech_start|>
        # t_i already mapped from 4096 (e.g., llama embedding) -> 512
        # s_i already 512
        # z is mimi's decoder-input which is also 512
        causal_attention_config = MimiConfig(**self.config.__dict__)
        causal_attention_config.num_hidden_layers = causal_attention_layers
        causal_attention_config.hidden_size = 512
        self.ar_transformer = CausalAttentionTransformer(causal_attention_config)

        # embedding for special positions in the autoregressive decoder
        self.text_speech_latent_embed = nn.Embedding(1, 512)
        self.time_speech_start_embed = nn.Embedding(1, 512)
        self.time_speech_end_embed = nn.Embedding(1, 512)

        # Binary classification head for end token prediction
        self.end_token_classifier = nn.Linear(512, 1)

        self.post_init()

        # Frozen Mimi components
        self.encoder = model.encoder
        self.encoder_transformer = model.encoder_transformer
        self.quantizer = model.quantizer
        self.downsample = model.downsample
        self.upsample = model.upsample

        # print the number of parameters for each sub network in Millions
        self._print_subnetwork_parameter_counts()

    def initialize_text_embeddings_from_weights(self, embedding_weight: torch.Tensor) -> None:
        """
        Initialize text embeddings from a weight matrix.
        
        Args:
            embedding_weight: Weight matrix of shape (vocab_size, 4096)
        """
        if embedding_weight.dim() != 2 or embedding_weight.size(1) != 4096:
            raise ValueError("embedding_weight must have shape (vocab_size, 4096)")
        if embedding_weight.size(0) != self.text_token_embedding.num_embeddings:
            raise ValueError("Provided vocab_size does not match model's text_token_embedding")
        with torch.no_grad():
            self.text_token_embedding.weight.copy_(embedding_weight)
        for p in self.text_token_embedding.parameters():
            p.requires_grad = True

    def initialize_text_embeddings_from_llama(self, llama_embeddings_module: torch.nn.Module) -> None:
        """
        Initialize text embeddings from a LLaMA embedding module.
        
        Args:
            llama_embeddings_module: LLaMA embedding module with weight shape (vocab_size, 4096)
        """
        if not hasattr(llama_embeddings_module, 'weight'):
            raise ValueError("llama_embeddings_module must have a 'weight' attribute")
        weight = llama_embeddings_module.weight.data
        self.initialize_text_embeddings_from_weights(weight)

    def _print_subnetwork_parameter_counts(self) -> None:
        """Print parameter counts for model subnetworks."""
        print("=" * 70)
        print("TextSyncMimi Parameter Counts")
        print("=" * 70)
        print(f"Encoder: {sum(p.numel() for p in self.encoder.parameters()) / 1e6:.2f}M")
        print(f"Encoder Transformer: {sum(p.numel() for p in self.encoder_transformer.parameters()) / 1e6:.2f}M")
        print(f"Cross-Attention Transformer: {sum(p.numel() for p in self.cross_attention_transformer.parameters()) / 1e6:.2f}M")
        print(f"AR Transformer: {sum(p.numel() for p in self.ar_transformer.parameters()) / 1e6:.2f}M")
        print(f"Quantizer: {sum(p.numel() for p in self.quantizer.parameters()) / 1e6:.2f}M")
        print("=" * 70)

    def encode_audio_to_representation(
        self,
        input_values: torch.Tensor,
        audio_attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Encode audio to speech representation.
        
        Args:
            input_values: Audio waveform (B, 1, audio_len)
            audio_attention_mask: Attention mask (B, audio_len)
            
        Returns:
            Speech embeddings (B, 512, 12.5 * T)
        """
        batch_size = input_values.shape[0]
        device = input_values.device
        
        # Encode through Mimi encoder pipeline
        embeddings = self.encoder(input_values)
        encoder_outputs = self.encoder_transformer(embeddings.transpose(1, 2))
        embeddings = encoder_outputs[0].transpose(1, 2)
        embeddings = self.downsample(embeddings)
        
        # Apply attention mask if provided
        if audio_attention_mask is not None:
            speech_seq_len = embeddings.shape[-1]
            speech_attention_mask = torch.zeros(batch_size, speech_seq_len, device=device, dtype=torch.bool)
            
            for b in range(batch_size):
                actual_audio_len = audio_attention_mask[b].sum().item()
                actual_speech_len = int(actual_audio_len * 12.5 / 24000)
                actual_speech_len = min(actual_speech_len, speech_seq_len)
                if actual_speech_len > 0:
                    speech_attention_mask[b, :actual_speech_len] = True
            
            speech_mask_expanded = speech_attention_mask.unsqueeze(1)
            embeddings = embeddings * speech_mask_expanded.float()
            
        return embeddings

    def generate_autoregressive(
        self,
        text_token_ids: torch.LongTensor,
        input_values: Optional[torch.Tensor] = None,
        speech_embeddings: Optional[torch.Tensor] = None,
        audio_attention_mask: Optional[torch.Tensor] = None,
        speech_attention_mask: Optional[torch.Tensor] = None,
        text_attention_mask: Optional[torch.Tensor] = None,
        max_z_tokens: int = 50,
        end_token_threshold: float = 0.5,
        device: Optional[torch.device] = None,
    ) -> List[List[torch.Tensor]]:
        """
        Generate audio autoregressively.
        
        Args:
            text_token_ids: Text token IDs (B, L)
            input_values: Audio input (B, 1, 24000 * T) - for normal mode
            speech_embeddings: Pre-computed speech embeddings (B, T, 512) - for cached mode
            audio_attention_mask: Audio mask (B, audio_seq_len) - for normal mode
            speech_attention_mask: Speech mask (B, speech_seq_len) - for cached mode
            text_attention_mask: Text mask (B, text_seq_len)
            max_z_tokens: Maximum z tokens per text position
            end_token_threshold: Probability threshold for stopping
            device: Device for computation
            
        Returns:
            List of z_tokens lists (one per batch item)
        """
        if device is None:
            device = text_token_ids.device
            
        self.eval()
        
        with torch.no_grad():
            # Get speech embeddings for cross-attention context
            if speech_embeddings is not None:
                # Use pre-computed speech embeddings (cached mode)
                # speech_embeddings should already be (B, T, 512)
                pass  # speech_embeddings is already provided
            else:
                # Compute speech embeddings from input_values (normal mode)
                if input_values is None:
                    raise ValueError("Either input_values or speech_embeddings must be provided")
                speech_embeddings = self.encode_audio_to_representation(
                    input_values, 
                    audio_attention_mask=audio_attention_mask
                )
                speech_embeddings = speech_embeddings.transpose(1, 2)  # (B, T, 512)
            
            # Embed token ids then project to 512
            text_embeddings_4096 = self.text_token_embedding(text_token_ids)  # (B, L, 4096)
            text_embeddings_proj = self.text_proj(text_embeddings_4096)  # (B, L, 512)
            
            # Apply cross attention (same as in forward)
            # Create attention masks
            formatted_text_attention_mask = None
            formatted_speech_attention_mask = None
            
            batch_size, text_seq_len = text_embeddings_proj.shape[:2]
            
            if text_attention_mask is not None:
                causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_embeddings_proj.dtype))
                causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
                padding_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len)
                combined_mask = causal_mask * padding_mask
                formatted_text_attention_mask = torch.where(combined_mask.bool(), 0.0, float('-inf'))
            else:
                causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_embeddings_proj.dtype))
                causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
                formatted_text_attention_mask = torch.where(causal_mask.bool(), 0.0, float('-inf'))
            
            # Handle speech attention mask (use speech_attention_mask if available, otherwise audio_attention_mask)
            if speech_attention_mask is not None:
                # For cached data, speech_attention_mask is already in the right format
                speech_seq_len = speech_embeddings.shape[1]
                speech_mask = speech_attention_mask.bool()
                formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
                formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
            elif audio_attention_mask is not None:
                # For non-cached data, convert audio_attention_mask to speech_attention_mask
                speech_seq_len = speech_embeddings.shape[1]
                speech_mask = torch.zeros(batch_size, speech_seq_len, dtype=torch.bool, device=device)
                for b in range(batch_size):
                    audio_len = audio_attention_mask[b].sum().item()
                    speech_len = int(audio_len * 12.5 / 24000)
                    speech_len = min(speech_len, speech_seq_len)
                    speech_mask[b, :speech_len] = True
                formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
                formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
            else:
                formatted_speech_attention_mask = None
            
            # Cross attention
            cross_attention_outputs = self.cross_attention_transformer(
                hidden_states=text_embeddings_proj,
                encoder_hidden_states=speech_embeddings,
                attention_mask=formatted_text_attention_mask,
                encoder_attention_mask=formatted_speech_attention_mask,
                alignment_chunk_sizes=None,  # V1 learns alignment
            )
            cross_attention_outputs = cross_attention_outputs.last_hidden_state
            
            # Get special embeddings
            text_speech_latent_emb = self.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device))
            time_speech_start_emb = self.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device))
            time_speech_end_emb = self.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device))
            
            generated_z_tokens = []
            
            # Generate for each batch item
            for b in range(batch_size):
                # Get valid text length for this sample
                if text_attention_mask is not None:
                    valid_text_len = text_attention_mask[b].sum().item()
                else:
                    valid_text_len = text_embeddings_proj.shape[1]
                
                # Start sequence with text_speech_latent for context
                sequence = [text_speech_latent_emb]  # (1, 512)
                batch_z_tokens = []  # Store z_tokens for this batch item
                
                # Generate for each text position
                for i in range(valid_text_len):
                    # Add t_i and s_i
                    t_i = text_embeddings_proj[b, i:i+1]  # (1, 512)
                    s_i = cross_attention_outputs[b, i:i+1]  # (1, 512)
                    sequence.extend([t_i, s_i])
                    
                    # Add time_speech_start
                    sequence.append(time_speech_start_emb)
                    
                    # Generate z tokens autoregressively for this text position
                    z_count = 0
                    while z_count < max_z_tokens:
                        # Prepare current sequence for AR transformer
                        current_sequence = torch.cat(sequence, dim=0).unsqueeze(0)  # (1, seq_len, 512)
                        
                        # Create attention mask for current sequence
                        seq_len = current_sequence.shape[1]
                        ar_attention_mask = torch.ones(1, seq_len, dtype=torch.bool, device=device)
                        
                        # Get prediction from AR transformer
                        ar_outputs = self.ar_transformer(
                            hidden_states=current_sequence,
                            attention_mask=ar_attention_mask,
                        )
                        
                        # Get the last prediction
                        last_prediction = ar_outputs.last_hidden_state[0, -1:, :]  # (1, 512)
                        
                        # Check stopping condition using BCE classifier (v1.1)
                        end_token_logit = self.end_token_classifier(last_prediction).squeeze(-1)  # (1,)
                        end_token_prob = torch.sigmoid(end_token_logit).item()  # Convert to probability
                        
                        # Stop if probability is high enough (>= threshold means stop)
                        if end_token_prob >= end_token_threshold:
                            # Stop generating z tokens
                            break
                        else:
                            # Add this prediction as next z token to both sequence (for context) and z_tokens (for output)
                            sequence.append(last_prediction)
                            batch_z_tokens.append(last_prediction.squeeze(0))  # Remove batch dimension for output
                            z_count += 1
                    
                    # Add time_speech_end to sequence for context
                    sequence.append(time_speech_end_emb)
                
                # Store z_tokens for this batch item
                generated_z_tokens.append(batch_z_tokens)
            
            return generated_z_tokens

    def forward(
        self,
        text_token_ids: torch.LongTensor,
        input_values: Optional[torch.Tensor] = None,
        speech_embeddings: Optional[torch.Tensor] = None,
        alignment_chunk_sizes: torch.Tensor = None,
        audio_attention_mask: Optional[torch.Tensor] = None,
        speech_attention_mask: Optional[torch.Tensor] = None,
        text_attention_mask: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Dict[str, torch.Tensor]:
        """
        Forward pass for training.
        
        Args:
            text_token_ids: Text token IDs (B, L)
            input_values: Audio input (B, 1, 24000 * T) - for normal mode
            speech_embeddings: Pre-computed speech embeddings (B, T, 512) - for cached mode
            alignment_chunk_sizes: Alignment chunk sizes (B, L)
            audio_attention_mask: Audio mask (B, audio_seq_len)
            speech_attention_mask: Speech mask (B, speech_seq_len)
            text_attention_mask: Text mask (B, text_seq_len)
            
        Returns:
            Dictionary with 'loss', 'reconstruction_loss', and 'bce_end_token_loss'
        """
        # Get speech embeddings
        if speech_embeddings is not None:
            pass
        elif input_values is not None:
            # Normal mode: compute speech embeddings from input_values

            speech_embeddings_raw = self.encode_audio_to_representation(
                input_values, 
                audio_attention_mask
            )
            # speech_embeddings_raw.shape = (B, 512, 12.5*T)
            # Transpose: [B, 512, 12.5*T] -> [B, 12.5*T, 512]
            speech_embeddings = speech_embeddings_raw.transpose(1, 2)
        else:
            raise ValueError("Either input_values or speech_embeddings must be provided")
        # Embed token ids and project to 512-dim
        text_embeddings_4096 = self.text_token_embedding(text_token_ids)  # (B, L, 4096)
        text_embeddings = self.text_proj(text_embeddings_4096)  # (B, L, 512)
        
        # Create proper attention masks for cross-attention
        formatted_text_attention_mask = None
        formatted_speech_attention_mask = None
        
        # Handle text attention mask (causal mask for decoder)
        batch_size, text_seq_len = text_embeddings.shape[:2]
            
        if text_attention_mask is not None:
            # Create causal mask and apply padding mask
            causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=text_embeddings.device, dtype=text_embeddings.dtype))
            causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
            
            # Apply padding mask to causal mask
            padding_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len)
            combined_mask = causal_mask * padding_mask
            
            # Convert to attention scores (-inf for masked positions)
            formatted_text_attention_mask = torch.where(combined_mask.bool(), 0.0, float('-inf'))
        else:
            # Create causal mask for all positions (no padding mask)
            causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=text_embeddings.device, dtype=text_embeddings.dtype))
            causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
            formatted_text_attention_mask = torch.where(causal_mask.bool(), 0.0, float('-inf'))
        
        # Handle speech attention mask (encoder mask)
        # Use speech_attention_mask if available (cached mode), otherwise audio_attention_mask (normal mode)
        if speech_attention_mask is not None:
            # Cached mode: speech_attention_mask is already in the right format
            speech_seq_len = speech_embeddings.shape[1]
            speech_mask = speech_attention_mask.bool()
            
            # Convert to attention format: [batch_size, 1, 1, speech_seq_len]
            formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
            formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
        elif audio_attention_mask is not None:
            # Normal mode: convert audio mask to speech embedding mask
            speech_seq_len = speech_embeddings.shape[1]
            
            # Create speech attention mask based on actual lengths
            speech_mask = torch.zeros(batch_size, speech_seq_len, dtype=torch.bool, device=speech_embeddings.device)
            
            for b in range(batch_size):
                audio_len = audio_attention_mask[b].sum().item()
                speech_len = int(audio_len * 12.5 / 24000)
                speech_len = min(speech_len, speech_seq_len)
                speech_mask[b, :speech_len] = True
            
            # Convert to attention format: [batch_size, 1, 1, speech_seq_len]
            formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
            formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
        else:
            # No masking
            formatted_speech_attention_mask = None

        # Cross attention: text attends to speech (no alignment constraints in V1)
        # hidden_states (decoder) = text, encoder_hidden_states = speech
        cross_attention_outputs = self.cross_attention_transformer(
            hidden_states=text_embeddings,
            encoder_hidden_states=speech_embeddings,
            attention_mask=formatted_text_attention_mask,  # Causal mask for text (decoder)
            encoder_attention_mask=formatted_speech_attention_mask,  # Mask for speech (encoder)
            alignment_chunk_sizes=None, # v1 doesn't use alignment_chunk_sizes -- the model should learn the alignment itself
        )
        cross_attention_outputs = cross_attention_outputs.last_hidden_state

        # Auto-regressive decoder part
        # Following v0.5 where the target is the dequantized Mimi decoder-input
        # Compute target representation = Mimi decoder-input (quantized->dequantized at 12.5*seconds)
        # 12.5*seconds => T
        with torch.no_grad():
            embeddings_bct = speech_embeddings.transpose(1, 2)  # (B, 512, T)
            codes_kbt = self.quantizer.encode(embeddings_bct)   # [K, B, T]
            codes_bkt = codes_kbt.transpose(0, 1)               # [B, K, T]
            decoder_input_emb = self.quantizer.decode(codes_bkt)  # (B, 512, T)
            target_representation = decoder_input_emb.transpose(1, 2)  # (B, T, 512)

        # Build the interleaved sequence for the autoregressive decoder
        # as well as the mask for loss computation
        # Get special embeddings (all are single embeddings)
        device = text_embeddings.device
        text_speech_latent_emb = self.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device))  # (1, 512)
        time_speech_start_emb = self.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device))    # (1, 512)  
        time_speech_end_emb = self.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device))        # (1, 512)

        batch_size = text_embeddings.shape[0]
        interleaved_sequences = []
        loss_masks = []
        bce_labels_batch = []  # BCE labels: 0 for z tokens, 1 for time_speech_end_emb
        bce_masks = []  # BCE mask: True for z tokens and time_speech_end_emb
        sequence_lengths = []  # Track actual sequence lengths before padding
        all_z_tokens = []  # Collect all valid z_tokens for separation loss
        max_total_length = 0

        for b in range(batch_size):
            # Start with text_speech_latent embedding
            sequence_parts = [text_speech_latent_emb]  # List to collect sequence parts
            loss_mask_parts = [False]  # Don't compute loss on special tokens
            bce_label_parts = [0]  # BCE labels (dummy for text_speech_latent_emb)
            bce_mask_parts = [False]  # BCE mask (False for text_speech_latent_emb)
            
            # Get valid text length for this batch item
            if text_attention_mask is not None:
                valid_text_len = text_attention_mask[b].sum().item()
            else:
                valid_text_len = text_embeddings.shape[1]
            
            # Track current position in target_representation
            speech_position = 0
            
            # For each text token
            for i in range(valid_text_len):
                # Add t_i (text embedding)
                t_i = text_embeddings[b, i:i+1]  # (1, 512)
                sequence_parts.append(t_i)
                loss_mask_parts.append(False)
                bce_label_parts.append(0)  # Dummy label for t_i
                bce_mask_parts.append(False)  # No BCE loss for t_i
                
                # Add s_i (cross attention output)
                s_i = cross_attention_outputs[b, i:i+1]  # (1, 512)
                sequence_parts.append(s_i)
                loss_mask_parts.append(False)
                bce_label_parts.append(0)  # Dummy label for s_i
                bce_mask_parts.append(False)  # No BCE loss for s_i
                
                # Add time_speech_start
                sequence_parts.append(time_speech_start_emb)
                loss_mask_parts.append(False)
                bce_label_parts.append(0)  # Dummy label for time_speech_start
                bce_mask_parts.append(False)  # No BCE loss for time_speech_start
                
                # Add z tokens for this chunk
                chunk_size = alignment_chunk_sizes[b, i].item()
                if chunk_size > 0:  # Only add if chunk size is positive
                    end_position = speech_position + chunk_size
                    # Make sure we don't exceed target_representation length
                    end_position = min(end_position, target_representation.shape[1])
                    actual_chunk_size = end_position - speech_position
                    
                    if actual_chunk_size > 0:
                        z_tokens = target_representation[b, speech_position:end_position]  # (actual_chunk_size, 512)
                        sequence_parts.append(z_tokens)
                        loss_mask_parts.extend([True] * actual_chunk_size)  # Compute loss on z tokens
                        bce_label_parts.extend([0] * actual_chunk_size)  # Label 0 for z tokens
                        bce_mask_parts.extend([True] * actual_chunk_size)  # Compute BCE loss on z tokens
                        
                        # Collect z_tokens for separation loss computation
                        all_z_tokens.append(z_tokens)
                    
                    speech_position = end_position
                
                # Add time_speech_end
                sequence_parts.append(time_speech_end_emb)
                loss_mask_parts.append(False)
                bce_label_parts.append(1)
                bce_mask_parts.append(True)
            
            # Concatenate all parts for this batch item
            full_sequence = torch.cat(sequence_parts, dim=0)  # (total_length, 512)
            loss_mask = torch.tensor(loss_mask_parts, dtype=torch.bool, device=device)
            bce_labels = torch.tensor(bce_label_parts, dtype=torch.float, device=device)
            bce_mask = torch.tensor(bce_mask_parts, dtype=torch.bool, device=device)
            
            interleaved_sequences.append(full_sequence)
            loss_masks.append(loss_mask)
            bce_labels_batch.append(bce_labels)
            bce_masks.append(bce_mask)
            sequence_lengths.append(full_sequence.shape[0])  # Track actual length before padding
            max_total_length = max(max_total_length, full_sequence.shape[0])

        # Pad sequences
        padded_sequences = []
        padded_loss_masks = []
        padded_bce_labels = []
        padded_bce_masks = []

        for sequence, loss_mask, bce_labels, bce_mask in zip(interleaved_sequences, loss_masks, bce_labels_batch, bce_masks):
            current_length = sequence.shape[0]
            if current_length < max_total_length:
                padding = torch.zeros(max_total_length - current_length, 512, device=device, dtype=sequence.dtype)
                padded_sequence = torch.cat([sequence, padding], dim=0)
                
                mask_padding = torch.zeros(max_total_length - current_length, dtype=torch.bool, device=device)
                padded_mask = torch.cat([loss_mask, mask_padding], dim=0)
                
                bce_label_padding = torch.zeros(max_total_length - current_length, dtype=torch.float, device=device)
                padded_bce_label = torch.cat([bce_labels, bce_label_padding], dim=0)
                
                bce_mask_padding = torch.zeros(max_total_length - current_length, dtype=torch.bool, device=device)
                padded_bce_mask = torch.cat([bce_mask, bce_mask_padding], dim=0)
            else:
                padded_sequence = sequence
                padded_mask = loss_mask
                padded_bce_label = bce_labels
                padded_bce_mask = bce_mask
            
            padded_sequences.append(padded_sequence)
            padded_loss_masks.append(padded_mask)
            padded_bce_labels.append(padded_bce_label)
            padded_bce_masks.append(padded_bce_mask)

        # Stack into batch tensors
        interleaved_batch = torch.stack(padded_sequences, dim=0)  # (batch_size, max_total_length, 512)
        loss_mask_batch = torch.stack(padded_loss_masks, dim=0)   # (batch_size, max_total_length)
        bce_labels_batch_tensor = torch.stack(padded_bce_labels, dim=0)  # (batch_size, max_total_length)
        bce_mask_batch = torch.stack(padded_bce_masks, dim=0)     # (batch_size, max_total_length)

        # Autoregressive prediction
        if max_total_length > 1:
            ar_input = interleaved_batch[:, :-1, :]  # (batch_size, max_total_length-1, 512)
            ar_targets = interleaved_batch[:, 1:, :]  # (batch_size, max_total_length-1, 512) 
            ar_loss_mask = loss_mask_batch[:, 1:]    # (batch_size, max_total_length-1) - shift mask left
            ar_bce_labels = bce_labels_batch_tensor[:, 1:]  # (batch_size, max_total_length-1) - shift labels left
            ar_bce_mask = bce_mask_batch[:, 1:]      # (batch_size, max_total_length-1) - shift mask left
            
            # Create attention mask for autoregressive transformer
            # We need to mask padded positions while maintaining causal property
            ar_seq_len = ar_input.shape[1]
            ar_attention_mask = torch.zeros(batch_size, ar_seq_len, dtype=torch.bool, device=device)
            for b in range(batch_size):
                valid_len = min(ar_seq_len, sequence_lengths[b] - 1)
                if valid_len > 0:
                    ar_attention_mask[b, :valid_len] = True
            
            ar_outputs = self.ar_transformer(
                hidden_states=ar_input,
                attention_mask=ar_attention_mask,  # This will be combined with causal mask inside transformer
            )
            ar_predictions = ar_outputs.last_hidden_state  # (batch_size, max_total_length-1, 512)
            
            # Compute BCE predictions for end token classification
            bce_logits = self.end_token_classifier(ar_predictions).squeeze(-1)  # (batch_size, max_total_length-1)
            
            # Compute L2 loss only where ar_loss_mask is True (z tokens)
            if ar_loss_mask.any():
                # Extract valid positions for loss computation
                valid_predictions = ar_predictions[ar_loss_mask]  # (num_valid_positions, 512)
                valid_targets = ar_targets[ar_loss_mask]          # (num_valid_positions, 512)
                
                # Compute L2 loss (MSE)
                reconstruction_loss = nn.functional.mse_loss(
                    valid_predictions, 
                    valid_targets, 
                    reduction='mean'
                )
            else:
                # Fallback if no valid positions (shouldn't happen in practice)
                reconstruction_loss = torch.tensor(0.0, device=device, requires_grad=True)
            
            # Compute BCE loss for end token classification (v1.1)
            if ar_bce_mask.any():
                # Extract valid positions for BCE loss computation
                valid_bce_logits = bce_logits[ar_bce_mask]  # (num_valid_bce_positions,)
                valid_bce_labels = ar_bce_labels[ar_bce_mask]  # (num_valid_bce_positions,)
                
                # Compute BCE loss
                bce_end_token_loss = nn.functional.binary_cross_entropy_with_logits(
                    valid_bce_logits,
                    valid_bce_labels,
                    reduction='mean'
                )
            else:
                # Fallback if no valid BCE positions
                bce_end_token_loss = torch.tensor(0.0, device=device, requires_grad=True)
            
            if self.bce_threshold > 0.0:
                clamped_bce_loss = torch.clamp(bce_end_token_loss - self.bce_threshold, min=0.0)
                total_loss = reconstruction_loss + self.alpha * clamped_bce_loss
            else:
                total_loss = reconstruction_loss + self.alpha * bce_end_token_loss
        else:
            reconstruction_loss = torch.tensor(0.0, device=device, requires_grad=True)
            bce_end_token_loss = torch.tensor(0.0, device=device, requires_grad=True)
            total_loss = reconstruction_loss + torch.tensor(0.0, device=device, requires_grad=True)

        return {
            'loss': total_loss,
            'reconstruction_loss': reconstruction_loss,
            'bce_end_token_loss': bce_end_token_loss,
        }


__all__ = ["TextSyncMimi"]