first add
Browse files- __init__.py +0 -0
- config.json +96 -0
- configuration_mossttsrealtime.py +115 -0
- merges.txt +0 -0
- model.safetensors +3 -0
- modeling_mossttsrealtime.py +192 -0
- modeling_mossttsrealtime_local.py +449 -0
- processing_mossttsrealtime.py +172 -0
- streaming_mossttsrealtime.py +1003 -0
- tokenizer.json +3 -0
- tokenizer_config.json +240 -0
- vocab.json +0 -0
__init__.py
ADDED
|
File without changes
|
config.json
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"MossTTSRealtime"
|
| 4 |
+
],
|
| 5 |
+
"audio_pad_token": 1024,
|
| 6 |
+
"audio_vocab_size": 1027,
|
| 7 |
+
"dtype": "bfloat16",
|
| 8 |
+
"initializer_range": 0.02,
|
| 9 |
+
"language_config": {
|
| 10 |
+
"_name_or_path": ".",
|
| 11 |
+
"architectures": [
|
| 12 |
+
"Qwen3ForCausalLM"
|
| 13 |
+
],
|
| 14 |
+
"attention_bias": false,
|
| 15 |
+
"attention_dropout": 0.0,
|
| 16 |
+
"bos_token_id": 151643,
|
| 17 |
+
"dtype": "bfloat16",
|
| 18 |
+
"eos_token_id": 151645,
|
| 19 |
+
"head_dim": 128,
|
| 20 |
+
"hidden_act": "silu",
|
| 21 |
+
"hidden_size": 2048,
|
| 22 |
+
"initializer_range": 0.02,
|
| 23 |
+
"intermediate_size": 6144,
|
| 24 |
+
"layer_types": [
|
| 25 |
+
"full_attention",
|
| 26 |
+
"full_attention",
|
| 27 |
+
"full_attention",
|
| 28 |
+
"full_attention",
|
| 29 |
+
"full_attention",
|
| 30 |
+
"full_attention",
|
| 31 |
+
"full_attention",
|
| 32 |
+
"full_attention",
|
| 33 |
+
"full_attention",
|
| 34 |
+
"full_attention",
|
| 35 |
+
"full_attention",
|
| 36 |
+
"full_attention",
|
| 37 |
+
"full_attention",
|
| 38 |
+
"full_attention",
|
| 39 |
+
"full_attention",
|
| 40 |
+
"full_attention",
|
| 41 |
+
"full_attention",
|
| 42 |
+
"full_attention",
|
| 43 |
+
"full_attention",
|
| 44 |
+
"full_attention",
|
| 45 |
+
"full_attention",
|
| 46 |
+
"full_attention",
|
| 47 |
+
"full_attention",
|
| 48 |
+
"full_attention",
|
| 49 |
+
"full_attention",
|
| 50 |
+
"full_attention",
|
| 51 |
+
"full_attention",
|
| 52 |
+
"full_attention"
|
| 53 |
+
],
|
| 54 |
+
"max_position_embeddings": 40960,
|
| 55 |
+
"max_window_layers": 28,
|
| 56 |
+
"model_type": "qwen3",
|
| 57 |
+
"num_attention_heads": 16,
|
| 58 |
+
"num_hidden_layers": 28,
|
| 59 |
+
"num_key_value_heads": 8,
|
| 60 |
+
"rms_norm_eps": 1e-06,
|
| 61 |
+
"rope_scaling": null,
|
| 62 |
+
"rope_theta": 1000000,
|
| 63 |
+
"sliding_window": null,
|
| 64 |
+
"use_cache": true,
|
| 65 |
+
"use_sliding_window": false,
|
| 66 |
+
"vocab_size": 151936
|
| 67 |
+
},
|
| 68 |
+
"local_config": {
|
| 69 |
+
"attention_bias": false,
|
| 70 |
+
"attention_dropout": 0.0,
|
| 71 |
+
"audio_pad_token": 1024,
|
| 72 |
+
"audio_vocab_size": 1027,
|
| 73 |
+
"head_dim": 128,
|
| 74 |
+
"hidden_act": "silu",
|
| 75 |
+
"hidden_size": 2048,
|
| 76 |
+
"initializer_range": 0.02,
|
| 77 |
+
"intermediate_size": 6144,
|
| 78 |
+
"max_position_embeddings": 33,
|
| 79 |
+
"model_type": "MossTTSRealtimeLocalTransformer",
|
| 80 |
+
"num_attention_heads": 16,
|
| 81 |
+
"num_hidden_layers": 4,
|
| 82 |
+
"num_key_value_heads": 8,
|
| 83 |
+
"rms_norm_eps": 1e-06,
|
| 84 |
+
"rope_theta": 1000000,
|
| 85 |
+
"rvq": 16,
|
| 86 |
+
"tie_word_embeddings": false,
|
| 87 |
+
"use_cache": true
|
| 88 |
+
},
|
| 89 |
+
"model_type": "moss_tts_realtime",
|
| 90 |
+
"reference_audio_pad": 151654,
|
| 91 |
+
"rvq": 16,
|
| 92 |
+
"text_pad": 151655,
|
| 93 |
+
"tie_word_embeddings": false,
|
| 94 |
+
"transformers_version": "4.57.1",
|
| 95 |
+
"vocab_size": 151936
|
| 96 |
+
}
|
configuration_mossttsrealtime.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""MossTTSRealtimeModel configuration."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from typing import Any
|
| 20 |
+
|
| 21 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 22 |
+
from transformers.models.qwen3 import Qwen3Config
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _ensure_config(cfg: Any, cls: type[PretrainedConfig]) -> PretrainedConfig:
|
| 26 |
+
if isinstance(cfg, cls):
|
| 27 |
+
return cfg
|
| 28 |
+
if cfg is None:
|
| 29 |
+
return cls()
|
| 30 |
+
if isinstance(cfg, dict):
|
| 31 |
+
return cls(**cfg)
|
| 32 |
+
raise TypeError(f"Unsupported config type for {cls.__name__}: {type(cfg)}")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class MossTTSRealtimeLocalTransformerConfig(PretrainedConfig):
|
| 36 |
+
model_type = "moss_tts_realtime_local_transformer"
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
head_dim: int = 128,
|
| 41 |
+
use_cache: bool = True,
|
| 42 |
+
hidden_size: int = 2048,
|
| 43 |
+
rms_norm_eps: float = 1e-6,
|
| 44 |
+
num_hidden_layers: int = 4,
|
| 45 |
+
intermediate_size: int = 6144,
|
| 46 |
+
num_attention_heads: int = 16,
|
| 47 |
+
initializer_range: float = 0.02,
|
| 48 |
+
attention_bias: bool = False,
|
| 49 |
+
attention_dropout: float = 0.0,
|
| 50 |
+
max_position_embeddings: int = 33,
|
| 51 |
+
num_key_value_heads: int = 8,
|
| 52 |
+
hidden_act: str = "silu",
|
| 53 |
+
rope_theta: int = 1000000,
|
| 54 |
+
rope_type: str = "linear",
|
| 55 |
+
pad_token_id: int = 1024,
|
| 56 |
+
rope_parameters: dict | None = None,
|
| 57 |
+
**kwargs,
|
| 58 |
+
):
|
| 59 |
+
super().__init__(**kwargs)
|
| 60 |
+
self.head_dim = head_dim
|
| 61 |
+
self.hidden_size = hidden_size
|
| 62 |
+
self.intermediate_size = intermediate_size
|
| 63 |
+
self.num_hidden_layers = num_hidden_layers
|
| 64 |
+
self.num_attention_heads = num_attention_heads
|
| 65 |
+
self.initializer_range = initializer_range
|
| 66 |
+
self.rms_norm_eps = rms_norm_eps
|
| 67 |
+
self.use_cache = use_cache
|
| 68 |
+
self.hidden_act = hidden_act
|
| 69 |
+
self.rope_theta = rope_theta
|
| 70 |
+
self.rope_type = rope_type
|
| 71 |
+
if rope_parameters is None:
|
| 72 |
+
rope_parameters = {"rope_type": rope_type, "rope_theta": rope_theta, "factor": 1.0}
|
| 73 |
+
self.rope_parameters = rope_parameters
|
| 74 |
+
self.attention_bias = attention_bias
|
| 75 |
+
self.attention_dropout = attention_dropout
|
| 76 |
+
self.num_key_value_heads = num_key_value_heads
|
| 77 |
+
self.max_position_embeddings = max_position_embeddings
|
| 78 |
+
self.pad_token_id = pad_token_id
|
| 79 |
+
|
| 80 |
+
self.audio_pad_token = 1024
|
| 81 |
+
self.audio_vocab_size = 1027
|
| 82 |
+
self.rvq = 16
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class MossTTSRealtimeConfig(PretrainedConfig):
|
| 86 |
+
model_type = "moss_tts_realtime"
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
language_config: Qwen3Config | dict | None = None,
|
| 91 |
+
local_config: MossTTSRealtimeLocalTransformerConfig | dict | None = None,
|
| 92 |
+
rvq: int = 16,
|
| 93 |
+
audio_pad_token: int = 1024,
|
| 94 |
+
audio_vocab_size: int = 1027,
|
| 95 |
+
reference_audio_pad: int = 151654,
|
| 96 |
+
text_pad: int = 151655,
|
| 97 |
+
initializer_range: float = 0.02,
|
| 98 |
+
**kwargs,
|
| 99 |
+
):
|
| 100 |
+
super().__init__(**kwargs)
|
| 101 |
+
self.rvq = rvq
|
| 102 |
+
self.initializer_range = initializer_range
|
| 103 |
+
self.audio_pad_token = audio_pad_token
|
| 104 |
+
self.audio_vocab_size = audio_vocab_size
|
| 105 |
+
self.reference_audio_pad = reference_audio_pad
|
| 106 |
+
self.text_pad = text_pad
|
| 107 |
+
self.language_config = _ensure_config(language_config, Qwen3Config)
|
| 108 |
+
self.local_config = _ensure_config(local_config, MossTTSRealtimeLocalTransformerConfig)
|
| 109 |
+
|
| 110 |
+
attn_impl = self._attn_implementation
|
| 111 |
+
self.language_config._attn_implementation = attn_impl
|
| 112 |
+
self.local_config._attn_implementation = attn_impl
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
__all__ = ["MossTTSRealtimeConfig", "MossTTSRealtimeLocalTransformerConfig"]
|
merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:13a7a2a0322fe63984ff982a042373a41f707d95df5963233eff40e068a070d4
|
| 3 |
+
size 4663931664
|
modeling_mossttsrealtime.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""MossTTSRealtime backbone model."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Optional
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
|
| 26 |
+
from transformers.cache_utils import Cache
|
| 27 |
+
from transformers.modeling_outputs import ModelOutput
|
| 28 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 29 |
+
from transformers.models.qwen3 import Qwen3Model
|
| 30 |
+
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3DecoderLayer
|
| 31 |
+
|
| 32 |
+
from .configuration_mossttsrealtime import MossTTSRealtimeConfig
|
| 33 |
+
from .modeling_mossttsrealtime_local import MossTTSRealtimeLocalTransformerForCausalLM
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class MossTTSRealtimePretrainedModel(PreTrainedModel):
|
| 37 |
+
config_class = MossTTSRealtimeConfig
|
| 38 |
+
config: MossTTSRealtimeConfig
|
| 39 |
+
base_model_prefix = "model"
|
| 40 |
+
supports_gradient_checkpointing = True
|
| 41 |
+
_no_split_modules = ["Qwen3DecoderLayer"]
|
| 42 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 43 |
+
_supports_sdpa = True
|
| 44 |
+
_supports_flex_attn = True
|
| 45 |
+
_supports_flash_attn = True
|
| 46 |
+
_can_compile_fullgraph = True
|
| 47 |
+
_supports_attention_backend = True
|
| 48 |
+
_can_record_outputs = {
|
| 49 |
+
"hidden_states": Qwen3DecoderLayer,
|
| 50 |
+
"attentions": Qwen3Attention,
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
def _init_weights(self, module):
|
| 54 |
+
|
| 55 |
+
from transformers import initialization as init
|
| 56 |
+
|
| 57 |
+
std = self.config.initializer_range
|
| 58 |
+
if isinstance(module, nn.Linear):
|
| 59 |
+
# module.weight.data.normal_(mean=0.0, std=std)
|
| 60 |
+
init.normal_(module.weight, mean=0.0, std=std)
|
| 61 |
+
if module.bias is not None:
|
| 62 |
+
# module.bias.data.zero_()
|
| 63 |
+
init.zeros_(module.bias)
|
| 64 |
+
elif isinstance(module, nn.Embedding):
|
| 65 |
+
# module.weight.data.normal_(mean=0.0, std=std)
|
| 66 |
+
init.normal_(module.weight, mean=0.0, std=std)
|
| 67 |
+
if module.padding_idx is not None:
|
| 68 |
+
# module.weight.data[module.padding_idx].zero_()
|
| 69 |
+
init.zeros_(module.weight[module.padding_idx])
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@dataclass
|
| 73 |
+
class MossTTSRealtimeOutputWithPast(ModelOutput):
|
| 74 |
+
loss: Optional[torch.FloatTensor] = None
|
| 75 |
+
logits: Optional[torch.FloatTensor] = None
|
| 76 |
+
past_key_values: Optional[Cache] = None
|
| 77 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 78 |
+
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 79 |
+
attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 80 |
+
local_loss: Optional[torch.FloatTensor] = None
|
| 81 |
+
local_logits: Optional[torch.FloatTensor] = None
|
| 82 |
+
local_past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None
|
| 83 |
+
local_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 84 |
+
local_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 85 |
+
backbone_loss: Optional[torch.FloatTensor] = None
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class MossTTSRealtime(MossTTSRealtimePretrainedModel):
|
| 89 |
+
def __init__(self, config: MossTTSRealtimeConfig):
|
| 90 |
+
super().__init__(config)
|
| 91 |
+
self.config = config
|
| 92 |
+
self.embed_tokens = nn.ModuleList([])
|
| 93 |
+
self.embed_tokens.append(
|
| 94 |
+
nn.Embedding(
|
| 95 |
+
config.language_config.vocab_size,
|
| 96 |
+
config.language_config.hidden_size,
|
| 97 |
+
config.language_config.pad_token_id,
|
| 98 |
+
)
|
| 99 |
+
)
|
| 100 |
+
self.audio_vocab_size = self.config.audio_vocab_size
|
| 101 |
+
for _ in range(self.config.rvq):
|
| 102 |
+
self.embed_tokens.append(
|
| 103 |
+
nn.Embedding(self.audio_vocab_size, config.language_config.hidden_size, self.config.audio_pad_token)
|
| 104 |
+
)
|
| 105 |
+
self.language_model = Qwen3Model._from_config(config.language_config)
|
| 106 |
+
self.local_transformer = MossTTSRealtimeLocalTransformerForCausalLM._from_config(config.local_config)
|
| 107 |
+
self.post_init()
|
| 108 |
+
|
| 109 |
+
def get_input_embeddings(self, input_ids):
|
| 110 |
+
if input_ids.device != self.embed_tokens[0].weight.device:
|
| 111 |
+
input_ids = input_ids.to(self.embed_tokens[0].weight.device)
|
| 112 |
+
inputs_embeds = self.embed_tokens[0](input_ids[..., 0])
|
| 113 |
+
for i, embed in enumerate(self.embed_tokens):
|
| 114 |
+
if i == 0:
|
| 115 |
+
continue
|
| 116 |
+
inputs_embeds = inputs_embeds + embed(input_ids[..., i])
|
| 117 |
+
return inputs_embeds
|
| 118 |
+
|
| 119 |
+
def forward(
|
| 120 |
+
self,
|
| 121 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 122 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 123 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 124 |
+
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
| 125 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 126 |
+
labels: Optional[torch.LongTensor] = None,
|
| 127 |
+
use_cache: Optional[bool] = False,
|
| 128 |
+
output_attentions: Optional[bool] = None,
|
| 129 |
+
output_hidden_states: Optional[bool] = None,
|
| 130 |
+
return_dict: Optional[bool] = None,
|
| 131 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 132 |
+
hidden_out_layers: Optional[list] = None,
|
| 133 |
+
**kwargs,
|
| 134 |
+
):
|
| 135 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 136 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 137 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 138 |
+
|
| 139 |
+
if inputs_embeds is None:
|
| 140 |
+
inputs_embeds = self.get_input_embeddings(input_ids)
|
| 141 |
+
|
| 142 |
+
outputs = self.language_model(
|
| 143 |
+
position_ids=position_ids,
|
| 144 |
+
attention_mask=attention_mask,
|
| 145 |
+
past_key_values=past_key_values,
|
| 146 |
+
inputs_embeds=inputs_embeds,
|
| 147 |
+
use_cache=use_cache,
|
| 148 |
+
output_hidden_states=True,
|
| 149 |
+
cache_position=cache_position,
|
| 150 |
+
**kwargs,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
loss = None
|
| 154 |
+
local_outputs = None
|
| 155 |
+
if labels is not None:
|
| 156 |
+
audio_labels = labels[:, :, 1:]
|
| 157 |
+
train_mask = ~(audio_labels == -100).all(dim=-1)
|
| 158 |
+
local_input_ids = audio_labels[train_mask][..., : self.config.rvq - 1]
|
| 159 |
+
local_input_ids[local_input_ids == -100] = 1024
|
| 160 |
+
local_input_ids = F.pad(local_input_ids, (1, 0), value=0)
|
| 161 |
+
|
| 162 |
+
train_idx = train_mask.nonzero(as_tuple=True)
|
| 163 |
+
local_hidden_states = outputs[0][train_idx[0], train_idx[1] - 1, :].reshape(
|
| 164 |
+
-1, 1, self.config.local_config.hidden_size
|
| 165 |
+
)
|
| 166 |
+
local_labels = audio_labels[train_mask]
|
| 167 |
+
|
| 168 |
+
local_outputs = self.local_transformer(
|
| 169 |
+
input_ids=local_input_ids,
|
| 170 |
+
backbone_last_hidden_state=local_hidden_states,
|
| 171 |
+
use_cache=use_cache,
|
| 172 |
+
return_dict=True,
|
| 173 |
+
labels=local_labels,
|
| 174 |
+
**kwargs,
|
| 175 |
+
)
|
| 176 |
+
loss = local_outputs.loss
|
| 177 |
+
|
| 178 |
+
return MossTTSRealtimeOutputWithPast(
|
| 179 |
+
loss=loss,
|
| 180 |
+
logits=None,
|
| 181 |
+
past_key_values=outputs.past_key_values,
|
| 182 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 183 |
+
hidden_states=outputs.hidden_states,
|
| 184 |
+
attentions=outputs.attentions,
|
| 185 |
+
local_logits=local_outputs.logits if local_outputs is not None else None,
|
| 186 |
+
local_past_key_values=local_outputs.past_key_values if local_outputs is not None else None,
|
| 187 |
+
local_hidden_states=local_outputs.hidden_states if local_outputs is not None else None,
|
| 188 |
+
local_attentions=local_outputs.attentions if local_outputs is not None else None,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
__all__ = ["MossTTSRealtime", "MossTTSRealtimeConfig", "MossTTSRealtimeOutputWithPast", "MossTTSRealtimePretrainedModel", "Qwen3Model"]
|
modeling_mossttsrealtime_local.py
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Local transformer used by MossTTSRealtime for RVQ codebook decoding."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from typing import Optional, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
|
| 24 |
+
from transformers.activations import ACT2FN
|
| 25 |
+
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
| 26 |
+
from transformers.generation import GenerationMixin
|
| 27 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 28 |
+
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 29 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 30 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 31 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 32 |
+
from transformers.masking_utils import create_causal_mask
|
| 33 |
+
from transformers.processing_utils import Unpack
|
| 34 |
+
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
| 35 |
+
from transformers.loss.loss_utils import ForCausalLMLoss
|
| 36 |
+
|
| 37 |
+
from .configuration_mossttsrealtime import MossTTSRealtimeLocalTransformerConfig
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class MossTTSRealtimeLocalTransformerRMSNorm(nn.Module):
|
| 43 |
+
def __init__(self, hidden_size, eps=1e-6) -> None:
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 46 |
+
self.variance_epsilon = eps
|
| 47 |
+
|
| 48 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 49 |
+
input_dtype = hidden_states.dtype
|
| 50 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 51 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 52 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 53 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 54 |
+
|
| 55 |
+
def extra_repr(self):
|
| 56 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class MossTTSRealtimeLocalTransformerMLP(nn.Module):
|
| 60 |
+
def __init__(self, config: MossTTSRealtimeLocalTransformerConfig):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.config = config
|
| 63 |
+
self.hidden_size = config.hidden_size
|
| 64 |
+
self.intermediate_size = config.intermediate_size
|
| 65 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 66 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 67 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 68 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 72 |
+
return down_proj
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def rotate_half(x):
|
| 76 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 77 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 78 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
| 82 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 83 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 84 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 85 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 86 |
+
return q_embed, k_embed
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 90 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 91 |
+
if n_rep == 1:
|
| 92 |
+
return hidden_states
|
| 93 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 94 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def eager_attention_forward(
|
| 98 |
+
module: nn.Module,
|
| 99 |
+
query: torch.Tensor,
|
| 100 |
+
key: torch.Tensor,
|
| 101 |
+
value: torch.Tensor,
|
| 102 |
+
attention_mask: Optional[torch.Tensor],
|
| 103 |
+
scaling: float,
|
| 104 |
+
dropout: float = 0.0,
|
| 105 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 106 |
+
):
|
| 107 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 108 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 109 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 110 |
+
if attention_mask is not None:
|
| 111 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 112 |
+
attn_weights = attn_weights + causal_mask
|
| 113 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 114 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 115 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 116 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 117 |
+
return attn_output, attn_weights
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class MossTTSRealtimeLocalTransformerAttention(nn.Module):
|
| 121 |
+
def __init__(self, config: MossTTSRealtimeLocalTransformerConfig, layer_idx: int):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.config = config
|
| 124 |
+
self.layer_idx = layer_idx
|
| 125 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 126 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 127 |
+
self.scaling = self.head_dim**-0.5
|
| 128 |
+
self.attention_dropout = config.attention_dropout
|
| 129 |
+
self.is_causal = True
|
| 130 |
+
|
| 131 |
+
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias)
|
| 132 |
+
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
| 133 |
+
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
| 134 |
+
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
|
| 135 |
+
self.q_norm = MossTTSRealtimeLocalTransformerRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 136 |
+
self.k_norm = MossTTSRealtimeLocalTransformerRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 137 |
+
self.sliding_window = None
|
| 138 |
+
|
| 139 |
+
def forward(
|
| 140 |
+
self,
|
| 141 |
+
hidden_states: torch.Tensor,
|
| 142 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 143 |
+
attention_mask: Optional[torch.Tensor],
|
| 144 |
+
past_key_values: Optional[Cache] = None,
|
| 145 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 146 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 147 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 148 |
+
input_shape = hidden_states.shape[:-1]
|
| 149 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 150 |
+
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 151 |
+
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 152 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 153 |
+
cos, sin = position_embeddings
|
| 154 |
+
|
| 155 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 156 |
+
|
| 157 |
+
if past_key_values is not None:
|
| 158 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 159 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 160 |
+
|
| 161 |
+
attention_interface = eager_attention_forward
|
| 162 |
+
if self.config._attn_implementation != "eager":
|
| 163 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 164 |
+
|
| 165 |
+
attn_output, attn_weights = attention_interface(
|
| 166 |
+
self,
|
| 167 |
+
query_states,
|
| 168 |
+
key_states,
|
| 169 |
+
value_states,
|
| 170 |
+
attention_mask,
|
| 171 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 172 |
+
scaling=self.scaling,
|
| 173 |
+
sliding_window=self.sliding_window,
|
| 174 |
+
**kwargs,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 178 |
+
attn_output = self.o_proj(attn_output)
|
| 179 |
+
return attn_output, attn_weights
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class MossTTSRealtimeLocalTransformerDecoderLayer(GradientCheckpointingLayer):
|
| 183 |
+
def __init__(self, config: MossTTSRealtimeLocalTransformerConfig, layer_idx: int):
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.hidden_size = config.hidden_size
|
| 186 |
+
self.self_attn = MossTTSRealtimeLocalTransformerAttention(config=config, layer_idx=layer_idx)
|
| 187 |
+
self.mlp = MossTTSRealtimeLocalTransformerMLP(config)
|
| 188 |
+
self.input_layernorm = MossTTSRealtimeLocalTransformerRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 189 |
+
self.post_attention_layernorm = MossTTSRealtimeLocalTransformerRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 190 |
+
self.attention_type = "full_attention"
|
| 191 |
+
|
| 192 |
+
def forward(
|
| 193 |
+
self,
|
| 194 |
+
hidden_states: torch.Tensor,
|
| 195 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 196 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 197 |
+
past_key_values: Optional[Cache] = None,
|
| 198 |
+
use_cache: Optional[bool] = False,
|
| 199 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 200 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 201 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 202 |
+
) -> torch.Tensor:
|
| 203 |
+
residual = hidden_states
|
| 204 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 205 |
+
hidden_states, _ = self.self_attn(
|
| 206 |
+
hidden_states=hidden_states,
|
| 207 |
+
attention_mask=attention_mask,
|
| 208 |
+
position_ids=position_ids,
|
| 209 |
+
past_key_values=past_key_values,
|
| 210 |
+
use_cache=use_cache,
|
| 211 |
+
cache_position=cache_position,
|
| 212 |
+
position_embeddings=position_embeddings,
|
| 213 |
+
**kwargs,
|
| 214 |
+
)
|
| 215 |
+
hidden_states = residual + hidden_states
|
| 216 |
+
residual = hidden_states
|
| 217 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 218 |
+
hidden_states = self.mlp(hidden_states)
|
| 219 |
+
hidden_states = residual + hidden_states
|
| 220 |
+
return hidden_states
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class MossTTSRealtimeLocalTransformerPreTrainedModel(PreTrainedModel):
|
| 224 |
+
config: MossTTSRealtimeLocalTransformerConfig
|
| 225 |
+
base_model_prefix = "local_transformer"
|
| 226 |
+
supports_gradient_checkpointing = True
|
| 227 |
+
_no_split_modules = ["MossTTSRealtimeLocalTransformerDecoderLayer"]
|
| 228 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 229 |
+
_supports_sdpa = True
|
| 230 |
+
_supports_flex_attn = True
|
| 231 |
+
_supports_flash_attn = True
|
| 232 |
+
_can_compile_fullgraph = True
|
| 233 |
+
_supports_attention_backend = True
|
| 234 |
+
_can_record_outputs = {
|
| 235 |
+
"hidden_states": MossTTSRealtimeLocalTransformerDecoderLayer,
|
| 236 |
+
"attentions": MossTTSRealtimeLocalTransformerAttention,
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class MossTTSRealtimeLocalTransformerRotaryEmbedding(nn.Module):
|
| 241 |
+
inv_freq: torch.Tensor
|
| 242 |
+
|
| 243 |
+
def __init__(self, config: MossTTSRealtimeLocalTransformerConfig, device=None):
|
| 244 |
+
super().__init__()
|
| 245 |
+
self.config = config
|
| 246 |
+
self.rope_type = getattr(config, "rope_type", "linear")
|
| 247 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 248 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 249 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 250 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 251 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 252 |
+
self.original_inv_freq = self.inv_freq
|
| 253 |
+
|
| 254 |
+
@torch.no_grad()
|
| 255 |
+
@dynamic_rope_update
|
| 256 |
+
def forward(self, x, position_ids):
|
| 257 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 258 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 259 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 260 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 261 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 262 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 263 |
+
cos = emb.cos() * self.attention_scaling
|
| 264 |
+
sin = emb.sin() * self.attention_scaling
|
| 265 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class MossTTSRealtimeLocalTransformer(MossTTSRealtimeLocalTransformerPreTrainedModel):
|
| 269 |
+
def __init__(self, config: MossTTSRealtimeLocalTransformerConfig):
|
| 270 |
+
super().__init__(config)
|
| 271 |
+
self.padding_idx = config.pad_token_id
|
| 272 |
+
self.embed_tokens = nn.ModuleList(
|
| 273 |
+
[nn.Embedding(config.audio_vocab_size, config.hidden_size, config.audio_pad_token) for _ in range(config.rvq - 1)]
|
| 274 |
+
)
|
| 275 |
+
self.layers = nn.ModuleList(
|
| 276 |
+
[MossTTSRealtimeLocalTransformerDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 277 |
+
)
|
| 278 |
+
self.norm = MossTTSRealtimeLocalTransformerRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 279 |
+
self.rotary_emb = MossTTSRealtimeLocalTransformerRotaryEmbedding(config=config)
|
| 280 |
+
self.gradient_checkpointing = False
|
| 281 |
+
self.has_sliding_layers = None
|
| 282 |
+
self.post_init()
|
| 283 |
+
|
| 284 |
+
def forward(
|
| 285 |
+
self,
|
| 286 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 287 |
+
backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
|
| 288 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 289 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 290 |
+
past_key_values: Optional[Cache] = None,
|
| 291 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 292 |
+
use_cache: Optional[bool] = None,
|
| 293 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 294 |
+
codebook_idx: Optional[int] = None,
|
| 295 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 296 |
+
) -> BaseModelOutputWithPast:
|
| 297 |
+
if position_ids is not None and not torch.compiler.is_compiling():
|
| 298 |
+
position_ids = None
|
| 299 |
+
|
| 300 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 301 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
|
| 302 |
+
|
| 303 |
+
if use_cache and past_key_values is None:
|
| 304 |
+
past_key_values = StaticCache(config=self.config, max_cache_len=16, device=inputs_embeds.device)
|
| 305 |
+
|
| 306 |
+
if cache_position is None:
|
| 307 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 308 |
+
inputs_seq_length = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
|
| 309 |
+
device = inputs_embeds.device if inputs_embeds is not None else input_ids.device
|
| 310 |
+
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_seq_length, device=device)
|
| 311 |
+
|
| 312 |
+
if inputs_embeds is None:
|
| 313 |
+
if codebook_idx is not None:
|
| 314 |
+
if input_ids.ndim == 1:
|
| 315 |
+
input_ids = input_ids.unsqueeze(1)
|
| 316 |
+
token_emb = self.embed_tokens[codebook_idx - 1](input_ids[:, 0]).unsqueeze(1) # [B,1,H]
|
| 317 |
+
inputs_embeds = token_emb
|
| 318 |
+
else:
|
| 319 |
+
codebook_idxs = torch.clamp(cache_position - 1, min=0)
|
| 320 |
+
inputs_embeds = self.embed_tokens[codebook_idxs - 1](input_ids)
|
| 321 |
+
|
| 322 |
+
input_ids_are_first_codebook = cache_position[0] == 0
|
| 323 |
+
if backbone_last_hidden_state is not None:
|
| 324 |
+
inputs_embeds[:, 0] = backbone_last_hidden_state
|
| 325 |
+
else:
|
| 326 |
+
if not torch.compiler.is_compiling() and input_ids_are_first_codebook:
|
| 327 |
+
logger.warning(
|
| 328 |
+
"When the first codebook token is provided, `backbone_last_hidden_state` should also be provided for correct inference."
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
causal_mask = create_causal_mask(
|
| 332 |
+
config=self.config,
|
| 333 |
+
input_embeds=inputs_embeds,
|
| 334 |
+
attention_mask=attention_mask,
|
| 335 |
+
cache_position=cache_position,
|
| 336 |
+
past_key_values=past_key_values,
|
| 337 |
+
position_ids=position_ids,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
hidden_states = inputs_embeds
|
| 341 |
+
position_ids = cache_position.unsqueeze(0)
|
| 342 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 343 |
+
|
| 344 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 345 |
+
hidden_states = decoder_layer(
|
| 346 |
+
hidden_states,
|
| 347 |
+
attention_mask=causal_mask,
|
| 348 |
+
position_ids=position_ids,
|
| 349 |
+
past_key_values=past_key_values,
|
| 350 |
+
use_cache=use_cache,
|
| 351 |
+
cache_position=cache_position,
|
| 352 |
+
position_embeddings=position_embeddings,
|
| 353 |
+
**kwargs,
|
| 354 |
+
)
|
| 355 |
+
hidden_states = self.norm(hidden_states)
|
| 356 |
+
return BaseModelOutputWithPast(
|
| 357 |
+
last_hidden_state=hidden_states,
|
| 358 |
+
past_key_values=past_key_values if use_cache else None,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class MossTTSRealtimeLocalTransformerForCausalLM(MossTTSRealtimeLocalTransformerPreTrainedModel, GenerationMixin):
|
| 363 |
+
_tied_weights_keys = None
|
| 364 |
+
_tp_plan = None
|
| 365 |
+
_pp_plan = None
|
| 366 |
+
|
| 367 |
+
def __init__(self, config):
|
| 368 |
+
super().__init__(config)
|
| 369 |
+
self.model = MossTTSRealtimeLocalTransformer(config)
|
| 370 |
+
self.audio_vocab_size = self.config.audio_vocab_size
|
| 371 |
+
|
| 372 |
+
self.local_lm_heads = nn.ModuleList(
|
| 373 |
+
[nn.Linear(config.hidden_size, config.audio_vocab_size, bias=False) for _ in range(config.rvq)]
|
| 374 |
+
)
|
| 375 |
+
self.post_init()
|
| 376 |
+
|
| 377 |
+
def forward(
|
| 378 |
+
self,
|
| 379 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 380 |
+
backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
|
| 381 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 382 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 383 |
+
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
| 384 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 385 |
+
labels: Optional[torch.LongTensor] = None,
|
| 386 |
+
use_cache: Optional[bool] = None,
|
| 387 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 388 |
+
codebook_idx: Optional[int] = None,
|
| 389 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 390 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 391 |
+
) -> Union[tuple, CausalLMOutputWithPast]:
|
| 392 |
+
outputs = self.model(
|
| 393 |
+
input_ids=input_ids,
|
| 394 |
+
backbone_last_hidden_state=backbone_last_hidden_state,
|
| 395 |
+
inputs_embeds=inputs_embeds,
|
| 396 |
+
attention_mask=attention_mask,
|
| 397 |
+
position_ids=position_ids,
|
| 398 |
+
past_key_values=past_key_values,
|
| 399 |
+
use_cache=use_cache,
|
| 400 |
+
cache_position=cache_position,
|
| 401 |
+
codebook_idx=codebook_idx,
|
| 402 |
+
**kwargs,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
hidden_states = outputs.last_hidden_state
|
| 406 |
+
|
| 407 |
+
if isinstance(logits_to_keep, int):
|
| 408 |
+
if logits_to_keep == 0:
|
| 409 |
+
slice_indices = slice(0, None)
|
| 410 |
+
else:
|
| 411 |
+
slice_indices = slice(-logits_to_keep, None)
|
| 412 |
+
else:
|
| 413 |
+
slice_indices = logits_to_keep
|
| 414 |
+
hs = hidden_states[:, slice_indices, :]
|
| 415 |
+
|
| 416 |
+
if cache_position is not None:
|
| 417 |
+
logits = self.local_lm_heads[codebook_idx](hs[:, 0, :]).unsqueeze(1)
|
| 418 |
+
else:
|
| 419 |
+
logits_list = []
|
| 420 |
+
for i in range(hs.shape[1]):
|
| 421 |
+
logits_list.append(self.local_lm_heads[i](hs[:, i, :]))
|
| 422 |
+
logits = torch.stack(logits_list, dim=1)
|
| 423 |
+
|
| 424 |
+
logits = logits.contiguous()
|
| 425 |
+
loss = None
|
| 426 |
+
if labels is not None:
|
| 427 |
+
loss = ForCausalLMLoss(logits, None, self.audio_vocab_size, shift_labels=labels.contiguous())
|
| 428 |
+
|
| 429 |
+
return CausalLMOutputWithPast(
|
| 430 |
+
loss=loss,
|
| 431 |
+
logits=logits,
|
| 432 |
+
past_key_values=outputs.past_key_values,
|
| 433 |
+
hidden_states=outputs.hidden_states,
|
| 434 |
+
attentions=outputs.attentions,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
__all__ = [
|
| 441 |
+
"MossTTSRealtimeLocalTransformer",
|
| 442 |
+
"MossTTSRealtimeLocalTransformerAttention",
|
| 443 |
+
"MossTTSRealtimeLocalTransformerConfig",
|
| 444 |
+
"MossTTSRealtimeLocalTransformerDecoderLayer",
|
| 445 |
+
"MossTTSRealtimeLocalTransformerForCausalLM",
|
| 446 |
+
"MossTTSRealtimeLocalTransformerPreTrainedModel",
|
| 447 |
+
"MossTTSRealtimeLocalTransformerRMSNorm",
|
| 448 |
+
"MossTTSRealtimeLocalTransformerRotaryEmbedding",
|
| 449 |
+
]
|
processing_mossttsrealtime.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Processing utilities for MossTTSRealtime."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from typing import Iterable, Optional
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MossTTSRealtimeProcessor:
|
| 25 |
+
"""Builds MossTTSRealtime prompt inputs with text and audio codebooks.
|
| 26 |
+
|
| 27 |
+
This processor focuses on preparing the mixed text/audio token layout expected by MossTTSRealtime.
|
| 28 |
+
It does not perform audio encoding/decoding by itself.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
tokenizer,
|
| 34 |
+
audio_pad_token: str = "<|audio_pad|>",
|
| 35 |
+
text_pad_token: str = "<|text_pad|>",
|
| 36 |
+
tts_system_prompt: Optional[str] = None,
|
| 37 |
+
channels: int = 16,
|
| 38 |
+
audio_channel_pad: int = 1024,
|
| 39 |
+
audio_bos_token: int = 1025,
|
| 40 |
+
audio_eos_token: int = 1026,
|
| 41 |
+
delay_tokens_len: int = 12,
|
| 42 |
+
):
|
| 43 |
+
self.tokenizer = tokenizer
|
| 44 |
+
self.channels = channels
|
| 45 |
+
self.audio_channel_pad = audio_channel_pad
|
| 46 |
+
self.audio_bos_token = audio_bos_token
|
| 47 |
+
self.audio_eos_token = audio_eos_token
|
| 48 |
+
self.delay_tokens_len = delay_tokens_len
|
| 49 |
+
|
| 50 |
+
self.audio_pad_token_id = self._convert_token_to_id(audio_pad_token)
|
| 51 |
+
self.text_pad_token_id = self._convert_token_to_id(text_pad_token)
|
| 52 |
+
|
| 53 |
+
if tts_system_prompt is None:
|
| 54 |
+
tts_system_prompt = (
|
| 55 |
+
"<|im_start|>system\n"
|
| 56 |
+
"You are a highly expressive text-to-speech (TTS) engine developed by Mosi Intelligence. \n"
|
| 57 |
+
"You possess natural language understanding, emotional modeling, and multi-style speech generation "
|
| 58 |
+
"capabilities, allowing you to generate the corresponding speech based on the text given in the assistant."
|
| 59 |
+
"<|im_end|>\n"
|
| 60 |
+
)
|
| 61 |
+
self.ttsbase_system_prompt = tts_system_prompt
|
| 62 |
+
|
| 63 |
+
def _convert_token_to_id(self, token: str) -> int:
|
| 64 |
+
if hasattr(self.tokenizer, "convert_tokens_to_ids"):
|
| 65 |
+
token_id = self.tokenizer.convert_tokens_to_ids(token)
|
| 66 |
+
if token_id is not None and token_id != self.tokenizer.unk_token_id:
|
| 67 |
+
return int(token_id)
|
| 68 |
+
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
|
| 69 |
+
if not token_ids:
|
| 70 |
+
raise ValueError(f"Token '{token}' could not be converted to an id.")
|
| 71 |
+
if len(token_ids) != 1:
|
| 72 |
+
raise ValueError(f"Token '{token}' maps to multiple ids: {token_ids}")
|
| 73 |
+
return int(token_ids[0])
|
| 74 |
+
|
| 75 |
+
def make_voice_clone_prompt(self, prompt_audio_tokens_len: int) -> str:
|
| 76 |
+
padded_audio_prompt = f"{'<|audio_pad|>' * prompt_audio_tokens_len}"
|
| 77 |
+
voice_clone = (
|
| 78 |
+
"<|im_start|>context\n"
|
| 79 |
+
"The assistant section should be synthesized using the following voice timbre:"
|
| 80 |
+
f"{padded_audio_prompt}"
|
| 81 |
+
)
|
| 82 |
+
return voice_clone
|
| 83 |
+
|
| 84 |
+
def _normalize_audio_tokens(self, audio_tokens: np.ndarray | Iterable) -> np.ndarray:
|
| 85 |
+
tokens = np.array(audio_tokens)
|
| 86 |
+
if tokens.ndim != 2:
|
| 87 |
+
raise ValueError(f"Expected 2D audio tokens, got shape {tokens.shape}")
|
| 88 |
+
if tokens.shape[0] == self.channels:
|
| 89 |
+
tokens = tokens.T
|
| 90 |
+
elif tokens.shape[1] == self.channels:
|
| 91 |
+
tokens = tokens
|
| 92 |
+
elif tokens.shape[0] > self.channels and tokens.shape[1] != self.channels:
|
| 93 |
+
tokens = tokens[: self.channels, :].T
|
| 94 |
+
elif tokens.shape[1] > self.channels and tokens.shape[0] != self.channels:
|
| 95 |
+
tokens = tokens[:, : self.channels]
|
| 96 |
+
if tokens.shape[1] != self.channels:
|
| 97 |
+
raise ValueError(f"Expected {self.channels} channels, got shape {tokens.shape}")
|
| 98 |
+
return tokens
|
| 99 |
+
|
| 100 |
+
def make_ensemble(self, prompt_audio_tokens: Optional[np.ndarray] = None) -> np.ndarray:
|
| 101 |
+
if prompt_audio_tokens is not None:
|
| 102 |
+
prompt_audio_tokens = self._normalize_audio_tokens(prompt_audio_tokens)
|
| 103 |
+
prompt_audio_tokens = prompt_audio_tokens[:, : self.channels]
|
| 104 |
+
system_prompt_text = f"{self.ttsbase_system_prompt}" + f"{self.make_voice_clone_prompt(prompt_audio_tokens.shape[0])}"
|
| 105 |
+
else:
|
| 106 |
+
system_prompt_text = f"{self.ttsbase_system_prompt}"
|
| 107 |
+
|
| 108 |
+
system_prompt_tokens = self.tokenizer(system_prompt_text)["input_ids"]
|
| 109 |
+
system_prompt_tokens_full = np.full(
|
| 110 |
+
shape=(len(system_prompt_tokens), self.channels + 1), fill_value=self.audio_channel_pad, dtype=np.int64
|
| 111 |
+
)
|
| 112 |
+
system_prompt_tokens_full[:, 0] = system_prompt_tokens
|
| 113 |
+
|
| 114 |
+
if prompt_audio_tokens is not None:
|
| 115 |
+
system_prompt_tokens = np.array(system_prompt_tokens)
|
| 116 |
+
indices = np.where(system_prompt_tokens == self.audio_pad_token_id)[0]
|
| 117 |
+
if indices.size == 0:
|
| 118 |
+
raise ValueError("No <|audio_pad|> tokens found in the system prompt.")
|
| 119 |
+
prompt_audio_start_pos, prompt_audio_end_pos = indices[0], indices[-1]
|
| 120 |
+
system_prompt_tokens_full[prompt_audio_start_pos : prompt_audio_end_pos + 1, 1:] = prompt_audio_tokens
|
| 121 |
+
|
| 122 |
+
return system_prompt_tokens_full
|
| 123 |
+
|
| 124 |
+
def make_user_prompt(self, text: str, audio_tokens: np.ndarray) -> np.ndarray:
|
| 125 |
+
prefill_temp = "<|im_end|>\n<|im_start|>user\n"
|
| 126 |
+
text_tokens = self.tokenizer(text)["input_ids"]
|
| 127 |
+
text_start_pos = len(self.tokenizer.encode(prefill_temp))
|
| 128 |
+
token = self._normalize_audio_tokens(audio_tokens)
|
| 129 |
+
|
| 130 |
+
text_len = len(text_tokens)
|
| 131 |
+
audio_len = token.shape[0]
|
| 132 |
+
|
| 133 |
+
if text_len >= self.delay_tokens_len:
|
| 134 |
+
padded_text_len = audio_len + self.delay_tokens_len - text_len + 1
|
| 135 |
+
cur_input_id_ch1 = prefill_temp + text + "<|text_pad|>" * padded_text_len
|
| 136 |
+
assistant_tokens_ch1 = self.tokenizer(cur_input_id_ch1)["input_ids"]
|
| 137 |
+
cur_input_id = np.full(
|
| 138 |
+
shape=(len(assistant_tokens_ch1), self.channels + 1),
|
| 139 |
+
fill_value=self.audio_channel_pad,
|
| 140 |
+
dtype=np.int64,
|
| 141 |
+
)
|
| 142 |
+
cur_input_id[:, 0] = assistant_tokens_ch1
|
| 143 |
+
cur_input_id[
|
| 144 |
+
text_start_pos + self.delay_tokens_len : text_start_pos + self.delay_tokens_len + audio_len, 1:
|
| 145 |
+
] = token
|
| 146 |
+
cur_input_id[text_start_pos + self.delay_tokens_len - 1, 1] = self.audio_bos_token
|
| 147 |
+
cur_input_id[text_start_pos + self.delay_tokens_len + audio_len, 1] = self.audio_eos_token
|
| 148 |
+
else:
|
| 149 |
+
padded_text_len = audio_len + 1
|
| 150 |
+
cur_input_id_ch1 = prefill_temp + text + "<|text_pad|>" * padded_text_len
|
| 151 |
+
assistant_tokens_ch1 = self.tokenizer(cur_input_id_ch1)["input_ids"]
|
| 152 |
+
cur_input_id = np.full(
|
| 153 |
+
shape=(len(assistant_tokens_ch1), self.channels + 1),
|
| 154 |
+
fill_value=self.audio_channel_pad,
|
| 155 |
+
dtype=np.int64,
|
| 156 |
+
)
|
| 157 |
+
cur_input_id[:, 0] = assistant_tokens_ch1
|
| 158 |
+
cur_input_id[-(audio_len + 1) : -1, 1:] = token
|
| 159 |
+
cur_input_id[-(audio_len + 2), 1] = self.audio_bos_token
|
| 160 |
+
cur_input_id[-1, 1] = self.audio_eos_token
|
| 161 |
+
|
| 162 |
+
begin_of_response = self.tokenizer.encode("<|im_end|>\n<|im_start|>assistant\n")
|
| 163 |
+
begin_of_response_full = np.full(
|
| 164 |
+
shape=(len(begin_of_response), self.channels + 1), fill_value=self.audio_channel_pad, dtype=np.int64
|
| 165 |
+
)
|
| 166 |
+
begin_of_response_full[:, 0] = begin_of_response
|
| 167 |
+
|
| 168 |
+
input_ids = np.concatenate([cur_input_id, begin_of_response_full], axis=0)
|
| 169 |
+
return input_ids
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
__all__ = ["MossTTSRealtimeProcessor"]
|
streaming_mossttsrealtime.py
ADDED
|
@@ -0,0 +1,1003 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Streaming inference utilities for MossTTSRealtime."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
import re
|
| 21 |
+
import numpy as np
|
| 22 |
+
import contextlib
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
import torchaudio
|
| 27 |
+
from transformers.cache_utils import StaticCache
|
| 28 |
+
from transformers.utils.import_utils import requires
|
| 29 |
+
from typing import Iterable, Iterator, List, Optional, Sequence
|
| 30 |
+
|
| 31 |
+
@requires(backends=("torch",))
|
| 32 |
+
class MossTTSRealtimeInference:
|
| 33 |
+
"""Step-wise inference wrapper for MossTTSRealtime.
|
| 34 |
+
|
| 35 |
+
This class mirrors the non-streaming inference logic but exposes a
|
| 36 |
+
prefill/step/finish API for streaming usage.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
model,
|
| 42 |
+
tokenizer,
|
| 43 |
+
max_length: int = 1000,
|
| 44 |
+
channels: int = 16,
|
| 45 |
+
audio_channel_pad: int = 1024,
|
| 46 |
+
audio_bos_token: int = 1025,
|
| 47 |
+
audio_eos_token: int = 1026,
|
| 48 |
+
text_pad_id: int = 151655,
|
| 49 |
+
aud_pad_id: int = 151654,
|
| 50 |
+
):
|
| 51 |
+
self.model = model
|
| 52 |
+
self.tokenizer = tokenizer
|
| 53 |
+
self.max_length = max_length
|
| 54 |
+
self.channels = channels
|
| 55 |
+
self.audio_channel_pad = audio_channel_pad
|
| 56 |
+
self.audio_bos_token = audio_bos_token
|
| 57 |
+
self.audio_eos_token = audio_eos_token
|
| 58 |
+
self.text_pad_id = text_pad_id
|
| 59 |
+
self.aud_pad_id = aud_pad_id
|
| 60 |
+
|
| 61 |
+
self.past_key_values = None
|
| 62 |
+
self.attention_mask = None
|
| 63 |
+
self._generated_tokens: List[torch.Tensor] = []
|
| 64 |
+
self._is_stopping = None
|
| 65 |
+
self._last_audio_tokens = None
|
| 66 |
+
self._step_idx = 0
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def device(self):
|
| 70 |
+
return next(self.model.parameters()).device
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def is_finished(self) -> bool:
|
| 74 |
+
return self._is_stopping is not None and bool(self._is_stopping.all())
|
| 75 |
+
|
| 76 |
+
def reset_generation_state(self, keep_cache: bool = True):
|
| 77 |
+
# When keep_cache=True, retain the attention_mask so that its length matches past_key_values.
|
| 78 |
+
# This is used for concatenation in the next prefill step.
|
| 79 |
+
if not keep_cache:
|
| 80 |
+
self.past_key_values = None
|
| 81 |
+
self.attention_mask = None
|
| 82 |
+
self._generated_tokens = []
|
| 83 |
+
self._is_stopping = None
|
| 84 |
+
self._last_audio_tokens = None
|
| 85 |
+
self._step_idx = 0
|
| 86 |
+
|
| 87 |
+
def _normalize_input_ids(self, input_ids):
|
| 88 |
+
if isinstance(input_ids, torch.Tensor):
|
| 89 |
+
input_ids = input_ids.detach().cpu().numpy()
|
| 90 |
+
if isinstance(input_ids, np.ndarray):
|
| 91 |
+
if input_ids.ndim == 2:
|
| 92 |
+
return [input_ids]
|
| 93 |
+
if input_ids.ndim == 3:
|
| 94 |
+
return [input_ids[i] for i in range(input_ids.shape[0])]
|
| 95 |
+
if isinstance(input_ids, (list, tuple)):
|
| 96 |
+
return [np.array(item) for item in input_ids]
|
| 97 |
+
raise ValueError("input_ids must be a list/array/tensor of shape [T, C] or [B, T, C].")
|
| 98 |
+
|
| 99 |
+
def _normalize_text_prefix(self, text_prefix_ids, batch_size: int) -> list[list[int]]:
|
| 100 |
+
if text_prefix_ids is None:
|
| 101 |
+
raise ValueError("text_prefix_ids must be provided for prefill.")
|
| 102 |
+
if isinstance(text_prefix_ids, torch.Tensor):
|
| 103 |
+
text_prefix_ids = text_prefix_ids.detach().cpu().tolist()
|
| 104 |
+
if isinstance(text_prefix_ids, np.ndarray):
|
| 105 |
+
text_prefix_ids = text_prefix_ids.tolist()
|
| 106 |
+
if isinstance(text_prefix_ids, list):
|
| 107 |
+
if len(text_prefix_ids) == 0:
|
| 108 |
+
return [[] for _ in range(batch_size)]
|
| 109 |
+
if isinstance(text_prefix_ids[0], (int, np.integer)):
|
| 110 |
+
return [list(text_prefix_ids)]
|
| 111 |
+
if len(text_prefix_ids) == 1 and batch_size > 1:
|
| 112 |
+
return [list(text_prefix_ids[0]) for _ in range(batch_size)]
|
| 113 |
+
if len(text_prefix_ids) != batch_size:
|
| 114 |
+
raise ValueError(
|
| 115 |
+
f"text_prefix_ids batch size mismatch: got {len(text_prefix_ids)}, expected {batch_size}."
|
| 116 |
+
)
|
| 117 |
+
return [list(item) for item in text_prefix_ids]
|
| 118 |
+
raise ValueError("text_prefix_ids must be list-like or tensor-like.")
|
| 119 |
+
|
| 120 |
+
@torch.inference_mode()
|
| 121 |
+
def prefill(
|
| 122 |
+
self,
|
| 123 |
+
input_ids,
|
| 124 |
+
text_prefix_ids,
|
| 125 |
+
max_prefill_len: Optional[int] = None,
|
| 126 |
+
past_key_values=None,
|
| 127 |
+
device: Optional[torch.device] = None,
|
| 128 |
+
temperature: float = 0.8,
|
| 129 |
+
top_p: float = 0.6,
|
| 130 |
+
top_k: int = 30,
|
| 131 |
+
do_sample: bool = True,
|
| 132 |
+
repetition_penalty: Optional[float] = 1.1,
|
| 133 |
+
repetition_window: Optional[int] = 50,
|
| 134 |
+
) -> torch.Tensor:
|
| 135 |
+
if device is None:
|
| 136 |
+
device = self.device
|
| 137 |
+
|
| 138 |
+
if past_key_values is not None:
|
| 139 |
+
self.past_key_values = past_key_values
|
| 140 |
+
|
| 141 |
+
input_ids_list = self._normalize_input_ids(input_ids)
|
| 142 |
+
batch_size = len(input_ids_list)
|
| 143 |
+
text_prefix_list = self._normalize_text_prefix(text_prefix_ids, batch_size)
|
| 144 |
+
|
| 145 |
+
concat_inputs_id_list = []
|
| 146 |
+
for i in range(batch_size):
|
| 147 |
+
prefix = text_prefix_list[i]
|
| 148 |
+
if max_prefill_len is not None:
|
| 149 |
+
prefix = prefix[:max_prefill_len]
|
| 150 |
+
if len(prefix) == 0:
|
| 151 |
+
raise ValueError("Prefill requires at least one text token.")
|
| 152 |
+
|
| 153 |
+
text_seg = np.full((len(prefix), self.channels + 1), self.audio_channel_pad, dtype=np.int64)
|
| 154 |
+
text_seg[:, 0] = np.array(prefix, dtype=np.int64)
|
| 155 |
+
text_seg[len(prefix) - 1, 1] = self.audio_bos_token
|
| 156 |
+
concat_inputs_id = np.concatenate([input_ids_list[i], text_seg], axis=0)
|
| 157 |
+
concat_inputs_id_list.append(concat_inputs_id)
|
| 158 |
+
|
| 159 |
+
attention_masks = [np.ones(ids.shape[0], dtype=np.bool_) for ids in concat_inputs_id_list]
|
| 160 |
+
max_len = max(ids.shape[0] for ids in concat_inputs_id_list)
|
| 161 |
+
padded_input_ids, padded_attns = [], []
|
| 162 |
+
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.text_pad_id
|
| 163 |
+
|
| 164 |
+
for ids, attn in zip(concat_inputs_id_list, attention_masks):
|
| 165 |
+
pad_len = max_len - ids.shape[0]
|
| 166 |
+
input_pad = np.full((pad_len, self.channels + 1), self.audio_channel_pad, dtype=np.int64)
|
| 167 |
+
input_pad[:, 0] = pad_token_id
|
| 168 |
+
padded_input_ids.append(np.concatenate([input_pad, ids]))
|
| 169 |
+
attn_pad = np.zeros(pad_len, dtype=np.bool_)
|
| 170 |
+
padded_attns.append(np.concatenate([attn_pad, attn]))
|
| 171 |
+
|
| 172 |
+
current_input_ids = torch.from_numpy(np.stack(padded_input_ids)).to(device)
|
| 173 |
+
current_attention_mask = torch.from_numpy(np.stack(padded_attns)).to(device)
|
| 174 |
+
|
| 175 |
+
if self.attention_mask is not None and self.past_key_values is not None:
|
| 176 |
+
current_attention_mask = torch.cat([self.attention_mask, current_attention_mask], dim=-1)
|
| 177 |
+
|
| 178 |
+
outputs = self.model(
|
| 179 |
+
input_ids=current_input_ids,
|
| 180 |
+
attention_mask=current_attention_mask,
|
| 181 |
+
past_key_values=self.past_key_values,
|
| 182 |
+
use_cache=True,
|
| 183 |
+
return_dict=True,
|
| 184 |
+
)
|
| 185 |
+
self.past_key_values = outputs.past_key_values
|
| 186 |
+
self.attention_mask = current_attention_mask
|
| 187 |
+
|
| 188 |
+
backbone_hidden_states = outputs.last_hidden_state[:, -1:, :]
|
| 189 |
+
audio_tokens = self.generate_local_transformer(
|
| 190 |
+
hidden_states=backbone_hidden_states,
|
| 191 |
+
temperature=temperature,
|
| 192 |
+
top_p=top_p,
|
| 193 |
+
top_k=top_k,
|
| 194 |
+
do_sample=do_sample,
|
| 195 |
+
repetition_penalty=repetition_penalty,
|
| 196 |
+
repetition_window=repetition_window,
|
| 197 |
+
generated_tokens=None,
|
| 198 |
+
gen_step=0,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
self._generated_tokens = [audio_tokens]
|
| 202 |
+
self._last_audio_tokens = audio_tokens
|
| 203 |
+
self._is_stopping = audio_tokens[:, 0] == self.audio_eos_token
|
| 204 |
+
self._step_idx = 1
|
| 205 |
+
return audio_tokens
|
| 206 |
+
|
| 207 |
+
@torch.inference_mode()
|
| 208 |
+
def step(
|
| 209 |
+
self,
|
| 210 |
+
text_token: Optional[Iterable[int] | torch.Tensor | int],
|
| 211 |
+
temperature: float = 0.8,
|
| 212 |
+
top_p: float = 0.6,
|
| 213 |
+
top_k: int = 30,
|
| 214 |
+
do_sample: bool = True,
|
| 215 |
+
repetition_penalty: Optional[float] = 1.1,
|
| 216 |
+
repetition_window: Optional[int] = 50,
|
| 217 |
+
) -> torch.Tensor:
|
| 218 |
+
if self._last_audio_tokens is None or self.attention_mask is None:
|
| 219 |
+
raise ValueError("You must call prefill() before step().")
|
| 220 |
+
if self.is_finished:
|
| 221 |
+
return self._last_audio_tokens
|
| 222 |
+
|
| 223 |
+
batch_size = self._last_audio_tokens.shape[0]
|
| 224 |
+
if text_token is None:
|
| 225 |
+
text_tokens = [self.text_pad_id] * batch_size
|
| 226 |
+
elif isinstance(text_token, torch.Tensor):
|
| 227 |
+
text_tokens = text_token.detach().cpu().tolist()
|
| 228 |
+
elif isinstance(text_token, (list, tuple, np.ndarray)):
|
| 229 |
+
text_tokens = list(text_token)
|
| 230 |
+
else:
|
| 231 |
+
text_tokens = [int(text_token)]
|
| 232 |
+
|
| 233 |
+
if len(text_tokens) != batch_size:
|
| 234 |
+
raise ValueError(f"text_token batch size mismatch: got {len(text_tokens)}, expected {batch_size}.")
|
| 235 |
+
|
| 236 |
+
device = self._last_audio_tokens.device
|
| 237 |
+
text_t = torch.tensor(text_tokens, device=device, dtype=torch.long)
|
| 238 |
+
step_ids = torch.cat([text_t[:, None, None], self._last_audio_tokens.unsqueeze(1)], dim=2)
|
| 239 |
+
self.attention_mask = torch.cat([self.attention_mask, (~self._is_stopping).unsqueeze(-1)], dim=-1)
|
| 240 |
+
|
| 241 |
+
outputs = self.model(
|
| 242 |
+
input_ids=step_ids,
|
| 243 |
+
attention_mask=self.attention_mask,
|
| 244 |
+
past_key_values=self.past_key_values,
|
| 245 |
+
use_cache=True,
|
| 246 |
+
return_dict=True,
|
| 247 |
+
)
|
| 248 |
+
self.past_key_values = outputs.past_key_values
|
| 249 |
+
backbone_hidden_states = outputs.last_hidden_state[:, -1:, :]
|
| 250 |
+
|
| 251 |
+
history = torch.stack(self._generated_tokens, dim=1) if self._generated_tokens else None
|
| 252 |
+
audio_tokens = self.generate_local_transformer(
|
| 253 |
+
hidden_states=backbone_hidden_states,
|
| 254 |
+
temperature=temperature,
|
| 255 |
+
top_p=top_p,
|
| 256 |
+
top_k=top_k,
|
| 257 |
+
do_sample=do_sample,
|
| 258 |
+
repetition_penalty=repetition_penalty,
|
| 259 |
+
repetition_window=repetition_window,
|
| 260 |
+
generated_tokens=history,
|
| 261 |
+
gen_step=self._step_idx,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
self._generated_tokens.append(audio_tokens)
|
| 265 |
+
self._last_audio_tokens = audio_tokens
|
| 266 |
+
self._is_stopping |= audio_tokens[:, 0] == self.audio_eos_token
|
| 267 |
+
self._step_idx += 1
|
| 268 |
+
return audio_tokens
|
| 269 |
+
|
| 270 |
+
@torch.inference_mode()
|
| 271 |
+
def finish(
|
| 272 |
+
self,
|
| 273 |
+
max_steps: Optional[int] = None,
|
| 274 |
+
temperature: float = 0.8,
|
| 275 |
+
top_p: float = 0.6,
|
| 276 |
+
top_k: int = 30,
|
| 277 |
+
do_sample: bool = True,
|
| 278 |
+
repetition_penalty: Optional[float] = 1.1,
|
| 279 |
+
repetition_window: Optional[int] = 50,
|
| 280 |
+
) -> list[torch.Tensor]:
|
| 281 |
+
outputs = []
|
| 282 |
+
steps_left = max_steps if max_steps is not None else self.max_length
|
| 283 |
+
while steps_left > 0 and not self.is_finished:
|
| 284 |
+
outputs.append(
|
| 285 |
+
self.step(
|
| 286 |
+
text_token=None,
|
| 287 |
+
temperature=temperature,
|
| 288 |
+
top_p=top_p,
|
| 289 |
+
top_k=top_k,
|
| 290 |
+
do_sample=do_sample,
|
| 291 |
+
repetition_penalty=repetition_penalty,
|
| 292 |
+
repetition_window=repetition_window,
|
| 293 |
+
)
|
| 294 |
+
)
|
| 295 |
+
steps_left -= 1
|
| 296 |
+
return outputs
|
| 297 |
+
|
| 298 |
+
@torch.compile(fullgraph=True)
|
| 299 |
+
def generate_local_transformer(
|
| 300 |
+
self,
|
| 301 |
+
hidden_states: torch.Tensor,
|
| 302 |
+
temperature: float,
|
| 303 |
+
top_p: float,
|
| 304 |
+
top_k: int,
|
| 305 |
+
do_sample: bool,
|
| 306 |
+
repetition_penalty: Optional[float],
|
| 307 |
+
repetition_window: Optional[int],
|
| 308 |
+
generated_tokens: Optional[torch.Tensor],
|
| 309 |
+
gen_step: int,
|
| 310 |
+
) -> torch.Tensor:
|
| 311 |
+
batch_size = hidden_states.shape[0]
|
| 312 |
+
device = hidden_states.device
|
| 313 |
+
local_inputs = hidden_states.reshape(-1, 1, self.model.config.local_config.hidden_size)
|
| 314 |
+
output_token = torch.empty(batch_size, self.channels, dtype=torch.long, device=device)
|
| 315 |
+
|
| 316 |
+
past_key_values = StaticCache(config=self.model.local_transformer.config, max_cache_len=self.channels)
|
| 317 |
+
local_token = None
|
| 318 |
+
|
| 319 |
+
cache_pos_t = torch.zeros(1, dtype=torch.long, device=device)
|
| 320 |
+
|
| 321 |
+
for i in range(self.channels):
|
| 322 |
+
cache_pos_t.fill_(i)
|
| 323 |
+
|
| 324 |
+
local_outputs = self.model.local_transformer(
|
| 325 |
+
input_ids=local_token,
|
| 326 |
+
inputs_embeds=local_inputs,
|
| 327 |
+
past_key_values=past_key_values,
|
| 328 |
+
cache_position=cache_pos_t,
|
| 329 |
+
codebook_idx=i,
|
| 330 |
+
use_cache=True,
|
| 331 |
+
logits_to_keep=1,
|
| 332 |
+
)
|
| 333 |
+
logits = local_outputs.logits
|
| 334 |
+
|
| 335 |
+
if repetition_penalty and repetition_penalty != 1.0 and generated_tokens is not None:
|
| 336 |
+
logits = self.apply_repetition_penalty(
|
| 337 |
+
scores=logits,
|
| 338 |
+
history_tokens=generated_tokens[:, :gen_step, i],
|
| 339 |
+
penalty=float(repetition_penalty),
|
| 340 |
+
repetition_window=repetition_window,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
local_token = self.sample_token(
|
| 344 |
+
logits=logits,
|
| 345 |
+
temperature=temperature,
|
| 346 |
+
top_p=top_p,
|
| 347 |
+
top_k=top_k,
|
| 348 |
+
do_sample=do_sample,
|
| 349 |
+
)
|
| 350 |
+
output_token[:, i] = local_token.squeeze(-1)
|
| 351 |
+
|
| 352 |
+
if i == 0:
|
| 353 |
+
local_inputs = None
|
| 354 |
+
return output_token
|
| 355 |
+
|
| 356 |
+
def apply_repetition_penalty(
|
| 357 |
+
self,
|
| 358 |
+
scores: torch.Tensor,
|
| 359 |
+
history_tokens: torch.Tensor,
|
| 360 |
+
penalty: float = 1.1,
|
| 361 |
+
repetition_window: Optional[int] = None,
|
| 362 |
+
):
|
| 363 |
+
scores_ = scores[:, 0, :]
|
| 364 |
+
B, V = scores_.shape
|
| 365 |
+
ht = history_tokens
|
| 366 |
+
|
| 367 |
+
if repetition_window is not None and repetition_window > 0:
|
| 368 |
+
ht = ht[:, -repetition_window:]
|
| 369 |
+
|
| 370 |
+
ht_sorted, _ = torch.sort(ht, dim=1)
|
| 371 |
+
uniq = torch.unique_consecutive(ht_sorted, dim=1)
|
| 372 |
+
|
| 373 |
+
b_idx = torch.arange(B, device=uniq.device).unsqueeze(1).expand_as(uniq)
|
| 374 |
+
b_flat = b_idx.reshape(-1)
|
| 375 |
+
t_flat = uniq.reshape(-1)
|
| 376 |
+
|
| 377 |
+
cur = scores_[b_flat, t_flat]
|
| 378 |
+
new = torch.where(cur < 0, cur * penalty, cur / penalty)
|
| 379 |
+
|
| 380 |
+
scores_[b_flat, t_flat] = new
|
| 381 |
+
|
| 382 |
+
return scores_
|
| 383 |
+
|
| 384 |
+
def sample_token(self, logits, temperature, top_p=0.6, top_k=30, do_sample=True):
|
| 385 |
+
if not do_sample or temperature == 0:
|
| 386 |
+
return torch.argmax(logits, dim=-1)
|
| 387 |
+
logits = logits / temperature
|
| 388 |
+
original_shape = logits.shape
|
| 389 |
+
vocab_size = original_shape[-1]
|
| 390 |
+
reshaped_logits = logits.reshape(-1, vocab_size)
|
| 391 |
+
|
| 392 |
+
if top_k is not None:
|
| 393 |
+
reshaped_logits = self.apply_top_k(reshaped_logits, top_k)
|
| 394 |
+
|
| 395 |
+
if top_p is not None:
|
| 396 |
+
reshaped_logits = self.apply_top_p(reshaped_logits, top_p)
|
| 397 |
+
|
| 398 |
+
probs = F.softmax(reshaped_logits, dim=-1)
|
| 399 |
+
next_tokens_flat = torch.multinomial(probs, num_samples=1)
|
| 400 |
+
|
| 401 |
+
output_shape = original_shape[:-1]
|
| 402 |
+
return next_tokens_flat.view(output_shape)
|
| 403 |
+
|
| 404 |
+
def apply_top_k(self, logits, top_k, filter_value=float("-inf"), min_tokens_to_keep: int = 1):
|
| 405 |
+
if not isinstance(top_k, int) or top_k <= 0:
|
| 406 |
+
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
|
| 407 |
+
batch_size, vocab_size = logits.shape
|
| 408 |
+
top_k = max(top_k, min_tokens_to_keep)
|
| 409 |
+
top_k = min(top_k, vocab_size)
|
| 410 |
+
indices_to_remove = torch.topk(logits, top_k, dim=-1).values[..., -1, None]
|
| 411 |
+
return logits.masked_fill(logits < indices_to_remove, filter_value)
|
| 412 |
+
|
| 413 |
+
def apply_top_p(self, logits, top_p, filter_value=float("-inf"), min_tokens_to_keep: int = 1):
|
| 414 |
+
top_p = float(top_p)
|
| 415 |
+
if top_p < 0 or top_p > 1.0:
|
| 416 |
+
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
|
| 417 |
+
|
| 418 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
| 419 |
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
| 420 |
+
|
| 421 |
+
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
| 422 |
+
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
|
| 423 |
+
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 424 |
+
logits_processed = logits.masked_fill(indices_to_remove, filter_value)
|
| 425 |
+
return logits_processed
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
@requires(backends=("torch",))
|
| 429 |
+
class MossTTSRealtimeStreamingSession:
|
| 430 |
+
"""Manage text-to-audio streaming for a single conversation."""
|
| 431 |
+
|
| 432 |
+
_split_pattern = re.compile(
|
| 433 |
+
r"[。!?!?\.\u2026]\s*"
|
| 434 |
+
r"|[,,;;::\u2014\u2013\-]\s*"
|
| 435 |
+
r"|\)\s*|\]\s*"
|
| 436 |
+
r"|\n"
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
def __init__(
|
| 440 |
+
self,
|
| 441 |
+
inferencer: MossTTSRealtimeInference,
|
| 442 |
+
processor,
|
| 443 |
+
codec=None,
|
| 444 |
+
codec_sample_rate: int = 24000,
|
| 445 |
+
codec_encode_kwargs: Optional[dict] = None,
|
| 446 |
+
prefill_text_len: int = 12,
|
| 447 |
+
text_buffer_size: int = 32,
|
| 448 |
+
min_text_chunk_chars: int = 8,
|
| 449 |
+
temperature: float = 0.8,
|
| 450 |
+
top_p: float = 0.6,
|
| 451 |
+
top_k: int = 30,
|
| 452 |
+
do_sample: bool = True,
|
| 453 |
+
repetition_penalty: Optional[float] = 1.1,
|
| 454 |
+
repetition_window: Optional[int] = 50,
|
| 455 |
+
):
|
| 456 |
+
self.inferencer = inferencer
|
| 457 |
+
self.processor = processor
|
| 458 |
+
self.tokenizer = processor.tokenizer
|
| 459 |
+
self.codec = codec
|
| 460 |
+
self.codec_sample_rate = codec_sample_rate
|
| 461 |
+
self.codec_encode_kwargs = codec_encode_kwargs or {}
|
| 462 |
+
|
| 463 |
+
self.prefill_text_len = prefill_text_len
|
| 464 |
+
self.text_buffer_size = text_buffer_size
|
| 465 |
+
self.min_text_chunk_chars = min_text_chunk_chars
|
| 466 |
+
|
| 467 |
+
self.temperature = temperature
|
| 468 |
+
self.top_p = top_p
|
| 469 |
+
self.top_k = top_k
|
| 470 |
+
self.do_sample = do_sample
|
| 471 |
+
self.repetition_penalty = repetition_penalty
|
| 472 |
+
self.repetition_window = repetition_window
|
| 473 |
+
|
| 474 |
+
self._voice_prompt_tokens = None
|
| 475 |
+
self._turn_input_ids = None
|
| 476 |
+
self._turn_idx = 0
|
| 477 |
+
|
| 478 |
+
self._text_cache = ""
|
| 479 |
+
self._pending_tokens: list[int] = []
|
| 480 |
+
self._prefilled = False
|
| 481 |
+
self._text_ended = False
|
| 482 |
+
|
| 483 |
+
def set_voice_prompt_tokens(self, audio_tokens: np.ndarray):
|
| 484 |
+
self._voice_prompt_tokens = audio_tokens
|
| 485 |
+
|
| 486 |
+
def set_voice_prompt(self, audio, sample_rate: Optional[int] = None):
|
| 487 |
+
"""Set voice prompt from either audio tokens or waveform.
|
| 488 |
+
|
| 489 |
+
If `audio` is a 2D array whose shape matches the codebook channels, it is
|
| 490 |
+
treated as audio tokens. Otherwise a codec is required to encode waveform
|
| 491 |
+
prompts into tokens.
|
| 492 |
+
"""
|
| 493 |
+
if isinstance(audio, np.ndarray) and audio.ndim == 2:
|
| 494 |
+
if self.processor.channels in audio.shape:
|
| 495 |
+
self._voice_prompt_tokens = audio
|
| 496 |
+
return
|
| 497 |
+
if isinstance(audio, torch.Tensor) and audio.dim() == 2:
|
| 498 |
+
if self.processor.channels in audio.shape:
|
| 499 |
+
self._voice_prompt_tokens = audio.detach().cpu().numpy()
|
| 500 |
+
return
|
| 501 |
+
|
| 502 |
+
if self.codec is None:
|
| 503 |
+
raise ValueError("codec is required to encode waveform prompts.")
|
| 504 |
+
|
| 505 |
+
waveform = audio
|
| 506 |
+
if isinstance(audio, (str, bytes)):
|
| 507 |
+
wav, sr = torchaudio.load(audio)
|
| 508 |
+
if wav.shape[0] > 1:
|
| 509 |
+
wav = wav.mean(dim=0, keepdim=True)
|
| 510 |
+
waveform = wav.squeeze(0)
|
| 511 |
+
sample_rate = sr
|
| 512 |
+
|
| 513 |
+
if isinstance(waveform, np.ndarray):
|
| 514 |
+
waveform = torch.from_numpy(waveform)
|
| 515 |
+
if not isinstance(waveform, torch.Tensor):
|
| 516 |
+
raise ValueError("Unsupported audio type for voice prompt.")
|
| 517 |
+
|
| 518 |
+
if sample_rate is not None and sample_rate != self.codec_sample_rate:
|
| 519 |
+
waveform = torchaudio.functional.resample(waveform, sample_rate, self.codec_sample_rate)
|
| 520 |
+
|
| 521 |
+
waveform = waveform.to(self.inferencer.device)
|
| 522 |
+
encode_out = self.codec.encode([waveform], **self.codec_encode_kwargs)
|
| 523 |
+
if isinstance(encode_out, dict):
|
| 524 |
+
if "codes_list" in encode_out:
|
| 525 |
+
tokens = encode_out["codes_list"][0]
|
| 526 |
+
elif "audio_codes" in encode_out:
|
| 527 |
+
tokens = encode_out["audio_codes"][0]
|
| 528 |
+
else:
|
| 529 |
+
raise ValueError("codec.encode output missing audio codes.")
|
| 530 |
+
else:
|
| 531 |
+
tokens = encode_out
|
| 532 |
+
if isinstance(tokens, torch.Tensor):
|
| 533 |
+
tokens = tokens.detach().cpu().numpy()
|
| 534 |
+
self._voice_prompt_tokens = tokens
|
| 535 |
+
|
| 536 |
+
def clear_voice_prompt(self):
|
| 537 |
+
self._voice_prompt_tokens = None
|
| 538 |
+
|
| 539 |
+
def reset_turn(
|
| 540 |
+
self,
|
| 541 |
+
user_text: Optional[str] = None,
|
| 542 |
+
user_audio_tokens: Optional[np.ndarray] = None,
|
| 543 |
+
input_ids: Optional[np.ndarray] = None,
|
| 544 |
+
include_system_prompt: Optional[bool] = None,
|
| 545 |
+
reset_cache: bool = False,
|
| 546 |
+
):
|
| 547 |
+
if include_system_prompt is None:
|
| 548 |
+
include_system_prompt = self._turn_idx == 0
|
| 549 |
+
|
| 550 |
+
if input_ids is None:
|
| 551 |
+
if user_text is None or user_audio_tokens is None:
|
| 552 |
+
raise ValueError("user_text and user_audio_tokens are required when input_ids is not provided.")
|
| 553 |
+
user_prompt = self.processor.make_user_prompt(user_text, user_audio_tokens)
|
| 554 |
+
if include_system_prompt:
|
| 555 |
+
system_prompt = self.processor.make_ensemble(self._voice_prompt_tokens)
|
| 556 |
+
input_ids = np.concatenate([system_prompt, user_prompt], axis=0)
|
| 557 |
+
else:
|
| 558 |
+
input_ids = user_prompt
|
| 559 |
+
|
| 560 |
+
self._turn_input_ids = input_ids
|
| 561 |
+
self._turn_idx += 1
|
| 562 |
+
|
| 563 |
+
self._text_cache = ""
|
| 564 |
+
self._pending_tokens = []
|
| 565 |
+
self._prefilled = False
|
| 566 |
+
self._text_ended = False
|
| 567 |
+
|
| 568 |
+
self.inferencer.reset_generation_state(keep_cache=not reset_cache)
|
| 569 |
+
|
| 570 |
+
def push_text_tokens(self, tokens: Iterable[int]) -> list[torch.Tensor]:
|
| 571 |
+
self._pending_tokens.extend([int(t) for t in tokens])
|
| 572 |
+
return self._drain_pending_tokens()
|
| 573 |
+
|
| 574 |
+
def push_text(self, text_fragment: str) -> list[torch.Tensor]:
|
| 575 |
+
self._text_cache += text_fragment
|
| 576 |
+
segments = self._extract_text_segments(force=False)
|
| 577 |
+
for segment in segments:
|
| 578 |
+
self._pending_tokens.extend(self._tokenize(segment))
|
| 579 |
+
return self._drain_pending_tokens()
|
| 580 |
+
|
| 581 |
+
def end_text(self) -> list[torch.Tensor]:
|
| 582 |
+
self._text_ended = True
|
| 583 |
+
if self._text_cache:
|
| 584 |
+
self._pending_tokens.extend(self._tokenize(self._text_cache))
|
| 585 |
+
self._text_cache = ""
|
| 586 |
+
return self._drain_pending_tokens()
|
| 587 |
+
|
| 588 |
+
def drain(self, max_steps: Optional[int] = None) -> list[torch.Tensor]:
|
| 589 |
+
if not self._prefilled:
|
| 590 |
+
return []
|
| 591 |
+
return self.inferencer.finish(
|
| 592 |
+
max_steps=max_steps,
|
| 593 |
+
temperature=self.temperature,
|
| 594 |
+
top_p=self.top_p,
|
| 595 |
+
top_k=self.top_k,
|
| 596 |
+
do_sample=self.do_sample,
|
| 597 |
+
repetition_penalty=self.repetition_penalty,
|
| 598 |
+
repetition_window=self.repetition_window,
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
def _tokenize(self, text: str) -> list[int]:
|
| 602 |
+
return self.tokenizer.encode(text, add_special_tokens=False)
|
| 603 |
+
|
| 604 |
+
def _extract_text_segments(self, force: bool) -> list[str]:
|
| 605 |
+
segments = []
|
| 606 |
+
if force:
|
| 607 |
+
if self._text_cache:
|
| 608 |
+
segments.append(self._text_cache)
|
| 609 |
+
self._text_cache = ""
|
| 610 |
+
return segments
|
| 611 |
+
|
| 612 |
+
while self._text_cache:
|
| 613 |
+
cut_idx = None
|
| 614 |
+
if len(self._text_cache) >= self.min_text_chunk_chars:
|
| 615 |
+
matches = list(self._split_pattern.finditer(self._text_cache))
|
| 616 |
+
for match in matches:
|
| 617 |
+
if match.end() >= self.min_text_chunk_chars:
|
| 618 |
+
cut_idx = match.end()
|
| 619 |
+
break
|
| 620 |
+
if cut_idx is None and len(self._text_cache) >= self.text_buffer_size:
|
| 621 |
+
whitespace_idx = self._text_cache.rfind(" ")
|
| 622 |
+
if whitespace_idx != -1:
|
| 623 |
+
cut_idx = whitespace_idx + 1
|
| 624 |
+
if cut_idx is None:
|
| 625 |
+
break
|
| 626 |
+
segments.append(self._text_cache[:cut_idx])
|
| 627 |
+
self._text_cache = self._text_cache[cut_idx:]
|
| 628 |
+
return segments
|
| 629 |
+
|
| 630 |
+
def _prefill_if_needed(self) -> list[torch.Tensor]:
|
| 631 |
+
if self._prefilled:
|
| 632 |
+
return []
|
| 633 |
+
if not self._pending_tokens and not self._text_ended:
|
| 634 |
+
return []
|
| 635 |
+
if len(self._pending_tokens) < self.prefill_text_len and not self._text_ended:
|
| 636 |
+
return []
|
| 637 |
+
if self._turn_input_ids is None:
|
| 638 |
+
raise ValueError("reset_turn must be called before streaming text.")
|
| 639 |
+
|
| 640 |
+
if self._text_ended:
|
| 641 |
+
prefill_len = len(self._pending_tokens)
|
| 642 |
+
else:
|
| 643 |
+
prefill_len = min(len(self._pending_tokens), self.prefill_text_len)
|
| 644 |
+
|
| 645 |
+
if prefill_len == 0:
|
| 646 |
+
return []
|
| 647 |
+
|
| 648 |
+
prefix_tokens = [self._pending_tokens.pop(0) for _ in range(prefill_len)]
|
| 649 |
+
audio_tokens = self.inferencer.prefill(
|
| 650 |
+
input_ids=[self._turn_input_ids],
|
| 651 |
+
text_prefix_ids=[prefix_tokens],
|
| 652 |
+
temperature=self.temperature,
|
| 653 |
+
top_p=self.top_p,
|
| 654 |
+
top_k=self.top_k,
|
| 655 |
+
do_sample=self.do_sample,
|
| 656 |
+
repetition_penalty=None,
|
| 657 |
+
repetition_window=self.repetition_window,
|
| 658 |
+
)
|
| 659 |
+
self._prefilled = True
|
| 660 |
+
return [audio_tokens]
|
| 661 |
+
|
| 662 |
+
def _drain_pending_tokens(self) -> list[torch.Tensor]:
|
| 663 |
+
outputs: list[torch.Tensor] = []
|
| 664 |
+
outputs.extend(self._prefill_if_needed())
|
| 665 |
+
if not self._prefilled:
|
| 666 |
+
return outputs
|
| 667 |
+
|
| 668 |
+
while self._pending_tokens and not self.inferencer.is_finished:
|
| 669 |
+
token = self._pending_tokens.pop(0)
|
| 670 |
+
outputs.append(
|
| 671 |
+
self.inferencer.step(
|
| 672 |
+
token,
|
| 673 |
+
temperature=self.temperature,
|
| 674 |
+
top_p=self.top_p,
|
| 675 |
+
top_k=self.top_k,
|
| 676 |
+
do_sample=self.do_sample,
|
| 677 |
+
repetition_penalty=self.repetition_penalty,
|
| 678 |
+
repetition_window=self.repetition_window,
|
| 679 |
+
)
|
| 680 |
+
)
|
| 681 |
+
return outputs
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
@requires(backends=("torch",))
|
| 685 |
+
class AudioStreamDecoder:
|
| 686 |
+
"""Decode audio tokens into waveform chunks with optional crossfade."""
|
| 687 |
+
|
| 688 |
+
def __init__(
|
| 689 |
+
self,
|
| 690 |
+
codec,
|
| 691 |
+
chunk_frames: int = 40,
|
| 692 |
+
overlap_frames: int = 4,
|
| 693 |
+
decode_kwargs: Optional[dict] = None,
|
| 694 |
+
device: Optional[torch.device] = None,
|
| 695 |
+
):
|
| 696 |
+
self.codec = codec
|
| 697 |
+
self.chunk_frames = chunk_frames
|
| 698 |
+
self.overlap_frames = overlap_frames
|
| 699 |
+
self.decode_kwargs = decode_kwargs or {}
|
| 700 |
+
self.device = device
|
| 701 |
+
|
| 702 |
+
self._buffer: list[torch.Tensor] = []
|
| 703 |
+
self._buffer_len = 0
|
| 704 |
+
self._prev_tail: Optional[torch.Tensor] = None
|
| 705 |
+
|
| 706 |
+
def push_tokens(self, audio_tokens: np.ndarray | torch.Tensor):
|
| 707 |
+
if isinstance(audio_tokens, np.ndarray):
|
| 708 |
+
audio_tokens = torch.from_numpy(audio_tokens)
|
| 709 |
+
if audio_tokens.dim() != 2:
|
| 710 |
+
raise ValueError(f"Expected [T, C] audio tokens, got {tuple(audio_tokens.shape)}")
|
| 711 |
+
self._buffer.append(audio_tokens)
|
| 712 |
+
self._buffer_len += audio_tokens.shape[0]
|
| 713 |
+
|
| 714 |
+
def audio_chunks(self) -> Iterable[torch.Tensor]:
|
| 715 |
+
while self._buffer_len >= self.chunk_frames:
|
| 716 |
+
chunk_tokens = self._consume_frames(self.chunk_frames)
|
| 717 |
+
wav = self._decode(chunk_tokens, chunk_duration=0.32)
|
| 718 |
+
yield self._apply_crossfade(wav)
|
| 719 |
+
|
| 720 |
+
def flush(self) -> Optional[torch.Tensor]:
|
| 721 |
+
if self._buffer_len == 0:
|
| 722 |
+
return None
|
| 723 |
+
chunk_tokens = self._consume_frames(self._buffer_len)
|
| 724 |
+
wav = self._decode(chunk_tokens)
|
| 725 |
+
return self._apply_crossfade(wav, final_chunk=True)
|
| 726 |
+
|
| 727 |
+
def _consume_frames(self, num_frames: int) -> torch.Tensor:
|
| 728 |
+
frames = []
|
| 729 |
+
remaining = num_frames
|
| 730 |
+
while remaining > 0 and self._buffer:
|
| 731 |
+
head = self._buffer[0]
|
| 732 |
+
if head.shape[0] <= remaining:
|
| 733 |
+
frames.append(head)
|
| 734 |
+
remaining -= head.shape[0]
|
| 735 |
+
self._buffer.pop(0)
|
| 736 |
+
else:
|
| 737 |
+
frames.append(head[:remaining])
|
| 738 |
+
self._buffer[0] = head[remaining:]
|
| 739 |
+
remaining = 0
|
| 740 |
+
self._buffer_len -= num_frames - remaining
|
| 741 |
+
return torch.cat(frames, dim=0)
|
| 742 |
+
|
| 743 |
+
def _decode(self, tokens: torch.Tensor, chunk_duration: float = 0.32) -> torch.Tensor:
|
| 744 |
+
device = self.device
|
| 745 |
+
if device is None:
|
| 746 |
+
if hasattr(self.codec, "device"):
|
| 747 |
+
device = self.codec.device
|
| 748 |
+
else:
|
| 749 |
+
try:
|
| 750 |
+
device = next(self.codec.parameters()).device
|
| 751 |
+
except Exception:
|
| 752 |
+
device = None
|
| 753 |
+
if device is not None:
|
| 754 |
+
tokens = tokens.to(device)
|
| 755 |
+
tokens_t = tokens.permute(1, 0)
|
| 756 |
+
# allow callers to override decode settings (e.g. chunk_duration=-1 to disable internal streaming)
|
| 757 |
+
decode_kwargs = dict(self.decode_kwargs) if self.decode_kwargs else {}
|
| 758 |
+
if "chunk_duration" in decode_kwargs:
|
| 759 |
+
override = decode_kwargs.pop("chunk_duration")
|
| 760 |
+
if override is None:
|
| 761 |
+
chunk_duration_arg = None
|
| 762 |
+
else:
|
| 763 |
+
try:
|
| 764 |
+
override_f = float(override)
|
| 765 |
+
except Exception:
|
| 766 |
+
override_f = None
|
| 767 |
+
chunk_duration_arg = None if override_f is None or override_f <= 0 else override_f
|
| 768 |
+
else:
|
| 769 |
+
chunk_duration_arg = chunk_duration
|
| 770 |
+
|
| 771 |
+
decoded = self.codec.decode(tokens_t, chunk_duration=chunk_duration_arg, **decode_kwargs)
|
| 772 |
+
if isinstance(decoded, dict):
|
| 773 |
+
wav = decoded["audio"][0]
|
| 774 |
+
else:
|
| 775 |
+
wav = decoded
|
| 776 |
+
if isinstance(wav, np.ndarray):
|
| 777 |
+
wav = torch.from_numpy(wav)
|
| 778 |
+
if wav.dim() > 1:
|
| 779 |
+
wav = wav.squeeze(0)
|
| 780 |
+
return wav
|
| 781 |
+
|
| 782 |
+
def _apply_crossfade(self, wav: torch.Tensor, final_chunk: bool = False) -> torch.Tensor:
|
| 783 |
+
if self.overlap_frames <= 0:
|
| 784 |
+
return wav
|
| 785 |
+
if self._prev_tail is None:
|
| 786 |
+
self._prev_tail = wav[-self._overlap_samples(wav) :].clone() if not final_chunk else None
|
| 787 |
+
return wav
|
| 788 |
+
|
| 789 |
+
overlap = self._overlap_samples(wav)
|
| 790 |
+
if overlap == 0:
|
| 791 |
+
return wav
|
| 792 |
+
|
| 793 |
+
prev_tail = self._prev_tail
|
| 794 |
+
if prev_tail.numel() < overlap:
|
| 795 |
+
overlap = prev_tail.numel()
|
| 796 |
+
if overlap == 0:
|
| 797 |
+
return wav
|
| 798 |
+
|
| 799 |
+
fade_out = torch.linspace(1.0, 0.0, overlap, device=wav.device)
|
| 800 |
+
fade_in = 1.0 - fade_out
|
| 801 |
+
cross = prev_tail[-overlap:] * fade_out + wav[:overlap] * fade_in
|
| 802 |
+
merged = torch.cat([prev_tail[:-overlap], cross, wav[overlap:]], dim=-1)
|
| 803 |
+
|
| 804 |
+
self._prev_tail = None if final_chunk else wav[-overlap:].clone()
|
| 805 |
+
return merged
|
| 806 |
+
|
| 807 |
+
def _overlap_samples(self, wav: torch.Tensor) -> int:
|
| 808 |
+
if self.chunk_frames <= 0:
|
| 809 |
+
return 0
|
| 810 |
+
return int(wav.numel() * (self.overlap_frames / self.chunk_frames))
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
class TextDeltaTokenizer:
|
| 814 |
+
"""
|
| 815 |
+
Convert LLM streaming text (delta) into “incremental token IDs”.
|
| 816 |
+
|
| 817 |
+
Notes:
|
| 818 |
+
- The input is a delta that is progressively appended to the same string
|
| 819 |
+
(consistent with the common delta output behavior in vLLM).
|
| 820 |
+
- Each time, re-encode the *full text* with the tokenizer, then take only
|
| 821 |
+
the newly added token IDs.
|
| 822 |
+
- This guarantees that tokenization is consistent with the final complete
|
| 823 |
+
text, avoiding boundary mismatches caused by tokenizing partial segments.
|
| 824 |
+
"""
|
| 825 |
+
|
| 826 |
+
def __init__(self, tokenizer, *, hold_back: int = 3):
|
| 827 |
+
self.tokenizer = tokenizer
|
| 828 |
+
self.hold_back = max(0, int(hold_back))
|
| 829 |
+
self._text = ""
|
| 830 |
+
self._all_ids: list[int] = []
|
| 831 |
+
self._emitted_count: int = 0
|
| 832 |
+
|
| 833 |
+
@property
|
| 834 |
+
def text(self) -> str:
|
| 835 |
+
return self._text
|
| 836 |
+
|
| 837 |
+
@property
|
| 838 |
+
def token_ids(self) -> list[int]:
|
| 839 |
+
return list(self._all_ids)
|
| 840 |
+
|
| 841 |
+
def push_delta(self, delta: str) -> list[int]:
|
| 842 |
+
if not delta:
|
| 843 |
+
return []
|
| 844 |
+
self._text += str(delta)
|
| 845 |
+
self._all_ids = self.tokenizer.encode(self._text, add_special_tokens=False)
|
| 846 |
+
# 留 hold_back 个 token 不输出(尾部可能随后续 delta 而改变)
|
| 847 |
+
stable_count = max(self._emitted_count, len(self._all_ids) - self.hold_back)
|
| 848 |
+
new_ids = self._all_ids[self._emitted_count : stable_count]
|
| 849 |
+
self._emitted_count = stable_count
|
| 850 |
+
return new_ids
|
| 851 |
+
|
| 852 |
+
def flush(self) -> list[int]:
|
| 853 |
+
self._all_ids = self.tokenizer.encode(self._text, add_special_tokens=False)
|
| 854 |
+
remaining = self._all_ids[self._emitted_count :]
|
| 855 |
+
self._emitted_count = len(self._all_ids)
|
| 856 |
+
return remaining
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
def _sanitize_audio_tokens(
|
| 860 |
+
tokens: torch.Tensor,
|
| 861 |
+
*,
|
| 862 |
+
codebook_size: int,
|
| 863 |
+
audio_eos_token: int,
|
| 864 |
+
) -> tuple[torch.Tensor, bool]:
|
| 865 |
+
if tokens.dim() == 1:
|
| 866 |
+
tokens = tokens.unsqueeze(0)
|
| 867 |
+
if tokens.numel() == 0:
|
| 868 |
+
return tokens, False
|
| 869 |
+
|
| 870 |
+
eos_rows = (tokens[:, 0] == audio_eos_token).nonzero(as_tuple=False)
|
| 871 |
+
invalid_rows = ((tokens < 0) | (tokens >= codebook_size)).any(dim=1)
|
| 872 |
+
|
| 873 |
+
stop_idx = None
|
| 874 |
+
if eos_rows.numel() > 0:
|
| 875 |
+
stop_idx = int(eos_rows[0].item())
|
| 876 |
+
if invalid_rows.any():
|
| 877 |
+
invalid_idx = int(invalid_rows.nonzero(as_tuple=False)[0].item())
|
| 878 |
+
stop_idx = invalid_idx if stop_idx is None else min(stop_idx, invalid_idx)
|
| 879 |
+
|
| 880 |
+
if stop_idx is not None:
|
| 881 |
+
return tokens[:stop_idx], True
|
| 882 |
+
return tokens, False
|
| 883 |
+
|
| 884 |
+
|
| 885 |
+
def _maybe_codec_streaming(codec, *, batch_size: int):
|
| 886 |
+
if codec is None or not hasattr(codec, "streaming"):
|
| 887 |
+
return contextlib.nullcontext()
|
| 888 |
+
return codec.streaming(batch_size=batch_size)
|
| 889 |
+
|
| 890 |
+
|
| 891 |
+
@requires(backends=("torch",))
|
| 892 |
+
class MossTTSRealtimeTextStreamBridge:
|
| 893 |
+
"""
|
| 894 |
+
Bridge: external LLM streaming text (delta) -> TTS streaming audio chunks.
|
| 895 |
+
|
| 896 |
+
Usage overview:
|
| 897 |
+
- First configure `MossTTSRealtimeStreamingSession` (especially `prefill_text_len=12`).
|
| 898 |
+
- Provide an `AudioStreamDecoder`, then continuously feed the LLM delta text via
|
| 899 |
+
`push_text_delta()`.
|
| 900 |
+
- Once the accumulated token count reaches `prefill_text_len`, the session will
|
| 901 |
+
start generating audio tokens; the bridge will immediately decode them into WAV
|
| 902 |
+
chunks and yield them.
|
| 903 |
+
"""
|
| 904 |
+
|
| 905 |
+
def __init__(
|
| 906 |
+
self,
|
| 907 |
+
session: MossTTSRealtimeStreamingSession,
|
| 908 |
+
decoder: AudioStreamDecoder,
|
| 909 |
+
*,
|
| 910 |
+
codebook_size: Optional[int] = None,
|
| 911 |
+
audio_eos_token: Optional[int] = None,
|
| 912 |
+
batch_size: int = 1,
|
| 913 |
+
):
|
| 914 |
+
self.session = session
|
| 915 |
+
self.decoder = decoder
|
| 916 |
+
self.batch_size = int(batch_size)
|
| 917 |
+
|
| 918 |
+
if codebook_size is None:
|
| 919 |
+
codebook_size = int(getattr(getattr(session, "codec", None), "codebook_size", 1024))
|
| 920 |
+
if audio_eos_token is None:
|
| 921 |
+
audio_eos_token = int(getattr(session.inferencer, "audio_eos_token", 1026))
|
| 922 |
+
|
| 923 |
+
self.codebook_size = int(codebook_size)
|
| 924 |
+
self.audio_eos_token = int(audio_eos_token)
|
| 925 |
+
|
| 926 |
+
def push_text_delta(self, delta: str) -> Iterator[torch.Tensor]:
|
| 927 |
+
"""
|
| 928 |
+
Push a chunk of incremental text output from the LLM and return newly generated WAV chunks.
|
| 929 |
+
|
| 930 |
+
Internally, this directly calls `session.push_text()`, which segments the text
|
| 931 |
+
based on punctuation/length and then tokenizes the *entire segment* at once,
|
| 932 |
+
avoiding the prefix instability issues of incremental BPE tokenization.
|
| 933 |
+
"""
|
| 934 |
+
audio_frames = self.session.push_text(delta)
|
| 935 |
+
yield from self._decode_audio_frames(audio_frames)
|
| 936 |
+
|
| 937 |
+
def push_text_tokens(self, token_ids: Sequence[int]) -> Iterator[torch.Tensor]:
|
| 938 |
+
if not token_ids:
|
| 939 |
+
return
|
| 940 |
+
audio_frames = self.session.push_text_tokens(token_ids)
|
| 941 |
+
yield from self._decode_audio_frames(audio_frames)
|
| 942 |
+
|
| 943 |
+
def finish(self, *, drain_step: int = 1) -> Iterator[torch.Tensor]:
|
| 944 |
+
audio_frames = self.session.end_text()
|
| 945 |
+
yield from self._decode_audio_frames(audio_frames)
|
| 946 |
+
|
| 947 |
+
while True:
|
| 948 |
+
more_frames = self.session.drain(max_steps=drain_step)
|
| 949 |
+
if not more_frames:
|
| 950 |
+
break
|
| 951 |
+
yield from self._decode_audio_frames(more_frames)
|
| 952 |
+
if self.session.inferencer.is_finished:
|
| 953 |
+
break
|
| 954 |
+
|
| 955 |
+
final = self.decoder.flush()
|
| 956 |
+
if final is not None and final.numel() > 0:
|
| 957 |
+
yield final.detach().cpu()
|
| 958 |
+
|
| 959 |
+
def stream_from_text_deltas(self, deltas: Iterable[str], *, drain_step: int = 1) -> Iterator[torch.Tensor]:
|
| 960 |
+
"""一口气消费一个 delta 迭代器,并持续 yield wav chunk。"""
|
| 961 |
+
with _maybe_codec_streaming(getattr(self.session, "codec", None), batch_size=self.batch_size):
|
| 962 |
+
for delta in deltas:
|
| 963 |
+
yield from self.push_text_delta(delta)
|
| 964 |
+
yield from self.finish(drain_step=drain_step)
|
| 965 |
+
|
| 966 |
+
def _decode_audio_frames(self, audio_frames: list[torch.Tensor]) -> Iterator[torch.Tensor]:
|
| 967 |
+
for frame in audio_frames:
|
| 968 |
+
tokens = frame
|
| 969 |
+
if tokens.dim() == 3:
|
| 970 |
+
tokens = tokens[0]
|
| 971 |
+
if tokens.dim() != 2:
|
| 972 |
+
raise ValueError(f"Expected [B, C] or [1, C] audio tokens, got {tuple(tokens.shape)}")
|
| 973 |
+
if tokens.shape[0] != 1:
|
| 974 |
+
raise ValueError(
|
| 975 |
+
f"This bridge currently supports batch_size=1 for decoding, got batch={tokens.shape[0]}."
|
| 976 |
+
)
|
| 977 |
+
|
| 978 |
+
tokens, stop = _sanitize_audio_tokens(
|
| 979 |
+
tokens,
|
| 980 |
+
codebook_size=self.codebook_size,
|
| 981 |
+
audio_eos_token=self.audio_eos_token,
|
| 982 |
+
)
|
| 983 |
+
if tokens.numel() == 0:
|
| 984 |
+
if stop:
|
| 985 |
+
break
|
| 986 |
+
continue
|
| 987 |
+
|
| 988 |
+
self.decoder.push_tokens(tokens.detach())
|
| 989 |
+
for wav in self.decoder.audio_chunks():
|
| 990 |
+
if wav.numel() == 0:
|
| 991 |
+
continue
|
| 992 |
+
yield wav.detach().cpu()
|
| 993 |
+
if stop:
|
| 994 |
+
break
|
| 995 |
+
|
| 996 |
+
|
| 997 |
+
__all__ = [
|
| 998 |
+
"AudioStreamDecoder",
|
| 999 |
+
"MossTTSRealtimeInference",
|
| 1000 |
+
"MossTTSRealtimeStreamingSession",
|
| 1001 |
+
"MossTTSRealtimeTextStreamBridge",
|
| 1002 |
+
"TextDeltaTokenizer",
|
| 1003 |
+
]
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a474913a182c0ad6297e47244a990659638b3093121fa92a4701fd45bd95c921
|
| 3 |
+
size 11422650
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_prefix_space": false,
|
| 4 |
+
"added_tokens_decoder": {
|
| 5 |
+
"151643": {
|
| 6 |
+
"content": "<|endoftext|>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false,
|
| 11 |
+
"special": true
|
| 12 |
+
},
|
| 13 |
+
"151644": {
|
| 14 |
+
"content": "<|im_start|>",
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"normalized": false,
|
| 17 |
+
"rstrip": false,
|
| 18 |
+
"single_word": false,
|
| 19 |
+
"special": true
|
| 20 |
+
},
|
| 21 |
+
"151645": {
|
| 22 |
+
"content": "<|im_end|>",
|
| 23 |
+
"lstrip": false,
|
| 24 |
+
"normalized": false,
|
| 25 |
+
"rstrip": false,
|
| 26 |
+
"single_word": false,
|
| 27 |
+
"special": true
|
| 28 |
+
},
|
| 29 |
+
"151646": {
|
| 30 |
+
"content": "<|object_ref_start|>",
|
| 31 |
+
"lstrip": false,
|
| 32 |
+
"normalized": false,
|
| 33 |
+
"rstrip": false,
|
| 34 |
+
"single_word": false,
|
| 35 |
+
"special": true
|
| 36 |
+
},
|
| 37 |
+
"151647": {
|
| 38 |
+
"content": "<|object_ref_end|>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false,
|
| 43 |
+
"special": true
|
| 44 |
+
},
|
| 45 |
+
"151648": {
|
| 46 |
+
"content": "<|box_start|>",
|
| 47 |
+
"lstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"rstrip": false,
|
| 50 |
+
"single_word": false,
|
| 51 |
+
"special": true
|
| 52 |
+
},
|
| 53 |
+
"151649": {
|
| 54 |
+
"content": "<|box_end|>",
|
| 55 |
+
"lstrip": false,
|
| 56 |
+
"normalized": false,
|
| 57 |
+
"rstrip": false,
|
| 58 |
+
"single_word": false,
|
| 59 |
+
"special": true
|
| 60 |
+
},
|
| 61 |
+
"151650": {
|
| 62 |
+
"content": "<|quad_start|>",
|
| 63 |
+
"lstrip": false,
|
| 64 |
+
"normalized": false,
|
| 65 |
+
"rstrip": false,
|
| 66 |
+
"single_word": false,
|
| 67 |
+
"special": true
|
| 68 |
+
},
|
| 69 |
+
"151651": {
|
| 70 |
+
"content": "<|quad_end|>",
|
| 71 |
+
"lstrip": false,
|
| 72 |
+
"normalized": false,
|
| 73 |
+
"rstrip": false,
|
| 74 |
+
"single_word": false,
|
| 75 |
+
"special": true
|
| 76 |
+
},
|
| 77 |
+
"151652": {
|
| 78 |
+
"content": "<|audio_start|>",
|
| 79 |
+
"lstrip": false,
|
| 80 |
+
"normalized": false,
|
| 81 |
+
"rstrip": false,
|
| 82 |
+
"single_word": false,
|
| 83 |
+
"special": true
|
| 84 |
+
},
|
| 85 |
+
"151653": {
|
| 86 |
+
"content": "<|audio_end|>",
|
| 87 |
+
"lstrip": false,
|
| 88 |
+
"normalized": false,
|
| 89 |
+
"rstrip": false,
|
| 90 |
+
"single_word": false,
|
| 91 |
+
"special": true
|
| 92 |
+
},
|
| 93 |
+
"151654": {
|
| 94 |
+
"content": "<|audio_pad|>",
|
| 95 |
+
"lstrip": false,
|
| 96 |
+
"normalized": false,
|
| 97 |
+
"rstrip": false,
|
| 98 |
+
"single_word": false,
|
| 99 |
+
"special": true
|
| 100 |
+
},
|
| 101 |
+
"151655": {
|
| 102 |
+
"content": "<|text_pad|>",
|
| 103 |
+
"lstrip": false,
|
| 104 |
+
"normalized": false,
|
| 105 |
+
"rstrip": false,
|
| 106 |
+
"single_word": false,
|
| 107 |
+
"special": true
|
| 108 |
+
},
|
| 109 |
+
"151656": {
|
| 110 |
+
"content": "<|video_pad|>",
|
| 111 |
+
"lstrip": false,
|
| 112 |
+
"normalized": false,
|
| 113 |
+
"rstrip": false,
|
| 114 |
+
"single_word": false,
|
| 115 |
+
"special": true
|
| 116 |
+
},
|
| 117 |
+
"151657": {
|
| 118 |
+
"content": "<tool_call>",
|
| 119 |
+
"lstrip": false,
|
| 120 |
+
"normalized": false,
|
| 121 |
+
"rstrip": false,
|
| 122 |
+
"single_word": false,
|
| 123 |
+
"special": false
|
| 124 |
+
},
|
| 125 |
+
"151658": {
|
| 126 |
+
"content": "</tool_call>",
|
| 127 |
+
"lstrip": false,
|
| 128 |
+
"normalized": false,
|
| 129 |
+
"rstrip": false,
|
| 130 |
+
"single_word": false,
|
| 131 |
+
"special": false
|
| 132 |
+
},
|
| 133 |
+
"151659": {
|
| 134 |
+
"content": "<|fim_prefix|>",
|
| 135 |
+
"lstrip": false,
|
| 136 |
+
"normalized": false,
|
| 137 |
+
"rstrip": false,
|
| 138 |
+
"single_word": false,
|
| 139 |
+
"special": false
|
| 140 |
+
},
|
| 141 |
+
"151660": {
|
| 142 |
+
"content": "<|fim_middle|>",
|
| 143 |
+
"lstrip": false,
|
| 144 |
+
"normalized": false,
|
| 145 |
+
"rstrip": false,
|
| 146 |
+
"single_word": false,
|
| 147 |
+
"special": false
|
| 148 |
+
},
|
| 149 |
+
"151661": {
|
| 150 |
+
"content": "<|fim_suffix|>",
|
| 151 |
+
"lstrip": false,
|
| 152 |
+
"normalized": false,
|
| 153 |
+
"rstrip": false,
|
| 154 |
+
"single_word": false,
|
| 155 |
+
"special": false
|
| 156 |
+
},
|
| 157 |
+
"151662": {
|
| 158 |
+
"content": "<|fim_pad|>",
|
| 159 |
+
"lstrip": false,
|
| 160 |
+
"normalized": false,
|
| 161 |
+
"rstrip": false,
|
| 162 |
+
"single_word": false,
|
| 163 |
+
"special": false
|
| 164 |
+
},
|
| 165 |
+
"151663": {
|
| 166 |
+
"content": "<|repo_name|>",
|
| 167 |
+
"lstrip": false,
|
| 168 |
+
"normalized": false,
|
| 169 |
+
"rstrip": false,
|
| 170 |
+
"single_word": false,
|
| 171 |
+
"special": false
|
| 172 |
+
},
|
| 173 |
+
"151664": {
|
| 174 |
+
"content": "<|file_sep|>",
|
| 175 |
+
"lstrip": false,
|
| 176 |
+
"normalized": false,
|
| 177 |
+
"rstrip": false,
|
| 178 |
+
"single_word": false,
|
| 179 |
+
"special": false
|
| 180 |
+
},
|
| 181 |
+
"151665": {
|
| 182 |
+
"content": "<tool_response>",
|
| 183 |
+
"lstrip": false,
|
| 184 |
+
"normalized": false,
|
| 185 |
+
"rstrip": false,
|
| 186 |
+
"single_word": false,
|
| 187 |
+
"special": false
|
| 188 |
+
},
|
| 189 |
+
"151666": {
|
| 190 |
+
"content": "</tool_response>",
|
| 191 |
+
"lstrip": false,
|
| 192 |
+
"normalized": false,
|
| 193 |
+
"rstrip": false,
|
| 194 |
+
"single_word": false,
|
| 195 |
+
"special": false
|
| 196 |
+
},
|
| 197 |
+
"151667": {
|
| 198 |
+
"content": "<think>",
|
| 199 |
+
"lstrip": false,
|
| 200 |
+
"normalized": false,
|
| 201 |
+
"rstrip": false,
|
| 202 |
+
"single_word": false,
|
| 203 |
+
"special": false
|
| 204 |
+
},
|
| 205 |
+
"151668": {
|
| 206 |
+
"content": "</think>",
|
| 207 |
+
"lstrip": false,
|
| 208 |
+
"normalized": false,
|
| 209 |
+
"rstrip": false,
|
| 210 |
+
"single_word": false,
|
| 211 |
+
"special": false
|
| 212 |
+
}
|
| 213 |
+
},
|
| 214 |
+
"additional_special_tokens": [
|
| 215 |
+
"<|im_start|>",
|
| 216 |
+
"<|im_end|>",
|
| 217 |
+
"<|object_ref_start|>",
|
| 218 |
+
"<|object_ref_end|>",
|
| 219 |
+
"<|box_start|>",
|
| 220 |
+
"<|box_end|>",
|
| 221 |
+
"<|quad_start|>",
|
| 222 |
+
"<|quad_end|>",
|
| 223 |
+
"<|audio_start|>",
|
| 224 |
+
"<|audio_end|>",
|
| 225 |
+
"<|audio_pad|>",
|
| 226 |
+
"<|image_pad|>",
|
| 227 |
+
"<|video_pad|>"
|
| 228 |
+
],
|
| 229 |
+
"bos_token": null,
|
| 230 |
+
"clean_up_tokenization_spaces": false,
|
| 231 |
+
"eos_token": "<|im_end|>",
|
| 232 |
+
"errors": "replace",
|
| 233 |
+
"extra_special_tokens": {},
|
| 234 |
+
"model_max_length": 131072,
|
| 235 |
+
"pad_token": "<|endoftext|>",
|
| 236 |
+
"processor_class": "AsteroidProcessor",
|
| 237 |
+
"split_special_tokens": false,
|
| 238 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 239 |
+
"unk_token": null
|
| 240 |
+
}
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|