"""PyTorch TextSyncMimi model - Text-synchronous neural audio codec based on Mimi.""" import torch import torch.nn as nn from typing import Optional, Dict, List, Union try: from .configuration_mimi import MimiConfig from .configuration_text_sync_mimi import TextSyncMimiConfig from .modeling_mimi_clean import MimiPreTrainedModel, MimiModel from .modeling_backbone_components import ( CrossAttentionTransformer, CausalAttentionTransformer ) except ImportError: from configuration_mimi import MimiConfig from configuration_text_sync_mimi import TextSyncMimiConfig from modeling_mimi_clean import MimiPreTrainedModel, MimiModel from modeling_backbone_components import ( CrossAttentionTransformer, CausalAttentionTransformer ) class TextSyncMimi(MimiPreTrainedModel): """ TextSyncMimi: Text-Synchronous Neural Audio Codec Model A neural audio codec model that combines text and speech representations for high-quality text-to-speech synthesis. Features: - Learnable text embeddings - Cross-attention transformer for text-speech alignment - Autoregressive transformer for causal speech generation - BCE-based end token prediction for dynamic duration control Architecture: - Text Embedding Layer: Maps token IDs to 4,096-dim embeddings - Mimi Encoder: Pre-trained audio encoder (frozen) - Text Projection: Linear projection from 4,096 to 512 dimensions - Cross-Attention Transformer: Aligns text with speech features - Autoregressive Transformer: Generates speech representations - End Token Classifier: Predicts when to stop generating """ config_class = TextSyncMimiConfig def __init__( self, config: Optional[Union[MimiConfig, 'TextSyncMimiConfig']] = None, model_id: Optional[str] = None, token: Optional[str] = None, alpha: Optional[float] = None, cross_attention_layers: Optional[int] = None, causal_attention_layers: Optional[int] = None, bce_threshold: Optional[float] = None, vocab_size: Optional[int] = None, ): """ Initialize TextSyncMimi model. Args: config: Model configuration (TextSyncMimiConfig or MimiConfig) model_id: Mimi model ID (e.g., "kyutai/mimi"). If None, uses config.mimi_model_id token: Hugging Face authentication token alpha: Weight for BCE end token loss. If None, uses config.alpha cross_attention_layers: Number of cross-attention layers. If None, uses config causal_attention_layers: Number of autoregressive layers. If None, uses config bce_threshold: BCE loss threshold. If None, uses config.bce_threshold vocab_size: Text vocabulary size. If None, uses config.vocab_size """ # Handle config initialization for both manual instantiation and from_pretrained if config is None: if model_id is None: raise ValueError("Either config or model_id must be provided") config = MimiConfig.from_pretrained(model_id, token=token) super().__init__(config) # Extract parameters from config if not explicitly provided if hasattr(config, 'mimi_model_id'): model_id = model_id or config.mimi_model_id if model_id is None: raise ValueError("model_id must be provided either as argument or in config.mimi_model_id") alpha = alpha if alpha is not None else getattr(config, 'alpha', 1.0) cross_attention_layers = cross_attention_layers if cross_attention_layers is not None else getattr(config, 'cross_attention_layers', 2) causal_attention_layers = causal_attention_layers if causal_attention_layers is not None else getattr(config, 'causal_attention_layers', 2) bce_threshold = bce_threshold if bce_threshold is not None else getattr(config, 'bce_threshold', 0.1) vocab_size = vocab_size if vocab_size is not None else getattr(config, 'vocab_size', 128256) # load the mimi backbone self.config = config model = MimiModel.from_pretrained(model_id, token=token) # hyperparameters for auxiliary loss self.alpha = alpha self.bce_threshold = bce_threshold # Learnable text token embedding self.text_token_embedding = nn.Embedding(vocab_size, 4096) # Text projection self.text_proj = nn.Linear(4096, 512) # Cross-attention transformer cross_attention_config = MimiConfig(**self.config.__dict__) cross_attention_config.num_hidden_layers = cross_attention_layers cross_attention_config.hidden_size = 512 self.cross_attention_transformer = CrossAttentionTransformer(cross_attention_config) # decoder part (v1) # Auto-regressive decoder: # <|text_speech_latent|> [t_i] [s_i] <|time_speech_start|> [z_(i,1)] [z_(i,2)] ... [z_(i,K)] <|time_speech_end|> # masking (not computing loss for <|text_speech_latent|> [t_i] [s_i] <|time_speech_start|> # t_i already mapped from 4096 (e.g., llama embedding) -> 512 # s_i already 512 # z is mimi's decoder-input which is also 512 causal_attention_config = MimiConfig(**self.config.__dict__) causal_attention_config.num_hidden_layers = causal_attention_layers causal_attention_config.hidden_size = 512 self.ar_transformer = CausalAttentionTransformer(causal_attention_config) # embedding for special positions in the autoregressive decoder self.text_speech_latent_embed = nn.Embedding(1, 512) self.time_speech_start_embed = nn.Embedding(1, 512) self.time_speech_end_embed = nn.Embedding(1, 512) # Binary classification head for end token prediction self.end_token_classifier = nn.Linear(512, 1) self.post_init() # Frozen Mimi components self.encoder = model.encoder self.encoder_transformer = model.encoder_transformer self.quantizer = model.quantizer self.downsample = model.downsample self.upsample = model.upsample # print the number of parameters for each sub network in Millions self._print_subnetwork_parameter_counts() def initialize_text_embeddings_from_weights(self, embedding_weight: torch.Tensor) -> None: """ Initialize text embeddings from a weight matrix. Args: embedding_weight: Weight matrix of shape (vocab_size, 4096) """ if embedding_weight.dim() != 2 or embedding_weight.size(1) != 4096: raise ValueError("embedding_weight must have shape (vocab_size, 4096)") if embedding_weight.size(0) != self.text_token_embedding.num_embeddings: raise ValueError("Provided vocab_size does not match model's text_token_embedding") with torch.no_grad(): self.text_token_embedding.weight.copy_(embedding_weight) for p in self.text_token_embedding.parameters(): p.requires_grad = True def initialize_text_embeddings_from_llama(self, llama_embeddings_module: torch.nn.Module) -> None: """ Initialize text embeddings from a LLaMA embedding module. Args: llama_embeddings_module: LLaMA embedding module with weight shape (vocab_size, 4096) """ if not hasattr(llama_embeddings_module, 'weight'): raise ValueError("llama_embeddings_module must have a 'weight' attribute") weight = llama_embeddings_module.weight.data self.initialize_text_embeddings_from_weights(weight) def _print_subnetwork_parameter_counts(self) -> None: """Print parameter counts for model subnetworks.""" print("=" * 70) print("TextSyncMimi Parameter Counts") print("=" * 70) print(f"Encoder: {sum(p.numel() for p in self.encoder.parameters()) / 1e6:.2f}M") print(f"Encoder Transformer: {sum(p.numel() for p in self.encoder_transformer.parameters()) / 1e6:.2f}M") print(f"Cross-Attention Transformer: {sum(p.numel() for p in self.cross_attention_transformer.parameters()) / 1e6:.2f}M") print(f"AR Transformer: {sum(p.numel() for p in self.ar_transformer.parameters()) / 1e6:.2f}M") print(f"Quantizer: {sum(p.numel() for p in self.quantizer.parameters()) / 1e6:.2f}M") print("=" * 70) def encode_audio_to_representation( self, input_values: torch.Tensor, audio_attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Encode audio to speech representation. Args: input_values: Audio waveform (B, 1, audio_len) audio_attention_mask: Attention mask (B, audio_len) Returns: Speech embeddings (B, 512, 12.5 * T) """ batch_size = input_values.shape[0] device = input_values.device # Encode through Mimi encoder pipeline embeddings = self.encoder(input_values) encoder_outputs = self.encoder_transformer(embeddings.transpose(1, 2)) embeddings = encoder_outputs[0].transpose(1, 2) embeddings = self.downsample(embeddings) # Apply attention mask if provided if audio_attention_mask is not None: speech_seq_len = embeddings.shape[-1] speech_attention_mask = torch.zeros(batch_size, speech_seq_len, device=device, dtype=torch.bool) for b in range(batch_size): actual_audio_len = audio_attention_mask[b].sum().item() actual_speech_len = int(actual_audio_len * 12.5 / 24000) actual_speech_len = min(actual_speech_len, speech_seq_len) if actual_speech_len > 0: speech_attention_mask[b, :actual_speech_len] = True speech_mask_expanded = speech_attention_mask.unsqueeze(1) embeddings = embeddings * speech_mask_expanded.float() return embeddings def generate_autoregressive( self, text_token_ids: torch.LongTensor, input_values: Optional[torch.Tensor] = None, speech_embeddings: Optional[torch.Tensor] = None, audio_attention_mask: Optional[torch.Tensor] = None, speech_attention_mask: Optional[torch.Tensor] = None, text_attention_mask: Optional[torch.Tensor] = None, max_z_tokens: int = 50, end_token_threshold: float = 0.5, device: Optional[torch.device] = None, ) -> List[List[torch.Tensor]]: """ Generate audio autoregressively. Args: text_token_ids: Text token IDs (B, L) input_values: Audio input (B, 1, 24000 * T) - for normal mode speech_embeddings: Pre-computed speech embeddings (B, T, 512) - for cached mode audio_attention_mask: Audio mask (B, audio_seq_len) - for normal mode speech_attention_mask: Speech mask (B, speech_seq_len) - for cached mode text_attention_mask: Text mask (B, text_seq_len) max_z_tokens: Maximum z tokens per text position end_token_threshold: Probability threshold for stopping device: Device for computation Returns: List of z_tokens lists (one per batch item) """ if device is None: device = text_token_ids.device self.eval() with torch.no_grad(): # Get speech embeddings for cross-attention context if speech_embeddings is not None: # Use pre-computed speech embeddings (cached mode) # speech_embeddings should already be (B, T, 512) pass # speech_embeddings is already provided else: # Compute speech embeddings from input_values (normal mode) if input_values is None: raise ValueError("Either input_values or speech_embeddings must be provided") speech_embeddings = self.encode_audio_to_representation( input_values, audio_attention_mask=audio_attention_mask ) speech_embeddings = speech_embeddings.transpose(1, 2) # (B, T, 512) # Embed token ids then project to 512 text_embeddings_4096 = self.text_token_embedding(text_token_ids) # (B, L, 4096) text_embeddings_proj = self.text_proj(text_embeddings_4096) # (B, L, 512) # Apply cross attention (same as in forward) # Create attention masks formatted_text_attention_mask = None formatted_speech_attention_mask = None batch_size, text_seq_len = text_embeddings_proj.shape[:2] if text_attention_mask is not None: causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_embeddings_proj.dtype)) causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1) padding_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len) combined_mask = causal_mask * padding_mask formatted_text_attention_mask = torch.where(combined_mask.bool(), 0.0, float('-inf')) else: causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_embeddings_proj.dtype)) causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1) formatted_text_attention_mask = torch.where(causal_mask.bool(), 0.0, float('-inf')) # Handle speech attention mask (use speech_attention_mask if available, otherwise audio_attention_mask) if speech_attention_mask is not None: # For cached data, speech_attention_mask is already in the right format speech_seq_len = speech_embeddings.shape[1] speech_mask = speech_attention_mask.bool() formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len) formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf')) elif audio_attention_mask is not None: # For non-cached data, convert audio_attention_mask to speech_attention_mask speech_seq_len = speech_embeddings.shape[1] speech_mask = torch.zeros(batch_size, speech_seq_len, dtype=torch.bool, device=device) for b in range(batch_size): audio_len = audio_attention_mask[b].sum().item() speech_len = int(audio_len * 12.5 / 24000) speech_len = min(speech_len, speech_seq_len) speech_mask[b, :speech_len] = True formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len) formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf')) else: formatted_speech_attention_mask = None # Cross attention cross_attention_outputs = self.cross_attention_transformer( hidden_states=text_embeddings_proj, encoder_hidden_states=speech_embeddings, attention_mask=formatted_text_attention_mask, encoder_attention_mask=formatted_speech_attention_mask, alignment_chunk_sizes=None, # V1 learns alignment ) cross_attention_outputs = cross_attention_outputs.last_hidden_state # Get special embeddings text_speech_latent_emb = self.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device)) time_speech_start_emb = self.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device)) time_speech_end_emb = self.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device)) generated_z_tokens = [] # Generate for each batch item for b in range(batch_size): # Get valid text length for this sample if text_attention_mask is not None: valid_text_len = text_attention_mask[b].sum().item() else: valid_text_len = text_embeddings_proj.shape[1] # Start sequence with text_speech_latent for context sequence = [text_speech_latent_emb] # (1, 512) batch_z_tokens = [] # Store z_tokens for this batch item # Generate for each text position for i in range(valid_text_len): # Add t_i and s_i t_i = text_embeddings_proj[b, i:i+1] # (1, 512) s_i = cross_attention_outputs[b, i:i+1] # (1, 512) sequence.extend([t_i, s_i]) # Add time_speech_start sequence.append(time_speech_start_emb) # Generate z tokens autoregressively for this text position z_count = 0 while z_count < max_z_tokens: # Prepare current sequence for AR transformer current_sequence = torch.cat(sequence, dim=0).unsqueeze(0) # (1, seq_len, 512) # Create attention mask for current sequence seq_len = current_sequence.shape[1] ar_attention_mask = torch.ones(1, seq_len, dtype=torch.bool, device=device) # Get prediction from AR transformer ar_outputs = self.ar_transformer( hidden_states=current_sequence, attention_mask=ar_attention_mask, ) # Get the last prediction last_prediction = ar_outputs.last_hidden_state[0, -1:, :] # (1, 512) # Check stopping condition using BCE classifier (v1.1) end_token_logit = self.end_token_classifier(last_prediction).squeeze(-1) # (1,) end_token_prob = torch.sigmoid(end_token_logit).item() # Convert to probability # Stop if probability is high enough (>= threshold means stop) if end_token_prob >= end_token_threshold: # Stop generating z tokens break else: # Add this prediction as next z token to both sequence (for context) and z_tokens (for output) sequence.append(last_prediction) batch_z_tokens.append(last_prediction.squeeze(0)) # Remove batch dimension for output z_count += 1 # Add time_speech_end to sequence for context sequence.append(time_speech_end_emb) # Store z_tokens for this batch item generated_z_tokens.append(batch_z_tokens) return generated_z_tokens def forward( self, text_token_ids: torch.LongTensor, input_values: Optional[torch.Tensor] = None, speech_embeddings: Optional[torch.Tensor] = None, alignment_chunk_sizes: torch.Tensor = None, audio_attention_mask: Optional[torch.Tensor] = None, speech_attention_mask: Optional[torch.Tensor] = None, text_attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Dict[str, torch.Tensor]: """ Forward pass for training. Args: text_token_ids: Text token IDs (B, L) input_values: Audio input (B, 1, 24000 * T) - for normal mode speech_embeddings: Pre-computed speech embeddings (B, T, 512) - for cached mode alignment_chunk_sizes: Alignment chunk sizes (B, L) audio_attention_mask: Audio mask (B, audio_seq_len) speech_attention_mask: Speech mask (B, speech_seq_len) text_attention_mask: Text mask (B, text_seq_len) Returns: Dictionary with 'loss', 'reconstruction_loss', and 'bce_end_token_loss' """ # Get speech embeddings if speech_embeddings is not None: pass elif input_values is not None: # Normal mode: compute speech embeddings from input_values speech_embeddings_raw = self.encode_audio_to_representation( input_values, audio_attention_mask ) # speech_embeddings_raw.shape = (B, 512, 12.5*T) # Transpose: [B, 512, 12.5*T] -> [B, 12.5*T, 512] speech_embeddings = speech_embeddings_raw.transpose(1, 2) else: raise ValueError("Either input_values or speech_embeddings must be provided") # Embed token ids and project to 512-dim text_embeddings_4096 = self.text_token_embedding(text_token_ids) # (B, L, 4096) text_embeddings = self.text_proj(text_embeddings_4096) # (B, L, 512) # Create proper attention masks for cross-attention formatted_text_attention_mask = None formatted_speech_attention_mask = None # Handle text attention mask (causal mask for decoder) batch_size, text_seq_len = text_embeddings.shape[:2] if text_attention_mask is not None: # Create causal mask and apply padding mask causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=text_embeddings.device, dtype=text_embeddings.dtype)) causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1) # Apply padding mask to causal mask padding_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len) combined_mask = causal_mask * padding_mask # Convert to attention scores (-inf for masked positions) formatted_text_attention_mask = torch.where(combined_mask.bool(), 0.0, float('-inf')) else: # Create causal mask for all positions (no padding mask) causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=text_embeddings.device, dtype=text_embeddings.dtype)) causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1) formatted_text_attention_mask = torch.where(causal_mask.bool(), 0.0, float('-inf')) # Handle speech attention mask (encoder mask) # Use speech_attention_mask if available (cached mode), otherwise audio_attention_mask (normal mode) if speech_attention_mask is not None: # Cached mode: speech_attention_mask is already in the right format speech_seq_len = speech_embeddings.shape[1] speech_mask = speech_attention_mask.bool() # Convert to attention format: [batch_size, 1, 1, speech_seq_len] formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len) formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf')) elif audio_attention_mask is not None: # Normal mode: convert audio mask to speech embedding mask speech_seq_len = speech_embeddings.shape[1] # Create speech attention mask based on actual lengths speech_mask = torch.zeros(batch_size, speech_seq_len, dtype=torch.bool, device=speech_embeddings.device) for b in range(batch_size): audio_len = audio_attention_mask[b].sum().item() speech_len = int(audio_len * 12.5 / 24000) speech_len = min(speech_len, speech_seq_len) speech_mask[b, :speech_len] = True # Convert to attention format: [batch_size, 1, 1, speech_seq_len] formatted_speech_attention_mask = speech_mask.view(batch_size, 1, 1, speech_seq_len) formatted_speech_attention_mask = torch.where(formatted_speech_attention_mask, 0.0, float('-inf')) else: # No masking formatted_speech_attention_mask = None # Cross attention: text attends to speech (no alignment constraints in V1) # hidden_states (decoder) = text, encoder_hidden_states = speech cross_attention_outputs = self.cross_attention_transformer( hidden_states=text_embeddings, encoder_hidden_states=speech_embeddings, attention_mask=formatted_text_attention_mask, # Causal mask for text (decoder) encoder_attention_mask=formatted_speech_attention_mask, # Mask for speech (encoder) alignment_chunk_sizes=None, # v1 doesn't use alignment_chunk_sizes -- the model should learn the alignment itself ) cross_attention_outputs = cross_attention_outputs.last_hidden_state # Auto-regressive decoder part # Following v0.5 where the target is the dequantized Mimi decoder-input # Compute target representation = Mimi decoder-input (quantized->dequantized at 12.5*seconds) # 12.5*seconds => T with torch.no_grad(): embeddings_bct = speech_embeddings.transpose(1, 2) # (B, 512, T) codes_kbt = self.quantizer.encode(embeddings_bct) # [K, B, T] codes_bkt = codes_kbt.transpose(0, 1) # [B, K, T] decoder_input_emb = self.quantizer.decode(codes_bkt) # (B, 512, T) target_representation = decoder_input_emb.transpose(1, 2) # (B, T, 512) # Build the interleaved sequence for the autoregressive decoder # as well as the mask for loss computation # Get special embeddings (all are single embeddings) device = text_embeddings.device text_speech_latent_emb = self.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device)) # (1, 512) time_speech_start_emb = self.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device)) # (1, 512) time_speech_end_emb = self.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device)) # (1, 512) batch_size = text_embeddings.shape[0] interleaved_sequences = [] loss_masks = [] bce_labels_batch = [] # BCE labels: 0 for z tokens, 1 for time_speech_end_emb bce_masks = [] # BCE mask: True for z tokens and time_speech_end_emb sequence_lengths = [] # Track actual sequence lengths before padding all_z_tokens = [] # Collect all valid z_tokens for separation loss max_total_length = 0 for b in range(batch_size): # Start with text_speech_latent embedding sequence_parts = [text_speech_latent_emb] # List to collect sequence parts loss_mask_parts = [False] # Don't compute loss on special tokens bce_label_parts = [0] # BCE labels (dummy for text_speech_latent_emb) bce_mask_parts = [False] # BCE mask (False for text_speech_latent_emb) # Get valid text length for this batch item if text_attention_mask is not None: valid_text_len = text_attention_mask[b].sum().item() else: valid_text_len = text_embeddings.shape[1] # Track current position in target_representation speech_position = 0 # For each text token for i in range(valid_text_len): # Add t_i (text embedding) t_i = text_embeddings[b, i:i+1] # (1, 512) sequence_parts.append(t_i) loss_mask_parts.append(False) bce_label_parts.append(0) # Dummy label for t_i bce_mask_parts.append(False) # No BCE loss for t_i # Add s_i (cross attention output) s_i = cross_attention_outputs[b, i:i+1] # (1, 512) sequence_parts.append(s_i) loss_mask_parts.append(False) bce_label_parts.append(0) # Dummy label for s_i bce_mask_parts.append(False) # No BCE loss for s_i # Add time_speech_start sequence_parts.append(time_speech_start_emb) loss_mask_parts.append(False) bce_label_parts.append(0) # Dummy label for time_speech_start bce_mask_parts.append(False) # No BCE loss for time_speech_start # Add z tokens for this chunk chunk_size = alignment_chunk_sizes[b, i].item() if chunk_size > 0: # Only add if chunk size is positive end_position = speech_position + chunk_size # Make sure we don't exceed target_representation length end_position = min(end_position, target_representation.shape[1]) actual_chunk_size = end_position - speech_position if actual_chunk_size > 0: z_tokens = target_representation[b, speech_position:end_position] # (actual_chunk_size, 512) sequence_parts.append(z_tokens) loss_mask_parts.extend([True] * actual_chunk_size) # Compute loss on z tokens bce_label_parts.extend([0] * actual_chunk_size) # Label 0 for z tokens bce_mask_parts.extend([True] * actual_chunk_size) # Compute BCE loss on z tokens # Collect z_tokens for separation loss computation all_z_tokens.append(z_tokens) speech_position = end_position # Add time_speech_end sequence_parts.append(time_speech_end_emb) loss_mask_parts.append(False) bce_label_parts.append(1) bce_mask_parts.append(True) # Concatenate all parts for this batch item full_sequence = torch.cat(sequence_parts, dim=0) # (total_length, 512) loss_mask = torch.tensor(loss_mask_parts, dtype=torch.bool, device=device) bce_labels = torch.tensor(bce_label_parts, dtype=torch.float, device=device) bce_mask = torch.tensor(bce_mask_parts, dtype=torch.bool, device=device) interleaved_sequences.append(full_sequence) loss_masks.append(loss_mask) bce_labels_batch.append(bce_labels) bce_masks.append(bce_mask) sequence_lengths.append(full_sequence.shape[0]) # Track actual length before padding max_total_length = max(max_total_length, full_sequence.shape[0]) # Pad sequences padded_sequences = [] padded_loss_masks = [] padded_bce_labels = [] padded_bce_masks = [] for sequence, loss_mask, bce_labels, bce_mask in zip(interleaved_sequences, loss_masks, bce_labels_batch, bce_masks): current_length = sequence.shape[0] if current_length < max_total_length: padding = torch.zeros(max_total_length - current_length, 512, device=device, dtype=sequence.dtype) padded_sequence = torch.cat([sequence, padding], dim=0) mask_padding = torch.zeros(max_total_length - current_length, dtype=torch.bool, device=device) padded_mask = torch.cat([loss_mask, mask_padding], dim=0) bce_label_padding = torch.zeros(max_total_length - current_length, dtype=torch.float, device=device) padded_bce_label = torch.cat([bce_labels, bce_label_padding], dim=0) bce_mask_padding = torch.zeros(max_total_length - current_length, dtype=torch.bool, device=device) padded_bce_mask = torch.cat([bce_mask, bce_mask_padding], dim=0) else: padded_sequence = sequence padded_mask = loss_mask padded_bce_label = bce_labels padded_bce_mask = bce_mask padded_sequences.append(padded_sequence) padded_loss_masks.append(padded_mask) padded_bce_labels.append(padded_bce_label) padded_bce_masks.append(padded_bce_mask) # Stack into batch tensors interleaved_batch = torch.stack(padded_sequences, dim=0) # (batch_size, max_total_length, 512) loss_mask_batch = torch.stack(padded_loss_masks, dim=0) # (batch_size, max_total_length) bce_labels_batch_tensor = torch.stack(padded_bce_labels, dim=0) # (batch_size, max_total_length) bce_mask_batch = torch.stack(padded_bce_masks, dim=0) # (batch_size, max_total_length) # Autoregressive prediction if max_total_length > 1: ar_input = interleaved_batch[:, :-1, :] # (batch_size, max_total_length-1, 512) ar_targets = interleaved_batch[:, 1:, :] # (batch_size, max_total_length-1, 512) ar_loss_mask = loss_mask_batch[:, 1:] # (batch_size, max_total_length-1) - shift mask left ar_bce_labels = bce_labels_batch_tensor[:, 1:] # (batch_size, max_total_length-1) - shift labels left ar_bce_mask = bce_mask_batch[:, 1:] # (batch_size, max_total_length-1) - shift mask left # Create attention mask for autoregressive transformer # We need to mask padded positions while maintaining causal property ar_seq_len = ar_input.shape[1] ar_attention_mask = torch.zeros(batch_size, ar_seq_len, dtype=torch.bool, device=device) for b in range(batch_size): valid_len = min(ar_seq_len, sequence_lengths[b] - 1) if valid_len > 0: ar_attention_mask[b, :valid_len] = True ar_outputs = self.ar_transformer( hidden_states=ar_input, attention_mask=ar_attention_mask, # This will be combined with causal mask inside transformer ) ar_predictions = ar_outputs.last_hidden_state # (batch_size, max_total_length-1, 512) # Compute BCE predictions for end token classification bce_logits = self.end_token_classifier(ar_predictions).squeeze(-1) # (batch_size, max_total_length-1) # Compute L2 loss only where ar_loss_mask is True (z tokens) if ar_loss_mask.any(): # Extract valid positions for loss computation valid_predictions = ar_predictions[ar_loss_mask] # (num_valid_positions, 512) valid_targets = ar_targets[ar_loss_mask] # (num_valid_positions, 512) # Compute L2 loss (MSE) reconstruction_loss = nn.functional.mse_loss( valid_predictions, valid_targets, reduction='mean' ) else: # Fallback if no valid positions (shouldn't happen in practice) reconstruction_loss = torch.tensor(0.0, device=device, requires_grad=True) # Compute BCE loss for end token classification (v1.1) if ar_bce_mask.any(): # Extract valid positions for BCE loss computation valid_bce_logits = bce_logits[ar_bce_mask] # (num_valid_bce_positions,) valid_bce_labels = ar_bce_labels[ar_bce_mask] # (num_valid_bce_positions,) # Compute BCE loss bce_end_token_loss = nn.functional.binary_cross_entropy_with_logits( valid_bce_logits, valid_bce_labels, reduction='mean' ) else: # Fallback if no valid BCE positions bce_end_token_loss = torch.tensor(0.0, device=device, requires_grad=True) if self.bce_threshold > 0.0: clamped_bce_loss = torch.clamp(bce_end_token_loss - self.bce_threshold, min=0.0) total_loss = reconstruction_loss + self.alpha * clamped_bce_loss else: total_loss = reconstruction_loss + self.alpha * bce_end_token_loss else: reconstruction_loss = torch.tensor(0.0, device=device, requires_grad=True) bce_end_token_loss = torch.tensor(0.0, device=device, requires_grad=True) total_loss = reconstruction_loss + torch.tensor(0.0, device=device, requires_grad=True) return { 'loss': total_loss, 'reconstruction_loss': reconstruction_loss, 'bce_end_token_loss': bce_end_token_loss, } __all__ = ["TextSyncMimi"]