File size: 37,729 Bytes
d7a2a0f e72ddf8 94c52d0 e72ddf8 94c52d0 e72ddf8 d7a2a0f 94c52d0 d7a2a0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 |
"""PyTorch TextSyncMimi model - Text-synchronous neural audio codec based on Mimi."""
import torch
import torch.nn as nn
from typing import Optional, Dict, List, Union
try:
from .configuration_mimi import MimiConfig
from .configuration_text_sync_mimi import TextSyncMimiConfig
from .modeling_mimi_clean import MimiPreTrainedModel, MimiModel
from .modeling_backbone_components import (
CrossAttentionTransformer,
CausalAttentionTransformer
)
except ImportError:
from configuration_mimi import MimiConfig
from configuration_text_sync_mimi import TextSyncMimiConfig
from modeling_mimi_clean import MimiPreTrainedModel, MimiModel
from modeling_backbone_components import (
CrossAttentionTransformer,
CausalAttentionTransformer
)
class TextSyncMimi(MimiPreTrainedModel):
"""
TextSyncMimi: Text-Synchronous Neural Audio Codec Model
A neural audio codec model that combines text and speech representations for
high-quality text-to-speech synthesis. Features:
- Learnable text embeddings
- Cross-attention transformer for text-speech alignment
- Autoregressive transformer for causal speech generation
- BCE-based end token prediction for dynamic duration control
Architecture:
- Text Embedding Layer: Maps token IDs to 4,096-dim embeddings
- Mimi Encoder: Pre-trained audio encoder (frozen)
- Text Projection: Linear projection from 4,096 to 512 dimensions
- Cross-Attention Transformer: Aligns text with speech features
- Autoregressive Transformer: Generates speech representations
- End Token Classifier: Predicts when to stop generating
"""
config_class = TextSyncMimiConfig
def __init__(
self,
config: Optional[Union[MimiConfig, 'TextSyncMimiConfig']] = None,
model_id: Optional[str] = None,
token: Optional[str] = None,
alpha: Optional[float] = None,
cross_attention_layers: Optional[int] = None,
causal_attention_layers: Optional[int] = None,
bce_threshold: Optional[float] = None,
vocab_size: Optional[int] = None,
):
"""
Initialize TextSyncMimi model.
Args:
config: Model configuration (TextSyncMimiConfig or MimiConfig)
model_id: Mimi model ID (e.g., "kyutai/mimi"). If None, uses config.mimi_model_id
token: Hugging Face authentication token
alpha: Weight for BCE end token loss. If None, uses config.alpha
cross_attention_layers: Number of cross-attention layers. If None, uses config
causal_attention_layers: Number of autoregressive layers. If None, uses config
bce_threshold: BCE loss threshold. If None, uses config.bce_threshold
vocab_size: Text vocabulary size. If None, uses config.vocab_size
"""
# Handle config initialization for both manual instantiation and from_pretrained
if config is None:
if model_id is None:
raise ValueError("Either config or model_id must be provided")
config = MimiConfig.from_pretrained(model_id, token=token)
super().__init__(config)
# Extract parameters from config if not explicitly provided
if hasattr(config, 'mimi_model_id'):
model_id = model_id or config.mimi_model_id
if model_id is None:
raise ValueError("model_id must be provided either as argument or in config.mimi_model_id")
alpha = alpha if alpha is not None else getattr(config, 'alpha', 1.0)
cross_attention_layers = cross_attention_layers if cross_attention_layers is not None else getattr(config, 'cross_attention_layers', 2)
causal_attention_layers = causal_attention_layers if causal_attention_layers is not None else getattr(config, 'causal_attention_layers', 2)
bce_threshold = bce_threshold if bce_threshold is not None else getattr(config, 'bce_threshold', 0.1)
vocab_size = vocab_size if vocab_size is not None else getattr(config, 'vocab_size', 128256)
# load the mimi backbone
self.config = config
model = MimiModel.from_pretrained(model_id, token=token)
# hyperparameters for auxiliary loss
self.alpha = alpha
self.bce_threshold = bce_threshold
# Learnable text token embedding
self.text_token_embedding = nn.Embedding(vocab_size, 4096)
# Text projection
self.text_proj = nn.Linear(4096, 512)
# Cross-attention transformer
cross_attention_config = MimiConfig(**self.config.__dict__)
cross_attention_config.num_hidden_layers = cross_attention_layers
cross_attention_config.hidden_size = 512
self.cross_attention_transformer = CrossAttentionTransformer(cross_attention_config)
# decoder part (v1)
# Auto-regressive decoder:
# <|text_speech_latent|> [t_i] [s_i] <|time_speech_start|> [z_(i,1)] [z_(i,2)] ... [z_(i,K)] <|time_speech_end|>
# masking (not computing loss for <|text_speech_latent|> [t_i] [s_i] <|time_speech_start|>
# t_i already mapped from 4096 (e.g., llama embedding) -> 512
# s_i already 512
# z is mimi's decoder-input which is also 512
causal_attention_config = MimiConfig(**self.config.__dict__)
causal_attention_config.num_hidden_layers = causal_attention_layers
causal_attention_config.hidden_size = 512
self.ar_transformer = CausalAttentionTransformer(causal_attention_config)
# embedding for special positions in the autoregressive decoder
self.text_speech_latent_embed = nn.Embedding(1, 512)
self.time_speech_start_embed = nn.Embedding(1, 512)
self.time_speech_end_embed = nn.Embedding(1, 512)
# Binary classification head for end token prediction
self.end_token_classifier = nn.Linear(512, 1)
self.post_init()
# Frozen Mimi components
self.encoder = model.encoder
self.encoder_transformer = model.encoder_transformer
self.quantizer = model.quantizer
self.downsample = model.downsample
self.upsample = model.upsample
# print the number of parameters for each sub network in Millions
self._print_subnetwork_parameter_counts()
def initialize_text_embeddings_from_weights(self, embedding_weight: torch.Tensor) -> None:
"""
Initialize text embeddings from a weight matrix.
Args:
embedding_weight: Weight matrix of shape (vocab_size, 4096)
"""
if embedding_weight.dim() != 2 or embedding_weight.size(1) != 4096:
raise ValueError("embedding_weight must have shape (vocab_size, 4096)")
if embedding_weight.size(0) != self.text_token_embedding.num_embeddings:
raise ValueError("Provided vocab_size does not match model's text_token_embedding")
with torch.no_grad():
self.text_token_embedding.weight.copy_(embedding_weight)
for p in self.text_token_embedding.parameters():
p.requires_grad = True
def initialize_text_embeddings_from_llama(self, llama_embeddings_module: torch.nn.Module) -> None:
"""
Initialize text embeddings from a LLaMA embedding module.
Args:
llama_embeddings_module: LLaMA embedding module with weight shape (vocab_size, 4096)
"""
if not hasattr(llama_embeddings_module, 'weight'):
raise ValueError("llama_embeddings_module must have a 'weight' attribute")
weight = llama_embeddings_module.weight.data
self.initialize_text_embeddings_from_weights(weight)
def _print_subnetwork_parameter_counts(self) -> None:
"""Print parameter counts for model subnetworks."""
print("=" * 70)
print("TextSyncMimi Parameter Counts")
print("=" * 70)
print(f"Encoder: {sum(p.numel() for p in self.encoder.parameters()) / 1e6:.2f}M")
print(f"Encoder Transformer: {sum(p.numel() for p in self.encoder_transformer.parameters()) / 1e6:.2f}M")
print(f"Cross-Attention Transformer: {sum(p.numel() for p in self.cross_attention_transformer.parameters()) / 1e6:.2f}M")
print(f"AR Transformer: {sum(p.numel() for p in self.ar_transformer.parameters()) / 1e6:.2f}M")
print(f"Quantizer: {sum(p.numel() for p in self.quantizer.parameters()) / 1e6:.2f}M")
print("=" * 70)
def encode_audio_to_representation(
self,
input_values: torch.Tensor,
audio_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Encode audio to speech representation.
Args:
input_values: Audio waveform (B, 1, audio_len)
audio_attention_mask: Attention mask (B, audio_len)
Returns:
Speech embeddings (B, 512, 12.5 * T)
"""
batch_size = input_values.shape[0]
device = input_values.device
# Encode through Mimi encoder pipeline
embeddings = self.encoder(input_values)
encoder_outputs = self.encoder_transformer(embeddings.transpose(1, 2))
embeddings = encoder_outputs[0].transpose(1, 2)
embeddings = self.downsample(embeddings)
# Apply attention mask if provided
if audio_attention_mask is not None:
speech_seq_len = embeddings.shape[-1]
speech_attention_mask = torch.zeros(batch_size, speech_seq_len, device=device, dtype=torch.bool)
for b in range(batch_size):
actual_audio_len = audio_attention_mask[b].sum().item()
actual_speech_len = int(actual_audio_len * 12.5 / 24000)
actual_speech_len = min(actual_speech_len, speech_seq_len)
if actual_speech_len > 0:
speech_attention_mask[b, :actual_speech_len] = True
speech_mask_expanded = speech_attention_mask.unsqueeze(1)
embeddings = embeddings * speech_mask_expanded.float()
return embeddings
def generate_autoregressive(
self,
text_token_ids: torch.LongTensor,
input_values: Optional[torch.Tensor] = None,
speech_embeddings: Optional[torch.Tensor] = None,
audio_attention_mask: Optional[torch.Tensor] = None,
speech_attention_mask: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.Tensor] = None,
max_z_tokens: int = 50,
end_token_threshold: float = 0.5,
device: Optional[torch.device] = None,
) -> List[List[torch.Tensor]]:
"""
Generate audio autoregressively.
Args:
text_token_ids: Text token IDs (B, L)
input_values: Audio input (B, 1, 24000 * T) - for normal mode
speech_embeddings: Pre-computed speech embeddings (B, T, 512) - for cached mode
audio_attention_mask: Audio mask (B, audio_seq_len) - for normal mode
speech_attention_mask: Speech mask (B, speech_seq_len) - for cached mode
text_attention_mask: Text mask (B, text_seq_len)
max_z_tokens: Maximum z tokens per text position
end_token_threshold: Probability threshold for stopping
device: Device for computation
Returns:
List of z_tokens lists (one per batch item)
"""
if device is None:
device = text_token_ids.device
self.eval()
with torch.no_grad():
# Get speech embeddings for cross-attention context
if speech_embeddings is not None:
# Use pre-computed speech embeddings (cached mode)
# speech_embeddings should already be (B, T, 512)
pass # speech_embeddings is already provided
else:
# Compute speech embeddings from input_values (normal mode)
if input_values is None:
raise ValueError("Either input_values or speech_embeddings must be provided")
speech_embeddings = self.encode_audio_to_representation(
input_values,
audio_attention_mask=audio_attention_mask
)
speech_embeddings = speech_embeddings.transpose(1, 2) # (B, T, 512)
# Embed token ids then project to 512
text_embeddings_4096 = self.text_token_embedding(text_token_ids) # (B, L, 4096)
text_embeddings_proj = self.text_proj(text_embeddings_4096) # (B, L, 512)
# Apply cross attention (same as in forward)
# Create attention masks
formatted_text_attention_mask = None
formatted_speech_attention_mask = None
batch_size, text_seq_len = text_embeddings_proj.shape[:2]
if text_attention_mask is not None:
causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_embeddings_proj.dtype))
causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
padding_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len)
combined_mask = causal_mask * padding_mask
formatted_text_attention_mask = torch.where(combined_mask.bool(), 0.0, float('-inf'))
else:
causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_embeddings_proj.dtype))
causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
formatted_text_attention_mask = torch.where(causal_mask.bool(), 0.0, float('-inf'))
# Handle speech attention mask (use speech_attention_mask if available, otherwise audio_attention_mask)
if speech_attention_mask is not None:
# For cached data, speech_attention_mask is already in the right format
speech_seq_len = speech_embeddings.shape[1]
speech_mask = speech_attention_mask.bool()
formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
elif audio_attention_mask is not None:
# For non-cached data, convert audio_attention_mask to speech_attention_mask
speech_seq_len = speech_embeddings.shape[1]
speech_mask = torch.zeros(batch_size, speech_seq_len, dtype=torch.bool, device=device)
for b in range(batch_size):
audio_len = audio_attention_mask[b].sum().item()
speech_len = int(audio_len * 12.5 / 24000)
speech_len = min(speech_len, speech_seq_len)
speech_mask[b, :speech_len] = True
formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
else:
formatted_speech_attention_mask = None
# Cross attention
cross_attention_outputs = self.cross_attention_transformer(
hidden_states=text_embeddings_proj,
encoder_hidden_states=speech_embeddings,
attention_mask=formatted_text_attention_mask,
encoder_attention_mask=formatted_speech_attention_mask,
alignment_chunk_sizes=None, # V1 learns alignment
)
cross_attention_outputs = cross_attention_outputs.last_hidden_state
# Get special embeddings
text_speech_latent_emb = self.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device))
time_speech_start_emb = self.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device))
time_speech_end_emb = self.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device))
generated_z_tokens = []
# Generate for each batch item
for b in range(batch_size):
# Get valid text length for this sample
if text_attention_mask is not None:
valid_text_len = text_attention_mask[b].sum().item()
else:
valid_text_len = text_embeddings_proj.shape[1]
# Start sequence with text_speech_latent for context
sequence = [text_speech_latent_emb] # (1, 512)
batch_z_tokens = [] # Store z_tokens for this batch item
# Generate for each text position
for i in range(valid_text_len):
# Add t_i and s_i
t_i = text_embeddings_proj[b, i:i+1] # (1, 512)
s_i = cross_attention_outputs[b, i:i+1] # (1, 512)
sequence.extend([t_i, s_i])
# Add time_speech_start
sequence.append(time_speech_start_emb)
# Generate z tokens autoregressively for this text position
z_count = 0
while z_count < max_z_tokens:
# Prepare current sequence for AR transformer
current_sequence = torch.cat(sequence, dim=0).unsqueeze(0) # (1, seq_len, 512)
# Create attention mask for current sequence
seq_len = current_sequence.shape[1]
ar_attention_mask = torch.ones(1, seq_len, dtype=torch.bool, device=device)
# Get prediction from AR transformer
ar_outputs = self.ar_transformer(
hidden_states=current_sequence,
attention_mask=ar_attention_mask,
)
# Get the last prediction
last_prediction = ar_outputs.last_hidden_state[0, -1:, :] # (1, 512)
# Check stopping condition using BCE classifier (v1.1)
end_token_logit = self.end_token_classifier(last_prediction).squeeze(-1) # (1,)
end_token_prob = torch.sigmoid(end_token_logit).item() # Convert to probability
# Stop if probability is high enough (>= threshold means stop)
if end_token_prob >= end_token_threshold:
# Stop generating z tokens
break
else:
# Add this prediction as next z token to both sequence (for context) and z_tokens (for output)
sequence.append(last_prediction)
batch_z_tokens.append(last_prediction.squeeze(0)) # Remove batch dimension for output
z_count += 1
# Add time_speech_end to sequence for context
sequence.append(time_speech_end_emb)
# Store z_tokens for this batch item
generated_z_tokens.append(batch_z_tokens)
return generated_z_tokens
def forward(
self,
text_token_ids: torch.LongTensor,
input_values: Optional[torch.Tensor] = None,
speech_embeddings: Optional[torch.Tensor] = None,
alignment_chunk_sizes: torch.Tensor = None,
audio_attention_mask: Optional[torch.Tensor] = None,
speech_attention_mask: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""
Forward pass for training.
Args:
text_token_ids: Text token IDs (B, L)
input_values: Audio input (B, 1, 24000 * T) - for normal mode
speech_embeddings: Pre-computed speech embeddings (B, T, 512) - for cached mode
alignment_chunk_sizes: Alignment chunk sizes (B, L)
audio_attention_mask: Audio mask (B, audio_seq_len)
speech_attention_mask: Speech mask (B, speech_seq_len)
text_attention_mask: Text mask (B, text_seq_len)
Returns:
Dictionary with 'loss', 'reconstruction_loss', and 'bce_end_token_loss'
"""
# Get speech embeddings
if speech_embeddings is not None:
pass
elif input_values is not None:
# Normal mode: compute speech embeddings from input_values
speech_embeddings_raw = self.encode_audio_to_representation(
input_values,
audio_attention_mask
)
# speech_embeddings_raw.shape = (B, 512, 12.5*T)
# Transpose: [B, 512, 12.5*T] -> [B, 12.5*T, 512]
speech_embeddings = speech_embeddings_raw.transpose(1, 2)
else:
raise ValueError("Either input_values or speech_embeddings must be provided")
# Embed token ids and project to 512-dim
text_embeddings_4096 = self.text_token_embedding(text_token_ids) # (B, L, 4096)
text_embeddings = self.text_proj(text_embeddings_4096) # (B, L, 512)
# Create proper attention masks for cross-attention
formatted_text_attention_mask = None
formatted_speech_attention_mask = None
# Handle text attention mask (causal mask for decoder)
batch_size, text_seq_len = text_embeddings.shape[:2]
if text_attention_mask is not None:
# Create causal mask and apply padding mask
causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=text_embeddings.device, dtype=text_embeddings.dtype))
causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
# Apply padding mask to causal mask
padding_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len)
combined_mask = causal_mask * padding_mask
# Convert to attention scores (-inf for masked positions)
formatted_text_attention_mask = torch.where(combined_mask.bool(), 0.0, float('-inf'))
else:
# Create causal mask for all positions (no padding mask)
causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=text_embeddings.device, dtype=text_embeddings.dtype))
causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1)
formatted_text_attention_mask = torch.where(causal_mask.bool(), 0.0, float('-inf'))
# Handle speech attention mask (encoder mask)
# Use speech_attention_mask if available (cached mode), otherwise audio_attention_mask (normal mode)
if speech_attention_mask is not None:
# Cached mode: speech_attention_mask is already in the right format
speech_seq_len = speech_embeddings.shape[1]
speech_mask = speech_attention_mask.bool()
# Convert to attention format: [batch_size, 1, 1, speech_seq_len]
formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
elif audio_attention_mask is not None:
# Normal mode: convert audio mask to speech embedding mask
speech_seq_len = speech_embeddings.shape[1]
# Create speech attention mask based on actual lengths
speech_mask = torch.zeros(batch_size, speech_seq_len, dtype=torch.bool, device=speech_embeddings.device)
for b in range(batch_size):
audio_len = audio_attention_mask[b].sum().item()
speech_len = int(audio_len * 12.5 / 24000)
speech_len = min(speech_len, speech_seq_len)
speech_mask[b, :speech_len] = True
# Convert to attention format: [batch_size, 1, 1, speech_seq_len]
formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len)
formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf'))
else:
# No masking
formatted_speech_attention_mask = None
# Cross attention: text attends to speech (no alignment constraints in V1)
# hidden_states (decoder) = text, encoder_hidden_states = speech
cross_attention_outputs = self.cross_attention_transformer(
hidden_states=text_embeddings,
encoder_hidden_states=speech_embeddings,
attention_mask=formatted_text_attention_mask, # Causal mask for text (decoder)
encoder_attention_mask=formatted_speech_attention_mask, # Mask for speech (encoder)
alignment_chunk_sizes=None, # v1 doesn't use alignment_chunk_sizes -- the model should learn the alignment itself
)
cross_attention_outputs = cross_attention_outputs.last_hidden_state
# Auto-regressive decoder part
# Following v0.5 where the target is the dequantized Mimi decoder-input
# Compute target representation = Mimi decoder-input (quantized->dequantized at 12.5*seconds)
# 12.5*seconds => T
with torch.no_grad():
embeddings_bct = speech_embeddings.transpose(1, 2) # (B, 512, T)
codes_kbt = self.quantizer.encode(embeddings_bct) # [K, B, T]
codes_bkt = codes_kbt.transpose(0, 1) # [B, K, T]
decoder_input_emb = self.quantizer.decode(codes_bkt) # (B, 512, T)
target_representation = decoder_input_emb.transpose(1, 2) # (B, T, 512)
# Build the interleaved sequence for the autoregressive decoder
# as well as the mask for loss computation
# Get special embeddings (all are single embeddings)
device = text_embeddings.device
text_speech_latent_emb = self.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device)) # (1, 512)
time_speech_start_emb = self.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device)) # (1, 512)
time_speech_end_emb = self.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device)) # (1, 512)
batch_size = text_embeddings.shape[0]
interleaved_sequences = []
loss_masks = []
bce_labels_batch = [] # BCE labels: 0 for z tokens, 1 for time_speech_end_emb
bce_masks = [] # BCE mask: True for z tokens and time_speech_end_emb
sequence_lengths = [] # Track actual sequence lengths before padding
all_z_tokens = [] # Collect all valid z_tokens for separation loss
max_total_length = 0
for b in range(batch_size):
# Start with text_speech_latent embedding
sequence_parts = [text_speech_latent_emb] # List to collect sequence parts
loss_mask_parts = [False] # Don't compute loss on special tokens
bce_label_parts = [0] # BCE labels (dummy for text_speech_latent_emb)
bce_mask_parts = [False] # BCE mask (False for text_speech_latent_emb)
# Get valid text length for this batch item
if text_attention_mask is not None:
valid_text_len = text_attention_mask[b].sum().item()
else:
valid_text_len = text_embeddings.shape[1]
# Track current position in target_representation
speech_position = 0
# For each text token
for i in range(valid_text_len):
# Add t_i (text embedding)
t_i = text_embeddings[b, i:i+1] # (1, 512)
sequence_parts.append(t_i)
loss_mask_parts.append(False)
bce_label_parts.append(0) # Dummy label for t_i
bce_mask_parts.append(False) # No BCE loss for t_i
# Add s_i (cross attention output)
s_i = cross_attention_outputs[b, i:i+1] # (1, 512)
sequence_parts.append(s_i)
loss_mask_parts.append(False)
bce_label_parts.append(0) # Dummy label for s_i
bce_mask_parts.append(False) # No BCE loss for s_i
# Add time_speech_start
sequence_parts.append(time_speech_start_emb)
loss_mask_parts.append(False)
bce_label_parts.append(0) # Dummy label for time_speech_start
bce_mask_parts.append(False) # No BCE loss for time_speech_start
# Add z tokens for this chunk
chunk_size = alignment_chunk_sizes[b, i].item()
if chunk_size > 0: # Only add if chunk size is positive
end_position = speech_position + chunk_size
# Make sure we don't exceed target_representation length
end_position = min(end_position, target_representation.shape[1])
actual_chunk_size = end_position - speech_position
if actual_chunk_size > 0:
z_tokens = target_representation[b, speech_position:end_position] # (actual_chunk_size, 512)
sequence_parts.append(z_tokens)
loss_mask_parts.extend([True] * actual_chunk_size) # Compute loss on z tokens
bce_label_parts.extend([0] * actual_chunk_size) # Label 0 for z tokens
bce_mask_parts.extend([True] * actual_chunk_size) # Compute BCE loss on z tokens
# Collect z_tokens for separation loss computation
all_z_tokens.append(z_tokens)
speech_position = end_position
# Add time_speech_end
sequence_parts.append(time_speech_end_emb)
loss_mask_parts.append(False)
bce_label_parts.append(1)
bce_mask_parts.append(True)
# Concatenate all parts for this batch item
full_sequence = torch.cat(sequence_parts, dim=0) # (total_length, 512)
loss_mask = torch.tensor(loss_mask_parts, dtype=torch.bool, device=device)
bce_labels = torch.tensor(bce_label_parts, dtype=torch.float, device=device)
bce_mask = torch.tensor(bce_mask_parts, dtype=torch.bool, device=device)
interleaved_sequences.append(full_sequence)
loss_masks.append(loss_mask)
bce_labels_batch.append(bce_labels)
bce_masks.append(bce_mask)
sequence_lengths.append(full_sequence.shape[0]) # Track actual length before padding
max_total_length = max(max_total_length, full_sequence.shape[0])
# Pad sequences
padded_sequences = []
padded_loss_masks = []
padded_bce_labels = []
padded_bce_masks = []
for sequence, loss_mask, bce_labels, bce_mask in zip(interleaved_sequences, loss_masks, bce_labels_batch, bce_masks):
current_length = sequence.shape[0]
if current_length < max_total_length:
padding = torch.zeros(max_total_length - current_length, 512, device=device, dtype=sequence.dtype)
padded_sequence = torch.cat([sequence, padding], dim=0)
mask_padding = torch.zeros(max_total_length - current_length, dtype=torch.bool, device=device)
padded_mask = torch.cat([loss_mask, mask_padding], dim=0)
bce_label_padding = torch.zeros(max_total_length - current_length, dtype=torch.float, device=device)
padded_bce_label = torch.cat([bce_labels, bce_label_padding], dim=0)
bce_mask_padding = torch.zeros(max_total_length - current_length, dtype=torch.bool, device=device)
padded_bce_mask = torch.cat([bce_mask, bce_mask_padding], dim=0)
else:
padded_sequence = sequence
padded_mask = loss_mask
padded_bce_label = bce_labels
padded_bce_mask = bce_mask
padded_sequences.append(padded_sequence)
padded_loss_masks.append(padded_mask)
padded_bce_labels.append(padded_bce_label)
padded_bce_masks.append(padded_bce_mask)
# Stack into batch tensors
interleaved_batch = torch.stack(padded_sequences, dim=0) # (batch_size, max_total_length, 512)
loss_mask_batch = torch.stack(padded_loss_masks, dim=0) # (batch_size, max_total_length)
bce_labels_batch_tensor = torch.stack(padded_bce_labels, dim=0) # (batch_size, max_total_length)
bce_mask_batch = torch.stack(padded_bce_masks, dim=0) # (batch_size, max_total_length)
# Autoregressive prediction
if max_total_length > 1:
ar_input = interleaved_batch[:, :-1, :] # (batch_size, max_total_length-1, 512)
ar_targets = interleaved_batch[:, 1:, :] # (batch_size, max_total_length-1, 512)
ar_loss_mask = loss_mask_batch[:, 1:] # (batch_size, max_total_length-1) - shift mask left
ar_bce_labels = bce_labels_batch_tensor[:, 1:] # (batch_size, max_total_length-1) - shift labels left
ar_bce_mask = bce_mask_batch[:, 1:] # (batch_size, max_total_length-1) - shift mask left
# Create attention mask for autoregressive transformer
# We need to mask padded positions while maintaining causal property
ar_seq_len = ar_input.shape[1]
ar_attention_mask = torch.zeros(batch_size, ar_seq_len, dtype=torch.bool, device=device)
for b in range(batch_size):
valid_len = min(ar_seq_len, sequence_lengths[b] - 1)
if valid_len > 0:
ar_attention_mask[b, :valid_len] = True
ar_outputs = self.ar_transformer(
hidden_states=ar_input,
attention_mask=ar_attention_mask, # This will be combined with causal mask inside transformer
)
ar_predictions = ar_outputs.last_hidden_state # (batch_size, max_total_length-1, 512)
# Compute BCE predictions for end token classification
bce_logits = self.end_token_classifier(ar_predictions).squeeze(-1) # (batch_size, max_total_length-1)
# Compute L2 loss only where ar_loss_mask is True (z tokens)
if ar_loss_mask.any():
# Extract valid positions for loss computation
valid_predictions = ar_predictions[ar_loss_mask] # (num_valid_positions, 512)
valid_targets = ar_targets[ar_loss_mask] # (num_valid_positions, 512)
# Compute L2 loss (MSE)
reconstruction_loss = nn.functional.mse_loss(
valid_predictions,
valid_targets,
reduction='mean'
)
else:
# Fallback if no valid positions (shouldn't happen in practice)
reconstruction_loss = torch.tensor(0.0, device=device, requires_grad=True)
# Compute BCE loss for end token classification (v1.1)
if ar_bce_mask.any():
# Extract valid positions for BCE loss computation
valid_bce_logits = bce_logits[ar_bce_mask] # (num_valid_bce_positions,)
valid_bce_labels = ar_bce_labels[ar_bce_mask] # (num_valid_bce_positions,)
# Compute BCE loss
bce_end_token_loss = nn.functional.binary_cross_entropy_with_logits(
valid_bce_logits,
valid_bce_labels,
reduction='mean'
)
else:
# Fallback if no valid BCE positions
bce_end_token_loss = torch.tensor(0.0, device=device, requires_grad=True)
if self.bce_threshold > 0.0:
clamped_bce_loss = torch.clamp(bce_end_token_loss - self.bce_threshold, min=0.0)
total_loss = reconstruction_loss + self.alpha * clamped_bce_loss
else:
total_loss = reconstruction_loss + self.alpha * bce_end_token_loss
else:
reconstruction_loss = torch.tensor(0.0, device=device, requires_grad=True)
bce_end_token_loss = torch.tensor(0.0, device=device, requires_grad=True)
total_loss = reconstruction_loss + torch.tensor(0.0, device=device, requires_grad=True)
return {
'loss': total_loss,
'reconstruction_loss': reconstruction_loss,
'bce_end_token_loss': bce_end_token_loss,
}
__all__ = ["TextSyncMimi"]
|