rulerman commited on
Commit
c76414e
·
verified ·
1 Parent(s): b806ee7
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|audio_end|>": 151653,
9
+ "<|audio_pad|>": 151654,
10
+ "<|audio_start|>": 151652,
11
+ "<|box_end|>": 151649,
12
+ "<|box_start|>": 151648,
13
+ "<|endoftext|>": 151643,
14
+ "<|file_sep|>": 151664,
15
+ "<|fim_middle|>": 151660,
16
+ "<|fim_pad|>": 151662,
17
+ "<|fim_prefix|>": 151659,
18
+ "<|fim_suffix|>": 151661,
19
+ "<|im_end|>": 151645,
20
+ "<|im_start|>": 151644,
21
+ "<|image_pad|>": 151655,
22
+ "<|object_ref_end|>": 151647,
23
+ "<|object_ref_start|>": 151646,
24
+ "<|quad_end|>": 151651,
25
+ "<|quad_start|>": 151650,
26
+ "<|repo_name|>": 151663,
27
+ "<|video_pad|>": 151656
28
+ }
chat_template.jinja ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {% for message in messages %}<|im_start|>{{ message['role'] }}
2
+ {% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content.get('type') == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}<|im_end|>
3
+ {% endfor %}{% if add_generation_prompt %}<|im_start|>assistant
4
+ {% endif %}
config.json ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "moss_tts_delay",
3
+ "architectures": [
4
+ "MossTTSDelayModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_moss_tts.MossTTSDelayConfig",
8
+ "AutoModel": "modeling_moss_tts.MossTTSDelayModel"
9
+ },
10
+ "dtype": "bfloat16",
11
+ "initializer_range": 0.02,
12
+ "language_config": {
13
+ "_name_or_path": "Qwen/Qwen3-8B",
14
+ "architectures": [
15
+ "Qwen3ForCausalLM"
16
+ ],
17
+ "attention_bias": false,
18
+ "attention_dropout": 0.0,
19
+ "bos_token_id": 151643,
20
+ "eos_token_id": 151645,
21
+ "pad_token_id": 151643,
22
+ "head_dim": 128,
23
+ "hidden_act": "silu",
24
+ "hidden_size": 4096,
25
+ "initializer_range": 0.02,
26
+ "intermediate_size": 12288,
27
+ "layer_types": [
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention",
44
+ "full_attention",
45
+ "full_attention",
46
+ "full_attention",
47
+ "full_attention",
48
+ "full_attention",
49
+ "full_attention",
50
+ "full_attention",
51
+ "full_attention",
52
+ "full_attention",
53
+ "full_attention",
54
+ "full_attention",
55
+ "full_attention",
56
+ "full_attention",
57
+ "full_attention",
58
+ "full_attention",
59
+ "full_attention",
60
+ "full_attention",
61
+ "full_attention",
62
+ "full_attention",
63
+ "full_attention"
64
+ ],
65
+ "max_position_embeddings": 40960,
66
+ "max_window_layers": 36,
67
+ "model_type": "qwen3",
68
+ "num_attention_heads": 32,
69
+ "num_hidden_layers": 36,
70
+ "num_key_value_heads": 8,
71
+ "rms_norm_eps": 1e-06,
72
+ "rope_scaling": null,
73
+ "rope_theta": 1000000,
74
+ "sliding_window": null,
75
+ "use_cache": true,
76
+ "use_sliding_window": false,
77
+ "vocab_size": 155648
78
+ },
79
+ "n_vq": 16,
80
+ "audio_vocab_size": 1024,
81
+ "audio_user_slot_token_id": 151654,
82
+ "audio_assistant_gen_slot_token_id": 151656,
83
+ "audio_assistant_delay_slot_token_id": 151662,
84
+ "audio_start_token_id": 151652,
85
+ "audio_end_token_id": 151653,
86
+ "audio_pad_code": 1024,
87
+ "sampling_rate": 24000,
88
+ "transformers_version": "4.57.1"
89
+ }
configuration_moss_tts.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ MossTTSDelay model configuration """
16
+
17
+ from typing import Optional, Union
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+ from transformers.models.qwen3 import Qwen3Config
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class MossTTSDelayConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`MossTTSDelayModel`]. It is used to instantiate an
28
+ MossTTSDelay model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of the MossTTSDelay [MossTTSDelay-8B](https://huggingface.co/OpenMOSS/mosstts-8b) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ language_config (`Union[Qwen3Config, dict]`, *optional*):
36
+ Configuration for the backbone language model (Qwen3).
37
+ initializer_range (`float`, *optional*, defaults to 0.02):
38
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
39
+ n_vq (`int`, *optional*, defaults to 16):
40
+ Number of additional VQ (Vector Quantization) heads/channels for audio.
41
+ Determines the number of codebooks used in the audio representation.
42
+ audio_vocab_size (`int`, *optional*, defaults to 1024):
43
+ Vocabulary size for the audio tokens (codebooks 1 to N).
44
+ audio_user_slot_token_id (`int`, *optional*, defaults to 151654):
45
+ The specific token ID used as a placeholder/slot for user-side audio inputs in the prompt.
46
+ audio_assistant_gen_slot_token_id (`int`, *optional*, defaults to 151656):
47
+ The specific token ID representing the generation slot for the assistant's audio output.
48
+ Acting as the trigger for the TTS generation process.
49
+ audio_assistant_delay_slot_token_id (`int`, *optional*, defaults to 151662):
50
+ The token ID used in the 'Delay Pattern' paradigm to represent the delayed/offset positions
51
+ between different VQ channels.
52
+ audio_start_token_id (`int`, *optional*, defaults to 151652):
53
+ Special token ID used to denote the start of an audio sequence in the stream.
54
+ audio_end_token_id (`int`, *optional*, defaults to 151653):
55
+ Special token ID used to denote the end of an audio sequence (EOS for audio).
56
+ audio_pad_code (`int`, *optional*, defaults to 1024):
57
+ The padding value used within the audio VQ codebooks. Typically equals `audio_vocab_size`.
58
+ """
59
+ model_type = "moss_tts_delay"
60
+ keys_to_ignore_at_inference = ["past_key_values"]
61
+
62
+ def __init__(
63
+ self,
64
+ language_config: Optional[Union[Qwen3Config, dict]] = None,
65
+ initializer_range: float = 0.02,
66
+ n_vq: int = 16,
67
+ pad_token_id: int = 151643,
68
+ im_start_token_id: int = 151644,
69
+ im_end_token_id: int = 151645,
70
+ audio_vocab_size: int = 1024,
71
+ audio_user_slot_token_id: int = 151654,
72
+ audio_assistant_gen_slot_token_id: int = 151656,
73
+ audio_assistant_delay_slot_token_id: int = 151662,
74
+ audio_start_token_id: int = 151652,
75
+ audio_end_token_id: int = 151653,
76
+ audio_pad_code: int = 1024,
77
+ sampling_rate: int = 24000,
78
+ **kwargs,
79
+ ):
80
+ if isinstance(language_config, dict):
81
+ self.language_config = Qwen3Config(**language_config)
82
+ elif language_config is None:
83
+ self.language_config = Qwen3Config()
84
+ else:
85
+ self.language_config = language_config
86
+
87
+ self.initializer_range = initializer_range
88
+ self.n_vq = n_vq
89
+ self.audio_vocab_size = audio_vocab_size
90
+ self.audio_user_slot_token_id = audio_user_slot_token_id
91
+ self.audio_assistant_gen_slot_token_id = audio_assistant_gen_slot_token_id
92
+ self.audio_assistant_delay_slot_token_id = audio_assistant_delay_slot_token_id
93
+ self.audio_start_token_id = audio_start_token_id
94
+ self.audio_end_token_id = audio_end_token_id
95
+ self.audio_pad_code = audio_pad_code
96
+ self.sampling_rate = sampling_rate
97
+
98
+ self.hidden_size = self.language_config.hidden_size
99
+ self.vocab_size = self.language_config.vocab_size
100
+ self.im_start_token_id = self.language_config
101
+ self.pad_token_id = pad_token_id
102
+ self.im_start_token_id = im_start_token_id
103
+ self.im_end_token_id = im_end_token_id
104
+
105
+
106
+ super().__init__(**kwargs)
107
+
108
+ def to_dict(self):
109
+ output = super().to_dict()
110
+ if hasattr(self.language_config, "to_dict"):
111
+ output["language_config"] = self.language_config.to_dict()
112
+ else:
113
+ output["language_config"] = self.language_config
114
+ return output
generation_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_sample": true,
3
+ "temperature": 1.1,
4
+ "top_p": 0.9,
5
+ "top_k": 50,
6
+ "repetition_penalty": 1.1,
7
+ "max_new_tokens": 8192
8
+ }
inference_utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import torch.nn.functional as F
4
+ from typing import Optional, List, Tuple
5
+ from tqdm import tqdm
6
+
7
+
8
+ def apply_top_k(logits, top_k):
9
+ batch_size, vocab_size = logits.shape
10
+ top_k = min(top_k, vocab_size)
11
+ top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1)
12
+ filtered_logits = torch.full_like(logits, float("-inf"))
13
+ batch_indices = torch.arange(batch_size).unsqueeze(-1)
14
+ filtered_logits[batch_indices, top_k_indices] = top_k_values
15
+ return filtered_logits
16
+
17
+
18
+ def apply_top_p(logits, top_p):
19
+ probs = F.softmax(logits, dim=-1)
20
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
21
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
22
+ sorted_indices_to_remove = cumulative_probs > top_p
23
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
24
+ sorted_indices_to_remove[..., 0] = False
25
+ batch_size = logits.shape[0]
26
+ filtered_logits = logits.clone()
27
+ for i in range(batch_size):
28
+ indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
29
+ filtered_logits[i, indices_to_remove] = float("-inf")
30
+ return filtered_logits
31
+
32
+
33
+ def apply_top_p_optimized(logits, top_p):
34
+ probs = F.softmax(logits, dim=-1)
35
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
36
+
37
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
38
+
39
+ sorted_indices_to_remove = cumulative_probs > top_p
40
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
41
+ sorted_indices_to_remove[..., 0] = False
42
+
43
+ indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
44
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
45
+ )
46
+
47
+ logits[indices_to_remove] = float("-inf")
48
+ return logits
49
+
50
+
51
+ def apply_repetition_penalty_delay_pattern(
52
+ logits: torch.Tensor,
53
+ prev_tokens: torch.LongTensor,
54
+ penalty: float,
55
+ ):
56
+ """
57
+ logits: [B, H, V] or [N, V]
58
+ prev_tokens: [B, T, H] or [N, T] or [B, H]
59
+
60
+ Apply the repetition penalty independently for each H (VQ head).
61
+ """
62
+ if penalty == 1.0 or prev_tokens is None:
63
+ return logits
64
+
65
+ vocab_size = logits.size(-1)
66
+
67
+ # Case 1: regular [N, V] (text layer)
68
+ if logits.dim() == 2:
69
+ prev_tokens_flat = prev_tokens.reshape(-1)
70
+ unique_tokens = torch.unique(prev_tokens_flat)
71
+
72
+ token_logits = logits[:, unique_tokens]
73
+ pos_mask = token_logits > 0
74
+ token_logits[pos_mask] /= penalty
75
+ token_logits[~pos_mask] *= penalty
76
+ logits[:, unique_tokens] = token_logits
77
+ return logits
78
+
79
+ # Case 2: Delay Pattern audio [B, H, V]
80
+ assert logits.dim() == 3, "Delay Pattern audio logits must be [B, H, V]"
81
+ B, H, V = logits.shape
82
+
83
+ for h in range(H):
84
+ # prev_tokens_h: [B, T] or [B]
85
+ prev_tokens_h = prev_tokens[..., h].reshape(-1)
86
+ unique_tokens = torch.unique(prev_tokens_h)
87
+
88
+ if unique_tokens.numel() == 0:
89
+ continue
90
+
91
+ token_logits = logits[:, h, unique_tokens]
92
+ pos_mask = token_logits > 0
93
+ token_logits[pos_mask] /= penalty
94
+ token_logits[~pos_mask] *= penalty
95
+ logits[:, h, unique_tokens] = token_logits
96
+
97
+ return logits
98
+
99
+
100
+ def sample_token(
101
+ logits,
102
+ prev_tokens: Optional[torch.LongTensor] = None,
103
+ repetition_penalty: float = 1.0,
104
+ top_p=None,
105
+ top_k=None,
106
+ do_sample=True,
107
+ ):
108
+ vocab_size = logits.size(-1)
109
+
110
+ # ===== Repetition Penalty (before reshaping!) =====
111
+ if prev_tokens is not None and repetition_penalty != 1.0:
112
+ logits = apply_repetition_penalty_delay_pattern(
113
+ logits,
114
+ prev_tokens,
115
+ repetition_penalty,
116
+ )
117
+
118
+ if not do_sample:
119
+ return torch.argmax(logits, dim=-1)
120
+
121
+ # ===== Only flatten after this, for top-k / top-p / multinomial =====
122
+ original_shape = logits.shape
123
+ reshaped_logits = logits.view(-1, vocab_size)
124
+
125
+ if top_k is not None and top_k > 0:
126
+ reshaped_logits = apply_top_k(reshaped_logits, top_k)
127
+
128
+ if top_p is not None and top_p < 1.0:
129
+ reshaped_logits = apply_top_p_optimized(reshaped_logits, top_p)
130
+
131
+ probs = F.softmax(reshaped_logits, dim=-1)
132
+ next_tokens = torch.multinomial(probs, num_samples=1)
133
+
134
+ return next_tokens.view(original_shape[:-1])
135
+
136
+
137
+ def find_last_equal_C(tensor, C):
138
+ """
139
+ tensor: torch.Tensor of shape [batch_size, seq_len]
140
+ C: scalar value to match
141
+ Returns: torch.Tensor of shape [batch_size] with last indices
142
+ """
143
+ mask = (tensor == C).int() # Shape: [batch_size, seq_len], bool tensor
144
+ flipped_mask = mask.flip(dims=[1]) # Flip along sequence dimension
145
+ flipped_indices = flipped_mask.argmax(dim=1) # First True in flipped
146
+ seq_len = tensor.shape[1]
147
+ last_indices = (seq_len - 1) - flipped_indices # Convert to original indices
148
+
149
+ # Optional: Handle cases with no C (set to -1), though problem assumes existence
150
+ actual_values = tensor[torch.arange(tensor.shape[0]), last_indices]
151
+ no_match = actual_values != C
152
+ last_indices[no_match] = -1
153
+
154
+ return last_indices
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad8bb115b6c87c902c76e7a4ef90a6eee98041a87f741fcd163fbe81d855d87a
3
+ size 4932667368
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7cb9641e2f25651a43bc7f681337d79ab57ca50d63355660787712fcee393c40
3
+ size 4915961640
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:646b345809967a41308d95afcf5af233ae94de7997befac287017af602c59687
3
+ size 4983069760
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab0a0ed173b7a86a8cb16b82ca981ad0f1866c2c2d0ecb47c97e72b68ccbaba6
3
+ size 1879339648
model.safetensors.index.json ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_parameters": 8355492864,
4
+ "total_size": 16710985728
5
+ },
6
+ "weight_map": {
7
+ "emb_ext.0.weight": "model-00004-of-00004.safetensors",
8
+ "emb_ext.1.weight": "model-00004-of-00004.safetensors",
9
+ "emb_ext.10.weight": "model-00004-of-00004.safetensors",
10
+ "emb_ext.11.weight": "model-00004-of-00004.safetensors",
11
+ "emb_ext.12.weight": "model-00004-of-00004.safetensors",
12
+ "emb_ext.13.weight": "model-00004-of-00004.safetensors",
13
+ "emb_ext.14.weight": "model-00004-of-00004.safetensors",
14
+ "emb_ext.15.weight": "model-00004-of-00004.safetensors",
15
+ "emb_ext.2.weight": "model-00004-of-00004.safetensors",
16
+ "emb_ext.3.weight": "model-00004-of-00004.safetensors",
17
+ "emb_ext.4.weight": "model-00004-of-00004.safetensors",
18
+ "emb_ext.5.weight": "model-00004-of-00004.safetensors",
19
+ "emb_ext.6.weight": "model-00004-of-00004.safetensors",
20
+ "emb_ext.7.weight": "model-00004-of-00004.safetensors",
21
+ "emb_ext.8.weight": "model-00004-of-00004.safetensors",
22
+ "emb_ext.9.weight": "model-00004-of-00004.safetensors",
23
+ "language_model.embed_tokens.weight": "model-00001-of-00004.safetensors",
24
+ "language_model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
25
+ "language_model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
26
+ "language_model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
27
+ "language_model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
28
+ "language_model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
29
+ "language_model.layers.0.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
30
+ "language_model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
31
+ "language_model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
32
+ "language_model.layers.0.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
33
+ "language_model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
34
+ "language_model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
35
+ "language_model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
36
+ "language_model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
37
+ "language_model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
38
+ "language_model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
39
+ "language_model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
40
+ "language_model.layers.1.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
41
+ "language_model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
42
+ "language_model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
43
+ "language_model.layers.1.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
44
+ "language_model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
45
+ "language_model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
46
+ "language_model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
47
+ "language_model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
48
+ "language_model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
49
+ "language_model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
50
+ "language_model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
51
+ "language_model.layers.10.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
52
+ "language_model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
53
+ "language_model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
54
+ "language_model.layers.10.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
55
+ "language_model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
56
+ "language_model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
57
+ "language_model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
58
+ "language_model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
59
+ "language_model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
60
+ "language_model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
61
+ "language_model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
62
+ "language_model.layers.11.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
63
+ "language_model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
64
+ "language_model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
65
+ "language_model.layers.11.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
66
+ "language_model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
67
+ "language_model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
68
+ "language_model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
69
+ "language_model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
70
+ "language_model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
71
+ "language_model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
72
+ "language_model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
73
+ "language_model.layers.12.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
74
+ "language_model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
75
+ "language_model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
76
+ "language_model.layers.12.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
77
+ "language_model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
78
+ "language_model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
79
+ "language_model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
80
+ "language_model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
81
+ "language_model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
82
+ "language_model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
83
+ "language_model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
84
+ "language_model.layers.13.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
85
+ "language_model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
86
+ "language_model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
87
+ "language_model.layers.13.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
88
+ "language_model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
89
+ "language_model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
90
+ "language_model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
91
+ "language_model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
92
+ "language_model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
93
+ "language_model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
94
+ "language_model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
95
+ "language_model.layers.14.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
96
+ "language_model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
97
+ "language_model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
98
+ "language_model.layers.14.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
99
+ "language_model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
100
+ "language_model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
101
+ "language_model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
102
+ "language_model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
103
+ "language_model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
104
+ "language_model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
105
+ "language_model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
106
+ "language_model.layers.15.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
107
+ "language_model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
108
+ "language_model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
109
+ "language_model.layers.15.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
110
+ "language_model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
111
+ "language_model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
112
+ "language_model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
113
+ "language_model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
114
+ "language_model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
115
+ "language_model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
116
+ "language_model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
117
+ "language_model.layers.16.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
118
+ "language_model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
119
+ "language_model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
120
+ "language_model.layers.16.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
121
+ "language_model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
122
+ "language_model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
123
+ "language_model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
124
+ "language_model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
125
+ "language_model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
126
+ "language_model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
127
+ "language_model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
128
+ "language_model.layers.17.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
129
+ "language_model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
130
+ "language_model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
131
+ "language_model.layers.17.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
132
+ "language_model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
133
+ "language_model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
134
+ "language_model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
135
+ "language_model.layers.18.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
136
+ "language_model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
137
+ "language_model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
138
+ "language_model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
139
+ "language_model.layers.18.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
140
+ "language_model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
141
+ "language_model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
142
+ "language_model.layers.18.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
143
+ "language_model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
144
+ "language_model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
145
+ "language_model.layers.19.input_layernorm.weight": "model-00002-of-00004.safetensors",
146
+ "language_model.layers.19.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
147
+ "language_model.layers.19.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
148
+ "language_model.layers.19.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
149
+ "language_model.layers.19.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
150
+ "language_model.layers.19.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
151
+ "language_model.layers.19.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
152
+ "language_model.layers.19.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
153
+ "language_model.layers.19.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
154
+ "language_model.layers.19.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
155
+ "language_model.layers.19.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
156
+ "language_model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
157
+ "language_model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
158
+ "language_model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
159
+ "language_model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
160
+ "language_model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
161
+ "language_model.layers.2.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
162
+ "language_model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
163
+ "language_model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
164
+ "language_model.layers.2.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
165
+ "language_model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
166
+ "language_model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
167
+ "language_model.layers.20.input_layernorm.weight": "model-00002-of-00004.safetensors",
168
+ "language_model.layers.20.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
169
+ "language_model.layers.20.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
170
+ "language_model.layers.20.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
171
+ "language_model.layers.20.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
172
+ "language_model.layers.20.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
173
+ "language_model.layers.20.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
174
+ "language_model.layers.20.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
175
+ "language_model.layers.20.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
176
+ "language_model.layers.20.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
177
+ "language_model.layers.20.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
178
+ "language_model.layers.21.input_layernorm.weight": "model-00002-of-00004.safetensors",
179
+ "language_model.layers.21.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
180
+ "language_model.layers.21.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
181
+ "language_model.layers.21.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
182
+ "language_model.layers.21.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
183
+ "language_model.layers.21.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
184
+ "language_model.layers.21.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
185
+ "language_model.layers.21.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
186
+ "language_model.layers.21.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
187
+ "language_model.layers.21.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
188
+ "language_model.layers.21.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
189
+ "language_model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
190
+ "language_model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
191
+ "language_model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
192
+ "language_model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
193
+ "language_model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
194
+ "language_model.layers.22.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
195
+ "language_model.layers.22.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
196
+ "language_model.layers.22.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
197
+ "language_model.layers.22.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
198
+ "language_model.layers.22.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
199
+ "language_model.layers.22.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
200
+ "language_model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
201
+ "language_model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
202
+ "language_model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
203
+ "language_model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
204
+ "language_model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
205
+ "language_model.layers.23.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
206
+ "language_model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
207
+ "language_model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
208
+ "language_model.layers.23.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
209
+ "language_model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
210
+ "language_model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
211
+ "language_model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
212
+ "language_model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
213
+ "language_model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
214
+ "language_model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
215
+ "language_model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
216
+ "language_model.layers.24.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
217
+ "language_model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
218
+ "language_model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
219
+ "language_model.layers.24.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
220
+ "language_model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
221
+ "language_model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
222
+ "language_model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
223
+ "language_model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
224
+ "language_model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
225
+ "language_model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
226
+ "language_model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
227
+ "language_model.layers.25.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
228
+ "language_model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
229
+ "language_model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
230
+ "language_model.layers.25.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
231
+ "language_model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
232
+ "language_model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
233
+ "language_model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
234
+ "language_model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
235
+ "language_model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
236
+ "language_model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
237
+ "language_model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
238
+ "language_model.layers.26.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
239
+ "language_model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
240
+ "language_model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
241
+ "language_model.layers.26.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
242
+ "language_model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
243
+ "language_model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
244
+ "language_model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
245
+ "language_model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
246
+ "language_model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
247
+ "language_model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
248
+ "language_model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
249
+ "language_model.layers.27.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
250
+ "language_model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
251
+ "language_model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
252
+ "language_model.layers.27.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
253
+ "language_model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
254
+ "language_model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
255
+ "language_model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
256
+ "language_model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
257
+ "language_model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
258
+ "language_model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
259
+ "language_model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
260
+ "language_model.layers.28.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
261
+ "language_model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
262
+ "language_model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
263
+ "language_model.layers.28.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
264
+ "language_model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
265
+ "language_model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
266
+ "language_model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
267
+ "language_model.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
268
+ "language_model.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
269
+ "language_model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
270
+ "language_model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
271
+ "language_model.layers.29.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
272
+ "language_model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
273
+ "language_model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
274
+ "language_model.layers.29.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
275
+ "language_model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
276
+ "language_model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
277
+ "language_model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
278
+ "language_model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
279
+ "language_model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
280
+ "language_model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
281
+ "language_model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
282
+ "language_model.layers.3.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
283
+ "language_model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
284
+ "language_model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
285
+ "language_model.layers.3.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
286
+ "language_model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
287
+ "language_model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
288
+ "language_model.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors",
289
+ "language_model.layers.30.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
290
+ "language_model.layers.30.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
291
+ "language_model.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
292
+ "language_model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
293
+ "language_model.layers.30.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
294
+ "language_model.layers.30.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
295
+ "language_model.layers.30.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
296
+ "language_model.layers.30.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
297
+ "language_model.layers.30.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
298
+ "language_model.layers.30.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
299
+ "language_model.layers.31.input_layernorm.weight": "model-00003-of-00004.safetensors",
300
+ "language_model.layers.31.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
301
+ "language_model.layers.31.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
302
+ "language_model.layers.31.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
303
+ "language_model.layers.31.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
304
+ "language_model.layers.31.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
305
+ "language_model.layers.31.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
306
+ "language_model.layers.31.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
307
+ "language_model.layers.31.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
308
+ "language_model.layers.31.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
309
+ "language_model.layers.31.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
310
+ "language_model.layers.32.input_layernorm.weight": "model-00003-of-00004.safetensors",
311
+ "language_model.layers.32.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
312
+ "language_model.layers.32.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
313
+ "language_model.layers.32.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
314
+ "language_model.layers.32.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
315
+ "language_model.layers.32.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
316
+ "language_model.layers.32.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
317
+ "language_model.layers.32.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
318
+ "language_model.layers.32.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
319
+ "language_model.layers.32.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
320
+ "language_model.layers.32.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
321
+ "language_model.layers.33.input_layernorm.weight": "model-00003-of-00004.safetensors",
322
+ "language_model.layers.33.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
323
+ "language_model.layers.33.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
324
+ "language_model.layers.33.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
325
+ "language_model.layers.33.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
326
+ "language_model.layers.33.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
327
+ "language_model.layers.33.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
328
+ "language_model.layers.33.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
329
+ "language_model.layers.33.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
330
+ "language_model.layers.33.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
331
+ "language_model.layers.33.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
332
+ "language_model.layers.34.input_layernorm.weight": "model-00003-of-00004.safetensors",
333
+ "language_model.layers.34.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
334
+ "language_model.layers.34.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
335
+ "language_model.layers.34.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
336
+ "language_model.layers.34.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
337
+ "language_model.layers.34.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
338
+ "language_model.layers.34.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
339
+ "language_model.layers.34.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
340
+ "language_model.layers.34.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
341
+ "language_model.layers.34.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
342
+ "language_model.layers.34.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
343
+ "language_model.layers.35.input_layernorm.weight": "model-00004-of-00004.safetensors",
344
+ "language_model.layers.35.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
345
+ "language_model.layers.35.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
346
+ "language_model.layers.35.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
347
+ "language_model.layers.35.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
348
+ "language_model.layers.35.self_attn.k_norm.weight": "model-00004-of-00004.safetensors",
349
+ "language_model.layers.35.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
350
+ "language_model.layers.35.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
351
+ "language_model.layers.35.self_attn.q_norm.weight": "model-00004-of-00004.safetensors",
352
+ "language_model.layers.35.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
353
+ "language_model.layers.35.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
354
+ "language_model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
355
+ "language_model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
356
+ "language_model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
357
+ "language_model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
358
+ "language_model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
359
+ "language_model.layers.4.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
360
+ "language_model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
361
+ "language_model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
362
+ "language_model.layers.4.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
363
+ "language_model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
364
+ "language_model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
365
+ "language_model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
366
+ "language_model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
367
+ "language_model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
368
+ "language_model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
369
+ "language_model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
370
+ "language_model.layers.5.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
371
+ "language_model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
372
+ "language_model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
373
+ "language_model.layers.5.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
374
+ "language_model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
375
+ "language_model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
376
+ "language_model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
377
+ "language_model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
378
+ "language_model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
379
+ "language_model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
380
+ "language_model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
381
+ "language_model.layers.6.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
382
+ "language_model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
383
+ "language_model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
384
+ "language_model.layers.6.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
385
+ "language_model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
386
+ "language_model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
387
+ "language_model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
388
+ "language_model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
389
+ "language_model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
390
+ "language_model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
391
+ "language_model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
392
+ "language_model.layers.7.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
393
+ "language_model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
394
+ "language_model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
395
+ "language_model.layers.7.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
396
+ "language_model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
397
+ "language_model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
398
+ "language_model.layers.8.input_layernorm.weight": "model-00001-of-00004.safetensors",
399
+ "language_model.layers.8.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
400
+ "language_model.layers.8.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
401
+ "language_model.layers.8.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
402
+ "language_model.layers.8.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
403
+ "language_model.layers.8.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
404
+ "language_model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
405
+ "language_model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
406
+ "language_model.layers.8.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
407
+ "language_model.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
408
+ "language_model.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
409
+ "language_model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
410
+ "language_model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
411
+ "language_model.layers.9.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
412
+ "language_model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
413
+ "language_model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
414
+ "language_model.layers.9.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
415
+ "language_model.layers.9.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
416
+ "language_model.layers.9.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
417
+ "language_model.layers.9.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
418
+ "language_model.layers.9.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
419
+ "language_model.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
420
+ "language_model.norm.weight": "model-00004-of-00004.safetensors",
421
+ "lm_heads.0.weight": "model-00004-of-00004.safetensors",
422
+ "lm_heads.1.weight": "model-00004-of-00004.safetensors",
423
+ "lm_heads.10.weight": "model-00004-of-00004.safetensors",
424
+ "lm_heads.11.weight": "model-00004-of-00004.safetensors",
425
+ "lm_heads.12.weight": "model-00004-of-00004.safetensors",
426
+ "lm_heads.13.weight": "model-00004-of-00004.safetensors",
427
+ "lm_heads.14.weight": "model-00004-of-00004.safetensors",
428
+ "lm_heads.15.weight": "model-00004-of-00004.safetensors",
429
+ "lm_heads.16.weight": "model-00004-of-00004.safetensors",
430
+ "lm_heads.2.weight": "model-00004-of-00004.safetensors",
431
+ "lm_heads.3.weight": "model-00004-of-00004.safetensors",
432
+ "lm_heads.4.weight": "model-00004-of-00004.safetensors",
433
+ "lm_heads.5.weight": "model-00004-of-00004.safetensors",
434
+ "lm_heads.6.weight": "model-00004-of-00004.safetensors",
435
+ "lm_heads.7.weight": "model-00004-of-00004.safetensors",
436
+ "lm_heads.8.weight": "model-00004-of-00004.safetensors",
437
+ "lm_heads.9.weight": "model-00004-of-00004.safetensors"
438
+ }
439
+ }
modeling_moss_tts.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Modeling classes for MossTTSDelay. """
16
+
17
+ from dataclasses import dataclass
18
+ from typing import List, Optional, Tuple, Union
19
+ from tqdm import tqdm
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch.nn import CrossEntropyLoss
24
+
25
+ from transformers.modeling_utils import PreTrainedModel
26
+ from transformers.modeling_outputs import ModelOutput
27
+ from transformers.utils import (
28
+ add_start_docstrings,
29
+ add_start_docstrings_to_model_forward,
30
+ logging,
31
+ replace_return_docstrings,
32
+ )
33
+ from transformers.cache_utils import Cache
34
+ from transformers.models.qwen3 import Qwen3Model
35
+ from transformers import initialization as init
36
+
37
+ from .configuration_moss_tts import MossTTSDelayConfig
38
+ from .inference_utils import sample_token, find_last_equal_C
39
+
40
+ try:
41
+ from .processing_moss_tts import UserMessage, AssistantMessage, MossTTSDelayProcessor
42
+ except Exception:
43
+ UserMessage = None
44
+ AssistantMessage = None
45
+ MossTTSDelayProcessor = None
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+ _CONFIG_FOR_DOC = "MossTTSDelayConfig"
50
+
51
+
52
+ @dataclass
53
+ class MossTTSDelayOutputWithPast(ModelOutput):
54
+ """
55
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
56
+
57
+ Args:
58
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
59
+ Weighted sum of channel losses.
60
+ all_sum_losses (`torch.FloatTensor` of shape `(batch_size, n_vq + 1)`, *optional*):
61
+ Sum of losses for each sample and each channel before averaging.
62
+ all_token_nums (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
63
+ Number of non-masked tokens per sample.
64
+ sample_losses (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
65
+ Loss per sample.
66
+ channel_losses (`torch.FloatTensor` of shape `(n_vq + 1,)`, *optional*):
67
+ Loss per channel (text head + vq heads).
68
+ logits (`List[torch.FloatTensor]`, *optional*):
69
+ List of prediction scores from each head.
70
+ past_key_values (`Cache`, *optional*):
71
+ Pre-computed hidden-states (key and values in the self-attention blocks).
72
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
73
+ Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, +
74
+ one for the output of each layer).
75
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed):
76
+ Tuple of torch.FloatTensor (one for each layer) of the attention weights.
77
+ """
78
+ loss: Optional[torch.FloatTensor] = None
79
+ all_sum_losses: Optional[torch.FloatTensor] = None
80
+ all_token_nums: Optional[torch.LongTensor] = None
81
+ sample_losses: Optional[torch.FloatTensor] = None
82
+ channel_losses: Optional[torch.FloatTensor] = None
83
+ logits: Optional[List[torch.FloatTensor]] = None
84
+ past_key_values: Optional[Cache] = None
85
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
86
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
87
+
88
+
89
+
90
+
91
+ class MossTTSDelayPreTrainedModel(PreTrainedModel):
92
+ config_class = MossTTSDelayConfig
93
+ base_model_prefix = "model"
94
+ supports_gradient_checkpointing = True
95
+ _no_split_modules = ["Qwen3DecoderLayer"]
96
+ _skip_keys_device_placement = "past_key_values"
97
+ _supports_flash_attn = True
98
+ _supports_flash_attn_2 = True
99
+ _supports_sdpa = True
100
+ _supports_flex_attn = True
101
+
102
+ def _init_weights(self, module):
103
+ """
104
+ Transformers 5.0+ safe init:
105
+ - MUST use transformers.initialization helpers
106
+ - MUST respect param._is_hf_initialized to avoid overwriting ckpt-loaded params
107
+ """
108
+ # Let HF handle its standard modules first (LayerNorm, Linear, Embedding, etc.)
109
+ super()._init_weights(module)
110
+
111
+ # Pick a std consistent with HF conventions
112
+ # Prefer model/text config initializer_range if present.
113
+ std = None
114
+ if hasattr(self.config, "initializer_range"):
115
+ std = self.config.initializer_range
116
+ elif hasattr(self.config, "language_config") and hasattr(self.config.language_config, "initializer_range"):
117
+ std = self.config.language_config.initializer_range
118
+ else:
119
+ std = 0.02
120
+
121
+ # Initialize extra audio embeddings
122
+ if isinstance(module, nn.Embedding):
123
+ # Only touch our extra embeddings (avoid double touching LM's embeddings if not desired)
124
+ # If you prefer, you can skip this check and rely on super()._init_weights for all embeddings.
125
+ if getattr(module, "num_embeddings", None) == self.config.audio_vocab_size + 1:
126
+ init.normal_(module.weight, mean=0.0, std=std)
127
+ # If you later set padding_idx, you must explicitly zero it (and respect _is_hf_initialized!)
128
+ # init.zeros_ will internally check param flags, but slicing needs manual care.
129
+
130
+ # Initialize multi-head projections you added
131
+ if isinstance(module, nn.Linear):
132
+ # For your lm_heads, super()._init_weights already covers typical Linear.
133
+ # This block is only needed if you have custom Linear variants later.
134
+ pass
135
+
136
+
137
+
138
+ MOSSTTS_START_DOCSTRING = r"""
139
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
140
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
141
+ etc.)
142
+
143
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
144
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
145
+ and behavior.
146
+
147
+ Parameters:
148
+ config ([`MossTTSDelayConfig`]):
149
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
150
+ load the weights associated with the model, only the configuration. Check out the
151
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
152
+ """
153
+
154
+
155
+ @add_start_docstrings(
156
+ "The MossTTSDelay Model architecture tailored for Text-to-Speech generation with multi-head VQ prediction.",
157
+ MOSSTTS_START_DOCSTRING,
158
+ )
159
+ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
160
+ UserMessage = UserMessage
161
+ AssistantMessage = AssistantMessage
162
+ Processor = MossTTSDelayProcessor
163
+
164
+ def __init__(self, config: MossTTSDelayConfig):
165
+ super().__init__(config)
166
+ self.config = config
167
+
168
+ config.language_config.torch_dtype = config.torch_dtype
169
+
170
+ self.language_model = Qwen3Model(config.language_config)
171
+
172
+ # Audio VQ Embeddings (Extra channels)
173
+ # Note: input_ids[..., 0] uses Qwen's embedding.
174
+ # input_ids[..., 1:] use these extensions.
175
+ self.emb_ext = nn.ModuleList()
176
+ for vq_idx in range(self.config.n_vq):
177
+ # Add +1 for potential padding/special tokens logic if strictly required by upstream data prep
178
+ self.emb_ext.append(
179
+ nn.Embedding(self.config.audio_vocab_size + 1, config.language_config.hidden_size, padding_idx=None)
180
+ )
181
+
182
+ # Multi-Head Prediction Layers
183
+ # Head 0: Main language head
184
+ # Head 1..N: Audio VQ heads
185
+ self.lm_heads = nn.ModuleList([
186
+ nn.Linear(config.language_config.hidden_size, config.language_config.vocab_size, bias=False)
187
+ ])
188
+ for vq_idx in range(self.config.n_vq):
189
+ self.lm_heads.append(
190
+ nn.Linear(config.language_config.hidden_size, self.config.audio_vocab_size + 1, bias=False)
191
+ )
192
+
193
+ # Initialize weights and apply final processing
194
+ self.post_init()
195
+
196
+ def get_input_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
197
+ """
198
+ Computes the combined embeddings from text and multiple audio VQ channels.
199
+
200
+ Args:
201
+ input_ids: Shape (Batch, Seq_Len, 1 + n_vq)
202
+ """
203
+ # Base Text/Content Embedding
204
+ # input_ids[..., 0] is standard text or semantic tokens
205
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids[..., 0])
206
+
207
+ # Add VQ Embeddings
208
+ for i, embed_layer in enumerate(self.emb_ext):
209
+ # i corresponds to channel i+1 in input_ids
210
+ # We assume the data pipeline ensures indices are within range
211
+ inputs_embeds = inputs_embeds + embed_layer(input_ids[..., i + 1])
212
+
213
+ return inputs_embeds
214
+
215
+ def set_input_embeddings(self, value):
216
+ self.language_model.embed_tokens = value
217
+
218
+ def get_output_embeddings(self):
219
+ # Returning a list of heads might break some HF utilities expecting a single head.
220
+ # However, for custom models, this is acceptable.
221
+ return self.lm_heads
222
+
223
+ @add_start_docstrings_to_model_forward(MOSSTTS_START_DOCSTRING)
224
+ @replace_return_docstrings(output_type=MossTTSDelayOutputWithPast, config_class=_CONFIG_FOR_DOC)
225
+ def forward(
226
+ self,
227
+ input_ids: Optional[torch.LongTensor] = None,
228
+ attention_mask: Optional[torch.Tensor] = None,
229
+ position_ids: Optional[torch.LongTensor] = None,
230
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
231
+ inputs_embeds: Optional[torch.FloatTensor] = None,
232
+ labels: Optional[torch.LongTensor] = None,
233
+ use_cache: Optional[bool] = None,
234
+ output_attentions: Optional[bool] = None,
235
+ cache_position: Optional[torch.LongTensor] = None,
236
+ hidden_out_layers: Optional[List[int]] = None,
237
+ channelwise_loss_weight: Optional[List[float]] = None,
238
+ **kwargs,
239
+ ) -> Union[Tuple, MossTTSDelayOutputWithPast]:
240
+ r"""
241
+ Args:
242
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 1 + n_vq)`):
243
+ Indices of input sequence tokens in the vocabulary.
244
+ Dimension 2 contains: [Text/Semantics, VQ_0, VQ_1, ..., VQ_N].
245
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length, 1 + n_vq)`, *optional*):
246
+ Labels for computing the masked language modeling loss.
247
+ channelwise_loss_weight (`List[float]`, *optional*):
248
+ Manual weights for summing losses across different heads (Text vs Audio channels).
249
+
250
+ Returns:
251
+ """
252
+
253
+ if len(input_ids.shape) != 3 or input_ids.shape[-1] != self.config.n_vq + 1:
254
+ raise ValueError("`Input_ids`'s shape should be exactly (batch_size, sequence_length, 1 + n_vq).")
255
+
256
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
257
+
258
+ # 1. Prepare Embeddings
259
+ if inputs_embeds is None:
260
+ inputs_embeds = self.get_input_embeddings(input_ids)
261
+
262
+ # 2. Backbone Forward
263
+ # Qwen3Model outputs standard CausalLMOutputWithPast or similar
264
+ outputs = self.language_model(
265
+ input_ids=None, # Passed via inputs_embeds
266
+ position_ids=position_ids,
267
+ attention_mask=attention_mask,
268
+ past_key_values=past_key_values,
269
+ inputs_embeds=inputs_embeds,
270
+ use_cache=use_cache,
271
+ output_attentions=output_attentions,
272
+ output_hidden_states=True, # Always need hidden states for multi-head projection
273
+ return_dict=True,
274
+ cache_position=cache_position,
275
+ **kwargs,
276
+ )
277
+
278
+ # 3. Handle specific layer outputs if requested (Delay Pattern often requires features from specific layers)
279
+ last_hidden_state = outputs.last_hidden_state
280
+ if hidden_out_layers is None:
281
+ # Default to using the last layer for all heads
282
+ # In some architectures (like MusicGen), different codebooks come from different transformer layers.
283
+ # Here we default to the final layer as per original code behavior [-1] * (n + 1).
284
+ hidden_states_for_heads = [last_hidden_state] * (len(self.lm_heads))
285
+ else:
286
+ # If hidden_out_layers is provided (e.g. [-1, -2, -3...]), fetch them from all_hidden_states
287
+ # Note: outputs.hidden_states includes embedding output at index 0 usually.
288
+ all_hs = outputs.hidden_states
289
+ hidden_states_for_heads = [all_hs[idx] for idx in hidden_out_layers]
290
+
291
+ # 4. Project to Logits (Multi-Head)
292
+ layer_logits = []
293
+ for i, (hs, head) in enumerate(zip(hidden_states_for_heads, self.lm_heads)):
294
+ logits = head(hs)
295
+ # Original code logic: Mask the last token index for audio heads (indices > 0)
296
+ # This implies the vocab size is (N+1) but the model shouldn't predict the (N+1)-th token
297
+ # (perhaps reserved for padding in the input but invalid for prediction).
298
+ if i > 0:
299
+ logits[..., -1] = float("-inf")
300
+ layer_logits.append(logits)
301
+
302
+ # 5. Loss Calculation
303
+ loss = None
304
+ all_sum_losses = None
305
+ all_token_nums = None
306
+ sample_losses = None
307
+ channel_losses = None
308
+
309
+ if labels is not None:
310
+ # Ensure labels match input shape rank (B, S, C)
311
+ if labels.dim() != 3:
312
+ raise ValueError(f"Labels must have rank 3 (B, S, C), got {labels.shape}")
313
+
314
+ batch_size = labels.size(0)
315
+ n_heads = len(layer_logits)
316
+
317
+ # Container for per-sample, per-channel losses
318
+ # Shape: [Batch, n_heads]
319
+ all_sum_losses_list = []
320
+
321
+ # Count valid tokens (not -100) per sample.
322
+ # Note: Assuming mask is consistent across channels or we take sum over dim 1 (seq)
323
+ # Usually strict masking means checking one channel or all.
324
+ # Original code: torch.sum(labels != -100, dim=1) -> [B, C]
325
+ all_token_nums = torch.sum(labels != -100, dim=1)
326
+
327
+ for i, logits in enumerate(layer_logits):
328
+ # logits: [B, S, V]
329
+ # cur_labels: [B, S]
330
+ cur_labels = labels[..., i]
331
+
332
+ # Flatten for CrossEntropy
333
+ # logits: [B*S, V], labels: [B*S]
334
+ loss_fct = CrossEntropyLoss(reduction='none')
335
+ vocab_size = logits.size(-1)
336
+
337
+ reshaped_logits = logits.view(-1, vocab_size)
338
+ reshaped_labels = cur_labels.contiguous().view(-1)
339
+
340
+ # Calculate loss per token
341
+ per_token_loss = loss_fct(reshaped_logits, reshaped_labels)
342
+
343
+ # Reshape back to [B, S] and sum over Sequence dimension to get per-sample loss
344
+ per_token_loss = per_token_loss.view(batch_size, -1)
345
+ per_sample_loss = torch.sum(per_token_loss, dim=-1) # [B]
346
+
347
+ all_sum_losses_list.append(per_sample_loss)
348
+
349
+ # Stack to [B, n_heads]
350
+ all_sum_losses = torch.stack(all_sum_losses_list, dim=1)
351
+
352
+ # Weighted Loss Aggregation
353
+ if channelwise_loss_weight is not None:
354
+ if len(channelwise_loss_weight) != n_heads:
355
+ raise ValueError(f"channelwise_loss_weight length {len(channelwise_loss_weight)} != {n_heads}")
356
+
357
+ w_tensor = torch.tensor(channelwise_loss_weight, device=all_sum_losses.device, dtype=all_sum_losses.dtype)
358
+
359
+ # Sample losses: Weighted sum over channels per sample / Total weight
360
+ # Normalize by token count per channel
361
+ # Avoid division by zero with epsilon or mask
362
+ token_counts_safe = all_token_nums.float().clamp(min=1.0)
363
+
364
+ normalized_losses = all_sum_losses / token_counts_safe
365
+ sample_losses = (normalized_losses * w_tensor).sum(dim=1) / w_tensor.sum()
366
+
367
+ # Channel losses: Sum over batch / Sum tokens over batch
368
+ total_loss_per_channel = all_sum_losses.sum(dim=0)
369
+ total_tokens_per_channel = all_token_nums.sum(dim=0).float().clamp(min=1.0)
370
+ channel_losses = total_loss_per_channel / total_tokens_per_channel
371
+
372
+ # Final scalar loss
373
+ loss = (channel_losses * w_tensor).sum() / w_tensor.sum()
374
+ else:
375
+ # Default average if no weights provided
376
+ total_tokens = all_token_nums.sum().float().clamp(min=1.0)
377
+ loss = all_sum_losses.sum() / total_tokens
378
+ channel_losses = all_sum_losses.sum(dim=0) / all_token_nums.sum(dim=0).clamp(min=1.0)
379
+
380
+ return MossTTSDelayOutputWithPast(
381
+ loss=loss,
382
+ all_sum_losses=all_sum_losses,
383
+ all_token_nums=all_token_nums,
384
+ sample_losses=sample_losses,
385
+ channel_losses=channel_losses,
386
+ logits=layer_logits,
387
+ past_key_values=outputs.past_key_values,
388
+ hidden_states=outputs.hidden_states,
389
+ attentions=outputs.attentions,
390
+ )
391
+
392
+ @torch.inference_mode()
393
+ def generate(
394
+ self,
395
+ input_ids: torch.LongTensor,
396
+ attention_mask: Optional[torch.Tensor] = None,
397
+ max_new_tokens: Optional[int] = None,
398
+ text_temperature: float = 1.2,
399
+ text_top_p: float = 0.9,
400
+ text_top_k: int = 50,
401
+ audio_temperature: Optional[float] = None,
402
+ audio_top_p: Optional[float] = None,
403
+ audio_top_k: Optional[int] = None,
404
+ audio_repetition_penalty: Optional[float] = None,
405
+ ):
406
+ generation_config = getattr(self, "generation_config", None)
407
+
408
+ def _cfg_value(name: str, default_value: Union[int, float]) -> Union[int, float]:
409
+ if generation_config is None:
410
+ return default_value
411
+ value = getattr(generation_config, name, None)
412
+ if value is None:
413
+ return default_value
414
+ return value
415
+
416
+ if max_new_tokens is None:
417
+ try:
418
+ max_new_tokens = int(_cfg_value("max_new_tokens", 1000))
419
+ except (TypeError, ValueError):
420
+ max_new_tokens = 1000
421
+ if audio_temperature is None:
422
+ try:
423
+ audio_temperature = float(_cfg_value("temperature", 1.1))
424
+ except (TypeError, ValueError):
425
+ audio_temperature = 1.1
426
+ if audio_top_p is None:
427
+ try:
428
+ audio_top_p = float(_cfg_value("top_p", 0.9))
429
+ except (TypeError, ValueError):
430
+ audio_top_p = 0.9
431
+ if audio_top_k is None:
432
+ try:
433
+ audio_top_k = int(_cfg_value("top_k", 50))
434
+ except (TypeError, ValueError):
435
+ audio_top_k = 50
436
+ if audio_repetition_penalty is None:
437
+ try:
438
+ audio_repetition_penalty = float(_cfg_value("repetition_penalty", 1.1))
439
+ except (TypeError, ValueError):
440
+ audio_repetition_penalty = 1.1
441
+
442
+ if text_temperature > 0:
443
+ text_do_sample = True
444
+ else:
445
+ text_temperature = 1
446
+ text_do_sample = False
447
+ if audio_temperature > 0:
448
+ audio_do_sample = True
449
+ else:
450
+ audio_temperature = 1
451
+ audio_do_sample = False
452
+
453
+ past_key_values = None
454
+ device = input_ids.device
455
+ current_input_ids = input_ids
456
+ current_attention_mask = attention_mask
457
+ batch_size, seq_len, n_vq = input_ids.shape
458
+ n_vq -= 1
459
+
460
+ generation_ids = input_ids[:]
461
+ is_stopping = torch.zeros(batch_size, dtype=torch.bool, device=device)
462
+
463
+ # 三个阶段: 1. 非 audio; 2. audio not delay; 3. audio delay
464
+ audio_lengths = torch.zeros(batch_size, dtype=torch.int64, device=device) # 0 的时候表示阶段1;
465
+ torch_int64_max = torch.iinfo(torch.int64).max
466
+ delayed_lengths = torch.full((batch_size,), torch_int64_max, dtype=torch.int64, device=device) # 最大值的时候表示阶段2;
467
+
468
+ # 考虑 continuation 时 audio_start 已经在 input_ids 中的情况;
469
+ # NOTE 注意我们目前不考虑任何输入已经开始 delay 的情况;
470
+ # 需要同时考虑 continuation 和直接生成的情况;
471
+ is_continuation = (input_ids[:, -1, 0] == self.config.audio_start_token_id) | (input_ids[:, -1, 0] == self.config.audio_assistant_gen_slot_token_id)
472
+ audio_start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id)
473
+ audio_start_mask = is_continuation & (audio_start_indices != -1)
474
+ audio_lengths[audio_start_mask] = seq_len - audio_start_indices[audio_start_mask]
475
+
476
+ is_audio = audio_start_mask.clone()
477
+
478
+ pre_exclude_mask0 = torch.tensor([self.config.pad_token_id, self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id, self.config.audio_end_token_id], device=device)
479
+ pre_exclude_mask1 = torch.ones(self.config.language_config.vocab_size, device=device).bool()
480
+ pre_exclude_mask1[[self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id]] = False
481
+
482
+
483
+ # 注意 time_step 未必表示对于实际对话时,当前输出token的位置,因为有续写的情况;
484
+ for time_step in tqdm(range(max_new_tokens), desc=f"Generating bs{batch_size} ..."):
485
+ outputs = self(
486
+ input_ids=current_input_ids,
487
+ attention_mask=current_attention_mask,
488
+ past_key_values=past_key_values,
489
+ use_cache=True,
490
+ )
491
+ past_key_values = outputs.past_key_values
492
+
493
+ next_token_logits = [logit[:, -1, :] / text_temperature if logit_idx == 0 else logit[:, -1, :] / audio_temperature for logit_idx, logit in enumerate(outputs.logits)] # List, len=n_vq+1, [batch_size, 1, vocab_size];
494
+ next_token_logits[0] = next_token_logits[0].clone()
495
+ # 1. 先处理 text token;
496
+ next_text_token = torch.full((batch_size,), self.config.pad_token_id, device=device)
497
+ # 第二个 audio_assistant_delay_slot_token_id 和 audio_end 是不需要采样的,audio_start, 每一个 audio_assistant_gen_slot_token_ids 和第一个 audio_assistant_delay_slot_token_id 是需要采样的;
498
+ next_text_token[~is_stopping & (delayed_lengths < n_vq)] = self.config.audio_assistant_delay_slot_token_id
499
+ is_audio_eos = ~is_stopping & (delayed_lengths == n_vq)
500
+ next_text_token[is_audio_eos] = self.config.audio_end_token_id
501
+ is_audio[is_audio_eos] = False
502
+ sampling_text_mask = ~is_stopping & (delayed_lengths > n_vq)
503
+ next_token_logits[0][~is_audio] = next_token_logits[0][~is_audio].index_fill(-1, pre_exclude_mask0, float('-inf'))
504
+ next_token_logits[0][is_audio] = next_token_logits[0][is_audio].masked_fill(pre_exclude_mask1, float('-inf'))
505
+ if time_step == 0:
506
+ next_token_logits[0][..., 151662] = float('-inf')
507
+ if time_step <= n_vq:
508
+ next_token_logits[0][..., self.config.im_end_token_id] = float('-inf')
509
+
510
+ # 文本层不使用重复惩罚;
511
+ next_text_token[sampling_text_mask] = sample_token(
512
+ logits=next_token_logits[0][sampling_text_mask],
513
+ top_p=text_top_p,
514
+ top_k=text_top_k,
515
+ do_sample=text_do_sample
516
+ )
517
+ is_audio[next_text_token == self.config.audio_start_token_id] = True
518
+ # 只存在一种停止逻辑,即 next_text_token = <|im_end|>;
519
+ is_stopping[next_text_token == self.config.im_end_token_id] = True
520
+
521
+ # 2. 再处理 audio tokens;
522
+ # audio_start 和 audio_end 之外的内容直接pad,默认是 pad,我们只需要填充有值的部分即可;
523
+ next_audio_tokens = torch.full((batch_size, n_vq), self.config.audio_pad_code, device=device)
524
+
525
+ # 需要考虑的是与 audio_start 的距离;
526
+ # 先查看是否是pad的情况; true 表示有值;
527
+ pre_audio_mask = audio_lengths.unsqueeze(1) > torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq)
528
+ post_audio_mask = torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) > delayed_lengths.unsqueeze(1) - 1
529
+ post_audio_mask[delayed_lengths == torch_int64_max] = True
530
+ sampling_audio_mask = pre_audio_mask & post_audio_mask
531
+ next_audio_tokens[~sampling_audio_mask] = self.config.audio_pad_code
532
+
533
+ if sampling_audio_mask.sum() > 0:
534
+ audio_logits = torch.stack(next_token_logits[1:], dim=1)[sampling_audio_mask] # torch.stack -> [batch_size, n_vq - 1, vocab_size]
535
+ audio_logits[..., self.config.audio_pad_code] = float('-inf')
536
+ next_audio_tokens[sampling_audio_mask] = sample_token(
537
+ logits=audio_logits,
538
+ prev_tokens=generation_ids[:, :, 1:],
539
+ repetition_penalty=audio_repetition_penalty,
540
+ top_p=audio_top_p,
541
+ top_k=audio_top_k,
542
+ do_sample=audio_do_sample
543
+ )
544
+
545
+ # 这里显示的是下一个时间步时可以直接使用的 audio_lengths 和 delayed_lengths 的状态;
546
+ # audio_lengths[(next_text_token == self.audio_start_token_id) & (audio_lengths > 0)] += 1
547
+ # audio_lengths[(next_text_token == self.audio_start_token_id) | (next_text_token == self.audio_assistant_gen_slot_token_id)] += 1
548
+ audio_lengths[(next_text_token == self.config.audio_start_token_id) | (next_text_token == self.config.audio_assistant_gen_slot_token_id) | (next_text_token == self.config.audio_assistant_delay_slot_token_id)] += 1
549
+ audio_lengths[next_text_token == self.config.audio_end_token_id] = 0
550
+ delayed_lengths[(delayed_lengths == torch_int64_max) & (next_text_token == self.config.audio_assistant_delay_slot_token_id)] = 0
551
+ delayed_lengths[delayed_lengths != torch_int64_max] += 1
552
+ delayed_lengths[delayed_lengths > n_vq] = torch_int64_max
553
+
554
+ current_input_ids = torch.cat([next_text_token[:, None, None], next_audio_tokens[:, None, :]], dim=2) # [batch_size, 1, n_vq + 1]
555
+ current_attention_mask = torch.cat([current_attention_mask, (~is_stopping).unsqueeze(-1)], dim=-1)
556
+ generation_ids = torch.cat([generation_ids, current_input_ids], dim=1) # [batch_size, seq_len, n_vq + 1]
557
+
558
+ if is_stopping.sum() == batch_size:
559
+ break
560
+
561
+ start_indices = find_last_equal_C(input_ids[..., 0], self.config.im_start_token_id) + 3
562
+ start_lengths = seq_len - start_indices
563
+
564
+ output = []
565
+ for start_idx, start_length, cur_generation_ids in zip(start_indices, start_lengths, generation_ids):
566
+ output.append((start_length, cur_generation_ids[start_idx:]))
567
+
568
+ return output
processing_moss_tts.py ADDED
@@ -0,0 +1,984 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal, Final, cast
18
+ from dataclasses import dataclass
19
+ from pathlib import Path
20
+ import re
21
+ import torchaudio
22
+
23
+ from transformers import processing_utils
24
+
25
+ if hasattr(processing_utils, "MODALITY_TO_BASE_CLASS_MAPPING"):
26
+ processing_utils.MODALITY_TO_BASE_CLASS_MAPPING["audio_tokenizer"] = (
27
+ "PreTrainedModel"
28
+ )
29
+ elif hasattr(processing_utils, "AUTO_TO_BASE_CLASS_MAPPING"):
30
+ processing_utils.AUTO_TO_BASE_CLASS_MAPPING["audio_tokenizer"] = (
31
+ "PreTrainedModel"
32
+ )
33
+
34
+ import torch
35
+ from transformers import (
36
+ PreTrainedTokenizerBase,
37
+ BatchFeature,
38
+ ProcessorMixin,
39
+ logging,
40
+ AutoConfig,
41
+ AutoModel,
42
+ AutoTokenizer,
43
+ )
44
+
45
+ from .configuration_moss_tts import MossTTSDelayConfig
46
+
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ AUDIO_PLACEHOLDER = "<|audio|>"
52
+
53
+
54
+ @dataclass
55
+ class Message:
56
+ def to_dict(self) -> Dict[str, Any]:
57
+ raise NotImplementedError
58
+
59
+
60
+ @dataclass
61
+ class UserMessage(Message):
62
+ text: Optional[str] = None
63
+ reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None
64
+ instruction: Optional[str] = None
65
+ tokens: Optional[int] = None
66
+ quality: Optional[str] = None
67
+ sound_event: Optional[str] = None
68
+ ambient_sound: Optional[str] = None
69
+ language: Optional[str] = None
70
+ scene: Optional[str] = None
71
+
72
+ def __post_init__(self):
73
+ template = """<user_inst>
74
+ - Reference(s):
75
+ {reference}
76
+ - Instruction:
77
+ {instruction}
78
+ - Tokens:
79
+ None
80
+ - Quality:
81
+ {quality}
82
+ - Sound Event:
83
+ {sound_event}
84
+ - Ambient Sound:
85
+ {ambient_sound}
86
+ - Language:
87
+ {language}
88
+ - Scene:
89
+ {scene}
90
+ - Text:
91
+ {text}
92
+ </user_inst>"""
93
+
94
+ audio_codes_list = []
95
+ if self.reference is None:
96
+ reference = "None"
97
+ elif isinstance(self.reference, List):
98
+ reference = []
99
+ for speaker_idx, speaker_reference in enumerate(self.reference):
100
+ if speaker_reference is None:
101
+ reference.append(f"[S{speaker_idx + 1}]: None")
102
+ else:
103
+ reference.append(f"[S{speaker_idx + 1}]:\n{AUDIO_PLACEHOLDER}")
104
+ audio_codes_list.append(speaker_reference)
105
+ reference = "\n".join(reference)
106
+ else:
107
+ raise TypeError("`reference` should be exactly a list when it is not None.")
108
+
109
+ content = (
110
+ template.replace("{reference}", str(reference))
111
+ .replace("{instruction}", str(self.instruction))
112
+ .replace("{quality}", str(self.quality))
113
+ .replace("{sound_event}", str(self.sound_event))
114
+ .replace("{ambient_sound}", str(self.ambient_sound))
115
+ .replace("{language}", str(self.language))
116
+ .replace("{scene}", "None")
117
+ .replace("{text}", str(self.text))
118
+ )
119
+
120
+ self._content = content
121
+ self._audio_codes_list = audio_codes_list
122
+
123
+ def to_dict(self):
124
+ return {
125
+ "role": "user",
126
+ "content": self._content,
127
+ "audio_codes_list": self._audio_codes_list,
128
+ }
129
+
130
+
131
+ @dataclass
132
+ class AssistantMessage(Message):
133
+ audio_codes_list: List[Union[str, torch.Tensor]]
134
+ content: str = AUDIO_PLACEHOLDER
135
+
136
+ def to_dict(self):
137
+ return {
138
+ "role": "assistant",
139
+ "content": self.content,
140
+ "audio_codes_list": self.audio_codes_list,
141
+ }
142
+
143
+
144
+ USER_MESSAGE_FIELDS = (
145
+ "text",
146
+ "reference",
147
+ "instruction",
148
+ "tokens",
149
+ "quality",
150
+ "sound_event",
151
+ "ambient_sound",
152
+ "language",
153
+ "scene",
154
+ )
155
+
156
+
157
+ class MossTTSDelayProcessor(ProcessorMixin):
158
+ tokenizer_class = "AutoTokenizer"
159
+ audio_tokenizer_class = "AutoModel"
160
+
161
+ tokenizer: PreTrainedTokenizerBase
162
+ audio_tokenizer: Any
163
+
164
+ def __init__(
165
+ self,
166
+ tokenizer: PreTrainedTokenizerBase,
167
+ audio_tokenizer: Any = None,
168
+ model_config: Optional[MossTTSDelayConfig] = None,
169
+ **kwargs,
170
+ ):
171
+ super().__init__(tokenizer=tokenizer, audio_tokenizer=audio_tokenizer, **kwargs)
172
+
173
+ # Explicit assignments for type-checkers; ProcessorMixin sets these too.
174
+ self.tokenizer = tokenizer
175
+ self.audio_tokenizer = audio_tokenizer
176
+ if model_config is None:
177
+ model_config = MossTTSDelayConfig()
178
+ self.model_config = model_config
179
+
180
+ self.imstart_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
181
+ self.imend_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
182
+ self.newline_token_id = 198
183
+
184
+ def _id_to_token(token_id: int) -> str:
185
+ tok = tokenizer.convert_ids_to_tokens(int(token_id))
186
+ if isinstance(tok, list):
187
+ return tok[0] if len(tok) > 0 else ""
188
+ return cast(str, tok)
189
+
190
+ self.audio_user_slot_token = _id_to_token(
191
+ self.model_config.audio_user_slot_token_id
192
+ )
193
+ self.audio_assistant_gen_slot_token = _id_to_token(
194
+ self.model_config.audio_assistant_gen_slot_token_id
195
+ )
196
+ self.audio_assistant_delay_slot_token = _id_to_token(
197
+ self.model_config.audio_assistant_delay_slot_token_id
198
+ )
199
+ self.audio_start_token = _id_to_token(self.model_config.audio_start_token_id)
200
+ self.audio_end_token = _id_to_token(self.model_config.audio_end_token_id)
201
+
202
+ @classmethod
203
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
204
+ trust_remote_code = kwargs.pop("trust_remote_code", True)
205
+ kwargs.pop("_from_auto", None)
206
+
207
+ audio_tokenizer_name_or_path = kwargs.pop(
208
+ "codec_path", "OpenMOSS-Team/MOSS-Audio-Tokenizer"
209
+ )
210
+
211
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
212
+ model_config = cast(
213
+ MossTTSDelayConfig,
214
+ AutoConfig.from_pretrained(
215
+ pretrained_model_name_or_path,
216
+ *args,
217
+ trust_remote_code=trust_remote_code,
218
+ **kwargs,
219
+ ),
220
+ )
221
+ tokenizer = AutoTokenizer.from_pretrained(
222
+ pretrained_model_name_or_path,
223
+ *args,
224
+ trust_remote_code=trust_remote_code,
225
+ **kwargs,
226
+ )
227
+ audio_tokenizer = AutoModel.from_pretrained(
228
+ audio_tokenizer_name_or_path,
229
+ trust_remote_code=trust_remote_code,
230
+ **kwargs,
231
+ )
232
+
233
+ return cls(
234
+ tokenizer=tokenizer,
235
+ audio_tokenizer=audio_tokenizer,
236
+ model_config=model_config,
237
+ **kwargs,
238
+ )
239
+
240
+ def __call__(self, *args, **kwargs) -> BatchFeature:
241
+ conversations = args[0] if len(args) > 0 else kwargs.pop("conversations")
242
+ mode: str = kwargs.pop("mode", "generation")
243
+ apply_chat_template: bool = kwargs.pop("apply_chat_template", True)
244
+ n_vq: Optional[int] = kwargs.pop("n_vq", None)
245
+
246
+ # Common ProcessorMixin kwargs that we ignore because we always return torch tensors.
247
+ kwargs.pop("return_tensors", None)
248
+ kwargs.pop("padding", None)
249
+ kwargs.pop("truncation", None)
250
+
251
+ """
252
+ mode only works when a Message is converted to a dict.
253
+ """
254
+
255
+ if mode not in {"generation", "continuation"}:
256
+ raise RuntimeError
257
+
258
+ if isinstance(conversations, (Message, Dict)):
259
+ conversations = [conversations]
260
+
261
+ truncation = False
262
+ if mode == "continuation":
263
+ truncation = True
264
+
265
+ input_ids_list = []
266
+ for conversation in conversations:
267
+ if isinstance(conversation, (Message, Dict)):
268
+ conversation = [conversation]
269
+
270
+ # Normalize early so downstream logic always deals with dict messages.
271
+ conversation = [self._normalize_message(m) for m in conversation]
272
+
273
+ if (mode == "generation") ^ (len(conversation) % 2 != 0):
274
+ raise ValueError
275
+
276
+ if (mode == "generation") ^ (conversation[-1]["role"] == "user"):
277
+ raise ValueError
278
+
279
+ unified_codes = []
280
+ for message_idx, message in enumerate(conversation):
281
+ if apply_chat_template:
282
+ add_generation_prompt = (
283
+ mode == "generation" and message_idx == len(conversation) - 1
284
+ )
285
+ try:
286
+ content = self.tokenizer.apply_chat_template(
287
+ [{"role": message["role"], "content": message["content"]}],
288
+ add_generation_prompt=add_generation_prompt,
289
+ tokenize=False,
290
+ )
291
+ except TypeError:
292
+ try:
293
+ content = self.tokenizer.apply_chat_template(
294
+ [
295
+ {
296
+ "role": message["role"],
297
+ "content": message["content"],
298
+ }
299
+ ],
300
+ add_generation_prompt=add_generation_prompt,
301
+ )
302
+ except Exception:
303
+ logger.warning(
304
+ "apply_chat_template failed; fallback to raw content."
305
+ )
306
+ content = message["content"]
307
+ else:
308
+ content = message["content"]
309
+
310
+ if not isinstance(content, str):
311
+ content = str(content)
312
+
313
+ # Batch-encode all path-based references in one call when possible.
314
+ # This ensures we actually exercise audio_tokenizer.batch_encode for multi-reference prompts,
315
+ # instead of repeatedly calling it with batch=1.
316
+ raw_audio_items = message.get("audio_codes_list", [])
317
+
318
+ audio_codes_list: List[torch.Tensor] = []
319
+ if len(raw_audio_items) > 0:
320
+ encoded_items: List[Optional[torch.Tensor]] = [None] * len(
321
+ raw_audio_items
322
+ )
323
+ paths: List[str] = []
324
+ path_positions: List[int] = []
325
+
326
+ for idx, item in enumerate(raw_audio_items):
327
+ if isinstance(item, torch.Tensor):
328
+ if n_vq is not None and item.shape[1] != n_vq:
329
+ raise RuntimeError(
330
+ "audio_codes's n_vq is not equal to the parameter `n_vq`. Your can set the parameter `n_vq` as None if you have already tokenzied the wavs."
331
+ )
332
+ encoded_items[idx] = item
333
+ continue
334
+
335
+ if isinstance(item, (str, os.PathLike)):
336
+ paths.append(str(item))
337
+ path_positions.append(idx)
338
+ continue
339
+
340
+ raise TypeError(
341
+ "Each audio item must be a torch.Tensor of codes or a path-like string."
342
+ )
343
+
344
+ if len(paths) > 0:
345
+ encoded_from_paths = self.encode_audios_from_path(paths, n_vq)
346
+ if len(encoded_from_paths) != len(paths):
347
+ raise RuntimeError(
348
+ "encode_audios_from_path returned an unexpected number of items."
349
+ )
350
+ for pos, codes in zip(path_positions, encoded_from_paths):
351
+ encoded_items[pos] = codes
352
+
353
+ audio_codes_list = [cast(torch.Tensor, t) for t in encoded_items]
354
+ unified_codes.append(
355
+ self._get_unified_codes(
356
+ message["role"], content, audio_codes_list, truncation
357
+ )
358
+ )
359
+
360
+ unified_codes = torch.cat(unified_codes)
361
+ input_ids_list.append(unified_codes)
362
+
363
+ return BatchFeature(data=self._pad(input_ids_list))
364
+
365
+ @staticmethod
366
+ def build_user_message(
367
+ text: Optional[str] = None,
368
+ reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None,
369
+ instruction: Optional[str] = None,
370
+ tokens: Optional[int] = None,
371
+ quality: Optional[str] = None,
372
+ sound_event: Optional[str] = None,
373
+ ambient_sound: Optional[str] = None,
374
+ language: Optional[str] = None,
375
+ scene: Optional[str] = None,
376
+ ) -> Dict:
377
+ if reference is not None and not isinstance(reference, list):
378
+ reference = [reference]
379
+ return UserMessage(
380
+ text=text,
381
+ reference=reference,
382
+ instruction=instruction,
383
+ tokens=tokens,
384
+ quality=quality,
385
+ sound_event=sound_event,
386
+ ambient_sound=ambient_sound,
387
+ language=language,
388
+ scene=scene,
389
+ ).to_dict()
390
+
391
+ @staticmethod
392
+ def build_assistant_message(
393
+ audio_codes_list: List[Union[str, torch.Tensor]],
394
+ content: str = AUDIO_PLACEHOLDER,
395
+ ) -> Dict:
396
+ return AssistantMessage(
397
+ audio_codes_list=audio_codes_list,
398
+ content=content,
399
+ ).to_dict()
400
+
401
+ def _normalize_message(self, message: Union[Message, Dict]) -> Dict:
402
+ if isinstance(message, Message):
403
+ return message.to_dict()
404
+ if not isinstance(message, dict):
405
+ raise TypeError("Each message must be a Message or dict.")
406
+ if "role" not in message:
407
+ raise ValueError("Message dict must include a 'role' field.")
408
+ if "content" in message and "audio_codes_list" in message:
409
+ return message
410
+ role = message["role"]
411
+ if role == "user":
412
+ kwargs = {key: message.get(key) for key in USER_MESSAGE_FIELDS}
413
+ return self.build_user_message(**kwargs)
414
+ if role == "assistant":
415
+ return self.build_assistant_message(
416
+ audio_codes_list=message.get("audio_codes_list", []),
417
+ content=message.get("content", AUDIO_PLACEHOLDER),
418
+ )
419
+ raise ValueError(f"Unsupported role: {role}")
420
+
421
+ def _pad(self, input_ids_list: List[torch.Tensor]):
422
+ device = input_ids_list[0].device
423
+ lengths = torch.tensor([w.shape[0] for w in input_ids_list], device=device)
424
+ pad_input_ids = torch.nn.utils.rnn.pad_sequence(
425
+ input_ids_list,
426
+ batch_first=True,
427
+ padding_value=self.model_config.audio_pad_code,
428
+ padding_side="left",
429
+ )
430
+ other_channel_mask = (pad_input_ids.shape[1] - lengths).unsqueeze(
431
+ 1
432
+ ) > torch.arange(pad_input_ids.shape[1], device=device).unsqueeze(0)
433
+ pad_input_ids[..., 0][other_channel_mask] = self.model_config.pad_token_id
434
+ attention_mask = torch.zeros(
435
+ pad_input_ids.shape[0], pad_input_ids.shape[1], device=device
436
+ )
437
+ attention_mask[~other_channel_mask] = 1
438
+ attention_mask = attention_mask.bool()
439
+ return {
440
+ "input_ids": pad_input_ids, # [batch_size, seqlen, n_vq]
441
+ "attention_mask": attention_mask,
442
+ }
443
+
444
+ @staticmethod
445
+ def _replace_audio_placeholders(
446
+ content: str,
447
+ lengths: List[int],
448
+ n_vq: int,
449
+ gen_slot_token: str,
450
+ delay_slot_token: str,
451
+ audio_start_token: str,
452
+ audio_end_token: str,
453
+ ) -> str:
454
+ if n_vq < 1:
455
+ raise ValueError(f"n_vq must be >= 1, got {n_vq}")
456
+
457
+ num_placeholders = content.count(AUDIO_PLACEHOLDER)
458
+ if num_placeholders != len(lengths):
459
+ raise ValueError(
460
+ f"Number of {AUDIO_PLACEHOLDER} ({num_placeholders}) "
461
+ f"does not match lengths ({len(lengths)})"
462
+ )
463
+
464
+ def build_audio_block(length: int) -> str:
465
+ if length < 0:
466
+ raise ValueError(f"length must be >= 0, got {length}")
467
+
468
+ if length == 0:
469
+ return f"{audio_start_token}{audio_end_token}"
470
+
471
+ step_tokens = gen_slot_token * length + (delay_slot_token * (n_vq - 1))
472
+ return f"{audio_start_token}{step_tokens}{audio_end_token}"
473
+
474
+ lengths_iter = iter(lengths)
475
+
476
+ def replacer(match: re.Match) -> str:
477
+ length = next(lengths_iter)
478
+ return build_audio_block(length)
479
+
480
+ result = re.sub(re.escape(AUDIO_PLACEHOLDER), replacer, content)
481
+
482
+ return result
483
+
484
+ @staticmethod
485
+ def _merge_consecutive_audio_placeholders(
486
+ content: str,
487
+ audio_codes_list: List[torch.Tensor],
488
+ ) -> Tuple[str, List[torch.Tensor]]:
489
+ matches = list(re.finditer(re.escape(AUDIO_PLACEHOLDER), content))
490
+ if len(matches) <= 1:
491
+ return content, audio_codes_list
492
+
493
+ if len(matches) != len(audio_codes_list):
494
+ raise ValueError(
495
+ "Audio placeholders do not match the provided audio codes list."
496
+ )
497
+
498
+ new_audio_codes_list = []
499
+ new_parts = []
500
+ last_pos = 0
501
+ i = 0
502
+ while i < len(matches):
503
+ j = i
504
+ while (
505
+ j + 1 < len(matches)
506
+ and content[matches[j].end() : matches[j + 1].start()].strip() == ""
507
+ ):
508
+ j += 1
509
+
510
+ new_parts.append(content[last_pos : matches[i].start()])
511
+ new_parts.append(AUDIO_PLACEHOLDER)
512
+ last_pos = matches[j].end()
513
+
514
+ if j == i:
515
+ new_audio_codes_list.append(audio_codes_list[i])
516
+ else:
517
+ new_audio_codes_list.append(
518
+ torch.cat(audio_codes_list[i : j + 1], dim=0)
519
+ )
520
+
521
+ i = j + 1
522
+
523
+ new_parts.append(content[last_pos:])
524
+ return "".join(new_parts), new_audio_codes_list
525
+
526
+ @staticmethod
527
+ def apply_delay_pattern(codes: torch.Tensor, pad_code: int) -> torch.Tensor:
528
+ delayed_tokens = torch.full(
529
+ (codes.shape[0] + codes.shape[1] - 1, codes.shape[1]),
530
+ pad_code,
531
+ device=codes.device,
532
+ dtype=codes.dtype,
533
+ )
534
+ for i in range(codes.shape[1]):
535
+ delayed_tokens[i : i + codes.shape[0], i] = codes[:, i]
536
+ return delayed_tokens
537
+
538
+ @staticmethod
539
+ def apply_de_delay_pattern(delay_codes: torch.Tensor) -> torch.Tensor:
540
+ tokens = torch.full(
541
+ (delay_codes.shape[0] - delay_codes.shape[1] + 1, delay_codes.shape[1]),
542
+ 0,
543
+ device=delay_codes.device,
544
+ dtype=delay_codes.dtype,
545
+ )
546
+ for i in range(delay_codes.shape[1]):
547
+ tokens[:, i] = delay_codes[i : i + tokens.shape[0], i]
548
+ return tokens
549
+
550
+ def _get_unified_codes(
551
+ self,
552
+ role: str,
553
+ content: str,
554
+ audio_codes_list: List[torch.Tensor],
555
+ truncation: bool,
556
+ ) -> torch.Tensor:
557
+ """
558
+ 此时的 content 已经是带上了对话格式
559
+ """
560
+ if role == "user":
561
+ audio_gen_slot_token = audio_delay_slot_token = self.audio_user_slot_token
562
+ truncation= False
563
+ else:
564
+ audio_gen_slot_token = self.audio_assistant_gen_slot_token
565
+ audio_delay_slot_token = self.audio_assistant_delay_slot_token
566
+
567
+ # Always follow model RVQ channels (e.g. n_vq=16) and truncate
568
+ # tokenizer outputs (e.g. 32-layer RVQ) to the first n_vq layers.
569
+ n_vq = self.model_config.n_vq
570
+ normalized_audio_codes_list: List[torch.Tensor] = []
571
+ for audio_codes in audio_codes_list:
572
+ if audio_codes.dim() != 2:
573
+ raise RuntimeError(
574
+ f"Expect audio codes with rank 2, got {tuple(audio_codes.shape)}"
575
+ )
576
+
577
+ # Handle possible [NQ, T] layout.
578
+ if audio_codes.shape[1] < n_vq and audio_codes.shape[0] >= n_vq:
579
+ audio_codes = audio_codes.transpose(0, 1)
580
+
581
+ if audio_codes.shape[1] < n_vq:
582
+ raise RuntimeError(
583
+ f"audio_codes channels ({audio_codes.shape[1]}) < model n_vq ({n_vq})"
584
+ )
585
+
586
+ audio_codes = audio_codes[:, :n_vq]
587
+ normalized_audio_codes_list.append(audio_codes)
588
+
589
+ audio_codes_list = normalized_audio_codes_list
590
+
591
+ if len(audio_codes_list) > 1 and AUDIO_PLACEHOLDER in content:
592
+ content, audio_codes_list = self._merge_consecutive_audio_placeholders(
593
+ content, audio_codes_list
594
+ )
595
+ content = self._replace_audio_placeholders(
596
+ content=content,
597
+ lengths=[len(audio_codes) for audio_codes in audio_codes_list],
598
+ n_vq=n_vq,
599
+ gen_slot_token=audio_gen_slot_token,
600
+ delay_slot_token=audio_delay_slot_token,
601
+ audio_start_token=self.audio_start_token,
602
+ audio_end_token=self.audio_end_token,
603
+ )
604
+ text_codes = torch.tensor(
605
+ self.tokenizer.encode(content),
606
+ device=audio_codes_list[0].device if audio_codes_list else None,
607
+ )
608
+
609
+ audio_start_indices = torch.where(
610
+ text_codes == self.model_config.audio_start_token_id
611
+ )[0]
612
+ audio_end_indices = torch.where(
613
+ text_codes == self.model_config.audio_end_token_id
614
+ )[0]
615
+ if len(audio_start_indices) != len(audio_codes_list) or len(
616
+ audio_end_indices
617
+ ) != len(audio_codes_list):
618
+ raise ValueError(
619
+ "Audio placeholders do not match the provided audio codes list."
620
+ )
621
+
622
+ delay_audio_codes_list = []
623
+ if len(audio_codes_list) == 0:
624
+ delay_audio_codes_list = torch.full(
625
+ (len(text_codes), n_vq),
626
+ self.model_config.audio_pad_code,
627
+ device=text_codes.device,
628
+ dtype=text_codes.dtype,
629
+ )
630
+ else:
631
+ prefix_idx = 0
632
+ for audio_start_idx_t, audio_end_idx_t, audio_codes in zip(
633
+ audio_start_indices, audio_end_indices, audio_codes_list
634
+ ):
635
+ audio_start_idx = int(audio_start_idx_t.item())
636
+ audio_end_idx = int(audio_end_idx_t.item())
637
+ delay_audio_codes = self.apply_delay_pattern(
638
+ audio_codes, self.model_config.audio_pad_code
639
+ )
640
+ pad_codes = torch.full(
641
+ (audio_start_idx - prefix_idx + 1, n_vq),
642
+ self.model_config.audio_pad_code,
643
+ device=audio_codes.device,
644
+ dtype=audio_codes.dtype,
645
+ )
646
+ delay_audio_codes_list.extend([pad_codes, delay_audio_codes])
647
+ prefix_idx = audio_end_idx
648
+
649
+ if truncation:
650
+ delay_audio_codes_list[-1] = delay_audio_codes_list[-1][
651
+ : -(n_vq - 1), :
652
+ ]
653
+ else:
654
+ last_audio_end_idx = int(audio_end_indices[-1].item())
655
+ pad_codes = torch.full(
656
+ (len(text_codes) - last_audio_end_idx, n_vq),
657
+ self.model_config.audio_pad_code,
658
+ device=audio_codes_list[0].device,
659
+ dtype=audio_codes_list[0].dtype,
660
+ )
661
+ delay_audio_codes_list.append(pad_codes)
662
+
663
+ delay_audio_codes_list = torch.cat(delay_audio_codes_list)
664
+
665
+ if text_codes.shape[0] != delay_audio_codes_list.shape[0]:
666
+ text_codes = text_codes[: delay_audio_codes_list.shape[0]]
667
+
668
+ unified_codes = torch.cat(
669
+ [text_codes.unsqueeze(1), delay_audio_codes_list], dim=1
670
+ )
671
+ return unified_codes
672
+
673
+ def _parse_text_codes(self, start_length, text_codes):
674
+ text = cast(str, self.tokenizer.decode(text_codes))
675
+ prefix = cast(str, self.tokenizer.decode(text_codes[:start_length]))
676
+ text = text[len(prefix) :]
677
+
678
+ AUDIO_PATTERN = re.compile(
679
+ rf"(?:{self.audio_start_token})?"
680
+ rf"(?:{self.audio_assistant_gen_slot_token})*"
681
+ rf"(?:{self.audio_assistant_delay_slot_token})*"
682
+ rf"{self.audio_end_token}"
683
+ )
684
+
685
+ def normalize_audio_segments(text: str) -> str:
686
+ def repl(match: re.Match) -> str:
687
+ seg = match.group(0)
688
+ # Replace with <|audio|> if gen_slot is present in the segment;
689
+ if self.audio_assistant_gen_slot_token in seg:
690
+ return AUDIO_PLACEHOLDER
691
+ # Otherwise, remove it.
692
+ return ""
693
+
694
+ return AUDIO_PATTERN.sub(repl, text)
695
+
696
+ return normalize_audio_segments(text)
697
+
698
+ def _parse_audio_codes(self, start_length, audio_codes):
699
+ # De-delay back to [T', n_vq]
700
+ audio_codes = self.apply_de_delay_pattern(audio_codes)
701
+
702
+ # Rows that are all pad are separators between real audio segments.
703
+ is_pad = (audio_codes == self.model_config.audio_pad_code).all(dim=1)
704
+ non_pad = ~is_pad
705
+ if not non_pad.any():
706
+ return []
707
+
708
+ idx = torch.nonzero(non_pad).squeeze(1)
709
+ breaks = torch.where(idx[1:] != idx[:-1] + 1)[0] + 1
710
+ if breaks.numel() == 0:
711
+ segments_idx = [idx]
712
+ else:
713
+ segments_idx = torch.split(idx, breaks.tolist())
714
+
715
+ audio_codes_list = [audio_codes[s] for s in segments_idx]
716
+
717
+ # Batch-decode all audio segments together.
718
+ decoded_audio_list = self.decode_audio_codes(audio_codes_list)
719
+
720
+ # Keep codec causal context by decoding the whole first segment first,
721
+ # then trim at waveform level according to start_length ratio.
722
+ if (
723
+ start_length > 0
724
+ and len(audio_codes_list) > 0
725
+ and len(decoded_audio_list) > 0
726
+ ):
727
+ first_codes_length = audio_codes_list[0].shape[0]
728
+ if first_codes_length > 0:
729
+ trim_ratio = max(
730
+ 0.0, min(float(start_length) / float(first_codes_length), 1.0)
731
+ )
732
+ first_audio = decoded_audio_list[0]
733
+ if trim_ratio >= 1.0:
734
+ decoded_audio_list = decoded_audio_list[1:]
735
+ elif trim_ratio > 0.0:
736
+ trim_samples = int(first_audio.shape[-1] * trim_ratio)
737
+ decoded_audio_list[0] = first_audio[..., trim_samples:]
738
+
739
+ return decoded_audio_list
740
+
741
+ def decode(self, output: List[Tuple[int, torch.Tensor]]):
742
+ """
743
+ 1. 这里不管怎样,都需要一个完整的 assistant generation ids;
744
+ 2. 支持从任意位置进行截断;
745
+ """
746
+
747
+ genearted_messages = []
748
+ for start_length, generation_ids in output:
749
+ content = self._parse_text_codes(start_length, generation_ids[:, 0])
750
+ audio_codes_list = self._parse_audio_codes(
751
+ start_length, generation_ids[:, 1:]
752
+ )
753
+ if content == "":
754
+ message = None
755
+ else:
756
+ message = AssistantMessage(
757
+ content=content,
758
+ audio_codes_list=cast(
759
+ List[Union[str, torch.Tensor]], audio_codes_list
760
+ ),
761
+ )
762
+ genearted_messages.append(message)
763
+ return genearted_messages
764
+
765
+ @staticmethod
766
+ def loudness_normalize(
767
+ wav: torch.Tensor,
768
+ target_dbfs: float = -20,
769
+ gain_range: tuple[float, float] = (-3.0, 3.0),
770
+ ) -> torch.Tensor:
771
+ wav = wav.to(torch.float32)
772
+ if wav.numel() == 0:
773
+ return wav
774
+ current_dbfs = 10.0 * torch.log10(torch.mean(wav**2) + 1e-9)
775
+ gain = float(target_dbfs - current_dbfs)
776
+ gain = max(gain_range[0], min(gain, gain_range[1]))
777
+ factor = 10.0 ** (gain / 20.0)
778
+ return wav * factor
779
+
780
+ def _get_audio_tokenizer_device(self) -> torch.device:
781
+ """Best-effort device inference for `self.audio_tokenizer`.
782
+
783
+ Notes:
784
+ - Old TAC wrapper exposed `.device`, but standard `torch.nn.Module` does not.
785
+ - New MossAudioTokenizerModel is a `PreTrainedModel`; parameters define its device.
786
+ """
787
+
788
+ audio_tokenizer = getattr(self, "audio_tokenizer", None)
789
+ if audio_tokenizer is None:
790
+ logger.warning(
791
+ "audio_tokenizer is not set on processor. Using CPU as default."
792
+ )
793
+ return torch.device("cpu")
794
+
795
+ device_attr = getattr(audio_tokenizer, "device", None)
796
+ if isinstance(device_attr, torch.device):
797
+ return device_attr
798
+
799
+ try:
800
+ return next(audio_tokenizer.parameters()).device
801
+ except StopIteration:
802
+ # No parameters (shouldn't happen for real models); default to CPU.
803
+ logger.warning(
804
+ "No parameters found on audio_tokenizer. Using CPU as default."
805
+ )
806
+ return torch.device("cpu")
807
+
808
+ def encode_audios_from_wav(
809
+ self,
810
+ wav_list: List[torch.Tensor],
811
+ sampling_rate: int,
812
+ n_vq: Optional[int] = None,
813
+ ):
814
+ if self.audio_tokenizer is None:
815
+ raise RuntimeError("audio_tokenizer is not set on processor.")
816
+ audio_tokenizer = self.audio_tokenizer
817
+
818
+ if isinstance(wav_list, torch.Tensor):
819
+ wav_list = [wav_list]
820
+ wav_list_ = []
821
+ resample = False
822
+ if sampling_rate != self.model_config.sampling_rate:
823
+ resample = True
824
+ device = self._get_audio_tokenizer_device()
825
+ for wav in wav_list:
826
+ if wav.shape[0] > 1:
827
+ wav = torch.mean(wav, dim=0, keepdim=True)
828
+ if resample:
829
+ wav = torchaudio.functional.resample(
830
+ waveform=wav,
831
+ orig_freq=sampling_rate,
832
+ new_freq=self.model_config.sampling_rate,
833
+ )
834
+ wav = wav.to(device)
835
+ wav_list_.append(self.loudness_normalize(wav.squeeze(0)))
836
+
837
+ # New MossAudioTokenizerModel API: prefer batch_encode(list[wav])
838
+ if hasattr(audio_tokenizer, "batch_encode"):
839
+ enc = audio_tokenizer.batch_encode(wav_list_, num_quantizers=n_vq)
840
+ audio_codes = enc.audio_codes # (NQ, B, T)
841
+ audio_codes_lengths = enc.audio_codes_lengths # (B,)
842
+ else:
843
+ # Fallback: use encode() with explicit padding.
844
+ max_len = max(int(wav.shape[-1]) for wav in wav_list_)
845
+ input_values = torch.zeros(
846
+ len(wav_list_), 1, max_len, device=device, dtype=torch.float32
847
+ )
848
+ padding_mask = torch.zeros(
849
+ len(wav_list_), max_len, device=device, dtype=torch.bool
850
+ )
851
+ for i, wav in enumerate(wav_list_):
852
+ this_len = int(wav.shape[-1])
853
+ input_values[i, 0, :this_len] = wav
854
+ padding_mask[i, :this_len] = True
855
+ enc = audio_tokenizer.encode(
856
+ input_values,
857
+ padding_mask=padding_mask,
858
+ num_quantizers=n_vq,
859
+ return_dict=True,
860
+ )
861
+ audio_codes = enc.audio_codes
862
+ audio_codes_lengths = enc.audio_codes_lengths
863
+
864
+ if audio_codes is None or audio_codes_lengths is None:
865
+ raise RuntimeError(
866
+ "audio_tokenizer.encode() returned empty outputs (audio_codes/audio_codes_lengths)."
867
+ )
868
+
869
+ # Keep processor's historical contract: list[Tensor] with shape (T, NQ)
870
+ # and on CPU (so downstream text/audio packing remains device-agnostic).
871
+ codes_list: List[torch.Tensor] = []
872
+ for i in range(int(audio_codes.shape[1])):
873
+ length_i = int(audio_codes_lengths[i].item())
874
+ codes_i = (
875
+ audio_codes[:, i, :length_i]
876
+ .transpose(0, 1)
877
+ .contiguous()
878
+ .to(torch.long)
879
+ .cpu()
880
+ )
881
+ codes_list.append(codes_i)
882
+ return codes_list
883
+
884
+ def encode_audios_from_path(
885
+ self, wav_path_list: Union[str, List[str]], n_vq: Optional[int] = None
886
+ ):
887
+ if isinstance(wav_path_list, str):
888
+ wav_path_list = [wav_path_list]
889
+
890
+ if len(wav_path_list) == 0:
891
+ raise ValueError("Empty wav_path_list")
892
+
893
+ # Load + (if needed) resample each wav independently, so callers can
894
+ # pass a heterogeneous batch of files while still benefiting from
895
+ # audio_tokenizer.batch_encode.
896
+ target_sr = int(self.model_config.sampling_rate)
897
+ wav_list: List[torch.Tensor] = []
898
+ for wav_path in wav_path_list:
899
+ wav, sr = torchaudio.load(wav_path)
900
+ if int(sr) != target_sr:
901
+ wav = torchaudio.functional.resample(
902
+ waveform=wav,
903
+ orig_freq=int(sr),
904
+ new_freq=target_sr,
905
+ )
906
+ wav_list.append(wav)
907
+
908
+ return self.encode_audios_from_wav(wav_list, target_sr, n_vq)
909
+
910
+ def decode_audio_codes(
911
+ self, audio_tokens_list: Union[torch.Tensor, List[torch.Tensor]]
912
+ ):
913
+ if self.audio_tokenizer is None:
914
+ raise RuntimeError("audio_tokenizer is not set on processor.")
915
+ audio_tokenizer = self.audio_tokenizer
916
+
917
+ if isinstance(audio_tokens_list, torch.Tensor):
918
+ audio_tokens_list = [audio_tokens_list]
919
+ if len(audio_tokens_list) == 0:
920
+ return []
921
+
922
+ device = self._get_audio_tokenizer_device()
923
+
924
+ # Processor uses (T, NQ); MossAudioTokenizer expects (NQ, T) (or (NQ, B, T)).
925
+ codes_list = [
926
+ codes.transpose(0, 1).contiguous().to(device=device, dtype=torch.long)
927
+ for codes in audio_tokens_list
928
+ ]
929
+
930
+ # Align with legacy behavior: decode each sample with chunk_duration=8.0.
931
+ # Streaming chunk decode currently supports batch_size=1 in MossAudioTokenizer.
932
+ if hasattr(audio_tokenizer, "decode"):
933
+ wav_list: List[torch.Tensor] = []
934
+ for codes in codes_list:
935
+ try:
936
+ dec = audio_tokenizer.decode(
937
+ codes,
938
+ return_dict=True,
939
+ chunk_duration=8.0,
940
+ )
941
+ except TypeError:
942
+ # Compatibility fallback for tokenizers without chunk_duration arg.
943
+ dec = audio_tokenizer.decode(
944
+ codes,
945
+ return_dict=True,
946
+ )
947
+
948
+ audio = dec.audio
949
+ audio_lengths = dec.audio_lengths
950
+ if audio is None:
951
+ raise RuntimeError("audio_tokenizer.decode() returned empty audio.")
952
+
953
+ if audio_lengths is None:
954
+ cur_len = int(audio.shape[-1])
955
+ else:
956
+ cur_len = int(audio_lengths[0].item())
957
+
958
+ if audio.ndim == 3:
959
+ wav = audio[0, 0, :cur_len]
960
+ elif audio.ndim == 2:
961
+ wav = audio[0, :cur_len]
962
+ else:
963
+ raise RuntimeError(
964
+ f"Unexpected audio shape from decode: {tuple(audio.shape)}"
965
+ )
966
+ wav_list.append(wav.contiguous().to(torch.float32).cpu())
967
+ return wav_list
968
+
969
+ if hasattr(audio_tokenizer, "batch_decode"):
970
+ dec = audio_tokenizer.batch_decode(codes_list)
971
+ audio = dec.audio
972
+ audio_lengths = dec.audio_lengths
973
+ if audio is None or audio_lengths is None:
974
+ raise RuntimeError(
975
+ "audio_tokenizer.batch_decode() returned empty outputs (audio/audio_lengths)."
976
+ )
977
+ wav_list = []
978
+ for i in range(int(audio.shape[0])):
979
+ length_i = int(audio_lengths[i].item())
980
+ wav = audio[i, 0, :length_i].contiguous().to(torch.float32).cpu()
981
+ wav_list.append(wav)
982
+ return wav_list
983
+
984
+ raise RuntimeError("audio_tokenizer has neither decode() nor batch_decode().")
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "MossTTSDelayProcessor",
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_moss_tts.MossTTSDelayProcessor"
5
+ }
6
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|audio_start|>",
12
+ "<|audio_end|>",
13
+ "<|audio_user_slot|>",
14
+ "<|image_pad|>",
15
+ "<|audio_assistant_gen_slot|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb3c8fa82993d515469c2800cc455bff4aaa3c4fed9da1f2b0c0668c304f335a
3
+ size 11422691
tokenizer_config.json ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|audio_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|audio_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|audio_user_slot|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|audio_assistant_gen_slot|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|audio_assistant_delay_slot|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|audio_start|>",
224
+ "<|audio_end|>",
225
+ "<|audio_user_slot|>",
226
+ "<|image_pad|>",
227
+ "<|audio_assistant_gen_slot|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|endoftext|>",
236
+ "processor_class": "AsteroidProcessor",
237
+ "split_special_tokens": false,
238
+ "tokenizer_class": "Qwen2Tokenizer",
239
+ "unk_token": null
240
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff