MOSS-TTS-Realtime / modeling_mossttsrealtime.py
gaoyang07
Update modeling
b1cede0
# Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MossTTSRealtime model."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import initialization as init
from transformers.cache_utils import Cache
from transformers.modeling_outputs import ModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.models.qwen3 import Qwen3Model
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3DecoderLayer
from .configuration_mossttsrealtime import MossTTSRealtimeConfig
from .modeling_mossttsrealtime_local import MossTTSRealtimeLocalTransformerForCausalLM
class MossTTSRealtimePretrainedModel(PreTrainedModel):
config_class = MossTTSRealtimeConfig
config: MossTTSRealtimeConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen3DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_sdpa = True
_supports_flex_attn = True
_supports_flash_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Qwen3DecoderLayer,
"attentions": Qwen3Attention,
}
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
init.normal_(module.weight, mean=0.0, std=std)
if module.padding_idx is not None:
init.zeros_(module.weight[module.padding_idx])
@dataclass
class MossTTSRealtimeOutputWithPast(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[Cache] = None
last_hidden_state: Optional[torch.FloatTensor] = None
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
local_loss: Optional[torch.FloatTensor] = None
local_logits: Optional[torch.FloatTensor] = None
local_past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None
local_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
local_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
backbone_loss: Optional[torch.FloatTensor] = None
class MossTTSRealtime(MossTTSRealtimePretrainedModel):
def __init__(self, config: MossTTSRealtimeConfig):
super().__init__(config)
self.config = config
self.embed_tokens = nn.ModuleList([])
self.embed_tokens.append(
nn.Embedding(
config.language_config.vocab_size,
config.language_config.hidden_size,
config.language_config.pad_token_id,
)
)
self.audio_vocab_size = self.config.audio_vocab_size
for _ in range(self.config.rvq):
self.embed_tokens.append(
nn.Embedding(self.audio_vocab_size, config.language_config.hidden_size, self.config.audio_pad_token)
)
self.language_model = Qwen3Model._from_config(config.language_config)
self.local_transformer = MossTTSRealtimeLocalTransformerForCausalLM._from_config(config.local_config)
self.post_init()
def get_input_embeddings(self, input_ids):
if input_ids.device != self.embed_tokens[0].weight.device:
input_ids = input_ids.to(self.embed_tokens[0].weight.device)
inputs_embeds = self.embed_tokens[0](input_ids[..., 0])
for i, embed in enumerate(self.embed_tokens):
if i == 0:
continue
inputs_embeds = inputs_embeds + embed(input_ids[..., i])
return inputs_embeds
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
hidden_out_layers: Optional[list] = None,
**kwargs,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
outputs = self.language_model(
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
loss = None
local_outputs = None
if labels is not None:
audio_labels = labels[:, :, 1:]
train_mask = ~(audio_labels == -100).all(dim=-1)
local_input_ids = audio_labels[train_mask][..., : self.config.rvq - 1]
local_input_ids[local_input_ids == -100] = self.config.audio_pad_token
local_input_ids = F.pad(local_input_ids, (1, 0), value=0)
train_idx = train_mask.nonzero(as_tuple=True)
hidden_positions = torch.clamp(train_idx[1] - 1, min=0)
local_hidden_states = outputs.last_hidden_state[train_idx[0], hidden_positions, :].reshape(
-1, 1, self.config.local_config.hidden_size
)
local_labels = audio_labels[train_mask]
local_outputs = self.local_transformer(
input_ids=local_input_ids,
backbone_last_hidden_state=local_hidden_states,
use_cache=use_cache,
return_dict=True,
labels=local_labels,
**kwargs,
)
loss = local_outputs.loss
output = MossTTSRealtimeOutputWithPast(
loss=loss,
logits=None,
past_key_values=outputs.past_key_values,
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
local_logits=local_outputs.logits if local_outputs is not None else None,
local_past_key_values=local_outputs.past_key_values if local_outputs is not None else None,
local_hidden_states=local_outputs.hidden_states if local_outputs is not None else None,
local_attentions=local_outputs.attentions if local_outputs is not None else None,
)
if not return_dict:
return output.to_tuple()
return output
__all__ = ["MossTTSRealtime", "MossTTSRealtimeConfig", "MossTTSRealtimeOutputWithPast", "MossTTSRealtimePretrainedModel"]