| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ Modeling classes for MossTTSDelay. """ |
| |
|
| | from dataclasses import dataclass |
| | from typing import List, Optional, Tuple, Union |
| | from tqdm import tqdm |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import CrossEntropyLoss |
| |
|
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.modeling_outputs import ModelOutput |
| | from transformers.utils import ( |
| | add_start_docstrings, |
| | add_start_docstrings_to_model_forward, |
| | logging, |
| | replace_return_docstrings, |
| | ) |
| | from transformers.cache_utils import Cache |
| | from transformers.models.qwen3 import Qwen3Model |
| | from transformers import initialization as init |
| |
|
| | from .configuration_moss_tts import MossTTSDelayConfig |
| | from .inference_utils import sample_token, find_last_equal_C |
| |
|
| | try: |
| | from .processing_moss_tts import UserMessage, AssistantMessage, MossTTSDelayProcessor |
| | except Exception: |
| | UserMessage = None |
| | AssistantMessage = None |
| | MossTTSDelayProcessor = None |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | _CONFIG_FOR_DOC = "MossTTSDelayConfig" |
| |
|
| |
|
| | @dataclass |
| | class MossTTSDelayOutputWithPast(ModelOutput): |
| | """ |
| | Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). |
| | |
| | Args: |
| | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
| | Weighted sum of channel losses. |
| | all_sum_losses (`torch.FloatTensor` of shape `(batch_size, n_vq + 1)`, *optional*): |
| | Sum of losses for each sample and each channel before averaging. |
| | all_token_nums (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| | Number of non-masked tokens per sample. |
| | sample_losses (`torch.FloatTensor` of shape `(batch_size,)`, *optional*): |
| | Loss per sample. |
| | channel_losses (`torch.FloatTensor` of shape `(n_vq + 1,)`, *optional*): |
| | Loss per channel (text head + vq heads). |
| | logits (`List[torch.FloatTensor]`, *optional*): |
| | List of prediction scores from each head. |
| | past_key_values (`Cache`, *optional*): |
| | Pre-computed hidden-states (key and values in the self-attention blocks). |
| | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed): |
| | Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, + |
| | one for the output of each layer). |
| | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed): |
| | Tuple of torch.FloatTensor (one for each layer) of the attention weights. |
| | """ |
| | loss: Optional[torch.FloatTensor] = None |
| | all_sum_losses: Optional[torch.FloatTensor] = None |
| | all_token_nums: Optional[torch.LongTensor] = None |
| | sample_losses: Optional[torch.FloatTensor] = None |
| | channel_losses: Optional[torch.FloatTensor] = None |
| | logits: Optional[List[torch.FloatTensor]] = None |
| | past_key_values: Optional[Cache] = None |
| | hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| | attentions: Optional[Tuple[torch.FloatTensor]] = None |
| |
|
| |
|
| |
|
| |
|
| | class MossTTSDelayPreTrainedModel(PreTrainedModel): |
| | config_class = MossTTSDelayConfig |
| | base_model_prefix = "model" |
| | supports_gradient_checkpointing = True |
| | _no_split_modules = ["Qwen3DecoderLayer"] |
| | _skip_keys_device_placement = "past_key_values" |
| | _supports_flash_attn = True |
| | _supports_flash_attn_2 = True |
| | _supports_sdpa = True |
| | _supports_flex_attn = True |
| |
|
| | def _init_weights(self, module): |
| | """ |
| | Transformers 5.0+ safe init: |
| | - MUST use transformers.initialization helpers |
| | - MUST respect param._is_hf_initialized to avoid overwriting ckpt-loaded params |
| | """ |
| | |
| | super()._init_weights(module) |
| |
|
| | |
| | |
| | std = None |
| | if hasattr(self.config, "initializer_range"): |
| | std = self.config.initializer_range |
| | elif hasattr(self.config, "language_config") and hasattr(self.config.language_config, "initializer_range"): |
| | std = self.config.language_config.initializer_range |
| | else: |
| | std = 0.02 |
| |
|
| | |
| | if isinstance(module, nn.Embedding): |
| | |
| | |
| | if getattr(module, "num_embeddings", None) == self.config.audio_vocab_size + 1: |
| | init.normal_(module.weight, mean=0.0, std=std) |
| | |
| | |
| |
|
| | |
| | if isinstance(module, nn.Linear): |
| | |
| | |
| | pass |
| |
|
| |
|
| |
|
| | MOSSTTS_START_DOCSTRING = r""" |
| | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
| | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| | etc.) |
| | |
| | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
| | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
| | and behavior. |
| | |
| | Parameters: |
| | config ([`MossTTSDelayConfig`]): |
| | Model configuration class with all the parameters of the model. Initializing with a config file does not |
| | load the weights associated with the model, only the configuration. Check out the |
| | [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| | """ |
| |
|
| |
|
| | @add_start_docstrings( |
| | "The MossTTSDelay Model architecture tailored for Text-to-Speech generation with multi-head VQ prediction.", |
| | MOSSTTS_START_DOCSTRING, |
| | ) |
| | class MossTTSDelayModel(MossTTSDelayPreTrainedModel): |
| | UserMessage = UserMessage |
| | AssistantMessage = AssistantMessage |
| | Processor = MossTTSDelayProcessor |
| |
|
| | def __init__(self, config: MossTTSDelayConfig): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | config.language_config.torch_dtype = config.torch_dtype |
| | |
| | self.language_model = Qwen3Model(config.language_config) |
| |
|
| | |
| | |
| | |
| | self.emb_ext = nn.ModuleList() |
| | for vq_idx in range(self.config.n_vq): |
| | |
| | self.emb_ext.append( |
| | nn.Embedding(self.config.audio_vocab_size + 1, config.language_config.hidden_size, padding_idx=None) |
| | ) |
| |
|
| | |
| | |
| | |
| | self.lm_heads = nn.ModuleList([ |
| | nn.Linear(config.language_config.hidden_size, config.language_config.vocab_size, bias=False) |
| | ]) |
| | for vq_idx in range(self.config.n_vq): |
| | self.lm_heads.append( |
| | nn.Linear(config.language_config.hidden_size, self.config.audio_vocab_size + 1, bias=False) |
| | ) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def get_input_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: |
| | """ |
| | Computes the combined embeddings from text and multiple audio VQ channels. |
| | |
| | Args: |
| | input_ids: Shape (Batch, Seq_Len, 1 + n_vq) |
| | """ |
| | |
| | |
| | inputs_embeds = self.language_model.get_input_embeddings()(input_ids[..., 0]) |
| |
|
| | |
| | for i, embed_layer in enumerate(self.emb_ext): |
| | |
| | |
| | inputs_embeds = inputs_embeds + embed_layer(input_ids[..., i + 1]) |
| | |
| | return inputs_embeds |
| |
|
| | def set_input_embeddings(self, value): |
| | self.language_model.embed_tokens = value |
| |
|
| | def get_output_embeddings(self): |
| | |
| | |
| | return self.lm_heads |
| |
|
| | @add_start_docstrings_to_model_forward(MOSSTTS_START_DOCSTRING) |
| | @replace_return_docstrings(output_type=MossTTSDelayOutputWithPast, config_class=_CONFIG_FOR_DOC) |
| | def forward( |
| | self, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | hidden_out_layers: Optional[List[int]] = None, |
| | channelwise_loss_weight: Optional[List[float]] = None, |
| | **kwargs, |
| | ) -> Union[Tuple, MossTTSDelayOutputWithPast]: |
| | r""" |
| | Args: |
| | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 1 + n_vq)`): |
| | Indices of input sequence tokens in the vocabulary. |
| | Dimension 2 contains: [Text/Semantics, VQ_0, VQ_1, ..., VQ_N]. |
| | labels (`torch.LongTensor` of shape `(batch_size, sequence_length, 1 + n_vq)`, *optional*): |
| | Labels for computing the masked language modeling loss. |
| | channelwise_loss_weight (`List[float]`, *optional*): |
| | Manual weights for summing losses across different heads (Text vs Audio channels). |
| | |
| | Returns: |
| | """ |
| |
|
| | if len(input_ids.shape) != 3 or input_ids.shape[-1] != self.config.n_vq + 1: |
| | raise ValueError("`Input_ids`'s shape should be exactly (batch_size, sequence_length, 1 + n_vq).") |
| |
|
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| |
|
| | |
| | if inputs_embeds is None: |
| | inputs_embeds = self.get_input_embeddings(input_ids) |
| |
|
| | |
| | |
| | outputs = self.language_model( |
| | input_ids=None, |
| | position_ids=position_ids, |
| | attention_mask=attention_mask, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=True, |
| | return_dict=True, |
| | cache_position=cache_position, |
| | **kwargs, |
| | ) |
| |
|
| | |
| | last_hidden_state = outputs.last_hidden_state |
| | if hidden_out_layers is None: |
| | |
| | |
| | |
| | hidden_states_for_heads = [last_hidden_state] * (len(self.lm_heads)) |
| | else: |
| | |
| | |
| | all_hs = outputs.hidden_states |
| | hidden_states_for_heads = [all_hs[idx] for idx in hidden_out_layers] |
| |
|
| | |
| | layer_logits = [] |
| | for i, (hs, head) in enumerate(zip(hidden_states_for_heads, self.lm_heads)): |
| | logits = head(hs) |
| | |
| | |
| | |
| | if i > 0: |
| | logits[..., -1] = float("-inf") |
| | layer_logits.append(logits) |
| |
|
| | |
| | loss = None |
| | all_sum_losses = None |
| | all_token_nums = None |
| | sample_losses = None |
| | channel_losses = None |
| |
|
| | if labels is not None: |
| | |
| | if labels.dim() != 3: |
| | raise ValueError(f"Labels must have rank 3 (B, S, C), got {labels.shape}") |
| |
|
| | batch_size = labels.size(0) |
| | n_heads = len(layer_logits) |
| | |
| | |
| | |
| | all_sum_losses_list = [] |
| | |
| | |
| | |
| | |
| | |
| | all_token_nums = torch.sum(labels != -100, dim=1) |
| |
|
| | for i, logits in enumerate(layer_logits): |
| | |
| | |
| | cur_labels = labels[..., i] |
| | |
| | |
| | |
| | loss_fct = CrossEntropyLoss(reduction='none') |
| | vocab_size = logits.size(-1) |
| | |
| | reshaped_logits = logits.view(-1, vocab_size) |
| | reshaped_labels = cur_labels.contiguous().view(-1) |
| | |
| | |
| | per_token_loss = loss_fct(reshaped_logits, reshaped_labels) |
| | |
| | |
| | per_token_loss = per_token_loss.view(batch_size, -1) |
| | per_sample_loss = torch.sum(per_token_loss, dim=-1) |
| | |
| | all_sum_losses_list.append(per_sample_loss) |
| |
|
| | |
| | all_sum_losses = torch.stack(all_sum_losses_list, dim=1) |
| |
|
| | |
| | if channelwise_loss_weight is not None: |
| | if len(channelwise_loss_weight) != n_heads: |
| | raise ValueError(f"channelwise_loss_weight length {len(channelwise_loss_weight)} != {n_heads}") |
| | |
| | w_tensor = torch.tensor(channelwise_loss_weight, device=all_sum_losses.device, dtype=all_sum_losses.dtype) |
| | |
| | |
| | |
| | |
| | token_counts_safe = all_token_nums.float().clamp(min=1.0) |
| | |
| | normalized_losses = all_sum_losses / token_counts_safe |
| | sample_losses = (normalized_losses * w_tensor).sum(dim=1) / w_tensor.sum() |
| | |
| | |
| | total_loss_per_channel = all_sum_losses.sum(dim=0) |
| | total_tokens_per_channel = all_token_nums.sum(dim=0).float().clamp(min=1.0) |
| | channel_losses = total_loss_per_channel / total_tokens_per_channel |
| | |
| | |
| | loss = (channel_losses * w_tensor).sum() / w_tensor.sum() |
| | else: |
| | |
| | total_tokens = all_token_nums.sum().float().clamp(min=1.0) |
| | loss = all_sum_losses.sum() / total_tokens |
| | channel_losses = all_sum_losses.sum(dim=0) / all_token_nums.sum(dim=0).clamp(min=1.0) |
| |
|
| | return MossTTSDelayOutputWithPast( |
| | loss=loss, |
| | all_sum_losses=all_sum_losses, |
| | all_token_nums=all_token_nums, |
| | sample_losses=sample_losses, |
| | channel_losses=channel_losses, |
| | logits=layer_logits, |
| | past_key_values=outputs.past_key_values, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|
| | @torch.inference_mode() |
| | def generate( |
| | self, |
| | input_ids: torch.LongTensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | max_new_tokens: Optional[int] = None, |
| | text_temperature: float = 1.1, |
| | text_top_p: float = 0.9, |
| | text_top_k: int = 50, |
| | audio_temperature: Optional[float] = None, |
| | audio_top_p: Optional[float] = None, |
| | audio_top_k: Optional[int] = None, |
| | audio_repetition_penalty: Optional[float] = None, |
| | ): |
| | generation_config = getattr(self, "generation_config", None) |
| |
|
| | def _cfg_value(name: str, default_value: Union[int, float]) -> Union[int, float]: |
| | if generation_config is None: |
| | return default_value |
| | value = getattr(generation_config, name, None) |
| | if value is None: |
| | return default_value |
| | return value |
| |
|
| | if max_new_tokens is None: |
| | try: |
| | max_new_tokens = int(_cfg_value("max_new_tokens", 1000)) |
| | except (TypeError, ValueError): |
| | max_new_tokens = 1000 |
| | if audio_temperature is None: |
| | try: |
| | audio_temperature = float(_cfg_value("temperature", 1.1)) |
| | except (TypeError, ValueError): |
| | audio_temperature = 1.1 |
| | if audio_top_p is None: |
| | try: |
| | audio_top_p = float(_cfg_value("top_p", 0.9)) |
| | except (TypeError, ValueError): |
| | audio_top_p = 0.9 |
| | if audio_top_k is None: |
| | try: |
| | audio_top_k = int(_cfg_value("top_k", 50)) |
| | except (TypeError, ValueError): |
| | audio_top_k = 50 |
| | if audio_repetition_penalty is None: |
| | try: |
| | audio_repetition_penalty = float(_cfg_value("repetition_penalty", 1.1)) |
| | except (TypeError, ValueError): |
| | audio_repetition_penalty = 1.1 |
| |
|
| | if text_temperature > 0: |
| | text_do_sample = True |
| | else: |
| | text_temperature = 1 |
| | text_do_sample = False |
| | if audio_temperature > 0: |
| | audio_do_sample = True |
| | else: |
| | audio_temperature = 1 |
| | audio_do_sample = False |
| | |
| | past_key_values = None |
| | device = input_ids.device |
| | current_input_ids = input_ids |
| | current_attention_mask = attention_mask |
| | batch_size, seq_len, n_vq = input_ids.shape |
| | n_vq -= 1 |
| | |
| | generation_ids = input_ids[:] |
| | is_stopping = torch.zeros(batch_size, dtype=torch.bool, device=device) |
| |
|
| | |
| | audio_lengths = torch.zeros(batch_size, dtype=torch.int64, device=device) |
| | torch_int64_max = torch.iinfo(torch.int64).max |
| | delayed_lengths = torch.full((batch_size,), torch_int64_max, dtype=torch.int64, device=device) |
| | |
| | |
| | |
| | |
| | is_continuation = (input_ids[:, -1, 0] == self.config.audio_start_token_id) | (input_ids[:, -1, 0] == self.config.audio_assistant_gen_slot_token_id) |
| | audio_start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id) |
| | audio_start_mask = is_continuation & (audio_start_indices != -1) |
| | audio_lengths[audio_start_mask] = seq_len - audio_start_indices[audio_start_mask] |
| | |
| | is_audio = audio_start_mask.clone() |
| | |
| | pre_exclude_mask0 = torch.tensor([self.config.pad_token_id, self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id, self.config.audio_end_token_id], device=device) |
| | pre_exclude_mask1 = torch.ones(self.config.language_config.vocab_size, device=device).bool() |
| | pre_exclude_mask1[[self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id]] = False |
| | |
| | |
| | |
| | for time_step in tqdm(range(max_new_tokens), desc=f"Generating bs{batch_size} ..."): |
| | outputs = self( |
| | input_ids=current_input_ids, |
| | attention_mask=current_attention_mask, |
| | past_key_values=past_key_values, |
| | use_cache=True, |
| | ) |
| | past_key_values = outputs.past_key_values |
| | |
| | next_token_logits = [logit[:, -1, :] / text_temperature if logit_idx == 0 else logit[:, -1, :] / audio_temperature for logit_idx, logit in enumerate(outputs.logits)] |
| | next_token_logits[0] = next_token_logits[0].clone() |
| | |
| | next_text_token = torch.full((batch_size,), self.config.pad_token_id, device=device) |
| | |
| | |
| | next_text_token[~is_stopping & (delayed_lengths < n_vq)] = self.config.audio_assistant_delay_slot_token_id |
| | is_audio_eos = ~is_stopping & (delayed_lengths == n_vq) |
| | next_text_token[is_audio_eos] = self.config.audio_end_token_id |
| | is_audio[is_audio_eos] = False |
| | sampling_text_mask = ~is_stopping & (delayed_lengths > n_vq) |
| | next_token_logits[0][~is_audio] = next_token_logits[0][~is_audio].index_fill(-1, pre_exclude_mask0, float('-inf')) |
| | next_token_logits[0][is_audio] = next_token_logits[0][is_audio].masked_fill(pre_exclude_mask1, float('-inf')) |
| | if time_step == 0: |
| | next_token_logits[0][..., 151662] = float('-inf') |
| | if time_step <= n_vq: |
| | next_token_logits[0][..., self.config.im_end_token_id] = float('-inf') |
| | |
| | |
| | next_text_token[sampling_text_mask] = sample_token( |
| | logits=next_token_logits[0][sampling_text_mask], |
| | top_p=text_top_p, |
| | top_k=text_top_k, |
| | do_sample=text_do_sample |
| | ) |
| | is_audio[next_text_token == self.config.audio_start_token_id] = True |
| | |
| | is_stopping[next_text_token == self.config.im_end_token_id] = True |
| |
|
| | |
| | |
| | next_audio_tokens = torch.full((batch_size, n_vq), self.config.audio_pad_code, device=device) |
| | |
| | |
| | |
| | pre_audio_mask = audio_lengths.unsqueeze(1) > torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) |
| | post_audio_mask = torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) > delayed_lengths.unsqueeze(1) - 1 |
| | post_audio_mask[delayed_lengths == torch_int64_max] = True |
| | sampling_audio_mask = pre_audio_mask & post_audio_mask |
| | next_audio_tokens[~sampling_audio_mask] = self.config.audio_pad_code |
| | |
| | if sampling_audio_mask.sum() > 0: |
| | |
| | audio_logits = torch.stack(next_token_logits[2:], dim=1)[sampling_audio_mask[:, 1:]] |
| | audio_logits[..., self.config.audio_pad_code] = float('-inf') |
| | |
| | audio_ch0_logits = next_token_logits[1][sampling_audio_mask[:, 0]] |
| | audio_ch0_logits[..., 1024] = float('-inf') |
| | next_audio_tokens[:, 0][sampling_audio_mask[:, 0]] = sample_token( |
| | logits=audio_ch0_logits, |
| | prev_tokens=generation_ids[:, :, 1], |
| | repetition_penalty=audio_repetition_penalty, |
| | top_p=audio_top_p, |
| | top_k=audio_top_k, |
| | do_sample=audio_do_sample |
| | ) |
| | |
| | next_audio_tokens[:, 1:][sampling_audio_mask[:, 1:]] = sample_token( |
| | logits=audio_logits, |
| | prev_tokens=generation_ids[:, :, 2:], |
| | repetition_penalty=audio_repetition_penalty, |
| | top_p=audio_top_p, |
| | top_k=audio_top_k, |
| | do_sample=audio_do_sample |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | audio_lengths[(next_text_token == self.config.audio_start_token_id) | (next_text_token == self.config.audio_assistant_gen_slot_token_id) | (next_text_token == self.config.audio_assistant_delay_slot_token_id)] += 1 |
| | audio_lengths[next_text_token == self.config.audio_end_token_id] = 0 |
| | delayed_lengths[(delayed_lengths == torch_int64_max) & (next_text_token == self.config.audio_assistant_delay_slot_token_id)] = 0 |
| | delayed_lengths[delayed_lengths != torch_int64_max] += 1 |
| | delayed_lengths[delayed_lengths > n_vq] = torch_int64_max |
| | |
| | current_input_ids = torch.cat([next_text_token[:, None, None], next_audio_tokens[:, None, :]], dim=2) |
| | current_attention_mask = torch.cat([current_attention_mask, (~is_stopping).unsqueeze(-1)], dim=-1) |
| | generation_ids = torch.cat([generation_ids, current_input_ids], dim=1) |
| | |
| | if is_stopping.sum() == batch_size: |
| | break |
| | |
| | start_indices = find_last_equal_C(input_ids[..., 0], self.config.im_start_token_id) + 3 |
| | start_lengths = seq_len - start_indices |
| |
|
| | output = [] |
| | for start_idx, start_length, cur_generation_ids in zip(start_indices, start_lengths, generation_ids): |
| | output.append((start_length, cur_generation_ids[start_idx:])) |
| | |
| | return output |
| |
|