gaoyang07 commited on
Commit
b1cede0
·
1 Parent(s): 8282490

Update modeling

Browse files
modeling_mossttsrealtime.py CHANGED
@@ -11,8 +11,7 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
-
15
- """MossTTSRealtime backbone model."""
16
 
17
  from __future__ import annotations
18
 
@@ -23,12 +22,12 @@ import torch
23
  import torch.nn as nn
24
  import torch.nn.functional as F
25
 
 
26
  from transformers.cache_utils import Cache
27
  from transformers.modeling_outputs import ModelOutput
28
  from transformers.modeling_utils import PreTrainedModel
29
  from transformers.models.qwen3 import Qwen3Model
30
  from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3DecoderLayer
31
-
32
  from .configuration_mossttsrealtime import MossTTSRealtimeConfig
33
  from .modeling_mossttsrealtime_local import MossTTSRealtimeLocalTransformerForCausalLM
34
 
@@ -51,21 +50,14 @@ class MossTTSRealtimePretrainedModel(PreTrainedModel):
51
  }
52
 
53
  def _init_weights(self, module):
54
-
55
- from transformers import initialization as init
56
-
57
  std = self.config.initializer_range
58
  if isinstance(module, nn.Linear):
59
- # module.weight.data.normal_(mean=0.0, std=std)
60
  init.normal_(module.weight, mean=0.0, std=std)
61
  if module.bias is not None:
62
- # module.bias.data.zero_()
63
  init.zeros_(module.bias)
64
  elif isinstance(module, nn.Embedding):
65
- # module.weight.data.normal_(mean=0.0, std=std)
66
  init.normal_(module.weight, mean=0.0, std=std)
67
  if module.padding_idx is not None:
68
- # module.weight.data[module.padding_idx].zero_()
69
  init.zeros_(module.weight[module.padding_idx])
70
 
71
 
@@ -145,7 +137,9 @@ class MossTTSRealtime(MossTTSRealtimePretrainedModel):
145
  past_key_values=past_key_values,
146
  inputs_embeds=inputs_embeds,
147
  use_cache=use_cache,
148
- output_hidden_states=True,
 
 
149
  cache_position=cache_position,
150
  **kwargs,
151
  )
@@ -156,11 +150,12 @@ class MossTTSRealtime(MossTTSRealtimePretrainedModel):
156
  audio_labels = labels[:, :, 1:]
157
  train_mask = ~(audio_labels == -100).all(dim=-1)
158
  local_input_ids = audio_labels[train_mask][..., : self.config.rvq - 1]
159
- local_input_ids[local_input_ids == -100] = 1024
160
  local_input_ids = F.pad(local_input_ids, (1, 0), value=0)
161
 
162
  train_idx = train_mask.nonzero(as_tuple=True)
163
- local_hidden_states = outputs[0][train_idx[0], train_idx[1] - 1, :].reshape(
 
164
  -1, 1, self.config.local_config.hidden_size
165
  )
166
  local_labels = audio_labels[train_mask]
@@ -175,7 +170,7 @@ class MossTTSRealtime(MossTTSRealtimePretrainedModel):
175
  )
176
  loss = local_outputs.loss
177
 
178
- return MossTTSRealtimeOutputWithPast(
179
  loss=loss,
180
  logits=None,
181
  past_key_values=outputs.past_key_values,
@@ -187,6 +182,9 @@ class MossTTSRealtime(MossTTSRealtimePretrainedModel):
187
  local_hidden_states=local_outputs.hidden_states if local_outputs is not None else None,
188
  local_attentions=local_outputs.attentions if local_outputs is not None else None,
189
  )
 
 
 
190
 
191
 
192
- __all__ = ["MossTTSRealtime", "MossTTSRealtimeConfig", "MossTTSRealtimeOutputWithPast", "MossTTSRealtimePretrainedModel", "Qwen3Model"]
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+ """MossTTSRealtime model."""
 
15
 
16
  from __future__ import annotations
17
 
 
22
  import torch.nn as nn
23
  import torch.nn.functional as F
24
 
25
+ from transformers import initialization as init
26
  from transformers.cache_utils import Cache
27
  from transformers.modeling_outputs import ModelOutput
28
  from transformers.modeling_utils import PreTrainedModel
29
  from transformers.models.qwen3 import Qwen3Model
30
  from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3DecoderLayer
 
31
  from .configuration_mossttsrealtime import MossTTSRealtimeConfig
32
  from .modeling_mossttsrealtime_local import MossTTSRealtimeLocalTransformerForCausalLM
33
 
 
50
  }
51
 
52
  def _init_weights(self, module):
 
 
 
53
  std = self.config.initializer_range
54
  if isinstance(module, nn.Linear):
 
55
  init.normal_(module.weight, mean=0.0, std=std)
56
  if module.bias is not None:
 
57
  init.zeros_(module.bias)
58
  elif isinstance(module, nn.Embedding):
 
59
  init.normal_(module.weight, mean=0.0, std=std)
60
  if module.padding_idx is not None:
 
61
  init.zeros_(module.weight[module.padding_idx])
62
 
63
 
 
137
  past_key_values=past_key_values,
138
  inputs_embeds=inputs_embeds,
139
  use_cache=use_cache,
140
+ output_attentions=output_attentions,
141
+ output_hidden_states=output_hidden_states,
142
+ return_dict=True,
143
  cache_position=cache_position,
144
  **kwargs,
145
  )
 
150
  audio_labels = labels[:, :, 1:]
151
  train_mask = ~(audio_labels == -100).all(dim=-1)
152
  local_input_ids = audio_labels[train_mask][..., : self.config.rvq - 1]
153
+ local_input_ids[local_input_ids == -100] = self.config.audio_pad_token
154
  local_input_ids = F.pad(local_input_ids, (1, 0), value=0)
155
 
156
  train_idx = train_mask.nonzero(as_tuple=True)
157
+ hidden_positions = torch.clamp(train_idx[1] - 1, min=0)
158
+ local_hidden_states = outputs.last_hidden_state[train_idx[0], hidden_positions, :].reshape(
159
  -1, 1, self.config.local_config.hidden_size
160
  )
161
  local_labels = audio_labels[train_mask]
 
170
  )
171
  loss = local_outputs.loss
172
 
173
+ output = MossTTSRealtimeOutputWithPast(
174
  loss=loss,
175
  logits=None,
176
  past_key_values=outputs.past_key_values,
 
182
  local_hidden_states=local_outputs.hidden_states if local_outputs is not None else None,
183
  local_attentions=local_outputs.attentions if local_outputs is not None else None,
184
  )
185
+ if not return_dict:
186
+ return output.to_tuple()
187
+ return output
188
 
189
 
190
+ __all__ = ["MossTTSRealtime", "MossTTSRealtimeConfig", "MossTTSRealtimeOutputWithPast", "MossTTSRealtimePretrainedModel"]
modeling_mossttsrealtime_local.py CHANGED
@@ -11,7 +11,6 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
-
15
  """Local transformer used by MossTTSRealtime for RVQ codebook decoding."""
16
 
17
  from __future__ import annotations
@@ -22,7 +21,7 @@ import torch
22
  import torch.nn as nn
23
 
24
  from transformers.activations import ACT2FN
25
- from transformers.cache_utils import Cache, DynamicCache, StaticCache
26
  from transformers.generation import GenerationMixin
27
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
28
  from transformers.modeling_layers import GradientCheckpointingLayer
@@ -31,9 +30,8 @@ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_u
31
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
32
  from transformers.masking_utils import create_causal_mask
33
  from transformers.processing_utils import Unpack
34
- from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
35
  from transformers.loss.loss_utils import ForCausalLMLoss
36
-
37
  from .configuration_mossttsrealtime import MossTTSRealtimeLocalTransformerConfig
38
 
39
  logger = logging.get_logger(__name__)
@@ -221,7 +219,10 @@ class MossTTSRealtimeLocalTransformerDecoderLayer(GradientCheckpointingLayer):
221
 
222
 
223
  class MossTTSRealtimeLocalTransformerPreTrainedModel(PreTrainedModel):
 
 
224
  config: MossTTSRealtimeLocalTransformerConfig
 
225
  base_model_prefix = "local_transformer"
226
  supports_gradient_checkpointing = True
227
  _no_split_modules = ["MossTTSRealtimeLocalTransformerDecoderLayer"]
@@ -231,6 +232,7 @@ class MossTTSRealtimeLocalTransformerPreTrainedModel(PreTrainedModel):
231
  _supports_flash_attn = True
232
  _can_compile_fullgraph = True
233
  _supports_attention_backend = True
 
234
  _can_record_outputs = {
235
  "hidden_states": MossTTSRealtimeLocalTransformerDecoderLayer,
236
  "attentions": MossTTSRealtimeLocalTransformerAttention,
@@ -297,11 +299,12 @@ class MossTTSRealtimeLocalTransformer(MossTTSRealtimeLocalTransformerPreTrainedM
297
  if position_ids is not None and not torch.compiler.is_compiling():
298
  position_ids = None
299
 
300
- if (input_ids is None) ^ (inputs_embeds is not None):
301
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
302
 
303
  if use_cache and past_key_values is None:
304
- past_key_values = StaticCache(config=self.config, max_cache_len=16, device=inputs_embeds.device)
 
305
 
306
  if cache_position is None:
307
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@@ -311,17 +314,32 @@ class MossTTSRealtimeLocalTransformer(MossTTSRealtimeLocalTransformerPreTrainedM
311
 
312
  if inputs_embeds is None:
313
  if codebook_idx is not None:
 
 
 
 
314
  if input_ids.ndim == 1:
315
  input_ids = input_ids.unsqueeze(1)
316
  token_emb = self.embed_tokens[codebook_idx - 1](input_ids[:, 0]).unsqueeze(1) # [B,1,H]
317
  inputs_embeds = token_emb
318
  else:
319
- codebook_idxs = torch.clamp(cache_position - 1, min=0)
320
- inputs_embeds = self.embed_tokens[codebook_idxs - 1](input_ids)
321
-
322
- input_ids_are_first_codebook = cache_position[0] == 0
 
 
 
 
 
 
 
 
 
 
 
323
  if backbone_last_hidden_state is not None:
324
- inputs_embeds[:, 0] = backbone_last_hidden_state
325
  else:
326
  if not torch.compiler.is_compiling() and input_ids_are_first_codebook:
327
  logger.warning(
@@ -414,8 +432,14 @@ class MossTTSRealtimeLocalTransformerForCausalLM(MossTTSRealtimeLocalTransformer
414
  hs = hidden_states[:, slice_indices, :]
415
 
416
  if cache_position is not None:
 
 
417
  logits = self.local_lm_heads[codebook_idx](hs[:, 0, :]).unsqueeze(1)
418
  else:
 
 
 
 
419
  logits_list = []
420
  for i in range(hs.shape[1]):
421
  logits_list.append(self.local_lm_heads[i](hs[:, i, :]))
@@ -434,9 +458,6 @@ class MossTTSRealtimeLocalTransformerForCausalLM(MossTTSRealtimeLocalTransformer
434
  attentions=outputs.attentions,
435
  )
436
 
437
-
438
-
439
-
440
  __all__ = [
441
  "MossTTSRealtimeLocalTransformer",
442
  "MossTTSRealtimeLocalTransformerAttention",
@@ -446,4 +467,4 @@ __all__ = [
446
  "MossTTSRealtimeLocalTransformerPreTrainedModel",
447
  "MossTTSRealtimeLocalTransformerRMSNorm",
448
  "MossTTSRealtimeLocalTransformerRotaryEmbedding",
449
- ]
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
 
14
  """Local transformer used by MossTTSRealtime for RVQ codebook decoding."""
15
 
16
  from __future__ import annotations
 
21
  import torch.nn as nn
22
 
23
  from transformers.activations import ACT2FN
24
+ from transformers.cache_utils import Cache, StaticCache
25
  from transformers.generation import GenerationMixin
26
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
27
  from transformers.modeling_layers import GradientCheckpointingLayer
 
30
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
31
  from transformers.masking_utils import create_causal_mask
32
  from transformers.processing_utils import Unpack
 
33
  from transformers.loss.loss_utils import ForCausalLMLoss
34
+ from transformers.utils import TransformersKwargs, logging
35
  from .configuration_mossttsrealtime import MossTTSRealtimeLocalTransformerConfig
36
 
37
  logger = logging.get_logger(__name__)
 
219
 
220
 
221
  class MossTTSRealtimeLocalTransformerPreTrainedModel(PreTrainedModel):
222
+
223
+ config_class = MossTTSRealtimeLocalTransformerConfig
224
  config: MossTTSRealtimeLocalTransformerConfig
225
+
226
  base_model_prefix = "local_transformer"
227
  supports_gradient_checkpointing = True
228
  _no_split_modules = ["MossTTSRealtimeLocalTransformerDecoderLayer"]
 
232
  _supports_flash_attn = True
233
  _can_compile_fullgraph = True
234
  _supports_attention_backend = True
235
+
236
  _can_record_outputs = {
237
  "hidden_states": MossTTSRealtimeLocalTransformerDecoderLayer,
238
  "attentions": MossTTSRealtimeLocalTransformerAttention,
 
299
  if position_ids is not None and not torch.compiler.is_compiling():
300
  position_ids = None
301
 
302
+ if (input_ids is None) == (inputs_embeds is None):
303
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
304
 
305
  if use_cache and past_key_values is None:
306
+ device = inputs_embeds.device if inputs_embeds is not None else input_ids.device
307
+ past_key_values = StaticCache(config=self.config, max_cache_len=self.config.rvq, device=device)
308
 
309
  if cache_position is None:
310
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
314
 
315
  if inputs_embeds is None:
316
  if codebook_idx is not None:
317
+ if codebook_idx <= 0:
318
+ raise ValueError(f"`codebook_idx` must be in [1, {len(self.embed_tokens)}], got {codebook_idx}.")
319
+ if codebook_idx > len(self.embed_tokens):
320
+ raise ValueError(f"`codebook_idx` must be in [1, {len(self.embed_tokens)}], got {codebook_idx}.")
321
  if input_ids.ndim == 1:
322
  input_ids = input_ids.unsqueeze(1)
323
  token_emb = self.embed_tokens[codebook_idx - 1](input_ids[:, 0]).unsqueeze(1) # [B,1,H]
324
  inputs_embeds = token_emb
325
  else:
326
+ if input_ids.shape[1] != cache_position.shape[0]:
327
+ raise ValueError(
328
+ "`input_ids` and `cache_position` must align in sequence length: "
329
+ f"got {input_ids.shape[1]} and {cache_position.shape[0]}."
330
+ )
331
+ codebook_idxs = torch.clamp(cache_position - 1, min=0, max=len(self.embed_tokens) - 1)
332
+ inputs_embeds = torch.stack(
333
+ [
334
+ self.embed_tokens[codebook_idx](input_ids[:, seq_idx])
335
+ for seq_idx, codebook_idx in enumerate(codebook_idxs.tolist())
336
+ ],
337
+ dim=1,
338
+ )
339
+
340
+ input_ids_are_first_codebook = bool(cache_position[0] == 0)
341
  if backbone_last_hidden_state is not None:
342
+ inputs_embeds[:, 0, :] = backbone_last_hidden_state[:, 0, :]
343
  else:
344
  if not torch.compiler.is_compiling() and input_ids_are_first_codebook:
345
  logger.warning(
 
432
  hs = hidden_states[:, slice_indices, :]
433
 
434
  if cache_position is not None:
435
+ if codebook_idx is None:
436
+ raise ValueError("`codebook_idx` must be provided when `cache_position` is provided.")
437
  logits = self.local_lm_heads[codebook_idx](hs[:, 0, :]).unsqueeze(1)
438
  else:
439
+ if hs.shape[1] > len(self.local_lm_heads):
440
+ raise ValueError(
441
+ f"Cannot project {hs.shape[1]} codebooks with only {len(self.local_lm_heads)} LM heads."
442
+ )
443
  logits_list = []
444
  for i in range(hs.shape[1]):
445
  logits_list.append(self.local_lm_heads[i](hs[:, i, :]))
 
458
  attentions=outputs.attentions,
459
  )
460
 
 
 
 
461
  __all__ = [
462
  "MossTTSRealtimeLocalTransformer",
463
  "MossTTSRealtimeLocalTransformerAttention",
 
467
  "MossTTSRealtimeLocalTransformerPreTrainedModel",
468
  "MossTTSRealtimeLocalTransformerRMSNorm",
469
  "MossTTSRealtimeLocalTransformerRotaryEmbedding",
470
+ ]
processing_mossttsrealtime.py CHANGED
@@ -11,7 +11,6 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
-
15
  """Processing utilities for MossTTSRealtime."""
16
 
17
  from __future__ import annotations
@@ -20,14 +19,19 @@ from typing import Iterable, Optional
20
 
21
  import numpy as np
22
 
 
 
23
 
24
- class MossTTSRealtimeProcessor:
25
  """Builds MossTTSRealtime prompt inputs with text and audio codebooks.
26
 
27
  This processor focuses on preparing the mixed text/audio token layout expected by MossTTSRealtime.
28
  It does not perform audio encoding/decoding by itself.
29
  """
30
 
 
 
 
31
  def __init__(
32
  self,
33
  tokenizer,
@@ -40,7 +44,9 @@ class MossTTSRealtimeProcessor:
40
  audio_eos_token: int = 1026,
41
  delay_tokens_len: int = 12,
42
  ):
43
- self.tokenizer = tokenizer
 
 
44
  self.channels = channels
45
  self.audio_channel_pad = audio_channel_pad
46
  self.audio_bos_token = audio_bos_token
@@ -58,7 +64,7 @@ class MossTTSRealtimeProcessor:
58
  "capabilities, allowing you to generate the corresponding speech based on the text given in the assistant."
59
  "<|im_end|>\n"
60
  )
61
- self.ttsbase_system_prompt = tts_system_prompt
62
 
63
  def _convert_token_to_id(self, token: str) -> int:
64
  if hasattr(self.tokenizer, "convert_tokens_to_ids"):
@@ -73,7 +79,7 @@ class MossTTSRealtimeProcessor:
73
  return int(token_ids[0])
74
 
75
  def make_voice_clone_prompt(self, prompt_audio_tokens_len: int) -> str:
76
- padded_audio_prompt = f"{'<|audio_pad|>' * prompt_audio_tokens_len}"
77
  voice_clone = (
78
  "<|im_start|>context\n"
79
  "The assistant section should be synthesized using the following voice timbre:"
@@ -85,6 +91,7 @@ class MossTTSRealtimeProcessor:
85
  tokens = np.array(audio_tokens)
86
  if tokens.ndim != 2:
87
  raise ValueError(f"Expected 2D audio tokens, got shape {tokens.shape}")
 
88
  if tokens.shape[0] == self.channels:
89
  tokens = tokens.T
90
  elif tokens.shape[1] == self.channels:
@@ -101,9 +108,9 @@ class MossTTSRealtimeProcessor:
101
  if prompt_audio_tokens is not None:
102
  prompt_audio_tokens = self._normalize_audio_tokens(prompt_audio_tokens)
103
  prompt_audio_tokens = prompt_audio_tokens[:, : self.channels]
104
- system_prompt_text = f"{self.ttsbase_system_prompt}" + f"{self.make_voice_clone_prompt(prompt_audio_tokens.shape[0])}"
105
  else:
106
- system_prompt_text = f"{self.ttsbase_system_prompt}"
107
 
108
  system_prompt_tokens = self.tokenizer(system_prompt_text)["input_ids"]
109
  system_prompt_tokens_full = np.full(
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
 
14
  """Processing utilities for MossTTSRealtime."""
15
 
16
  from __future__ import annotations
 
19
 
20
  import numpy as np
21
 
22
+ from transformers.processing_utils import ProcessorMixin
23
+
24
 
25
+ class MossTTSRealtimeProcessor(ProcessorMixin):
26
  """Builds MossTTSRealtime prompt inputs with text and audio codebooks.
27
 
28
  This processor focuses on preparing the mixed text/audio token layout expected by MossTTSRealtime.
29
  It does not perform audio encoding/decoding by itself.
30
  """
31
 
32
+ attributes = ["tokenizer"]
33
+ tokenizer_class = "AutoTokenizer"
34
+
35
  def __init__(
36
  self,
37
  tokenizer,
 
44
  audio_eos_token: int = 1026,
45
  delay_tokens_len: int = 12,
46
  ):
47
+ super().__init__(tokenizer=tokenizer)
48
+ self.audio_pad_token = audio_pad_token
49
+ self.text_pad_token = text_pad_token
50
  self.channels = channels
51
  self.audio_channel_pad = audio_channel_pad
52
  self.audio_bos_token = audio_bos_token
 
64
  "capabilities, allowing you to generate the corresponding speech based on the text given in the assistant."
65
  "<|im_end|>\n"
66
  )
67
+ self.tts_system_prompt = tts_system_prompt
68
 
69
  def _convert_token_to_id(self, token: str) -> int:
70
  if hasattr(self.tokenizer, "convert_tokens_to_ids"):
 
79
  return int(token_ids[0])
80
 
81
  def make_voice_clone_prompt(self, prompt_audio_tokens_len: int) -> str:
82
+ padded_audio_prompt = f"{self.audio_pad_token * prompt_audio_tokens_len}"
83
  voice_clone = (
84
  "<|im_start|>context\n"
85
  "The assistant section should be synthesized using the following voice timbre:"
 
91
  tokens = np.array(audio_tokens)
92
  if tokens.ndim != 2:
93
  raise ValueError(f"Expected 2D audio tokens, got shape {tokens.shape}")
94
+ # Accept [channels, T] or [T, channels], and slice to expected channels if needed.
95
  if tokens.shape[0] == self.channels:
96
  tokens = tokens.T
97
  elif tokens.shape[1] == self.channels:
 
108
  if prompt_audio_tokens is not None:
109
  prompt_audio_tokens = self._normalize_audio_tokens(prompt_audio_tokens)
110
  prompt_audio_tokens = prompt_audio_tokens[:, : self.channels]
111
+ system_prompt_text = f"{self.tts_system_prompt}" + f"{self.make_voice_clone_prompt(prompt_audio_tokens.shape[0])}"
112
  else:
113
+ system_prompt_text = f"{self.tts_system_prompt}"
114
 
115
  system_prompt_tokens = self.tokenizer(system_prompt_text)["input_ids"]
116
  system_prompt_tokens_full = np.full(
streaming_mossttsrealtime.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
@@ -11,22 +11,25 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
-
15
  """Streaming inference utilities for MossTTSRealtime."""
16
 
17
  from __future__ import annotations
18
 
19
-
20
- import re
21
- import numpy as np
22
  import contextlib
 
 
23
 
 
24
  import torch
25
  import torch.nn.functional as F
26
- import torchaudio
27
  from transformers.cache_utils import StaticCache
 
28
  from transformers.utils.import_utils import requires
29
- from typing import Iterable, Iterator, List, Optional, Sequence
 
 
 
30
 
31
  @requires(backends=("torch",))
32
  class MossTTSRealtimeInference:
@@ -74,11 +77,11 @@ class MossTTSRealtimeInference:
74
  return self._is_stopping is not None and bool(self._is_stopping.all())
75
 
76
  def reset_generation_state(self, keep_cache: bool = True):
77
- # When keep_cache=True, retain the attention_mask so that its length matches past_key_values.
78
- # This is used for concatenation in the next prefill step.
79
  if not keep_cache:
80
  self.past_key_values = None
81
  self.attention_mask = None
 
 
82
  self._generated_tokens = []
83
  self._is_stopping = None
84
  self._last_audio_tokens = None
@@ -172,6 +175,7 @@ class MossTTSRealtimeInference:
172
  current_input_ids = torch.from_numpy(np.stack(padded_input_ids)).to(device)
173
  current_attention_mask = torch.from_numpy(np.stack(padded_attns)).to(device)
174
 
 
175
  if self.attention_mask is not None and self.past_key_values is not None:
176
  current_attention_mask = torch.cat([self.attention_mask, current_attention_mask], dim=-1)
177
 
@@ -321,7 +325,7 @@ class MossTTSRealtimeInference:
321
  for i in range(self.channels):
322
  cache_pos_t.fill_(i)
323
 
324
- local_outputs = self.model.local_transformer(
325
  input_ids=local_token,
326
  inputs_embeds=local_inputs,
327
  past_key_values=past_key_values,
@@ -335,7 +339,7 @@ class MossTTSRealtimeInference:
335
  if repetition_penalty and repetition_penalty != 1.0 and generated_tokens is not None:
336
  logits = self.apply_repetition_penalty(
337
  scores=logits,
338
- history_tokens=generated_tokens[:, :gen_step, i],
339
  penalty=float(repetition_penalty),
340
  repetition_window=repetition_window,
341
  )
@@ -355,22 +359,22 @@ class MossTTSRealtimeInference:
355
 
356
  def apply_repetition_penalty(
357
  self,
358
- scores: torch.Tensor,
359
  history_tokens: torch.Tensor,
360
  penalty: float = 1.1,
361
  repetition_window: Optional[int] = None,
362
  ):
363
  scores_ = scores[:, 0, :]
364
- B, V = scores_.shape
365
  ht = history_tokens
366
 
367
  if repetition_window is not None and repetition_window > 0:
368
- ht = ht[:, -repetition_window:]
369
 
370
  ht_sorted, _ = torch.sort(ht, dim=1)
371
  uniq = torch.unique_consecutive(ht_sorted, dim=1)
372
 
373
- b_idx = torch.arange(B, device=uniq.device).unsqueeze(1).expand_as(uniq)
374
  b_flat = b_idx.reshape(-1)
375
  t_flat = uniq.reshape(-1)
376
 
@@ -430,9 +434,9 @@ class MossTTSRealtimeStreamingSession:
430
  """Manage text-to-audio streaming for a single conversation."""
431
 
432
  _split_pattern = re.compile(
433
- r"[。!?!?\.\u2026]\s*"
434
- r"|[,,;;::\u2014\u2013\-]\s*"
435
- r"|\)\s*|\]\s*"
436
  r"|\n"
437
  )
438
 
@@ -504,6 +508,7 @@ class MossTTSRealtimeStreamingSession:
504
 
505
  waveform = audio
506
  if isinstance(audio, (str, bytes)):
 
507
  wav, sr = torchaudio.load(audio)
508
  if wav.shape[0] > 1:
509
  wav = wav.mean(dim=0, keepdim=True)
@@ -516,6 +521,7 @@ class MossTTSRealtimeStreamingSession:
516
  raise ValueError("Unsupported audio type for voice prompt.")
517
 
518
  if sample_rate is not None and sample_rate != self.codec_sample_rate:
 
519
  waveform = torchaudio.functional.resample(waveform, sample_rate, self.codec_sample_rate)
520
 
521
  waveform = waveform.to(self.inferencer.device)
@@ -839,17 +845,19 @@ class TextDeltaTokenizer:
839
  return list(self._all_ids)
840
 
841
  def push_delta(self, delta: str) -> list[int]:
 
842
  if not delta:
843
  return []
844
  self._text += str(delta)
845
  self._all_ids = self.tokenizer.encode(self._text, add_special_tokens=False)
846
- # hold_back token 不输出(尾部可能随后续 delta 而改变)
847
  stable_count = max(self._emitted_count, len(self._all_ids) - self.hold_back)
848
  new_ids = self._all_ids[self._emitted_count : stable_count]
849
  self._emitted_count = stable_count
850
  return new_ids
851
 
852
  def flush(self) -> list[int]:
 
853
  self._all_ids = self.tokenizer.encode(self._text, add_special_tokens=False)
854
  remaining = self._all_ids[self._emitted_count :]
855
  self._emitted_count = len(self._all_ids)
@@ -862,6 +870,7 @@ def _sanitize_audio_tokens(
862
  codebook_size: int,
863
  audio_eos_token: int,
864
  ) -> tuple[torch.Tensor, bool]:
 
865
  if tokens.dim() == 1:
866
  tokens = tokens.unsqueeze(0)
867
  if tokens.numel() == 0:
@@ -935,12 +944,14 @@ class MossTTSRealtimeTextStreamBridge:
935
  yield from self._decode_audio_frames(audio_frames)
936
 
937
  def push_text_tokens(self, token_ids: Sequence[int]) -> Iterator[torch.Tensor]:
 
938
  if not token_ids:
939
  return
940
  audio_frames = self.session.push_text_tokens(token_ids)
941
  yield from self._decode_audio_frames(audio_frames)
942
 
943
  def finish(self, *, drain_step: int = 1) -> Iterator[torch.Tensor]:
 
944
  audio_frames = self.session.end_text()
945
  yield from self._decode_audio_frames(audio_frames)
946
 
@@ -957,7 +968,7 @@ class MossTTSRealtimeTextStreamBridge:
957
  yield final.detach().cpu()
958
 
959
  def stream_from_text_deltas(self, deltas: Iterable[str], *, drain_step: int = 1) -> Iterator[torch.Tensor]:
960
- """一口气消费一个 delta 迭代器,并持续 yield wav chunk。"""
961
  with _maybe_codec_streaming(getattr(self.session, "codec", None), batch_size=self.batch_size):
962
  for delta in deltas:
963
  yield from self.push_text_delta(delta)
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
 
14
  """Streaming inference utilities for MossTTSRealtime."""
15
 
16
  from __future__ import annotations
17
 
 
 
 
18
  import contextlib
19
+ import re
20
+ from typing import Iterable, Iterator, List, Optional, Sequence
21
 
22
+ import numpy as np
23
  import torch
24
  import torch.nn.functional as F
25
+
26
  from transformers.cache_utils import StaticCache
27
+ from transformers.utils import is_torchaudio_available, requires_backends
28
  from transformers.utils.import_utils import requires
29
+
30
+ if is_torchaudio_available():
31
+ import torchaudio
32
+
33
 
34
  @requires(backends=("torch",))
35
  class MossTTSRealtimeInference:
 
77
  return self._is_stopping is not None and bool(self._is_stopping.all())
78
 
79
  def reset_generation_state(self, keep_cache: bool = True):
 
 
80
  if not keep_cache:
81
  self.past_key_values = None
82
  self.attention_mask = None
83
+ # Keep the mask when reusing cache so it stays aligned with past_key_values.
84
+ # This allows concatenation with the next turn prefill mask.
85
  self._generated_tokens = []
86
  self._is_stopping = None
87
  self._last_audio_tokens = None
 
175
  current_input_ids = torch.from_numpy(np.stack(padded_input_ids)).to(device)
176
  current_attention_mask = torch.from_numpy(np.stack(padded_attns)).to(device)
177
 
178
+ # For multi-turn continuation, concatenate the cached mask and the current prefill mask.
179
  if self.attention_mask is not None and self.past_key_values is not None:
180
  current_attention_mask = torch.cat([self.attention_mask, current_attention_mask], dim=-1)
181
 
 
325
  for i in range(self.channels):
326
  cache_pos_t.fill_(i)
327
 
328
+ local_outputs = self.model.local_transformer(
329
  input_ids=local_token,
330
  inputs_embeds=local_inputs,
331
  past_key_values=past_key_values,
 
339
  if repetition_penalty and repetition_penalty != 1.0 and generated_tokens is not None:
340
  logits = self.apply_repetition_penalty(
341
  scores=logits,
342
+ history_tokens=generated_tokens[:, :gen_step, i],
343
  penalty=float(repetition_penalty),
344
  repetition_window=repetition_window,
345
  )
 
359
 
360
  def apply_repetition_penalty(
361
  self,
362
+ scores: torch.Tensor,
363
  history_tokens: torch.Tensor,
364
  penalty: float = 1.1,
365
  repetition_window: Optional[int] = None,
366
  ):
367
  scores_ = scores[:, 0, :]
368
+ batch_size = scores_.shape[0]
369
  ht = history_tokens
370
 
371
  if repetition_window is not None and repetition_window > 0:
372
+ ht = ht[:, -repetition_window:]
373
 
374
  ht_sorted, _ = torch.sort(ht, dim=1)
375
  uniq = torch.unique_consecutive(ht_sorted, dim=1)
376
 
377
+ b_idx = torch.arange(batch_size, device=uniq.device).unsqueeze(1).expand_as(uniq)
378
  b_flat = b_idx.reshape(-1)
379
  t_flat = uniq.reshape(-1)
380
 
 
434
  """Manage text-to-audio streaming for a single conversation."""
435
 
436
  _split_pattern = re.compile(
437
+ r"[。!?!?\.\u2026]\s*" # sentence boundaries: 。!? ! ? . …
438
+ r"|[,,;;::\u2014\u2013\-]\s*" # short pauses: , , ; ; : : — – -
439
+ r"|\)\s*|\]\s*" # closing brackets: ) ]
440
  r"|\n"
441
  )
442
 
 
508
 
509
  waveform = audio
510
  if isinstance(audio, (str, bytes)):
511
+ requires_backends(self, ["torchaudio"])
512
  wav, sr = torchaudio.load(audio)
513
  if wav.shape[0] > 1:
514
  wav = wav.mean(dim=0, keepdim=True)
 
521
  raise ValueError("Unsupported audio type for voice prompt.")
522
 
523
  if sample_rate is not None and sample_rate != self.codec_sample_rate:
524
+ requires_backends(self, ["torchaudio"])
525
  waveform = torchaudio.functional.resample(waveform, sample_rate, self.codec_sample_rate)
526
 
527
  waveform = waveform.to(self.inferencer.device)
 
845
  return list(self._all_ids)
846
 
847
  def push_delta(self, delta: str) -> list[int]:
848
+ """Append a text delta and return newly stable token ids (may be empty)."""
849
  if not delta:
850
  return []
851
  self._text += str(delta)
852
  self._all_ids = self.tokenizer.encode(self._text, add_special_tokens=False)
853
+ # Keep the tail un-emitted because the latest tokens can still change.
854
  stable_count = max(self._emitted_count, len(self._all_ids) - self.hold_back)
855
  new_ids = self._all_ids[self._emitted_count : stable_count]
856
  self._emitted_count = stable_count
857
  return new_ids
858
 
859
  def flush(self) -> list[int]:
860
+ """Emit all remaining token ids at end of stream."""
861
  self._all_ids = self.tokenizer.encode(self._text, add_special_tokens=False)
862
  remaining = self._all_ids[self._emitted_count :]
863
  self._emitted_count = len(self._all_ids)
 
870
  codebook_size: int,
871
  audio_eos_token: int,
872
  ) -> tuple[torch.Tensor, bool]:
873
+ """Trim rows after EOS/invalid tokens and return whether decoding should stop."""
874
  if tokens.dim() == 1:
875
  tokens = tokens.unsqueeze(0)
876
  if tokens.numel() == 0:
 
944
  yield from self._decode_audio_frames(audio_frames)
945
 
946
  def push_text_tokens(self, token_ids: Sequence[int]) -> Iterator[torch.Tensor]:
947
+ """Push token ids directly (for sources that stream token ids)."""
948
  if not token_ids:
949
  return
950
  audio_frames = self.session.push_text_tokens(token_ids)
951
  yield from self._decode_audio_frames(audio_frames)
952
 
953
  def finish(self, *, drain_step: int = 1) -> Iterator[torch.Tensor]:
954
+ """Mark text stream end and emit all remaining audio chunks (including flush)."""
955
  audio_frames = self.session.end_text()
956
  yield from self._decode_audio_frames(audio_frames)
957
 
 
968
  yield final.detach().cpu()
969
 
970
  def stream_from_text_deltas(self, deltas: Iterable[str], *, drain_step: int = 1) -> Iterator[torch.Tensor]:
971
+ """Consume a full delta iterator and continuously yield waveform chunks."""
972
  with _maybe_codec_streaming(getattr(self.session, "codec", None), batch_size=self.batch_size):
973
  for delta in deltas:
974
  yield from self.push_text_delta(delta)