Zhyw commited on
Commit
8282490
·
1 Parent(s): 26b54dd
__init__.py ADDED
File without changes
config.json ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MossTTSRealtime"
4
+ ],
5
+ "audio_pad_token": 1024,
6
+ "audio_vocab_size": 1027,
7
+ "dtype": "bfloat16",
8
+ "initializer_range": 0.02,
9
+ "language_config": {
10
+ "_name_or_path": ".",
11
+ "architectures": [
12
+ "Qwen3ForCausalLM"
13
+ ],
14
+ "attention_bias": false,
15
+ "attention_dropout": 0.0,
16
+ "bos_token_id": 151643,
17
+ "dtype": "bfloat16",
18
+ "eos_token_id": 151645,
19
+ "head_dim": 128,
20
+ "hidden_act": "silu",
21
+ "hidden_size": 2048,
22
+ "initializer_range": 0.02,
23
+ "intermediate_size": 6144,
24
+ "layer_types": [
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention",
44
+ "full_attention",
45
+ "full_attention",
46
+ "full_attention",
47
+ "full_attention",
48
+ "full_attention",
49
+ "full_attention",
50
+ "full_attention",
51
+ "full_attention",
52
+ "full_attention"
53
+ ],
54
+ "max_position_embeddings": 40960,
55
+ "max_window_layers": 28,
56
+ "model_type": "qwen3",
57
+ "num_attention_heads": 16,
58
+ "num_hidden_layers": 28,
59
+ "num_key_value_heads": 8,
60
+ "rms_norm_eps": 1e-06,
61
+ "rope_scaling": null,
62
+ "rope_theta": 1000000,
63
+ "sliding_window": null,
64
+ "use_cache": true,
65
+ "use_sliding_window": false,
66
+ "vocab_size": 151936
67
+ },
68
+ "local_config": {
69
+ "attention_bias": false,
70
+ "attention_dropout": 0.0,
71
+ "audio_pad_token": 1024,
72
+ "audio_vocab_size": 1027,
73
+ "head_dim": 128,
74
+ "hidden_act": "silu",
75
+ "hidden_size": 2048,
76
+ "initializer_range": 0.02,
77
+ "intermediate_size": 6144,
78
+ "max_position_embeddings": 33,
79
+ "model_type": "MossTTSRealtimeLocalTransformer",
80
+ "num_attention_heads": 16,
81
+ "num_hidden_layers": 4,
82
+ "num_key_value_heads": 8,
83
+ "rms_norm_eps": 1e-06,
84
+ "rope_theta": 1000000,
85
+ "rvq": 16,
86
+ "tie_word_embeddings": false,
87
+ "use_cache": true
88
+ },
89
+ "model_type": "moss_tts_realtime",
90
+ "reference_audio_pad": 151654,
91
+ "rvq": 16,
92
+ "text_pad": 151655,
93
+ "tie_word_embeddings": false,
94
+ "transformers_version": "4.57.1",
95
+ "vocab_size": 151936
96
+ }
configuration_mossttsrealtime.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """MossTTSRealtimeModel configuration."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from typing import Any
20
+
21
+ from transformers.configuration_utils import PretrainedConfig
22
+ from transformers.models.qwen3 import Qwen3Config
23
+
24
+
25
+ def _ensure_config(cfg: Any, cls: type[PretrainedConfig]) -> PretrainedConfig:
26
+ if isinstance(cfg, cls):
27
+ return cfg
28
+ if cfg is None:
29
+ return cls()
30
+ if isinstance(cfg, dict):
31
+ return cls(**cfg)
32
+ raise TypeError(f"Unsupported config type for {cls.__name__}: {type(cfg)}")
33
+
34
+
35
+ class MossTTSRealtimeLocalTransformerConfig(PretrainedConfig):
36
+ model_type = "moss_tts_realtime_local_transformer"
37
+
38
+ def __init__(
39
+ self,
40
+ head_dim: int = 128,
41
+ use_cache: bool = True,
42
+ hidden_size: int = 2048,
43
+ rms_norm_eps: float = 1e-6,
44
+ num_hidden_layers: int = 4,
45
+ intermediate_size: int = 6144,
46
+ num_attention_heads: int = 16,
47
+ initializer_range: float = 0.02,
48
+ attention_bias: bool = False,
49
+ attention_dropout: float = 0.0,
50
+ max_position_embeddings: int = 33,
51
+ num_key_value_heads: int = 8,
52
+ hidden_act: str = "silu",
53
+ rope_theta: int = 1000000,
54
+ rope_type: str = "linear",
55
+ pad_token_id: int = 1024,
56
+ rope_parameters: dict | None = None,
57
+ **kwargs,
58
+ ):
59
+ super().__init__(**kwargs)
60
+ self.head_dim = head_dim
61
+ self.hidden_size = hidden_size
62
+ self.intermediate_size = intermediate_size
63
+ self.num_hidden_layers = num_hidden_layers
64
+ self.num_attention_heads = num_attention_heads
65
+ self.initializer_range = initializer_range
66
+ self.rms_norm_eps = rms_norm_eps
67
+ self.use_cache = use_cache
68
+ self.hidden_act = hidden_act
69
+ self.rope_theta = rope_theta
70
+ self.rope_type = rope_type
71
+ if rope_parameters is None:
72
+ rope_parameters = {"rope_type": rope_type, "rope_theta": rope_theta, "factor": 1.0}
73
+ self.rope_parameters = rope_parameters
74
+ self.attention_bias = attention_bias
75
+ self.attention_dropout = attention_dropout
76
+ self.num_key_value_heads = num_key_value_heads
77
+ self.max_position_embeddings = max_position_embeddings
78
+ self.pad_token_id = pad_token_id
79
+
80
+ self.audio_pad_token = 1024
81
+ self.audio_vocab_size = 1027
82
+ self.rvq = 16
83
+
84
+
85
+ class MossTTSRealtimeConfig(PretrainedConfig):
86
+ model_type = "moss_tts_realtime"
87
+
88
+ def __init__(
89
+ self,
90
+ language_config: Qwen3Config | dict | None = None,
91
+ local_config: MossTTSRealtimeLocalTransformerConfig | dict | None = None,
92
+ rvq: int = 16,
93
+ audio_pad_token: int = 1024,
94
+ audio_vocab_size: int = 1027,
95
+ reference_audio_pad: int = 151654,
96
+ text_pad: int = 151655,
97
+ initializer_range: float = 0.02,
98
+ **kwargs,
99
+ ):
100
+ super().__init__(**kwargs)
101
+ self.rvq = rvq
102
+ self.initializer_range = initializer_range
103
+ self.audio_pad_token = audio_pad_token
104
+ self.audio_vocab_size = audio_vocab_size
105
+ self.reference_audio_pad = reference_audio_pad
106
+ self.text_pad = text_pad
107
+ self.language_config = _ensure_config(language_config, Qwen3Config)
108
+ self.local_config = _ensure_config(local_config, MossTTSRealtimeLocalTransformerConfig)
109
+
110
+ attn_impl = self._attn_implementation
111
+ self.language_config._attn_implementation = attn_impl
112
+ self.local_config._attn_implementation = attn_impl
113
+
114
+
115
+ __all__ = ["MossTTSRealtimeConfig", "MossTTSRealtimeLocalTransformerConfig"]
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13a7a2a0322fe63984ff982a042373a41f707d95df5963233eff40e068a070d4
3
+ size 4663931664
modeling_mossttsrealtime.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """MossTTSRealtime backbone model."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from dataclasses import dataclass
20
+ from typing import Optional
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+
26
+ from transformers.cache_utils import Cache
27
+ from transformers.modeling_outputs import ModelOutput
28
+ from transformers.modeling_utils import PreTrainedModel
29
+ from transformers.models.qwen3 import Qwen3Model
30
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3DecoderLayer
31
+
32
+ from .configuration_mossttsrealtime import MossTTSRealtimeConfig
33
+ from .modeling_mossttsrealtime_local import MossTTSRealtimeLocalTransformerForCausalLM
34
+
35
+
36
+ class MossTTSRealtimePretrainedModel(PreTrainedModel):
37
+ config_class = MossTTSRealtimeConfig
38
+ config: MossTTSRealtimeConfig
39
+ base_model_prefix = "model"
40
+ supports_gradient_checkpointing = True
41
+ _no_split_modules = ["Qwen3DecoderLayer"]
42
+ _skip_keys_device_placement = ["past_key_values"]
43
+ _supports_sdpa = True
44
+ _supports_flex_attn = True
45
+ _supports_flash_attn = True
46
+ _can_compile_fullgraph = True
47
+ _supports_attention_backend = True
48
+ _can_record_outputs = {
49
+ "hidden_states": Qwen3DecoderLayer,
50
+ "attentions": Qwen3Attention,
51
+ }
52
+
53
+ def _init_weights(self, module):
54
+
55
+ from transformers import initialization as init
56
+
57
+ std = self.config.initializer_range
58
+ if isinstance(module, nn.Linear):
59
+ # module.weight.data.normal_(mean=0.0, std=std)
60
+ init.normal_(module.weight, mean=0.0, std=std)
61
+ if module.bias is not None:
62
+ # module.bias.data.zero_()
63
+ init.zeros_(module.bias)
64
+ elif isinstance(module, nn.Embedding):
65
+ # module.weight.data.normal_(mean=0.0, std=std)
66
+ init.normal_(module.weight, mean=0.0, std=std)
67
+ if module.padding_idx is not None:
68
+ # module.weight.data[module.padding_idx].zero_()
69
+ init.zeros_(module.weight[module.padding_idx])
70
+
71
+
72
+ @dataclass
73
+ class MossTTSRealtimeOutputWithPast(ModelOutput):
74
+ loss: Optional[torch.FloatTensor] = None
75
+ logits: Optional[torch.FloatTensor] = None
76
+ past_key_values: Optional[Cache] = None
77
+ last_hidden_state: Optional[torch.FloatTensor] = None
78
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
79
+ attentions: Optional[tuple[torch.FloatTensor]] = None
80
+ local_loss: Optional[torch.FloatTensor] = None
81
+ local_logits: Optional[torch.FloatTensor] = None
82
+ local_past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None
83
+ local_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
84
+ local_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
85
+ backbone_loss: Optional[torch.FloatTensor] = None
86
+
87
+
88
+ class MossTTSRealtime(MossTTSRealtimePretrainedModel):
89
+ def __init__(self, config: MossTTSRealtimeConfig):
90
+ super().__init__(config)
91
+ self.config = config
92
+ self.embed_tokens = nn.ModuleList([])
93
+ self.embed_tokens.append(
94
+ nn.Embedding(
95
+ config.language_config.vocab_size,
96
+ config.language_config.hidden_size,
97
+ config.language_config.pad_token_id,
98
+ )
99
+ )
100
+ self.audio_vocab_size = self.config.audio_vocab_size
101
+ for _ in range(self.config.rvq):
102
+ self.embed_tokens.append(
103
+ nn.Embedding(self.audio_vocab_size, config.language_config.hidden_size, self.config.audio_pad_token)
104
+ )
105
+ self.language_model = Qwen3Model._from_config(config.language_config)
106
+ self.local_transformer = MossTTSRealtimeLocalTransformerForCausalLM._from_config(config.local_config)
107
+ self.post_init()
108
+
109
+ def get_input_embeddings(self, input_ids):
110
+ if input_ids.device != self.embed_tokens[0].weight.device:
111
+ input_ids = input_ids.to(self.embed_tokens[0].weight.device)
112
+ inputs_embeds = self.embed_tokens[0](input_ids[..., 0])
113
+ for i, embed in enumerate(self.embed_tokens):
114
+ if i == 0:
115
+ continue
116
+ inputs_embeds = inputs_embeds + embed(input_ids[..., i])
117
+ return inputs_embeds
118
+
119
+ def forward(
120
+ self,
121
+ input_ids: Optional[torch.LongTensor] = None,
122
+ attention_mask: Optional[torch.Tensor] = None,
123
+ position_ids: Optional[torch.LongTensor] = None,
124
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
125
+ inputs_embeds: Optional[torch.FloatTensor] = None,
126
+ labels: Optional[torch.LongTensor] = None,
127
+ use_cache: Optional[bool] = False,
128
+ output_attentions: Optional[bool] = None,
129
+ output_hidden_states: Optional[bool] = None,
130
+ return_dict: Optional[bool] = None,
131
+ cache_position: Optional[torch.LongTensor] = None,
132
+ hidden_out_layers: Optional[list] = None,
133
+ **kwargs,
134
+ ):
135
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
136
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
137
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
138
+
139
+ if inputs_embeds is None:
140
+ inputs_embeds = self.get_input_embeddings(input_ids)
141
+
142
+ outputs = self.language_model(
143
+ position_ids=position_ids,
144
+ attention_mask=attention_mask,
145
+ past_key_values=past_key_values,
146
+ inputs_embeds=inputs_embeds,
147
+ use_cache=use_cache,
148
+ output_hidden_states=True,
149
+ cache_position=cache_position,
150
+ **kwargs,
151
+ )
152
+
153
+ loss = None
154
+ local_outputs = None
155
+ if labels is not None:
156
+ audio_labels = labels[:, :, 1:]
157
+ train_mask = ~(audio_labels == -100).all(dim=-1)
158
+ local_input_ids = audio_labels[train_mask][..., : self.config.rvq - 1]
159
+ local_input_ids[local_input_ids == -100] = 1024
160
+ local_input_ids = F.pad(local_input_ids, (1, 0), value=0)
161
+
162
+ train_idx = train_mask.nonzero(as_tuple=True)
163
+ local_hidden_states = outputs[0][train_idx[0], train_idx[1] - 1, :].reshape(
164
+ -1, 1, self.config.local_config.hidden_size
165
+ )
166
+ local_labels = audio_labels[train_mask]
167
+
168
+ local_outputs = self.local_transformer(
169
+ input_ids=local_input_ids,
170
+ backbone_last_hidden_state=local_hidden_states,
171
+ use_cache=use_cache,
172
+ return_dict=True,
173
+ labels=local_labels,
174
+ **kwargs,
175
+ )
176
+ loss = local_outputs.loss
177
+
178
+ return MossTTSRealtimeOutputWithPast(
179
+ loss=loss,
180
+ logits=None,
181
+ past_key_values=outputs.past_key_values,
182
+ last_hidden_state=outputs.last_hidden_state,
183
+ hidden_states=outputs.hidden_states,
184
+ attentions=outputs.attentions,
185
+ local_logits=local_outputs.logits if local_outputs is not None else None,
186
+ local_past_key_values=local_outputs.past_key_values if local_outputs is not None else None,
187
+ local_hidden_states=local_outputs.hidden_states if local_outputs is not None else None,
188
+ local_attentions=local_outputs.attentions if local_outputs is not None else None,
189
+ )
190
+
191
+
192
+ __all__ = ["MossTTSRealtime", "MossTTSRealtimeConfig", "MossTTSRealtimeOutputWithPast", "MossTTSRealtimePretrainedModel", "Qwen3Model"]
modeling_mossttsrealtime_local.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Local transformer used by MossTTSRealtime for RVQ codebook decoding."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from typing import Optional, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
26
+ from transformers.generation import GenerationMixin
27
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
28
+ from transformers.modeling_layers import GradientCheckpointingLayer
29
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
30
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
31
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
32
+ from transformers.masking_utils import create_causal_mask
33
+ from transformers.processing_utils import Unpack
34
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
35
+ from transformers.loss.loss_utils import ForCausalLMLoss
36
+
37
+ from .configuration_mossttsrealtime import MossTTSRealtimeLocalTransformerConfig
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ class MossTTSRealtimeLocalTransformerRMSNorm(nn.Module):
43
+ def __init__(self, hidden_size, eps=1e-6) -> None:
44
+ super().__init__()
45
+ self.weight = nn.Parameter(torch.ones(hidden_size))
46
+ self.variance_epsilon = eps
47
+
48
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
49
+ input_dtype = hidden_states.dtype
50
+ hidden_states = hidden_states.to(torch.float32)
51
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
52
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
53
+ return self.weight * hidden_states.to(input_dtype)
54
+
55
+ def extra_repr(self):
56
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
57
+
58
+
59
+ class MossTTSRealtimeLocalTransformerMLP(nn.Module):
60
+ def __init__(self, config: MossTTSRealtimeLocalTransformerConfig):
61
+ super().__init__()
62
+ self.config = config
63
+ self.hidden_size = config.hidden_size
64
+ self.intermediate_size = config.intermediate_size
65
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
66
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
67
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
68
+ self.act_fn = ACT2FN[config.hidden_act]
69
+
70
+ def forward(self, x):
71
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
72
+ return down_proj
73
+
74
+
75
+ def rotate_half(x):
76
+ x1 = x[..., : x.shape[-1] // 2]
77
+ x2 = x[..., x.shape[-1] // 2 :]
78
+ return torch.cat((-x2, x1), dim=-1)
79
+
80
+
81
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
82
+ cos = cos.unsqueeze(unsqueeze_dim)
83
+ sin = sin.unsqueeze(unsqueeze_dim)
84
+ q_embed = (q * cos) + (rotate_half(q) * sin)
85
+ k_embed = (k * cos) + (rotate_half(k) * sin)
86
+ return q_embed, k_embed
87
+
88
+
89
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
90
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
91
+ if n_rep == 1:
92
+ return hidden_states
93
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
94
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
95
+
96
+
97
+ def eager_attention_forward(
98
+ module: nn.Module,
99
+ query: torch.Tensor,
100
+ key: torch.Tensor,
101
+ value: torch.Tensor,
102
+ attention_mask: Optional[torch.Tensor],
103
+ scaling: float,
104
+ dropout: float = 0.0,
105
+ **kwargs: Unpack[TransformersKwargs],
106
+ ):
107
+ key_states = repeat_kv(key, module.num_key_value_groups)
108
+ value_states = repeat_kv(value, module.num_key_value_groups)
109
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
110
+ if attention_mask is not None:
111
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
112
+ attn_weights = attn_weights + causal_mask
113
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
114
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
115
+ attn_output = torch.matmul(attn_weights, value_states)
116
+ attn_output = attn_output.transpose(1, 2).contiguous()
117
+ return attn_output, attn_weights
118
+
119
+
120
+ class MossTTSRealtimeLocalTransformerAttention(nn.Module):
121
+ def __init__(self, config: MossTTSRealtimeLocalTransformerConfig, layer_idx: int):
122
+ super().__init__()
123
+ self.config = config
124
+ self.layer_idx = layer_idx
125
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
126
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
127
+ self.scaling = self.head_dim**-0.5
128
+ self.attention_dropout = config.attention_dropout
129
+ self.is_causal = True
130
+
131
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias)
132
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
133
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
134
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
135
+ self.q_norm = MossTTSRealtimeLocalTransformerRMSNorm(self.head_dim, eps=config.rms_norm_eps)
136
+ self.k_norm = MossTTSRealtimeLocalTransformerRMSNorm(self.head_dim, eps=config.rms_norm_eps)
137
+ self.sliding_window = None
138
+
139
+ def forward(
140
+ self,
141
+ hidden_states: torch.Tensor,
142
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
143
+ attention_mask: Optional[torch.Tensor],
144
+ past_key_values: Optional[Cache] = None,
145
+ cache_position: Optional[torch.LongTensor] = None,
146
+ **kwargs: Unpack[FlashAttentionKwargs],
147
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
148
+ input_shape = hidden_states.shape[:-1]
149
+ hidden_shape = (*input_shape, -1, self.head_dim)
150
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
151
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
152
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
153
+ cos, sin = position_embeddings
154
+
155
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
156
+
157
+ if past_key_values is not None:
158
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
159
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
160
+
161
+ attention_interface = eager_attention_forward
162
+ if self.config._attn_implementation != "eager":
163
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
164
+
165
+ attn_output, attn_weights = attention_interface(
166
+ self,
167
+ query_states,
168
+ key_states,
169
+ value_states,
170
+ attention_mask,
171
+ dropout=0.0 if not self.training else self.attention_dropout,
172
+ scaling=self.scaling,
173
+ sliding_window=self.sliding_window,
174
+ **kwargs,
175
+ )
176
+
177
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
178
+ attn_output = self.o_proj(attn_output)
179
+ return attn_output, attn_weights
180
+
181
+
182
+ class MossTTSRealtimeLocalTransformerDecoderLayer(GradientCheckpointingLayer):
183
+ def __init__(self, config: MossTTSRealtimeLocalTransformerConfig, layer_idx: int):
184
+ super().__init__()
185
+ self.hidden_size = config.hidden_size
186
+ self.self_attn = MossTTSRealtimeLocalTransformerAttention(config=config, layer_idx=layer_idx)
187
+ self.mlp = MossTTSRealtimeLocalTransformerMLP(config)
188
+ self.input_layernorm = MossTTSRealtimeLocalTransformerRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
189
+ self.post_attention_layernorm = MossTTSRealtimeLocalTransformerRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
190
+ self.attention_type = "full_attention"
191
+
192
+ def forward(
193
+ self,
194
+ hidden_states: torch.Tensor,
195
+ attention_mask: Optional[torch.Tensor] = None,
196
+ position_ids: Optional[torch.LongTensor] = None,
197
+ past_key_values: Optional[Cache] = None,
198
+ use_cache: Optional[bool] = False,
199
+ cache_position: Optional[torch.LongTensor] = None,
200
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
201
+ **kwargs: Unpack[TransformersKwargs],
202
+ ) -> torch.Tensor:
203
+ residual = hidden_states
204
+ hidden_states = self.input_layernorm(hidden_states)
205
+ hidden_states, _ = self.self_attn(
206
+ hidden_states=hidden_states,
207
+ attention_mask=attention_mask,
208
+ position_ids=position_ids,
209
+ past_key_values=past_key_values,
210
+ use_cache=use_cache,
211
+ cache_position=cache_position,
212
+ position_embeddings=position_embeddings,
213
+ **kwargs,
214
+ )
215
+ hidden_states = residual + hidden_states
216
+ residual = hidden_states
217
+ hidden_states = self.post_attention_layernorm(hidden_states)
218
+ hidden_states = self.mlp(hidden_states)
219
+ hidden_states = residual + hidden_states
220
+ return hidden_states
221
+
222
+
223
+ class MossTTSRealtimeLocalTransformerPreTrainedModel(PreTrainedModel):
224
+ config: MossTTSRealtimeLocalTransformerConfig
225
+ base_model_prefix = "local_transformer"
226
+ supports_gradient_checkpointing = True
227
+ _no_split_modules = ["MossTTSRealtimeLocalTransformerDecoderLayer"]
228
+ _skip_keys_device_placement = ["past_key_values"]
229
+ _supports_sdpa = True
230
+ _supports_flex_attn = True
231
+ _supports_flash_attn = True
232
+ _can_compile_fullgraph = True
233
+ _supports_attention_backend = True
234
+ _can_record_outputs = {
235
+ "hidden_states": MossTTSRealtimeLocalTransformerDecoderLayer,
236
+ "attentions": MossTTSRealtimeLocalTransformerAttention,
237
+ }
238
+
239
+
240
+ class MossTTSRealtimeLocalTransformerRotaryEmbedding(nn.Module):
241
+ inv_freq: torch.Tensor
242
+
243
+ def __init__(self, config: MossTTSRealtimeLocalTransformerConfig, device=None):
244
+ super().__init__()
245
+ self.config = config
246
+ self.rope_type = getattr(config, "rope_type", "linear")
247
+ self.max_seq_len_cached = config.max_position_embeddings
248
+ self.original_max_seq_len = config.max_position_embeddings
249
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
250
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
251
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
252
+ self.original_inv_freq = self.inv_freq
253
+
254
+ @torch.no_grad()
255
+ @dynamic_rope_update
256
+ def forward(self, x, position_ids):
257
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
258
+ position_ids_expanded = position_ids[:, None, :].float()
259
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
260
+ with torch.autocast(device_type=device_type, enabled=False):
261
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
262
+ emb = torch.cat((freqs, freqs), dim=-1)
263
+ cos = emb.cos() * self.attention_scaling
264
+ sin = emb.sin() * self.attention_scaling
265
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
266
+
267
+
268
+ class MossTTSRealtimeLocalTransformer(MossTTSRealtimeLocalTransformerPreTrainedModel):
269
+ def __init__(self, config: MossTTSRealtimeLocalTransformerConfig):
270
+ super().__init__(config)
271
+ self.padding_idx = config.pad_token_id
272
+ self.embed_tokens = nn.ModuleList(
273
+ [nn.Embedding(config.audio_vocab_size, config.hidden_size, config.audio_pad_token) for _ in range(config.rvq - 1)]
274
+ )
275
+ self.layers = nn.ModuleList(
276
+ [MossTTSRealtimeLocalTransformerDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
277
+ )
278
+ self.norm = MossTTSRealtimeLocalTransformerRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
279
+ self.rotary_emb = MossTTSRealtimeLocalTransformerRotaryEmbedding(config=config)
280
+ self.gradient_checkpointing = False
281
+ self.has_sliding_layers = None
282
+ self.post_init()
283
+
284
+ def forward(
285
+ self,
286
+ input_ids: Optional[torch.LongTensor] = None,
287
+ backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
288
+ attention_mask: Optional[torch.Tensor] = None,
289
+ position_ids: Optional[torch.LongTensor] = None,
290
+ past_key_values: Optional[Cache] = None,
291
+ inputs_embeds: Optional[torch.FloatTensor] = None,
292
+ use_cache: Optional[bool] = None,
293
+ cache_position: Optional[torch.LongTensor] = None,
294
+ codebook_idx: Optional[int] = None,
295
+ **kwargs: Unpack[TransformersKwargs],
296
+ ) -> BaseModelOutputWithPast:
297
+ if position_ids is not None and not torch.compiler.is_compiling():
298
+ position_ids = None
299
+
300
+ if (input_ids is None) ^ (inputs_embeds is not None):
301
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
302
+
303
+ if use_cache and past_key_values is None:
304
+ past_key_values = StaticCache(config=self.config, max_cache_len=16, device=inputs_embeds.device)
305
+
306
+ if cache_position is None:
307
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
308
+ inputs_seq_length = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
309
+ device = inputs_embeds.device if inputs_embeds is not None else input_ids.device
310
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_seq_length, device=device)
311
+
312
+ if inputs_embeds is None:
313
+ if codebook_idx is not None:
314
+ if input_ids.ndim == 1:
315
+ input_ids = input_ids.unsqueeze(1)
316
+ token_emb = self.embed_tokens[codebook_idx - 1](input_ids[:, 0]).unsqueeze(1) # [B,1,H]
317
+ inputs_embeds = token_emb
318
+ else:
319
+ codebook_idxs = torch.clamp(cache_position - 1, min=0)
320
+ inputs_embeds = self.embed_tokens[codebook_idxs - 1](input_ids)
321
+
322
+ input_ids_are_first_codebook = cache_position[0] == 0
323
+ if backbone_last_hidden_state is not None:
324
+ inputs_embeds[:, 0] = backbone_last_hidden_state
325
+ else:
326
+ if not torch.compiler.is_compiling() and input_ids_are_first_codebook:
327
+ logger.warning(
328
+ "When the first codebook token is provided, `backbone_last_hidden_state` should also be provided for correct inference."
329
+ )
330
+
331
+ causal_mask = create_causal_mask(
332
+ config=self.config,
333
+ input_embeds=inputs_embeds,
334
+ attention_mask=attention_mask,
335
+ cache_position=cache_position,
336
+ past_key_values=past_key_values,
337
+ position_ids=position_ids,
338
+ )
339
+
340
+ hidden_states = inputs_embeds
341
+ position_ids = cache_position.unsqueeze(0)
342
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
343
+
344
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
345
+ hidden_states = decoder_layer(
346
+ hidden_states,
347
+ attention_mask=causal_mask,
348
+ position_ids=position_ids,
349
+ past_key_values=past_key_values,
350
+ use_cache=use_cache,
351
+ cache_position=cache_position,
352
+ position_embeddings=position_embeddings,
353
+ **kwargs,
354
+ )
355
+ hidden_states = self.norm(hidden_states)
356
+ return BaseModelOutputWithPast(
357
+ last_hidden_state=hidden_states,
358
+ past_key_values=past_key_values if use_cache else None,
359
+ )
360
+
361
+
362
+ class MossTTSRealtimeLocalTransformerForCausalLM(MossTTSRealtimeLocalTransformerPreTrainedModel, GenerationMixin):
363
+ _tied_weights_keys = None
364
+ _tp_plan = None
365
+ _pp_plan = None
366
+
367
+ def __init__(self, config):
368
+ super().__init__(config)
369
+ self.model = MossTTSRealtimeLocalTransformer(config)
370
+ self.audio_vocab_size = self.config.audio_vocab_size
371
+
372
+ self.local_lm_heads = nn.ModuleList(
373
+ [nn.Linear(config.hidden_size, config.audio_vocab_size, bias=False) for _ in range(config.rvq)]
374
+ )
375
+ self.post_init()
376
+
377
+ def forward(
378
+ self,
379
+ input_ids: Optional[torch.LongTensor] = None,
380
+ backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
381
+ attention_mask: Optional[torch.Tensor] = None,
382
+ position_ids: Optional[torch.LongTensor] = None,
383
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
384
+ inputs_embeds: Optional[torch.FloatTensor] = None,
385
+ labels: Optional[torch.LongTensor] = None,
386
+ use_cache: Optional[bool] = None,
387
+ cache_position: Optional[torch.LongTensor] = None,
388
+ codebook_idx: Optional[int] = None,
389
+ logits_to_keep: Union[int, torch.Tensor] = 0,
390
+ **kwargs: Unpack[TransformersKwargs],
391
+ ) -> Union[tuple, CausalLMOutputWithPast]:
392
+ outputs = self.model(
393
+ input_ids=input_ids,
394
+ backbone_last_hidden_state=backbone_last_hidden_state,
395
+ inputs_embeds=inputs_embeds,
396
+ attention_mask=attention_mask,
397
+ position_ids=position_ids,
398
+ past_key_values=past_key_values,
399
+ use_cache=use_cache,
400
+ cache_position=cache_position,
401
+ codebook_idx=codebook_idx,
402
+ **kwargs,
403
+ )
404
+
405
+ hidden_states = outputs.last_hidden_state
406
+
407
+ if isinstance(logits_to_keep, int):
408
+ if logits_to_keep == 0:
409
+ slice_indices = slice(0, None)
410
+ else:
411
+ slice_indices = slice(-logits_to_keep, None)
412
+ else:
413
+ slice_indices = logits_to_keep
414
+ hs = hidden_states[:, slice_indices, :]
415
+
416
+ if cache_position is not None:
417
+ logits = self.local_lm_heads[codebook_idx](hs[:, 0, :]).unsqueeze(1)
418
+ else:
419
+ logits_list = []
420
+ for i in range(hs.shape[1]):
421
+ logits_list.append(self.local_lm_heads[i](hs[:, i, :]))
422
+ logits = torch.stack(logits_list, dim=1)
423
+
424
+ logits = logits.contiguous()
425
+ loss = None
426
+ if labels is not None:
427
+ loss = ForCausalLMLoss(logits, None, self.audio_vocab_size, shift_labels=labels.contiguous())
428
+
429
+ return CausalLMOutputWithPast(
430
+ loss=loss,
431
+ logits=logits,
432
+ past_key_values=outputs.past_key_values,
433
+ hidden_states=outputs.hidden_states,
434
+ attentions=outputs.attentions,
435
+ )
436
+
437
+
438
+
439
+
440
+ __all__ = [
441
+ "MossTTSRealtimeLocalTransformer",
442
+ "MossTTSRealtimeLocalTransformerAttention",
443
+ "MossTTSRealtimeLocalTransformerConfig",
444
+ "MossTTSRealtimeLocalTransformerDecoderLayer",
445
+ "MossTTSRealtimeLocalTransformerForCausalLM",
446
+ "MossTTSRealtimeLocalTransformerPreTrainedModel",
447
+ "MossTTSRealtimeLocalTransformerRMSNorm",
448
+ "MossTTSRealtimeLocalTransformerRotaryEmbedding",
449
+ ]
processing_mossttsrealtime.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Processing utilities for MossTTSRealtime."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from typing import Iterable, Optional
20
+
21
+ import numpy as np
22
+
23
+
24
+ class MossTTSRealtimeProcessor:
25
+ """Builds MossTTSRealtime prompt inputs with text and audio codebooks.
26
+
27
+ This processor focuses on preparing the mixed text/audio token layout expected by MossTTSRealtime.
28
+ It does not perform audio encoding/decoding by itself.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ tokenizer,
34
+ audio_pad_token: str = "<|audio_pad|>",
35
+ text_pad_token: str = "<|text_pad|>",
36
+ tts_system_prompt: Optional[str] = None,
37
+ channels: int = 16,
38
+ audio_channel_pad: int = 1024,
39
+ audio_bos_token: int = 1025,
40
+ audio_eos_token: int = 1026,
41
+ delay_tokens_len: int = 12,
42
+ ):
43
+ self.tokenizer = tokenizer
44
+ self.channels = channels
45
+ self.audio_channel_pad = audio_channel_pad
46
+ self.audio_bos_token = audio_bos_token
47
+ self.audio_eos_token = audio_eos_token
48
+ self.delay_tokens_len = delay_tokens_len
49
+
50
+ self.audio_pad_token_id = self._convert_token_to_id(audio_pad_token)
51
+ self.text_pad_token_id = self._convert_token_to_id(text_pad_token)
52
+
53
+ if tts_system_prompt is None:
54
+ tts_system_prompt = (
55
+ "<|im_start|>system\n"
56
+ "You are a highly expressive text-to-speech (TTS) engine developed by Mosi Intelligence. \n"
57
+ "You possess natural language understanding, emotional modeling, and multi-style speech generation "
58
+ "capabilities, allowing you to generate the corresponding speech based on the text given in the assistant."
59
+ "<|im_end|>\n"
60
+ )
61
+ self.ttsbase_system_prompt = tts_system_prompt
62
+
63
+ def _convert_token_to_id(self, token: str) -> int:
64
+ if hasattr(self.tokenizer, "convert_tokens_to_ids"):
65
+ token_id = self.tokenizer.convert_tokens_to_ids(token)
66
+ if token_id is not None and token_id != self.tokenizer.unk_token_id:
67
+ return int(token_id)
68
+ token_ids = self.tokenizer.encode(token, add_special_tokens=False)
69
+ if not token_ids:
70
+ raise ValueError(f"Token '{token}' could not be converted to an id.")
71
+ if len(token_ids) != 1:
72
+ raise ValueError(f"Token '{token}' maps to multiple ids: {token_ids}")
73
+ return int(token_ids[0])
74
+
75
+ def make_voice_clone_prompt(self, prompt_audio_tokens_len: int) -> str:
76
+ padded_audio_prompt = f"{'<|audio_pad|>' * prompt_audio_tokens_len}"
77
+ voice_clone = (
78
+ "<|im_start|>context\n"
79
+ "The assistant section should be synthesized using the following voice timbre:"
80
+ f"{padded_audio_prompt}"
81
+ )
82
+ return voice_clone
83
+
84
+ def _normalize_audio_tokens(self, audio_tokens: np.ndarray | Iterable) -> np.ndarray:
85
+ tokens = np.array(audio_tokens)
86
+ if tokens.ndim != 2:
87
+ raise ValueError(f"Expected 2D audio tokens, got shape {tokens.shape}")
88
+ if tokens.shape[0] == self.channels:
89
+ tokens = tokens.T
90
+ elif tokens.shape[1] == self.channels:
91
+ tokens = tokens
92
+ elif tokens.shape[0] > self.channels and tokens.shape[1] != self.channels:
93
+ tokens = tokens[: self.channels, :].T
94
+ elif tokens.shape[1] > self.channels and tokens.shape[0] != self.channels:
95
+ tokens = tokens[:, : self.channels]
96
+ if tokens.shape[1] != self.channels:
97
+ raise ValueError(f"Expected {self.channels} channels, got shape {tokens.shape}")
98
+ return tokens
99
+
100
+ def make_ensemble(self, prompt_audio_tokens: Optional[np.ndarray] = None) -> np.ndarray:
101
+ if prompt_audio_tokens is not None:
102
+ prompt_audio_tokens = self._normalize_audio_tokens(prompt_audio_tokens)
103
+ prompt_audio_tokens = prompt_audio_tokens[:, : self.channels]
104
+ system_prompt_text = f"{self.ttsbase_system_prompt}" + f"{self.make_voice_clone_prompt(prompt_audio_tokens.shape[0])}"
105
+ else:
106
+ system_prompt_text = f"{self.ttsbase_system_prompt}"
107
+
108
+ system_prompt_tokens = self.tokenizer(system_prompt_text)["input_ids"]
109
+ system_prompt_tokens_full = np.full(
110
+ shape=(len(system_prompt_tokens), self.channels + 1), fill_value=self.audio_channel_pad, dtype=np.int64
111
+ )
112
+ system_prompt_tokens_full[:, 0] = system_prompt_tokens
113
+
114
+ if prompt_audio_tokens is not None:
115
+ system_prompt_tokens = np.array(system_prompt_tokens)
116
+ indices = np.where(system_prompt_tokens == self.audio_pad_token_id)[0]
117
+ if indices.size == 0:
118
+ raise ValueError("No <|audio_pad|> tokens found in the system prompt.")
119
+ prompt_audio_start_pos, prompt_audio_end_pos = indices[0], indices[-1]
120
+ system_prompt_tokens_full[prompt_audio_start_pos : prompt_audio_end_pos + 1, 1:] = prompt_audio_tokens
121
+
122
+ return system_prompt_tokens_full
123
+
124
+ def make_user_prompt(self, text: str, audio_tokens: np.ndarray) -> np.ndarray:
125
+ prefill_temp = "<|im_end|>\n<|im_start|>user\n"
126
+ text_tokens = self.tokenizer(text)["input_ids"]
127
+ text_start_pos = len(self.tokenizer.encode(prefill_temp))
128
+ token = self._normalize_audio_tokens(audio_tokens)
129
+
130
+ text_len = len(text_tokens)
131
+ audio_len = token.shape[0]
132
+
133
+ if text_len >= self.delay_tokens_len:
134
+ padded_text_len = audio_len + self.delay_tokens_len - text_len + 1
135
+ cur_input_id_ch1 = prefill_temp + text + "<|text_pad|>" * padded_text_len
136
+ assistant_tokens_ch1 = self.tokenizer(cur_input_id_ch1)["input_ids"]
137
+ cur_input_id = np.full(
138
+ shape=(len(assistant_tokens_ch1), self.channels + 1),
139
+ fill_value=self.audio_channel_pad,
140
+ dtype=np.int64,
141
+ )
142
+ cur_input_id[:, 0] = assistant_tokens_ch1
143
+ cur_input_id[
144
+ text_start_pos + self.delay_tokens_len : text_start_pos + self.delay_tokens_len + audio_len, 1:
145
+ ] = token
146
+ cur_input_id[text_start_pos + self.delay_tokens_len - 1, 1] = self.audio_bos_token
147
+ cur_input_id[text_start_pos + self.delay_tokens_len + audio_len, 1] = self.audio_eos_token
148
+ else:
149
+ padded_text_len = audio_len + 1
150
+ cur_input_id_ch1 = prefill_temp + text + "<|text_pad|>" * padded_text_len
151
+ assistant_tokens_ch1 = self.tokenizer(cur_input_id_ch1)["input_ids"]
152
+ cur_input_id = np.full(
153
+ shape=(len(assistant_tokens_ch1), self.channels + 1),
154
+ fill_value=self.audio_channel_pad,
155
+ dtype=np.int64,
156
+ )
157
+ cur_input_id[:, 0] = assistant_tokens_ch1
158
+ cur_input_id[-(audio_len + 1) : -1, 1:] = token
159
+ cur_input_id[-(audio_len + 2), 1] = self.audio_bos_token
160
+ cur_input_id[-1, 1] = self.audio_eos_token
161
+
162
+ begin_of_response = self.tokenizer.encode("<|im_end|>\n<|im_start|>assistant\n")
163
+ begin_of_response_full = np.full(
164
+ shape=(len(begin_of_response), self.channels + 1), fill_value=self.audio_channel_pad, dtype=np.int64
165
+ )
166
+ begin_of_response_full[:, 0] = begin_of_response
167
+
168
+ input_ids = np.concatenate([cur_input_id, begin_of_response_full], axis=0)
169
+ return input_ids
170
+
171
+
172
+ __all__ = ["MossTTSRealtimeProcessor"]
streaming_mossttsrealtime.py ADDED
@@ -0,0 +1,1003 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Streaming inference utilities for MossTTSRealtime."""
16
+
17
+ from __future__ import annotations
18
+
19
+
20
+ import re
21
+ import numpy as np
22
+ import contextlib
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ import torchaudio
27
+ from transformers.cache_utils import StaticCache
28
+ from transformers.utils.import_utils import requires
29
+ from typing import Iterable, Iterator, List, Optional, Sequence
30
+
31
+ @requires(backends=("torch",))
32
+ class MossTTSRealtimeInference:
33
+ """Step-wise inference wrapper for MossTTSRealtime.
34
+
35
+ This class mirrors the non-streaming inference logic but exposes a
36
+ prefill/step/finish API for streaming usage.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ model,
42
+ tokenizer,
43
+ max_length: int = 1000,
44
+ channels: int = 16,
45
+ audio_channel_pad: int = 1024,
46
+ audio_bos_token: int = 1025,
47
+ audio_eos_token: int = 1026,
48
+ text_pad_id: int = 151655,
49
+ aud_pad_id: int = 151654,
50
+ ):
51
+ self.model = model
52
+ self.tokenizer = tokenizer
53
+ self.max_length = max_length
54
+ self.channels = channels
55
+ self.audio_channel_pad = audio_channel_pad
56
+ self.audio_bos_token = audio_bos_token
57
+ self.audio_eos_token = audio_eos_token
58
+ self.text_pad_id = text_pad_id
59
+ self.aud_pad_id = aud_pad_id
60
+
61
+ self.past_key_values = None
62
+ self.attention_mask = None
63
+ self._generated_tokens: List[torch.Tensor] = []
64
+ self._is_stopping = None
65
+ self._last_audio_tokens = None
66
+ self._step_idx = 0
67
+
68
+ @property
69
+ def device(self):
70
+ return next(self.model.parameters()).device
71
+
72
+ @property
73
+ def is_finished(self) -> bool:
74
+ return self._is_stopping is not None and bool(self._is_stopping.all())
75
+
76
+ def reset_generation_state(self, keep_cache: bool = True):
77
+ # When keep_cache=True, retain the attention_mask so that its length matches past_key_values.
78
+ # This is used for concatenation in the next prefill step.
79
+ if not keep_cache:
80
+ self.past_key_values = None
81
+ self.attention_mask = None
82
+ self._generated_tokens = []
83
+ self._is_stopping = None
84
+ self._last_audio_tokens = None
85
+ self._step_idx = 0
86
+
87
+ def _normalize_input_ids(self, input_ids):
88
+ if isinstance(input_ids, torch.Tensor):
89
+ input_ids = input_ids.detach().cpu().numpy()
90
+ if isinstance(input_ids, np.ndarray):
91
+ if input_ids.ndim == 2:
92
+ return [input_ids]
93
+ if input_ids.ndim == 3:
94
+ return [input_ids[i] for i in range(input_ids.shape[0])]
95
+ if isinstance(input_ids, (list, tuple)):
96
+ return [np.array(item) for item in input_ids]
97
+ raise ValueError("input_ids must be a list/array/tensor of shape [T, C] or [B, T, C].")
98
+
99
+ def _normalize_text_prefix(self, text_prefix_ids, batch_size: int) -> list[list[int]]:
100
+ if text_prefix_ids is None:
101
+ raise ValueError("text_prefix_ids must be provided for prefill.")
102
+ if isinstance(text_prefix_ids, torch.Tensor):
103
+ text_prefix_ids = text_prefix_ids.detach().cpu().tolist()
104
+ if isinstance(text_prefix_ids, np.ndarray):
105
+ text_prefix_ids = text_prefix_ids.tolist()
106
+ if isinstance(text_prefix_ids, list):
107
+ if len(text_prefix_ids) == 0:
108
+ return [[] for _ in range(batch_size)]
109
+ if isinstance(text_prefix_ids[0], (int, np.integer)):
110
+ return [list(text_prefix_ids)]
111
+ if len(text_prefix_ids) == 1 and batch_size > 1:
112
+ return [list(text_prefix_ids[0]) for _ in range(batch_size)]
113
+ if len(text_prefix_ids) != batch_size:
114
+ raise ValueError(
115
+ f"text_prefix_ids batch size mismatch: got {len(text_prefix_ids)}, expected {batch_size}."
116
+ )
117
+ return [list(item) for item in text_prefix_ids]
118
+ raise ValueError("text_prefix_ids must be list-like or tensor-like.")
119
+
120
+ @torch.inference_mode()
121
+ def prefill(
122
+ self,
123
+ input_ids,
124
+ text_prefix_ids,
125
+ max_prefill_len: Optional[int] = None,
126
+ past_key_values=None,
127
+ device: Optional[torch.device] = None,
128
+ temperature: float = 0.8,
129
+ top_p: float = 0.6,
130
+ top_k: int = 30,
131
+ do_sample: bool = True,
132
+ repetition_penalty: Optional[float] = 1.1,
133
+ repetition_window: Optional[int] = 50,
134
+ ) -> torch.Tensor:
135
+ if device is None:
136
+ device = self.device
137
+
138
+ if past_key_values is not None:
139
+ self.past_key_values = past_key_values
140
+
141
+ input_ids_list = self._normalize_input_ids(input_ids)
142
+ batch_size = len(input_ids_list)
143
+ text_prefix_list = self._normalize_text_prefix(text_prefix_ids, batch_size)
144
+
145
+ concat_inputs_id_list = []
146
+ for i in range(batch_size):
147
+ prefix = text_prefix_list[i]
148
+ if max_prefill_len is not None:
149
+ prefix = prefix[:max_prefill_len]
150
+ if len(prefix) == 0:
151
+ raise ValueError("Prefill requires at least one text token.")
152
+
153
+ text_seg = np.full((len(prefix), self.channels + 1), self.audio_channel_pad, dtype=np.int64)
154
+ text_seg[:, 0] = np.array(prefix, dtype=np.int64)
155
+ text_seg[len(prefix) - 1, 1] = self.audio_bos_token
156
+ concat_inputs_id = np.concatenate([input_ids_list[i], text_seg], axis=0)
157
+ concat_inputs_id_list.append(concat_inputs_id)
158
+
159
+ attention_masks = [np.ones(ids.shape[0], dtype=np.bool_) for ids in concat_inputs_id_list]
160
+ max_len = max(ids.shape[0] for ids in concat_inputs_id_list)
161
+ padded_input_ids, padded_attns = [], []
162
+ pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.text_pad_id
163
+
164
+ for ids, attn in zip(concat_inputs_id_list, attention_masks):
165
+ pad_len = max_len - ids.shape[0]
166
+ input_pad = np.full((pad_len, self.channels + 1), self.audio_channel_pad, dtype=np.int64)
167
+ input_pad[:, 0] = pad_token_id
168
+ padded_input_ids.append(np.concatenate([input_pad, ids]))
169
+ attn_pad = np.zeros(pad_len, dtype=np.bool_)
170
+ padded_attns.append(np.concatenate([attn_pad, attn]))
171
+
172
+ current_input_ids = torch.from_numpy(np.stack(padded_input_ids)).to(device)
173
+ current_attention_mask = torch.from_numpy(np.stack(padded_attns)).to(device)
174
+
175
+ if self.attention_mask is not None and self.past_key_values is not None:
176
+ current_attention_mask = torch.cat([self.attention_mask, current_attention_mask], dim=-1)
177
+
178
+ outputs = self.model(
179
+ input_ids=current_input_ids,
180
+ attention_mask=current_attention_mask,
181
+ past_key_values=self.past_key_values,
182
+ use_cache=True,
183
+ return_dict=True,
184
+ )
185
+ self.past_key_values = outputs.past_key_values
186
+ self.attention_mask = current_attention_mask
187
+
188
+ backbone_hidden_states = outputs.last_hidden_state[:, -1:, :]
189
+ audio_tokens = self.generate_local_transformer(
190
+ hidden_states=backbone_hidden_states,
191
+ temperature=temperature,
192
+ top_p=top_p,
193
+ top_k=top_k,
194
+ do_sample=do_sample,
195
+ repetition_penalty=repetition_penalty,
196
+ repetition_window=repetition_window,
197
+ generated_tokens=None,
198
+ gen_step=0,
199
+ )
200
+
201
+ self._generated_tokens = [audio_tokens]
202
+ self._last_audio_tokens = audio_tokens
203
+ self._is_stopping = audio_tokens[:, 0] == self.audio_eos_token
204
+ self._step_idx = 1
205
+ return audio_tokens
206
+
207
+ @torch.inference_mode()
208
+ def step(
209
+ self,
210
+ text_token: Optional[Iterable[int] | torch.Tensor | int],
211
+ temperature: float = 0.8,
212
+ top_p: float = 0.6,
213
+ top_k: int = 30,
214
+ do_sample: bool = True,
215
+ repetition_penalty: Optional[float] = 1.1,
216
+ repetition_window: Optional[int] = 50,
217
+ ) -> torch.Tensor:
218
+ if self._last_audio_tokens is None or self.attention_mask is None:
219
+ raise ValueError("You must call prefill() before step().")
220
+ if self.is_finished:
221
+ return self._last_audio_tokens
222
+
223
+ batch_size = self._last_audio_tokens.shape[0]
224
+ if text_token is None:
225
+ text_tokens = [self.text_pad_id] * batch_size
226
+ elif isinstance(text_token, torch.Tensor):
227
+ text_tokens = text_token.detach().cpu().tolist()
228
+ elif isinstance(text_token, (list, tuple, np.ndarray)):
229
+ text_tokens = list(text_token)
230
+ else:
231
+ text_tokens = [int(text_token)]
232
+
233
+ if len(text_tokens) != batch_size:
234
+ raise ValueError(f"text_token batch size mismatch: got {len(text_tokens)}, expected {batch_size}.")
235
+
236
+ device = self._last_audio_tokens.device
237
+ text_t = torch.tensor(text_tokens, device=device, dtype=torch.long)
238
+ step_ids = torch.cat([text_t[:, None, None], self._last_audio_tokens.unsqueeze(1)], dim=2)
239
+ self.attention_mask = torch.cat([self.attention_mask, (~self._is_stopping).unsqueeze(-1)], dim=-1)
240
+
241
+ outputs = self.model(
242
+ input_ids=step_ids,
243
+ attention_mask=self.attention_mask,
244
+ past_key_values=self.past_key_values,
245
+ use_cache=True,
246
+ return_dict=True,
247
+ )
248
+ self.past_key_values = outputs.past_key_values
249
+ backbone_hidden_states = outputs.last_hidden_state[:, -1:, :]
250
+
251
+ history = torch.stack(self._generated_tokens, dim=1) if self._generated_tokens else None
252
+ audio_tokens = self.generate_local_transformer(
253
+ hidden_states=backbone_hidden_states,
254
+ temperature=temperature,
255
+ top_p=top_p,
256
+ top_k=top_k,
257
+ do_sample=do_sample,
258
+ repetition_penalty=repetition_penalty,
259
+ repetition_window=repetition_window,
260
+ generated_tokens=history,
261
+ gen_step=self._step_idx,
262
+ )
263
+
264
+ self._generated_tokens.append(audio_tokens)
265
+ self._last_audio_tokens = audio_tokens
266
+ self._is_stopping |= audio_tokens[:, 0] == self.audio_eos_token
267
+ self._step_idx += 1
268
+ return audio_tokens
269
+
270
+ @torch.inference_mode()
271
+ def finish(
272
+ self,
273
+ max_steps: Optional[int] = None,
274
+ temperature: float = 0.8,
275
+ top_p: float = 0.6,
276
+ top_k: int = 30,
277
+ do_sample: bool = True,
278
+ repetition_penalty: Optional[float] = 1.1,
279
+ repetition_window: Optional[int] = 50,
280
+ ) -> list[torch.Tensor]:
281
+ outputs = []
282
+ steps_left = max_steps if max_steps is not None else self.max_length
283
+ while steps_left > 0 and not self.is_finished:
284
+ outputs.append(
285
+ self.step(
286
+ text_token=None,
287
+ temperature=temperature,
288
+ top_p=top_p,
289
+ top_k=top_k,
290
+ do_sample=do_sample,
291
+ repetition_penalty=repetition_penalty,
292
+ repetition_window=repetition_window,
293
+ )
294
+ )
295
+ steps_left -= 1
296
+ return outputs
297
+
298
+ @torch.compile(fullgraph=True)
299
+ def generate_local_transformer(
300
+ self,
301
+ hidden_states: torch.Tensor,
302
+ temperature: float,
303
+ top_p: float,
304
+ top_k: int,
305
+ do_sample: bool,
306
+ repetition_penalty: Optional[float],
307
+ repetition_window: Optional[int],
308
+ generated_tokens: Optional[torch.Tensor],
309
+ gen_step: int,
310
+ ) -> torch.Tensor:
311
+ batch_size = hidden_states.shape[0]
312
+ device = hidden_states.device
313
+ local_inputs = hidden_states.reshape(-1, 1, self.model.config.local_config.hidden_size)
314
+ output_token = torch.empty(batch_size, self.channels, dtype=torch.long, device=device)
315
+
316
+ past_key_values = StaticCache(config=self.model.local_transformer.config, max_cache_len=self.channels)
317
+ local_token = None
318
+
319
+ cache_pos_t = torch.zeros(1, dtype=torch.long, device=device)
320
+
321
+ for i in range(self.channels):
322
+ cache_pos_t.fill_(i)
323
+
324
+ local_outputs = self.model.local_transformer(
325
+ input_ids=local_token,
326
+ inputs_embeds=local_inputs,
327
+ past_key_values=past_key_values,
328
+ cache_position=cache_pos_t,
329
+ codebook_idx=i,
330
+ use_cache=True,
331
+ logits_to_keep=1,
332
+ )
333
+ logits = local_outputs.logits
334
+
335
+ if repetition_penalty and repetition_penalty != 1.0 and generated_tokens is not None:
336
+ logits = self.apply_repetition_penalty(
337
+ scores=logits,
338
+ history_tokens=generated_tokens[:, :gen_step, i],
339
+ penalty=float(repetition_penalty),
340
+ repetition_window=repetition_window,
341
+ )
342
+
343
+ local_token = self.sample_token(
344
+ logits=logits,
345
+ temperature=temperature,
346
+ top_p=top_p,
347
+ top_k=top_k,
348
+ do_sample=do_sample,
349
+ )
350
+ output_token[:, i] = local_token.squeeze(-1)
351
+
352
+ if i == 0:
353
+ local_inputs = None
354
+ return output_token
355
+
356
+ def apply_repetition_penalty(
357
+ self,
358
+ scores: torch.Tensor,
359
+ history_tokens: torch.Tensor,
360
+ penalty: float = 1.1,
361
+ repetition_window: Optional[int] = None,
362
+ ):
363
+ scores_ = scores[:, 0, :]
364
+ B, V = scores_.shape
365
+ ht = history_tokens
366
+
367
+ if repetition_window is not None and repetition_window > 0:
368
+ ht = ht[:, -repetition_window:]
369
+
370
+ ht_sorted, _ = torch.sort(ht, dim=1)
371
+ uniq = torch.unique_consecutive(ht_sorted, dim=1)
372
+
373
+ b_idx = torch.arange(B, device=uniq.device).unsqueeze(1).expand_as(uniq)
374
+ b_flat = b_idx.reshape(-1)
375
+ t_flat = uniq.reshape(-1)
376
+
377
+ cur = scores_[b_flat, t_flat]
378
+ new = torch.where(cur < 0, cur * penalty, cur / penalty)
379
+
380
+ scores_[b_flat, t_flat] = new
381
+
382
+ return scores_
383
+
384
+ def sample_token(self, logits, temperature, top_p=0.6, top_k=30, do_sample=True):
385
+ if not do_sample or temperature == 0:
386
+ return torch.argmax(logits, dim=-1)
387
+ logits = logits / temperature
388
+ original_shape = logits.shape
389
+ vocab_size = original_shape[-1]
390
+ reshaped_logits = logits.reshape(-1, vocab_size)
391
+
392
+ if top_k is not None:
393
+ reshaped_logits = self.apply_top_k(reshaped_logits, top_k)
394
+
395
+ if top_p is not None:
396
+ reshaped_logits = self.apply_top_p(reshaped_logits, top_p)
397
+
398
+ probs = F.softmax(reshaped_logits, dim=-1)
399
+ next_tokens_flat = torch.multinomial(probs, num_samples=1)
400
+
401
+ output_shape = original_shape[:-1]
402
+ return next_tokens_flat.view(output_shape)
403
+
404
+ def apply_top_k(self, logits, top_k, filter_value=float("-inf"), min_tokens_to_keep: int = 1):
405
+ if not isinstance(top_k, int) or top_k <= 0:
406
+ raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
407
+ batch_size, vocab_size = logits.shape
408
+ top_k = max(top_k, min_tokens_to_keep)
409
+ top_k = min(top_k, vocab_size)
410
+ indices_to_remove = torch.topk(logits, top_k, dim=-1).values[..., -1, None]
411
+ return logits.masked_fill(logits < indices_to_remove, filter_value)
412
+
413
+ def apply_top_p(self, logits, top_p, filter_value=float("-inf"), min_tokens_to_keep: int = 1):
414
+ top_p = float(top_p)
415
+ if top_p < 0 or top_p > 1.0:
416
+ raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
417
+
418
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
419
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
420
+
421
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
422
+ sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
423
+ indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter(1, sorted_indices, sorted_indices_to_remove)
424
+ logits_processed = logits.masked_fill(indices_to_remove, filter_value)
425
+ return logits_processed
426
+
427
+
428
+ @requires(backends=("torch",))
429
+ class MossTTSRealtimeStreamingSession:
430
+ """Manage text-to-audio streaming for a single conversation."""
431
+
432
+ _split_pattern = re.compile(
433
+ r"[。!?!?\.\u2026]\s*"
434
+ r"|[,,;;::\u2014\u2013\-]\s*"
435
+ r"|\)\s*|\]\s*"
436
+ r"|\n"
437
+ )
438
+
439
+ def __init__(
440
+ self,
441
+ inferencer: MossTTSRealtimeInference,
442
+ processor,
443
+ codec=None,
444
+ codec_sample_rate: int = 24000,
445
+ codec_encode_kwargs: Optional[dict] = None,
446
+ prefill_text_len: int = 12,
447
+ text_buffer_size: int = 32,
448
+ min_text_chunk_chars: int = 8,
449
+ temperature: float = 0.8,
450
+ top_p: float = 0.6,
451
+ top_k: int = 30,
452
+ do_sample: bool = True,
453
+ repetition_penalty: Optional[float] = 1.1,
454
+ repetition_window: Optional[int] = 50,
455
+ ):
456
+ self.inferencer = inferencer
457
+ self.processor = processor
458
+ self.tokenizer = processor.tokenizer
459
+ self.codec = codec
460
+ self.codec_sample_rate = codec_sample_rate
461
+ self.codec_encode_kwargs = codec_encode_kwargs or {}
462
+
463
+ self.prefill_text_len = prefill_text_len
464
+ self.text_buffer_size = text_buffer_size
465
+ self.min_text_chunk_chars = min_text_chunk_chars
466
+
467
+ self.temperature = temperature
468
+ self.top_p = top_p
469
+ self.top_k = top_k
470
+ self.do_sample = do_sample
471
+ self.repetition_penalty = repetition_penalty
472
+ self.repetition_window = repetition_window
473
+
474
+ self._voice_prompt_tokens = None
475
+ self._turn_input_ids = None
476
+ self._turn_idx = 0
477
+
478
+ self._text_cache = ""
479
+ self._pending_tokens: list[int] = []
480
+ self._prefilled = False
481
+ self._text_ended = False
482
+
483
+ def set_voice_prompt_tokens(self, audio_tokens: np.ndarray):
484
+ self._voice_prompt_tokens = audio_tokens
485
+
486
+ def set_voice_prompt(self, audio, sample_rate: Optional[int] = None):
487
+ """Set voice prompt from either audio tokens or waveform.
488
+
489
+ If `audio` is a 2D array whose shape matches the codebook channels, it is
490
+ treated as audio tokens. Otherwise a codec is required to encode waveform
491
+ prompts into tokens.
492
+ """
493
+ if isinstance(audio, np.ndarray) and audio.ndim == 2:
494
+ if self.processor.channels in audio.shape:
495
+ self._voice_prompt_tokens = audio
496
+ return
497
+ if isinstance(audio, torch.Tensor) and audio.dim() == 2:
498
+ if self.processor.channels in audio.shape:
499
+ self._voice_prompt_tokens = audio.detach().cpu().numpy()
500
+ return
501
+
502
+ if self.codec is None:
503
+ raise ValueError("codec is required to encode waveform prompts.")
504
+
505
+ waveform = audio
506
+ if isinstance(audio, (str, bytes)):
507
+ wav, sr = torchaudio.load(audio)
508
+ if wav.shape[0] > 1:
509
+ wav = wav.mean(dim=0, keepdim=True)
510
+ waveform = wav.squeeze(0)
511
+ sample_rate = sr
512
+
513
+ if isinstance(waveform, np.ndarray):
514
+ waveform = torch.from_numpy(waveform)
515
+ if not isinstance(waveform, torch.Tensor):
516
+ raise ValueError("Unsupported audio type for voice prompt.")
517
+
518
+ if sample_rate is not None and sample_rate != self.codec_sample_rate:
519
+ waveform = torchaudio.functional.resample(waveform, sample_rate, self.codec_sample_rate)
520
+
521
+ waveform = waveform.to(self.inferencer.device)
522
+ encode_out = self.codec.encode([waveform], **self.codec_encode_kwargs)
523
+ if isinstance(encode_out, dict):
524
+ if "codes_list" in encode_out:
525
+ tokens = encode_out["codes_list"][0]
526
+ elif "audio_codes" in encode_out:
527
+ tokens = encode_out["audio_codes"][0]
528
+ else:
529
+ raise ValueError("codec.encode output missing audio codes.")
530
+ else:
531
+ tokens = encode_out
532
+ if isinstance(tokens, torch.Tensor):
533
+ tokens = tokens.detach().cpu().numpy()
534
+ self._voice_prompt_tokens = tokens
535
+
536
+ def clear_voice_prompt(self):
537
+ self._voice_prompt_tokens = None
538
+
539
+ def reset_turn(
540
+ self,
541
+ user_text: Optional[str] = None,
542
+ user_audio_tokens: Optional[np.ndarray] = None,
543
+ input_ids: Optional[np.ndarray] = None,
544
+ include_system_prompt: Optional[bool] = None,
545
+ reset_cache: bool = False,
546
+ ):
547
+ if include_system_prompt is None:
548
+ include_system_prompt = self._turn_idx == 0
549
+
550
+ if input_ids is None:
551
+ if user_text is None or user_audio_tokens is None:
552
+ raise ValueError("user_text and user_audio_tokens are required when input_ids is not provided.")
553
+ user_prompt = self.processor.make_user_prompt(user_text, user_audio_tokens)
554
+ if include_system_prompt:
555
+ system_prompt = self.processor.make_ensemble(self._voice_prompt_tokens)
556
+ input_ids = np.concatenate([system_prompt, user_prompt], axis=0)
557
+ else:
558
+ input_ids = user_prompt
559
+
560
+ self._turn_input_ids = input_ids
561
+ self._turn_idx += 1
562
+
563
+ self._text_cache = ""
564
+ self._pending_tokens = []
565
+ self._prefilled = False
566
+ self._text_ended = False
567
+
568
+ self.inferencer.reset_generation_state(keep_cache=not reset_cache)
569
+
570
+ def push_text_tokens(self, tokens: Iterable[int]) -> list[torch.Tensor]:
571
+ self._pending_tokens.extend([int(t) for t in tokens])
572
+ return self._drain_pending_tokens()
573
+
574
+ def push_text(self, text_fragment: str) -> list[torch.Tensor]:
575
+ self._text_cache += text_fragment
576
+ segments = self._extract_text_segments(force=False)
577
+ for segment in segments:
578
+ self._pending_tokens.extend(self._tokenize(segment))
579
+ return self._drain_pending_tokens()
580
+
581
+ def end_text(self) -> list[torch.Tensor]:
582
+ self._text_ended = True
583
+ if self._text_cache:
584
+ self._pending_tokens.extend(self._tokenize(self._text_cache))
585
+ self._text_cache = ""
586
+ return self._drain_pending_tokens()
587
+
588
+ def drain(self, max_steps: Optional[int] = None) -> list[torch.Tensor]:
589
+ if not self._prefilled:
590
+ return []
591
+ return self.inferencer.finish(
592
+ max_steps=max_steps,
593
+ temperature=self.temperature,
594
+ top_p=self.top_p,
595
+ top_k=self.top_k,
596
+ do_sample=self.do_sample,
597
+ repetition_penalty=self.repetition_penalty,
598
+ repetition_window=self.repetition_window,
599
+ )
600
+
601
+ def _tokenize(self, text: str) -> list[int]:
602
+ return self.tokenizer.encode(text, add_special_tokens=False)
603
+
604
+ def _extract_text_segments(self, force: bool) -> list[str]:
605
+ segments = []
606
+ if force:
607
+ if self._text_cache:
608
+ segments.append(self._text_cache)
609
+ self._text_cache = ""
610
+ return segments
611
+
612
+ while self._text_cache:
613
+ cut_idx = None
614
+ if len(self._text_cache) >= self.min_text_chunk_chars:
615
+ matches = list(self._split_pattern.finditer(self._text_cache))
616
+ for match in matches:
617
+ if match.end() >= self.min_text_chunk_chars:
618
+ cut_idx = match.end()
619
+ break
620
+ if cut_idx is None and len(self._text_cache) >= self.text_buffer_size:
621
+ whitespace_idx = self._text_cache.rfind(" ")
622
+ if whitespace_idx != -1:
623
+ cut_idx = whitespace_idx + 1
624
+ if cut_idx is None:
625
+ break
626
+ segments.append(self._text_cache[:cut_idx])
627
+ self._text_cache = self._text_cache[cut_idx:]
628
+ return segments
629
+
630
+ def _prefill_if_needed(self) -> list[torch.Tensor]:
631
+ if self._prefilled:
632
+ return []
633
+ if not self._pending_tokens and not self._text_ended:
634
+ return []
635
+ if len(self._pending_tokens) < self.prefill_text_len and not self._text_ended:
636
+ return []
637
+ if self._turn_input_ids is None:
638
+ raise ValueError("reset_turn must be called before streaming text.")
639
+
640
+ if self._text_ended:
641
+ prefill_len = len(self._pending_tokens)
642
+ else:
643
+ prefill_len = min(len(self._pending_tokens), self.prefill_text_len)
644
+
645
+ if prefill_len == 0:
646
+ return []
647
+
648
+ prefix_tokens = [self._pending_tokens.pop(0) for _ in range(prefill_len)]
649
+ audio_tokens = self.inferencer.prefill(
650
+ input_ids=[self._turn_input_ids],
651
+ text_prefix_ids=[prefix_tokens],
652
+ temperature=self.temperature,
653
+ top_p=self.top_p,
654
+ top_k=self.top_k,
655
+ do_sample=self.do_sample,
656
+ repetition_penalty=None,
657
+ repetition_window=self.repetition_window,
658
+ )
659
+ self._prefilled = True
660
+ return [audio_tokens]
661
+
662
+ def _drain_pending_tokens(self) -> list[torch.Tensor]:
663
+ outputs: list[torch.Tensor] = []
664
+ outputs.extend(self._prefill_if_needed())
665
+ if not self._prefilled:
666
+ return outputs
667
+
668
+ while self._pending_tokens and not self.inferencer.is_finished:
669
+ token = self._pending_tokens.pop(0)
670
+ outputs.append(
671
+ self.inferencer.step(
672
+ token,
673
+ temperature=self.temperature,
674
+ top_p=self.top_p,
675
+ top_k=self.top_k,
676
+ do_sample=self.do_sample,
677
+ repetition_penalty=self.repetition_penalty,
678
+ repetition_window=self.repetition_window,
679
+ )
680
+ )
681
+ return outputs
682
+
683
+
684
+ @requires(backends=("torch",))
685
+ class AudioStreamDecoder:
686
+ """Decode audio tokens into waveform chunks with optional crossfade."""
687
+
688
+ def __init__(
689
+ self,
690
+ codec,
691
+ chunk_frames: int = 40,
692
+ overlap_frames: int = 4,
693
+ decode_kwargs: Optional[dict] = None,
694
+ device: Optional[torch.device] = None,
695
+ ):
696
+ self.codec = codec
697
+ self.chunk_frames = chunk_frames
698
+ self.overlap_frames = overlap_frames
699
+ self.decode_kwargs = decode_kwargs or {}
700
+ self.device = device
701
+
702
+ self._buffer: list[torch.Tensor] = []
703
+ self._buffer_len = 0
704
+ self._prev_tail: Optional[torch.Tensor] = None
705
+
706
+ def push_tokens(self, audio_tokens: np.ndarray | torch.Tensor):
707
+ if isinstance(audio_tokens, np.ndarray):
708
+ audio_tokens = torch.from_numpy(audio_tokens)
709
+ if audio_tokens.dim() != 2:
710
+ raise ValueError(f"Expected [T, C] audio tokens, got {tuple(audio_tokens.shape)}")
711
+ self._buffer.append(audio_tokens)
712
+ self._buffer_len += audio_tokens.shape[0]
713
+
714
+ def audio_chunks(self) -> Iterable[torch.Tensor]:
715
+ while self._buffer_len >= self.chunk_frames:
716
+ chunk_tokens = self._consume_frames(self.chunk_frames)
717
+ wav = self._decode(chunk_tokens, chunk_duration=0.32)
718
+ yield self._apply_crossfade(wav)
719
+
720
+ def flush(self) -> Optional[torch.Tensor]:
721
+ if self._buffer_len == 0:
722
+ return None
723
+ chunk_tokens = self._consume_frames(self._buffer_len)
724
+ wav = self._decode(chunk_tokens)
725
+ return self._apply_crossfade(wav, final_chunk=True)
726
+
727
+ def _consume_frames(self, num_frames: int) -> torch.Tensor:
728
+ frames = []
729
+ remaining = num_frames
730
+ while remaining > 0 and self._buffer:
731
+ head = self._buffer[0]
732
+ if head.shape[0] <= remaining:
733
+ frames.append(head)
734
+ remaining -= head.shape[0]
735
+ self._buffer.pop(0)
736
+ else:
737
+ frames.append(head[:remaining])
738
+ self._buffer[0] = head[remaining:]
739
+ remaining = 0
740
+ self._buffer_len -= num_frames - remaining
741
+ return torch.cat(frames, dim=0)
742
+
743
+ def _decode(self, tokens: torch.Tensor, chunk_duration: float = 0.32) -> torch.Tensor:
744
+ device = self.device
745
+ if device is None:
746
+ if hasattr(self.codec, "device"):
747
+ device = self.codec.device
748
+ else:
749
+ try:
750
+ device = next(self.codec.parameters()).device
751
+ except Exception:
752
+ device = None
753
+ if device is not None:
754
+ tokens = tokens.to(device)
755
+ tokens_t = tokens.permute(1, 0)
756
+ # allow callers to override decode settings (e.g. chunk_duration=-1 to disable internal streaming)
757
+ decode_kwargs = dict(self.decode_kwargs) if self.decode_kwargs else {}
758
+ if "chunk_duration" in decode_kwargs:
759
+ override = decode_kwargs.pop("chunk_duration")
760
+ if override is None:
761
+ chunk_duration_arg = None
762
+ else:
763
+ try:
764
+ override_f = float(override)
765
+ except Exception:
766
+ override_f = None
767
+ chunk_duration_arg = None if override_f is None or override_f <= 0 else override_f
768
+ else:
769
+ chunk_duration_arg = chunk_duration
770
+
771
+ decoded = self.codec.decode(tokens_t, chunk_duration=chunk_duration_arg, **decode_kwargs)
772
+ if isinstance(decoded, dict):
773
+ wav = decoded["audio"][0]
774
+ else:
775
+ wav = decoded
776
+ if isinstance(wav, np.ndarray):
777
+ wav = torch.from_numpy(wav)
778
+ if wav.dim() > 1:
779
+ wav = wav.squeeze(0)
780
+ return wav
781
+
782
+ def _apply_crossfade(self, wav: torch.Tensor, final_chunk: bool = False) -> torch.Tensor:
783
+ if self.overlap_frames <= 0:
784
+ return wav
785
+ if self._prev_tail is None:
786
+ self._prev_tail = wav[-self._overlap_samples(wav) :].clone() if not final_chunk else None
787
+ return wav
788
+
789
+ overlap = self._overlap_samples(wav)
790
+ if overlap == 0:
791
+ return wav
792
+
793
+ prev_tail = self._prev_tail
794
+ if prev_tail.numel() < overlap:
795
+ overlap = prev_tail.numel()
796
+ if overlap == 0:
797
+ return wav
798
+
799
+ fade_out = torch.linspace(1.0, 0.0, overlap, device=wav.device)
800
+ fade_in = 1.0 - fade_out
801
+ cross = prev_tail[-overlap:] * fade_out + wav[:overlap] * fade_in
802
+ merged = torch.cat([prev_tail[:-overlap], cross, wav[overlap:]], dim=-1)
803
+
804
+ self._prev_tail = None if final_chunk else wav[-overlap:].clone()
805
+ return merged
806
+
807
+ def _overlap_samples(self, wav: torch.Tensor) -> int:
808
+ if self.chunk_frames <= 0:
809
+ return 0
810
+ return int(wav.numel() * (self.overlap_frames / self.chunk_frames))
811
+
812
+
813
+ class TextDeltaTokenizer:
814
+ """
815
+ Convert LLM streaming text (delta) into “incremental token IDs”.
816
+
817
+ Notes:
818
+ - The input is a delta that is progressively appended to the same string
819
+ (consistent with the common delta output behavior in vLLM).
820
+ - Each time, re-encode the *full text* with the tokenizer, then take only
821
+ the newly added token IDs.
822
+ - This guarantees that tokenization is consistent with the final complete
823
+ text, avoiding boundary mismatches caused by tokenizing partial segments.
824
+ """
825
+
826
+ def __init__(self, tokenizer, *, hold_back: int = 3):
827
+ self.tokenizer = tokenizer
828
+ self.hold_back = max(0, int(hold_back))
829
+ self._text = ""
830
+ self._all_ids: list[int] = []
831
+ self._emitted_count: int = 0
832
+
833
+ @property
834
+ def text(self) -> str:
835
+ return self._text
836
+
837
+ @property
838
+ def token_ids(self) -> list[int]:
839
+ return list(self._all_ids)
840
+
841
+ def push_delta(self, delta: str) -> list[int]:
842
+ if not delta:
843
+ return []
844
+ self._text += str(delta)
845
+ self._all_ids = self.tokenizer.encode(self._text, add_special_tokens=False)
846
+ # 留 hold_back 个 token 不输出(尾部可能随后续 delta 而改变)
847
+ stable_count = max(self._emitted_count, len(self._all_ids) - self.hold_back)
848
+ new_ids = self._all_ids[self._emitted_count : stable_count]
849
+ self._emitted_count = stable_count
850
+ return new_ids
851
+
852
+ def flush(self) -> list[int]:
853
+ self._all_ids = self.tokenizer.encode(self._text, add_special_tokens=False)
854
+ remaining = self._all_ids[self._emitted_count :]
855
+ self._emitted_count = len(self._all_ids)
856
+ return remaining
857
+
858
+
859
+ def _sanitize_audio_tokens(
860
+ tokens: torch.Tensor,
861
+ *,
862
+ codebook_size: int,
863
+ audio_eos_token: int,
864
+ ) -> tuple[torch.Tensor, bool]:
865
+ if tokens.dim() == 1:
866
+ tokens = tokens.unsqueeze(0)
867
+ if tokens.numel() == 0:
868
+ return tokens, False
869
+
870
+ eos_rows = (tokens[:, 0] == audio_eos_token).nonzero(as_tuple=False)
871
+ invalid_rows = ((tokens < 0) | (tokens >= codebook_size)).any(dim=1)
872
+
873
+ stop_idx = None
874
+ if eos_rows.numel() > 0:
875
+ stop_idx = int(eos_rows[0].item())
876
+ if invalid_rows.any():
877
+ invalid_idx = int(invalid_rows.nonzero(as_tuple=False)[0].item())
878
+ stop_idx = invalid_idx if stop_idx is None else min(stop_idx, invalid_idx)
879
+
880
+ if stop_idx is not None:
881
+ return tokens[:stop_idx], True
882
+ return tokens, False
883
+
884
+
885
+ def _maybe_codec_streaming(codec, *, batch_size: int):
886
+ if codec is None or not hasattr(codec, "streaming"):
887
+ return contextlib.nullcontext()
888
+ return codec.streaming(batch_size=batch_size)
889
+
890
+
891
+ @requires(backends=("torch",))
892
+ class MossTTSRealtimeTextStreamBridge:
893
+ """
894
+ Bridge: external LLM streaming text (delta) -> TTS streaming audio chunks.
895
+
896
+ Usage overview:
897
+ - First configure `MossTTSRealtimeStreamingSession` (especially `prefill_text_len=12`).
898
+ - Provide an `AudioStreamDecoder`, then continuously feed the LLM delta text via
899
+ `push_text_delta()`.
900
+ - Once the accumulated token count reaches `prefill_text_len`, the session will
901
+ start generating audio tokens; the bridge will immediately decode them into WAV
902
+ chunks and yield them.
903
+ """
904
+
905
+ def __init__(
906
+ self,
907
+ session: MossTTSRealtimeStreamingSession,
908
+ decoder: AudioStreamDecoder,
909
+ *,
910
+ codebook_size: Optional[int] = None,
911
+ audio_eos_token: Optional[int] = None,
912
+ batch_size: int = 1,
913
+ ):
914
+ self.session = session
915
+ self.decoder = decoder
916
+ self.batch_size = int(batch_size)
917
+
918
+ if codebook_size is None:
919
+ codebook_size = int(getattr(getattr(session, "codec", None), "codebook_size", 1024))
920
+ if audio_eos_token is None:
921
+ audio_eos_token = int(getattr(session.inferencer, "audio_eos_token", 1026))
922
+
923
+ self.codebook_size = int(codebook_size)
924
+ self.audio_eos_token = int(audio_eos_token)
925
+
926
+ def push_text_delta(self, delta: str) -> Iterator[torch.Tensor]:
927
+ """
928
+ Push a chunk of incremental text output from the LLM and return newly generated WAV chunks.
929
+
930
+ Internally, this directly calls `session.push_text()`, which segments the text
931
+ based on punctuation/length and then tokenizes the *entire segment* at once,
932
+ avoiding the prefix instability issues of incremental BPE tokenization.
933
+ """
934
+ audio_frames = self.session.push_text(delta)
935
+ yield from self._decode_audio_frames(audio_frames)
936
+
937
+ def push_text_tokens(self, token_ids: Sequence[int]) -> Iterator[torch.Tensor]:
938
+ if not token_ids:
939
+ return
940
+ audio_frames = self.session.push_text_tokens(token_ids)
941
+ yield from self._decode_audio_frames(audio_frames)
942
+
943
+ def finish(self, *, drain_step: int = 1) -> Iterator[torch.Tensor]:
944
+ audio_frames = self.session.end_text()
945
+ yield from self._decode_audio_frames(audio_frames)
946
+
947
+ while True:
948
+ more_frames = self.session.drain(max_steps=drain_step)
949
+ if not more_frames:
950
+ break
951
+ yield from self._decode_audio_frames(more_frames)
952
+ if self.session.inferencer.is_finished:
953
+ break
954
+
955
+ final = self.decoder.flush()
956
+ if final is not None and final.numel() > 0:
957
+ yield final.detach().cpu()
958
+
959
+ def stream_from_text_deltas(self, deltas: Iterable[str], *, drain_step: int = 1) -> Iterator[torch.Tensor]:
960
+ """一口气消费一个 delta 迭代器,并持续 yield wav chunk。"""
961
+ with _maybe_codec_streaming(getattr(self.session, "codec", None), batch_size=self.batch_size):
962
+ for delta in deltas:
963
+ yield from self.push_text_delta(delta)
964
+ yield from self.finish(drain_step=drain_step)
965
+
966
+ def _decode_audio_frames(self, audio_frames: list[torch.Tensor]) -> Iterator[torch.Tensor]:
967
+ for frame in audio_frames:
968
+ tokens = frame
969
+ if tokens.dim() == 3:
970
+ tokens = tokens[0]
971
+ if tokens.dim() != 2:
972
+ raise ValueError(f"Expected [B, C] or [1, C] audio tokens, got {tuple(tokens.shape)}")
973
+ if tokens.shape[0] != 1:
974
+ raise ValueError(
975
+ f"This bridge currently supports batch_size=1 for decoding, got batch={tokens.shape[0]}."
976
+ )
977
+
978
+ tokens, stop = _sanitize_audio_tokens(
979
+ tokens,
980
+ codebook_size=self.codebook_size,
981
+ audio_eos_token=self.audio_eos_token,
982
+ )
983
+ if tokens.numel() == 0:
984
+ if stop:
985
+ break
986
+ continue
987
+
988
+ self.decoder.push_tokens(tokens.detach())
989
+ for wav in self.decoder.audio_chunks():
990
+ if wav.numel() == 0:
991
+ continue
992
+ yield wav.detach().cpu()
993
+ if stop:
994
+ break
995
+
996
+
997
+ __all__ = [
998
+ "AudioStreamDecoder",
999
+ "MossTTSRealtimeInference",
1000
+ "MossTTSRealtimeStreamingSession",
1001
+ "MossTTSRealtimeTextStreamBridge",
1002
+ "TextDeltaTokenizer",
1003
+ ]
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a474913a182c0ad6297e47244a990659638b3093121fa92a4701fd45bd95c921
3
+ size 11422650
tokenizer_config.json ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|audio_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|audio_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|audio_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|text_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|audio_start|>",
224
+ "<|audio_end|>",
225
+ "<|audio_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|endoftext|>",
236
+ "processor_class": "AsteroidProcessor",
237
+ "split_special_tokens": false,
238
+ "tokenizer_class": "Qwen2Tokenizer",
239
+ "unk_token": null
240
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff