Upload modeling_text_sync_mimi.py with huggingface_hub
Browse files- modeling_text_sync_mimi.py +713 -0
modeling_text_sync_mimi.py
ADDED
|
@@ -0,0 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch TextSyncMimi model - Text-synchronous neural audio codec based on Mimi."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from typing import Optional, Dict, List, Union
|
| 6 |
+
|
| 7 |
+
from configuration_mimi import MimiConfig
|
| 8 |
+
from modeling_mimi_clean import MimiPreTrainedModel, MimiModel
|
| 9 |
+
from modeling_backbone_components import (
|
| 10 |
+
CrossAttentionTransformer,
|
| 11 |
+
CausalAttentionTransformer
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TextSyncMimi(MimiPreTrainedModel):
|
| 16 |
+
"""
|
| 17 |
+
TextSyncMimi: Text-Synchronous Neural Audio Codec Model
|
| 18 |
+
|
| 19 |
+
A neural audio codec model that combines text and speech representations for
|
| 20 |
+
high-quality text-to-speech synthesis. Features:
|
| 21 |
+
|
| 22 |
+
- Learnable text embeddings
|
| 23 |
+
- Cross-attention transformer for text-speech alignment
|
| 24 |
+
- Autoregressive transformer for causal speech generation
|
| 25 |
+
- BCE-based end token prediction for dynamic duration control
|
| 26 |
+
|
| 27 |
+
Architecture:
|
| 28 |
+
- Text Embedding Layer: Maps token IDs to 4,096-dim embeddings
|
| 29 |
+
- Mimi Encoder: Pre-trained audio encoder (frozen)
|
| 30 |
+
- Text Projection: Linear projection from 4,096 to 512 dimensions
|
| 31 |
+
- Cross-Attention Transformer: Aligns text with speech features
|
| 32 |
+
- Autoregressive Transformer: Generates speech representations
|
| 33 |
+
- End Token Classifier: Predicts when to stop generating
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
config: Optional[Union[MimiConfig, 'TextSyncMimiConfig']] = None,
|
| 39 |
+
model_id: Optional[str] = None,
|
| 40 |
+
token: Optional[str] = None,
|
| 41 |
+
alpha: Optional[float] = None,
|
| 42 |
+
cross_attention_layers: Optional[int] = None,
|
| 43 |
+
causal_attention_layers: Optional[int] = None,
|
| 44 |
+
bce_threshold: Optional[float] = None,
|
| 45 |
+
vocab_size: Optional[int] = None,
|
| 46 |
+
):
|
| 47 |
+
"""
|
| 48 |
+
Initialize TextSyncMimi model.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
config: Model configuration (TextSyncMimiConfig or MimiConfig)
|
| 52 |
+
model_id: Mimi model ID (e.g., "kyutai/mimi"). If None, uses config.mimi_model_id
|
| 53 |
+
token: Hugging Face authentication token
|
| 54 |
+
alpha: Weight for BCE end token loss. If None, uses config.alpha
|
| 55 |
+
cross_attention_layers: Number of cross-attention layers. If None, uses config
|
| 56 |
+
causal_attention_layers: Number of autoregressive layers. If None, uses config
|
| 57 |
+
bce_threshold: BCE loss threshold. If None, uses config.bce_threshold
|
| 58 |
+
vocab_size: Text vocabulary size. If None, uses config.vocab_size
|
| 59 |
+
"""
|
| 60 |
+
# Handle config initialization for both manual instantiation and from_pretrained
|
| 61 |
+
if config is None:
|
| 62 |
+
if model_id is None:
|
| 63 |
+
raise ValueError("Either config or model_id must be provided")
|
| 64 |
+
config = MimiConfig.from_pretrained(model_id, token=token)
|
| 65 |
+
|
| 66 |
+
super().__init__(config)
|
| 67 |
+
|
| 68 |
+
# Extract parameters from config if not explicitly provided
|
| 69 |
+
if hasattr(config, 'mimi_model_id'):
|
| 70 |
+
model_id = model_id or config.mimi_model_id
|
| 71 |
+
if model_id is None:
|
| 72 |
+
raise ValueError("model_id must be provided either as argument or in config.mimi_model_id")
|
| 73 |
+
|
| 74 |
+
alpha = alpha if alpha is not None else getattr(config, 'alpha', 1.0)
|
| 75 |
+
cross_attention_layers = cross_attention_layers if cross_attention_layers is not None else getattr(config, 'cross_attention_layers', 2)
|
| 76 |
+
causal_attention_layers = causal_attention_layers if causal_attention_layers is not None else getattr(config, 'causal_attention_layers', 2)
|
| 77 |
+
bce_threshold = bce_threshold if bce_threshold is not None else getattr(config, 'bce_threshold', 0.1)
|
| 78 |
+
vocab_size = vocab_size if vocab_size is not None else getattr(config, 'vocab_size', 128256)
|
| 79 |
+
|
| 80 |
+
# load the mimi backbone
|
| 81 |
+
self.config = config
|
| 82 |
+
model = MimiModel.from_pretrained(model_id, token=token)
|
| 83 |
+
|
| 84 |
+
# hyperparameters for auxiliary loss
|
| 85 |
+
self.alpha = alpha
|
| 86 |
+
self.bce_threshold = bce_threshold
|
| 87 |
+
|
| 88 |
+
# Learnable text token embedding
|
| 89 |
+
self.text_token_embedding = nn.Embedding(vocab_size, 4096)
|
| 90 |
+
|
| 91 |
+
# Text projection
|
| 92 |
+
self.text_proj = nn.Linear(4096, 512)
|
| 93 |
+
|
| 94 |
+
# Cross-attention transformer
|
| 95 |
+
cross_attention_config = MimiConfig(**self.config.__dict__)
|
| 96 |
+
cross_attention_config.num_hidden_layers = cross_attention_layers
|
| 97 |
+
cross_attention_config.hidden_size = 512
|
| 98 |
+
self.cross_attention_transformer = CrossAttentionTransformer(cross_attention_config)
|
| 99 |
+
|
| 100 |
+
# decoder part (v1)
|
| 101 |
+
# Auto-regressive decoder:
|
| 102 |
+
# <|text_speech_latent|> [t_i] [s_i] <|time_speech_start|> [z_(i,1)] [z_(i,2)] ... [z_(i,K)] <|time_speech_end|>
|
| 103 |
+
# masking (not computing loss for <|text_speech_latent|> [t_i] [s_i] <|time_speech_start|>
|
| 104 |
+
# t_i already mapped from 4096 (e.g., llama embedding) -> 512
|
| 105 |
+
# s_i already 512
|
| 106 |
+
# z is mimi's decoder-input which is also 512
|
| 107 |
+
causal_attention_config = MimiConfig(**self.config.__dict__)
|
| 108 |
+
causal_attention_config.num_hidden_layers = causal_attention_layers
|
| 109 |
+
causal_attention_config.hidden_size = 512
|
| 110 |
+
self.ar_transformer = CausalAttentionTransformer(causal_attention_config)
|
| 111 |
+
|
| 112 |
+
# embedding for special positions in the autoregressive decoder
|
| 113 |
+
self.text_speech_latent_embed = nn.Embedding(1, 512)
|
| 114 |
+
self.time_speech_start_embed = nn.Embedding(1, 512)
|
| 115 |
+
self.time_speech_end_embed = nn.Embedding(1, 512)
|
| 116 |
+
|
| 117 |
+
# Binary classification head for end token prediction
|
| 118 |
+
self.end_token_classifier = nn.Linear(512, 1)
|
| 119 |
+
|
| 120 |
+
self.post_init()
|
| 121 |
+
|
| 122 |
+
# Frozen Mimi components
|
| 123 |
+
self.encoder = model.encoder
|
| 124 |
+
self.encoder_transformer = model.encoder_transformer
|
| 125 |
+
self.quantizer = model.quantizer
|
| 126 |
+
self.downsample = model.downsample
|
| 127 |
+
self.upsample = model.upsample
|
| 128 |
+
|
| 129 |
+
# print the number of parameters for each sub network in Millions
|
| 130 |
+
self._print_subnetwork_parameter_counts()
|
| 131 |
+
|
| 132 |
+
def initialize_text_embeddings_from_weights(self, embedding_weight: torch.Tensor) -> None:
|
| 133 |
+
"""
|
| 134 |
+
Initialize text embeddings from a weight matrix.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
embedding_weight: Weight matrix of shape (vocab_size, 4096)
|
| 138 |
+
"""
|
| 139 |
+
if embedding_weight.dim() != 2 or embedding_weight.size(1) != 4096:
|
| 140 |
+
raise ValueError("embedding_weight must have shape (vocab_size, 4096)")
|
| 141 |
+
if embedding_weight.size(0) != self.text_token_embedding.num_embeddings:
|
| 142 |
+
raise ValueError("Provided vocab_size does not match model's text_token_embedding")
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
self.text_token_embedding.weight.copy_(embedding_weight)
|
| 145 |
+
for p in self.text_token_embedding.parameters():
|
| 146 |
+
p.requires_grad = True
|
| 147 |
+
|
| 148 |
+
def initialize_text_embeddings_from_llama(self, llama_embeddings_module: torch.nn.Module) -> None:
|
| 149 |
+
"""
|
| 150 |
+
Initialize text embeddings from a LLaMA embedding module.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
llama_embeddings_module: LLaMA embedding module with weight shape (vocab_size, 4096)
|
| 154 |
+
"""
|
| 155 |
+
if not hasattr(llama_embeddings_module, 'weight'):
|
| 156 |
+
raise ValueError("llama_embeddings_module must have a 'weight' attribute")
|
| 157 |
+
weight = llama_embeddings_module.weight.data
|
| 158 |
+
self.initialize_text_embeddings_from_weights(weight)
|
| 159 |
+
|
| 160 |
+
def _print_subnetwork_parameter_counts(self) -> None:
|
| 161 |
+
"""Print parameter counts for model subnetworks."""
|
| 162 |
+
print("=" * 70)
|
| 163 |
+
print("TextSyncMimi Parameter Counts")
|
| 164 |
+
print("=" * 70)
|
| 165 |
+
print(f"Encoder: {sum(p.numel() for p in self.encoder.parameters()) / 1e6:.2f}M")
|
| 166 |
+
print(f"Encoder Transformer: {sum(p.numel() for p in self.encoder_transformer.parameters()) / 1e6:.2f}M")
|
| 167 |
+
print(f"Cross-Attention Transformer: {sum(p.numel() for p in self.cross_attention_transformer.parameters()) / 1e6:.2f}M")
|
| 168 |
+
print(f"AR Transformer: {sum(p.numel() for p in self.ar_transformer.parameters()) / 1e6:.2f}M")
|
| 169 |
+
print(f"Quantizer: {sum(p.numel() for p in self.quantizer.parameters()) / 1e6:.2f}M")
|
| 170 |
+
print("=" * 70)
|
| 171 |
+
|
| 172 |
+
def encode_audio_to_representation(
|
| 173 |
+
self,
|
| 174 |
+
input_values: torch.Tensor,
|
| 175 |
+
audio_attention_mask: Optional[torch.Tensor] = None,
|
| 176 |
+
) -> torch.Tensor:
|
| 177 |
+
"""
|
| 178 |
+
Encode audio to speech representation.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
input_values: Audio waveform (B, 1, audio_len)
|
| 182 |
+
audio_attention_mask: Attention mask (B, audio_len)
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
Speech embeddings (B, 512, 12.5 * T)
|
| 186 |
+
"""
|
| 187 |
+
batch_size = input_values.shape[0]
|
| 188 |
+
device = input_values.device
|
| 189 |
+
|
| 190 |
+
# Encode through Mimi encoder pipeline
|
| 191 |
+
embeddings = self.encoder(input_values)
|
| 192 |
+
encoder_outputs = self.encoder_transformer(embeddings.transpose(1, 2))
|
| 193 |
+
embeddings = encoder_outputs[0].transpose(1, 2)
|
| 194 |
+
embeddings = self.downsample(embeddings)
|
| 195 |
+
|
| 196 |
+
# Apply attention mask if provided
|
| 197 |
+
if audio_attention_mask is not None:
|
| 198 |
+
speech_seq_len = embeddings.shape[-1]
|
| 199 |
+
speech_attention_mask = torch.zeros(batch_size, speech_seq_len, device=device, dtype=torch.bool)
|
| 200 |
+
|
| 201 |
+
for b in range(batch_size):
|
| 202 |
+
actual_audio_len = audio_attention_mask[b].sum().item()
|
| 203 |
+
actual_speech_len = int(actual_audio_len * 12.5 / 24000)
|
| 204 |
+
actual_speech_len = min(actual_speech_len, speech_seq_len)
|
| 205 |
+
if actual_speech_len > 0:
|
| 206 |
+
speech_attention_mask[b, :actual_speech_len] = True
|
| 207 |
+
|
| 208 |
+
speech_mask_expanded = speech_attention_mask.unsqueeze(1)
|
| 209 |
+
embeddings = embeddings * speech_mask_expanded.float()
|
| 210 |
+
|
| 211 |
+
return embeddings
|
| 212 |
+
|
| 213 |
+
def generate_autoregressive(
|
| 214 |
+
self,
|
| 215 |
+
text_token_ids: torch.LongTensor,
|
| 216 |
+
input_values: Optional[torch.Tensor] = None,
|
| 217 |
+
speech_embeddings: Optional[torch.Tensor] = None,
|
| 218 |
+
audio_attention_mask: Optional[torch.Tensor] = None,
|
| 219 |
+
speech_attention_mask: Optional[torch.Tensor] = None,
|
| 220 |
+
text_attention_mask: Optional[torch.Tensor] = None,
|
| 221 |
+
max_z_tokens: int = 50,
|
| 222 |
+
end_token_threshold: float = 0.5,
|
| 223 |
+
device: Optional[torch.device] = None,
|
| 224 |
+
) -> List[List[torch.Tensor]]:
|
| 225 |
+
"""
|
| 226 |
+
Generate audio autoregressively.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
text_token_ids: Text token IDs (B, L)
|
| 230 |
+
input_values: Audio input (B, 1, 24000 * T) - for normal mode
|
| 231 |
+
speech_embeddings: Pre-computed speech embeddings (B, T, 512) - for cached mode
|
| 232 |
+
audio_attention_mask: Audio mask (B, audio_seq_len) - for normal mode
|
| 233 |
+
speech_attention_mask: Speech mask (B, speech_seq_len) - for cached mode
|
| 234 |
+
text_attention_mask: Text mask (B, text_seq_len)
|
| 235 |
+
max_z_tokens: Maximum z tokens per text position
|
| 236 |
+
end_token_threshold: Probability threshold for stopping
|
| 237 |
+
device: Device for computation
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
List of z_tokens lists (one per batch item)
|
| 241 |
+
"""
|
| 242 |
+
if device is None:
|
| 243 |
+
device = text_token_ids.device
|
| 244 |
+
|
| 245 |
+
self.eval()
|
| 246 |
+
|
| 247 |
+
with torch.no_grad():
|
| 248 |
+
# Get speech embeddings for cross-attention context
|
| 249 |
+
if speech_embeddings is not None:
|
| 250 |
+
# Use pre-computed speech embeddings (cached mode)
|
| 251 |
+
# speech_embeddings should already be (B, T, 512)
|
| 252 |
+
pass # speech_embeddings is already provided
|
| 253 |
+
else:
|
| 254 |
+
# Compute speech embeddings from input_values (normal mode)
|
| 255 |
+
if input_values is None:
|
| 256 |
+
raise ValueError("Either input_values or speech_embeddings must be provided")
|
| 257 |
+
speech_embeddings = self.encode_audio_to_representation(
|
| 258 |
+
input_values,
|
| 259 |
+
audio_attention_mask=audio_attention_mask
|
| 260 |
+
)
|
| 261 |
+
speech_embeddings = speech_embeddings.transpose(1, 2) # (B, T, 512)
|
| 262 |
+
|
| 263 |
+
# Embed token ids then project to 512
|
| 264 |
+
text_embeddings_4096 = self.text_token_embedding(text_token_ids) # (B, L, 4096)
|
| 265 |
+
text_embeddings_proj = self.text_proj(text_embeddings_4096) # (B, L, 512)
|
| 266 |
+
|
| 267 |
+
# Apply cross attention (same as in forward)
|
| 268 |
+
# Create attention masks
|
| 269 |
+
formatted_text_attention_mask = None
|
| 270 |
+
formatted_speech_attention_mask = None
|
| 271 |
+
|
| 272 |
+
batch_size, text_seq_len = text_embeddings_proj.shape[:2]
|
| 273 |
+
|
| 274 |
+
if text_attention_mask is not None:
|
| 275 |
+
causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_embeddings_proj.dtype))
|
| 276 |
+
causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
|
| 277 |
+
padding_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len)
|
| 278 |
+
combined_mask = causal_mask * padding_mask
|
| 279 |
+
formatted_text_attention_mask = torch.where(combined_mask.bool(), 0.0, float('-inf'))
|
| 280 |
+
else:
|
| 281 |
+
causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_embeddings_proj.dtype))
|
| 282 |
+
causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
|
| 283 |
+
formatted_text_attention_mask = torch.where(causal_mask.bool(), 0.0, float('-inf'))
|
| 284 |
+
|
| 285 |
+
# Handle speech attention mask (use speech_attention_mask if available, otherwise audio_attention_mask)
|
| 286 |
+
if speech_attention_mask is not None:
|
| 287 |
+
# For cached data, speech_attention_mask is already in the right format
|
| 288 |
+
speech_seq_len = speech_embeddings.shape[1]
|
| 289 |
+
speech_mask = speech_attention_mask.bool()
|
| 290 |
+
formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
|
| 291 |
+
formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
|
| 292 |
+
elif audio_attention_mask is not None:
|
| 293 |
+
# For non-cached data, convert audio_attention_mask to speech_attention_mask
|
| 294 |
+
speech_seq_len = speech_embeddings.shape[1]
|
| 295 |
+
speech_mask = torch.zeros(batch_size, speech_seq_len, dtype=torch.bool, device=device)
|
| 296 |
+
for b in range(batch_size):
|
| 297 |
+
audio_len = audio_attention_mask[b].sum().item()
|
| 298 |
+
speech_len = int(audio_len * 12.5 / 24000)
|
| 299 |
+
speech_len = min(speech_len, speech_seq_len)
|
| 300 |
+
speech_mask[b, :speech_len] = True
|
| 301 |
+
formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
|
| 302 |
+
formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
|
| 303 |
+
else:
|
| 304 |
+
formatted_speech_attention_mask = None
|
| 305 |
+
|
| 306 |
+
# Cross attention
|
| 307 |
+
cross_attention_outputs = self.cross_attention_transformer(
|
| 308 |
+
hidden_states=text_embeddings_proj,
|
| 309 |
+
encoder_hidden_states=speech_embeddings,
|
| 310 |
+
attention_mask=formatted_text_attention_mask,
|
| 311 |
+
encoder_attention_mask=formatted_speech_attention_mask,
|
| 312 |
+
alignment_chunk_sizes=None, # V1 learns alignment
|
| 313 |
+
)
|
| 314 |
+
cross_attention_outputs = cross_attention_outputs.last_hidden_state
|
| 315 |
+
|
| 316 |
+
# Get special embeddings
|
| 317 |
+
text_speech_latent_emb = self.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device))
|
| 318 |
+
time_speech_start_emb = self.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device))
|
| 319 |
+
time_speech_end_emb = self.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device))
|
| 320 |
+
|
| 321 |
+
generated_z_tokens = []
|
| 322 |
+
|
| 323 |
+
# Generate for each batch item
|
| 324 |
+
for b in range(batch_size):
|
| 325 |
+
# Get valid text length for this sample
|
| 326 |
+
if text_attention_mask is not None:
|
| 327 |
+
valid_text_len = text_attention_mask[b].sum().item()
|
| 328 |
+
else:
|
| 329 |
+
valid_text_len = text_embeddings_proj.shape[1]
|
| 330 |
+
|
| 331 |
+
# Start sequence with text_speech_latent for context
|
| 332 |
+
sequence = [text_speech_latent_emb] # (1, 512)
|
| 333 |
+
batch_z_tokens = [] # Store z_tokens for this batch item
|
| 334 |
+
|
| 335 |
+
# Generate for each text position
|
| 336 |
+
for i in range(valid_text_len):
|
| 337 |
+
# Add t_i and s_i
|
| 338 |
+
t_i = text_embeddings_proj[b, i:i+1] # (1, 512)
|
| 339 |
+
s_i = cross_attention_outputs[b, i:i+1] # (1, 512)
|
| 340 |
+
sequence.extend([t_i, s_i])
|
| 341 |
+
|
| 342 |
+
# Add time_speech_start
|
| 343 |
+
sequence.append(time_speech_start_emb)
|
| 344 |
+
|
| 345 |
+
# Generate z tokens autoregressively for this text position
|
| 346 |
+
z_count = 0
|
| 347 |
+
while z_count < max_z_tokens:
|
| 348 |
+
# Prepare current sequence for AR transformer
|
| 349 |
+
current_sequence = torch.cat(sequence, dim=0).unsqueeze(0) # (1, seq_len, 512)
|
| 350 |
+
|
| 351 |
+
# Create attention mask for current sequence
|
| 352 |
+
seq_len = current_sequence.shape[1]
|
| 353 |
+
ar_attention_mask = torch.ones(1, seq_len, dtype=torch.bool, device=device)
|
| 354 |
+
|
| 355 |
+
# Get prediction from AR transformer
|
| 356 |
+
ar_outputs = self.ar_transformer(
|
| 357 |
+
hidden_states=current_sequence,
|
| 358 |
+
attention_mask=ar_attention_mask,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# Get the last prediction
|
| 362 |
+
last_prediction = ar_outputs.last_hidden_state[0, -1:, :] # (1, 512)
|
| 363 |
+
|
| 364 |
+
# Check stopping condition using BCE classifier (v1.1)
|
| 365 |
+
end_token_logit = self.end_token_classifier(last_prediction).squeeze(-1) # (1,)
|
| 366 |
+
end_token_prob = torch.sigmoid(end_token_logit).item() # Convert to probability
|
| 367 |
+
|
| 368 |
+
# Stop if probability is high enough (>= threshold means stop)
|
| 369 |
+
if end_token_prob >= end_token_threshold:
|
| 370 |
+
# Stop generating z tokens
|
| 371 |
+
break
|
| 372 |
+
else:
|
| 373 |
+
# Add this prediction as next z token to both sequence (for context) and z_tokens (for output)
|
| 374 |
+
sequence.append(last_prediction)
|
| 375 |
+
batch_z_tokens.append(last_prediction.squeeze(0)) # Remove batch dimension for output
|
| 376 |
+
z_count += 1
|
| 377 |
+
|
| 378 |
+
# Add time_speech_end to sequence for context
|
| 379 |
+
sequence.append(time_speech_end_emb)
|
| 380 |
+
|
| 381 |
+
# Store z_tokens for this batch item
|
| 382 |
+
generated_z_tokens.append(batch_z_tokens)
|
| 383 |
+
|
| 384 |
+
return generated_z_tokens
|
| 385 |
+
|
| 386 |
+
def forward(
|
| 387 |
+
self,
|
| 388 |
+
text_token_ids: torch.LongTensor,
|
| 389 |
+
input_values: Optional[torch.Tensor] = None,
|
| 390 |
+
speech_embeddings: Optional[torch.Tensor] = None,
|
| 391 |
+
alignment_chunk_sizes: torch.Tensor = None,
|
| 392 |
+
audio_attention_mask: Optional[torch.Tensor] = None,
|
| 393 |
+
speech_attention_mask: Optional[torch.Tensor] = None,
|
| 394 |
+
text_attention_mask: Optional[torch.Tensor] = None,
|
| 395 |
+
**kwargs,
|
| 396 |
+
) -> Dict[str, torch.Tensor]:
|
| 397 |
+
"""
|
| 398 |
+
Forward pass for training.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
text_token_ids: Text token IDs (B, L)
|
| 402 |
+
input_values: Audio input (B, 1, 24000 * T) - for normal mode
|
| 403 |
+
speech_embeddings: Pre-computed speech embeddings (B, T, 512) - for cached mode
|
| 404 |
+
alignment_chunk_sizes: Alignment chunk sizes (B, L)
|
| 405 |
+
audio_attention_mask: Audio mask (B, audio_seq_len)
|
| 406 |
+
speech_attention_mask: Speech mask (B, speech_seq_len)
|
| 407 |
+
text_attention_mask: Text mask (B, text_seq_len)
|
| 408 |
+
|
| 409 |
+
Returns:
|
| 410 |
+
Dictionary with 'loss', 'reconstruction_loss', and 'bce_end_token_loss'
|
| 411 |
+
"""
|
| 412 |
+
# Get speech embeddings
|
| 413 |
+
if speech_embeddings is not None:
|
| 414 |
+
pass
|
| 415 |
+
elif input_values is not None:
|
| 416 |
+
# Normal mode: compute speech embeddings from input_values
|
| 417 |
+
|
| 418 |
+
speech_embeddings_raw = self.encode_audio_to_representation(
|
| 419 |
+
input_values,
|
| 420 |
+
audio_attention_mask
|
| 421 |
+
)
|
| 422 |
+
# speech_embeddings_raw.shape = (B, 512, 12.5*T)
|
| 423 |
+
# Transpose: [B, 512, 12.5*T] -> [B, 12.5*T, 512]
|
| 424 |
+
speech_embeddings = speech_embeddings_raw.transpose(1, 2)
|
| 425 |
+
else:
|
| 426 |
+
raise ValueError("Either input_values or speech_embeddings must be provided")
|
| 427 |
+
# Embed token ids and project to 512-dim
|
| 428 |
+
text_embeddings_4096 = self.text_token_embedding(text_token_ids) # (B, L, 4096)
|
| 429 |
+
text_embeddings = self.text_proj(text_embeddings_4096) # (B, L, 512)
|
| 430 |
+
|
| 431 |
+
# Create proper attention masks for cross-attention
|
| 432 |
+
formatted_text_attention_mask = None
|
| 433 |
+
formatted_speech_attention_mask = None
|
| 434 |
+
|
| 435 |
+
# Handle text attention mask (causal mask for decoder)
|
| 436 |
+
batch_size, text_seq_len = text_embeddings.shape[:2]
|
| 437 |
+
|
| 438 |
+
if text_attention_mask is not None:
|
| 439 |
+
# Create causal mask and apply padding mask
|
| 440 |
+
causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=text_embeddings.device, dtype=text_embeddings.dtype))
|
| 441 |
+
causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
|
| 442 |
+
|
| 443 |
+
# Apply padding mask to causal mask
|
| 444 |
+
padding_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len)
|
| 445 |
+
combined_mask = causal_mask * padding_mask
|
| 446 |
+
|
| 447 |
+
# Convert to attention scores (-inf for masked positions)
|
| 448 |
+
formatted_text_attention_mask = torch.where(combined_mask.bool(), 0.0, float('-inf'))
|
| 449 |
+
else:
|
| 450 |
+
# Create causal mask for all positions (no padding mask)
|
| 451 |
+
causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=text_embeddings.device, dtype=text_embeddings.dtype))
|
| 452 |
+
causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
|
| 453 |
+
formatted_text_attention_mask = torch.where(causal_mask.bool(), 0.0, float('-inf'))
|
| 454 |
+
|
| 455 |
+
# Handle speech attention mask (encoder mask)
|
| 456 |
+
# Use speech_attention_mask if available (cached mode), otherwise audio_attention_mask (normal mode)
|
| 457 |
+
if speech_attention_mask is not None:
|
| 458 |
+
# Cached mode: speech_attention_mask is already in the right format
|
| 459 |
+
speech_seq_len = speech_embeddings.shape[1]
|
| 460 |
+
speech_mask = speech_attention_mask.bool()
|
| 461 |
+
|
| 462 |
+
# Convert to attention format: [batch_size, 1, 1, speech_seq_len]
|
| 463 |
+
formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
|
| 464 |
+
formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
|
| 465 |
+
elif audio_attention_mask is not None:
|
| 466 |
+
# Normal mode: convert audio mask to speech embedding mask
|
| 467 |
+
speech_seq_len = speech_embeddings.shape[1]
|
| 468 |
+
|
| 469 |
+
# Create speech attention mask based on actual lengths
|
| 470 |
+
speech_mask = torch.zeros(batch_size, speech_seq_len, dtype=torch.bool, device=speech_embeddings.device)
|
| 471 |
+
|
| 472 |
+
for b in range(batch_size):
|
| 473 |
+
audio_len = audio_attention_mask[b].sum().item()
|
| 474 |
+
speech_len = int(audio_len * 12.5 / 24000)
|
| 475 |
+
speech_len = min(speech_len, speech_seq_len)
|
| 476 |
+
speech_mask[b, :speech_len] = True
|
| 477 |
+
|
| 478 |
+
# Convert to attention format: [batch_size, 1, 1, speech_seq_len]
|
| 479 |
+
formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
|
| 480 |
+
formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
|
| 481 |
+
else:
|
| 482 |
+
# No masking
|
| 483 |
+
formatted_speech_attention_mask = None
|
| 484 |
+
|
| 485 |
+
# Cross attention: text attends to speech (no alignment constraints in V1)
|
| 486 |
+
# hidden_states (decoder) = text, encoder_hidden_states = speech
|
| 487 |
+
cross_attention_outputs = self.cross_attention_transformer(
|
| 488 |
+
hidden_states=text_embeddings,
|
| 489 |
+
encoder_hidden_states=speech_embeddings,
|
| 490 |
+
attention_mask=formatted_text_attention_mask, # Causal mask for text (decoder)
|
| 491 |
+
encoder_attention_mask=formatted_speech_attention_mask, # Mask for speech (encoder)
|
| 492 |
+
alignment_chunk_sizes=None, # v1 doesn't use alignment_chunk_sizes -- the model should learn the alignment itself
|
| 493 |
+
)
|
| 494 |
+
cross_attention_outputs = cross_attention_outputs.last_hidden_state
|
| 495 |
+
|
| 496 |
+
# Auto-regressive decoder part
|
| 497 |
+
# Following v0.5 where the target is the dequantized Mimi decoder-input
|
| 498 |
+
# Compute target representation = Mimi decoder-input (quantized->dequantized at 12.5*seconds)
|
| 499 |
+
# 12.5*seconds => T
|
| 500 |
+
with torch.no_grad():
|
| 501 |
+
embeddings_bct = speech_embeddings.transpose(1, 2) # (B, 512, T)
|
| 502 |
+
codes_kbt = self.quantizer.encode(embeddings_bct) # [K, B, T]
|
| 503 |
+
codes_bkt = codes_kbt.transpose(0, 1) # [B, K, T]
|
| 504 |
+
decoder_input_emb = self.quantizer.decode(codes_bkt) # (B, 512, T)
|
| 505 |
+
target_representation = decoder_input_emb.transpose(1, 2) # (B, T, 512)
|
| 506 |
+
|
| 507 |
+
# Build the interleaved sequence for the autoregressive decoder
|
| 508 |
+
# as well as the mask for loss computation
|
| 509 |
+
# Get special embeddings (all are single embeddings)
|
| 510 |
+
device = text_embeddings.device
|
| 511 |
+
text_speech_latent_emb = self.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device)) # (1, 512)
|
| 512 |
+
time_speech_start_emb = self.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device)) # (1, 512)
|
| 513 |
+
time_speech_end_emb = self.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device)) # (1, 512)
|
| 514 |
+
|
| 515 |
+
batch_size = text_embeddings.shape[0]
|
| 516 |
+
interleaved_sequences = []
|
| 517 |
+
loss_masks = []
|
| 518 |
+
bce_labels_batch = [] # BCE labels: 0 for z tokens, 1 for time_speech_end_emb
|
| 519 |
+
bce_masks = [] # BCE mask: True for z tokens and time_speech_end_emb
|
| 520 |
+
sequence_lengths = [] # Track actual sequence lengths before padding
|
| 521 |
+
all_z_tokens = [] # Collect all valid z_tokens for separation loss
|
| 522 |
+
max_total_length = 0
|
| 523 |
+
|
| 524 |
+
for b in range(batch_size):
|
| 525 |
+
# Start with text_speech_latent embedding
|
| 526 |
+
sequence_parts = [text_speech_latent_emb] # List to collect sequence parts
|
| 527 |
+
loss_mask_parts = [False] # Don't compute loss on special tokens
|
| 528 |
+
bce_label_parts = [0] # BCE labels (dummy for text_speech_latent_emb)
|
| 529 |
+
bce_mask_parts = [False] # BCE mask (False for text_speech_latent_emb)
|
| 530 |
+
|
| 531 |
+
# Get valid text length for this batch item
|
| 532 |
+
if text_attention_mask is not None:
|
| 533 |
+
valid_text_len = text_attention_mask[b].sum().item()
|
| 534 |
+
else:
|
| 535 |
+
valid_text_len = text_embeddings.shape[1]
|
| 536 |
+
|
| 537 |
+
# Track current position in target_representation
|
| 538 |
+
speech_position = 0
|
| 539 |
+
|
| 540 |
+
# For each text token
|
| 541 |
+
for i in range(valid_text_len):
|
| 542 |
+
# Add t_i (text embedding)
|
| 543 |
+
t_i = text_embeddings[b, i:i+1] # (1, 512)
|
| 544 |
+
sequence_parts.append(t_i)
|
| 545 |
+
loss_mask_parts.append(False)
|
| 546 |
+
bce_label_parts.append(0) # Dummy label for t_i
|
| 547 |
+
bce_mask_parts.append(False) # No BCE loss for t_i
|
| 548 |
+
|
| 549 |
+
# Add s_i (cross attention output)
|
| 550 |
+
s_i = cross_attention_outputs[b, i:i+1] # (1, 512)
|
| 551 |
+
sequence_parts.append(s_i)
|
| 552 |
+
loss_mask_parts.append(False)
|
| 553 |
+
bce_label_parts.append(0) # Dummy label for s_i
|
| 554 |
+
bce_mask_parts.append(False) # No BCE loss for s_i
|
| 555 |
+
|
| 556 |
+
# Add time_speech_start
|
| 557 |
+
sequence_parts.append(time_speech_start_emb)
|
| 558 |
+
loss_mask_parts.append(False)
|
| 559 |
+
bce_label_parts.append(0) # Dummy label for time_speech_start
|
| 560 |
+
bce_mask_parts.append(False) # No BCE loss for time_speech_start
|
| 561 |
+
|
| 562 |
+
# Add z tokens for this chunk
|
| 563 |
+
chunk_size = alignment_chunk_sizes[b, i].item()
|
| 564 |
+
if chunk_size > 0: # Only add if chunk size is positive
|
| 565 |
+
end_position = speech_position + chunk_size
|
| 566 |
+
# Make sure we don't exceed target_representation length
|
| 567 |
+
end_position = min(end_position, target_representation.shape[1])
|
| 568 |
+
actual_chunk_size = end_position - speech_position
|
| 569 |
+
|
| 570 |
+
if actual_chunk_size > 0:
|
| 571 |
+
z_tokens = target_representation[b, speech_position:end_position] # (actual_chunk_size, 512)
|
| 572 |
+
sequence_parts.append(z_tokens)
|
| 573 |
+
loss_mask_parts.extend([True] * actual_chunk_size) # Compute loss on z tokens
|
| 574 |
+
bce_label_parts.extend([0] * actual_chunk_size) # Label 0 for z tokens
|
| 575 |
+
bce_mask_parts.extend([True] * actual_chunk_size) # Compute BCE loss on z tokens
|
| 576 |
+
|
| 577 |
+
# Collect z_tokens for separation loss computation
|
| 578 |
+
all_z_tokens.append(z_tokens)
|
| 579 |
+
|
| 580 |
+
speech_position = end_position
|
| 581 |
+
|
| 582 |
+
# Add time_speech_end
|
| 583 |
+
sequence_parts.append(time_speech_end_emb)
|
| 584 |
+
loss_mask_parts.append(False)
|
| 585 |
+
bce_label_parts.append(1)
|
| 586 |
+
bce_mask_parts.append(True)
|
| 587 |
+
|
| 588 |
+
# Concatenate all parts for this batch item
|
| 589 |
+
full_sequence = torch.cat(sequence_parts, dim=0) # (total_length, 512)
|
| 590 |
+
loss_mask = torch.tensor(loss_mask_parts, dtype=torch.bool, device=device)
|
| 591 |
+
bce_labels = torch.tensor(bce_label_parts, dtype=torch.float, device=device)
|
| 592 |
+
bce_mask = torch.tensor(bce_mask_parts, dtype=torch.bool, device=device)
|
| 593 |
+
|
| 594 |
+
interleaved_sequences.append(full_sequence)
|
| 595 |
+
loss_masks.append(loss_mask)
|
| 596 |
+
bce_labels_batch.append(bce_labels)
|
| 597 |
+
bce_masks.append(bce_mask)
|
| 598 |
+
sequence_lengths.append(full_sequence.shape[0]) # Track actual length before padding
|
| 599 |
+
max_total_length = max(max_total_length, full_sequence.shape[0])
|
| 600 |
+
|
| 601 |
+
# Pad sequences
|
| 602 |
+
padded_sequences = []
|
| 603 |
+
padded_loss_masks = []
|
| 604 |
+
padded_bce_labels = []
|
| 605 |
+
padded_bce_masks = []
|
| 606 |
+
|
| 607 |
+
for sequence, loss_mask, bce_labels, bce_mask in zip(interleaved_sequences, loss_masks, bce_labels_batch, bce_masks):
|
| 608 |
+
current_length = sequence.shape[0]
|
| 609 |
+
if current_length < max_total_length:
|
| 610 |
+
padding = torch.zeros(max_total_length - current_length, 512, device=device, dtype=sequence.dtype)
|
| 611 |
+
padded_sequence = torch.cat([sequence, padding], dim=0)
|
| 612 |
+
|
| 613 |
+
mask_padding = torch.zeros(max_total_length - current_length, dtype=torch.bool, device=device)
|
| 614 |
+
padded_mask = torch.cat([loss_mask, mask_padding], dim=0)
|
| 615 |
+
|
| 616 |
+
bce_label_padding = torch.zeros(max_total_length - current_length, dtype=torch.float, device=device)
|
| 617 |
+
padded_bce_label = torch.cat([bce_labels, bce_label_padding], dim=0)
|
| 618 |
+
|
| 619 |
+
bce_mask_padding = torch.zeros(max_total_length - current_length, dtype=torch.bool, device=device)
|
| 620 |
+
padded_bce_mask = torch.cat([bce_mask, bce_mask_padding], dim=0)
|
| 621 |
+
else:
|
| 622 |
+
padded_sequence = sequence
|
| 623 |
+
padded_mask = loss_mask
|
| 624 |
+
padded_bce_label = bce_labels
|
| 625 |
+
padded_bce_mask = bce_mask
|
| 626 |
+
|
| 627 |
+
padded_sequences.append(padded_sequence)
|
| 628 |
+
padded_loss_masks.append(padded_mask)
|
| 629 |
+
padded_bce_labels.append(padded_bce_label)
|
| 630 |
+
padded_bce_masks.append(padded_bce_mask)
|
| 631 |
+
|
| 632 |
+
# Stack into batch tensors
|
| 633 |
+
interleaved_batch = torch.stack(padded_sequences, dim=0) # (batch_size, max_total_length, 512)
|
| 634 |
+
loss_mask_batch = torch.stack(padded_loss_masks, dim=0) # (batch_size, max_total_length)
|
| 635 |
+
bce_labels_batch_tensor = torch.stack(padded_bce_labels, dim=0) # (batch_size, max_total_length)
|
| 636 |
+
bce_mask_batch = torch.stack(padded_bce_masks, dim=0) # (batch_size, max_total_length)
|
| 637 |
+
|
| 638 |
+
# Autoregressive prediction
|
| 639 |
+
if max_total_length > 1:
|
| 640 |
+
ar_input = interleaved_batch[:, :-1, :] # (batch_size, max_total_length-1, 512)
|
| 641 |
+
ar_targets = interleaved_batch[:, 1:, :] # (batch_size, max_total_length-1, 512)
|
| 642 |
+
ar_loss_mask = loss_mask_batch[:, 1:] # (batch_size, max_total_length-1) - shift mask left
|
| 643 |
+
ar_bce_labels = bce_labels_batch_tensor[:, 1:] # (batch_size, max_total_length-1) - shift labels left
|
| 644 |
+
ar_bce_mask = bce_mask_batch[:, 1:] # (batch_size, max_total_length-1) - shift mask left
|
| 645 |
+
|
| 646 |
+
# Create attention mask for autoregressive transformer
|
| 647 |
+
# We need to mask padded positions while maintaining causal property
|
| 648 |
+
ar_seq_len = ar_input.shape[1]
|
| 649 |
+
ar_attention_mask = torch.zeros(batch_size, ar_seq_len, dtype=torch.bool, device=device)
|
| 650 |
+
for b in range(batch_size):
|
| 651 |
+
valid_len = min(ar_seq_len, sequence_lengths[b] - 1)
|
| 652 |
+
if valid_len > 0:
|
| 653 |
+
ar_attention_mask[b, :valid_len] = True
|
| 654 |
+
|
| 655 |
+
ar_outputs = self.ar_transformer(
|
| 656 |
+
hidden_states=ar_input,
|
| 657 |
+
attention_mask=ar_attention_mask, # This will be combined with causal mask inside transformer
|
| 658 |
+
)
|
| 659 |
+
ar_predictions = ar_outputs.last_hidden_state # (batch_size, max_total_length-1, 512)
|
| 660 |
+
|
| 661 |
+
# Compute BCE predictions for end token classification
|
| 662 |
+
bce_logits = self.end_token_classifier(ar_predictions).squeeze(-1) # (batch_size, max_total_length-1)
|
| 663 |
+
|
| 664 |
+
# Compute L2 loss only where ar_loss_mask is True (z tokens)
|
| 665 |
+
if ar_loss_mask.any():
|
| 666 |
+
# Extract valid positions for loss computation
|
| 667 |
+
valid_predictions = ar_predictions[ar_loss_mask] # (num_valid_positions, 512)
|
| 668 |
+
valid_targets = ar_targets[ar_loss_mask] # (num_valid_positions, 512)
|
| 669 |
+
|
| 670 |
+
# Compute L2 loss (MSE)
|
| 671 |
+
reconstruction_loss = nn.functional.mse_loss(
|
| 672 |
+
valid_predictions,
|
| 673 |
+
valid_targets,
|
| 674 |
+
reduction='mean'
|
| 675 |
+
)
|
| 676 |
+
else:
|
| 677 |
+
# Fallback if no valid positions (shouldn't happen in practice)
|
| 678 |
+
reconstruction_loss = torch.tensor(0.0, device=device, requires_grad=True)
|
| 679 |
+
|
| 680 |
+
# Compute BCE loss for end token classification (v1.1)
|
| 681 |
+
if ar_bce_mask.any():
|
| 682 |
+
# Extract valid positions for BCE loss computation
|
| 683 |
+
valid_bce_logits = bce_logits[ar_bce_mask] # (num_valid_bce_positions,)
|
| 684 |
+
valid_bce_labels = ar_bce_labels[ar_bce_mask] # (num_valid_bce_positions,)
|
| 685 |
+
|
| 686 |
+
# Compute BCE loss
|
| 687 |
+
bce_end_token_loss = nn.functional.binary_cross_entropy_with_logits(
|
| 688 |
+
valid_bce_logits,
|
| 689 |
+
valid_bce_labels,
|
| 690 |
+
reduction='mean'
|
| 691 |
+
)
|
| 692 |
+
else:
|
| 693 |
+
# Fallback if no valid BCE positions
|
| 694 |
+
bce_end_token_loss = torch.tensor(0.0, device=device, requires_grad=True)
|
| 695 |
+
|
| 696 |
+
if self.bce_threshold > 0.0:
|
| 697 |
+
clamped_bce_loss = torch.clamp(bce_end_token_loss - self.bce_threshold, min=0.0)
|
| 698 |
+
total_loss = reconstruction_loss + self.alpha * clamped_bce_loss
|
| 699 |
+
else:
|
| 700 |
+
total_loss = reconstruction_loss + self.alpha * bce_end_token_loss
|
| 701 |
+
else:
|
| 702 |
+
reconstruction_loss = torch.tensor(0.0, device=device, requires_grad=True)
|
| 703 |
+
bce_end_token_loss = torch.tensor(0.0, device=device, requires_grad=True)
|
| 704 |
+
total_loss = reconstruction_loss + torch.tensor(0.0, device=device, requires_grad=True)
|
| 705 |
+
|
| 706 |
+
return {
|
| 707 |
+
'loss': total_loss,
|
| 708 |
+
'reconstruction_loss': reconstruction_loss,
|
| 709 |
+
'bce_end_token_loss': bce_end_token_loss,
|
| 710 |
+
}
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
__all__ = ["TextSyncMimi"]
|