CaasiHUANG commited on
Commit
4ba9398
·
1 Parent(s): 67b524e

Add model weights

Browse files
__init__.py ADDED
File without changes
added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|audio_end|>": 151653,
9
+ "<|audio_pad|>": 151654,
10
+ "<|audio_start|>": 151652,
11
+ "<|box_end|>": 151649,
12
+ "<|box_start|>": 151648,
13
+ "<|endoftext|>": 151643,
14
+ "<|file_sep|>": 151664,
15
+ "<|fim_middle|>": 151660,
16
+ "<|fim_pad|>": 151662,
17
+ "<|fim_prefix|>": 151659,
18
+ "<|fim_suffix|>": 151661,
19
+ "<|im_end|>": 151645,
20
+ "<|im_start|>": 151644,
21
+ "<|image_pad|>": 151655,
22
+ "<|object_ref_end|>": 151647,
23
+ "<|object_ref_start|>": 151646,
24
+ "<|quad_end|>": 151651,
25
+ "<|quad_start|>": 151650,
26
+ "<|repo_name|>": 151663,
27
+ "<|video_pad|>": 151656
28
+ }
chat_template.jinja ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {% for message in messages %}<|im_start|>{{ message['role'] }}
2
+ {% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content.get('type') == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}<|im_end|>
3
+ {% endfor %}{% if add_generation_prompt %}<|im_start|>assistant
4
+ {% endif %}
config.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "moss_tts_delay",
3
+ "architectures": [
4
+ "MossTTSDelayModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_moss_tts.MossTTSDelayConfig",
8
+ "AutoModel": "modeling_moss_tts.MossTTSDelayModel"
9
+ },
10
+ "dtype": "bfloat16",
11
+ "gen_token_id": 151656,
12
+ "initializer_range": 0.02,
13
+ "language_config": {
14
+ "_name_or_path": "Qwen/Qwen3-1.7B",
15
+ "architectures": [
16
+ "Qwen3ForCausalLM"
17
+ ],
18
+ "attention_bias": false,
19
+ "attention_dropout": 0.0,
20
+ "bos_token_id": 151643,
21
+ "eos_token_id": 151645,
22
+ "pad_token_id": 151643,
23
+ "head_dim": 128,
24
+ "hidden_act": "silu",
25
+ "hidden_size": 2048,
26
+ "initializer_range": 0.02,
27
+ "intermediate_size": 6144,
28
+ "layer_types": [
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention",
44
+ "full_attention",
45
+ "full_attention",
46
+ "full_attention",
47
+ "full_attention",
48
+ "full_attention",
49
+ "full_attention",
50
+ "full_attention",
51
+ "full_attention",
52
+ "full_attention",
53
+ "full_attention",
54
+ "full_attention",
55
+ "full_attention",
56
+ "full_attention"
57
+ ],
58
+ "max_position_embeddings": 40960,
59
+ "max_window_layers": 28,
60
+ "model_type": "qwen3",
61
+ "num_attention_heads": 16,
62
+ "num_hidden_layers": 28,
63
+ "num_key_value_heads": 8,
64
+ "rms_norm_eps": 1e-06,
65
+ "rope_scaling": null,
66
+ "rope_theta": 1000000,
67
+ "sliding_window": null,
68
+ "tie_word_embeddings": true,
69
+ "use_cache": true,
70
+ "use_sliding_window": false,
71
+ "vocab_size": 155648
72
+ },
73
+ "n_vq": 16,
74
+ "audio_ch0_vocab_size": 1024,
75
+ "audio_token_id": 151654,
76
+ "audio_vocab_size": 1024,
77
+ "audio_user_slot_token_id": 151654,
78
+ "audio_assistant_gen_slot_token_id": 151656,
79
+ "audio_assistant_delay_slot_token_id": 151662,
80
+ "audio_start_token_id": 151652,
81
+ "audio_end_token_id": 151653,
82
+ "audio_pad_code": 1024,
83
+ "sampling_rate": 24000,
84
+ "transformers_version": "4.57.1"
85
+ }
configuration_moss_tts.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ MossTTSDelay model configuration """
16
+
17
+ from typing import Optional, Union
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+ from transformers.models.qwen3 import Qwen3Config
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class MossTTSDelayConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`MossTTSDelayModel`]. It is used to instantiate an
28
+ MossTTSDelay model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of the MossTTSDelay [MossTTSDelay-8B](https://huggingface.co/OpenMOSS/mosstts-8b) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ language_config (`Union[Qwen3Config, dict]`, *optional*):
36
+ Configuration for the backbone language model (Qwen3).
37
+ initializer_range (`float`, *optional*, defaults to 0.02):
38
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
39
+ n_vq (`int`, *optional*, defaults to 32):
40
+ Number of additional VQ (Vector Quantization) heads/channels for audio.
41
+ Determines the number of codebooks used in the audio representation.
42
+ audio_vocab_size (`int`, *optional*, defaults to 1024):
43
+ Vocabulary size for the audio tokens (codebooks 1 to N).
44
+ audio_user_slot_token_id (`int`, *optional*, defaults to 151654):
45
+ The specific token ID used as a placeholder/slot for user-side audio inputs in the prompt.
46
+ audio_assistant_gen_slot_token_id (`int`, *optional*, defaults to 151656):
47
+ The specific token ID representing the generation slot for the assistant's audio output.
48
+ Acting as the trigger for the TTS generation process.
49
+ audio_assistant_delay_slot_token_id (`int`, *optional*, defaults to 151662):
50
+ The token ID used in the 'Delay Pattern' paradigm to represent the delayed/offset positions
51
+ between different VQ channels.
52
+ audio_start_token_id (`int`, *optional*, defaults to 151652):
53
+ Special token ID used to denote the start of an audio sequence in the stream.
54
+ audio_end_token_id (`int`, *optional*, defaults to 151653):
55
+ Special token ID used to denote the end of an audio sequence (EOS for audio).
56
+ audio_pad_code (`int`, *optional*, defaults to 1024):
57
+ The padding value used within the audio VQ codebooks. Typically equals `audio_vocab_size`.
58
+ """
59
+ model_type = "moss_tts_delay"
60
+ keys_to_ignore_at_inference = ["past_key_values"]
61
+
62
+ def __init__(
63
+ self,
64
+ language_config: Optional[Union[Qwen3Config, dict]] = None,
65
+ initializer_range: float = 0.02,
66
+ n_vq: int = 32,
67
+ pad_token_id: int = 151643,
68
+ im_start_token_id: int = 151644,
69
+ im_end_token_id: int = 151645,
70
+ audio_vocab_size: int = 1024,
71
+ audio_user_slot_token_id: int = 151654,
72
+ audio_assistant_gen_slot_token_id: int = 151656,
73
+ audio_assistant_delay_slot_token_id: int = 151662,
74
+ audio_start_token_id: int = 151652,
75
+ audio_end_token_id: int = 151653,
76
+ audio_pad_code: int = 1024,
77
+ sampling_rate: int = 24000,
78
+ **kwargs,
79
+ ):
80
+ if isinstance(language_config, dict):
81
+ self.language_config = Qwen3Config(**language_config)
82
+ elif language_config is None:
83
+ self.language_config = Qwen3Config()
84
+ else:
85
+ self.language_config = language_config
86
+
87
+ self.initializer_range = initializer_range
88
+ self.n_vq = n_vq
89
+ self.audio_vocab_size = audio_vocab_size
90
+ self.audio_user_slot_token_id = audio_user_slot_token_id
91
+ self.audio_assistant_gen_slot_token_id = audio_assistant_gen_slot_token_id
92
+ self.audio_assistant_delay_slot_token_id = audio_assistant_delay_slot_token_id
93
+ self.audio_start_token_id = audio_start_token_id
94
+ self.audio_end_token_id = audio_end_token_id
95
+ self.audio_pad_code = audio_pad_code
96
+ self.sampling_rate = sampling_rate
97
+
98
+ self.hidden_size = self.language_config.hidden_size
99
+ self.vocab_size = self.language_config.vocab_size
100
+ self.im_start_token_id = self.language_config
101
+ self.pad_token_id = pad_token_id
102
+ self.im_start_token_id = im_start_token_id
103
+ self.im_end_token_id = im_end_token_id
104
+
105
+
106
+ super().__init__(**kwargs)
107
+
108
+ def to_dict(self):
109
+ output = super().to_dict()
110
+ if hasattr(self.language_config, "to_dict"):
111
+ output["language_config"] = self.language_config.to_dict()
112
+ else:
113
+ output["language_config"] = self.language_config
114
+ return output
inference_utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import torch.nn.functional as F
4
+ from typing import Optional, List, Tuple
5
+ from tqdm import tqdm
6
+
7
+
8
+ def apply_top_k(logits, top_k):
9
+ batch_size, vocab_size = logits.shape
10
+ top_k = min(top_k, vocab_size)
11
+ top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1)
12
+ filtered_logits = torch.full_like(logits, float("-inf"))
13
+ batch_indices = torch.arange(batch_size).unsqueeze(-1)
14
+ filtered_logits[batch_indices, top_k_indices] = top_k_values
15
+ return filtered_logits
16
+
17
+
18
+ def apply_top_p(logits, top_p):
19
+ probs = F.softmax(logits, dim=-1)
20
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
21
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
22
+ sorted_indices_to_remove = cumulative_probs > top_p
23
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
24
+ sorted_indices_to_remove[..., 0] = False
25
+ batch_size = logits.shape[0]
26
+ filtered_logits = logits.clone()
27
+ for i in range(batch_size):
28
+ indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
29
+ filtered_logits[i, indices_to_remove] = float("-inf")
30
+ return filtered_logits
31
+
32
+
33
+ def apply_top_p_optimized(logits, top_p):
34
+ probs = F.softmax(logits, dim=-1)
35
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
36
+
37
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
38
+
39
+ sorted_indices_to_remove = cumulative_probs > top_p
40
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
41
+ sorted_indices_to_remove[..., 0] = False
42
+
43
+ indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
44
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
45
+ )
46
+
47
+ logits[indices_to_remove] = float("-inf")
48
+ return logits
49
+
50
+
51
+ def apply_repetition_penalty_delay_pattern(
52
+ logits: torch.Tensor,
53
+ prev_tokens: torch.LongTensor,
54
+ penalty: float,
55
+ ):
56
+ """
57
+ logits: [B, H, V] or [N, V]
58
+ prev_tokens: [B, T, H] or [N, T] or [B, H]
59
+
60
+ Apply the repetition penalty independently for each H (VQ head).
61
+ """
62
+ if penalty == 1.0 or prev_tokens is None:
63
+ return logits
64
+
65
+ vocab_size = logits.size(-1)
66
+
67
+ # Case 1: regular [N, V] (text layer)
68
+ if logits.dim() == 2:
69
+ prev_tokens_flat = prev_tokens.reshape(-1)
70
+ unique_tokens = torch.unique(prev_tokens_flat)
71
+
72
+ token_logits = logits[:, unique_tokens]
73
+ pos_mask = token_logits > 0
74
+ token_logits[pos_mask] /= penalty
75
+ token_logits[~pos_mask] *= penalty
76
+ logits[:, unique_tokens] = token_logits
77
+ return logits
78
+
79
+ # Case 2: Delay Pattern audio [B, H, V]
80
+ assert logits.dim() == 3, "Delay Pattern audio logits must be [B, H, V]"
81
+ B, H, V = logits.shape
82
+
83
+ for h in range(H):
84
+ # prev_tokens_h: [B, T] or [B]
85
+ prev_tokens_h = prev_tokens[..., h].reshape(-1)
86
+ unique_tokens = torch.unique(prev_tokens_h)
87
+
88
+ if unique_tokens.numel() == 0:
89
+ continue
90
+
91
+ token_logits = logits[:, h, unique_tokens]
92
+ pos_mask = token_logits > 0
93
+ token_logits[pos_mask] /= penalty
94
+ token_logits[~pos_mask] *= penalty
95
+ logits[:, h, unique_tokens] = token_logits
96
+
97
+ return logits
98
+
99
+
100
+ def sample_token(
101
+ logits,
102
+ prev_tokens: Optional[torch.LongTensor] = None,
103
+ repetition_penalty: float = 1.0,
104
+ top_p=None,
105
+ top_k=None,
106
+ do_sample=True,
107
+ ):
108
+ vocab_size = logits.size(-1)
109
+
110
+ # ===== Repetition Penalty (before reshaping!) =====
111
+ if prev_tokens is not None and repetition_penalty != 1.0:
112
+ logits = apply_repetition_penalty_delay_pattern(
113
+ logits,
114
+ prev_tokens,
115
+ repetition_penalty,
116
+ )
117
+
118
+ if not do_sample:
119
+ return torch.argmax(logits, dim=-1)
120
+
121
+ # ===== Only flatten after this, for top-k / top-p / multinomial =====
122
+ original_shape = logits.shape
123
+ reshaped_logits = logits.view(-1, vocab_size)
124
+
125
+ if top_k is not None and top_k > 0:
126
+ reshaped_logits = apply_top_k(reshaped_logits, top_k)
127
+
128
+ if top_p is not None and top_p < 1.0:
129
+ reshaped_logits = apply_top_p_optimized(reshaped_logits, top_p)
130
+
131
+ probs = F.softmax(reshaped_logits, dim=-1)
132
+ next_tokens = torch.multinomial(probs, num_samples=1)
133
+
134
+ return next_tokens.view(original_shape[:-1])
135
+
136
+
137
+ def find_last_equal_C(tensor, C):
138
+ """
139
+ tensor: torch.Tensor of shape [batch_size, seq_len]
140
+ C: scalar value to match
141
+ Returns: torch.Tensor of shape [batch_size] with last indices
142
+ """
143
+ mask = (tensor == C).int() # Shape: [batch_size, seq_len], bool tensor
144
+ flipped_mask = mask.flip(dims=[1]) # Flip along sequence dimension
145
+ flipped_indices = flipped_mask.argmax(dim=1) # First True in flipped
146
+ seq_len = tensor.shape[1]
147
+ last_indices = (seq_len - 1) - flipped_indices # Convert to original indices
148
+
149
+ # Optional: Handle cases with no C (set to -1), though problem assumes existence
150
+ actual_values = tensor[torch.arange(tensor.shape[0]), last_indices]
151
+ no_match = actual_values != C
152
+ last_indices[no_match] = -1
153
+
154
+ return last_indices
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dbe345257ff9f6cc84195bed830a268b39d5e0b728ff3ba90e715150a49b16d4
3
+ size 4228278872
modeling_moss_tts.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Modeling classes for MossTTSDelay. """
16
+
17
+ from dataclasses import dataclass
18
+ from typing import List, Optional, Tuple, Union
19
+ from tqdm import tqdm
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch.nn import CrossEntropyLoss
24
+
25
+ from transformers.modeling_utils import PreTrainedModel
26
+ from transformers.modeling_outputs import ModelOutput
27
+ from transformers.utils import (
28
+ add_start_docstrings,
29
+ add_start_docstrings_to_model_forward,
30
+ logging,
31
+ replace_return_docstrings,
32
+ )
33
+ from transformers.cache_utils import Cache
34
+ from transformers.models.qwen3 import Qwen3Model
35
+ from transformers import initialization as init
36
+
37
+ from .configuration_moss_tts import MossTTSDelayConfig
38
+ from .inference_utils import sample_token, find_last_equal_C
39
+
40
+ try:
41
+ from .processing_moss_tts import UserMessage, AssistantMessage, MossTTSDelayProcessor
42
+ except Exception:
43
+ UserMessage = None
44
+ AssistantMessage = None
45
+ MossTTSDelayProcessor = None
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+ _CONFIG_FOR_DOC = "MossTTSDelayConfig"
50
+
51
+
52
+ @dataclass
53
+ class MossTTSDelayOutputWithPast(ModelOutput):
54
+ """
55
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
56
+
57
+ Args:
58
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
59
+ Weighted sum of channel losses.
60
+ all_sum_losses (`torch.FloatTensor` of shape `(batch_size, n_vq + 1)`, *optional*):
61
+ Sum of losses for each sample and each channel before averaging.
62
+ all_token_nums (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
63
+ Number of non-masked tokens per sample.
64
+ sample_losses (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
65
+ Loss per sample.
66
+ channel_losses (`torch.FloatTensor` of shape `(n_vq + 1,)`, *optional*):
67
+ Loss per channel (text head + vq heads).
68
+ logits (`List[torch.FloatTensor]`, *optional*):
69
+ List of prediction scores from each head.
70
+ past_key_values (`Cache`, *optional*):
71
+ Pre-computed hidden-states (key and values in the self-attention blocks).
72
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
73
+ Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, +
74
+ one for the output of each layer).
75
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed):
76
+ Tuple of torch.FloatTensor (one for each layer) of the attention weights.
77
+ """
78
+ loss: Optional[torch.FloatTensor] = None
79
+ all_sum_losses: Optional[torch.FloatTensor] = None
80
+ all_token_nums: Optional[torch.LongTensor] = None
81
+ sample_losses: Optional[torch.FloatTensor] = None
82
+ channel_losses: Optional[torch.FloatTensor] = None
83
+ logits: Optional[List[torch.FloatTensor]] = None
84
+ past_key_values: Optional[Cache] = None
85
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
86
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
87
+
88
+
89
+
90
+
91
+ class MossTTSDelayPreTrainedModel(PreTrainedModel):
92
+ config_class = MossTTSDelayConfig
93
+ base_model_prefix = "model"
94
+ supports_gradient_checkpointing = True
95
+ _no_split_modules = ["Qwen3DecoderLayer"]
96
+ _skip_keys_device_placement = "past_key_values"
97
+ _supports_flash_attn = True
98
+ _supports_flash_attn_2 = True
99
+ _supports_sdpa = True
100
+ _supports_flex_attn = True
101
+
102
+ def _init_weights(self, module):
103
+ """
104
+ Transformers 5.0+ safe init:
105
+ - MUST use transformers.initialization helpers
106
+ - MUST respect param._is_hf_initialized to avoid overwriting ckpt-loaded params
107
+ """
108
+ # Let HF handle its standard modules first (LayerNorm, Linear, Embedding, etc.)
109
+ super()._init_weights(module)
110
+
111
+ # Pick a std consistent with HF conventions
112
+ # Prefer model/text config initializer_range if present.
113
+ std = None
114
+ if hasattr(self.config, "initializer_range"):
115
+ std = self.config.initializer_range
116
+ elif hasattr(self.config, "language_config") and hasattr(self.config.language_config, "initializer_range"):
117
+ std = self.config.language_config.initializer_range
118
+ else:
119
+ std = 0.02
120
+
121
+ # Initialize extra audio embeddings
122
+ if isinstance(module, nn.Embedding):
123
+ # Only touch our extra embeddings (avoid double touching LM's embeddings if not desired)
124
+ # If you prefer, you can skip this check and rely on super()._init_weights for all embeddings.
125
+ if getattr(module, "num_embeddings", None) == self.config.audio_vocab_size + 1:
126
+ init.normal_(module.weight, mean=0.0, std=std)
127
+ # If you later set padding_idx, you must explicitly zero it (and respect _is_hf_initialized!)
128
+ # init.zeros_ will internally check param flags, but slicing needs manual care.
129
+
130
+ # Initialize multi-head projections you added
131
+ if isinstance(module, nn.Linear):
132
+ # For your lm_heads, super()._init_weights already covers typical Linear.
133
+ # This block is only needed if you have custom Linear variants later.
134
+ pass
135
+
136
+
137
+
138
+ MOSSTTS_START_DOCSTRING = r"""
139
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
140
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
141
+ etc.)
142
+
143
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
144
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
145
+ and behavior.
146
+
147
+ Parameters:
148
+ config ([`MossTTSDelayConfig`]):
149
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
150
+ load the weights associated with the model, only the configuration. Check out the
151
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
152
+ """
153
+
154
+
155
+ @add_start_docstrings(
156
+ "The MossTTSDelay Model architecture tailored for Text-to-Speech generation with multi-head VQ prediction.",
157
+ MOSSTTS_START_DOCSTRING,
158
+ )
159
+ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
160
+ UserMessage = UserMessage
161
+ AssistantMessage = AssistantMessage
162
+ Processor = MossTTSDelayProcessor
163
+
164
+ def __init__(self, config: MossTTSDelayConfig):
165
+ super().__init__(config)
166
+ self.config = config
167
+
168
+ config.language_config.torch_dtype = config.torch_dtype
169
+
170
+ self.language_model = Qwen3Model(config.language_config)
171
+
172
+ # Audio VQ Embeddings (Extra channels)
173
+ # Note: input_ids[..., 0] uses Qwen's embedding.
174
+ # input_ids[..., 1:] use these extensions.
175
+ self.emb_ext = nn.ModuleList()
176
+ for vq_idx in range(self.config.n_vq):
177
+ # Add +1 for potential padding/special tokens logic if strictly required by upstream data prep
178
+ self.emb_ext.append(
179
+ nn.Embedding(self.config.audio_vocab_size + 1, config.language_config.hidden_size, padding_idx=None)
180
+ )
181
+
182
+ # Multi-Head Prediction Layers
183
+ # Head 0: Main language head
184
+ # Head 1..N: Audio VQ heads
185
+ self.lm_heads = nn.ModuleList([
186
+ nn.Linear(config.language_config.hidden_size, config.language_config.vocab_size, bias=False)
187
+ ])
188
+ for vq_idx in range(self.config.n_vq):
189
+ self.lm_heads.append(
190
+ nn.Linear(config.language_config.hidden_size, self.config.audio_vocab_size + 1, bias=False)
191
+ )
192
+
193
+ # Initialize weights and apply final processing
194
+ self.post_init()
195
+
196
+ def get_input_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
197
+ """
198
+ Computes the combined embeddings from text and multiple audio VQ channels.
199
+
200
+ Args:
201
+ input_ids: Shape (Batch, Seq_Len, 1 + n_vq)
202
+ """
203
+ # Base Text/Content Embedding
204
+ # input_ids[..., 0] is standard text or semantic tokens
205
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids[..., 0])
206
+
207
+ # Add VQ Embeddings
208
+ for i, embed_layer in enumerate(self.emb_ext):
209
+ # i corresponds to channel i+1 in input_ids
210
+ # We assume the data pipeline ensures indices are within range
211
+ inputs_embeds = inputs_embeds + embed_layer(input_ids[..., i + 1])
212
+
213
+ return inputs_embeds
214
+
215
+ def set_input_embeddings(self, value):
216
+ self.language_model.embed_tokens = value
217
+
218
+ def get_output_embeddings(self):
219
+ # Returning a list of heads might break some HF utilities expecting a single head.
220
+ # However, for custom models, this is acceptable.
221
+ return self.lm_heads
222
+
223
+ @add_start_docstrings_to_model_forward(MOSSTTS_START_DOCSTRING)
224
+ @replace_return_docstrings(output_type=MossTTSDelayOutputWithPast, config_class=_CONFIG_FOR_DOC)
225
+ def forward(
226
+ self,
227
+ input_ids: Optional[torch.LongTensor] = None,
228
+ attention_mask: Optional[torch.Tensor] = None,
229
+ position_ids: Optional[torch.LongTensor] = None,
230
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
231
+ inputs_embeds: Optional[torch.FloatTensor] = None,
232
+ labels: Optional[torch.LongTensor] = None,
233
+ use_cache: Optional[bool] = None,
234
+ output_attentions: Optional[bool] = None,
235
+ cache_position: Optional[torch.LongTensor] = None,
236
+ hidden_out_layers: Optional[List[int]] = None,
237
+ channelwise_loss_weight: Optional[List[float]] = None,
238
+ **kwargs,
239
+ ) -> Union[Tuple, MossTTSDelayOutputWithPast]:
240
+ r"""
241
+ Args:
242
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 1 + n_vq)`):
243
+ Indices of input sequence tokens in the vocabulary.
244
+ Dimension 2 contains: [Text/Semantics, VQ_0, VQ_1, ..., VQ_N].
245
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length, 1 + n_vq)`, *optional*):
246
+ Labels for computing the masked language modeling loss.
247
+ channelwise_loss_weight (`List[float]`, *optional*):
248
+ Manual weights for summing losses across different heads (Text vs Audio channels).
249
+
250
+ Returns:
251
+ """
252
+
253
+ if len(input_ids.shape) != 3 or input_ids.shape[-1] != self.config.n_vq + 1:
254
+ raise ValueError("`Input_ids`'s shape should be exactly (batch_size, sequence_length, 1 + n_vq).")
255
+
256
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
257
+
258
+ # 1. Prepare Embeddings
259
+ if inputs_embeds is None:
260
+ inputs_embeds = self.get_input_embeddings(input_ids)
261
+
262
+ # 2. Backbone Forward
263
+ # Qwen3Model outputs standard CausalLMOutputWithPast or similar
264
+ outputs = self.language_model(
265
+ input_ids=None, # Passed via inputs_embeds
266
+ position_ids=position_ids,
267
+ attention_mask=attention_mask,
268
+ past_key_values=past_key_values,
269
+ inputs_embeds=inputs_embeds,
270
+ use_cache=use_cache,
271
+ output_attentions=output_attentions,
272
+ output_hidden_states=True, # Always need hidden states for multi-head projection
273
+ return_dict=True,
274
+ cache_position=cache_position,
275
+ **kwargs,
276
+ )
277
+
278
+ # 3. Handle specific layer outputs if requested (Delay Pattern often requires features from specific layers)
279
+ last_hidden_state = outputs.last_hidden_state
280
+ if hidden_out_layers is None:
281
+ # Default to using the last layer for all heads
282
+ # In some architectures (like MusicGen), different codebooks come from different transformer layers.
283
+ # Here we default to the final layer as per original code behavior [-1] * (n + 1).
284
+ hidden_states_for_heads = [last_hidden_state] * (len(self.lm_heads))
285
+ else:
286
+ # If hidden_out_layers is provided (e.g. [-1, -2, -3...]), fetch them from all_hidden_states
287
+ # Note: outputs.hidden_states includes embedding output at index 0 usually.
288
+ all_hs = outputs.hidden_states
289
+ hidden_states_for_heads = [all_hs[idx] for idx in hidden_out_layers]
290
+
291
+ # 4. Project to Logits (Multi-Head)
292
+ layer_logits = []
293
+ for i, (hs, head) in enumerate(zip(hidden_states_for_heads, self.lm_heads)):
294
+ logits = head(hs)
295
+ # Original code logic: Mask the last token index for audio heads (indices > 0)
296
+ # This implies the vocab size is (N+1) but the model shouldn't predict the (N+1)-th token
297
+ # (perhaps reserved for padding in the input but invalid for prediction).
298
+ if i > 0:
299
+ logits[..., -1] = float("-inf")
300
+ layer_logits.append(logits)
301
+
302
+ # 5. Loss Calculation
303
+ loss = None
304
+ all_sum_losses = None
305
+ all_token_nums = None
306
+ sample_losses = None
307
+ channel_losses = None
308
+
309
+ if labels is not None:
310
+ # Ensure labels match input shape rank (B, S, C)
311
+ if labels.dim() != 3:
312
+ raise ValueError(f"Labels must have rank 3 (B, S, C), got {labels.shape}")
313
+
314
+ batch_size = labels.size(0)
315
+ n_heads = len(layer_logits)
316
+
317
+ # Container for per-sample, per-channel losses
318
+ # Shape: [Batch, n_heads]
319
+ all_sum_losses_list = []
320
+
321
+ # Count valid tokens (not -100) per sample.
322
+ # Note: Assuming mask is consistent across channels or we take sum over dim 1 (seq)
323
+ # Usually strict masking means checking one channel or all.
324
+ # Original code: torch.sum(labels != -100, dim=1) -> [B, C]
325
+ all_token_nums = torch.sum(labels != -100, dim=1)
326
+
327
+ for i, logits in enumerate(layer_logits):
328
+ # logits: [B, S, V]
329
+ # cur_labels: [B, S]
330
+ cur_labels = labels[..., i]
331
+
332
+ # Flatten for CrossEntropy
333
+ # logits: [B*S, V], labels: [B*S]
334
+ loss_fct = CrossEntropyLoss(reduction='none')
335
+ vocab_size = logits.size(-1)
336
+
337
+ reshaped_logits = logits.view(-1, vocab_size)
338
+ reshaped_labels = cur_labels.contiguous().view(-1)
339
+
340
+ # Calculate loss per token
341
+ per_token_loss = loss_fct(reshaped_logits, reshaped_labels)
342
+
343
+ # Reshape back to [B, S] and sum over Sequence dimension to get per-sample loss
344
+ per_token_loss = per_token_loss.view(batch_size, -1)
345
+ per_sample_loss = torch.sum(per_token_loss, dim=-1) # [B]
346
+
347
+ all_sum_losses_list.append(per_sample_loss)
348
+
349
+ # Stack to [B, n_heads]
350
+ all_sum_losses = torch.stack(all_sum_losses_list, dim=1)
351
+
352
+ # Weighted Loss Aggregation
353
+ if channelwise_loss_weight is not None:
354
+ if len(channelwise_loss_weight) != n_heads:
355
+ raise ValueError(f"channelwise_loss_weight length {len(channelwise_loss_weight)} != {n_heads}")
356
+
357
+ w_tensor = torch.tensor(channelwise_loss_weight, device=all_sum_losses.device, dtype=all_sum_losses.dtype)
358
+
359
+ # Sample losses: Weighted sum over channels per sample / Total weight
360
+ # Normalize by token count per channel
361
+ # Avoid division by zero with epsilon or mask
362
+ token_counts_safe = all_token_nums.float().clamp(min=1.0)
363
+
364
+ normalized_losses = all_sum_losses / token_counts_safe
365
+ sample_losses = (normalized_losses * w_tensor).sum(dim=1) / w_tensor.sum()
366
+
367
+ # Channel losses: Sum over batch / Sum tokens over batch
368
+ total_loss_per_channel = all_sum_losses.sum(dim=0)
369
+ total_tokens_per_channel = all_token_nums.sum(dim=0).float().clamp(min=1.0)
370
+ channel_losses = total_loss_per_channel / total_tokens_per_channel
371
+
372
+ # Final scalar loss
373
+ loss = (channel_losses * w_tensor).sum() / w_tensor.sum()
374
+ else:
375
+ # Default average if no weights provided
376
+ total_tokens = all_token_nums.sum().float().clamp(min=1.0)
377
+ loss = all_sum_losses.sum() / total_tokens
378
+ channel_losses = all_sum_losses.sum(dim=0) / all_token_nums.sum(dim=0).clamp(min=1.0)
379
+
380
+ return MossTTSDelayOutputWithPast(
381
+ loss=loss,
382
+ all_sum_losses=all_sum_losses,
383
+ all_token_nums=all_token_nums,
384
+ sample_losses=sample_losses,
385
+ channel_losses=channel_losses,
386
+ logits=layer_logits,
387
+ past_key_values=outputs.past_key_values,
388
+ hidden_states=outputs.hidden_states,
389
+ attentions=outputs.attentions,
390
+ )
391
+
392
+ @torch.inference_mode()
393
+ def generate(
394
+ self,
395
+ input_ids: torch.LongTensor,
396
+ attention_mask: Optional[torch.Tensor] = None,
397
+ max_new_tokens: int = 1000,
398
+ text_temperature: float = 1.5,
399
+ text_top_p: float = 1.0,
400
+ text_top_k: int = 50,
401
+ audio_temperature: float = 1.5,
402
+ audio_top_p: float = 0.6,
403
+ audio_top_k: int = 50,
404
+ audio_repetition_penalty: float = 1.1,
405
+ ):
406
+ if text_temperature > 0:
407
+ text_do_sample = True
408
+ else:
409
+ text_temperature = 1
410
+ text_do_sample = False
411
+ if audio_temperature > 0:
412
+ audio_do_sample = True
413
+ else:
414
+ audio_temperature = 1
415
+ audio_do_sample = False
416
+
417
+ past_key_values = None
418
+ device = input_ids.device
419
+ current_input_ids = input_ids
420
+ current_attention_mask = attention_mask
421
+ batch_size, seq_len, n_vq = input_ids.shape
422
+ n_vq -= 1
423
+
424
+ generation_ids = input_ids[:]
425
+ is_stopping = torch.zeros(batch_size, dtype=torch.bool, device=device)
426
+
427
+ # 三个阶段: 1. 非 audio; 2. audio not delay; 3. audio delay
428
+ audio_lengths = torch.zeros(batch_size, dtype=torch.int64, device=device) # 0 的时候表示阶段1;
429
+ torch_int64_max = torch.iinfo(torch.int64).max
430
+ delayed_lengths = torch.full((batch_size,), torch_int64_max, dtype=torch.int64, device=device) # 最大值的时候表示阶段2;
431
+
432
+ # 考虑 continuation 时 audio_start 已经在 input_ids 中的情况;
433
+ # NOTE 注意我们目前不考虑任何输入已经开始 delay 的情况;
434
+ # 需要同时考虑 continuation 和直接生成的情况;
435
+ is_continuation = (input_ids[:, -1, 0] == self.config.audio_start_token_id) | (input_ids[:, -1, 0] == self.config.audio_assistant_gen_slot_token_id)
436
+ audio_start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id)
437
+ audio_start_mask = is_continuation & (audio_start_indices != -1)
438
+ audio_lengths[audio_start_mask] = seq_len - audio_start_indices[audio_start_mask]
439
+
440
+ is_audio = audio_start_mask.clone()
441
+
442
+ pre_exclude_mask0 = torch.tensor([self.config.pad_token_id, self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id, self.config.audio_end_token_id], device=device)
443
+ pre_exclude_mask1 = torch.ones(self.config.language_config.vocab_size, device=device).bool()
444
+ pre_exclude_mask1[[self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id]] = False
445
+
446
+
447
+ # 注意 time_step 未必表示对于实际对话时,当前输出token的位置,因为有续写的情况;
448
+ for time_step in tqdm(range(max_new_tokens), desc=f"Generating bs{batch_size} ..."):
449
+ outputs = self(
450
+ input_ids=current_input_ids,
451
+ attention_mask=current_attention_mask,
452
+ past_key_values=past_key_values,
453
+ use_cache=True,
454
+ )
455
+ past_key_values = outputs.past_key_values
456
+
457
+ next_token_logits = [logit[:, -1, :] / text_temperature if logit_idx == 0 else logit[:, -1, :] / audio_temperature for logit_idx, logit in enumerate(outputs.logits)] # List, len=n_vq+1, [batch_size, 1, vocab_size];
458
+ next_token_logits[0] = next_token_logits[0].clone()
459
+ # 1. 先处理 text token;
460
+ next_text_token = torch.full((batch_size,), self.config.pad_token_id, device=device)
461
+ # 第二个 audio_assistant_delay_slot_token_id 和 audio_end 是不需要采样的,audio_start, 每一个 audio_assistant_gen_slot_token_ids 和第一个 audio_assistant_delay_slot_token_id 是需要采样的;
462
+ next_text_token[~is_stopping & (delayed_lengths < n_vq)] = self.config.audio_assistant_delay_slot_token_id
463
+ is_audio_eos = ~is_stopping & (delayed_lengths == n_vq)
464
+ next_text_token[is_audio_eos] = self.config.audio_end_token_id
465
+ is_audio[is_audio_eos] = False
466
+ sampling_text_mask = ~is_stopping & (delayed_lengths > n_vq)
467
+ next_token_logits[0][~is_audio] = next_token_logits[0][~is_audio].index_fill(-1, pre_exclude_mask0, float('-inf'))
468
+ next_token_logits[0][is_audio] = next_token_logits[0][is_audio].masked_fill(pre_exclude_mask1, float('-inf'))
469
+ if time_step == 0:
470
+ next_token_logits[0][..., 151662] = float('-inf')
471
+ if time_step <= n_vq:
472
+ next_token_logits[0][..., self.config.im_end_token_id] = float('-inf')
473
+
474
+ # 文本层不使用重复惩罚;
475
+ next_text_token[sampling_text_mask] = sample_token(
476
+ logits=next_token_logits[0][sampling_text_mask],
477
+ top_p=text_top_p,
478
+ top_k=text_top_k,
479
+ do_sample=text_do_sample
480
+ )
481
+ is_audio[next_text_token == self.config.audio_start_token_id] = True
482
+ # 只存在一种停止逻辑,即 next_text_token = <|im_end|>;
483
+ is_stopping[next_text_token == self.config.im_end_token_id] = True
484
+
485
+ # 2. 再处理 audio tokens;
486
+ # audio_start 和 audio_end 之外的内容直接pad,默认是 pad,我们只需要填充有值的部分即可;
487
+ next_audio_tokens = torch.full((batch_size, n_vq), self.config.audio_pad_code, device=device)
488
+
489
+ # 需要考虑的是与 audio_start 的距离;
490
+ # 先查看是否是pad的情况; true 表示有值;
491
+ pre_audio_mask = audio_lengths.unsqueeze(1) > torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq)
492
+ post_audio_mask = torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) > delayed_lengths.unsqueeze(1) - 1
493
+ post_audio_mask[delayed_lengths == torch_int64_max] = True
494
+ sampling_audio_mask = pre_audio_mask & post_audio_mask
495
+ next_audio_tokens[~sampling_audio_mask] = self.config.audio_pad_code
496
+
497
+ if sampling_audio_mask.sum() > 0:
498
+ audio_logits = torch.stack(next_token_logits[1:], dim=1)[sampling_audio_mask] # torch.stack -> [batch_size, n_vq - 1, vocab_size]
499
+ audio_logits[..., self.config.audio_pad_code] = float('-inf')
500
+ next_audio_tokens[sampling_audio_mask] = sample_token(
501
+ logits=audio_logits,
502
+ prev_tokens=generation_ids[:, :, 1:],
503
+ repetition_penalty=audio_repetition_penalty,
504
+ top_p=audio_top_p,
505
+ top_k=audio_top_k,
506
+ do_sample=audio_do_sample
507
+ )
508
+
509
+ # 这里显示的是下一个时间步时可以直接使用的 audio_lengths 和 delayed_lengths 的状态;
510
+ # audio_lengths[(next_text_token == self.audio_start_token_id) & (audio_lengths > 0)] += 1
511
+ # audio_lengths[(next_text_token == self.audio_start_token_id) | (next_text_token == self.audio_assistant_gen_slot_token_id)] += 1
512
+ audio_lengths[(next_text_token == self.config.audio_start_token_id) | (next_text_token == self.config.audio_assistant_gen_slot_token_id) | (next_text_token == self.config.audio_assistant_delay_slot_token_id)] += 1
513
+ audio_lengths[next_text_token == self.config.audio_end_token_id] = 0
514
+ delayed_lengths[(delayed_lengths == torch_int64_max) & (next_text_token == self.config.audio_assistant_delay_slot_token_id)] = 0
515
+ delayed_lengths[delayed_lengths != torch_int64_max] += 1
516
+ delayed_lengths[delayed_lengths > n_vq] = torch_int64_max
517
+
518
+ current_input_ids = torch.cat([next_text_token[:, None, None], next_audio_tokens[:, None, :]], dim=2) # [batch_size, 1, n_vq + 1]
519
+ current_attention_mask = torch.cat([current_attention_mask, (~is_stopping).unsqueeze(-1)], dim=-1)
520
+ generation_ids = torch.cat([generation_ids, current_input_ids], dim=1) # [batch_size, seq_len, n_vq + 1]
521
+
522
+ if is_stopping.sum() == batch_size:
523
+ break
524
+
525
+ start_indices = find_last_equal_C(input_ids[..., 0], self.config.im_start_token_id) + 3
526
+ start_lengths = seq_len - start_indices
527
+
528
+ output = []
529
+ for start_idx, start_length, cur_generation_ids in zip(start_indices, start_lengths, generation_ids):
530
+ output.append((start_length, cur_generation_ids[start_idx:]))
531
+
532
+ return output
processing_moss_tts.py ADDED
@@ -0,0 +1,1017 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal, Final, cast
18
+ from dataclasses import dataclass
19
+ from pathlib import Path
20
+ import re
21
+ import torchaudio
22
+
23
+ from transformers import processing_utils
24
+
25
+ processing_utils.MODALITY_TO_BASE_CLASS_MAPPING["audio_tokenizer"] = "PreTrainedModel"
26
+
27
+ import torch
28
+ from transformers import (
29
+ PreTrainedTokenizerBase,
30
+ BatchFeature,
31
+ ProcessorMixin,
32
+ logging,
33
+ AutoConfig,
34
+ AutoModel,
35
+ AutoTokenizer,
36
+ )
37
+
38
+ from .configuration_moss_tts import MossTTSDelayConfig
39
+
40
+
41
+ def normalize_instruction(instruction: str) -> str:
42
+ """
43
+ Normalize instruction:
44
+ 1. Remove [] and {} tags
45
+ 2. Replace decorative symbols with comma
46
+ 3. Remove consecutive duplicate punctuation
47
+ 4. Remove line breaks
48
+ 5. If contains Chinese, replace English comma with Chinese comma
49
+ 6. Keep quotes
50
+ """
51
+ if not instruction:
52
+ return instruction
53
+
54
+ # Remove line breaks
55
+ instruction = instruction.replace("\n", " ")
56
+
57
+ # Remove [] and {} tags
58
+ instruction = re.sub(r"\[.*?\]", "", instruction)
59
+ instruction = re.sub(r"\{.*?\}", "", instruction)
60
+
61
+ # Replace decorative symbols with comma
62
+ decorative_chars = "【】《》()『』「」~-_"
63
+ for char in decorative_chars:
64
+ instruction = instruction.replace(char, ",")
65
+
66
+ # Remove consecutive punctuation (keep only first one)
67
+ instruction = re.sub(r'([,。!?,.!?;;])+', r'\1', instruction)
68
+
69
+ # Check if contains Chinese characters
70
+ has_chinese = bool(re.search(r'[\u4e00-\u9fff]', instruction))
71
+
72
+ if has_chinese:
73
+ # Replace English comma with Chinese comma
74
+ instruction = instruction.replace(',', ',')
75
+
76
+ return instruction.strip()
77
+
78
+ def normalize_text(text: str) -> str:
79
+ """
80
+ Normalize text:
81
+ 1. Remove [] and {} tags
82
+ 2. Replace decorative symbols with comma
83
+ 3. Remove consecutive duplicate punctuation
84
+ 4. Remove line breaks
85
+ 5. Remove quotes
86
+ """
87
+ if not text:
88
+ return text
89
+
90
+ # Remove line breaks
91
+ text = text.replace("\n", " ")
92
+
93
+ # Remove [] and {} tags
94
+ text = re.sub(r"\[.*?\]", "", text)
95
+ text = re.sub(r"\{.*?\}", "", text)
96
+
97
+ # Replace decorative symbols with comma
98
+ decorative_chars = "【】《》()『』「」~-_"
99
+ for char in decorative_chars:
100
+ text = text.replace(char, ",")
101
+
102
+ # Remove quotes (中英文引号)
103
+ quotes = ['"', '"', '"', "'", "'", "'"]
104
+ for q in quotes:
105
+ text = text.replace(q, "")
106
+
107
+ # Remove consecutive punctuation (keep only first one)
108
+ text = re.sub(r'([,。!?,.!?;;])+', r'\1', text)
109
+
110
+ return text.strip()
111
+
112
+ logger = logging.get_logger(__name__)
113
+
114
+
115
+ AUDIO_PLACEHOLDER = "<|audio|>"
116
+
117
+
118
+ @dataclass
119
+ class Message:
120
+ def to_dict(self) -> Dict[str, Any]:
121
+ raise NotImplementedError
122
+
123
+
124
+ @dataclass
125
+ class UserMessage(Message):
126
+ text: Optional[str] = None
127
+ reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None
128
+ instruction: Optional[str] = None
129
+ tokens: Optional[int] = None
130
+ quality: Optional[str] = None
131
+ sound_event: Optional[str] = None
132
+ ambient_sound: Optional[str] = None
133
+ language: Optional[str] = None
134
+
135
+ def __post_init__(self):
136
+ template = """<user_inst>
137
+ - Reference(s):
138
+ {reference}
139
+ - Instruction:
140
+ {instruction}
141
+ - Tokens:
142
+ {tokens}
143
+ - Quality:
144
+ {quality}
145
+ - Sound Event:
146
+ {sound_event}
147
+ - Ambient Sound:
148
+ {ambient_sound}
149
+ - Language:
150
+ {language}
151
+ - Text:
152
+ {text}
153
+ </user_inst>"""
154
+
155
+ audio_codes_list = []
156
+ if self.reference is None:
157
+ reference = "None"
158
+ elif isinstance(self.reference, List):
159
+ reference = []
160
+ for speaker_idx, speaker_reference in enumerate(self.reference):
161
+ if speaker_reference is not None:
162
+ reference.append(f"[S{speaker_idx}]:\n{AUDIO_PLACEHOLDER}")
163
+ reference = "\n".join(reference)
164
+ audio_codes_list = [
165
+ speaker_reference
166
+ for speaker_reference in self.reference
167
+ if speaker_reference is not None
168
+ ]
169
+ else:
170
+ raise TypeError("`reference` should be exactly a list when it is not None.")
171
+
172
+ content = (
173
+ template.replace("{reference}", str(reference))
174
+ .replace("{instruction}", str(self.instruction))
175
+ .replace("{tokens}", str(self.tokens))
176
+ .replace("{quality}", str(self.quality))
177
+ .replace("{sound_event}", str(self.sound_event))
178
+ .replace("{ambient_sound}", str(self.ambient_sound))
179
+ .replace("{language}", str(self.language))
180
+ .replace("{text}", str(self.text))
181
+ )
182
+
183
+ self._content = content
184
+ self._audio_codes_list = audio_codes_list
185
+
186
+ def to_dict(self):
187
+ return {
188
+ "role": "user",
189
+ "content": self._content,
190
+ "audio_codes_list": self._audio_codes_list,
191
+ }
192
+
193
+
194
+ @dataclass
195
+ class AssistantMessage(Message):
196
+ audio_codes_list: List[Union[str, torch.Tensor]]
197
+ content: str = AUDIO_PLACEHOLDER
198
+
199
+ def to_dict(self):
200
+ return {
201
+ "role": "assistant",
202
+ "content": self.content,
203
+ "audio_codes_list": self.audio_codes_list,
204
+ }
205
+
206
+
207
+ USER_MESSAGE_FIELDS = (
208
+ "text",
209
+ "reference",
210
+ "instruction",
211
+ "tokens",
212
+ "quality",
213
+ "sound_event",
214
+ "ambient_sound",
215
+ "language",
216
+ )
217
+
218
+
219
+ class MossTTSDelayProcessor(ProcessorMixin):
220
+ tokenizer_class = "AutoTokenizer"
221
+ audio_tokenizer_class = "AutoModel"
222
+
223
+ tokenizer: PreTrainedTokenizerBase
224
+ audio_tokenizer: Any
225
+
226
+ def __init__(
227
+ self,
228
+ tokenizer: PreTrainedTokenizerBase,
229
+ audio_tokenizer: Any = None,
230
+ model_config: Optional[MossTTSDelayConfig] = None,
231
+ normalize_inputs: bool = False,
232
+ **kwargs,
233
+ ):
234
+ super().__init__(tokenizer=tokenizer, audio_tokenizer=audio_tokenizer, **kwargs)
235
+
236
+ # Explicit assignments for type-checkers; ProcessorMixin sets these too.
237
+ self.tokenizer = tokenizer
238
+ self.audio_tokenizer = audio_tokenizer
239
+ if model_config is None:
240
+ model_config = MossTTSDelayConfig()
241
+ self.model_config = model_config
242
+ self.normalize_inputs = normalize_inputs
243
+
244
+ self.imstart_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
245
+ self.imend_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
246
+ self.newline_token_id = 198
247
+
248
+ def _id_to_token(token_id: int) -> str:
249
+ tok = tokenizer.convert_ids_to_tokens(int(token_id))
250
+ if isinstance(tok, list):
251
+ return tok[0] if len(tok) > 0 else ""
252
+ return cast(str, tok)
253
+
254
+ self.audio_user_slot_token = _id_to_token(
255
+ self.model_config.audio_user_slot_token_id
256
+ )
257
+ self.audio_assistant_gen_slot_token = _id_to_token(
258
+ self.model_config.audio_assistant_gen_slot_token_id
259
+ )
260
+ self.audio_assistant_delay_slot_token = _id_to_token(
261
+ self.model_config.audio_assistant_delay_slot_token_id
262
+ )
263
+ self.audio_start_token = _id_to_token(self.model_config.audio_start_token_id)
264
+ self.audio_end_token = _id_to_token(self.model_config.audio_end_token_id)
265
+
266
+ @classmethod
267
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
268
+ trust_remote_code = kwargs.pop("trust_remote_code", True)
269
+ kwargs.pop("_from_auto", None)
270
+
271
+ audio_tokenizer_name_or_path = kwargs.pop(
272
+ "codec_path", "OpenMOSS-Team/MOSS-Audio-Tokenizer"
273
+ )
274
+ normalize_inputs = kwargs.pop("normalize_inputs", False)
275
+
276
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
277
+ model_config = cast(
278
+ MossTTSDelayConfig,
279
+ AutoConfig.from_pretrained(
280
+ pretrained_model_name_or_path,
281
+ *args,
282
+ trust_remote_code=trust_remote_code,
283
+ **kwargs,
284
+ ),
285
+ )
286
+ tokenizer = AutoTokenizer.from_pretrained(
287
+ pretrained_model_name_or_path,
288
+ *args,
289
+ trust_remote_code=trust_remote_code,
290
+ **kwargs,
291
+ )
292
+ audio_tokenizer = AutoModel.from_pretrained(
293
+ audio_tokenizer_name_or_path,
294
+ trust_remote_code=trust_remote_code,
295
+ **kwargs,
296
+ )
297
+
298
+ return cls(
299
+ tokenizer=tokenizer,
300
+ audio_tokenizer=audio_tokenizer,
301
+ model_config=model_config,
302
+ normalize_inputs=normalize_inputs,
303
+ **kwargs,
304
+ )
305
+
306
+ def __call__(self, *args, **kwargs) -> BatchFeature:
307
+ conversations = args[0] if len(args) > 0 else kwargs.pop("conversations")
308
+ mode: str = kwargs.pop("mode", "generation")
309
+ apply_chat_template: bool = kwargs.pop("apply_chat_template", True)
310
+ n_vq: Optional[int] = kwargs.pop("n_vq", None)
311
+
312
+ # Common ProcessorMixin kwargs that we ignore because we always return torch tensors.
313
+ kwargs.pop("return_tensors", None)
314
+ kwargs.pop("padding", None)
315
+ kwargs.pop("truncation", None)
316
+
317
+ """
318
+ mode only works when a Message is converted to a dict.
319
+ """
320
+
321
+ if mode not in {"generation", "continuation"}:
322
+ raise RuntimeError
323
+
324
+ if isinstance(conversations, (Message, Dict)):
325
+ conversations = [conversations]
326
+
327
+ truncation = False
328
+ if mode == "continuation":
329
+ truncation = True
330
+
331
+ input_ids_list = []
332
+ for conversation in conversations:
333
+ if isinstance(conversation, (Message, Dict)):
334
+ conversation = [conversation]
335
+
336
+ # Normalize early so downstream logic always deals with dict messages.
337
+ conversation = [self._normalize_message(m) for m in conversation]
338
+
339
+ if (mode == "generation") ^ (len(conversation) % 2 != 0):
340
+ raise ValueError
341
+
342
+ if (mode == "generation") ^ (conversation[-1]["role"] == "user"):
343
+ raise ValueError
344
+
345
+ unified_codes = []
346
+ for message_idx, message in enumerate(conversation):
347
+ if apply_chat_template:
348
+ add_generation_prompt = (
349
+ mode == "generation" and message_idx == len(conversation) - 1
350
+ )
351
+ try:
352
+ content = self.tokenizer.apply_chat_template(
353
+ [{"role": message["role"], "content": message["content"]}],
354
+ add_generation_prompt=add_generation_prompt,
355
+ tokenize=False,
356
+ )
357
+ except TypeError:
358
+ try:
359
+ content = self.tokenizer.apply_chat_template(
360
+ [
361
+ {
362
+ "role": message["role"],
363
+ "content": message["content"],
364
+ }
365
+ ],
366
+ add_generation_prompt=add_generation_prompt,
367
+ )
368
+ except Exception:
369
+ logger.warning(
370
+ "apply_chat_template failed; fallback to raw content."
371
+ )
372
+ content = message["content"]
373
+ else:
374
+ content = message["content"]
375
+
376
+ if not isinstance(content, str):
377
+ content = str(content)
378
+
379
+ # Batch-encode all path-based references in one call when possible.
380
+ # This ensures we actually exercise audio_tokenizer.batch_encode for multi-reference prompts,
381
+ # instead of repeatedly calling it with batch=1.
382
+ raw_audio_items = message.get("audio_codes_list", [])
383
+
384
+ audio_codes_list: List[torch.Tensor] = []
385
+ if len(raw_audio_items) > 0:
386
+ encoded_items: List[Optional[torch.Tensor]] = [None] * len(
387
+ raw_audio_items
388
+ )
389
+ paths: List[str] = []
390
+ path_positions: List[int] = []
391
+
392
+ for idx, item in enumerate(raw_audio_items):
393
+ if isinstance(item, torch.Tensor):
394
+ if n_vq is not None and item.shape[1] != n_vq:
395
+ raise RuntimeError(
396
+ "audio_codes's n_vq is not equal to the parameter `n_vq`. Your can set the parameter `n_vq` as None if you have already tokenzied the wavs."
397
+ )
398
+ encoded_items[idx] = item
399
+ continue
400
+
401
+ if isinstance(item, (str, os.PathLike)):
402
+ paths.append(str(item))
403
+ path_positions.append(idx)
404
+ continue
405
+
406
+ raise TypeError(
407
+ "Each audio item must be a torch.Tensor of codes or a path-like string."
408
+ )
409
+
410
+ if len(paths) > 0:
411
+ encoded_from_paths = self.encode_audios_from_path(paths, n_vq)
412
+ if len(encoded_from_paths) != len(paths):
413
+ raise RuntimeError(
414
+ "encode_audios_from_path returned an unexpected number of items."
415
+ )
416
+ for pos, codes in zip(path_positions, encoded_from_paths):
417
+ encoded_items[pos] = codes
418
+
419
+ audio_codes_list = [cast(torch.Tensor, t) for t in encoded_items]
420
+ unified_codes.append(
421
+ self._get_unified_codes(
422
+ message["role"], content, audio_codes_list, truncation
423
+ )
424
+ )
425
+
426
+ unified_codes = torch.cat(unified_codes)
427
+ input_ids_list.append(unified_codes)
428
+
429
+ return BatchFeature(data=self._pad(input_ids_list))
430
+
431
+ @staticmethod
432
+ def build_user_message(
433
+ text: Optional[str] = None,
434
+ reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None,
435
+ instruction: Optional[str] = None,
436
+ tokens: Optional[int] = None,
437
+ quality: Optional[str] = None,
438
+ sound_event: Optional[str] = None,
439
+ ambient_sound: Optional[str] = None,
440
+ language: Optional[str] = None,
441
+ normalize: bool = False,
442
+ ) -> Dict:
443
+ if normalize:
444
+ if text is not None:
445
+ text = normalize_text(text)
446
+ if instruction is not None:
447
+ instruction = normalize_instruction(instruction)
448
+ if reference is not None and not isinstance(reference, list):
449
+ reference = [reference]
450
+ return UserMessage(
451
+ text=text,
452
+ reference=reference,
453
+ instruction=instruction,
454
+ tokens=tokens,
455
+ quality=quality,
456
+ sound_event=sound_event,
457
+ ambient_sound=ambient_sound,
458
+ language=language,
459
+ ).to_dict()
460
+
461
+ @staticmethod
462
+ def build_assistant_message(
463
+ audio_codes_list: List[Union[str, torch.Tensor]],
464
+ content: str = AUDIO_PLACEHOLDER,
465
+ ) -> Dict:
466
+ return AssistantMessage(
467
+ audio_codes_list=audio_codes_list,
468
+ content=content,
469
+ ).to_dict()
470
+
471
+ def _normalize_message(self, message: Union[Message, Dict]) -> Dict:
472
+ if isinstance(message, Message):
473
+ return message.to_dict()
474
+ if not isinstance(message, dict):
475
+ raise TypeError("Each message must be a Message or dict.")
476
+ if "role" not in message:
477
+ raise ValueError("Message dict must include a 'role' field.")
478
+ if "content" in message and "audio_codes_list" in message:
479
+ return message
480
+ role = message["role"]
481
+ if role == "user":
482
+ kwargs = {key: message.get(key) for key in USER_MESSAGE_FIELDS}
483
+ # 应用processor的全局normalize设置
484
+ kwargs['normalize'] = self.normalize_inputs
485
+ return self.build_user_message(**kwargs)
486
+ if role == "assistant":
487
+ return self.build_assistant_message(
488
+ audio_codes_list=message.get("audio_codes_list", []),
489
+ content=message.get("content", AUDIO_PLACEHOLDER),
490
+ )
491
+ raise ValueError(f"Unsupported role: {role}")
492
+
493
+ def _pad(self, input_ids_list: List[torch.Tensor]):
494
+ device = input_ids_list[0].device
495
+ lengths = torch.tensor([w.shape[0] for w in input_ids_list], device=device)
496
+ pad_input_ids = torch.nn.utils.rnn.pad_sequence(
497
+ input_ids_list,
498
+ batch_first=True,
499
+ padding_value=self.model_config.audio_pad_code,
500
+ padding_side="left",
501
+ )
502
+ other_channel_mask = (pad_input_ids.shape[1] - lengths).unsqueeze(
503
+ 1
504
+ ) > torch.arange(pad_input_ids.shape[1], device=device).unsqueeze(0)
505
+ pad_input_ids[..., 0][other_channel_mask] = self.model_config.pad_token_id
506
+ attention_mask = torch.zeros(
507
+ pad_input_ids.shape[0], pad_input_ids.shape[1], device=device
508
+ )
509
+ attention_mask[~other_channel_mask] = 1
510
+ attention_mask = attention_mask.bool()
511
+ return {
512
+ "input_ids": pad_input_ids, # [batch_size, seqlen, n_vq]
513
+ "attention_mask": attention_mask,
514
+ }
515
+
516
+ @staticmethod
517
+ def _replace_audio_placeholders(
518
+ content: str,
519
+ lengths: List[int],
520
+ n_vq: int,
521
+ gen_slot_token: str,
522
+ delay_slot_token: str,
523
+ audio_start_token: str,
524
+ audio_end_token: str,
525
+ ) -> str:
526
+ if n_vq < 1:
527
+ raise ValueError(f"n_vq must be >= 1, got {n_vq}")
528
+
529
+ num_placeholders = content.count(AUDIO_PLACEHOLDER)
530
+ if num_placeholders != len(lengths):
531
+ raise ValueError(
532
+ f"Number of {AUDIO_PLACEHOLDER} ({num_placeholders}) "
533
+ f"does not match lengths ({len(lengths)})"
534
+ )
535
+
536
+ def build_audio_block(length: int) -> str:
537
+ if length < 0:
538
+ raise ValueError(f"length must be >= 0, got {length}")
539
+
540
+ if length == 0:
541
+ return f"{audio_start_token}{audio_end_token}"
542
+
543
+ step_tokens = gen_slot_token * length + (delay_slot_token * (n_vq - 1))
544
+ return f"{audio_start_token}{step_tokens}{audio_end_token}"
545
+
546
+ lengths_iter = iter(lengths)
547
+
548
+ def replacer(match: re.Match) -> str:
549
+ length = next(lengths_iter)
550
+ return build_audio_block(length)
551
+
552
+ result = re.sub(re.escape(AUDIO_PLACEHOLDER), replacer, content)
553
+
554
+ return result
555
+
556
+ @staticmethod
557
+ def _merge_consecutive_audio_placeholders(
558
+ content: str,
559
+ audio_codes_list: List[torch.Tensor],
560
+ ) -> Tuple[str, List[torch.Tensor]]:
561
+ matches = list(re.finditer(re.escape(AUDIO_PLACEHOLDER), content))
562
+ if len(matches) <= 1:
563
+ return content, audio_codes_list
564
+
565
+ if len(matches) != len(audio_codes_list):
566
+ raise ValueError(
567
+ "Audio placeholders do not match the provided audio codes list."
568
+ )
569
+
570
+ new_audio_codes_list = []
571
+ new_parts = []
572
+ last_pos = 0
573
+ i = 0
574
+ while i < len(matches):
575
+ j = i
576
+ while (
577
+ j + 1 < len(matches)
578
+ and content[matches[j].end() : matches[j + 1].start()].strip() == ""
579
+ ):
580
+ j += 1
581
+
582
+ new_parts.append(content[last_pos : matches[i].start()])
583
+ new_parts.append(AUDIO_PLACEHOLDER)
584
+ last_pos = matches[j].end()
585
+
586
+ if j == i:
587
+ new_audio_codes_list.append(audio_codes_list[i])
588
+ else:
589
+ new_audio_codes_list.append(
590
+ torch.cat(audio_codes_list[i : j + 1], dim=0)
591
+ )
592
+
593
+ i = j + 1
594
+
595
+ new_parts.append(content[last_pos:])
596
+ return "".join(new_parts), new_audio_codes_list
597
+
598
+ @staticmethod
599
+ def apply_delay_pattern(codes: torch.Tensor, pad_code: int) -> torch.Tensor:
600
+ delayed_tokens = torch.full(
601
+ (codes.shape[0] + codes.shape[1] - 1, codes.shape[1]),
602
+ pad_code,
603
+ device=codes.device,
604
+ dtype=codes.dtype,
605
+ )
606
+ for i in range(codes.shape[1]):
607
+ delayed_tokens[i : i + codes.shape[0], i] = codes[:, i]
608
+ return delayed_tokens
609
+
610
+ @staticmethod
611
+ def apply_de_delay_pattern(delay_codes: torch.Tensor) -> torch.Tensor:
612
+ tokens = torch.full(
613
+ (delay_codes.shape[0] - delay_codes.shape[1] + 1, delay_codes.shape[1]),
614
+ 0,
615
+ device=delay_codes.device,
616
+ dtype=delay_codes.dtype,
617
+ )
618
+ for i in range(delay_codes.shape[1]):
619
+ tokens[:, i] = delay_codes[i : i + tokens.shape[0], i]
620
+ return tokens
621
+
622
+ def _get_unified_codes(
623
+ self,
624
+ role: str,
625
+ content: str,
626
+ audio_codes_list: List[torch.Tensor],
627
+ truncation: bool,
628
+ ) -> torch.Tensor:
629
+ """
630
+ 此时的 content 已经是带上了对话格式
631
+ """
632
+ if role == "user":
633
+ audio_gen_slot_token = audio_delay_slot_token = self.audio_user_slot_token
634
+ else:
635
+ audio_gen_slot_token = self.audio_assistant_gen_slot_token
636
+ audio_delay_slot_token = self.audio_assistant_delay_slot_token
637
+
638
+ if len(audio_codes_list):
639
+ n_vq = audio_codes_list[0].shape[1]
640
+ else:
641
+ n_vq = self.model_config.n_vq
642
+
643
+ if len(audio_codes_list) > 1 and AUDIO_PLACEHOLDER in content:
644
+ content, audio_codes_list = self._merge_consecutive_audio_placeholders(
645
+ content, audio_codes_list
646
+ )
647
+ content = self._replace_audio_placeholders(
648
+ content=content,
649
+ lengths=[len(audio_codes) for audio_codes in audio_codes_list],
650
+ n_vq=n_vq,
651
+ gen_slot_token=audio_gen_slot_token,
652
+ delay_slot_token=audio_delay_slot_token,
653
+ audio_start_token=self.audio_start_token,
654
+ audio_end_token=self.audio_end_token,
655
+ )
656
+ text_codes = torch.tensor(
657
+ self.tokenizer.encode(content),
658
+ device=audio_codes_list[0].device if audio_codes_list else None,
659
+ )
660
+
661
+ audio_start_indices = torch.where(
662
+ text_codes == self.model_config.audio_start_token_id
663
+ )[0]
664
+ audio_end_indices = torch.where(
665
+ text_codes == self.model_config.audio_end_token_id
666
+ )[0]
667
+ if len(audio_start_indices) != len(audio_codes_list) or len(
668
+ audio_end_indices
669
+ ) != len(audio_codes_list):
670
+ raise ValueError(
671
+ "Audio placeholders do not match the provided audio codes list."
672
+ )
673
+
674
+ delay_audio_codes_list = []
675
+ if len(audio_codes_list) == 0:
676
+ delay_audio_codes_list = torch.full(
677
+ (len(text_codes), n_vq),
678
+ self.model_config.audio_pad_code,
679
+ device=text_codes.device,
680
+ dtype=text_codes.dtype,
681
+ )
682
+ else:
683
+ prefix_idx = 0
684
+ for audio_start_idx_t, audio_end_idx_t, audio_codes in zip(
685
+ audio_start_indices, audio_end_indices, audio_codes_list
686
+ ):
687
+ audio_start_idx = int(audio_start_idx_t.item())
688
+ audio_end_idx = int(audio_end_idx_t.item())
689
+ delay_audio_codes = self.apply_delay_pattern(
690
+ audio_ocodes, self.model_config.audio_pad_code
691
+ )
692
+ pad_codes = torch.full(
693
+ (audio_start_idx - prefix_idx + 1, n_vq),
694
+ self.model_config.audio_pad_code,
695
+ device=audio_codes.device,
696
+ dtype=audio_codes.dtype,
697
+ )
698
+ delay_audio_codes_list.extend([pad_codes, delay_audio_codes])
699
+ prefix_idx = audio_end_idx
700
+
701
+ if truncation:
702
+ delay_audio_codes_list[-1] = delay_audio_codes_list[-1][
703
+ : -(n_vq - 1), :
704
+ ]
705
+ else:
706
+ last_audio_end_idx = int(audio_end_indices[-1].item())
707
+ pad_codes = torch.full(
708
+ (len(text_codes) - last_audio_end_idx, n_vq),
709
+ self.model_config.audio_pad_code,
710
+ device=audio_codes_list[0].device,
711
+ dtype=audio_codes_list[0].dtype,
712
+ )
713
+ delay_audio_codes_list.append(pad_codes)
714
+
715
+ delay_audio_codes_list = torch.cat(delay_audio_codes_list)
716
+
717
+ if text_codes.shape[0] != delay_audio_codes_list.shape[0]:
718
+ text_codes = text_codes[: delay_audio_codes_list.shape[0]]
719
+
720
+ unified_codes = torch.cat(
721
+ [text_codes.unsqueeze(1), delay_audio_codes_list], dim=1
722
+ )
723
+ return unified_codes
724
+
725
+ def _parse_text_codes(self, start_length, text_codes):
726
+ text = cast(str, self.tokenizer.decode(text_codes))
727
+ prefix = cast(str, self.tokenizer.decode(text_codes[:start_length]))
728
+ text = text[len(prefix) :]
729
+
730
+ AUDIO_PATTERN = re.compile(
731
+ rf"(?:{self.audio_start_token})?"
732
+ rf"(?:{self.audio_assistant_gen_slot_token})*"
733
+ rf"(?:{self.audio_assistant_delay_slot_token})*"
734
+ rf"{self.audio_end_token}"
735
+ )
736
+
737
+ def normalize_audio_segments(text: str) -> str:
738
+ def repl(match: re.Match) -> str:
739
+ seg = match.group(0)
740
+ # Replace with <|audio|> if gen_slot is present in the segment;
741
+ if self.audio_assistant_gen_slot_token in seg:
742
+ return AUDIO_PLACEHOLDER
743
+ # Otherwise, remove it.
744
+ return ""
745
+
746
+ return AUDIO_PATTERN.sub(repl, text)
747
+
748
+ return normalize_audio_segments(text)
749
+
750
+ def _parse_audio_codes(self, start_length, audio_codes):
751
+ # De-delay back to [T', n_vq]
752
+ audio_codes = self.apply_de_delay_pattern(audio_codes)
753
+
754
+ # Rows that are all pad are separators between real audio segments.
755
+ is_pad = (audio_codes == self.model_config.audio_pad_code).all(dim=1)
756
+ non_pad = ~is_pad
757
+ if not non_pad.any():
758
+ return []
759
+
760
+ idx = torch.nonzero(non_pad).squeeze(1)
761
+ breaks = torch.where(idx[1:] != idx[:-1] + 1)[0] + 1
762
+ if breaks.numel() == 0:
763
+ segments_idx = [idx]
764
+ else:
765
+ segments_idx = torch.split(idx, breaks.tolist())
766
+
767
+ audio_codes_list = [audio_codes[s] for s in segments_idx]
768
+
769
+ # Batch-decode all audio segments together.
770
+ decoded_audio_list = self.decode_audio_codes(audio_codes_list)
771
+
772
+ # Keep codec causal context by decoding the whole first segment first,
773
+ # then trim at waveform level according to start_length ratio.
774
+ if (
775
+ start_length > 0
776
+ and len(audio_codes_list) > 0
777
+ and len(decoded_audio_list) > 0
778
+ ):
779
+ first_codes_length = audio_codes_list[0].shape[0]
780
+ if first_codes_length > 0:
781
+ trim_ratio = max(
782
+ 0.0, min(float(start_length) / float(first_codes_length), 1.0)
783
+ )
784
+ first_audio = decoded_audio_list[0]
785
+ if trim_ratio >= 1.0:
786
+ decoded_audio_list = decoded_audio_list[1:]
787
+ elif trim_ratio > 0.0:
788
+ trim_samples = int(first_audio.shape[-1] * trim_ratio)
789
+ decoded_audio_list[0] = first_audio[..., trim_samples:]
790
+
791
+ return decoded_audio_list
792
+
793
+ def decode(self, output: List[Tuple[int, torch.Tensor]]):
794
+ """
795
+ 1. 这里不管怎样,都需要一个完整的 assistant generation ids;
796
+ 2. 支持从任意位置进行截断;
797
+ """
798
+
799
+ genearted_messages = []
800
+ for start_length, generation_ids in output:
801
+ content = self._parse_text_codes(start_length, generation_ids[:, 0])
802
+ audio_codes_list = self._parse_audio_codes(
803
+ start_length, generation_ids[:, 1:]
804
+ )
805
+ if content == "":
806
+ message = None
807
+ else:
808
+ message = AssistantMessage(
809
+ content=content,
810
+ audio_codes_list=cast(
811
+ List[Union[str, torch.Tensor]], audio_codes_list
812
+ ),
813
+ )
814
+ genearted_messages.append(message)
815
+ return genearted_messages
816
+
817
+ @staticmethod
818
+ def loudness_normalize(
819
+ wav: torch.Tensor,
820
+ target_dbfs: float = -20,
821
+ gain_range: tuple[float, float] = (-3.0, 3.0),
822
+ ) -> torch.Tensor:
823
+ wav = wav.to(torch.float32)
824
+ if wav.numel() == 0:
825
+ return wav
826
+ current_dbfs = 10.0 * torch.log10(torch.mean(wav**2) + 1e-9)
827
+ gain = float(target_dbfs - current_dbfs)
828
+ gain = max(gain_range[0], min(gain, gain_range[1]))
829
+ factor = 10.0 ** (gain / 20.0)
830
+ return wav * factor
831
+
832
+ def _get_audio_tokenizer_device(self) -> torch.device:
833
+ """Best-effort device inference for `self.audio_tokenizer`.
834
+
835
+ Notes:
836
+ - Old TAC wrapper exposed `.device`, but standard `torch.nn.Module` does not.
837
+ - New MossAudioTokenizerModel is a `PreTrainedModel`; parameters define its device.
838
+ """
839
+
840
+ audio_tokenizer = getattr(self, "audio_tokenizer", None)
841
+ if audio_tokenizer is None:
842
+ logger.warning(
843
+ "audio_tokenizer is not set on processor. Using CPU as default."
844
+ )
845
+ return torch.device("cpu")
846
+
847
+ device_attr = getattr(audio_tokenizer, "device", None)
848
+ if isinstance(device_attr, torch.device):
849
+ return device_attr
850
+
851
+ try:
852
+ return next(audio_tokenizer.parameters()).device
853
+ except StopIteration:
854
+ # No parameters (shouldn't happen for real models); default to CPU.
855
+ logger.warning(
856
+ "No parameters found on audio_tokenizer. Using CPU as default."
857
+ )
858
+ return torch.device("cpu")
859
+
860
+ def encode_audios_from_wav(
861
+ self,
862
+ wav_list: List[torch.Tensor],
863
+ sampling_rate: int,
864
+ n_vq: Optional[int] = None,
865
+ ):
866
+ if self.audio_tokenizer is None:
867
+ raise RuntimeError("audio_tokenizer is not set on processor.")
868
+ audio_tokenizer = self.audio_tokenizer
869
+
870
+ if isinstance(wav_list, torch.Tensor):
871
+ wav_list = [wav_list]
872
+ wav_list_ = []
873
+ resample = False
874
+ if sampling_rate != self.model_config.sampling_rate:
875
+ resample = True
876
+ device = self._get_audio_tokenizer_device()
877
+ for wav in wav_list:
878
+ if wav.shape[0] > 1:
879
+ wav = torch.mean(wav, dim=0, keepdim=True)
880
+ if resample:
881
+ wav = torchaudio.functional.resample(
882
+ waveform=wav,
883
+ orig_freq=sampling_rate,
884
+ new_freq=self.model_config.sampling_rate,
885
+ )
886
+ wav = wav.to(device)
887
+ wav_list_.append(self.loudness_normalize(wav.squeeze(0)))
888
+
889
+ # New MossAudioTokenizerModel API: prefer batch_encode(list[wav])
890
+ if hasattr(audio_tokenizer, "batch_encode"):
891
+ enc = audio_tokenizer.batch_encode(wav_list_, num_quantizers=n_vq)
892
+ audio_codes = enc.audio_codes # (NQ, B, T)
893
+ audio_codes_lengths = enc.audio_codes_lengths # (B,)
894
+ else:
895
+ # Fallback: use encode() with explicit padding.
896
+ max_len = max(int(wav.shape[-1]) for wav in wav_list_)
897
+ input_values = torch.zeros(
898
+ len(wav_list_), 1, max_len, device=device, dtype=torch.float32
899
+ )
900
+ padding_mask = torch.zeros(
901
+ len(wav_list_), max_len, device=device, dtype=torch.bool
902
+ )
903
+ for i, wav in enumerate(wav_list_):
904
+ this_len = int(wav.shape[-1])
905
+ input_values[i, 0, :this_len] = wav
906
+ padding_mask[i, :this_len] = True
907
+ enc = audio_tokenizer.encode(
908
+ input_values,
909
+ padding_mask=padding_mask,
910
+ num_quantizers=n_vq,
911
+ return_dict=True,
912
+ )
913
+ audio_codes = enc.audio_codes
914
+ audio_codes_lengths = enc.audio_codes_lengths
915
+
916
+ if audio_codes is None or audio_codes_lengths is None:
917
+ raise RuntimeError(
918
+ "audio_tokenizer.encode() returned empty outputs (audio_codes/audio_codes_lengths)."
919
+ )
920
+
921
+ # Keep processor's historical contract: list[Tensor] with shape (T, NQ)
922
+ # and on CPU (so downstream text/audio packing remains device-agnostic).
923
+ codes_list: List[torch.Tensor] = []
924
+ for i in range(int(audio_codes.shape[1])):
925
+ length_i = int(audio_codes_lengths[i].item())
926
+ codes_i = (
927
+ audio_codes[:, i, :length_i]
928
+ .transpose(0, 1)
929
+ .contiguous()
930
+ .to(torch.long)
931
+ .cpu()
932
+ )
933
+ codes_list.append(codes_i)
934
+ return codes_list
935
+
936
+ def encode_audios_from_path(
937
+ self, wav_path_list: Union[str, List[str]], n_vq: Optional[int] = None
938
+ ):
939
+ if isinstance(wav_path_list, str):
940
+ wav_path_list = [wav_path_list]
941
+
942
+ if len(wav_path_list) == 0:
943
+ raise ValueError("Empty wav_path_list")
944
+
945
+ # Load + (if needed) resample each wav independently, so callers can
946
+ # pass a heterogeneous batch of files while still benefiting from
947
+ # audio_tokenizer.batch_encode.
948
+ target_sr = int(self.model_config.sampling_rate)
949
+ wav_list: List[torch.Tensor] = []
950
+ for wav_path in wav_path_list:
951
+ wav, sr = torchaudio.load(wav_path)
952
+ if int(sr) != target_sr:
953
+ wav = torchaudio.functional.resample(
954
+ waveform=wav,
955
+ orig_freq=int(sr),
956
+ new_freq=target_sr,
957
+ )
958
+ wav_list.append(wav)
959
+
960
+ return self.encode_audios_from_wav(wav_list, target_sr, n_vq)
961
+
962
+ def decode_audio_codes(
963
+ self, audio_tokens_list: Union[torch.Tensor, List[torch.Tensor]]
964
+ ):
965
+ if self.audio_tokenizer is None:
966
+ raise RuntimeError("audio_tokenizer is not set on processor.")
967
+ audio_tokenizer = self.audio_tokenizer
968
+
969
+ if isinstance(audio_tokens_list, torch.Tensor):
970
+ audio_tokens_list = [audio_tokens_list]
971
+ if len(audio_tokens_list) == 0:
972
+ return []
973
+
974
+ device = self._get_audio_tokenizer_device()
975
+
976
+ # Processor uses (T, NQ); MossAudioTokenizer expects (NQ, T) (or (NQ, B, T)).
977
+ codes_list = [
978
+ codes.transpose(0, 1).contiguous().to(device=device, dtype=torch.long)
979
+ for codes in audio_tokens_list
980
+ ]
981
+
982
+ if hasattr(audio_tokenizer, "batch_decode"):
983
+ dec = audio_tokenizer.batch_decode(codes_list)
984
+ audio = dec.audio # (B, C, T)
985
+ audio_lengths = dec.audio_lengths # (B,)
986
+ else:
987
+ # Fallback: pad to (NQ, B, T) + mask, then decode.
988
+ nq = int(codes_list[0].shape[0])
989
+ max_t = max(int(c.shape[1]) for c in codes_list)
990
+ audio_codes = torch.zeros(
991
+ nq, len(codes_list), max_t, device=device, dtype=torch.long
992
+ )
993
+ padding_mask = torch.zeros(
994
+ len(codes_list), max_t, device=device, dtype=torch.bool
995
+ )
996
+ for i, c in enumerate(codes_list):
997
+ t = int(c.shape[1])
998
+ audio_codes[:, i, :t] = c
999
+ padding_mask[i, :t] = True
1000
+ dec = audio_tokenizer.decode(
1001
+ audio_codes, padding_mask=padding_mask, return_dict=True
1002
+ )
1003
+ audio = dec.audio
1004
+ audio_lengths = dec.audio_lengths
1005
+
1006
+ if audio is None or audio_lengths is None:
1007
+ raise RuntimeError(
1008
+ "audio_tokenizer.decode() returned empty outputs (audio/audio_lengths)."
1009
+ )
1010
+
1011
+ # Return historical contract: list of 1D waveforms (T,)
1012
+ wav_list: List[torch.Tensor] = []
1013
+ for i in range(int(audio.shape[0])):
1014
+ length_i = int(audio_lengths[i].item())
1015
+ wav = audio[i, 0, :length_i].contiguous().to(torch.float32).cpu()
1016
+ wav_list.append(wav)
1017
+ return wav_list
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "MossTTSDelayProcessor",
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_moss_tts.MossTTSDelayProcessor"
5
+ }
6
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|audio_start|>",
12
+ "<|audio_end|>",
13
+ "<|audio_user_slot|>",
14
+ "<|image_pad|>",
15
+ "<|audio_assistant_gen_slot|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb3c8fa82993d515469c2800cc455bff4aaa3c4fed9da1f2b0c0668c304f335a
3
+ size 11422691
tokenizer_config.json ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|audio_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|audio_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|audio_user_slot|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|audio_assistant_gen_slot|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|audio_assistant_delay_slot|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|audio_start|>",
224
+ "<|audio_end|>",
225
+ "<|audio_user_slot|>",
226
+ "<|image_pad|>",
227
+ "<|audio_assistant_gen_slot|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|endoftext|>",
236
+ "processor_class": "AsteroidProcessor",
237
+ "split_special_tokens": false,
238
+ "tokenizer_class": "Qwen2Tokenizer",
239
+ "unk_token": null
240
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff