schwarztgyt commited on
Commit
a724b39
·
1 Parent(s): 054b672

Upload voiceplus_qwen3_1.7B_tp8_rvq32_all_data_tacv3_max_lr_2e-4_min_2e-4_enhanced_lm_head_add_layer_norm_wd_0.1_from_pretrained_seqlen_14336_decay iter_0015000 model snapshot

Browse files
README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
__init__.py ADDED
File without changes
added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|audio_end|>": 151653,
9
+ "<|audio_pad|>": 151654,
10
+ "<|audio_start|>": 151652,
11
+ "<|box_end|>": 151649,
12
+ "<|box_start|>": 151648,
13
+ "<|endoftext|>": 151643,
14
+ "<|file_sep|>": 151664,
15
+ "<|fim_middle|>": 151660,
16
+ "<|fim_pad|>": 151662,
17
+ "<|fim_prefix|>": 151659,
18
+ "<|fim_suffix|>": 151661,
19
+ "<|im_end|>": 151645,
20
+ "<|im_start|>": 151644,
21
+ "<|image_pad|>": 151655,
22
+ "<|object_ref_end|>": 151647,
23
+ "<|object_ref_start|>": 151646,
24
+ "<|quad_end|>": 151651,
25
+ "<|quad_start|>": 151650,
26
+ "<|repo_name|>": 151663,
27
+ "<|video_pad|>": 151656
28
+ }
chat_template.jinja ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {% for message in messages %}<|im_start|>{{ message['role'] }}
2
+ {% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content.get('type') == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}<|im_end|>
3
+ {% endfor %}{% if add_generation_prompt %}<|im_start|>assistant
4
+ {% endif %}
config.json ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "moss_tts_delay",
3
+ "architectures": [
4
+ "MossTTSDelayModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_moss_tts.MossTTSDelayConfig",
8
+ "AutoModel": "modeling_moss_tts.MossTTSDelayModel"
9
+ },
10
+ "dtype": "bfloat16",
11
+ "initializer_range": 0.02,
12
+ "language_config": {
13
+ "_name_or_path": "Qwen/Qwen3-8B",
14
+ "architectures": [
15
+ "Qwen3ForCausalLM"
16
+ ],
17
+ "attention_bias": false,
18
+ "attention_dropout": 0.0,
19
+ "bos_token_id": 151643,
20
+ "eos_token_id": 151645,
21
+ "pad_token_id": 151643,
22
+ "head_dim": 128,
23
+ "hidden_act": "silu",
24
+ "hidden_size": 2048,
25
+ "initializer_range": 0.02,
26
+ "intermediate_size": 6144,
27
+ "layer_types": [
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention",
44
+ "full_attention",
45
+ "full_attention",
46
+ "full_attention",
47
+ "full_attention",
48
+ "full_attention",
49
+ "full_attention",
50
+ "full_attention",
51
+ "full_attention",
52
+ "full_attention",
53
+ "full_attention",
54
+ "full_attention",
55
+ "full_attention"
56
+ ],
57
+ "max_position_embeddings": 40960,
58
+ "max_window_layers": 28,
59
+ "model_type": "qwen3",
60
+ "num_attention_heads": 16,
61
+ "num_hidden_layers": 28,
62
+ "num_key_value_heads": 8,
63
+ "rms_norm_eps": 1e-06,
64
+ "rope_scaling": null,
65
+ "rope_theta": 1000000,
66
+ "sliding_window": null,
67
+ "use_cache": true,
68
+ "use_sliding_window": false,
69
+ "vocab_size": 155648
70
+ },
71
+ "n_vq": 32,
72
+ "audio_vocab_size": 1024,
73
+ "audio_user_slot_token_id": 151654,
74
+ "audio_assistant_gen_slot_token_id": 151656,
75
+ "audio_assistant_delay_slot_token_id": 151662,
76
+ "audio_start_token_id": 151652,
77
+ "audio_end_token_id": 151653,
78
+ "audio_pad_code": 1024,
79
+ "sampling_rate": 24000,
80
+ "transformers_version": "4.57.1",
81
+
82
+ "additional_mlp_ffn_hidden_size": 2048,
83
+ "local_ffn_hidden_size": 8960,
84
+ "local_hidden_size": 1536,
85
+ "local_num_layers": 4
86
+ }
configuration_moss_tts.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ MossTTSDelay model configuration """
16
+
17
+ from typing import Optional, Union
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+ from transformers.models.qwen3 import Qwen3Config
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class MossTTSDelayConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`MossTTSDelayModel`]. It is used to instantiate an
28
+ MossTTSDelay model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of the MossTTSDelay [MossTTSDelay-8B](https://huggingface.co/OpenMOSS/mosstts-8b) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ language_config (`Union[Qwen3Config, dict]`, *optional*):
36
+ Configuration for the backbone language model (Qwen3).
37
+ initializer_range (`float`, *optional*, defaults to 0.02):
38
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
39
+ n_vq (`int`, *optional*, defaults to 32):
40
+ Number of additional VQ (Vector Quantization) heads/channels for audio.
41
+ Determines the number of codebooks used in the audio representation.
42
+ audio_vocab_size (`int`, *optional*, defaults to 1024):
43
+ Vocabulary size for the audio tokens (codebooks 1 to N).
44
+ audio_user_slot_token_id (`int`, *optional*, defaults to 151654):
45
+ The specific token ID used as a placeholder/slot for user-side audio inputs in the prompt.
46
+ audio_assistant_gen_slot_token_id (`int`, *optional*, defaults to 151656):
47
+ The specific token ID representing the generation slot for the assistant's audio output.
48
+ Acting as the trigger for the TTS generation process.
49
+ audio_assistant_delay_slot_token_id (`int`, *optional*, defaults to 151662):
50
+ The token ID used in the 'Delay Pattern' paradigm to represent the delayed/offset positions
51
+ between different VQ channels.
52
+ audio_start_token_id (`int`, *optional*, defaults to 151652):
53
+ Special token ID used to denote the start of an audio sequence in the stream.
54
+ audio_end_token_id (`int`, *optional*, defaults to 151653):
55
+ Special token ID used to denote the end of an audio sequence (EOS for audio).
56
+ audio_pad_code (`int`, *optional*, defaults to 1024):
57
+ The padding value used within the audio VQ codebooks. Typically equals `audio_vocab_size`.
58
+ """
59
+ model_type = "moss_tts_delay"
60
+ keys_to_ignore_at_inference = ["past_key_values"]
61
+
62
+ def __init__(
63
+ self,
64
+ language_config: Optional[Union[Qwen3Config, dict]] = None,
65
+ initializer_range: float = 0.02,
66
+ n_vq: int = 32,
67
+ pad_token_id: int = 151643,
68
+ im_start_token_id: int = 151644,
69
+ im_end_token_id: int = 151645,
70
+ audio_vocab_size: int = 1024,
71
+ audio_user_slot_token_id: int = 151654,
72
+ audio_assistant_gen_slot_token_id: int = 151656,
73
+ audio_assistant_delay_slot_token_id: int = 151662,
74
+ audio_start_token_id: int = 151652,
75
+ audio_end_token_id: int = 151653,
76
+ audio_pad_code: int = 1024,
77
+ sampling_rate: int = 24000,
78
+ additional_mlp_ffn_hidden_size: int = 2048,
79
+ local_ffn_hidden_size: int = 8960,
80
+ local_hidden_size: int = 1536,
81
+ local_num_layers: int = 4,
82
+ **kwargs,
83
+ ):
84
+ if isinstance(language_config, dict):
85
+ self.language_config = Qwen3Config(**language_config)
86
+ elif language_config is None:
87
+ self.language_config = Qwen3Config()
88
+ else:
89
+ self.language_config = language_config
90
+
91
+ self.initializer_range = initializer_range
92
+ self.n_vq = n_vq
93
+ self.audio_vocab_size = audio_vocab_size
94
+ self.audio_user_slot_token_id = audio_user_slot_token_id
95
+ self.audio_assistant_gen_slot_token_id = audio_assistant_gen_slot_token_id
96
+ self.audio_assistant_delay_slot_token_id = audio_assistant_delay_slot_token_id
97
+ self.audio_start_token_id = audio_start_token_id
98
+ self.audio_end_token_id = audio_end_token_id
99
+ self.audio_pad_code = audio_pad_code
100
+ self.sampling_rate = sampling_rate
101
+
102
+ self.hidden_size = self.language_config.hidden_size
103
+ self.vocab_size = self.language_config.vocab_size
104
+ self.im_start_token_id = self.language_config
105
+ self.pad_token_id = pad_token_id
106
+ self.im_start_token_id = im_start_token_id
107
+ self.im_end_token_id = im_end_token_id
108
+
109
+ self.additional_mlp_ffn_hidden_size = additional_mlp_ffn_hidden_size
110
+ self.local_ffn_hidden_size = local_ffn_hidden_size
111
+ self.local_hidden_size = local_hidden_size
112
+ self.local_num_layers = local_num_layers
113
+
114
+ super().__init__(**kwargs)
115
+
116
+ def to_dict(self):
117
+ output = super().to_dict()
118
+ if hasattr(self.language_config, "to_dict"):
119
+ output["language_config"] = self.language_config.to_dict()
120
+ else:
121
+ output["language_config"] = self.language_config
122
+ return output
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 151643,
4
+ "eos_token_id": 151645,
5
+ "transformers_version": "4.51.3"
6
+ }
inference_utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import torch.nn.functional as F
4
+ from typing import Optional, List, Tuple
5
+ from tqdm import tqdm
6
+
7
+
8
+ def apply_top_k(logits, top_k):
9
+ batch_size, vocab_size = logits.shape
10
+ top_k = min(top_k, vocab_size)
11
+ top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1)
12
+ filtered_logits = torch.full_like(logits, float("-inf"))
13
+ batch_indices = torch.arange(batch_size).unsqueeze(-1)
14
+ filtered_logits[batch_indices, top_k_indices] = top_k_values
15
+ return filtered_logits
16
+
17
+
18
+ def apply_top_p(logits, top_p):
19
+ probs = F.softmax(logits, dim=-1)
20
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
21
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
22
+ sorted_indices_to_remove = cumulative_probs > top_p
23
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
24
+ sorted_indices_to_remove[..., 0] = False
25
+ batch_size = logits.shape[0]
26
+ filtered_logits = logits.clone()
27
+ for i in range(batch_size):
28
+ indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
29
+ filtered_logits[i, indices_to_remove] = float("-inf")
30
+ return filtered_logits
31
+
32
+
33
+ def apply_top_p_optimized(logits, top_p):
34
+ probs = F.softmax(logits, dim=-1)
35
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
36
+
37
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
38
+
39
+ sorted_indices_to_remove = cumulative_probs > top_p
40
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
41
+ sorted_indices_to_remove[..., 0] = False
42
+
43
+ indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
44
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
45
+ )
46
+
47
+ logits[indices_to_remove] = float("-inf")
48
+ return logits
49
+
50
+
51
+ def apply_repetition_penalty_delay_pattern(
52
+ logits: torch.Tensor,
53
+ prev_tokens: torch.LongTensor,
54
+ penalty: float,
55
+ ):
56
+ """
57
+ logits: [B, H, V] or [N, V]
58
+ prev_tokens: [B, T, H] or [N, T] or [B, H]
59
+
60
+ Apply the repetition penalty independently for each H (VQ head).
61
+ """
62
+ if penalty == 1.0 or prev_tokens is None:
63
+ return logits
64
+
65
+ vocab_size = logits.size(-1)
66
+
67
+ # Case 1: regular [N, V] (text layer)
68
+ if logits.dim() == 2:
69
+ prev_tokens_flat = prev_tokens.reshape(-1)
70
+ unique_tokens = torch.unique(prev_tokens_flat)
71
+
72
+ token_logits = logits[:, unique_tokens]
73
+ pos_mask = token_logits > 0
74
+ token_logits[pos_mask] /= penalty
75
+ token_logits[~pos_mask] *= penalty
76
+ logits[:, unique_tokens] = token_logits
77
+ return logits
78
+
79
+ # Case 2: Delay Pattern audio [B, H, V]
80
+ assert logits.dim() == 3, "Delay Pattern audio logits must be [B, H, V]"
81
+ B, H, V = logits.shape
82
+
83
+ for h in range(H):
84
+ # prev_tokens_h: [B, T] or [B]
85
+ prev_tokens_h = prev_tokens[..., h].reshape(-1)
86
+ unique_tokens = torch.unique(prev_tokens_h)
87
+
88
+ if unique_tokens.numel() == 0:
89
+ continue
90
+
91
+ token_logits = logits[:, h, unique_tokens]
92
+ pos_mask = token_logits > 0
93
+ token_logits[pos_mask] /= penalty
94
+ token_logits[~pos_mask] *= penalty
95
+ logits[:, h, unique_tokens] = token_logits
96
+
97
+ return logits
98
+
99
+
100
+ def sample_token(
101
+ logits,
102
+ prev_tokens: Optional[torch.LongTensor] = None,
103
+ repetition_penalty: float = 1.0,
104
+ top_p=None,
105
+ top_k=None,
106
+ do_sample=True,
107
+ ):
108
+ vocab_size = logits.size(-1)
109
+
110
+ # ===== Repetition Penalty (before reshaping!) =====
111
+ if prev_tokens is not None and repetition_penalty != 1.0:
112
+ logits = apply_repetition_penalty_delay_pattern(
113
+ logits,
114
+ prev_tokens,
115
+ repetition_penalty,
116
+ )
117
+
118
+ if not do_sample:
119
+ return torch.argmax(logits, dim=-1)
120
+
121
+ # ===== Only flatten after this, for top-k / top-p / multinomial =====
122
+ original_shape = logits.shape
123
+ reshaped_logits = logits.view(-1, vocab_size)
124
+
125
+ if top_k is not None and top_k > 0:
126
+ reshaped_logits = apply_top_k(reshaped_logits, top_k)
127
+
128
+ if top_p is not None and top_p < 1.0:
129
+ reshaped_logits = apply_top_p_optimized(reshaped_logits, top_p)
130
+
131
+ probs = F.softmax(reshaped_logits, dim=-1)
132
+ next_tokens = torch.multinomial(probs, num_samples=1)
133
+
134
+ return next_tokens.view(original_shape[:-1])
135
+
136
+
137
+ def find_last_equal_C(tensor, C):
138
+ """
139
+ tensor: torch.Tensor of shape [batch_size, seq_len]
140
+ C: scalar value to match
141
+ Returns: torch.Tensor of shape [batch_size] with last indices
142
+ """
143
+ mask = (tensor == C).int() # Shape: [batch_size, seq_len], bool tensor
144
+ flipped_mask = mask.flip(dims=[1]) # Flip along sequence dimension
145
+ flipped_indices = flipped_mask.argmax(dim=1) # First True in flipped
146
+ seq_len = tensor.shape[1]
147
+ last_indices = (seq_len - 1) - flipped_indices # Convert to original indices
148
+
149
+ # Optional: Handle cases with no C (set to -1), though problem assumes existence
150
+ actual_values = tensor[torch.arange(tensor.shape[0]), last_indices]
151
+ no_match = actual_values != C
152
+ last_indices[no_match] = -1
153
+
154
+ return last_indices
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4de3d89661e92a9bf781117150ea7fb71de01d4bd8a80fd46379a5077645ae48
3
+ size 4999026432
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e91dd170b478f97349b5448f163f3ebd244d9847d79b7707af95dd97b8be5bea
3
+ size 1122255376
model.safetensors.index.json ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_parameters": 9859825152,
4
+ "total_size": 6121212928
5
+ },
6
+ "weight_map": {
7
+ "layer_norm_before_lm_heads.0.weight": "model-00002-of-00002.safetensors",
8
+ "layer_norm_before_lm_heads.1.weight": "model-00002-of-00002.safetensors",
9
+ "layer_norm_before_lm_heads.10.weight": "model-00002-of-00002.safetensors",
10
+ "layer_norm_before_lm_heads.11.weight": "model-00002-of-00002.safetensors",
11
+ "layer_norm_before_lm_heads.12.weight": "model-00002-of-00002.safetensors",
12
+ "layer_norm_before_lm_heads.13.weight": "model-00002-of-00002.safetensors",
13
+ "layer_norm_before_lm_heads.14.weight": "model-00002-of-00002.safetensors",
14
+ "layer_norm_before_lm_heads.15.weight": "model-00002-of-00002.safetensors",
15
+ "layer_norm_before_lm_heads.16.weight": "model-00002-of-00002.safetensors",
16
+ "layer_norm_before_lm_heads.17.weight": "model-00002-of-00002.safetensors",
17
+ "layer_norm_before_lm_heads.18.weight": "model-00002-of-00002.safetensors",
18
+ "layer_norm_before_lm_heads.19.weight": "model-00002-of-00002.safetensors",
19
+ "layer_norm_before_lm_heads.2.weight": "model-00002-of-00002.safetensors",
20
+ "layer_norm_before_lm_heads.20.weight": "model-00002-of-00002.safetensors",
21
+ "layer_norm_before_lm_heads.21.weight": "model-00002-of-00002.safetensors",
22
+ "layer_norm_before_lm_heads.22.weight": "model-00002-of-00002.safetensors",
23
+ "layer_norm_before_lm_heads.23.weight": "model-00002-of-00002.safetensors",
24
+ "layer_norm_before_lm_heads.24.weight": "model-00002-of-00002.safetensors",
25
+ "layer_norm_before_lm_heads.25.weight": "model-00002-of-00002.safetensors",
26
+ "layer_norm_before_lm_heads.26.weight": "model-00002-of-00002.safetensors",
27
+ "layer_norm_before_lm_heads.27.weight": "model-00002-of-00002.safetensors",
28
+ "layer_norm_before_lm_heads.28.weight": "model-00002-of-00002.safetensors",
29
+ "layer_norm_before_lm_heads.29.weight": "model-00002-of-00002.safetensors",
30
+ "layer_norm_before_lm_heads.3.weight": "model-00002-of-00002.safetensors",
31
+ "layer_norm_before_lm_heads.30.weight": "model-00002-of-00002.safetensors",
32
+ "layer_norm_before_lm_heads.31.weight": "model-00002-of-00002.safetensors",
33
+ "layer_norm_before_lm_heads.32.weight": "model-00002-of-00002.safetensors",
34
+ "layer_norm_before_lm_heads.4.weight": "model-00002-of-00002.safetensors",
35
+ "layer_norm_before_lm_heads.5.weight": "model-00002-of-00002.safetensors",
36
+ "layer_norm_before_lm_heads.6.weight": "model-00002-of-00002.safetensors",
37
+ "layer_norm_before_lm_heads.7.weight": "model-00002-of-00002.safetensors",
38
+ "layer_norm_before_lm_heads.8.weight": "model-00002-of-00002.safetensors",
39
+ "layer_norm_before_lm_heads.9.weight": "model-00002-of-00002.safetensors",
40
+ "lm_heads.0.weight": "model-00002-of-00002.safetensors",
41
+ "lm_heads.1.weight": "model-00002-of-00002.safetensors",
42
+ "lm_heads.10.weight": "model-00002-of-00002.safetensors",
43
+ "lm_heads.11.weight": "model-00002-of-00002.safetensors",
44
+ "lm_heads.12.weight": "model-00002-of-00002.safetensors",
45
+ "lm_heads.13.weight": "model-00002-of-00002.safetensors",
46
+ "lm_heads.14.weight": "model-00002-of-00002.safetensors",
47
+ "lm_heads.15.weight": "model-00002-of-00002.safetensors",
48
+ "lm_heads.16.weight": "model-00002-of-00002.safetensors",
49
+ "lm_heads.17.weight": "model-00002-of-00002.safetensors",
50
+ "lm_heads.18.weight": "model-00002-of-00002.safetensors",
51
+ "lm_heads.19.weight": "model-00002-of-00002.safetensors",
52
+ "lm_heads.2.weight": "model-00002-of-00002.safetensors",
53
+ "lm_heads.20.weight": "model-00002-of-00002.safetensors",
54
+ "lm_heads.21.weight": "model-00002-of-00002.safetensors",
55
+ "lm_heads.22.weight": "model-00002-of-00002.safetensors",
56
+ "lm_heads.23.weight": "model-00002-of-00002.safetensors",
57
+ "lm_heads.24.weight": "model-00002-of-00002.safetensors",
58
+ "lm_heads.25.weight": "model-00002-of-00002.safetensors",
59
+ "lm_heads.26.weight": "model-00002-of-00002.safetensors",
60
+ "lm_heads.27.weight": "model-00002-of-00002.safetensors",
61
+ "lm_heads.28.weight": "model-00002-of-00002.safetensors",
62
+ "lm_heads.29.weight": "model-00002-of-00002.safetensors",
63
+ "lm_heads.3.weight": "model-00002-of-00002.safetensors",
64
+ "lm_heads.30.weight": "model-00002-of-00002.safetensors",
65
+ "lm_heads.31.weight": "model-00002-of-00002.safetensors",
66
+ "lm_heads.32.weight": "model-00002-of-00002.safetensors",
67
+ "lm_heads.4.weight": "model-00002-of-00002.safetensors",
68
+ "lm_heads.5.weight": "model-00002-of-00002.safetensors",
69
+ "lm_heads.6.weight": "model-00002-of-00002.safetensors",
70
+ "lm_heads.7.weight": "model-00002-of-00002.safetensors",
71
+ "lm_heads.8.weight": "model-00002-of-00002.safetensors",
72
+ "lm_heads.9.weight": "model-00002-of-00002.safetensors",
73
+ "local_to_speech_embedding_mlps.0.down_proj.weight": "model-00001-of-00002.safetensors",
74
+ "local_to_speech_embedding_mlps.0.gate_proj.weight": "model-00001-of-00002.safetensors",
75
+ "local_to_speech_embedding_mlps.0.up_proj.weight": "model-00001-of-00002.safetensors",
76
+ "local_to_speech_embedding_mlps.1.down_proj.weight": "model-00001-of-00002.safetensors",
77
+ "local_to_speech_embedding_mlps.1.gate_proj.weight": "model-00001-of-00002.safetensors",
78
+ "local_to_speech_embedding_mlps.1.up_proj.weight": "model-00001-of-00002.safetensors",
79
+ "local_to_speech_embedding_mlps.10.down_proj.weight": "model-00001-of-00002.safetensors",
80
+ "local_to_speech_embedding_mlps.10.gate_proj.weight": "model-00001-of-00002.safetensors",
81
+ "local_to_speech_embedding_mlps.10.up_proj.weight": "model-00001-of-00002.safetensors",
82
+ "local_to_speech_embedding_mlps.11.down_proj.weight": "model-00001-of-00002.safetensors",
83
+ "local_to_speech_embedding_mlps.11.gate_proj.weight": "model-00001-of-00002.safetensors",
84
+ "local_to_speech_embedding_mlps.11.up_proj.weight": "model-00001-of-00002.safetensors",
85
+ "local_to_speech_embedding_mlps.12.down_proj.weight": "model-00001-of-00002.safetensors",
86
+ "local_to_speech_embedding_mlps.12.gate_proj.weight": "model-00001-of-00002.safetensors",
87
+ "local_to_speech_embedding_mlps.12.up_proj.weight": "model-00001-of-00002.safetensors",
88
+ "local_to_speech_embedding_mlps.13.down_proj.weight": "model-00001-of-00002.safetensors",
89
+ "local_to_speech_embedding_mlps.13.gate_proj.weight": "model-00001-of-00002.safetensors",
90
+ "local_to_speech_embedding_mlps.13.up_proj.weight": "model-00001-of-00002.safetensors",
91
+ "local_to_speech_embedding_mlps.14.down_proj.weight": "model-00001-of-00002.safetensors",
92
+ "local_to_speech_embedding_mlps.14.gate_proj.weight": "model-00001-of-00002.safetensors",
93
+ "local_to_speech_embedding_mlps.14.up_proj.weight": "model-00001-of-00002.safetensors",
94
+ "local_to_speech_embedding_mlps.15.down_proj.weight": "model-00001-of-00002.safetensors",
95
+ "local_to_speech_embedding_mlps.15.gate_proj.weight": "model-00001-of-00002.safetensors",
96
+ "local_to_speech_embedding_mlps.15.up_proj.weight": "model-00001-of-00002.safetensors",
97
+ "local_to_speech_embedding_mlps.16.down_proj.weight": "model-00002-of-00002.safetensors",
98
+ "local_to_speech_embedding_mlps.16.gate_proj.weight": "model-00001-of-00002.safetensors",
99
+ "local_to_speech_embedding_mlps.16.up_proj.weight": "model-00002-of-00002.safetensors",
100
+ "local_to_speech_embedding_mlps.17.down_proj.weight": "model-00002-of-00002.safetensors",
101
+ "local_to_speech_embedding_mlps.17.gate_proj.weight": "model-00002-of-00002.safetensors",
102
+ "local_to_speech_embedding_mlps.17.up_proj.weight": "model-00002-of-00002.safetensors",
103
+ "local_to_speech_embedding_mlps.18.down_proj.weight": "model-00002-of-00002.safetensors",
104
+ "local_to_speech_embedding_mlps.18.gate_proj.weight": "model-00002-of-00002.safetensors",
105
+ "local_to_speech_embedding_mlps.18.up_proj.weight": "model-00002-of-00002.safetensors",
106
+ "local_to_speech_embedding_mlps.19.down_proj.weight": "model-00002-of-00002.safetensors",
107
+ "local_to_speech_embedding_mlps.19.gate_proj.weight": "model-00002-of-00002.safetensors",
108
+ "local_to_speech_embedding_mlps.19.up_proj.weight": "model-00002-of-00002.safetensors",
109
+ "local_to_speech_embedding_mlps.2.down_proj.weight": "model-00001-of-00002.safetensors",
110
+ "local_to_speech_embedding_mlps.2.gate_proj.weight": "model-00001-of-00002.safetensors",
111
+ "local_to_speech_embedding_mlps.2.up_proj.weight": "model-00001-of-00002.safetensors",
112
+ "local_to_speech_embedding_mlps.20.down_proj.weight": "model-00002-of-00002.safetensors",
113
+ "local_to_speech_embedding_mlps.20.gate_proj.weight": "model-00002-of-00002.safetensors",
114
+ "local_to_speech_embedding_mlps.20.up_proj.weight": "model-00002-of-00002.safetensors",
115
+ "local_to_speech_embedding_mlps.21.down_proj.weight": "model-00002-of-00002.safetensors",
116
+ "local_to_speech_embedding_mlps.21.gate_proj.weight": "model-00002-of-00002.safetensors",
117
+ "local_to_speech_embedding_mlps.21.up_proj.weight": "model-00002-of-00002.safetensors",
118
+ "local_to_speech_embedding_mlps.22.down_proj.weight": "model-00002-of-00002.safetensors",
119
+ "local_to_speech_embedding_mlps.22.gate_proj.weight": "model-00002-of-00002.safetensors",
120
+ "local_to_speech_embedding_mlps.22.up_proj.weight": "model-00002-of-00002.safetensors",
121
+ "local_to_speech_embedding_mlps.23.down_proj.weight": "model-00002-of-00002.safetensors",
122
+ "local_to_speech_embedding_mlps.23.gate_proj.weight": "model-00002-of-00002.safetensors",
123
+ "local_to_speech_embedding_mlps.23.up_proj.weight": "model-00002-of-00002.safetensors",
124
+ "local_to_speech_embedding_mlps.24.down_proj.weight": "model-00002-of-00002.safetensors",
125
+ "local_to_speech_embedding_mlps.24.gate_proj.weight": "model-00002-of-00002.safetensors",
126
+ "local_to_speech_embedding_mlps.24.up_proj.weight": "model-00002-of-00002.safetensors",
127
+ "local_to_speech_embedding_mlps.25.down_proj.weight": "model-00002-of-00002.safetensors",
128
+ "local_to_speech_embedding_mlps.25.gate_proj.weight": "model-00002-of-00002.safetensors",
129
+ "local_to_speech_embedding_mlps.25.up_proj.weight": "model-00002-of-00002.safetensors",
130
+ "local_to_speech_embedding_mlps.26.down_proj.weight": "model-00002-of-00002.safetensors",
131
+ "local_to_speech_embedding_mlps.26.gate_proj.weight": "model-00002-of-00002.safetensors",
132
+ "local_to_speech_embedding_mlps.26.up_proj.weight": "model-00002-of-00002.safetensors",
133
+ "local_to_speech_embedding_mlps.27.down_proj.weight": "model-00002-of-00002.safetensors",
134
+ "local_to_speech_embedding_mlps.27.gate_proj.weight": "model-00002-of-00002.safetensors",
135
+ "local_to_speech_embedding_mlps.27.up_proj.weight": "model-00002-of-00002.safetensors",
136
+ "local_to_speech_embedding_mlps.28.down_proj.weight": "model-00002-of-00002.safetensors",
137
+ "local_to_speech_embedding_mlps.28.gate_proj.weight": "model-00002-of-00002.safetensors",
138
+ "local_to_speech_embedding_mlps.28.up_proj.weight": "model-00002-of-00002.safetensors",
139
+ "local_to_speech_embedding_mlps.29.down_proj.weight": "model-00002-of-00002.safetensors",
140
+ "local_to_speech_embedding_mlps.29.gate_proj.weight": "model-00002-of-00002.safetensors",
141
+ "local_to_speech_embedding_mlps.29.up_proj.weight": "model-00002-of-00002.safetensors",
142
+ "local_to_speech_embedding_mlps.3.down_proj.weight": "model-00001-of-00002.safetensors",
143
+ "local_to_speech_embedding_mlps.3.gate_proj.weight": "model-00001-of-00002.safetensors",
144
+ "local_to_speech_embedding_mlps.3.up_proj.weight": "model-00001-of-00002.safetensors",
145
+ "local_to_speech_embedding_mlps.30.down_proj.weight": "model-00002-of-00002.safetensors",
146
+ "local_to_speech_embedding_mlps.30.gate_proj.weight": "model-00002-of-00002.safetensors",
147
+ "local_to_speech_embedding_mlps.30.up_proj.weight": "model-00002-of-00002.safetensors",
148
+ "local_to_speech_embedding_mlps.31.down_proj.weight": "model-00002-of-00002.safetensors",
149
+ "local_to_speech_embedding_mlps.31.gate_proj.weight": "model-00002-of-00002.safetensors",
150
+ "local_to_speech_embedding_mlps.31.up_proj.weight": "model-00002-of-00002.safetensors",
151
+ "local_to_speech_embedding_mlps.32.down_proj.weight": "model-00002-of-00002.safetensors",
152
+ "local_to_speech_embedding_mlps.32.gate_proj.weight": "model-00002-of-00002.safetensors",
153
+ "local_to_speech_embedding_mlps.32.up_proj.weight": "model-00002-of-00002.safetensors",
154
+ "local_to_speech_embedding_mlps.4.down_proj.weight": "model-00001-of-00002.safetensors",
155
+ "local_to_speech_embedding_mlps.4.gate_proj.weight": "model-00001-of-00002.safetensors",
156
+ "local_to_speech_embedding_mlps.4.up_proj.weight": "model-00001-of-00002.safetensors",
157
+ "local_to_speech_embedding_mlps.5.down_proj.weight": "model-00001-of-00002.safetensors",
158
+ "local_to_speech_embedding_mlps.5.gate_proj.weight": "model-00001-of-00002.safetensors",
159
+ "local_to_speech_embedding_mlps.5.up_proj.weight": "model-00001-of-00002.safetensors",
160
+ "local_to_speech_embedding_mlps.6.down_proj.weight": "model-00001-of-00002.safetensors",
161
+ "local_to_speech_embedding_mlps.6.gate_proj.weight": "model-00001-of-00002.safetensors",
162
+ "local_to_speech_embedding_mlps.6.up_proj.weight": "model-00001-of-00002.safetensors",
163
+ "local_to_speech_embedding_mlps.7.down_proj.weight": "model-00001-of-00002.safetensors",
164
+ "local_to_speech_embedding_mlps.7.gate_proj.weight": "model-00001-of-00002.safetensors",
165
+ "local_to_speech_embedding_mlps.7.up_proj.weight": "model-00001-of-00002.safetensors",
166
+ "local_to_speech_embedding_mlps.8.down_proj.weight": "model-00001-of-00002.safetensors",
167
+ "local_to_speech_embedding_mlps.8.gate_proj.weight": "model-00001-of-00002.safetensors",
168
+ "local_to_speech_embedding_mlps.8.up_proj.weight": "model-00001-of-00002.safetensors",
169
+ "local_to_speech_embedding_mlps.9.down_proj.weight": "model-00001-of-00002.safetensors",
170
+ "local_to_speech_embedding_mlps.9.gate_proj.weight": "model-00001-of-00002.safetensors",
171
+ "local_to_speech_embedding_mlps.9.up_proj.weight": "model-00001-of-00002.safetensors",
172
+ "local_transformer.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
173
+ "local_transformer.layers.0.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
174
+ "local_transformer.layers.0.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
175
+ "local_transformer.layers.0.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
176
+ "local_transformer.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
177
+ "local_transformer.layers.0.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
178
+ "local_transformer.layers.0.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
179
+ "local_transformer.layers.0.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
180
+ "local_transformer.layers.0.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
181
+ "local_transformer.layers.0.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
182
+ "local_transformer.layers.0.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
183
+ "local_transformer.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
184
+ "local_transformer.layers.1.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
185
+ "local_transformer.layers.1.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
186
+ "local_transformer.layers.1.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
187
+ "local_transformer.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
188
+ "local_transformer.layers.1.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
189
+ "local_transformer.layers.1.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
190
+ "local_transformer.layers.1.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
191
+ "local_transformer.layers.1.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
192
+ "local_transformer.layers.1.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
193
+ "local_transformer.layers.1.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
194
+ "local_transformer.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
195
+ "local_transformer.layers.2.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
196
+ "local_transformer.layers.2.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
197
+ "local_transformer.layers.2.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
198
+ "local_transformer.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
199
+ "local_transformer.layers.2.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
200
+ "local_transformer.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
201
+ "local_transformer.layers.2.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
202
+ "local_transformer.layers.2.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
203
+ "local_transformer.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
204
+ "local_transformer.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
205
+ "local_transformer.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
206
+ "local_transformer.layers.3.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
207
+ "local_transformer.layers.3.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
208
+ "local_transformer.layers.3.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
209
+ "local_transformer.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
210
+ "local_transformer.layers.3.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
211
+ "local_transformer.layers.3.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
212
+ "local_transformer.layers.3.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
213
+ "local_transformer.layers.3.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
214
+ "local_transformer.layers.3.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
215
+ "local_transformer.layers.3.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
216
+ "local_transformer.norm.weight": "model-00001-of-00002.safetensors",
217
+ "model.embedding_list.0.weight": "model-00001-of-00002.safetensors",
218
+ "model.embedding_list.1.weight": "model-00001-of-00002.safetensors",
219
+ "model.embedding_list.10.weight": "model-00001-of-00002.safetensors",
220
+ "model.embedding_list.11.weight": "model-00001-of-00002.safetensors",
221
+ "model.embedding_list.12.weight": "model-00001-of-00002.safetensors",
222
+ "model.embedding_list.13.weight": "model-00001-of-00002.safetensors",
223
+ "model.embedding_list.14.weight": "model-00001-of-00002.safetensors",
224
+ "model.embedding_list.15.weight": "model-00001-of-00002.safetensors",
225
+ "model.embedding_list.16.weight": "model-00001-of-00002.safetensors",
226
+ "model.embedding_list.17.weight": "model-00001-of-00002.safetensors",
227
+ "model.embedding_list.18.weight": "model-00001-of-00002.safetensors",
228
+ "model.embedding_list.19.weight": "model-00001-of-00002.safetensors",
229
+ "model.embedding_list.2.weight": "model-00001-of-00002.safetensors",
230
+ "model.embedding_list.20.weight": "model-00001-of-00002.safetensors",
231
+ "model.embedding_list.21.weight": "model-00001-of-00002.safetensors",
232
+ "model.embedding_list.22.weight": "model-00001-of-00002.safetensors",
233
+ "model.embedding_list.23.weight": "model-00001-of-00002.safetensors",
234
+ "model.embedding_list.24.weight": "model-00001-of-00002.safetensors",
235
+ "model.embedding_list.25.weight": "model-00001-of-00002.safetensors",
236
+ "model.embedding_list.26.weight": "model-00001-of-00002.safetensors",
237
+ "model.embedding_list.27.weight": "model-00001-of-00002.safetensors",
238
+ "model.embedding_list.28.weight": "model-00001-of-00002.safetensors",
239
+ "model.embedding_list.29.weight": "model-00001-of-00002.safetensors",
240
+ "model.embedding_list.3.weight": "model-00001-of-00002.safetensors",
241
+ "model.embedding_list.30.weight": "model-00001-of-00002.safetensors",
242
+ "model.embedding_list.31.weight": "model-00001-of-00002.safetensors",
243
+ "model.embedding_list.32.weight": "model-00001-of-00002.safetensors",
244
+ "model.embedding_list.4.weight": "model-00001-of-00002.safetensors",
245
+ "model.embedding_list.5.weight": "model-00001-of-00002.safetensors",
246
+ "model.embedding_list.6.weight": "model-00001-of-00002.safetensors",
247
+ "model.embedding_list.7.weight": "model-00001-of-00002.safetensors",
248
+ "model.embedding_list.8.weight": "model-00001-of-00002.safetensors",
249
+ "model.embedding_list.9.weight": "model-00001-of-00002.safetensors",
250
+ "model.language_model.embed_tokens.weight": "model-00001-of-00002.safetensors",
251
+ "model.language_model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
252
+ "model.language_model.layers.0.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
253
+ "model.language_model.layers.0.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
254
+ "model.language_model.layers.0.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
255
+ "model.language_model.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
256
+ "model.language_model.layers.0.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
257
+ "model.language_model.layers.0.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
258
+ "model.language_model.layers.0.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
259
+ "model.language_model.layers.0.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
260
+ "model.language_model.layers.0.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
261
+ "model.language_model.layers.0.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
262
+ "model.language_model.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
263
+ "model.language_model.layers.1.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
264
+ "model.language_model.layers.1.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
265
+ "model.language_model.layers.1.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
266
+ "model.language_model.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
267
+ "model.language_model.layers.1.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
268
+ "model.language_model.layers.1.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
269
+ "model.language_model.layers.1.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
270
+ "model.language_model.layers.1.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
271
+ "model.language_model.layers.1.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
272
+ "model.language_model.layers.1.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
273
+ "model.language_model.layers.10.input_layernorm.weight": "model-00001-of-00002.safetensors",
274
+ "model.language_model.layers.10.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
275
+ "model.language_model.layers.10.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
276
+ "model.language_model.layers.10.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
277
+ "model.language_model.layers.10.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
278
+ "model.language_model.layers.10.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
279
+ "model.language_model.layers.10.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
280
+ "model.language_model.layers.10.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
281
+ "model.language_model.layers.10.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
282
+ "model.language_model.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
283
+ "model.language_model.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
284
+ "model.language_model.layers.11.input_layernorm.weight": "model-00001-of-00002.safetensors",
285
+ "model.language_model.layers.11.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
286
+ "model.language_model.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
287
+ "model.language_model.layers.11.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
288
+ "model.language_model.layers.11.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
289
+ "model.language_model.layers.11.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
290
+ "model.language_model.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
291
+ "model.language_model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
292
+ "model.language_model.layers.11.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
293
+ "model.language_model.layers.11.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
294
+ "model.language_model.layers.11.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
295
+ "model.language_model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
296
+ "model.language_model.layers.12.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
297
+ "model.language_model.layers.12.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
298
+ "model.language_model.layers.12.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
299
+ "model.language_model.layers.12.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
300
+ "model.language_model.layers.12.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
301
+ "model.language_model.layers.12.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
302
+ "model.language_model.layers.12.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
303
+ "model.language_model.layers.12.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
304
+ "model.language_model.layers.12.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
305
+ "model.language_model.layers.12.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
306
+ "model.language_model.layers.13.input_layernorm.weight": "model-00001-of-00002.safetensors",
307
+ "model.language_model.layers.13.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
308
+ "model.language_model.layers.13.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
309
+ "model.language_model.layers.13.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
310
+ "model.language_model.layers.13.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
311
+ "model.language_model.layers.13.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
312
+ "model.language_model.layers.13.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
313
+ "model.language_model.layers.13.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
314
+ "model.language_model.layers.13.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
315
+ "model.language_model.layers.13.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
316
+ "model.language_model.layers.13.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
317
+ "model.language_model.layers.14.input_layernorm.weight": "model-00001-of-00002.safetensors",
318
+ "model.language_model.layers.14.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
319
+ "model.language_model.layers.14.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
320
+ "model.language_model.layers.14.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
321
+ "model.language_model.layers.14.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
322
+ "model.language_model.layers.14.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
323
+ "model.language_model.layers.14.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
324
+ "model.language_model.layers.14.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
325
+ "model.language_model.layers.14.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
326
+ "model.language_model.layers.14.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
327
+ "model.language_model.layers.14.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
328
+ "model.language_model.layers.15.input_layernorm.weight": "model-00001-of-00002.safetensors",
329
+ "model.language_model.layers.15.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
330
+ "model.language_model.layers.15.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
331
+ "model.language_model.layers.15.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
332
+ "model.language_model.layers.15.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
333
+ "model.language_model.layers.15.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
334
+ "model.language_model.layers.15.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
335
+ "model.language_model.layers.15.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
336
+ "model.language_model.layers.15.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
337
+ "model.language_model.layers.15.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
338
+ "model.language_model.layers.15.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
339
+ "model.language_model.layers.16.input_layernorm.weight": "model-00001-of-00002.safetensors",
340
+ "model.language_model.layers.16.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
341
+ "model.language_model.layers.16.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
342
+ "model.language_model.layers.16.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
343
+ "model.language_model.layers.16.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
344
+ "model.language_model.layers.16.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
345
+ "model.language_model.layers.16.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
346
+ "model.language_model.layers.16.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
347
+ "model.language_model.layers.16.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
348
+ "model.language_model.layers.16.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
349
+ "model.language_model.layers.16.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
350
+ "model.language_model.layers.17.input_layernorm.weight": "model-00001-of-00002.safetensors",
351
+ "model.language_model.layers.17.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
352
+ "model.language_model.layers.17.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
353
+ "model.language_model.layers.17.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
354
+ "model.language_model.layers.17.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
355
+ "model.language_model.layers.17.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
356
+ "model.language_model.layers.17.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
357
+ "model.language_model.layers.17.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
358
+ "model.language_model.layers.17.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
359
+ "model.language_model.layers.17.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
360
+ "model.language_model.layers.17.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
361
+ "model.language_model.layers.18.input_layernorm.weight": "model-00001-of-00002.safetensors",
362
+ "model.language_model.layers.18.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
363
+ "model.language_model.layers.18.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
364
+ "model.language_model.layers.18.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
365
+ "model.language_model.layers.18.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
366
+ "model.language_model.layers.18.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
367
+ "model.language_model.layers.18.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
368
+ "model.language_model.layers.18.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
369
+ "model.language_model.layers.18.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
370
+ "model.language_model.layers.18.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
371
+ "model.language_model.layers.18.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
372
+ "model.language_model.layers.19.input_layernorm.weight": "model-00001-of-00002.safetensors",
373
+ "model.language_model.layers.19.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
374
+ "model.language_model.layers.19.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
375
+ "model.language_model.layers.19.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
376
+ "model.language_model.layers.19.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
377
+ "model.language_model.layers.19.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
378
+ "model.language_model.layers.19.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
379
+ "model.language_model.layers.19.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
380
+ "model.language_model.layers.19.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
381
+ "model.language_model.layers.19.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
382
+ "model.language_model.layers.19.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
383
+ "model.language_model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
384
+ "model.language_model.layers.2.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
385
+ "model.language_model.layers.2.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
386
+ "model.language_model.layers.2.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
387
+ "model.language_model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
388
+ "model.language_model.layers.2.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
389
+ "model.language_model.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
390
+ "model.language_model.layers.2.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
391
+ "model.language_model.layers.2.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
392
+ "model.language_model.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
393
+ "model.language_model.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
394
+ "model.language_model.layers.20.input_layernorm.weight": "model-00001-of-00002.safetensors",
395
+ "model.language_model.layers.20.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
396
+ "model.language_model.layers.20.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
397
+ "model.language_model.layers.20.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
398
+ "model.language_model.layers.20.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
399
+ "model.language_model.layers.20.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
400
+ "model.language_model.layers.20.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
401
+ "model.language_model.layers.20.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
402
+ "model.language_model.layers.20.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
403
+ "model.language_model.layers.20.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
404
+ "model.language_model.layers.20.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
405
+ "model.language_model.layers.21.input_layernorm.weight": "model-00001-of-00002.safetensors",
406
+ "model.language_model.layers.21.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
407
+ "model.language_model.layers.21.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
408
+ "model.language_model.layers.21.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
409
+ "model.language_model.layers.21.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
410
+ "model.language_model.layers.21.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
411
+ "model.language_model.layers.21.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
412
+ "model.language_model.layers.21.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
413
+ "model.language_model.layers.21.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
414
+ "model.language_model.layers.21.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
415
+ "model.language_model.layers.21.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
416
+ "model.language_model.layers.22.input_layernorm.weight": "model-00001-of-00002.safetensors",
417
+ "model.language_model.layers.22.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
418
+ "model.language_model.layers.22.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
419
+ "model.language_model.layers.22.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
420
+ "model.language_model.layers.22.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
421
+ "model.language_model.layers.22.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
422
+ "model.language_model.layers.22.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
423
+ "model.language_model.layers.22.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
424
+ "model.language_model.layers.22.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
425
+ "model.language_model.layers.22.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
426
+ "model.language_model.layers.22.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
427
+ "model.language_model.layers.23.input_layernorm.weight": "model-00001-of-00002.safetensors",
428
+ "model.language_model.layers.23.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
429
+ "model.language_model.layers.23.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
430
+ "model.language_model.layers.23.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
431
+ "model.language_model.layers.23.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
432
+ "model.language_model.layers.23.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
433
+ "model.language_model.layers.23.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
434
+ "model.language_model.layers.23.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
435
+ "model.language_model.layers.23.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
436
+ "model.language_model.layers.23.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
437
+ "model.language_model.layers.23.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
438
+ "model.language_model.layers.24.input_layernorm.weight": "model-00001-of-00002.safetensors",
439
+ "model.language_model.layers.24.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
440
+ "model.language_model.layers.24.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
441
+ "model.language_model.layers.24.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
442
+ "model.language_model.layers.24.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
443
+ "model.language_model.layers.24.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
444
+ "model.language_model.layers.24.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
445
+ "model.language_model.layers.24.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
446
+ "model.language_model.layers.24.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
447
+ "model.language_model.layers.24.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
448
+ "model.language_model.layers.24.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
449
+ "model.language_model.layers.25.input_layernorm.weight": "model-00001-of-00002.safetensors",
450
+ "model.language_model.layers.25.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
451
+ "model.language_model.layers.25.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
452
+ "model.language_model.layers.25.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
453
+ "model.language_model.layers.25.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
454
+ "model.language_model.layers.25.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
455
+ "model.language_model.layers.25.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
456
+ "model.language_model.layers.25.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
457
+ "model.language_model.layers.25.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
458
+ "model.language_model.layers.25.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
459
+ "model.language_model.layers.25.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
460
+ "model.language_model.layers.26.input_layernorm.weight": "model-00001-of-00002.safetensors",
461
+ "model.language_model.layers.26.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
462
+ "model.language_model.layers.26.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
463
+ "model.language_model.layers.26.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
464
+ "model.language_model.layers.26.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
465
+ "model.language_model.layers.26.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
466
+ "model.language_model.layers.26.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
467
+ "model.language_model.layers.26.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
468
+ "model.language_model.layers.26.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
469
+ "model.language_model.layers.26.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
470
+ "model.language_model.layers.26.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
471
+ "model.language_model.layers.27.input_layernorm.weight": "model-00001-of-00002.safetensors",
472
+ "model.language_model.layers.27.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
473
+ "model.language_model.layers.27.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
474
+ "model.language_model.layers.27.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
475
+ "model.language_model.layers.27.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
476
+ "model.language_model.layers.27.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
477
+ "model.language_model.layers.27.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
478
+ "model.language_model.layers.27.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
479
+ "model.language_model.layers.27.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
480
+ "model.language_model.layers.27.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
481
+ "model.language_model.layers.27.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
482
+ "model.language_model.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
483
+ "model.language_model.layers.3.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
484
+ "model.language_model.layers.3.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
485
+ "model.language_model.layers.3.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
486
+ "model.language_model.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
487
+ "model.language_model.layers.3.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
488
+ "model.language_model.layers.3.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
489
+ "model.language_model.layers.3.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
490
+ "model.language_model.layers.3.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
491
+ "model.language_model.layers.3.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
492
+ "model.language_model.layers.3.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
493
+ "model.language_model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
494
+ "model.language_model.layers.4.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
495
+ "model.language_model.layers.4.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
496
+ "model.language_model.layers.4.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
497
+ "model.language_model.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
498
+ "model.language_model.layers.4.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
499
+ "model.language_model.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
500
+ "model.language_model.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
501
+ "model.language_model.layers.4.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
502
+ "model.language_model.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
503
+ "model.language_model.layers.4.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
504
+ "model.language_model.layers.5.input_layernorm.weight": "model-00001-of-00002.safetensors",
505
+ "model.language_model.layers.5.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
506
+ "model.language_model.layers.5.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
507
+ "model.language_model.layers.5.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
508
+ "model.language_model.layers.5.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
509
+ "model.language_model.layers.5.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
510
+ "model.language_model.layers.5.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
511
+ "model.language_model.layers.5.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
512
+ "model.language_model.layers.5.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
513
+ "model.language_model.layers.5.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
514
+ "model.language_model.layers.5.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
515
+ "model.language_model.layers.6.input_layernorm.weight": "model-00001-of-00002.safetensors",
516
+ "model.language_model.layers.6.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
517
+ "model.language_model.layers.6.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
518
+ "model.language_model.layers.6.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
519
+ "model.language_model.layers.6.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
520
+ "model.language_model.layers.6.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
521
+ "model.language_model.layers.6.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
522
+ "model.language_model.layers.6.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
523
+ "model.language_model.layers.6.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
524
+ "model.language_model.layers.6.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
525
+ "model.language_model.layers.6.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
526
+ "model.language_model.layers.7.input_layernorm.weight": "model-00001-of-00002.safetensors",
527
+ "model.language_model.layers.7.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
528
+ "model.language_model.layers.7.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
529
+ "model.language_model.layers.7.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
530
+ "model.language_model.layers.7.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
531
+ "model.language_model.layers.7.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
532
+ "model.language_model.layers.7.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
533
+ "model.language_model.layers.7.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
534
+ "model.language_model.layers.7.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
535
+ "model.language_model.layers.7.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
536
+ "model.language_model.layers.7.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
537
+ "model.language_model.layers.8.input_layernorm.weight": "model-00001-of-00002.safetensors",
538
+ "model.language_model.layers.8.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
539
+ "model.language_model.layers.8.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
540
+ "model.language_model.layers.8.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
541
+ "model.language_model.layers.8.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
542
+ "model.language_model.layers.8.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
543
+ "model.language_model.layers.8.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
544
+ "model.language_model.layers.8.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
545
+ "model.language_model.layers.8.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
546
+ "model.language_model.layers.8.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
547
+ "model.language_model.layers.8.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
548
+ "model.language_model.layers.9.input_layernorm.weight": "model-00001-of-00002.safetensors",
549
+ "model.language_model.layers.9.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
550
+ "model.language_model.layers.9.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
551
+ "model.language_model.layers.9.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
552
+ "model.language_model.layers.9.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
553
+ "model.language_model.layers.9.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
554
+ "model.language_model.layers.9.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
555
+ "model.language_model.layers.9.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
556
+ "model.language_model.layers.9.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
557
+ "model.language_model.layers.9.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
558
+ "model.language_model.layers.9.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
559
+ "model.language_model.norm.weight": "model-00001-of-00002.safetensors",
560
+ "speech_embedding_to_local_mlp.down_proj.weight": "model-00001-of-00002.safetensors",
561
+ "speech_embedding_to_local_mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
562
+ "speech_embedding_to_local_mlp.up_proj.weight": "model-00001-of-00002.safetensors"
563
+ }
564
+ }
modeling_moss_tts.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import torch
4
+ import torch.nn as nn
5
+ import logging
6
+ import sys
7
+
8
+ from tqdm import tqdm
9
+ from dataclasses import dataclass
10
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
11
+ from transformers.utils import ModelOutput
12
+ from transformers.cache_utils import Cache
13
+ from typing import Optional, List, Tuple, Union
14
+ from transformers.loss.loss_utils import ForCausalLMLoss
15
+ from transformers import PreTrainedModel, GenerationMixin
16
+ from transformers.generation.streamers import BaseStreamer
17
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3Model, Qwen3Attention, eager_attention_forward
18
+ from transformers.modeling_outputs import BaseModelOutputWithPast
19
+ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
20
+ from transformers.generation.configuration_utils import GenerationConfig
21
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
22
+ from transformers.generation.logits_process import LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper
23
+ from transformers.masking_utils import create_causal_mask
24
+
25
+ from .inference_utils import find_last_equal_C
26
+ from .configuration_moss_tts import MossTTSDelayConfig
27
+
28
+ import math
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+
33
+
34
+ class MossTTSRMSNorm(nn.Module):
35
+ def __init__(self, dim: int, eps: float = 1e-6):
36
+ super().__init__()
37
+ self.eps = eps
38
+ self.weight = nn.Parameter(torch.ones(dim))
39
+
40
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
41
+ # x: [..., dim]
42
+ norm = x.pow(2).mean(dim=-1, keepdim=True)
43
+ x = x * torch.rsqrt(norm + self.eps)
44
+ return x * self.weight
45
+
46
+
47
+ class MossTTSMLP(nn.Module):
48
+ """
49
+ HF-style MLP adapter equivalent to Megatron's SwiGLU FFN:
50
+ in: input_size
51
+ mid: ffn_hidden_size
52
+ out: output_size
53
+
54
+ Computes:
55
+ y = down( silu(gate(x)) * up(x) )
56
+
57
+ Optionally includes a pre-norm on input (common in Megatron blocks).
58
+ """
59
+ def __init__(
60
+ self,
61
+ input_size: int,
62
+ ffn_hidden_size: int,
63
+ output_size: int,
64
+ bias: bool = False,
65
+ prenorm: bool = False,
66
+ norm_eps: float = 1e-6,
67
+ use_rmsnorm: bool = True,
68
+ ):
69
+ super().__init__()
70
+
71
+ self.prenorm = prenorm
72
+ if prenorm:
73
+ if use_rmsnorm:
74
+ self.norm = MossTTSRMSNorm(input_size, eps=norm_eps)
75
+ else:
76
+ self.norm = nn.LayerNorm(input_size, eps=norm_eps)
77
+ else:
78
+ self.norm = None
79
+
80
+ # SwiGLU uses two projections to ffn_hidden_size: gate and up
81
+ self.gate_proj = nn.Linear(input_size, ffn_hidden_size, bias=bias)
82
+ self.up_proj = nn.Linear(input_size, ffn_hidden_size, bias=bias)
83
+
84
+ # down projection to output_size (note: output can differ from input)
85
+ self.down_proj = nn.Linear(ffn_hidden_size, output_size, bias=bias)
86
+
87
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
88
+ if self.norm is not None:
89
+ x = self.norm(x)
90
+
91
+ gate = self.gate_proj(x)
92
+ up = self.up_proj(x)
93
+ h = F.silu(gate) * up
94
+ y = self.down_proj(h)
95
+ return y
96
+
97
+ def moss_tts_masked_embedding(embedding: nn.Embedding,
98
+ input_ids: torch.LongTensor,
99
+ ignore_index: int = -100) -> torch.Tensor:
100
+ """
101
+ 对 input_ids 中 != ignore_index 的位置做 embedding,ignore_index 的位置输出全 0 向量。
102
+
103
+ Args:
104
+ embedding: 一个 nn.Embedding 层
105
+ input_ids: 任意形状的 LongTensor,里面允许出现 ignore_index
106
+ ignore_index: 需要被忽略的位置标记(默认 -100)
107
+
108
+ Returns:
109
+ embeddings: 形状为 (*input_ids.shape, embedding.embedding_dim) 的张量
110
+ """
111
+ # mask: True 表示需要正常 embedding,False 表示输出 0
112
+ mask = (input_ids != ignore_index) # shape: [...]
113
+
114
+ # 为了避免 -100 这种非法 index 传进 embedding,这里先临时替换掉
115
+ safe_ids = input_ids.clone()
116
+ safe_ids[~mask] = 0
117
+
118
+ # 正常过 embedding
119
+ out = embedding(safe_ids) # shape: [..., dim]
120
+
121
+ # 把 ignore_index 对应的位置置 0
122
+ out[~mask] = 0.0
123
+
124
+ return out
125
+
126
+ class MossTTSAttentionWithoutPositionalEmbedding(Qwen3Attention):
127
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
128
+
129
+ def __init__(self, config: MossTTSDelayConfig, layer_idx: int):
130
+ super().__init__(config, layer_idx)
131
+
132
+
133
+ def forward(
134
+ self,
135
+ hidden_states: torch.Tensor,
136
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
137
+ attention_mask: Optional[torch.Tensor],
138
+ past_key_value: Optional[Cache] = None,
139
+ cache_position: Optional[torch.LongTensor] = None,
140
+ **kwargs,
141
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
142
+ input_shape = hidden_states.shape[:-1]
143
+ hidden_shape = (*input_shape, -1, self.head_dim)
144
+
145
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
146
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
147
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
148
+
149
+ assert past_key_value is None
150
+
151
+ attention_interface = eager_attention_forward
152
+ if self.config._attn_implementation != "eager":
153
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
154
+ print(
155
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
156
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
157
+ )
158
+ else:
159
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
160
+
161
+ attn_output, attn_weights = attention_interface(
162
+ self,
163
+ query_states,
164
+ key_states,
165
+ value_states,
166
+ is_causal=True,
167
+ attention_mask=None,
168
+ dropout=0.0 if not self.training else self.attention_dropout,
169
+ scaling=self.scaling,
170
+ sliding_window=self.sliding_window, # diff with Llama
171
+ **kwargs,
172
+ )
173
+
174
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
175
+ attn_output = self.o_proj(attn_output)
176
+ return attn_output, attn_weights
177
+
178
+ class MossTTSLocalTransformer(Qwen3Model):
179
+ def __init__(self, config: MossTTSDelayConfig):
180
+ super().__init__(config)
181
+ del self.rotary_emb
182
+ del self.embed_tokens
183
+ for layer_idx in range(config.num_hidden_layers):
184
+ self.layers[layer_idx].self_attn = MossTTSAttentionWithoutPositionalEmbedding(config, layer_idx)
185
+ self.post_init()
186
+
187
+ def forward(
188
+ self,
189
+ input_ids: Optional[torch.LongTensor] = None,
190
+ attention_mask: Optional[torch.Tensor] = None,
191
+ position_ids: Optional[torch.LongTensor] = None,
192
+ past_key_values: Optional[Cache] = None,
193
+ inputs_embeds: Optional[torch.FloatTensor] = None,
194
+ use_cache: Optional[bool] = None,
195
+ output_attentions: Optional[bool] = None,
196
+ output_hidden_states: Optional[bool] = None,
197
+ cache_position: Optional[torch.LongTensor] = None,
198
+ **flash_attn_kwargs,
199
+ ) -> BaseModelOutputWithPast:
200
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
201
+ output_hidden_states = (
202
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
203
+ )
204
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
205
+ use_cache = False
206
+ assert not use_cache
207
+
208
+ if (input_ids is None) ^ (inputs_embeds is not None):
209
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
210
+
211
+ if self.gradient_checkpointing and self.training and use_cache:
212
+ print(
213
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
214
+ )
215
+ use_cache = False
216
+
217
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
218
+ if not isinstance(past_key_values, (type(None), Cache)):
219
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
220
+
221
+ if inputs_embeds is None:
222
+ inputs_embeds = self.embed_tokens(input_ids)
223
+
224
+ if use_cache and past_key_values is None:
225
+ assert False
226
+ past_key_values = DynamicCache()
227
+
228
+ if cache_position is None:
229
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
230
+ cache_position = torch.arange(
231
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
232
+ )
233
+
234
+ if position_ids is None:
235
+ position_ids = cache_position.unsqueeze(0)
236
+
237
+ # causal_mask = self._update_causal_mask( # ???
238
+ # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
239
+ # )
240
+ mask_kwargs = {
241
+ "config": self.config,
242
+ "input_embeds": inputs_embeds,
243
+ "attention_mask": attention_mask,
244
+ "cache_position": cache_position,
245
+ "past_key_values": past_key_values,
246
+ "position_ids": position_ids,
247
+ }
248
+ causal_mask = create_causal_mask(**mask_kwargs),
249
+
250
+
251
+ hidden_states = inputs_embeds
252
+
253
+ # create position embeddings to be shared across the decoder layers
254
+ # position_embeddings = self.rotary_emb(hidden_states, position_ids)
255
+
256
+ # decoder layers
257
+ all_hidden_states = () if output_hidden_states else None
258
+ all_self_attns = () if output_attentions else None
259
+
260
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
261
+ if output_hidden_states:
262
+ all_hidden_states += (hidden_states,)
263
+
264
+ layer_outputs = decoder_layer(
265
+ hidden_states,
266
+ attention_mask=causal_mask,
267
+ position_ids=None,
268
+ past_key_value=None,
269
+ output_attentions=output_attentions,
270
+ use_cache=use_cache,
271
+ cache_position=None,
272
+ position_embeddings=None,
273
+ **flash_attn_kwargs,
274
+ )
275
+
276
+ hidden_states = layer_outputs
277
+
278
+ if output_attentions:
279
+ all_self_attns += (layer_outputs[1],)
280
+
281
+ hidden_states = self.norm(hidden_states)
282
+
283
+ # add hidden states from the last decoder layer
284
+ if output_hidden_states:
285
+ all_hidden_states += (hidden_states,)
286
+
287
+ return BaseModelOutputWithPast(
288
+ last_hidden_state=hidden_states,
289
+ past_key_values=past_key_values if use_cache else None,
290
+ hidden_states=all_hidden_states,
291
+ attentions=all_self_attns,
292
+ )
293
+
294
+ @dataclass
295
+ class MosiTTSOutputWithPast(ModelOutput):
296
+ loss: Optional[torch.FloatTensor] = None
297
+ logits: torch.FloatTensor = None
298
+ loss_all: Optional[Tuple[torch.FloatTensor]] = None
299
+ logits_all: Optional[Tuple[torch.FloatTensor]] = None
300
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
301
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
302
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
303
+
304
+
305
+ @dataclass
306
+ class MossTTSGenerateDecoderOnlyOutput(ModelOutput):
307
+ sequences: torch.LongTensor = None
308
+ scores: Optional[Tuple[torch.FloatTensor]] = None
309
+ logits: Optional[Tuple[torch.FloatTensor]] = None
310
+ attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
311
+ hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
312
+ past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
313
+
314
+
315
+ class CustomMixin(GenerationMixin): # TODO 待检查正确性
316
+ def _sample(
317
+ self,
318
+ input_ids: torch.LongTensor, # (B, T, 1+Nq)
319
+ logits_processor: LogitsProcessorList,
320
+ stopping_criteria: StoppingCriteriaList,
321
+ generation_config: GenerationConfig,
322
+ synced_gpus: bool,
323
+ streamer: Optional["BaseStreamer"] = None,
324
+ **model_kwargs,
325
+ ) -> Union[MossTTSGenerateDecoderOnlyOutput, torch.LongTensor]:
326
+ # 提取配置参数
327
+ # assert False
328
+ speech_pad_idx = self.config.audio_pad_code
329
+ device = input_ids.device
330
+ eos_token_id = generation_config.eos_token_id
331
+ output_attentions = generation_config.output_attentions
332
+ output_hidden_states = generation_config.output_hidden_states
333
+ output_scores = generation_config.output_scores
334
+ output_logits = generation_config.output_logits
335
+ return_dict_in_generate = generation_config.return_dict_in_generate
336
+ max_length = generation_config.max_length
337
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
338
+ do_sample = generation_config.do_sample
339
+
340
+ # 初始化输出元组
341
+ scores = () if (return_dict_in_generate and output_scores) else None
342
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
343
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
344
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
345
+
346
+ # 初始化跟踪变量
347
+ batch_size, cur_len, channels = input_ids.shape # channels = 8
348
+ input_ids_length = cur_len
349
+ # assert batch_size == 1
350
+ this_peer_finished = False
351
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) # (B, )
352
+ base_length = input_ids.shape[1]
353
+ model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
354
+ # model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
355
+
356
+ # 定义logits processor
357
+ if generation_config.do_samples is not None:
358
+ do_samples = generation_config.do_samples
359
+ realprocessor = [LogitsProcessorList() for _ in range(channels)]
360
+ for i, layer_config in enumerate(generation_config.layers):
361
+ if not do_samples[i]:
362
+ continue
363
+ if layer_config.get("repetition_penalty") is not None and i != 0: # 文本层不用重复惩罚
364
+ realprocessor[i].append(RepetitionPenaltyLogitsProcessor(penalty=layer_config.get("repetition_penalty")))
365
+ if layer_config.get("temperature") is not None:
366
+ realprocessor[i].append(TemperatureLogitsWarper(temperature=layer_config.get("temperature")))
367
+ if layer_config.get("top_k") is not None:
368
+ realprocessor[i].append(TopKLogitsWarper(top_k=layer_config.get("top_k")))
369
+ if layer_config.get("top_p") is not None:
370
+ realprocessor[i].append(TopPLogitsWarper(top_p=layer_config.get("top_p")))
371
+ else:
372
+ assert False
373
+ do_samples = [do_sample for _ in range(channels)]
374
+ realprocessor = [logits_processor for _ in range(channels)]
375
+
376
+ pbar = tqdm()
377
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
378
+ # 准备模型输入
379
+ pbar.update()
380
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
381
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
382
+ model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
383
+ # 前向传递
384
+ outputs = self(**model_inputs, n_vq_for_inference=generation_config.n_vq_for_inference, return_dict=True, output_hidden_states=True)
385
+ model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
386
+
387
+ if synced_gpus and this_peer_finished:
388
+ continue
389
+
390
+ global_trm_output_hidden_states = outputs.hidden_states[-1][:, -1, :] # (B, D)
391
+ dtype = global_trm_output_hidden_states.dtype
392
+
393
+ local_trm_dim = self.local_transformer_config.hidden_size
394
+ local_transformer_inputs = torch.zeros(batch_size, 0, local_trm_dim).to(device).to(dtype) # (B, 0 <= t <= Nq, D), 维护当前 local trm 的输入
395
+ current_local_transformer_input = self.speech_embedding_to_local_mlp(global_trm_output_hidden_states) # (B, D) 维护当前 timestamp 的 local trm 的输入,
396
+
397
+ next_tokens = [] # 1+Nq * (B, )
398
+ # n_vq_for_inference = int(os.environ['N_VQ_FOR_INFERENCE'])
399
+ n_vq_for_inference = generation_config.n_vq_for_inference
400
+ for layer_index in range(min(channels, 1 + n_vq_for_inference)):
401
+ local_transformer_inputs = torch.cat([local_transformer_inputs, current_local_transformer_input.unsqueeze(1)], dim=1) # (B, t, D)
402
+ local_transformer_outputs = self.local_transformer(
403
+ input_ids=None,
404
+ attention_mask=None,
405
+ inputs_embeds=local_transformer_inputs # (B, t=1+Nq, D)
406
+ )[0] # (B, t=1+Nq, D)
407
+ local_transformer_outputs = self.layer_norm_before_lm_heads[layer_index](
408
+ self.local_to_speech_embedding_mlps[layer_index](local_transformer_outputs) # (B, t=1+Nq, D)
409
+ ) # (B, t=1+Nq, D)
410
+
411
+ next_token_logit = self.lm_heads[layer_index](local_transformer_outputs[:, -1, :]) # (B, V)
412
+ if layer_index != 0:
413
+ next_token_logit[:, speech_pad_idx] = -torch.inf
414
+ next_token_score = realprocessor[layer_index](input_ids[..., layer_index], next_token_logit) # (B, V)
415
+
416
+ if do_samples[layer_index]:
417
+ channel_ntk = torch.multinomial(nn.functional.softmax(next_token_score, dim=-1), num_samples=1).squeeze(1) # (B, )
418
+ else:
419
+ channel_ntk = torch.argmax(next_token_score, dim=-1) # (B, )
420
+
421
+ next_tokens.append(channel_ntk) # 1+Nq * (B, )
422
+ current_local_transformer_input = self.model.embedding_list[layer_index](channel_ntk) # (B, D)
423
+ current_local_transformer_input = self.speech_embedding_to_local_mlp(current_local_transformer_input) # (B, D)
424
+
425
+ for layer_index in range(1 + n_vq_for_inference, channels):
426
+ next_tokens.append(torch.zeros((batch_size, )).to(torch.int).to(device))
427
+ next_tokens = torch.stack(next_tokens, dim=-1) # (B, 1+Nq)
428
+
429
+ if has_eos_stopping_criteria:
430
+ for i in range(channels):
431
+ pddp = eos_token_id if i == 0 else speech_pad_idx
432
+ next_tokens[:, i] = next_tokens[:, i] * unfinished_sequences + pddp * (1 - unfinished_sequences)
433
+
434
+ input_ids = torch.cat([input_ids, next_tokens[:, None, :]], dim=1) # (B, T, 1+Nq)
435
+ if streamer is not None:
436
+ streamer.put(next_tokens[:, 0].cpu())
437
+
438
+ stopping = stopping_criteria(input_ids[..., 0], scores)
439
+ # stopping = stopping_criteria(input_ids[..., 0], scores)
440
+ unfinished_sequences = unfinished_sequences & ~stopping
441
+ this_peer_finished = unfinished_sequences.max() == 0
442
+
443
+ if return_dict_in_generate:
444
+ if output_scores:
445
+ assert False
446
+ scores += (next_token_scores,)
447
+ if output_logits:
448
+ assert False
449
+ raw_logits += (next_token_logits,)
450
+ if output_attentions:
451
+ decoder_attentions += (outputs.attentions,)
452
+ if output_hidden_states:
453
+ decoder_hidden_states += (outputs.hidden_states,)
454
+
455
+ cur_len += 1
456
+ del outputs
457
+
458
+ if streamer is not None:
459
+ streamer.end()
460
+
461
+ if return_dict_in_generate:
462
+ return MossTTSGenerateDecoderOnlyOutput(
463
+ sequences=input_ids,
464
+ scores=scores,
465
+ logits=raw_logits,
466
+ attentions=decoder_attentions,
467
+ hidden_states=decoder_hidden_states,
468
+ past_key_values=model_kwargs.get("past_key_values"),
469
+ )
470
+ else:
471
+ start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id)
472
+ start_lengths = input_ids_length - start_indices - 1 # voice clone 下是 0,续写情况下是 prompt 音频的长度,不包括 audio_start_token
473
+ output = []
474
+ for start_idx, start_length, cur_generation_ids in zip(start_indices, start_lengths, input_ids):
475
+ output.append((start_length, cur_generation_ids[start_idx:]))
476
+
477
+ return output
478
+
479
+
480
+ class MosiTTSPretrainedModel(PreTrainedModel):
481
+ config_class = MossTTSDelayConfig
482
+ base_model_prefix = "model"
483
+ supports_gradient_checkpointing = True
484
+ _no_split_modules = ["Qwen2DecoderLayer"]
485
+ _skip_keys_device_placement = ["past_key_values"]
486
+ _supports_flash_attn_2 = True
487
+ _supports_sdpa = True
488
+ _supports_flex_attn = True
489
+ _supports_cache_class = True
490
+ _supports_quantized_cache = True
491
+ _supports_static_cache = True
492
+ _supports_attention_backend = True
493
+
494
+
495
+ class MosiTTSModel(MosiTTSPretrainedModel):
496
+ def __init__(self, config: MossTTSDelayConfig):
497
+ super().__init__(config)
498
+ self.text_pad_idx = config.pad_token_id
499
+ self.speech_pad_idx = config.audio_pad_code
500
+ self.embedding_list = nn.ModuleList([])
501
+ self.embedding_list.append(nn.Embedding(config.vocab_size, config.hidden_size, self.text_pad_idx))
502
+ self.channels = 1 + config.n_vq
503
+ for _ in range(1, self.channels):
504
+ self.embedding_list.append(nn.Embedding(config.audio_vocab_size + 1, config.hidden_size, self.speech_pad_idx))
505
+
506
+ self.language_model = Qwen3Model(config.language_config)
507
+ self.post_init()
508
+
509
+ def get_input_embeddings(self):
510
+ return self.embedding_list[0]
511
+
512
+ def set_input_embeddings(self, value: nn.Embedding):
513
+ self.embedding_list[0] = value
514
+
515
+ def _prepare_multi_modal_inputs(self, input_ids: torch.LongTensor, n_vq_for_inference: int, **kwargs) -> torch.FloatTensor:
516
+ """
517
+ Prepares multi-modal embeddings from input_ids of shape (batch_size, channels, sequence_length).
518
+ For channel 0: text + speech tokens, for channels 1 to channels-1: speech tokens padded with speech_pad_token.
519
+ """
520
+ batch_size, seq_length, channels = input_ids.shape
521
+ if channels != self.channels:
522
+ raise ValueError(f"Expected {self.config.channels} channels, got {channels}")
523
+
524
+ inputs_embeds = torch.zeros(batch_size, seq_length, self.config.hidden_size, device=input_ids.device, dtype=self.embedding_list[0].weight.dtype)
525
+ for i in range(min(channels, 1 + n_vq_for_inference)):
526
+ embed_layer = self.embedding_list[i]
527
+ channel_input = input_ids[...,i]
528
+ inputs_embeds += embed_layer(channel_input)
529
+
530
+ return inputs_embeds # (B, T, D)
531
+
532
+ def forward(
533
+ self,
534
+ input_ids: torch.LongTensor = None, # Shape: (batch_size, channels, sequence_length)
535
+ attention_mask: Optional[torch.Tensor] = None,
536
+ position_ids: Optional[torch.LongTensor] = None,
537
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
538
+ inputs_embeds: Optional[torch.FloatTensor] = None,
539
+ use_cache: Optional[bool] = None,
540
+ output_attentions: Optional[bool] = None,
541
+ output_hidden_states: Optional[bool] = None,
542
+ return_dict: Optional[bool] = None,
543
+ cache_position: Optional[torch.LongTensor] = None,
544
+ **kwargs,
545
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
546
+
547
+ if (input_ids is None) ^ (inputs_embeds is not None):
548
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
549
+
550
+ if input_ids is not None:
551
+ inputs_embeds = self._prepare_multi_modal_inputs(input_ids, **kwargs) # (B, T, D)
552
+
553
+ outputs = self.language_model(
554
+ input_ids=None,
555
+ attention_mask=attention_mask,
556
+ position_ids=position_ids,
557
+ past_key_values=past_key_values,
558
+ inputs_embeds=inputs_embeds,
559
+ use_cache=use_cache,
560
+ output_attentions=output_attentions,
561
+ output_hidden_states=output_hidden_states,
562
+ return_dict=return_dict,
563
+ cache_position=cache_position,
564
+ )
565
+ return outputs
566
+
567
+
568
+ class MossTTSDelayModel(MosiTTSPretrainedModel, CustomMixin):
569
+ _tied_weights_keys = []
570
+ _tp_plan = {"lm_head": "colwise_rep"}
571
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
572
+
573
+ def __init__(self, config: MossTTSDelayConfig):
574
+ super().__init__(config)
575
+ self.model = MosiTTSModel(config)
576
+ self.channels = 1 + config.n_vq
577
+ self.weights = [1 for _ in range(self.channels)]
578
+ self._tied_weights_keys = [f"lm_heads.{i}.weight" for i in range(self.channels)]
579
+ self.vocab_size = config.vocab_size
580
+
581
+ local_transformer_config = copy.deepcopy(config.language_config)
582
+ local_transformer_config.num_hidden_layers = config.local_num_layers
583
+ local_transformer_config.hidden_size = config.local_hidden_size
584
+ local_transformer_config.intermediate_size = config.local_ffn_hidden_size
585
+ self.local_transformer_config = local_transformer_config
586
+ self.local_transformer = MossTTSLocalTransformer(self.local_transformer_config)
587
+
588
+ self.speech_embedding_to_local_mlp = MossTTSMLP(
589
+ input_size=config.hidden_size,
590
+ ffn_hidden_size=config.additional_mlp_ffn_hidden_size,
591
+ output_size=config.local_hidden_size
592
+ )
593
+ self.local_to_speech_embedding_mlps = nn.ModuleList([
594
+ MossTTSMLP(
595
+ input_size=config.local_hidden_size,
596
+ ffn_hidden_size=config.additional_mlp_ffn_hidden_size,
597
+ output_size=config.hidden_size
598
+ )
599
+ for _ in range(self.channels)
600
+ ])
601
+
602
+ self.layer_norm_before_lm_heads = nn.ModuleList([
603
+ MossTTSRMSNorm(config.hidden_size)
604
+ for _ in range(self.channels)
605
+ ])
606
+
607
+ self.lm_heads = nn.ModuleList([])
608
+ self.lm_heads.append(nn.Linear(config.hidden_size, config.vocab_size, bias=False))
609
+ for _ in range(1, self.channels):
610
+ self.lm_heads.append(nn.Linear(config.hidden_size, 1 + config.audio_vocab_size, bias=False))
611
+ self.post_init()
612
+
613
+ def get_input_embeddings(self):
614
+ return self.model.embedding_list[0]
615
+
616
+ def can_generate(self):
617
+ return True
618
+
619
+ # def tie_weights(self):
620
+ # ...
621
+ # for i in range(self.config.channels):
622
+ # self._tie_or_clone_weights(self.lm_heads[i], self.model.embedding_list[i])
623
+
624
+ def set_input_embeddings(self, value):
625
+ self.model.embedding_list[0] = value
626
+
627
+ def get_output_embeddings(self):
628
+ return self.lm_heads[0]
629
+
630
+ def set_output_embeddings(self, new_embeddings):
631
+ self.lm_heads[0] = new_embeddings
632
+
633
+ def set_decoder(self, decoder):
634
+ self.model = decoder
635
+
636
+ def get_decoder(self):
637
+ return self.model
638
+
639
+ def set_weights(self, weights):
640
+ self.weights = weights
641
+
642
+ def _prepare_shifted_audio_inputs(self, label_ids): # (B, T, 1 + Nq) 可能有 -100
643
+ text_and_audio_label_embed_list = [] # Nq * (1, T, B, D)
644
+ for i in range(0, self.local_transformer_config.channels - 1):
645
+ text_and_audio_label_embed_list.append(
646
+ moss_tts_masked_embedding(self.model.embedding_list[i], label_ids[:, :, i]).unsqueeze(0).transpose(1, 2) # (B, T) -> (B, T, D) -> (1, B, T, D) -> (1, T, B, D)
647
+ ) # (1, T, B, D)
648
+ audio_label_embeds = torch.stack(text_and_audio_label_embed_list, dim=0) # (Nq, 1, T, B, D)
649
+ audio_label_embeds = audio_label_embeds.contiguous()[:, 0, :, :, :].transpose(1, 2) # (Nq, B, T, D)
650
+ return audio_label_embeds # (Nq, B, T, D)
651
+
652
+ def forward(
653
+ self,
654
+ input_ids: torch.LongTensor = None, # (B, T, 1 + Nq)
655
+ attention_mask: Optional[torch.Tensor] = None,
656
+ position_ids: Optional[torch.LongTensor] = None,
657
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
658
+ inputs_embeds: Optional[torch.FloatTensor] = None,
659
+ labels: Optional[torch.LongTensor] = None, # (B, T, 1 + Nq), TODO labels 为 input_ids shift 一位的结果
660
+ use_cache: Optional[bool] = None,
661
+ output_attentions: Optional[bool] = None,
662
+ output_hidden_states: Optional[bool] = None,
663
+ return_dict: Optional[bool] = None,
664
+ cache_position: Optional[torch.LongTensor] = None,
665
+ **kwargs,
666
+ ) -> Union[Tuple, MosiTTSOutputWithPast]:
667
+ device = input_ids.device if not input_ids is None else inputs_embeds.device
668
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
669
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
670
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
671
+
672
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
673
+ outputs = self.model(
674
+ input_ids=input_ids, # (B, T, 1 + Nq)
675
+ attention_mask=attention_mask,
676
+ position_ids=position_ids,
677
+ past_key_values=past_key_values,
678
+ inputs_embeds=inputs_embeds,
679
+ use_cache=use_cache,
680
+ output_attentions=output_attentions,
681
+ output_hidden_states=output_hidden_states,
682
+ return_dict=return_dict,
683
+ cache_position=cache_position,
684
+ **kwargs,
685
+ )
686
+
687
+ if labels is not None:
688
+ local_transformer_inputs_from_global = outputs[0].unsqueeze(0) # (1, B, T, D)
689
+ D_global= local_transformer_inputs_from_global.shape[-1]
690
+ local_transformer_inputs_from_speech_embeddings = self._prepare_shifted_audio_inputs(labels) # (B, T, 1 + Nq) -> (Nq, B, T, D)
691
+ local_transformer_input_hidden_states = torch.cat([local_transformer_inputs_from_global, local_transformer_inputs_from_speech_embeddings], dim=0).contiguous() # (1 + Nq, B, T, D)
692
+ local_transformer_input_hidden_states = self.speech_embedding_to_local_mlp(local_transformer_input_hidden_states) # (1 + Nq, B, T, D)
693
+ N_channels, B, T, D_local = local_transformer_input_hidden_states.shape
694
+ local_transformer_input_hidden_states = local_transformer_input_hidden_states.permute(1, 2, 0, 3) # (B, T, 1 + Nq, D)
695
+ local_transformer_input_hidden_states = local_transformer_input_hidden_states.reshape(B * T, N_channels, D_local) # (batch_size=B * T, time=1+Nq, D)
696
+ local_transformer_output_hidden_states = self.local_transformer( # TODO 没有开位置编码
697
+ input_ids=None,
698
+ attention_mask=None,
699
+ inputs_embeds=local_transformer_input_hidden_states # (batch_size=B * T, time=1+Nq, D)
700
+ )[0] # (batch_size=B * T, time=1+Nq, D)
701
+ after_lm_head_mlp_hidden_states = [] # Nq+1 * (B*T, D) TODO ???
702
+ for i in range(self.channels):
703
+ after_lm_head_mlp_hidden_states.append(
704
+ self.layer_norm_before_lm_heads[i](
705
+ self.local_to_speech_embedding_mlps[i](
706
+ local_transformer_output_hidden_states[:, i, :] # (B*T, D)
707
+ )
708
+ )
709
+ ) # Nq+1 * (B*T, D)
710
+
711
+ after_lm_head_mlp_hidden_states = torch.stack(after_lm_head_mlp_hidden_states, dim=0) # (1 + Nq, B*T, D)
712
+ after_lm_head_mlp_hidden_states = after_lm_head_mlp_hidden_states.reshape(N_channels, B, T, D_global) # (1 + Nq, B, T, D)
713
+ logits_all = [lm_head(h_i) for lm_head, h_i in zip(self.lm_heads, after_lm_head_mlp_hidden_states)] # 1+Nq * (B, T, V)
714
+
715
+ loss_all = torch.empty(self.channels, device=device) # (1 + Nq)
716
+
717
+ for i in range(self.channels):
718
+ vocab_size = self.config.vocab_size if i == 0 else self.config.audio_vocab_size
719
+ loss_all[i] = ForCausalLMLoss(logits_all[i], labels[..., i], vocab_size, shift_labels=labels[..., i]) # (B, T, V), (B, T) => (1, )
720
+ normalized_weights = [weight_i / sum(self.weights) for weight_i in self.weights] # (1+Nq, )
721
+
722
+ total_loss = 0
723
+ for w, loss in zip(normalized_weights, loss_all):
724
+ total_loss += w * loss
725
+ else:
726
+ total_loss = None
727
+ loss_all = None,
728
+ logits_all = [None]
729
+
730
+ assert return_dict
731
+ if not return_dict:
732
+ output = (logits_all,) + outputs[1:]
733
+ return (total_loss, loss_all, ) + output if loss is not None else output
734
+
735
+ return MosiTTSOutputWithPast(
736
+ loss=total_loss,
737
+ logits=logits_all[0],
738
+ loss_all=loss_all,
739
+ logits_all=logits_all, # 1+Nq * (B, T, V)
740
+ past_key_values=outputs.past_key_values,
741
+ hidden_states=outputs.hidden_states, # L * (B, T, D)
742
+ attentions=outputs.attentions,
743
+ )
processing_moss_tts.py ADDED
@@ -0,0 +1,946 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal, Final, cast
18
+ from dataclasses import dataclass
19
+ from pathlib import Path
20
+ import re
21
+ import torchaudio
22
+
23
+ from transformers import processing_utils
24
+
25
+ processing_utils.MODALITY_TO_BASE_CLASS_MAPPING["audio_tokenizer"] = "PreTrainedModel"
26
+
27
+ import torch
28
+ from transformers import (
29
+ PreTrainedTokenizerBase,
30
+ BatchFeature,
31
+ ProcessorMixin,
32
+ logging,
33
+ AutoConfig,
34
+ AutoModel,
35
+ AutoTokenizer,
36
+ )
37
+
38
+ from .configuration_moss_tts import MossTTSDelayConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ AUDIO_PLACEHOLDER = "<|audio|>"
45
+
46
+
47
+ @dataclass
48
+ class Message:
49
+ def to_dict(self) -> Dict[str, Any]:
50
+ raise NotImplementedError
51
+
52
+
53
+ @dataclass
54
+ class UserMessage(Message):
55
+ text: Optional[str] = None
56
+ reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None
57
+ instruction: Optional[str] = None
58
+ tokens: Optional[int] = None
59
+ quality: Optional[str] = None
60
+ sound_event: Optional[str] = None
61
+ ambient_sound: Optional[str] = None
62
+ language: Optional[str] = None
63
+
64
+ def __post_init__(self):
65
+ template = """<user_inst>
66
+ - Reference(s):
67
+ {reference}
68
+ - Instruction:
69
+ {instruction}
70
+ - Tokens:
71
+ {tokens}
72
+ - Quality:
73
+ {quality}
74
+ - Sound Event:
75
+ {sound_event}
76
+ - Ambient Sound:
77
+ {ambient_sound}
78
+ - Language:
79
+ {language}
80
+ - Text:
81
+ {text}
82
+ </user_inst>"""
83
+
84
+ audio_codes_list = []
85
+ if self.reference is None:
86
+ reference = "None"
87
+ elif isinstance(self.reference, List):
88
+ reference = []
89
+ for speaker_idx, speaker_reference in enumerate(self.reference):
90
+ if speaker_reference is not None:
91
+ reference.append(f"[S{speaker_idx}]:\n{AUDIO_PLACEHOLDER}")
92
+ reference = "\n".join(reference)
93
+ audio_codes_list = [
94
+ speaker_reference
95
+ for speaker_reference in self.reference
96
+ if speaker_reference is not None
97
+ ]
98
+ else:
99
+ raise TypeError("`reference` should be exactly a list when it is not None.")
100
+
101
+ content = (
102
+ template.replace("{reference}", str(reference))
103
+ .replace("{instruction}", str(self.instruction))
104
+ .replace("{tokens}", str(self.tokens))
105
+ .replace("{quality}", str(self.quality))
106
+ .replace("{sound_event}", str(self.sound_event))
107
+ .replace("{ambient_sound}", str(self.ambient_sound))
108
+ .replace("{language}", str(self.language))
109
+ .replace("{text}", str(self.text))
110
+ )
111
+
112
+ self._content = content
113
+ self._audio_codes_list = audio_codes_list
114
+
115
+ def to_dict(self):
116
+ return {
117
+ "role": "user",
118
+ "content": self._content,
119
+ "audio_codes_list": self._audio_codes_list,
120
+ }
121
+
122
+
123
+ @dataclass
124
+ class AssistantMessage(Message):
125
+ audio_codes_list: List[Union[str, torch.Tensor]]
126
+ content: str = AUDIO_PLACEHOLDER
127
+
128
+ def to_dict(self):
129
+ return {
130
+ "role": "assistant",
131
+ "content": self.content,
132
+ "audio_codes_list": self.audio_codes_list,
133
+ }
134
+
135
+
136
+ USER_MESSAGE_FIELDS = (
137
+ "text",
138
+ "reference",
139
+ "instruction",
140
+ "tokens",
141
+ "quality",
142
+ "sound_event",
143
+ "ambient_sound",
144
+ "language",
145
+ )
146
+
147
+
148
+ class MossTTSDelayProcessor(ProcessorMixin):
149
+ tokenizer_class = "AutoTokenizer"
150
+ audio_tokenizer_class = "AutoModel"
151
+
152
+ tokenizer: PreTrainedTokenizerBase
153
+ audio_tokenizer: Any
154
+
155
+ def __init__(
156
+ self,
157
+ tokenizer: PreTrainedTokenizerBase,
158
+ audio_tokenizer: Any = None,
159
+ model_config: Optional[MossTTSDelayConfig] = None,
160
+ **kwargs,
161
+ ):
162
+ super().__init__(tokenizer=tokenizer, audio_tokenizer=audio_tokenizer, **kwargs)
163
+
164
+ # Explicit assignments for type-checkers; ProcessorMixin sets these too.
165
+ self.tokenizer = tokenizer
166
+ self.audio_tokenizer = audio_tokenizer
167
+ if model_config is None:
168
+ model_config = MossTTSDelayConfig()
169
+ self.model_config = model_config
170
+
171
+ self.imstart_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
172
+ self.imend_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
173
+ self.newline_token_id = 198
174
+
175
+ def _id_to_token(token_id: int) -> str:
176
+ tok = tokenizer.convert_ids_to_tokens(int(token_id))
177
+ if isinstance(tok, list):
178
+ return tok[0] if len(tok) > 0 else ""
179
+ return cast(str, tok)
180
+
181
+ self.audio_user_slot_token = _id_to_token(
182
+ self.model_config.audio_user_slot_token_id
183
+ )
184
+ self.audio_assistant_gen_slot_token = _id_to_token(
185
+ self.model_config.audio_assistant_gen_slot_token_id
186
+ )
187
+ self.audio_assistant_delay_slot_token = _id_to_token(
188
+ self.model_config.audio_assistant_delay_slot_token_id
189
+ )
190
+ self.audio_start_token = _id_to_token(self.model_config.audio_start_token_id)
191
+ self.audio_end_token = _id_to_token(self.model_config.audio_end_token_id)
192
+
193
+ @classmethod
194
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
195
+ trust_remote_code = kwargs.pop("trust_remote_code", True)
196
+ kwargs.pop("_from_auto", None)
197
+
198
+ audio_tokenizer_name_or_path = kwargs.pop(
199
+ # "codec_path", "OpenMOSS-Team/MOSS-Audio-Tokenizer"
200
+ "codec_path", "/inspire/sj-ssd3/project/embodied-multimodality/public/ytgong/MOSS-TTS/MOSS-Audio-Tokenizer-snapshot"
201
+ )
202
+
203
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
204
+ model_config = cast(
205
+ MossTTSDelayConfig,
206
+ AutoConfig.from_pretrained(
207
+ pretrained_model_name_or_path,
208
+ *args,
209
+ trust_remote_code=trust_remote_code,
210
+ **kwargs,
211
+ ),
212
+ )
213
+ tokenizer = AutoTokenizer.from_pretrained(
214
+ pretrained_model_name_or_path,
215
+ *args,
216
+ trust_remote_code=trust_remote_code,
217
+ **kwargs,
218
+ )
219
+ audio_tokenizer = AutoModel.from_pretrained(
220
+ audio_tokenizer_name_or_path,
221
+ trust_remote_code=trust_remote_code,
222
+ **kwargs,
223
+ )
224
+
225
+ return cls(
226
+ tokenizer=tokenizer,
227
+ audio_tokenizer=audio_tokenizer,
228
+ model_config=model_config,
229
+ **kwargs,
230
+ )
231
+
232
+ def __call__(self, *args, **kwargs) -> BatchFeature:
233
+ conversations = args[0] if len(args) > 0 else kwargs.pop("conversations")
234
+ mode: str = kwargs.pop("mode", "generation")
235
+ apply_chat_template: bool = kwargs.pop("apply_chat_template", True)
236
+ n_vq: Optional[int] = kwargs.pop("n_vq", None)
237
+
238
+ # Common ProcessorMixin kwargs that we ignore because we always return torch tensors.
239
+ kwargs.pop("return_tensors", None)
240
+ kwargs.pop("padding", None)
241
+ kwargs.pop("truncation", None)
242
+
243
+ """
244
+ mode only works when a Message is converted to a dict.
245
+ """
246
+
247
+ if mode not in {"generation", "continuation"}:
248
+ raise RuntimeError
249
+
250
+ if isinstance(conversations, (Message, Dict)):
251
+ conversations = [conversations]
252
+
253
+ truncation = False
254
+ if mode == "continuation":
255
+ truncation = True
256
+
257
+ input_ids_list = []
258
+ for conversation in conversations:
259
+ if isinstance(conversation, (Message, Dict)):
260
+ conversation = [conversation]
261
+
262
+ # Normalize early so downstream logic always deals with dict messages.
263
+ conversation = [self._normalize_message(m) for m in conversation]
264
+
265
+ if (mode == "generation") ^ (len(conversation) % 2 != 0):
266
+ raise ValueError
267
+
268
+ if (mode == "generation") ^ (conversation[-1]["role"] == "user"):
269
+ raise ValueError
270
+
271
+ unified_codes = []
272
+ for message_idx, message in enumerate(conversation):
273
+ if apply_chat_template:
274
+ add_generation_prompt = (
275
+ mode == "generation" and message_idx == len(conversation) - 1
276
+ )
277
+ try:
278
+ content = self.tokenizer.apply_chat_template(
279
+ [{"role": message["role"], "content": message["content"]}],
280
+ add_generation_prompt=add_generation_prompt,
281
+ tokenize=False,
282
+ )
283
+ except TypeError:
284
+ try:
285
+ content = self.tokenizer.apply_chat_template(
286
+ [
287
+ {
288
+ "role": message["role"],
289
+ "content": message["content"],
290
+ }
291
+ ],
292
+ add_generation_prompt=add_generation_prompt,
293
+ )
294
+ except Exception:
295
+ logger.warning(
296
+ "apply_chat_template failed; fallback to raw content."
297
+ )
298
+ content = message["content"]
299
+ else:
300
+ content = message["content"]
301
+
302
+ if not isinstance(content, str):
303
+ content = str(content)
304
+
305
+ # Batch-encode all path-based references in one call when possible.
306
+ # This ensures we actually exercise audio_tokenizer.batch_encode for multi-reference prompts,
307
+ # instead of repeatedly calling it with batch=1.
308
+ raw_audio_items = message.get("audio_codes_list", [])
309
+
310
+ audio_codes_list: List[torch.Tensor] = []
311
+ if len(raw_audio_items) > 0:
312
+ encoded_items: List[Optional[torch.Tensor]] = [None] * len(
313
+ raw_audio_items
314
+ )
315
+ paths: List[str] = []
316
+ path_positions: List[int] = []
317
+
318
+ for idx, item in enumerate(raw_audio_items):
319
+ if isinstance(item, torch.Tensor):
320
+ if n_vq is not None and item.shape[1] != n_vq:
321
+ raise RuntimeError(
322
+ "audio_codes's n_vq is not equal to the parameter `n_vq`. Your can set the parameter `n_vq` as None if you have already tokenzied the wavs."
323
+ )
324
+ encoded_items[idx] = item
325
+ continue
326
+
327
+ if isinstance(item, (str, os.PathLike)):
328
+ paths.append(str(item))
329
+ path_positions.append(idx)
330
+ continue
331
+
332
+ raise TypeError(
333
+ "Each audio item must be a torch.Tensor of codes or a path-like string."
334
+ )
335
+
336
+ if len(paths) > 0:
337
+ encoded_from_paths = self.encode_audios_from_path(paths, n_vq) # List
338
+ if len(encoded_from_paths) != len(paths):
339
+ raise RuntimeError(
340
+ "encode_audios_from_path returned an unexpected number of items."
341
+ )
342
+ for pos, codes in zip(path_positions, encoded_from_paths):
343
+ encoded_items[pos] = codes
344
+
345
+ audio_codes_list = [cast(torch.Tensor, t) for t in encoded_items]
346
+ unified_codes.append(
347
+ self._get_unified_codes(
348
+ message["role"], content, audio_codes_list, truncation
349
+ )
350
+ )
351
+
352
+ unified_codes = torch.cat(unified_codes) # (T, Nq)
353
+ if mode == "generation":
354
+ audio_start_position_tokens = torch.zeros((1, unified_codes.shape[-1])).to(unified_codes.dtype).to(unified_codes.device) # (1, Nq)
355
+ audio_start_position_tokens[:, 0] = self.tokenizer.encode(self.audio_start_token)[0]
356
+ audio_start_position_tokens[:, 1:] = self.model_config.audio_pad_code
357
+ unified_codes = torch.cat([unified_codes, audio_start_position_tokens], dim=0) # (T, Nq)
358
+ input_ids_list.append(unified_codes)
359
+
360
+ return BatchFeature(data=self._pad(input_ids_list))
361
+
362
+ @staticmethod
363
+ def build_user_message(
364
+ text: Optional[str] = None,
365
+ reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None,
366
+ instruction: Optional[str] = None,
367
+ tokens: Optional[int] = None,
368
+ quality: Optional[str] = None,
369
+ sound_event: Optional[str] = None,
370
+ ambient_sound: Optional[str] = None,
371
+ language: Optional[str] = None,
372
+ ) -> Dict:
373
+ if reference is not None and not isinstance(reference, list):
374
+ reference = [reference]
375
+ return UserMessage(
376
+ text=text,
377
+ reference=reference,
378
+ instruction=instruction,
379
+ tokens=tokens,
380
+ quality=quality,
381
+ sound_event=sound_event,
382
+ ambient_sound=ambient_sound,
383
+ language=language,
384
+ ).to_dict()
385
+
386
+ @staticmethod
387
+ def build_assistant_message(
388
+ audio_codes_list: List[Union[str, torch.Tensor]],
389
+ content: str = AUDIO_PLACEHOLDER,
390
+ ) -> Dict:
391
+ return AssistantMessage(
392
+ audio_codes_list=audio_codes_list,
393
+ content=content,
394
+ ).to_dict()
395
+
396
+ def _normalize_message(self, message: Union[Message, Dict]) -> Dict:
397
+ if isinstance(message, Message):
398
+ return message.to_dict()
399
+ if not isinstance(message, dict):
400
+ raise TypeError("Each message must be a Message or dict.")
401
+ if "role" not in message:
402
+ raise ValueError("Message dict must include a 'role' field.")
403
+ if "content" in message and "audio_codes_list" in message:
404
+ return message
405
+ role = message["role"]
406
+ if role == "user":
407
+ kwargs = {key: message.get(key) for key in USER_MESSAGE_FIELDS}
408
+ return self.build_user_message(**kwargs)
409
+ if role == "assistant":
410
+ return self.build_assistant_message(
411
+ audio_codes_list=message.get("audio_codes_list", []),
412
+ content=message.get("content", AUDIO_PLACEHOLDER),
413
+ )
414
+ raise ValueError(f"Unsupported role: {role}")
415
+
416
+ def _pad(self, input_ids_list: List[torch.Tensor]):
417
+ device = input_ids_list[0].device
418
+ lengths = torch.tensor([w.shape[0] for w in input_ids_list], device=device)
419
+ pad_input_ids = torch.nn.utils.rnn.pad_sequence(
420
+ input_ids_list,
421
+ batch_first=True,
422
+ padding_value=self.model_config.audio_pad_code,
423
+ padding_side="left",
424
+ )
425
+ other_channel_mask = (pad_input_ids.shape[1] - lengths).unsqueeze(
426
+ 1
427
+ ) > torch.arange(pad_input_ids.shape[1], device=device).unsqueeze(0)
428
+ pad_input_ids[..., 0][other_channel_mask] = self.model_config.pad_token_id
429
+ attention_mask = torch.zeros(
430
+ pad_input_ids.shape[0], pad_input_ids.shape[1], device=device
431
+ )
432
+ attention_mask[~other_channel_mask] = 1
433
+ attention_mask = attention_mask.bool()
434
+ return {
435
+ "input_ids": pad_input_ids, # [batch_size, seqlen, n_vq]
436
+ "attention_mask": attention_mask,
437
+ }
438
+
439
+ @staticmethod
440
+ def _replace_audio_placeholders(
441
+ content: str,
442
+ lengths: List[int],
443
+ n_vq: int,
444
+ gen_slot_token: str,
445
+ delay_slot_token: str,
446
+ audio_start_token: str,
447
+ audio_end_token: str,
448
+ ) -> str:
449
+ if n_vq < 1:
450
+ raise ValueError(f"n_vq must be >= 1, got {n_vq}")
451
+
452
+ num_placeholders = content.count(AUDIO_PLACEHOLDER)
453
+ if num_placeholders != len(lengths):
454
+ raise ValueError(
455
+ f"Number of {AUDIO_PLACEHOLDER} ({num_placeholders}) "
456
+ f"does not match lengths ({len(lengths)})"
457
+ )
458
+
459
+ def build_audio_block(length: int) -> str:
460
+ if length < 0:
461
+ raise ValueError(f"length must be >= 0, got {length}")
462
+
463
+ if length == 0:
464
+ return f"{audio_start_token}{audio_end_token}"
465
+
466
+ # step_tokens = gen_slot_token * length + (delay_slot_token * (n_vq - 1))
467
+ step_tokens = gen_slot_token * length
468
+ return f"{audio_start_token}{step_tokens}{audio_end_token}"
469
+
470
+ lengths_iter = iter(lengths)
471
+
472
+ def replacer(match: re.Match) -> str:
473
+ length = next(lengths_iter)
474
+ return build_audio_block(length)
475
+
476
+ result = re.sub(re.escape(AUDIO_PLACEHOLDER), replacer, content)
477
+
478
+ return result
479
+
480
+ @staticmethod
481
+ def _merge_consecutive_audio_placeholders(
482
+ content: str,
483
+ audio_codes_list: List[torch.Tensor],
484
+ ) -> Tuple[str, List[torch.Tensor]]:
485
+ matches = list(re.finditer(re.escape(AUDIO_PLACEHOLDER), content))
486
+ if len(matches) <= 1:
487
+ return content, audio_codes_list
488
+
489
+ if len(matches) != len(audio_codes_list):
490
+ raise ValueError(
491
+ "Audio placeholders do not match the provided audio codes list."
492
+ )
493
+
494
+ new_audio_codes_list = []
495
+ new_parts = []
496
+ last_pos = 0
497
+ i = 0
498
+ while i < len(matches):
499
+ j = i
500
+ while (
501
+ j + 1 < len(matches)
502
+ and content[matches[j].end() : matches[j + 1].start()].strip() == ""
503
+ ):
504
+ j += 1
505
+
506
+ new_parts.append(content[last_pos : matches[i].start()])
507
+ new_parts.append(AUDIO_PLACEHOLDER)
508
+ last_pos = matches[j].end()
509
+
510
+ if j == i:
511
+ new_audio_codes_list.append(audio_codes_list[i])
512
+ else:
513
+ new_audio_codes_list.append(
514
+ torch.cat(audio_codes_list[i : j + 1], dim=0)
515
+ )
516
+
517
+ i = j + 1
518
+
519
+ new_parts.append(content[last_pos:])
520
+ return "".join(new_parts), new_audio_codes_list
521
+
522
+ @staticmethod
523
+ def apply_delay_pattern(codes: torch.Tensor, pad_code: int) -> torch.Tensor:
524
+ delayed_tokens = torch.full(
525
+ (codes.shape[0] + codes.shape[1] - 1, codes.shape[1]),
526
+ pad_code,
527
+ device=codes.device,
528
+ dtype=codes.dtype,
529
+ )
530
+ for i in range(codes.shape[1]):
531
+ delayed_tokens[i : i + codes.shape[0], i] = codes[:, i]
532
+ return delayed_tokens
533
+
534
+ @staticmethod
535
+ def apply_de_delay_pattern(delay_codes: torch.Tensor) -> torch.Tensor:
536
+ tokens = torch.full(
537
+ (delay_codes.shape[0] - delay_codes.shape[1] + 1, delay_codes.shape[1]),
538
+ 0,
539
+ device=delay_codes.device,
540
+ dtype=delay_codes.dtype,
541
+ )
542
+ for i in range(delay_codes.shape[1]):
543
+ tokens[:, i] = delay_codes[i : i + tokens.shape[0], i]
544
+ return tokens
545
+
546
+ def _get_unified_codes(
547
+ self,
548
+ role: str,
549
+ content: str,
550
+ audio_codes_list: List[torch.Tensor],
551
+ truncation: bool,
552
+ ) -> torch.Tensor:
553
+ """
554
+ 此时的 content 已经是带上了对话格式
555
+ """
556
+ if role == "user":
557
+ audio_gen_slot_token = audio_delay_slot_token = self.audio_user_slot_token
558
+ else:
559
+ audio_gen_slot_token = self.audio_assistant_gen_slot_token
560
+ audio_delay_slot_token = self.audio_assistant_delay_slot_token
561
+
562
+ if len(audio_codes_list):
563
+ n_vq = audio_codes_list[0].shape[1]
564
+ else:
565
+ n_vq = self.model_config.n_vq
566
+
567
+ if len(audio_codes_list) > 1 and AUDIO_PLACEHOLDER in content:
568
+ content, audio_codes_list = self._merge_consecutive_audio_placeholders(
569
+ content, audio_codes_list
570
+ )
571
+ content = self._replace_audio_placeholders(
572
+ content=content,
573
+ lengths=[len(audio_codes) for audio_codes in audio_codes_list],
574
+ n_vq=n_vq,
575
+ gen_slot_token=audio_gen_slot_token,
576
+ delay_slot_token=audio_delay_slot_token,
577
+ audio_start_token=self.audio_start_token,
578
+ audio_end_token=self.audio_end_token,
579
+ )
580
+ text_codes = torch.tensor(
581
+ self.tokenizer.encode(content),
582
+ device=audio_codes_list[0].device if audio_codes_list else None,
583
+ )
584
+
585
+ audio_start_indices = torch.where(
586
+ text_codes == self.model_config.audio_start_token_id
587
+ )[0]
588
+ audio_end_indices = torch.where(
589
+ text_codes == self.model_config.audio_end_token_id
590
+ )[0]
591
+ if len(audio_start_indices) != len(audio_codes_list) or len(
592
+ audio_end_indices
593
+ ) != len(audio_codes_list):
594
+ raise ValueError(
595
+ "Audio placeholders do not match the provided audio codes list."
596
+ )
597
+
598
+ delay_audio_codes_list = []
599
+ assert len(audio_codes_list) <= 1
600
+ if len(audio_codes_list) == 0:
601
+ delay_audio_codes_list = torch.full(
602
+ (len(text_codes), n_vq),
603
+ self.model_config.audio_pad_code,
604
+ device=text_codes.device,
605
+ dtype=text_codes.dtype,
606
+ )
607
+ else:
608
+ prefix_idx = 0
609
+ for audio_start_idx_t, audio_end_idx_t, audio_codes in zip(
610
+ audio_start_indices, audio_end_indices, audio_codes_list
611
+ ):
612
+ audio_start_idx = int(audio_start_idx_t.item())
613
+ audio_end_idx = int(audio_end_idx_t.item())
614
+ # delay_audio_codes = self.apply_delay_pattern(
615
+ # audio_codes, self.model_config.audio_pad_code
616
+ # )
617
+ delay_audio_codes = audio_codes # not delay
618
+ pad_codes = torch.full(
619
+ (audio_start_idx - prefix_idx + 1, n_vq),
620
+ self.model_config.audio_pad_code,
621
+ device=audio_codes.device,
622
+ dtype=audio_codes.dtype,
623
+ )
624
+ delay_audio_codes_list.extend([pad_codes, delay_audio_codes])
625
+ prefix_idx = audio_end_idx
626
+
627
+ if truncation:
628
+ # delay_audio_codes_list[-1] = delay_audio_codes_list[-1][
629
+ # : -(n_vq - 1), :
630
+ # ]
631
+ ...
632
+ else:
633
+ last_audio_end_idx = int(audio_end_indices[-1].item())
634
+ pad_codes = torch.full(
635
+ (len(text_codes) - last_audio_end_idx, n_vq),
636
+ self.model_config.audio_pad_code,
637
+ device=audio_codes_list[0].device,
638
+ dtype=audio_codes_list[0].dtype,
639
+ )
640
+ delay_audio_codes_list.append(pad_codes)
641
+
642
+ delay_audio_codes_list = torch.cat(delay_audio_codes_list)
643
+
644
+ if text_codes.shape[0] != delay_audio_codes_list.shape[0]:
645
+ text_codes = text_codes[: delay_audio_codes_list.shape[0]]
646
+
647
+ unified_codes = torch.cat(
648
+ [text_codes.unsqueeze(1), delay_audio_codes_list], dim=1
649
+ )
650
+ return unified_codes
651
+
652
+ def _parse_text_codes(self, start_length, text_codes):
653
+ text = cast(str, self.tokenizer.decode(text_codes))
654
+ prefix = cast(str, self.tokenizer.decode(text_codes[:start_length]))
655
+ text = text[len(prefix) :]
656
+
657
+ AUDIO_PATTERN = re.compile(
658
+ rf"(?:{self.audio_start_token})?"
659
+ rf"(?:{self.audio_assistant_gen_slot_token})*"
660
+ rf"(?:{self.audio_assistant_delay_slot_token})*"
661
+ rf"{self.audio_end_token}"
662
+ )
663
+
664
+ def normalize_audio_segments(text: str) -> str:
665
+ def repl(match: re.Match) -> str:
666
+ seg = match.group(0)
667
+ # Replace with <|audio|> if gen_slot is present in the segment;
668
+ if self.audio_assistant_gen_slot_token in seg:
669
+ return AUDIO_PLACEHOLDER
670
+ # Otherwise, remove it.
671
+ return ""
672
+
673
+ return AUDIO_PATTERN.sub(repl, text)
674
+
675
+ return normalize_audio_segments(text)
676
+
677
+ def _parse_audio_codes(self, start_length, audio_codes):
678
+ # De-delay back to [T', n_vq]
679
+ # audio_codes = self.apply_de_delay_pattern(audio_codes)
680
+
681
+ # Rows that are all pad are separators between real audio segments.
682
+ is_pad = (audio_codes == self.model_config.audio_pad_code).all(dim=1)
683
+ non_pad = ~is_pad
684
+ if not non_pad.any():
685
+ return []
686
+
687
+ idx = torch.nonzero(non_pad).squeeze(1)
688
+ breaks = torch.where(idx[1:] != idx[:-1] + 1)[0] + 1
689
+ if breaks.numel() == 0:
690
+ segments_idx = [idx]
691
+ else:
692
+ # assert len(breaks) == 1
693
+ # segments_idx = torch.split(idx, [breaks.tolist()[0], len(idx) - breaks.tolist()[0]])
694
+ segments_idx = torch.split(idx, breaks.tolist())
695
+
696
+ audio_codes_list = [audio_codes[s] for s in segments_idx]
697
+
698
+ # Batch-decode all audio segments together.
699
+ decoded_audio_list = self.decode_audio_codes(audio_codes_list)
700
+
701
+ # Keep codec causal context by decoding the whole first segment first,
702
+ # then trim at waveform level according to start_length ratio.
703
+ if (
704
+ start_length > 0
705
+ and len(audio_codes_list) > 0
706
+ and len(decoded_audio_list) > 0
707
+ ):
708
+ first_codes_length = audio_codes_list[0].shape[0]
709
+ if first_codes_length > 0:
710
+ trim_ratio = max(
711
+ 0.0, min(float(start_length) / float(first_codes_length), 1.0)
712
+ )
713
+ first_audio = decoded_audio_list[0]
714
+ if trim_ratio >= 1.0:
715
+ decoded_audio_list = decoded_audio_list[1:]
716
+ elif trim_ratio > 0.0:
717
+ trim_samples = int(first_audio.shape[-1] * trim_ratio)
718
+ decoded_audio_list[0] = first_audio[..., trim_samples:]
719
+
720
+ return decoded_audio_list
721
+
722
+ def decode(self, output: List[Tuple[int, torch.Tensor]]):
723
+ """
724
+ 1. 这里不管怎样,都需要一个完整的 assistant generation ids;
725
+ 2. 支持从任意位置进行截断;
726
+ """
727
+
728
+ genearted_messages = []
729
+ for start_length, generation_ids in output:
730
+ content = self._parse_text_codes(start_length, generation_ids[:, 0])
731
+ audio_codes_list = self._parse_audio_codes(
732
+ start_length, generation_ids[:, 1:]
733
+ )
734
+ if content == "":
735
+ message = None
736
+ else:
737
+ message = AssistantMessage(
738
+ content=content,
739
+ audio_codes_list=cast(
740
+ List[Union[str, torch.Tensor]], audio_codes_list
741
+ ),
742
+ )
743
+ genearted_messages.append(message)
744
+ return genearted_messages
745
+
746
+ @staticmethod
747
+ def loudness_normalize(
748
+ wav: torch.Tensor,
749
+ target_dbfs: float = -20,
750
+ gain_range: tuple[float, float] = (-3.0, 3.0),
751
+ ) -> torch.Tensor:
752
+ wav = wav.to(torch.float32)
753
+ if wav.numel() == 0:
754
+ return wav
755
+ current_dbfs = 10.0 * torch.log10(torch.mean(wav**2) + 1e-9)
756
+ gain = float(target_dbfs - current_dbfs)
757
+ gain = max(gain_range[0], min(gain, gain_range[1]))
758
+ factor = 10.0 ** (gain / 20.0)
759
+ return wav * factor
760
+
761
+ def _get_audio_tokenizer_device(self) -> torch.device:
762
+ """Best-effort device inference for `self.audio_tokenizer`.
763
+
764
+ Notes:
765
+ - Old TAC wrapper exposed `.device`, but standard `torch.nn.Module` does not.
766
+ - New MossAudioTokenizerModel is a `PreTrainedModel`; parameters define its device.
767
+ """
768
+
769
+ audio_tokenizer = getattr(self, "audio_tokenizer", None)
770
+ if audio_tokenizer is None:
771
+ logger.warning(
772
+ "audio_tokenizer is not set on processor. Using CPU as default."
773
+ )
774
+ return torch.device("cpu")
775
+
776
+ device_attr = getattr(audio_tokenizer, "device", None)
777
+ if isinstance(device_attr, torch.device):
778
+ return device_attr
779
+
780
+ try:
781
+ return next(audio_tokenizer.parameters()).device
782
+ except StopIteration:
783
+ # No parameters (shouldn't happen for real models); default to CPU.
784
+ logger.warning(
785
+ "No parameters found on audio_tokenizer. Using CPU as default."
786
+ )
787
+ return torch.device("cpu")
788
+
789
+ def encode_audios_from_wav(
790
+ self,
791
+ wav_list: List[torch.Tensor],
792
+ sampling_rate: int,
793
+ n_vq: Optional[int] = None,
794
+ ):
795
+ if self.audio_tokenizer is None:
796
+ raise RuntimeError("audio_tokenizer is not set on processor.")
797
+ audio_tokenizer = self.audio_tokenizer
798
+
799
+ if isinstance(wav_list, torch.Tensor):
800
+ wav_list = [wav_list]
801
+ wav_list_ = []
802
+ resample = False
803
+ if sampling_rate != self.model_config.sampling_rate:
804
+ resample = True
805
+ device = self._get_audio_tokenizer_device()
806
+ for wav in wav_list:
807
+ if wav.shape[0] > 1:
808
+ wav = torch.mean(wav, dim=0, keepdim=True)
809
+ if resample:
810
+ wav = torchaudio.functional.resample(
811
+ waveform=wav,
812
+ orig_freq=sampling_rate,
813
+ new_freq=self.model_config.sampling_rate,
814
+ )
815
+ wav = wav.to(device)
816
+ wav_list_.append(self.loudness_normalize(wav.squeeze(0)))
817
+
818
+ # New MossAudioTokenizerModel API: prefer batch_encode(list[wav])
819
+ if hasattr(audio_tokenizer, "batch_encode"):
820
+ enc = audio_tokenizer.batch_encode(wav_list_, num_quantizers=n_vq)
821
+ audio_codes = enc.audio_codes # (NQ, B, T)
822
+ audio_codes_lengths = enc.audio_codes_lengths # (B,)
823
+ else:
824
+ # Fallback: use encode() with explicit padding.
825
+ max_len = max(int(wav.shape[-1]) for wav in wav_list_)
826
+ input_values = torch.zeros(
827
+ len(wav_list_), 1, max_len, device=device, dtype=torch.float32
828
+ )
829
+ padding_mask = torch.zeros(
830
+ len(wav_list_), max_len, device=device, dtype=torch.bool
831
+ )
832
+ for i, wav in enumerate(wav_list_):
833
+ this_len = int(wav.shape[-1])
834
+ input_values[i, 0, :this_len] = wav
835
+ padding_mask[i, :this_len] = True
836
+ enc = audio_tokenizer.encode(
837
+ input_values,
838
+ padding_mask=padding_mask,
839
+ num_quantizers=n_vq,
840
+ return_dict=True,
841
+ )
842
+ audio_codes = enc.audio_codes
843
+ audio_codes_lengths = enc.audio_codes_lengths
844
+
845
+ if audio_codes is None or audio_codes_lengths is None:
846
+ raise RuntimeError(
847
+ "audio_tokenizer.encode() returned empty outputs (audio_codes/audio_codes_lengths)."
848
+ )
849
+
850
+ # Keep processor's historical contract: list[Tensor] with shape (T, NQ)
851
+ # and on CPU (so downstream text/audio packing remains device-agnostic).
852
+ codes_list: List[torch.Tensor] = []
853
+ for i in range(int(audio_codes.shape[1])):
854
+ length_i = int(audio_codes_lengths[i].item())
855
+ codes_i = (
856
+ audio_codes[:, i, :length_i]
857
+ .transpose(0, 1)
858
+ .contiguous()
859
+ .to(torch.long)
860
+ .cpu()
861
+ )
862
+ codes_list.append(codes_i)
863
+ return codes_list
864
+
865
+ def encode_audios_from_path(
866
+ self, wav_path_list: Union[str, List[str]], n_vq: Optional[int] = None
867
+ ):
868
+ if isinstance(wav_path_list, str):
869
+ wav_path_list = [wav_path_list]
870
+
871
+ if len(wav_path_list) == 0:
872
+ raise ValueError("Empty wav_path_list")
873
+
874
+ # Load + (if needed) resample each wav independently, so callers can
875
+ # pass a heterogeneous batch of files while still benefiting from
876
+ # audio_tokenizer.batch_encode.
877
+ target_sr = int(self.model_config.sampling_rate)
878
+ wav_list: List[torch.Tensor] = []
879
+ for wav_path in wav_path_list:
880
+ wav, sr = torchaudio.load(wav_path)
881
+ if int(sr) != target_sr:
882
+ wav = torchaudio.functional.resample(
883
+ waveform=wav,
884
+ orig_freq=int(sr),
885
+ new_freq=target_sr,
886
+ )
887
+ wav_list.append(wav)
888
+
889
+ return self.encode_audios_from_wav(wav_list, target_sr, n_vq)
890
+
891
+ def decode_audio_codes(
892
+ self, audio_tokens_list: Union[torch.Tensor, List[torch.Tensor]]
893
+ ):
894
+ if self.audio_tokenizer is None:
895
+ raise RuntimeError("audio_tokenizer is not set on processor.")
896
+ audio_tokenizer = self.audio_tokenizer
897
+
898
+ if isinstance(audio_tokens_list, torch.Tensor):
899
+ audio_tokens_list = [audio_tokens_list]
900
+ if len(audio_tokens_list) == 0:
901
+ return []
902
+
903
+ device = self._get_audio_tokenizer_device()
904
+
905
+ # Processor uses (T, NQ); MossAudioTokenizer expects (NQ, T) (or (NQ, B, T)).
906
+ codes_list = [
907
+ codes.transpose(0, 1).contiguous().to(device=device, dtype=torch.long)
908
+ for codes in audio_tokens_list
909
+ ]
910
+
911
+ if hasattr(audio_tokenizer, "batch_decode"):
912
+ dec = audio_tokenizer.batch_decode(codes_list)
913
+ audio = dec.audio # (B, C, T)
914
+ audio_lengths = dec.audio_lengths # (B,)
915
+ else:
916
+ # Fallback: pad to (NQ, B, T) + mask, then decode.
917
+ nq = int(codes_list[0].shape[0])
918
+ max_t = max(int(c.shape[1]) for c in codes_list)
919
+ audio_codes = torch.zeros(
920
+ nq, len(codes_list), max_t, device=device, dtype=torch.long
921
+ )
922
+ padding_mask = torch.zeros(
923
+ len(codes_list), max_t, device=device, dtype=torch.bool
924
+ )
925
+ for i, c in enumerate(codes_list):
926
+ t = int(c.shape[1])
927
+ audio_codes[:, i, :t] = c
928
+ padding_mask[i, :t] = True
929
+ dec = audio_tokenizer.decode(
930
+ audio_codes, padding_mask=padding_mask, return_dict=True
931
+ )
932
+ audio = dec.audio
933
+ audio_lengths = dec.audio_lengths
934
+
935
+ if audio is None or audio_lengths is None:
936
+ raise RuntimeError(
937
+ "audio_tokenizer.decode() returned empty outputs (audio/audio_lengths)."
938
+ )
939
+
940
+ # Return historical contract: list of 1D waveforms (T,)
941
+ wav_list: List[torch.Tensor] = []
942
+ for i in range(int(audio.shape[0])):
943
+ length_i = int(audio_lengths[i].item())
944
+ wav = audio[i, 0, :length_i].contiguous().to(torch.float32).cpu()
945
+ wav_list.append(wav)
946
+ return wav_list
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "MossTTSDelayProcessor",
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_moss_tts.MossTTSDelayProcessor"
5
+ }
6
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|audio_start|>",
12
+ "<|audio_end|>",
13
+ "<|audio_user_slot|>",
14
+ "<|image_pad|>",
15
+ "<|audio_assistant_gen_slot|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb3c8fa82993d515469c2800cc455bff4aaa3c4fed9da1f2b0c0668c304f335a
3
+ size 11422691
tokenizer_config.json ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|audio_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|audio_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|audio_user_slot|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|audio_assistant_gen_slot|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|audio_assistant_delay_slot|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|audio_start|>",
224
+ "<|audio_end|>",
225
+ "<|audio_user_slot|>",
226
+ "<|image_pad|>",
227
+ "<|audio_assistant_gen_slot|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|endoftext|>",
236
+ "processor_class": "AsteroidProcessor",
237
+ "split_special_tokens": false,
238
+ "tokenizer_class": "Qwen2Tokenizer",
239
+ "unk_token": null
240
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff