Li-Ruixiao commited on
Commit
e4aa3d2
·
1 Parent(s): 87cc7fb

Add initial implementation of MossTTSDelay model, configuration, and processing utilities

Browse files
__init__.py ADDED
File without changes
assets/prompt1.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2839f28ad240479cb73f5b374bed748608f2d6639e45efe3f2727c3aaecdde22
3
- size 232236
 
 
 
 
assets/prompt2.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d924b8fb75e1ad7bddd207dd4f9ba71fc867422b95f569ac6aa2262b7695cc14
3
- size 174512
 
 
 
 
assets/ref1.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:32e84754bbe11be6082baba99236217e238d3dbfb97ff2545a3e675e031e5fdd
3
- size 140692
 
 
 
 
assets/ref2.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e5112b5e2bef2a727534af85da1e56048a5ab5552de7aa7cbb5f48b0fa4f5eec
3
- size 448172
 
 
 
 
config.json ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "moss_tts_delay",
3
+ "architectures": [
4
+ "MossTTSDelayModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_moss_tts.MossTTSDelayConfig",
8
+ "AutoModel": "modeling_moss_tts.MossTTSDelayModel"
9
+ },
10
+ "dtype": "bfloat16",
11
+ "initializer_range": 0.02,
12
+ "language_config": {
13
+ "_name_or_path": "Qwen/Qwen3-8B",
14
+ "architectures": [
15
+ "Qwen3ForCausalLM"
16
+ ],
17
+ "attention_bias": false,
18
+ "attention_dropout": 0.0,
19
+ "bos_token_id": 151643,
20
+ "eos_token_id": 151645,
21
+ "pad_token_id": 151643,
22
+ "head_dim": 128,
23
+ "hidden_act": "silu",
24
+ "hidden_size": 4096,
25
+ "initializer_range": 0.02,
26
+ "intermediate_size": 12288,
27
+ "layer_types": [
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention",
44
+ "full_attention",
45
+ "full_attention",
46
+ "full_attention",
47
+ "full_attention",
48
+ "full_attention",
49
+ "full_attention",
50
+ "full_attention",
51
+ "full_attention",
52
+ "full_attention",
53
+ "full_attention",
54
+ "full_attention",
55
+ "full_attention",
56
+ "full_attention",
57
+ "full_attention",
58
+ "full_attention",
59
+ "full_attention",
60
+ "full_attention",
61
+ "full_attention",
62
+ "full_attention",
63
+ "full_attention"
64
+ ],
65
+ "max_position_embeddings": 40960,
66
+ "max_window_layers": 36,
67
+ "model_type": "qwen3",
68
+ "num_attention_heads": 32,
69
+ "num_hidden_layers": 36,
70
+ "num_key_value_heads": 8,
71
+ "rms_norm_eps": 1e-06,
72
+ "rope_scaling": null,
73
+ "rope_theta": 1000000,
74
+ "sliding_window": null,
75
+ "use_cache": true,
76
+ "use_sliding_window": false,
77
+ "vocab_size": 155648
78
+ },
79
+ "n_vq": 32,
80
+ "audio_vocab_size": 1024,
81
+ "audio_user_slot_token_id": 151654,
82
+ "audio_assistant_gen_slot_token_id": 151656,
83
+ "audio_assistant_delay_slot_token_id": 151662,
84
+ "audio_start_token_id": 151652,
85
+ "audio_end_token_id": 151653,
86
+ "audio_pad_code": 1024,
87
+ "sampling_rate": 24000,
88
+ "transformers_version": "4.57.1"
89
+ }
configuration_moss_tts.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ MossTTSDelay model configuration """
16
+
17
+ from typing import Optional, Union
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+ from transformers.models.qwen3 import Qwen3Config
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class MossTTSDelayConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`MossTTSDelayModel`]. It is used to instantiate an
28
+ MossTTSDelay model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of the MossTTSDelay [MossTTSDelay-8B](https://huggingface.co/OpenMOSS/mosstts-8b) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ language_config (`Union[Qwen3Config, dict]`, *optional*):
36
+ Configuration for the backbone language model (Qwen3).
37
+ initializer_range (`float`, *optional*, defaults to 0.02):
38
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
39
+ n_vq (`int`, *optional*, defaults to 32):
40
+ Number of additional VQ (Vector Quantization) heads/channels for audio.
41
+ Determines the number of codebooks used in the audio representation.
42
+ audio_vocab_size (`int`, *optional*, defaults to 1024):
43
+ Vocabulary size for the audio tokens (codebooks 1 to N).
44
+ audio_user_slot_token_id (`int`, *optional*, defaults to 151654):
45
+ The specific token ID used as a placeholder/slot for user-side audio inputs in the prompt.
46
+ audio_assistant_gen_slot_token_id (`int`, *optional*, defaults to 151656):
47
+ The specific token ID representing the generation slot for the assistant's audio output.
48
+ Acting as the trigger for the TTS generation process.
49
+ audio_assistant_delay_slot_token_id (`int`, *optional*, defaults to 151662):
50
+ The token ID used in the 'Delay Pattern' paradigm to represent the delayed/offset positions
51
+ between different VQ channels.
52
+ audio_start_token_id (`int`, *optional*, defaults to 151652):
53
+ Special token ID used to denote the start of an audio sequence in the stream.
54
+ audio_end_token_id (`int`, *optional*, defaults to 151653):
55
+ Special token ID used to denote the end of an audio sequence (EOS for audio).
56
+ audio_pad_code (`int`, *optional*, defaults to 1024):
57
+ The padding value used within the audio VQ codebooks. Typically equals `audio_vocab_size`.
58
+ """
59
+ model_type = "moss_tts_delay"
60
+ keys_to_ignore_at_inference = ["past_key_values"]
61
+
62
+ def __init__(
63
+ self,
64
+ language_config: Optional[Union[Qwen3Config, dict]] = None,
65
+ initializer_range: float = 0.02,
66
+ n_vq: int = 32,
67
+ pad_token_id: int = 151643,
68
+ im_start_token_id: int = 151644,
69
+ im_end_token_id: int = 151645,
70
+ audio_vocab_size: int = 1024,
71
+ audio_user_slot_token_id: int = 151654,
72
+ audio_assistant_gen_slot_token_id: int = 151656,
73
+ audio_assistant_delay_slot_token_id: int = 151662,
74
+ audio_start_token_id: int = 151652,
75
+ audio_end_token_id: int = 151653,
76
+ audio_pad_code: int = 1024,
77
+ sampling_rate: int = 24000,
78
+ **kwargs,
79
+ ):
80
+ if isinstance(language_config, dict):
81
+ self.language_config = Qwen3Config(**language_config)
82
+ elif language_config is None:
83
+ self.language_config = Qwen3Config()
84
+ else:
85
+ self.language_config = language_config
86
+
87
+ self.initializer_range = initializer_range
88
+ self.n_vq = n_vq
89
+ self.audio_vocab_size = audio_vocab_size
90
+ self.audio_user_slot_token_id = audio_user_slot_token_id
91
+ self.audio_assistant_gen_slot_token_id = audio_assistant_gen_slot_token_id
92
+ self.audio_assistant_delay_slot_token_id = audio_assistant_delay_slot_token_id
93
+ self.audio_start_token_id = audio_start_token_id
94
+ self.audio_end_token_id = audio_end_token_id
95
+ self.audio_pad_code = audio_pad_code
96
+ self.sampling_rate = sampling_rate
97
+
98
+ self.hidden_size = self.language_config.hidden_size
99
+ self.vocab_size = self.language_config.vocab_size
100
+ self.im_start_token_id = self.language_config
101
+ self.pad_token_id = pad_token_id
102
+ self.im_start_token_id = im_start_token_id
103
+ self.im_end_token_id = im_end_token_id
104
+
105
+
106
+ super().__init__(**kwargs)
107
+
108
+ def to_dict(self):
109
+ output = super().to_dict()
110
+ if hasattr(self.language_config, "to_dict"):
111
+ output["language_config"] = self.language_config.to_dict()
112
+ else:
113
+ output["language_config"] = self.language_config
114
+ return output
inference_utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import torch.nn.functional as F
4
+ from typing import Optional, List, Tuple
5
+ from tqdm import tqdm
6
+
7
+
8
+ def apply_top_k(logits, top_k):
9
+ batch_size, vocab_size = logits.shape
10
+ top_k = min(top_k, vocab_size)
11
+ top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1)
12
+ filtered_logits = torch.full_like(logits, float("-inf"))
13
+ batch_indices = torch.arange(batch_size).unsqueeze(-1)
14
+ filtered_logits[batch_indices, top_k_indices] = top_k_values
15
+ return filtered_logits
16
+
17
+
18
+ def apply_top_p(logits, top_p):
19
+ probs = F.softmax(logits, dim=-1)
20
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
21
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
22
+ sorted_indices_to_remove = cumulative_probs > top_p
23
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
24
+ sorted_indices_to_remove[..., 0] = False
25
+ batch_size = logits.shape[0]
26
+ filtered_logits = logits.clone()
27
+ for i in range(batch_size):
28
+ indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
29
+ filtered_logits[i, indices_to_remove] = float("-inf")
30
+ return filtered_logits
31
+
32
+
33
+ def apply_top_p_optimized(logits, top_p):
34
+ probs = F.softmax(logits, dim=-1)
35
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
36
+
37
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
38
+
39
+ sorted_indices_to_remove = cumulative_probs > top_p
40
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
41
+ sorted_indices_to_remove[..., 0] = False
42
+
43
+ indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
44
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
45
+ )
46
+
47
+ logits[indices_to_remove] = float("-inf")
48
+ return logits
49
+
50
+
51
+ def apply_repetition_penalty_delay_pattern(
52
+ logits: torch.Tensor,
53
+ prev_tokens: torch.LongTensor,
54
+ penalty: float,
55
+ ):
56
+ """
57
+ logits: [B, H, V] or [N, V]
58
+ prev_tokens: [B, T, H] or [N, T] or [B, H]
59
+
60
+ Apply the repetition penalty independently for each H (VQ head).
61
+ """
62
+ if penalty == 1.0 or prev_tokens is None:
63
+ return logits
64
+
65
+ vocab_size = logits.size(-1)
66
+
67
+ # Case 1: regular [N, V] (text layer)
68
+ if logits.dim() == 2:
69
+ prev_tokens_flat = prev_tokens.reshape(-1)
70
+ unique_tokens = torch.unique(prev_tokens_flat)
71
+
72
+ token_logits = logits[:, unique_tokens]
73
+ pos_mask = token_logits > 0
74
+ token_logits[pos_mask] /= penalty
75
+ token_logits[~pos_mask] *= penalty
76
+ logits[:, unique_tokens] = token_logits
77
+ return logits
78
+
79
+ # Case 2: Delay Pattern audio [B, H, V]
80
+ assert logits.dim() == 3, "Delay Pattern audio logits must be [B, H, V]"
81
+ B, H, V = logits.shape
82
+
83
+ for h in range(H):
84
+ # prev_tokens_h: [B, T] or [B]
85
+ prev_tokens_h = prev_tokens[..., h].reshape(-1)
86
+ unique_tokens = torch.unique(prev_tokens_h)
87
+
88
+ if unique_tokens.numel() == 0:
89
+ continue
90
+
91
+ token_logits = logits[:, h, unique_tokens]
92
+ pos_mask = token_logits > 0
93
+ token_logits[pos_mask] /= penalty
94
+ token_logits[~pos_mask] *= penalty
95
+ logits[:, h, unique_tokens] = token_logits
96
+
97
+ return logits
98
+
99
+
100
+ def sample_token(
101
+ logits,
102
+ prev_tokens: Optional[torch.LongTensor] = None,
103
+ repetition_penalty: float = 1.0,
104
+ top_p=None,
105
+ top_k=None,
106
+ do_sample=True,
107
+ ):
108
+ vocab_size = logits.size(-1)
109
+
110
+ # ===== Repetition Penalty (before reshaping!) =====
111
+ if prev_tokens is not None and repetition_penalty != 1.0:
112
+ logits = apply_repetition_penalty_delay_pattern(
113
+ logits,
114
+ prev_tokens,
115
+ repetition_penalty,
116
+ )
117
+
118
+ if not do_sample:
119
+ return torch.argmax(logits, dim=-1)
120
+
121
+ # ===== Only flatten after this, for top-k / top-p / multinomial =====
122
+ original_shape = logits.shape
123
+ reshaped_logits = logits.view(-1, vocab_size)
124
+
125
+ if top_k is not None and top_k > 0:
126
+ reshaped_logits = apply_top_k(reshaped_logits, top_k)
127
+
128
+ if top_p is not None and top_p < 1.0:
129
+ reshaped_logits = apply_top_p_optimized(reshaped_logits, top_p)
130
+
131
+ probs = F.softmax(reshaped_logits, dim=-1)
132
+ next_tokens = torch.multinomial(probs, num_samples=1)
133
+
134
+ return next_tokens.view(original_shape[:-1])
135
+
136
+
137
+ def find_last_equal_C(tensor, C):
138
+ """
139
+ tensor: torch.Tensor of shape [batch_size, seq_len]
140
+ C: scalar value to match
141
+ Returns: torch.Tensor of shape [batch_size] with last indices
142
+ """
143
+ mask = (tensor == C).int() # Shape: [batch_size, seq_len], bool tensor
144
+ flipped_mask = mask.flip(dims=[1]) # Flip along sequence dimension
145
+ flipped_indices = flipped_mask.argmax(dim=1) # First True in flipped
146
+ seq_len = tensor.shape[1]
147
+ last_indices = (seq_len - 1) - flipped_indices # Convert to original indices
148
+
149
+ # Optional: Handle cases with no C (set to -1), though problem assumes existence
150
+ actual_values = tensor[torch.arange(tensor.shape[0]), last_indices]
151
+ no_match = actual_values != C
152
+ last_indices[no_match] = -1
153
+
154
+ return last_indices
modeling_moss_tts.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Modeling classes for MossTTSDelay. """
16
+
17
+ from dataclasses import dataclass
18
+ from typing import List, Optional, Tuple, Union
19
+ from tqdm import tqdm
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch.nn import CrossEntropyLoss
24
+
25
+ from transformers.modeling_utils import PreTrainedModel
26
+ from transformers.modeling_outputs import ModelOutput
27
+ from transformers.utils import (
28
+ add_start_docstrings,
29
+ add_start_docstrings_to_model_forward,
30
+ logging,
31
+ replace_return_docstrings,
32
+ )
33
+ from transformers.cache_utils import Cache
34
+ from transformers.models.qwen3 import Qwen3Model
35
+ from transformers import initialization as init
36
+
37
+ from .configuration_moss_tts import MossTTSDelayConfig
38
+ from .inference_utils import sample_token, find_last_equal_C
39
+
40
+ try:
41
+ from .processing_moss_tts import UserMessage, AssistantMessage, MossTTSDelayProcessor
42
+ except Exception:
43
+ UserMessage = None
44
+ AssistantMessage = None
45
+ MossTTSDelayProcessor = None
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+ _CONFIG_FOR_DOC = "MossTTSDelayConfig"
50
+
51
+
52
+ @dataclass
53
+ class MossTTSDelayOutputWithPast(ModelOutput):
54
+ """
55
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
56
+
57
+ Args:
58
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
59
+ Weighted sum of channel losses.
60
+ all_sum_losses (`torch.FloatTensor` of shape `(batch_size, n_vq + 1)`, *optional*):
61
+ Sum of losses for each sample and each channel before averaging.
62
+ all_token_nums (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
63
+ Number of non-masked tokens per sample.
64
+ sample_losses (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
65
+ Loss per sample.
66
+ channel_losses (`torch.FloatTensor` of shape `(n_vq + 1,)`, *optional*):
67
+ Loss per channel (text head + vq heads).
68
+ logits (`List[torch.FloatTensor]`, *optional*):
69
+ List of prediction scores from each head.
70
+ past_key_values (`Cache`, *optional*):
71
+ Pre-computed hidden-states (key and values in the self-attention blocks).
72
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
73
+ Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, +
74
+ one for the output of each layer).
75
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed):
76
+ Tuple of torch.FloatTensor (one for each layer) of the attention weights.
77
+ """
78
+ loss: Optional[torch.FloatTensor] = None
79
+ all_sum_losses: Optional[torch.FloatTensor] = None
80
+ all_token_nums: Optional[torch.LongTensor] = None
81
+ sample_losses: Optional[torch.FloatTensor] = None
82
+ channel_losses: Optional[torch.FloatTensor] = None
83
+ logits: Optional[List[torch.FloatTensor]] = None
84
+ past_key_values: Optional[Cache] = None
85
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
86
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
87
+
88
+
89
+
90
+
91
+ class MossTTSDelayPreTrainedModel(PreTrainedModel):
92
+ config_class = MossTTSDelayConfig
93
+ base_model_prefix = "model"
94
+ supports_gradient_checkpointing = True
95
+ _no_split_modules = ["Qwen3DecoderLayer"]
96
+ _skip_keys_device_placement = "past_key_values"
97
+ _supports_flash_attn = True
98
+ _supports_flash_attn_2 = True
99
+ _supports_sdpa = True
100
+ _supports_flex_attn = True
101
+
102
+ def _init_weights(self, module):
103
+ """
104
+ Transformers 5.0+ safe init:
105
+ - MUST use transformers.initialization helpers
106
+ - MUST respect param._is_hf_initialized to avoid overwriting ckpt-loaded params
107
+ """
108
+ # Let HF handle its standard modules first (LayerNorm, Linear, Embedding, etc.)
109
+ super()._init_weights(module)
110
+
111
+ # Pick a std consistent with HF conventions
112
+ # Prefer model/text config initializer_range if present.
113
+ std = None
114
+ if hasattr(self.config, "initializer_range"):
115
+ std = self.config.initializer_range
116
+ elif hasattr(self.config, "language_config") and hasattr(self.config.language_config, "initializer_range"):
117
+ std = self.config.language_config.initializer_range
118
+ else:
119
+ std = 0.02
120
+
121
+ # Initialize extra audio embeddings
122
+ if isinstance(module, nn.Embedding):
123
+ # Only touch our extra embeddings (avoid double touching LM's embeddings if not desired)
124
+ # If you prefer, you can skip this check and rely on super()._init_weights for all embeddings.
125
+ if getattr(module, "num_embeddings", None) == self.config.audio_vocab_size + 1:
126
+ init.normal_(module.weight, mean=0.0, std=std)
127
+ # If you later set padding_idx, you must explicitly zero it (and respect _is_hf_initialized!)
128
+ # init.zeros_ will internally check param flags, but slicing needs manual care.
129
+
130
+ # Initialize multi-head projections you added
131
+ if isinstance(module, nn.Linear):
132
+ # For your lm_heads, super()._init_weights already covers typical Linear.
133
+ # This block is only needed if you have custom Linear variants later.
134
+ pass
135
+
136
+
137
+
138
+ MOSSTTS_START_DOCSTRING = r"""
139
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
140
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
141
+ etc.)
142
+
143
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
144
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
145
+ and behavior.
146
+
147
+ Parameters:
148
+ config ([`MossTTSDelayConfig`]):
149
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
150
+ load the weights associated with the model, only the configuration. Check out the
151
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
152
+ """
153
+
154
+
155
+ @add_start_docstrings(
156
+ "The MossTTSDelay Model architecture tailored for Text-to-Speech generation with multi-head VQ prediction.",
157
+ MOSSTTS_START_DOCSTRING,
158
+ )
159
+ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
160
+ UserMessage = UserMessage
161
+ AssistantMessage = AssistantMessage
162
+ Processor = MossTTSDelayProcessor
163
+
164
+ def __init__(self, config: MossTTSDelayConfig):
165
+ super().__init__(config)
166
+ self.config = config
167
+
168
+ config.language_config.torch_dtype = config.torch_dtype
169
+
170
+ self.language_model = Qwen3Model(config.language_config)
171
+
172
+ # Audio VQ Embeddings (Extra channels)
173
+ # Note: input_ids[..., 0] uses Qwen's embedding.
174
+ # input_ids[..., 1:] use these extensions.
175
+ self.emb_ext = nn.ModuleList()
176
+ for vq_idx in range(self.config.n_vq):
177
+ # Add +1 for potential padding/special tokens logic if strictly required by upstream data prep
178
+ self.emb_ext.append(
179
+ nn.Embedding(self.config.audio_vocab_size + 1, config.language_config.hidden_size, padding_idx=None)
180
+ )
181
+
182
+ # Multi-Head Prediction Layers
183
+ # Head 0: Main language head
184
+ # Head 1..N: Audio VQ heads
185
+ self.lm_heads = nn.ModuleList([
186
+ nn.Linear(config.language_config.hidden_size, config.language_config.vocab_size, bias=False)
187
+ ])
188
+ for vq_idx in range(self.config.n_vq):
189
+ self.lm_heads.append(
190
+ nn.Linear(config.language_config.hidden_size, self.config.audio_vocab_size + 1, bias=False)
191
+ )
192
+
193
+ # Initialize weights and apply final processing
194
+ self.post_init()
195
+
196
+ def get_input_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
197
+ """
198
+ Computes the combined embeddings from text and multiple audio VQ channels.
199
+
200
+ Args:
201
+ input_ids: Shape (Batch, Seq_Len, 1 + n_vq)
202
+ """
203
+ # Base Text/Content Embedding
204
+ # input_ids[..., 0] is standard text or semantic tokens
205
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids[..., 0])
206
+
207
+ # Add VQ Embeddings
208
+ for i, embed_layer in enumerate(self.emb_ext):
209
+ # i corresponds to channel i+1 in input_ids
210
+ # We assume the data pipeline ensures indices are within range
211
+ inputs_embeds = inputs_embeds + embed_layer(input_ids[..., i + 1])
212
+
213
+ return inputs_embeds
214
+
215
+ def set_input_embeddings(self, value):
216
+ self.language_model.embed_tokens = value
217
+
218
+ def get_output_embeddings(self):
219
+ # Returning a list of heads might break some HF utilities expecting a single head.
220
+ # However, for custom models, this is acceptable.
221
+ return self.lm_heads
222
+
223
+ @add_start_docstrings_to_model_forward(MOSSTTS_START_DOCSTRING)
224
+ @replace_return_docstrings(output_type=MossTTSDelayOutputWithPast, config_class=_CONFIG_FOR_DOC)
225
+ def forward(
226
+ self,
227
+ input_ids: Optional[torch.LongTensor] = None,
228
+ attention_mask: Optional[torch.Tensor] = None,
229
+ position_ids: Optional[torch.LongTensor] = None,
230
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
231
+ inputs_embeds: Optional[torch.FloatTensor] = None,
232
+ labels: Optional[torch.LongTensor] = None,
233
+ use_cache: Optional[bool] = None,
234
+ output_attentions: Optional[bool] = None,
235
+ cache_position: Optional[torch.LongTensor] = None,
236
+ hidden_out_layers: Optional[List[int]] = None,
237
+ channelwise_loss_weight: Optional[List[float]] = None,
238
+ **kwargs,
239
+ ) -> Union[Tuple, MossTTSDelayOutputWithPast]:
240
+ r"""
241
+ Args:
242
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 1 + n_vq)`):
243
+ Indices of input sequence tokens in the vocabulary.
244
+ Dimension 2 contains: [Text/Semantics, VQ_0, VQ_1, ..., VQ_N].
245
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length, 1 + n_vq)`, *optional*):
246
+ Labels for computing the masked language modeling loss.
247
+ channelwise_loss_weight (`List[float]`, *optional*):
248
+ Manual weights for summing losses across different heads (Text vs Audio channels).
249
+
250
+ Returns:
251
+ """
252
+
253
+ if len(input_ids.shape) != 3 or input_ids.shape[-1] != self.config.n_vq + 1:
254
+ raise ValueError("`Input_ids`'s shape should be exactly (batch_size, sequence_length, 1 + n_vq).")
255
+
256
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
257
+
258
+ # 1. Prepare Embeddings
259
+ if inputs_embeds is None:
260
+ inputs_embeds = self.get_input_embeddings(input_ids)
261
+
262
+ # 2. Backbone Forward
263
+ # Qwen3Model outputs standard CausalLMOutputWithPast or similar
264
+ outputs = self.language_model(
265
+ input_ids=None, # Passed via inputs_embeds
266
+ position_ids=position_ids,
267
+ attention_mask=attention_mask,
268
+ past_key_values=past_key_values,
269
+ inputs_embeds=inputs_embeds,
270
+ use_cache=use_cache,
271
+ output_attentions=output_attentions,
272
+ output_hidden_states=True, # Always need hidden states for multi-head projection
273
+ return_dict=True,
274
+ cache_position=cache_position,
275
+ **kwargs,
276
+ )
277
+
278
+ # 3. Handle specific layer outputs if requested (Delay Pattern often requires features from specific layers)
279
+ last_hidden_state = outputs.last_hidden_state
280
+ if hidden_out_layers is None:
281
+ # Default to using the last layer for all heads
282
+ # In some architectures (like MusicGen), different codebooks come from different transformer layers.
283
+ # Here we default to the final layer as per original code behavior [-1] * (n + 1).
284
+ hidden_states_for_heads = [last_hidden_state] * (len(self.lm_heads))
285
+ else:
286
+ # If hidden_out_layers is provided (e.g. [-1, -2, -3...]), fetch them from all_hidden_states
287
+ # Note: outputs.hidden_states includes embedding output at index 0 usually.
288
+ all_hs = outputs.hidden_states
289
+ hidden_states_for_heads = [all_hs[idx] for idx in hidden_out_layers]
290
+
291
+ # 4. Project to Logits (Multi-Head)
292
+ layer_logits = []
293
+ for i, (hs, head) in enumerate(zip(hidden_states_for_heads, self.lm_heads)):
294
+ logits = head(hs)
295
+ # Original code logic: Mask the last token index for audio heads (indices > 0)
296
+ # This implies the vocab size is (N+1) but the model shouldn't predict the (N+1)-th token
297
+ # (perhaps reserved for padding in the input but invalid for prediction).
298
+ if i > 0:
299
+ logits[..., -1] = float("-inf")
300
+ layer_logits.append(logits)
301
+
302
+ # 5. Loss Calculation
303
+ loss = None
304
+ all_sum_losses = None
305
+ all_token_nums = None
306
+ sample_losses = None
307
+ channel_losses = None
308
+
309
+ if labels is not None:
310
+ # Ensure labels match input shape rank (B, S, C)
311
+ if labels.dim() != 3:
312
+ raise ValueError(f"Labels must have rank 3 (B, S, C), got {labels.shape}")
313
+
314
+ batch_size = labels.size(0)
315
+ n_heads = len(layer_logits)
316
+
317
+ # Container for per-sample, per-channel losses
318
+ # Shape: [Batch, n_heads]
319
+ all_sum_losses_list = []
320
+
321
+ # Count valid tokens (not -100) per sample.
322
+ # Note: Assuming mask is consistent across channels or we take sum over dim 1 (seq)
323
+ # Usually strict masking means checking one channel or all.
324
+ # Original code: torch.sum(labels != -100, dim=1) -> [B, C]
325
+ all_token_nums = torch.sum(labels != -100, dim=1)
326
+
327
+ for i, logits in enumerate(layer_logits):
328
+ # logits: [B, S, V]
329
+ # cur_labels: [B, S]
330
+ cur_labels = labels[..., i]
331
+
332
+ # Flatten for CrossEntropy
333
+ # logits: [B*S, V], labels: [B*S]
334
+ loss_fct = CrossEntropyLoss(reduction='none')
335
+ vocab_size = logits.size(-1)
336
+
337
+ reshaped_logits = logits.view(-1, vocab_size)
338
+ reshaped_labels = cur_labels.contiguous().view(-1)
339
+
340
+ # Calculate loss per token
341
+ per_token_loss = loss_fct(reshaped_logits, reshaped_labels)
342
+
343
+ # Reshape back to [B, S] and sum over Sequence dimension to get per-sample loss
344
+ per_token_loss = per_token_loss.view(batch_size, -1)
345
+ per_sample_loss = torch.sum(per_token_loss, dim=-1) # [B]
346
+
347
+ all_sum_losses_list.append(per_sample_loss)
348
+
349
+ # Stack to [B, n_heads]
350
+ all_sum_losses = torch.stack(all_sum_losses_list, dim=1)
351
+
352
+ # Weighted Loss Aggregation
353
+ if channelwise_loss_weight is not None:
354
+ if len(channelwise_loss_weight) != n_heads:
355
+ raise ValueError(f"channelwise_loss_weight length {len(channelwise_loss_weight)} != {n_heads}")
356
+
357
+ w_tensor = torch.tensor(channelwise_loss_weight, device=all_sum_losses.device, dtype=all_sum_losses.dtype)
358
+
359
+ # Sample losses: Weighted sum over channels per sample / Total weight
360
+ # Normalize by token count per channel
361
+ # Avoid division by zero with epsilon or mask
362
+ token_counts_safe = all_token_nums.float().clamp(min=1.0)
363
+
364
+ normalized_losses = all_sum_losses / token_counts_safe
365
+ sample_losses = (normalized_losses * w_tensor).sum(dim=1) / w_tensor.sum()
366
+
367
+ # Channel losses: Sum over batch / Sum tokens over batch
368
+ total_loss_per_channel = all_sum_losses.sum(dim=0)
369
+ total_tokens_per_channel = all_token_nums.sum(dim=0).float().clamp(min=1.0)
370
+ channel_losses = total_loss_per_channel / total_tokens_per_channel
371
+
372
+ # Final scalar loss
373
+ loss = (channel_losses * w_tensor).sum() / w_tensor.sum()
374
+ else:
375
+ # Default average if no weights provided
376
+ total_tokens = all_token_nums.sum().float().clamp(min=1.0)
377
+ loss = all_sum_losses.sum() / total_tokens
378
+ channel_losses = all_sum_losses.sum(dim=0) / all_token_nums.sum(dim=0).clamp(min=1.0)
379
+
380
+ return MossTTSDelayOutputWithPast(
381
+ loss=loss,
382
+ all_sum_losses=all_sum_losses,
383
+ all_token_nums=all_token_nums,
384
+ sample_losses=sample_losses,
385
+ channel_losses=channel_losses,
386
+ logits=layer_logits,
387
+ past_key_values=outputs.past_key_values,
388
+ hidden_states=outputs.hidden_states,
389
+ attentions=outputs.attentions,
390
+ )
391
+
392
+ @torch.inference_mode()
393
+ def generate(
394
+ self,
395
+ input_ids: torch.LongTensor,
396
+ attention_mask: Optional[torch.Tensor] = None,
397
+ max_new_tokens: int = 1000,
398
+ text_temperature: float = 1.5,
399
+ text_top_p: float = 1.0,
400
+ text_top_k: int = 50,
401
+ audio_temperature: float = 1.5,
402
+ audio_top_p: float = 0.8,
403
+ audio_top_k: int = 50,
404
+ audio_repetition_penalty: float = 1.0,
405
+ ):
406
+ if text_temperature > 0:
407
+ text_do_sample = True
408
+ else:
409
+ text_temperature = 1
410
+ text_do_sample = False
411
+ if audio_temperature > 0:
412
+ audio_do_sample = True
413
+ else:
414
+ audio_temperature = 1
415
+ audio_do_sample = False
416
+
417
+ past_key_values = None
418
+ device = input_ids.device
419
+ current_input_ids = input_ids
420
+ current_attention_mask = attention_mask
421
+ batch_size, seq_len, n_vq = input_ids.shape
422
+ n_vq -= 1
423
+
424
+ generation_ids = input_ids[:]
425
+ is_stopping = torch.zeros(batch_size, dtype=torch.bool, device=device)
426
+
427
+ # 三个阶段: 1. 非 audio; 2. audio not delay; 3. audio delay
428
+ audio_lengths = torch.zeros(batch_size, dtype=torch.int64, device=device) # 0 的时候表示阶段1;
429
+ torch_int64_max = torch.iinfo(torch.int64).max
430
+ delayed_lengths = torch.full((batch_size,), torch_int64_max, dtype=torch.int64, device=device) # 最大值的时候表示阶段2;
431
+
432
+ # 考虑 continuation 时 audio_start 已经在 input_ids 中的情况;
433
+ # NOTE 注意我们目前不考虑任何输入已经开始 delay 的情况;
434
+ # 需要同时考虑 continuation 和直接生成的情况;
435
+ is_continuation = (input_ids[:, -1, 0] == self.config.audio_start_token_id) | (input_ids[:, -1, 0] == self.config.audio_assistant_gen_slot_token_id)
436
+ audio_start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id)
437
+ audio_start_mask = is_continuation & (audio_start_indices != -1)
438
+ audio_lengths[audio_start_mask] = seq_len - audio_start_indices[audio_start_mask]
439
+
440
+ is_audio = audio_start_mask.clone()
441
+
442
+ pre_exclude_mask0 = torch.tensor([self.config.pad_token_id, self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id, self.config.audio_end_token_id], device=device)
443
+ pre_exclude_mask1 = torch.ones(self.config.language_config.vocab_size, device=device).bool()
444
+ pre_exclude_mask1[[self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id]] = False
445
+
446
+
447
+ # 注意 time_step 未必表示对于实际对话时,当前输出token的位置,因为有续写的情况;
448
+ for time_step in tqdm(range(max_new_tokens), desc=f"Generating bs{batch_size} ..."):
449
+ outputs = self(
450
+ input_ids=current_input_ids,
451
+ attention_mask=current_attention_mask,
452
+ past_key_values=past_key_values,
453
+ use_cache=True,
454
+ )
455
+ past_key_values = outputs.past_key_values
456
+
457
+ next_token_logits = [logit[:, -1, :] / text_temperature if logit_idx == 0 else logit[:, -1, :] / audio_temperature for logit_idx, logit in enumerate(outputs.logits)] # List, len=n_vq+1, [batch_size, 1, vocab_size];
458
+ next_token_logits[0] = next_token_logits[0].clone()
459
+ # 1. 先处理 text token;
460
+ next_text_token = torch.full((batch_size,), self.config.pad_token_id, device=device)
461
+ # 第二个 audio_assistant_delay_slot_token_id 和 audio_end 是不需要采样的,audio_start, 每一个 audio_assistant_gen_slot_token_ids 和第一个 audio_assistant_delay_slot_token_id 是需要采样的;
462
+ next_text_token[~is_stopping & (delayed_lengths < n_vq)] = self.config.audio_assistant_delay_slot_token_id
463
+ is_audio_eos = ~is_stopping & (delayed_lengths == n_vq)
464
+ next_text_token[is_audio_eos] = self.config.audio_end_token_id
465
+ is_audio[is_audio_eos] = False
466
+ sampling_text_mask = ~is_stopping & (delayed_lengths > n_vq)
467
+ next_token_logits[0][~is_audio] = next_token_logits[0][~is_audio].index_fill(-1, pre_exclude_mask0, float('-inf'))
468
+ next_token_logits[0][is_audio] = next_token_logits[0][is_audio].masked_fill(pre_exclude_mask1, float('-inf'))
469
+ if time_step == 0:
470
+ next_token_logits[0][..., 151662] = float('-inf')
471
+ if time_step <= n_vq:
472
+ next_token_logits[0][..., self.config.im_end_token_id] = float('-inf')
473
+
474
+ # 文本层不使用重复惩罚;
475
+ next_text_token[sampling_text_mask] = sample_token(
476
+ logits=next_token_logits[0][sampling_text_mask],
477
+ top_p=text_top_p,
478
+ top_k=text_top_k,
479
+ do_sample=text_do_sample
480
+ )
481
+ is_audio[next_text_token == self.config.audio_start_token_id] = True
482
+ # 只存在一种停止逻辑,即 next_text_token = <|im_end|>;
483
+ is_stopping[next_text_token == self.config.im_end_token_id] = True
484
+
485
+ # 2. 再处理 audio tokens;
486
+ # audio_start 和 audio_end 之外的内容直接pad,默认是 pad,我们只需要填充有值的部分即可;
487
+ next_audio_tokens = torch.full((batch_size, n_vq), self.config.audio_pad_code, device=device)
488
+
489
+ # 需要考虑的是与 audio_start 的距离;
490
+ # 先查看是否是pad的情况; true 表示有值;
491
+ pre_audio_mask = audio_lengths.unsqueeze(1) > torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq)
492
+ post_audio_mask = torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) > delayed_lengths.unsqueeze(1) - 1
493
+ post_audio_mask[delayed_lengths == torch_int64_max] = True
494
+ sampling_audio_mask = pre_audio_mask & post_audio_mask
495
+ next_audio_tokens[~sampling_audio_mask] = self.config.audio_pad_code
496
+
497
+ if sampling_audio_mask.sum() > 0:
498
+ audio_logits = torch.stack(next_token_logits[1:], dim=1)[sampling_audio_mask] # torch.stack -> [batch_size, n_vq - 1, vocab_size]
499
+ audio_logits[..., self.config.audio_pad_code] = float('-inf')
500
+ next_audio_tokens[sampling_audio_mask] = sample_token(
501
+ logits=audio_logits,
502
+ prev_tokens=generation_ids[:, :, 1:],
503
+ repetition_penalty=audio_repetition_penalty,
504
+ top_p=audio_top_p,
505
+ top_k=audio_top_k,
506
+ do_sample=audio_do_sample
507
+ )
508
+
509
+ # 这里显示的是下一个时间步时可以直接使用的 audio_lengths 和 delayed_lengths 的状态;
510
+ # audio_lengths[(next_text_token == self.audio_start_token_id) & (audio_lengths > 0)] += 1
511
+ # audio_lengths[(next_text_token == self.audio_start_token_id) | (next_text_token == self.audio_assistant_gen_slot_token_id)] += 1
512
+ audio_lengths[(next_text_token == self.config.audio_start_token_id) | (next_text_token == self.config.audio_assistant_gen_slot_token_id) | (next_text_token == self.config.audio_assistant_delay_slot_token_id)] += 1
513
+ audio_lengths[next_text_token == self.config.audio_end_token_id] = 0
514
+ delayed_lengths[(delayed_lengths == torch_int64_max) & (next_text_token == self.config.audio_assistant_delay_slot_token_id)] = 0
515
+ delayed_lengths[delayed_lengths != torch_int64_max] += 1
516
+ delayed_lengths[delayed_lengths > n_vq] = torch_int64_max
517
+
518
+ current_input_ids = torch.cat([next_text_token[:, None, None], next_audio_tokens[:, None, :]], dim=2) # [batch_size, 1, n_vq + 1]
519
+ current_attention_mask = torch.cat([current_attention_mask, (~is_stopping).unsqueeze(-1)], dim=-1)
520
+ generation_ids = torch.cat([generation_ids, current_input_ids], dim=1) # [batch_size, seq_len, n_vq + 1]
521
+
522
+ if is_stopping.sum() == batch_size:
523
+ break
524
+
525
+ start_indices = find_last_equal_C(input_ids[..., 0], self.config.im_start_token_id) + 3
526
+ start_lengths = seq_len - start_indices
527
+
528
+ output = []
529
+ for start_idx, start_length, cur_generation_ids in zip(start_indices, start_lengths, generation_ids):
530
+ output.append((start_length, cur_generation_ids[start_idx:]))
531
+
532
+ return output
processing_moss_tts.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from typing import Dict, List, Optional, Tuple, Type, Union, Literal, Final
18
+ from dataclasses import dataclass
19
+ from pathlib import Path
20
+ import re
21
+ import torchaudio
22
+
23
+ import torch
24
+ from transformers import PreTrainedTokenizerBase, BatchFeature, ProcessorMixin, logging, AutoConfig, AutoModel, AutoTokenizer
25
+
26
+ from .configuration_moss_tts import MossTTSDelayConfig
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ AUDIO_PLACEHOLDER = "<|audio|>"
33
+
34
+
35
+ @dataclass
36
+ class Message:
37
+ pass
38
+
39
+
40
+
41
+ @dataclass
42
+ class UserMessage(Message):
43
+ text: Optional[str] = None
44
+ reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None
45
+ instruction: Optional[str] = None
46
+ tokens: Optional[int] = None
47
+ quality: Optional[str] = None
48
+ sound_event: Optional[str] = None
49
+ ambient_sound: Optional[str] = None
50
+ language: Optional[str] = None
51
+
52
+ def __post_init__(self):
53
+ template = """<user_inst>
54
+ - Reference(s):
55
+ {reference}
56
+ - Instruction:
57
+ {instruction}
58
+ - Tokens:
59
+ {tokens}
60
+ - Quality:
61
+ {quality}
62
+ - Sound Event:
63
+ {sound_event}
64
+ - Ambient Sound:
65
+ {ambient_sound}
66
+ - Language:
67
+ {language}
68
+ - Text:
69
+ {text}
70
+ </user_inst>"""
71
+
72
+ audio_codes_list = []
73
+ if self.reference is None:
74
+ reference = "None"
75
+ elif isinstance(self.reference, List):
76
+ reference = []
77
+ for speaker_idx, speaker_reference in enumerate(self.reference):
78
+ if speaker_reference is not None:
79
+ reference.append(f"[S{speaker_idx}]:\n{AUDIO_PLACEHOLDER}")
80
+ reference = "\n".join(reference)
81
+ audio_codes_list = [speaker_reference for speaker_reference in self.reference if speaker_reference is not None]
82
+ else:
83
+ raise TypeError("`reference` should be exactly a list when it is not None.")
84
+
85
+ content = (
86
+ template
87
+ .replace("{reference}", str(reference))
88
+ .replace("{instruction}", str(self.instruction))
89
+ .replace("{tokens}", str(self.tokens))
90
+ .replace("{quality}", str(self.quality))
91
+ .replace("{sound_event}", str(self.sound_event))
92
+ .replace("{ambient_sound}", str(self.ambient_sound))
93
+ .replace("{language}", str(self.language))
94
+ .replace("{text}", str(self.text))
95
+ )
96
+
97
+ self._content = content
98
+ self._audio_codes_list = audio_codes_list
99
+
100
+ def to_dict(self):
101
+ return {
102
+ "role": "user",
103
+ "content": self._content,
104
+ "audio_codes_list": self._audio_codes_list
105
+ }
106
+
107
+
108
+ @dataclass
109
+ class AssistantMessage(Message):
110
+ audio_codes_list: List[Union[str, torch.Tensor]]
111
+ content: str = AUDIO_PLACEHOLDER
112
+
113
+ def to_dict(self):
114
+ return {
115
+ "role": "assistant",
116
+ "content": self.content,
117
+ "audio_codes_list": self.audio_codes_list
118
+ }
119
+
120
+ USER_MESSAGE_FIELDS = (
121
+ "text",
122
+ "reference",
123
+ "instruction",
124
+ "tokens",
125
+ "quality",
126
+ "sound_event",
127
+ "ambient_sound",
128
+ "language",
129
+ )
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+ class MossTTSDelayProcessor(ProcessorMixin):
138
+ tokenizer_class = "AutoTokenizer"
139
+ audio_tokenizer_class = "AutoModel"
140
+
141
+ def __init__(
142
+ self,
143
+ tokenizer: PreTrainedTokenizerBase,
144
+ audio_tokenizer: AutoModel = None,
145
+ model_config: Optional[MossTTSDelayConfig] = None,
146
+ **kwargs
147
+ ):
148
+ super().__init__(
149
+ tokenizer=tokenizer,
150
+ audio_tokenizer=audio_tokenizer,
151
+ **kwargs
152
+ )
153
+ if model_config is None:
154
+ model_config = MossTTSDelayConfig()
155
+ self.model_config = model_config
156
+
157
+ self.imstart_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
158
+ self.imend_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
159
+ self.newline_token_id = 198
160
+
161
+ self.audio_user_slot_token = tokenizer.convert_ids_to_tokens(self.model_config.audio_user_slot_token_id)
162
+ self.audio_assistant_gen_slot_token = tokenizer.convert_ids_to_tokens(self.model_config.audio_assistant_gen_slot_token_id)
163
+ self.audio_assistant_delay_slot_token = tokenizer.convert_ids_to_tokens(self.model_config.audio_assistant_delay_slot_token_id)
164
+ self.audio_start_token = tokenizer.convert_ids_to_tokens(self.model_config.audio_start_token_id)
165
+ self.audio_end_token = tokenizer.convert_ids_to_tokens(self.model_config.audio_end_token_id)
166
+
167
+ @classmethod
168
+ def from_pretrained(cls, pretrained_model_name_or_path, trust_remote_code=True, **kwargs):
169
+ kwargs.pop("_from_auto")
170
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
171
+ model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
172
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
173
+
174
+ audio_tokenizer_name_or_path = kwargs.pop("codec_path", "OpenMOSS-Team/MOSS-Audio-Tokenizer")
175
+ assert isinstance(audio_tokenizer_name_or_path, str), f"Unsupported audio_tokenizer_path input format: {type(audio_tokenizer_name_or_path)}"
176
+ audio_tokenizer = AutoModel.from_pretrained(audio_tokenizer_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
177
+
178
+ return cls(
179
+ tokenizer=tokenizer,
180
+ audio_tokenizer=audio_tokenizer,
181
+ model_config=model_config,
182
+ **kwargs
183
+ )
184
+
185
+ def __call__(
186
+ self,
187
+ conversations: Union[Message, Dict, List[Message], List[Dict], List[List[Message]], List[List[Dict]]],
188
+ mode: str = "generation",
189
+ apply_chat_template: bool = True,
190
+ n_vq: Optional[int] = None
191
+ ) -> BatchFeature:
192
+
193
+ """
194
+ mode 只会在将 Message 转换为 to_dict 时起作用;
195
+ """
196
+
197
+ if mode not in {"generation", "continuation"}:
198
+ raise RuntimeError
199
+
200
+ if isinstance(conversations, (Message, Dict)):
201
+ conversations = [conversations]
202
+
203
+ truncation = False
204
+ if mode == "continuation":
205
+ truncation = True
206
+
207
+ input_ids_list = []
208
+ for conversation in conversations:
209
+ if isinstance(conversation, (Message, Dict)):
210
+ conversation = [conversation]
211
+
212
+ if (mode == "generation") ^ (len(conversation) % 2 != 0):
213
+ raise ValueError
214
+
215
+ if (mode == "generation") ^ (conversation[-1]['role'] == "user"):
216
+ raise ValueError
217
+
218
+ unified_codes = []
219
+ for message_idx, message in enumerate(conversation):
220
+ message = self._normalize_message(message)
221
+ if apply_chat_template:
222
+ add_generation_prompt = mode == "generation" and message_idx == len(conversation) - 1
223
+ try:
224
+ content = self.tokenizer.apply_chat_template(
225
+ [{"role": message["role"], "content": message["content"]}],
226
+ add_generation_prompt=add_generation_prompt,
227
+ tokenize=False,
228
+ )
229
+ except TypeError:
230
+ try:
231
+ content = self.tokenizer.apply_chat_template(
232
+ [{"role": message["role"], "content": message["content"]}],
233
+ add_generation_prompt=add_generation_prompt,
234
+ )
235
+ except Exception:
236
+ logger.warning("apply_chat_template failed; fallback to raw content.")
237
+ content = message["content"]
238
+ else:
239
+ content = message['content']
240
+
241
+ audio_codes_list = []
242
+ for audio_codes in message["audio_codes_list"]:
243
+ if isinstance(audio_codes, torch.Tensor):
244
+ if n_vq is not None and audio_codes.shape[1] != n_vq:
245
+ raise RuntimeError("audio_codes's n_vq is not equal to the parameter `n_vq`. Your can set the parameter `n_vq` as None if you have already tokenzied the wavs.")
246
+ else:
247
+ audio_codes = self.encode_audios_from_path(audio_codes, n_vq)[0]
248
+ audio_codes_list.append(audio_codes)
249
+ unified_codes.append(self._get_unified_codes(message['role'], content, audio_codes_list, truncation))
250
+
251
+ unified_codes = torch.cat(unified_codes)
252
+ input_ids_list.append(unified_codes)
253
+
254
+ return self._pad(input_ids_list)
255
+
256
+ @staticmethod
257
+ def build_user_message(
258
+ text: Optional[str] = None,
259
+ reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None,
260
+ instruction: Optional[str] = None,
261
+ tokens: Optional[int] = None,
262
+ quality: Optional[str] = None,
263
+ sound_event: Optional[str] = None,
264
+ ambient_sound: Optional[str] = None,
265
+ language: Optional[str] = None,
266
+ ) -> Dict:
267
+ if reference is not None and not isinstance(reference, list):
268
+ reference = [reference]
269
+ return UserMessage(
270
+ text=text,
271
+ reference=reference,
272
+ instruction=instruction,
273
+ tokens=tokens,
274
+ quality=quality,
275
+ sound_event=sound_event,
276
+ ambient_sound=ambient_sound,
277
+ language=language,
278
+ ).to_dict()
279
+
280
+ @staticmethod
281
+ def build_assistant_message(
282
+ audio_codes_list: List[Union[str, torch.Tensor]],
283
+ content: str = AUDIO_PLACEHOLDER,
284
+ ) -> Dict:
285
+ return AssistantMessage(
286
+ audio_codes_list=audio_codes_list,
287
+ content=content,
288
+ ).to_dict()
289
+
290
+ def _normalize_message(self, message: Union[Message, Dict]) -> Dict:
291
+ if isinstance(message, Message):
292
+ return message.to_dict()
293
+ if not isinstance(message, dict):
294
+ raise TypeError("Each message must be a Message or dict.")
295
+ if "role" not in message:
296
+ raise ValueError("Message dict must include a 'role' field.")
297
+ if "content" in message and "audio_codes_list" in message:
298
+ return message
299
+ role = message["role"]
300
+ if role == "user":
301
+ kwargs = {key: message.get(key) for key in USER_MESSAGE_FIELDS}
302
+ return self.build_user_message(**kwargs)
303
+ if role == "assistant":
304
+ return self.build_assistant_message(
305
+ audio_codes_list=message.get("audio_codes_list", []),
306
+ content=message.get("content", AUDIO_PLACEHOLDER),
307
+ )
308
+ raise ValueError(f"Unsupported role: {role}")
309
+
310
+ def _pad(self, input_ids_list: List[torch.Tensor]):
311
+ device = input_ids_list[0].device
312
+ lengths = torch.tensor([w.shape[0] for w in input_ids_list], device=device)
313
+ pad_input_ids = torch.nn.utils.rnn.pad_sequence(input_ids_list, batch_first=True, padding_value=self.model_config.audio_pad_code, padding_side="left")
314
+ other_channel_mask = (pad_input_ids.shape[1] - lengths).unsqueeze(1) > torch.arange(pad_input_ids.shape[1], device=device).unsqueeze(0)
315
+ pad_input_ids[..., 0][other_channel_mask] = self.model_config.pad_token_id
316
+ attention_mask = torch.zeros(pad_input_ids.shape[0], pad_input_ids.shape[1], device=device)
317
+ attention_mask[~other_channel_mask] = 1
318
+ attention_mask = attention_mask.bool()
319
+ return {
320
+ "input_ids": pad_input_ids, # [batch_size, seqlen, n_vq]
321
+ "attention_mask": attention_mask,
322
+ }
323
+
324
+ @staticmethod
325
+ def _replace_audio_placeholders(
326
+ content: str,
327
+ lengths: List[int],
328
+ n_vq: int,
329
+ gen_slot_token: str,
330
+ delay_slot_token: str,
331
+ audio_start_token: str,
332
+ audio_end_token: str
333
+ ) -> str:
334
+ if n_vq < 1:
335
+ raise ValueError(f"n_vq must be >= 1, got {n_vq}")
336
+
337
+ num_placeholders = content.count(AUDIO_PLACEHOLDER)
338
+ if num_placeholders != len(lengths):
339
+ raise ValueError(
340
+ f"Number of {AUDIO_PLACEHOLDER} ({num_placeholders}) "
341
+ f"does not match lengths ({len(lengths)})"
342
+ )
343
+
344
+ def build_audio_block(length: int) -> str:
345
+ if length < 0:
346
+ raise ValueError(f"length must be >= 0, got {length}")
347
+
348
+ if length == 0:
349
+ return f"{audio_start_token}{audio_end_token}"
350
+
351
+ step_tokens = gen_slot_token * length + (delay_slot_token * (n_vq - 1))
352
+ return f"{audio_start_token}{step_tokens}{audio_end_token}"
353
+
354
+ lengths_iter = iter(lengths)
355
+
356
+ def replacer(match: re.Match) -> str:
357
+ length = next(lengths_iter)
358
+ return build_audio_block(length)
359
+
360
+ result = re.sub(re.escape(AUDIO_PLACEHOLDER), replacer, content)
361
+
362
+ return result
363
+
364
+ @staticmethod
365
+ def _merge_consecutive_audio_placeholders(
366
+ content: str,
367
+ audio_codes_list: List[torch.Tensor],
368
+ ) -> Tuple[str, List[torch.Tensor]]:
369
+ matches = list(re.finditer(re.escape(AUDIO_PLACEHOLDER), content))
370
+ if len(matches) <= 1:
371
+ return content, audio_codes_list
372
+
373
+ if len(matches) != len(audio_codes_list):
374
+ raise ValueError("Audio placeholders do not match the provided audio codes list.")
375
+
376
+ new_audio_codes_list = []
377
+ new_parts = []
378
+ last_pos = 0
379
+ i = 0
380
+ while i < len(matches):
381
+ j = i
382
+ while (
383
+ j + 1 < len(matches)
384
+ and content[matches[j].end():matches[j + 1].start()].strip() == ""
385
+ ):
386
+ j += 1
387
+
388
+ new_parts.append(content[last_pos:matches[i].start()])
389
+ new_parts.append(AUDIO_PLACEHOLDER)
390
+ last_pos = matches[j].end()
391
+
392
+ if j == i:
393
+ new_audio_codes_list.append(audio_codes_list[i])
394
+ else:
395
+ new_audio_codes_list.append(torch.cat(audio_codes_list[i:j + 1], dim=0))
396
+
397
+ i = j + 1
398
+
399
+ new_parts.append(content[last_pos:])
400
+ return "".join(new_parts), new_audio_codes_list
401
+
402
+ @staticmethod
403
+ def apply_delay_pattern(codes: torch.Tensor, pad_code: int) -> torch.Tensor:
404
+ delayed_tokens = torch.full(
405
+ (codes.shape[0] + codes.shape[1] - 1, codes.shape[1]),
406
+ pad_code,
407
+ device=codes.device,
408
+ dtype=codes.dtype,
409
+ )
410
+ for i in range(codes.shape[1]):
411
+ delayed_tokens[i: i + codes.shape[0], i] = codes[:, i]
412
+ return delayed_tokens
413
+
414
+ @staticmethod
415
+ def apply_de_delay_pattern(delay_codes: torch.Tensor) -> torch.Tensor:
416
+ tokens = torch.full(
417
+ (delay_codes.shape[0] - delay_codes.shape[1] + 1, delay_codes.shape[1]),
418
+ 0,
419
+ device=delay_codes.device,
420
+ dtype=delay_codes.dtype,
421
+ )
422
+ for i in range(delay_codes.shape[1]):
423
+ tokens[:, i] = delay_codes[i: i + tokens.shape[0], i]
424
+ return tokens
425
+
426
+
427
+ def _get_unified_codes(self, role: str, content: str, audio_codes_list: List[Union[str, torch.Tensor]], truncation: bool) -> torch.Tensor:
428
+ """
429
+ 此时的 content 已经是带上了对话格式
430
+ """
431
+ if role == "user":
432
+ audio_gen_slot_token = audio_delay_slot_token = self.audio_user_slot_token
433
+ else:
434
+ audio_gen_slot_token = self.audio_assistant_gen_slot_token
435
+ audio_delay_slot_token = self.audio_assistant_delay_slot_token
436
+
437
+ if len(audio_codes_list):
438
+ n_vq = audio_codes_list[0].shape[1]
439
+ else:
440
+ n_vq = self.model_config.n_vq
441
+
442
+ if len(audio_codes_list) > 1 and AUDIO_PLACEHOLDER in content:
443
+ content, audio_codes_list = self._merge_consecutive_audio_placeholders(
444
+ content, audio_codes_list
445
+ )
446
+ content = self._replace_audio_placeholders(
447
+ content=content,
448
+ lengths=[len(audio_codes) for audio_codes in audio_codes_list],
449
+ n_vq=n_vq,
450
+ gen_slot_token=audio_gen_slot_token,
451
+ delay_slot_token=audio_delay_slot_token,
452
+ audio_start_token=self.audio_start_token,
453
+ audio_end_token=self.audio_end_token,
454
+ )
455
+ text_codes = torch.tensor(self.tokenizer.encode(content), device=audio_codes_list[0].device if audio_codes_list else None)
456
+
457
+ audio_start_indices = torch.where(text_codes == self.model_config.audio_start_token_id)[0]
458
+ audio_end_indices = torch.where(text_codes == self.model_config.audio_end_token_id)[0]
459
+ if len(audio_start_indices) != len(audio_codes_list) or len(audio_end_indices) != len(audio_codes_list):
460
+ raise ValueError("Audio placeholders do not match the provided audio codes list.")
461
+
462
+ delay_audio_codes_list = []
463
+ if len(audio_codes_list) == 0:
464
+ delay_audio_codes_list = torch.full(
465
+ (len(text_codes), n_vq),
466
+ self.model_config.audio_pad_code,
467
+ device=text_codes.device,
468
+ dtype=text_codes.dtype,
469
+ )
470
+ else:
471
+ prefix_idx = 0
472
+ for audio_start_idx, audio_end_idx, audio_codes in zip(audio_start_indices, audio_end_indices, audio_codes_list):
473
+ delay_audio_codes = self.apply_delay_pattern(audio_codes, self.model_config.audio_pad_code)
474
+ pad_codes = torch.full(
475
+ (audio_start_idx - prefix_idx + 1, n_vq),
476
+ self.model_config.audio_pad_code,
477
+ device=audio_codes.device,
478
+ dtype=audio_codes.dtype,
479
+ )
480
+ delay_audio_codes_list.extend([pad_codes, delay_audio_codes])
481
+ prefix_idx = audio_end_idx
482
+
483
+ if truncation:
484
+ delay_audio_codes_list[-1] = delay_audio_codes_list[-1][:-(n_vq - 1), :]
485
+ else:
486
+ pad_codes = torch.full(
487
+ (len(text_codes) - audio_end_indices[-1], n_vq),
488
+ self.model_config.audio_pad_code,
489
+ device=audio_codes_list[0].device,
490
+ dtype=audio_codes_list[0].dtype,
491
+ )
492
+ delay_audio_codes_list.append(pad_codes)
493
+
494
+ delay_audio_codes_list = torch.cat(delay_audio_codes_list)
495
+
496
+ if text_codes.shape[0] != delay_audio_codes_list.shape[0]:
497
+ text_codes = text_codes[:delay_audio_codes_list.shape[0]]
498
+
499
+ unified_codes = torch.cat([text_codes.unsqueeze(1), delay_audio_codes_list], dim=1)
500
+ return unified_codes
501
+
502
+ def _parse_text_codes(self, start_length, text_codes):
503
+ text = self.tokenizer.decode(text_codes)
504
+ prefix = self.tokenizer.decode(text_codes[:start_length])
505
+ text = text[len(prefix):]
506
+
507
+ AUDIO_PATTERN = re.compile(
508
+ rf'(?:{self.audio_start_token})?'
509
+ rf'(?:{self.audio_assistant_gen_slot_token})*'
510
+ rf'(?:{self.audio_assistant_delay_slot_token})*'
511
+ rf'{self.audio_end_token}'
512
+ )
513
+
514
+ def normalize_audio_segments(text: str) -> str:
515
+ def repl(match: re.Match) -> str:
516
+ seg = match.group(0)
517
+ # 如果片段内包含至少一个 gen_slot,则替换为 <|audio|>
518
+ if self.audio_assistant_gen_slot_token in seg:
519
+ return AUDIO_PLACEHOLDER
520
+ # 否则直接删除
521
+ return ""
522
+
523
+ return AUDIO_PATTERN.sub(repl, text)
524
+
525
+ return normalize_audio_segments(text)
526
+
527
+ def _parse_audio_codes(self, start_length, audio_codes):
528
+ # De-delay back to [T', n_vq]
529
+ audio_codes = self.apply_de_delay_pattern(audio_codes)
530
+
531
+ # Rows that are all pad are separators between real audio segments.
532
+ is_pad = (audio_codes == self.model_config.audio_pad_code).all(dim=1)
533
+ non_pad = ~is_pad
534
+ if not non_pad.any():
535
+ return []
536
+
537
+ idx = torch.nonzero(non_pad).squeeze(1)
538
+ breaks = torch.where(idx[1:] != idx[:-1] + 1)[0] + 1
539
+ if breaks.numel() == 0:
540
+ segments_idx = [idx]
541
+ else:
542
+ segments_idx = torch.split(idx, breaks.tolist())
543
+
544
+ audio_codes_list = [audio_codes[s] for s in segments_idx]
545
+
546
+ decoded_audio_list = []
547
+ for segment_codes in audio_codes_list:
548
+ decoded_segment = self.decode_audio_codes([segment_codes])
549
+ if len(decoded_segment) > 0:
550
+ decoded_audio_list.append(decoded_segment[0])
551
+
552
+ # Keep codec causal context by decoding the whole first segment first,
553
+ # then trim at waveform level according to start_length ratio.
554
+ if start_length > 0 and len(audio_codes_list) > 0 and len(decoded_audio_list) > 0:
555
+ first_codes_length = audio_codes_list[0].shape[0]
556
+ if first_codes_length > 0:
557
+ trim_ratio = max(0.0, min(float(start_length) / float(first_codes_length), 1.0))
558
+ first_audio = decoded_audio_list[0]
559
+ if trim_ratio >= 1.0:
560
+ decoded_audio_list = decoded_audio_list[1:]
561
+ elif trim_ratio > 0.0:
562
+ trim_samples = int(first_audio.shape[-1] * trim_ratio)
563
+ decoded_audio_list[0] = first_audio[..., trim_samples:]
564
+
565
+ return decoded_audio_list
566
+
567
+
568
+ def decode(self, output: List[Tuple[int, torch.Tensor]]):
569
+ """
570
+ 1. 这里不管怎样,都需要一个完整的 assistant generation ids;
571
+ 2. 支持从任意位置进行截断;
572
+ """
573
+
574
+ genearted_messages = []
575
+ for start_length, generation_ids in output:
576
+ content = self._parse_text_codes(start_length, generation_ids[:, 0])
577
+ audio_codes_list = self._parse_audio_codes(start_length, generation_ids[:, 1:])
578
+ if content == "":
579
+ message = None
580
+ else:
581
+ message = AssistantMessage(
582
+ content=content,
583
+ audio_codes_list=audio_codes_list
584
+ )
585
+ genearted_messages.append(message)
586
+ return genearted_messages
587
+
588
+ @staticmethod
589
+ def loudness_normalize(wav: torch.Tensor, target_dbfs: float = -20, gain_range: tuple[float, float] = (-3.0, 3.0)) -> torch.Tensor:
590
+ wav = wav.to(torch.float32)
591
+ if wav.numel() == 0: return wav
592
+ rms = torch.sqrt(torch.mean(wav ** 2))
593
+ current_dbfs = 20.0 * torch.log10(rms + 1e-9)
594
+ gain = float(target_dbfs - current_dbfs)
595
+ gain = max(gain_range[0], min(gain, gain_range[1]))
596
+ factor = 10.0 ** (gain / 20.0)
597
+ return wav * factor
598
+
599
+ def encode_audios_from_wav(self, wav_list: List[torch.Tensor], sampling_rate: int, n_vq: int = None):
600
+ if isinstance(wav_list, torch.Tensor):
601
+ wav_list = [wav_list]
602
+ wav_list_ = []
603
+ resample = False
604
+ if sampling_rate != self.model_config.sampling_rate:
605
+ resample = True
606
+ for wav in wav_list:
607
+ if wav.shape[0] > 1:
608
+ wav = torch.mean(wav, dim=0, keepdim=True)
609
+ if resample:
610
+ wav = torchaudio.functional.resample(waveform=wav, orig_freq=sampling_rate, new_freq=self.model_config.sampling_rate)
611
+ wav_list_.append(self.loudness_normalize(wav.squeeze(0)))
612
+ return self.audio_tokenizer.encode(wav_list_, n_vq)
613
+
614
+ def encode_audios_from_path(self, wav_path_list: List[str], n_vq: int = None):
615
+ if isinstance(wav_path_list, str):
616
+ wav_path_list = [wav_path_list]
617
+ wav_list = []
618
+ sampling_rate = None
619
+ for wav_path in wav_path_list:
620
+ wav, sr = torchaudio.load(wav_path)
621
+ if sampling_rate is None:
622
+ sampling_rate = sr
623
+ elif sampling_rate != sr:
624
+ raise ValueError("sampling_rate of audios in the same batch should be the same.")
625
+ wav_list.append(wav)
626
+ return self.encode_audios_from_wav(wav_list, sampling_rate, n_vq)
627
+
628
+ def decode_audio_codes(self, audio_tokens_list: List[torch.Tensor]):
629
+ return self.audio_tokenizer.decode(audio_tokens_list)
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "MossTTSDelayProcessor",
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_moss_tts.MossTTSDelayProcessor"
5
+ }
6
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|audio_start|>",
12
+ "<|audio_end|>",
13
+ "<|audio_user_slot|>",
14
+ "<|image_pad|>",
15
+ "<|audio_assistant_gen_slot|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }