|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""MossTTSRealtime model.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from transformers import initialization as init |
|
|
from transformers.cache_utils import Cache |
|
|
from transformers.modeling_outputs import ModelOutput |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.models.qwen3 import Qwen3Model |
|
|
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3DecoderLayer |
|
|
from .configuration_mossttsrealtime import MossTTSRealtimeConfig |
|
|
from .modeling_mossttsrealtime_local import MossTTSRealtimeLocalTransformerForCausalLM |
|
|
|
|
|
|
|
|
class MossTTSRealtimePretrainedModel(PreTrainedModel): |
|
|
config_class = MossTTSRealtimeConfig |
|
|
config: MossTTSRealtimeConfig |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["Qwen3DecoderLayer"] |
|
|
_skip_keys_device_placement = ["past_key_values"] |
|
|
_supports_sdpa = True |
|
|
_supports_flex_attn = True |
|
|
_supports_flash_attn = True |
|
|
_can_compile_fullgraph = True |
|
|
_supports_attention_backend = True |
|
|
_can_record_outputs = { |
|
|
"hidden_states": Qwen3DecoderLayer, |
|
|
"attentions": Qwen3Attention, |
|
|
} |
|
|
|
|
|
def _init_weights(self, module): |
|
|
std = self.config.initializer_range |
|
|
if isinstance(module, nn.Linear): |
|
|
init.normal_(module.weight, mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
init.normal_(module.weight, mean=0.0, std=std) |
|
|
if module.padding_idx is not None: |
|
|
init.zeros_(module.weight[module.padding_idx]) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class MossTTSRealtimeOutputWithPast(ModelOutput): |
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
logits: Optional[torch.FloatTensor] = None |
|
|
past_key_values: Optional[Cache] = None |
|
|
last_hidden_state: Optional[torch.FloatTensor] = None |
|
|
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None |
|
|
attentions: Optional[tuple[torch.FloatTensor]] = None |
|
|
local_loss: Optional[torch.FloatTensor] = None |
|
|
local_logits: Optional[torch.FloatTensor] = None |
|
|
local_past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None |
|
|
local_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None |
|
|
local_attentions: Optional[tuple[torch.FloatTensor, ...]] = None |
|
|
backbone_loss: Optional[torch.FloatTensor] = None |
|
|
|
|
|
|
|
|
class MossTTSRealtime(MossTTSRealtimePretrainedModel): |
|
|
def __init__(self, config: MossTTSRealtimeConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.embed_tokens = nn.ModuleList([]) |
|
|
self.embed_tokens.append( |
|
|
nn.Embedding( |
|
|
config.language_config.vocab_size, |
|
|
config.language_config.hidden_size, |
|
|
config.language_config.pad_token_id, |
|
|
) |
|
|
) |
|
|
self.audio_vocab_size = self.config.audio_vocab_size |
|
|
for _ in range(self.config.rvq): |
|
|
self.embed_tokens.append( |
|
|
nn.Embedding(self.audio_vocab_size, config.language_config.hidden_size, self.config.audio_pad_token) |
|
|
) |
|
|
self.language_model = Qwen3Model._from_config(config.language_config) |
|
|
self.local_transformer = MossTTSRealtimeLocalTransformerForCausalLM._from_config(config.local_config) |
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self, input_ids): |
|
|
if input_ids.device != self.embed_tokens[0].weight.device: |
|
|
input_ids = input_ids.to(self.embed_tokens[0].weight.device) |
|
|
inputs_embeds = self.embed_tokens[0](input_ids[..., 0]) |
|
|
for i, embed in enumerate(self.embed_tokens): |
|
|
if i == 0: |
|
|
continue |
|
|
inputs_embeds = inputs_embeds + embed(input_ids[..., i]) |
|
|
return inputs_embeds |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = False, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
hidden_out_layers: Optional[list] = None, |
|
|
**kwargs, |
|
|
): |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.get_input_embeddings(input_ids) |
|
|
|
|
|
outputs = self.language_model( |
|
|
position_ids=position_ids, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=True, |
|
|
cache_position=cache_position, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
loss = None |
|
|
local_outputs = None |
|
|
if labels is not None: |
|
|
audio_labels = labels[:, :, 1:] |
|
|
train_mask = ~(audio_labels == -100).all(dim=-1) |
|
|
local_input_ids = audio_labels[train_mask][..., : self.config.rvq - 1] |
|
|
local_input_ids[local_input_ids == -100] = self.config.audio_pad_token |
|
|
local_input_ids = F.pad(local_input_ids, (1, 0), value=0) |
|
|
|
|
|
train_idx = train_mask.nonzero(as_tuple=True) |
|
|
hidden_positions = torch.clamp(train_idx[1] - 1, min=0) |
|
|
local_hidden_states = outputs.last_hidden_state[train_idx[0], hidden_positions, :].reshape( |
|
|
-1, 1, self.config.local_config.hidden_size |
|
|
) |
|
|
local_labels = audio_labels[train_mask] |
|
|
|
|
|
local_outputs = self.local_transformer( |
|
|
input_ids=local_input_ids, |
|
|
backbone_last_hidden_state=local_hidden_states, |
|
|
use_cache=use_cache, |
|
|
return_dict=True, |
|
|
labels=local_labels, |
|
|
**kwargs, |
|
|
) |
|
|
loss = local_outputs.loss |
|
|
|
|
|
output = MossTTSRealtimeOutputWithPast( |
|
|
loss=loss, |
|
|
logits=None, |
|
|
past_key_values=outputs.past_key_values, |
|
|
last_hidden_state=outputs.last_hidden_state, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
local_logits=local_outputs.logits if local_outputs is not None else None, |
|
|
local_past_key_values=local_outputs.past_key_values if local_outputs is not None else None, |
|
|
local_hidden_states=local_outputs.hidden_states if local_outputs is not None else None, |
|
|
local_attentions=local_outputs.attentions if local_outputs is not None else None, |
|
|
) |
|
|
if not return_dict: |
|
|
return output.to_tuple() |
|
|
return output |
|
|
|
|
|
|
|
|
__all__ = ["MossTTSRealtime", "MossTTSRealtimeConfig", "MossTTSRealtimeOutputWithPast", "MossTTSRealtimePretrainedModel"] |
|
|
|