Zhyw0 commited on
Commit
2590432
·
1 Parent(s): 692d70e

Add mossttsrealtime model

Browse files
mossttsrealtime/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ from transformers.utils import _LazyModule
18
+ from transformers.utils.import_utils import define_import_structure
19
+
20
+
21
+ if TYPE_CHECKING:
22
+ from .configuration_mossttsrealtime import *
23
+ from .modeling_mossttsrealtime import *
24
+ from .modeling_mossttsrealtime_local import *
25
+ from .processing_mossttsrealtime import *
26
+ from .streaming_mossttsrealtime import *
27
+ else:
28
+ import sys
29
+
30
+ _file = globals()["__file__"]
31
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
mossttsrealtime/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (840 Bytes). View file
 
mossttsrealtime/__pycache__/configuration_mossttsrealtime.cpython-312.pyc ADDED
Binary file (4.38 kB). View file
 
mossttsrealtime/__pycache__/modeling_mossttsrealtime.cpython-312.pyc ADDED
Binary file (9.15 kB). View file
 
mossttsrealtime/__pycache__/modeling_mossttsrealtime_local.cpython-312.pyc ADDED
Binary file (26.6 kB). View file
 
mossttsrealtime/__pycache__/processing_mossttsrealtime.cpython-312.pyc ADDED
Binary file (9.31 kB). View file
 
mossttsrealtime/__pycache__/streaming_mossttsrealtime.cpython-312.pyc ADDED
Binary file (52.2 kB). View file
 
mossttsrealtime/configuration_mossttsrealtime.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """MossTTSRealtimeModel configuration."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from typing import Any
20
+
21
+ from transformers.configuration_utils import PretrainedConfig
22
+ from transformers.models.qwen3 import Qwen3Config
23
+
24
+
25
+ def _ensure_config(cfg: Any, cls: type[PretrainedConfig]) -> PretrainedConfig:
26
+ if isinstance(cfg, cls):
27
+ return cfg
28
+ if cfg is None:
29
+ return cls()
30
+ if isinstance(cfg, dict):
31
+ return cls(**cfg)
32
+ raise TypeError(f"Unsupported config type for {cls.__name__}: {type(cfg)}")
33
+
34
+
35
+ class MossTTSRealtimeLocalTransformerConfig(PretrainedConfig):
36
+ model_type = "moss_tts_realtime_local_transformer"
37
+
38
+ def __init__(
39
+ self,
40
+ head_dim: int = 128,
41
+ use_cache: bool = True,
42
+ hidden_size: int = 2048,
43
+ rms_norm_eps: float = 1e-6,
44
+ num_hidden_layers: int = 4,
45
+ intermediate_size: int = 6144,
46
+ num_attention_heads: int = 16,
47
+ initializer_range: float = 0.02,
48
+ attention_bias: bool = False,
49
+ attention_dropout: float = 0.0,
50
+ max_position_embeddings: int = 33,
51
+ num_key_value_heads: int = 8,
52
+ hidden_act: str = "silu",
53
+ rope_theta: int = 1000000,
54
+ rope_type: str = "linear",
55
+ pad_token_id: int = 1024,
56
+ rope_parameters: dict | None = None,
57
+ **kwargs,
58
+ ):
59
+ super().__init__(**kwargs)
60
+ self.head_dim = head_dim
61
+ self.hidden_size = hidden_size
62
+ self.intermediate_size = intermediate_size
63
+ self.num_hidden_layers = num_hidden_layers
64
+ self.num_attention_heads = num_attention_heads
65
+ self.initializer_range = initializer_range
66
+ self.rms_norm_eps = rms_norm_eps
67
+ self.use_cache = use_cache
68
+ self.hidden_act = hidden_act
69
+ self.rope_theta = rope_theta
70
+ self.rope_type = rope_type
71
+ if rope_parameters is None:
72
+ rope_parameters = {"rope_type": rope_type, "rope_theta": rope_theta, "factor": 1.0}
73
+ self.rope_parameters = rope_parameters
74
+ self.attention_bias = attention_bias
75
+ self.attention_dropout = attention_dropout
76
+ self.num_key_value_heads = num_key_value_heads
77
+ self.max_position_embeddings = max_position_embeddings
78
+ self.pad_token_id = pad_token_id
79
+
80
+ self.audio_pad_token = 1024
81
+ self.audio_vocab_size = 1027
82
+ self.rvq = 16
83
+
84
+
85
+ class MossTTSRealtimeConfig(PretrainedConfig):
86
+ model_type = "moss_tts_realtime"
87
+
88
+ def __init__(
89
+ self,
90
+ language_config: Qwen3Config | dict | None = None,
91
+ local_config: MossTTSRealtimeLocalTransformerConfig | dict | None = None,
92
+ rvq: int = 16,
93
+ audio_pad_token: int = 1024,
94
+ audio_vocab_size: int = 1027,
95
+ reference_audio_pad: int = 151654,
96
+ text_pad: int = 151655,
97
+ initializer_range: float = 0.02,
98
+ **kwargs,
99
+ ):
100
+ super().__init__(**kwargs)
101
+ self.rvq = rvq
102
+ self.initializer_range = initializer_range
103
+ self.audio_pad_token = audio_pad_token
104
+ self.audio_vocab_size = audio_vocab_size
105
+ self.reference_audio_pad = reference_audio_pad
106
+ self.text_pad = text_pad
107
+ self.language_config = _ensure_config(language_config, Qwen3Config)
108
+ self.local_config = _ensure_config(local_config, MossTTSRealtimeLocalTransformerConfig)
109
+
110
+ attn_impl = self._attn_implementation
111
+ self.language_config._attn_implementation = attn_impl
112
+ self.local_config._attn_implementation = attn_impl
113
+
114
+
115
+ __all__ = ["MossTTSRealtimeConfig", "MossTTSRealtimeLocalTransformerConfig"]
mossttsrealtime/modeling_mossttsrealtime.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """MossTTSRealtime model."""
15
+
16
+ from __future__ import annotations
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Optional
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+ from transformers import initialization as init
26
+ from transformers.cache_utils import Cache
27
+ from transformers.modeling_outputs import ModelOutput
28
+ from transformers.modeling_utils import PreTrainedModel
29
+ from transformers.models.qwen3 import Qwen3Model
30
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3DecoderLayer
31
+ from .configuration_mossttsrealtime import MossTTSRealtimeConfig
32
+ from .modeling_mossttsrealtime_local import MossTTSRealtimeLocalTransformerForCausalLM
33
+
34
+
35
+ class MossTTSRealtimePretrainedModel(PreTrainedModel):
36
+ config_class = MossTTSRealtimeConfig
37
+ config: MossTTSRealtimeConfig
38
+ base_model_prefix = "model"
39
+ supports_gradient_checkpointing = True
40
+ _no_split_modules = ["Qwen3DecoderLayer"]
41
+ _skip_keys_device_placement = ["past_key_values"]
42
+ _supports_sdpa = True
43
+ _supports_flex_attn = True
44
+ _supports_flash_attn = True
45
+ _can_compile_fullgraph = True
46
+ _supports_attention_backend = True
47
+ _can_record_outputs = {
48
+ "hidden_states": Qwen3DecoderLayer,
49
+ "attentions": Qwen3Attention,
50
+ }
51
+
52
+ def _init_weights(self, module):
53
+ std = self.config.initializer_range
54
+ if isinstance(module, nn.Linear):
55
+ init.normal_(module.weight, mean=0.0, std=std)
56
+ if module.bias is not None:
57
+ init.zeros_(module.bias)
58
+ elif isinstance(module, nn.Embedding):
59
+ init.normal_(module.weight, mean=0.0, std=std)
60
+ if module.padding_idx is not None:
61
+ init.zeros_(module.weight[module.padding_idx])
62
+
63
+
64
+ @dataclass
65
+ class MossTTSRealtimeOutputWithPast(ModelOutput):
66
+ loss: Optional[torch.FloatTensor] = None
67
+ logits: Optional[torch.FloatTensor] = None
68
+ past_key_values: Optional[Cache] = None
69
+ last_hidden_state: Optional[torch.FloatTensor] = None
70
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
71
+ attentions: Optional[tuple[torch.FloatTensor]] = None
72
+ local_loss: Optional[torch.FloatTensor] = None
73
+ local_logits: Optional[torch.FloatTensor] = None
74
+ local_past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None
75
+ local_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
76
+ local_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
77
+ backbone_loss: Optional[torch.FloatTensor] = None
78
+
79
+
80
+ class MossTTSRealtime(MossTTSRealtimePretrainedModel):
81
+ def __init__(self, config: MossTTSRealtimeConfig):
82
+ super().__init__(config)
83
+ self.config = config
84
+ self.embed_tokens = nn.ModuleList([])
85
+ self.embed_tokens.append(
86
+ nn.Embedding(
87
+ config.language_config.vocab_size,
88
+ config.language_config.hidden_size,
89
+ config.language_config.pad_token_id,
90
+ )
91
+ )
92
+ self.audio_vocab_size = self.config.audio_vocab_size
93
+ for _ in range(self.config.rvq):
94
+ self.embed_tokens.append(
95
+ nn.Embedding(self.audio_vocab_size, config.language_config.hidden_size, self.config.audio_pad_token)
96
+ )
97
+ self.language_model = Qwen3Model._from_config(config.language_config)
98
+ self.local_transformer = MossTTSRealtimeLocalTransformerForCausalLM._from_config(config.local_config)
99
+ self.post_init()
100
+
101
+ def get_input_embeddings(self, input_ids):
102
+ if input_ids.device != self.embed_tokens[0].weight.device:
103
+ input_ids = input_ids.to(self.embed_tokens[0].weight.device)
104
+ inputs_embeds = self.embed_tokens[0](input_ids[..., 0])
105
+ for i, embed in enumerate(self.embed_tokens):
106
+ if i == 0:
107
+ continue
108
+ inputs_embeds = inputs_embeds + embed(input_ids[..., i])
109
+ return inputs_embeds
110
+
111
+ def forward(
112
+ self,
113
+ input_ids: Optional[torch.LongTensor] = None,
114
+ attention_mask: Optional[torch.Tensor] = None,
115
+ position_ids: Optional[torch.LongTensor] = None,
116
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
117
+ inputs_embeds: Optional[torch.FloatTensor] = None,
118
+ labels: Optional[torch.LongTensor] = None,
119
+ use_cache: Optional[bool] = False,
120
+ output_attentions: Optional[bool] = None,
121
+ output_hidden_states: Optional[bool] = None,
122
+ return_dict: Optional[bool] = None,
123
+ cache_position: Optional[torch.LongTensor] = None,
124
+ hidden_out_layers: Optional[list] = None,
125
+ **kwargs,
126
+ ):
127
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
128
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
129
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
130
+
131
+ if inputs_embeds is None:
132
+ inputs_embeds = self.get_input_embeddings(input_ids)
133
+
134
+ outputs = self.language_model(
135
+ position_ids=position_ids,
136
+ attention_mask=attention_mask,
137
+ past_key_values=past_key_values,
138
+ inputs_embeds=inputs_embeds,
139
+ use_cache=use_cache,
140
+ output_attentions=output_attentions,
141
+ output_hidden_states=output_hidden_states,
142
+ return_dict=True,
143
+ cache_position=cache_position,
144
+ **kwargs,
145
+ )
146
+
147
+ loss = None
148
+ local_outputs = None
149
+ if labels is not None:
150
+ audio_labels = labels[:, :, 1:]
151
+ train_mask = ~(audio_labels == -100).all(dim=-1)
152
+ local_input_ids = audio_labels[train_mask][..., : self.config.rvq - 1]
153
+ local_input_ids[local_input_ids == -100] = self.config.audio_pad_token
154
+ local_input_ids = F.pad(local_input_ids, (1, 0), value=0)
155
+
156
+ train_idx = train_mask.nonzero(as_tuple=True)
157
+ hidden_positions = torch.clamp(train_idx[1] - 1, min=0)
158
+ local_hidden_states = outputs.last_hidden_state[train_idx[0], hidden_positions, :].reshape(
159
+ -1, 1, self.config.local_config.hidden_size
160
+ )
161
+ local_labels = audio_labels[train_mask]
162
+
163
+ local_outputs = self.local_transformer(
164
+ input_ids=local_input_ids,
165
+ backbone_last_hidden_state=local_hidden_states,
166
+ use_cache=use_cache,
167
+ return_dict=True,
168
+ labels=local_labels,
169
+ **kwargs,
170
+ )
171
+ loss = local_outputs.loss
172
+
173
+ output = MossTTSRealtimeOutputWithPast(
174
+ loss=loss,
175
+ logits=None,
176
+ past_key_values=outputs.past_key_values,
177
+ last_hidden_state=outputs.last_hidden_state,
178
+ hidden_states=outputs.hidden_states,
179
+ attentions=outputs.attentions,
180
+ local_logits=local_outputs.logits if local_outputs is not None else None,
181
+ local_past_key_values=local_outputs.past_key_values if local_outputs is not None else None,
182
+ local_hidden_states=local_outputs.hidden_states if local_outputs is not None else None,
183
+ local_attentions=local_outputs.attentions if local_outputs is not None else None,
184
+ )
185
+ if not return_dict:
186
+ return output.to_tuple()
187
+ return output
188
+
189
+
190
+ __all__ = ["MossTTSRealtime", "MossTTSRealtimeConfig", "MossTTSRealtimeOutputWithPast", "MossTTSRealtimePretrainedModel"]
mossttsrealtime/modeling_mossttsrealtime_local.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Local transformer used by MossTTSRealtime for RVQ codebook decoding."""
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Optional, Union
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from transformers.activations import ACT2FN
24
+ from transformers.cache_utils import Cache, StaticCache
25
+ from transformers.generation import GenerationMixin
26
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
27
+ from transformers.modeling_layers import GradientCheckpointingLayer
28
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
29
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
30
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
31
+ from transformers.masking_utils import create_causal_mask
32
+ from transformers.processing_utils import Unpack
33
+ from transformers.loss.loss_utils import ForCausalLMLoss
34
+ from transformers.utils import TransformersKwargs, logging
35
+ from .configuration_mossttsrealtime import MossTTSRealtimeLocalTransformerConfig
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ class MossTTSRealtimeLocalTransformerRMSNorm(nn.Module):
41
+ def __init__(self, hidden_size, eps=1e-6) -> None:
42
+ super().__init__()
43
+ self.weight = nn.Parameter(torch.ones(hidden_size))
44
+ self.variance_epsilon = eps
45
+
46
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
47
+ input_dtype = hidden_states.dtype
48
+ hidden_states = hidden_states.to(torch.float32)
49
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
50
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
51
+ return self.weight * hidden_states.to(input_dtype)
52
+
53
+ def extra_repr(self):
54
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
55
+
56
+
57
+ class MossTTSRealtimeLocalTransformerMLP(nn.Module):
58
+ def __init__(self, config: MossTTSRealtimeLocalTransformerConfig):
59
+ super().__init__()
60
+ self.config = config
61
+ self.hidden_size = config.hidden_size
62
+ self.intermediate_size = config.intermediate_size
63
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
64
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
65
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
66
+ self.act_fn = ACT2FN[config.hidden_act]
67
+
68
+ def forward(self, x):
69
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
70
+ return down_proj
71
+
72
+
73
+ def rotate_half(x):
74
+ x1 = x[..., : x.shape[-1] // 2]
75
+ x2 = x[..., x.shape[-1] // 2 :]
76
+ return torch.cat((-x2, x1), dim=-1)
77
+
78
+
79
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
80
+ cos = cos.unsqueeze(unsqueeze_dim)
81
+ sin = sin.unsqueeze(unsqueeze_dim)
82
+ q_embed = (q * cos) + (rotate_half(q) * sin)
83
+ k_embed = (k * cos) + (rotate_half(k) * sin)
84
+ return q_embed, k_embed
85
+
86
+
87
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
88
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
89
+ if n_rep == 1:
90
+ return hidden_states
91
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
92
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
93
+
94
+
95
+ def eager_attention_forward(
96
+ module: nn.Module,
97
+ query: torch.Tensor,
98
+ key: torch.Tensor,
99
+ value: torch.Tensor,
100
+ attention_mask: Optional[torch.Tensor],
101
+ scaling: float,
102
+ dropout: float = 0.0,
103
+ **kwargs: Unpack[TransformersKwargs],
104
+ ):
105
+ key_states = repeat_kv(key, module.num_key_value_groups)
106
+ value_states = repeat_kv(value, module.num_key_value_groups)
107
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
108
+ if attention_mask is not None:
109
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
110
+ attn_weights = attn_weights + causal_mask
111
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
112
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
113
+ attn_output = torch.matmul(attn_weights, value_states)
114
+ attn_output = attn_output.transpose(1, 2).contiguous()
115
+ return attn_output, attn_weights
116
+
117
+
118
+ class MossTTSRealtimeLocalTransformerAttention(nn.Module):
119
+ def __init__(self, config: MossTTSRealtimeLocalTransformerConfig, layer_idx: int):
120
+ super().__init__()
121
+ self.config = config
122
+ self.layer_idx = layer_idx
123
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
124
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
125
+ self.scaling = self.head_dim**-0.5
126
+ self.attention_dropout = config.attention_dropout
127
+ self.is_causal = True
128
+
129
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias)
130
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
131
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
132
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
133
+ self.q_norm = MossTTSRealtimeLocalTransformerRMSNorm(self.head_dim, eps=config.rms_norm_eps)
134
+ self.k_norm = MossTTSRealtimeLocalTransformerRMSNorm(self.head_dim, eps=config.rms_norm_eps)
135
+ self.sliding_window = None
136
+
137
+ def forward(
138
+ self,
139
+ hidden_states: torch.Tensor,
140
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
141
+ attention_mask: Optional[torch.Tensor],
142
+ past_key_values: Optional[Cache] = None,
143
+ cache_position: Optional[torch.LongTensor] = None,
144
+ **kwargs: Unpack[FlashAttentionKwargs],
145
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
146
+ input_shape = hidden_states.shape[:-1]
147
+ hidden_shape = (*input_shape, -1, self.head_dim)
148
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
149
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
150
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
151
+ cos, sin = position_embeddings
152
+
153
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
154
+
155
+ if past_key_values is not None:
156
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
157
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
158
+
159
+ attention_interface = eager_attention_forward
160
+ if self.config._attn_implementation != "eager":
161
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
162
+
163
+ attn_output, attn_weights = attention_interface(
164
+ self,
165
+ query_states,
166
+ key_states,
167
+ value_states,
168
+ attention_mask,
169
+ dropout=0.0 if not self.training else self.attention_dropout,
170
+ scaling=self.scaling,
171
+ sliding_window=self.sliding_window,
172
+ **kwargs,
173
+ )
174
+
175
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
176
+ attn_output = self.o_proj(attn_output)
177
+ return attn_output, attn_weights
178
+
179
+
180
+ class MossTTSRealtimeLocalTransformerDecoderLayer(GradientCheckpointingLayer):
181
+ def __init__(self, config: MossTTSRealtimeLocalTransformerConfig, layer_idx: int):
182
+ super().__init__()
183
+ self.hidden_size = config.hidden_size
184
+ self.self_attn = MossTTSRealtimeLocalTransformerAttention(config=config, layer_idx=layer_idx)
185
+ self.mlp = MossTTSRealtimeLocalTransformerMLP(config)
186
+ self.input_layernorm = MossTTSRealtimeLocalTransformerRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
187
+ self.post_attention_layernorm = MossTTSRealtimeLocalTransformerRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
188
+ self.attention_type = "full_attention"
189
+
190
+ def forward(
191
+ self,
192
+ hidden_states: torch.Tensor,
193
+ attention_mask: Optional[torch.Tensor] = None,
194
+ position_ids: Optional[torch.LongTensor] = None,
195
+ past_key_values: Optional[Cache] = None,
196
+ use_cache: Optional[bool] = False,
197
+ cache_position: Optional[torch.LongTensor] = None,
198
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
199
+ **kwargs: Unpack[TransformersKwargs],
200
+ ) -> torch.Tensor:
201
+ residual = hidden_states
202
+ hidden_states = self.input_layernorm(hidden_states)
203
+ hidden_states, _ = self.self_attn(
204
+ hidden_states=hidden_states,
205
+ attention_mask=attention_mask,
206
+ position_ids=position_ids,
207
+ past_key_values=past_key_values,
208
+ use_cache=use_cache,
209
+ cache_position=cache_position,
210
+ position_embeddings=position_embeddings,
211
+ **kwargs,
212
+ )
213
+ hidden_states = residual + hidden_states
214
+ residual = hidden_states
215
+ hidden_states = self.post_attention_layernorm(hidden_states)
216
+ hidden_states = self.mlp(hidden_states)
217
+ hidden_states = residual + hidden_states
218
+ return hidden_states
219
+
220
+
221
+ class MossTTSRealtimeLocalTransformerPreTrainedModel(PreTrainedModel):
222
+
223
+ config_class = MossTTSRealtimeLocalTransformerConfig
224
+ config: MossTTSRealtimeLocalTransformerConfig
225
+
226
+ base_model_prefix = "local_transformer"
227
+ supports_gradient_checkpointing = True
228
+ _no_split_modules = ["MossTTSRealtimeLocalTransformerDecoderLayer"]
229
+ _skip_keys_device_placement = ["past_key_values"]
230
+ _supports_sdpa = True
231
+ _supports_flex_attn = True
232
+ _supports_flash_attn = True
233
+ _can_compile_fullgraph = True
234
+ _supports_attention_backend = True
235
+
236
+ _can_record_outputs = {
237
+ "hidden_states": MossTTSRealtimeLocalTransformerDecoderLayer,
238
+ "attentions": MossTTSRealtimeLocalTransformerAttention,
239
+ }
240
+
241
+
242
+ class MossTTSRealtimeLocalTransformerRotaryEmbedding(nn.Module):
243
+ inv_freq: torch.Tensor
244
+
245
+ def __init__(self, config: MossTTSRealtimeLocalTransformerConfig, device=None):
246
+ super().__init__()
247
+ self.config = config
248
+ self.rope_type = getattr(config, "rope_type", "linear")
249
+ self.max_seq_len_cached = config.max_position_embeddings
250
+ self.original_max_seq_len = config.max_position_embeddings
251
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
252
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
253
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
254
+ self.original_inv_freq = self.inv_freq
255
+
256
+ @torch.no_grad()
257
+ @dynamic_rope_update
258
+ def forward(self, x, position_ids):
259
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
260
+ position_ids_expanded = position_ids[:, None, :].float()
261
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
262
+ with torch.autocast(device_type=device_type, enabled=False):
263
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
264
+ emb = torch.cat((freqs, freqs), dim=-1)
265
+ cos = emb.cos() * self.attention_scaling
266
+ sin = emb.sin() * self.attention_scaling
267
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
268
+
269
+
270
+ class MossTTSRealtimeLocalTransformer(MossTTSRealtimeLocalTransformerPreTrainedModel):
271
+ def __init__(self, config: MossTTSRealtimeLocalTransformerConfig):
272
+ super().__init__(config)
273
+ self.padding_idx = config.pad_token_id
274
+ self.embed_tokens = nn.ModuleList(
275
+ [nn.Embedding(config.audio_vocab_size, config.hidden_size, config.audio_pad_token) for _ in range(config.rvq - 1)]
276
+ )
277
+ self.layers = nn.ModuleList(
278
+ [MossTTSRealtimeLocalTransformerDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
279
+ )
280
+ self.norm = MossTTSRealtimeLocalTransformerRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
281
+ self.rotary_emb = MossTTSRealtimeLocalTransformerRotaryEmbedding(config=config)
282
+ self.gradient_checkpointing = False
283
+ self.has_sliding_layers = None
284
+ self.post_init()
285
+
286
+ def forward(
287
+ self,
288
+ input_ids: Optional[torch.LongTensor] = None,
289
+ backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
290
+ attention_mask: Optional[torch.Tensor] = None,
291
+ position_ids: Optional[torch.LongTensor] = None,
292
+ past_key_values: Optional[Cache] = None,
293
+ inputs_embeds: Optional[torch.FloatTensor] = None,
294
+ use_cache: Optional[bool] = None,
295
+ cache_position: Optional[torch.LongTensor] = None,
296
+ codebook_idx: Optional[int] = None,
297
+ **kwargs: Unpack[TransformersKwargs],
298
+ ) -> BaseModelOutputWithPast:
299
+ if position_ids is not None and not torch.compiler.is_compiling():
300
+ position_ids = None
301
+
302
+ if (input_ids is None) == (inputs_embeds is None):
303
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
304
+
305
+ if use_cache and past_key_values is None:
306
+ device = inputs_embeds.device if inputs_embeds is not None else input_ids.device
307
+ past_key_values = StaticCache(config=self.config, max_cache_len=self.config.rvq, device=device)
308
+
309
+ if cache_position is None:
310
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
311
+ inputs_seq_length = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
312
+ device = inputs_embeds.device if inputs_embeds is not None else input_ids.device
313
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_seq_length, device=device)
314
+
315
+ if inputs_embeds is None:
316
+ if codebook_idx is not None:
317
+ if codebook_idx <= 0:
318
+ raise ValueError(f"`codebook_idx` must be in [1, {len(self.embed_tokens)}], got {codebook_idx}.")
319
+ if codebook_idx > len(self.embed_tokens):
320
+ raise ValueError(f"`codebook_idx` must be in [1, {len(self.embed_tokens)}], got {codebook_idx}.")
321
+ if input_ids.ndim == 1:
322
+ input_ids = input_ids.unsqueeze(1)
323
+ token_emb = self.embed_tokens[codebook_idx - 1](input_ids[:, 0]).unsqueeze(1) # [B,1,H]
324
+ inputs_embeds = token_emb
325
+ else:
326
+ if input_ids.shape[1] != cache_position.shape[0]:
327
+ raise ValueError(
328
+ "`input_ids` and `cache_position` must align in sequence length: "
329
+ f"got {input_ids.shape[1]} and {cache_position.shape[0]}."
330
+ )
331
+ codebook_idxs = torch.clamp(cache_position - 1, min=0, max=len(self.embed_tokens) - 1)
332
+ inputs_embeds = torch.stack(
333
+ [
334
+ self.embed_tokens[codebook_idx](input_ids[:, seq_idx])
335
+ for seq_idx, codebook_idx in enumerate(codebook_idxs.tolist())
336
+ ],
337
+ dim=1,
338
+ )
339
+
340
+ input_ids_are_first_codebook = bool(cache_position[0] == 0)
341
+ if backbone_last_hidden_state is not None:
342
+ inputs_embeds[:, 0, :] = backbone_last_hidden_state[:, 0, :]
343
+ else:
344
+ if not torch.compiler.is_compiling() and input_ids_are_first_codebook:
345
+ logger.warning(
346
+ "When the first codebook token is provided, `backbone_last_hidden_state` should also be provided for correct inference."
347
+ )
348
+
349
+ causal_mask = create_causal_mask(
350
+ config=self.config,
351
+ input_embeds=inputs_embeds,
352
+ attention_mask=attention_mask,
353
+ cache_position=cache_position,
354
+ past_key_values=past_key_values,
355
+ position_ids=position_ids,
356
+ )
357
+
358
+ hidden_states = inputs_embeds
359
+ position_ids = cache_position.unsqueeze(0)
360
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
361
+
362
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
363
+ hidden_states = decoder_layer(
364
+ hidden_states,
365
+ attention_mask=causal_mask,
366
+ position_ids=position_ids,
367
+ past_key_values=past_key_values,
368
+ use_cache=use_cache,
369
+ cache_position=cache_position,
370
+ position_embeddings=position_embeddings,
371
+ **kwargs,
372
+ )
373
+ hidden_states = self.norm(hidden_states)
374
+ return BaseModelOutputWithPast(
375
+ last_hidden_state=hidden_states,
376
+ past_key_values=past_key_values if use_cache else None,
377
+ )
378
+
379
+
380
+ class MossTTSRealtimeLocalTransformerForCausalLM(MossTTSRealtimeLocalTransformerPreTrainedModel, GenerationMixin):
381
+ _tied_weights_keys = None
382
+ _tp_plan = None
383
+ _pp_plan = None
384
+
385
+ def __init__(self, config):
386
+ super().__init__(config)
387
+ self.model = MossTTSRealtimeLocalTransformer(config)
388
+ self.audio_vocab_size = self.config.audio_vocab_size
389
+
390
+ self.local_lm_heads = nn.ModuleList(
391
+ [nn.Linear(config.hidden_size, config.audio_vocab_size, bias=False) for _ in range(config.rvq)]
392
+ )
393
+ self.post_init()
394
+
395
+ def forward(
396
+ self,
397
+ input_ids: Optional[torch.LongTensor] = None,
398
+ backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
399
+ attention_mask: Optional[torch.Tensor] = None,
400
+ position_ids: Optional[torch.LongTensor] = None,
401
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
402
+ inputs_embeds: Optional[torch.FloatTensor] = None,
403
+ labels: Optional[torch.LongTensor] = None,
404
+ use_cache: Optional[bool] = None,
405
+ cache_position: Optional[torch.LongTensor] = None,
406
+ codebook_idx: Optional[int] = None,
407
+ logits_to_keep: Union[int, torch.Tensor] = 0,
408
+ **kwargs: Unpack[TransformersKwargs],
409
+ ) -> Union[tuple, CausalLMOutputWithPast]:
410
+ outputs = self.model(
411
+ input_ids=input_ids,
412
+ backbone_last_hidden_state=backbone_last_hidden_state,
413
+ inputs_embeds=inputs_embeds,
414
+ attention_mask=attention_mask,
415
+ position_ids=position_ids,
416
+ past_key_values=past_key_values,
417
+ use_cache=use_cache,
418
+ cache_position=cache_position,
419
+ codebook_idx=codebook_idx,
420
+ **kwargs,
421
+ )
422
+
423
+ hidden_states = outputs.last_hidden_state
424
+
425
+ if isinstance(logits_to_keep, int):
426
+ if logits_to_keep == 0:
427
+ slice_indices = slice(0, None)
428
+ else:
429
+ slice_indices = slice(-logits_to_keep, None)
430
+ else:
431
+ slice_indices = logits_to_keep
432
+ hs = hidden_states[:, slice_indices, :]
433
+
434
+ if cache_position is not None:
435
+ if codebook_idx is None:
436
+ raise ValueError("`codebook_idx` must be provided when `cache_position` is provided.")
437
+ logits = self.local_lm_heads[codebook_idx](hs[:, 0, :]).unsqueeze(1)
438
+ else:
439
+ if hs.shape[1] > len(self.local_lm_heads):
440
+ raise ValueError(
441
+ f"Cannot project {hs.shape[1]} codebooks with only {len(self.local_lm_heads)} LM heads."
442
+ )
443
+ logits_list = []
444
+ for i in range(hs.shape[1]):
445
+ logits_list.append(self.local_lm_heads[i](hs[:, i, :]))
446
+ logits = torch.stack(logits_list, dim=1)
447
+
448
+ logits = logits.contiguous()
449
+ loss = None
450
+ if labels is not None:
451
+ loss = ForCausalLMLoss(logits, None, self.audio_vocab_size, shift_labels=labels.contiguous())
452
+
453
+ return CausalLMOutputWithPast(
454
+ loss=loss,
455
+ logits=logits,
456
+ past_key_values=outputs.past_key_values,
457
+ hidden_states=outputs.hidden_states,
458
+ attentions=outputs.attentions,
459
+ )
460
+
461
+ __all__ = [
462
+ "MossTTSRealtimeLocalTransformer",
463
+ "MossTTSRealtimeLocalTransformerAttention",
464
+ "MossTTSRealtimeLocalTransformerConfig",
465
+ "MossTTSRealtimeLocalTransformerDecoderLayer",
466
+ "MossTTSRealtimeLocalTransformerForCausalLM",
467
+ "MossTTSRealtimeLocalTransformerPreTrainedModel",
468
+ "MossTTSRealtimeLocalTransformerRMSNorm",
469
+ "MossTTSRealtimeLocalTransformerRotaryEmbedding",
470
+ ]
471
+
mossttsrealtime/processing_mossttsrealtime.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Processing utilities for MossTTSRealtime."""
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Iterable, Optional
19
+
20
+ import numpy as np
21
+
22
+ from transformers.processing_utils import ProcessorMixin
23
+
24
+
25
+ class MossTTSRealtimeProcessor(ProcessorMixin):
26
+ """Builds MossTTSRealtime prompt inputs with text and audio codebooks.
27
+ This processor focuses on preparing the mixed text/audio token layout expected by MossTTSRealtime.
28
+ It does not perform audio encoding/decoding by itself.
29
+ """
30
+
31
+ attributes = ["tokenizer"]
32
+ tokenizer_class = "AutoTokenizer"
33
+
34
+ def __init__(
35
+ self,
36
+ tokenizer,
37
+ audio_pad_token: str = "<|audio_pad|>",
38
+ text_pad_token: str = "<|text_pad|>",
39
+ tts_system_prompt: Optional[str] = None,
40
+ channels: int = 16,
41
+ audio_channel_pad: int = 1024,
42
+ audio_bos_token: int = 1025,
43
+ audio_eos_token: int = 1026,
44
+ delay_tokens_len: int = 12,
45
+ ):
46
+ super().__init__(tokenizer=tokenizer)
47
+ self.audio_pad_token = audio_pad_token
48
+ self.text_pad_token = text_pad_token
49
+ self.channels = channels
50
+ self.audio_channel_pad = audio_channel_pad
51
+ self.audio_bos_token = audio_bos_token
52
+ self.audio_eos_token = audio_eos_token
53
+ self.delay_tokens_len = delay_tokens_len
54
+
55
+ self.audio_pad_token_id = self._convert_token_to_id(audio_pad_token)
56
+ self.text_pad_token_id = self._convert_token_to_id(text_pad_token)
57
+
58
+ if tts_system_prompt is None:
59
+ tts_system_prompt = (
60
+ "<|im_start|>system\n"
61
+ "You are a highly expressive text-to-speech (TTS) engine developed by Mosi Intelligence. \n"
62
+ "You possess natural language understanding, emotional modeling, and multi-style speech generation "
63
+ "capabilities, allowing you to generate the corresponding speech based on the text given in the assistant."
64
+ "<|im_end|>\n"
65
+ )
66
+ self.tts_system_prompt = tts_system_prompt
67
+
68
+ def _convert_token_to_id(self, token: str) -> int:
69
+ if hasattr(self.tokenizer, "convert_tokens_to_ids"):
70
+ token_id = self.tokenizer.convert_tokens_to_ids(token)
71
+ if token_id is not None and token_id != self.tokenizer.unk_token_id:
72
+ return int(token_id)
73
+ token_ids = self.tokenizer.encode(token, add_special_tokens=False)
74
+ if not token_ids:
75
+ raise ValueError(f"Token '{token}' could not be converted to an id.")
76
+ if len(token_ids) != 1:
77
+ raise ValueError(f"Token '{token}' maps to multiple ids: {token_ids}")
78
+ return int(token_ids[0])
79
+
80
+ def make_voice_clone_prompt(self, prompt_audio_tokens_len: int) -> str:
81
+ padded_audio_prompt = f"{self.audio_pad_token * prompt_audio_tokens_len}"
82
+ voice_clone = (
83
+ "<|im_start|>context\n"
84
+ "The assistant section should be synthesized using the following voice timbre:"
85
+ f"{padded_audio_prompt}"
86
+ )
87
+ return voice_clone
88
+
89
+ def _normalize_audio_tokens(self, audio_tokens: np.ndarray | Iterable) -> np.ndarray:
90
+ tokens = np.array(audio_tokens)
91
+ if tokens.ndim != 2:
92
+ raise ValueError(f"Expected 2D audio tokens, got shape {tokens.shape}")
93
+ # Accept [channels, T] or [T, channels], and slice to expected channels if needed.
94
+ if tokens.shape[0] == self.channels:
95
+ tokens = tokens.T
96
+ elif tokens.shape[1] == self.channels:
97
+ tokens = tokens
98
+ elif tokens.shape[0] > self.channels and tokens.shape[1] != self.channels:
99
+ tokens = tokens[: self.channels, :].T
100
+ elif tokens.shape[1] > self.channels and tokens.shape[0] != self.channels:
101
+ tokens = tokens[:, : self.channels]
102
+ if tokens.shape[1] != self.channels:
103
+ raise ValueError(f"Expected {self.channels} channels, got shape {tokens.shape}")
104
+ return tokens
105
+
106
+ def make_ensemble(self, prompt_audio_tokens: Optional[np.ndarray] = None) -> np.ndarray:
107
+ if prompt_audio_tokens is not None:
108
+ prompt_audio_tokens = self._normalize_audio_tokens(prompt_audio_tokens)
109
+ prompt_audio_tokens = prompt_audio_tokens[:, : self.channels]
110
+ system_prompt_text = f"{self.tts_system_prompt}" + f"{self.make_voice_clone_prompt(prompt_audio_tokens.shape[0])}"
111
+ else:
112
+ system_prompt_text = f"{self.tts_system_prompt}"
113
+
114
+ system_prompt_tokens = self.tokenizer(system_prompt_text)["input_ids"]
115
+ system_prompt_tokens_full = np.full(
116
+ shape=(len(system_prompt_tokens), self.channels + 1), fill_value=self.audio_channel_pad, dtype=np.int64
117
+ )
118
+ system_prompt_tokens_full[:, 0] = system_prompt_tokens
119
+
120
+ if prompt_audio_tokens is not None:
121
+ system_prompt_tokens = np.array(system_prompt_tokens)
122
+ indices = np.where(system_prompt_tokens == self.audio_pad_token_id)[0]
123
+ if indices.size == 0:
124
+ raise ValueError("No <|audio_pad|> tokens found in the system prompt.")
125
+ prompt_audio_start_pos, prompt_audio_end_pos = indices[0], indices[-1]
126
+ system_prompt_tokens_full[prompt_audio_start_pos : prompt_audio_end_pos + 1, 1:] = prompt_audio_tokens
127
+
128
+ return system_prompt_tokens_full
129
+
130
+ def make_user_prompt(self, text: str, audio_tokens: np.ndarray) -> np.ndarray:
131
+ prefill_temp = "<|im_end|>\n<|im_start|>user\n"
132
+ text_tokens = self.tokenizer(text)["input_ids"]
133
+ text_start_pos = len(self.tokenizer.encode(prefill_temp))
134
+ token = self._normalize_audio_tokens(audio_tokens)
135
+
136
+ text_len = len(text_tokens)
137
+ audio_len = token.shape[0]
138
+
139
+ if text_len >= self.delay_tokens_len:
140
+ padded_text_len = audio_len + self.delay_tokens_len - text_len + 1
141
+ cur_input_id_ch1 = prefill_temp + text + "<|text_pad|>" * padded_text_len
142
+ assistant_tokens_ch1 = self.tokenizer(cur_input_id_ch1)["input_ids"]
143
+ cur_input_id = np.full(
144
+ shape=(len(assistant_tokens_ch1), self.channels + 1),
145
+ fill_value=self.audio_channel_pad,
146
+ dtype=np.int64,
147
+ )
148
+ cur_input_id[:, 0] = assistant_tokens_ch1
149
+ cur_input_id[
150
+ text_start_pos + self.delay_tokens_len : text_start_pos + self.delay_tokens_len + audio_len, 1:
151
+ ] = token
152
+ cur_input_id[text_start_pos + self.delay_tokens_len - 1, 1] = self.audio_bos_token
153
+ cur_input_id[text_start_pos + self.delay_tokens_len + audio_len, 1] = self.audio_eos_token
154
+ else:
155
+ padded_text_len = audio_len + 1
156
+ cur_input_id_ch1 = prefill_temp + text + "<|text_pad|>" * padded_text_len
157
+ assistant_tokens_ch1 = self.tokenizer(cur_input_id_ch1)["input_ids"]
158
+ cur_input_id = np.full(
159
+ shape=(len(assistant_tokens_ch1), self.channels + 1),
160
+ fill_value=self.audio_channel_pad,
161
+ dtype=np.int64,
162
+ )
163
+ cur_input_id[:, 0] = assistant_tokens_ch1
164
+ cur_input_id[-(audio_len + 1) : -1, 1:] = token
165
+ cur_input_id[-(audio_len + 2), 1] = self.audio_bos_token
166
+ cur_input_id[-1, 1] = self.audio_eos_token
167
+
168
+ begin_of_response = self.tokenizer.encode("<|im_end|>\n<|im_start|>assistant\n")
169
+ begin_of_response_full = np.full(
170
+ shape=(len(begin_of_response), self.channels + 1), fill_value=self.audio_channel_pad, dtype=np.int64
171
+ )
172
+ begin_of_response_full[:, 0] = begin_of_response
173
+
174
+ input_ids = np.concatenate([cur_input_id, begin_of_response_full], axis=0)
175
+ return input_ids
176
+
177
+
178
+ __all__ = ["MossTTSRealtimeProcessor"]
mossttsrealtime/streaming_mossttsrealtime.py ADDED
@@ -0,0 +1,1051 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Streaming inference utilities for MossTTSRealtime."""
15
+
16
+ from __future__ import annotations
17
+
18
+ import contextlib
19
+ import re
20
+ from typing import Iterable, Iterator, List, Optional, Sequence
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.nn.functional as F
25
+
26
+ from transformers.cache_utils import DynamicCache, StaticCache
27
+ from transformers.utils import is_torchaudio_available, requires_backends
28
+ from transformers.utils.import_utils import requires
29
+
30
+ if is_torchaudio_available():
31
+ import torchaudio
32
+
33
+
34
+ @requires(backends=("torch",))
35
+ class MossTTSRealtimeInference:
36
+ """Step-wise inference wrapper for MossTTSRealtime.
37
+ This class mirrors the non-streaming inference logic but exposes a
38
+ prefill/step/finish API for streaming usage.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ model,
44
+ tokenizer,
45
+ max_length: int = 1000,
46
+ channels: int = 16,
47
+ audio_channel_pad: int = 1024,
48
+ audio_bos_token: int = 1025,
49
+ audio_eos_token: int = 1026,
50
+ text_pad_id: int = 151655,
51
+ aud_pad_id: int = 151654,
52
+ ):
53
+ self.model = model
54
+ self.tokenizer = tokenizer
55
+ self.max_length = max_length
56
+ self.channels = channels
57
+ self.audio_channel_pad = audio_channel_pad
58
+ self.audio_bos_token = audio_bos_token
59
+ self.audio_eos_token = audio_eos_token
60
+ self.text_pad_id = text_pad_id
61
+ self.aud_pad_id = aud_pad_id
62
+
63
+ self.past_key_values = None
64
+ self.attention_mask = None
65
+ self._generated_tokens: List[torch.Tensor] = []
66
+ self._is_stopping = None
67
+ self._last_audio_tokens = None
68
+ self._step_idx = 0
69
+ attn_impl = ""
70
+ for cfg in (
71
+ getattr(getattr(self.model, "local_transformer", None), "config", None),
72
+ getattr(getattr(self.model, "config", None), "local_config", None),
73
+ getattr(self.model, "config", None),
74
+ ):
75
+ if cfg is None:
76
+ continue
77
+ for name in ("_attn_implementation", "attn_implementation"):
78
+ candidate = getattr(cfg, name, None)
79
+ if isinstance(candidate, str) and candidate.strip():
80
+ attn_impl = candidate.strip().lower()
81
+ break
82
+ if attn_impl:
83
+ break
84
+ self._use_dynamic_local_cache = attn_impl == "flash_attention_2"
85
+ self._should_compile_local_transformer = not self._use_dynamic_local_cache
86
+ self._compiled_local_transformer = None
87
+
88
+ @property
89
+ def device(self):
90
+ return next(self.model.parameters()).device
91
+
92
+ @property
93
+ def is_finished(self) -> bool:
94
+ return self._is_stopping is not None and bool(self._is_stopping.all())
95
+
96
+ def _build_local_past_key_values(self):
97
+ if self._use_dynamic_local_cache:
98
+ return DynamicCache()
99
+ return StaticCache(config=self.model.local_transformer.config, max_cache_len=self.channels)
100
+
101
+ def _get_local_transformer_runner(self):
102
+ if not self._should_compile_local_transformer:
103
+ return self._generate_local_transformer_impl
104
+ if self._compiled_local_transformer is None:
105
+ self._compiled_local_transformer = torch.compile(self._generate_local_transformer_impl, fullgraph=True)
106
+ return self._compiled_local_transformer
107
+
108
+ def reset_generation_state(self, keep_cache: bool = True):
109
+ if not keep_cache:
110
+ self.past_key_values = None
111
+ self.attention_mask = None
112
+ # Keep the mask when reusing cache so it stays aligned with past_key_values.
113
+ # This allows concatenation with the next turn prefill mask.
114
+ self._generated_tokens = []
115
+ self._is_stopping = None
116
+ self._last_audio_tokens = None
117
+ self._step_idx = 0
118
+
119
+ def _normalize_input_ids(self, input_ids):
120
+ if isinstance(input_ids, torch.Tensor):
121
+ input_ids = input_ids.detach().cpu().numpy()
122
+ if isinstance(input_ids, np.ndarray):
123
+ if input_ids.ndim == 2:
124
+ return [input_ids]
125
+ if input_ids.ndim == 3:
126
+ return [input_ids[i] for i in range(input_ids.shape[0])]
127
+ if isinstance(input_ids, (list, tuple)):
128
+ return [np.array(item) for item in input_ids]
129
+ raise ValueError("input_ids must be a list/array/tensor of shape [T, C] or [B, T, C].")
130
+
131
+ def _normalize_text_prefix(self, text_prefix_ids, batch_size: int) -> list[list[int]]:
132
+ if text_prefix_ids is None:
133
+ raise ValueError("text_prefix_ids must be provided for prefill.")
134
+ if isinstance(text_prefix_ids, torch.Tensor):
135
+ text_prefix_ids = text_prefix_ids.detach().cpu().tolist()
136
+ if isinstance(text_prefix_ids, np.ndarray):
137
+ text_prefix_ids = text_prefix_ids.tolist()
138
+ if isinstance(text_prefix_ids, list):
139
+ if len(text_prefix_ids) == 0:
140
+ return [[] for _ in range(batch_size)]
141
+ if isinstance(text_prefix_ids[0], (int, np.integer)):
142
+ return [list(text_prefix_ids)]
143
+ if len(text_prefix_ids) == 1 and batch_size > 1:
144
+ return [list(text_prefix_ids[0]) for _ in range(batch_size)]
145
+ if len(text_prefix_ids) != batch_size:
146
+ raise ValueError(
147
+ f"text_prefix_ids batch size mismatch: got {len(text_prefix_ids)}, expected {batch_size}."
148
+ )
149
+ return [list(item) for item in text_prefix_ids]
150
+ raise ValueError("text_prefix_ids must be list-like or tensor-like.")
151
+
152
+ @torch.inference_mode()
153
+ def prefill(
154
+ self,
155
+ input_ids,
156
+ text_prefix_ids,
157
+ max_prefill_len: Optional[int] = None,
158
+ past_key_values=None,
159
+ device: Optional[torch.device] = None,
160
+ temperature: float = 0.8,
161
+ top_p: float = 0.6,
162
+ top_k: int = 30,
163
+ do_sample: bool = True,
164
+ repetition_penalty: Optional[float] = 1.1,
165
+ repetition_window: Optional[int] = 50,
166
+ ) -> torch.Tensor:
167
+ if device is None:
168
+ device = self.device
169
+
170
+ if past_key_values is not None:
171
+ self.past_key_values = past_key_values
172
+
173
+ input_ids_list = self._normalize_input_ids(input_ids)
174
+ batch_size = len(input_ids_list)
175
+ text_prefix_list = self._normalize_text_prefix(text_prefix_ids, batch_size)
176
+
177
+ concat_inputs_id_list = []
178
+ for i in range(batch_size):
179
+ prefix = text_prefix_list[i]
180
+ if max_prefill_len is not None:
181
+ prefix = prefix[:max_prefill_len]
182
+ if len(prefix) == 0:
183
+ raise ValueError("Prefill requires at least one text token.")
184
+
185
+ text_seg = np.full((len(prefix), self.channels + 1), self.audio_channel_pad, dtype=np.int64)
186
+ text_seg[:, 0] = np.array(prefix, dtype=np.int64)
187
+ text_seg[len(prefix) - 1, 1] = self.audio_bos_token
188
+ concat_inputs_id = np.concatenate([input_ids_list[i], text_seg], axis=0)
189
+ concat_inputs_id_list.append(concat_inputs_id)
190
+
191
+ attention_masks = [np.ones(ids.shape[0], dtype=np.bool_) for ids in concat_inputs_id_list]
192
+ max_len = max(ids.shape[0] for ids in concat_inputs_id_list)
193
+ padded_input_ids, padded_attns = [], []
194
+ pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.text_pad_id
195
+
196
+ for ids, attn in zip(concat_inputs_id_list, attention_masks):
197
+ pad_len = max_len - ids.shape[0]
198
+ input_pad = np.full((pad_len, self.channels + 1), self.audio_channel_pad, dtype=np.int64)
199
+ input_pad[:, 0] = pad_token_id
200
+ padded_input_ids.append(np.concatenate([input_pad, ids]))
201
+ attn_pad = np.zeros(pad_len, dtype=np.bool_)
202
+ padded_attns.append(np.concatenate([attn_pad, attn]))
203
+
204
+ current_input_ids = torch.from_numpy(np.stack(padded_input_ids)).to(device)
205
+ current_attention_mask = torch.from_numpy(np.stack(padded_attns)).to(device)
206
+
207
+ # For multi-turn continuation, concatenate the cached mask and the current prefill mask.
208
+ if self.attention_mask is not None and self.past_key_values is not None:
209
+ current_attention_mask = torch.cat([self.attention_mask, current_attention_mask], dim=-1)
210
+
211
+ outputs = self.model(
212
+ input_ids=current_input_ids,
213
+ attention_mask=current_attention_mask,
214
+ past_key_values=self.past_key_values,
215
+ use_cache=True,
216
+ return_dict=True,
217
+ )
218
+ self.past_key_values = outputs.past_key_values
219
+ self.attention_mask = current_attention_mask
220
+
221
+ backbone_hidden_states = outputs.last_hidden_state[:, -1:, :]
222
+ audio_tokens = self.generate_local_transformer(
223
+ hidden_states=backbone_hidden_states,
224
+ temperature=temperature,
225
+ top_p=top_p,
226
+ top_k=top_k,
227
+ do_sample=do_sample,
228
+ repetition_penalty=repetition_penalty,
229
+ repetition_window=repetition_window,
230
+ generated_tokens=None,
231
+ gen_step=0,
232
+ )
233
+
234
+ self._generated_tokens = [audio_tokens]
235
+ self._last_audio_tokens = audio_tokens
236
+ self._is_stopping = audio_tokens[:, 0] == self.audio_eos_token
237
+ self._step_idx = 1
238
+ return audio_tokens
239
+
240
+ @torch.inference_mode()
241
+ def step(
242
+ self,
243
+ text_token: Optional[Iterable[int] | torch.Tensor | int],
244
+ temperature: float = 0.8,
245
+ top_p: float = 0.6,
246
+ top_k: int = 30,
247
+ do_sample: bool = True,
248
+ repetition_penalty: Optional[float] = 1.1,
249
+ repetition_window: Optional[int] = 50,
250
+ ) -> torch.Tensor:
251
+ if self._last_audio_tokens is None or self.attention_mask is None:
252
+ raise ValueError("You must call prefill() before step().")
253
+ if self.is_finished:
254
+ return self._last_audio_tokens
255
+
256
+ batch_size = self._last_audio_tokens.shape[0]
257
+ if text_token is None:
258
+ text_tokens = [self.text_pad_id] * batch_size
259
+ elif isinstance(text_token, torch.Tensor):
260
+ text_tokens = text_token.detach().cpu().tolist()
261
+ elif isinstance(text_token, (list, tuple, np.ndarray)):
262
+ text_tokens = list(text_token)
263
+ else:
264
+ text_tokens = [int(text_token)]
265
+
266
+ if len(text_tokens) != batch_size:
267
+ raise ValueError(f"text_token batch size mismatch: got {len(text_tokens)}, expected {batch_size}.")
268
+
269
+ device = self._last_audio_tokens.device
270
+ text_t = torch.tensor(text_tokens, device=device, dtype=torch.long)
271
+ step_ids = torch.cat([text_t[:, None, None], self._last_audio_tokens.unsqueeze(1)], dim=2)
272
+ self.attention_mask = torch.cat([self.attention_mask, (~self._is_stopping).unsqueeze(-1)], dim=-1)
273
+
274
+ outputs = self.model(
275
+ input_ids=step_ids,
276
+ attention_mask=self.attention_mask,
277
+ past_key_values=self.past_key_values,
278
+ use_cache=True,
279
+ return_dict=True,
280
+ )
281
+ self.past_key_values = outputs.past_key_values
282
+ backbone_hidden_states = outputs.last_hidden_state[:, -1:, :]
283
+
284
+ history = torch.stack(self._generated_tokens, dim=1) if self._generated_tokens else None
285
+ audio_tokens = self.generate_local_transformer(
286
+ hidden_states=backbone_hidden_states,
287
+ temperature=temperature,
288
+ top_p=top_p,
289
+ top_k=top_k,
290
+ do_sample=do_sample,
291
+ repetition_penalty=repetition_penalty,
292
+ repetition_window=repetition_window,
293
+ generated_tokens=history,
294
+ gen_step=self._step_idx,
295
+ )
296
+
297
+ self._generated_tokens.append(audio_tokens)
298
+ self._last_audio_tokens = audio_tokens
299
+ self._is_stopping |= audio_tokens[:, 0] == self.audio_eos_token
300
+ self._step_idx += 1
301
+ return audio_tokens
302
+
303
+ @torch.inference_mode()
304
+ def finish(
305
+ self,
306
+ max_steps: Optional[int] = None,
307
+ temperature: float = 0.8,
308
+ top_p: float = 0.6,
309
+ top_k: int = 30,
310
+ do_sample: bool = True,
311
+ repetition_penalty: Optional[float] = 1.1,
312
+ repetition_window: Optional[int] = 50,
313
+ ) -> list[torch.Tensor]:
314
+ outputs = []
315
+ steps_left = max_steps if max_steps is not None else self.max_length
316
+ while steps_left > 0 and not self.is_finished:
317
+ outputs.append(
318
+ self.step(
319
+ text_token=None,
320
+ temperature=temperature,
321
+ top_p=top_p,
322
+ top_k=top_k,
323
+ do_sample=do_sample,
324
+ repetition_penalty=repetition_penalty,
325
+ repetition_window=repetition_window,
326
+ )
327
+ )
328
+ steps_left -= 1
329
+ return outputs
330
+
331
+ def generate_local_transformer(
332
+ self,
333
+ hidden_states: torch.Tensor,
334
+ temperature: float,
335
+ top_p: float,
336
+ top_k: int,
337
+ do_sample: bool,
338
+ repetition_penalty: Optional[float],
339
+ repetition_window: Optional[int],
340
+ generated_tokens: Optional[torch.Tensor],
341
+ gen_step: int,
342
+ ) -> torch.Tensor:
343
+ runner = self._get_local_transformer_runner()
344
+ return runner(
345
+ hidden_states=hidden_states,
346
+ temperature=temperature,
347
+ top_p=top_p,
348
+ top_k=top_k,
349
+ do_sample=do_sample,
350
+ repetition_penalty=repetition_penalty,
351
+ repetition_window=repetition_window,
352
+ generated_tokens=generated_tokens,
353
+ gen_step=gen_step,
354
+ )
355
+
356
+ def _generate_local_transformer_impl(
357
+ self,
358
+ hidden_states: torch.Tensor,
359
+ temperature: float,
360
+ top_p: float,
361
+ top_k: int,
362
+ do_sample: bool,
363
+ repetition_penalty: Optional[float],
364
+ repetition_window: Optional[int],
365
+ generated_tokens: Optional[torch.Tensor],
366
+ gen_step: int,
367
+ ) -> torch.Tensor:
368
+ batch_size = hidden_states.shape[0]
369
+ device = hidden_states.device
370
+ local_inputs = hidden_states.reshape(-1, 1, self.model.config.local_config.hidden_size)
371
+ output_token = torch.empty(batch_size, self.channels, dtype=torch.long, device=device)
372
+
373
+ past_key_values = self._build_local_past_key_values()
374
+ local_token = None
375
+
376
+ cache_pos_t = torch.zeros(1, dtype=torch.long, device=device)
377
+
378
+ for i in range(self.channels):
379
+ cache_pos_t.fill_(i)
380
+
381
+ local_outputs = self.model.local_transformer(
382
+ input_ids=local_token,
383
+ inputs_embeds=local_inputs,
384
+ past_key_values=past_key_values,
385
+ cache_position=cache_pos_t,
386
+ codebook_idx=i,
387
+ use_cache=True,
388
+ logits_to_keep=1,
389
+ )
390
+ logits = local_outputs.logits
391
+
392
+ if repetition_penalty and repetition_penalty != 1.0 and generated_tokens is not None:
393
+ logits = self.apply_repetition_penalty(
394
+ scores=logits,
395
+ history_tokens=generated_tokens[:, :gen_step, i],
396
+ penalty=float(repetition_penalty),
397
+ repetition_window=repetition_window,
398
+ )
399
+
400
+ local_token = self.sample_token(
401
+ logits=logits,
402
+ temperature=temperature,
403
+ top_p=top_p,
404
+ top_k=top_k,
405
+ do_sample=do_sample,
406
+ )
407
+ output_token[:, i] = local_token.squeeze(-1)
408
+
409
+ if i == 0:
410
+ local_inputs = None
411
+ return output_token
412
+
413
+ def apply_repetition_penalty(
414
+ self,
415
+ scores: torch.Tensor,
416
+ history_tokens: torch.Tensor,
417
+ penalty: float = 1.1,
418
+ repetition_window: Optional[int] = None,
419
+ ):
420
+ scores_ = scores[:, 0, :]
421
+ ht = history_tokens
422
+
423
+ if repetition_window is not None and repetition_window > 0:
424
+ ht = ht[:, -repetition_window:]
425
+
426
+ cur = scores_.gather(1, ht)
427
+ new = torch.where(cur < 0, cur * penalty, cur / penalty)
428
+ scores_.scatter_(1, ht, new)
429
+ return scores_
430
+
431
+ def sample_token(self, logits, temperature, top_p=0.6, top_k=30, do_sample=True):
432
+ if not do_sample or temperature == 0:
433
+ return torch.argmax(logits, dim=-1)
434
+ logits = logits / temperature
435
+ original_shape = logits.shape
436
+ vocab_size = original_shape[-1]
437
+ reshaped_logits = logits.reshape(-1, vocab_size)
438
+
439
+ if top_k is not None:
440
+ reshaped_logits = self.apply_top_k(reshaped_logits, top_k)
441
+
442
+ if top_p is not None:
443
+ reshaped_logits = self.apply_top_p(reshaped_logits, top_p)
444
+
445
+ probs = F.softmax(reshaped_logits, dim=-1)
446
+ next_tokens_flat = torch.multinomial(probs, num_samples=1)
447
+
448
+ output_shape = original_shape[:-1]
449
+ return next_tokens_flat.view(output_shape)
450
+
451
+ def apply_top_k(self, logits, top_k, filter_value=float("-inf"), min_tokens_to_keep: int = 1):
452
+ if not isinstance(top_k, int) or top_k <= 0:
453
+ raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
454
+ batch_size, vocab_size = logits.shape
455
+ top_k = max(top_k, min_tokens_to_keep)
456
+ top_k = min(top_k, vocab_size)
457
+ indices_to_remove = torch.topk(logits, top_k, dim=-1).values[..., -1, None]
458
+ return logits.masked_fill(logits < indices_to_remove, filter_value)
459
+
460
+ def apply_top_p(self, logits, top_p, filter_value=float("-inf"), min_tokens_to_keep: int = 1):
461
+ top_p = float(top_p)
462
+ if top_p < 0 or top_p > 1.0:
463
+ raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
464
+
465
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
466
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
467
+
468
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
469
+ sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
470
+ indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter(1, sorted_indices, sorted_indices_to_remove)
471
+ logits_processed = logits.masked_fill(indices_to_remove, filter_value)
472
+ return logits_processed
473
+
474
+
475
+ @requires(backends=("torch",))
476
+ class MossTTSRealtimeStreamingSession:
477
+ """Manage text-to-audio streaming for a single conversation."""
478
+
479
+ _split_pattern = re.compile(
480
+ r"[。!?!?\.\u2026]\s*" # sentence boundaries: 。!? ! ? . …
481
+ r"|[,,;;::\u2014\u2013\-]\s*" # short pauses: , , ; ; : : — – -
482
+ r"|\)\s*|\]\s*" # closing brackets: ) ]
483
+ r"|\n"
484
+ )
485
+
486
+ def __init__(
487
+ self,
488
+ inferencer: MossTTSRealtimeInference,
489
+ processor,
490
+ codec=None,
491
+ codec_sample_rate: int = 24000,
492
+ codec_encode_kwargs: Optional[dict] = None,
493
+ prefill_text_len: int = 12,
494
+ text_buffer_size: int = 32,
495
+ min_text_chunk_chars: int = 8,
496
+ temperature: float = 0.8,
497
+ top_p: float = 0.6,
498
+ top_k: int = 30,
499
+ do_sample: bool = True,
500
+ repetition_penalty: Optional[float] = 1.1,
501
+ repetition_window: Optional[int] = 50,
502
+ ):
503
+ self.inferencer = inferencer
504
+ self.processor = processor
505
+ self.tokenizer = processor.tokenizer
506
+ self.codec = codec
507
+ self.codec_sample_rate = codec_sample_rate
508
+ self.codec_encode_kwargs = codec_encode_kwargs or {}
509
+
510
+ self.prefill_text_len = prefill_text_len
511
+ self.text_buffer_size = text_buffer_size
512
+ self.min_text_chunk_chars = min_text_chunk_chars
513
+
514
+ self.temperature = temperature
515
+ self.top_p = top_p
516
+ self.top_k = top_k
517
+ self.do_sample = do_sample
518
+ self.repetition_penalty = repetition_penalty
519
+ self.repetition_window = repetition_window
520
+
521
+ self._voice_prompt_tokens = None
522
+ self._turn_input_ids = None
523
+ self._turn_idx = 0
524
+
525
+ self._text_cache = ""
526
+ self._pending_tokens: list[int] = []
527
+ self._prefilled = False
528
+ self._text_ended = False
529
+
530
+ def set_voice_prompt_tokens(self, audio_tokens: np.ndarray):
531
+ self._voice_prompt_tokens = audio_tokens
532
+
533
+ def set_voice_prompt(self, audio, sample_rate: Optional[int] = None):
534
+ """Set voice prompt from either audio tokens or waveform.
535
+ If `audio` is a 2D array whose shape matches the codebook channels, it is
536
+ treated as audio tokens. Otherwise a codec is required to encode waveform
537
+ prompts into tokens.
538
+ """
539
+ if isinstance(audio, np.ndarray) and audio.ndim == 2:
540
+ if self.processor.channels in audio.shape:
541
+ self._voice_prompt_tokens = audio
542
+ return
543
+ if isinstance(audio, torch.Tensor) and audio.dim() == 2:
544
+ if self.processor.channels in audio.shape:
545
+ self._voice_prompt_tokens = audio.detach().cpu().numpy()
546
+ return
547
+
548
+ if self.codec is None:
549
+ raise ValueError("codec is required to encode waveform prompts.")
550
+
551
+ waveform = audio
552
+ if isinstance(audio, (str, bytes)):
553
+ requires_backends(self, ["torchaudio"])
554
+ wav, sr = torchaudio.load(audio)
555
+ if wav.shape[0] > 1:
556
+ wav = wav.mean(dim=0, keepdim=True)
557
+ waveform = wav.squeeze(0)
558
+ sample_rate = sr
559
+
560
+ if isinstance(waveform, np.ndarray):
561
+ waveform = torch.from_numpy(waveform)
562
+ if not isinstance(waveform, torch.Tensor):
563
+ raise ValueError("Unsupported audio type for voice prompt.")
564
+
565
+ if sample_rate is not None and sample_rate != self.codec_sample_rate:
566
+ requires_backends(self, ["torchaudio"])
567
+ waveform = torchaudio.functional.resample(waveform, sample_rate, self.codec_sample_rate)
568
+
569
+ waveform = waveform.to(self.inferencer.device)
570
+ encode_out = self.codec.encode([waveform], **self.codec_encode_kwargs)
571
+ if isinstance(encode_out, dict):
572
+ if "codes_list" in encode_out:
573
+ tokens = encode_out["codes_list"][0]
574
+ elif "audio_codes" in encode_out:
575
+ tokens = encode_out["audio_codes"][0]
576
+ else:
577
+ raise ValueError("codec.encode output missing audio codes.")
578
+ else:
579
+ tokens = encode_out
580
+ if isinstance(tokens, torch.Tensor):
581
+ tokens = tokens.detach().cpu().numpy()
582
+ self._voice_prompt_tokens = tokens
583
+
584
+ def clear_voice_prompt(self):
585
+ self._voice_prompt_tokens = None
586
+
587
+ def reset_turn(
588
+ self,
589
+ user_text: Optional[str] = None,
590
+ user_audio_tokens: Optional[np.ndarray] = None,
591
+ input_ids: Optional[np.ndarray] = None,
592
+ include_system_prompt: Optional[bool] = None,
593
+ reset_cache: bool = False,
594
+ ):
595
+ if include_system_prompt is None:
596
+ include_system_prompt = self._turn_idx == 0
597
+
598
+ if input_ids is None:
599
+ if user_text is None or user_audio_tokens is None:
600
+ raise ValueError("user_text and user_audio_tokens are required when input_ids is not provided.")
601
+ user_prompt = self.processor.make_user_prompt(user_text, user_audio_tokens)
602
+ if include_system_prompt:
603
+ system_prompt = self.processor.make_ensemble(self._voice_prompt_tokens)
604
+ input_ids = np.concatenate([system_prompt, user_prompt], axis=0)
605
+ else:
606
+ input_ids = user_prompt
607
+
608
+ self._turn_input_ids = input_ids
609
+ self._turn_idx += 1
610
+
611
+ self._text_cache = ""
612
+ self._pending_tokens = []
613
+ self._prefilled = False
614
+ self._text_ended = False
615
+
616
+ self.inferencer.reset_generation_state(keep_cache=not reset_cache)
617
+
618
+ def push_text_tokens(self, tokens: Iterable[int]) -> list[torch.Tensor]:
619
+ self._pending_tokens.extend([int(t) for t in tokens])
620
+ return self._drain_pending_tokens()
621
+
622
+ def push_text(self, text_fragment: str) -> list[torch.Tensor]:
623
+ self._text_cache += text_fragment
624
+ segments = self._extract_text_segments(force=False)
625
+ for segment in segments:
626
+ self._pending_tokens.extend(self._tokenize(segment))
627
+ return self._drain_pending_tokens()
628
+
629
+ def end_text(self) -> list[torch.Tensor]:
630
+ self._text_ended = True
631
+ if self._text_cache:
632
+ self._pending_tokens.extend(self._tokenize(self._text_cache))
633
+ self._text_cache = ""
634
+ return self._drain_pending_tokens()
635
+
636
+ def drain(self, max_steps: Optional[int] = None) -> list[torch.Tensor]:
637
+ if not self._prefilled:
638
+ return []
639
+ return self.inferencer.finish(
640
+ max_steps=max_steps,
641
+ temperature=self.temperature,
642
+ top_p=self.top_p,
643
+ top_k=self.top_k,
644
+ do_sample=self.do_sample,
645
+ repetition_penalty=self.repetition_penalty,
646
+ repetition_window=self.repetition_window,
647
+ )
648
+
649
+ def _tokenize(self, text: str) -> list[int]:
650
+ return self.tokenizer.encode(text, add_special_tokens=False)
651
+
652
+ def _extract_text_segments(self, force: bool) -> list[str]:
653
+ segments = []
654
+ if force:
655
+ if self._text_cache:
656
+ segments.append(self._text_cache)
657
+ self._text_cache = ""
658
+ return segments
659
+
660
+ while self._text_cache:
661
+ cut_idx = None
662
+ if len(self._text_cache) >= self.min_text_chunk_chars:
663
+ matches = list(self._split_pattern.finditer(self._text_cache))
664
+ for match in matches:
665
+ if match.end() >= self.min_text_chunk_chars:
666
+ cut_idx = match.end()
667
+ break
668
+ if cut_idx is None and len(self._text_cache) >= self.text_buffer_size:
669
+ whitespace_idx = self._text_cache.rfind(" ")
670
+ if whitespace_idx != -1:
671
+ cut_idx = whitespace_idx + 1
672
+ if cut_idx is None:
673
+ break
674
+ segments.append(self._text_cache[:cut_idx])
675
+ self._text_cache = self._text_cache[cut_idx:]
676
+ return segments
677
+
678
+ def _prefill_if_needed(self) -> list[torch.Tensor]:
679
+ if self._prefilled:
680
+ return []
681
+ if not self._pending_tokens and not self._text_ended:
682
+ return []
683
+ if len(self._pending_tokens) < self.prefill_text_len and not self._text_ended:
684
+ return []
685
+ if self._turn_input_ids is None:
686
+ raise ValueError("reset_turn must be called before streaming text.")
687
+
688
+ if self._text_ended:
689
+ prefill_len = len(self._pending_tokens)
690
+ else:
691
+ prefill_len = min(len(self._pending_tokens), self.prefill_text_len)
692
+
693
+ if prefill_len == 0:
694
+ return []
695
+
696
+ prefix_tokens = [self._pending_tokens.pop(0) for _ in range(prefill_len)]
697
+ audio_tokens = self.inferencer.prefill(
698
+ input_ids=[self._turn_input_ids],
699
+ text_prefix_ids=[prefix_tokens],
700
+ temperature=self.temperature,
701
+ top_p=self.top_p,
702
+ top_k=self.top_k,
703
+ do_sample=self.do_sample,
704
+ repetition_penalty=None,
705
+ repetition_window=self.repetition_window,
706
+ )
707
+ self._prefilled = True
708
+ return [audio_tokens]
709
+
710
+ def _drain_pending_tokens(self) -> list[torch.Tensor]:
711
+ outputs: list[torch.Tensor] = []
712
+ outputs.extend(self._prefill_if_needed())
713
+ if not self._prefilled:
714
+ return outputs
715
+
716
+ while self._pending_tokens and not self.inferencer.is_finished:
717
+ token = self._pending_tokens.pop(0)
718
+ outputs.append(
719
+ self.inferencer.step(
720
+ token,
721
+ temperature=self.temperature,
722
+ top_p=self.top_p,
723
+ top_k=self.top_k,
724
+ do_sample=self.do_sample,
725
+ repetition_penalty=self.repetition_penalty,
726
+ repetition_window=self.repetition_window,
727
+ )
728
+ )
729
+ return outputs
730
+
731
+
732
+ @requires(backends=("torch",))
733
+ class AudioStreamDecoder:
734
+ """Decode audio tokens into waveform chunks with optional crossfade."""
735
+
736
+ def __init__(
737
+ self,
738
+ codec,
739
+ chunk_frames: int = 40,
740
+ overlap_frames: int = 4,
741
+ initial_chunk_frames: Optional[int] = None,
742
+ decode_chunk_duration: Optional[float] = None,
743
+ decode_kwargs: Optional[dict] = None,
744
+ device: Optional[torch.device] = None,
745
+ ):
746
+ self.codec = codec
747
+ self.chunk_frames = chunk_frames
748
+ self.overlap_frames = overlap_frames
749
+ self.initial_chunk_frames = initial_chunk_frames
750
+ self.decode_chunk_duration = decode_chunk_duration
751
+ self.decode_kwargs = decode_kwargs or {}
752
+ self.device = device
753
+
754
+ self._buffer: list[torch.Tensor] = []
755
+ self._buffer_len = 0
756
+ self._prev_tail: Optional[torch.Tensor] = None
757
+ self._chunks_emitted = 0
758
+
759
+ def push_tokens(self, audio_tokens: np.ndarray | torch.Tensor):
760
+ if isinstance(audio_tokens, np.ndarray):
761
+ audio_tokens = torch.from_numpy(audio_tokens)
762
+ if audio_tokens.dim() != 2:
763
+ raise ValueError(f"Expected [T, C] audio tokens, got {tuple(audio_tokens.shape)}")
764
+ self._buffer.append(audio_tokens)
765
+ self._buffer_len += audio_tokens.shape[0]
766
+
767
+ @property
768
+ def _active_chunk_frames(self) -> int:
769
+ if self.initial_chunk_frames is not None:
770
+ return min(self.initial_chunk_frames + self._chunks_emitted, self.chunk_frames)
771
+ return self.chunk_frames
772
+
773
+ def audio_chunks(self) -> Iterable[torch.Tensor]:
774
+ while self._buffer_len >= self._active_chunk_frames:
775
+ chunk_tokens = self._consume_frames(self._active_chunk_frames)
776
+ wav = self._decode(chunk_tokens)
777
+ self._chunks_emitted += 1
778
+ yield self._apply_crossfade(wav)
779
+
780
+ def flush(self) -> Optional[torch.Tensor]:
781
+ if self._buffer_len == 0:
782
+ return None
783
+ chunk_tokens = self._consume_frames(self._buffer_len)
784
+ wav = self._decode(chunk_tokens)
785
+ return self._apply_crossfade(wav, final_chunk=True)
786
+
787
+ def _consume_frames(self, num_frames: int) -> torch.Tensor:
788
+ frames = []
789
+ remaining = num_frames
790
+ while remaining > 0 and self._buffer:
791
+ head = self._buffer[0]
792
+ if head.shape[0] <= remaining:
793
+ frames.append(head)
794
+ remaining -= head.shape[0]
795
+ self._buffer.pop(0)
796
+ else:
797
+ frames.append(head[:remaining])
798
+ self._buffer[0] = head[remaining:]
799
+ remaining = 0
800
+ self._buffer_len -= num_frames - remaining
801
+ return torch.cat(frames, dim=0)
802
+
803
+ def _decode(self, tokens: torch.Tensor) -> torch.Tensor:
804
+ device = self.device
805
+ if device is None:
806
+ if hasattr(self.codec, "device"):
807
+ device = self.codec.device
808
+ else:
809
+ try:
810
+ device = next(self.codec.parameters()).device
811
+ except Exception:
812
+ device = None
813
+ if device is not None:
814
+ tokens = tokens.to(device)
815
+ tokens_t = tokens.permute(1, 0)
816
+ decode_kwargs = dict(self.decode_kwargs) if self.decode_kwargs else {}
817
+ decoded = self.codec.decode(tokens_t, chunk_duration=self.decode_chunk_duration, **decode_kwargs)
818
+ if isinstance(decoded, dict):
819
+ wav = decoded["audio"][0]
820
+ else:
821
+ wav = decoded
822
+ if isinstance(wav, np.ndarray):
823
+ wav = torch.from_numpy(wav)
824
+ if wav.dim() > 1:
825
+ wav = wav.squeeze(0)
826
+ return wav
827
+
828
+ def _apply_crossfade(self, wav: torch.Tensor, final_chunk: bool = False) -> torch.Tensor:
829
+ if self.overlap_frames <= 0:
830
+ return wav
831
+ if self._prev_tail is None:
832
+ self._prev_tail = wav[-self._overlap_samples(wav) :].clone() if not final_chunk else None
833
+ return wav
834
+
835
+ overlap = self._overlap_samples(wav)
836
+ if overlap == 0:
837
+ return wav
838
+
839
+ prev_tail = self._prev_tail
840
+ if prev_tail.numel() < overlap:
841
+ overlap = prev_tail.numel()
842
+ if overlap == 0:
843
+ return wav
844
+
845
+ fade_out = torch.linspace(1.0, 0.0, overlap, device=wav.device)
846
+ fade_in = 1.0 - fade_out
847
+ cross = prev_tail[-overlap:] * fade_out + wav[:overlap] * fade_in
848
+ merged = torch.cat([prev_tail[:-overlap], cross, wav[overlap:]], dim=-1)
849
+
850
+ self._prev_tail = None if final_chunk else wav[-overlap:].clone()
851
+ return merged
852
+
853
+ def _overlap_samples(self, wav: torch.Tensor) -> int:
854
+ if self.chunk_frames <= 0:
855
+ return 0
856
+ return int(wav.numel() * (self.overlap_frames / self.chunk_frames))
857
+
858
+
859
+ class TextDeltaTokenizer:
860
+ """
861
+ Convert LLM streaming text (delta) into “incremental token IDs”.
862
+ Notes:
863
+ - The input is a delta that is progressively appended to the same string
864
+ (consistent with the common delta output behavior in vLLM).
865
+ - Each time, re-encode the *full text* with the tokenizer, then take only
866
+ the newly added token IDs.
867
+ - This guarantees that tokenization is consistent with the final complete
868
+ text, avoiding boundary mismatches caused by tokenizing partial segments.
869
+ """
870
+
871
+ def __init__(self, tokenizer, *, hold_back: int = 3):
872
+ self.tokenizer = tokenizer
873
+ self.hold_back = max(0, int(hold_back))
874
+ self._text = ""
875
+ self._all_ids: list[int] = []
876
+ self._emitted_count: int = 0
877
+
878
+ @property
879
+ def text(self) -> str:
880
+ return self._text
881
+
882
+ @property
883
+ def token_ids(self) -> list[int]:
884
+ return list(self._all_ids)
885
+
886
+ def push_delta(self, delta: str) -> list[int]:
887
+ """Append a text delta and return newly stable token ids (may be empty)."""
888
+ if not delta:
889
+ return []
890
+ self._text += str(delta)
891
+ self._all_ids = self.tokenizer.encode(self._text, add_special_tokens=False)
892
+ # Keep the tail un-emitted because the latest tokens can still change.
893
+ stable_count = max(self._emitted_count, len(self._all_ids) - self.hold_back)
894
+ new_ids = self._all_ids[self._emitted_count : stable_count]
895
+ self._emitted_count = stable_count
896
+ return new_ids
897
+
898
+ def flush(self) -> list[int]:
899
+ """Emit all remaining token ids at end of stream."""
900
+ self._all_ids = self.tokenizer.encode(self._text, add_special_tokens=False)
901
+ remaining = self._all_ids[self._emitted_count :]
902
+ self._emitted_count = len(self._all_ids)
903
+ return remaining
904
+
905
+
906
+ def _sanitize_audio_tokens(
907
+ tokens: torch.Tensor,
908
+ *,
909
+ codebook_size: int,
910
+ audio_eos_token: int,
911
+ ) -> tuple[torch.Tensor, bool]:
912
+ """Trim rows after EOS/invalid tokens and return whether decoding should stop."""
913
+ if tokens.dim() == 1:
914
+ tokens = tokens.unsqueeze(0)
915
+ if tokens.numel() == 0:
916
+ return tokens, False
917
+
918
+ eos_rows = (tokens[:, 0] == audio_eos_token).nonzero(as_tuple=False)
919
+ invalid_rows = ((tokens < 0) | (tokens >= codebook_size)).any(dim=1)
920
+
921
+ stop_idx = None
922
+ if eos_rows.numel() > 0:
923
+ stop_idx = int(eos_rows[0].item())
924
+ if invalid_rows.any():
925
+ invalid_idx = int(invalid_rows.nonzero(as_tuple=False)[0].item())
926
+ stop_idx = invalid_idx if stop_idx is None else min(stop_idx, invalid_idx)
927
+
928
+ if stop_idx is not None:
929
+ return tokens[:stop_idx], True
930
+ return tokens, False
931
+
932
+
933
+ def _maybe_codec_streaming(codec, *, batch_size: int):
934
+ if codec is None or not hasattr(codec, "streaming"):
935
+ return contextlib.nullcontext()
936
+ return codec.streaming(batch_size=batch_size)
937
+
938
+
939
+ @requires(backends=("torch",))
940
+ class MossTTSRealtimeTextStreamBridge:
941
+ """
942
+ Bridge: external LLM streaming text (delta) -> TTS streaming audio chunks.
943
+ Usage overview:
944
+ - First configure `MossTTSRealtimeStreamingSession` (especially `prefill_text_len=12`).
945
+ - Provide an `AudioStreamDecoder`, then continuously feed the LLM delta text via
946
+ `push_text_delta()`.
947
+ - Once the accumulated token count reaches `prefill_text_len`, the session will
948
+ start generating audio tokens; the bridge will immediately decode them into WAV
949
+ chunks and yield them.
950
+ """
951
+
952
+ def __init__(
953
+ self,
954
+ session: MossTTSRealtimeStreamingSession,
955
+ decoder: AudioStreamDecoder,
956
+ *,
957
+ codebook_size: Optional[int] = None,
958
+ audio_eos_token: Optional[int] = None,
959
+ batch_size: int = 1,
960
+ ):
961
+ self.session = session
962
+ self.decoder = decoder
963
+ self.batch_size = int(batch_size)
964
+
965
+ if codebook_size is None:
966
+ codebook_size = int(getattr(getattr(session, "codec", None), "codebook_size", 1024))
967
+ if audio_eos_token is None:
968
+ audio_eos_token = int(getattr(session.inferencer, "audio_eos_token", 1026))
969
+
970
+ self.codebook_size = int(codebook_size)
971
+ self.audio_eos_token = int(audio_eos_token)
972
+
973
+ def push_text_delta(self, delta: str) -> Iterator[torch.Tensor]:
974
+ """
975
+ Push a chunk of incremental text output from the LLM and return newly generated WAV chunks.
976
+ Internally, this directly calls `session.push_text()`, which segments the text
977
+ based on punctuation/length and then tokenizes the *entire segment* at once,
978
+ avoiding the prefix instability issues of incremental BPE tokenization.
979
+ """
980
+ audio_frames = self.session.push_text(delta)
981
+ yield from self._decode_audio_frames(audio_frames)
982
+
983
+ def push_text_tokens(self, token_ids: Sequence[int]) -> Iterator[torch.Tensor]:
984
+ """Push token ids directly (for sources that stream token ids)."""
985
+ if not token_ids:
986
+ return
987
+ audio_frames = self.session.push_text_tokens(token_ids)
988
+ yield from self._decode_audio_frames(audio_frames)
989
+
990
+ def finish(self, *, drain_step: int = 1) -> Iterator[torch.Tensor]:
991
+ """Mark text stream end and emit all remaining audio chunks (including flush)."""
992
+ audio_frames = self.session.end_text()
993
+ yield from self._decode_audio_frames(audio_frames)
994
+
995
+ while True:
996
+ more_frames = self.session.drain(max_steps=drain_step)
997
+ if not more_frames:
998
+ break
999
+ yield from self._decode_audio_frames(more_frames)
1000
+ if self.session.inferencer.is_finished:
1001
+ break
1002
+
1003
+ final = self.decoder.flush()
1004
+ if final is not None and final.numel() > 0:
1005
+ yield final.detach().cpu()
1006
+
1007
+ def stream_from_text_deltas(self, deltas: Iterable[str], *, drain_step: int = 1) -> Iterator[torch.Tensor]:
1008
+ """Consume a full delta iterator and continuously yield waveform chunks."""
1009
+ with _maybe_codec_streaming(getattr(self.session, "codec", None), batch_size=self.batch_size):
1010
+ for delta in deltas:
1011
+ yield from self.push_text_delta(delta)
1012
+ yield from self.finish(drain_step=drain_step)
1013
+
1014
+ def _decode_audio_frames(self, audio_frames: list[torch.Tensor]) -> Iterator[torch.Tensor]:
1015
+ for frame in audio_frames:
1016
+ tokens = frame
1017
+ if tokens.dim() == 3:
1018
+ tokens = tokens[0]
1019
+ if tokens.dim() != 2:
1020
+ raise ValueError(f"Expected [B, C] or [1, C] audio tokens, got {tuple(tokens.shape)}")
1021
+ if tokens.shape[0] != 1:
1022
+ raise ValueError(
1023
+ f"This bridge currently supports batch_size=1 for decoding, got batch={tokens.shape[0]}."
1024
+ )
1025
+
1026
+ tokens, stop = _sanitize_audio_tokens(
1027
+ tokens,
1028
+ codebook_size=self.codebook_size,
1029
+ audio_eos_token=self.audio_eos_token,
1030
+ )
1031
+ if tokens.numel() == 0:
1032
+ if stop:
1033
+ break
1034
+ continue
1035
+
1036
+ self.decoder.push_tokens(tokens.detach())
1037
+ for wav in self.decoder.audio_chunks():
1038
+ if wav.numel() == 0:
1039
+ continue
1040
+ yield wav.detach().cpu()
1041
+ if stop:
1042
+ break
1043
+
1044
+
1045
+ __all__ = [
1046
+ "AudioStreamDecoder",
1047
+ "MossTTSRealtimeInference",
1048
+ "MossTTSRealtimeStreamingSession",
1049
+ "MossTTSRealtimeTextStreamBridge",
1050
+ "TextDeltaTokenizer",
1051
+ ]