YWMditto commited on
Commit
5b37e97
·
0 Parent(s):

OpenMOSS, MOSI.AI, and MOSS-TTS have been officially open-sourced!

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ # MOSS-TTS Family
5
+
6
+ ## Overview
7
+ MOSS‑TTS Family is an open‑source **speech and sound generation model family** from [MOSI.AI](https://mosi.cn/#hero) and the [OpenMOSS team](https://www.open-moss.com/). It is designed for **high‑fidelity**, **high‑expressiveness**, and **complex real‑world scenarios**, covering stable long‑form speech, multi‑speaker dialogue, voice/character design, environmental sound effects, and real‑time streaming TTS.
8
+
9
+
10
+ ## Introduction
11
+
12
+ <p align="center">
13
+ <img src="https://speech-demo.oss-cn-shanghai.aliyuncs.com/moss_tts_demo/tts_readme_imgaes_demo/moss_tts_family_arch.jpeg" width="85%" />
14
+ </p>
15
+
16
+ When a single piece of audio needs to **sound like a real person**, **pronounce every word accurately**, **switch speaking styles across content**, **remain stable over tens of minutes**, and **support dialogue, role‑play, and real‑time interaction**, a single TTS model is often not enough. The **MOSS‑TTS Family** breaks the workflow into five production‑ready models that can be used independently or composed into a complete pipeline.
17
+
18
+ - **MOSS‑TTS**: MOSS-TTS is the flagship, production-ready Text-to-Speech foundation model in the MOSS-TTS Family, built to ship, scale, and deliver real-world voice applications beyond demos. It provides high-fidelity zero-shot voice cloning as the core capability, along with ultra-long speech generation, token-level duration control, multilingual and code-switched synthesis, and fine-grained Pinyin/phoneme pronunciation control. Together, these features make it a robust base model for scalable narration, dubbing, and voice-driven products.
19
+ - **MOSS‑TTSD**: MOSS-TTSD is a production-oriented long-form spoken dialogue generation model for creating highly expressive, multi-party conversational audio at scale. It supports continuous long-duration generation, flexible multi-speaker turn-taking control, and zero-shot voice cloning from short reference audio, enabling natural conversations with rich interaction dynamics. It is designed for real-world long-form content such as podcasts, audiobooks, commentary, dubbing, and entertainment dialogue.
20
+ - **MOSS‑VoiceGenerator**: MOSS-VoiceGenerator is an open-source voice design system that generates speaker timbres directly from free-form text descriptions, enabling fast creation of voices for characters, personalities, and emotions—without requiring reference audio. It unifies timbre design, style control, and content synthesis in a single instruction-driven model, producing high-fidelity, emotionally expressive speech that feels naturally human. It can be used standalone for creative production, or as a voice design layer that improves integration and usability for downstream TTS systems.
21
+ - **MOSS‑SoundEffect**: MOSS-SoundEffect is a high-fidelity sound effect generation model built for real-world content creation, offering strong environmental richness, broad category coverage, and reliable duration controllability. Trained on large-scale, high-quality data, it generates consistent audio from text prompts across natural ambience, urban scenes, creatures, human actions, and music-like clips. It is well suited for film and game production, interactive experiences, and data synthesis pipelines.
22
+ - **MOSS‑TTS‑Realtime**: MOSS-TTS-Realtime is a context-aware, multi-turn streaming TTS foundation model designed for real-time voice agents. Unlike conventional TTS that synthesizes replies in isolation, it conditions generation on multi-turn dialogue history—including both textual and acoustic signals from prior user speech—so responses stay coherent, consistent, and natural across turns. With low-latency incremental synthesis and strong voice stability, it enables truly conversational, human-like real-time speech experiences.
23
+
24
+
25
+ ## Released Models
26
+
27
+ | Model | Architecture | Size | Model Card | Hugging Face |
28
+ |---|---|---:|---|---|
29
+ | **MOSS-TTS** | MossTTSDelay | 8B | [moss_tts_model_card.md](https://github.com/OpenMOSS/MOSS-TTS/blob/main/moss_tts_model_card.md) | 🤗 [Huggingface](https://huggingface.co/OpenMOSS-Team/MOSS-TTS) |
30
+ | | MossTTSLocal | 1.7B | [moss_tts_model_card.md](https://github.com/OpenMOSS/MOSS-TTS/blob/main/moss_tts_model_card.md) | 🤗 [Huggingface](https://huggingface.co/OpenMOSS-Team/MOSS-TTS-Local-Transformer) |
31
+ | **MOSS‑TTSD‑V1.0** | MossTTSDelay | 8B | [moss_ttsd_model_card.md](https://github.com/OpenMOSS/MOSS-TTS/blob/main/moss_ttsd_model_card.md) | 🤗 [Huggingface](https://huggingface.co/OpenMOSS-Team/MOSS-TTSD-v1.0) |
32
+ | **MOSS‑VoiceGenerator** | MossTTSDelay | 1.7B | [moss_voice_generator_model_card.md](https://github.com/OpenMOSS/MOSS-TTS/blob/main/moss_voice_generator_model_card.md) | 🤗 [Huggingface](https://huggingface.co/OpenMOSS-Team/MOSS-Voice-Generator) |
33
+ | **MOSS‑SoundEffect** | MossTTSDelay | 8B | [moss_sound_effect_model_card.md](https://github.com/OpenMOSS/MOSS-TTS/blob/main/moss_sound_effect_model_card.md) | 🤗 [Huggingface](https://huggingface.co/OpenMOSS-Team/MOSS-SoundEffect) |
34
+ | **MOSS‑TTS‑Realtime** | MossTTSRealtime | 1.7B | [moss_tts_realtime_model_card.md](https://github.com/OpenMOSS/MOSS-TTS/blob/main/moss_tts_realtime_model_card.md) | 🤗 [Huggingface](https://huggingface.co/OpenMOSS-Team/MOSS-TTS-Realtime) |
35
+
36
+
37
+ # MOSS-SoundEffect
38
+
39
+ **MOSS-SoundEffect** is the **environment sound & sound effect generation model** in the **MOSS‑TTS Family**. It generates ambient soundscapes and concrete sound effects directly from text descriptions, and is designed to complement speech content with immersive context in production workflows.
40
+
41
+
42
+ ## 1. Overview
43
+
44
+ ### 1.1 TTS Family Positioning
45
+
46
+ MOSS-SoundEffect is designed as an audio generation backbone for creating high-fidelity environmental and action sounds from text, serving both scalable content pipelines and a strong research baseline for controllable audio generation.
47
+
48
+ **Design goals**
49
+ * **Coverage & richness**: broad sound taxonomy with layered ambience and realistic texture
50
+ * **Composability**: easy integration into creative pipelines (games/film/tools) and synthetic data generation setups
51
+
52
+
53
+ ### 1.2 Key Capabilities
54
+ MOSS‑SoundEffect focuses on **contextual audio completion** beyond speech, enabling creators and systems to enrich scenes with believable acoustic environments and action‑level cues.
55
+
56
+ **What it can generate**
57
+ - **Natural environments**: e.g., “fresh snow crunching under footsteps.”
58
+ - **Urban environments**: e.g., “a sports car roaring past on the highway.”
59
+ - **Animals & creatures**: e.g., “early morning park with birds chirping in a quiet atmosphere.”
60
+ - **Human actions**: e.g., “clear footsteps echoing on concrete at a steady rhythm.”
61
+
62
+ **Why it matters**
63
+ - Completes **scene immersion** for narrative content, film/TV, documentaries, games, and podcasts.
64
+ - Supports **voice agents** and interactive systems that need ambient context, not just speech.
65
+ - Acts as the **sound‑design layer** of the MOSS‑TTS Family’s end‑to‑end workflow.
66
+
67
+
68
+
69
+ ### 1.3 Model Architecture
70
+ **MOSS-SoundEffect** employs the **MossTTSDelay** architecture (see [moss_tts_delay/README.md](https://github.com/OpenMOSS/MOSS-TTS/blob/main/moss_tts_delay/README.md)), reusing the same discrete token generation backbone for audio synthesis. A text prompt (optionally with simple control tags such as **duration**) is tokenized and fed into the Delay-pattern autoregressive model to predict **RVQ audio tokens** over time. The generated tokens are then decoded by the audio tokenizer/vocoder to produce high-fidelity sound effects, enabling consistent quality and controllable length across diverse SFX categories.
71
+
72
+
73
+
74
+ ### 1.4 Released Models
75
+ **Recommended decoding hyperparameters**
76
+ | Model | audio_temperature | audio_top_p | audio_top_k | audio_repetition_penalty |
77
+ |---|---:|---:|---:|---:|
78
+ | **MOSS-SoundEffect** | 1.5 | 0.6 | 50 | 1.2 |
79
+
80
+
81
+ ## 2. Quick Start
82
+
83
+
84
+
85
+ ```python
86
+ import os
87
+ from pathlib import Path
88
+ import torch
89
+ import torchaudio
90
+ from transformers import AutoModel, AutoProcessor
91
+ # Disable the broken cuDNN SDPA backend
92
+ torch.backends.cuda.enable_cudnn_sdp(False)
93
+ # Keep these enabled as fallbacks
94
+ torch.backends.cuda.enable_flash_sdp(True)
95
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
96
+ torch.backends.cuda.enable_math_sdp(True)
97
+
98
+
99
+ pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-SoundEffect"
100
+ device = "cuda" if torch.cuda.is_available() else "cpu"
101
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
102
+
103
+ processor = AutoProcessor.from_pretrained(
104
+ pretrained_model_name_or_path,
105
+ trust_remote_code=True,
106
+ )
107
+ processor.audio_tokenizer = processor.audio_tokenizer.to(device)
108
+
109
+ text_1 = "雷声隆隆,雨声淅沥。"
110
+ text_2 = "清晰脚步声在水泥地面回响,节奏稳定。"
111
+
112
+ conversations = [
113
+ [processor.build_user_message(ambient_sound=text_1)],
114
+ [processor.build_user_message(ambient_sound=text_2)]
115
+ ]
116
+
117
+ model = AutoModel.from_pretrained(
118
+ pretrained_model_name_or_path,
119
+ trust_remote_code=True,
120
+ attn_implementation="sdpa",
121
+ torch_dtype=dtype,
122
+ ).to(device)
123
+ model.eval()
124
+
125
+ batch_size = 1
126
+
127
+ messages = []
128
+ save_dir = Path("inference_root")
129
+ save_dir.mkdir(exist_ok=True, parents=True)
130
+ sample_idx = 0
131
+ with torch.no_grad():
132
+ for start in range(0, len(conversations), batch_size):
133
+ batch_conversations = conversations[start : start + batch_size]
134
+ batch = processor(batch_conversations, mode="generation")
135
+ input_ids = batch["input_ids"].to(device)
136
+ attention_mask = batch["attention_mask"].to(device)
137
+
138
+ outputs = model.generate(
139
+ input_ids=input_ids,
140
+ attention_mask=attention_mask,
141
+ max_new_tokens=4096,
142
+ )
143
+
144
+ for message in processor.decode(outputs):
145
+ audio = message.audio_codes_list[0]
146
+ out_path = save_dir / f"sample{sample_idx}.wav"
147
+ sample_idx += 1
148
+ torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
149
+ ```
150
+
151
+ ### Input Types
152
+
153
+ **UserMessage**
154
+ | Field | Type | Required | Description |
155
+ |---|---|---:|---|
156
+ | `ambient_sound` | `str` | Yes | Description of environment sound & sound effect |
157
+ | `tokens` | `int` | No | Expected number of audio tokens. **1s ≈ 12.5 tokens**. |
158
+
__init__.py ADDED
File without changes
added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|audio_end|>": 151653,
9
+ "<|audio_pad|>": 151654,
10
+ "<|audio_start|>": 151652,
11
+ "<|box_end|>": 151649,
12
+ "<|box_start|>": 151648,
13
+ "<|endoftext|>": 151643,
14
+ "<|file_sep|>": 151664,
15
+ "<|fim_middle|>": 151660,
16
+ "<|fim_pad|>": 151662,
17
+ "<|fim_prefix|>": 151659,
18
+ "<|fim_suffix|>": 151661,
19
+ "<|im_end|>": 151645,
20
+ "<|im_start|>": 151644,
21
+ "<|image_pad|>": 151655,
22
+ "<|object_ref_end|>": 151647,
23
+ "<|object_ref_start|>": 151646,
24
+ "<|quad_end|>": 151651,
25
+ "<|quad_start|>": 151650,
26
+ "<|repo_name|>": 151663,
27
+ "<|video_pad|>": 151656
28
+ }
chat_template.jinja ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {% for message in messages %}<|im_start|>{{ message['role'] }}
2
+ {% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content.get('type') == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}<|im_end|>
3
+ {% endfor %}{% if add_generation_prompt %}<|im_start|>assistant
4
+ {% endif %}
config.json ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "moss_tts_delay",
3
+ "architectures": [
4
+ "MossTTSDelayModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_moss_tts.MossTTSDelayConfig",
8
+ "AutoModel": "modeling_moss_tts.MossTTSDelayModel"
9
+ },
10
+ "dtype": "bfloat16",
11
+ "initializer_range": 0.02,
12
+ "language_config": {
13
+ "_name_or_path": "Qwen/Qwen3-8B",
14
+ "architectures": [
15
+ "Qwen3ForCausalLM"
16
+ ],
17
+ "attention_bias": false,
18
+ "attention_dropout": 0.0,
19
+ "bos_token_id": 151643,
20
+ "eos_token_id": 151645,
21
+ "pad_token_id": 151643,
22
+ "head_dim": 128,
23
+ "hidden_act": "silu",
24
+ "hidden_size": 4096,
25
+ "initializer_range": 0.02,
26
+ "intermediate_size": 12288,
27
+ "layer_types": [
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention",
44
+ "full_attention",
45
+ "full_attention",
46
+ "full_attention",
47
+ "full_attention",
48
+ "full_attention",
49
+ "full_attention",
50
+ "full_attention",
51
+ "full_attention",
52
+ "full_attention",
53
+ "full_attention",
54
+ "full_attention",
55
+ "full_attention",
56
+ "full_attention",
57
+ "full_attention",
58
+ "full_attention",
59
+ "full_attention",
60
+ "full_attention",
61
+ "full_attention",
62
+ "full_attention",
63
+ "full_attention"
64
+ ],
65
+ "max_position_embeddings": 40960,
66
+ "max_window_layers": 36,
67
+ "model_type": "qwen3",
68
+ "num_attention_heads": 32,
69
+ "num_hidden_layers": 36,
70
+ "num_key_value_heads": 8,
71
+ "rms_norm_eps": 1e-06,
72
+ "rope_scaling": null,
73
+ "rope_theta": 1000000,
74
+ "sliding_window": null,
75
+ "use_cache": true,
76
+ "use_sliding_window": false,
77
+ "vocab_size": 155648
78
+ },
79
+ "n_vq": 16,
80
+ "audio_vocab_size": 1024,
81
+ "audio_user_slot_token_id": 151654,
82
+ "audio_assistant_gen_slot_token_id": 151656,
83
+ "audio_assistant_delay_slot_token_id": 151662,
84
+ "audio_start_token_id": 151652,
85
+ "audio_end_token_id": 151653,
86
+ "audio_pad_code": 1024,
87
+ "sampling_rate": 24000,
88
+ "transformers_version": "4.57.1"
89
+ }
configuration_moss_tts.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ MossTTSDelay model configuration """
16
+
17
+ from typing import Optional, Union
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+ from transformers.models.qwen3 import Qwen3Config
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class MossTTSDelayConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`MossTTSDelayModel`]. It is used to instantiate an
28
+ MossTTSDelay model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of the MossTTSDelay [MossTTSDelay-8B](https://huggingface.co/OpenMOSS/mosstts-8b) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ language_config (`Union[Qwen3Config, dict]`, *optional*):
36
+ Configuration for the backbone language model (Qwen3).
37
+ initializer_range (`float`, *optional*, defaults to 0.02):
38
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
39
+ n_vq (`int`, *optional*, defaults to 32):
40
+ Number of additional VQ (Vector Quantization) heads/channels for audio.
41
+ Determines the number of codebooks used in the audio representation.
42
+ audio_vocab_size (`int`, *optional*, defaults to 1024):
43
+ Vocabulary size for the audio tokens (codebooks 1 to N).
44
+ audio_user_slot_token_id (`int`, *optional*, defaults to 151654):
45
+ The specific token ID used as a placeholder/slot for user-side audio inputs in the prompt.
46
+ audio_assistant_gen_slot_token_id (`int`, *optional*, defaults to 151656):
47
+ The specific token ID representing the generation slot for the assistant's audio output.
48
+ Acting as the trigger for the TTS generation process.
49
+ audio_assistant_delay_slot_token_id (`int`, *optional*, defaults to 151662):
50
+ The token ID used in the 'Delay Pattern' paradigm to represent the delayed/offset positions
51
+ between different VQ channels.
52
+ audio_start_token_id (`int`, *optional*, defaults to 151652):
53
+ Special token ID used to denote the start of an audio sequence in the stream.
54
+ audio_end_token_id (`int`, *optional*, defaults to 151653):
55
+ Special token ID used to denote the end of an audio sequence (EOS for audio).
56
+ audio_pad_code (`int`, *optional*, defaults to 1024):
57
+ The padding value used within the audio VQ codebooks. Typically equals `audio_vocab_size`.
58
+ """
59
+ model_type = "moss_tts_delay"
60
+ keys_to_ignore_at_inference = ["past_key_values"]
61
+
62
+ def __init__(
63
+ self,
64
+ language_config: Optional[Union[Qwen3Config, dict]] = None,
65
+ initializer_range: float = 0.02,
66
+ n_vq: int = 32,
67
+ pad_token_id: int = 151643,
68
+ im_start_token_id: int = 151644,
69
+ im_end_token_id: int = 151645,
70
+ audio_vocab_size: int = 1024,
71
+ audio_user_slot_token_id: int = 151654,
72
+ audio_assistant_gen_slot_token_id: int = 151656,
73
+ audio_assistant_delay_slot_token_id: int = 151662,
74
+ audio_start_token_id: int = 151652,
75
+ audio_end_token_id: int = 151653,
76
+ audio_pad_code: int = 1024,
77
+ sampling_rate: int = 24000,
78
+ **kwargs,
79
+ ):
80
+ if isinstance(language_config, dict):
81
+ self.language_config = Qwen3Config(**language_config)
82
+ elif language_config is None:
83
+ self.language_config = Qwen3Config()
84
+ else:
85
+ self.language_config = language_config
86
+
87
+ self.initializer_range = initializer_range
88
+ self.n_vq = n_vq
89
+ self.audio_vocab_size = audio_vocab_size
90
+ self.audio_user_slot_token_id = audio_user_slot_token_id
91
+ self.audio_assistant_gen_slot_token_id = audio_assistant_gen_slot_token_id
92
+ self.audio_assistant_delay_slot_token_id = audio_assistant_delay_slot_token_id
93
+ self.audio_start_token_id = audio_start_token_id
94
+ self.audio_end_token_id = audio_end_token_id
95
+ self.audio_pad_code = audio_pad_code
96
+ self.sampling_rate = sampling_rate
97
+
98
+ self.hidden_size = self.language_config.hidden_size
99
+ self.vocab_size = self.language_config.vocab_size
100
+ self.im_start_token_id = self.language_config
101
+ self.pad_token_id = pad_token_id
102
+ self.im_start_token_id = im_start_token_id
103
+ self.im_end_token_id = im_end_token_id
104
+
105
+
106
+ super().__init__(**kwargs)
107
+
108
+ def to_dict(self):
109
+ output = super().to_dict()
110
+ if hasattr(self.language_config, "to_dict"):
111
+ output["language_config"] = self.language_config.to_dict()
112
+ else:
113
+ output["language_config"] = self.language_config
114
+ return output
inference_utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import torch.nn.functional as F
4
+ from typing import Optional, List, Tuple
5
+ from tqdm import tqdm
6
+
7
+
8
+ def apply_top_k(logits, top_k):
9
+ batch_size, vocab_size = logits.shape
10
+ top_k = min(top_k, vocab_size)
11
+ top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1)
12
+ filtered_logits = torch.full_like(logits, float("-inf"))
13
+ batch_indices = torch.arange(batch_size).unsqueeze(-1)
14
+ filtered_logits[batch_indices, top_k_indices] = top_k_values
15
+ return filtered_logits
16
+
17
+
18
+ def apply_top_p(logits, top_p):
19
+ probs = F.softmax(logits, dim=-1)
20
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
21
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
22
+ sorted_indices_to_remove = cumulative_probs > top_p
23
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
24
+ sorted_indices_to_remove[..., 0] = False
25
+ batch_size = logits.shape[0]
26
+ filtered_logits = logits.clone()
27
+ for i in range(batch_size):
28
+ indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
29
+ filtered_logits[i, indices_to_remove] = float("-inf")
30
+ return filtered_logits
31
+
32
+
33
+ def apply_top_p_optimized(logits, top_p):
34
+ probs = F.softmax(logits, dim=-1)
35
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
36
+
37
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
38
+
39
+ sorted_indices_to_remove = cumulative_probs > top_p
40
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
41
+ sorted_indices_to_remove[..., 0] = False
42
+
43
+ indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
44
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
45
+ )
46
+
47
+ logits[indices_to_remove] = float("-inf")
48
+ return logits
49
+
50
+
51
+ def apply_repetition_penalty_delay_pattern(
52
+ logits: torch.Tensor,
53
+ prev_tokens: torch.LongTensor,
54
+ penalty: float,
55
+ ):
56
+ """
57
+ logits: [B, H, V] or [N, V]
58
+ prev_tokens: [B, T, H] or [N, T] or [B, H]
59
+
60
+ Apply the repetition penalty independently for each H (VQ head).
61
+ """
62
+ if penalty == 1.0 or prev_tokens is None:
63
+ return logits
64
+
65
+ vocab_size = logits.size(-1)
66
+
67
+ # Case 1: regular [N, V] (text layer)
68
+ if logits.dim() == 2:
69
+ prev_tokens_flat = prev_tokens.reshape(-1)
70
+ unique_tokens = torch.unique(prev_tokens_flat)
71
+
72
+ token_logits = logits[:, unique_tokens]
73
+ pos_mask = token_logits > 0
74
+ token_logits[pos_mask] /= penalty
75
+ token_logits[~pos_mask] *= penalty
76
+ logits[:, unique_tokens] = token_logits
77
+ return logits
78
+
79
+ # Case 2: Delay Pattern audio [B, H, V]
80
+ assert logits.dim() == 3, "Delay Pattern audio logits must be [B, H, V]"
81
+ B, H, V = logits.shape
82
+
83
+ for h in range(H):
84
+ # prev_tokens_h: [B, T] or [B]
85
+ prev_tokens_h = prev_tokens[..., h].reshape(-1)
86
+ unique_tokens = torch.unique(prev_tokens_h)
87
+
88
+ if unique_tokens.numel() == 0:
89
+ continue
90
+
91
+ token_logits = logits[:, h, unique_tokens]
92
+ pos_mask = token_logits > 0
93
+ token_logits[pos_mask] /= penalty
94
+ token_logits[~pos_mask] *= penalty
95
+ logits[:, h, unique_tokens] = token_logits
96
+
97
+ return logits
98
+
99
+
100
+ def sample_token(
101
+ logits,
102
+ prev_tokens: Optional[torch.LongTensor] = None,
103
+ repetition_penalty: float = 1.0,
104
+ top_p=None,
105
+ top_k=None,
106
+ do_sample=True,
107
+ ):
108
+ vocab_size = logits.size(-1)
109
+
110
+ # ===== Repetition Penalty (before reshaping!) =====
111
+ if prev_tokens is not None and repetition_penalty != 1.0:
112
+ logits = apply_repetition_penalty_delay_pattern(
113
+ logits,
114
+ prev_tokens,
115
+ repetition_penalty,
116
+ )
117
+
118
+ if not do_sample:
119
+ return torch.argmax(logits, dim=-1)
120
+
121
+ # ===== Only flatten after this, for top-k / top-p / multinomial =====
122
+ original_shape = logits.shape
123
+ reshaped_logits = logits.view(-1, vocab_size)
124
+
125
+ if top_k is not None and top_k > 0:
126
+ reshaped_logits = apply_top_k(reshaped_logits, top_k)
127
+
128
+ if top_p is not None and top_p < 1.0:
129
+ reshaped_logits = apply_top_p_optimized(reshaped_logits, top_p)
130
+
131
+ probs = F.softmax(reshaped_logits, dim=-1)
132
+ next_tokens = torch.multinomial(probs, num_samples=1)
133
+
134
+ return next_tokens.view(original_shape[:-1])
135
+
136
+
137
+ def find_last_equal_C(tensor, C):
138
+ """
139
+ tensor: torch.Tensor of shape [batch_size, seq_len]
140
+ C: scalar value to match
141
+ Returns: torch.Tensor of shape [batch_size] with last indices
142
+ """
143
+ mask = (tensor == C).int() # Shape: [batch_size, seq_len], bool tensor
144
+ flipped_mask = mask.flip(dims=[1]) # Flip along sequence dimension
145
+ flipped_indices = flipped_mask.argmax(dim=1) # First True in flipped
146
+ seq_len = tensor.shape[1]
147
+ last_indices = (seq_len - 1) - flipped_indices # Convert to original indices
148
+
149
+ # Optional: Handle cases with no C (set to -1), though problem assumes existence
150
+ actual_values = tensor[torch.arange(tensor.shape[0]), last_indices]
151
+ no_match = actual_values != C
152
+ last_indices[no_match] = -1
153
+
154
+ return last_indices
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b7d7ef1f70796b5b7bbf1277c8d61ae0ec4879ac9f5a92ca06af3fcfb21cfa5
3
+ size 4932667368
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35f93f16e4a5cfa7b10a44fb233d48f6206dfd65f5f4ec16d4ef622e46e1051c
3
+ size 4915961640
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43305ea6285bb1eb55d3bb004c5dab05d9ebf3dd259b79f5e328f493f3f90a1f
3
+ size 4983069760
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8748f19c135e185febfc4964120b2246021f4bbb7104e11493248984f07fa79b
3
+ size 1879339648
model.safetensors.index.json ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_parameters": 8489841664,
4
+ "total_size": 16979683328
5
+ },
6
+ "weight_map": {
7
+ "emb_ext.0.weight": "model-00004-of-00004.safetensors",
8
+ "emb_ext.1.weight": "model-00004-of-00004.safetensors",
9
+ "emb_ext.10.weight": "model-00004-of-00004.safetensors",
10
+ "emb_ext.11.weight": "model-00004-of-00004.safetensors",
11
+ "emb_ext.12.weight": "model-00004-of-00004.safetensors",
12
+ "emb_ext.13.weight": "model-00004-of-00004.safetensors",
13
+ "emb_ext.14.weight": "model-00004-of-00004.safetensors",
14
+ "emb_ext.15.weight": "model-00004-of-00004.safetensors",
15
+ "emb_ext.16.weight": "model-00004-of-00004.safetensors",
16
+ "emb_ext.17.weight": "model-00004-of-00004.safetensors",
17
+ "emb_ext.18.weight": "model-00004-of-00004.safetensors",
18
+ "emb_ext.19.weight": "model-00004-of-00004.safetensors",
19
+ "emb_ext.2.weight": "model-00004-of-00004.safetensors",
20
+ "emb_ext.20.weight": "model-00004-of-00004.safetensors",
21
+ "emb_ext.21.weight": "model-00004-of-00004.safetensors",
22
+ "emb_ext.22.weight": "model-00004-of-00004.safetensors",
23
+ "emb_ext.23.weight": "model-00004-of-00004.safetensors",
24
+ "emb_ext.24.weight": "model-00004-of-00004.safetensors",
25
+ "emb_ext.25.weight": "model-00004-of-00004.safetensors",
26
+ "emb_ext.26.weight": "model-00004-of-00004.safetensors",
27
+ "emb_ext.27.weight": "model-00004-of-00004.safetensors",
28
+ "emb_ext.28.weight": "model-00004-of-00004.safetensors",
29
+ "emb_ext.29.weight": "model-00004-of-00004.safetensors",
30
+ "emb_ext.3.weight": "model-00004-of-00004.safetensors",
31
+ "emb_ext.30.weight": "model-00004-of-00004.safetensors",
32
+ "emb_ext.31.weight": "model-00004-of-00004.safetensors",
33
+ "emb_ext.4.weight": "model-00004-of-00004.safetensors",
34
+ "emb_ext.5.weight": "model-00004-of-00004.safetensors",
35
+ "emb_ext.6.weight": "model-00004-of-00004.safetensors",
36
+ "emb_ext.7.weight": "model-00004-of-00004.safetensors",
37
+ "emb_ext.8.weight": "model-00004-of-00004.safetensors",
38
+ "emb_ext.9.weight": "model-00004-of-00004.safetensors",
39
+ "language_model.embed_tokens.weight": "model-00001-of-00004.safetensors",
40
+ "language_model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
41
+ "language_model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
42
+ "language_model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
43
+ "language_model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
44
+ "language_model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
45
+ "language_model.layers.0.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
46
+ "language_model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
47
+ "language_model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
48
+ "language_model.layers.0.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
49
+ "language_model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
50
+ "language_model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
51
+ "language_model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
52
+ "language_model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
53
+ "language_model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
54
+ "language_model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
55
+ "language_model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
56
+ "language_model.layers.1.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
57
+ "language_model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
58
+ "language_model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
59
+ "language_model.layers.1.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
60
+ "language_model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
61
+ "language_model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
62
+ "language_model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
63
+ "language_model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
64
+ "language_model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
65
+ "language_model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
66
+ "language_model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
67
+ "language_model.layers.10.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
68
+ "language_model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
69
+ "language_model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
70
+ "language_model.layers.10.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
71
+ "language_model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
72
+ "language_model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
73
+ "language_model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
74
+ "language_model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
75
+ "language_model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
76
+ "language_model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
77
+ "language_model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
78
+ "language_model.layers.11.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
79
+ "language_model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
80
+ "language_model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
81
+ "language_model.layers.11.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
82
+ "language_model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
83
+ "language_model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
84
+ "language_model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
85
+ "language_model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
86
+ "language_model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
87
+ "language_model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
88
+ "language_model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
89
+ "language_model.layers.12.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
90
+ "language_model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
91
+ "language_model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
92
+ "language_model.layers.12.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
93
+ "language_model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
94
+ "language_model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
95
+ "language_model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
96
+ "language_model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
97
+ "language_model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
98
+ "language_model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
99
+ "language_model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
100
+ "language_model.layers.13.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
101
+ "language_model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
102
+ "language_model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
103
+ "language_model.layers.13.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
104
+ "language_model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
105
+ "language_model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
106
+ "language_model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
107
+ "language_model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
108
+ "language_model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
109
+ "language_model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
110
+ "language_model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
111
+ "language_model.layers.14.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
112
+ "language_model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
113
+ "language_model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
114
+ "language_model.layers.14.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
115
+ "language_model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
116
+ "language_model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
117
+ "language_model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
118
+ "language_model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
119
+ "language_model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
120
+ "language_model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
121
+ "language_model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
122
+ "language_model.layers.15.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
123
+ "language_model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
124
+ "language_model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
125
+ "language_model.layers.15.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
126
+ "language_model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
127
+ "language_model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
128
+ "language_model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
129
+ "language_model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
130
+ "language_model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
131
+ "language_model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
132
+ "language_model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
133
+ "language_model.layers.16.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
134
+ "language_model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
135
+ "language_model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
136
+ "language_model.layers.16.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
137
+ "language_model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
138
+ "language_model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
139
+ "language_model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
140
+ "language_model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
141
+ "language_model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
142
+ "language_model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
143
+ "language_model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
144
+ "language_model.layers.17.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
145
+ "language_model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
146
+ "language_model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
147
+ "language_model.layers.17.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
148
+ "language_model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
149
+ "language_model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
150
+ "language_model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
151
+ "language_model.layers.18.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
152
+ "language_model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
153
+ "language_model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
154
+ "language_model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
155
+ "language_model.layers.18.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
156
+ "language_model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
157
+ "language_model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
158
+ "language_model.layers.18.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
159
+ "language_model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
160
+ "language_model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
161
+ "language_model.layers.19.input_layernorm.weight": "model-00002-of-00004.safetensors",
162
+ "language_model.layers.19.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
163
+ "language_model.layers.19.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
164
+ "language_model.layers.19.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
165
+ "language_model.layers.19.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
166
+ "language_model.layers.19.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
167
+ "language_model.layers.19.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
168
+ "language_model.layers.19.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
169
+ "language_model.layers.19.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
170
+ "language_model.layers.19.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
171
+ "language_model.layers.19.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
172
+ "language_model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
173
+ "language_model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
174
+ "language_model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
175
+ "language_model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
176
+ "language_model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
177
+ "language_model.layers.2.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
178
+ "language_model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
179
+ "language_model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
180
+ "language_model.layers.2.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
181
+ "language_model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
182
+ "language_model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
183
+ "language_model.layers.20.input_layernorm.weight": "model-00002-of-00004.safetensors",
184
+ "language_model.layers.20.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
185
+ "language_model.layers.20.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
186
+ "language_model.layers.20.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
187
+ "language_model.layers.20.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
188
+ "language_model.layers.20.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
189
+ "language_model.layers.20.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
190
+ "language_model.layers.20.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
191
+ "language_model.layers.20.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
192
+ "language_model.layers.20.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
193
+ "language_model.layers.20.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
194
+ "language_model.layers.21.input_layernorm.weight": "model-00002-of-00004.safetensors",
195
+ "language_model.layers.21.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
196
+ "language_model.layers.21.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
197
+ "language_model.layers.21.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
198
+ "language_model.layers.21.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
199
+ "language_model.layers.21.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
200
+ "language_model.layers.21.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
201
+ "language_model.layers.21.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
202
+ "language_model.layers.21.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
203
+ "language_model.layers.21.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
204
+ "language_model.layers.21.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
205
+ "language_model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
206
+ "language_model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
207
+ "language_model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
208
+ "language_model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
209
+ "language_model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
210
+ "language_model.layers.22.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
211
+ "language_model.layers.22.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
212
+ "language_model.layers.22.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
213
+ "language_model.layers.22.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
214
+ "language_model.layers.22.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
215
+ "language_model.layers.22.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
216
+ "language_model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
217
+ "language_model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
218
+ "language_model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
219
+ "language_model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
220
+ "language_model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
221
+ "language_model.layers.23.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
222
+ "language_model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
223
+ "language_model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
224
+ "language_model.layers.23.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
225
+ "language_model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
226
+ "language_model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
227
+ "language_model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
228
+ "language_model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
229
+ "language_model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
230
+ "language_model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
231
+ "language_model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
232
+ "language_model.layers.24.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
233
+ "language_model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
234
+ "language_model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
235
+ "language_model.layers.24.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
236
+ "language_model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
237
+ "language_model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
238
+ "language_model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
239
+ "language_model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
240
+ "language_model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
241
+ "language_model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
242
+ "language_model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
243
+ "language_model.layers.25.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
244
+ "language_model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
245
+ "language_model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
246
+ "language_model.layers.25.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
247
+ "language_model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
248
+ "language_model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
249
+ "language_model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
250
+ "language_model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
251
+ "language_model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
252
+ "language_model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
253
+ "language_model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
254
+ "language_model.layers.26.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
255
+ "language_model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
256
+ "language_model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
257
+ "language_model.layers.26.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
258
+ "language_model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
259
+ "language_model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
260
+ "language_model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
261
+ "language_model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
262
+ "language_model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
263
+ "language_model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
264
+ "language_model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
265
+ "language_model.layers.27.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
266
+ "language_model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
267
+ "language_model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
268
+ "language_model.layers.27.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
269
+ "language_model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
270
+ "language_model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
271
+ "language_model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
272
+ "language_model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
273
+ "language_model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
274
+ "language_model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
275
+ "language_model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
276
+ "language_model.layers.28.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
277
+ "language_model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
278
+ "language_model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
279
+ "language_model.layers.28.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
280
+ "language_model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
281
+ "language_model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
282
+ "language_model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
283
+ "language_model.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
284
+ "language_model.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
285
+ "language_model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
286
+ "language_model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
287
+ "language_model.layers.29.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
288
+ "language_model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
289
+ "language_model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
290
+ "language_model.layers.29.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
291
+ "language_model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
292
+ "language_model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
293
+ "language_model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
294
+ "language_model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
295
+ "language_model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
296
+ "language_model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
297
+ "language_model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
298
+ "language_model.layers.3.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
299
+ "language_model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
300
+ "language_model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
301
+ "language_model.layers.3.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
302
+ "language_model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
303
+ "language_model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
304
+ "language_model.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors",
305
+ "language_model.layers.30.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
306
+ "language_model.layers.30.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
307
+ "language_model.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
308
+ "language_model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
309
+ "language_model.layers.30.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
310
+ "language_model.layers.30.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
311
+ "language_model.layers.30.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
312
+ "language_model.layers.30.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
313
+ "language_model.layers.30.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
314
+ "language_model.layers.30.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
315
+ "language_model.layers.31.input_layernorm.weight": "model-00003-of-00004.safetensors",
316
+ "language_model.layers.31.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
317
+ "language_model.layers.31.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
318
+ "language_model.layers.31.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
319
+ "language_model.layers.31.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
320
+ "language_model.layers.31.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
321
+ "language_model.layers.31.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
322
+ "language_model.layers.31.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
323
+ "language_model.layers.31.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
324
+ "language_model.layers.31.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
325
+ "language_model.layers.31.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
326
+ "language_model.layers.32.input_layernorm.weight": "model-00003-of-00004.safetensors",
327
+ "language_model.layers.32.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
328
+ "language_model.layers.32.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
329
+ "language_model.layers.32.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
330
+ "language_model.layers.32.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
331
+ "language_model.layers.32.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
332
+ "language_model.layers.32.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
333
+ "language_model.layers.32.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
334
+ "language_model.layers.32.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
335
+ "language_model.layers.32.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
336
+ "language_model.layers.32.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
337
+ "language_model.layers.33.input_layernorm.weight": "model-00003-of-00004.safetensors",
338
+ "language_model.layers.33.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
339
+ "language_model.layers.33.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
340
+ "language_model.layers.33.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
341
+ "language_model.layers.33.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
342
+ "language_model.layers.33.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
343
+ "language_model.layers.33.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
344
+ "language_model.layers.33.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
345
+ "language_model.layers.33.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
346
+ "language_model.layers.33.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
347
+ "language_model.layers.33.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
348
+ "language_model.layers.34.input_layernorm.weight": "model-00003-of-00004.safetensors",
349
+ "language_model.layers.34.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
350
+ "language_model.layers.34.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
351
+ "language_model.layers.34.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
352
+ "language_model.layers.34.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
353
+ "language_model.layers.34.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
354
+ "language_model.layers.34.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
355
+ "language_model.layers.34.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
356
+ "language_model.layers.34.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
357
+ "language_model.layers.34.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
358
+ "language_model.layers.34.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
359
+ "language_model.layers.35.input_layernorm.weight": "model-00004-of-00004.safetensors",
360
+ "language_model.layers.35.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
361
+ "language_model.layers.35.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
362
+ "language_model.layers.35.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
363
+ "language_model.layers.35.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
364
+ "language_model.layers.35.self_attn.k_norm.weight": "model-00004-of-00004.safetensors",
365
+ "language_model.layers.35.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
366
+ "language_model.layers.35.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
367
+ "language_model.layers.35.self_attn.q_norm.weight": "model-00004-of-00004.safetensors",
368
+ "language_model.layers.35.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
369
+ "language_model.layers.35.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
370
+ "language_model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
371
+ "language_model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
372
+ "language_model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
373
+ "language_model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
374
+ "language_model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
375
+ "language_model.layers.4.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
376
+ "language_model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
377
+ "language_model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
378
+ "language_model.layers.4.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
379
+ "language_model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
380
+ "language_model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
381
+ "language_model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
382
+ "language_model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
383
+ "language_model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
384
+ "language_model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
385
+ "language_model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
386
+ "language_model.layers.5.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
387
+ "language_model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
388
+ "language_model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
389
+ "language_model.layers.5.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
390
+ "language_model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
391
+ "language_model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
392
+ "language_model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
393
+ "language_model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
394
+ "language_model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
395
+ "language_model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
396
+ "language_model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
397
+ "language_model.layers.6.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
398
+ "language_model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
399
+ "language_model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
400
+ "language_model.layers.6.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
401
+ "language_model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
402
+ "language_model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
403
+ "language_model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
404
+ "language_model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
405
+ "language_model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
406
+ "language_model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
407
+ "language_model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
408
+ "language_model.layers.7.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
409
+ "language_model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
410
+ "language_model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
411
+ "language_model.layers.7.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
412
+ "language_model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
413
+ "language_model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
414
+ "language_model.layers.8.input_layernorm.weight": "model-00001-of-00004.safetensors",
415
+ "language_model.layers.8.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
416
+ "language_model.layers.8.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
417
+ "language_model.layers.8.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
418
+ "language_model.layers.8.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
419
+ "language_model.layers.8.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
420
+ "language_model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
421
+ "language_model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
422
+ "language_model.layers.8.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
423
+ "language_model.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
424
+ "language_model.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
425
+ "language_model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
426
+ "language_model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
427
+ "language_model.layers.9.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
428
+ "language_model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
429
+ "language_model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
430
+ "language_model.layers.9.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
431
+ "language_model.layers.9.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
432
+ "language_model.layers.9.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
433
+ "language_model.layers.9.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
434
+ "language_model.layers.9.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
435
+ "language_model.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
436
+ "language_model.norm.weight": "model-00004-of-00004.safetensors",
437
+ "lm_heads.0.weight": "model-00004-of-00004.safetensors",
438
+ "lm_heads.1.weight": "model-00004-of-00004.safetensors",
439
+ "lm_heads.10.weight": "model-00004-of-00004.safetensors",
440
+ "lm_heads.11.weight": "model-00004-of-00004.safetensors",
441
+ "lm_heads.12.weight": "model-00004-of-00004.safetensors",
442
+ "lm_heads.13.weight": "model-00004-of-00004.safetensors",
443
+ "lm_heads.14.weight": "model-00004-of-00004.safetensors",
444
+ "lm_heads.15.weight": "model-00004-of-00004.safetensors",
445
+ "lm_heads.16.weight": "model-00004-of-00004.safetensors",
446
+ "lm_heads.17.weight": "model-00004-of-00004.safetensors",
447
+ "lm_heads.18.weight": "model-00004-of-00004.safetensors",
448
+ "lm_heads.19.weight": "model-00004-of-00004.safetensors",
449
+ "lm_heads.2.weight": "model-00004-of-00004.safetensors",
450
+ "lm_heads.20.weight": "model-00004-of-00004.safetensors",
451
+ "lm_heads.21.weight": "model-00004-of-00004.safetensors",
452
+ "lm_heads.22.weight": "model-00004-of-00004.safetensors",
453
+ "lm_heads.23.weight": "model-00004-of-00004.safetensors",
454
+ "lm_heads.24.weight": "model-00004-of-00004.safetensors",
455
+ "lm_heads.25.weight": "model-00004-of-00004.safetensors",
456
+ "lm_heads.26.weight": "model-00004-of-00004.safetensors",
457
+ "lm_heads.27.weight": "model-00004-of-00004.safetensors",
458
+ "lm_heads.28.weight": "model-00004-of-00004.safetensors",
459
+ "lm_heads.29.weight": "model-00004-of-00004.safetensors",
460
+ "lm_heads.3.weight": "model-00004-of-00004.safetensors",
461
+ "lm_heads.30.weight": "model-00004-of-00004.safetensors",
462
+ "lm_heads.31.weight": "model-00004-of-00004.safetensors",
463
+ "lm_heads.32.weight": "model-00004-of-00004.safetensors",
464
+ "lm_heads.4.weight": "model-00004-of-00004.safetensors",
465
+ "lm_heads.5.weight": "model-00004-of-00004.safetensors",
466
+ "lm_heads.6.weight": "model-00004-of-00004.safetensors",
467
+ "lm_heads.7.weight": "model-00004-of-00004.safetensors",
468
+ "lm_heads.8.weight": "model-00004-of-00004.safetensors",
469
+ "lm_heads.9.weight": "model-00004-of-00004.safetensors"
470
+ }
471
+ }
modeling_moss_tts.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Modeling classes for MossTTSDelay. """
16
+
17
+ from dataclasses import dataclass
18
+ from typing import List, Optional, Tuple, Union
19
+ from tqdm import tqdm
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch.nn import CrossEntropyLoss
24
+
25
+ from transformers.modeling_utils import PreTrainedModel
26
+ from transformers.modeling_outputs import ModelOutput
27
+ from transformers.utils import (
28
+ add_start_docstrings,
29
+ add_start_docstrings_to_model_forward,
30
+ logging,
31
+ replace_return_docstrings,
32
+ )
33
+ from transformers.cache_utils import Cache
34
+ from transformers.models.qwen3 import Qwen3Model
35
+ from transformers import initialization as init
36
+
37
+ from .configuration_moss_tts import MossTTSDelayConfig
38
+ from .inference_utils import sample_token, find_last_equal_C
39
+
40
+ try:
41
+ from .processing_moss_tts import UserMessage, AssistantMessage, MossTTSDelayProcessor
42
+ except Exception:
43
+ UserMessage = None
44
+ AssistantMessage = None
45
+ MossTTSDelayProcessor = None
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+ _CONFIG_FOR_DOC = "MossTTSDelayConfig"
50
+
51
+
52
+ @dataclass
53
+ class MossTTSDelayOutputWithPast(ModelOutput):
54
+ """
55
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
56
+
57
+ Args:
58
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
59
+ Weighted sum of channel losses.
60
+ all_sum_losses (`torch.FloatTensor` of shape `(batch_size, n_vq + 1)`, *optional*):
61
+ Sum of losses for each sample and each channel before averaging.
62
+ all_token_nums (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
63
+ Number of non-masked tokens per sample.
64
+ sample_losses (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
65
+ Loss per sample.
66
+ channel_losses (`torch.FloatTensor` of shape `(n_vq + 1,)`, *optional*):
67
+ Loss per channel (text head + vq heads).
68
+ logits (`List[torch.FloatTensor]`, *optional*):
69
+ List of prediction scores from each head.
70
+ past_key_values (`Cache`, *optional*):
71
+ Pre-computed hidden-states (key and values in the self-attention blocks).
72
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
73
+ Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, +
74
+ one for the output of each layer).
75
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed):
76
+ Tuple of torch.FloatTensor (one for each layer) of the attention weights.
77
+ """
78
+ loss: Optional[torch.FloatTensor] = None
79
+ all_sum_losses: Optional[torch.FloatTensor] = None
80
+ all_token_nums: Optional[torch.LongTensor] = None
81
+ sample_losses: Optional[torch.FloatTensor] = None
82
+ channel_losses: Optional[torch.FloatTensor] = None
83
+ logits: Optional[List[torch.FloatTensor]] = None
84
+ past_key_values: Optional[Cache] = None
85
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
86
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
87
+
88
+
89
+
90
+
91
+ class MossTTSDelayPreTrainedModel(PreTrainedModel):
92
+ config_class = MossTTSDelayConfig
93
+ base_model_prefix = "model"
94
+ supports_gradient_checkpointing = True
95
+ _no_split_modules = ["Qwen3DecoderLayer"]
96
+ _skip_keys_device_placement = "past_key_values"
97
+ _supports_flash_attn = True
98
+ _supports_flash_attn_2 = True
99
+ _supports_sdpa = True
100
+ _supports_flex_attn = True
101
+
102
+ def _init_weights(self, module):
103
+ """
104
+ Transformers 5.0+ safe init:
105
+ - MUST use transformers.initialization helpers
106
+ - MUST respect param._is_hf_initialized to avoid overwriting ckpt-loaded params
107
+ """
108
+ # Let HF handle its standard modules first (LayerNorm, Linear, Embedding, etc.)
109
+ super()._init_weights(module)
110
+
111
+ # Pick a std consistent with HF conventions
112
+ # Prefer model/text config initializer_range if present.
113
+ std = None
114
+ if hasattr(self.config, "initializer_range"):
115
+ std = self.config.initializer_range
116
+ elif hasattr(self.config, "language_config") and hasattr(self.config.language_config, "initializer_range"):
117
+ std = self.config.language_config.initializer_range
118
+ else:
119
+ std = 0.02
120
+
121
+ # Initialize extra audio embeddings
122
+ if isinstance(module, nn.Embedding):
123
+ # Only touch our extra embeddings (avoid double touching LM's embeddings if not desired)
124
+ # If you prefer, you can skip this check and rely on super()._init_weights for all embeddings.
125
+ if getattr(module, "num_embeddings", None) == self.config.audio_vocab_size + 1:
126
+ init.normal_(module.weight, mean=0.0, std=std)
127
+ # If you later set padding_idx, you must explicitly zero it (and respect _is_hf_initialized!)
128
+ # init.zeros_ will internally check param flags, but slicing needs manual care.
129
+
130
+ # Initialize multi-head projections you added
131
+ if isinstance(module, nn.Linear):
132
+ # For your lm_heads, super()._init_weights already covers typical Linear.
133
+ # This block is only needed if you have custom Linear variants later.
134
+ pass
135
+
136
+
137
+
138
+ MOSSTTS_START_DOCSTRING = r"""
139
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
140
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
141
+ etc.)
142
+
143
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
144
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
145
+ and behavior.
146
+
147
+ Parameters:
148
+ config ([`MossTTSDelayConfig`]):
149
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
150
+ load the weights associated with the model, only the configuration. Check out the
151
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
152
+ """
153
+
154
+
155
+ @add_start_docstrings(
156
+ "The MossTTSDelay Model architecture tailored for Text-to-Speech generation with multi-head VQ prediction.",
157
+ MOSSTTS_START_DOCSTRING,
158
+ )
159
+ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
160
+ UserMessage = UserMessage
161
+ AssistantMessage = AssistantMessage
162
+ Processor = MossTTSDelayProcessor
163
+
164
+ def __init__(self, config: MossTTSDelayConfig):
165
+ super().__init__(config)
166
+ self.config = config
167
+
168
+ config.language_config.torch_dtype = config.torch_dtype
169
+
170
+ self.language_model = Qwen3Model(config.language_config)
171
+
172
+ # Audio VQ Embeddings (Extra channels)
173
+ # Note: input_ids[..., 0] uses Qwen's embedding.
174
+ # input_ids[..., 1:] use these extensions.
175
+ self.emb_ext = nn.ModuleList()
176
+ for vq_idx in range(self.config.n_vq):
177
+ # Add +1 for potential padding/special tokens logic if strictly required by upstream data prep
178
+ self.emb_ext.append(
179
+ nn.Embedding(self.config.audio_vocab_size + 1, config.language_config.hidden_size, padding_idx=None)
180
+ )
181
+
182
+ # Multi-Head Prediction Layers
183
+ # Head 0: Main language head
184
+ # Head 1..N: Audio VQ heads
185
+ self.lm_heads = nn.ModuleList([
186
+ nn.Linear(config.language_config.hidden_size, config.language_config.vocab_size, bias=False)
187
+ ])
188
+ for vq_idx in range(self.config.n_vq):
189
+ self.lm_heads.append(
190
+ nn.Linear(config.language_config.hidden_size, self.config.audio_vocab_size + 1, bias=False)
191
+ )
192
+
193
+ # Initialize weights and apply final processing
194
+ self.post_init()
195
+
196
+ def get_input_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
197
+ """
198
+ Computes the combined embeddings from text and multiple audio VQ channels.
199
+
200
+ Args:
201
+ input_ids: Shape (Batch, Seq_Len, 1 + n_vq)
202
+ """
203
+ # Base Text/Content Embedding
204
+ # input_ids[..., 0] is standard text or semantic tokens
205
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids[..., 0])
206
+
207
+ # Add VQ Embeddings
208
+ for i, embed_layer in enumerate(self.emb_ext):
209
+ # i corresponds to channel i+1 in input_ids
210
+ # We assume the data pipeline ensures indices are within range
211
+ inputs_embeds = inputs_embeds + embed_layer(input_ids[..., i + 1])
212
+
213
+ return inputs_embeds
214
+
215
+ def set_input_embeddings(self, value):
216
+ self.language_model.embed_tokens = value
217
+
218
+ def get_output_embeddings(self):
219
+ # Returning a list of heads might break some HF utilities expecting a single head.
220
+ # However, for custom models, this is acceptable.
221
+ return self.lm_heads
222
+
223
+ @add_start_docstrings_to_model_forward(MOSSTTS_START_DOCSTRING)
224
+ @replace_return_docstrings(output_type=MossTTSDelayOutputWithPast, config_class=_CONFIG_FOR_DOC)
225
+ def forward(
226
+ self,
227
+ input_ids: Optional[torch.LongTensor] = None,
228
+ attention_mask: Optional[torch.Tensor] = None,
229
+ position_ids: Optional[torch.LongTensor] = None,
230
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
231
+ inputs_embeds: Optional[torch.FloatTensor] = None,
232
+ labels: Optional[torch.LongTensor] = None,
233
+ use_cache: Optional[bool] = None,
234
+ output_attentions: Optional[bool] = None,
235
+ cache_position: Optional[torch.LongTensor] = None,
236
+ hidden_out_layers: Optional[List[int]] = None,
237
+ channelwise_loss_weight: Optional[List[float]] = None,
238
+ **kwargs,
239
+ ) -> Union[Tuple, MossTTSDelayOutputWithPast]:
240
+ r"""
241
+ Args:
242
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 1 + n_vq)`):
243
+ Indices of input sequence tokens in the vocabulary.
244
+ Dimension 2 contains: [Text/Semantics, VQ_0, VQ_1, ..., VQ_N].
245
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length, 1 + n_vq)`, *optional*):
246
+ Labels for computing the masked language modeling loss.
247
+ channelwise_loss_weight (`List[float]`, *optional*):
248
+ Manual weights for summing losses across different heads (Text vs Audio channels).
249
+
250
+ Returns:
251
+ """
252
+
253
+ if len(input_ids.shape) != 3 or input_ids.shape[-1] != self.config.n_vq + 1:
254
+ raise ValueError("`Input_ids`'s shape should be exactly (batch_size, sequence_length, 1 + n_vq).")
255
+
256
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
257
+
258
+ # 1. Prepare Embeddings
259
+ if inputs_embeds is None:
260
+ inputs_embeds = self.get_input_embeddings(input_ids)
261
+
262
+ # 2. Backbone Forward
263
+ # Qwen3Model outputs standard CausalLMOutputWithPast or similar
264
+ outputs = self.language_model(
265
+ input_ids=None, # Passed via inputs_embeds
266
+ position_ids=position_ids,
267
+ attention_mask=attention_mask,
268
+ past_key_values=past_key_values,
269
+ inputs_embeds=inputs_embeds,
270
+ use_cache=use_cache,
271
+ output_attentions=output_attentions,
272
+ output_hidden_states=True, # Always need hidden states for multi-head projection
273
+ return_dict=True,
274
+ cache_position=cache_position,
275
+ **kwargs,
276
+ )
277
+
278
+ # 3. Handle specific layer outputs if requested (Delay Pattern often requires features from specific layers)
279
+ last_hidden_state = outputs.last_hidden_state
280
+ if hidden_out_layers is None:
281
+ # Default to using the last layer for all heads
282
+ # In some architectures (like MusicGen), different codebooks come from different transformer layers.
283
+ # Here we default to the final layer as per original code behavior [-1] * (n + 1).
284
+ hidden_states_for_heads = [last_hidden_state] * (len(self.lm_heads))
285
+ else:
286
+ # If hidden_out_layers is provided (e.g. [-1, -2, -3...]), fetch them from all_hidden_states
287
+ # Note: outputs.hidden_states includes embedding output at index 0 usually.
288
+ all_hs = outputs.hidden_states
289
+ hidden_states_for_heads = [all_hs[idx] for idx in hidden_out_layers]
290
+
291
+ # 4. Project to Logits (Multi-Head)
292
+ layer_logits = []
293
+ for i, (hs, head) in enumerate(zip(hidden_states_for_heads, self.lm_heads)):
294
+ logits = head(hs)
295
+ # Original code logic: Mask the last token index for audio heads (indices > 0)
296
+ # This implies the vocab size is (N+1) but the model shouldn't predict the (N+1)-th token
297
+ # (perhaps reserved for padding in the input but invalid for prediction).
298
+ if i > 0:
299
+ logits[..., -1] = float("-inf")
300
+ layer_logits.append(logits)
301
+
302
+ # 5. Loss Calculation
303
+ loss = None
304
+ all_sum_losses = None
305
+ all_token_nums = None
306
+ sample_losses = None
307
+ channel_losses = None
308
+
309
+ if labels is not None:
310
+ # Ensure labels match input shape rank (B, S, C)
311
+ if labels.dim() != 3:
312
+ raise ValueError(f"Labels must have rank 3 (B, S, C), got {labels.shape}")
313
+
314
+ batch_size = labels.size(0)
315
+ n_heads = len(layer_logits)
316
+
317
+ # Container for per-sample, per-channel losses
318
+ # Shape: [Batch, n_heads]
319
+ all_sum_losses_list = []
320
+
321
+ # Count valid tokens (not -100) per sample.
322
+ # Note: Assuming mask is consistent across channels or we take sum over dim 1 (seq)
323
+ # Usually strict masking means checking one channel or all.
324
+ # Original code: torch.sum(labels != -100, dim=1) -> [B, C]
325
+ all_token_nums = torch.sum(labels != -100, dim=1)
326
+
327
+ for i, logits in enumerate(layer_logits):
328
+ # logits: [B, S, V]
329
+ # cur_labels: [B, S]
330
+ cur_labels = labels[..., i]
331
+
332
+ # Flatten for CrossEntropy
333
+ # logits: [B*S, V], labels: [B*S]
334
+ loss_fct = CrossEntropyLoss(reduction='none')
335
+ vocab_size = logits.size(-1)
336
+
337
+ reshaped_logits = logits.view(-1, vocab_size)
338
+ reshaped_labels = cur_labels.contiguous().view(-1)
339
+
340
+ # Calculate loss per token
341
+ per_token_loss = loss_fct(reshaped_logits, reshaped_labels)
342
+
343
+ # Reshape back to [B, S] and sum over Sequence dimension to get per-sample loss
344
+ per_token_loss = per_token_loss.view(batch_size, -1)
345
+ per_sample_loss = torch.sum(per_token_loss, dim=-1) # [B]
346
+
347
+ all_sum_losses_list.append(per_sample_loss)
348
+
349
+ # Stack to [B, n_heads]
350
+ all_sum_losses = torch.stack(all_sum_losses_list, dim=1)
351
+
352
+ # Weighted Loss Aggregation
353
+ if channelwise_loss_weight is not None:
354
+ if len(channelwise_loss_weight) != n_heads:
355
+ raise ValueError(f"channelwise_loss_weight length {len(channelwise_loss_weight)} != {n_heads}")
356
+
357
+ w_tensor = torch.tensor(channelwise_loss_weight, device=all_sum_losses.device, dtype=all_sum_losses.dtype)
358
+
359
+ # Sample losses: Weighted sum over channels per sample / Total weight
360
+ # Normalize by token count per channel
361
+ # Avoid division by zero with epsilon or mask
362
+ token_counts_safe = all_token_nums.float().clamp(min=1.0)
363
+
364
+ normalized_losses = all_sum_losses / token_counts_safe
365
+ sample_losses = (normalized_losses * w_tensor).sum(dim=1) / w_tensor.sum()
366
+
367
+ # Channel losses: Sum over batch / Sum tokens over batch
368
+ total_loss_per_channel = all_sum_losses.sum(dim=0)
369
+ total_tokens_per_channel = all_token_nums.sum(dim=0).float().clamp(min=1.0)
370
+ channel_losses = total_loss_per_channel / total_tokens_per_channel
371
+
372
+ # Final scalar loss
373
+ loss = (channel_losses * w_tensor).sum() / w_tensor.sum()
374
+ else:
375
+ # Default average if no weights provided
376
+ total_tokens = all_token_nums.sum().float().clamp(min=1.0)
377
+ loss = all_sum_losses.sum() / total_tokens
378
+ channel_losses = all_sum_losses.sum(dim=0) / all_token_nums.sum(dim=0).clamp(min=1.0)
379
+
380
+ return MossTTSDelayOutputWithPast(
381
+ loss=loss,
382
+ all_sum_losses=all_sum_losses,
383
+ all_token_nums=all_token_nums,
384
+ sample_losses=sample_losses,
385
+ channel_losses=channel_losses,
386
+ logits=layer_logits,
387
+ past_key_values=outputs.past_key_values,
388
+ hidden_states=outputs.hidden_states,
389
+ attentions=outputs.attentions,
390
+ )
391
+
392
+ @torch.inference_mode()
393
+ def generate(
394
+ self,
395
+ input_ids: torch.LongTensor,
396
+ attention_mask: Optional[torch.Tensor] = None,
397
+ max_new_tokens: int = 1000,
398
+ text_temperature: float = 1.5,
399
+ text_top_p: float = 0.6,
400
+ text_top_k: int = 50,
401
+ audio_temperature: float = 1.5,
402
+ audio_top_p: float = 0.6,
403
+ audio_top_k: int = 50,
404
+ audio_repetition_penalty: float = 1.2
405
+ ):
406
+ if text_temperature > 0:
407
+ text_do_sample = True
408
+ else:
409
+ text_temperature = 1
410
+ text_do_sample = False
411
+ if audio_temperature > 0:
412
+ audio_do_sample = True
413
+ else:
414
+ audio_temperature = 1
415
+ audio_do_sample = False
416
+
417
+ past_key_values = None
418
+ device = input_ids.device
419
+ current_input_ids = input_ids
420
+ current_attention_mask = attention_mask
421
+ batch_size, seq_len, n_vq = input_ids.shape
422
+ n_vq -= 1
423
+
424
+ generation_ids = input_ids[:]
425
+ is_stopping = torch.zeros(batch_size, dtype=torch.bool, device=device)
426
+
427
+ audio_lengths = torch.zeros(batch_size, dtype=torch.int64, device=device)
428
+ torch_int64_max = torch.iinfo(torch.int64).max
429
+ delayed_lengths = torch.full((batch_size,), torch_int64_max, dtype=torch.int64, device=device)
430
+
431
+ is_continuation = (input_ids[:, -1, 0] == self.config.audio_start_token_id) | (input_ids[:, -1, 0] == self.config.audio_assistant_gen_slot_token_id)
432
+ audio_start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id)
433
+ audio_start_mask = is_continuation & (audio_start_indices != -1)
434
+ audio_lengths[audio_start_mask] = seq_len - audio_start_indices[audio_start_mask]
435
+
436
+ is_audio = audio_start_mask.clone()
437
+
438
+ pre_exclude_mask0 = torch.tensor([self.config.pad_token_id, self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id, self.config.audio_end_token_id], device=device)
439
+ pre_exclude_mask1 = torch.ones(self.config.language_config.vocab_size, device=device).bool()
440
+ pre_exclude_mask1[[self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id]] = False
441
+
442
+ for time_step in tqdm(range(max_new_tokens), desc=f"Generating bs{batch_size} ..."):
443
+ outputs = self(
444
+ input_ids=current_input_ids,
445
+ attention_mask=current_attention_mask,
446
+ past_key_values=past_key_values,
447
+ use_cache=True,
448
+ )
449
+ past_key_values = outputs.past_key_values
450
+
451
+ next_token_logits = [logit[:, -1, :] / text_temperature if logit_idx == 0 else logit[:, -1, :] / audio_temperature for logit_idx, logit in enumerate(outputs.logits)] # List, len=n_vq+1, [batch_size, 1, vocab_size];
452
+ next_token_logits[0] = next_token_logits[0].clone()
453
+ next_text_token = torch.full((batch_size,), self.config.pad_token_id, device=device)
454
+ next_text_token[~is_stopping & (delayed_lengths < n_vq)] = self.config.audio_assistant_delay_slot_token_id
455
+ is_audio_eos = ~is_stopping & (delayed_lengths == n_vq)
456
+ next_text_token[is_audio_eos] = self.config.audio_end_token_id
457
+ is_audio[is_audio_eos] = False
458
+ sampling_text_mask = ~is_stopping & (delayed_lengths > n_vq)
459
+ next_token_logits[0][~is_audio] = next_token_logits[0][~is_audio].index_fill(-1, pre_exclude_mask0, float('-inf'))
460
+ next_token_logits[0][is_audio] = next_token_logits[0][is_audio].masked_fill(pre_exclude_mask1, float('-inf'))
461
+ if time_step == 0:
462
+ next_token_logits[0][..., 151662] = float('-inf')
463
+ if time_step <= n_vq:
464
+ next_token_logits[0][..., self.config.im_end_token_id] = float('-inf')
465
+
466
+ next_text_token[sampling_text_mask] = sample_token(
467
+ logits=next_token_logits[0][sampling_text_mask],
468
+ top_p=text_top_p,
469
+ top_k=text_top_k,
470
+ do_sample=text_do_sample
471
+ )
472
+ is_audio[next_text_token == self.config.audio_start_token_id] = True
473
+ is_stopping[next_text_token == self.config.im_end_token_id] = True
474
+
475
+ next_audio_tokens = torch.full((batch_size, n_vq), self.config.audio_pad_code, device=device)
476
+
477
+ pre_audio_mask = audio_lengths.unsqueeze(1) > torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq)
478
+ post_audio_mask = torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) > delayed_lengths.unsqueeze(1) - 1
479
+ post_audio_mask[delayed_lengths == torch_int64_max] = True
480
+ sampling_audio_mask = pre_audio_mask & post_audio_mask
481
+ next_audio_tokens[~sampling_audio_mask] = self.config.audio_pad_code
482
+
483
+ if sampling_audio_mask.sum() > 0:
484
+ audio_logits = torch.stack(next_token_logits[1:], dim=1)[sampling_audio_mask] # torch.stack -> [batch_size, n_vq - 1, vocab_size]
485
+ audio_logits[..., self.config.audio_pad_code] = float('-inf')
486
+ next_audio_tokens[sampling_audio_mask] = sample_token(
487
+ logits=audio_logits,
488
+ prev_tokens=generation_ids[:, :, 1:],
489
+ repetition_penalty=audio_repetition_penalty,
490
+ top_p=audio_top_p,
491
+ top_k=audio_top_k,
492
+ do_sample=audio_do_sample
493
+ )
494
+
495
+ audio_lengths[(next_text_token == self.config.audio_start_token_id) | (next_text_token == self.config.audio_assistant_gen_slot_token_id) | (next_text_token == self.config.audio_assistant_delay_slot_token_id)] += 1
496
+ audio_lengths[next_text_token == self.config.audio_end_token_id] = 0
497
+ delayed_lengths[(delayed_lengths == torch_int64_max) & (next_text_token == self.config.audio_assistant_delay_slot_token_id)] = 0
498
+ delayed_lengths[delayed_lengths != torch_int64_max] += 1
499
+ delayed_lengths[delayed_lengths > n_vq] = torch_int64_max
500
+
501
+ current_input_ids = torch.cat([next_text_token[:, None, None], next_audio_tokens[:, None, :]], dim=2) # [batch_size, 1, n_vq + 1]
502
+ current_attention_mask = torch.cat([current_attention_mask, (~is_stopping).unsqueeze(-1)], dim=-1)
503
+ generation_ids = torch.cat([generation_ids, current_input_ids], dim=1) # [batch_size, seq_len, n_vq + 1]
504
+
505
+ if is_stopping.sum() == batch_size:
506
+ break
507
+
508
+ start_indices = find_last_equal_C(input_ids[..., 0], self.config.im_start_token_id) + 3
509
+ start_lengths = seq_len - start_indices
510
+
511
+ output = []
512
+ for start_idx, start_length, cur_generation_ids in zip(start_indices, start_lengths, generation_ids):
513
+ output.append((start_length, cur_generation_ids[start_idx:]))
514
+
515
+ return output
processing_moss_tts.py ADDED
@@ -0,0 +1,930 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal, Final, cast
18
+ from dataclasses import dataclass
19
+ from pathlib import Path
20
+ import re
21
+ import torchaudio
22
+
23
+ from transformers import processing_utils
24
+
25
+ processing_utils.MODALITY_TO_BASE_CLASS_MAPPING["audio_tokenizer"] = "PreTrainedModel"
26
+
27
+ import torch
28
+ from transformers import (
29
+ PreTrainedTokenizerBase,
30
+ BatchFeature,
31
+ ProcessorMixin,
32
+ logging,
33
+ AutoConfig,
34
+ AutoModel,
35
+ AutoTokenizer,
36
+ )
37
+
38
+ from .configuration_moss_tts import MossTTSDelayConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ AUDIO_PLACEHOLDER = "<|audio|>"
45
+
46
+
47
+ @dataclass
48
+ class Message:
49
+ def to_dict(self) -> Dict[str, Any]:
50
+ raise NotImplementedError
51
+
52
+
53
+ @dataclass
54
+ class UserMessage(Message):
55
+ text: Optional[str] = None
56
+ reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None
57
+ instruction: Optional[str] = None
58
+ tokens: Optional[int] = None
59
+ quality: Optional[str] = None
60
+ sound_event: Optional[str] = None
61
+ ambient_sound: Optional[str] = None
62
+ language: Optional[str] = None
63
+
64
+ def __post_init__(self):
65
+ template = """<user_inst>
66
+ - Reference(s):
67
+ {reference}
68
+ - Instruction:
69
+ {instruction}
70
+ - Tokens:
71
+ {tokens}
72
+ - Quality:
73
+ {quality}
74
+ - Sound Event:
75
+ {sound_event}
76
+ - Ambient Sound:
77
+ {ambient_sound}
78
+ - Language:
79
+ {language}
80
+ - Text:
81
+ {text}
82
+ </user_inst>"""
83
+
84
+ audio_codes_list = []
85
+ if self.reference is None:
86
+ reference = "None"
87
+ elif isinstance(self.reference, List):
88
+ reference = []
89
+ for speaker_idx, speaker_reference in enumerate(self.reference):
90
+ if speaker_reference is not None:
91
+ reference.append(f"[S{speaker_idx+1}]:\n{AUDIO_PLACEHOLDER}")
92
+ reference = "\n".join(reference)
93
+ audio_codes_list = [
94
+ speaker_reference
95
+ for speaker_reference in self.reference
96
+ if speaker_reference is not None
97
+ ]
98
+ else:
99
+ raise TypeError("`reference` should be exactly a list when it is not None.")
100
+
101
+ content = (
102
+ template.replace("{reference}", str(reference))
103
+ .replace("{instruction}", str(self.instruction))
104
+ .replace("{tokens}", str(self.tokens))
105
+ .replace("{quality}", str(self.quality))
106
+ .replace("{sound_event}", str(self.sound_event))
107
+ .replace("{ambient_sound}", str(self.ambient_sound))
108
+ .replace("{language}", str(self.language))
109
+ .replace("{text}", str(self.text))
110
+ )
111
+
112
+ self._content = content
113
+ self._audio_codes_list = audio_codes_list
114
+
115
+ def to_dict(self):
116
+ return {
117
+ "role": "user",
118
+ "content": self._content,
119
+ "audio_codes_list": self._audio_codes_list,
120
+ }
121
+
122
+
123
+ @dataclass
124
+ class AssistantMessage(Message):
125
+ audio_codes_list: List[Union[str, torch.Tensor]]
126
+ content: str = AUDIO_PLACEHOLDER
127
+
128
+ def to_dict(self):
129
+ return {
130
+ "role": "assistant",
131
+ "content": self.content,
132
+ "audio_codes_list": self.audio_codes_list,
133
+ }
134
+
135
+
136
+ USER_MESSAGE_FIELDS = (
137
+ "text",
138
+ "reference",
139
+ "instruction",
140
+ "tokens",
141
+ "quality",
142
+ "sound_event",
143
+ "ambient_sound",
144
+ "language",
145
+ )
146
+
147
+
148
+ class MossTTSDelayProcessor(ProcessorMixin):
149
+ tokenizer_class = "AutoTokenizer"
150
+ audio_tokenizer_class = "AutoModel"
151
+
152
+ tokenizer: PreTrainedTokenizerBase
153
+ audio_tokenizer: Any
154
+
155
+ def __init__(
156
+ self,
157
+ tokenizer: PreTrainedTokenizerBase,
158
+ audio_tokenizer: Any = None,
159
+ model_config: Optional[MossTTSDelayConfig] = None,
160
+ **kwargs,
161
+ ):
162
+ super().__init__(tokenizer=tokenizer, audio_tokenizer=audio_tokenizer, **kwargs)
163
+
164
+ # Explicit assignments for type-checkers; ProcessorMixin sets these too.
165
+ self.tokenizer = tokenizer
166
+ self.audio_tokenizer = audio_tokenizer
167
+ if model_config is None:
168
+ model_config = MossTTSDelayConfig()
169
+ self.model_config = model_config
170
+
171
+ self.imstart_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
172
+ self.imend_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
173
+ self.newline_token_id = 198
174
+
175
+ def _id_to_token(token_id: int) -> str:
176
+ tok = tokenizer.convert_ids_to_tokens(int(token_id))
177
+ if isinstance(tok, list):
178
+ return tok[0] if len(tok) > 0 else ""
179
+ return cast(str, tok)
180
+
181
+ self.audio_user_slot_token = _id_to_token(
182
+ self.model_config.audio_user_slot_token_id
183
+ )
184
+ self.audio_assistant_gen_slot_token = _id_to_token(
185
+ self.model_config.audio_assistant_gen_slot_token_id
186
+ )
187
+ self.audio_assistant_delay_slot_token = _id_to_token(
188
+ self.model_config.audio_assistant_delay_slot_token_id
189
+ )
190
+ self.audio_start_token = _id_to_token(self.model_config.audio_start_token_id)
191
+ self.audio_end_token = _id_to_token(self.model_config.audio_end_token_id)
192
+
193
+ @classmethod
194
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
195
+ trust_remote_code = kwargs.pop("trust_remote_code", True)
196
+ kwargs.pop("_from_auto", None)
197
+
198
+ audio_tokenizer_name_or_path = kwargs.pop(
199
+ "codec_path", "OpenMOSS-Team/MOSS-Audio-Tokenizer"
200
+ )
201
+
202
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
203
+ model_config = cast(
204
+ MossTTSDelayConfig,
205
+ AutoConfig.from_pretrained(
206
+ pretrained_model_name_or_path,
207
+ *args,
208
+ trust_remote_code=trust_remote_code,
209
+ **kwargs,
210
+ ),
211
+ )
212
+ tokenizer = AutoTokenizer.from_pretrained(
213
+ pretrained_model_name_or_path,
214
+ *args,
215
+ trust_remote_code=trust_remote_code,
216
+ **kwargs,
217
+ )
218
+ audio_tokenizer = AutoModel.from_pretrained(
219
+ audio_tokenizer_name_or_path,
220
+ trust_remote_code=trust_remote_code,
221
+ **kwargs,
222
+ )
223
+
224
+ return cls(
225
+ tokenizer=tokenizer,
226
+ audio_tokenizer=audio_tokenizer,
227
+ model_config=model_config,
228
+ **kwargs,
229
+ )
230
+
231
+ def __call__(self, *args, **kwargs) -> BatchFeature:
232
+ conversations = args[0] if len(args) > 0 else kwargs.pop("conversations")
233
+ mode: str = kwargs.pop("mode", "generation")
234
+ apply_chat_template: bool = kwargs.pop("apply_chat_template", True)
235
+ n_vq: Optional[int] = kwargs.pop("n_vq", None)
236
+
237
+ # Common ProcessorMixin kwargs that we ignore because we always return torch tensors.
238
+ kwargs.pop("return_tensors", None)
239
+ kwargs.pop("padding", None)
240
+ kwargs.pop("truncation", None)
241
+
242
+ """
243
+ mode only works when a Message is converted to a dict.
244
+ """
245
+
246
+ if mode not in {"generation", "continuation"}:
247
+ raise RuntimeError
248
+
249
+ if isinstance(conversations, (Message, Dict)):
250
+ conversations = [conversations]
251
+
252
+ truncation = False
253
+ if mode == "continuation":
254
+ truncation = True
255
+
256
+ input_ids_list = []
257
+ for conversation in conversations:
258
+ if isinstance(conversation, (Message, Dict)):
259
+ conversation = [conversation]
260
+
261
+ # Normalize early so downstream logic always deals with dict messages.
262
+ conversation = [self._normalize_message(m) for m in conversation]
263
+
264
+ if (mode == "generation") ^ (len(conversation) % 2 != 0):
265
+ raise ValueError
266
+
267
+ if (mode == "generation") ^ (conversation[-1]["role"] == "user"):
268
+ raise ValueError
269
+
270
+ unified_codes = []
271
+ for message_idx, message in enumerate(conversation):
272
+ if apply_chat_template:
273
+ add_generation_prompt = (
274
+ mode == "generation" and message_idx == len(conversation) - 1
275
+ )
276
+ try:
277
+ content = self.tokenizer.apply_chat_template(
278
+ [{"role": message["role"], "content": message["content"]}],
279
+ add_generation_prompt=add_generation_prompt,
280
+ tokenize=False,
281
+ )
282
+ except TypeError:
283
+ try:
284
+ content = self.tokenizer.apply_chat_template(
285
+ [
286
+ {
287
+ "role": message["role"],
288
+ "content": message["content"],
289
+ }
290
+ ],
291
+ add_generation_prompt=add_generation_prompt,
292
+ )
293
+ except Exception:
294
+ logger.warning(
295
+ "apply_chat_template failed; fallback to raw content."
296
+ )
297
+ content = message["content"]
298
+ else:
299
+ content = message["content"]
300
+
301
+ if not isinstance(content, str):
302
+ content = str(content)
303
+
304
+ # Batch-encode all path-based references in one call when possible.
305
+ # This ensures we actually exercise audio_tokenizer.batch_encode for multi-reference prompts,
306
+ # instead of repeatedly calling it with batch=1.
307
+ raw_audio_items = message.get("audio_codes_list", [])
308
+
309
+ audio_codes_list: List[torch.Tensor] = []
310
+ if len(raw_audio_items) > 0:
311
+ encoded_items: List[Optional[torch.Tensor]] = [None] * len(
312
+ raw_audio_items
313
+ )
314
+ paths: List[str] = []
315
+ path_positions: List[int] = []
316
+
317
+ for idx, item in enumerate(raw_audio_items):
318
+ if isinstance(item, torch.Tensor):
319
+ if n_vq is not None and item.shape[1] != n_vq:
320
+ raise RuntimeError(
321
+ "audio_codes's n_vq is not equal to the parameter `n_vq`. Your can set the parameter `n_vq` as None if you have already tokenzied the wavs."
322
+ )
323
+ encoded_items[idx] = item
324
+ continue
325
+
326
+ if isinstance(item, (str, os.PathLike)):
327
+ paths.append(str(item))
328
+ path_positions.append(idx)
329
+ continue
330
+
331
+ raise TypeError(
332
+ "Each audio item must be a torch.Tensor of codes or a path-like string."
333
+ )
334
+
335
+ if len(paths) > 0:
336
+ encoded_from_paths = self.encode_audios_from_path(paths, n_vq)
337
+ if len(encoded_from_paths) != len(paths):
338
+ raise RuntimeError(
339
+ "encode_audios_from_path returned an unexpected number of items."
340
+ )
341
+ for pos, codes in zip(path_positions, encoded_from_paths):
342
+ encoded_items[pos] = codes
343
+
344
+ audio_codes_list = [cast(torch.Tensor, t) for t in encoded_items]
345
+ unified_codes.append(
346
+ self._get_unified_codes(
347
+ message["role"], content, audio_codes_list, truncation
348
+ )
349
+ )
350
+
351
+ unified_codes = torch.cat(unified_codes)
352
+ input_ids_list.append(unified_codes)
353
+
354
+ return BatchFeature(data=self._pad(input_ids_list))
355
+
356
+ @staticmethod
357
+ def build_user_message(
358
+ text: Optional[str] = None,
359
+ reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None,
360
+ instruction: Optional[str] = None,
361
+ tokens: Optional[int] = None,
362
+ quality: Optional[str] = None,
363
+ sound_event: Optional[str] = None,
364
+ ambient_sound: Optional[str] = None,
365
+ language: Optional[str] = None,
366
+ ) -> Dict:
367
+ if reference is not None and not isinstance(reference, list):
368
+ reference = [reference]
369
+ return UserMessage(
370
+ text=text,
371
+ reference=reference,
372
+ instruction=instruction,
373
+ tokens=tokens,
374
+ quality=quality,
375
+ sound_event=sound_event,
376
+ ambient_sound=ambient_sound,
377
+ language=language,
378
+ ).to_dict()
379
+
380
+ @staticmethod
381
+ def build_assistant_message(
382
+ audio_codes_list: List[Union[str, torch.Tensor]],
383
+ content: str = AUDIO_PLACEHOLDER,
384
+ ) -> Dict:
385
+ return AssistantMessage(
386
+ audio_codes_list=audio_codes_list,
387
+ content=content,
388
+ ).to_dict()
389
+
390
+ def _normalize_message(self, message: Union[Message, Dict]) -> Dict:
391
+ if isinstance(message, Message):
392
+ return message.to_dict()
393
+ if not isinstance(message, dict):
394
+ raise TypeError("Each message must be a Message or dict.")
395
+ if "role" not in message:
396
+ raise ValueError("Message dict must include a 'role' field.")
397
+ if "content" in message and "audio_codes_list" in message:
398
+ return message
399
+ role = message["role"]
400
+ if role == "user":
401
+ kwargs = {key: message.get(key) for key in USER_MESSAGE_FIELDS}
402
+ return self.build_user_message(**kwargs)
403
+ if role == "assistant":
404
+ return self.build_assistant_message(
405
+ audio_codes_list=message.get("audio_codes_list", []),
406
+ content=message.get("content", AUDIO_PLACEHOLDER),
407
+ )
408
+ raise ValueError(f"Unsupported role: {role}")
409
+
410
+ def _pad(self, input_ids_list: List[torch.Tensor]):
411
+ device = input_ids_list[0].device
412
+ lengths = torch.tensor([w.shape[0] for w in input_ids_list], device=device)
413
+ pad_input_ids = torch.nn.utils.rnn.pad_sequence(
414
+ input_ids_list,
415
+ batch_first=True,
416
+ padding_value=self.model_config.audio_pad_code,
417
+ padding_side="left",
418
+ )
419
+ other_channel_mask = (pad_input_ids.shape[1] - lengths).unsqueeze(
420
+ 1
421
+ ) > torch.arange(pad_input_ids.shape[1], device=device).unsqueeze(0)
422
+ pad_input_ids[..., 0][other_channel_mask] = self.model_config.pad_token_id
423
+ attention_mask = torch.zeros(
424
+ pad_input_ids.shape[0], pad_input_ids.shape[1], device=device
425
+ )
426
+ attention_mask[~other_channel_mask] = 1
427
+ attention_mask = attention_mask.bool()
428
+ return {
429
+ "input_ids": pad_input_ids, # [batch_size, seqlen, n_vq]
430
+ "attention_mask": attention_mask,
431
+ }
432
+
433
+ @staticmethod
434
+ def _replace_audio_placeholders(
435
+ content: str,
436
+ lengths: List[int],
437
+ n_vq: int,
438
+ gen_slot_token: str,
439
+ delay_slot_token: str,
440
+ audio_start_token: str,
441
+ audio_end_token: str,
442
+ ) -> str:
443
+ if n_vq < 1:
444
+ raise ValueError(f"n_vq must be >= 1, got {n_vq}")
445
+
446
+ num_placeholders = content.count(AUDIO_PLACEHOLDER)
447
+ if num_placeholders != len(lengths):
448
+ raise ValueError(
449
+ f"Number of {AUDIO_PLACEHOLDER} ({num_placeholders}) "
450
+ f"does not match lengths ({len(lengths)})"
451
+ )
452
+
453
+ def build_audio_block(length: int) -> str:
454
+ if length < 0:
455
+ raise ValueError(f"length must be >= 0, got {length}")
456
+
457
+ if length == 0:
458
+ return f"{audio_start_token}{audio_end_token}"
459
+
460
+ step_tokens = gen_slot_token * length + (delay_slot_token * (n_vq - 1))
461
+ return f"{audio_start_token}{step_tokens}{audio_end_token}"
462
+
463
+ lengths_iter = iter(lengths)
464
+
465
+ def replacer(match: re.Match) -> str:
466
+ length = next(lengths_iter)
467
+ return build_audio_block(length)
468
+
469
+ result = re.sub(re.escape(AUDIO_PLACEHOLDER), replacer, content)
470
+
471
+ return result
472
+
473
+ @staticmethod
474
+ def _merge_consecutive_audio_placeholders(
475
+ content: str,
476
+ audio_codes_list: List[torch.Tensor],
477
+ ) -> Tuple[str, List[torch.Tensor]]:
478
+ matches = list(re.finditer(re.escape(AUDIO_PLACEHOLDER), content))
479
+ if len(matches) <= 1:
480
+ return content, audio_codes_list
481
+
482
+ if len(matches) != len(audio_codes_list):
483
+ raise ValueError(
484
+ "Audio placeholders do not match the provided audio codes list."
485
+ )
486
+
487
+ new_audio_codes_list = []
488
+ new_parts = []
489
+ last_pos = 0
490
+ i = 0
491
+ while i < len(matches):
492
+ j = i
493
+ while (
494
+ j + 1 < len(matches)
495
+ and content[matches[j].end() : matches[j + 1].start()].strip() == ""
496
+ ):
497
+ j += 1
498
+
499
+ new_parts.append(content[last_pos : matches[i].start()])
500
+ new_parts.append(AUDIO_PLACEHOLDER)
501
+ last_pos = matches[j].end()
502
+
503
+ if j == i:
504
+ new_audio_codes_list.append(audio_codes_list[i])
505
+ else:
506
+ new_audio_codes_list.append(
507
+ torch.cat(audio_codes_list[i : j + 1], dim=0)
508
+ )
509
+
510
+ i = j + 1
511
+
512
+ new_parts.append(content[last_pos:])
513
+ return "".join(new_parts), new_audio_codes_list
514
+
515
+ @staticmethod
516
+ def apply_delay_pattern(codes: torch.Tensor, pad_code: int) -> torch.Tensor:
517
+ delayed_tokens = torch.full(
518
+ (codes.shape[0] + codes.shape[1] - 1, codes.shape[1]),
519
+ pad_code,
520
+ device=codes.device,
521
+ dtype=codes.dtype,
522
+ )
523
+ for i in range(codes.shape[1]):
524
+ delayed_tokens[i : i + codes.shape[0], i] = codes[:, i]
525
+ return delayed_tokens
526
+
527
+ @staticmethod
528
+ def apply_de_delay_pattern(delay_codes: torch.Tensor) -> torch.Tensor:
529
+ tokens = torch.full(
530
+ (delay_codes.shape[0] - delay_codes.shape[1] + 1, delay_codes.shape[1]),
531
+ 0,
532
+ device=delay_codes.device,
533
+ dtype=delay_codes.dtype,
534
+ )
535
+ for i in range(delay_codes.shape[1]):
536
+ tokens[:, i] = delay_codes[i : i + tokens.shape[0], i]
537
+ return tokens
538
+
539
+ def _get_unified_codes(
540
+ self,
541
+ role: str,
542
+ content: str,
543
+ audio_codes_list: List[torch.Tensor],
544
+ truncation: bool,
545
+ ) -> torch.Tensor:
546
+ """
547
+ 此时的 content 已经是带上了对话格式
548
+ """
549
+ if role == "user":
550
+ audio_gen_slot_token = audio_delay_slot_token = self.audio_user_slot_token
551
+ truncation = False
552
+ else:
553
+ audio_gen_slot_token = self.audio_assistant_gen_slot_token
554
+ audio_delay_slot_token = self.audio_assistant_delay_slot_token
555
+
556
+ if len(audio_codes_list):
557
+ n_vq = audio_codes_list[0].shape[1]
558
+ else:
559
+ n_vq = self.model_config.n_vq
560
+
561
+ if len(audio_codes_list) > 1 and AUDIO_PLACEHOLDER in content:
562
+ content, audio_codes_list = self._merge_consecutive_audio_placeholders(
563
+ content, audio_codes_list
564
+ )
565
+ content = self._replace_audio_placeholders(
566
+ content=content,
567
+ lengths=[len(audio_codes) for audio_codes in audio_codes_list],
568
+ n_vq=n_vq,
569
+ gen_slot_token=audio_gen_slot_token,
570
+ delay_slot_token=audio_delay_slot_token,
571
+ audio_start_token=self.audio_start_token,
572
+ audio_end_token=self.audio_end_token,
573
+ )
574
+ text_codes = torch.tensor(
575
+ self.tokenizer.encode(content),
576
+ device=audio_codes_list[0].device if audio_codes_list else None,
577
+ )
578
+
579
+ audio_start_indices = torch.where(
580
+ text_codes == self.model_config.audio_start_token_id
581
+ )[0]
582
+ audio_end_indices = torch.where(
583
+ text_codes == self.model_config.audio_end_token_id
584
+ )[0]
585
+ if len(audio_start_indices) != len(audio_codes_list) or len(
586
+ audio_end_indices
587
+ ) != len(audio_codes_list):
588
+ raise ValueError(
589
+ "Audio placeholders do not match the provided audio codes list."
590
+ )
591
+
592
+ delay_audio_codes_list = []
593
+ if len(audio_codes_list) == 0:
594
+ delay_audio_codes_list = torch.full(
595
+ (len(text_codes), n_vq),
596
+ self.model_config.audio_pad_code,
597
+ device=text_codes.device,
598
+ dtype=text_codes.dtype,
599
+ )
600
+ else:
601
+ prefix_idx = 0
602
+ for audio_start_idx_t, audio_end_idx_t, audio_codes in zip(
603
+ audio_start_indices, audio_end_indices, audio_codes_list
604
+ ):
605
+ audio_start_idx = int(audio_start_idx_t.item())
606
+ audio_end_idx = int(audio_end_idx_t.item())
607
+ delay_audio_codes = self.apply_delay_pattern(
608
+ audio_codes, self.model_config.audio_pad_code
609
+ )
610
+ pad_codes = torch.full(
611
+ (audio_start_idx - prefix_idx + 1, n_vq),
612
+ self.model_config.audio_pad_code,
613
+ device=audio_codes.device,
614
+ dtype=audio_codes.dtype,
615
+ )
616
+ delay_audio_codes_list.extend([pad_codes, delay_audio_codes])
617
+ prefix_idx = audio_end_idx
618
+
619
+ if truncation:
620
+ delay_audio_codes_list[-1] = delay_audio_codes_list[-1][
621
+ : -(n_vq - 1), :
622
+ ]
623
+ else:
624
+ last_audio_end_idx = int(audio_end_indices[-1].item())
625
+ pad_codes = torch.full(
626
+ (len(text_codes) - last_audio_end_idx, n_vq),
627
+ self.model_config.audio_pad_code,
628
+ device=audio_codes_list[0].device,
629
+ dtype=audio_codes_list[0].dtype,
630
+ )
631
+ delay_audio_codes_list.append(pad_codes)
632
+
633
+ delay_audio_codes_list = torch.cat(delay_audio_codes_list)
634
+
635
+ if text_codes.shape[0] != delay_audio_codes_list.shape[0]:
636
+ text_codes = text_codes[: delay_audio_codes_list.shape[0]]
637
+
638
+ unified_codes = torch.cat(
639
+ [text_codes.unsqueeze(1), delay_audio_codes_list], dim=1
640
+ )
641
+ return unified_codes
642
+
643
+ def _parse_text_codes(self, start_length, text_codes):
644
+ text = cast(str, self.tokenizer.decode(text_codes))
645
+ prefix = cast(str, self.tokenizer.decode(text_codes[:start_length]))
646
+ text = text[len(prefix) :]
647
+
648
+ AUDIO_PATTERN = re.compile(
649
+ rf"(?:{self.audio_start_token})?"
650
+ rf"(?:{self.audio_assistant_gen_slot_token})*"
651
+ rf"(?:{self.audio_assistant_delay_slot_token})*"
652
+ rf"{self.audio_end_token}"
653
+ )
654
+
655
+ def normalize_audio_segments(text: str) -> str:
656
+ def repl(match: re.Match) -> str:
657
+ seg = match.group(0)
658
+ # Replace with <|audio|> if gen_slot is present in the segment;
659
+ if self.audio_assistant_gen_slot_token in seg:
660
+ return AUDIO_PLACEHOLDER
661
+ # Otherwise, remove it.
662
+ return ""
663
+
664
+ return AUDIO_PATTERN.sub(repl, text)
665
+
666
+ return normalize_audio_segments(text)
667
+
668
+ def _parse_audio_codes(self, start_length, audio_codes):
669
+ # De-delay back to [T', n_vq]
670
+ audio_codes = self.apply_de_delay_pattern(audio_codes)
671
+
672
+ # Rows that are all pad are separators between real audio segments.
673
+ is_pad = (audio_codes == self.model_config.audio_pad_code).all(dim=1)
674
+ non_pad = ~is_pad
675
+ if not non_pad.any():
676
+ return []
677
+
678
+ idx = torch.nonzero(non_pad).squeeze(1)
679
+ breaks = torch.where(idx[1:] != idx[:-1] + 1)[0] + 1
680
+ if breaks.numel() == 0:
681
+ segments_idx = [idx]
682
+ else:
683
+ segments_idx = torch.split(idx, breaks.tolist())
684
+
685
+ audio_codes_list = [audio_codes[s] for s in segments_idx]
686
+
687
+ # Batch-decode all audio segments together.
688
+ decoded_audio_list = self.decode_audio_codes(audio_codes_list)
689
+
690
+ # Keep codec causal context by decoding the whole first segment first,
691
+ # then trim at waveform level according to start_length ratio.
692
+ if (
693
+ start_length > 0
694
+ and len(audio_codes_list) > 0
695
+ and len(decoded_audio_list) > 0
696
+ ):
697
+ first_codes_length = audio_codes_list[0].shape[0]
698
+ if first_codes_length > 0:
699
+ trim_ratio = max(
700
+ 0.0, min(float(start_length) / float(first_codes_length), 1.0)
701
+ )
702
+ first_audio = decoded_audio_list[0]
703
+ if trim_ratio >= 1.0:
704
+ decoded_audio_list = decoded_audio_list[1:]
705
+ elif trim_ratio > 0.0:
706
+ trim_samples = int(first_audio.shape[-1] * trim_ratio)
707
+ decoded_audio_list[0] = first_audio[..., trim_samples:]
708
+
709
+ return decoded_audio_list
710
+
711
+ def decode(self, output: List[Tuple[int, torch.Tensor]]):
712
+ """
713
+ 1. 这里不管怎样,都需要一个完整的 assistant generation ids;
714
+ 2. 支持从任意位置进行截断;
715
+ """
716
+
717
+ genearted_messages = []
718
+ for start_length, generation_ids in output:
719
+ content = self._parse_text_codes(start_length, generation_ids[:, 0])
720
+ audio_codes_list = self._parse_audio_codes(
721
+ start_length, generation_ids[:, 1:]
722
+ )
723
+ if content == "":
724
+ message = None
725
+ else:
726
+ message = AssistantMessage(
727
+ content=content,
728
+ audio_codes_list=cast(
729
+ List[Union[str, torch.Tensor]], audio_codes_list
730
+ ),
731
+ )
732
+ genearted_messages.append(message)
733
+ return genearted_messages
734
+
735
+ @staticmethod
736
+ def loudness_normalize(
737
+ wav: torch.Tensor,
738
+ target_dbfs: float = -20,
739
+ gain_range: tuple[float, float] = (-3.0, 3.0),
740
+ ) -> torch.Tensor:
741
+ wav = wav.to(torch.float32)
742
+ if wav.numel() == 0:
743
+ return wav
744
+ current_dbfs = 10.0 * torch.log10(torch.mean(wav**2) + 1e-9)
745
+ gain = float(target_dbfs - current_dbfs)
746
+ gain = max(gain_range[0], min(gain, gain_range[1]))
747
+ factor = 10.0 ** (gain / 20.0)
748
+ return wav * factor
749
+
750
+ def _get_audio_tokenizer_device(self) -> torch.device:
751
+ """Best-effort device inference for `self.audio_tokenizer`.
752
+
753
+ Notes:
754
+ - Old TAC wrapper exposed `.device`, but standard `torch.nn.Module` does not.
755
+ - New MossAudioTokenizerModel is a `PreTrainedModel`; parameters define its device.
756
+ """
757
+
758
+ audio_tokenizer = getattr(self, "audio_tokenizer", None)
759
+ if audio_tokenizer is None:
760
+ logger.warning(
761
+ "audio_tokenizer is not set on processor. Using CPU as default."
762
+ )
763
+ return torch.device("cpu")
764
+
765
+ device_attr = getattr(audio_tokenizer, "device", None)
766
+ if isinstance(device_attr, torch.device):
767
+ return device_attr
768
+
769
+ try:
770
+ return next(audio_tokenizer.parameters()).device
771
+ except StopIteration:
772
+ # No parameters (shouldn't happen for real models); default to CPU.
773
+ logger.warning(
774
+ "No parameters found on audio_tokenizer. Using CPU as default."
775
+ )
776
+ return torch.device("cpu")
777
+
778
+ def encode_audios_from_wav(
779
+ self,
780
+ wav_list: List[torch.Tensor],
781
+ sampling_rate: int,
782
+ n_vq: Optional[int] = None,
783
+ ):
784
+ if self.audio_tokenizer is None:
785
+ raise RuntimeError("audio_tokenizer is not set on processor.")
786
+ audio_tokenizer = self.audio_tokenizer
787
+
788
+ if isinstance(wav_list, torch.Tensor):
789
+ wav_list = [wav_list]
790
+ wav_list_ = []
791
+ resample = False
792
+ if sampling_rate != self.model_config.sampling_rate:
793
+ resample = True
794
+ device = self._get_audio_tokenizer_device()
795
+ for wav in wav_list:
796
+ if wav.shape[0] > 1:
797
+ wav = torch.mean(wav, dim=0, keepdim=True)
798
+ if resample:
799
+ wav = torchaudio.functional.resample(
800
+ waveform=wav,
801
+ orig_freq=sampling_rate,
802
+ new_freq=self.model_config.sampling_rate,
803
+ )
804
+ wav = wav.to(device)
805
+ wav_list_.append(self.loudness_normalize(wav.squeeze(0)))
806
+
807
+ # New MossAudioTokenizerModel API: prefer batch_encode(list[wav])
808
+ if hasattr(audio_tokenizer, "batch_encode"):
809
+ enc = audio_tokenizer.batch_encode(wav_list_, num_quantizers=n_vq)
810
+ audio_codes = enc.audio_codes # (NQ, B, T)
811
+ audio_codes_lengths = enc.audio_codes_lengths # (B,)
812
+ else:
813
+ # Fallback: use encode() with explicit padding.
814
+ max_len = max(int(wav.shape[-1]) for wav in wav_list_)
815
+ input_values = torch.zeros(
816
+ len(wav_list_), 1, max_len, device=device, dtype=torch.float32
817
+ )
818
+ padding_mask = torch.zeros(
819
+ len(wav_list_), max_len, device=device, dtype=torch.bool
820
+ )
821
+ for i, wav in enumerate(wav_list_):
822
+ this_len = int(wav.shape[-1])
823
+ input_values[i, 0, :this_len] = wav
824
+ padding_mask[i, :this_len] = True
825
+ enc = audio_tokenizer.encode(
826
+ input_values,
827
+ padding_mask=padding_mask,
828
+ num_quantizers=n_vq,
829
+ return_dict=True,
830
+ )
831
+ audio_codes = enc.audio_codes
832
+ audio_codes_lengths = enc.audio_codes_lengths
833
+
834
+ if audio_codes is None or audio_codes_lengths is None:
835
+ raise RuntimeError(
836
+ "audio_tokenizer.encode() returned empty outputs (audio_codes/audio_codes_lengths)."
837
+ )
838
+
839
+ # Keep processor's historical contract: list[Tensor] with shape (T, NQ)
840
+ # and on CPU (so downstream text/audio packing remains device-agnostic).
841
+ codes_list: List[torch.Tensor] = []
842
+ for i in range(int(audio_codes.shape[1])):
843
+ length_i = int(audio_codes_lengths[i].item())
844
+ codes_i = (
845
+ audio_codes[:, i, :length_i]
846
+ .transpose(0, 1)
847
+ .contiguous()
848
+ .to(torch.long)
849
+ .cpu()
850
+ )
851
+ codes_list.append(codes_i)
852
+ return codes_list
853
+
854
+ def encode_audios_from_path(
855
+ self, wav_path_list: Union[str, List[str]], n_vq: Optional[int] = None
856
+ ):
857
+ if isinstance(wav_path_list, str):
858
+ wav_path_list = [wav_path_list]
859
+
860
+ if len(wav_path_list) == 0:
861
+ raise ValueError("Empty wav_path_list")
862
+
863
+ # Load + (if needed) resample each wav independently, so callers can
864
+ # pass a heterogeneous batch of files while still benefiting from
865
+ # audio_tokenizer.batch_encode.
866
+ target_sr = int(self.model_config.sampling_rate)
867
+ wav_list: List[torch.Tensor] = []
868
+ for wav_path in wav_path_list:
869
+ wav, sr = torchaudio.load(wav_path)
870
+ if int(sr) != target_sr:
871
+ wav = torchaudio.functional.resample(
872
+ waveform=wav,
873
+ orig_freq=int(sr),
874
+ new_freq=target_sr,
875
+ )
876
+ wav_list.append(wav)
877
+
878
+ return self.encode_audios_from_wav(wav_list, target_sr, n_vq)
879
+
880
+ def decode_audio_codes(
881
+ self, audio_tokens_list: Union[torch.Tensor, List[torch.Tensor]]
882
+ ):
883
+ if self.audio_tokenizer is None:
884
+ raise RuntimeError("audio_tokenizer is not set on processor.")
885
+ audio_tokenizer = self.audio_tokenizer
886
+
887
+ if isinstance(audio_tokens_list, torch.Tensor):
888
+ audio_tokens_list = [audio_tokens_list]
889
+ if len(audio_tokens_list) == 0:
890
+ return []
891
+
892
+ device = self._get_audio_tokenizer_device()
893
+
894
+ # Processor uses (T, NQ); MossAudioTokenizer expects (NQ, T) (or (NQ, B, T)).
895
+ codes_list = [
896
+ codes.transpose(0, 1).contiguous().to(device=device, dtype=torch.long)
897
+ for codes in audio_tokens_list
898
+ ]
899
+
900
+ # Fallback: pad to (NQ, B, T) + mask, then decode.
901
+ nq = int(codes_list[0].shape[0])
902
+ max_t = max(int(c.shape[1]) for c in codes_list)
903
+ audio_codes = torch.zeros(
904
+ nq, len(codes_list), max_t, device=device, dtype=torch.long
905
+ )
906
+ padding_mask = torch.zeros(
907
+ len(codes_list), max_t, device=device, dtype=torch.bool
908
+ )
909
+ for i, c in enumerate(codes_list):
910
+ t = int(c.shape[1])
911
+ audio_codes[:, i, :t] = c
912
+ padding_mask[i, :t] = True
913
+ dec = audio_tokenizer.decode(
914
+ audio_codes, padding_mask=padding_mask, return_dict=True, chunk_duration=8
915
+ )
916
+ audio = dec.audio
917
+ audio_lengths = dec.audio_lengths
918
+
919
+ if audio is None or audio_lengths is None:
920
+ raise RuntimeError(
921
+ "audio_tokenizer.decode() returned empty outputs (audio/audio_lengths)."
922
+ )
923
+
924
+ # Return historical contract: list of 1D waveforms (T,)
925
+ wav_list: List[torch.Tensor] = []
926
+ for i in range(int(audio.shape[0])):
927
+ length_i = int(audio_lengths[i].item())
928
+ wav = audio[i, 0, :length_i].contiguous().to(torch.float32).cpu()
929
+ wav_list.append(wav)
930
+ return wav_list
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "MossTTSDelayProcessor",
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_moss_tts.MossTTSDelayProcessor"
5
+ }
6
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|audio_start|>",
12
+ "<|audio_end|>",
13
+ "<|audio_user_slot|>",
14
+ "<|image_pad|>",
15
+ "<|audio_assistant_gen_slot|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb3c8fa82993d515469c2800cc455bff4aaa3c4fed9da1f2b0c0668c304f335a
3
+ size 11422691
tokenizer_config.json ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|audio_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|audio_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|audio_user_slot|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|audio_assistant_gen_slot|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|audio_assistant_delay_slot|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|audio_start|>",
224
+ "<|audio_end|>",
225
+ "<|audio_user_slot|>",
226
+ "<|image_pad|>",
227
+ "<|audio_assistant_gen_slot|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|endoftext|>",
236
+ "processor_class": "AsteroidProcessor",
237
+ "split_special_tokens": false,
238
+ "tokenizer_class": "Qwen2Tokenizer",
239
+ "unk_token": null
240
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff